import logging

from dataiku.core.percentage_progress import PercentageProgress
from dataiku.doctor.exception import EmptyDatasetException
from dataiku.modelevaluation.drift.drift_model import DriftModel
from dataiku.modelevaluation.drift.drift_preparator import DriftPreparator
from dataiku.modelevaluation.drift.drift_univariate import DriftUnivariate
from dataiku.modelevaluation.drift.prediction_drift import PredictionDrift, create_prediction_series_from_statistics
from dataiku.modelevaluation.drift.surrogate_model import SurrogateModel

logger = logging.getLogger(__name__)


class DataDriftComputer(object):
    def __init__(self, me1, me2, data_drift_params, job_id, compute_prediction_drift, reference_threshold, current_threshold):
        self.ref_me = me1
        """:type : dataiku.modelevaluation.server.ModelLikeInfo"""
        self.cur_me = me2
        """:type : dataiku.modelevaluation.server.ModelLikeInfo"""
        self.data_drift_params = data_drift_params
        """:type : dataiku.modelevaluation.data_types.DataDriftParams"""
        self.progress = PercentageProgress(job_id)
        """:type : dataiku.core.percentage_progress.PercentageProgress"""
        self.compute_prediction_drift = compute_prediction_drift
        """:type : boolean"""
        self.reference_threshold = reference_threshold
        """:type : float"""
        self.current_threshold = current_threshold
        """:type : float"""

    def compute(self):
        column_importance = self._get_or_compute_column_importance()
        self.progress.set_percentage(10)

        # Preparation ensures the compared dataframe have exactly the same schema
        preparator = DriftPreparator(self.ref_me, self.cur_me, self.data_drift_params, can_text_drift=False, should_text_drift=False)
        ref_df_prepared, cur_df_prepared, _, _, per_column_settings, _ = preparator.prepare()
        self.progress.set_percentage(20)
        univariate_drift = DriftUnivariate(
            ref_df_prepared, cur_df_prepared,
            self.data_drift_params.nb_bins,
            self.data_drift_params.compute_histograms,
            self.progress,
            handle_drift_failure_as_error=False
        ).compute_drift()

        prediction_drift = None
        prediction_column_name = None
        if self.cur_me.prediction_df is not None:
            prediction_column_name = [col for col in self.cur_me.prediction_df if not col.startswith('proba_')][0]

        if self.compute_prediction_drift and self.ref_me.prediction_statistics and self.cur_me.prediction_statistics:
            prediction_drift = PredictionDrift(
                create_prediction_series_from_statistics(self.ref_me.prediction_statistics, self.ref_me.prediction_type, self.reference_threshold),
                create_prediction_series_from_statistics(self.cur_me.prediction_statistics, self.cur_me.prediction_type, self.current_threshold),
                prediction_column_name
            ).compute_drift()

        if ref_df_prepared.empty:
            raise EmptyDatasetException("Reference can not be empty")
        if cur_df_prepared.empty:
            raise EmptyDatasetException("Current value can not be empty")

        if len(ref_df_prepared.columns) > 0:
            drift_model = DriftModel(ref_df_prepared, cur_df_prepared, column_importance,
                                     self.data_drift_params.confidence_level, handle_drift_failure_as_error=False).compute_drift()
        else:
            logger.error("Cannot train a drift model (no data or no column importance)")
            drift_model = None
        self.progress.set_percentage(100)

        return {
            "univariateDriftResult": univariate_drift,
            "driftModelResult": drift_model,
            "perColumnSettings": per_column_settings,
            "referenceSampleSize": len(self.ref_me.sample_df),
            "currentSampleSize": len(self.cur_me.sample_df),
            "referenceThreshold": self.reference_threshold,
            "currentThreshold": self.current_threshold,
            "predictionDriftResult": prediction_drift
        }

    def _get_or_compute_column_importance(self):
        if self.ref_me.column_importance is not None:
            # Use already known column importance
            logger.info("Re-using pre-computed column importance for {}".format(self.ref_me.ref))
            return dict(zip(
                self.ref_me.column_importance["columns"],
                self.ref_me.column_importance["importances"]
            ))

        elif self.ref_me.prediction_df is not None:
            # Fallback on using a surrogate model to approximate them if predictions are not available
            logger.info("Estimating column importance with a surrogate model for {}".format(self.ref_me.ref))

            try:
                surrogate_model = SurrogateModel(self.ref_me.sample_df, self.ref_me.prediction_df,
                                                 self.ref_me.prediction_type, self.ref_me.preprocessing_params)
                column_importance = surrogate_model.compute_column_importance()
                return column_importance
            except Exception as e:
                logger.warning("Error computing column importance: {}".format(e))
                return None
        else:
            # Column importances and predictions are not available
            # There is no way to compute column importance
            logger.error("Cannot obtain column importance for {}".format(self.ref_me.ref))
            return None
