import pandas as pd
import logging

from dataiku.core.doctor_constants import END_OF_WEEK_DAY, UNIT_ALIGNMENT, MONTHLY_ALIGNMENT
from dataiku.core.doctor_constants import NUMBER_OF_TIME_UNITS
from dataiku.core.doctor_constants import TARGET_VARIABLE
from dataiku.core.doctor_constants import TIMESERIES_IDENTIFIER_COLUMNS
from dataiku.core.doctor_constants import TIME_STEP_PARAMS
from dataiku.core.doctor_constants import TIME_VARIABLE
from dataiku.core.doctor_constants import TIME_UNIT
from dataiku.doctor.diagnostics.timeseries import check_zero_target_ratio
from dataiku.doctor.utils import get_filtered_features
from dataiku.doctor.timeseries.preparation.resampling import Resampler
from dataiku.doctor.timeseries.preparation.resampling.utils import get_time_unit_end_of_week, get_monthly_day_alignment
from dataiku.doctor import step_constants
from dataiku.doctor.timeseries.utils import FULL_TIMESERIES_DF_IDENTIFIER
from dataiku.doctor.timeseries.utils import timeseries_iterator
from dataiku.doctor.timeseries.utils import add_timeseries_identifiers_columns
from dataiku.doctor.preprocessing_collector import PredictionPreprocessingDataCollector
from dataiku.doctor.preprocessing_handler import PredictionPreprocessingHandler
from dataiku.doctor.preprocessing_handler import read_resource
from dataiku.doctor.preprocessing_handler import write_resource
from dataiku.doctor.preprocessing.dataframe_preprocessing import OutputRawColumns
from dataiku.doctor.preprocessing.dataframe_preprocessing import RemapValueToOutput


logger = logging.getLogger(__name__)


def get_external_features(preprocessing_params):
    return get_filtered_features(preprocessing_params, include_roles=["INPUT"])


def resample_timeseries(
    df, schema, resampling_params, core_params, numerical_columns, categorical_columns,
    compute_zero_target_ratio_diagnostic=False
):
    time_step_params = core_params[TIME_STEP_PARAMS]
    number_of_time_units = time_step_params[NUMBER_OF_TIME_UNITS]
    time_unit = time_step_params[TIME_UNIT]
    time_unit_end_of_week = get_time_unit_end_of_week(time_step_params[END_OF_WEEK_DAY])
    unit_alignment = time_step_params.get(UNIT_ALIGNMENT)
    monthly_alignment = get_monthly_day_alignment(core_params)
    timeseries_identifier_columns = core_params[TIMESERIES_IDENTIFIER_COLUMNS]
    resampler = Resampler(
        interpolation_method=resampling_params["numericalInterpolateMethod"],
        extrapolation_method=resampling_params["numericalExtrapolateMethod"],
        interpolation_constant_value=resampling_params.get("numericalInterpolateConstantValue"),
        extrapolation_constant_value=resampling_params.get("numericalExtrapolateConstantValue"),
        category_imputation_method=resampling_params["categoricalImputeMethod"],
        category_constant_value=resampling_params.get("categoricalConstantValue"),
        start_date_mode=resampling_params.get("startDateMode"),
        custom_start_date=resampling_params.get("customStartDate"),
        end_date_mode=resampling_params.get("endDateMode"),
        custom_end_date=resampling_params.get("customEndDate"),
        time_step=number_of_time_units,
        time_unit=time_unit,
        time_unit_end_of_week=time_unit_end_of_week,
        duplicate_timestamps_handling_method=resampling_params["duplicateTimestampsHandlingMethod"],
        unit_alignment=unit_alignment,
        monthly_alignment=monthly_alignment,
    )

    log_resampling_params = resampling_params.copy()
    log_resampling_params["numericalColumns"] = len(numerical_columns) if numerical_columns else 0
    log_resampling_params["categoricalColumns"] = len(categorical_columns) if categorical_columns else 0
    log_resampling_params["identifiersColumns"] = len(timeseries_identifier_columns) if timeseries_identifier_columns else 0
    log_resampling_params["numberOfTimeunits"] = number_of_time_units
    log_resampling_params["timeUnit"] = time_unit
    log_resampling_params["endOfWeekDay"] = time_unit_end_of_week
    log_resampling_params["unitAlignment"] = unit_alignment
    log_resampling_params["monthlyAlignment"] = monthly_alignment
    logger.info("Resampling with params: {}".format(log_resampling_params))

    resampled_df = resampler.transform(
        df,
        datetime_column=core_params[TIME_VARIABLE],
        timeseries_identifier_columns=timeseries_identifier_columns,
        numerical_columns=numerical_columns,
        categorical_columns=categorical_columns,
    )

    columns_to_round = [
        column["name"]
        for column in schema["columns"]
        if column["type"] in ["tinyint", "smallint", "int", "bigint"]
        and column["name"] in numerical_columns
    ]
    # int columns must be resampled into int values (note that they can also contain NaN values)
    resampled_df[columns_to_round] = resampled_df[columns_to_round].round()

    if compute_zero_target_ratio_diagnostic:
        check_zero_target_ratio(resampled_df, timeseries_identifier_columns, core_params[TARGET_VARIABLE])
    
    return resampled_df


class TimeseriesPreprocessing:
    def __init__(self, data_folder_context, core_params, preprocessing_params, listener):
        self.data_folder_context = data_folder_context
        self.core_params = core_params
        self.preprocessing_params = preprocessing_params
        self.listener = listener

        self.preproc_handler_by_timeseries = {}
        self.pipeline_by_timeseries = {}
        self.resources = {}
        self.collector_data = {}
        self.external_features = {}

    def create_timeseries_preprocessing_handlers(self, df, on_full_df, use_saved_resources=False):
        if not get_external_features(self.preprocessing_params):
            return

        with self.listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING):
            if on_full_df:
                self._create_single_timeseries_preprocessing_handler(
                    FULL_TIMESERIES_DF_IDENTIFIER,
                    df,
                    timeseries_identifiers_to_output=True,
                    use_saved_resources=use_saved_resources,
                )
            else:
                for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(
                    df, self.core_params[TIMESERIES_IDENTIFIER_COLUMNS]
                ):
                    self._create_single_timeseries_preprocessing_handler(
                        timeseries_identifier,
                        df_of_timeseries_identifier,
                        timeseries_identifiers_to_output=False,
                        use_saved_resources=use_saved_resources,
                    )

    def fit_and_process(self, df, step_name, on_full_df, save_data=False):
        if not get_external_features(self.preprocessing_params):
            if save_data:
                self._save_data()
                self._report()
            return df

        self.create_timeseries_preprocessing_handlers(df, on_full_df, use_saved_resources=False)

        with self.listener.push_step(step_name):
            if on_full_df:
                transformed_df = self._fit_and_process_single_timeseries(FULL_TIMESERIES_DF_IDENTIFIER, df)
            else:
                transformed_df = pd.DataFrame()
                for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(
                    df, self.core_params[TIMESERIES_IDENTIFIER_COLUMNS]
                ):
                    transformed_df = pd.concat([
                        transformed_df,
                        self._fit_and_process_single_timeseries(timeseries_identifier, df_of_timeseries_identifier)
                    ], ignore_index=True)

        if save_data:
            self._save_data()
            self._report()

        self.listener.save_status()
        return transformed_df

    def process(self, df, step_name, on_full_df):
        if not get_external_features(self.preprocessing_params):
            return df

        timeseries_identifier_columns = self.core_params[TIMESERIES_IDENTIFIER_COLUMNS]

        with self.listener.push_step(step_name):
            if on_full_df:
                transformed_df = self._process_single_timeseries(FULL_TIMESERIES_DF_IDENTIFIER, df)
            else:
                transformed_df = pd.DataFrame()
                for timeseries_identifier, df_of_timeseries_identifier in timeseries_iterator(
                    df, timeseries_identifier_columns
                ):
                    transformed_df = pd.concat([
                        transformed_df,
                        self._process_single_timeseries(timeseries_identifier, df_of_timeseries_identifier)
                    ], ignore_index=True)

        self.listener.save_status()
        return transformed_df

    def load_resources(self):
        """Load resource files listed in resource_types.json
        and store them and their file type by name in self.resources.
        Load collector data in self.collector_data.
        """
        resource_types = read_resource(self.data_folder_context, "resource_types", "json")
        for resource_name, resource_type in resource_types.items():
            resource = read_resource(self.data_folder_context, resource_name, resource_type)
            if resource:
                self.resources[resource_name] = (resource, resource_type)

        self.collector_data = read_resource(self.data_folder_context, "collector_data", "json")

    def _create_single_timeseries_preprocessing_handler(
        self,
        timeseries_identifier,
        df_of_timeseries_identifier,
        timeseries_identifiers_to_output=False,
        use_saved_resources=False,
    ):
        if use_saved_resources:
            collector_data = self.collector_data[timeseries_identifier]
        else:
            collector_data = PredictionPreprocessingDataCollector(
                df_of_timeseries_identifier, self.preprocessing_params
            ).build()

        # run_folder should not exist because single timeseries don't save anything
        preproc_handler = SingleTimeseriesPreprocessingHandler(
            self.core_params, self.preprocessing_params, self.data_folder_context, collector_data, timeseries_identifier
        )

        if use_saved_resources:
            preproc_handler.set_resources(self.resources, timeseries_identifier)

        pipeline = preproc_handler.build_preprocessing_pipeline(
            with_timeseries_identifiers=timeseries_identifiers_to_output
        )

        self.preproc_handler_by_timeseries[timeseries_identifier] = preproc_handler
        self.pipeline_by_timeseries[timeseries_identifier] = pipeline

    def _fit_and_process_single_timeseries(self, timeseries_identifier, df_of_timeseries_identifier):
        transformed_multiframe = self.pipeline_by_timeseries[timeseries_identifier].fit_and_process(
            df_of_timeseries_identifier
        )

        # "TRAIN" because of the last preprocessing step in PredictionPreprocessingHandler: EmitCurrentMFAsResult("TRAIN")
        self.external_features[timeseries_identifier] = transformed_multiframe["TRAIN"].columns()

        return self._multiframe_to_df(transformed_multiframe, timeseries_identifier)

    def _process_single_timeseries(self, timeseries_identifier, df_of_timeseries_identifier):
        transformed_multiframe = self.pipeline_by_timeseries[timeseries_identifier].process(
            df_of_timeseries_identifier
        )

        # "TRAIN" because of the last preprocessing step in PredictionPreprocessingHandler: EmitCurrentMFAsResult("TRAIN")
        if timeseries_identifier in self.external_features:
            assert (
                self.external_features[timeseries_identifier] == transformed_multiframe["TRAIN"].columns()
            ), "External features columns mismatch"
        else:
            self.external_features[timeseries_identifier] = transformed_multiframe["TRAIN"].columns()

        return self._multiframe_to_df(transformed_multiframe, timeseries_identifier)

    def _multiframe_to_df(self, multiframe, timeseries_identifier):
        processed_df = multiframe["TRAIN"].as_dataframe()

        # TODO @timeseries make "target" and "time" constants (needs to edit multiple steps and preprocessing handler)
        processed_df[self.core_params[TARGET_VARIABLE]] = multiframe["target"].reset_index(drop=True)
        processed_df[self.core_params[TIME_VARIABLE]] = multiframe["time"].reset_index(drop=True)

        if timeseries_identifier == FULL_TIMESERIES_DF_IDENTIFIER:
            if self.core_params[TIMESERIES_IDENTIFIER_COLUMNS]:
                # In this case we've have used OutputRawColumns to move time series identifiers away in input dataframe, directly to output
                processed_df[self.core_params[TIMESERIES_IDENTIFIER_COLUMNS]] = multiframe[
                    TIMESERIES_IDENTIFIER_COLUMNS
                ].reset_index(drop=True)
        else:
            add_timeseries_identifiers_columns(processed_df, timeseries_identifier)

        return processed_df

    def _report(self):
        report_by_timeseries = {}
        for timeseries_identifier, pipeline in self.pipeline_by_timeseries.items():
            report = {}
            if hasattr(self, "core_params"):
                pipeline.report_fit(report, self.core_params)
            else:
                pipeline.report_fit(report, {})
            report_by_timeseries[timeseries_identifier] = report
        write_resource(self.data_folder_context, "preprocessing_report", "json", report_by_timeseries)

    def _save_data(self):
        collector_data_by_timeseries = {}
        resources = {}
        resource_types = {}
        for timeseries_identifier, preproc_handler in self.preproc_handler_by_timeseries.items():
            collector_data_by_timeseries[timeseries_identifier] = preproc_handler.get_collector_data()
            for resource_name, resource, resource_type in preproc_handler.list_resources():
                if resource_name not in resource_types:
                    resource_types[resource_name] = resource_type
                if resource_name not in resources:
                    resources[resource_name] = {}
                resources[resource_name][timeseries_identifier] = resource

        write_resource(self.data_folder_context, "collector_data", "json", collector_data_by_timeseries)

        for resource_name, resource_by_timeseries in resources.items():
            write_resource(self.data_folder_context, resource_name, resource_types[resource_name], resource_by_timeseries)

        write_resource(self.data_folder_context, "resource_types", "json", resource_types)


class SingleTimeseriesPreprocessingHandler(PredictionPreprocessingHandler):
    def __init__(self, core_params, preprocessing_params, data_folder_context, collector_data=None, timeseries_identifier=None):
        super(SingleTimeseriesPreprocessingHandler, self).__init__(core_params, preprocessing_params, data_folder_context)
        self.collector_data = collector_data
        self.timeseries_identifier = timeseries_identifier

    @property
    def time_variable(self):
        return self.core_params.get(TIME_VARIABLE)

    @property
    def target_map(self, with_target=False):
        return None

    def get_collector_data(self):
        return self.collector_data

    def preprocessing_steps(self, with_timeseries_identifiers=False):
        # Move time column away
        yield RemapValueToOutput(self.time_variable, "time", None)

        # Move time series identifiers away
        if with_timeseries_identifiers and self.core_params[TIMESERIES_IDENTIFIER_COLUMNS]:
            yield OutputRawColumns(self.core_params[TIMESERIES_IDENTIFIER_COLUMNS], TIMESERIES_IDENTIFIER_COLUMNS)

        for step in super(SingleTimeseriesPreprocessingHandler, self).preprocessing_steps(with_target=True):
            yield step
        # TODO @timeseries check if RealignTarget can be used ? if it's the case, needs to add a RealignTime step and RealignTimeseriesIdentifiers step

    def set_resources(self, resources, timeseries_identifier):
        """Set the __resources and __resource_types fields of the preprocessing handler of a single timeseries
        using its identifiers encoding and the complete resources (resources of all timeseries) that was loaded previously.
        """
        for resource_name, (resource, resource_type) in resources.items():
            if timeseries_identifier in resource:
                self.set_resource(resource_name, resource[timeseries_identifier], resource_type)
            else:
                self.set_resource(resource_name, {}, resource_type)
