# encoding: utf-8

"""
Execute a prediction scoring recipe for a MLFLOW_PYFUNC model
Must be called in a Flow environment
"""

import sys
import logging
from os import environ
import pandas as pd

from dataiku.base.folder_context import build_folder_context
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core import doctor_constants as constants
from dataiku.core.dataset import Dataset
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper, safe_unicode_str
from dataiku.doctor.utils import scoring_recipe_utils
from dataikuscoring.mlflow import (mlflow_classification_predict_to_scoring_data,
                                   mlflow_regression_predict_to_scoring_data,
                                   mlflow_raw_predict)
from dataiku.external_ml.utils import load_external_model_meta
from dataiku.external_ml.mlflow.pyfunc_read_meta import read_user_meta

from dataiku.external_ml.mlflow.model_information_handler import MLflowModelInformationHandler
from dataiku.doctor.individual_explainer import DEFAULT_NB_EXPLANATIONS
from dataiku.doctor.individual_explainer import DEFAULT_SHAPLEY_BACKGROUND_SIZE
from dataiku.doctor.individual_explainer import DEFAULT_SUB_CHUNK_SIZE


def main(model_folder, input_dataset_smartname, output_dataset_smartname, output_style, forced_classifier_threshold, recipe_desc, fmi=None, proxy=None):
    input_dataset = Dataset(input_dataset_smartname)
    batch_size = 1000  # TODO
    output_dataset = Dataset(output_dataset_smartname)

    if proxy:
        environ["_PROXY_MODEL_PROXY"] = proxy

    environ["_DKU_PROXY_MODELS_INFERENCE_TYPE"] = "dku_scoring"

    logging.info("Scoring with batch size: {}".format(batch_size))
    model_folder_context = build_folder_context(model_folder)
    # Load model
    import mlflow
    with model_folder_context.get_folder_path_to_read() as model_folder_path:
        mlflow_model = mlflow.pyfunc.load_model(model_folder_path)
    imported_model_metadata = load_external_model_meta(model_folder_context)

    if forced_classifier_threshold != "":
        logging.info("Binary threshold overridden : using {}".format(forced_classifier_threshold))
        used_threshold = float(forced_classifier_threshold)
    else:
        used_threshold = read_user_meta(model_folder_context).get("activeClassifierThreshold")

    def output_generator():
        individual_explainer = None
        logging.info("Start output generator ...")
        try:
            for input_df in input_dataset.iter_dataframes(infer_with_pandas=False, chunksize=batch_size, float_precision="round_trip"):
                input_df.index = range(input_df.shape[0])
                input_df_copy_unnormalized = input_df.copy()
                logging.info("Got a dataframe to score: %s" % str(input_df.shape))
                logging.info("Predicting it with %s" % input_df)

                if "targetColumnName" in imported_model_metadata:
                    target_name = imported_model_metadata["targetColumnName"]
                    if target_name in input_df:
                        input_df = input_df.drop(target_name, axis=1)

                prediction_type = imported_model_metadata.get("predictionType", None)
                if output_style == "RAW":
                    pred_df = mlflow_raw_predict(mlflow_model, imported_model_metadata, input_df)
                else:
                    if not prediction_type:
                        raise Exception("Prediction type is not set on the MLflow model version, cannot use parsed output")

                    if prediction_type in [constants.BINARY_CLASSIFICATION, constants.MULTICLASS]:
                        scoring_data = mlflow_classification_predict_to_scoring_data(mlflow_model, imported_model_metadata, input_df, used_threshold)
                        pred_df = scoring_data.pred_and_proba_df
                        target_map = imported_model_metadata["labelToIntMap"]
                        if not recipe_desc["outputProbabilities"]:  # was only for conditional outputs
                            classes = [class_label for (class_label, _) in sorted(target_map.items())]
                            proba_cols = [u"proba_{}".format(safe_unicode_str(c)) for c in classes]
                            pred_df.drop(proba_cols, axis=1, inplace=True)
                        elif prediction_type == constants.BINARY_CLASSIFICATION:
                            perf_file = "perf.json"
                            if model_folder_context.isfile(perf_file):
                                model_perf = model_folder_context.read_json(perf_file)
                                if recipe_desc["outputProbaPercentiles"] and "probaPercentiles" in model_perf and model_perf["probaPercentiles"]:
                                    percentile = pd.Series(model_perf["probaPercentiles"])
                                    proba_1 = u"proba_{}".format(
                                        safe_unicode_str(next(k for k, v in target_map.items()
                                                              if v == 1)))
                                    pred_df["proba_percentile"] = pred_df[proba_1].apply(
                                        lambda p: percentile.where(percentile <= p).count() + 1)

                    elif prediction_type == constants.REGRESSION:
                        scoring_data = mlflow_regression_predict_to_scoring_data(mlflow_model, imported_model_metadata, input_df)
                        pred_df = scoring_data.preds_df
                    else:
                        raise Exception("{} not yet implemented".format(prediction_type))

                clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in pred_df.columns]

                if recipe_desc.get("outputExplanations"):

                    individual_explanation_params = recipe_desc.get("individualExplanationParams", {})
                    method = individual_explanation_params.get("method")
                    if individual_explainer is None:
                        model_info_handler = MLflowModelInformationHandler(model_folder_context)
                        individual_explainer = model_info_handler.get_explainer()
                        individual_explainer.make_ready(
                            input_df if individual_explanation_params.get("drawInScoredSet", False) else None)
                    logging.info("Starting row level explanations for this batch using {} method".format(
                        method
                    ))
                    nb_explanation = individual_explanation_params.get("nbExplanations", DEFAULT_NB_EXPLANATIONS)
                    shapley_background_size = individual_explanation_params.get("shapleyBackgroundSize",
                                                                                DEFAULT_SHAPLEY_BACKGROUND_SIZE)
                    sub_chunk_size = individual_explanation_params.get("subChunkSize", DEFAULT_SUB_CHUNK_SIZE)
                    explanations = individual_explainer.explain(
                        input_df, nb_explanation, method=method, sub_chunk_size=sub_chunk_size,
                        shapley_background_size=shapley_background_size
                    )
                    pred_df["explanations"] = individual_explainer.format_explanations(explanations, nb_explanation,
                                                                                       with_json=True)
                    logging.info("Done row level explanations for this batch")

                if fmi:
                    scoring_recipe_utils.add_output_model_metadata(pred_df, fmi)

                yield pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)
        except ValueError as e:
            if "has NA values in column" in str(e):
                raise ValueError('It seems that you are trying to score on a dataset containing empty values in a boolean or integer typed column. ' +
                                 'This is not supported: Try to filter data or adjust your dataset column\'s types.')
            else:
                raise e

    logging.info("Starting writer")

    writer = None
    # with output_dataset.get_writer() as writer:
    i = 0
    logging.info("Starting to iterate")
    for output_df in output_generator():
        logging.info("Generator generated a df %s" % str(output_df.shape))
        if i == 0:
            output_dataset.write_schema_from_dataframe(output_df)
            writer = output_dataset.get_writer()
        i = i + 1
        writer.write_dataframe(output_df)
        logging.info("Output df written")

    if writer is not None:
        writer.close()


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    # model folder, input dataset name, output dataset name
    with ErrorMonitoringWrapper():
        recipe_desc = dkujson.load_from_filepath(sys.argv[6])
        main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], recipe_desc, sys.argv[7], sys.argv[8])
