from dataiku.doctor.deephub.deephub_model import ComputerVisionDeepHubModel
from dataiku.doctor.deephub.utils.constants import ObjectDetectionModels

import logging
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

logger = logging.getLogger(__name__)


class ObjectDetectionDeepHubModel(ComputerVisionDeepHubModel):
    TYPE = "DEEP_HUB_IMAGE_OBJECT_DETECTION"

    def __init__(self, target_remapping, modeling_params):
        super(ObjectDetectionDeepHubModel, self).__init__(target_remapping, modeling_params)

    @property
    def model_name(self):
        """
        :rtype: ObjectDetectionModels
        """
        if self.modeling_params["pretrainedModel"] not in [e.name for e in ObjectDetectionModels]:
            raise RuntimeError("Unsupported object detection model {}".format(self.modeling_params["pretrainedModel"]))

        return ObjectDetectionModels[self.modeling_params["pretrainedModel"]]

    def get_model(self, pretrained):

        # Because the model consider first class as background, we need to pass num_classes + 1
        num_classes = len(self.target_remapping) + 1

        if self.model_name == ObjectDetectionModels.FASTERRCNN:

            # We build kwargs manually, because we want to leave default values for some arguments depending on
            # condition (e.g. `num_classes`)
            model_kwargs = {
                "pretrained": True,  # due to https://app.shortcut.com/dataiku/story/143472/bump-some-more-deephub-packages-pytorch we always load the model
                                     # with default pretrained weights, even when they are overridden by local weights (when scoring).

                # DEPRECATION: started with torchvision 0.13 using pretrained=true is deprecated, and will be removed at 0.15
                # pretrained=true can be replaced with weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1 but this change is not backward compatible, meaning
                # it won't work for users not updating their code env with a new torch version. it will probably necessitate to bump a new
                # version of the deephub code env which is not transparent to users.
                "pretrained_backbone": False,  # ensure we never download unwanted backbone weights
                "trainable_backbone_layers": self.get_number_of_retrained_layers(),
            }
            logger.info("Number of retrained layers: {}".format(model_kwargs["trainable_backbone_layers"]))

            model = torchvision.models.detection.fasterrcnn_resnet50_fpn(**model_kwargs)

            # Replace pre-trained classifier with new one, because it probably does not have
            # the same number of classes as the initial pre-trained model (which has the 91 COCO classes):
            # get number of input features for the classifier
            in_features = model.roi_heads.box_predictor.cls_score.in_features
            #  * replace the pre-trained head with a new one with input features and new number of classes
            model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=num_classes)

            # Note: at this point the model contains pretrained weights from the initial training (from torch),
            # except for the head which contains randomly initialized weights.

        else:
            raise RuntimeError("Unknown pretrained model for "
                               "object detection: '{}'".format(self.modeling_params["pretrainedModel"]))

        return model

    def get_resolved_params(self):
        return {"retrainedLayers": self.get_number_of_retrained_layers()}


class DummyObjectDetectionDeepHubModel(ObjectDetectionDeepHubModel):
    DUMMY = True

    @property
    def model_name(self):
        return "Dummy Object detection Model"

    def get_model(self, pretrained):
        logger.info("Getting dummy model")

    def get_number_of_retrained_layers(self):
        return 0
