import logging
import time
from abc import ABCMeta
import pandas as pd
import numpy as np
import torch
import six

from dataiku.doctor.deephub.deephub_torch_datasets import ObjectDetectionWithTargetDataset, ObjectDetectionDataset
from dataiku.doctor.deephub.utils.constants import ObjectDetectionModels
from dataiku.doctor.deephub.deephub_training import DeepHubTrainingEngine
from dataiku.doctor.deephub.object_detection_performance import ObjectDetectionPerformanceComputer

logger = logging.getLogger(__name__)


@six.add_metaclass(ABCMeta)
class AbstractObjectDetectionDeepHubTrainingEngine(DeepHubTrainingEngine):
    TYPE = "DEEP_HUB_IMAGE_OBJECT_DETECTION"

    def build_dataset(self, df, files_reader, model, for_eval=False):
        """
        :param for_eval: is the dataset used for evaluation
        :type model: dataiku.doctor.deephub.deephub_model.ComputerVisionDeepHubModel
        :type for_eval: bool
        :type df: pd.DataFrame
        :type files_reader: dataiku.doctor.deephub.utils.file_utils.FilesReader
        :rtype: dataiku.doctor.deephub.deephub_torch_datasets.DeepHubDataset
        """
        augmentation_params = None if for_eval else self.modeling_params["augmentationParams"]
        transforms_lists = model.build_image_transforms_lists(augmentation_params)
        return ObjectDetectionWithTargetDataset(files_reader, df, self.target_remapping, self.file_path_col,
                                                self.target_variable, transforms_lists)

    @property
    def model_name(self):
        """
        :rtype: ObjectDetectionModels
        """
        if self.modeling_params["pretrainedModel"] not in [e.name for e in ObjectDetectionModels]:
            raise RuntimeError("Unsupported object detection model {}".format(self.modeling_params["pretrainedModel"]))

        return ObjectDetectionModels[self.modeling_params["pretrainedModel"]]

    @staticmethod
    def accumulate(data_accumulator, targets, batch_outputs):

        # Detach tensors, create numpy/list of objects compatible with dss then accumulate values for this batch
        images_ids = [target["image_id"].detach().cpu().numpy().item() for target in targets]
        formatted_targets = ObjectDetectionDataset.post_process_model_data(targets)
        formatted_outputs = ObjectDetectionDataset.post_process_model_data(batch_outputs)

        # each target & outputs have different size or object detection (as each image can have 1 or more objects)
        # so it's easier to accumulate each image separately to avoid flattenization of batches of items of different
        # sizes.
        for image_id, formatted_target, formatted_output in zip(images_ids, formatted_targets, formatted_outputs):
            data_accumulator.accumulate_value("images_ids", image_id)
            data_accumulator.accumulate_value("targets", formatted_target)
            data_accumulator.accumulate_value("detections", formatted_output)

    def compute_performance(self, data_accumulator, origin_index, metric_params):
        # retrieve all the accumulated values from the last epoch:
        images_ids = data_accumulator.get_accumulated_value("images_ids", default=[])
        targets = data_accumulator.get_accumulated_value("targets", default=[])
        detections = data_accumulator.get_accumulated_value("detections", default=[])

        ground_truth_df = pd.DataFrame({"image_id": images_ids, "target": targets})
        detections_df = pd.DataFrame({"image_id": images_ids, "prediction": detections})

        # Some sampling strategy (e.g. distributed) can duplicate images to fill incomplete batches. Keep only 1
        # occurrence of prediction for each image
        ground_truth_df.drop_duplicates(subset=['image_id'], inplace=True)
        detections_df.drop_duplicates(subset=['image_id'], inplace=True)

        performance_computer = ObjectDetectionPerformanceComputer(self.target_remapping, origin_index,
                                                                  ground_truth_df, detections_df, metric_params)
        return performance_computer.compute_performance()


class DummyObjectDetectionDeepHubTrainingEngine(AbstractObjectDetectionDeepHubTrainingEngine):
    DUMMY = True

    def __init__(self, target_remapping, modeling_params, target_variable, file_path_col):
        super(DummyObjectDetectionDeepHubTrainingEngine, self).__init__(target_remapping, modeling_params,
                                                                        target_variable, file_path_col)
        # randomly initialize Numpy & torch seeds & log it to be able to reproduce a bug (if any).
        seed = 1337
        self.random_state = np.random.RandomState(seed=seed)
        self.torch_gen = torch.Generator().manual_seed(seed)

        logger.info("Producing random predictions from dummy model with random seed {}".format(seed))

    def init_training_params(self, model):
        logger.info("Init dummy model training params")

    def on_train_start(self, num_batches_per_epoch):
        logger.info("on train start dummy model")

    def train_one_epoch(self, epoch, model, device, train_data_loader, deephub_logger):
        logger.info("Training one epoch with dummy model")
        if epoch == 0:
            deephub_logger.update_meter("lr", 0.001)

        for batch in deephub_logger.iter_over_data(train_data_loader, epoch=epoch, redraw_batch_if_empty=True):
            assert batch is not None, "received an empty batch"

            time.sleep(0.01)  # training part
            deephub_logger.update_meter("lr", 0.001)
            deephub_logger.update_meter("loss", 0.2)

    def on_epoch_end(self, epoch_index, val_metric):
        logger.info("Dummy model, skipping LR scheduler step")

    def _simulate_error(self, correct_label):
        """
            Simulate a real prediction by introducing random error in the predicted label with a given probability
            Returns: a label id, either randomly chosen from the target_remapping or from the correct label
            (ground truth).
        """
        # As we want to simulate a real PyTorch model, adding 1 here to the 0-based DSS categories: PyTorch expects
        # "labels" to be 1-based for object detection (0 being for the background).
        random_label = self.random_state.choice(len(self.target_remapping)) + 1
        return self.random_state.choice([random_label, correct_label], p=[0.3, 0.7])

    def predict_and_accumulate(self, model, device, data_loader, deephub_logger, data_accumulator, epoch=None):
        """
        Runs a simulated prediction to be able to compute performance afterwards and test the rest of the feature.
        It works as follows:

        For each ground truth box:
        * we first decide how many corresponding detections there will be (between 0 and 3)
        * then for each detection box, we apply a perturbation so that it does not completely overlap the ground truth.
        * With a given probability it predicts a random category from the target_remapping (to simulate errors in predicted
          classes) - the rest of the time it predicts the same category than in the ground truth.
        """

        with torch.no_grad():
            for batch in deephub_logger.iter_over_data(data_loader, epoch=epoch):
                if batch is None:
                    continue

                targets, _ = batch

                outputs = []
                for target in targets:
                    # Selecting how much detections will there be for each ground truth
                    # The most probable is 0 (undetected box) or 1, sometimes we might have more
                    nb_repetitions_per_gt = torch.tensor([self.random_state.choice([0, 1, 2, 3], p=[0.5, 0.2, 0.15, 0.15])
                                                          for _ in range(target["labels"].size()[0])])
                    predicted_labels = target["labels"].detach().clone().repeat_interleave(nb_repetitions_per_gt)

                    # At this point all the predictions have the correct label. Introduce some randomness in the
                    # predicted labels to simulate errors:
                    predicted_labels = torch.tensor([self._simulate_error(label) for label in predicted_labels])
                    predicted_scores = torch.rand(predicted_labels.size(), generator=self.torch_gen)

                    # For boxes, apply perturbation on them of 10% of the amplitude of the box in each direction
                    predicted_boxes = target["boxes"].detach().clone().repeat_interleave(nb_repetitions_per_gt, axis=0)
                    if predicted_boxes.size()[0] > 0:  # Ensure at least one prediction
                        dx1 = torch.normal(mean=torch.zeros(predicted_boxes.size()[0]),
                                           std=(0.1 * torch.abs(predicted_boxes[:, 0] - predicted_boxes[:, 2])),
                                           generator=self.torch_gen)
                        dx2 = torch.normal(mean=torch.zeros(predicted_boxes.size()[0]),
                                           std=(0.1 * torch.abs(predicted_boxes[:, 0] - predicted_boxes[:, 2])),
                                           generator=self.torch_gen)
                        dy1 = torch.normal(mean=torch.zeros(predicted_boxes.size()[0]),
                                           std=(0.1 * torch.abs(predicted_boxes[:, 1] - predicted_boxes[:, 3])),
                                           generator=self.torch_gen)
                        dy2 = torch.normal(mean=torch.zeros(predicted_boxes.size()[0]),
                                           std=(0.1 * torch.abs(predicted_boxes[:, 1] - predicted_boxes[:, 3])),
                                           generator=self.torch_gen)

                        predicted_boxes[:, 0] += dx1
                        predicted_boxes[:, 2] += (dx1 + dx2)
                        predicted_boxes[:, 1] += dy1
                        predicted_boxes[:, 3] += (dy1 + dy2)

                    # Clip results to ensure predicted boxes are > 0
                    predicted_boxes = torch.where(predicted_boxes > 0., predicted_boxes,
                                                  torch.tensor(0., dtype=torch.float32))

                    # Simulate Pytorch format for predictions:
                    output = {"labels": predicted_labels, "boxes": predicted_boxes, "scores": predicted_scores}
                    outputs.append(output)

                self.accumulate(data_accumulator, targets, outputs)
                deephub_logger.update_meter("loss", 0.1)


class ObjectDetectionDeepHubTrainingEngine(AbstractObjectDetectionDeepHubTrainingEngine):

    def __init__(self, target_remapping, modeling_params, target_variable, file_path_col):
        super(ObjectDetectionDeepHubTrainingEngine, self).__init__(target_remapping, modeling_params,
                                                                   target_variable, file_path_col)

    def train_one_epoch(self, epoch, model, device, train_data_loader, deephub_logger):
        model.train()
        deephub_logger.update_meter("lr", self.optimizer.param_groups[0]["lr"])

        for batch in deephub_logger.iter_over_data(train_data_loader, epoch=epoch, redraw_batch_if_empty=True):
            assert batch is not None, "received an empty batch"
            targets, images = batch
            # Using function for training to leverage python function scoping to free memory every batch
            self.train_batch(model, device, images, targets, deephub_logger)

    def train_batch(self, model, device, images, targets, deephub_logger):
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(images, targets)

        losses = sum([loss for loss in loss_dict.values()])

        self.optimizer.zero_grad()
        losses.backward()

        self.optimizer.step()
        self.lr_scheduler_strategy.on_batch_end()

        deephub_logger.update_meter("lr", self.optimizer.param_groups[0]["lr"])
        deephub_logger.update_meter("loss", losses.item() / len(loss_dict))  # averaged loss per sample from the batch.

        current_reduced_loss = deephub_logger.get_meter("loss").get_global_avg_value()
        if not np.isfinite(current_reduced_loss):
            raise RuntimeError("Loss is not finite (value={}), stopping training".format(current_reduced_loss))

    def predict_and_accumulate(self, model, device, data_loader, deephub_logger, data_accumulator, epoch=None):
        with torch.no_grad():
            for batch in deephub_logger.iter_over_data(data_loader, epoch=epoch, redraw_batch_if_empty=False):

                if batch is None:
                    logger.info("Got an empty batch, not predicting it")
                    continue

                targets, images = batch
                images = [img.to(device) for img in images]
                targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

                model.eval()
                outputs = model(images)
                self.accumulate(data_accumulator, targets, outputs)

                # PyTorch Object Detection models don't accept to retrieve loss from model in eval() mode, thus 
                # we must switch to train() mode to retrieve the eval loss.
                # This additional forward pass on the eval_set is an added time cost but ensure we can detect
                # overfitting by comparing both training & eval losses.
                model.train()
                loss_dict = model(images, targets)

                losses = sum([loss for loss in loss_dict.values()])
                deephub_logger.update_meter("loss", losses.item() / len(loss_dict))  # averaged loss per sample from the batch.
