/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.scoring.exports;

import com.dataiku.dip.DKUApp;
import com.dataiku.dip.DSSTempUtils;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.model.CompatibilityWithReason;
import com.dataiku.dip.analysis.model.prediction.ClassicalPredictionModelDetails;
import com.dataiku.dip.analysis.model.prediction.PartitionedModelExtract;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.prediction.ResolvedClassicalPredictionCoreParams;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.datalayer.Column;
import com.dataiku.dip.datalayer.Row;
import com.dataiku.dip.datalayer.streamimpl.StreamRow;
import com.dataiku.dip.datalayer.streamimpl.StreamRowFactory;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.export.ZipUnzipDir;
import com.dataiku.dip.input.formats.ExtractionLimit;
import com.dataiku.dip.input.formats.csv.CSVFormatConfig;
import com.dataiku.dip.mec.engine.CSVSchemaAdapter;
import com.dataiku.dip.output.CSVOutputFormatter;
import com.dataiku.dip.output.OutputFormatter;
import com.dataiku.dip.output.PythonOutputFormatter;
import com.dataiku.dip.scoring.exports.ScoringExporter;
import com.dataiku.dip.util.AutoDelete;
import com.dataiku.dip.utils.DKUFileUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.PathUtils;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang.StringUtils;

public class PythonScoring
implements ScoringExporter.ScoringWriter {
    protected static final DKULogger logger = DKULogger.getLogger((String)"dku.scoring.exports.python");
    public static final String CONTENT_TYPE = "application/zip";
    private static final String DEFAULT_SAMPLE = "[\n    {\n        \"field1\": \"value1-1\",\n        \"field2\": \"value1-2\" \n    },{\n        \"field1\": \"value2-1\",\n        \"field2\": \"value2-2\"\n    }\n]";
    private static final int ROWS_IN_SAMPLE = 2;
    public static final String MODEL_RESOURCES_ZIP = "model.zip";
    protected FullModelId fullModelId;

    public PythonScoring(FullModelId fullModelId) {
        this.fullModelId = fullModelId;
    }

    @Override
    public CompatibilityWithReason getCompatibility() throws IOException {
        return PredictionResultsReader.makeDetails((FullModelId)this.fullModelId).pythonCompatibility;
    }

    @Nullable
    public String getPythonVersion() throws IOException {
        ClassicalPredictionModelDetails classicalPredictionModelDetails = PredictionResultsReader.makeDetails(this.fullModelId);
        if (classicalPredictionModelDetails.trainInfo != null) {
            return classicalPredictionModelDetails.trainInfo.pythonVersion;
        }
        return null;
    }

    @Override
    public String makeFileName() throws IOException {
        String slug = PathUtils.slugify((String)this.fullModelId.getHeadMLTask().name);
        return StringUtils.endsWith((String)slug, (String)"-") ? slug + "weights.zip" : slug + "-weights.zip";
    }

    @Override
    public String getContentType() {
        return CONTENT_TYPE;
    }

    @Override
    public void writeTo(OutputStream os) throws Exception {
        try (AutoDelete tmpExportDir = DSSTempUtils.getTempFolder((String)"model-export", (String)"model");){
            this.writeToDir((File)tmpExportDir);
            ZipUnzipDir.zipDirectoryToStream((File)tmpExportDir, os);
        }
    }

    protected static String getDssVersion() {
        return DKUApp.getDSSVersion().product_version;
    }

    protected static String getSamplesPath() {
        return DKUApp.getInstallFolder() + "/resources/python-samples/dataiku-scoring/";
    }

    protected File getRequirementsFileFromSamples(@Nonnull String modelRequirementsContent) {
        String samplePath = PythonScoring.getSamplesPath();
        return new File(samplePath, "requirements.txt");
    }

    public void writeToDir(File tmpExportDir) throws IOException {
        try (AutoDelete tmpResources = DSSTempUtils.getTempFolder((String)"model-export", (String)"resources");){
            Object requirementsContent;
            File exportRequirements = new File(tmpExportDir.getAbsolutePath(), "requirements.txt");
            String dssVersion = PythonScoring.getDssVersion();
            String samplePath = PythonScoring.getSamplesPath();
            String baseExportDir = tmpExportDir.getAbsolutePath();
            File exportResources = new File(baseExportDir + "/model.zip");
            ScoringExporter.copyResourcesFromFMI((File)tmpResources, this.fullModelId);
            ZipUnzipDir.zipDirectory((File)tmpResources, exportResources);
            String sampleFileContent = this.getSampleFileContent(samplePath, this.fullModelId);
            String sampleData = this.generateSampleData(this.fullModelId);
            sampleFileContent = sampleFileContent.replaceAll("SAMPLE_DATA", Matcher.quoteReplacement(sampleData));
            File exportSample = new File(baseExportDir, "sample.py");
            DKUFileUtils.writeFileUTF8((File)exportSample, (String)sampleFileContent);
            String modelRequirementsContent = "";
            if (this.fullModelId.isExternalMLflowModelVersion()) {
                File modelRequirements = new File(String.valueOf(tmpResources) + "/requirements.txt");
                if (modelRequirements.canRead()) {
                    modelRequirementsContent = DKUFileUtils.readFileToStringUTF8((File)modelRequirements);
                }
                File requirements = this.getRequirementsFileFromSamples(modelRequirementsContent);
                requirementsContent = DKUFileUtils.readFileToStringUTF8((File)requirements) + "\n" + modelRequirementsContent;
            } else {
                File requirements = this.getRequirementsFileFromSamples(modelRequirementsContent);
                requirementsContent = DKUFileUtils.readFileToStringUTF8((File)requirements);
            }
            requirementsContent = ((String)requirementsContent).replaceAll("DATAIKUSCORING_VERSION", dssVersion);
            DKUFileUtils.writeFileUTF8((File)exportRequirements, (String)requirementsContent);
        }
    }

    protected String getSampleFileContent(String samplePath, FullModelId fullModelId) throws IOException {
        String samplesDirName = "samples-python";
        File samplesDir = new File(samplePath, samplesDirName);
        String sampleFileName = "sample_other.py";
        if (PredictionMLTask.PredictionType.MULTICLASS.equals((Object)fullModelId.getPredictionType()) || PredictionMLTask.PredictionType.BINARY_CLASSIFICATION.equals((Object)fullModelId.getPredictionType())) {
            sampleFileName = "sample_classification.py";
        }
        return DKUFileUtils.readFileToStringUTF8((File)new File(samplesDir, sampleFileName));
    }

    /*
     * Enabled aggressive exception aggregation
     */
    private String generateSampleData(FullModelId fullModelId) {
        try {
            Schema outputSchema;
            Schema inputSchema;
            InputStream sampleDataInputStream;
            if (fullModelId.isPartitionedBaseModel()) {
                FullModelId firstUsableFmi = this.getFirstUsableModelPartition(fullModelId);
                if (firstUsableFmi == null) {
                    logger.warn((Object)"No usable model partition found, returning default code sample.");
                    return DEFAULT_SAMPLE;
                }
                if (!firstUsableFmi.hasDataToStreamCSV()) {
                    logger.warn((Object)"No sample data found, returning default sample.");
                    return DEFAULT_SAMPLE;
                }
                LinkedHashMap<String, String> dimensionNameToFirstUsableValue = this.buildDimensionsSamples(firstUsableFmi);
                ArrayList<SchemaColumn> additionalColumns = new ArrayList<SchemaColumn>();
                Schema modelPartitionSchema = this.getInputSchema(firstUsableFmi);
                for (String column : dimensionNameToFirstUsableValue.keySet()) {
                    if (modelPartitionSchema.hasColumn(column)) continue;
                    additionalColumns.add(new SchemaColumn(column, Type.STRING));
                }
                if (!additionalColumns.isEmpty()) {
                    Schema inputSchemaWithDimensions = new Schema(new ArrayList(additionalColumns), false);
                    inputSchemaWithDimensions.columns.addAll(modelPartitionSchema.columns);
                    Schema outputSchemaWithDimensions = new Schema(new ArrayList(additionalColumns), false);
                    outputSchemaWithDimensions.columns.addAll(this.getOutputSchema((FullModelId)firstUsableFmi).columns);
                    try (InputStream firstFmiInputStream = firstUsableFmi.openDataInputStreamCSV();){
                        String csvWithDimensions = this.transformCSVWithForcedValuesForMissingColumns(firstFmiInputStream, modelPartitionSchema, inputSchemaWithDimensions, dimensionNameToFirstUsableValue);
                        sampleDataInputStream = new ByteArrayInputStream(csvWithDimensions.getBytes("utf8"));
                        inputSchema = inputSchemaWithDimensions;
                        outputSchema = outputSchemaWithDimensions;
                    }
                } else {
                    inputSchema = this.getInputSchema(firstUsableFmi);
                    outputSchema = this.getOutputSchema(firstUsableFmi);
                    sampleDataInputStream = firstUsableFmi.openDataInputStreamCSV();
                }
            } else {
                if (!fullModelId.hasDataToStreamCSV()) {
                    logger.warn((Object)"No sample data found, returning default sample.");
                    return DEFAULT_SAMPLE;
                }
                inputSchema = this.getInputSchema(fullModelId);
                outputSchema = this.getOutputSchema(fullModelId);
                sampleDataInputStream = fullModelId.openDataInputStreamCSV();
            }
            try (InputStream sampleData = sampleDataInputStream;){
                String string;
                try (ByteArrayOutputStream baos = new ByteArrayOutputStream();){
                    PythonOutputFormatter outputFormatter = new PythonOutputFormatter(true, 4);
                    outputFormatter.setOutputSchema(outputSchema);
                    CSVSchemaAdapter.transformCSV(sampleData, inputSchema, false, baos, (OutputFormatter)outputFormatter, new StreamRowFactory(), new ExtractionLimit(2L));
                    string = baos.toString("utf8");
                }
                return string;
            }
        }
        catch (Exception e) {
            logger.error((Object)"Error while generating sample data, returning default sample.", (Throwable)e);
            return DEFAULT_SAMPLE;
        }
    }

    private Schema getInputSchema(FullModelId fullModelId) throws IOException {
        File datasetSchemaFile = fullModelId.getSessionFile("input_dataset_schema.json");
        if (datasetSchemaFile.exists()) {
            return (Schema)JSON.parseFile((File)datasetSchemaFile, Schema.class);
        }
        return fullModelId.getNonScoredSchema();
    }

    private Schema getOutputSchema(FullModelId fullModelId) throws IOException {
        Schema schema = this.getInputSchema(fullModelId);
        String targetName = fullModelId.getResolvedPredictionPreprocessingParams().getTarget();
        schema.removeColumn(targetName);
        return schema;
    }

    private String transformCSVWithForcedValuesForMissingColumns(InputStream csvInputStream, Schema inputSchema, Schema outputSchema, LinkedHashMap<String, String> forcedValuesForMissingColumns) throws Exception {
        CSVFormatConfig outCsvf = CSVFormatConfig.getStandardTabExcelFormat();
        outCsvf.parseHeaderRow = false;
        CSVOutputFormatter csvOutputFormatter = new CSVOutputFormatter(outCsvf);
        csvOutputFormatter.setOutputSchema(outputSchema);
        StreamRowFactoryWithForcedValuesForSomeColumns rowFactory = new StreamRowFactoryWithForcedValuesForSomeColumns(forcedValuesForMissingColumns);
        try (ByteArrayOutputStream baos = new ByteArrayOutputStream();){
            CSVSchemaAdapter.transformCSV(csvInputStream, inputSchema, false, baos, (OutputFormatter)csvOutputFormatter, rowFactory, new ExtractionLimit(2L));
            String string = baos.toString("utf8");
            return string;
        }
    }

    private FullModelId getFirstUsableModelPartition(FullModelId fullModelId) {
        try {
            PartitionedModelExtract extract = fullModelId.getPartitionedModelExtract();
            Collection<PartitionedModelExtract.PartitionedModelSummary> summaries = extract.summaries.values();
            if (CollectionUtils.isEmpty(summaries)) {
                return null;
            }
            PartitionedModelExtract.PartitionedModelSummary firstUsableFmiSnippet = extract.summaries.values().stream().filter(summary -> summary.state.isUsable()).findFirst().orElseGet(null);
            if (firstUsableFmiSnippet == null) {
                return null;
            }
            String partition = firstUsableFmiSnippet.snippet.partitionName;
            return fullModelId.getModelPartition(partition);
        }
        catch (Exception e) {
            logger.error((Object)("Error while retrieving partitions of model: " + fullModelId.toString()), (Throwable)e);
            return null;
        }
    }

    private LinkedHashMap<String, String> buildDimensionsSamples(FullModelId modelPartition) throws IOException {
        LinkedHashMap<String, String> result = new LinkedHashMap<String, String>();
        String[] parts = modelPartition.getPartitionName().split(Pattern.quote("|"));
        ResolvedClassicalPredictionCoreParams resolvedClassicalPredictionCoreParams = (ResolvedClassicalPredictionCoreParams)this.fullModelId.getResolvedCoreParams();
        for (int i = 0; i < resolvedClassicalPredictionCoreParams.partitionedModel.dimensionNames.size(); ++i) {
            String dimensionName = resolvedClassicalPredictionCoreParams.partitionedModel.dimensionNames.get(i);
            String dimensionValue = parts[i];
            result.put(dimensionName, dimensionValue);
        }
        return result;
    }

    private class StreamRowFactoryWithForcedValuesForSomeColumns
    extends StreamRowFactory {
        private final Map<String, String> forcedValues;

        public StreamRowFactoryWithForcedValuesForSomeColumns(Map<String, String> forcedValues) {
            this.forcedValues = forcedValues;
        }

        public Row row() {
            return new StreamRow(){

                public String get(Column cd) {
                    String forcedValue = StreamRowFactoryWithForcedValuesForSomeColumns.this.forcedValues.get(cd.getName());
                    if (forcedValue != null) {
                        return forcedValue;
                    }
                    return super.get(cd);
                }
            };
        }
    }
}

