import logging

import numpy as np
import os
import pandas as pd

from dataiku.core import dkujson
from dataiku.doctor.timeseries.models.base_estimator import BaseTimeseriesEstimator
from dataiku.doctor.timeseries.utils import ModelForecast, timeseries_iterator, log_df, SINGLE_TIMESERIES_IDENTIFIER, \
    future_date_range, DUMMY_IDENTIFIER_DELIMITER, NEURALFORECAST_N_WINDOWS
from dataiku.doctor.utils import get_platform_info
from dataiku.doctor.utils.gpu_execution import NeuralforecastGPUCapability, get_default_gpu_config

logger = logging.getLogger(__name__)

if get_platform_info()['machine'] == 'arm64':
    # Sets torch env variables to be compatible with mps devices (arm).
    os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = '1'

class DKUNeuralforecastEstimator(BaseTimeseriesEstimator):

    def __init__(
            self,
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment=None,
            quantiles=None,
            random_state=1337,
            gpu_config=get_default_gpu_config()
    ):
        super(DKUNeuralforecastEstimator, self).__init__(
            frequency,
            prediction_length,
            time_variable,
            target_variable,
            timeseries_identifiers,
            monthly_day_alignment,
        )
        self.time_variable = time_variable
        self.target_variable = target_variable
        self.timeseries_identifiers = timeseries_identifiers
        self.quantiles = quantiles
        self.random_state = random_state
        self.gpu_config = gpu_config

        self.predictor = None
        self.nf = None
        self.fitted_values = None

        if NeuralforecastGPUCapability.should_use_gpu(gpu_config):
            self.accelerator = "gpu"
            self.devices = NeuralforecastGPUCapability.get_lightning_devices(gpu_config)
        else:
            self.accelerator = "cpu"
            self.devices = "auto"

    @staticmethod
    def get_name():
        raise NotImplementedError

    def get_model(self):
        raise NotImplementedError

    def get_trainer_kwargs(self):
        """
        Enable/disable GPU depending on the gpu_config.
        """
        return {
            "accelerator": self.accelerator,
            "devices": self.devices,
        }

    def initialize(self, core_params, modeling_params):
        pass

    def _build_single_forecasts_dict(self, prediction, prediction_length_override, last_past_date):
        ret = {
            ModelForecast.TIMESTAMPS: future_date_range(
                last_past_date,
                prediction_length_override if prediction_length_override else self.prediction_length,
                self.frequency,
                self.monthly_day_alignment,
            ),
            ModelForecast.FORECAST_VALUES: prediction[self.get_name()].to_numpy()[:prediction_length_override],
            ModelForecast.QUANTILES_FORECASTS: DKUNeuralforecastEstimator.extract_quantile_forecasts(prediction, self.quantiles, self.get_name(), prediction_length_override)
        }
        return ret

    def fit(self, train_df, external_features=None, shift_map=None):
        """Fit one model for all timeseries"""
        # must be imported after the torch env variables are set
        from neuralforecast.utils import PredictionIntervals
        from neuralforecast import NeuralForecast

        if external_features is not None:
            self.external_features = external_features

        train_df = self.prepare_df(train_df)


        self.predictor = self.get_model()

        self.nf = NeuralForecast(models=[self.predictor], freq=self.frequency)
        self.nf.fit(train_df, prediction_intervals=PredictionIntervals(n_windows=NEURALFORECAST_N_WINDOWS))
        self.fitted_values = self.nf.predict_insample(step_size=self.prediction_length)
        return self

    def prepare_df(self, df):
        if df is None: # future_df can be None
            return None
        prepared_df = df.copy()
        if self.timeseries_identifiers:
            prepared_df['unique_id'] = self.build_unique_id_column(prepared_df, self.timeseries_identifiers)
        else:
            prepared_df['unique_id'] = SINGLE_TIMESERIES_IDENTIFIER
        unique_external_features = DKUNeuralforecastEstimator.get_unique_external_features(self.external_features) if self.external_features else {}
        known_in_advance_features = unique_external_features.get("INPUT", [])
        past_only_features = unique_external_features.get("INPUT_PAST_ONLY", [])
        columns_to_keep = known_in_advance_features + past_only_features + [self.time_variable, self.target_variable, 'unique_id']
        prepared_df = prepared_df.drop(columns=[column for column in prepared_df.columns if column not in columns_to_keep])
        return prepared_df.rename(columns={self.time_variable: 'ds', self.target_variable: 'y'})

    def prepare_future_df(self, future_df):
        """
        Extends if needed the external features to reach a length of self.prediction_length.
        Returns a df compatible with neuralforecasts models.
        """
        if future_df is None:
            return None

        if not self.timeseries_identifiers:
            return self.prepare_df(self._extend_series(future_df))

        extended_dfs = []
        for _, group_df in future_df.groupby(self.timeseries_identifiers):
            extended_group = self._extend_series(group_df)
            extended_dfs.append(extended_group)

        future_df_extended = pd.concat(extended_dfs, ignore_index=True)
        return self.prepare_df(future_df_extended)

    def _extend_series(self, df):
        """
        Extends (if needed) the single identifier input df to reach the minimum size of self.prediction_length.
        For the new rows added :
        - external features are filled with 0s
        - identifier columns are filled with the current identifier value
        - dates are extended
        """
        if len(df) >= self.prediction_length:
            return df

        last_date = df[self.time_variable].iloc[-1]
        new_dates = future_date_range(last_past_date=last_date, prediction_length=self.prediction_length - len(df), frequency=self.frequency, format_as_string=False)
        new_rows = pd.DataFrame({self.time_variable: new_dates})

        for col in self.timeseries_identifiers:
            new_rows[col] = df[col].iloc[0]

        for col in df.columns:
            if col not in [self.time_variable] + self.timeseries_identifiers:
                new_rows[col] = 0

        return pd.concat([df, new_rows], ignore_index=True)

    @staticmethod
    def get_unique_external_features(external_features_per_identifier):
        res = {}
        for feature_role_list in external_features_per_identifier.values():
            for feature_role in feature_role_list.keys():
                if feature_role not in res:
                    res[feature_role] = set()
                for feature in feature_role_list[feature_role]:
                    res[feature_role].add(feature)
        for feature_role in res:
            res[feature_role] = list(res[feature_role])
            # turning a set into a list doesn't provide a stable ordering
            # to enable the reproducibility of models, the list is sorted
            res[feature_role].sort()
        return res

    @staticmethod
    def to_neural_forecast_identifiers(identifiers, identifier_columns):
        """
        Generates a string from the identifier dict. Usage of dummy values in the generated identifiers to prevent collisions.
        """
        parsed_identifiers = dkujson.loads(identifiers)
        return DUMMY_IDENTIFIER_DELIMITER.join([str(parsed_identifiers[id_col]) for id_col in identifier_columns])

    @staticmethod
    def build_unique_id_column(df, identifiers):
        """
        Aggregates the identifier columns. Usage of dummy values in the generated identifiers to prevent collisions.
        This method exists partly to be unit tested.
        """
        return df[identifiers].astype(str).agg(DUMMY_IDENTIFIER_DELIMITER.join, axis=1)

    def _prepare_predict(self):
        """
        Basic check before running the predict method.
        """
        if self.predictor is None:
            raise ValueError("Trying to predict an estimator that has not been trained")

    def _predict(self, past_df, future_df):
        self._prepare_predict()
        past_df_prepared = self.prepare_df(past_df)
        future_df_prepared = self.prepare_future_df(future_df)
        # neuralforecast models can only predict the same number of values than the prediction_length used at fit time
        # moreover, for algos supporting past only external features, recursive scoring is not possible
        # combining those 2 conditions, we always predict a length of prediction_length
        # if we want to score less than prediction_length values, the forecasts are trimmed in _build_single_forecasts_dict
        forecasts = self.nf.predict(past_df_prepared, futr_df=future_df_prepared, random_seed=self.random_state)

        quantiles_df = self.compute_quantiles(past_df_prepared, forecasts)

        result_df = pd.concat([forecasts.reset_index(drop=True), quantiles_df.reset_index(drop=True)], axis=1)
        return result_df


    def compute_quantiles(self, past_df_prepared, forecasts):
        """
        adapted from neuralforecast.utils.add_conformal_distribution_intervals method.
        Computes the quantiles by taking advantage of the conformal scoring computed by the neuralforecast model at fit time.
        Unfortunately, we cannot use the native quantiles computation since it doesn't support missing identifiers.
        """
        # conformal prediction score
        cs_df = self.nf._cs_df

        # filter out missing identifiers, the step that make neuralforecast native method break
        cs_df = cs_df[cs_df["unique_id"].isin(past_df_prepared["unique_id"])]


        n_series = past_df_prepared["unique_id"].nunique()
        scores = cs_df[self.get_name()].to_numpy().reshape(n_series, NEURALFORECAST_N_WINDOWS, self.prediction_length)
        scores = scores.transpose(1, 0, 2)
        scores = scores[:, :, :self.prediction_length]
        mean = forecasts[self.get_name()].to_numpy().reshape(1, n_series, -1)
        scores = np.vstack([mean - scores, mean + scores])
        scores_quantiles = np.quantile(
            scores,
            self.quantiles,
            axis=0,
        )
        scores_quantiles = scores_quantiles.reshape(len(self.quantiles), -1).T

        return pd.DataFrame(scores_quantiles, columns=[str(quantile) for quantile in self.quantiles])

    def predict_single(
            self,
            past_df_of_timeseries_identifier,
            future_df_of_timeseries_identifier,
            quantiles,
            timeseries_identifier,
            fit_before_predict=False,
            prediction_length_override=None
    ):
        """
        Produces the forecast values for a single time series, with identifier timeseries_identifier
        """
        last_past_date = past_df_of_timeseries_identifier[self.time_variable].iloc[-1]
        return self._build_single_forecasts_dict(self._predict(past_df_of_timeseries_identifier, future_df_of_timeseries_identifier), prediction_length_override, last_past_date)

    def predict(self, past_df, future_df, quantiles, fit_before_predict=False, prediction_length_override=None):
        """
        Produces the forecast values for all time series
        Return:
            Dictionary where keys are time series identifiers and values are the forecast values for the time series.
            Each forecast contains the time stamps, the mean forecast values, and the quantile forecasts (2D-array)
        """
        self._prepare_predict()

        log_df(logger, past_df, self.time_variable, None, "\t - Past")
        if self.external_features.get("INPUT", False):
            log_df(logger, future_df, self.time_variable, None, "\t - External features future")

        forecast_df = self._predict(past_df, future_df)

        forecasts_by_timeseries = {}
        for identifier, past_df_of_timeseries_identifier in timeseries_iterator(
                past_df, self.timeseries_identifiers
        ):
            if identifier == SINGLE_TIMESERIES_IDENTIFIER:
                neuralforecast_identifier = SINGLE_TIMESERIES_IDENTIFIER
            else:
                neuralforecast_identifier = self.to_neural_forecast_identifiers(identifier, self.timeseries_identifiers)
            last_past_date = past_df_of_timeseries_identifier[self.time_variable].iloc[-1]
            forecasts_by_timeseries[identifier] = self._build_single_forecasts_dict(
                forecast_df[forecast_df['unique_id'] == neuralforecast_identifier],
                prediction_length_override,
                last_past_date
            )

        return forecasts_by_timeseries

    def get_fitted_values_and_residuals(self, identifier, df_of_identifier, min_scoring_size):
        if identifier != SINGLE_TIMESERIES_IDENTIFIER:
            identifier = self.to_neural_forecast_identifiers(identifier, self.timeseries_identifiers)
        fitted_values = self.fitted_values[self.fitted_values['unique_id'] == identifier]
        fitted_values = fitted_values.drop(columns=['unique_id'])
        fitted_values = [np.nan] * (len(df_of_identifier[self.target_variable]) - len(fitted_values[self.get_name()])) + fitted_values[self.get_name()].to_list()
        residuals = (df_of_identifier[self.target_variable] - fitted_values)
        return fitted_values, residuals


    @staticmethod
    def extract_quantile_forecasts(result_df, quantiles, model_name, prediction_length_override=None):
        """
        Extracts forecast values for a given list of quantiles from a
        statsforecast result DataFrame.
        Args:
            result_df (pd.DataFrame): The DataFrame from a statsforecast model.predict() call.
            quantiles (list[float]): A list of quantiles to extract (e.g., [0.1, 0.5, 0.9]).
            model_name (str): Model name from _get_statsforecast_model_name.
            prediction_length_override (int): Maximum number of values to return.
        Returns:
            list[list[float]]: A list of lists, where each inner list contains the
                               forecast values for a corresponding quantile.
        Raises:
            KeyError: If a column required for a quantile is not in the DataFrame.
        """
        forecasted_quantiles = []

        for q in sorted(quantiles):
            if str(q) not in result_df.columns:
                raise Exception("The prediction result from neuralforecast model is missing quantiles data.")
            forecasted_quantiles.append(result_df[str(q)][:prediction_length_override].values)
        return np.array(forecasted_quantiles)
