import logging
import os
from contextlib import contextmanager

import torch

from dataiku.doctor.utils.gpu_execution import DeepHubGpuCapability

logger = logging.getLogger(__name__)


class DeepHubContext(object):

    def __init__(self):
        self.distributed = False
        self.cuda_based = False
        self.local_rank = 0
        self.rank = 0
        self.world_size = 1

    def set_distributed(self, distributed):
        self.distributed = distributed

    def set_world_size(self, world_size):
        self.world_size = world_size

    def set_local_rank(self, local_rank):
        self.local_rank = local_rank

    def set_rank(self, rank):
        self.rank = rank

    def set_cuda_based(self, cuda_based):
        self.cuda_based = cuda_based

    def is_main_process(self):
        return self.rank == 0

    def is_local_main_process(self):
        return self.local_rank == 0

    def get_device_id(self):
        # In order to simplify pytorch code, we always reason with gpu device_id == process local_rank
        if not self.cuda_based:
            return 0

        return self.local_rank

    def get_device(self):
        if not self.cuda_based:
            return torch.device("cpu")
        else:
            return torch.device("cuda:{}".format(self.get_device_id()))

    def get_dataloader_multiprocessing_context(self):
        # Starting macOS 10.13, Apple has changed how fork() works,
        # now using multiprocessing with fork and some threads easily crashes the process
        # The new default for macOS in Python 3.8 is multiprocessing="spawn" (see https://bugs.python.org/issue33725)
        # Some threading code is called while downloading images deeply in requests (https://bugs.python.org/issue31818)
        # So default to multiprocessing="spawn" for DeepHub to prevent crashes
        #
        # In the case of Distributed training, with the use of `DistributedDataParallel`,
        # nccl or gloo backend are not fork safe (see the list of warnings in
        # https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html)
        # So default to "spawn" method to have a consistent behavior everywhere
        return torch.multiprocessing.get_context("spawn")

    def init_process_group(self, run_folder):
        """ init process group mostly for distributed processing
            MUST be called after setting the whole context (local_rank, cuda_based, world_size)
        """
        backend = "nccl" if self.cuda_based else "gloo"
        store_path = os.path.join(run_folder, "pytorch-file-store")
        logger.info("using '%s' for default process group file store", store_path)
        file_store = torch.distributed.FileStore(store_path, self.world_size)
        torch.distributed.init_process_group(backend=backend, rank=self.local_rank, world_size=self.world_size, store=file_store)
        if self.cuda_based:
            torch.cuda.set_device(self.local_rank)

    def destroy_process_group(self):
        torch.distributed.destroy_process_group()

    def __str__(self):
        return "DeepHubTrainingContext(distributed={}, world_size={}, " \
               "rank={}, local_rank={}, cuda_based={})".format(self.distributed, self.world_size, self.rank,
                                                               self.local_rank, self.cuda_based)


__ctx = None


def get_deephub_context():
    """
    :rtype: DeepHubContext
    """
    if __ctx is None:
        raise Exception("Deephub not initialized, cannot retrieve it")
    return __ctx


def set_deephub_context(ctx):
    global __ctx
    __ctx = ctx


def init_deephub_context(per_node_gpu_list, run_folder):
    ctx = DeepHubContext()

    DeepHubGpuCapability.set_deephub_ctx_gpu_behaviour(per_node_gpu_list, ctx)

    world_size = int(os.environ.get("WORLD_SIZE", 1))

    if world_size > 1:  # distributed
        # env vars `WORLD_SIZE`, `LOCAL_RANK` & `RANK` are set by `torch.distributed.launch`
        local_rank = int(os.environ["LOCAL_RANK"])
        rank = int(os.environ["RANK"])

        ctx.set_distributed(True)
        ctx.set_world_size(world_size)
        ctx.set_local_rank(local_rank)
        ctx.set_rank(rank)
        ctx.init_process_group(run_folder)

    set_deephub_context(ctx)
    logger.info("initialized context: {}".format(ctx))


@contextmanager
def with_enforced_not_distributed_context():
    """
    Enforces a simulated not-distributed context.
    Mandatory for scoring model, because pytorch does not accept using distributed with model without gradients
    NB: this method is not threadsafe
    """
    ctx = get_deephub_context()
    if not ctx.distributed:
        yield
        return

    assert ctx.is_main_process(), "Can only enforce not-distributed context on main process"
    new_ctx = DeepHubContext()
    new_ctx.set_cuda_based(ctx.cuda_based)
    set_deephub_context(new_ctx)
    yield
    set_deephub_context(ctx)  # Put back original context
