/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import java.util.Collections;
import java.util.List;

public class ProphetMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "Prophet";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.ProphetSpace space = rpmp.prophet_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Seasonality mode", space.seasonality_mode).withMVParam("Seasonality prior scale", space.seasonality_prior_scale).withMVParam("Changepoint prior scale", space.changepoint_prior_scale);
        if (space._use_external_features) {
            description = description.withMVParam("External features prior scale", space.holidays_prior_scale);
        }
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PredictionModelingParams.ProphetSpace space = before.prophet_timeseries_grid;
        PostTrainPredictionModelingParams.ProphetParams params = after.prophet_timeseries_params;
        ModelTrainInfo.PostSearchDescription description = new ModelTrainInfo.PostSearchDescription().withSVParam("Seasonality mode", params.seasonality_mode).withSVParam("Seasonality prior scale", params.seasonality_prior_scale).withSVParam("Changepoint prior scale", params.changepoint_prior_scale).withSVParam("Growth", params.growth).withSVParam("Number of changepoints", params.n_changepoints).withSVParam("Changepoint range", params.changepoint_range).withSVParam("Yearly seasonality", params.yearly_seasonality).withSVParam("Weekly seasonality", params.weekly_seasonality).withSVParam("Daily seasonality", params.daily_seasonality).withSVParam("Seed", params.seed);
        if (space._use_external_features) {
            description = description.withSVParam("External features prior scale", params.holidays_prior_scale);
        }
        if (space.growth == PredictionModelingParams.ProphetGrowth.LOGISTIC) {
            description = description.withSVParam("Floor", params.floor).withSVParam("Capacity", params.cap);
        }
        return description;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.ProphetSpace space = pmp.prophet_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkNumericalDimension(space.seasonality_prior_scale, "Seasonality prior scale (Prophet)");
        checks.checkNumericalDimension(space.changepoint_prior_scale, "Changepoint prior scale (Prophet)");
        if (space._use_external_features) {
            checks.checkNumericalDimension(space.holidays_prior_scale, "External features prior scale (Prophet)");
        }
        checks.checkPositive(space.seasonality_mode.getLength(), "At least one seasonality mode must be selected (Prophet)");
        checks.checkBetween(space.changepoint_range, 0.0, 1.0, "Changepoint range must be between 0 and 1 (Prophet)");
        checks.checkNonNegative(space.n_changepoints, "Number of changepoints must be >= 0 (Prophet)");
        if (space.growth == PredictionModelingParams.ProphetGrowth.LOGISTIC) {
            checks.check(space.cap != null, "Capacity must be defined for logistic growth (Prophet)");
            checks.checkPositive(space.cap - space.floor, "Capacity must be strictly greater than Floor for logistic growth (Prophet)");
        }
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.ProphetSpace space = pmp.prophet_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        boolean useExternalFeatures = false;
        for (FeaturePreprocessingParams params : task.getPreprocessingParams().per_feature.values()) {
            if (params.role != FeaturePreprocessingParams.Role.INPUT) continue;
            useExternalFeatures = true;
            break;
        }
        space._use_external_features = useExternalFeatures;
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.PROPHET, pmp);
        preTrainParams.prophet_timeseries_grid = space;
        this.checkAndUpdateSearchStrategy(pmp, preTrainParams);
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.ProphetParams optimizedParams = optimized.prophet_timeseries_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.ProphetSpace space = preTrainParams.prophet_timeseries_grid;
        space.seasonality_mode = CategoricalHyperparameterDimension.create(optimizedParams.seasonality_mode, "additive", "multiplicative");
        space.seasonality_prior_scale.setToSingleValueGrid(optimizedParams.seasonality_prior_scale);
        space.changepoint_prior_scale.setToSingleValueGrid(optimizedParams.changepoint_prior_scale);
        if (space._use_external_features) {
            space.holidays_prior_scale.setToSingleValueGrid(optimizedParams.holidays_prior_scale);
        }
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.prophet_timeseries = preTrainParams.prophet_timeseries_grid;
        target.prophet_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.prophet_timeseries = usedToTrain.prophet_timeseries_grid;
        target.prophet_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.ProphetSpace prophetSpace = (PredictionModelingParams.ProphetSpace)space;
        int gridLength = prophetSpace.seasonality_mode.getLength() * prophetSpace.seasonality_prior_scale.getLength() * prophetSpace.changepoint_prior_scale.getLength();
        if (prophetSpace._use_external_features) {
            gridLength *= prophetSpace.holidays_prior_scale.getLength();
        }
        return gridLength;
    }
}

