from dataiku.langchain.dku_tracer import LangchainToDKUTracer
from dataiku.langchain import DKUChatLLM
from dataiku.llm.python import BaseLLM
from typing import Annotated
import dataiku
from typing_extensions import TypedDict
from langchain_core.messages import ToolMessage
from pydantic import BaseModel
from langgraph.store.base import BaseStore
from langgraph.errors import GraphBubbleUp
#from langgraph.graph import StateGraph, START, END
#from langgraph.graph.message import add_messages
from langgraph.types import Command
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.tools import tool, BaseTool
from langchain import hub
import logging
import json

from langgraph.prebuilt import ToolNode
from langgraph.prebuilt.tool_node import  msg_content_output

from langchain_core.messages import (
    AIMessage,
    AnyMessage,
    ToolCall,
    ToolMessage,
    convert_to_messages,
)
from typing import (
    Any,
    Callable,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Type,
    Union,
    cast,
    get_type_hints,
)
from langchain_core.runnables import RunnableConfig
from langchain_core.runnables.config import (
    get_config_list,
    get_executor_for_config,
)

from dataikuapi.dss.utils import AnyLoc

logger = logging.getLogger("dku.agents.tools_using")

import asyncio

class ToolException(Exception):
    """Optional exception that tool throws when execution error occurs.

    When this exception is thrown, the agent will not stop working,
    but it will handle the exception according to the handle_tool_error
    variable of the tool, and the processing result will be returned
    to the agent as observation, and printed in red on the console.
    """

    pass

def _handle_tool_error(
    e: ToolException,
    *,
    flag: Optional[Union[Literal[True], str, Callable[[ToolException], str]]],
) -> str:
    if isinstance(flag, bool):
        if e.args:
            content = e.args[0]
        else:
            content = "Tool execution error"
    elif isinstance(flag, str):
        content = flag
    elif callable(flag):
        content = flag(e)
    else:
        raise ValueError(
            f"Got unexpected type of `handle_tool_error`. Expected bool, str "
            f"or callable. Received: {flag}"
        )
    return content

class DKUToolNode(ToolNode):

    name: str = "DKUToolNode"


    def __init__(
        self,
        tools: Sequence[Union[BaseTool, Callable]],
        *,
        name: str = "tools",
        tags: Optional[list[str]] = None,
        handle_tool_errors: Union[
            bool, str, Callable[..., str], tuple[type[Exception], ...]
        ] = True,
        messages_key: str = "messages",
    ) -> None:
        super().__init__(tools=tools, name=name, tags=tags) #self._func, self._afunc, name=name, tags=tags) # trace=False)

        self._sources = []
        self._artifacts = []

    #def __init__(self, tools: list) -> None:
    #    self.tools_by_name = {tool.name: tool for tool in tools}

    async def _arun_one(
        self,
        call: ToolCall,
        input_type: Literal["list", "dict"],
        config: RunnableConfig,
    ) -> ToolMessage:

        import asyncio
        loop = asyncio.get_event_loop()

        def func():
            return self._run_one(call, input_type, config)

        logger.info("Start _arun_one on tool_call: %s" % call)
        return await loop.run_in_executor(None, func)

        raise Exception("arun not supported")


    def _run_one(
        self,
        call: ToolCall,
        input_type: Literal["list", "dict"],
        config: RunnableConfig,
    ) -> ToolMessage:
        invalid_tool_message = self._validate_tool_call(call)
        if invalid_tool_message:
            return invalid_tool_message

        from dataiku.langchain.dku_tracer import dku_span_builder_for_callbacks

        try:
            input = {**call, **{"type": "tool_call"}}

            with dku_span_builder_for_callbacks(config["callbacks"]).subspan("tu_toolcall") as s:

                logger.info("Invoking tool %s" % call["name"])

                tool_message = self.tools_by_name[call["name"]].invoke(input, config)

                logger.info("Got tool_message")
                logger.info(" type of content: %s" % type(tool_message.content))
                logger.info(" type of artifact: %s" % type(tool_message.artifact))

                output_dict = tool_message.artifact #json.loads(tool_message.artifact)

                #output = output_dict["output"]
                #output = json.loads(tool_message.content)
                #tool_message.content = json.dumps(output)

                logger.info("Got response from tool %s" % (call["name"]))
                logger.info("output dict has keys: %s" % (output_dict.keys()))

                if "sources" in output_dict and output_dict["sources"] is not None:
                    s.attributes["tu_toolcall_sources"] = output_dict["sources"]
                    self._sources.extend(output_dict["sources"])

                if "artifacts" in output_dict and output_dict["artifacts"] is not None:
                    self._artifacts.extend(output_dict["artifacts"])

                return tool_message

        # GraphInterrupt is a special exception that will always be raised.
        # It can be triggered in the following scenarios:
        # (1) a NodeInterrupt is raised inside a tool
        # (2) a NodeInterrupt is raised inside a graph node for a graph called as a tool
        # (3) a GraphInterrupt is raised when a subgraph is interrupted inside a graph called as a tool
        # (2 and 3 can happen in a "supervisor w/ tools" multi-agent architecture)
        except GraphBubbleUp as e:
            raise e
        except Exception as e:
            logger.exception("Failure")
            if isinstance(self.handle_tool_errors, tuple):
                handled_types: tuple = self.handle_tool_errors
            elif callable(self.handle_tool_errors):
                handled_types = _infer_handled_types(self.handle_tool_errors)
            else:
                # default behavior is catching all exceptions
                handled_types = (Exception,)

            # Unhandled
            if not self.handle_tool_errors or not isinstance(e, handled_types):
                raise e
            # Handled
            else:
                content = _handle_tool_error(e, flag=self.handle_tool_errors)
            return ToolMessage(
                content=content,
                name=call["name"],
                tool_call_id=call["id"],
                status="error",
            )

        if isinstance(response, Command):
            return self._validate_tool_command(response, call, input_type)
        elif isinstance(response, ToolMessage):
            response.content = cast(
                Union[str, list], msg_content_output(response.content)
            )
            return response
        else:
            raise TypeError(
                f"Tool {call['name']} returned unexpected type: {type(response)}"
            )


def format_multipart_messages(multipart_messages):
    # Langgraph react agent doesn't support multipart messages
    # If we have text-only multipart messages, then convert them to messages without parts
    messages_without_parts = []

    for message in multipart_messages:
        if "parts" in message:
            for part in message["parts"]:
                if part.get("type") != "TEXT":
                    raise Exception("Visual agents only support text inputs")
            # react agent doesn't support message parts - it needs all the text to be in the "content" field
            message = json.loads(json.dumps(message))
            message["content"] = "\n".join(part.get("text", "") for part in message["parts"])
            del message["parts"]

        messages_without_parts.append(message)

    return messages_without_parts


class ToolsUsingAgent(BaseLLM):
    def __init__(self):
        self.client = dataiku.api_client()
        self.project = self.client.get_default_project()
        pass

    def set_config(self, config, unused):
        self.config = config

    def load_tools(self, query) -> DKUToolNode:
        lctools = []
        context = query.get("context", None)

        for used_tool in self.config["tools"]:
            tool_ref = used_tool["toolRef"]

            logger.info("Will use tool %s"  % tool_ref)

            tool_loc = AnyLoc.from_ref(self.project.project_key, tool_ref)
            dku_api_tool = self.client.get_project(tool_loc.project_key).get_agent_tool(tool_loc.object_id)

            lctool = dku_api_tool.as_langchain_structured_tool(context=context)

            if isinstance(lctool, list):
                for sub_lctool in lctool:

                    if "additionalDescription" in used_tool:
                        sub_lctool.description = sub_lctool.description + "\n\n" + used_tool["additionalDescription"]

                    lctools.append(sub_lctool)
            else:
                if "additionalDescription" in used_tool:
                    lctool.description = lctool.description + "\n\n" + used_tool["additionalDescription"]

                lctools.append(lctool)


        tool_node = DKUToolNode(tools=lctools)

        return tool_node

    async def aprocess_stream(self, query, settings, trace):

        yield {"chunk":  {"type": "event", "eventKind": "AGENT_GETTING_READY"}}

        inputs = {
            "messages" : format_multipart_messages(query["messages"])
        }
        tracer = LangchainToDKUTracer(dku_trace=trace)

        tool_node = await asyncio.get_running_loop().run_in_executor(None, self.load_tools, query)

        logging.info("tool_node has %d artifacts at start " % len(tool_node._artifacts))


        from langgraph.prebuilt import create_react_agent

        used_completion_settings = self.config.get("completionSettings", {}) # ignore query settings, use config settings instead 
        lcllm = self.project.get_llm(self.config["llmId"]).as_langchain_chat_model(completion_settings=used_completion_settings)
        context = query.get("context")
        if context:
            lcllm = lcllm.bind(context=context)

        prompt = None
        if "systemPromptAppend" in self.config and len(self.config["systemPromptAppend"]) > 0:
            prompt = self.config["systemPromptAppend"]

        graph = create_react_agent(lcllm, tool_node, prompt=prompt, debug=False) #, checkpointer=saver)
        config = {"configurable": {"thread_id": "thread-1"}, "callbacks":  [tracer]}

        async for event in graph.astream_events(inputs, config, stream_mode="messages", version="v2"):
            kind = event["event"]
            if kind == "on_chat_model_stream":
                content = event["data"]["chunk"].content
                if content:
                    yield {"chunk":  {"text": content}}
                else:
                    # This is a streaming event without user-visible content, probably because it's a tool call. Don't stream to end
                    # user
                    pass
                    #yield {"chunk":  {"type": "event", "eventKind": kind, "eventData": {"name": event["name"]}}}
            elif kind == "on_tool_start":
                #logger.info("TOOL Event seen: %s" % event)
                yield {"chunk":  {"type": "event", "eventKind": "AGENT_TOOL_START", "eventData": {"toolName": event["name"]}}}
            elif kind == "on_chain_start":
                logging.debug("Chain start: %s" % event)
                if "name" in event and event["name"] == "agent":
                    yield {"chunk":  {"type": "event", "eventKind": "AGENT_THINKING", "eventData": {}}}
                else:
                    # flood, drop
                    pass
            elif kind == "on_chain_stream" or kind == "on_chain_start" or kind == "on_chain_end":
                # Flood events, not useful
                pass
            else:
                # Unknown event. Send to frontend, but don't fail on non-JSON-serializable
                try:
                    event_json = json.dumps(event)
                    yield {"chunk":  {"type": "event", "eventKind": "other", "eventData": {"event": event}}}
                except Exception as e:
                    pass
                #logger.info("Other event seen: %s" % event["event"])

        logger.info("At end, gathered %d artifacts" % len(tool_node._artifacts))

        if tool_node._artifacts:
            yield {"chunk":  {"type": "content", "artifacts": tool_node._artifacts } }

        logger.info("At end, here are all my sources: %s" %  (tool_node._sources))

        all_sources = tool_node._sources
        if not all_sources:
            all_sources = []

        yield { 
            "footer": {
                "additionalInformation" : {
                    "sources" : all_sources 
                }
            }
        }
