import logging
import numpy as np
import torch
import torch.distributed as dist
import time

from dataiku.doctor.deephub.deephub_context import get_deephub_context
from dataiku.doctor.utils.model_training_info import ModelTrainingInfoHandler
from dataiku.doctor.utils import unix_time_millis

logger = logging.getLogger(__name__)


class LoggedMeter(object):
    def __init__(self):
        # Keep latest value of a meter, averaged across workers if more than one (distributed)
        self._current_local_value = 0.0
        self._current_global_value = 0.0

        # Keep the accumulated sum of a meter's values and the number of values accumulated so far, to be able to return at any time the
        #  running averaged value (across all workers if more than one (distributed))
        self._nb_accumulated = 0
        self._local_summed_values = 0.0
        self._global_averaged_value = 0.0


    def update(self, value):
        self._current_local_value = value
        self._nb_accumulated += 1
        self._local_summed_values += value

    def synchronize(self):
        """
        Synchronize all workers (distributed mode) to update the global values according to all the workers's local values.
        """
        training_context = get_deephub_context()

        if self._nb_accumulated == 0:
            return   # no value to synchronize

        if not training_context.distributed:
            self._current_global_value = self._current_local_value
            self._global_averaged_value = self._local_summed_values / self._nb_accumulated
            return

        # synchronize current values && summed_values
        to_reduce = torch.tensor([self._current_local_value, self._local_summed_values, self._nb_accumulated], dtype=torch.float64, device=training_context.get_device())
        dist.all_reduce(to_reduce)  # summing each tensor with the corresponding ones in other workers

        self._current_global_value = 1.0 * to_reduce.tolist()[0] / training_context.world_size
        self._global_averaged_value = 1.0 * to_reduce.tolist()[1] / to_reduce.tolist()[2]

    def get_global_current_value(self):
        """
        Return the last synchronized current value of a meter across the running workers.
        /!\ warning: must be recomputed using synchronize to stay up to date when new updates are made on the different workers
        """
        if self._nb_accumulated == 0:
            return None
        return self._current_global_value

    def get_global_avg_value(self):
        """
        Return the last synchronized accumulated value of a meter across the running workers.
        /!\ warning: must be recomputed using synchronize to stay up to date when new updates are made on the different workers
        """
        if self._nb_accumulated == 0:
            return None
        return self._global_averaged_value

    def format_values(self):
        """ Return string ready to be loggued from meter values """
        return "{:.2e} (avg={:.2e})".format(self.get_global_current_value(), self.get_global_avg_value()) if self._nb_accumulated > 0 else "-"


class RedrawIterator(object):
    def __init__(self, data_loader, random_state, max_empty_batch_retries=100):
        """
        :param data_loader: initial data loader on which we potentially want to generate backup data
        :type data_loader: torch.utils.data.DataLoader
        :type random_state: random.RandomState
        :type max_empty_batch_retries: int
        """
        self._data_loader = data_loader
        self._iterator = None
        self._backup_iterator = None
        self._retries = 0
        self._random_state = random_state
        self._max_empty_batch_retries = max_empty_batch_retries

    def _init_backup_iterator(self):
        backup_data_loader = torch.utils.data.DataLoader(self._data_loader.dataset, batch_size=self._data_loader.batch_size,
                                                         num_workers=0, sampler=self._data_loader.sampler,
                                                         collate_fn=self._data_loader.collate_fn)
        fake_epoch = self._random_state.randint(1000)
        set_epoch_on_sampler(backup_data_loader.sampler, fake_epoch)
        self._backup_iterator = iter(backup_data_loader)

    def __next__(self):
        batch = next(self._iterator)
        if batch is not None:
            return batch
        if self._backup_iterator is None:
            self._init_backup_iterator()
        for backup_batch in self._backup_iterator:  # Fetch from backup until getting a non-empty batch
            if self._retries >= self._max_empty_batch_retries:
                raise Exception("batch load retried {} times due to empty batches, aborting".format(self._retries))
            if backup_batch is not None:
                return backup_batch
            self._retries += 1
        raise Exception("Exhausted backup data iterator, please check quality of your data")

    def __iter__(self):
        self._iterator = iter(self._data_loader)
        self._backup_iterator = None
        return self


class DeephubLogger(object):

    def __init__(self, name, iteration_callback=None, print_freq=10, max_retries_empty_batch_retries=10):
        self._print_freq = print_freq
        self._name = name
        self._meters = {}  # logged every few batches (typically lr, loss)
        self._iteration_callback = iteration_callback

        self._max_empty_batch_retries = max_retries_empty_batch_retries
        self._random_state = np.random.RandomState(seed=1337)

    def update_meter(self, name, value):
        if name not in self._meters:
            self._meters[name] = LoggedMeter()
        self._meters[name].update(value)

    def get_meter(self, meter_name):
        """ :return LoggedMeter
            Raise an exception if the meter asked was not logged beforehand.
        """
        if meter_name not in self._meters:
            raise Exception("Cannot get meter '{}', does not exist".format(meter_name))

        return self._meters[meter_name]

    def _formatted_meters(self):
        if len(self._meters) > 0:
            return " - " + ", ".join(["{}: {}".format(n, v.format_values()) for n, v in self._meters.items()])
        else:
            return ""

    def _synchronize(self):
        for sm in self._meters.values():
            sm.synchronize()

    def iter_over_data(self, data_loader, epoch=None, redraw_batch_if_empty=False):
        """
        :param data_loader: data on which to iterate
        :type data_loader: generator
        :param epoch: [optional] epoch number that we iterate on. If defined, will just impact the first and last logs
                                 by mentioning the epoch number
        :type epoch: int
        :type redraw_batch_if_empty: bool
        """
        epoch_text = "Epoch #{} - ".format(epoch) if epoch is not None else ""
        is_main_process = get_deephub_context().is_main_process()
        if is_main_process:
            logger.info("{}Start {}".format(epoch_text, self._name))
        num_steps = len(data_loader)
        batch_number_format = ":{}d".format(len(str(num_steps)))
        start_time = time.time()

        if epoch is not None:
            set_epoch_on_sampler(data_loader.sampler, epoch)

        if redraw_batch_if_empty:
            data_generator = RedrawIterator(data_loader, self._random_state)
        else:
            data_generator = data_loader

        for index, element in enumerate(data_generator):
            if index % self._print_freq == 0:
                self._synchronize()
                if is_main_process:
                    log_to_format = "{" + batch_number_format + "}/{}{}"
                    logger.info(log_to_format.format(index, num_steps, self._formatted_meters()))
                    if get_deephub_context().cuda_based:
                        log_to_format += "- memory={:.0f}MB".format(torch.cuda.max_memory_allocated() /
                                                                    (1024.0 * 1024.0))
            if self._iteration_callback and callable(self._iteration_callback):
                self._iteration_callback()
            yield element
        iteration_time = time.time() - start_time
        self._synchronize()
        if is_main_process:
            logger.info("{}End {} - {:.1f}s{}".format(epoch_text, self._name, iteration_time,
                                                      self._formatted_meters()))


class DeepHubModelTrainingInfoHandler(ModelTrainingInfoHandler):
    INFO_FILENAME = "deephub_model_info_handler"

    def __init__(self, num_train_steps, num_validation_steps, num_epochs, metric, folder_context, delay=2):
        super(DeepHubModelTrainingInfoHandler, self).__init__(folder_context, delay)
        self.num_train_steps = num_train_steps
        self.num_validation_steps = num_validation_steps
        self.num_epochs = num_epochs
        self.metric = metric

        self.started_at = None
        self.epochs = []
        self.current_epoch = 0
        self.current_num_steps_training = 0
        self.current_num_steps_scoring = 0
        self._current_epoch_start_time = None
        self.kept_model_epoch = -1

    def start_epoch(self, epoch):
        self._current_epoch_start_time = unix_time_millis()
        self.current_epoch = epoch
        self.current_num_steps_training = 0
        self.current_num_steps_scoring = 0
        self.update_info(force=True)

    def start_train_step(self):
        self.current_num_steps_training += 1
        self.update_info()

    def start_validation_step(self):
        self.current_num_steps_scoring += 1
        self.update_info()

    def end_epoch(self, meters):
        """
        :type meters: dict
        """
        new_epoch_result = {
            "epoch": self.current_epoch,
            "time": unix_time_millis() - self._current_epoch_start_time
        }
        self.kept_model_epoch = self.current_epoch if meters.pop("saved", False) else self.kept_model_epoch
        new_epoch_result.update(meters)
        self.epochs.append(new_epoch_result)
        self.update_info(force=True)

    def start_train(self):
        self.started_at = unix_time_millis()
        self.update_info(force=True)

    def update_info(self, force=False):
        if not get_deephub_context().is_main_process():
            return
        super(DeepHubModelTrainingInfoHandler, self).update_info(force)

    def to_dict(self):
        return {
            "epochs": self.epochs,
            "nbStepsTrainingPerEpoch": self.num_train_steps,
            "nbStepsScoringPerEpoch": self.num_validation_steps,
            "nbEpochs": self.num_epochs,
            "metric": self.metric,
            "startedAt": self.started_at,
            "currentEpoch": self.current_epoch,
            "currentNumStepsTraining": self.current_num_steps_training,
            "currentNumStepsScoring": self.current_num_steps_scoring,
            "keptModelEpoch": self.kept_model_epoch
        }


def set_epoch_on_sampler(sampler, epoch):
    """
    Distributed Sampler requires to set epoch prior to each iteration to make sure data is properly shuffled between
    each epoch

    :param sampler:
    :param epoch:
    """
    if isinstance(sampler, torch.utils.data.distributed.DistributedSampler):
        sampler.set_epoch(epoch)
