/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.mec.drift;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.model.core.ModelUserMeta;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureProgress;
import com.dataiku.dip.futures.FutureProgressState;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.mec.AbstractModelEvaluation;
import com.dataiku.dip.mec.FullModelEvaluationId;
import com.dataiku.dip.mec.ModelComparison;
import com.dataiku.dip.mec.ModelComparisonsService;
import com.dataiku.dip.mec.ModelEvaluationCodes;
import com.dataiku.dip.mec.ModelEvaluationStore;
import com.dataiku.dip.mec.ModelEvaluationStoresCRUDService;
import com.dataiku.dip.mec.TabularModelEvaluation;
import com.dataiku.dip.mec.drift.DriftParams;
import com.dataiku.dip.mec.drift.DriftResult;
import com.dataiku.dip.mec.drift.DriftResultCachingService;
import com.dataiku.dip.mec.engine.DriftKernelCachingService;
import com.dataiku.dip.mec.engine.DriftRunner;
import com.dataiku.dip.mec.engine.ModelComparisonCodes;
import com.dataiku.dip.security.DSSAuthCtx;
import com.dataiku.dip.server.services.TaggableObjectsService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.threads.BaseProgressingWorkThread;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.util.DKUExecutors;
import com.dataiku.dip.utils.DKULogger;
import com.google.common.collect.ImmutableList;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.security.InvalidParameterException;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class DriftService {
    @Autowired
    private FutureService futureService;
    @Autowired
    private ModelComparisonsService modelComparisonsService;
    @Autowired
    private ModelEvaluationStoresCRUDService mesService;
    @Autowired
    private DriftResultCachingService resultCachingService;
    @Autowired
    private DriftKernelCachingService kernelCachingService;
    @Autowired
    private TransactionService transactionService;
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.ml.mec.drift");

    public FutureResponse<DriftResult> startDriftComputation(DSSAuthCtx owner, final String projectKey, final ModelLikeId referenceModelLikeId, final ModelLikeId currentModelLikeId, final DriftParams params) throws Exception {
        if (!this.checkDriftCompatibility(referenceModelLikeId, currentModelLikeId)) {
            throw new CodedException((InfoMessage.MessageCode)ModelEvaluationCodes.ERR_DRIFT_UNSUPPORTED_INCOMPATIBILITY, String.valueOf(referenceModelLikeId) + " is not compatible with " + String.valueOf(currentModelLikeId));
        }
        if (!currentModelLikeId.hasDataToStreamCSV()) {
            throw new CodedException((InfoMessage.MessageCode)ModelEvaluationCodes.ERR_DRIFT_NO_DATA_SAMPLE, String.valueOf(currentModelLikeId) + " has no data sample");
        }
        if (!referenceModelLikeId.hasDataToStreamCSV()) {
            throw new CodedException((InfoMessage.MessageCode)ModelEvaluationCodes.ERR_DRIFT_NO_DATA_SAMPLE, String.valueOf(referenceModelLikeId) + " has no data sample");
        }
        return this.futureService.runFuture(new BaseProgressingWorkThread<DriftResult>(owner){
            DriftResult result;

            public FuturePayload getPayload() {
                FuturePayload payload = FuturePayload.newSimple((String)"data_drift_compute", (String)"Compute data drift");
                payload.targets.add(new FuturePayload.FuturePayloadTarget(projectKey, params.toString(), "Data drift with " + String.valueOf(params), null));
                return payload;
            }

            public double getDangerosity() {
                return 0.0;
            }

            public DriftResult getResult() {
                return this.result;
            }

            public void execute() throws Exception {
                try (FutureProgress.AutocloseableFutureProgressState state = FutureProgress.pushAutoCloseableState((String)"Drift computation", (double)100.0, (FutureProgressState.StateUnit)FutureProgressState.StateUnit.NONE);){
                    this.percentageProgressState = state;
                    this.result = DriftService.this.computeDrift_NT(this.owner, projectKey, params, referenceModelLikeId, currentModelLikeId, this.jobId);
                }
            }
        }, 0L, new TypeToken<FutureResponse<DriftResult>>(){});
    }

    private DriftResult computeDrift_NT(DSSAuthCtx owner, String projectKey, DriftParams params, ModelLikeId referenceModelLikeId, ModelLikeId currentModelLikeId, String jobId) throws Exception {
        logger.info((Object)("Compute the data drift of " + String.valueOf(currentModelLikeId) + " with " + String.valueOf(referenceModelLikeId) + " with " + String.valueOf(params)));
        double referenceThreshold = this.getThreshold(referenceModelLikeId);
        double currentThreshold = this.getThreshold(currentModelLikeId);
        DriftResult cachedResults = this.resultCachingService.get(params, referenceModelLikeId.toString(), currentModelLikeId.toString(), referenceThreshold, currentThreshold);
        if (cachedResults != null) {
            logger.info((Object)"Using cached results for data drift computation");
            return cachedResults;
        }
        try (DriftRunner driftRunner = this.kernelCachingService.getCachedDriftRunner(owner, projectKey);){
            boolean computePredictionDrift = this.checkPredictionDriftCompatibility(referenceModelLikeId, currentModelLikeId);
            DriftResult results = driftRunner.compute(params, referenceModelLikeId, currentModelLikeId, jobId, computePredictionDrift, referenceThreshold, currentThreshold);
            this.resultCachingService.put(params, referenceModelLikeId.toString(), currentModelLikeId.toString(), referenceThreshold, currentThreshold, results);
            DriftResult driftResult = results;
            return driftResult;
        }
    }

    private double getThreshold(ModelLikeId mli) throws IOException {
        switch (mli.getModelLikeType()) {
            case DOCTOR_MODEL: {
                FullModelId fmi = (FullModelId)mli;
                ModelUserMeta mum = fmi.getUserMetaOpt().orElse(null);
                if (mum == null) break;
                return mum.activeClassifierThreshold;
            }
            case MODEL_EVALUATION: {
                FullModelEvaluationId fme = (FullModelEvaluationId)mli;
                AbstractModelEvaluation me = fme.getModelEvaluation();
                if (!me.isTabular()) break;
                TabularModelEvaluation tabularModelEvaluation = (TabularModelEvaluation)me;
                return tabularModelEvaluation.activeClassifierThreshold;
            }
        }
        return 0.0;
    }

    public boolean checkDriftCompatibility(ModelLikeId referenceId, ModelLikeId currentId) {
        try {
            ModelComparisonsService.ComparableModelItem reference = this.modelComparisonsService.comparableItemFromId(referenceId);
            ModelComparisonsService.ComparableModelItem current = this.modelComparisonsService.comparableItemFromId(currentId);
            return reference.canBeUsedToComputeDataDriftAgainst(current);
        }
        catch (Exception e) {
            logger.error((Object)("Could not compute drift compatibility between " + String.valueOf(referenceId) + " and " + String.valueOf(currentId)), (Throwable)e);
            return false;
        }
    }

    public boolean checkPredictionDriftCompatibility(ModelLikeId referenceId, ModelLikeId currentId) {
        try {
            ModelComparisonsService.ComparableModelItem reference = this.modelComparisonsService.comparableItemFromId(referenceId);
            ModelComparisonsService.ComparableModelItem current = this.modelComparisonsService.comparableItemFromId(currentId);
            return reference.canBeUsedToComputePredictionDriftAgainst(current);
        }
        catch (Exception e) {
            logger.warn((Object)("Could not compute prediction drift compatibility between " + String.valueOf(referenceId) + " and " + String.valueOf(currentId)), (Throwable)e);
            return false;
        }
    }

    public ModelLikeId getReferenceForDrift(ModelLikeId currentId) throws IOException {
        block10: {
            if (currentId.getModelLikeType() != ModelLikeId.ModelLikeType.MODEL_EVALUATION) break block10;
            FullModelEvaluationId currentMeId = (FullModelEvaluationId)currentId;
            TabularModelEvaluation modelEvaluation = currentMeId.getTabularModelEvaluation();
            switch (Objects.requireNonNull(modelEvaluation.modelType)) {
                case SAVED_MODEL: {
                    return modelEvaluation.getBackingFullModelId();
                }
                case EXTERNAL: {
                    TaggableObjectsService.TaggableObjectRef mesRef = currentMeId.getUnderlyingStore();
                    ModelEvaluationStore mes = null;
                    try (Transaction t = this.transactionService.retrieveOrBeginRead();){
                        mes = this.mesService.getMandatory(mesRef.projectKey, mesRef.id);
                    }
                    return this.mesService.listEvaluations(mes, null, -1).stream().map(meh -> meh.evaluation).min((me1, me2) -> Long.signum(me1.created - me2.created)).map(me -> me.ref).orElse(null);
                }
            }
        }
        return null;
    }

    public List<ModelComparisonsService.ComparableModelItem> getCompatibleReferences(String projectKey, ModelLikeId currentId, InfoMessage.InfoMessages infoMessages, int maxComparables) throws IOException, DKUSecurityException {
        List<ModelComparisonsService.ComparableModelItem> availableModels = this.modelComparisonsService.listFilteredComparableItems(projectKey, (List<ModelComparisonsService.MELikesSource>)ImmutableList.of((Object)((Object)ModelComparisonsService.MELikesSource.FROM_MES), (Object)((Object)ModelComparisonsService.MELikesSource.FROM_SAVED_MODELS), (Object)((Object)ModelComparisonsService.MELikesSource.FROM_ANALYSIS)), ModelComparison.ModelTaskType.from(currentId.getPredictionType()), false, null, null, -1);
        Optional<ModelComparisonsService.ComparableModelItem> currentModel = availableModels.stream().filter(cm -> Objects.equals(cm.mli, currentId)).findAny();
        if (!currentModel.isPresent()) {
            throw new InvalidParameterException("Model evaluation " + String.valueOf(currentId) + " not found");
        }
        List<ModelComparisonsService.ComparableModelItem> res = currentModel.map(comparableModelItem -> availableModels.stream().peek(cm -> {
            cm.isCompatibleReference = cm.canBeUsedToComputeDriftAgainst((ModelComparisonsService.ComparableModelItem)comparableModelItem);
        }).filter(cm -> !Objects.equals(cm.mli, currentId)).collect(Collectors.toList())).orElse(Collections.emptyList());
        if (maxComparables > 0 && res.size() > maxComparables - 1) {
            if (null != infoMessages) {
                infoMessages.withWarningV((InfoMessage.MessageCode)ModelComparisonCodes.WARN_TRUNCATED_COMPARABLE_LIST, "Too many comparables (%d) found. The list was truncated to the %d first items.", new Object[]{res.size(), maxComparables});
            }
            res = res.subList(0, maxComparables - 1);
        }
        res.add(currentModel.get());
        return res;
    }

    public void clearCaches() {
        logger.info((Object)"Clearing computation results and kernel caches");
        this.kernelCachingService.invalidateAll();
        this.resultCachingService.invalidateAll();
    }

    @PostConstruct
    public void postConstruct() {
        int CLEANUP_PERIOD_MN = 10;
        logger.infoV("Scheduling periodic cleanup of kernel and result cache every %d minutes, starting in %d minutes", new Object[]{10, 10});
        DKUExecutors.newNamedSingleDaemonThreadExecutor("DriftKernels-cleanup").scheduleWithFixedDelay(new Runnable(){

            @Override
            public void run() {
                DriftService.this.kernelCachingService.cleanUp();
                DriftService.this.resultCachingService.cleanUp();
            }
        }, 10L, 10L, TimeUnit.MINUTES);
    }
}

