import inspect
import logging
import os
import time
from threading import Thread
import numpy as np
from sklearn.model_selection import check_cv
from sklearn.utils import indexable

from dataiku.base.utils import detect_usable_cpu_count
from dataiku.core import doctor_constants
from dataiku.doctor.crossval.result_store import NoopResultStore
from dataiku.doctor.crossval.result_store import OnDiskResultStore
from dataiku.doctor.crossval.search_context import ClassicalSearchContext
from dataiku.doctor.crossval.search_context import CausalSearchContext
from dataiku.doctor.crossval.search_context import TimeseriesForecastingSearchContext
from dataiku.doctor.crossval.search_context import Split
from dataiku.doctor.crossval.search_evaluation_monitor import SearchEvaluationMonitor
from dataiku.doctor.crossval.search_evaluator import SearchEvaluator
from dataiku.doctor.distributed.local_worker import LocalWorker
from dataiku.doctor.distributed.remote_worker_client import RemoteWorkerClient
from dataiku.doctor.distributed.remote_worker_client import WorkersStartupMonitor
from dataiku.doctor.distributed.work_scheduler import WorkScheduler
from dataiku.doctor.distributed.worker_splitter import WorkerSplitter
from dataiku.doctor.utils import unix_time_millis

logger = logging.getLogger(__name__)


class TabularSearchRunner(object):
    def __init__(self, trainable_model, search_settings, model_folder_context, evaluation_metric):

        self._trainable_model = trainable_model

        self.model_folder_context = model_folder_context
        self.evaluation_metric = evaluation_metric

        self.search_settings = search_settings

        if self.model_folder_context is None:
            self.result_store = NoopResultStore()
        else:
            self.result_store = OnDiskResultStore(self.model_folder_context)

        self.monitor = SearchEvaluationMonitor(search_settings.distributed)

        # Needs to be initialized before calling get_best_estimator
        self.search_context = None

        # Defined after search
        self.aggregated_results = None
        self.best_result = None

    def get_final_fit_parameters(self, sample_weight=None):
        return self._trainable_model.get_fit_parameters(sample_weight, is_final_fit=True)

    def get_experiments_count(self):
        nb_experiments = self.search_settings.search_strategy.get_experiments_count()
        if self.search_settings.n_iter is not None:
            if nb_experiments is None:
                nb_experiments = self.search_settings.n_iter
            else:
                nb_experiments = min(self.search_settings.n_iter, nb_experiments)
        return nb_experiments

    def search_skipped(self):
        if self._trainable_model.must_search:
            return False

        nb_experiments = self.get_experiments_count()
        return nb_experiments is not None and nb_experiments <= 1

    def _build_work_scheduler(self, search_context):
        # REMOTE_WORKER_POOL_ID is set by DSS/JEK whenever it makes sense
        # (ie. we are running in K8S)
        remote_worker_pool_id = os.getenv("REMOTE_WORKER_POOL_ID")

        n_usable_cpus = n_threads = detect_usable_cpu_count()
        n_threads = self.search_settings.n_threads if self.search_settings.n_threads > 0 else n_usable_cpus
        workers = []

        # Distributed mode
        if self.search_settings.distributed and remote_worker_pool_id:
            n_remote_containers = max(0, self.search_settings.n_containers - 1)

            # Create 'n_threads' threads in the master
            for _ in range(n_threads):
                workers.append(LocalWorker())

            # Monitor the remote workers startup (for diagnostics)
            workers_startup_monitor = WorkersStartupMonitor()

            # Create 'n_containers' additional containers, each of them split into 'n_threads' threads
            for _ in range(n_remote_containers):
                remote_worker = RemoteWorkerClient(remote_worker_pool_id, workers_startup_monitor)

                # Note: split_worker() is no-op unless 'n_threads > 1'
                workers += WorkerSplitter.split_worker(remote_worker, n_threads)

            logger.info(
                "Distribute hyperparameter search using up to %s K8S container(s) with %s thread(s) per container"
                % (self.search_settings.n_containers, n_threads))

            scheduler = WorkScheduler(workers, search_context)
            scheduler.register_interrupt_callback(workers_startup_monitor.suspend)

        # Threaded mode
        else:
            # Create 'n_threads' threads in the master
            for _ in range(n_threads):
                workers.append(LocalWorker())

            logger.info("Execute hyperparameter search locally on %s threads" % n_threads)

            scheduler = WorkScheduler(workers, search_context)

        return scheduler

    def initialize_search_context(self, *args, **kwargs):
        raise NotImplementedError("initialize_search_context is not implemented")

    def get_best_estimator(self):
        if self.search_skipped():
            logger.info("Got single-point space, not performing hyperparameter search")
            default_parameters = self.search_settings.search_strategy.get_default_parameters()
            return self._trainable_model.clone_estimator(default_parameters)

        n_splits = self.search_context.n_splits
        nb_experiments = self.get_experiments_count()
        if nb_experiments is None:
            logger.info("Fitting {} folds for each candidate, for {}min".format(n_splits, self.search_settings.timeout))
        else:
            logger.info("Fitting {0} folds for each of {1} candidates, totalling"
                        " {2} fits".format(n_splits, nb_experiments, nb_experiments * n_splits))

        with self._build_work_scheduler(self.search_context) as scheduler:
            self.result_store.init_result_file(nb_experiments, scheduler.get_workers_count(), n_splits,
                                               self.evaluation_metric, self.search_settings.timeout)

            evaluator = SearchEvaluator(scheduler, self.search_context, self.result_store, self.search_settings.n_iter, self.monitor)

            with InterruptThread(scheduler, self.model_folder_context, self.search_settings.timeout, self.result_store):
                self.aggregated_results = self.search_settings.search_strategy.explore(evaluator)

                if len(self.aggregated_results) == 0:
                    raise ValueError("No results found during hyperparameter search, probably because all generated points were invalid")

                for aggregated_result in self.aggregated_results:
                    if aggregated_result["testScoreGibMean"] is None:  # aggregated_result has already been through dku_nonaninf, so we check for None values
                        logger.warning("Hyperparameter point {} has an invalid metric value on all {} folds".format(aggregated_result["parameters"], n_splits))
                    elif self.best_result is None or aggregated_result["testScoreGibMean"] > self.best_result["testScoreGibMean"]:
                            self.best_result = aggregated_result

                if self.best_result is None:
                    raise ValueError("Searched hyperparameter points all have an invalid metric value (NaN)")

            self.result_store.save_current_progress()

        best_parameters = self.best_result["parameters"]
        logger.info('Hyperparameter search done, best_parameters being : {}'.format(best_parameters))
        return self._trainable_model.clone_estimator(best_parameters)

    def get_score_info(self):
        return {
            "usedGridSearch": not self.search_skipped(),
            "gridSize": len(self.aggregated_results),
            "gridBestScore": self.best_result["testScoreMean"],
            "gridCells": [{'params': er["parameters"],
                           'score': er["testScoreMean"], 'scoreStd': er["testScoreStd"],
                           'fitTime': er["fitTimeMean"] / 1000, 'fitTimeStd': er["fitTimeStd"] / 1000,
                           'scoreTime': er["scoreTimeMean"] / 1000, 'scoreTimeStd': er["scoreTimeStd"] / 1000}
                          for er in self.aggregated_results],
        }


class ClassicalSearchRunner(TabularSearchRunner):

    def __init__(self, trainable_model, scoring, cv, search_settings, model_folder_context, evaluation_metric,
                 custom_evaluation_metric_gib):

        super(ClassicalSearchRunner, self).__init__(trainable_model=trainable_model,
                                                    search_settings=search_settings,
                                                    model_folder_context=model_folder_context,
                                                    evaluation_metric=evaluation_metric)

        self.scoring = scoring
        self.cv = cv
        self.metric_sign = self.get_metric_sign(scoring, custom_evaluation_metric_gib)

    def initialize_search_context(self, X, y, groups=None, sample_weight=None, class_weight=None, monotonic_cst=None):
        """
        :param np.ndarray X: The train/test dataframe as a 2D numpy array
        :param pd.Series y: The target column
        :param pd.Series groups: The groups column, if group k-fold is being used, or None.
        :param pd.Series sample_weight: The sample weights, if these are being used, or None.
        :param dict class_weight: The class weights, if these are being used, or None.
        """
        self._trainable_model.set_class_weight(class_weight)
        self._trainable_model.set_monotonic_cst(monotonic_cst)

        if not self.search_skipped():
            X, y, groups, sample_weight = indexable(X, y, groups, sample_weight)
            cv = check_cv(self.cv, y, classifier=self._trainable_model.is_classifier)
            splits = [Split(train, test) for train, test in cv.split(X, y, groups)]

            scorer = self._trainable_model.get_scorer(self.scoring)
            self.search_context = ClassicalSearchContext(X, y, splits, sample_weight, scorer,
                                        self.metric_sign, self._trainable_model)
    
    @staticmethod
    def get_metric_sign(scoring, custom_evaluation_metric_gib):
        if inspect.isfunction(scoring) and not hasattr(scoring, "_sign"):
            # custom scoring func, the scorer is wrapped, so no access to the sign
            return 1 if custom_evaluation_metric_gib else -1
        else:
            return getattr(scoring, "_sign", 1)


class TimeseriesForecastingSearchRunner(TabularSearchRunner):

    def __init__(
            self, trainable_model, search_settings, model_folder_context, split_handler, fit_before_predict,
            min_timeseries_size_for_training, evaluation_metric, metric_sign, skip_too_short_timeseries,
            custom_train_test_intervals):
        super(TimeseriesForecastingSearchRunner, self).__init__(trainable_model=trainable_model,
                                                                search_settings=search_settings,
                                                                model_folder_context=model_folder_context,
                                                                evaluation_metric=evaluation_metric)
        self.fit_before_predict = fit_before_predict
        self.min_timeseries_size_for_training = min_timeseries_size_for_training
        self.split_handler = split_handler
        self.model_scorer = None
        self.preprocessed_external_features = None
        self.shift_map = None
        self.metric_sign = metric_sign
        self.skip_too_short_timeseries = skip_too_short_timeseries
        self.custom_train_test_intervals = custom_train_test_intervals

    def set_preprocessed_external_features(self, preprocessed_external_features, shift_map):
        self.preprocessed_external_features = preprocessed_external_features
        self.shift_map = shift_map

    def set_model_scorer(self, model_scorer):
        self.model_scorer = model_scorer

    def initialize_search_context(self, df):
        if not self.search_skipped():
            self.search_context = TimeseriesForecastingSearchContext(
                df, self.split_handler, self.model_scorer, self.preprocessed_external_features, self.shift_map, self.fit_before_predict,
                self.min_timeseries_size_for_training, self.evaluation_metric, self.metric_sign, self._trainable_model,
                self.skip_too_short_timeseries, self.custom_train_test_intervals
            )


class CausalSearchRunner(TabularSearchRunner):

    def __init__(self, trainable_model, cv, search_settings, model_folder_context, causal_scorer,
                 causal_learning, evaluation_metric, propensity_settings, treatment_map=None):

        super(CausalSearchRunner, self).__init__(trainable_model=trainable_model,
                                                 search_settings=search_settings,
                                                 model_folder_context=model_folder_context,
                                                 evaluation_metric=evaluation_metric)
        self.causal_learning = causal_learning
        self.causal_scorer = causal_scorer
        self.cv = cv
        self.treatment_map = treatment_map
        self.propensity_settings = propensity_settings

    def initialize_search_context(self, X, y, t, groups=None, compute_propensity=False):
        """
        :param np.ndarray X: The train/test dataframe as a 2D numpy array
        :param pd.Series y: The target column
        :param pd.Series t: The treatment column
        :param pd.Series groups: The groups column, if group k-fold is being used, or None.
        :param bool compute_propensity
        """
        if not self.search_skipped():
            X, y, t = indexable(X, y, t)
            cv = check_cv(self.cv, y, classifier=self._trainable_model.is_classifier)
            splits = [Split(train, test) for train, test in cv.split(X, y, groups=groups)]
            if compute_propensity:
                logger.info("Computing propensity scores on the train set for hyperparameter search")
                propensities = self._build_propensities(X, t, splits)
            else:
                propensities = None
            metric_sign = 1  # Current causal metrics are all greater-is-better
            self.search_context = CausalSearchContext(X, y, t, propensities, splits, self.causal_scorer, self.causal_learning, metric_sign, self._trainable_model, self.treatment_map)

    def _build_propensities(self, X, t, splits):
        calibrate_proba = self.propensity_settings["calibrateProbabilities"]
        calibration_data_ratio = self.propensity_settings["calibrationDataRatio"]
        n_treatments = 2 if self.treatment_map is None else len(self.treatment_map.items())
        propensities = np.ones((t.shape[0], n_treatments))
        from dataiku.doctor.causal.utils.models import train_propensity_model
        for split in splits:
            propensity_model = train_propensity_model(X[split.train], t[split.train], calibrate_proba, calibration_data_ratio)
            propensities[split.test,:] = propensity_model.predict_proba(X[split.test])
        return propensities


class InterruptThread(Thread):
    def __init__(self, scheduler, model_folder_context, timeout, result_store):
        """
        :type scheduler: dataiku.doctor.distributed.work_scheduler.WorkScheduler
        :type model_folder_context: dataiku.base.folder_context.FolderContext or None
        :param timeout: optional timeout in minutes
        :type timeout: int or None
        :type result_store: dataiku.doctor.crossval.result_store.AbstractResultStore
        """
        super(InterruptThread, self).__init__()
        self.scheduler = scheduler
        self._model_folder_context = model_folder_context
        self.timeout = timeout
        self.result_store = result_store
        self.watching = True

    def run(self):
        while True:
            if not self.watching:
                logger.info('Completed search for hyperparameters')
                self.result_store.update_final_grid_size()
                break

            if self.planned_end_time_ms is not None and unix_time_millis() > self.planned_end_time_ms:
                logger.info('Aborting search for hyperparameters (timeout)')
                break

            # Note: make sure must_interrupt() isn't called too often, because it might be slow in container mode
            if (
                    self._model_folder_context is not None
                    and self._model_folder_context.isfile(doctor_constants.STOP_SEARCH_FILENAME, allow_cached=False)
            ):
                logger.info('Aborting search for hyperparameters (user)')
                break

            time.sleep(1)

        # Shutdown the scheduler to smoothly interrupt the search
        # (wait for current work to complete, reject new work)
        self.scheduler.interrupt_soft()

    def _get_planned_end_time_ms(self):
        return unix_time_millis() + self.timeout * 60 * 1000 if self.timeout is not None else None

    def __enter__(self):
        self.planned_end_time_ms = self._get_planned_end_time_ms()
        self.start()
        return self

    def __exit__(self, *_):
        self.watching = False
        self.join()


class SearchSettings(object):
    """Class that contains parameters about the hyperparameter search and used in the SearchRunenr.

    Args:
        search_strategy (AbstractSearchStrategy): One of GridSearchStrategy/RandomSearchStrategy/BayesianSearchStrategy.
        n_threads (int): Number of threads (-1 if auto).
        distributed (boolean): Distribute search over Kubernetes containers.
        n_containers (int): Number of Kubernetes containers.
        n_iter (int): Maximum number of hyperparameter combinations to explore (0 to unconstrain).
        timeout (int): Time limit in minutes (0 to unconstrain).
    """

    def __init__(self, search_strategy, n_threads, distributed, n_containers, n_iter, timeout):
     
        self.search_strategy = search_strategy
        self.n_threads = n_threads
        self.distributed = distributed
        self.n_containers = n_containers
        self.n_iter = n_iter
        self.timeout = timeout
