package com.customllm.llm;

import java.io.IOException;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.lang.StringUtils;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.HashMap;
import com.dataiku.common.rpc.ExternalJSONAPIClient;
import com.dataiku.dip.custom.PluginSettingsResolver.ResolvedSettings;
import com.dataiku.dip.llm.custom.CustomLLMClient;
import com.dataiku.dip.llm.online.LLMChatMessageUtils;
import com.dataiku.dip.llm.online.LLMClient;
import com.dataiku.dip.llm.online.LLMClient.ChatMessage;
import com.dataiku.dip.llm.online.LLMClient.CompletionQuery;
import com.dataiku.dip.llm.online.LLMClient.EmbeddingQuery;
import com.dataiku.dip.llm.online.LLMClient.SimpleCompletionResponse;
import com.dataiku.dip.llm.online.LLMClient.SimpleEmbeddingResponse;
import com.dataiku.dip.llm.online.LLMClient.StreamedCompletionResponseConsumer;
import com.dataiku.dip.ApplicationConfigurator;
import com.dataiku.dip.llm.promptstudio.PromptStudio.LLMStructuredRef;
import com.dataiku.dip.resourceusage.ComputeResourceUsage;
import com.dataiku.dip.resourceusage.ComputeResourceUsage.InternalLLMUsageData;
import com.dataiku.dip.resourceusage.ComputeResourceUsage.LLMUsageData;
import com.dataiku.dip.resourceusage.ComputeResourceUsage.LLMUsageType;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.LaxRedirectStrategy;
import com.dataiku.dip.connections.AbstractLLMConnection.HTTPBasedLLMNetworkSettings;
import com.dataiku.dip.connections.OpenAIConnection;
import com.dataiku.dip.llm.utils.OnlineLLMUtils;
import com.dataiku.dss.shadelib.org.apache.http.client.CookieStore;
import com.dataiku.dss.shadelib.org.apache.http.client.config.CookieSpecs;
import com.dataiku.dss.shadelib.org.apache.http.client.config.RequestConfig;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpDelete;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpGet;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPost;
import com.dataiku.dss.shadelib.org.apache.http.client.methods.HttpPut;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.BasicCookieStore;
import com.dataiku.dss.shadelib.org.apache.http.impl.client.HttpClientBuilder;
import com.dataiku.dip.utils.DKULogger;
import com.dataiku.dip.utils.JSON;
import com.dataiku.dip.utils.JF;
import com.dataiku.dip.utils.JF.ObjectBuilder;
import com.google.gson.*;
import com.customllm.llm.CustomApiImplementation;
import com.customllm.llm.GptApiImplementation;


public class NvidiaNIMPlugin extends CustomLLMClient {
    public NvidiaNIMPlugin() {
    }

    private String endpointUrl;
    private String model;
    private ResolvedSettings rs;
    private String inputType;
    private ExternalJSONAPIClient client;
    private final InternalLLMUsageData usageData = new LLMUsageData();
    private final HTTPBasedLLMNetworkSettings networkSettings = new HTTPBasedLLMNetworkSettings();
    private int maxParallel = 1;
    
    private String getAccessToken() {
        JsonElement apiKeysElement = this.rs.config.get("apikeys");

        if (apiKeysElement != null && apiKeysElement.isJsonObject()) {
            JsonElement apiKeyElement = apiKeysElement.getAsJsonObject().get("api_key");
            if (apiKeyElement != null && !apiKeyElement.isJsonNull()) {
                return "Bearer " + apiKeyElement.getAsString().trim();
            }
        }
        return null;
    }

    private static class RawChatCompletionMessage {
        String role;
        String content;
    }

    private static class RawChatCompletionChoice {
        RawChatCompletionMessage message;
    }

    private static class RawUsageResponse {
        int total_tokens;
        int prompt_tokens;
        int completion_tokens;
    }

    private static class RawChatCompletionResponse {
        List<RawChatCompletionChoice> choices;
        RawUsageResponse usage;
    }

    private static class EmbeddingResponse {
        List<EmbeddingResult> data = new ArrayList<>();
        RawUsageResponse usage;

    }

    private static class EmbeddingResult {
        double[] embedding;
    }

    private static class OpenAIEmbeddingResponse {
        List<OpenAIEmbeddingResult> data = new ArrayList<>();
        RawUsageResponse usage;

    }

    private static class OpenAIEmbeddingResult {
        double[] embedding;
    }


    @Override
    public void init(ResolvedSettings settings) {

        this.rs = settings;
        this.endpointUrl = settings.config.get("endpoint_url").getAsString();
        this.model = settings.config.get("model").getAsString();
        this.inputType = settings.config.get("inputType").getAsString();
        this.maxParallel = settings.config.get("maxParallelism").getAsNumber().intValue();

        this.networkSettings.queryTimeoutMS = settings.config.get("networkTimeout").getAsNumber().intValue();
        this.networkSettings.maxRetries = settings.config.get("maxRetries").getAsNumber().intValue();
        this.networkSettings.initialRetryDelayMS = settings.config.get("firstRetryDelay").getAsNumber().longValue();
        this.networkSettings.retryDelayScalingFactor = settings.config.get("retryDelayScale").getAsNumber().doubleValue();
    
        // TODO: Manage all AuthN/Z
        this.client = new ExternalJSONAPIClient(this.endpointUrl, null, true, ApplicationConfigurator.getProxySettings(),
        OnlineLLMUtils.getLLMResponseRetryStrategy(this.networkSettings), (builder) -> OnlineLLMUtils.add429RetryStrategy(builder, networkSettings)) {
            @Override
            protected HttpGet newGet(String path) {
                HttpGet get = new HttpGet(path);
                setAdditionalHeadersInRequest(get);
                get.addHeader("Content-Type", "application/json");
                get.addHeader("Authorization", getAccessToken());
                return get;
            }

            @Override
            protected HttpPost newPost(String path) {
                HttpPost post = new HttpPost(path);
                setAdditionalHeadersInRequest(post);
                post.addHeader("Content-Type", "application/json");
                post.addHeader("Authorization", getAccessToken());
                return post;

            }
        };

    }

    @Override
    public int getMaxParallelism() {
        return maxParallel;
    }

    @Override
    public synchronized List<SimpleCompletionResponse> completeBatch(List<CompletionQuery> completionQueries)
            throws IOException {
        List<SimpleCompletionResponse> ret = new ArrayList<>();
        for (CompletionQuery query : completionQueries) {

            long before = System.currentTimeMillis();
            SimpleCompletionResponse scr = null;

            logger.info("Chat Complete.");
            scr = chatComplete(
                   model,
                   query.messages,
                   query.settings.maxOutputTokens,
                   query.settings.temperature,
                   query.settings.topP,
                   query.settings.topK,
                   query.settings.stopSequences,
                   query.settings.toolChoice,
                   query.settings.tools
                   
                   );
            //scr.estimatedCost = (this.price * scr.promptTokens + this.price * scr.completionTokens) / 1000;
            usageData.totalComputationTimeMS += (System.currentTimeMillis() - before);
            usageData.totalPromptTokens += scr.promptTokens;
            usageData.totalCompletionTokens += scr.completionTokens;
            //usageData.estimatedCostUSD += scr.estimatedCost;

            ret.add(scr);
        }

        return ret;
    }

    @Override
    public List<SimpleEmbeddingResponse> embedBatch(List<EmbeddingQuery> queries, LLMClient.EmbeddingSettings settings) throws IOException {
        List<SimpleEmbeddingResponse> ret = new ArrayList<>();
        for (EmbeddingQuery query : queries) {
            long before = System.currentTimeMillis();

            logger.info("Chat Embed.");

            SimpleEmbeddingResponse scr = embed(model,
                    query.text, inputType);

           // scr.estimatedCost = (this.price * scr.promptTokens) / 1000;

            usageData.totalComputationTimeMS += (System.currentTimeMillis() - before);
            usageData.totalPromptTokens += scr.promptTokens;
            // usageData.estimatedCostUSD += scr.estimatedCost;

            ret.add(scr);
        }
        return ret;
    }

    @Override    
    public ComputeResourceUsage getTotalCRU(LLMUsageType usageType, LLMStructuredRef llmRef) {
        ComputeResourceUsage cru = new ComputeResourceUsage();
        cru.setupLLMUsage(usageType, llmRef.connection, llmRef.type.toString(), llmRef.id);
        cru.llmUsage.setFromInternal(this.usageData);
        return cru;
    }

    private CustomApiImplementation GetCustomApiImplementation(String model) throws IOException {
        CustomApiImplementation apiImplementation;
       
        apiImplementation = new GptApiImplementation(this.endpointUrl, model);

        return apiImplementation;
    }

    public SimpleCompletionResponse chatComplete(String model, List<ChatMessage> messages, Integer maxTokens, Double temperature, Double topP, Integer topK, List<String> stopSequences, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) throws IOException {
        CustomApiImplementation apiImplementation = GetCustomApiImplementation(model);
        apiImplementation.addSettingsInObject(model, maxTokens, temperature, topP, topK, stopSequences, toolChoice, tools);
        apiImplementation.addMessagesInObject(messages);
        SimpleCompletionResponse ret = apiImplementation.sendPostObject(this.client, this.networkSettings);
    
        return ret;
    }

    public SimpleEmbeddingResponse embed(String model, String text, String inputType) throws IOException {
        if (inputType.isEmpty()) {
            inputType = "query";
         }
        //String endpoint = this.endpointUrl + "/" + model + "/embeddings";
        ObjectBuilder ob = JF.obj().with("input", text).with("model", model).with("input_type",inputType);
        logger.info("Raw embedding query");
        OpenAIEmbeddingResponse rcr = this.client.postObjectToJSON(
            this.endpointUrl, this.networkSettings.queryTimeoutMS,
            OpenAIEmbeddingResponse.class, ob.get()
            );
        logger.info("Raw embedding response");

        if (rcr.data.size() != 1) {
            throw new IOException("Chat did not respond with valid embeddings");
        }

        SimpleEmbeddingResponse ret = new SimpleEmbeddingResponse();
        ret.embedding = rcr.data.get(0).embedding;
        ret.promptTokens = rcr.usage.total_tokens;
        return ret;

    }

    public boolean supportsStream() {
        return true;
    }
    public void streamComplete(LLMClient.CompletionQuery query, LLMClient.StreamedCompletionResponseConsumer consumer) throws Exception {
        streamChatComplete(consumer, model, query.messages, query.settings.maxOutputTokens, query.settings.temperature, query.settings.topP, query.settings.topK, query.settings.stopSequences, query.settings.toolChoice, query.settings.tools);
     }

     public void streamChatComplete(StreamedCompletionResponseConsumer consumer, String model, List<ChatMessage> messages, Integer maxTokens, Double temperature, Double topP, Integer topK, List<String> stopSequences, LLMClient.ToolChoice toolChoice, List<LLMClient.AbstractTool> tools) throws Exception {
        CustomApiImplementation apiImplementation = GetCustomApiImplementation(model);
        apiImplementation.addSettingsInObject(model, maxTokens, temperature, topP, topK, stopSequences, toolChoice, tools);
        apiImplementation.addMessagesInObject(messages);
        apiImplementation.streamChatComplete(client, consumer, networkSettings, toolChoice, tools);
    }
    private static String calculateMD5(String input) {
        try {
            MessageDigest md = MessageDigest.getInstance("MD5");
            md.update(input.getBytes());
            byte[] digest = md.digest();
            StringBuilder sb = new StringBuilder();
            for (byte b : digest) {
                sb.append(String.format("%02x", b));
            }
            return sb.toString();
        } catch (NoSuchAlgorithmException e) {
            e.printStackTrace();
        }
        return null;
    }

    @Override
    public List<SimpleEmbeddingResponse> embedBatch(List<EmbeddingQuery> queries) throws IOException {
        return embedBatch(queries, null);
    }

    private static final DKULogger logger = DKULogger.getLogger("dku.llm.customplugin.nvidia.nim");
}