import inspect
import logging
import math
import numbers
from datetime import timedelta

import numpy as np
import pandas as pd
import scipy.sparse as sp
from sklearn.utils.validation import _is_arraylike
from sklearn.utils.validation import _num_samples

from dataiku.base.block_link import register_as_serializable
from dataiku.base.utils import get_argspec
from dataiku.doctor.distributed.work_scheduler import AbstractContext
from dataiku.doctor.timeseries.utils import log_df
from dataiku.doctor.utils import dku_indexing
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils import dku_nonaninf
from dataiku.doctor.utils.skcompat import dku_fit

logger = logging.getLogger(__name__)


@register_as_serializable
class Split(object):
    """
    Indices of a single split
    """

    def __init__(self, train, test):
        self.train = train
        self.test = test


@register_as_serializable
class ClassicalSearchContext(AbstractContext):
    """
    Store all the context required to fit & score the estimator during the hyperparameter search,
    for a classical prediction ML task.

    This context is either used directly or pickled/streamed to a remote worker.

    This object (and any of its properties) should NOT be modified (can be shared between threads)
    """

    def __init__(self, X, y, splits, sample_weight, scorer,
                 metric_sign, trainable_model):
        self._X = X
        self._y = y
        self._splits = splits
        self._sample_weight = sample_weight
        self._scorer = scorer
        self._metric_sign = metric_sign
        self._trainable_model = trainable_model

    @property
    def splits(self):
        return self._splits

    @property
    def n_splits(self):
        return len(self._splits)

    @property
    def trainable_model(self):
        return self._trainable_model

    def execute_work(self, split_id, parameters):
        """
        Evaluate an hyper-parameter for one split
        """
        train = self._splits[split_id].train
        test = self._splits[split_id].test

        return _dku_fit_and_score(
            trainable_model=self._trainable_model,
            X=self._X,
            y=self._y,
            scorer=self._scorer,
            train=train,
            test=test,
            parameters=parameters,
            metric_sign=self._metric_sign,
            split_id=split_id,
            sample_weight=self._sample_weight
        )


@register_as_serializable
class TimeseriesForecastingSearchContext(AbstractContext):
    """
    Store all the context required to fit & score the estimator during the hyperparameter search,
    for a timeseries forecasting ML task.

    This context is either used directly or pickled/streamed to a remote worker.

    This object (and any of its properties) should NOT be modified (can be shared between threads)
    """

    def __init__(
        self, df, split_handler, model_scorer, preprocessed_external_features, shift_map, fit_before_predict,
        min_timeseries_size_for_training, eval_metrics, metric_sign, trainable_model,
        skip_too_short_timeseries, custom_train_test_intervals,
    ):
        self._df = df
        self._split_handler = split_handler
        self._model_scorer = model_scorer
        self._preprocessed_external_features = preprocessed_external_features
        self._shift_map = shift_map
        self._fit_before_predict = fit_before_predict
        self._min_timeseries_size_for_training = min_timeseries_size_for_training
        self._eval_metrics = eval_metrics
        self._metric_sign = metric_sign
        self._trainable_model = trainable_model
        self._skip_too_short_timeseries = skip_too_short_timeseries
        self._custom_train_test_intervals = custom_train_test_intervals

        self._prepare_and_validate_search_intervals_df()

    @property
    def n_splits(self):
        # Splitting and search logic will be applied to each distinct contiguous search intervals within the context df independently
        # So for example if kfold = 3 is used, each search interval will be split 3 times, resulting in 6 total splits.
        return self._split_handler.n_splits * max(len(self._search_intervals_df), 1)

    @property
    def trainable_model(self):
        return self._trainable_model

    def execute_work(self, split_id, parameters):
        """
        Evaluate an hyper-parameter for one split
        """
        search_interval_df = self._get_search_interval_df(split_id)
        search_interval_split_id = self._get_search_interval_split_id(split_id)

        train_df, test_df, historical_df = next(self._split_handler.split(
            search_interval_df,
            split_id=search_interval_split_id,
        ))

        return _dku_timeseries_fit_and_score(
            trainable_model=self._trainable_model,
            train_df=train_df,
            test_df=test_df,
            historical_df=historical_df,
            model_scorer=self._model_scorer,
            preprocessed_external_features=self._preprocessed_external_features,
            shift_map=self._shift_map,
            fit_before_predict=self._fit_before_predict,
            parameters=parameters,
            eval_metrics=self._eval_metrics,
            metric_sign=self._metric_sign,
            split_id=split_id,
        )

    def _get_search_interval_df(self, split_id):
        """
        Returns a subset of the context df containing only data for the search interval in which split_id belongs

        :param split_id: global kfold split id between 0 and n_splits
        :return: Subset dataframe of df, to use in `AbstractTimeseriesSplitHandler._single_split`
        :rtype: pd.DataFrame
        """
        interval_index = int(math.floor(split_id / self._split_handler.n_splits))
        return self._search_intervals_df[interval_index]

    def _get_search_interval_split_id(self, split_id):
        """
        Maps the global split_id to the interval_split_id.

        For example if there are 3 splits per intervals, and 2 search intervals,
        there will be 6 global split_id ranging from 0 to 5.
        Each search interval will have interval_split_id ranging from 0 to 2.
        In that case split_id=4 would be mapped to interval_split_id=1 inside second search interval.

        :param split_id: global split id between 0 and n_splits
        :return: kfold split id within interval
        :rtype: int
        """
        return split_id % self._split_handler.n_splits

    def _prepare_and_validate_search_intervals_df(self):
        """
        For each search interval, build and prepare a subset of the df containing only data for the search interval
        This method is raising if search intervals are invalid
        """
        intervals_dfs = self._build_intervals_dfs()
        self._validate_intervals_dfs(intervals_dfs)
        self._search_intervals_df = self._prepare_intervals_dfs(intervals_dfs)

    def _build_intervals_dfs(self):
        """
        For each search interval, build a df from the master search df, bounded between the interval start and interval end.
        :return: list of pd.DataFrame
        """
        search_intervals = self.get_search_intervals(self._df, self._split_handler.time_column,
                                                     self._custom_train_test_intervals)
        intervals_dfs = []
        for interval in search_intervals:
            start_date = interval[0]
            end_date = interval[1]
            search_interval_df = self._df[(self._df[self._split_handler.time_column] >= start_date) & (
                    self._df[self._split_handler.time_column] < end_date)]
            intervals_dfs.append(search_interval_df)
        return intervals_dfs

    def _validate_intervals_dfs(self, intervals_dfs):
        """
        Validate that intervals contain enough data to execute the HP search.
        :raises: :class:`ValueError`: When validation fails
        :param intervals_dfs: list of pd.DataFrame to validate
        """
        if len(intervals_dfs) == 1:
            # When a single search interval is defined, we simply prepare with the split_handler
            # and expect it to raise if the interval is invalid
            self._split_handler.prepare_split_dataframe(intervals_dfs[0],
                                                        self._min_timeseries_size_for_training,
                                                        skip_too_short_timeseries=self._skip_too_short_timeseries)
        else:
            # When multiple search intervals, we want at least one interval to be valid,
            # and we accept that some intervals might fail the validation
            at_least_one_valid_interval = False
            for interval_df in intervals_dfs:
                if self._is_valid_hp_search_interval(interval_df, self._skip_too_short_timeseries):
                    at_least_one_valid_interval = True
                    break
            if not at_least_one_valid_interval:
                raise ValueError(
                    "Hyperparameter optimization failed: none of the training subsets are long enough for k-fold cross-validation. You can fix this by increasing the training set size, reducing the number of folds, or enabling Skip too short time series in the general settings.")

    def _is_valid_hp_search_interval(self, interval_df, skip_too_short_timeseries):
        """
        Return true if the interval and can be fully split for HP search
        """
        min_size_required_for_split = self._split_handler.get_min_size_required_for_split(self._min_timeseries_size_for_training)
        valid_identifiers, invalid_identifiers = self._split_handler.retrieve_short_timeseries_identifiers_with_reasons(interval_df, min_size_required_for_split)
        if skip_too_short_timeseries:
            # At least one identifier must be valid
            return len(valid_identifiers) > 0
        else:
            # All identifiers must be valid
            return len(valid_identifiers) > 0 and len(invalid_identifiers) == 0

    def _prepare_intervals_dfs(self, intervals_dfs):
        """
        Prepare each search interval df, by applying any filtering needed by the _split_handler
        :param intervals_dfs: list of pd.DataFrame to validate
        :return:list of pd.DataFrame prepared and usable for HP search
        """
        prepared_intervals_df = []
        for interval_df in intervals_dfs:
            try:
                prepared_interval_df = self._split_handler.prepare_split_dataframe(
                    interval_df, self._min_timeseries_size_for_training,
                    # Force to True to avoid validation failure, as the intervals are expected to be already validated
                    skip_too_short_timeseries=True
                )
                prepared_intervals_df.append(prepared_interval_df)
            except ValueError:
                # If interval is invalid even with skip_too_short_timeseries=True
                # we ignore the interval
                pass

        return prepared_intervals_df

    @staticmethod
    def get_search_intervals(df, time_column, custom_train_test_intervals):
        """
        Return a list of intervals to use during HP search.
        :return:
        List[Tuple[pd.Timestamp, pd.Timestamp]]
           A list of non-overlapping intervals (as tuples of Timestamps).
           The result contains only contiguous intervals usable for HP search.
        """
        if custom_train_test_intervals is not None and len(custom_train_test_intervals) > 0:
            return TimeseriesForecastingSearchContext.get_contiguous_search_intervals(custom_train_test_intervals)
        else:
            # The search interval is used as [start, end) so we need to add a ms to the interval end to ensure we get all df.
            return [(min(df[time_column]), max(df[time_column]) + pd.Timedelta(milliseconds=1))]

    @staticmethod
    def get_contiguous_search_intervals(intervals):
        """
        Computes contiguous search intervals by merging overlapping or adjacent train intervals
        and removing any parts that overlap with evaluation (test) intervals.

        Parameters:
        ----------
        intervals : List[Dict[str, Tuple[str, str]]]
           A list of dictionaries, each containing:
               - 'train': Tuple of (start_datetime, end_datetime) for the training interval.
               - 'test' : Tuple of (start_datetime, end_datetime) for the evaluation/test interval.
           All dates should be strings parsable by `pd.Timestamp`.

        Returns:
        -------
        List[Tuple[pd.Timestamp, pd.Timestamp]]
           A list of non-overlapping training intervals (as tuples of Timestamps) with
           all evaluation overlaps removed. The result contains only the valid, contiguous
           training intervals after filtering.
        """

        # Step 1: Collect all intervals
        train_intervals = [(pd.Timestamp(i["train"][0]), pd.Timestamp(i["train"][1])) for i in intervals]
        test_intervals = [(pd.Timestamp(i["test"][0]), pd.Timestamp(i["test"][1])) for i in intervals]

        # Step 2: Sort and merge contiguous train intervals
        train_intervals.sort()
        train_intervals_union = []
        for start, end in train_intervals:
            if not train_intervals_union:
                train_intervals_union.append([start, end])
            else:
                last_start, last_end = train_intervals_union[-1]
                if start <= last_end:  # overlap or contiguous
                    train_intervals_union[-1][1] = max(last_end, end)
                else:
                    train_intervals_union.append([start, end])
        # Step 3: Remove test/eval intervals from train intervals union
        train_without_test_intervals = []
        for train_start, train_end in train_intervals_union:
            current_parts = [(train_start, train_end)]
            for test_start, test_end in test_intervals:
                next_parts = []
                for part_start, part_end in current_parts:
                    # No overlap
                    if part_end <= test_start or part_start >= test_end:
                        next_parts.append((part_start, part_end))
                    else:
                        # Keep left side
                        if part_start < test_start:
                            next_parts.append((part_start, min(part_end, test_start)))
                        # Keep right side
                        if part_end > test_end:
                            next_parts.append((max(part_start, test_end), part_end))
                current_parts = next_parts
            train_without_test_intervals.extend(current_parts)

        return train_without_test_intervals


def _dku_timeseries_fit_and_score(
    trainable_model, train_df, test_df, historical_df, model_scorer, preprocessed_external_features, shift_map, fit_before_predict, parameters,
    eval_metrics, metric_sign, split_id
):
    log_df(logger, train_df, model_scorer.time_variable, fold_id=split_id, prefix="Hyperparameter search train")
    log_df(logger, test_df, model_scorer.time_variable, fold_id=split_id, prefix="Hyperparameter search test")

    formatted_parameters = ""
    if parameters is not None:
        formatted_parameters = ", ".join("{}={}".format(k, v) for k, v in parameters.items())

    logger.info("Fit s={}: {} {}".format(
        split_id,
        formatted_parameters,
        (64 - len(formatted_parameters)) * "."
    ))

    # Always work with an independent copy of the estimator, as this function
    # may be called in a concurrent context. Hence the need to clone it before
    # assigning the parameters we want to test.
    estimator = trainable_model.clone_estimator(parameters)

    start_time = unix_time_millis()

    estimator.fit(train_df, external_features=preprocessed_external_features, shift_map=shift_map)

    fit_time = unix_time_millis() - start_time
    
    forecasts_by_timeseries = model_scorer.predict_all_test_timesteps(estimator, historical_df, test_df, fit_before_predict)
    _, aggregated_metrics = model_scorer.score(historical_df, test_df, forecasts_by_timeseries, split_id)

    eval_score = aggregated_metrics[eval_metrics]

    score_time = unix_time_millis() - start_time - fit_time

    end_msg = "{} (ft={:.1f} st={:.1f} sc={}, sg={})".format(formatted_parameters, fit_time / 1000, score_time / 1000, eval_score, metric_sign)
    logger.info("Done s={}: {}".format(split_id, end_msg))

    search_result = {
        # 'test_score_gib' is aimed to be used for picking the best estimator (always "greater is better")
        "test_score_gib": dku_nonaninf(metric_sign * eval_score),

        "train_score": None,  # cannot be computed in timeseries because we need past data to forecast future data
        "test_score": dku_nonaninf(eval_score),

        "num_samples": 1,  # TODO @timeseries instead of using 1, we should allow this num_samples param to be missing
        "fit_time": fit_time,
        "score_time": score_time,
        "time": fit_time + score_time,
        "parameters": parameters,
        "done_at": unix_time_millis(),
        "split_id": split_id
    }

    return search_result


@register_as_serializable
class CausalSearchContext(AbstractContext):
    """
    Store all the context required to fit & score the estimator during the hyperparameter search,
    for a causal ML task.

    This context is either used directly or pickled/streamed to a remote worker.

    This object (and any of its properties) should NOT be modified (can be shared between threads)
    """

    def __init__(self, X, y, t, propensities, splits, causal_scorer, causal_learning, metric_sign, causal_trainable_model, treatment_map=None):
        self._X = X
        self._y = y
        self._t = t
        self._propensities = propensities
        self._splits = splits
        self._causal_scorer = causal_scorer
        self._causal_learning = causal_learning
        self._metric_sign = metric_sign
        self._causal_trainable_model = causal_trainable_model
        self._treatment_map = treatment_map

    @property
    def splits(self):
        return self._splits

    @property
    def n_splits(self):
        return len(self._splits)

    @property
    def trainable_model(self):
        return self._causal_trainable_model

    def execute_work(self, split_id, parameters):
        """
        Evaluate an hyper-parameter for one split
        """
        train = self._splits[split_id].train
        test = self._splits[split_id].test

        return _dku_causal_fit_and_score(
            causal_trainable_model=self._causal_trainable_model,
            X=self._X,
            y=self._y,
            t=self._t,
            propensities=self._propensities,
            causal_scorer=self._causal_scorer,
            causal_learning=self._causal_learning,
            train=train,
            test=test,
            parameters=parameters,
            metric_sign=self._metric_sign,
            split_id=split_id,
            treatment_map=self._treatment_map
        )


def _dku_causal_score(dku_causal_model, X, y, t_binary, sample_weights, causal_scorer):
    return causal_scorer(dku_causal_model, X, y, t_binary, sample_weights=sample_weights)


def _dku_causal_fit_and_score(causal_trainable_model, X, y, t, propensities,
                              causal_scorer, causal_learning,
                              train, test, parameters, metric_sign,
                              split_id, treatment_map=None):
    formatted_parameters = ""
    if parameters is not None:
        formatted_parameters = ", ".join("%s=%s" % (k, v) for k, v in parameters.items())

    logger.info("Fit s={split_id}: {formatted_parameters} {dots}".format(
        split_id=split_id,
        formatted_parameters=formatted_parameters,
        dots=(64 - len(formatted_parameters)) * "."
    ))

    start_time = unix_time_millis()

    X_train = dku_indexing(X, train)
    y_train = dku_indexing(y, train)
    t_train = dku_indexing(t, train)
    unique_preprocessed_treatment_values_train = np.unique(t_train)

    X_eval = dku_indexing(X, test)
    y_eval = dku_indexing(y, test)
    t_eval = dku_indexing(t, test)
    unique_preprocessed_treatment_values_eval = np.unique(t_eval)

    if propensities is not None:
        propensities_train = dku_indexing(propensities, train)
        propensities_eval = dku_indexing(propensities, test)
    else:
        propensities_train = None
        propensities_eval = None

    _check_causal_trainable(0, "control", unique_preprocessed_treatment_values_train, is_control=True)
    _check_causal_trainable(0, "control", unique_preprocessed_treatment_values_eval, is_control=True, is_eval=True)
    if causal_trainable_model.is_classifier:
        _check_causal_classif_trainable(0, "control", y_train, t_train, is_control=True)

    if treatment_map is not None:
        logger.info("Multiple treatment values")
        for treatment_value, treatment_index in treatment_map.items_except_control():
            _check_causal_trainable(treatment_index, treatment_value, unique_preprocessed_treatment_values_train)
            _check_causal_trainable(treatment_index, treatment_value, unique_preprocessed_treatment_values_eval, is_eval=True)
            if causal_trainable_model.is_classifier:
                _check_causal_classif_trainable(treatment_index, treatment_value, y_train, t_train)
        # Always work with an independent copy of the estimator, as this function
        # may be called in a concurrent context. Hence the need to clone it before
        # assigning the parameters we want to test.
        estimator = causal_trainable_model.clone_estimator(parameters)
        dku_causal_model_multi = causal_learning.get_dku_causal_model(estimator, causal_trainable_model.is_classifier, treatment_map=treatment_map)
        dku_causal_model_multi.fit(X_train, y_train, t_train)
        fit_time = unix_time_millis() - start_time

        train_score = _dku_causal_score_multi_treatment(dku_causal_model_multi, X_train, y_train, t_train, propensities_train,
                                                        treatment_map, causal_scorer)
        eval_score = _dku_causal_score_multi_treatment(dku_causal_model_multi, X_eval, y_eval, t_eval, propensities_eval,
                                                       treatment_map, causal_scorer)

        score_time = unix_time_millis() - start_time - fit_time
    else:
        _check_causal_trainable(1, "treated", unique_preprocessed_treatment_values_train)
        _check_causal_trainable(1, "treated", unique_preprocessed_treatment_values_eval, is_eval=True)
        if causal_trainable_model.is_classifier:
            _check_causal_classif_trainable(1, "treated", y_train, t_train)
        # Always work with an independent copy of the estimator, as this function
        # may be called in a concurrent context. Hence the need to clone it before
        # assigning the parameters we want to test.
        estimator = causal_trainable_model.clone_estimator(parameters)
        dku_causal_model = causal_learning.get_dku_causal_model(estimator, causal_trainable_model.is_classifier)
        dku_causal_model.fit(X_train, y_train, t_train)

        fit_time = unix_time_millis() - start_time

        # TODO @causal: external sample weights (i.e. not inverse propensity) support for hyperparameter search
        sample_weights_train = np.ones(X_train.shape[0])
        sample_weights_eval = np.ones(X_eval.shape[0])
        if propensities is not None:
            treatment_0_mask_train = t_train == 0
            treatment_1_mask_train = t_train == 1
            sample_weights_train[treatment_0_mask_train] = 1 / propensities_train[treatment_0_mask_train][:,0]
            sample_weights_train[treatment_1_mask_train] = 1 / propensities_train[treatment_1_mask_train][:,1]
            treatment_0_mask_eval = t_eval == 0
            treatment_1_mask_eval = t_eval == 1
            sample_weights_eval[treatment_0_mask_eval] = 1 / propensities_eval[treatment_0_mask_eval][:,0]
            sample_weights_eval[treatment_1_mask_eval] = 1 / propensities_eval[treatment_1_mask_eval][:,1]
        train_score = _dku_causal_score(dku_causal_model=dku_causal_model,
                                        X=X_train, y=y_train, t_binary=t_train,
                                        sample_weights=sample_weights_train,
                                        causal_scorer=causal_scorer)

        eval_score = _dku_causal_score(dku_causal_model=dku_causal_model,
                                       X=X_eval, y=y_eval, t_binary=t_eval,
                                       sample_weights=sample_weights_eval,
                                       causal_scorer=causal_scorer)
        score_time = unix_time_millis() - start_time - fit_time

        end_msg = "{} (ft={:.1f} st={:.1f} sc={}, sg={})".format(formatted_parameters, fit_time / 1000, score_time / 1000, eval_score, metric_sign)
        logger.info("Done s={}: {}".format(split_id, end_msg))

    num_samples = _num_samples(X_eval)
    search_result = {
        # 'test_score_gib' is aimed to be used for picking the best estimator (always "greater is better")
        "test_score_gib": dku_nonaninf(eval_score),

        # Here, 'metric_sign' is used here to get the initial metric's value since 'train_score'
        # and 'test_score' are forced to be 'greater is better' (via make_scorer())
        "train_score": dku_nonaninf(metric_sign * train_score),
        "test_score": dku_nonaninf(metric_sign * eval_score),

        "num_samples": num_samples,
        "fit_time": fit_time,
        "score_time": score_time,
        "time": fit_time + score_time,
        "parameters": parameters,
        "done_at": unix_time_millis(),
        "split_id": split_id
    }

    return search_result


def _dku_causal_score_multi_treatment(dku_causal_model_multi, X, y, t, propensities, treatment_map, causal_scorer):
    train_scores = []
    weights = []
    ATEs = []
    for treatment, i in treatment_map.items_except_control():
        mask_eval = (t == 0) | (t == i)
        mask_t_0_eval = t[mask_eval] == 0
        mask_t_i_eval = t[mask_eval] == i
        if propensities is not None:
            normalized_propensity = propensities[mask_eval, i] / (propensities[mask_eval, i] + propensities[mask_eval, 0])
            sample_weights = np.ones(normalized_propensity.shape[0])

            sample_weights[mask_t_0_eval] = 1 / (1-normalized_propensity[mask_t_0_eval])
            sample_weights[mask_t_i_eval] = 1 / normalized_propensity[mask_t_i_eval]
        else:
            sample_weights = None
        dku_causal_model_single = dku_causal_model_multi._models[treatment]
        train_score_ = _dku_causal_score(dku_causal_model=dku_causal_model_single,
                                         X=X[mask_eval, :], y=y[mask_eval], t_binary=mask_t_i_eval,
                                         sample_weights=sample_weights,
                                         causal_scorer=causal_scorer)
        absATE = abs(np.mean(y[t == i]) - np.mean(y[t == 0]))
        train_score_raw = train_score_ * absATE
        ATEs.append(absATE)
        train_scores.append(train_score_raw)
        weights.append(np.sum(t == i))
    train_score = np.average(train_scores, weights=weights) / np.average(ATEs, weights=weights)
    return train_score


def _check_causal_trainable(treatment_index, treatment_value, all_treatment_index, is_control=False, is_eval=False):
    if treatment_index not in all_treatment_index:
        suffix = "control" if is_control else "treatment = \"{}\"".format(treatment_value)
        fold_name = "test" if is_eval else "train"
        raise ValueError("No data in {} fold for ".format(fold_name) + suffix)


def _check_causal_classif_trainable(treatment_index, treatment_value, y, t, is_control=False):
    all_outcome_index_for_treatment = np.unique(y[t == treatment_index])
    suffix = "control" if is_control else "treatment = \"{}\"".format(treatment_value)
    if 0 not in all_outcome_index_for_treatment:
        raise ValueError("No negative class outcome in train fold for " + suffix)
    if 1 not in all_outcome_index_for_treatment:
        raise ValueError("No positive class outcome in train fold for " + suffix)


def _dku_score(estimator, X, y, scorer, sample_weight=None, indices=None):
    if inspect.isfunction(scorer):
        argspec = get_argspec(scorer)
    else:
        argspec = [[]]  # scorers are callables, ie. classes
    if 'indices' in argspec[0]:  # regular args
        score = scorer(estimator, X, y, sample_weight=sample_weight, indices=indices)
    else:
        score = scorer(estimator, X, y, sample_weight=sample_weight)

    if hasattr(score, 'item'):
        try:
            # e.g. unwrap memmapped scalars
            score = score.item()
        except ValueError:
            # non-scalar?
            pass
    if not isinstance(score, numbers.Number):
        raise ValueError("scoring must return a number, got %s (%s) instead."
                         % (str(score), type(score)))
    return score

def _dku_fit_and_score(trainable_model, X, y, scorer, train, test, parameters,
                       metric_sign, split_id, sample_weight):
    formatted_parameters = ""
    if parameters is not None:
        formatted_parameters = ", ".join("%s=%s" % (k, v) for k, v in parameters.items())

    logger.info("Fit s={split_id}: {formatted_parameters} {dots}".format(
        split_id=split_id,
        formatted_parameters=formatted_parameters,
        dots=(64 - len(formatted_parameters)) * "."
    ))

    # Always work with an independent copy of the estimator, as this function
    # may be called in a concurrent context. Hence the need to clone it before
    # assigning the parameters we want to test.
    estimator = trainable_model.clone_estimator(parameters)

    start_time = unix_time_millis()

    X_train = dku_indexing(X, train)
    y_train = dku_indexing(y, train)

    w_train = None
    if sample_weight is not None:
        w_train = dku_indexing(sample_weight, train)

    X_eval = y_eval = None
    if trainable_model.requires_evaluation_set:
        X_eval = dku_indexing(X, test)
        y_eval = dku_indexing(y, test)

    # Warning: this is a small leak from the split data, as it is also used
    # for the evaluation.
    fit_params = trainable_model.get_fit_parameters(sample_weight=w_train, X_eval=X_eval, y_eval=y_eval)
    fit_params = {k: _dku_index_param_value(X, v, train) for k, v in fit_params.items()}

    # Some fold may not have one of the classes, leading to a failure
    class_weight = estimator.get_params().get("class_weight", None)
    if class_weight is not None:
        classes = np.unique(y_train.values)
        estimator.set_params(class_weight={key: class_weight[key] for key in class_weight.keys() if key in classes})

    dku_fit(estimator, X_train, y_train, **fit_params)

    fit_time = unix_time_millis() - start_time
    # score with sample weights whenever they are enabled, regardless of the support by the algorithm
    train_score = _dku_score(estimator, X_train, y_train, scorer, sample_weight=w_train, indices=train)

    # For memory usage, load the evaluation set here as we don't need X_train anymore
    if X_eval is None:
        X_eval = dku_indexing(X, test)
        y_eval = dku_indexing(y, test)

    w_eval = None
    if sample_weight is not None:
        w_eval = dku_indexing(sample_weight, test)
    # score with sample weights whenever they are enabled, regardless of the support by the algorithm
    eval_score = _dku_score(estimator, X_eval, y_eval, scorer, sample_weight=w_eval, indices=test)

    score_time = unix_time_millis() - start_time - fit_time

    end_msg = "{} (ft={:.1f} st={:.1f} sc={}, sg={})".format(formatted_parameters, fit_time / 1000, score_time / 1000, eval_score, metric_sign)
    logger.info("Done s={}: {}".format(split_id, end_msg))

    num_samples = _num_samples(X_eval)
    search_result = {
        # 'test_score_gib' is aimed to be used for picking the best estimator (always "greater is better")
        "test_score_gib": dku_nonaninf(eval_score),

        # Here, 'metric_sign' is used here to get the initial metric's value since 'train_score'
        # and 'test_score' are forced to be 'greater is better' (via make_scorer())
        "train_score": dku_nonaninf(metric_sign * train_score),
        "test_score": dku_nonaninf(metric_sign * eval_score),

        "num_samples": num_samples,
        "fit_time": fit_time,
        "score_time": score_time,
        "time": fit_time + score_time,
        "parameters": parameters,
        "done_at": unix_time_millis(),
        "split_id": split_id
    }

    extra_attributes = trainable_model.get_extra_per_split_search_result_attributes(estimator)
    search_result.update(extra_attributes)
    return search_result


def _dku_index_param_value(X, v, indices):
    if not _is_arraylike(v) or _num_samples(v) != _num_samples(X):
        # pass through: skip indexing
        return v
    if sp.issparse(v):
        v = v.tocsr()
    return dku_indexing(v, indices)
