# -*- coding: utf-8 -*-

"""
Train a Python custom model, track the experiments using MLflow capabilities and save it into a Dataiku Saved Model using the Dataiku API.
This script supports versioning and reusability with minimal modifications.
"""

import dataiku
import pandas as pd
import mlflow
from datetime import datetime
from sklearn.model_selection import cross_validate, StratifiedKFold, ParameterGrid
from sklearn.pipeline import make_pipeline
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OrdinalEncoder
from sklearn.ensemble import RandomForestClassifier
from dataikuapi.dss.ml import DSSPredictionMLTaskSettings

# ---------------------- CONFIGURATION ----------------------
CONFIG = {
    "experiment_tracking_folder_name": "Tabular Classification Experiments",
    "experiment_tracking_folder_connection": "filesystem_folders",
    "experiment_name": "Tabular-Experiments",
    "saved_model_name": "Tabular Classification Model",
    "dataset_training": "input-tabular_classification",
    "dataset_evaluation": "evaluate-tabular_classification",
    "auto_deploy": True,
    "deployment_metric": "mean_test_roc_auc",
    "cv_folds": 5,
    "metrics": ["f1_macro", "roc_auc"],
    "categorical_cols": [
        "job", "marital", "education", "default", "housing",
        "loan", "contact", "month", "poutcome"
    ],
    "hparams_dict": {
        "n_estimators": [50, 100, 70],
        "criterion": ["gini"],
        "max_depth": [12, 20],
        "min_samples_split": [3],
        "random_state": [42]
    },
    "target_column": "y",
    "class_labels": ["no", "yes"]
}
# -----------------------------------------------------------

# --- Utility Functions ---
def get_code_env_name():
    """Retrieve the current code environment name from custom variables."""
    return dataiku.get_custom_variables().get("code_env")


def now_str() -> str:
    return datetime.now().strftime("%Y%m%d%H%M%S")


def get_or_create_folder(project):
    """Gets or creates a managed folder by name within the project."""
    for folder in project.list_managed_folders():
        if folder["name"] == CONFIG["experiment_tracking_folder_name"]:
            print(f"Found folder: {folder['name']}")
            return project.get_managed_folder(folder["id"])
    print("Creating experiment folder...")
    return project.create_managed_folder(CONFIG["experiment_tracking_folder_name"],
                                         connection_name=CONFIG["experiment_tracking_folder_connection"])


def get_or_create_experiment(mlflow, mlflow_ext):
    """Checks if an MLflow experiment exists, creates it if not, and sets it as active."""
    existing_experiments = mlflow_ext.list_experiments()

    if len(existing_experiments.get("experiments", [])) >= 1:
        if CONFIG["experiment_name"] in [exp["name"] for exp in existing_experiments["experiments"]]:
            print(f"Experiment already existed with name: {CONFIG['experiment_name']}")
        else:
            mlflow.create_experiment(CONFIG["experiment_name"])
            print(f"Experiment created with name: {CONFIG['experiment_name']}")
    else:
        mlflow.create_experiment(CONFIG["experiment_name"])
        print(f"Experiment created with name: {CONFIG['experiment_name']}")

    experiment = mlflow.get_experiment_by_name(CONFIG["experiment_name"])
    experiment_id = experiment.experiment_id
    print(f"Experiment name is: {CONFIG['experiment_name']} and ID is: {experiment_id}")
    mlflow.set_experiment(CONFIG["experiment_name"])
    return experiment_id


def build_pipeline(hparams):
    """Builds a Scikit-learn pipeline with a standard scaler and a RandomForestClassifier."""
    preprocessor = ColumnTransformer([
        ('categorical', OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1), CONFIG["categorical_cols"])
    ], remainder="passthrough")
    model = RandomForestClassifier(**hparams)
    return make_pipeline(preprocessor, model)


def log_run(pipeline, hparams, mlflow, data, target):
    """Trains the model and logs parameters, metrics, and the model to MLflow."""
    run_name = f"run-{now_str()}"
    with mlflow.start_run(run_name=run_name) as run:
        run_id = run.info.run_id
        run_params = {
            "model_algo": "RandomForestClassifier",
            "categorical_cols": CONFIG["categorical_cols"],
            **hparams
        }

        print(f"Running CV for {run_name}...")
        scores = cross_validate(pipeline, data, target,
                                cv=StratifiedKFold(n_splits=CONFIG["cv_folds"]),
                                scoring=CONFIG["metrics"])
        run_metrics = {f"mean_test_{m}": scores[f"test_{m}"].mean() for m in CONFIG["metrics"]}
        run_metrics.update({f"std_test_{m}": scores[f"test_{m}"].std() for m in CONFIG["metrics"]})

        pipeline.fit(data, target)
        run_params["class_labels"] = pipeline.classes_.tolist()

        mlflow.log_params(run_params)
        mlflow.log_metrics(run_metrics)
        mlflow.sklearn.log_model(pipeline, artifact_path=f"RandomForestClassifier-{run_id}")

        return run.info.artifact_uri, run_id, run_params["class_labels"]


def get_best_run(mlflow, experiment_id):
    """Get the best run."""
    best_run = None
    for _, run_info in mlflow.search_runs(experiment_id).iterrows():
        run = mlflow.get_run(run_info["run_id"])
        if best_run is None or run.data.metrics.get(CONFIG["deployment_metric"], 0) > best_run.data.metrics.get(CONFIG["deployment_metric"], 0):
            best_run = run
    return best_run


def get_or_create_saved_model(project):
    """Get or create a Saved Model."""
    for sm in project.list_saved_models():
        if sm["name"] == CONFIG["saved_model_name"]:
            print(f"Found Saved Model: {sm['name']}")
            return project.get_saved_model(sm["id"])
    print("Creating Saved Model...")
    return project.create_mlflow_pyfunc_model(
        name=CONFIG["saved_model_name"],
        prediction_type=DSSPredictionMLTaskSettings.PredictionTypes.BINARY
    )


def deploy_model(sm, mlflow, folder, model_path, run_id, class_labels):
    """Deploy the model to the Flow."""
    mlflow_version = sm.import_mlflow_version_from_managed_folder(
        version_id=run_id,
        managed_folder=folder,
        path=model_path,
        code_env_name=get_code_env_name()
    )
    sm.set_active_version(mlflow_version.version_id)
    print(f"Evaluate Saved Model...")
    mlflow_version.set_core_metadata(
        target_column_name=CONFIG["target_column"],
        class_labels=class_labels,
        get_features_from_dataset=CONFIG["dataset_evaluation"]
    )
    mlflow_version.evaluate(CONFIG["dataset_evaluation"])


def main():
    """Main execution function: trains multiple models and logs them using MLflow."""
    client = dataiku.api_client()
    project = client.get_default_project()
    df = dataiku.Dataset(CONFIG["dataset_training"]).get_dataframe()
    target = df[CONFIG["target_column"]]
    data = df.drop(columns=[CONFIG["target_column"]])

    folder = get_or_create_folder(project)
    mlflow_ext = project.get_mlflow_extension()
    
    param_grid = ParameterGrid(CONFIG["hparams_dict"])

    with project.setup_mlflow(managed_folder=folder) as mlflow:
        experiment_id = get_or_create_experiment(mlflow, mlflow_ext)

        for hparams in param_grid:
            pipeline = build_pipeline(hparams)
            artifact_uri, run_id, class_labels = log_run(pipeline, hparams, mlflow, data, target)
            mlflow_ext.set_run_inference_info(
                run_id=run_id,
                prediction_type="BINARY_CLASSIFICATION",
                classes=class_labels,
                code_env_name=CONFIG["experiment_name"],
                target=CONFIG["target_column"]
            )

    if CONFIG["auto_deploy"]:
        with project.setup_mlflow(managed_folder=folder) as mlflow:
            best_run = get_best_run(mlflow, experiment_id)
            if best_run:
                run_id = best_run.info.run_id
                model_path = f"{experiment_id}/{run_id}/artifacts/RandomForestClassifier-{run_id}"
                print(f"Deploying best model: {run_id}")
                sm = get_or_create_saved_model(project)
                deploy_model(sm, mlflow, folder, model_path, run_id, CONFIG["class_labels"])


if __name__ == "__main__":
    main()