import logging
import os
from abc import ABCMeta

import six
import torch

from dataiku.doctor.deephub.data_augmentation.image_transformer import build_transforms_lists
from dataiku.doctor.deephub.deephub_context import get_deephub_context
from dataiku.doctor.deephub.utils.deephub_registry import DeepHubRegistry

logger = logging.getLogger(__name__)


class DeepHubModelHandler(object):
    DEEP_HUB_MODEL_FILENAME = "deep_hub_model.pkl"

    def __init__(self, model_folder_context, deephub_model, pretrained_locally):
        """
         Model handler handles :
         - the distribution of the model in a distributed context
         - putting the model on the context devices

        It is task agnostic: All the tasks specificities (Object detection, image classification) are handled
        by the deephub_model itself when getting the Neural network architecture.

        :param model_folder_context: folder context to folder where data on the model is stored
        :type model_folder_context: dataiku.base.folder_context.FolderContext
        :param deephub_model: Task dependant object holding the information to build the Neural Network architecture.
        :type deephub_model: DeepHubModel
        :param pretrained_locally: boolean to indicate whether model has been pretrained locally or if "default"
                                   pretrained weights should be used (provided by the library itself).
        :type pretrained_locally: bool
        """
        self.model_folder_context = model_folder_context
        self.deephub_model = deephub_model

        self.nn_base_model = deephub_model.get_model(pretrained=not pretrained_locally)
        self.pretrained_locally = pretrained_locally
        self.device = get_deephub_context().get_device()

        self.nn_model = self._build_model()

    # Constructors

    @staticmethod
    def build_for_pretrained_model(model_folder_context, base_model):
        return DeepHubModelHandler(model_folder_context, base_model, pretrained_locally=False)

    @staticmethod
    def build_for_scoring(model_folder_context, base_model):
        return DeepHubModelHandler(model_folder_context, base_model, pretrained_locally=True)

    def _build_model(self):
        """ Return a Neural network model from nn_base_model by :
                - loading the NN weights
                - moving the NN onto the deephub context device (cpu/gpus)
                - creating a DistributedDataParallel NN if training context is distributed.
            :rtype torch.nn.Module or None
        """
        if self.nn_base_model is None:  # only use case is dummy training
            logger.info("Not building model because base model is not defined")
            return None

        if self.pretrained_locally:
            logger.info("Loading local weights for model")
            if not self.model_folder_context.isfile(DeepHubModelHandler.DEEP_HUB_MODEL_FILENAME):
                raise IOError("Cannot build a locally trained model without local "
                              "weights ({})".format(DeepHubModelHandler.DEEP_HUB_MODEL_FILENAME))
            device = get_deephub_context().get_device()
            with self.model_folder_context.get_file_path_to_read(DeepHubModelHandler.DEEP_HUB_MODEL_FILENAME) as model_weights_path:
                self.nn_base_model.load_state_dict(torch.load(model_weights_path, map_location=device))

        if "TORCH_HOME" in os.environ:
            logger.info("Using pretrained model from TORCH_HOME: %s", os.environ["TORCH_HOME"])
        training_context = get_deephub_context()
        self.nn_base_model.to(self.device)
        if training_context.distributed:
            device_ids = [training_context.get_device_id()] if training_context.cuda_based else None
            return torch.nn.parallel.DistributedDataParallel(self.nn_base_model,
                                                             device_ids=device_ids,
                                                             output_device=training_context.get_device_id())
        else:
            return self.nn_base_model

    def save(self):
        if self.nn_base_model is None:  # only use case is dummy training
            logger.info("Not saving model because not defined")
            return
        with self.model_folder_context.get_file_path_to_write(DeepHubModelHandler.DEEP_HUB_MODEL_FILENAME) as model_weights_path:
            logger.info("Saving model to {}".format(model_weights_path))
            torch.save(self.nn_base_model.state_dict(), model_weights_path)


@six.add_metaclass(ABCMeta)
class DeepHubModel(object):
    TYPE = "DEEP_HUB_MODEL"
    REGISTRY = DeepHubRegistry()
    DUMMY = False

    def __init__(self, target_remapping, modeling_params):
        self.target_remapping = target_remapping
        self.modeling_params = modeling_params

    @staticmethod
    def define(training_class):
        DeepHubModel.REGISTRY.register(training_class.TYPE, training_class.DUMMY, training_class)

    @staticmethod
    def build(prediction_type, target_remapping, modeling_params):
        """
        :rtype: DeepHubModel
        """

        dummy = modeling_params.get("dummy", False)
        try:
            model_class = DeepHubModel.REGISTRY.get(prediction_type, dummy)
        except KeyError:
            raise ValueError("Unknown training engine: {} (dummy={})".format(prediction_type, dummy))
        return model_class(target_remapping, modeling_params)

    def get_model(self, pretrained):
        """
        Create the Neural Network Module holding the model architecture

        :param pretrained: boolean to indicate whether if "default" pretrained weights should be used (provided by the
                           library itself)
        :type pretrained: bool
        :rtype: torch.nn.Module

        """
        raise NotImplementedError()

    def get_resolved_params(self):
        """
        :return: param_name: param_value for every param that was resolved when building the model
        :rtype: dict
        """
        return {}

    @property
    def model_name(self):
        """ Return model name from constants.py depending on modeling params and type of prediction """
        raise NotImplementedError


@six.add_metaclass(ABCMeta)
class ComputerVisionDeepHubModel(DeepHubModel):

    def get_number_of_retrained_layers(self):
        return self.modeling_params["nbFinetunedLayers"]

    def build_image_transforms_lists(self, augmentation_params=None):
        """
        :rtype: ImageTransformsLists
        """
        return build_transforms_lists(augmentation_params)
