/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online;

import com.dataiku.dip.llm.EnrichedLLMStructuredRef;
import com.dataiku.dip.llm.LLMStructuredRef;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.utils.AgentTrajectoryService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JSON;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Objects;

public class LLMTracingUtils {
    public static void addLLMIdentifiers(LLMClient.LLMMeshTraceObservation observation, String llmId, LLMClient llmClient) {
        observation.attributes.addProperty("llmId", llmId);
        try {
            EnrichedLLMStructuredRef ref;
            if (llmClient != null && (ref = llmClient.getEnrichedRef()) != null && ref.type != null) {
                observation.attributes.addProperty("llmProvider", ref.type.toString());
                observation.attributes.addProperty("llmModel", ref.getModelNameForAudit());
            }
        }
        catch (Exception exception) {
            // empty catch block
        }
    }

    public static void setCompletionInput(LLMClient.LLMMeshTraceSpan span, LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        span.setCompletionLLMInput(query);
        span.attributes.add("completionQuery", (JsonElement)LLMChatMessageUtils.completionQueryToLightJsonObject(query));
        span.attributes.add("completionQuerySettings", (JsonElement)JSON.toJsonObject((Object)settings, (String[])new String[0]));
    }

    public static void addIdentifiersAndSetCompletionInput(LLMClient.LLMMeshTraceSpan span, String llmId, LLMClient client, LLMClient.SingleCompletionQuery query, LLMClient.CompletionSettings settings) {
        LLMTracingUtils.addLLMIdentifiers(span, llmId, client);
        LLMTracingUtils.setCompletionInput(span, query, settings);
    }

    public static void setEmbeddingInput(LLMClient.LLMMeshTraceSpan span, LLMClient.EmbeddingQuery query, LLMClient.EmbeddingSettings settings) {
        span.setEmbeddingLLMInput(query);
        span.attributes.add("embeddingQuery", (JsonElement)LLMChatMessageUtils.embeddingQueryToLightJsonObject(query));
        span.attributes.add("embeddingQuerySettings", (JsonElement)JSON.toJsonObject((Object)settings, (String[])new String[0]));
    }

    public static void setRerankingInput(LLMClient.LLMMeshTraceSpan span, LLMClient.RerankingQuery query, LLMClient.RerankingSettings settings) {
        span.setRerankingLLMInput(query);
        span.attributes.add("rerankingQuery", (JsonElement)LLMChatMessageUtils.rerankingQueryToLightJsonObject(query));
        span.attributes.add("rerankingQuerySettings", (JsonElement)JSON.toJsonObject((Object)settings, (String[])new String[0]));
    }

    public static void setCompletionOutput(LLMClient.LLMMeshTraceSpan span, LLMClient.SimpleCompletionResponseOrError r) {
        span.setLLMOutput(r);
        JsonObject completionResponseJO = JSON.toJsonObject((Object)r, (String[])new String[0]);
        if (completionResponseJO.has("trace")) {
            completionResponseJO.remove("trace");
        }
        span.attributes.add("completionResponse", (JsonElement)completionResponseJO);
    }

    public static void addUsageMetadataAndSetCompletionOutput(LLMClient.LLMMeshTraceSpan span, LLMClient.SimpleCompletionResponseOrError r) {
        LLMTracingUtils.setCompletionOutput(span, r);
        span.usageMetadata = new LLMClient.UsageMetadata(r);
    }

    public static void setEmbeddingOutput(LLMClient.LLMMeshTraceSpan span, LLMClient.SimpleEmbeddingResponseOrError result) {
        span.attributes.add("embeddingResponse", (JsonElement)JF.obj().with("ok", Boolean.valueOf(result.ok)).with("promptTokens", (Number)result.promptTokens).with("tokenCountsAreEstimated", result.tokenCountsAreEstimated).with("estimatedCost", (Number)result.estimatedCost).with("additionalInformation", (JsonElement)result.additionalInformation).get());
    }

    public static void setRerankingOutput(LLMClient.LLMMeshTraceSpan span, LLMClient.SingleRerankingResponseOrError result) {
        span.attributes.add("rerankingResponse", (JsonElement)JF.obj().with("ok", Boolean.valueOf(result.ok)).with("estimatedCost", (Number)result.estimatedCost).get());
    }

    public static void addUsageMetadataAndSetEmbeddingOutput(LLMClient.LLMMeshTraceSpan span, LLMClient.SimpleEmbeddingResponseOrError r) {
        LLMTracingUtils.setEmbeddingOutput(span, r);
        span.usageMetadata = new LLMClient.UsageMetadata(r);
    }

    public static void addUsageMetadataAndSetRerankingOutput(LLMClient.LLMMeshTraceSpan span, LLMClient.SingleRerankingResponseOrError r) {
        LLMTracingUtils.setRerankingOutput(span, r);
        span.usageMetadata = new LLMClient.UsageMetadata(r);
    }

    public static void addUsageMetadataFromError(LLMClient.LLMMeshTraceSpan span, Exception e) {
        if (e instanceof LLMClient.LLMException) {
            LLMClient.LLMException llmException = (LLMClient.LLMException)e;
            span.usageMetadata = new LLMClient.UsageMetadata(llmException);
        }
    }

    public static LLMClient.TotalUsage getTotalUsage(LLMClient.LLMMeshTraceObservation observation) {
        return LLMTracingUtils.getTotalUsage(observation, null);
    }

    private static LLMClient.TotalUsage getTotalUsage(LLMClient.LLMMeshTraceObservation observation, LLMClient.TotalUsageSpanType parentType) {
        LLMClient.LLMMeshTraceEvent event;
        LLMClient.TotalUsageSpanType aggregationType = LLMTracingUtils.getAggregationType(observation, parentType);
        LLMClient.TotalUsage totalUsage = new LLMClient.TotalUsage();
        if (observation instanceof LLMClient.LLMMeshTraceSpan) {
            LLMClient.LLMMeshTraceSpan span = (LLMClient.LLMMeshTraceSpan)observation;
            boolean updateIsNeeded = false;
            if (LLMTracingUtils.isLLMMeshCall(span)) {
                totalUsage.llmMeshCalls = 1;
                updateIsNeeded = true;
            }
            if (span.usageMetadata != null) {
                totalUsage.promptTokens = span.usageMetadata.promptTokens;
                totalUsage.completionTokens = span.usageMetadata.completionTokens;
                totalUsage.totalTokens = span.usageMetadata.totalTokens;
                totalUsage.estimatedCost = span.usageMetadata.estimatedCost;
                totalUsage.images = span.usageMetadata.images;
                Object[] variables = new Object[]{span.usageMetadata.promptTokens, span.usageMetadata.completionTokens, span.usageMetadata.totalTokens, span.usageMetadata.estimatedCost, span.usageMetadata.images};
                if (Arrays.stream(variables).anyMatch(Objects::nonNull)) {
                    updateIsNeeded = true;
                }
            }
            LLMTracingUtils.updateDetailsIfNeeded(aggregationType, totalUsage, updateIsNeeded);
            for (LLMClient.LLMMeshTraceObservation child : span.children) {
                totalUsage.aggregate(LLMTracingUtils.getTotalUsage(child, aggregationType));
            }
        } else if (observation instanceof LLMClient.LLMMeshTraceEvent && LLMTracingUtils.isCacheHit(event = (LLMClient.LLMMeshTraceEvent)observation)) {
            totalUsage.llmCacheHits = LLMClient.TotalUsage.safeAdd(totalUsage.llmCacheHits, 1);
            LLMTracingUtils.updateDetailsIfNeeded(aggregationType, totalUsage, true);
        }
        return totalUsage;
    }

    private static void updateDetailsIfNeeded(LLMClient.TotalUsageSpanType aggregationType, LLMClient.TotalUsage totalUsage, boolean updateIsNeeded) {
        if (aggregationType != null && updateIsNeeded) {
            totalUsage.details = new HashMap<LLMClient.TotalUsageSpanType, LLMClient.TotalUsage>();
            totalUsage.details.put(aggregationType, new LLMClient.TotalUsage(totalUsage));
        }
    }

    private static boolean isLLMMeshCall(LLMClient.LLMMeshTraceSpan span) {
        return LLMClient.TotalUsageSpanType.CALL_SPAN_NAMES.contains(span.name);
    }

    private static boolean isCacheHit(LLMClient.LLMMeshTraceObservation observation) {
        return "DKU_LLM_MESH_CACHE_HIT".equals(observation.name);
    }

    private static LLMClient.TotalUsageSpanType getAggregationType(LLMClient.LLMMeshTraceObservation observation, LLMClient.TotalUsageSpanType parentType) {
        LLMClient.TotalUsageSpanType currentType = LLMClient.TotalUsageSpanType.from(observation);
        if (currentType == null || parentType != null && parentType.priority.ordinal() > currentType.priority.ordinal()) {
            return parentType;
        }
        return currentType;
    }

    public static LLMClient.SimpleCompletionResponseOrError enrichResponseFromTraceData(AgentTrajectoryService agentTrajectoryService, AuthCtx authCtx, String projectKey, LLMClient.SimpleCompletionResponseOrError response, LLMClient.LLMMeshTraceSpan trace, LLMStructuredRef llmRef, LLMClient.CompletionSettings completionSettings) {
        response.totalUsage = LLMTracingUtils.getTotalUsage(trace);
        if (completionSettings != null && completionSettings.outputTrajectory.booleanValue()) {
            response = agentTrajectoryService.enrichResponseWithTrajectoryIfNeeded(authCtx, projectKey, response, trace, llmRef);
        }
        return response;
    }

    public static void enrichFooterFromTraceData(AgentTrajectoryService agentTrajectoryService, AuthCtx authCtx, String projectKey, LLMClient.StreamedCompletionResponseFooter footer, LLMStructuredRef llmRef, LLMClient.CompletionSettings completionSettings) {
        footer.totalUsage = LLMTracingUtils.getTotalUsage(footer.trace);
        if (completionSettings != null && completionSettings.outputTrajectory.booleanValue()) {
            agentTrajectoryService.enrichFooterWithTrajectoryIfNeeded(authCtx, projectKey, footer, footer.trace, llmRef);
        }
    }
}

