import logging
import math
import pandas as pd
import json

from dataiku.core import doctor_constants
from dataiku.doctor import utils, step_constants
from dataiku.core import intercom
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.prediction.custom_scoring import get_custom_evaluation_metric
from dataiku.doctor.timeseries.preparation.preprocessing import TimeseriesPreprocessing
from dataiku.doctor.timeseries.models import TimeseriesForecastingAlgorithm
from dataiku.doctor.timeseries.perf.model_perf import TimeseriesModelScorer, PER_TIMESERIES_METRICS, TIMESERIES_AGGREGATED_METRICS
from dataiku.doctor.timeseries.preparation.resampling.utils import get_frequency
from dataiku.doctor.timeseries.preparation.preprocessing import get_external_features
from dataiku.doctor.timeseries.score.scoring_handler import resample_for_scoring
from dataiku.doctor.timeseries.utils import add_timeseries_identifiers_columns, prefix_custom_metric_name
from dataiku.doctor.timeseries.utils import build_quantile_column_name
from dataiku.doctor.timeseries.utils import add_ignored_timeseries_diagnostics_and_logs
from dataiku.doctor.timeseries.utils import timeseries_iterator
from dataiku.doctor.timeseries.utils import FORECAST_COLUMN
from dataiku.doctor.timeseries.utils import ModelForecast
from dataiku.doctor.utils.listener import DiagOnlyContext
from dataiku.doctor.utils.listener import ProgressListener


logger = logging.getLogger(__name__)

WORST_PREFIX = 'worst'


class TimeseriesEvaluationHandler(object):
    def __init__(self, core_params, preprocessing_params, modeling_params, metrics_params, clf, preprocessing_folder_context,
                 model_evaluation_store_folder_context, diagnostics_folder_context):
        self.core_params = core_params
        self.preprocessing_params = preprocessing_params
        self.modeling_params = modeling_params
        self.clf = clf

        self.resampling_params = self.preprocessing_params[doctor_constants.TIMESERIES_SAMPLING]

        self.prediction_length = core_params[doctor_constants.PREDICTION_LENGTH]
        self.time_variable = core_params[doctor_constants.TIME_VARIABLE]
        self.target_variable = core_params[doctor_constants.TARGET_VARIABLE]
        self.timeseries_identifier_columns = core_params[doctor_constants.TIMESERIES_IDENTIFIER_COLUMNS]
        self.frequency = get_frequency(core_params)

        evaluation_params = core_params[doctor_constants.EVALUATION_PARAMS]
        self.gap_size = evaluation_params[doctor_constants.GAP_SIZE]

        self.algorithm = TimeseriesForecastingAlgorithm.build(self.modeling_params["algorithm"])
        self.min_timeseries_size = self.algorithm.get_min_size_for_scoring(self.modeling_params, self.prediction_length)

        self.custom_metrics = metrics_params["customMetrics"]

        self.external_features = (
            get_external_features(self.preprocessing_params) if self.algorithm.SUPPORTS_EXTERNAL_FEATURES else None
        )

        context = DiagOnlyContext(diagnostics_folder_context)
        self.listener = ProgressListener(context=context)

        self.timeseries_preprocessing = TimeseriesPreprocessing(
            preprocessing_folder_context, core_params, preprocessing_params, self.listener
        )
        self.has_model_evaluation_store = model_evaluation_store_folder_context is not None
        self.model_evaluation_store_folder_context = model_evaluation_store_folder_context
        if self.external_features:
            self.timeseries_preprocessing.load_resources()

    def evaluate(self, df, schema, quantiles, max_nb_forecast_timesteps, output_metrics, partition_columns=None, compute_metrics=False, compute_per_timeseries_metrics=False, refit=False):
        """Compute forecasts and metrics for an evaluation recipe

        Args:
            df (DataFrame): Input dataframe on which to evaluate
            schema (dict): Schema of the input dataframe
            quantiles (list[float]): Quantiles used in forecast
            max_nb_forecast_timesteps (int): Maximum number of forecast timesteps (starting from the end) to include in the output
            output_metrics (list[str]): List of metrics to output
            partition_columns (list, optional): Columns from which to extract partitioning values, in partition dispatch mode. Defaults to None.
            compute_metrics (bool, optional): Whether to compute metrics dataframe. Defaults to False.
            compute_per_timeseries_metrics (bool, optional): Whether to compute per-timeseries metrics or global aggregated metrics. Defaults to False.
            refit (bool, optional): Whether to refit the model on input data (only valid for statistical, non-gluonts models)

        Returns:
            (DataFrame, DataFrame): Tuple of a Dataframe containing the forecasts computed by using a rolling window for the past data, and the ground truth from the initial dataframe
                and a DataFrame containing the metrics of each timeseries or aggregated over all timeseries.
        """
        logger.info("Running evaluation recipe with params: {}".format(self.modeling_params))

        with self.listener.push_step(step_constants.ProcessingStep.STEP_TIMESERIES_RESAMPLING):
            # Resample/preprocess full df. Resampling sorts the dataframe by time.
            resampled_df = resample_for_scoring(
                df,
                schema,
                self.resampling_params,
                self.core_params,
                self.preprocessing_params,
                self.external_features,
                separate_external_features=False
            )

        if self.external_features:
            preprocess_on_full_df = self.algorithm.USE_GLUON_TS

            self.timeseries_preprocessing.create_timeseries_preprocessing_handlers(
                resampled_df, preprocess_on_full_df, use_saved_resources=True
            )

            preprocessed_df = self.timeseries_preprocessing.process(
                resampled_df,
                step_constants.ProcessingStep.STEP_PREPROCESS_TEST,
                preprocess_on_full_df,
            )
        else:
            preprocessed_df = resampled_df

        model_scorer = TimeseriesModelScorer(
            self.target_variable,
            self.time_variable,
            self.timeseries_identifier_columns,
            self.prediction_length,
            self.gap_size,
            quantiles,
            bool(self.timeseries_preprocessing.external_features),
            self.frequency,
            self.custom_metrics,
            max_nb_forecast_timesteps=max_nb_forecast_timesteps
        )

        with self.listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            # To evaluate a time series, it must be bigger than the minimum required time series size for scoring + horizon
            min_required_length_for_evaluation = self.min_timeseries_size + self.prediction_length

            unseen_timeseries_identifiers = []
            too_short_timeseries_identifiers = []
            train_dfs, test_dfs = [], []
            for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(preprocessed_df, self.timeseries_identifier_columns):
                timeseries_length = len(df_of_timeseries_identifier.index)

                if not self.algorithm.USE_GLUON_TS and timeseries_identifier not in self.clf.trained_models:
                    unseen_timeseries_identifiers.append(timeseries_identifier)
                elif timeseries_length < min_required_length_for_evaluation:
                    too_short_timeseries_identifiers.append(timeseries_identifier)
                else:
                    # Compute the maximum number of horizons we can evaluate:
                    # i.e. number of forecast horizons (prediction_length) that can fit in the time steps after the
                    # min required length (timeseries_length - self.min_timeseries_size)
                    nb_horizons_to_evaluate = math.floor((timeseries_length - self.min_timeseries_size) / self.prediction_length)
                    # Enforce number of timesteps set by user
                    if max_nb_forecast_timesteps > 0:
                        timesteps_as_horizons = math.ceil(max_nb_forecast_timesteps / self.prediction_length)
                        nb_horizons_to_evaluate = min(nb_horizons_to_evaluate, timesteps_as_horizons)

                    # Split each time series dataframe into train and test (the last nb_horizons_to_evaluate * horizon time steps)
                    train_dfs.append(df_of_timeseries_identifier[:-(nb_horizons_to_evaluate * self.prediction_length)])
                    test_dfs.append(df_of_timeseries_identifier[-(nb_horizons_to_evaluate * self.prediction_length):])

            add_ignored_timeseries_diagnostics_and_logs(
                self.timeseries_identifier_columns,
                unseen_timeseries_identifiers,
                too_short_timeseries_identifiers,
                all_timeseries_ignored=len(train_dfs)==0,
                min_required_length=min_required_length_for_evaluation,
                recipe_type="evaluation",
                diagnostic_type=DiagnosticType.ML_DIAGNOSTICS_EVALUATION_DATASET_SANITY_CHECKS,
            )

            train_df = pd.concat(train_dfs, ignore_index=True)
            test_df = pd.concat(test_dfs, ignore_index=True)

            fit_before_predict = self.algorithm.should_fit_before_predict(force=refit)
            forecasts_by_timeseries = model_scorer.predict_all_test_timesteps(
                self.clf, train_df, test_df, fit_before_predict
            )

        metrics_df = pd.DataFrame()
        per_timeseries_metrics, aggregated_metrics = {}, {}

        if self.has_model_evaluation_store or compute_metrics:
            per_timeseries_metrics, aggregated_metrics = model_scorer.score(
                train_df, test_df, forecasts_by_timeseries,
                append_forecasts=True,
                compute_aggregated_metrics=(self.has_model_evaluation_store or not compute_per_timeseries_metrics),
            )

        if self.has_model_evaluation_store:
            score = { "perTimeseriesMetrics": per_timeseries_metrics, "aggregatedMetrics": aggregated_metrics}
            _compute_and_append_worst_metrics(score, self.custom_metrics)
            _save_score(self.timeseries_identifier_columns, model_scorer, score, self.model_evaluation_store_folder_context)
            _save_forecasts(model_scorer, self.model_evaluation_store_folder_context)
            _update_intrinsic_perf(model_scorer, self.model_evaluation_store_folder_context)

        if compute_metrics:
            if compute_per_timeseries_metrics:
                metrics_df = _build_per_timeseries_metrics_df(per_timeseries_metrics, output_metrics)
            else:
                metrics_df = _build_aggregated_metrics_df(aggregated_metrics, output_metrics)

        forecasts_dfs = []
        for timeseries_id, forecast in forecasts_by_timeseries.items():
            # Fill up forecast dataframe for timeseries identifier
            # 1. Forecast quantiles
            forecast_df_of_timeseries_id = pd.DataFrame(
                forecast[ModelForecast.QUANTILES_FORECASTS].T,
                columns=[build_quantile_column_name(q) for q in quantiles],
            )
            # 2. Forecast values
            forecast_df_of_timeseries_id[FORECAST_COLUMN] = forecast[ModelForecast.FORECAST_VALUES]

            # 3. Time variable
            forecast_df_of_timeseries_id[self.time_variable] = pd.to_datetime(forecast[ModelForecast.TIMESTAMPS])

            # 4. Time series identifiers if any
            add_timeseries_identifiers_columns(forecast_df_of_timeseries_id, timeseries_id)

            # Append to the global forecasts Dataframe
            forecasts_dfs.append(forecast_df_of_timeseries_id)

        forecasts_df = pd.concat(forecasts_dfs, ignore_index=True)

        # retrieve ground truth values and other features from the resampled data and not the preprocessed data
        forecasts_df = forecasts_df.merge(
            resampled_df, on=(self.timeseries_identifier_columns or []) + [self.time_variable], how="left"
        )

        # Fill up partition_columns if needed
        if partition_columns:
            # If we are in partition dispatch mode, retrieve partition column values from input dataframe
            for partition_column in partition_columns:
                forecasts_df[partition_column] = df.iloc[0][partition_column]

        return forecasts_df, metrics_df


def _save_forecasts(model_scorer, folder_context):
    total_nb_timeseries = len(model_scorer.forecasts)
    max_nb_timeseries = intercom.jek_or_backend_get_call("ml/prediction/get-max-nb-timeseries-in-forecast-charts")

    # save historical and forecasts values per timeseries for only the first max_nb_timeseries time series
    if total_nb_timeseries > max_nb_timeseries:
        truncated_forecasts = {
            timeseries_identifier: model_scorer.forecasts[timeseries_identifier]
            for timeseries_identifier in list(model_scorer.forecasts)[:max_nb_timeseries]
        }
        folder_context.write_json("forecasts.json.gz",
                                 {"perTimeseries": model_scorer.remove_naninf_in_forecasts(truncated_forecasts)})
        diagnostic_message = "Only the first {} out of the total {} evaluated time series will be displayed in the model report forecast charts.".format(max_nb_timeseries, total_nb_timeseries)
        diagnostics.add_or_update(DiagnosticType.ML_DIAGNOSTICS_DATASET_SANITY_CHECKS, diagnostic_message)
    else:
        folder_context.write_json("forecasts.json.gz",
                                 {"perTimeseries": model_scorer.remove_naninf_in_forecasts(model_scorer.forecasts)})

    if total_nb_timeseries > 1:
        # save only the first time series forecasts in another file to be used in the model evaluation snippets
        first_timeseries_identifier, first_timeseries_forecasts = next(iter(model_scorer.forecasts.items()))
        folder_context.write_json("one_forecast.json",
                                {"perTimeseries": {first_timeseries_identifier: model_scorer.remove_nanif_in_one_forecast(first_timeseries_forecasts)}})


def _get_per_timeseries_perf_file_name(timeseries_identifier_columns):
    """Save compressed perf json file only if there are multiple time series"""
    if timeseries_identifier_columns:
        return "per_timeseries_perf.json.gz"
    else:
        return "per_timeseries_perf.json"


def _save_score(timeseries_identifier_columns, model_scorer, score, folder_context):
    clean_score = model_scorer.remove_naninf(score)
    score_aggregated_metrics = {
        TIMESERIES_AGGREGATED_METRICS: clean_score[TIMESERIES_AGGREGATED_METRICS],
        PER_TIMESERIES_METRICS: {}
    }
    score_per_timeseries_metrics = {
        TIMESERIES_AGGREGATED_METRICS: {},
        PER_TIMESERIES_METRICS: clean_score[PER_TIMESERIES_METRICS]
    }
    # We save two different files because it happens quite often that we read a bunch of model evaluation without needing the per_timeseries metrics,
    # for instance when listing all the model evaluations of a model evaluation store and their aggregated metrics, so it feels like a waste of resources, and
    # potentially a performance problem at some point, to force the backend to parse everything everytime.
    folder_context.write_json('perf.json', score_aggregated_metrics)
    folder_context.write_json(_get_per_timeseries_perf_file_name(timeseries_identifier_columns), score_per_timeseries_metrics)


def _update_intrinsic_perf(model_scorer, folder_context):
    model_intrinsic_perf = folder_context.read_json("iperf.json")
    total_nb_timeseries = len(model_scorer.forecasts)
    model_intrinsic_perf["totalNbOfTimeseries"] = total_nb_timeseries
    folder_context.write_json("iperf.json", model_intrinsic_perf)


def _build_aggregated_metrics_df(aggregated_metrics, output_metrics):
    # keep only aggregated output metrics
    aggregated_metrics = {metric_name: aggregated_metrics[metric_name] for metric_name in output_metrics if metric_name in aggregated_metrics}
    metrics_df = pd.DataFrame([aggregated_metrics])

    # add column with timestamp of the evaluation session
    metrics_df["date"] = utils.get_datetime_now_utc()
    return metrics_df


def _compute_and_append_worst_metrics(score, custom_metrics):
    per_timeseries_metrics = score[PER_TIMESERIES_METRICS]
    aggregated_metrics = score[TIMESERIES_AGGREGATED_METRICS]
    custom_metric_signs = {}
    if custom_metrics:
        for custom_metric in custom_metrics:
            metric_sign = 1 if custom_metric["greaterIsBetter"] else -1
            custom_metric_signs[prefix_custom_metric_name(custom_metric['name'])] = metric_sign

    for timeseries_identifier, timeseries_score in per_timeseries_metrics.items():
        for metric_key, metric_value in timeseries_score.items():
            if metric_key in custom_metric_signs:
                metric_sign = custom_metric_signs[metric_key]
            else:
                metric_sign = -1  # timeseries default metrics are only errors (i.e. lower is better)

            # metric names will be like: worstMase, worstMape, worstSmape, worstMse, worstMsis
            worst_metric_key = "{}{}".format(WORST_PREFIX, metric_key.capitalize())

            if worst_metric_key not in aggregated_metrics\
                    or metric_value * metric_sign < aggregated_metrics[worst_metric_key] * metric_sign:
                aggregated_metrics[worst_metric_key] = metric_value


def _build_per_timeseries_metrics_df(per_timeseries_metrics, output_metrics):
    per_timeseries_metrics_rows = []
    for timeseries_identifier, metrics in per_timeseries_metrics.items():
        # keep only per-timeseries output metrics
        single_timeseries_metrics = {metric_name: metrics[metric_name] for metric_name in output_metrics if metric_name in metrics}
        # unfold the encoded identifier into several columns
        for timeseries_identifier_column, timeseries_identifier_value in json.loads(timeseries_identifier).items():
            single_timeseries_metrics[timeseries_identifier_column] = timeseries_identifier_value
        per_timeseries_metrics_rows.append(single_timeseries_metrics)
    metrics_df = pd.DataFrame(per_timeseries_metrics_rows)

    # add column with timestamp of the evaluation session
    metrics_df["date"] = utils.get_datetime_now_utc()
    return metrics_df
