import logging
import pandas as pd
from pandas.api.types import is_numeric_dtype

from dataiku.core.doctor_constants import PREDICTION_LENGTH
from dataiku.core.doctor_constants import TARGET_VARIABLE
from dataiku.core.doctor_constants import TIMESERIES_IDENTIFIER_COLUMNS
from dataiku.core.doctor_constants import TIMESERIES_SAMPLING
from dataiku.core.doctor_constants import TIME_VARIABLE
from dataiku.doctor import step_constants
from dataiku.doctor.diagnostics import diagnostics
from dataiku.doctor.diagnostics.diagnostics import DiagnosticType
from dataiku.doctor.exception import TimeseriesResamplingException
from dataiku.doctor.timeseries.preparation.preprocessing import TimeseriesPreprocessing
from dataiku.doctor.timeseries.models import TimeseriesForecastingAlgorithm
from dataiku.doctor.timeseries.preparation.resampling.utils import get_frequency
from dataiku.doctor.timeseries.preparation.preprocessing import get_external_features
from dataiku.doctor.timeseries.preparation.preprocessing import get_filtered_features
from dataiku.doctor.timeseries.preparation.preprocessing import resample_timeseries
from dataiku.doctor.timeseries.utils import add_timeseries_identifiers_columns
from dataiku.doctor.timeseries.utils import build_quantile_column_name
from dataiku.doctor.timeseries.utils import encode_timeseries_identifier
from dataiku.doctor.timeseries.utils import get_dataframe_of_timeseries_identifier
from dataiku.doctor.timeseries.utils import add_ignored_timeseries_diagnostics_and_logs
from dataiku.doctor.timeseries.utils import ignored_timeseries_warning_message
from dataiku.doctor.timeseries.utils import ignored_timeseries_diagnostic_message
from dataiku.doctor.timeseries.utils import timeseries_iterator
from dataiku.doctor.timeseries.utils import FORECAST_COLUMN
from dataiku.doctor.timeseries.utils import ModelForecast
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.listener import DiagOnlyContext


logger = logging.getLogger(__name__)


class TimeseriesScoringHandler(object):
    def __init__(self, core_params, preprocessing_params, modeling_params, clf, preprocessing_folder_context,
                 diagnostics_folder_context=None, scoring_prediction_length=None):
        self.core_params = core_params
        self.preprocessing_params = preprocessing_params
        self.modeling_params = modeling_params
        self.clf = clf

        self.resampling_params = self.preprocessing_params[TIMESERIES_SAMPLING]
        self.training_prediction_length = core_params[PREDICTION_LENGTH]

        if scoring_prediction_length is not None:
            self.scoring_prediction_length = scoring_prediction_length
        else:
            self.scoring_prediction_length = self.training_prediction_length
        self.time_variable = core_params[TIME_VARIABLE]
        self.target_variable = core_params[TARGET_VARIABLE]
        self.timeseries_identifier_columns = core_params[TIMESERIES_IDENTIFIER_COLUMNS]
        self.frequency = get_frequency(core_params)

        self.algorithm = TimeseriesForecastingAlgorithm.build(self.modeling_params["algorithm"])
        self.min_timeseries_size = self.algorithm.get_min_size_for_scoring(self.modeling_params, self.training_prediction_length)

        self.external_features = (
            get_external_features(self.preprocessing_params) if self.algorithm.SUPPORTS_EXTERNAL_FEATURES else None
        )

        context = DiagOnlyContext(diagnostics_folder_context) if diagnostics_folder_context else None
        self.listener = ProgressListener(context=context)

        self.timeseries_preprocessing = TimeseriesPreprocessing(preprocessing_folder_context, core_params, preprocessing_params, self.listener)
        if self.external_features:
            self.timeseries_preprocessing.load_resources()

    def score(self, df, schema, quantiles, past_time_steps_to_include=0, partition_columns=None, refit=False):
        logger.info("Running scoring recipe with params: {}".format(self.modeling_params))

        with self.listener.push_step(step_constants.ProcessingStep.STEP_SCORING):
            past_preprocessed_df, future_preprocessed_df, past_resampled_df, future_resampled_df = self._prepare_data_for_scoring(df, schema)

            fit_before_predict = self.algorithm.should_fit_before_predict(force=refit)
            forecasts_by_timeseries = self.clf.predict(
                past_preprocessed_df, future_preprocessed_df, quantiles, fit_before_predict=fit_before_predict, prediction_length_override=self.scoring_prediction_length
            )

        forecasts_dfs = []
        # we output the resampled data and not the preprocessed data
        for timeseries_identifier, timeseries_forecast in forecasts_by_timeseries.items():
            # Fill up forecast dataframe for timeseries identifier
            # 1. Forecast quantiles
            forecast_df_of_timeseries_identifier = pd.DataFrame(
                timeseries_forecast[ModelForecast.QUANTILES_FORECASTS].T,
                columns=[build_quantile_column_name(q) for q in quantiles],
            )
            # 2. Forecast values
            forecast_df_of_timeseries_identifier[FORECAST_COLUMN] = timeseries_forecast[ModelForecast.FORECAST_VALUES]

            # 3. Time variable
            forecast_df_of_timeseries_identifier[self.time_variable] = pd.to_datetime(timeseries_forecast[ModelForecast.TIMESTAMPS])

            # 4. Time series identifiers if any
            add_timeseries_identifiers_columns(forecast_df_of_timeseries_identifier, timeseries_identifier)

            # 5. Retrieve other features from the resampled data and not the preprocessed data
            if future_resampled_df is not None:
                future_resampled_df_of_timeseries_identifier = get_dataframe_of_timeseries_identifier(
                    future_resampled_df, timeseries_identifier
                )
                for column in df.columns:
                    if column not in (self.timeseries_identifier_columns or []) + [self.time_variable, self.target_variable]:
                        forecast_df_of_timeseries_identifier[column] = future_resampled_df_of_timeseries_identifier[column]

            # 6. Append past time steps of current time series
            past_resampled_df_of_timeseries_identifier = get_dataframe_of_timeseries_identifier(
                past_resampled_df, timeseries_identifier
            )
            forecast_df_of_timeseries_identifier = self._append_past_time_steps_to_forecast(
                past_resampled_df_of_timeseries_identifier,
                forecast_df_of_timeseries_identifier,
                past_time_steps_to_include,
            )

            # Append to the global forecasts Dataframe
            forecasts_dfs.append(forecast_df_of_timeseries_identifier)

        forecasts_df = pd.concat(forecasts_dfs, ignore_index=True)

        if partition_columns:
            # If we are in partition dispatch mode, retrieve partition column values from input dataframe
            for partition_column in partition_columns:
                forecasts_df[partition_column] = df.iloc[0][partition_column]

        return forecasts_df

    @staticmethod
    def _append_past_time_steps_to_forecast(past_time_steps_df, forecast_df, past_time_steps_to_include):
        if past_time_steps_to_include == 0:
            # Make sure forecast_df contains past_time_steps_df's columns (sc-91842)
            return pd.concat([past_time_steps_df.iloc[:0], forecast_df], ignore_index=True)
        elif past_time_steps_to_include == -1:
            return pd.concat([past_time_steps_df, forecast_df], ignore_index=True)
        else:
            return pd.concat([past_time_steps_df.iloc[-past_time_steps_to_include:], forecast_df], ignore_index=True)

    def _prepare_data_for_scoring(self, df, schema):
        """
        Prepare past and future data for scoring: resampling and preprocessing
        :return past_preprocessed_df, future_preprocessed_df, past_resampled_df, future_resampled_df
        """
        with self.listener.push_step(step_constants.ProcessingStep.STEP_TIMESERIES_RESAMPLING):
            resampled_df = resample_for_scoring(
                df,
                schema,
                self.resampling_params,
                self.core_params,
                self.preprocessing_params,
                self.external_features,
                separate_external_features=True,
            )

        if self.external_features:
            # split past and future df and check that there are enough future external features values
            past_resampled_df, future_resampled_df = self._split_past_and_future_df(resampled_df, check_enough_future_values=True)

            preprocess_on_full_df = self.algorithm.USE_GLUON_TS

            # we could either pass past_resampled_df or future_resampled_df because it will only look at which timeseries identifiers are in the df
            # to retrieve the right resources
            self.timeseries_preprocessing.create_timeseries_preprocessing_handlers(
                past_resampled_df, preprocess_on_full_df, use_saved_resources=True
            )

            past_preprocessed_df = self.timeseries_preprocessing.process(
                past_resampled_df,
                step_constants.ProcessingStep.STEP_PREPROCESS_TEST,
                preprocess_on_full_df,
            )

            future_preprocessed_df = self.timeseries_preprocessing.process(
                future_resampled_df,
                step_constants.ProcessingStep.STEP_PREPROCESS_TEST,
                preprocess_on_full_df,
            )

        else:
            # even without external features, there can be future values of extra columns that need be added to the output dataset
            past_resampled_df, future_resampled_df = self._split_past_and_future_df(resampled_df, check_enough_future_values=False)

            # when there are no external features, there are no preprocessing, only past resampling is used for prediction
            past_preprocessed_df = past_resampled_df
            future_preprocessed_df = None

        unseen_timeseries_identifiers = []
        too_short_timeseries_identifiers = []
        past_preprocessed_dfs = []
        for timeseries_identifier, past_preprocessed_df_of_timeseries_identifier in timeseries_iterator(
            past_preprocessed_df, self.timeseries_identifier_columns
        ):
            if not self.algorithm.USE_GLUON_TS and timeseries_identifier not in self.clf.trained_models:
                unseen_timeseries_identifiers.append(timeseries_identifier)
            elif len(past_preprocessed_df_of_timeseries_identifier.index) < self.min_timeseries_size:
                too_short_timeseries_identifiers.append(timeseries_identifier)
            else:
                past_preprocessed_dfs.append(past_preprocessed_df_of_timeseries_identifier)

        add_ignored_timeseries_diagnostics_and_logs(
            self.timeseries_identifier_columns,
            unseen_timeseries_identifiers,
            too_short_timeseries_identifiers,
            all_timeseries_ignored=len(past_preprocessed_dfs)==0,
            min_required_length=self.min_timeseries_size,
            recipe_type="scoring",
            diagnostic_type=DiagnosticType.ML_DIAGNOSTICS_SCORING_DATASET_SANITY_CHECKS,
        )

        # only past_preprocessed_df is filtered because we iterate through all time series of past_preprocessed_df when scoring
        # no need to filter future_preprocessed_df, past_resampled_df or future_resampled_df because we don't iterate over them,
        # instead we only call get_dataframe_of_timeseries_identifier on them to retrieve the needed time series
        past_preprocessed_df = pd.concat(past_preprocessed_dfs, ignore_index=True)

        return past_preprocessed_df, future_preprocessed_df, past_resampled_df, future_resampled_df

    def _split_past_and_future_df(self, df, check_enough_future_values=False):
        past_dfs = []
        future_dfs = []
        ignored_timeseries_identifiers = []
        for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(
            df, self.timeseries_identifier_columns
        ):
            df_of_timeseries_identifier.reset_index(drop=True, inplace=True)
            last_valid_target_index = df_of_timeseries_identifier[self.target_variable].last_valid_index()

            # this check is only performed if future values of external features are required
            n_actual_values = len(df_of_timeseries_identifier.index) - last_valid_target_index - 1
            if check_enough_future_values and self.scoring_prediction_length > n_actual_values:
                ignored_timeseries_identifiers.append(timeseries_identifier)
            else:
                past_dfs.append(df_of_timeseries_identifier.loc[:last_valid_target_index])
                future_dfs.append(
                    df_of_timeseries_identifier.loc[last_valid_target_index + 1 : last_valid_target_index + self.scoring_prediction_length]
                )

        if self.timeseries_identifier_columns and ignored_timeseries_identifiers:
            explanation_message = "because {} less than {} future values of external features".format(
                "they have" if len(ignored_timeseries_identifiers) > 1 else "it has",
                self.scoring_prediction_length
            )
            logger.warning(ignored_timeseries_warning_message(ignored_timeseries_identifiers, explanation_message))
            diagnostics.add_or_update(
                DiagnosticType.ML_DIAGNOSTICS_SCORING_DATASET_SANITY_CHECKS,
                ignored_timeseries_diagnostic_message(ignored_timeseries_identifiers, explanation_message)
            )

        if len(past_dfs) == 0:
            if self.timeseries_identifier_columns:
                error_message = "No timeseries has enough future values of external features for prediction"
            else:
                error_message = "Not enough future values of external features were provided for prediction"
            error_message += " ({n_required_values} future value{plural} {verb} required)".format(
                n_required_values=self.scoring_prediction_length,
                plural='' if self.scoring_prediction_length == 1 else 's',
                verb='is' if self.scoring_prediction_length == 1 else 'are'
            )
            raise ValueError(error_message)

        past_df = pd.concat(past_dfs, ignore_index=True)
        future_df = pd.concat(future_dfs, ignore_index=True)

        # preprocessing fail with only NaN in target column (future target values won't be used anyway)
        future_df[self.target_variable] = 0

        return past_df, future_df

    def _get_long_enough_timeseries_identifiers(self, timeseries_sizes):
        """
        Return time series identifiers for time series whose size is bigger than the min required size for scoring,
        and log time series identifiers of time series that are smaller and hence ignored.
        """
        timeseries_sizes_mask = timeseries_sizes >= self.min_timeseries_size

        kept_timeseries_identifiers_values = timeseries_sizes_mask[timeseries_sizes_mask].index.values
        ignored_timeseries_identifiers_values = timeseries_sizes_mask[~timeseries_sizes_mask].index.values
        if len(ignored_timeseries_identifiers_values) > 0:
            ignored_timeseries_identifiers = [
                encode_timeseries_identifier(ignored_timeseries_identifier, self.timeseries_identifier_columns)
                for ignored_timeseries_identifier in ignored_timeseries_identifiers_values
            ]
            explanation_message = "because {} less than {} time steps".format(
                "they have" if len(ignored_timeseries_identifiers) > 1 else "it has",
                self.min_timeseries_size
            )
            logger.warning(ignored_timeseries_warning_message(ignored_timeseries_identifiers, explanation_message))
            diagnostics.add_or_update(
                DiagnosticType.ML_DIAGNOSTICS_SCORING_DATASET_SANITY_CHECKS,
                ignored_timeseries_diagnostic_message(ignored_timeseries_identifiers, explanation_message)
            )

        return kept_timeseries_identifiers_values


def _remove_end_missing_target(df, time_variable, target_variable, timeseries_identifier_columns):
    """Remove last missing values of target for each timeseries to get the target dataframe to be scored"""
    base_columns = (timeseries_identifier_columns or []) + [time_variable]
    target_dfs = []
    for _, df_of_timeseries_identifier in timeseries_iterator(
        df, timeseries_identifier_columns
    ):
        # dataframe is not sorted by date in scoring
        df_of_timeseries_identifier.sort_values(by=[time_variable], inplace=True)
        df_of_timeseries_identifier.reset_index(drop=True, inplace=True)

        # last null values are the target values to forecast
        last_valid_target_index = df_of_timeseries_identifier[target_variable].last_valid_index()
        target_dfs.append(df_of_timeseries_identifier[base_columns + [target_variable]].loc[:last_valid_target_index])

    return pd.concat(target_dfs, ignore_index=True)


def resample_for_scoring(df, schema, resampling_params, core_params, preprocessing_params, external_features, separate_external_features=False):
    """Resampling for the scoring/evaluation recipes.
    First it resamples either the target and the external features together (evaluation) or separately (scoring) based on separate_external_features.
    Then it resamples the extra columns of the input dataset.
    Finally, it merges the resampled dataframes.
    """
    time_variable = core_params[TIME_VARIABLE]
    target_variable = core_params[TARGET_VARIABLE]
    timeseries_identifier_columns = core_params[TIMESERIES_IDENTIFIER_COLUMNS]

    ignored_timeseries_identifiers = []
    dfs = []
    # Ignore time series where the target column contains less than 2 valid (non-NaN) values because they cannot be resampled 
    for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(df, timeseries_identifier_columns):
        if df_of_timeseries_identifier[target_variable].count() < 2:
            ignored_timeseries_identifiers.append(timeseries_identifier)
        else:
            dfs.append(df_of_timeseries_identifier)

    if timeseries_identifier_columns and ignored_timeseries_identifiers:
        explanation_message = "because {} target column contain less than 2 valid values, {} cannot be resampled.".format(
            "their" if len(ignored_timeseries_identifiers) > 1 else "its",
            "they" if len(ignored_timeseries_identifiers) > 1 else "it",
        )
        logger.warning(ignored_timeseries_warning_message(ignored_timeseries_identifiers, explanation_message))
        diagnostics.add_or_update(
            DiagnosticType.ML_DIAGNOSTICS_TIMESERIES_RESAMPLING_CHECKS,
            ignored_timeseries_diagnostic_message(ignored_timeseries_identifiers, explanation_message)
        )

    if len(dfs) == 0:
        if timeseries_identifier_columns:
            error_message = "No time series can be resampled because their target column contain less than  2 valid values."
        else: 
            error_message = "Input time series cannot be resampled because its target column contains less than 2 valid values."
        raise TimeseriesResamplingException(error_message)

    # now df only contains time series with enough target values to be resampled
    df = pd.concat(dfs, ignore_index=True)

    base_columns = (timeseries_identifier_columns or []) + [time_variable]
    external_features_columns = (external_features or [])
    extra_columns = [col for col in df.columns if col not in base_columns + [target_variable] + external_features_columns]

    numerical_external_features = get_filtered_features(preprocessing_params, include_types=["NUMERIC"], include_roles=["INPUT"])
    categorical_external_features = get_filtered_features(preprocessing_params, exclude_types=["NUMERIC"], include_roles=["INPUT"])

    if separate_external_features:
        # first we resample the target column only
        if external_features:  # when we have external features, we need to remove every ending rows with missing values
            target_df = _remove_end_missing_target(df, time_variable, target_variable, timeseries_identifier_columns)
        else:
            target_df = df[base_columns + [target_variable]]

        logger.info("Resampling target column {}".format(target_variable))
        resampled_df = resample_timeseries(target_df, schema, resampling_params, core_params, numerical_columns=[target_variable], categorical_columns=[])

        # then we resample external features separately
        if external_features:
            external_features_df = df[base_columns + external_features_columns]
            logger.info("Resampling external features {}".format(external_features_columns))
            resampled_external_features_df = resample_timeseries(external_features_df, schema, resampling_params, core_params, numerical_external_features, categorical_external_features)
            resampled_df = pd.merge(resampled_df, resampled_external_features_df, on=base_columns, how="right")
    else:
        # we resample the target and the external features together
        # we only add the external features columns if the model does use external features (otherwise they are considered as extra columns)
        numerical_columns = [target_variable] + numerical_external_features if external_features else [target_variable]
        categorical_columns = categorical_external_features if external_features else []

        target_external_features_df = df[base_columns + [target_variable] + external_features_columns]
        message = "Resampling target column {}".format(target_variable)
        if external_features:
            message += " and external features {}".format(external_features_columns)
        logger.info(message)
        resampled_df = resample_timeseries(target_external_features_df, schema, resampling_params, core_params, numerical_columns, categorical_columns)

    # finally we resample the other extra columns
    if extra_columns:
        extra_df = df[base_columns + extra_columns]

        # extra columns can be REJECT or INPUT in preprocessing_params or not in preprocessing_params (in this case, we use is_numeric_dtype to get the type)
        numerical_preprocessing_columns = get_filtered_features(preprocessing_params, include_types=["NUMERIC"], include_roles=["INPUT", "REJECT"])
        categorical_preprocessing_columns = get_filtered_features(preprocessing_params, exclude_types=["NUMERIC"], include_roles=["INPUT", "REJECT"])

        numerical_columns = []
        categorical_columns = []
        for col in extra_columns:
            if col in numerical_preprocessing_columns:
                numerical_columns.append(col)
            elif col in categorical_preprocessing_columns:
                categorical_columns.append(col)
            elif is_numeric_dtype(extra_df[col]):
                numerical_columns.append(col)
            else:
                categorical_columns.append(col)

        logger.info("Resampling additional columns {}".format(extra_columns))
        resampled_extra_df = resample_timeseries(extra_df, schema, resampling_params, core_params, numerical_columns, categorical_columns)
        
        # in evaluation we merge left because we don't care about time steps of extra columns after the target
        how = "outer" if separate_external_features else "left"
        resampled_df = pd.merge(resampled_df, resampled_extra_df, on=base_columns, how=how)

    return resampled_df


def _filter_timeseries_dataframe(df, timeseries_identifier_columns, timeseries_identifiers_values):
    """
    Filter input dataframe, keeping only rows corresponding to timeseries whose identifiers are in
    timeseries_identifiers_values
    """
    if len(timeseries_identifier_columns) > 1:
        row_indices = df[timeseries_identifier_columns].apply(tuple, 1).isin(timeseries_identifiers_values)
    else:
        row_indices = df[timeseries_identifier_columns[0]].isin(timeseries_identifiers_values)
    return df.loc[row_indices].reset_index(drop=True)
