from abc import ABCMeta

import six

from dataiku.doctor.deephub.image_classification_performance import ImageClassificationPerformanceComputer
from dataiku.doctor.deephub.utils.constants import ImageClassificationModels
from dataiku.doctor.deephub.deephub_torch_datasets import ImageClassificationWithTargetDataset
from dataiku.doctor.deephub.deephub_torch_datasets import ImageClassificationDataset
from dataiku.doctor.deephub.deephub_training import DeepHubTrainingEngine
import logging
import time
import numpy as np
import torch

logger = logging.getLogger(__name__)


@six.add_metaclass(ABCMeta)
class AbstractImageClassificationDeepHubTrainingEngine(DeepHubTrainingEngine):
    TYPE = "DEEP_HUB_IMAGE_CLASSIFICATION"

    def build_dataset(self, df, files_reader, model, for_eval=False):
        """
        :param for_eval: is the dataset used for evaluation
        :type model: dataiku.doctor.deephub.deephub_model.ComputerVisionDeepHubModel
        :type for_eval: bool
        :type df: pd.DataFrame
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :rtype: dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset
        """
        augmentation_params = None if for_eval else self.modeling_params["augmentationParams"]
        transforms_lists = model.build_image_transforms_lists(augmentation_params)
        return ImageClassificationWithTargetDataset(files_reader, df, self.target_remapping, self.file_path_col,
                                                    self.target_variable, transforms_lists)

    @property
    def model_name(self):
        """
        :rtype: ImageClassificationModels
        """
        if self.modeling_params["pretrainedModel"] not in [e.name for e in ImageClassificationModels]:
            raise RuntimeError("Unsupported image classification "
                               "model {}".format(self.modeling_params["pretrainedModel"]))

        return ImageClassificationModels[self.modeling_params["pretrainedModel"]]

    @staticmethod
    def accumulate(data_accumulator, targets, batch_outputs):

        predicted_categories, predicted_probas = ImageClassificationDataset.post_process_model_data(batch_outputs)

        data_accumulator.accumulate_value("images_ids", targets["image_id"].detach().cpu().numpy())
        data_accumulator.accumulate_value("targets", targets["category"].detach().cpu().numpy())
        data_accumulator.accumulate_value("probas", predicted_probas)
        data_accumulator.accumulate_value("predictions", predicted_categories)

    def compute_performance(self, data_accumulator, origin_index, metric_params):
        # retrieve the lists of all the batches' accumulated values then concatenate each list into arrays :
        images_ids = np.concatenate(data_accumulator.get_accumulated_value("images_ids", default=[]))
        targets = np.concatenate(data_accumulator.get_accumulated_value("targets", default=[]))
        probas = np.concatenate(data_accumulator.get_accumulated_value("probas", default=[]))
        predictions = np.concatenate(data_accumulator.get_accumulated_value("predictions", default=[]))

        # Some sampling strategy (e.g. distributed) can duplicate images to fill incomplete batches. Keep only 1
        # occurrence of prediction, target and probas for each image
        images_ids, unique_ids_mask = np.unique(images_ids, return_index=True)
        targets, probas, predictions = targets[unique_ids_mask], probas[unique_ids_mask], predictions[unique_ids_mask]

        performance_computer = ImageClassificationPerformanceComputer(self.target_remapping, origin_index,
                                                                      images_ids, targets, probas, predictions)
        return performance_computer.compute_performance()


class DummyImageClassificationDeepHubTrainingEngine(AbstractImageClassificationDeepHubTrainingEngine):
    DUMMY = True

    def __init__(self, target_remapping, modeling_params, target_variable, file_path_col):
        super(DummyImageClassificationDeepHubTrainingEngine, self).__init__(target_remapping, modeling_params,
                                                                            target_variable, file_path_col)
        # randomly initialize Numpy & torch seeds & log it to be able to reproduce a bug (if any).
        seed = 1337
        self.random_state = np.random.RandomState(seed=seed)
        self.torch_gen = torch.Generator().manual_seed(seed)

        logger.info("Producing random predictions from dummy model with random seed {}".format(seed))

    def init_training_params(self, model):
        logger.info("Init dummy model training params")

    def on_train_start(self, num_batches_per_epoch):
        logger.info("on train start dummy model")

    def train_one_epoch(self, epoch, model, device, train_data_loader, deephub_logger):
        logger.info("Training one epoch with dummy model")

        if epoch == 0:
            deephub_logger.update_meter("lr", 0.001)

        for batch in deephub_logger.iter_over_data(train_data_loader, epoch=epoch, redraw_batch_if_empty=True):
            assert batch is not None, "received an empty batch"

            targets, _ = batch

            time.sleep(0.01)  # training part
            deephub_logger.update_meter("lr", 0.001)
            deephub_logger.update_meter("loss", 0.2)

    def on_epoch_end(self, epoch_index, val_metric):
        logger.info("Dummy model, skipping LR scheduler step")

    def predict_and_accumulate(self, model, device, data_loader, deephub_logger, data_accumulator, epoch=None):
        """
        Runs a simulated prediction to be able to compute performance afterwards and test the rest of the feature.

        Accumulate for each batch a random tensor with shape (nb_img_in_batch, nb_categories) simulating model's scores
        """

        with torch.no_grad():
            for batch in deephub_logger.iter_over_data(data_loader, epoch=epoch):
                if batch is None:
                    continue

                targets, _ = batch

                nb_categories = len(self.target_remapping)
                predicted_scores = torch.rand(len(targets["image_id"]), nb_categories, generator=self.torch_gen)

                self.accumulate(data_accumulator, targets, predicted_scores)
                deephub_logger.update_meter("loss", 0.2)


class ImageClassificationDeepHubTrainingEngine(AbstractImageClassificationDeepHubTrainingEngine):

    def __init__(self, target_remapping, modeling_params, target_variable, file_path_col):
        super(ImageClassificationDeepHubTrainingEngine, self).__init__(target_remapping, modeling_params,
                                                                       target_variable, file_path_col)
        self.loss_function = torch.nn.CrossEntropyLoss()

    def train_one_epoch(self, epoch, model, device, train_data_loader, deephub_logger):
        model.train()
        deephub_logger.update_meter("lr", self.optimizer.param_groups[0]["lr"])

        for batch in deephub_logger.iter_over_data(train_data_loader, epoch=epoch, redraw_batch_if_empty=True):
            assert batch is not None, "received an empty batch"
            targets, images = batch
            # Using function for training to leverage python function scoping to free memory every batch
            self.train_batch(model, device, images, targets, deephub_logger)

    def train_batch(self, model, device, images, targets, deephub_logger):
        images = images.to(device)
        targets["category"] = targets["category"].to(device)

        predicted_scores = model(images)
        losses = self.loss_function(predicted_scores, targets["category"])

        self.optimizer.zero_grad()
        losses.backward()

        self.optimizer.step()
        self.lr_scheduler_strategy.on_batch_end()

        deephub_logger.update_meter("lr", self.optimizer.param_groups[0]["lr"])
        deephub_logger.update_meter("loss", losses.item())  # averaged loss per sample from the batch.

        # should we synchronize before? current loss is not taken into account here, take the
        # last global avg loss between workers
        current_reduced_loss = deephub_logger.get_meter("loss").get_global_current_value()
        if not np.isfinite(current_reduced_loss):
            raise RuntimeError("Loss is not finite (value={}), stopping training".format(current_reduced_loss))

    def predict_and_accumulate(self, model, device, data_loader, deephub_logger, data_accumulator, epoch=None):
        model.eval()
        with torch.no_grad():
            for batch in deephub_logger.iter_over_data(data_loader, epoch=epoch, redraw_batch_if_empty=False):

                if batch is None:
                    logger.info("Got an empty batch, not predicting it")
                    continue

                targets, images = batch
                images = images.to(device)
                predicted_scores = model(images)

                losses = self.loss_function(predicted_scores, targets["category"].to(device))
                deephub_logger.update_meter("loss", losses.item())

                self.accumulate(data_accumulator, targets, predicted_scores)
