# coding: utf-8
from __future__ import unicode_literals

from abc import ABCMeta
import json
import logging
import sys

from six import add_metaclass

from dataiku.base.folder_context import build_folder_context
from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.base.utils import watch_stdin
from dataiku.core import debugging, read_proxy_params
from dataiku.core import doctor_constants as constants
from dataiku.core import schema_handling
from dataiku.doctor.interactive_model.server import AbstractInteractiveScorer
from dataiku.doctor.interactive_model.server import InteractiveModelProtocol
from dataiku.doctor.exploration.exploration import CounterfactualsStrategy
from dataiku.doctor.exploration.exploration import EmuWrapper
from dataiku.doctor.exploration.exploration import OutcomeOptimizationStrategy
from dataiku.doctor.exploration.emu.generators import SpecialTarget
from dataiku.doctor.exploration.predictor_adapter_for_emu import PredictorAdapterForEmu
from dataiku.doctor.utils import dataframe_from_dict_with_dtypes
from dataiku.doctor.individual_explainer import DEFAULT_SHAPLEY_BACKGROUND_SIZE
from dataiku.external_ml.mlflow.predictor import MLflowPredictor
from dataiku.external_ml.mlflow.pyfunc_common import get_configured_threshold, load_evaluation_dataset_sample
from dataiku.external_ml.mlflow.pyfunc_read_meta import read_meta, set_signature, set_formats

from dataikuscoring.mlflow import mlflow_classification_predict_to_scoring_data, mlflow_regression_predict_to_scoring_data
from dataiku.external_ml.mlflow.pyfunc_evaluate_model import evaluate_and_save
from dataiku.external_ml.utils import load_external_model_meta

logger = logging.getLogger(__name__)

##############################
# PREDICTOR ADAPTERS FOR EMU #
##############################


@add_metaclass(ABCMeta)
class MLFlowPredictorAdapterForEmu(PredictorAdapterForEmu):
    def __init__(self, mlflow_model, mlflow_imported_model, used_threshold):
        self.mlflow_model = mlflow_model
        self.mlflow_imported_model = mlflow_imported_model
        # TODO @exploration: Merging release/10.0 into outcome optimization feature branch resulted in this new `used_threshold`
        #                    attribute that doesn't make sense for OO. A PR should be opened to refactor this.
        self.used_threshold = used_threshold

    def preprocess(self, df):
        X = df.drop([col for col in df.columns if df.dtypes[col] == "object"], axis=1)
        return X


class MLFlowClassifierAdapterForEmu(MLFlowPredictorAdapterForEmu):
    def predict(self, df):
        # retrieve the "processed" targets, e.g. classes replaced by 0, 1, ...
        scoring_data = mlflow_classification_predict_to_scoring_data(self.mlflow_model, self.mlflow_imported_model, df, self.used_threshold)
        return scoring_data.preds_df["prediction"].replace(self.mlflow_imported_model["labelToIntMap"]).values

    def predict_proba(self, df):
        scoring_data = mlflow_classification_predict_to_scoring_data(self.mlflow_model, self.mlflow_imported_model, df, self.used_threshold)
        return scoring_data.probas_df.values


class MLFlowRegressorAdapterForEmu(MLFlowPredictorAdapterForEmu):
    def predict(self, df):
        scoring_data = mlflow_regression_predict_to_scoring_data(self.mlflow_model, self.mlflow_imported_model, df)
        return scoring_data.preds_df["prediction"].values


##############################
#        EMU WRAPPERS        #
##############################


@add_metaclass(ABCMeta)
class MLFlowEmuWrapper(EmuWrapper):
    """
    Adaptation of EmuWrapper for MLflow which does not have "model_handler".

    This also handles the fact that for MLflow, we can only evaluate counterfactuals plausibility
    if all features are numeric, because there is no explicit preprocessing step that gives
    fully-numeric design matrices that are required for plausibility (which makes a SKLearn KMeans of observations)
    """

    def __init__(self, model_folder_context, mlflow_model, mlflow_imported_model, computer):
        self.model_folder_context = model_folder_context
        self.used_threshold = get_configured_threshold(self.model_folder_context)
        self.mlflow_model = mlflow_model
        self.mim = mlflow_imported_model
        self.evaluated_dataset = None
        super(MLFlowEmuWrapper, self).__init__(computer)

    def _get_integer_features_map(self):
        if self.mim.get("features", None) is None:
            return None
        return {
            e["name"]: bool(e["type"] in ["tinyint", "smallint", "int", "bigint"])
            for e in self.mim["features"]
        }

    def _get_valid_and_preprocessed_data(self, df):
        valid_df = df.copy()
        valid_X = df.copy().drop(self._get_target_col(), axis=1)
        if self.mim.get("features") is not None:
            valid_X = valid_X[[feature["name"] for feature in self.mim["features"]]]
        valid_y = df[self._get_target_col()].copy().astype(object).replace(self.mim["labelToIntMap"]).values
        return valid_X, valid_df, valid_y

    def _init_dataset(self):
        if self.evaluated_dataset is None:
            self.evaluated_dataset = load_evaluation_dataset_sample(self.model_folder_context)
        return self.evaluated_dataset.copy()

    def _init_model(self):
        self.model = self.computer.get_model(self.mlflow_model, self.mim, self.used_threshold)

    def _get_target_col(self):
        return self.mim["targetColumnName"]


class MLFlowEmuWrapperForCounterfactuals(MLFlowEmuWrapper):
    def _get_scoring_data_df(self, df):
        scoring_data = mlflow_classification_predict_to_scoring_data(self.mlflow_model, self.mim, df, self.used_threshold)
        metadata_df = scoring_data.probas_df
        metadata_df["prediction"] = scoring_data.preds_df["prediction"]
        return metadata_df


class MLFlowEmuWrapperForOutcomeOptimization(MLFlowEmuWrapper):
    def _get_scoring_data_df(self, df):
        scoring_data = mlflow_regression_predict_to_scoring_data(self.mlflow_model, self.mim, df)
        return scoring_data.preds_df


#############################
#       EMU COMPUTERS       #
#############################


class MLFlowCounterfactualsStrategy(CounterfactualsStrategy):
    def get_model(self, mlflow_model, mim, used_threshold):
        return MLFlowClassifierAdapterForEmu(mlflow_model, mim, used_threshold)


class MLFlowOutcomeOptimizationStrategy(OutcomeOptimizationStrategy):
    def get_model(self, mlflow_model, mim, used_threshold):
        return MLFlowRegressorAdapterForEmu(mlflow_model, mim, used_threshold)

############################
#          SERVER          #
############################


class MLFlowInteractiveModelProtocol(InteractiveModelProtocol):
    def start(self):
        interactive_scorer = None
        while True:
            try:
                try:
                    command = self.link.read_json()
                except (EOFError, IOError, OSError) as e:
                    logger.info("Error reading command from server. Communication broken ? {}".format(e))
                    break
                params = json.loads(command["params"])
                params["model_folder_context"] = build_folder_context(params["model_folder"])
                read_proxy_params(params)

                if command["type"] in ["SCORING", "COUNTERFACTUALS", "OUTCOME_OPTIMIZATION", "EXPLANATIONS"] and interactive_scorer is None:
                    interactive_scorer = MLFlowInteractiveScorer(params)

                if command["type"] == "SCORING":
                    self._handle_compute_score(interactive_scorer, params["records"])

                elif command["type"] == "EXPLANATIONS":
                    self._handle_compute_explanation(interactive_scorer, params["computation_params"],
                                                     params["records"])

                elif command["type"] == "COUNTERFACTUALS":
                    self._handle_compute_counterfactuals(interactive_scorer, params["computation_params"],
                                                         params["records"])

                elif command["type"] == "OUTCOME_OPTIMIZATION":
                    self._handle_compute_outcome_optimization(interactive_scorer, params["computation_params"],
                                                              params["records"])

                # These 3 commands are not interactive and they write files to the model folder.
                # Breaking out of the loop allows us to send files when we are in a containerized environment.
                # See also exec_mlflow_interactive_model_server.py
                elif command["type"] == "READ_META":
                    self._attempt_non_interactive_command(MLFlowInteractiveModelProtocol._read_meta, params)
                    break
                elif command["type"] == "EVALUATE":
                    self._attempt_non_interactive_command(MLFlowInteractiveModelProtocol._evaluate, params)
                    break
                elif command["type"] == "SET_SIGNATURE_AND_FORMATS":
                    self._attempt_non_interactive_command(MLFlowInteractiveModelProtocol._set_signature_and_formats, params)
                    break
                else:
                    logger.info("Interactive Scoring - Command %s not recognized" % command["type"])
            except Exception as e:
                self._handle_command_exception(e)

    @staticmethod
    def _read_meta(params):
        read_meta(params["model_folder_context"])

    @staticmethod
    def _set_signature_and_formats(params):
        logger.info("MLflow set formats and signature with params %s" % params)
        default_sampling = {"samplingMethod": "HEAD_SEQUENTIAL", "maxRecords": 500}
        set_formats(
            params["model_folder_context"],
            params.get("signatureAndFormatsGuessingDataset", params.get("featuresDataset")),
            params.get("guessingDatasetSamplingParam", default_sampling),
            params.get("target"),
            params.get("inputFormat"),
            params.get("outputFormat")
        )
        set_signature(
            params["model_folder_context"],
            params.get("signatureAndFormatsGuessingDataset", params.get("featuresDataset")),
            params.get("guessingDatasetSamplingParam", default_sampling),
            params.get("target"),
            params.get("features")
        )
        logger.info("Finished setting formats. Sending back results.")

    @staticmethod
    def _evaluate(params):
        logger.info("MLflow Evaluate with params %s" % params)
        evaluate_and_save(
            params["model_folder_context"],
            params["dataset_ref"],
            params["schema"],
            params["selection"],
            params["skip_expensive_reports"]
        )

    def _attempt_non_interactive_command(self, function, params):
        try:
            function(params)
            self._send_results(None)
        except Exception as e:
            self._handle_command_exception(e)


class MLFlowInteractiveScorer(AbstractInteractiveScorer):
    # TODO: Factorize this class with dataiku.doctor.interactive_model.server.InteractiveScorer
    def __init__(self, params):
        import mlflow

        logger.info("Loading MLflow interactive scorer")
        self.model_folder_context = params["model_folder_context"]
        self.mim = load_external_model_meta(self.model_folder_context)

        # Load model
        with self.model_folder_context.get_folder_path_to_read() as model_folder_path:
            self.mlflow_model = mlflow.pyfunc.load_model(model_folder_path)

        logger.info("Done loading MLflow interactive scorer")
        if self.mim.get("predictionType") == constants.REGRESSION:
            computer = MLFlowOutcomeOptimizationStrategy()
            emu_wrapper_class = MLFlowEmuWrapperForOutcomeOptimization
        else:
            computer = MLFlowCounterfactualsStrategy(self.mim["labelToIntMap"])
            emu_wrapper_class = MLFlowEmuWrapperForCounterfactuals
        self.emu_wrapper = emu_wrapper_class(self.model_folder_context, self.mlflow_model, self.mim, computer)

        # Added for explanations
        self.predictor = MLflowPredictor(self.model_folder_context)

    def score(self, records):
        logger.info("MLflow interactive scorer, scoring: %s" % records)
        df = self._get_dataframe(records)
        logger.info("Scoring with df=%s" % df)

        used_threshold = get_configured_threshold(self.model_folder_context)

        prediction_type = self.mim.get("predictionType")
        if prediction_type in [constants.BINARY_CLASSIFICATION, constants.MULTICLASS]:
            scoring_data = mlflow_classification_predict_to_scoring_data(self.mlflow_model, self.mim, df, used_threshold)
        elif constants.REGRESSION == prediction_type:
            scoring_data = mlflow_regression_predict_to_scoring_data(self.mlflow_model, self.mim, df)
        else:
            raise ValueError("bad prediction type %s" % prediction_type)

        logger.info("Got scoring_data=%s" % scoring_data)

        return scoring_data.pred_and_proba_df if scoring_data.pred_and_proba_df is not None else scoring_data.preds_df

    def explain(self, computation_params, records):
        logger.info("MLflow interactive scorer, explain: %s" % records)
        df = self._get_dataframe(records)
        logger.info("Explaining with df=%s" % df)
        explanations_df = self.predictor._compute_explanations(
            df,
            method=computation_params["explanationMethod"],
            n_explanations=computation_params["nExplanations"],
            mc_steps=computation_params.get("mcSteps", DEFAULT_SHAPLEY_BACKGROUND_SIZE)
        ).rename(lambda x: x.replace("explanations_", ""), axis=1)
        return self.score(records), explanations_df

    def compute_counterfactuals(self, computation_params, records):
        df = self._get_dataframe(records)
        if df.shape[0] != 1:
            raise ValueError("Cannot compute counterfactuals with multiple references")
        self.emu_wrapper.set_constraints(computation_params["featureDomains"], df)
        target = computation_params.get("target", None)
        self.emu_wrapper.update_generator(target, compute_plausibility="ONLY_IF_FULLY_NUMERIC")
        return [self.emu_wrapper.compute(df)]  # Must return a list, so wrapping results on a single record into a list

    def optimize_outcome(self, computation_params, records):
        df = self._get_dataframe(records)
        if df.shape[0] != 1:
            raise ValueError("Cannot optimize outcome with multiple references")
        self.emu_wrapper.set_constraints(computation_params["featureDomains"], df)
        target = computation_params.get("target", SpecialTarget.MIN)
        self.emu_wrapper.update_generator(target, compute_plausibility="ONLY_IF_FULLY_NUMERIC")
        return [self.emu_wrapper.compute(df)]  # Must return a list, so wrapping results on a single record into a list

    def _get_dataframe(self, records):
        logger.info("records=%s" % records)
        for rec in records:
            for (key, value) in rec.items():
                logger.info("k=%s v=%s" % (key, value))

        # format records as {"feature1": array, "feature2: array} like in API Node python server

        # We gather all_features like that to preserve order
        all_features = []
        for rec in records:
            for k in rec.keys():
                if k not in all_features:
                    all_features.append(k)

        logger.info("all_features=%s" % all_features)
        records_as_dict = {feature: [record[feature] if feature in record else None for record in records]
                           for feature in all_features}
        logger.info("recordS_as_dict:%s" % records_as_dict)

        dtypes = {}
        logger.info("MIM is %s" % self.mim)
        if self.mim.get("features") is not None:
            for feature in self.mim["features"]:
                dtypes[feature["name"]] = schema_handling.DKU_PANDAS_TYPES_MAP.get(feature["type"], None)
        logger.info("Computed dtypes: %s" % dtypes)
        df = dataframe_from_dict_with_dtypes(records_as_dict, dtypes)
        # Reorder
        if self.mim.get("features") is not None:
            df = df[[feature["name"] for feature in self.mim["features"]]]

        return df


def serve(port, secret, server_cert=None):
    watch_stdin()
    link = JavaLink(port, secret, server_cert=server_cert)
    link.connect()
    interactive_model = MLFlowInteractiveModelProtocol(link)
    interactive_model.start()


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    debugging.install_handler()

    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
