/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.resourceusage;

import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.DKUApp;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.connections.ConnectionsDAO;
import com.dataiku.dip.connections.DSSConnection;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsageContext;
import com.dataiku.dip.server.services.PubSubService;
import com.dataiku.dip.server.services.TransactionService;
import com.dataiku.dip.transactions.fs.RelFile;
import com.dataiku.dip.transactions.fs.ifaces.ReadWriteFS;
import com.dataiku.dip.transactions.fs.utils.NativeCache;
import com.dataiku.dip.transactions.ifaces.Transaction;
import com.dataiku.dip.utils.DKUDateUtils;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dss.shadelib.org.apache.commons.codec.digest.DigestUtils;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import javax.annotation.PostConstruct;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class BackendComputeResourceUsageAggregationService {
    @Autowired
    PubSubService pubSubService;
    @Autowired
    private ConnectionsDAO connectionsDAO;
    @Autowired
    private TransactionService transactionsService;
    private static final RelFile dataFile = new RelFile(RelFile.global((String)"caches"), new String[]{"cru-aggregation.json"});
    private ReadWriteFS cache;
    private ComputeResourceUsageAggregation data;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.cru.aggregation");

    @PostConstruct
    public void init() throws IOException {
        if (!DKUApp.getParams().getBoolParam("dku.usageReporting.cruAggregation.enabled", true)) {
            return;
        }
        this.cache = NativeCache.build((File)ApplicationConfigurator.getBaseFolderF());
        try {
            this.data = (ComputeResourceUsageAggregation)this.cache.readObjectDefault(dataFile, ComputeResourceUsageAggregation.class);
            if (this.data == null) {
                this.data = new ComputeResourceUsageAggregation();
            }
        }
        catch (Exception e) {
            this.data = new ComputeResourceUsageAggregation();
        }
        Executors.newSingleThreadScheduledExecutor(new ThreadFactoryBuilder().setNameFormat("cru-aggregation-flusher-%d").build()).scheduleAtFixedRate(() -> {
            logger.debug((Object)"Flushing CRU");
            BackendComputeResourceUsageAggregationService backendComputeResourceUsageAggregationService = this;
            synchronized (backendComputeResourceUsageAggregationService) {
                long minDate = System.currentTimeMillis() - DKUApp.getParams().getLongParam("dku.usageReporting.cruAggregation.maxDays", 180L) * 86400L * 1000L;
                ArrayList<String> daysToDelete = new ArrayList<String>();
                for (String day : this.data.days.keySet()) {
                    long dayDate = DKUDateUtils.parseISOUTC((String)(day + "T00:00:00.000Z"));
                    if (dayDate >= minDate) continue;
                    logger.debug((Object)("Pruning aggregation for day: " + day));
                    daysToDelete.add(day);
                }
                for (String dayToDelete : daysToDelete) {
                    this.data.days.remove(dayToDelete);
                }
                try {
                    this.cache.writeObject(dataFile, (Object)this.data);
                }
                catch (IOException e) {
                    logger.warn((Object)"Failed to flush CRU aggregation data", (Throwable)e);
                }
            }
        }, 1L, DKUApp.getParams().getIntParam("dku.usageReporting.cruAggregation.flushIntervalMinutes", Integer.valueOf(1)), TimeUnit.MINUTES);
        this.pubSubService.subscribe("compute-resource-usage-to-aggregate", evt -> {
            logger.trace(() -> "Aggregate: " + JSON.json((Object)evt.cru));
            ComputeResourceUsageContext safeCtx = evt.cru.context;
            if (safeCtx == null) {
                safeCtx = new ComputeResourceUsageContext();
                safeCtx.type = ComputeResourceUsageContext.ComputeResourceUsageContextType.UNKNOWN;
            }
            if (evt.cru.type != null) {
                switch (evt.cru.type) {
                    case SQL_QUERY: {
                        this.aggregateSQLQuery(evt.cru, safeCtx);
                        break;
                    }
                    case SQL_CONNECTION: {
                        this.aggregateSQLConnection(evt.cru, safeCtx);
                        break;
                    }
                    case LLM_USAGE: {
                        this.aggregateLLMUsage(evt.cru, safeCtx);
                        break;
                    }
                }
            }
        });
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void aggregateSQLQuery(ComputeResourceUsage cru, ComputeResourceUsageContext safeContext) throws IOException {
        ComputeResourceUsage.SQLQueryUsageData squd = cru.sqlQuery;
        if (squd == null) {
            return;
        }
        String type = this.getSQLType(squd.connection);
        BackendComputeResourceUsageAggregationService backendComputeResourceUsageAggregationService = this;
        synchronized (backendComputeResourceUsageAggregationService) {
            SQLTypeResourceConsumption typeConsumption = this.getSQLByType(type);
            this.registerLoginAndProject(typeConsumption, safeContext);
            ++typeConsumption.totalQueries;
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.queriesPerContextType, safeContext.type, 1L);
            if (cru.endTime - cru.startTime > 0L && cru.endTime - cru.startTime < 604800000L) {
                typeConsumption.totalQueriesTimeMS += cru.endTime - cru.startTime;
                BackendComputeResourceUsageAggregationService.inc(typeConsumption.queriesTimePerContextType, safeContext.type, cru.endTime - cru.startTime);
            }
            if (squd.fetchedRowCount != null && squd.fetchedRowCount > 0L) {
                typeConsumption.totalFetchedRows += squd.fetchedRowCount.longValue();
                BackendComputeResourceUsageAggregationService.inc(typeConsumption.fetchedRowsPerContextType, safeContext.type, squd.fetchedRowCount);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void aggregateSQLConnection(ComputeResourceUsage cru, ComputeResourceUsageContext safeContext) throws IOException {
        ComputeResourceUsage.SQLConnectionUsageData scud = cru.sqlConnection;
        if (scud == null) {
            return;
        }
        String type = this.getSQLType(scud.connection);
        BackendComputeResourceUsageAggregationService backendComputeResourceUsageAggregationService = this;
        synchronized (backendComputeResourceUsageAggregationService) {
            SQLTypeResourceConsumption typeConsumption = this.getSQLByType(type);
            this.registerLoginAndProject(typeConsumption, safeContext);
            ++typeConsumption.totalConnections;
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.connectionsPerContextType, safeContext.type, 1L);
            if (cru.endTime - cru.startTime > 0L && cru.endTime - cru.startTime < 604800000L) {
                typeConsumption.totalConnectionsTimeMS += cru.endTime - cru.startTime;
                BackendComputeResourceUsageAggregationService.inc(typeConsumption.connectionTimePerContextType, safeContext.type, cru.endTime - cru.startTime);
            }
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void aggregateLLMUsage(ComputeResourceUsage cru, ComputeResourceUsageContext safeContext) throws IOException {
        ComputeResourceUsage.LLMUsageData lud = cru.llmUsage;
        if (lud == null) {
            return;
        }
        String mainType = lud.llmType;
        Map<Object, Object> conns = new HashMap();
        try (Transaction t2 = this.transactionsService.beginRead();){
            conns = this.connectionsDAO.listUnsafe();
        }
        catch (IOException t2) {
            // empty catch block
        }
        DSSConnection conn = (DSSConnection)conns.get(lud.connection);
        if (conn == null || !(conn instanceof AbstractLLMConnection)) {
            return;
        }
        AbstractLLMConnection llmConn = (AbstractLLMConnection)conn;
        String detailedTypeForAggregation = llmConn.getModelDetailedTypeForCRUAggregation(lud.llmId);
        BackendComputeResourceUsageAggregationService backendComputeResourceUsageAggregationService = this;
        synchronized (backendComputeResourceUsageAggregationService) {
            LLMTypeResourceConsumption typeConsumption = this.getLLMByType(detailedTypeForAggregation);
            this.registerLoginAndProject(typeConsumption, safeContext);
            ++typeConsumption.totalCompletionReports;
            typeConsumption.totalCompletionQueries += lud.totalQueries;
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionQueriesPerContextType, safeContext.type, lud.totalQueries);
            typeConsumption.totalCompletionCacheHits += lud.cacheHitQueries;
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionCacheHitsPerContextType, safeContext.type, lud.cacheHitQueries);
            typeConsumption.totalCompletionCacheMisses += lud.cacheMissQueries;
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionCacheMissesPerContextType, safeContext.type, lud.cacheMissQueries);
            typeConsumption.totalCompletionPromptTokens += lud.getTotalPromptTokens();
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionsPromptTokensPerContextType, safeContext.type, lud.getTotalPromptTokens());
            typeConsumption.totalCompletionCompletionTokens += lud.getTotalCompletionTokens();
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionsCompletionTokensPerContextType, safeContext.type, lud.getTotalCompletionTokens());
            typeConsumption.totalCompletionCostUSD += lud.getEstimatedCostUSD();
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionCostUSDPerContextType, safeContext.type, lud.getEstimatedCostUSD());
            typeConsumption.totalCompletionComputationTimeMS += lud.getTotalComputationTimeMS();
            BackendComputeResourceUsageAggregationService.inc(typeConsumption.completionComputationTimeMSPerContextType, safeContext.type, lud.getTotalComputationTimeMS());
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public ComputeResourceUsageAggregation getData() {
        BackendComputeResourceUsageAggregationService backendComputeResourceUsageAggregationService = this;
        synchronized (backendComputeResourceUsageAggregationService) {
            return (ComputeResourceUsageAggregation)JSON.deepCopy((Object)this.data);
        }
    }

    private void registerLoginAndProject(WithDistinctLoginsAndProjects ret, ComputeResourceUsageContext safeContext) {
        String loginH = safeContext.authIdentifier != null ? DigestUtils.md5Hex((String)safeContext.authIdentifier.toLowerCase(Locale.ENGLISH)) : "unknown_user";
        String projectKeyH = safeContext.projectKey != null ? DigestUtils.md5Hex((String)safeContext.projectKey) : "unknown_project";
        ret.distinctLoginHs.add(loginH);
        ret.distinctProjectHs.add(projectKeyH);
    }

    private String getSQLType(String connectionName) {
        Map<Object, Object> conns = new HashMap();
        try (Transaction t2 = this.transactionsService.beginRead();){
            conns = this.connectionsDAO.listUnsafe();
        }
        catch (IOException t2) {
            // empty catch block
        }
        DSSConnection conn = (DSSConnection)conns.get(connectionName);
        return conn == null ? "unknown_sql" : conn.type;
    }

    private SQLTypeResourceConsumption getSQLByType(String type) {
        assert (Thread.holdsLock(this));
        ComputeResourceUsageAggregationDay dayData = this.getCurrentDay();
        SQLTypeResourceConsumption typeConsumption = dayData.sqlByType.get(type);
        if (typeConsumption == null) {
            typeConsumption = new SQLTypeResourceConsumption();
            dayData.sqlByType.put(type, typeConsumption);
        }
        return typeConsumption;
    }

    private LLMTypeResourceConsumption getLLMByType(String type) {
        assert (Thread.holdsLock(this));
        ComputeResourceUsageAggregationDay dayData = this.getCurrentDay();
        LLMTypeResourceConsumption typeConsumption = dayData.llmByType.get(type);
        if (typeConsumption == null) {
            typeConsumption = new LLMTypeResourceConsumption();
            dayData.llmByType.put(type, typeConsumption);
        }
        return typeConsumption;
    }

    private ComputeResourceUsageAggregationDay getCurrentDay() {
        assert (Thread.holdsLock(this));
        String day = DKUDateUtils.isoFormatLocalNow().substring(0, 10);
        ComputeResourceUsageAggregationDay dayData = this.data.days.get(day);
        if (dayData == null) {
            dayData = new ComputeResourceUsageAggregationDay();
            this.data.days.put(day, dayData);
        }
        return dayData;
    }

    private static void inc(Map<String, Long> map, Enum<?> key, Long incValue) {
        Long prev = map.get(key.toString());
        if (prev == null) {
            prev = 0L;
        }
        map.put(key.toString(), prev + incValue);
    }

    private static void inc(Map<String, Double> map, Enum<?> key, Double incValue) {
        Double prev = map.get(key.toString());
        if (prev == null) {
            prev = 0.0;
        }
        map.put(key.toString(), prev + incValue);
    }

    public static class ComputeResourceUsageAggregation {
        String version = "1";
        Map<String, ComputeResourceUsageAggregationDay> days = new HashMap<String, ComputeResourceUsageAggregationDay>();
    }

    static class SQLTypeResourceConsumption
    extends WithDistinctLoginsAndProjects {
        public long totalConnections;
        public Map<String, Long> connectionsPerContextType = new HashMap<String, Long>();
        public long totalQueries;
        public Map<String, Long> queriesPerContextType = new HashMap<String, Long>();
        public long totalConnectionsTimeMS;
        public Map<String, Long> connectionTimePerContextType = new HashMap<String, Long>();
        public long totalQueriesTimeMS;
        public Map<String, Long> queriesTimePerContextType = new HashMap<String, Long>();
        public long totalFetchedRows;
        public Map<String, Long> fetchedRowsPerContextType = new HashMap<String, Long>();

        SQLTypeResourceConsumption() {
        }
    }

    static class WithDistinctLoginsAndProjects {
        Set<String> distinctLoginHs = new HashSet<String>();
        Set<String> distinctProjectHs = new HashSet<String>();

        WithDistinctLoginsAndProjects() {
        }
    }

    static class LLMTypeResourceConsumption
    extends WithDistinctLoginsAndProjects {
        public long totalCompletionReports;
        public long totalCompletionQueries;
        public Map<String, Long> completionQueriesPerContextType = new HashMap<String, Long>();
        public long totalCompletionCacheHits;
        public Map<String, Long> completionCacheHitsPerContextType = new HashMap<String, Long>();
        public long totalCompletionCacheMisses;
        public Map<String, Long> completionCacheMissesPerContextType = new HashMap<String, Long>();
        public long totalCompletionPromptTokens;
        public Map<String, Long> completionsPromptTokensPerContextType = new HashMap<String, Long>();
        public long totalCompletionCompletionTokens;
        public Map<String, Long> completionsCompletionTokensPerContextType = new HashMap<String, Long>();
        public double totalCompletionCostUSD;
        public Map<String, Double> completionCostUSDPerContextType = new HashMap<String, Double>();
        public long totalCompletionComputationTimeMS;
        public Map<String, Long> completionComputationTimeMSPerContextType = new HashMap<String, Long>();

        LLMTypeResourceConsumption() {
        }
    }

    static class ComputeResourceUsageAggregationDay {
        Map<String, SQLTypeResourceConsumption> sqlByType = new HashMap<String, SQLTypeResourceConsumption>();
        Map<String, LLMTypeResourceConsumption> llmByType = new HashMap<String, LLMTypeResourceConsumption>();

        ComputeResourceUsageAggregationDay() {
        }
    }
}

