import logging
import sys

from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.doctor.causal.train.launch_training import launch_training
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu


def main(exec_folder, operation_mode):
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    exec_folder_context = build_folder_context(exec_folder)
    split_folder_context = exec_folder_context.get_subfolder_context("split")
    split_desc = split_folder_context.read_json("split.json")
    core_params = exec_folder_context.read_json("core_params.json")
    preprocessing_params = exec_folder_context.read_json("rpreprocessing_params.json")
    # Only one modeling set (the saved model), we only need modelingParams and run_folder to run training
    modeling_sets = [{
        "model_folder_context": exec_folder_context,
        "run_folder": exec_folder,
        "modelingParams": exec_folder_context.read_json("rmodeling_params.json")
    }]
    log_nvidia_smi_if_use_gpu(core_params=core_params)
    launch_training(core_params, modeling_sets, preprocessing_params, exec_folder_context, split_folder_context, split_desc, operation_mode)


if __name__ == "__main__":

    with ErrorMonitoringWrapper():
        main(sys.argv[1], sys.argv[2])
