/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.externalml.mlflow;

import com.dataiku.dip.coremodel.InfoMessage;
import com.dataiku.dip.coremodel.SchemaColumn;
import com.dataiku.dip.dataflow.exec.CodeBasedRecipeDatasetInfoHelper;
import com.dataiku.dip.exceptions.CodedException;
import com.dataiku.dip.externalinfras.databricks.datamodel.DatabricksRegisteredModel;
import com.dataiku.dip.externalinfras.databricks.datamodel.DatabricksRegisteredModelVersion;
import com.dataiku.dip.externalml.mlflow.DatabricksImportCodes;
import com.dataiku.dip.io.JavaBlockLink;
import com.dataiku.dip.io.SimplePythonKernel;
import com.dataiku.dip.io.SocketBlockLinkInteraction;
import com.dataiku.dip.io.SocketBlockLinkKernelException;
import com.dataiku.dip.threads.BaseKernelProtocol;
import com.dataiku.dip.utils.polyjson.Mapping;
import com.dataiku.dip.utils.polyjson.PolyJSON;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

public class DatabricksUtilsKernelProtocol
extends BaseKernelProtocol
implements AutoCloseable {
    public static final String PYTHON_PACKAGE = "dataiku.external_ml.databricks.utils_server";

    public DatabricksUtilsKernelProtocol(SimplePythonKernel simplePythonKernel) {
        super(simplePythonKernel);
    }

    protected void checkResponse(Response r) throws Exception {
        if (null != r && null != r.error) {
            throw new SocketBlockLinkKernelException("Databricks utils error", r.error);
        }
    }

    public File downloadModelFromDatabricks(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog, String modelName, String modelVersion, String targetDirectory) throws CodedException {
        JavaBlockLink link;
        try {
            link = this.simplePythonKernel.getLink();
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_MODEL_DOWNLOAD_FROM_REGISTRY_FAILURE, "Could not get link to Python kernel", (Throwable)e);
        }
        try {
            link.sendRequest((Object)new RequestDownload(connectionInfo, useUnityCatalog, modelName, modelVersion, targetDirectory));
            this.checkResponse((Response)link.receiveJsonResponse(RequestDownloadResponse.class));
            File f = new File(targetDirectory);
            if (!f.exists()) {
                throw new IOException("Download successful but download dir " + targetDirectory + "does not exist");
            }
            if (!f.isDirectory()) {
                throw new IOException(targetDirectory + " exists but is not a directory");
            }
            return f;
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_MODEL_DOWNLOAD_FROM_REGISTRY_FAILURE, "Error when downloading model", (Throwable)e);
        }
    }

    public RequestModelRegistrationResponse registerModelInDatabricks(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog, String modelName, String modelDirectory, String experimentName, @Nullable List<SchemaColumn> columns, @Nonnull String targetName) throws CodedException {
        JavaBlockLink link;
        try {
            link = this.simplePythonKernel.getLink();
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_MODEL_REGISTRATION_FAILURE, "Could not get link to Python kernel", (Throwable)e);
        }
        try {
            link.sendRequest((Object)new RequestModelRegistration(modelDirectory, connectionInfo, useUnityCatalog, modelName, experimentName, columns, targetName));
            RequestModelRegistrationResponse response = (RequestModelRegistrationResponse)link.receiveJsonResponse(RequestModelRegistrationResponse.class);
            this.checkResponse(response);
            return response;
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_MODEL_REGISTRATION_FAILURE, "Error when registering model", (Throwable)e);
        }
    }

    public List<DatabricksRegisteredModel> listRegisteredModels(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog) throws CodedException {
        JavaBlockLink link;
        try {
            link = this.simplePythonKernel.getLink();
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_LISTING_REGISTERED_MODELS_FAILURE, "Could not get link to Python kernel", (Throwable)e);
        }
        try {
            link.sendRequest((Object)new ListRegisteredModels(connectionInfo, useUnityCatalog));
            ListRegisteredModelsResponse listRegisteredModelsResponse = (ListRegisteredModelsResponse)link.receiveJsonResponse(ListRegisteredModelsResponse.class);
            this.checkResponse(listRegisteredModelsResponse);
            return listRegisteredModelsResponse.result;
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_LISTING_REGISTERED_MODELS_FAILURE, "Error when listing registered models", (Throwable)e);
        }
    }

    public List<DatabricksRegisteredModelVersion> listRegisteredModelVersions(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, String modelName, boolean useUnityCatalog) throws CodedException {
        JavaBlockLink link;
        try {
            link = this.simplePythonKernel.getLink();
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_LISTING_REGISTERED_MODEL_VERSIONS_FAILURE, "Could not get link to Python kernel", (Throwable)e);
        }
        try {
            link.sendRequest((Object)new ListRegisteredModelVersions(connectionInfo, modelName, useUnityCatalog));
            ListRegisteredModelVersionsResponse listRegisteredModelVersionsResponse = (ListRegisteredModelVersionsResponse)link.receiveJsonResponse(ListRegisteredModelVersionsResponse.class);
            this.checkResponse(listRegisteredModelVersionsResponse);
            return listRegisteredModelVersionsResponse.result;
        }
        catch (Exception e) {
            throw new CodedException((InfoMessage.MessageCode)DatabricksImportCodes.ERR_LISTING_REGISTERED_MODEL_VERSIONS_FAILURE, "Error when listing versions of a registered models", (Throwable)e);
        }
    }

    @PolyJSON(value={@Mapping(value=RequestDownloadResponse.class, type="RequestDownloadResponse"), @Mapping(value=ListRegisteredModelsResponse.class, type="ListRegisteredModelsResponse"), @Mapping(value=ListRegisteredModelVersionsResponse.class, type="ListRegisteredModelVersionsResponse"), @Mapping(value=RequestModelRegistrationResponse.class, type="RequestModelRegistrationResponse")})
    static abstract class Response {
        public SocketBlockLinkInteraction.SocketBlockLinkKernelError error;

        Response() {
        }
    }

    static class RequestDownload
    extends Command {
        public String modelName;
        public String modelVersion;
        public String targetDirectory;

        RequestDownload(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog, String modelName, String modelVersion, String targetDirectory) {
            super(connectionInfo, useUnityCatalog);
            this.connectionInfo = connectionInfo;
            this.modelName = modelName;
            this.modelVersion = modelVersion;
            this.targetDirectory = targetDirectory;
        }

        private RequestDownload() {
        }
    }

    static class RequestDownloadResponse
    extends Response {
        RequestDownloadResponse() {
        }
    }

    static class RequestModelRegistration
    extends Command {
        public String modelName;
        public String modelDirectory;
        public String experimentName;
        public List<SchemaColumn> columns;
        public String targetName;

        RequestModelRegistration(String modelDirectory, CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog, String modelName, String experimentName, List<SchemaColumn> columns, String targetName) {
            super(connectionInfo, useUnityCatalog);
            this.modelDirectory = modelDirectory;
            this.modelName = modelName;
            this.experimentName = experimentName;
            this.columns = columns;
            this.targetName = targetName;
        }

        private RequestModelRegistration() {
        }
    }

    public static class RequestModelRegistrationResponse
    extends Response {
        public String name;
        public String description;
        public String status;
        public String statusMessage;
        public String version;
        public String runId;
        public String experimentId;
    }

    static class ListRegisteredModels
    extends Command {
        ListRegisteredModels(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog) {
            super(connectionInfo, useUnityCatalog);
        }

        private ListRegisteredModels() {
        }
    }

    static class ListRegisteredModelsResponse
    extends Response {
        public List<DatabricksRegisteredModel> result = new ArrayList<DatabricksRegisteredModel>();

        ListRegisteredModelsResponse() {
        }
    }

    static class ListRegisteredModelVersions
    extends Command {
        public String modelName;

        ListRegisteredModelVersions(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, String modelName, boolean useUnityCatalog) {
            super(connectionInfo, useUnityCatalog);
            this.modelName = modelName;
        }

        private ListRegisteredModelVersions() {
        }
    }

    static class ListRegisteredModelVersionsResponse
    extends Response {
        public List<DatabricksRegisteredModelVersion> result = new ArrayList<DatabricksRegisteredModelVersion>();

        ListRegisteredModelVersionsResponse() {
        }
    }

    @PolyJSON(value={@Mapping(value=RequestDownload.class, type="RequestDownload"), @Mapping(value=ListRegisteredModels.class, type="ListRegisteredModels"), @Mapping(value=ListRegisteredModelVersions.class, type="ListRegisteredModelVersions"), @Mapping(value=RequestModelRegistration.class, type="RequestModelRegistration")})
    static abstract class Command {
        public CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo;
        boolean useUnityCatalog;

        Command(CodeBasedRecipeDatasetInfoHelper.ConnectionLocationInfo connectionInfo, boolean useUnityCatalog) {
            this.connectionInfo = connectionInfo;
            this.useUnityCatalog = useUnityCatalog;
        }

        private Command() {
        }
    }
}

