import pandas as pd
import numpy as np

from dataiku.modelevaluation.drift.drift_univariate import DriftUnivariate
from dataiku.core.doctor_constants import BINARY_CLASSIFICATION, MULTICLASS, REGRESSION

class PredictionDrift(DriftUnivariate):
    """
    Compute prediction drift

    Takes the prediction column of the current dataset and compare it to the statistics persisted during train-time model evaluation

    """
    def __init__(self, ref_prediction_series, cur_prediction_series, prediction_column=None):
        self.prediction_column = prediction_column if prediction_column is not None else "prediction"
        super(PredictionDrift, self).__init__(ref_prediction_series, cur_prediction_series, 20, False, None)

    def compute_drift(self):
        univariate_prediction = super(PredictionDrift, self)._compute_column(self.prediction_column, self.ref_df_prepared, self.cur_df_prepared)
        return {
            "chiSquareTestPvalue": univariate_prediction.get("chiSquareTestPvalue", None),
            "chiSquareTestStatistic": univariate_prediction.get("chiSquareTestStatistic", None),
            "ksTestPvalue": univariate_prediction.get("ksTestPvalue", None),
            "ksTestStatistic": univariate_prediction.get("ksTestStatistic", None),
            "populationStabilityIndex": univariate_prediction.get("populationStabilityIndex", None),
        }


def create_prediction_series_from_statistics(prediction_statistics, prediction_type, threshold=None):
    cuts = prediction_statistics.get("cuts", None)
    predicted_class_count_per_cut = prediction_statistics.get("predictedClassCountPerCut", None)
    predicted_class_count = None
    prediction_recreated = []
    if prediction_type == BINARY_CLASSIFICATION and cuts and predicted_class_count_per_cut and threshold is not None:
        cuts = np.array(cuts)
        cut_index = np.abs(cuts - threshold).argmin()
        predicted_class_count = predicted_class_count_per_cut[cut_index]

    if prediction_type == REGRESSION:
        prediction_recreated = prediction_statistics['predictions']
    elif prediction_type in {BINARY_CLASSIFICATION, MULTICLASS}:
        if predicted_class_count is None:
            predicted_class_count = prediction_statistics['predictedClassCount']
        for prediction in predicted_class_count:
            prediction_recreated += [prediction for x in range(predicted_class_count[prediction])]

    return pd.Series(prediction_recreated)
