import copy
import logging
from typing import Dict, List

from dataiku.llm.python.blocks_graph import NextBlock
from dataiku.llm.python.blocks_graph.prompts import DEFAULT_MANDATORY_CALL_SINGLE_TOOL_PROMPT
from dataiku.llm.python.blocks_graph.utils import _validate_and_parse_tool_call, default_if_blank, interpolate_cel
from dataiku.llm.python.tools_using_2 import  LLMTool, _tool_calls_from_chunks
from dataiku.llm.python.types import ChatMessage
from dataikuapi.dss.llm_tracing import SpanBuilder
from .single_tool_call import SingleToolCallHandler

logger = logging.getLogger("dku.agents.blocks_graph")

class MandatoryToolCallBlockHandler(SingleToolCallHandler):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)

    def process_stream(self, trace: SpanBuilder):
        logger.info("Mandatory single tool block starting with config %s", self.block_config)
        
        yield {"chunk": {"type": "event", "eventKind": "AGENT_THINKING", "eventData": {}}}

        cel_engine = self.standard_cel_engine()

        # 1. Prepare and Execute LLM Request to get tool arguments
        with trace.subspan("DKU_AGENT_LLM_CALL") as llm_trace:
            completion = self.new_completion()
            
            # Initialize settings
            completion._settings = copy.deepcopy(self.block_config.get("completionSettings", {}))
            
            # Provide Context & System Prompt
            completion.with_context(self.turn.current_merged_context)
            
            system_prompt = default_if_blank(self.block_config.get("systemPrompt"), DEFAULT_MANDATORY_CALL_SINGLE_TOOL_PROMPT)
            system_prompt = interpolate_cel(cel_engine, system_prompt)
            completion.with_message(system_prompt, "system")

            # Load and Validate Tool
            tool_id = self.block_config["tool"]
            tool_def = self.turn.agent.load_or_get_tool(tool_id)
            if tool_def is None:
                raise ValueError(f"Mandatory single tool block requires a configured tool. Tool '{tool_id}' not found.")
            
            tools_by_name = self._get_tools_by_name(tool_def)
            if not tools_by_name:
                raise ValueError(f"Tool definition '{tool_id}' yielded no executable tools.")
            if len(tools_by_name) != 1:
                raise ValueError(f"Mandatory single tool block requires exactly one tool. Found {len(tools_by_name)} in '{tool_id}'.")

            # Force Tool Choice
            target_tool_name = next(iter(tools_by_name))
            completion.settings["tools"] = [llm_tool.llm_descriptor for llm_tool in tools_by_name.values()]
            completion.settings["toolChoice"] = {"type": "tool_name", "name": target_tool_name}

            # Append Conversation History
            completion.cq["messages"].extend(self.turn.initial_messages)
            completion.cq["messages"].extend(self.sequence_context.generated_messages)

            logger.info("About to run completion with forced tool: %s", target_tool_name)

            # Stream Execution
            accumulated_tool_call_chunks = []
            
            for ichunk in self._run_completion(completion, llm_trace):
                # Note: We ignore ichunk.text here. In a mandatory tool block, 
                # we are interested in the structured tool call, not conversational chatter.

                if ichunk.memory_fragment and self.block_config.get("outputMode") == "ADD_TO_MESSAGES":
                    memory_fragment_msg: ChatMessage = {
                        "role": "memoryFragment",
                        "memoryFragment": ichunk.memory_fragment
                    }
                    self.sequence_context.generated_messages.append(memory_fragment_msg)

                if ichunk.artifacts:
                    artifacts = ichunk.artifacts
                    for artifact in artifacts:
                        hierarchy: List = artifact.setdefault("hierarchy", [])
                        hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number}) # TODO @lavish-agents ?, "agentId": agent_id, "agentName": agent_name})
                    yield {"chunk": {"artifacts": artifacts}}

                if ichunk.sources:
                    sources = ichunk.sources
                    for source in sources:
                        hierarchy: List = source.setdefault("hierarchy", [])
                        hierarchy.insert(0, {"type": "AGENT", "agentLoopIteration": self.iteration_number}) # TODO @lavish-agents ? , "agentId": agent_id, "agentName": agent_name})
                    self.sequence_context.sources.extend(sources)

                if ichunk.tool_call_chunks:
                    accumulated_tool_call_chunks.extend(ichunk.tool_call_chunks)

        # Reassemble tool calls
        dku_tool_calls = _tool_calls_from_chunks(accumulated_tool_call_chunks)
        logger.info("Gathered tool calls: %s" % dku_tool_calls)

        if not dku_tool_calls:
            raise Exception("Mandatory single tool block failed: LLM did not emit any tool call.")

        if len(dku_tool_calls) > 1:
            raise Exception(f"Mandatory single tool block failed: LLM emitted more than 1 call (emitted {len(dku_tool_calls)}).")
    
        # 2. Parse Tool Call
        tool_call = _validate_and_parse_tool_call(dku_tool_calls[0])

        # 3. Execute Tool(s)
        with trace.subspan("DKU_AGENT_TOOL_CALLS") as tools_trace:
            yield from self._call_tool(tool_call, tools_by_name.get(tool_call.name), tools_trace)

        # 4. Update Context & Transition
        yield NextBlock(id=self.block_config.get("nextBlock", None))

    def _get_tools_by_name(self, ptool) -> Dict[str, LLMTool]:
        return {llm_tool.llm_tool_name: llm_tool for llm_tool in ptool.llm_tools}
