import os

import tensorflow as tf

from dataiku.base.utils import TmpFolder
from dataiku.core import doctor_constants
from dataiku.doctor.deep_learning.keras_utils import retrieve_func_from_code_dict


def _check_model_output_dimension(model, prediction_type, target_map, modeling_params):
    output_shape = model.output_shape
    modeling_params["keras"]["oneDimensionalOutput"] = False

    if len(output_shape) != 2:
        raise ValueError("Output of Deep Learning must be 2-dimensional. It has currently a "
                         "dimension: {}".format(len(output_shape)))
    output_shape_dimension = output_shape[1]

    if prediction_type == doctor_constants.REGRESSION and output_shape_dimension != 1:
        raise ValueError("For regression problems, output of Deep Learning Architecture must have a "
                         "dimension equal to 1. It is currently: {}".format(output_shape_dimension))

    if prediction_type == doctor_constants.BINARY_CLASSIFICATION:
        if output_shape_dimension == 1:
            modeling_params["keras"]["oneDimensionalOutput"] = True
        if not (output_shape_dimension != 1 or output_shape_dimension != 2):
            raise ValueError("For binary classification problems, output of Deep Learning Architecture must have a "
                             "dimension equal to 1 or 2. It is currently: {}".format(output_shape_dimension))

    if prediction_type == doctor_constants.MULTICLASS and output_shape_dimension != len(target_map):
        raise ValueError("For this multiclass classification problem, output of Deep Learning "
                         "Architecture must have a dimension equal to {} (number of classes). "
                         "It is currently: {}".format(len(target_map), output_shape_dimension))


def _build_and_fit_model_tf1(input_shapes, output_num_labels, prediction_type, target_map, modeling_params,
                             keras_params, dic_build, run_folder_context, validation_sequence, save_model, fit_model,
                             train_sequence, gpu_config):
    from dataiku.doctor.deep_learning import gpu
    from keras.utils import multi_gpu_model
    from dataiku.doctor.deep_learning.keras_callbacks import get_base_callbacks

    num_gpus = gpu.get_num_gpu_used()
    use_multi_gpus = (num_gpus > 1)

    build_model = retrieve_func_from_code_dict("build_model", dic_build, "Architecture")
    compile_model = retrieve_func_from_code_dict("compile_model", dic_build, "Architecture")

    if use_multi_gpus:
        with tf.device('/cpu:0'):
            base_model = build_model(input_shapes, output_num_labels)
        model = multi_gpu_model(base_model, num_gpus)
    else:
        base_model = None
        model = build_model(input_shapes, output_num_labels)

    _check_model_output_dimension(model, prediction_type, target_map, modeling_params)
    model = compile_model(model)

    with (run_folder_context.get_subfolder_context("tensorboard_logs")
                            .get_folder_path_to_write(regularly_synchronize=True)) as tensorboard_folder_path:
        base_callbacks = get_base_callbacks(run_folder_context, modeling_params, validation_sequence,
                                            prediction_type, target_map, save_model, use_multi_gpus=use_multi_gpus,
                                            base_model=base_model, tensorboard_folder_path=tensorboard_folder_path)

        if keras_params["advancedFitMode"]:
            fit_model(model,
                      train_sequence,
                      validation_sequence,
                      base_callbacks)
        else:
            # Manually call fit_generator on model with parameters from UI
            epochs = keras_params["epochs"]
            steps_per_epoch = keras_params["stepsPerEpoch"] if not keras_params["trainOnAllData"] else None
            model.fit_generator(
                train_sequence,
                epochs=epochs,
                steps_per_epoch=steps_per_epoch,
                callbacks=base_callbacks,
                shuffle=keras_params["shuffleData"]
            )


def _build_and_fit_model_tf2(input_shapes, output_num_labels, prediction_type, target_map, modeling_params,
                             keras_params, dic_build, run_folder_context, validation_sequence, save_model, fit_model,
                             train_sequence, gpu_config):
    from dataiku.doctor.deep_learning.gpu import get_num_gpu_used
    from dataiku.doctor.deep_learning.keras_callbacks import get_base_callbacks

    build_model = retrieve_func_from_code_dict("build_model", dic_build, "Architecture")
    compile_model = retrieve_func_from_code_dict("compile_model", dic_build, "Architecture")

    def get_distributing_strategy():
        num_gpus = get_num_gpu_used()
        if num_gpus > 1:
            gpu_params = gpu_config["params"]
            gpus = ["/gpu:" + str(g) for g in gpu_params["gpuList"]]
            return tf.distribute.MirroredStrategy(gpus).scope()
        else:
            return tf.distribute.get_strategy().scope()

    with get_distributing_strategy():
        model = build_model(input_shapes, output_num_labels)
        _check_model_output_dimension(model, prediction_type, target_map, modeling_params)
        model = compile_model(model)

        with (run_folder_context.get_subfolder_context("tensorboard_logs")
                                .get_folder_path_to_write(regularly_synchronize=True)) as tensorboard_folder_path:
            base_callbacks = get_base_callbacks(run_folder_context, modeling_params, validation_sequence,
                                                prediction_type, target_map, save_model, tensorboard_folder_path)
            if keras_params["advancedFitMode"]:
                fit_model(model,
                          train_sequence,
                          validation_sequence,
                          base_callbacks)
            else:
                # Manually call fit on model with parameters from UI
                epochs = keras_params["epochs"]
                train_dataset = tf.data.Dataset.from_generator(
                    lambda: train_sequence,
                    args=[],
                    output_signature=train_sequence.get_output_signature()
                )

                with TmpFolder(run_folder_context.get_absolute_folder_path()) as tmp_folder_path:
                    if keras_params["cachePreprocessedData"]:
                        train_dataset = train_dataset.cache(os.path.join(tmp_folder_path, "preprocess_training_data"))
                    steps_per_epoch = len(train_sequence) if keras_params["trainOnAllData"] else keras_params["stepsPerEpoch"]
                    if keras_params["shuffleData"]:
                        train_dataset = train_dataset.shuffle(buffer_size=steps_per_epoch, reshuffle_each_iteration=True)
                    train_dataset = train_dataset.repeat()
                    model.fit(
                        train_dataset,
                        epochs=epochs,
                        steps_per_epoch=steps_per_epoch,
                        callbacks=base_callbacks,
                    )
