import numpy as np

from dataikuscoring.utils.prediction_result import AbstractPredictionResult
from dataikuscoring.utils.prediction_result import ClassificationPredictionResult


class DecisionsAndCuts(object):
    """
    This class is used for binary classification.

    In binary classification, when probas are available, preds usually depend on both
    the probas and a threshold:
      - if the proba for class 1 is greater than the threshold, then we predict class 1
      - we predict class 0 otherwise
    Thus, since predictions depend on the threshold, some features of DSS need to compute
    predictions for multiple cuts.

    Note: probas are threshold independent, so will be the same for all prediction results. However, because the results
    can be overridden and contain different values depending on the prediction, it is much simpler to consider the
    full prediction object result rather than just the prediction.
    """
    def __init__(self, cuts, prediction_results):
        """
        :param np.ndarray cuts: list of cuts
        :param list[ClassificationPredictionResult] prediction_results: list of prediction results
        """
        assert len(cuts) == len(prediction_results), "Params don't have the same length"
        self._cuts = cuts
        self._prediction_results = prediction_results

    def align_with_not_declined(self, array):
        """
        Since the alignment doesn't depend on the predictions we can simply take the first result to use it for aligning
        :type array: np.ndarray
        :rtype: np.ndarray
        """
        if self.is_empty():
            return array
        return self._prediction_results[0].align_with_not_declined(array)

    def assert_not_all_declined(self):
        for pred_result in self._prediction_results:
            pred_result.assert_not_all_declined()

    @staticmethod
    def from_probas_or_unmapped_preds(probas, preds, target_map, cuts=None):
        if probas is None:
            return DecisionsAndCuts.from_unmapped_preds(preds, target_map, cuts)
        return DecisionsAndCuts.from_probas(probas, target_map, cuts)

    @staticmethod
    def from_probas(probas, target_map, cuts=None):
        """
        :param np.ndarray probas: from which we'll infer the predictions for all cuts
        :param dict target_map: target_mapping of the classes in the model
        :param np.ndarray or None cuts: list of cuts, default values will be used if None
        :rtype: DecisionsAndCuts
        """
        if cuts is None:
            cuts = DecisionsAndCuts.get_default_cuts()
        prediction_results = []
        for cut in cuts:
            unmapped_preds = (probas[:, 1] > cut).astype(int)
            # not calling `probas.copy()` because we like to live dangerously
            prediction_results.append(ClassificationPredictionResult(target_map, probas=probas,
                                                                     unmapped_preds=unmapped_preds))
        return DecisionsAndCuts(cuts, prediction_results)

    @staticmethod
    def from_unmapped_preds(unmapped_preds, target_map, cuts=None):
        """
        :param np.ndarray unmapped_preds: from which will be used for all cuts
        :param dict target_map: target_mapping of the classes in the model
        :param np.ndarray or None cuts: list of cuts, default values will be used if None
        :rtype: DecisionsAndCuts
        """
        if cuts is None:
            cuts = DecisionsAndCuts.get_default_cuts()
        prediction_results = []
        for _ in cuts:
            prediction_results.append(ClassificationPredictionResult(target_map, unmapped_preds=unmapped_preds))
        return DecisionsAndCuts(cuts, prediction_results)

    @staticmethod
    def get_default_cuts():
        # round the cut because we want "precise" values for cuts in perf.json
        return np.around(np.arange(0, 1., 0.025), 4)

    @staticmethod
    def concat(decisions_and_cuts_list):
        """
        :param list[DecisionsAndCuts] decisions_and_cuts_list: having the same cuts
        :rtype: DecisionsAndCuts
        """
        if len(decisions_and_cuts_list) == 0:
            raise ValueError("No `DecisionsAndCuts` to concat")

        first_cuts = decisions_and_cuts_list[0].get_cuts()
        for decisions_and_cuts in decisions_and_cuts_list[1:]:
            if not np.allclose(first_cuts, decisions_and_cuts.get_cuts()):
                raise ValueError("All cuts should be equal")

        has_probas_list = [dc.has_probas() for dc in decisions_and_cuts_list]
        if not all(has_probas_list) and any(has_probas_list):
            raise ValueError("Some probas are not defined")

        concat_prediction_result_list = []
        for cut_index in range(len(first_cuts)):
            prediction_results_to_concat = [dc._prediction_results[cut_index] for dc in decisions_and_cuts_list]
            concat_prediction_result = AbstractPredictionResult.concat(prediction_results_to_concat)
            assert isinstance(concat_prediction_result, ClassificationPredictionResult), "Wrong concatenated type"
            concat_prediction_result_list.append(concat_prediction_result)

        return DecisionsAndCuts(first_cuts, concat_prediction_result_list)

    def __iter__(self):
        """
        :rtype: collections.Iterable[Tuple[ClassificationPredictionResult, float]]
        """
        for i in range(len(self._cuts)):
            yield self._prediction_results[i], self._cuts[i]

    def __len__(self):
        return len(self._cuts)

    def get_cuts(self):
        """
        :rtype: np.ndarray
        """
        return self._cuts

    def get_prediction_results(self):
        """
        :rtype list[ClassificationPredictionResult]
        """
        return self._prediction_results

    def get_unmapped_preds_not_declined_list(self):
        """
        :rtype: list[np.ndarray]
        """
        return [prediction_result.unmapped_preds_not_declined for prediction_result in self._prediction_results]

    def get_prediction_result_for_nearest_cut(self, cut):
        """
        :type cut: float
        :rtype: ClassificationPredictionResult
        """
        cut_index = np.abs(self._cuts - cut).argmin()
        return self._prediction_results[cut_index]

    def has_probas(self):
        return self._prediction_results[0].has_probas()

    def is_empty(self):
        return self._prediction_results[0].is_empty()
