/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.connections;

import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.AbstractSQLConnection;
import com.dataiku.dip.connections.ConnectionWithAWSAuthCredentials;
import com.dataiku.dip.connections.ConnectionWithBasicCredential;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.SageMakerConnection;
import com.dataiku.dip.coremodel.SimpleKeyValue;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.security.PasswordEncryptionService;
import com.dataiku.dip.server.SpringUtils;
import com.dataiku.dip.server.services.ConnectionsTestService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.ifaces.IsolationLevel;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.variables.VariablesContext;
import java.io.IOException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.lang.StringUtils;

public class SageMakerGenericLLMConnection
extends AbstractLLMConnection<SageMakerModel, AbstractLLMConnection.IHardcodedConnectionModel<SageMakerModel>, CustomSageMakerModel> {
    public static final String SAGEMAKER_GENERIC_LLM_CONNECTION_TYPE = "SageMaker-GenericLLM";
    public SageMakerGenericLLMConnectionParams params = new SageMakerGenericLLMConnectionParams();
    private static final DKULogger logger = DKULogger.getLogger((String)"dip.connections.sagemaker-genericllm");

    @Override
    protected DKULogger getLogger() {
        return logger;
    }

    @Override
    public AbstractLLMConnection.AbstractLLMConnectionParams getLLMConnectionParams() {
        return this.params;
    }

    @Override
    protected List<CustomSageMakerModel> listRawCustomModels() {
        return Collections.singletonList(this.params.sageMakerModel);
    }

    @Override
    protected SageMakerModel loadRawCustomModel(CustomSageMakerModel rawCustomModel) {
        SageMakerModel model = rawCustomModel.toModel();
        this.loadDefaultCustomModelSettings(model);
        return model;
    }

    @Override
    public void encryptFields(PasswordEncryptionService cryptoService, GeneralSettingsDAO.SecuritySettings securitySettings) {
    }

    @Override
    public void decryptFields(PasswordEncryptionService cryptoService) {
    }

    @Override
    public String getType() {
        return SAGEMAKER_GENERIC_LLM_CONNECTION_TYPE;
    }

    @Override
    public void expandParametersInPlaceAtDAOLevelUsingGlobalContextOnly(VariablesContext vc) {
        this.params.region = vc.expand(this.params.region);
    }

    @Override
    public List<AbstractSQLConnection.CustomDatabaseProperty> getDkuProperties() {
        return new ArrayList<AbstractSQLConnection.CustomDatabaseProperty>();
    }

    @Override
    protected <T> T getFullyResolvedCredentials_internal(ConnectionWithBasicCredential.CredentialResolutionContext ctx, Class<T> clazz) throws DKUSecurityException, IOException, SQLException {
        SageMakerConnection smConn;
        assert (clazz.isAssignableFrom(ConnectionWithAWSAuthCredentials.SerializableAWSCredential.class));
        TransactionService ts = (TransactionService)SpringUtils.getBean(TransactionService.class);
        try (Transaction t = ts.retrieveOrBeginRead(IsolationLevel.YOLO);){
            smConn = ConnectionsDAO.get().getMandatoryConnectionAs(ctx.authCtx, this.params.sageMakerConnection, SageMakerConnection.class);
        }
        ConnectionWithAWSAuthCredentials.SerializableAWSCredential creds = smConn.getFullyResolvedCredentials_internal(ctx, ConnectionWithAWSAuthCredentials.SerializableAWSCredential.class);
        return clazz.cast(creds);
    }

    @Override
    public Map<String, Object> getConsistencyCheckables() {
        Map<String, Object> consistencyCheckables = super.getConsistencyCheckables();
        consistencyCheckables.put(ENABLED_MODELS, Collections.singleton(this.params.endpointName));
        return consistencyCheckables;
    }

    @Override
    public ConnectionsTestService.ConnectionTestResult testConnection(AuthCtx authCtx, ConnectionsTestService connectionsTestService) throws Exception {
        return connectionsTestService.testSageMakerGenericLLM(this, authCtx);
    }

    public static class SageMakerGenericLLMConnectionParams
    extends AbstractLLMConnection.AbstractLLMConnectionParams {
        public AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        public String sageMakerConnection;
        public String region;
        public int maxParallelism = 8;
        public String endpointName;
        public boolean acceptEULA = true;
        @Nullable
        public String customQuery;
        @Nullable
        public String responseJsonPath;
        @Nullable
        public String finishReasonJsonPath;
        @Nullable
        public String promptTokensJsonPath;
        @Nullable
        public String completionTokensJsonPath;
        public CustomSageMakerModel sageMakerModel = new CustomSageMakerModel();
    }

    public static class CustomSageMakerModel
    extends AbstractLLMConnection.CustomModel<SageMakerModel> {
        public CustomSageMakerModelType modelType = CustomSageMakerModelType.TEXT_COMPLETION;
        public GenericLLMHandling handling = GenericLLMHandling.HUGGING_FACE;
        public String friendlyNameShort;
        public List<SimpleKeyValue> customHeaders = new ArrayList<SimpleKeyValue>();

        @Override
        public SageMakerModel toModel() {
            SageMakerModel model = new SageMakerModel();
            model.loadFromCustomModel(this);
            model.id = this.friendlyNameShort;
            model.displayName = "";
            model.modelType = this.modelType;
            model.handling = this.handling;
            model.customHeaders = this.customHeaders;
            return model;
        }
    }

    public static class SageMakerModel
    extends AbstractLLMConnection.BaseModel {
        public CustomSageMakerModelType modelType;
        public GenericLLMHandling handling;
        public List<SimpleKeyValue> customHeaders = new ArrayList<SimpleKeyValue>();

        @Override
        public LLMStructuredRef asStructuredRef(String connection) {
            return LLMStructuredRef.forGenericSageMakerConnection(connection, this.modelType);
        }

        @Override
        public Optional<String> getInvalidityReason() {
            if (StringUtils.isBlank((String)this.getId())) {
                return Optional.of("Missing model friendlyName");
            }
            if (this.modelType == null) {
                return Optional.of("Missing model type");
            }
            if (this.handling == null) {
                return Optional.of("Missing model handling");
            }
            return Optional.empty();
        }

        @Override
        public boolean canBeUsedForPurpose(@Nonnull AbstractLLMConnection.LLMUsagePurpose purpose) {
            if (this.modelType == null) {
                return false;
            }
            return this.modelType.matchesPurpose(purpose);
        }

        @Override
        public AbstractLLMConnection.ModelCapabilities getModelCapabilities() {
            AbstractLLMConnection.ModelCapabilities capabilities = new AbstractLLMConnection.ModelCapabilities();
            capabilities.canGenerateCrossLanguageOutput = true;
            capabilities.handlesSystemMessage = true;
            capabilities.promptDriven = this.modelType.matchesPurpose(AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION);
            if (this.handling != null) {
                capabilities.temperatureRange = this.handling.temperatureRange;
                capabilities.topKRange = this.handling.topKRange;
            }
            return capabilities;
        }
    }

    public static enum CustomSageMakerModelType {
        TEXT_COMPLETION(Arrays.asList(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION, AbstractLLMConnection.LLMUsagePurpose.CLASSIFICATION_WITH_USER_PROVIDED_CLASSES, AbstractLLMConnection.LLMUsagePurpose.EMOTION_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.SENTIMENT_ANALYSIS, AbstractLLMConnection.LLMUsagePurpose.GENERIC_COMPLETION)),
        SUMMARIZATION(Collections.singletonList(AbstractLLMConnection.LLMUsagePurpose.SUMMARIZATION)),
        TEXT_EMBEDDING(Collections.singletonList(AbstractLLMConnection.LLMUsagePurpose.TEXT_EMBEDDING_EXTRACTION)),
        IMAGE_EMBEDDING(Collections.singletonList(AbstractLLMConnection.LLMUsagePurpose.IMAGE_EMBEDDING_EXTRACTION));

        public final List<AbstractLLMConnection.LLMUsagePurpose> matchingPurposes;

        private CustomSageMakerModelType(List<AbstractLLMConnection.LLMUsagePurpose> purposes) {
            this.matchingPurposes = purposes;
        }

        public boolean matchesPurpose(AbstractLLMConnection.LLMUsagePurpose purpose) {
            return this.matchingPurposes.contains((Object)purpose);
        }
    }
}

