# encoding: utf-8
"""
Evaluates a MLflow saved model version in-place, filling-in perf.json
"""
import logging
import pandas as pd

from dataiku.doctor.utils import datetime_to_epoch
from dataiku.core import doctor_constants
from dataiku.core.dataset import Dataset
from dataiku.doctor import step_constants
from dataiku.doctor.diagnostics import default_diagnostics, diagnostics
from dataiku.doctor.preprocessing_collector import PredictionPreprocessingDataCollector
from dataiku.doctor.utils.listener import DiagOnlyContext, ProgressListener
from dataiku.external_ml.mlflow.pyfunc_common import get_mlflow_model_params, check_user_declared_classes_consistency
from dataiku.external_ml.mlflow.pyfunc_evaluate_common import process_input_df

from dataiku.doctor.evaluation.base import load_input_dataframe, sample_and_store_dataframe


logger = logging.getLogger(__name__)


def run_model_evaluation(mlflow_model, df, model_params, out_folder_context):
    # MLflow models cannot be partitioned, but we still use a bit of the partition formalism for more uniformity with
    # DSS models

    def partitions_generator():
        yield df, (model_params.model_folder, mlflow_model)

    recipe_desc = {
        "outputProbabilities": False,
        "outputProbaPercentiles": False,
        # When evaluating model just use the model threshold
        "overrideModelSpecifiedThreshold": False,
        "outputs": []
    }

    default_diagnostics.register_evaluation_callbacks()
    listener = ProgressListener(context=DiagOnlyContext(out_folder_context))

    with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
        diagnostics.on_load_evaluation_dataset_end(df=df)

    return process_input_df(
        part_df_and_model_params_generator=partitions_generator(),
        mlflow_imported_model=model_params.model_meta,
        partition_dispatch=False,
        modeling_params=model_params.modeling_params,
        with_sample_weight=False,
        recipe_desc=recipe_desc,
        cond_outputs=[],
        evaluation_store_folder_context=out_folder_context,
        dont_compute_performance=False
    )


def evaluate_and_save(model_folder_context, input_dataset_ref_name, schema, selection=None, skip_expensive_reports=False):
    import mlflow

    input_df = _retrieve_input_dataframe(Dataset(input_dataset_ref_name), schema, selection)

    model_params = get_mlflow_model_params(model_folder_context)
    if (
        model_params.core_params["prediction_type"] == "BINARY_CLASSIFICATION"
        or model_params.core_params["prediction_type"] == "MULTICLASS"
    ):
        target = model_params.core_params["target_variable"]
        user_declared_classes = model_params.model_meta["classLabels"]
        check_user_declared_classes_consistency(input_df[target], user_declared_classes)
    else:
        logger.info("Not dealing with a classification model, will not perform classes consistency check.")

    for feature in model_params.model_meta['features']:
        # Date features need to be converted to correspond to what the doctor expects.
        # However, we can't feed directly the model with this, so we need to do the opposite operation
        # when feeding the dataframe to the model (see dataikuscoring.mlflow.common.convert_date_features and its
        # usages)
        if feature['type'] == 'date' and pd.api.types.is_datetime64_any_dtype(input_df[feature['name']].dtype):
            input_df[feature['name']] = datetime_to_epoch(input_df[feature['name']])
    logger.info("Loaded MLflow model_params from %s" % model_folder_context)

    with model_folder_context.get_folder_path_to_read() as model_folder_path:
        mlflow_model = mlflow.pyfunc.load_model(model_folder_path)

    sample_raw = sample_and_store_dataframe(model_folder_context, input_df, schema)
    sample = sample_raw.copy()

    logger.info("Evaluating MLflow model with dataframe shape: %s" % str(sample.shape))

    run_model_evaluation(mlflow_model, sample, model_params, model_folder_context)

    # Compute collector_data for interactive scoring
    collector = PredictionPreprocessingDataCollector(sample, model_params.preprocessing_params)
    collector_data = collector.build()
    model_folder_context.write_json("collector_data.json", collector_data)

    iperf = model_folder_context.read_json("iperf.json")
    prediction_type = model_params.model_meta.get("predictionType")

    if skip_expensive_reports:
        logger.info("Skipping background rows drawing, feature and column importance computation")
    elif not(iperf.get("probaAware", False)) and (prediction_type != doctor_constants.REGRESSION):
        logger.info("Cannot draw background rows, compute feature and column importance: model is not probabilistic")
    else:
        run_explanations(model_folder_context, sample)
    logger.info("Done evaluating")


def run_explanations(model_folder_context, sample):
    from dataiku.external_ml.mlflow.model_information_handler import MLflowModelInformationHandler
    model_handler = MLflowModelInformationHandler(model_folder_context)
    model_handler.get_explainer().make_ready(trainset_override=sample, save=True)

    if model_handler.get_explainer().background_rows is not None:
        model_handler.compute_global_explanations(sample)


def _retrieve_input_dataframe(input_dataset, schema, selection):
    (names, dtypes, parse_date_columns) = Dataset.get_dataframe_schema_st(
        schema["columns"], parse_dates=True, infer_with_pandas=False
    )
    return load_input_dataframe(
        input_dataset=input_dataset,
        sampling=selection if selection is not None else {"samplingMethod": "FULL"},
        columns=names,
        dtypes=dtypes,
        parse_date_columns=parse_date_columns
    )
