/*
 * Decompiled with CFR 0.152.
 */
package com.dataiku.dip.llm.online.huggingface;

import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.ProxySettings;
import com.dataiku.dip.connections.AbstractLLMConnection;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.marshall.CoreCompletionSettings;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericLLMHandling;
import com.dataiku.dip.llm.online.sagemakergeneric.GenericTextCompletionLLMMarshall;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dip.utils.DKULogger;
import com.google.gson.JsonArray;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import java.io.IOException;
import java.util.List;

public class RawHuggingFaceInferenceAPIClient {
    private static final String DEFAULT_ENDPOINT_BASE = "https://api-inference.huggingface.co/models";
    ExternalJSONAPIClient client;
    private static DKULogger logger = DKULogger.getLogger((String)"dku.llm.huggingface.client");

    public RawHuggingFaceInferenceAPIClient(String apiKey, ProxySettings proxySettings, boolean forceContentLength) {
        AbstractLLMConnection.HTTPBasedLLMNetworkSettings networkSettings = new AbstractLLMConnection.HTTPBasedLLMNetworkSettings();
        this.client = OnlineLLMUtils.getExternalJSONClientWithRetryStrategy(DEFAULT_ENDPOINT_BASE, null, false, proxySettings, networkSettings);
        this.client.addHeader("Authorization", "Bearer " + apiKey);
        if (forceContentLength) {
            this.client.forceContentLength = true;
        }
    }

    public void close() {
        this.client.close();
    }

    public LLMClient.SimpleCompletionResponse complete(String model, String prompt, CoreCompletionSettings ccs) throws IOException {
        GenericTextCompletionLLMMarshall marshall = GenericTextCompletionLLMMarshall.get(GenericLLMHandling.HUGGING_FACE, null);
        JsonObject obj = marshall.prepareTextCompletionQuery(prompt, ccs);
        JsonArray rcr = (JsonArray)this.client.postObjectToJSON(model, JsonArray.class, (Object)obj);
        return marshall.parseTextCompletionResponse((JsonElement)rcr, prompt);
    }

    private static class RawGenerationResponse {
        List<RawGenerationGeneration> generations;

        private RawGenerationResponse() {
        }
    }

    public static class RawGenerationGeneration {
        public String generated_text;
    }
}

