/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.aigenerations;

import com.dataiku.common.rpc.InternalAPIClient;
import com.dataiku.common.server.APIError;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.dao.GeneralSettingsDAO;
import com.dataiku.dip.exceptions.DKUSecurityException;
import com.dataiku.dip.futures.FuturePayload;
import com.dataiku.dip.futures.FutureResponse;
import com.dataiku.dip.futures.FutureService;
import com.dataiku.dip.futures.SimpleFutureThread;
import com.dataiku.dip.license.LicenseStatusService;
import com.dataiku.dip.security.AuthCtx;
import com.dataiku.dip.semanticsearch.SemanticSearchFacets;
import com.dataiku.dip.semanticsearch.SemanticSearchMessage;
import com.dataiku.dip.semanticsearch.SemanticSearchService;
import com.dataiku.dip.util.AIFeaturesUtil;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.j2ts.annotations.UIModel;
import com.google.common.base.Stopwatch;
import com.google.gson.reflect.TypeToken;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.log4j.NDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

@Service
public class AISemanticSearchQueryAugmentationService {
    @Autowired
    private FutureService futureService;
    @Autowired
    private LicenseStatusService licenseStatusService;
    private static final DKULogger logger = DKULogger.getLogger(AISemanticSearchQueryAugmentationService.class);

    public FutureResponse<AugmentQueryResponse> augmentQuery(AuthCtx authCtx, String query, SemanticSearchFacets allAvailableFacets, List<SemanticSearchMessage> previousMessages, String conversationId) throws Exception {
        LicenseStatusService.LicensingStatus licensingStatus = this.licenseStatusService.getLicensingStatus();
        SemanticSearchService.checkAllowedToUse(licensingStatus);
        QueryAugmentationFutureThread futureThread = new QueryAugmentationFutureThread(authCtx, query, allAvailableFacets, licensingStatus, previousMessages, conversationId);
        return this.futureService.runFuture(futureThread, 0L, new TypeToken<FutureResponse<AugmentQueryResponse>>(){});
    }

    private static class QueryAugmentationFutureThread
    extends SimpleFutureThread<AugmentQueryResponse> {
        private final String query;
        private final LicenseStatusService.LicensingStatus licensingStatus;
        private final GeneralSettingsDAO.GeneralSettings generalSettings;
        protected final AuthCtx authCtx;
        private final List<SemanticSearchMessage> previousMessages;
        private final String conversationId;
        private final SemanticSearchFacets allAvailableFacets;

        public QueryAugmentationFutureThread(AuthCtx auth, String query, SemanticSearchFacets allAvailableFacets, LicenseStatusService.LicensingStatus licensingStatus, List<SemanticSearchMessage> previousMessages, String conversationId) {
            super(auth);
            this.authCtx = auth;
            this.query = query;
            this.licensingStatus = licensingStatus;
            this.generalSettings = ApplicationConfigurator.getGeneralSettingsUnsafeAutoTXN();
            this.previousMessages = previousMessages;
            this.conversationId = conversationId;
            this.allAvailableFacets = allAvailableFacets;
        }

        public FuturePayload getPayload() {
            return FuturePayload.newSimple((String)"ai_semantic_search_query_augmentation", (String)"AI Search Query Augmentation");
        }

        @Override
        protected AugmentQueryResponse compute() throws InterruptedException {
            NDC.push((String)"ai-semantic-search-augment-query");
            Stopwatch stopwatch = Stopwatch.createStarted();
            AugmentQueryRequest backendQuery = new AugmentQueryRequest();
            backendQuery.query = this.query;
            backendQuery.previousMessages = this.previousMessages;
            backendQuery.licenseId = this.licensingStatus != null && this.licensingStatus.licenseContent != null ? this.licensingStatus.licenseContent.licenseId : null;
            backendQuery.conversationId = this.conversationId;
            backendQuery.telemetryEnabled = this.generalSettings.aiDrivenAnalyticsSettings.telemetryEnabled;
            backendQuery.allAvailableFacets = this.allAvailableFacets;
            logger.infoV("Calling AI Server to augment query. ConvId: [%s], Facets provided: [%s]", new Object[]{this.conversationId, this.allAvailableFacets});
            try {
                AugmentQueryResponse augmentQueryResponse;
                block14: {
                    InternalAPIClient apiClient = AIFeaturesUtil.getAiServerAPIClient(this.authCtx, this.generalSettings, AIFeaturesUtil.CONNECTION_TIMEOUT, AIFeaturesUtil.SOCKET_TIMEOUT);
                    try {
                        AugmentQueryResponse response = (AugmentQueryResponse)apiClient.postObject("/semantic-search/augment-query", AugmentQueryResponse.class, (Object)backendQuery);
                        this.logResponseOutcome(response, stopwatch);
                        augmentQueryResponse = response;
                        if (apiClient == null) break block14;
                    }
                    catch (Throwable response) {
                        try {
                            if (apiClient != null) {
                                try {
                                    apiClient.close();
                                }
                                catch (Throwable throwable) {
                                    response.addSuppressed(throwable);
                                }
                            }
                            throw response;
                        }
                        catch (APIError.APIErrorException e) {
                            AugmentQueryResponse response2 = new AugmentQueryResponse();
                            response2.ok = false;
                            response2.error = "Failed to augment query via AI Server: " + e.getMessage();
                            logger.errorV("AI Server augmentation failed. Error: %s. AI-server call took %d milliseconds", new Object[]{e.getMessage(), stopwatch.elapsed(TimeUnit.MILLISECONDS)});
                            AugmentQueryResponse augmentQueryResponse2 = response2;
                            return augmentQueryResponse2;
                        }
                        catch (DKUSecurityException | IOException e) {
                            if (Thread.currentThread().isInterrupted()) {
                                logger.warnV("Query augmentation interrupted for ConvId: [%s]. AI-server call took %d milliseconds", new Object[]{this.conversationId, stopwatch.elapsed(TimeUnit.MILLISECONDS)});
                                throw new InterruptedException("Thread: " + Thread.currentThread().getName() + " got interrupted when executing: " + this.query);
                            }
                            logger.errorV(e, "AI Server request failed for ConvId: [%s]. AI-server call took %d milliseconds", new Object[]{this.conversationId, stopwatch.elapsed(TimeUnit.MILLISECONDS)});
                            throw new RuntimeException("Failed to augment query via AI Server", e);
                        }
                    }
                    apiClient.close();
                }
                return augmentQueryResponse;
            }
            finally {
                NDC.pop();
            }
        }

        private int size(List<?> list) {
            return list == null ? 0 : list.size();
        }

        private void logResponseOutcome(AugmentQueryResponse resp, Stopwatch stopwatch) {
            if (resp == null) {
                logger.warnV("AI Server returned null response for ConvId: [%s]. AI-server call took %d milliseconds", new Object[]{this.conversationId, stopwatch.elapsed(TimeUnit.MILLISECONDS)});
                return;
            }
            if (!resp.ok) {
                logger.errorV("AI Server augmentation failed. Error: %s. AI-server call took %d milliseconds", new Object[]{resp.error, stopwatch.elapsed(TimeUnit.MILLISECONDS)});
                return;
            }
            String reform = resp.reformulation != null ? String.format("Valid: %b, NewQuery: %s", resp.reformulation.isValidSearch, resp.reformulation.reformulatedQuery) : "N/A";
            int inclusiveKeywordCount = 0;
            int exclusiveKeywordCount = 0;
            if (resp.extractedKeywords != null) {
                inclusiveKeywordCount = this.size(resp.extractedKeywords.datasetNames) + this.size(resp.extractedKeywords.columnNames) + this.size(resp.extractedKeywords.tags) + this.size(resp.extractedKeywords.projectKeys) + this.size(resp.extractedKeywords.descriptionKeywords) + this.size(resp.extractedKeywords.dataStewards) + this.size(resp.extractedKeywords.datasetTypes) + this.size(resp.extractedKeywords.connectionTypes) + this.size(resp.extractedKeywords.connectionNames) + this.size(resp.extractedKeywords.catalogNames) + this.size(resp.extractedKeywords.schemaNames);
                exclusiveKeywordCount = this.size(resp.extractedKeywords.excludedCatalogNames) + this.size(resp.extractedKeywords.excludedConnectionNames) + this.size(resp.extractedKeywords.excludedConnectionTypes) + this.size(resp.extractedKeywords.excludedDatasetNames) + this.size(resp.extractedKeywords.excludedDatasetTypes) + this.size(resp.extractedKeywords.excludedDataStewards) + this.size(resp.extractedKeywords.excludedProjectKeys) + this.size(resp.extractedKeywords.excludedSchemaNames) + this.size(resp.extractedKeywords.excludedTags);
            }
            logger.infoV("Augmentation successful. Reformulation: [%s], Total inclusive Keywords: [%d], Total exclusive Keywords [%d]. AI-server call took %d milliseconds", new Object[]{reform, inclusiveKeywordCount, exclusiveKeywordCount, stopwatch.elapsed(TimeUnit.MILLISECONDS)});
        }
    }

    @UIModel
    public static class AugmentQueryResponse {
        public boolean ok;
        public String error;
        public QueryReformulationResponse reformulation;
        public KeywordExtractionResponse extractedKeywords;
    }

    public static class KeywordExtractionResponse {
        public List<LuceneKeyword> datasetNames;
        public List<LuceneKeyword> columnNames;
        public List<LuceneKeyword> tags;
        public List<LuceneKeyword> projectKeys;
        public List<LuceneKeyword> dataStewards;
        public List<LuceneKeyword> descriptionKeywords;
        @Nullable
        public Boolean isInDataCollection;
        @Nullable
        public Boolean isPartitioned;
        @Nullable
        public List<String> type;
        public List<String> datasetTypes;
        public List<String> catalogNames;
        public List<String> connectionNames;
        public List<String> connectionTypes;
        public List<String> schemaNames;
        public List<String> excludedDatasetNames;
        public List<String> excludedDatasetTypes;
        public List<String> excludedCatalogNames;
        public List<String> excludedConnectionNames;
        public List<String> excludedConnectionTypes;
        public List<String> excludedSchemaNames;
        public List<String> excludedTags;
        public List<String> excludedProjectKeys;
        public List<String> excludedDataStewards;
    }

    public static class QueryReformulationResponse {
        public boolean isValidSearch;
        public String reformulatedQuery;
    }

    public static class LuceneKeyword {
        public String value;
        public float boost;
    }

    private static class AugmentQueryRequest {
        String query;
        SemanticSearchFacets allAvailableFacets;
        List<SemanticSearchMessage> previousMessages;
        String licenseId;
        String conversationId;
        boolean telemetryEnabled;

        private AugmentQueryRequest() {
        }
    }
}

