# encoding: utf-8

"""
Execute a prediction scoring recipe in PyRegular mode
Must be called in a Flow environment
"""
import logging
import sys

import numpy as np
import pandas as pd

from dataiku.base.folder_context import build_folder_context
from dataiku.base.folder_context import get_partitions_fmi_folder_contexts
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core.dataset import Dataset
from dataiku.doctor.individual_explainer import DEFAULT_NB_EXPLANATIONS
from dataiku.doctor.individual_explainer import DEFAULT_SHAPLEY_BACKGROUND_SIZE
from dataiku.doctor.individual_explainer import DEFAULT_SUB_CHUNK_SIZE
from dataiku.doctor.posttraining.model_information_handler import PredictionModelInformationHandler
from dataiku.doctor.prediction.classification_scoring import binary_classif_scoring_add_percentile_and_cond_outputs
from dataiku.doctor.prediction.classification_scoring import binary_classification_predict
from dataiku.doctor.prediction.classification_scoring import multiclass_predict
from dataiku.doctor.prediction.common import check_classical_prediction_type
from dataiku.doctor.prediction.common import PredictionAlgorithmNaNSupport
from dataiku.doctor.prediction.overrides.ml_overrides_params import OVERRIDE_INFO_COL
from dataiku.doctor.prediction.overrides.ml_overrides_params import ml_overrides_params_from_model_folder
from dataiku.doctor.prediction.overrides.ml_overrides_results import OverridesResultsMixin
from dataiku.doctor.prediction.prediction_interval_model import PredictionIntervalsModel
from dataiku.doctor.prediction.regression_scoring import regression_predict
from dataiku.doctor.prediction.scorable_model import ScorableModel
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils.gpu_execution import GpuSupportingCapability, log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.gpu_execution import get_gpu_config_from_recipe_desc
from dataiku.doctor.utils.model_io import load_model_from_folder
from dataiku.doctor.utils.scoring_recipe_utils import add_fmi_metadata
from dataiku.doctor.utils.scoring_recipe_utils import add_prediction_time_metadata
from dataiku.doctor.utils.scoring_recipe_utils import dataframe_iterator
from dataiku.doctor.utils.scoring_recipe_utils import generate_part_df_and_model_params
from dataiku.doctor.utils.scoring_recipe_utils import get_empty_pred_df
from dataiku.doctor.utils.scoring_recipe_utils import get_input_parameters
from dataiku.doctor.utils.scoring_recipe_utils import is_partition_dispatch
from dataiku.doctor.utils.scoring_recipe_utils import smmd_colnames
from dataikuscoring.utils.prediction_result import AbstractPredictionResult

logger = logging.getLogger(__name__)


def load_model(model_folder_context, core_params, recipe_gpu_config, for_eval=False, global_model_assertions_params_list=None, global_model_overrides_params_list=None):
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    collector_data = model_folder_context.read_json("collector_data.json")
    preprocessing_params = model_folder_context.read_json("rpreprocessing_params.json")
    nan_support = PredictionAlgorithmNaNSupport(modeling_params, preprocessing_params)

    assertions_params_list = global_model_assertions_params_list
    if for_eval and global_model_assertions_params_list is None:
        assertions_params_list = _load_assertions_params(model_folder_context)

    if global_model_overrides_params_list is None:
        overrides_params_list = ml_overrides_params_from_model_folder(model_folder_context)
    else:
        overrides_params_list = global_model_overrides_params_list
    preprocessing_handler = PredictionPreprocessingHandler.build(core_params, preprocessing_params, model_folder_context,
                                                                 assertions=assertions_params_list, active_gpu_config=recipe_gpu_config,
                                                                 nan_support=nan_support)
    preprocessing_handler.collector_data = collector_data

    GpuSupportingCapability.init_cuda_visible_devices(recipe_gpu_config["params"]["gpuList"])

    pipeline = preprocessing_handler.build_preprocessing_pipeline(with_target=for_eval)

    clf = load_model_from_folder(model_folder_context)
    pred_interval_model = PredictionIntervalsModel.load_or_none(model_folder_context, core_params)

    scorable_model = ScorableModel.build(clf, core_params["taskType"], core_params["prediction_type"],
                                         modeling_params["algorithm"], preprocessing_params, overrides_params_list,
                                         pred_interval_model)
    return model_folder_context, preprocessing_params, scorable_model, pipeline, modeling_params, preprocessing_handler


def _load_assertions_params(model_folder_context):
    assertions_params_list = None
    assertions_params_filename = "rassertions.json"
    if model_folder_context.isfile(assertions_params_filename):
        assertions_params_list = model_folder_context.read_json(assertions_params_filename).get("assertions", None)
    return assertions_params_list


def load_model_partitions(model_folder_context, core_params, recipe_gpu_config, fmi, for_eval=False):
    # Prepare partitioned models if in partition dispatch mode, meaning model_folder is the base mode
    partition_dispatch = is_partition_dispatch(model_folder_context)
    if partition_dispatch:
        # enforcing assertions params to the one of global model
        global_model_assertions_params_list = _load_assertions_params(model_folder_context)
        global_model_overrides_params_list = ml_overrides_params_from_model_folder(model_folder_context)
        partitions = {}
        partitions_fmis = {}
        partitions_fmi_folder_contexts = get_partitions_fmi_folder_contexts(fmi)
        for partition_name, partition_fmi_folder_contexts in partitions_fmi_folder_contexts.items():
            partition_model_folder_context = partition_fmi_folder_contexts.model_folder_context
            partitions[partition_name] = load_model(partition_model_folder_context, core_params, recipe_gpu_config, for_eval=for_eval,
                                               global_model_assertions_params_list=global_model_assertions_params_list,
                                               global_model_overrides_params_list=global_model_overrides_params_list)
            partitions_fmis[partition_name] = partition_fmi_folder_contexts.fmi
    else:
        partitions = {"NP": load_model(model_folder_context, core_params, recipe_gpu_config, for_eval=for_eval)}
        partitions_fmis = {"NP": fmi}
    return partition_dispatch, partitions, partitions_fmis


def main(model_folder_in, input_dataset_smartname, managed_folder_smart_id, output_dataset_smartname, recipe_desc, script,
         preparation_output_schema, cond_outputs=None, fmi=None):

    model_folder_in_context = build_folder_context(model_folder_in)
    input_dataset, core_params, feature_preproc, names, dtypes, parse_date_columns = \
        get_input_parameters(model_folder_in_context, input_dataset_smartname, preparation_output_schema, script, managed_folder_smart_id)

    prediction_type = core_params["prediction_type"]
    check_classical_prediction_type(prediction_type)
    batch_size = recipe_desc.get("pythonBatchSize", 100000)
    logger.info("Scoring with batch size: {}".format(batch_size))

    recipe_gpu_config = get_gpu_config_from_recipe_desc(recipe_desc)
    log_nvidia_smi_if_use_gpu(gpu_config=recipe_gpu_config)
    partition_dispatch, partitions, partitions_fmis = load_model_partitions(model_folder_in_context, core_params, recipe_gpu_config, fmi)
    output_dataset = Dataset(output_dataset_smartname)

    def output_generator():
        logger.info("Start output generator ...")

        individual_explainer = None
        for input_df, input_df_copy_unnormalized in dataframe_iterator(
            input_dataset, names, dtypes, parse_date_columns, feature_preproc,
            batch_size=batch_size, float_precision="round_trip",
        ):
            part_dfs = []
            prediction_results = []
            for part_df, part_params, partition_id in generate_part_df_and_model_params(input_df, partition_dispatch, core_params,
                                                                          partitions, raise_if_not_found=False):

                model_folder_context, preprocessing_params, model, pipeline, modeling_params, preprocessing_handler = part_params

                logger.info("Predicting it")
                if prediction_type == doctor_constants.BINARY_CLASSIFICATION:

                    # Computing threshold
                    if recipe_desc["overrideModelSpecifiedThreshold"]:
                        used_threshold = recipe_desc.get("forcedClassifierThreshold")
                    else:
                        used_threshold = model_folder_context.read_json("user_meta.json").get("activeClassifierThreshold")
                    scoring_data = binary_classification_predict(model, pipeline, modeling_params,
                                                                 preprocessing_handler.target_map, used_threshold,
                                                                 part_df,
                                                                 output_probas=recipe_desc["outputProbabilities"],
                                                                 for_all_cuts=False)
                    pred_df = scoring_data.pred_and_proba_df
                    prediction_result = scoring_data.prediction_result

                    # Probability percentile & Conditional outputs
                    pred_df = binary_classif_scoring_add_percentile_and_cond_outputs(pred_df,
                                                                                     recipe_desc,
                                                                                     model_folder_context,
                                                                                     cond_outputs,
                                                                                     preprocessing_handler.target_map)
                elif prediction_type == doctor_constants.MULTICLASS:
                    scoring_data = multiclass_predict(model, pipeline, modeling_params,
                                                      preprocessing_handler.target_map, part_df,
                                                      output_probas=recipe_desc["outputProbabilities"])
                    pred_df = scoring_data.pred_and_proba_df
                    prediction_result = scoring_data.prediction_result
                elif prediction_type == doctor_constants.REGRESSION:
                    scoring_data = regression_predict(model, pipeline, modeling_params, part_df)
                    pred_df = scoring_data.preds_df
                    prediction_result = scoring_data.prediction_result
                else:
                    raise ValueError("bad prediction type %s" % prediction_type)

                if recipe_desc.get("outputModelMetadata", False):
                    if partition_id is not None:
                        add_fmi_metadata(pred_df, partitions_fmis[partition_id])
                    else:
                        add_fmi_metadata(pred_df, fmi)

                part_dfs.append(pred_df)
                prediction_results.append(prediction_result)

            prediction_result = None
            if partition_dispatch:
                if len(part_dfs) > 0:
                    pred_df = pd.concat(part_dfs, axis=0)
                    prediction_result = AbstractPredictionResult.concat(prediction_results)
                else:
                    logger.warning("All partitions found in dataset are unknown to "
                                    "the model, all predictions will be empty for this batch")
                    pred_df = get_empty_pred_df(input_df_copy_unnormalized.columns, output_dataset.read_schema())
            else:
                pred_df = part_dfs[0]
                prediction_result = prediction_results[0]

            logger.info("Done predicting it")

            if prediction_result is not None and isinstance(prediction_result, OverridesResultsMixin):
                pred_df[OVERRIDE_INFO_COL] = prediction_result.compute_and_return_info_column()

            # Row level explanations
            if recipe_desc.get("outputExplanations"):
                use_probas = model.is_proba_aware()
                if partition_dispatch:
                    logger.warning("Could not compute explanations with partition redispatch")
                    pred_df["explanations"] = np.nan
                elif not use_probas and prediction_type != doctor_constants.REGRESSION:
                    logger.warn("Could not compute explanations with a non-probabilistic model")
                    pred_df["explanations"] = np.nan
                else:
                    individual_explanation_params = recipe_desc.get("individualExplanationParams", {})
                    method = individual_explanation_params.get("method")
                    if individual_explainer is None:
                        split_folder_context = model_folder_in_context.get_subfolder_context("split")
                        split_desc = split_folder_context.read_json("split.json")
                        model_info_handler = PredictionModelInformationHandler(
                            split_desc, core_params, model_folder_in_context, model_folder_in_context, split_folder_context, fmi=fmi
                        )
                        individual_explainer = model_info_handler.get_explainer()
                        individual_explainer.make_ready(
                            input_df if individual_explanation_params.get("drawInScoredSet", False) else None)
                    logger.info("Starting row level explanations for this batch using {} method".format(
                        method
                    ))
                    nb_explanation = individual_explanation_params.get("nbExplanations", DEFAULT_NB_EXPLANATIONS)
                    shapley_background_size = individual_explanation_params.get("shapleyBackgroundSize",
                                                                                DEFAULT_SHAPLEY_BACKGROUND_SIZE)
                    sub_chunk_size = individual_explanation_params.get("subChunkSize", DEFAULT_SUB_CHUNK_SIZE)
                    explanations = individual_explainer.explain(
                        input_df, nb_explanation, method=method, sub_chunk_size=sub_chunk_size,
                        shapley_background_size=shapley_background_size
                    )
                    pred_df["explanations"] = individual_explainer.format_explanations(explanations, nb_explanation,
                                                                                       with_json=True)
                    logger.info("Done row level explanations for this batch")

            if recipe_desc.get("filterInputColumns", False):
                clean_kept_columns = [c for c in recipe_desc["keptInputColumns"] if c not in pred_df.columns]
            else:
                clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in pred_df.columns]

            if recipe_desc.get("outputModelMetadata", False):
                # add the "prediction time" and reorder the output model metadata Columns
                add_prediction_time_metadata(pred_df)
                ordered_columns = [col for col in pred_df.columns if col not in smmd_colnames] + smmd_colnames
                pred_df = pred_df[ordered_columns]

            yield pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)

    logger.info("Starting writer")
    with output_dataset.get_writer() as writer:
        i = 0
        logger.info("Starting to iterate")
        for output_df in output_generator():
            logger.info("Generator generated a df %s" % str(output_df.shape))
            #if i == 0:
            #    output_dataset.write_schema_from_dataframe(output_df)
            i = i+1
            writer.write_dataframe(output_df)
            logger.info("Output df written")


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        main(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4],
              dkujson.load_from_filepath(sys.argv[5]),
              dkujson.load_from_filepath(sys.argv[6]),
              dkujson.load_from_filepath(sys.argv[7]),
              dkujson.load_from_filepath(sys.argv[8]),
              sys.argv[9])
