/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;

public class ETSMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "ETS";
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.ETSSpace space = rpmp.ets_timeseries_grid;
        ModelTrainInfo.PreSearchDescription description = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Trend", space.trend).withMVParam("Damped trend", space.damped_trend).withMVParam("Seasonal", space.seasonal).withMVParam("Error", space.error).withSVParam("Season length", space.seasonal_periods).withSVParam("Seed", space.seed).withSVParam("Include unstable models", space.include_unstable);
        return description;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.ETSParams params = after.ets_params;
        return new ModelTrainInfo.PostSearchDescription().withSVParam("Trend", params.trend).withSVParam("Damped trend", params.damped_trend).withSVParam("Seasonal", params.seasonal).withSVParam("Error", params.error).withSVParam("Seed", params.seed).withSVParam("Season length", params.seasonal_periods).withSVParam("Include unstable models", params.include_unstable);
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.ETSSpace space = pmp.ets_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkPositive(space.trend.getLength(), "At least one type of trend must be selected (ETS)");
        checks.checkPositive(space.damped_trend.getLength(), "At least one type of damped trend must be selected (ETS)");
        checks.checkPositive(space.seasonal.getLength(), "At least one type of seasonality must be selected (ETS)");
        checks.checkPositive(space.error.getLength(), "At least one type of error must be selected (ETS)");
        checks.checkPositive(space.seasonal_periods - 1, "Season length parameter should be greater than 2 (ETS)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.ETSSpace space = pmp.ets_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.ETS, pmp);
        preTrainParams.ets_timeseries_grid = space;
        this.checkAndUpdateSearchStrategy(pmp, preTrainParams);
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.ETSParams optimizedParams = optimized.ets_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.ETSSpace space = preTrainParams.ets_timeseries_grid;
        space.trend = CategoricalHyperparameterDimension.create(optimizedParams.trend, "none", "add", "mul");
        space.damped_trend = CategoricalHyperparameterDimension.create(optimizedParams.damped_trend, "true", "false");
        space.seasonal = CategoricalHyperparameterDimension.create(optimizedParams.seasonal, "none", "add", "mul");
        space.error = CategoricalHyperparameterDimension.create(optimizedParams.error, "add", "mul");
        space.seasonal_periods = preTrainParams.ets_timeseries_grid.seasonal_periods;
        space.seed = preTrainParams.ets_timeseries_grid.seed;
        space.include_unstable = preTrainParams.ets_timeseries_grid.include_unstable;
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.ets_timeseries = preTrainParams.ets_timeseries_grid;
        target.ets_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.ets_timeseries = usedToTrain.ets_timeseries_grid;
        target.ets_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.ETSSpace ETSSpace2 = (PredictionModelingParams.ETSSpace)space;
        return ETSSpace2.trend.getLength() * ETSSpace2.seasonal.getLength() * ETSSpace2.damped_trend.getLength() * ETSSpace2.error.getLength();
    }
}

