import logging
import os
import uuid
import sys

import torch.distributed.run
import torch.distributed.elastic.utils.distributed

from dataiku.core.managed_folder import Folder
from dataiku.doctor.deephub.deephub_context import get_deephub_context
from dataiku.doctor.deephub.deephub_params import DeepHubTrainingParams
from dataiku.doctor.deephub.deephub_training import DeepHubTrainingHandler
import dataiku.doctor.deephub.builtins as builtins
from dataiku.doctor.deephub.utils.file_utils import CachedManagedFolderFilesReader
from dataiku.doctor.deephub.utils.process_monitor import watch_parent_process
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu

logger = logging.getLogger(__name__)


def launch_distributed_training(deephub_params):

    rdvz_endpoint = "localhost:{}".format(torch.distributed.elastic.utils.distributed.get_free_port())
    logger.info("Launch distributed training with c10d backend on {}".format(rdvz_endpoint))

    # Get a free port and a new UUID for distributed training to allow multiple training on the same server
    args = ["--rdzv_backend=c10d",
            "--rdzv_endpoint={}".format(rdvz_endpoint),
            "--rdzv_id={}".format(str(uuid.uuid4())),
            "--nnodes={}".format(1),  # would be more than one when distributed on k8s
            "--nproc_per_node={}".format(deephub_params.get_process_count_per_node()),
            "-m",
            "dataiku.doctor.deephub.launch_training"] + list(deephub_params.to_str_params())
    torch.distributed.run.main(args)


def launch_local_training(deephub_params):
    log_nvidia_smi_if_use_gpu(gpu_config=deephub_params.gpu_config)
    builtins.load()
    deephub_params.init_deephub_context()
    training_context = get_deephub_context()

    if training_context.is_main_process():
        default_diagnostics.register_deephub_callbacks(deephub_params.core_params)

    managed_folder = Folder(deephub_params.managed_folder_id)
    cached_data_base_folder = os.path.join(deephub_params.tmp_folder, "tmp_data")
    # We only need to init and clean the folder used for caching once per node (currently we only have 1), hence giving
    # this responsibility to the local main process
    with CachedManagedFolderFilesReader(managed_folder,
                                        cached_data_base_folder,
                                        init_and_clean_base_folder=training_context.is_local_main_process()) as files_reader:
        training_handler = DeepHubTrainingHandler(deephub_params, files_reader)
        training_handler.train()


def launch_training(deephub_params):
    if deephub_params.is_distributed():
        launch_distributed_training(deephub_params)
    else:
        launch_local_training(deephub_params)


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    deephub_process_pid = os.getpid()
    logger.info("Deephub Process pid={}".format(deephub_process_pid))
    args = sys.argv[1:]
    if len(args) != 8:
        raise Exception("Expected 7 arguments, received {}".format(len(args)))
    watch_parent_process()
    params = DeepHubTrainingParams.from_str_params(*args)
    launch_local_training(params)
