import logging
import pandas as pd

from dataiku.core import schema_handling
from dataiku.core.saved_model import ModelParams
from dataiku.external_ml.mlflow.pyfunc_read_meta import read_user_meta
from dataiku.external_ml.utils import load_external_model_meta

logger = logging.getLogger(__name__)


def get_configured_threshold(model_folder_context):
    return read_user_meta(model_folder_context).get("activeClassifierThreshold")


def get_mlflow_model_params(model_folder_context):
    """Return a dataiku.core.saved_model.ModelParams for MLflow imported models"""
    model_meta = load_external_model_meta(model_folder_context)

    modeling_params = {
        "metrics": model_meta["metricParams"],
        "autoOptimizeThreshold": True,
        "forcedClassifierThreshold": read_user_meta(model_folder_context).get("activeClassifierThreshold")
    }

    core_params = {
        "prediction_type": model_meta.get("predictionType", "OTHER"),
        "target_variable": model_meta.get("targetColumnName"),
    }

    per_feature = {}
    for feature in model_meta["features"]:
        output_feature = {"role": "INPUT"}
        if feature["type"] in ['tinyint', 'smallint', 'int', 'bigint', 'double', 'float', 'date', 'dateonly', 'datetimenotz']:
            output_feature["type"] = "NUMERIC"
        else:
            output_feature["type"] = "CATEGORY"
        per_feature[feature["name"]] = output_feature
    preprocessing_params = {
        "per_feature": per_feature,
        "target_remapping": [{"sourceValue": label, "mappedValue": value} for label, value in model_meta["labelToIntMap"].items()]
    }

    split_desc = {"schema": None}
    if model_folder_context.isfile("sample_schema.json"):
        split_desc = {"schema": model_folder_context.read_json("sample_schema.json")}

    model_params = ModelParams(
        model_type="PREDICTION",
        modeling_params=modeling_params,
        preprocessing_params=preprocessing_params,
        core_params=core_params,
        split_desc=split_desc,
        user_meta=read_user_meta(model_folder_context),
        model_perf=None,
        conditional_outputs=None,
        cluster_name_map=None,
        preprocessing_folder_context=model_folder_context,
        model_folder_context=model_folder_context,
        split_folder_context=model_folder_context,
        resolved_params=None,
        train_split_desc=None
    )
    model_params.model_meta = model_meta
    return model_params


def load_evaluation_dataset_sample(model_folder_context):
    schema = model_folder_context.read_json("sample_schema.json")
    dtypes = {}
    names = [c["name"] for c in schema["columns"]]
    for column in schema.get("features", []):
        dtypes[column["name"]] = schema_handling.DKU_RICHER_PANDAS_TYPES_MAP.get(column["type"], object)
    with model_folder_context.get_file_path_to_read("sample.csv.gz") as data_path:
        return pd.read_csv(data_path, dtype=dtypes, names=names, sep="\t")


def check_user_declared_classes_consistency(target_df_column, user_declared_classes):
    """
    Check that the classes declared by the user are consistent with the classes present in the target column of the
    evaluation dataset. Should only be called for classification models.
    """
    logger.info(
        "Checking consistency of user-declared classes {}".format(user_declared_classes)
        + " with classes present in the target column of the evaluation dataset"
    )

    feature_dataset_target_labels = target_df_column.unique()
    user_declared_labels = [label["label"] for label in user_declared_classes]
    # NOTE: the cast to str is required since the user-declared class labels are stored as strings. We didn't simply
    # call astype(str) on the df column so that the error message uses the original target labels
    missing_labels = sorted([label for label in feature_dataset_target_labels if str(label) not in user_declared_labels])
    if missing_labels:
        raise ValueError(
            "The following classes are present in the target column of the evaluation dataset, "
            "but were not declared: {}".format(missing_labels)
        )
