import logging

import numpy as np

from dataiku.core import doctor_constants
from dataiku.core.doctor_constants import TIMESERIES_IDENTIFIER_COLUMNS, TARGET_VARIABLE, PREDICTION_LENGTH, TIME_VARIABLE
from dataiku.doctor import step_constants
from dataiku.doctor.crossval.search_context import TimeseriesForecastingSearchContext
from dataiku.doctor.diagnostics import default_diagnostics
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.prediction.common import needs_hyperparameter_search
from dataiku.doctor.prediction.common import regridify_optimized_params
from dataiku.doctor.prediction.metric import MAPE
from dataiku.doctor.timeseries.preparation.auto_shifts_generation import TimeseriesAutoShiftsGenerator
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils import write_done_traininfo
from dataiku.doctor.utils import get_hyperparams_search_time_traininfo
from dataiku.doctor.utils import write_hyperparam_search_time_traininfo
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.listener import ModelStatusContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.split import load_train_set
from dataiku.doctor.timeseries.models import TimeseriesForecastingAlgorithm
from dataiku.doctor.timeseries.perf.model_perf import TimeseriesModelScorer
from dataiku.doctor.timeseries.preparation.preprocessing import add_rolling_windows_for_training, get_shift_map, get_windows_list, \
    has_external_features_or_windows, should_compute_auto_shifts, get_auto_shifts_past_only_range, \
    get_auto_shifts_known_in_advance_range, get_auto_shift_columns, is_shift_compatible
from dataiku.doctor.timeseries.preparation.preprocessing import TimeseriesPreprocessing
from dataiku.doctor.timeseries.train.split_handler import KFoldTimeseriesSplitHandler, CustomTrainTestTimeseriesSplitHandler
from dataiku.doctor.timeseries.train.training_handler import TimeseriesTrainingHandler
from dataiku.doctor.timeseries.train.training_handler import resample_for_training


logger = logging.getLogger(__name__)


def launch_training(core_params, modeling_sets, preprocessing_params, resampling_params,
                    run_folder_context, split_folder_context, split_desc):

    log_nvidia_smi_if_use_gpu(core_params=core_params)

    split_params = split_desc["params"]
    schema = split_desc["schema"]

    if core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_SPLIT):
        n_evaluation_splits = len(core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_INTERVALS))
    elif split_params["kfold"]:
        n_evaluation_splits = split_params["nFolds"]
    else:
        n_evaluation_splits = 1

    assert len(modeling_sets) == 1, "There cannot be more than one modeling set for time series forecasting training"
    modeling_set = modeling_sets[0]
    modeling_params = modeling_set["modelingParams"]
    has_external_features = has_external_features_or_windows(preprocessing_params, modeling_params.get("isShiftWindowsCompatible", False))
    has_auto_shifts_feature_generation = should_compute_auto_shifts(preprocessing_params, modeling_params)
    metrics_params = modeling_params["metrics"]
    model_folder_context = modeling_set["model_folder_context"]

    default_diagnostics.register_forecasting_callbacks(core_params)

    start = unix_time_millis()

    # Fill all the listeners ASAP to have correct progress data
    preprocessing_listener = ProgressListener()
    preprocessing_listener.add_future_steps(step_constants.TIMESERIES_LOAD_AND_RESAMPLE_STEPS)
    if has_auto_shifts_feature_generation:
        preprocessing_listener.add_future_step(step_constants.ProcessingStep.STEP_TIMESERIES_FEATURE_GENERATION)
    preprocessing_listener.add_future_step(step_constants.ProcessingStep.STEP_PREPARE_SPLITS)
    if has_external_features:
        preprocessing_listener.add_future_steps(step_constants.TIMESERIES_PREPROCESSING_STEPS)

    listener = preprocessing_listener.new_child(ModelStatusContext(model_folder_context, start))
    if needs_hyperparameter_search(modeling_set.get('modelingParams', {})):
        listener.add_future_step(step_constants.ProcessingStep.STEP_HYPERPARAMETER_SEARCHING)
    if n_evaluation_splits > 1:
        listener.add_future_steps(step_constants.TIMESERIES_KFOLD_TRAIN_STEPS)
    else:
        listener.add_future_steps(step_constants.TIMESERIES_REGULAR_TRAIN_STEPS)

    if not modeling_params.get("skipExpensiveReports"):
        listener.add_future_step(step_constants.ProcessingStep.STEP_TIMESERIES_RESIDUALS)

    listener.add_future_steps(step_constants.TIMESERIES_SCORING_SAVING_STEPS)
    # Load dataset
    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_LOADING_SRC):
        full_df = load_train_set(core_params, preprocessing_params, split_desc, "full",
                                 split_folder_context, use_diagnostics=False)

    # Beginning of preprocessing
    compute_zero_target_ratio_diagnostic = metrics_params["evaluationMetric"] == MAPE.name
    with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_TIMESERIES_RESAMPLING):
        full_df = resample_for_training(
            full_df,
            schema,
            resampling_params,
            core_params,
            preprocessing_params,
            modeling_params.get("isShiftWindowsCompatible", False),
            compute_zero_target_ratio_diagnostic,
        )

        # Windowing is performed once and for all on the full dataframe
        windows_list = get_windows_list(preprocessing_params) if modeling_params.get("isShiftWindowsCompatible", False) else []
        full_df = add_rolling_windows_for_training(full_df, core_params, windows_list, preprocessing_params, run_folder_context)

    # Generate auto shifts once and for all
    if has_auto_shifts_feature_generation:
        with preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_TIMESERIES_FEATURE_GENERATION):
            generate_auto_shifts_resources(full_df, core_params, preprocessing_params, run_folder_context)
            populate_auto_shifts(preprocessing_params, run_folder_context)

    algorithm = TimeseriesForecastingAlgorithm.build(modeling_params["algorithm"])
    grid_search_params = modeling_params["grid_search_params"]
    n_search_splits = grid_search_params["nFolds"] if grid_search_params["mode"] == "TIME_SERIES_KFOLD" else 1

    # The fold offset is applied only when both splits use a k-fold strategy on the same base data (no custom intervals).
    # Its purpose is to prevent the test sets of evaluation and hp search from overlapping.
    is_kfold_evaluation = n_evaluation_splits > 1 and not core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_SPLIT)
    is_kfold_hp_search = n_search_splits > 1
    is_fold_offset_applicable = is_kfold_evaluation and is_kfold_hp_search
    fold_offset = grid_search_params["foldOffset"] if is_fold_offset_applicable else False
    if fold_offset:
        logger.info("Time series split: using a fold offset for hyperparameters search and evaluation")

    equal_duration_folds = (n_evaluation_splits > 1 or n_search_splits > 1) and grid_search_params.get("equalDurationFolds", False)
    if equal_duration_folds:
        logger.info("Time series split: using equal duration train set folds")

    # Check if some time series will be skipped because they are too short
    with (preprocessing_listener.push_step(step_constants.ProcessingStep.STEP_PREPARE_SPLITS)):
        if core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_SPLIT):
            evaluation_split_handler = CustomTrainTestTimeseriesSplitHandler(core_params)
        else:
            evaluation_split_handler = KFoldTimeseriesSplitHandler(n_evaluation_splits, core_params, fold_offset, equal_duration_folds)
        search_split_handler = KFoldTimeseriesSplitHandler(n_search_splits, core_params, fold_offset, equal_duration_folds)

        search_runner = algorithm.get_search_runner(search_split_handler, core_params, modeling_params, preprocessing_params,
                                                    model_folder_context=model_folder_context)
        search_needed = not search_runner.search_skipped()

        min_timeseries_size_for_training = algorithm.get_min_size_for_training(modeling_params, preprocessing_params, core_params[doctor_constants.PREDICTION_LENGTH])
        timeseries_identifier_columns = core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]
        # Skipping too short time series does not make sense for single time series
        skip_too_short_timeseries = len(timeseries_identifier_columns) > 0 and core_params[doctor_constants.SKIP_TOO_SHORT_TIMESERIES_FOR_TRAINING]

        # The smallest required time series size can either come from evaluation splitting or hyperparameter search splitting
        # We take the most restrictive condition to skip time series that are too short
        if (
            n_evaluation_splits > n_search_splits
            or not search_needed
            or core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_SPLIT) # If we have custom interval actually we don't care about test size and only want to split and make sure the train part is ok within `prepare_split_df`
        ):
            full_df = evaluation_split_handler.prepare_split_dataframe(
                full_df, min_timeseries_size_for_training, skip_too_short_timeseries=skip_too_short_timeseries
            )
        else:
            # We add the test_size to min_timeseries_size_for_training because we always remove one test set before running hyperparam search
            test_size = core_params[doctor_constants.EVALUATION_PARAMS][doctor_constants.TEST_SIZE]
            full_df = search_split_handler.prepare_split_dataframe(
                full_df, min_timeseries_size_for_training + test_size, skip_too_short_timeseries=skip_too_short_timeseries
            )

    # We preprocess the full dataframe after potentially skipping time series that are too short for splitting
    full_timeseries_preprocessing = TimeseriesPreprocessing(run_folder_context, core_params, preprocessing_params, modeling_params, preprocessing_listener)

    # If model is gluon based, we run the preprocessing on the whole dataframe for the external features,
    # because the model is trained on the full dataframe.
    # In the statistical models case, we perform the preprocessing on a per time series basis, because
    # we train one model per time series.
    preprocess_on_full_df = algorithm.USE_GLUON_TS

    # No preprocessing is done when there are no external features
    transformed_full_df = full_timeseries_preprocessing.fit_and_process(
        full_df,
        step_constants.ProcessingStep.STEP_PREPROCESS_FULL,
        preprocess_on_full_df,
        save_data=True,
    )
    preprocessing_end = unix_time_millis()

    if has_external_features and not full_timeseries_preprocessing.external_features:
        raise ValueError("Missing preprocessed external features.")

    # The external features are used only if the algorithm supports them
    full_preprocessed_external_features = full_timeseries_preprocessing.external_features if algorithm.SUPPORTS_EXTERNAL_FEATURES else None
    use_external_features = bool(full_preprocessed_external_features)
    full_generated_features_mappings = {ts_id: pipeline.generated_features_mapping for ts_id, pipeline in full_timeseries_preprocessing.pipeline_by_timeseries.items()}
    full_shift_map = get_shift_map(preprocessing_params, full_generated_features_mappings)

    # All shifts and windows algos require at least one generated feature
    if algorithm.REQUIRES_EXTERNAL_FEATURES and full_shift_map.is_empty():
        raise ValueError("This algorithm requires external features, configure this under the 'Feature generation' tab.")

    # We initialize the training handler and search runner model scorer **after** full dataframe preprocessing,
    # so that we know for sure external features are used.
    model_scorer = TimeseriesModelScorer.build(core_params, metrics_params, use_external_features)
    training_handler = TimeseriesTrainingHandler(core_params, model_scorer, use_external_features, algorithm, model_folder_context, listener)
    search_runner.set_model_scorer(model_scorer)

    # Beginning of modeling
    model_start = unix_time_millis()
    if n_evaluation_splits > 1:
        search_df = None  # if no search, it can just be None
        if search_needed:
            search_df = get_search_df(full_df, core_params)
            # For preprocessing during HP search, we call 'fit_and_process' only once on search_df
            search_df, _, _, search_preprocessed_external_features, train_features_mappings = preprocess_train_test_external_features(
                use_external_features,
                search_df,
                None,
                None,
                core_params,
                preprocessing_params,
                modeling_params,
                run_folder_context,
                model_folder_context,
                preprocess_on_full_df,
                ProgressListener(),  # dummy listener not to crowd the model snippet, still logged
            )
            search_shift_map = get_shift_map(preprocessing_params, train_features_mappings)
            search_runner.set_preprocessed_external_features(search_preprocessed_external_features, search_shift_map)

        # hyperparameter search on the transformed search dataframe
        estimator, optimized_modeling_params = search_best_hyperparameters(search_df, search_runner, modeling_params, training_handler, model_folder_context, listener)

        listener.push_step(step_constants.ProcessingStep.KFOLD_STEP_PROCESSING_FOLD)
        listener.save_status()

        for fold_id, (train_df, test_df, historical_df) in enumerate(evaluation_split_handler.split(full_df)):
            listener.push_step("[{}/{}]".format(fold_id + 1, n_evaluation_splits))
            listener.save_status()

            train_df, test_df, historical_df, train_fold_preprocessed_external_features, train_fold_features_mappings = preprocess_train_test_external_features(
                use_external_features,
                train_df,
                test_df,
                historical_df,
                core_params,
                preprocessing_params,
                modeling_params,
                run_folder_context,
                model_folder_context,
                preprocess_on_full_df,
                listener
            )
            fold_shift_map = get_shift_map(preprocessing_params, train_fold_features_mappings)
            training_handler.train(estimator, optimized_modeling_params, train_df, test_df=test_df,
                                   historical_df=historical_df,
                                   preprocessed_external_features=train_fold_preprocessed_external_features,
                                   shift_map=fold_shift_map, score_model=True, fold_id=fold_id)

            listener.pop_step()

        listener.pop_step()

    else:  # simple train/test split for the final evaluation

        train_df, test_df, historical_df = next(evaluation_split_handler.split(full_df)) # Only one split
        train_df, test_df, historical_df, train_preprocessed_external_features, train_generated_features_mappings = preprocess_train_test_external_features(
            use_external_features,
            train_df,
            test_df,
            historical_df,
            core_params,
            preprocessing_params,
            modeling_params,
            run_folder_context,
            model_folder_context,
            preprocess_on_full_df,
            ProgressListener(),  # dummy listener not to crowd the model snippet, still logged
        )
        train_shift_map = get_shift_map(preprocessing_params, train_generated_features_mappings)
        search_runner.set_preprocessed_external_features(train_preprocessed_external_features, train_shift_map)
        search_df = train_df

        # hyperparameter search on the train transformed dataframe
        estimator, optimized_modeling_params = search_best_hyperparameters(search_df, search_runner, modeling_params, training_handler, model_folder_context, listener)

        training_handler.train(estimator, optimized_modeling_params, train_df, test_df=test_df, historical_df=historical_df,
                               preprocessed_external_features=train_preprocessed_external_features, shift_map=train_shift_map,
                               score_model=True, fold_id=0, step_name=step_constants.ProcessingStep.STEP_FITTING_FOR_EVAL)

    # retrain model on full data and save it alongside its parameters (also predict future values if no external features)
    training_handler.train(estimator, optimized_modeling_params, transformed_full_df,
                           preprocessed_external_features=full_preprocessed_external_features, shift_map=full_shift_map,
                           save_model=True, step_name=step_constants.ProcessingStep.STEP_TIMESERIES_FITTING_GLOBAL)

    if not modeling_params.get("skipExpensiveReports"):
        with listener.push_step(step_constants.ProcessingStep.STEP_TIMESERIES_RESIDUALS):
            min_size_for_scoring = algorithm.get_min_size_for_scoring(
                algorithm.get_actual_params(optimized_modeling_params, estimator, fit_params=None)['resolved'],
                preprocessing_params,
                core_params[doctor_constants.PREDICTION_LENGTH])
            residuals_model_folder_context = model_folder_context.get_subfolder_context("residuals")
            residuals_model_folder_context.create_if_not_exist()
            estimator.compute_residuals(transformed_full_df, min_size_for_scoring, residuals_model_folder_context, None)

    with listener.push_step(step_constants.ProcessingStep.STEP_SAVING_DATA):
        training_handler.save_intrinsic_scores_and_forecasts(full_df, optimized_modeling_params, preprocessing_params, estimator)
        training_handler.save_scores()
        training_handler.save_predicted_data(full_df, preprocessing_params, schema)

    end = unix_time_millis()
    modeling_set["listener"] = listener
    listeners_json = preprocessing_listener.merge(modeling_set["listener"])
    write_done_traininfo(model_folder_context, start, model_start, end, listeners_json, end_preprocessing_time=preprocessing_end)


def search_best_hyperparameters(df, search_runner, modeling_params, training_handler, model_folder_context, listener):
    previous_search_time = get_hyperparams_search_time_traininfo(model_folder_context)
    initial_state = step_constants.ProcessingStep.STEP_HYPERPARAMETER_SEARCHING
    
    search_needed = not search_runner.search_skipped()
    if search_needed:
        listener.push_step(initial_state, previous_duration=previous_search_time)
        listener.save_status()
        if search_runner.search_settings.n_threads != 1:
            # Warn against unreproducible hyperparameter search (sc-133647)
            message = "Hyperparameter search might not be reproducible for time series forecasting models when using more than one thread."
            diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_REPRODUCIBILITY, message)
            logger.warning(message)

    search_runner.initialize_search_context(df)
    best_estimator = search_runner.get_best_estimator()

    if search_needed:
        hp_search_time_step = listener.pop_step()
        write_hyperparam_search_time_traininfo(model_folder_context, hp_search_time_step["time"])

        # add info about hyperparameter search to the intrinsic perf data
        training_handler.update_intrinsic_perf_data(search_runner.get_score_info())

    actual_params = training_handler.algorithm.get_actual_params(modeling_params, best_estimator, fit_params=None)

    optimized_params = actual_params["resolved"]
    # Regridify to a unary grid the optimized params
    optimized_modeling_params = regridify_optimized_params(optimized_params, modeling_params)

    return best_estimator, optimized_modeling_params


def preprocess_train_test_external_features(
        use_external_features,
        train_df, test_df, historical_df,
        core_params, preprocessing_params, modeling_params,
        preprocessing_folder_context, model_folder_context,
        preprocess_on_full_df, listener
):
    """
    Returns the dataframes untouched if use_external_features is False.
    Else, first 'fit_and_process' train_df and then 'process' test_df if it exists.
    And also returns the external features after preprocessing.
    """
    if not use_external_features:
        return train_df, test_df, historical_df, None, None

    # The use of `model_folder_context` instead of the apparently more suitable `preprocessing_folder_context` is intentional !
    # As a consequence of reusing much of the classical ML preprocessing logic, `TimeseriesPreprocessing._data_folder_context`
    # is being scanned during the `init_resources` call of each preprocessing step, and, if some resource is found,
    # this resource is loaded in the step `resource` attribute.
    # Here, resources have already been saved in the `preprocessing_folder_context` by the previous `full_timeseries_preprocessing.fit_and_process` call.
    # However, since we're ignoring these resources here (they are only used in subsequent scoring), we do not want `init_resources` to
    # load anything so we pass a "decoy" context, namely `model_folder_context`, where no resource is to be found.
    timeseries_preprocessing_for_split = TimeseriesPreprocessing(
        model_folder_context, core_params, preprocessing_params, modeling_params, listener, windows_resources_folder_context=preprocessing_folder_context
    )

    train_df = timeseries_preprocessing_for_split.fit_and_process(
        train_df,
        step_constants.ProcessingStep.STEP_PREPROCESS_TRAIN,
        preprocess_on_full_df,
        save_data=False,
    )

    preprocessed_external_features = timeseries_preprocessing_for_split.external_features
    train_features_mappings = {ts_id: pipeline.generated_features_mapping for ts_id, pipeline in timeseries_preprocessing_for_split.pipeline_by_timeseries.items()}
    if test_df is not None:
        test_df = timeseries_preprocessing_for_split.process(
            test_df,
            step_constants.ProcessingStep.STEP_PREPROCESS_TEST,
            preprocess_on_full_df,
        )
        
    if historical_df is not None:
        historical_df = timeseries_preprocessing_for_split.process(
            historical_df,
            step_constants.ProcessingStep.STEP_PREPROCESS_TEST,
            preprocess_on_full_df,
        )

    return train_df, test_df, historical_df, preprocessed_external_features, train_features_mappings


def get_search_df(full_df, core_params):
    custom_train_test_intervals = core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_INTERVALS) if core_params.get(doctor_constants.CUSTOM_TRAIN_TEST_SPLIT) else None
    hp_search_intervals = TimeseriesForecastingSearchContext.get_search_intervals(full_df, core_params[TIME_VARIABLE], custom_train_test_intervals)
    if len(hp_search_intervals) > 1:
        time_col = full_df[core_params[TIME_VARIABLE]]
        # Create a mask for each interval to mark data within the interval
        interval_masks = [
            (time_col >= interval_start) & (time_col < interval_end)
            for interval_start, interval_end in hp_search_intervals
        ]
        # Combine the masks
        search_mask = np.logical_or.reduce(interval_masks)
        # Filter full_df to retrieve only the data within the search intervals
        search_df = full_df[search_mask]
    else:
        # remove 1 test set from full_df to get search_df
        dummy_split_handler = KFoldTimeseriesSplitHandler(1, core_params)
        search_df, _, _ = next(dummy_split_handler.split(full_df))
    return search_df


def generate_auto_shifts_resources(full_df, core_params, preprocessing_params, run_folder_context):
    auto_shift_columns = get_auto_shift_columns(preprocessing_params)
    TimeseriesAutoShiftsGenerator(
        run_folder_context,
        core_params[TIMESERIES_IDENTIFIER_COLUMNS],
        core_params[TARGET_VARIABLE],
        auto_shift_columns["past_only"],
        auto_shift_columns["known_in_advance"],
        get_auto_shifts_past_only_range(preprocessing_params),
        get_auto_shifts_known_in_advance_range(preprocessing_params),
        core_params[PREDICTION_LENGTH],
        preprocessing_params.get("feature_generation", {}).get("auto_shifts_params", {}).get("max_selected_horizon_shifts", 0)
    ).process(full_df)


def populate_auto_shifts(preprocessing_params, run_folder_context):
    auto_shifts = TimeseriesAutoShiftsGenerator.load_auto_shifts_resource(run_folder_context)
    at_least_one_auto_shift_selected = False
    for column in list(preprocessing_params["feature_generation"]["shifts"].keys()):
        role = preprocessing_params["per_feature"][column]["role"]
        if (is_shift_compatible(role) and preprocessing_params["feature_generation"]["shifts"][column].get(
                "from_horizon_mode", 'FIXED') == 'AUTO'):
            preprocessing_params["feature_generation"]["shifts"][column]["from_horizon_auto"] = \
            auto_shifts["aggregated"][column]["selected_shifts"]
            if not at_least_one_auto_shift_selected and len(auto_shifts["aggregated"][column]["selected_shifts"]) > 0:
                at_least_one_auto_shift_selected = True
    run_folder_context.write_json("rpreprocessing_params.json", preprocessing_params)

    if not at_least_one_auto_shift_selected:
        message = "Auto-shifts feature generation did not detect any correlated values."
        diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_MODELING_PARAMETERS, message)
        logger.warning(message)
