import numpy as np

from dataiku.doctor.prediction.common import prepare_multiframe
from dataiku.doctor.prediction.common import train_test_split
from dataiku.doctor.utils.model_io import from_pkl
from dataiku.doctor.prediction.scorable_model import SerializableMixin


class PredictionIntervalsModel(SerializableMixin):
    PKL_FILENAME = "prediction_intervals_clf.pkl"
    GZ_FILENAME = "dss_prediction_intervals_model.gz"
    PARAMS_FILENAME = "prediction_intervals_params.json"

    def __init__(self, clf=None, q=None, algorithm=None, coverage=0.95, split=0.5):
        self.q = q
        self.coverage = coverage
        self.split = split
        super(PredictionIntervalsModel, self).__init__(clf, algorithm)

    @staticmethod
    def for_training(uncertainty_params):
        algorithm = uncertainty_params["algorithm"]
        if algorithm == "LIGHTGBM_REGRESSION":
            from lightgbm import LGBMRegressor
            clf = LGBMRegressor(boosting_type='gbdt', class_weight=None, colsample_bytree=1.0,
                                importance_type='split', learning_rate=0.1, max_depth=-1,
                                min_child_samples=20, min_child_weight=0.001, min_split_gain=0.0,
                                n_estimators=100, n_jobs=-1, num_leaves=31, objective=None,
                                random_state=None, reg_alpha=0.0, reg_lambda=0.0, silent=True,
                                subsample=1.0, subsample_for_bin=200000, subsample_freq=0)
        else:
            raise ValueError("Algo not supported for prediction intervals model: '%s'" % algorithm)

        return PredictionIntervalsModel(clf=clf, algorithm=algorithm,
                                        coverage=uncertainty_params.get("predictionIntervalsCoverage", 0.95))

    @staticmethod
    def load_or_none(model_folder_context, core_params):
        """
        :type model_folder_context: dataiku.base.folder_context.FolderContext
        :type core_params: dict
        :rtype: PredictionIntervalsModel or None
        """
        uncertainty_params = core_params.get("uncertainty", {})
        if not uncertainty_params or not uncertainty_params.get("predictionIntervalsEnabled", False):
            return None

        clf = from_pkl(model_folder_context, PredictionIntervalsModel.PKL_FILENAME)
        q = model_folder_context.read_json(PredictionIntervalsModel.PARAMS_FILENAME)["q"]
        return PredictionIntervalsModel(clf=clf, q=q, algorithm=uncertainty_params["algorithm"],
                                        coverage=uncertainty_params["predictionIntervalsCoverage"])

    def fit(self, X, error):
        """
        :type X: np.ndarray or scipy.sparse.csr_matrix
        :type error: np.ndarray
        """
        split_point = int(self.split * X.shape[0])
        X_model, X_q, error_model, error_q = train_test_split(X, error, test_size=split_point, random_state=42)
        self.clf.fit(X_model, error_model)
        norm_residuals = error_q / self.clf.predict(X_q)
        self.q = np.quantile(norm_residuals, self.coverage)

    def predict(self, X):
        """
        :type X: np.ndarray or scipy.sparse.csr_matrix
        :rtype: np.ndarray
        """
        if X.shape[0] <= 0:
            return np.empty([0, 1]).astype(float)
        return abs(self.clf.predict(X) * self.q)


def train_prediction_interval_or_none(clf, core_params, modeling_params, transformed):
    uncertainty_params = core_params.get("uncertainty", {})
    if not uncertainty_params or not uncertainty_params.get("predictionIntervalsEnabled", False):
        return None
    transformed_X, _ = prepare_multiframe(transformed["TRAIN"], modeling_params)
    transformed_y = transformed["target"].astype(float)

    pred_interval_model = PredictionIntervalsModel.for_training(uncertainty_params)
    error = np.abs(transformed_y - clf.predict(transformed_X))
    pred_interval_model.fit(transformed_X, error)
    return pred_interval_model
