/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.mec.engine;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.ModelLikeId;
import com.dataiku.dip.analysis.model.prediction.ClassificationModelPredictionInfos;
import com.dataiku.dip.analysis.model.prediction.ColumnImportance;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionModelPredictionInfos;
import com.dataiku.dip.analysis.model.prediction.RegressionModelPredictionInfos;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionPreprocessingParams;
import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.io.JavaBlockLink;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.mec.AbstractModelEvaluation;
import com.dataiku.dip.mec.FullModelEvaluationId;
import com.dataiku.dip.mec.ModelEvaluationCodes;
import com.dataiku.dip.mec.drift.DriftParams;
import com.dataiku.dip.mec.drift.DriftResult;
import com.dataiku.dip.mec.engine.CSVSchemaAdapter;
import com.dataiku.dip.threads.BaseKernelProtocol;
import com.dataiku.dip.transactions.TransactionContext;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.NotImplementedException;
import com.dataiku.dip.utils.polyjson.Mapping;
import com.dataiku.dip.utils.polyjson.PolyJSON;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Stopwatch;
import java.io.File;
import java.io.IOException;
import java.io.OutputStream;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;

public class DriftKernelProtocol
extends BaseKernelProtocol
implements AutoCloseable {
    public static final String PYTHON_PACKAGE = "dataiku.modelevaluation.server";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.modelevaluation.postcomputation");

    public DriftKernelProtocol(SimplePythonKernel simplePythonKernel) {
        super(simplePythonKernel);
    }

    private PredictionMLTask.PredictionType getPredictionType(ModelLikeId mle) throws IOException {
        switch (mle.getModelLikeType()) {
            case MODEL_EVALUATION: {
                return ((FullModelEvaluationId)mle).getTabularModelEvaluation().predictionType;
            }
            case DOCTOR_MODEL: {
                FullModelId fmi = (FullModelId)mle;
                if (fmi.isExternalMLflowModelVersion()) {
                    return fmi.getMLflowImportedModelMetadata().predictionType;
                }
                PredictionMLTask mlTask = (PredictionMLTask)fmi.getHeadMLTask();
                return mlTask.predictionType;
            }
        }
        throw new NotImplementedException("Unknown model like type");
    }

    public void loadReferenceModelEvaluationWithDriftReferenceAndStatistics(final FullModelEvaluationId fmle) throws IOException, CodedException {
        AbstractModelEvaluation me = fmle.getModelEvaluation();
        Schema dataSchema = me.getDriftReferenceSchema();
        CSVSchemaAdapter.StreamPusher csvData = new CSVSchemaAdapter.StreamPusher(){

            @Override
            public void push(OutputStream out) throws Exception {
                fmle.streamDriftReferenceCSV(out);
            }
        };
        this.loadModelLike(fmle, csvData, dataSchema, String.valueOf(fmle) + "-reference", fmle.getReferencePredictionStatisticsFile());
    }

    public void loadModelLike(final ModelLikeId mle) throws IOException, CodedException {
        Schema dataSchema = mle.getDataSchema();
        CSVSchemaAdapter.StreamPusher csvData = new CSVSchemaAdapter.StreamPusher(){

            @Override
            public void push(OutputStream out) throws Exception {
                mle.streamDataCSV(out);
            }
        };
        this.loadModelLike(mle, csvData, dataSchema, mle.toString(), mle.getPredictionStatisticsFile());
    }

    private void loadModelLike(final ModelLikeId mle, CSVSchemaAdapter.StreamPusher csvData, Schema dataSchema, String id, File predictionStatisticsFile) throws IOException, CodedException {
        ResolvedClassicalPredictionPreprocessingParams preprocessingParams = mle.getResolvedPredictionPreprocessingParams();
        CSVSchemaAdapter.StreamPusher csvPrediction = new CSVSchemaAdapter.StreamPusher(){

            @Override
            public void push(OutputStream out) throws Exception {
                mle.streamPredictedCSV(out);
            }
        };
        this.loadModelLike(id, dataSchema, mle.getPredictedSchema(), preprocessingParams, this.getPredictionType(mle), csvData, csvPrediction, mle.getColumnImportance(), predictionStatisticsFile);
    }

    @VisibleForTesting
    void loadModelLike(String mleId, Schema dataSchema, Schema predictedSchema, ResolvedClassicalPredictionPreprocessingParams preprocessingParams, PredictionMLTask.PredictionType predictionType, CSVSchemaAdapter.StreamPusher csvData, CSVSchemaAdapter.StreamPusher predictionData, ColumnImportance columnImportance, File predictionStatisticsFile) throws CodedException {
        try {
            JavaBlockLink.WrappedSocketBlockLinkOutputStream os;
            Stopwatch stopwatch = Stopwatch.createStarted();
            logger.info((Object)"Started streaming sample to Python engine...");
            JavaBlockLink link = this.simplePythonKernel.getLink();
            PredictionModelPredictionInfos predictionInfos = null;
            if (predictionStatisticsFile != null && predictionStatisticsFile.exists()) {
                switch (predictionType) {
                    case REGRESSION: {
                        predictionInfos = (PredictionModelPredictionInfos)JSON.parseFile((File)predictionStatisticsFile, RegressionModelPredictionInfos.class);
                        break;
                    }
                    case BINARY_CLASSIFICATION: 
                    case MULTICLASS: {
                        predictionInfos = (PredictionModelPredictionInfos)JSON.parseFile((File)predictionStatisticsFile, ClassificationModelPredictionInfos.class);
                        break;
                    }
                    default: {
                        throw new IllegalArgumentException(String.format("The prediction type (%s) of the model %s is not compatible with drift computation", new Object[]{predictionType, mleId}));
                    }
                }
            }
            link.sendRequest((Object)new LoadModelLikeCommand(mleId, dataSchema, predictedSchema, preprocessingParams, predictionType, columnImportance, predictionInfos));
            link.receiveJsonResponse(WaitingForDataResponse.class);
            try {
                os = link.sendStreamAsync(10000);
                try {
                    csvData.push((OutputStream)os);
                }
                finally {
                    if (os != null) {
                        os.close();
                    }
                }
            }
            catch (Exception e) {
                throw new CodedException((InfoMessage.MessageCode)ModelEvaluationCodes.ERR_SAMPLE_STREAM_TO_KERNEL_FAILURE, "Could not stream sample to Python kernel", (Throwable)e);
            }
            if (predictedSchema != null) {
                try {
                    os = link.sendStreamAsync(10000);
                    try {
                        predictionData.push((OutputStream)os);
                    }
                    finally {
                        if (os != null) {
                            os.close();
                        }
                    }
                }
                catch (Exception e) {
                    throw new CodedException((InfoMessage.MessageCode)ModelEvaluationCodes.ERR_SAMPLE_STREAM_TO_KERNEL_FAILURE, "Could not stream prediction to Python kernel", (Throwable)e);
                }
            }
            link.receiveJsonResponse(DatasetReceivedResponse.class);
            logger.info((Object)("Parameter sent in " + stopwatch.elapsed(TimeUnit.MILLISECONDS) + "ms"));
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)ModelEvaluationCodes.ERR_SEND_PARAM_TO_KERNEL_FAILURE, "Could not send parameter to Python kernel", (Throwable)e);
        }
    }

    public DriftResult computeDrift(DriftParams params, String referenceId, String currentId, String jobId, boolean computePredictionDrift, Double referenceThreshold, Double currentThreshold) throws IOException {
        logger.info((Object)("Starting computation of drift between " + referenceId + " (reference) and " + currentId + " (current)"));
        if (logger.isDebugEnabled()) {
            logger.info((Object)("Parameters: " + JSON.pretty((Object)params)));
        }
        TransactionContext.assertNoAttachedTransaction();
        this.simplePythonKernel.getLink().sendRequest((Object)new ComputeDriftCommand(params, referenceId, currentId, jobId, computePredictionDrift, referenceThreshold, currentThreshold));
        DriftResultResponse response = (DriftResultResponse)this.simplePythonKernel.getLink().receiveJsonResponse(DriftResultResponse.class);
        if (response.error != null) {
            throw new IllegalArgumentException(response.error.message);
        }
        return response.result;
    }

    static class LoadModelLikeCommand
    extends Command {
        private ColumnImportance columnImportance;
        private Schema predictedSchema;
        public String ref;
        public Schema dataSchema;
        public ResolvedClassicalPredictionPreprocessingParams preprocessingParams;
        public PredictionMLTask.PredictionType predictionType;
        public PredictionModelPredictionInfos predictionStatistics;

        LoadModelLikeCommand(String mleId, Schema dataSchema, Schema predictedSchema, ResolvedClassicalPredictionPreprocessingParams preprocessingParams, PredictionMLTask.PredictionType predictionType, @Nullable ColumnImportance columnImportance, PredictionModelPredictionInfos predictionStatistics) {
            this.ref = mleId;
            this.dataSchema = dataSchema;
            this.predictedSchema = predictedSchema;
            this.preprocessingParams = preprocessingParams;
            this.predictionType = predictionType;
            this.columnImportance = columnImportance;
            this.predictionStatistics = predictionStatistics;
        }

        private LoadModelLikeCommand() {
        }
    }

    static class WaitingForDataResponse
    extends Response {
        WaitingForDataResponse() {
        }
    }

    static class DatasetReceivedResponse
    extends Response {
        DatasetReceivedResponse() {
        }
    }

    static class ComputeDriftCommand
    extends Command {
        public DriftParams params;
        public String ref1;
        public String ref2;
        public String jobId;
        public boolean computePredictionDrift;
        public Double referenceThreshold;
        public Double currentThreshold;

        ComputeDriftCommand(DriftParams params, String ref1, String ref2, String jobId, boolean computePredictionDrift, Double referenceThreshold, Double currentThreshold) {
            this.params = params;
            this.ref1 = ref1;
            this.ref2 = ref2;
            this.jobId = jobId;
            this.computePredictionDrift = computePredictionDrift;
            this.referenceThreshold = referenceThreshold;
            this.currentThreshold = currentThreshold;
        }

        private ComputeDriftCommand() {
        }
    }

    static class DriftResultResponse
    extends Response {
        DriftResult result;
        DataDriftError error;

        DriftResultResponse() {
        }

        static class DataDriftError {
            public String message;

            DataDriftError() {
            }
        }
    }

    @PolyJSON(value={@Mapping(value=DatasetReceivedResponse.class, type="DatasetReceived"), @Mapping(value=WaitingForDataResponse.class, type="WaitingForData"), @Mapping(value=DriftResultResponse.class, type="DriftResult")})
    static abstract class Response {
        Response() {
        }
    }

    @PolyJSON(value={@Mapping(value=LoadModelLikeCommand.class, type="LoadDriftParam"), @Mapping(value=ComputeDriftCommand.class, type="ComputeDrift")})
    static abstract class Command {
        Command() {
        }
    }
}

