/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.bedrock;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.analysis.ml.FullModelId;
import com.dataiku.dip.analysis.ml.llm.LLMSavedModelInfo;
import com.dataiku.dip.analysis.ml.llm.LLMStepwiseTrainingMetrics;
import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.DatasetRowJsonifier;
import com.dataiku.dip.llm.online.RemoteFineTuningClient;
import com.dataiku.dip.llm.online.bedrock.BedrockClient;
import com.dataiku.dip.llm.online.bedrock.Processor;
import com.dataiku.dip.llm.online.bedrock.RawBedrockClient;
import com.dataiku.dip.recipes.nlp.finetuning.FineTuningRecipePayloadParams;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.AmazonS3Client;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.AmazonS3URI;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.BucketLifecycleConfiguration;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.PutObjectRequest;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.PutObjectResult;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.S3Object;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.lifecycle.LifecycleFilter;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.lifecycle.LifecycleFilterPredicate;
import com.dataiku.dss.legacy.aws.com.amazonaws.services.s3.model.lifecycle.LifecyclePrefixPredicate;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonPrimitive;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.HashMap;
import java.util.Optional;

public class BedrockFineTuningClient
implements RemoteFineTuningClient {
    BedrockClient llmClient;
    private static final String MODEL_NAME_PREFIX = "finetuned";
    private static final String BUCKET_BASE_FOLDER = "dataiku-finetuning";
    private static final String BUCKET_TRAINING_FOLDER = "dataiku-finetuning/training";
    private static final String BUCKET_OUTPUT_FOLDER = "dataiku-finetuning/output";
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.bedrock");

    public BedrockFineTuningClient(BedrockClient bedrockClient) {
        this.llmClient = bedrockClient;
    }

    private RawBedrockClient getRawClient() {
        return this.llmClient.getRaw();
    }

    private void createFineTuningBucketIfNotExist(AmazonS3Client s3client, String bucketName) throws Exception {
        if (!s3client.doesBucketExistV2(bucketName)) {
            logger.debug((Object)("createFineTuningBucketIfNotExist: " + bucketName));
            s3client.createBucket(bucketName);
            BucketLifecycleConfiguration.Rule rule = new BucketLifecycleConfiguration.Rule().withId("ExpireAfterOneWeek").withFilter(new LifecycleFilter((LifecycleFilterPredicate)new LifecyclePrefixPredicate(BUCKET_BASE_FOLDER))).withExpirationInDays(7).withStatus("Enabled");
            BucketLifecycleConfiguration lifecycleConfig = new BucketLifecycleConfiguration().withRules(Collections.singletonList(rule));
            s3client.setBucketLifecycleConfiguration(bucketName, lifecycleConfig);
        }
    }

    @Override
    public boolean isChatModel() {
        switch (this.llmClient.getModel().getId()) {
            case "amazon.titan-text-lite-v1": 
            case "amazon.titan-text-express-v1": 
            case "meta.llama3-1-8b-instruct-v1:0": 
            case "meta.llama3-1-70b-instruct-v1:0": {
                return false;
            }
        }
        return true;
    }

    private boolean requiresTextProperty() {
        String modelId = this.llmClient.getModel().getId();
        return modelId == null || !modelId.startsWith("anthropic.claude-3-haiku");
    }

    @Override
    public void assertTrainingFileIsValid(File f) throws Exception {
        try (BufferedReader br = new BufferedReader(new FileReader(f));){
            long cappedCount = br.lines().limit(32L).count();
            if (cappedCount < 32L) {
                throw new Exception(String.format("Training file %s is too short, it should be at least 32 lines long. Got %d", f.getAbsolutePath(), cappedCount));
            }
        }
    }

    @Override
    public String uploadFile(File f, String runId, boolean isTrainingFile) throws Exception {
        AmazonS3Client s3client = this.llmClient.getS3Client();
        String bucketName = this.getRawClient().getFineTuningBucketName();
        this.createFineTuningBucketIfNotExist(s3client, bucketName);
        String fileURI = "s3://" + bucketName + "/dataiku-finetuning/training/" + runId + "/" + (isTrainingFile ? "training.jsonl" : "validation.jsonl");
        AmazonS3URI as3uri = new AmazonS3URI(fileURI);
        logger.debug((Object)("uploadFile: " + fileURI));
        PutObjectResult result = s3client.putObject(new PutObjectRequest(as3uri.getBucket(), as3uri.getKey(), f));
        logger.debug((Object)("raw upload response: " + JSON.json((Object)result)));
        return fileURI;
    }

    @Override
    public void deleteFile(String fileURI) throws Exception {
        logger.debug((Object)("deleteFile: " + fileURI));
        AmazonS3Client s3client = this.llmClient.getS3Client();
        AmazonS3URI as3uri = new AmazonS3URI(fileURI);
        s3client.deleteObject(as3uri.getBucket(), as3uri.getKey());
    }

    private String getCustomModelName(String baseModel, String fineTuningRunID) {
        int availableCharacters = Math.max(0, 63 - MODEL_NAME_PREFIX.length() - fineTuningRunID.length() - 2);
        String baseModelSanitized = baseModel.split(":")[0].replaceAll("[^0-9a-zA-Z_-]", "_");
        if (baseModelSanitized.length() > availableCharacters) {
            baseModelSanitized = baseModelSanitized.substring(0, availableCharacters);
        }
        return "finetuned-" + baseModelSanitized + "-" + fineTuningRunID;
    }

    @Override
    public String startFineTuning(String baseModel, String trainingFileURI, Optional<String> remoteValidationFileURI, String runId, FineTuningRecipePayloadParams.FineTuningHyperparameters hyperparameters, boolean useDefaults) throws Exception {
        String customModelName = this.getCustomModelName(baseModel, runId);
        String jobName = "dataiku-finetuning-" + runId;
        String outputURI = "s3://" + this.getRawClient().getFineTuningBucketName() + "/dataiku-finetuning/output/";
        return this.getRawClient().fineTuneStart(baseModel, customModelName, jobName, trainingFileURI, remoteValidationFileURI, outputURI, hyperparameters, useDefaults);
    }

    @Override
    public void waitForJobCompletion(String jobId) throws Exception {
        RawBedrockClient rawClient = this.getRawClient();
        Thread.sleep(5000L);
        while (true) {
            block6: {
                try {
                    JsonObject ftJob = rawClient.fineTuneGet(jobId);
                    String ftJobStatus = ftJob.get("status").getAsString();
                    logger.debug((Object)("fine-tuning job status: " + ftJobStatus));
                    if (ftJobStatus.equals("Completed")) break;
                    if (ftJobStatus.equals("Failed")) {
                        throw new RemoteFineTuningClient.FinetuningFailedException("fine-tuning failed: " + ftJob.get("failureMessage").getAsString());
                    }
                    if (ftJobStatus.equals("Stopping") || ftJobStatus.equals("Stopped")) {
                        throw new RemoteFineTuningClient.FinetuningCanceledException("fine-tuning canceled");
                    }
                }
                catch (ExternalJSONAPIClient.JSONAPIClientException e) {
                    throw e;
                }
                catch (IOException e) {
                    if (!e.getMessage().contains("Connection reset")) break block6;
                    logger.warn((Object)"Connection reset, retrying");
                }
            }
            Thread.sleep(30000L);
        }
    }

    @Override
    public JsonObject getJob(String jobId) throws Exception {
        return this.getRawClient().fineTuneGet(jobId);
    }

    @Override
    public void enrichSavedModelPostFinetuning(String jobId, EnrichedLLMStructuredRef enrichedLLMRef, JsonObject fineTuningResponse, LLMSavedModelInfo llmSmi, FullModelId fmi) throws Exception {
        llmSmi.llmType = LLMStructuredRef.LLMType.SAVED_MODEL_FINETUNED_BEDROCK;
        llmSmi.remoteJobName = fineTuningResponse.get("jobName").getAsString();
        llmSmi.remoteModelId = fineTuningResponse.get("outputModelArn").getAsString();
        llmSmi.remoteModelName = fineTuningResponse.get("outputModelName").getAsString();
        LLMStepwiseTrainingMetrics llmStepwiseTrainingMetrics = new LLMStepwiseTrainingMetrics();
        JsonObject hyperparameters = fineTuningResponse.getAsJsonObject("hyperParameters");
        if (hyperparameters.get("learningRate") != null) {
            llmSmi.learningRateMultiplier = Float.valueOf(hyperparameters.get("learningRate").getAsFloat());
        }
        if (hyperparameters.get("epochCount") != null) {
            llmSmi.nbEpochs = hyperparameters.get("epochCount").getAsInt();
        }
        if (hyperparameters.get("batchSize") != null) {
            llmSmi.batchSize = hyperparameters.get("batchSize").getAsInt();
        }
        try {
            String s3OutputURI = fineTuningResponse.getAsJsonObject("outputDataConfig").get("s3Uri").getAsString();
            this.fillLLMTrainingMetrics(llmSmi.remoteJobId, s3OutputURI, llmStepwiseTrainingMetrics, llmSmi);
            if (llmSmi.validationDataset != null) {
                this.fillLLMValidationMetrics(llmSmi.remoteJobId, s3OutputURI, llmStepwiseTrainingMetrics);
            }
        }
        catch (Exception e) {
            logger.error((Object)String.format("Error when filling metrics for job %s: %s", llmSmi.remoteModelId, e));
        }
        JSON.prettyToFile((Object)llmStepwiseTrainingMetrics, (File)fmi.getLLMStepwiseTrainingMetricsFile());
    }

    private JsonElement getContentForMessage(String content) {
        if (this.requiresTextProperty()) {
            JsonArray contentArray = new JsonArray();
            JsonObject contentObject = new JsonObject();
            contentObject.addProperty("text", content);
            contentArray.add((JsonElement)contentObject);
            return contentArray;
        }
        return new JsonPrimitive(content);
    }

    private JsonObject createMessage(String role, String content) {
        JsonObject message = new JsonObject();
        message.addProperty("role", role);
        message.add("content", this.getContentForMessage(content));
        return message;
    }

    @Override
    public DatasetRowJsonifier getJsonifyFunction() {
        if (!this.isChatModel()) {
            return this.getJsonifyForPromptModelFunction();
        }
        return (prompt, completion, systemMessage) -> {
            JsonObject res = new JsonObject();
            JsonArray messagesArray = new JsonArray();
            systemMessage.ifPresent(s -> res.add("system", this.getContentForMessage((String)s)));
            res.add("messages", (JsonElement)messagesArray);
            messagesArray.add((JsonElement)this.createMessage("user", prompt));
            messagesArray.add((JsonElement)this.createMessage("assistant", completion));
            res.addProperty("schemaVersion", "bedrock-conversation-2024");
            return res;
        };
    }

    private Optional<Integer> findIdx(String[] headers, String value) {
        for (int idx = 0; idx < headers.length; ++idx) {
            if (!headers[idx].equalsIgnoreCase(value)) continue;
            return Optional.of(idx);
        }
        return Optional.empty();
    }

    private void processPostTrainObject(String jobARN, String s3outputURI, String fileName, LLMStepwiseTrainingMetrics llmStepwiseTrainingMetrics, Processor processor) throws Exception {
        String remoteJobShortId = jobARN.substring(jobARN.lastIndexOf(47) + 1);
        String artifactURI = s3outputURI + "model-customization-job-" + remoteJobShortId + fileName;
        AmazonS3URI as3uri = new AmazonS3URI(artifactURI);
        AmazonS3Client s3client = this.llmClient.getS3Client();
        logger.info((Object)("Downloading fine-tuning metrics file from S3: " + artifactURI));
        S3Object metricsFile = s3client.getObject(as3uri.getBucket(), as3uri.getKey());
        String[] headers = null;
        try (BufferedReader reader = new BufferedReader(new InputStreamReader((InputStream)metricsFile.getObjectContent()));){
            String line = reader.readLine();
            if (line != null) {
                headers = line.split(",");
            }
            while ((line = reader.readLine()) != null) {
                llmStepwiseTrainingMetrics = processor.apply(llmStepwiseTrainingMetrics, line.split(","), headers);
            }
        }
        s3client.deleteObject(as3uri.getBucket(), as3uri.getKey());
    }

    private void fillLLMTrainingMetrics(String jobARN, String s3outputURI, LLMStepwiseTrainingMetrics llmStepwiseTrainingMetrics, LLMSavedModelInfo llmSavedModelInfo) throws Exception {
        HashMap<Integer, LLMStepwiseTrainingMetrics.FineTuningJobMetric> metrics = new HashMap<Integer, LLMStepwiseTrainingMetrics.FineTuningJobMetric>();
        llmStepwiseTrainingMetrics.metrics = metrics;
        llmSavedModelInfo.totalSteps = 0;
        Processor processor = (llmStm, line, headers) -> {
            try {
                Optional<Integer> stepIdx = this.findIdx(headers, "step_number");
                Optional<Integer> trainingLossIdx = this.findIdx(headers, "training_loss");
                Optional<Integer> trainingPerplexityIdx = this.findIdx(headers, "training_perplexity");
                Integer step = Integer.parseInt(line[stepIdx.get()]);
                Optional<Float> perplexityOpt = trainingPerplexityIdx.map(idx -> Float.valueOf(Float.parseFloat(line[idx])));
                LLMStepwiseTrainingMetrics.MetricValue trainingMetricValue = new LLMStepwiseTrainingMetrics.MetricValue(LLMStepwiseTrainingMetrics.MetricType.TRAINING, Float.parseFloat(line[trainingLossIdx.get()]), perplexityOpt);
                LLMStepwiseTrainingMetrics.FineTuningJobMetric metric = new LLMStepwiseTrainingMetrics.FineTuningJobMetric(step, trainingMetricValue);
                metrics.put(step, metric);
                llmSavedModelInfo.totalSteps = Math.max(llmSavedModelInfo.totalSteps, step);
            }
            catch (Exception e) {
                logger.warn((Object)String.format("Error processing line `%s` with headers `%s`: ", String.join((CharSequence)", ", line), String.join((CharSequence)", ", headers), e));
            }
            return llmStm;
        };
        this.processPostTrainObject(jobARN, s3outputURI, "/training_artifacts/step_wise_training_metrics.csv", llmStepwiseTrainingMetrics, processor);
    }

    private void fillLLMValidationMetrics(String jobARN, String s3outputURI, LLMStepwiseTrainingMetrics llmStepwiseTrainingMetrics) throws Exception {
        Processor processor = (llmStm, line, headers) -> {
            try {
                Optional<Integer> stepIdx = this.findIdx(headers, "step_number");
                Optional<Integer> validationLossIdx = this.findIdx(headers, "validation_loss");
                Optional<Integer> validationPerplexityIdx = this.findIdx(headers, "validation_perplexity");
                llmStepwiseTrainingMetrics.addValidationMetric(Integer.parseInt(line[stepIdx.get()]), Float.valueOf(Float.parseFloat(line[validationLossIdx.get()])), validationPerplexityIdx.isPresent() ? Optional.of(Float.valueOf(Float.parseFloat(line[validationPerplexityIdx.get()]))) : Optional.empty());
            }
            catch (Exception e) {
                logger.warn((Object)String.format("Error processing line `%s` with headers `%s`: ", String.join((CharSequence)", ", line), String.join((CharSequence)", ", headers), e));
            }
            return llmStm;
        };
        this.processPostTrainObject(jobARN, s3outputURI, "/validation_artifacts/post_fine_tuning_validation/validation/validation_metrics.csv", llmStepwiseTrainingMetrics, processor);
    }

    @Override
    public void cancelJob(String jobId) throws Exception {
        this.getRawClient().fineTuneCancel(jobId);
    }

    @Override
    public void deleteFinetunedModel(String modelId, String jobId) throws Exception {
        this.getRawClient().deleteFinetunedModel(modelId);
    }

    @Override
    public String getBaseModel() {
        String baseModelId = this.llmClient.getModelId();
        JsonArray fineTuningModels = this.getRawClient().listFoundationModelsForFineTuning();
        for (JsonElement model : fineTuningModels) {
            String modelId = model.getAsJsonObject().get("modelId").getAsString();
            if (!modelId.startsWith(baseModelId)) continue;
            return modelId;
        }
        return baseModelId;
    }
}

