import logging
from abc import ABCMeta, abstractmethod

import six
import torch
from sklearn.model_selection import train_test_split

from dataiku.base.folder_context import build_folder_context
from dataiku.core import doctor_constants
from dataiku.doctor import step_constants
from dataiku.doctor.deephub.computer_vision import LearningRateStrategy
from dataiku.doctor.deephub.computer_vision import init_optimizer
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor import utils
from dataiku.doctor.deephub.deephub_model import DeepHubModel
from dataiku.doctor.deephub.deephub_model import DeepHubModelHandler
from dataiku.doctor.deephub.utils.data_accumulator import PredictedDataAccumulator
from dataiku.doctor.deephub.deephub_logger import DeepHubModelTrainingInfoHandler
from dataiku.doctor.deephub.deephub_logger import DeephubLogger
from dataiku.doctor.deephub.deephub_context import get_deephub_context, with_enforced_not_distributed_context
from dataiku.doctor.deephub.utils.deephub_registry import DeepHubRegistry
from dataiku.doctor.diagnostics.diagnostics import DiagnosticsScoringResults
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils.listener import NoOpContext
from dataiku.doctor.utils.listener import ModelStatusContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.split import df_from_split_desc_no_normalization

logger = logging.getLogger(__name__)


class DeepHubTrainingHandler(object):
    """
    Responsible for the full training of a Deephub model:
        * initializing required components, in particular:
            * the training context
            * the engine used for the training
        * leveraging the engine to run the training/evaluation loop
        * Save all relevant information for real time (model information) and future use (performance, model)
    """

    def __init__(self, params, files_reader):
        """
        :type params: dataiku.doctor.deephub.deephub_params.DeepHubTrainingParams
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        """
        self._params = params
        self.split_desc = params.split_desc

        self.model_folder_context = params.model_folder_context
        self.preprocessing_folder_context = build_folder_context(params.preprocessing_folder)
        self.split_folder_context = build_folder_context(params.split_folder)
        self.modeling_params = params.modeling_params
        self.metric_params = params.modeling_params["metrics"]

        self.evaluation_metric = self.metric_params["evaluationMetric"]
        self.num_epochs = self.modeling_params["epochs"]
        self.batch_size = params.batch_size
        self.num_workers = 0
        if self.modeling_params["enableParallelDataLoading"]:
            self.num_workers = self.modeling_params["numWorkers"]

        self.files_reader = files_reader

        # Actual params of the model after model optimisation (contain final value of LR, completed epochs..).
        # will be dumped into a json file then parse back in Java code with corresponding class among children of
        # DeepHubPostTrainModelingParams - depending on 'type' field
        self.resolved_params = {
                "type": "",
                "completedEpochs": 0,
                "finalLR": None,
                "keptModelEpoch": -1,
                "trainValidationSplitRatio": params.get_model_optimization_split_params()["trainSplitRatio"]
        }
        self.base_model = DeepHubModel.build(params.core_params["prediction_type"], params.target_remapping,
                                             self.modeling_params)
        self.engine = DeepHubTrainingEngine.build(params.core_params["prediction_type"],
                                                  params.target_remapping, self.modeling_params,
                                                  params.core_params["target_variable"],
                                                  params.file_path_column)

    def get_resolved_params(self):
        """ Retrieve parameters that have been resolved during model or engine building """
        all_resolved_params = self.resolved_params.copy()
        all_resolved_params.update(self.engine.get_resolved_params())
        all_resolved_params.update(self.base_model.get_resolved_params())
        return all_resolved_params

    def train(self):
        training_context = get_deephub_context()
        start = unix_time_millis()
        if training_context.is_main_process():
            listener = ProgressListener(context=ModelStatusContext(self.model_folder_context, start))
        else:
            listener = ProgressListener(context=NoOpContext(None), verbose=False)
        listener.add_future_steps(step_constants.DEEPHUB_TRAIN_STEPS)

        self._train_evaluate_and_save(listener)

        if training_context.is_main_process():
            end = unix_time_millis()
            self.model_folder_context.write_json("actual_params.json", {"resolved": self.get_resolved_params()})
            # with real nb of epochs, real nb of trainable layers and all
            self.preprocessing_folder_context.write_json("preprocessing_report.json", {})
            utils.write_done_traininfo(self.model_folder_context, start, start, end, listener.to_jsonifiable(),
                                       end_preprocessing_time=start)
        if training_context.distributed:
            training_context.destroy_process_group()

    def _is_best_perf(self, new_value_gib, best_perf):
        """
        :type new_value_gib: float
        :type best_perf: dataiku.doctor.deephub.deephub_evaluation.PerformanceResults
        :return: True if we have no best_perf stored or the new value is better than the stored one
        :rtype: bool
        """
        if best_perf is None:
            return True
        best_value_gib = best_perf.get_metric(self.evaluation_metric).value_gib
        return new_value_gib > best_value_gib

    def _train_evaluate_and_save(self, listener):

        model_handler = DeepHubModelHandler.build_for_pretrained_model(self.model_folder_context, self.base_model)

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TRAIN):
            train_data_loader, validation_data_loader = self._build_train_data_loaders()

        with listener.push_step(step_constants.ProcessingStep.DEEPHUB_TRAINING_LOOP):
            train_info_handler = DeepHubModelTrainingInfoHandler(len(train_data_loader), len(validation_data_loader),
                                                                 self.num_epochs,
                                                                 self.modeling_params["metrics"]["evaluationMetric"],
                                                                 self.model_folder_context)

            training_context = get_deephub_context()
            early_stopping_handler = DeephubEarlyStoppingHandler(self.modeling_params["earlyStopping"],
                                                                 self.model_folder_context)
            self.engine.init_training_params(model_handler.nn_model)

            best_validation_perf = None
            train_info_handler.start_train()
            self.engine.on_train_start(len(train_data_loader))

            for epoch in range(self.num_epochs):
                # DeephubLogger will keep a running average of meters across training, make sure to have independent
                # epoch meters by reinitializing the logger at each epoch
                train_logger = DeephubLogger("Train", iteration_callback=train_info_handler.start_train_step)
                validation_logger = DeephubLogger("Validation", iteration_callback=train_info_handler.start_validation_step)
                train_info_handler.start_epoch(epoch)

                self.engine.train_one_epoch(epoch, model_handler.nn_model, model_handler.device,
                                            train_data_loader, train_logger)

                validation_data_accumulator = PredictedDataAccumulator()
                self.engine.predict_and_accumulate(model_handler.nn_model, model_handler.device, validation_data_loader,
                                                   validation_logger, validation_data_accumulator, epoch=epoch)
                validation_data_accumulator.gather()

                epoch_meters = {}
                if training_context.is_main_process():

                    epoch_meters["trainLoss"] = train_logger.get_meter("loss").get_global_avg_value()
                    if not validation_data_accumulator.has_accumulated_values():
                        raise Exception("Validation did not accumulate any value, "
                                        "validation set empty or without any target ?")

                    logger.info("Start computing validation performance")
                    validation_perf = self.engine.compute_performance(validation_data_accumulator,
                                                                      validation_data_loader.dataset.original_df.index,
                                                                      self.metric_params)
                    validation_score = validation_perf.get_metric(self.evaluation_metric)
                    # "validation" is improperly called "test" in the result file
                    epoch_meters["testScore"] = validation_score.value_gib
                    epoch_meters["testLoss"] = validation_logger.get_meter("loss").get_global_avg_value()
                    logger.info("Finished computing evaluation performance: {}: {}"
                                .format(self.evaluation_metric, validation_score.value))

                    self.engine.on_epoch_end(epoch, validation_score.value_gib)
                    # Decide on whether to save model
                    if self._is_best_perf(validation_score.value_gib, best_validation_perf):
                        logger.info("Best model so far for epoch #{}, saving it".format(epoch))
                        best_validation_perf = validation_perf
                        model_handler.save()
                        epoch_meters["saved"] = True
                        self.resolved_params["keptModelEpoch"] = epoch

                    early_stopping_handler.update_early_stop_status(validation_score)
                    self.resolved_params["completedEpochs"] = epoch + 1
                    self.resolved_params["finalLR"] = train_logger.get_meter("lr").get_global_current_value()

                train_info_handler.end_epoch(epoch_meters)
                if early_stopping_handler.synchronized_early_stop():
                    # at least one process signaled the need to early stop, stopping training for all
                    break

        if training_context.is_main_process():
            with with_enforced_not_distributed_context():
                self._compute_test_performance(listener)

    def _compute_test_performance(self, listener):
        assert not get_deephub_context().distributed

        with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_TEST):
            test_df = self._load_df("test")
            diagnostics.on_load_test_dataset_end(prediction_type=self._params.core_params["prediction_type"],
                                                 df=test_df,
                                                 target_variable=self._params.core_params["target_variable"])
            test_dataset = self._build_dataset("test", test_df, for_eval=True)
            test_sampler = torch.utils.data.SequentialSampler(test_dataset)
            test_data_loader = self._build_data_loader(test_dataset, test_sampler)

        with listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            # Needs to reload best model from disk, as it might not be the last one
            best_model_handler = DeepHubModelHandler.build_for_scoring(self.model_folder_context, self.base_model)
            test_logger = DeephubLogger("Test")
            test_data_accumulator = PredictedDataAccumulator()
            self.engine.predict_and_accumulate(best_model_handler.nn_model, best_model_handler.device, test_data_loader,
                                               test_logger, test_data_accumulator)
            test_data_accumulator.gather()

            if not test_data_accumulator.has_accumulated_values():
                raise Exception("Test did not accumulate any value, test set empty or without any target ?")

            logger.info("Start computing test performance")
            test_perf = self.engine.compute_performance(test_data_accumulator,
                                                        test_data_loader.dataset.original_df.index,
                                                        self.metric_params)
            self.save_perf(test_perf)

            scoring_results = DiagnosticsScoringResults(prediction_type=self.engine.TYPE,
                                                        metrics_params=self.metric_params,
                                                        perf_data=test_perf.perf,
                                                        test_predictions=test_perf.raw_predictions)
            diagnostics.on_scoring_end(scoring_results=scoring_results)

    def _load_df(self, dataset_name):
        return df_from_split_desc_no_normalization(self.split_desc, dataset_name, self.split_folder_context,
                                                   self._params.preprocessing_params["per_feature"],
                                                   self._params.core_params["prediction_type"])

    def _build_dataset(self, dataset_name, df, for_eval):
        dataset = self.engine.build_dataset(df, self.files_reader, self.base_model, for_eval=for_eval)
        if len(dataset) == 0:
            raise Exception("All rows of {} dataset are dropped, cannot complete training".format(dataset_name))
        return dataset

    def _build_data_loader(self, dataset, sampler):
        multiprocessing_context = None
        if self.num_workers > 0:
            multiprocessing_context = get_deephub_context().get_dataloader_multiprocessing_context()
        logger.info("Using {} worker(s) for data loading".format(self.num_workers))
        return torch.utils.data.DataLoader(dataset, batch_size=self.batch_size,
                                           num_workers=self.num_workers, sampler=sampler,
                                           collate_fn=dataset.data_loader_collate_fn,
                                           multiprocessing_context=multiprocessing_context)

    def _build_train_data_loaders(self):
        # Deephub only accepts train/test scheme, so we are sure to have "train" and "test" datasets
        full_train_df = self._load_df("train")

        if get_deephub_context().is_main_process():
            diagnostics.on_load_train_dataset_end(prediction_type=self._params.core_params["prediction_type"],
                                                  df=full_train_df,
                                                  target_variable=self._params.core_params["target_variable"])

        model_optim_split_params = self._params.get_model_optimization_split_params()
        # Need to use a random state to get exact same split on distributed context
        train_df, validation_df = train_test_split(full_train_df,
                                                   train_size=model_optim_split_params["trainSplitRatio"],
                                                   random_state=model_optim_split_params["seed"])

        train_dataset = self._build_dataset("train", train_df, for_eval=False)
        validation_dataset = self._build_dataset("validation", validation_df, for_eval=True)

        training_context = get_deephub_context()

        if training_context.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
            validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset)
        else:
            train_sampler = torch.utils.data.RandomSampler(train_dataset)
            validation_sampler = torch.utils.data.SequentialSampler(validation_dataset)

        train_data_loader = self._build_data_loader(train_dataset, train_sampler)
        validation_data_loader = self._build_data_loader(validation_dataset, validation_sampler)

        return train_data_loader, validation_data_loader

    def save_perf(self, test_results):
        """
        :type test_results: dataiku.doctor.deephub.deephub_evaluation.PerformanceResults
        """
        perf = test_results.to_dict()
        self.model_folder_context.write_json("perf.json", perf)
        self.model_folder_context.write_json("prediction_statistics.json",
                                             test_results.prediction_statistics)
        with self.model_folder_context.get_file_path_to_write("predicted.csv") as predicted_file_path:
            pred_df = test_results.get_predicted_data()
            pred_df.to_csv(predicted_file_path, sep="\t", header=True,
                           index=False, encoding='utf-8')


@six.add_metaclass(ABCMeta)
class DeepHubTrainingEngine(object):
    """
    Defines the logic of the training:
        * How to build and preprocess training data
        * What model to use
        * Handling of training/evaluation loops
        * Computing performance metrics
    """
    TYPE = "DEEP_HUB_ENGINE"
    REGISTRY = DeepHubRegistry()
    DUMMY = False

    def __init__(self, target_remapping, modeling_params, target_variable, file_path_col):
        self.target_remapping = target_remapping
        self.modeling_params = modeling_params
        self.target_variable = target_variable
        self.file_path_col = file_path_col

        # Used by the training loop, will be defined leveraging the model in `init_training_params`
        self.optimizer = None
        self.lr_scheduler_strategy = None

    @staticmethod
    def build(prediction_type, target_remapping, modeling_params, target_variable, file_path_col):
        """
        :rtype: DeepHubTrainingEngine
        """
        dummy = modeling_params.get("dummy", False)
        try:
            training_class = DeepHubTrainingEngine.REGISTRY.get(prediction_type, dummy)
        except KeyError:
            raise ValueError("Unknown training engine: {} (dummy={})".format(prediction_type, dummy))
        return training_class(target_remapping, modeling_params, target_variable, file_path_col)

    @staticmethod
    def define(training_class):
        DeepHubTrainingEngine.REGISTRY.register(training_class.TYPE, training_class.DUMMY, training_class)

    @abstractmethod
    def build_dataset(self, df, files_reader, model, for_eval=False):
        """
        :param for_eval: is the dataset used for evaluation
        :type model: DeepHubModel
        :type for_eval: bool
        :type df: pd.DataFrame
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :rtype: dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset
        """
        raise NotImplementedError()

    def init_training_params(self, model):
        """
        :param model: torch model
        """
        logger.info("Initializing training params with modeling params: {}".format(self.modeling_params))
        self.optimizer = init_optimizer(model, self.modeling_params)
        self.lr_scheduler_strategy = LearningRateStrategy(self.modeling_params, self.optimizer)

    def on_train_start(self, num_batches_per_epoch):
        """
        :type num_batches_per_epoch: int
        """
        self.lr_scheduler_strategy.on_train_start(num_batches_per_epoch)

    @abstractmethod
    def train_one_epoch(self, epoch, model, device, train_data_loader, deephub_logger):
        """
        :type epoch: int
        :param model: torch model
        :param device: device on which model is hosted
        :type train_data_loader: torch.utils.data.DataLoader
        :type deephub_logger: dataiku.doctor.deephub.deephub_logger.DeephubLogger

        """
        raise NotImplementedError()

    def on_epoch_end(self, epoch_index, val_metric):
        """
        Everything there is to do on epoch end. example: update LR scheduler
        :type epoch_index: int
        :type val_metric: float
        """
        self.lr_scheduler_strategy.on_epoch_end(epoch_index, val_metric)

    @abstractmethod
    def predict_and_accumulate(self, model, device, data_loader, deephub_logger, data_accumulator, epoch=None):
        """
        Iterate over data loader, run predictions, and accumulate data, to be able to then compute performances
        on it

        :type epoch: int | None
        :param model: torch model
        :param device: device on which model is hosted
        :type data_loader: torch.utils.data.DataLoader
        :type deephub_logger: dataiku.doctor.deephub.deephub_logger.DeephubLogger
        :type data_accumulator: dataiku.doctor.deephub.utils.PredictedDataAccumulator
        """
        raise NotImplementedError()

    @abstractmethod
    def compute_performance(self, data_accumulator, origin_index, metric_params):
        """
        Compute performance on accumulated data

        :type data_accumulator: dataiku.doctor.deephub.utils.data_accumulator.PredictedDataAccumulator
        :type origin_index: pandas.core.indexes.range.RangeIndex
        :type metric_params: dict
        :rtype: dataiku.doctor.deephub.deephub_evaluation.PerformanceResults
        """
        raise NotImplementedError()

    def get_resolved_params(self):
        """
        :return: {param_name : param_value} for every param that was resolved when building the engine
        :rtype: dict
        """
        return {'type': self.TYPE}


class DeephubEarlyStoppingHandler(object):

    def __init__(self, params, run_folder_context):
        """
        :type run_folder_context: dataiku.base.folder_context.FolderContext
        """
        self.auto_stop_enabled = params["enabled"]  # whether to apply automatic early stopping. if set to false, training can only be early-interrupted manually (suspend option).
        self.min_delta = params["minDelta"]  # minimal value for a score change to be considered significant
        self.patience = params["patience"]  # number of epochs to wait without significant improvement

        self.epochs_since_last_significant_improvement = 0
        self.best_significant_score = None

        self.training_context = get_deephub_context()
        device = self.training_context.get_device()
        self.stop_status = torch.tensor([0]).to(device)
        self.continue_status = torch.tensor([1]).to(device)

        # When used in distributed training, main process can change the early_stop_status to STOP_STATUS, which will be
        # used during synchronisation step to stop all the other processes
        self.early_stop_status = self.continue_status

        self._run_folder_context = run_folder_context

    def update_early_stop_status(self, current_score):
        """
            Compute whether processes must stop at the end of current epoch depending on current score &
            past values of this score.
            Note that this function is called only from main process when used in distributed training.

            updates self.early_stop_status to self.stop_status if the training loop should stop (if too many epochs were spent without significant
            improvement to score) - self.early_stop_status stays equal to self.continue_status otherwise.

            :type current_score: dataiku.doctor.deephub.deephub_evaluation.SignedMetric
        """
        if self.auto_stop_enabled:
            if self.best_significant_score is None or current_score.value_gib > self.best_significant_score + self.min_delta:
                self.best_significant_score = current_score.value_gib
                self.epochs_since_last_significant_improvement = 0
            else:
                self.epochs_since_last_significant_improvement += 1
                logger.info("Score has not improved this epoch, got {:.4f}. Best eval score so far: {:.4f}. "
                            "last improvement was {} epoch(s) ago"
                            .format(current_score.value,
                                    self.best_significant_score * current_score.sign,
                                    self.epochs_since_last_significant_improvement
                                    )
                            )
            if self.epochs_since_last_significant_improvement >= self.patience:
                logger.info("{} epoch(s) since last significant improvement in score, early stopping training."
                            .format(self.epochs_since_last_significant_improvement))
                self.early_stop_status = self.stop_status

    def synchronized_early_stop(self):
        """ This function is called by all processes, whether they are main processes or not (in distributed context)
            return True if current process should exit training loop, False if it can continue.
        """
        if self._run_folder_context.isfile(doctor_constants.STOP_SEARCH_FILENAME, allow_cached=False):
            # There was a manual action to early stop the training, all the workers must stop.
            logger.info("Manual interruption signal received, early stopping training.")
            return True

        if not self.auto_stop_enabled:
            return False

        if not self.training_context.distributed:
            return torch.equal(self.early_stop_status, self.stop_status)

        # Distributed context: update early_stop_status of all processes with main process value
        torch.distributed.broadcast(self.early_stop_status, src=0)
        return torch.equal(self.early_stop_status, self.stop_status)
