/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.server.openapi;

import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.prediction.PredictionResultsReader;
import com.dataiku.dip.analysis.model.MLTask;
import com.dataiku.dip.analysis.model.prediction.PredictionMLTask;
import com.dataiku.dip.analysis.model.preprocessing.FeaturePreprocessingParams;
import com.dataiku.dip.coremodel.Schema;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dao.SavedModel;
import com.dataiku.dip.datasets.Type;
import com.dataiku.dip.externalml.mlflow.MLFlowModelVersionInfo;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.com.fasterxml.jackson.databind.ObjectMapper;
import com.dataiku.dss.shadelib.io.swagger.models.Info;
import com.dataiku.dss.shadelib.io.swagger.models.Model;
import com.dataiku.dss.shadelib.io.swagger.models.ModelImpl;
import com.dataiku.dss.shadelib.io.swagger.models.Operation;
import com.dataiku.dss.shadelib.io.swagger.models.Path;
import com.dataiku.dss.shadelib.io.swagger.models.RefModel;
import com.dataiku.dss.shadelib.io.swagger.models.Response;
import com.dataiku.dss.shadelib.io.swagger.models.Swagger;
import com.dataiku.dss.shadelib.io.swagger.models.parameters.BodyParameter;
import com.dataiku.dss.shadelib.io.swagger.models.parameters.Parameter;
import com.dataiku.dss.shadelib.io.swagger.models.properties.BooleanProperty;
import com.dataiku.dss.shadelib.io.swagger.models.properties.DoubleProperty;
import com.dataiku.dss.shadelib.io.swagger.models.properties.FloatProperty;
import com.dataiku.dss.shadelib.io.swagger.models.properties.IntegerProperty;
import com.dataiku.dss.shadelib.io.swagger.models.properties.ObjectProperty;
import com.dataiku.dss.shadelib.io.swagger.models.properties.Property;
import com.dataiku.dss.shadelib.io.swagger.models.properties.StringProperty;
import com.dataiku.dss.shadelib.io.swagger.parser.SwaggerParser;
import com.dataiku.lambda.model.serverconfig.LambdaEndpointConfig;
import com.dataiku.lambda.model.studioconfig.DSSLambdaEndpointConfig;
import com.dataiku.lambda.model.studioconfig.OpenAPIDoc;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
import org.apache.commons.lang.StringUtils;

public class OpenAPIHelper {
    public static final String NO_OPEN_API_DOCUMENTATION = "No OpenAPI documentation";
    private static final String EI = "input";
    private static final String DEFAULT_ENDPOINT_DESCRIPTION = "Describe what the endpoint call does, what input it takes, what it returns ...";
    public static final String ENDPOINT_INPUT_REF = "EndpointInput";
    private static final String ENDPOINT_INPUT_DESC = "The input payload for executing the real-time machine learning service.";
    private static final String EO = "output";
    private static final String ENDPOINT_OUTPUT_REF = "EndpointOutput";
    private static final String ENDPOINT_OUTPUT_DESC = "The service processed the input correctly and provided a result prediction, if applicable.";
    private static final String EE = "error";
    private static final String ENDPOINT_ERROR_REF = "ErrorResponse";
    private static final String ENDPOINT_ERROR_DESC = "The service failed to execute due to an error.";
    private static final DKULogger logger = DKULogger.getLogger((String)"dku.features.openAPIHelper");

    public static void enrichSwaggerWithInfo(Swagger swagger, String serviceId, String version, String shortDesc) {
        swagger.info(new Info().title(serviceId).description(shortDesc).version(version));
        swagger.consumes("application/json");
        swagger.produces("application/json");
    }

    public static void enrichServiceSwaggerWithEndpointsOpenAPIDoc(Swagger swagger, List<OpenAPIDoc> serviceEndpointOADs) {
        LinkedHashMap endPointPath = new LinkedHashMap();
        for (OpenAPIDoc epOAD : serviceEndpointOADs) {
            String oad;
            Swagger epSwagger;
            if (!epOAD.enabled || (epSwagger = new SwaggerParser().parse(oad = epOAD.content.replaceFirst("\\{", "{\"swagger\": \"2.0\","))) == null) continue;
            epSwagger.getDefinitions().forEach((arg_0, arg_1) -> ((Swagger)swagger).addDefinition(arg_0, arg_1));
            endPointPath.putAll(epSwagger.getPaths());
        }
        swagger.setPaths(endPointPath);
    }

    public static Map<String, String> enrichSwaggerWithDefinitionsAndGetReferenceMap(Swagger swagger, DSSLambdaEndpointConfig endpointConfig, SavedModel sm, boolean isManual) throws IOException {
        HashMap<String, String> refMap = OpenAPIHelper.createRefMap(endpointConfig.id);
        if (isManual || sm == null || sm.activeVersion == null || !OpenAPIHelper.savedModelSupportAutomatedMode(sm)) {
            OpenAPIHelper.enrichSwaggerWithDefaultDef(swagger, refMap, endpointConfig.type);
        } else {
            OpenAPIHelper.enrichSwaggerWithAutomatedDef(swagger, refMap, sm);
        }
        return refMap;
    }

    public static void enrichSwaggerWithPath(Swagger swagger, DSSLambdaEndpointConfig endpointConfig, Map<String, String> refMap) {
        Operation operation = OpenAPIHelper.createOperation(endpointConfig, refMap);
        swagger.path(OpenAPIHelper.getEndpointCallName(endpointConfig.id, endpointConfig.type), new Path().post(operation));
    }

    public static boolean endpointTypeSupportAutomatedMode(LambdaEndpointConfig.EndpointType type) {
        return type == LambdaEndpointConfig.EndpointType.STD_PREDICTION;
    }

    private static boolean savedModelSupportAutomatedMode(@Nonnull SavedModel sm) {
        if (sm.getType() == MLTask.MLTaskType.PREDICTION) {
            PredictionMLTask miniTask = (PredictionMLTask)sm.miniTask;
            return null != miniTask && null != miniTask.predictionType;
        }
        return false;
    }

    private static void enrichSwaggerWithDefaultDef(Swagger swagger, Map<String, String> refMap, LambdaEndpointConfig.EndpointType endpointType) {
        swagger.addDefinition(refMap.get(EI), (Model)OpenAPIHelper.createDefaultInputModel(endpointType));
        swagger.addDefinition(refMap.get(EO), (Model)OpenAPIHelper.createDefaultOutputModel());
        swagger.addDefinition(refMap.get(EE), (Model)OpenAPIHelper.createErrorModel());
    }

    private static void enrichSwaggerWithAutomatedDef(Swagger swagger, Map<String, String> refMap, SavedModel sm) throws IOException {
        FullModelId fmi = new FullModelId(sm.projectKey, sm.id, sm.activeVersion);
        swagger.addDefinition(refMap.get(EI), (Model)OpenAPIHelper.createInputModel(fmi, sm.savedModelType));
        swagger.addDefinition(refMap.get(EO), (Model)OpenAPIHelper.createOutputModel(fmi));
        swagger.addDefinition(refMap.get(EE), (Model)OpenAPIHelper.createErrorModel());
    }

    private static Operation createOperation(DSSLambdaEndpointConfig endpointConfig, Map<String, String> refMap) {
        String operationDescription = endpointConfig.openAPI != null && !endpointConfig.openAPI.isManual && StringUtils.isNotBlank((String)endpointConfig.openAPI.description) ? endpointConfig.openAPI.description : DEFAULT_ENDPOINT_DESCRIPTION;
        Operation operation = new Operation().summary(endpointConfig.id + " endpoint").description(operationDescription);
        BodyParameter inputParameter = new BodyParameter().name("serviceInputPayload").description(ENDPOINT_INPUT_DESC).schema((Model)new RefModel(refMap.get(EI)));
        inputParameter.setRequired(true);
        operation.addParameter((Parameter)inputParameter);
        Response validResponse = new Response().description(ENDPOINT_OUTPUT_DESC).responseSchema((Model)new RefModel(refMap.get(EO)));
        Response errorResponse = new Response().description(ENDPOINT_ERROR_DESC).responseSchema((Model)new RefModel(refMap.get(EE)));
        operation.addResponse("200", validResponse);
        operation.defaultResponse(errorResponse);
        return operation;
    }

    private static HashMap<String, String> createRefMap(String endpointId) {
        HashMap<String, String> res = new HashMap<String, String>();
        res.put(EI, endpointId + "_EndpointInput");
        res.put(EO, endpointId + "_EndpointOutput");
        res.put(EE, endpointId + "_ErrorResponse");
        return res;
    }

    private static ModelImpl createErrorModel() {
        ModelImpl errorModel = new ModelImpl().type("object");
        errorModel.addProperty("status_code", (Property)new IntegerProperty());
        errorModel.addProperty("message", (Property)new StringProperty());
        return errorModel;
    }

    private static ModelImpl createInputModel(FullModelId fmi, SavedModel.SavedModelType smType) throws IOException {
        ObjectProperty features = new ObjectProperty().name("features");
        List input_feature = PredictionResultsReader.makeModelDetails((FullModelId)fmi).getPreprocessing().per_feature.entrySet().stream().filter(e -> ((FeaturePreprocessingParams)e.getValue()).role.equals((Object)FeaturePreprocessingParams.Role.INPUT)).map(Map.Entry::getKey).collect(Collectors.toList());
        List<Object> listSchemaColumn = new ArrayList();
        switch (smType) {
            case MLFLOW_PYFUNC: {
                File inputSchemaFile = fmi.getSessionFile("mlflow_imported_model.json");
                if (!inputSchemaFile.exists()) break;
                listSchemaColumn = ((MLFlowModelVersionInfo)JSON.parseFile((File)inputSchemaFile, MLFlowModelVersionInfo.class)).features;
                break;
            }
            case DSS_MANAGED: {
                File inputSchemaFile = fmi.getSessionFile("input_dataset_schema.json");
                if (!inputSchemaFile.exists()) break;
                listSchemaColumn = ((Schema)JSON.parseFile((File)inputSchemaFile, Schema.class)).getColumns();
                break;
            }
        }
        features.setProperties(OpenAPIHelper.createPropertyMapFromInputSchema(listSchemaColumn.stream().filter(e -> input_feature.contains(e.getName())).collect(Collectors.toList())));
        return new ModelImpl().type("object").property("features", (Property)features);
    }

    private static ModelImpl createOutputModel(FullModelId fmi) throws IOException {
        ModelImpl outputModel = new ModelImpl().type("object");
        ObjectProperty result = new ObjectProperty();
        switch (fmi.getPredictionType()) {
            case BINARY_CLASSIFICATION: 
            case REGRESSION: 
            case MULTICLASS: {
                result.property("prediction", (Property)new StringProperty());
                break;
            }
        }
        outputModel.addProperty("result", (Property)result);
        return outputModel;
    }

    private static HashMap<String, Property> createPropertyMapFromInputSchema(List<SchemaColumn> schemaColList) {
        HashMap<String, Property> res = new HashMap<String, Property>();
        for (SchemaColumn col : schemaColList) {
            if (col == null) continue;
            StringProperty prop = switch (col.getType()) {
                case Type.STRING, Type.DATE, Type.GEOPOINT, Type.GEOMETRY, Type.MAP, Type.OBJECT, Type.ARRAY -> new StringProperty();
                case Type.DOUBLE -> new DoubleProperty();
                case Type.FLOAT -> new FloatProperty();
                case Type.BIGINT, Type.INT, Type.SMALLINT, Type.TINYINT -> new IntegerProperty();
                case Type.BOOLEAN -> new BooleanProperty();
                default -> null;
            };
            res.put(col.getName(), (Property)prop);
        }
        return res;
    }

    public static String getEndpointCallName(String endpointId, LambdaEndpointConfig.EndpointType type) {
        String res = "/" + endpointId + "/";
        return switch (type) {
            default -> throw new IncompatibleClassChangeError();
            case LambdaEndpointConfig.EndpointType.STD_PREDICTION, LambdaEndpointConfig.EndpointType.STD_CAUSAL_PREDICTION, LambdaEndpointConfig.EndpointType.CUSTOM_PREDICTION, LambdaEndpointConfig.EndpointType.CUSTOM_R_PREDICTION, LambdaEndpointConfig.EndpointType.STD_CLUSTERING -> res + "predict";
            case LambdaEndpointConfig.EndpointType.STD_FORECAST -> res + "forecast";
            case LambdaEndpointConfig.EndpointType.DATASETS_LOOKUP -> res + "lookup";
            case LambdaEndpointConfig.EndpointType.PY_FUNCTION, LambdaEndpointConfig.EndpointType.R_FUNCTION -> res + "run";
            case LambdaEndpointConfig.EndpointType.SQL_QUERY -> res + "query";
        };
    }

    private static ModelImpl createDefaultInputModel(LambdaEndpointConfig.EndpointType endpointType) {
        ModelImpl model = new ModelImpl().type("object");
        HashMap<String, Object> properties = new HashMap<String, Object>();
        properties.put("stringPropExample", new StringProperty());
        properties.put("doublePropExample", new DoubleProperty());
        properties.put("floatPropExample", new FloatProperty());
        properties.put("intPropExample", new IntegerProperty());
        properties.put("boolPropExample", new BooleanProperty());
        if (LambdaEndpointConfig.EndpointType.PY_FUNCTION.equals((Object)endpointType) || LambdaEndpointConfig.EndpointType.R_FUNCTION.equals((Object)endpointType)) {
            model.setProperties(properties);
        } else if (LambdaEndpointConfig.EndpointType.DATASETS_LOOKUP.equals((Object)endpointType)) {
            ObjectProperty features = new ObjectProperty();
            features.setType("object");
            features.setProperties(properties);
            model.property("data", (Property)features);
        } else {
            ObjectProperty features = new ObjectProperty();
            features.setType("object");
            features.setProperties(properties);
            model.property("features", (Property)features);
        }
        return model;
    }

    private static ModelImpl createDefaultOutputModel() {
        ModelImpl outputModel = new ModelImpl().type("object");
        outputModel.addProperty("result", (Property)new ObjectProperty().property("OutputPropExample", (Property)new StringProperty()));
        return outputModel;
    }

    public static Swagger getSwagger(String openAPIContent) {
        if (openAPIContent == null) {
            throw new IllegalArgumentException("The openAPI documentation content is null.");
        }
        try {
            return new SwaggerParser().read(new ObjectMapper().readTree(openAPIContent));
        }
        catch (Exception e) {
            throw new IllegalArgumentException("Could not parse the openAPI documentation description", e);
        }
    }

    public static Operation getEndpointOperation(Swagger swagger, String endpointId, LambdaEndpointConfig.EndpointType endpointType) {
        try {
            Path path = swagger.getPath(OpenAPIHelper.getEndpointCallName(endpointId, endpointType));
            return (Operation)path.getOperations().get(0);
        }
        catch (Exception e) {
            logger.warn((Object)"Could not parse the operation of your openAPI documentation", (Throwable)e);
            return null;
        }
    }
}

