/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.model.prediction.algorithms.python;

import com.dataiku.dip.analysis.model.ModelTrainInfo;
import com.dataiku.dip.analysis.model.core.WorkSet;
import com.dataiku.dip.analysis.model.prediction.CategoricalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.IntegerHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.NumericalHyperparameterDimension;
import com.dataiku.dip.analysis.model.prediction.PostTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionParameterChecks;
import com.dataiku.dip.analysis.model.prediction.algorithms.python.PyMemoryAlgorithmMeta;
import java.util.Collections;
import java.util.List;
import java.util.Map;

public class AutoArimaMeta
extends PyMemoryAlgorithmMeta {
    @Override
    public String generateName(PreTrainPredictionModelingParams rpmp) {
        return "AutoARIMA";
    }

    private boolean mightHaveSeasonalSearchPoints(PredictionModelingParams.GridSearchParams gsParams, IntegerHyperparameterDimension m) {
        boolean isGridSearch;
        boolean bl = isGridSearch = gsParams.strategy == PredictionModelingParams.GridSearchParams.Strategy.GRID;
        if (isGridSearch && m.gridMode == NumericalHyperparameterDimension.ValueMode.EXPLICIT || m.randomMode == NumericalHyperparameterDimension.ValueMode.EXPLICIT) {
            return (Long)m.getMaxValue() > 1L;
        }
        return true;
    }

    @Override
    public ModelTrainInfo.PreSearchDescription generatePreTrainDescription(PreTrainPredictionModelingParams rpmp) {
        PredictionModelingParams.AutoArimaSpace space = rpmp.autoarima_timeseries_grid;
        ModelTrainInfo.PreSearchDescription preSearchDescription = new ModelTrainInfo.PreSearchDescription(rpmp).withGridLength(this.getSearchSize(rpmp.grid_search_params, space)).withMVParam("Season length", space.m).withMVParam("Information criterion", space.information_criterion).withSVParam("p (start)", space.start_p).withSVParam("p (max)", space.max_p).withSVParam("q (start)", space.start_q).withSVParam("q (max)", space.max_q);
        if (!space.stationary) {
            preSearchDescription = null == space.d ? preSearchDescription.withSVParam("d (max)", space.max_d).withMVParam("Unit root test", space.test) : preSearchDescription.withSVParam("d", space.d);
        }
        if (this.mightHaveSeasonalSearchPoints(rpmp.grid_search_params, space.m)) {
            preSearchDescription = preSearchDescription.withSVParam("P (start)", space.start_P).withSVParam("P (max)", space.max_P).withSVParam("Q (start)", space.start_Q).withSVParam("Q (max)", space.max_Q);
            if (!space.stationary) {
                preSearchDescription = null == space.D ? preSearchDescription.withSVParam("D (max)", space.max_D).withMVParam("Seasonal unit root test", space.seasonal_test) : preSearchDescription.withSVParam("D", space.D);
            }
        }
        preSearchDescription = preSearchDescription.withSVParam("Maximum iterations", space.maxiter);
        preSearchDescription = preSearchDescription.withSVParam("Max. sum of orders", space.max_order).withMVParam("Solver", space.method);
        return preSearchDescription;
    }

    @Override
    public ModelTrainInfo.PostSearchDescription generatePostTrainDescription(ModelTrainInfo.PreSearchDescription descBefore, PreTrainPredictionModelingParams before, PostTrainPredictionModelingParams after) {
        PostTrainPredictionModelingParams.AutoArimaParams params = after.auto_arima_timeseries_params;
        ModelTrainInfo.PostSearchDescription postSearchDescription = new ModelTrainInfo.PostSearchDescription();
        postSearchDescription = postSearchDescription.withSVParam("Season length", params.m);
        for (Map.Entry<String, Integer> entry : params.p.entrySet()) {
            postSearchDescription = postSearchDescription.withSVParam("Auto-regressive model order (p)", entry.getValue(), entry.getKey());
        }
        for (Map.Entry<String, Integer> entry : params.q.entrySet()) {
            postSearchDescription = postSearchDescription.withSVParam("Moving-average model order (q)", entry.getValue(), entry.getKey());
        }
        if (!params.stationary) {
            for (Map.Entry<String, Integer> entry : params.d.entrySet()) {
                postSearchDescription = postSearchDescription.withSVParam("Differencing term (d)", entry.getValue(), entry.getKey());
            }
        }
        if (params.m > 1L) {
            for (Map.Entry<String, Integer> entry : params.P.entrySet()) {
                postSearchDescription = postSearchDescription.withSVParam("Seasonal auto-regressive model order (P)", entry.getValue(), entry.getKey());
            }
            for (Map.Entry<String, Integer> entry : params.Q.entrySet()) {
                postSearchDescription = postSearchDescription.withSVParam("Seasonal moving-average model order (Q)", entry.getValue(), entry.getKey());
            }
            if (!params.stationary) {
                for (Map.Entry<String, Integer> entry : params.D.entrySet()) {
                    postSearchDescription = postSearchDescription.withSVParam("Seasonal differencing term (D)", entry.getValue(), entry.getKey());
                }
            }
        }
        postSearchDescription = postSearchDescription.withSVParam("Information criterion", params.information_criterion).withSVParam("Solver", params.method);
        return postSearchDescription;
    }

    @Override
    public void validateParameters(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, PredictionParameterChecks checks) {
        PredictionModelingParams.AutoArimaSpace space = pmp.autoarima_timeseries;
        if (space == null || !space.enabled) {
            return;
        }
        checks.checkPositive(space.information_criterion.getLength(), "At least one information criterion must be selected (AutoARIMA)");
        checks.checkPositive(space.test.getLength(), "At least one statistical test must be selected (AutoARIMA)");
        checks.checkPositive(space.seasonal_test.getLength(), "At least one seasonal statistical test must be selected (AutoARIMA)");
        checks.checkPositive(space.method.getLength(), "At least one solver method must be selected (AutoARIMA)");
        checks.checkNumericalDimension(space.m, "Season length (m) for seasonal differencing (AutoARIMA)");
        checks.checkBetween(space.start_p, 0.0, space.max_p, "The starting value of the auto-regressive model order (p) must be between 0 and its maximum value (AutoARIMA)");
        checks.checkBetween(space.start_q, 0.0, space.max_q, "The starting value of the moving-average model order (q) must be between 0 and its maximum value (AutoARIMA)");
        boolean mightHaveSeasonalSearchPoints = this.mightHaveSeasonalSearchPoints(pmp.gridSearchParams, space.m);
        if (mightHaveSeasonalSearchPoints) {
            checks.checkBetween(space.start_P, 0.0, space.max_P, "The starting value of the auto-regressive seasonal model order (P) must be between 0 and its maximum value (AutoARIMA)");
            checks.checkBetween(space.start_Q, 0.0, space.max_Q, "The starting value of the moving-average seasonal model order (Q) must be between 0 and its maximum value (AutoARIMA)");
        }
        if (space.d != null) {
            checks.checkNonNegative(space.d.intValue(), "First-differencing order (d) must be a non-negative integer (AutoARIMA)");
        }
        if (space.D != null) {
            checks.checkNonNegative(space.D.intValue(), "Seasonal differencing term (D) must be a non-negative integer (AutoARIMA)");
        }
        int upperBoundForMaxOrder = space.start_p + space.start_q + (mightHaveSeasonalSearchPoints ? space.start_P + space.start_Q : 0);
        checks.checkPositive(space.max_order - upperBoundForMaxOrder, "The upper bounds for the sum of orders must be strictly higher than the sum of the starting values of each order (AutoARIMA)");
    }

    @Override
    public List<WorkSet.ModelingSet> expandModeling(PredictionModelingParams pmp, PredictionMLTask.TabularPredictionMLTask task, int gsFolds) {
        PredictionModelingParams.AutoArimaSpace space = pmp.autoarima_timeseries;
        if (space == null || !space.enabled) {
            return Collections.emptyList();
        }
        PreTrainPredictionModelingParams preTrainParams = new PreTrainPredictionModelingParams(PreTrainPredictionModelingParams.Algorithm.AUTO_ARIMA, pmp);
        preTrainParams.autoarima_timeseries_grid = space;
        preTrainParams.gridLength = this.getSearchSize(preTrainParams.grid_search_params, space);
        WorkSet.ModelingSet modelingSet = new WorkSet.ModelingSet(preTrainParams);
        if (preTrainParams.gridLength > 1) {
            modelingSet.estimatedTrains = preTrainParams.gridLength * gsFolds + 1;
        }
        return Collections.singletonList(modelingSet);
    }

    @Override
    public PreTrainPredictionModelingParams regridifyToPreTrain(PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PostTrainPredictionModelingParams.AutoArimaParams optimizedParams = optimized.auto_arima_timeseries_params;
        PreTrainPredictionModelingParams preTrainParams = this.getCopyWithGridStrategy(usedToTrain);
        PredictionModelingParams.AutoArimaSpace space = preTrainParams.autoarima_timeseries_grid;
        space.m.setToSingleValueGrid(optimizedParams.m);
        space.information_criterion = CategoricalHyperparameterDimension.create(optimizedParams.information_criterion, "aic", "aicc", "bic", "hqic", "oob");
        space.test = CategoricalHyperparameterDimension.create(optimizedParams.test, "kpss", "adf", "pp");
        space.seasonal_test = CategoricalHyperparameterDimension.create(optimizedParams.seasonal_test, "ocsb", "ch");
        space.method = CategoricalHyperparameterDimension.create(optimizedParams.method, "lbfgs", "newton", "nm", "bfgs", "powell", "cg", "ncg", "basinhopping");
        return preTrainParams;
    }

    @Override
    public void regridifyToMLTask(PredictionModelingParams target, PostTrainPredictionModelingParams optimized, PreTrainPredictionModelingParams usedToTrain) {
        PreTrainPredictionModelingParams preTrainParams = this.regridifyToPreTrain(optimized, usedToTrain);
        target.autoarima_timeseries = preTrainParams.autoarima_timeseries_grid;
        target.autoarima_timeseries.enabled = true;
    }

    @Override
    public void refreshMLTask(PredictionModelingParams target, PreTrainPredictionModelingParams usedToTrain) {
        target.autoarima_timeseries = usedToTrain.autoarima_timeseries_grid;
        target.autoarima_timeseries.enabled = true;
    }

    @Override
    public boolean hasProbabilities(PreTrainPredictionModelingParams rpmp) {
        return true;
    }

    @Override
    protected int getGridLength(PredictionModelingParams.HyperparametersSpace space) {
        PredictionModelingParams.AutoArimaSpace autoArimaSpace = (PredictionModelingParams.AutoArimaSpace)space;
        return autoArimaSpace.m.getLength() * autoArimaSpace.information_criterion.getLength() * autoArimaSpace.test.getLength() * autoArimaSpace.seasonal_test.getLength() * autoArimaSpace.method.getLength();
    }
}

