import logging

from dataiku.doctor.timeseries.models.statistical.base_statsforecast_estimator import DkuStatsforecastBaseEstimator

logger = logging.getLogger(__name__)

class DkuCrostonEstimator(DkuStatsforecastBaseEstimator):
    def __init__(
            self,
            frequency,
            time_variable,
            prediction_length,
            target_variable,
            timeseries_identifier_columns,
            monthly_day_alignment=None,
            variant='SBA',  # CLASSIC, SBA, TSB
            alpha_d=0.1,
            alpha_p=0.1,
    ):
        super(DkuCrostonEstimator, self).__init__(
            frequency=frequency,
            time_variable=time_variable,
            prediction_length=prediction_length,
            target_variable=target_variable,
            timeseries_identifier_columns=timeseries_identifier_columns,
            monthly_day_alignment=monthly_day_alignment,
        )

        self.variant = variant
        # TSB specific parameters
        self.alpha_d = alpha_d
        self.alpha_p = alpha_p

    def initialize(self, core_params, modeling_params):
        super(DkuCrostonEstimator, self).initialize(core_params, modeling_params)
        algo_params = modeling_params["croston_timeseries_params"]

        self.variant = algo_params["variant"]
        self.alpha_d = algo_params["alpha_d"] if "alpha_d" in algo_params else None
        self.alpha_p = algo_params["alpha_p"] if "alpha_p" in algo_params else None

    def _build_statsforecast_model(self):
        if self.variant == 'CLASSIC':
            from statsforecast.models import CrostonClassic
            return CrostonClassic()
        if self.variant == 'TSB':
            from statsforecast.models import TSB
            return TSB(alpha_d=self.alpha_d, alpha_p=self.alpha_p)
        else:
            from statsforecast.models import CrostonSBA
            return CrostonSBA()

    def _get_statsforecast_model_name(self):
        if self.variant == 'CLASSIC':
            return "CrostonClassic"
        if self.variant == 'TSB':
            return 'TSB'
        else:
            return "CrostonSBA"
