/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.analysis.ml.prediction;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.MLTaskLoc;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.ml.SavedModelCodes;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.ml.prediction.PythonPostTrainComputationHandler;
import com.dataiku.dip.analysis.ml.prediction.PythonPostTrainRetrainingComputationHandler;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.core.AnalysisCoreParams;
import com.dataiku.dip.analysis.model.prediction.PreTrainPredictionModelingParams;
import com.dataiku.dip.analysis.model.prediction.PredictionGlobalExplanations;
import com.dataiku.dip.analysis.model.prediction.PredictionIndividualExplanations;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.TimeseriesEvaluationForecasts;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelDetails;
import com.dataiku.dip.analysis.model.prediction.TimeseriesForecastingModelIntrinsicPerf;
import com.dataiku.dip.analysis.model.prediction.TimeseriesInteractiveScoringScenarios;
import com.dataiku.dip.analysis.model.prediction.TimeseriesResiduals;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.futures.FutureAborter;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.mec.FullModelEvaluationId;
import com.dataiku.dip.mec.PythonPostModelEvaluationComputationHandler;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.resourceusage.CurrentComputeResourceUsageContext;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.security.impersonation.FilesystemACLUtils;
import com.dataiku.dip.threads.BaseProgressingWorkThread;
import com.dataiku.dip.utils.AutoCloseableLock;
import com.dataiku.dip.utils.NamedLock;
import com.dataiku.dip.utils.NotImplementedException;
import com.google.gson.JsonObject;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;

public abstract class PredictionPostComputationHandler {
    public String jobId;
    private static Logger logger = Logger.getLogger((String)"dku.analysis.ml.prediction.posttraining");

    static PredictionPostComputationHandler from(AuthCtx authCtx, String jobId, ModelLikeId mle, PostComputationCommand computationCommand, JsonObject computationParameters) throws IOException {
        if (PostComputationCommand.MODELLESS_SUBPOPULATION == computationCommand) {
            return new PythonPostModelEvaluationComputationHandler(authCtx, jobId, mle, computationCommand, computationParameters);
        }
        if (PostComputationCommand.RETRAINING == computationCommand) {
            return new PythonPostTrainRetrainingComputationHandler(authCtx, jobId, mle, computationCommand, computationParameters);
        }
        if (PostComputationCommand.LEARNING_CURVE == computationCommand) {
            return new PythonPostTrainComputationHandler(authCtx, jobId, mle, computationCommand, computationParameters);
        }
        switch (mle.getModelLikeType()) {
            case MODEL_EVALUATION: 
            case DOCTOR_MODEL: {
                MLTask.BackendType backend = mle.getBackendType();
                switch (backend) {
                    case PY_MEMORY: 
                    case KERAS: {
                        return new PythonPostTrainComputationHandler(authCtx, jobId, mle, computationCommand, computationParameters);
                    }
                }
                throw new IllegalArgumentException("Post Training computation not available on backend " + String.valueOf((Object)backend));
            }
        }
        throw new NotImplementedException("Unknown model like type");
    }

    protected abstract void compute() throws Exception;

    protected abstract void abort();

    public static enum PostComputationCommand {
        MODELLESS_SUBPOPULATION(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.REGRESSION), Arrays.asList(FeaturePreprocessingParams.Role.INPUT, FeaturePreprocessingParams.Role.REJECT), Arrays.asList(FeaturePreprocessingParams.FeatureType.NUMERIC, FeaturePreprocessingParams.FeatureType.CATEGORY)),
        PDP(Arrays.asList(MLTask.BackendType.PY_MEMORY, MLTask.BackendType.KERAS), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS, PredictionMLTask.PredictionType.REGRESSION), Collections.singletonList(FeaturePreprocessingParams.Role.INPUT), Arrays.asList(FeaturePreprocessingParams.FeatureType.NUMERIC, FeaturePreprocessingParams.FeatureType.CATEGORY)),
        LEARNING_CURVE(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS, PredictionMLTask.PredictionType.REGRESSION), null, null),
        SUBPOPULATION(Arrays.asList(MLTask.BackendType.PY_MEMORY, MLTask.BackendType.KERAS), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.REGRESSION), Arrays.asList(FeaturePreprocessingParams.Role.INPUT, FeaturePreprocessingParams.Role.REJECT), Arrays.asList(FeaturePreprocessingParams.FeatureType.NUMERIC, FeaturePreprocessingParams.FeatureType.CATEGORY)),
        GLOBAL_EXPLANATIONS(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS, PredictionMLTask.PredictionType.REGRESSION), null, null),
        INDIVIDUAL_EXPLANATIONS(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS, PredictionMLTask.PredictionType.REGRESSION), null, null),
        RETRAINING(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.BINARY_CLASSIFICATION, PredictionMLTask.PredictionType.MULTICLASS, PredictionMLTask.PredictionType.REGRESSION), null, null),
        INFORMATION_CRITERIA(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.TIMESERIES_FORECAST), null, null),
        TIMESERIES_RESIDUALS(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.TIMESERIES_FORECAST), null, null),
        TIMESERIES_INTERACTIVE_SCORING_CREATION(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.TIMESERIES_FORECAST), null, null),
        TIMESERIES_INTERACTIVE_SCORING_COMPUTATION(Collections.singletonList(MLTask.BackendType.PY_MEMORY), Arrays.asList(PredictionMLTask.PredictionType.TIMESERIES_FORECAST), null, null);

        private final List<MLTask.BackendType> supportedBackends;
        private final List<PredictionMLTask.PredictionType> supportedPredictionTypes;
        private final List<FeaturePreprocessingParams.Role> authorizedFeatureRoles;
        private final List<FeaturePreprocessingParams.FeatureType> authorizedFeatureTypes;

        private PostComputationCommand(List<MLTask.BackendType> supportedBackends, List<PredictionMLTask.PredictionType> supportedPredictionTypes, List<FeaturePreprocessingParams.Role> authorizedFeatureRoles, List<FeaturePreprocessingParams.FeatureType> authorizedFeatureTypes) {
            this.supportedBackends = supportedBackends;
            this.supportedPredictionTypes = supportedPredictionTypes;
            this.authorizedFeatureRoles = authorizedFeatureRoles;
            this.authorizedFeatureTypes = authorizedFeatureTypes;
        }

        public boolean containsAuthorizedFeatureRoles(FeaturePreprocessingParams.Role role) {
            return this.authorizedFeatureRoles.contains((Object)role);
        }

        public boolean containsAuthorizedFeatureTypes(FeaturePreprocessingParams.FeatureType featureType) {
            return this.authorizedFeatureTypes.contains((Object)featureType);
        }

        public void checkCompatible(ModelLikeId mle) throws CodedException, IOException {
            if (!this.supportedBackends.contains((Object)mle.getBackendType())) {
                throw new CodedException((InfoMessage.MessageCode)SavedModelCodes.ERR_ML_INVALID_POSTTRAIN_CONFIG, "Cannot run '" + this.name() + "' computation with backend '" + String.valueOf((Object)mle.getBackendType()) + "'");
            }
            if (!this.supportedPredictionTypes.contains((Object)mle.getPredictionType())) {
                throw new CodedException((InfoMessage.MessageCode)SavedModelCodes.ERR_ML_INVALID_POSTTRAIN_CONFIG, "Cannot run '" + this.name() + "' computation with prediction type '" + String.valueOf((Object)mle.getPredictionType()) + "'");
            }
        }

        String toCmd() {
            return "compute_" + this.name().toLowerCase();
        }
    }

    public static class TimeseriesInteractiveScoringComputationWorkThread
    extends PostComputationWorkThread<Map<String, TimeseriesEvaluationForecasts.TimeseriesScenariosForecasts>> {
        public TimeseriesInteractiveScoringComputationWorkThread(DSSAuthCtx user, ModelLikeId mle, JsonObject computationParameters) {
            super(user, mle, PostComputationCommand.TIMESERIES_INTERACTIVE_SCORING_COMPUTATION, computationParameters);
        }

        @Override
        public Map<String, TimeseriesEvaluationForecasts.TimeseriesScenariosForecasts> getActualResults() throws IOException {
            return this.mle.getInteractiveScoringTimeseriesEvaluationForecasts().get();
        }
    }

    public static class TimeseriesInteractiveScoringCreationWorkThread
    extends PostComputationWorkThread<Map<String, TimeseriesInteractiveScoringScenarios>> {
        public TimeseriesInteractiveScoringCreationWorkThread(DSSAuthCtx user, ModelLikeId mle, JsonObject computationParameters) {
            super(user, mle, PostComputationCommand.TIMESERIES_INTERACTIVE_SCORING_CREATION, computationParameters);
        }

        @Override
        public Map<String, TimeseriesInteractiveScoringScenarios> getActualResults() throws IOException {
            return this.mle.parseTimeseriesInteractiveScoringScenario().get();
        }
    }

    public static class TimeseriesResidualsWorkThread
    extends PostComputationWorkThread<Map<String, TimeseriesResiduals>> {
        public TimeseriesResidualsWorkThread(DSSAuthCtx user, ModelLikeId mle) {
            super(user, mle, PostComputationCommand.TIMESERIES_RESIDUALS, new JsonObject());
        }

        @Override
        public Map<String, TimeseriesResiduals> getActualResults() throws IOException {
            return this.mle.parseTimeseriesResiduals(new ArrayList<String>()).get();
        }
    }

    public static class TimeseriesInformationCriteriaWorkThread
    extends PostComputationWorkThread<List<TimeseriesForecastingModelIntrinsicPerf.InformationCriterion>> {
        public TimeseriesInformationCriteriaWorkThread(DSSAuthCtx user, ModelLikeId mle) {
            super(user, mle, PostComputationCommand.INFORMATION_CRITERIA, new JsonObject());
        }

        @Override
        public List<TimeseriesForecastingModelIntrinsicPerf.InformationCriterion> getActualResults() throws IOException {
            TimeseriesForecastingModelDetails timeseriesForecastingModelDetails = (TimeseriesForecastingModelDetails)PredictionResultsReader.makeModelDetails(this.mle.getUnderlyingModel());
            if (timeseriesForecastingModelDetails.iperf.informationCriteria != null) {
                logger.warn((Object)"Information criteria already retrieved for model. Returning it.");
                return timeseriesForecastingModelDetails.iperf.informationCriteria;
            }
            if (!Arrays.asList(PreTrainPredictionModelingParams.Algorithm.AUTO_ARIMA, PreTrainPredictionModelingParams.Algorithm.SEASONAL_LOESS).contains((Object)timeseriesForecastingModelDetails.modeling.algorithm)) {
                throw new IllegalArgumentException("Information criteria not available for this model");
            }
            return this.mle.getUnderlyingModel().getTimeseriesForecastingIPerf().informationCriteria;
        }
    }

    public static class SubpopulationMEWorkThread
    extends PostComputationWorkThread<PredictionResultsReader.PredictionSubpopulationResults> {
        private final List<String> features;

        public SubpopulationMEWorkThread(DSSAuthCtx user, FullModelEvaluationId fme, List<String> features, JsonObject computationParameters) {
            super(user, fme, PostComputationCommand.MODELLESS_SUBPOPULATION, computationParameters);
            this.features = features;
        }

        @Override
        public PredictionResultsReader.PredictionSubpopulationResults getActualResults() throws IOException {
            return PredictionResultsReader.makeSubpopulationResults(this.mle, this.features, false);
        }
    }

    public static class IndividualExplanationsPTCWorkThread
    extends PostComputationWorkThread<PredictionIndividualExplanations> {
        public IndividualExplanationsPTCWorkThread(DSSAuthCtx user, ModelLikeId mle, JsonObject computationParameters) {
            super(user, mle, PostComputationCommand.INDIVIDUAL_EXPLANATIONS, computationParameters);
        }

        @Override
        public PredictionIndividualExplanations getActualResults() throws IOException {
            return PredictionResultsReader.getIndividualExplanations(this.mle);
        }
    }

    public static class GlobalExplanationsPTCWorkThread
    extends PostComputationWorkThread<PredictionGlobalExplanations> {
        public GlobalExplanationsPTCWorkThread(DSSAuthCtx user, ModelLikeId mle) {
            super(user, mle, PostComputationCommand.GLOBAL_EXPLANATIONS, new JsonObject());
        }

        @Override
        public PredictionGlobalExplanations getActualResults() throws IOException {
            return this.mle.getGlobalExplanations().orElse(null);
        }
    }

    public static class LearningCurvePTCWorkThread
    extends PostComputationWorkThread<PredictionResultsReader.LearningCurveResults<?>> {
        public LearningCurvePTCWorkThread(DSSAuthCtx user, ModelLikeId mle, JsonObject computationParameters) {
            super(user, mle, PostComputationCommand.LEARNING_CURVE, computationParameters);
        }

        @Override
        public PredictionResultsReader.LearningCurveResults<?> getActualResults() throws IOException {
            return PredictionResultsReader.makeLearningCurveResults(this.mle);
        }
    }

    public static class PartialDependenciesPTCworkThread
    extends PostComputationWorkThread<PredictionResultsReader.PartialDependenciesResult> {
        public PartialDependenciesPTCworkThread(DSSAuthCtx user, ModelLikeId mle, JsonObject computationParameters) {
            super(user, mle, PostComputationCommand.PDP, computationParameters);
        }

        @Override
        public PredictionResultsReader.PartialDependenciesResult getActualResults() throws IOException {
            return PredictionResultsReader.makePartialDependenceResults(this.mle.getMainFolder());
        }
    }

    public static class RetrainingPTCWorkThread
    extends PostComputationWorkThread<String> {
        public RetrainingPTCWorkThread(DSSAuthCtx user, FullModelId mle) {
            super(user, mle, PostComputationCommand.RETRAINING, null);
        }

        @Override
        public String getActualResults() {
            return this.mle.toString();
        }
    }

    public static class SubpopulationPTCWorkThread
    extends PostComputationWorkThread<PredictionResultsReader.PredictionSubpopulationResults> {
        private final List<String> features;
        private boolean computePerformanceMetrics;

        public SubpopulationPTCWorkThread(DSSAuthCtx user, FullModelId fmi, List<String> features, JsonObject computationParameters, boolean computePerformanceMetrics) {
            super(user, fmi, PostComputationCommand.SUBPOPULATION, computationParameters);
            this.features = features;
            this.computePerformanceMetrics = computePerformanceMetrics;
        }

        @Override
        public PredictionResultsReader.PredictionSubpopulationResults getActualResults() throws IOException {
            return PredictionResultsReader.makeSubpopulationResults(this.mle, this.features, this.computePerformanceMetrics);
        }
    }

    public static abstract class PostComputationWorkThread<T>
    extends BaseProgressingWorkThread<T> {
        final ModelLikeId mle;
        private final FuturePayload futurePayload;
        private final PostComputationCommand computationCommand;
        private final JsonObject computationParameters;
        private boolean done = false;
        protected MLTaskLoc loc;
        protected AnalysisCoreParams cp;

        public PostComputationWorkThread(DSSAuthCtx user, ModelLikeId mle, PostComputationCommand computationCommand, JsonObject computationParameters) {
            super(user);
            this.mle = mle;
            this.futurePayload = mle.buildPostComputationFuturePayload();
            this.computationCommand = computationCommand;
            this.computationParameters = computationParameters;
        }

        public TypeToken<FutureResponse<T>> getTypeToken() {
            return new TypeToken<FutureResponse<T>>(){};
        }

        public FuturePayload getPayload() {
            return this.futurePayload;
        }

        public double getDangerosity() {
            return 0.0;
        }

        T getActualResults() throws Exception {
            return null;
        }

        public T getResult() {
            if (!this.done || this.isAborted()) {
                return null;
            }
            try {
                return this.getActualResults();
            }
            catch (Exception e) {
                logger.warn((Object)"Failed to read posttrain computation results", (Throwable)e);
                return null;
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        public void execute() throws Exception {
            AutoCloseableLock lock;
            try (FutureProgress.AutocloseableFutureProgressState waiting = FutureProgress.pushAutoCloseableState((String)"Waiting completion of other post training computation on this model");){
                String lockName = "analysis.ml.prediction.posttraining." + this.mle.toString();
                lock = NamedLock.acquireInterruptibly((String)lockName);
            }
            try (FutureProgress.AutocloseableFutureProgressState state = FutureProgress.pushAutoCloseableState((String)"post-train-computation", (double)100.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.NONE);){
                this.percentageProgressState = state;
                final PredictionPostComputationHandler handler = PredictionPostComputationHandler.from(this.owner, this.jobId, this.mle, this.computationCommand, this.computationParameters);
                try (FutureAborter.AutoCloseableAbortHook aborter = FutureAborter.pushAutoCloseableHook((Runnable)new Runnable(){

                    @Override
                    public void run() {
                        try {
                            handler.abort();
                        }
                        catch (Exception e) {
                            logger.warn((Object)"Error white aborting the post training task", (Throwable)e);
                        }
                    }
                });){
                    this.grantRequiredACLs();
                    ComputeResourceUsageContext cruContext = this.mle.getComputeResourceUsageContext(this.owner);
                    CurrentComputeResourceUsageContext.setInCurrentThread((ComputeResourceUsageContext)cruContext);
                    handler.compute();
                    CurrentComputeResourceUsageContext.clear();
                    this.done = true;
                }
            }
            finally {
                lock.close();
            }
        }

        protected void grantRequiredACLs() throws IOException, DKUSecurityException, InterruptedException {
            FullModelId underlyingModel = this.mle.getUnderlyingModel();
            if (underlyingModel != null && underlyingModel.exists() && underlyingModel != this.mle) {
                FilesystemACLUtils.grantFSReadACLs(this.owner, underlyingModel.getProjectKey(), underlyingModel.getFolderEnsuringSecurity());
            }
            FilesystemACLUtils.grantFSReadACLs(this.owner, this.mle.getProjectKey(), this.mle.getFolderEnsuringSecurity());
            FilesystemACLUtils.grantFSFullACLs((AuthCtx)this.owner, this.mle.getProjectKey(), this.mle.getSessionFolder());
            if (this.mle.getPostOperationsFolder().exists()) {
                FilesystemACLUtils.grantFSFullACLs((AuthCtx)this.owner, this.mle.getProjectKey(), this.mle.getPostOperationsFolder());
            }
        }
    }
}

