import logging

from dataiku.core import doctor_constants
from dataiku.core.doctor_constants import IMAGE, PREPROC_ONEVALUE
from dataiku.doctor.posttraining.model_information_handler import ModelInformationHandlerBase
from dataiku.doctor.prediction.classification_scoring import BinaryClassificationModelScorer
from dataiku.doctor.prediction.decisions_and_cuts import DecisionsAndCuts
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult
from dataikuscoring.utils.prediction_result import PredictionResult
from dataiku.doctor.prediction.regression_scoring import RegressionModelScorer
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.utils.split import load_df_with_normalization
from dataikuscoring.utils.scoring_data import ScoringData

logger = logging.getLogger(__name__)


class ModelLessModelInformationHandler(ModelInformationHandlerBase):
    def get_per_feature_col(self, col_name):
        pass

    def __init__(self, model_evaluation, features, iperf, resolved_preprocessing_params, modelevaluation_folder_context, model_folder_context):
        self._model_evaluation = model_evaluation
        self._features = features
        self._iperf = iperf
        self._preprocessing_params = resolved_preprocessing_params
        self._modelevaluation_folder_context = modelevaluation_folder_context
        self._model_folder_context = model_folder_context
        self._full_df = None
        self._modeling_params = self._prepare_modeling_params()
        self._core_params = self._prepare_core_params()
        self._schema = None
        self._postcompute_folder_context = self._modelevaluation_folder_context.get_subfolder_context("postcomputation")

    def _prepare_modeling_params(self):
        modeling_params = {'algorithm': 'EVALUATED', 'metrics': self._model_evaluation.get('metricParams', {})}
        if self._iperf.get('probaAware', False):
            modeling_params['autoOptimizeThreshold'] = True
            modeling_params['forcedClassifierThreshold'] = self._model_evaluation.get('activeClassifierThreshold', 0.5)
        else:
            modeling_params['autoOptimizeThreshold'] = False
            modeling_params['forcedClassifierThreshold'] = 0.5
        return modeling_params

    def _prepare_core_params(self):
        return {
            "weight": {
                "sampleWeightVariable": self.get_sample_weight_variable(),
                "sampleWeightMethod ": self.get_weight_method() if self.get_sample_weight_variable() else None
            },
            doctor_constants.PREDICTION_TYPE: self.get_prediction_type(),
            doctor_constants.TARGET_VARIABLE: self.get_target_variable(),
            doctor_constants.PREDICTION_VARIABLE: self.get_prediction_variable(),
            doctor_constants.PROBA_COLUMNS: self.get_probas_col_names(),
            "managedFolderSmartId": self.get_managed_folder_smart_id()
        }

    def get_clf(self):
        raise NotImplementedError("ModelLessModelInformationHandler doesn't have a model")

    def get_algorithm(self):
        return self._modeling_params.get("algorithm", None)

    def get_model(self):
        raise NotImplementedError("ModelLessModelInformationHandler doesn't have a model")

    def get_weight_method(self):
        return "SAMPLE_WEIGHT"

    def get_sample_weight_variable(self):
        n = self._model_evaluation.get("weightsVariable", None)
        if n == '':
            n = None # because None is used as a flag for "no weight" and None != ''
        return n

    def with_sample_weights(self):
        return self.get_sample_weight_variable()

    def get_output_folder_context(self):
        return self._postcompute_folder_context

    def get_prediction_type(self):
        return self._model_evaluation.get("predictionType")

    def get_model_type(self):
        return "PREDICTION"

    def get_prediction_variable(self):
        return self._model_evaluation.get("predictionVariable")

    def get_target_variable(self):
        return self._model_evaluation["targetVariable"]

    def get_features(self):
        return self._features

    def use_full_df(self):
        return True

    def get_feature_col(self, col_name):
        if col_name not in self._features.keys():
            raise ValueError("Column '{}' not found".format(col_name))
        return self._features[col_name]

    def get_type_of_column(self, col_name):
        return self.get_feature_col(col_name)["type"]

    def get_role_of_column(self, col_name):
        return self.get_feature_col(col_name)["role"]

    def get_probas_col_names(self):
        proba_col_list = self._model_evaluation.get("probaColumns", [])
        ret = []
        for cur_proba in proba_col_list:
            ret.append(cur_proba["column"])
        return ret

    def get_managed_folder_smart_id(self):
        """
        :return: The smart ID of the managed folder.
        :rtype: str
        :raises ValueError: If the model has image features but no image folder is provided.
        """
        has_image_features = any(feature.get("type") == IMAGE
                                 for feature in self._preprocessing_params[doctor_constants.PER_FEATURE].values())
        managed_folder_smart_id = self._model_evaluation.get("managedFolderSmartId")

        if has_image_features and not managed_folder_smart_id:
            raise ValueError(
                "Configuration Error: Older evaluations do not support image features for interactive computation. "
                "Please create a new evaluation."
            )
        return managed_folder_smart_id

    def get_target_mapping(self):
        return {
            c['sourceValue']: int(c['mappedValue'])
            for c in self._preprocessing_params.get('target_remapping', [])
        }

    def get_classes_mapping(self):
        return [int(c['mappedValue']) for c in self._preprocessing_params.get('target_remapping', [])]

    def get_collector_data(self):
        collector_data_filename = "collector_data.json"
        if not self._modelevaluation_folder_context.isfile(collector_data_filename):
            raise Exception("Collector data not found %s" % collector_data_filename)
        return self._modelevaluation_folder_context.read_json(collector_data_filename)

    def prepare_for_scoring_full(self, df):
        # Extract predictions
        preprocessing_handler = PredictionPreprocessingHandler.build(self._core_params, self._preprocessing_params,
                                                                     self._modelevaluation_folder_context)
        collector_data = self.get_collector_data()
        preprocessing_handler.collector_data = collector_data

        preprocessing_steps = list(map(lambda x: x.__class__.__name__, preprocessing_handler.preprocessing_steps())) # Temporary fix for sc-101984, will be fixed with sc-101989

        if self._model_folder_context and 'CustomPreprocessingStep' in preprocessing_steps:
            preprocessing_handler = PredictionPreprocessingHandler.build(self._core_params, self._preprocessing_params,
                                                                         self._model_folder_context)
            preprocessing_handler.collector_data = collector_data

        pipeline = preprocessing_handler.build_preprocessing_pipeline(with_target=True,
                                                                      allow_empty_mf=True,
                                                                      with_prediction=True)
        transformed = pipeline.process(df)
        idx = transformed["target"].index
        if idx.empty:
            return ScoringData(is_empty=True, reason=doctor_constants.PREPROC_DROPPED)
        targets = transformed["target"]
        preds = transformed["prediction"]
        try:
            probas = transformed[doctor_constants.PROBA_COLUMNS].values
        except KeyError:
            probas = None

        decisions_and_cuts = None
        if self.get_prediction_type() == doctor_constants.BINARY_CLASSIFICATION:
            decisions_and_cuts = DecisionsAndCuts.from_probas_or_unmapped_preds(probas, preds, self.get_target_mapping())

        prediction_result = None
        if self.get_prediction_type() == doctor_constants.REGRESSION:
            prediction_result = PredictionResult(preds)
        elif self.get_prediction_type() in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS}:
            prediction_result = ClassificationPredictionResult(self.get_target_mapping(), probas=probas,
                                                               unmapped_preds=preds)

        if self.get_sample_weight_variable():
            valid_sample_weights = transformed["weight"]
        else:
            valid_sample_weights = None
        valid_unprocessed_df = transformed['UNPROCESSED']
        return ScoringData(prediction_result=prediction_result, valid_y=targets, valid_sample_weights=valid_sample_weights, 
                           decisions_and_cuts=decisions_and_cuts, valid_unprocessed=valid_unprocessed_df)

    def get_full_df(self):
        file_name = ("sample_scored.csv.gz" if self._modelevaluation_folder_context.isfile("sample_scored.csv.gz")
                     else "sample.csv.gz")
        return load_df_with_normalization(file_name, self._modelevaluation_folder_context,
                                          self.get_schema(), self._features, self.get_prediction_type()), True

    def get_schema(self):
        if self._schema is None:
            schema_file = ("sample_scored_schema.json"
                           if self._modelevaluation_folder_context.isfile("sample_scored_schema.json")
                           else "sample_schema.json")
            self._schema = self._modelevaluation_folder_context.read_json(schema_file)
        return self._schema

    def run_binary_scoring(self, df):
        if self.get_prediction_type() == doctor_constants.BINARY_CLASSIFICATION:
            # warning: scoring_data.prediction_result.preds are computed with the default threshold, but are
            # not used anyway
            scoring_data = self.prepare_for_scoring_full(df)

            if scoring_data.is_empty:
                return False, scoring_data.reason, None

            print("target_mapping: {}".format(self.get_target_mapping()))

            binary_classif_scorer = BinaryClassificationModelScorer(
                self._modeling_params,
                None,
                scoring_data.decisions_and_cuts,
                scoring_data.valid_y,
                self.get_target_mapping(),
                test_unprocessed=scoring_data.valid_unprocessed,
                test_X=None,  # Not dumping on disk predicted_df
                test_df_index=None,  # Not dumping on disk predicted_df
                test_sample_weight=scoring_data.valid_sample_weights)

            can_score, reason = binary_classif_scorer.can_score()
            if not can_score:
                return False, reason, None

            perf = binary_classif_scorer.score()
            return True, None, perf

        else:
            raise ValueError("Cannot compute binary scoring on '{}' model".format(self.get_prediction_type().lower()))

    def run_regression_scoring(self, df):
        if self.get_prediction_type() != doctor_constants.REGRESSION:
            raise ValueError(
                "Cannot compute regression scoring on '{}' model".format(self.get_prediction_type().lower()))

        scoring_data = self.prepare_for_scoring_full(df)
        if scoring_data.is_empty:
            return False, scoring_data.reason, None

        if df.shape[0] == 1:
            return False, PREPROC_ONEVALUE, None

        regression_scorer = RegressionModelScorer(self._modeling_params,
                                                  scoring_data.prediction_result,
                                                  scoring_data.valid_y,
                                                  None,
                                                  test_unprocessed=scoring_data.valid_unprocessed,
                                                  test_X=None,  # Not dumping on disk predicted_df
                                                  test_df_index=None,  # Not dumping on disk predicted_df
                                                  test_sample_weight=scoring_data.valid_sample_weights)
        perf = regression_scorer.score()
        return True, None, perf

    def get_preprocessing_params(self):
        return self._preprocessing_params

    def get_model_folder_context(self):
        return self._model_folder_context
