import dataiku
from dataiku.llm.python import BaseLLM
from dataiku.langchain.dku_llm import DKUChatLLM
import time
import json

from text2sql_agent import execute_sql, get_table_columns
from langchain.tools import StructuredTool
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.prebuilt import create_react_agent

# Create custom tools for the agent
DATASETS = ["match_results", "match_goalscorers"]
tools = [StructuredTool.from_function(f) for f in [get_table_columns, execute_sql]]
tool_names = [tool.name for tool in tools]

# Instantiating a LangChain LLM with Dataiku
LLM_ID = dataiku.get_custom_variables()["LLM_id"]
llm = DKUChatLLM(llm_id=LLM_ID, temperature=0)
llm = llm.bind_tools(tools, parallel_tool_calls=False)

# Instantiate the agent
SYSTEM_PROMPT = f"""You are an expert in writing SQL queries. When writing a SQL query, you are very careful and precise to ensure you do not miss out any data.
You ignore cases for text filters, use only available columns in the table & compute new columns when required to answer the request if the information is available.

Here is the list of tables available for use:
{DATASETS}
"""
agent = create_react_agent(llm, tools=tools, state_modifier=SYSTEM_PROMPT)

class MyLLM(BaseLLM):
    def __init__(self):
        pass

    def process(self, query, settings, trace):
        
        start = time.time()
        
        # Invoke the agent for a question
        result = agent.invoke({"messages": query["messages"][0]["content"]})

        # Get intermediate steps
        tool_calls = {}
        for m in result["messages"][1:]:
            if type(m) == AIMessage:
                if "tool_calls" in m.additional_kwargs:
                    for tool_call in m.additional_kwargs["tool_calls"]:
                        call_id = tool_call["id"]
                        tool_calls[call_id] = {
                            "name": tool_call["function"]["name"],
                            "arguments": tool_call["function"]["arguments"],
                        }
                else:
                    reply = m.content
            if type(m) == ToolMessage:
                tool_calls[m.tool_call_id]["result"] = m.content

        intermediate_steps, i = "", 0

        for k in tool_calls:
            tool_call = tool_calls[k]
            i += 1
            action = f"{i}. {tool_call['name']}({tool_call['arguments']})"
            intermediate_steps += f"{action} --> {tool_call['result']}" + "\n"

        if len(intermediate_steps) > 0:
            intermediate_steps = intermediate_steps[:-1]

        sql_query = ""
        if len(tool_calls) > 0:
            if "sql_query" in tool_call["arguments"]:
                sql_query = json.loads(tool_call["arguments"])["sql_query"]

        return {
            "text": json.dumps(
                {
                    "reply": reply,
                    "answer_details": intermediate_steps,
                    "sql_query": sql_query,
                    "processing_time": f"{(time.time()-start):.1f}"   
                }
            )
        }