# -*- coding: utf-8 -*-

"""
Import an MLflow model into a Dataiku Saved Model using the Dataiku API.
This script supports versioning and reusability with minimal modifications.
"""

import dataiku
import dataikuapi
import os
from dataikuapi.dss.ml import DSSPredictionMLTaskSettings


# ---------------------- CONFIGURATION ----------------------
CONFIG = {
    "version_id": "v32",  # Version ID for the new model version
    "saved_model_name": "catboost-uci-bank",  # Name of the Saved Model in DSS
    "input_folder_id": "M8tpkIby",  # ID of the managed folder with the MLflow model
    "path_model_folder": "/mlflow-model-import/dist/",  # Subfolder path inside the input folder
    "catboost_model_dir": "catboost-uci-bank-20211213-162747",  # Name of the actual model folder inside dist
    "target_column": "y", # Target column
    "class_labels": ["no", "yes"], # Class labels in the target column
    "evaluation_dataset": "eval_data",  # Dataset to use for evaluation
}
# -----------------------------------------------------------


# --- Utility Functions ---
def get_code_env_name():
    """Retrieve the current code environment name from custom variables."""
    return dataiku.get_custom_variables().get("code_env")


def get_or_create_saved_model(project, model_name):
    """Return a SavedModel handle. Create it if it doesn't exist."""
    for sm in project.list_saved_models():
        if sm["name"] == model_name:
            print(f"Found SavedModel '{model_name}' with ID: {sm['id']}")
            return project.get_saved_model(sm["id"])
    
    sm = project.create_mlflow_pyfunc_model(
        name=model_name,
        prediction_type=DSSPredictionMLTaskSettings.PredictionTypes.BINARY
    )
    print(f"Created new SavedModel '{model_name}' with ID: {sm.id}")
    return sm


def import_mlflow_model(saved_model, version_id, model_path, code_env_name):
    """Import the MLflow model and create a new version."""
    if any(v["id"] == version_id for v in saved_model.list_versions()):
        raise ValueError(f"Version '{version_id}' already exists. Choose a different version ID.")
    
    return saved_model.import_mlflow_version_from_path(
        version_id=version_id,
        path=model_path,
        code_env_name=code_env_name
    )


def evaluate_model_version(sm_version, target_col, class_labels, eval_dataset):
    """Configure metadata and evaluate the new model version."""
    sm_version.set_core_metadata(
        target_column_name=target_col,
        class_labels=class_labels,
        get_features_from_dataset=eval_dataset
    )
    sm_version.evaluate(eval_dataset)
    

def main(config):
    """Main execution function: import an MLflow model into a Dataiku Saved Model."""
    
    # Connect to Dataiku
    client = dataiku.api_client()
    project = client.get_default_project()
    
    # Resolve model input path
    folder = dataiku.Folder(config["input_folder_id"])
    model_path = os.path.join(
        folder.get_path(),
        config["path_model_folder"].lstrip("/"),
        config["catboost_model_dir"]
    )

    # Get or create Saved Model
    saved_model = get_or_create_saved_model(project, config["saved_model_name"])

    # Import and evaluate new model version
    sm_version = import_mlflow_model(
        saved_model,
        version_id=config["version_id"],
        model_path=model_path,
        code_env_name=get_code_env_name()
    )
    evaluate_model_version(
        sm_version,
        target_col=config["target_column"],
        class_labels=config["class_labels"],
        eval_dataset=config["evaluation_dataset"]
    )
    
if __name__ == "__main__":
    main(CONFIG)