import copy
import json
import logging
import threading
from collections import defaultdict, deque
from typing import Collection

import dataiku
from dataiku.base.utils import redact_sensitive
from dataiku.llm.python import BaseLLM
from dataiku.llm.python.types import FunctionToolCall, MemoryFragment, ToolValidationResponse, ToolValidationRequest
from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter, DSSLLMStreamedCompletionChunk, DSSLLM

from dataikuapi.dss.llm_tracing import SpanBuilder
from dataiku.generated_sources.com.dataiku.dip.dao.saved_model.next_turn_behaviour import NextTurnBehaviour
from dataiku.generated_sources.com.dataiku.dip.llm.online.llm_client.tool_output import ToolOutput

from . import SequenceContext, NextBlock
from .blocks.standard_react import ReactBlockHandler
from .prompts import DEFAULT_NEXT_TURN_SMART_MODE_PROMPT
from .types import ToolCallWithPotentialValidation, ToolCallValidationInfo
from .utils import _validate_and_parse_tool_call

from dataiku.llm.python.tools_using_2 import PreparedTool, LLMTool

# Add "logger.trace" level in blocks-graph agent
TRACE_LEVEL_NUM = 5
logging.addLevelName(TRACE_LEVEL_NUM, "TRACE")
def trace(self, message, *args, **kws):
    if self.isEnabledFor(TRACE_LEVEL_NUM):
        self._log(TRACE_LEVEL_NUM, message, args, **kws)
logging.Logger.trace = trace

logger = logging.getLogger("dku.agents.blocks_graph")

CURRENT_BLOCK_ID_STATE_KEY = "_currentBlockId"
PREVIOUS_TURN_BLOCK_ID_SCRATCHPAD_KEY = "_previousTurnBlockId"


class BlocksGraphAgent(BaseLLM):
    def __init__(self):
        super().__init__()
        self.project = dataiku.api_client().get_default_project()

        self.tools_cache = {}
        self.tools_cache_lock = threading.Lock()

    def set_config(self, config, unused):
        self.config = config

    def load_or_get_tool(self, used_tool) -> PreparedTool:
        logger.info("Loading tool for used_tool: %s" % used_tool)
        cache_key = json.dumps(used_tool)

        with self.tools_cache_lock:
            if cache_key in self.tools_cache:
                return self.tools_cache[cache_key]
            else:
                tool_ref = used_tool["toolRef"]
                logger.info("Creating tool %s" % tool_ref)

                dku_api_tool = self.project.get_agent_tool(tool_ref)

                ptool = PreparedTool(dku_api_tool, used_tool)
                # specific subtool filtering
                if len(used_tool.get("subtoolName", '')) > 0:
                    matching_tools = list(filter(lambda t: t.dku_subtool_name == used_tool["subtoolName"], ptool.llm_tools))
                    if len(matching_tools) == 0:
                        raise ValueError(f"No enabled subtool named %s within tool %s", used_tool["subtoolName"], used_tool["toolRef"])
                    ptool.llm_tools = matching_tools

                self.tools_cache[cache_key] = ptool
                return ptool

    def process_stream(self, query, settings, trace: SpanBuilder):

        # The BlocksGraphBasedAgent class must serve multiple requests at the same time. Thus, we cannot store
        # anything on it. We therefore use a class representing a single request/turn (even if multiple
        # blocks and iterations can happen during the turn).

        turn = BlocksGraphAgentTurn(self, query)

        for c in turn.process_stream(trace):
            yield c

    def redact_context(self, context):
        """ Returns a copy of the context object with redacted values for keys that look like sensitive data
        """
        sensitive_keys = self.config["dkuRedactionConfig"]["redactedKeys"]
        sensitive_patterns = [p["pattern"] for p in self.config["dkuRedactionConfig"]["redactedKeyPatterns"]]
        return redact_sensitive(context, sensitive_keys, sensitive_patterns)


class BlocksGraphAgentTurn(object):
    """
    This class holds everything that happens during a single turn of the agent (i.e. a single call to the agent through the LLM Mesh API).
    The turn may itself represent several block transitions (and each block may itself represent several iterations of LLM calls and tool calls).
    """

    def __init__(self, agent, query):
        self.agent = agent
        self.initial_messages = query["messages"]
        self.initial_context = query.get("context")

        logging.info("Context at turn start: %s" % self.agent.redact_context(self.initial_context))

        self.context_upsert = {}
        self.current_merged_context = copy.deepcopy(self.initial_context)
        if self.current_merged_context is None:
            self.current_merged_context = {}

        # Append a hash of the agent config to the state context key to avoid collisions if multiple agents are used in the same project
        import hashlib
        agent_config_hash = hashlib.sha256(json.dumps(self.agent.config, sort_keys=True).encode("utf-8")).hexdigest()[:6]
        self.state_context_key = "_blocksGraphState_" + agent_config_hash

    def build_block_handler(self, block_id, sequence_context):
        from .block_builder import build_block_handler
        return build_block_handler(self, self.agent.config["blocks"], block_id, sequence_context)

    def context_get(self, key, default=None):
        return self.current_merged_context.get(key, default)

    def context_set(self, key, value):
        """Sets a key in the context, and mark it for upsert at the end of the turn"""
        self.current_merged_context[key] = value
        self.context_upsert[key] = value

    def context_has(self, key):
        return key in self.current_merged_context

    def state_get(self, key, default=None):
        return self.current_merged_context.get(self.state_context_key, {}).get(key, default)

    def state(self):
        return self.current_merged_context.get(self.state_context_key, {})

    def state_set(self, key, value):
        """Sets a key in the state (which will mark the whole state for upsert at the end of the turn)"""
        state = self.context_get(self.state_context_key, {})
        state[key] = value
        self.context_set(self.state_context_key, state)

    def _decide_starting_block_id(self, trace):
        """Chooses which block should be used to start this agent turn"""
        next_turn_behaviour: NextTurnBehaviour = self.agent.config["nextTurnBehaviour"]
        starting_block_id = self.agent.config.get("startingBlockId")
        if starting_block_id is None or starting_block_id == "":
            raise ValueError("Agent configuration is missing. Please set a starting block for this agent.")
        previous_turn_block_id = self.state_get(CURRENT_BLOCK_ID_STATE_KEY)

        # HITL overrides agent's settings
        if self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses":
            return previous_turn_block_id

        if previous_turn_block_id is None:
            logger.info(f"First turn of the conversation, beginning with starting block: {starting_block_id}")
            return starting_block_id

        if next_turn_behaviour == "LAST_BLOCK":
            logger.info(f"Continuing from previous block: {previous_turn_block_id}")
            return previous_turn_block_id
        elif next_turn_behaviour == "STARTING_BLOCK":
            logger.info(f"Returning to starting block: {starting_block_id}")
            return starting_block_id
        elif next_turn_behaviour == "SMART":
            smart_mode_classification = self._starting_block_smart_mode_classification(trace)
            logging.debug(f"Smart mode classification: {smart_mode_classification}")
            if smart_mode_classification == "CONTINUE":
                logger.info(f"Smart mode next turn behaviour selected previous block: {previous_turn_block_id}")
                return previous_turn_block_id
            else:
                logger.info(f"Smart mode next turn behaviour selected starting block: {starting_block_id}")
                return starting_block_id
        else:
            raise ValueError(f"Unknown next turn behaviour mode: {next_turn_behaviour}")

    def process_stream(self, trace : SpanBuilder):

        logger.info("Starting to process blocks graph turn")

        self.iteration_number = 1 # TODO @lavish-agents

        yield {"chunk":  {"type": "event", "eventKind": "AGENT_TURN_START"}}

        current_block_id = self._decide_starting_block_id(trace)
        sc = SequenceContext()

        logging.debug("BEFORE VALIDATION INITIAL MESSAGES: %s" % self.initial_messages)

        if self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses":
            logger.info("Starting agent turn with pending tool calls")

            validation_responses_block_handler = self.build_block_handler(current_block_id, sc)
            if not isinstance(validation_responses_block_handler, ReactBlockHandler):
                raise ValueError(f"Tool validation responses received by a block that does not support them: {validation_responses_block_handler.__class__}")

            with trace.subspan("DKU_AGENT_BLOCK_RUN") as block_trace:
                with block_trace.subspan("DKU_AGENT_REACT_ITERATION") as iteration_trace:
                    with iteration_trace.subspan("DKU_AGENT_TOOL_CALLS") as tools_trace:
                        current_tools = validation_responses_block_handler.load_tools()

                        # restore agentic loop state and retrieve pending tool calls and tool call validation infos
                        pending_tool_calls, validation_infos_map, partial_tool_outputs_map = self._initialise_from_tool_validations(current_tools, tools_trace, sc)

                        logger.info(f"Resuming interrupted agent iteration: {self.iteration_number}")
                        iteration_trace.attributes["iterationNumber"] = self.iteration_number
                        # FIXME @lavish-agents-hitl We set self.iteration number correctly here, from the memory fragment, but when we create the block handler in the loop below, it is reset to 1

                        # execute pending tool calls
                        if not all(pending_tool_call.block_id == current_block_id for pending_tool_call in pending_tool_calls):
                            raise ValueError(f"Received validated tool calls that are not for the current block ID ({current_block_id})")
                        tools_require_validation = yield from validation_responses_block_handler.play_pending_tool_calls(pending_tool_calls, tools_trace, validation_infos_map, partial_tool_outputs_map)
                        if tools_require_validation:
                            return
                        sc.generated_messages.extend(validation_responses_block_handler.block_turn_generated_messages)

        self.expand_and_filter_initial_messages()

        # Pre chain
        yield from self.run_pre_chain(sc, trace)

        # Main turn
        turn_stopped_for_hitl = False

        # Remember the last block from the previous turn, so custom code can use it
        if previous_turn_block_id := self.state_get(CURRENT_BLOCK_ID_STATE_KEY):
            sc.scratchpad[PREVIOUS_TURN_BLOCK_ID_SCRATCHPAD_KEY] = previous_turn_block_id

        logger.info("Blocks graph turn start, state is %s" % self.state())

        while True:
            with trace.subspan("DKU_AGENT_BLOCK_RUN") as block_trace:
                self.state_set(CURRENT_BLOCK_ID_STATE_KEY, current_block_id)
                block_trace.attributes["block_id"] = current_block_id
                block_trace.attributes["block_state"] = copy.deepcopy(self.state())

                yield {
                    "chunk": {
                        "type": "event",
                        "eventKind": "AGENT_BLOCK_START",
                        "eventData": {
                            "blockId": current_block_id,
                            "context": self.agent.redact_context(self.current_merged_context),
                        },
                    }
                }

                block_handler = self.build_block_handler(current_block_id, sc)

                logger.info("Blocks graph turn block, block handler %s" % block_handler)

                next_block = None
                for chunk in block_handler.process_stream(block_trace):
                    if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                        if chunk.text is not None:
                            yield {"chunk":  {"text": chunk.text}}
                    elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
                        yield chunk
                    elif isinstance(chunk, NextBlock):
                        next_block = chunk.id
                    elif isinstance(chunk, dict) and chunk.get("chunk") is not None:
                        if "toolValidationRequests" in chunk["chunk"]:
                            turn_stopped_for_hitl = True
                        yield chunk
                    else:
                        raise Exception("Unknown chunk type: %s" % chunk)

                yield {
                    "chunk": {
                        "type": "event",
                        "eventKind": "AGENT_BLOCK_DONE",
                        "eventData": {
                            "blockId": current_block_id,
                            "nextBlockId": next_block,
                            "context": self.agent.redact_context(self.current_merged_context),
                        },
                    }
                }

                logger.info("Blocks graph turn block done, next_block=%s" % next_block)

                if next_block is None:
                    logging.info("No next block, Turn is done")
                    break
                else:
                    self.state_set(CURRENT_BLOCK_ID_STATE_KEY, next_block)
                    current_block_id = next_block
                    continue

        logger.info("Turn done, stopped for Human-in-the-Loop: %s" % turn_stopped_for_hitl)
        if not turn_stopped_for_hitl:
            # Post chain
            yield from self.run_post_chain(sc, trace)

            if self.agent.config.get("shortTermMemoryEnabled") and sc.generated_messages:
                # Make a memory fragment from the current sequence context messages
                memory_fragment_chunk = {
                    "chunk": {
                        "type": "content",
                        "memoryFragment":  {
                            "agentLoopIteration" : self.iteration_number, # TODO @lavish-agents
                            "messages" : sc.generated_messages
                        }
                    }
                }
                yield memory_fragment_chunk

                yield {
                    "chunk": {
                        "type": "event",
                        "eventKind": "AGENT_EMIT_MEMORY_FRAGMENT",
                        "eventData": memory_fragment_chunk,
                    }
                }

        yield {
            "footer": {
                "contextUpsert": self.context_upsert,
                "additionalInformation": {
                    "sources": sc.sources
                }
            }
        }

    def _starting_block_smart_mode_classification(self, trace: SpanBuilder):
        with trace.subspan("DKU_AGENT_NEXT_TURN_SMART_MODE_LLM_CALL") as llm_trace:
            llm_id = self.agent.config.get("nextTurnSmartModeLLMId")
            if not llm_id:
                raise ValueError("Agent configuration is missing. Please set an LLM for next turn smart mode.")

            llm: DSSLLM = self.agent.project.get_llm(llm_id)
            completion = llm.new_completion()
            completion.with_context(self.current_merged_context)
            completion._settings = copy.deepcopy(self.agent.config.get("nextTurnSmartModeCompletionSettings", {}))
            completion.cq["messages"].extend(self.initial_messages)
            instructions = DEFAULT_NEXT_TURN_SMART_MODE_PROMPT
            if additional_instructions := self.agent.config.get("nextTurnSmartModeInstructionsAppend"):
                instructions += "\n\n" + additional_instructions
            completion.with_message(instructions, role="system")

            response = completion.execute()
            if response.trace:
                llm_trace.append_trace(response.trace)

            # Deliberately ignoring any artifacts and sources from the LLM

            return response.text

    def expand_and_filter_initial_messages(self):
        """
        Prepare the initial messages list by:
        - filtering out irrelevant messages used for tool call validations in past turns
        - filtering out short-term memory fragments beyond the memory horizon configured on the agent
        - expanding the messages stored within the short-term memory fragments that fall within the memory horizon
        """
        memory_horizon = self.agent.config.get("shortTermMemoryHorizon")
        if not self.agent.config.get("shortTermMemoryEnabled"):
            memory_horizon = 0

        expanded_memory_fragments_count = 0

        expanded_messages = deque()
        next_message_role = None
        for m in reversed(self.initial_messages):
            if m["role"] in ["toolValidationResponses", "toolValidationRequests"]:
                # Remove all messages relative to tool call validations in past turns, as they are not processable by the underlying LLM
                pass

            elif m["role"] == "memoryFragment":
                if next_message_role == "toolValidationRequests":
                    # Skip partial memory fragments that were used for HITL, as they are irrelevant now
                    pass
                elif next_message_role != "assistant":
                    raise ValueError(f"Memory fragment message before message with role {next_message_role} is not supported")
                elif memory_horizon is not None and expanded_memory_fragments_count >= memory_horizon:
                    # Skip memory fragments beyond the memory horizon
                    pass
                elif not (memory_fragment_messages := (m.get("memoryFragment") or {}).get("messages")):
                    raise ValueError("Memory fragment message received but the memory fragment is missing or empty")
                else:
                    # It is the "ADD_TO_MESSAGE" setting which determines which blocks should have their output remembered across blocks and across turns, not the "stream" option which only decides what should be streamed to the user.
                    # The latest assistant message was built by the client based on whatever text output has been streamed to them, so it is irrelevant here.
                    #
                    # Besides, the content of this externally-built assistant message includes all the aggregated text output across all blocks over the whole turn.
                    # Among all this text, the individual pieces of text output emitted by the different blocks that are worth remembering should have been included in the generated_messages already
                    # and those that haven't been added to the generated messages are not meant to be remembered.
                    #
                    # ==> If we rebuild the history from the memory fragment, then we remove the latest assistant message to avoid duplicating and misplacing intermediate messages
                    # For safety, we only do it if the last message of the memory fragment is itself an assistant message
                    if memory_fragment_messages[-1]["role"] == "assistant":
                        expanded_messages.popleft()

                    # Expand the short-term memory fragment
                    expanded_messages.extendleft(reversed(memory_fragment_messages))
                    expanded_memory_fragments_count += 1

            else:
                # Keep regular messages that are part of the user-visible chat history
                expanded_messages.appendleft(m)

            next_message_role = m["role"]

        self.initial_messages = list(expanded_messages)

    def run_pre_chain(self, sc: SequenceContext, trace: SpanBuilder):
        if self.agent.config.get("preChainHeadBlockId"):
            logger.info("Pre-chain head block configured, running it first")
            blocks_traversed = 0
            current_pre_block_id = self.agent.config["preChainHeadBlockId"]
            while current_pre_block_id:
                blocks_traversed += 1
                _MAX_CHAIN_LENGTH = 50
                if blocks_traversed > _MAX_CHAIN_LENGTH:
                    raise Exception(f"Pre-chain iteration exceeded maximum block chain length of {_MAX_CHAIN_LENGTH} (possible loop)")

                block_handler = self.build_block_handler(current_pre_block_id, sc)
                logger.info("Pre-chain running block %s", current_pre_block_id)
                next_block_id = None
                pending_external = False

                with trace.subspan("DKU_AGENT_PRE_BLOCK") as block_trace:
                    block_trace.attributes["block_id"] = current_pre_block_id
                    for chunk in block_handler.process_stream(block_trace):
                        if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                            yield chunk
                        elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
                            finish_reason = chunk.data.get("finishReason") if hasattr(chunk, "data") else None
                            if finish_reason == "tool_validation_requests":
                                pending_external = True
                            yield chunk
                        elif isinstance(chunk, NextBlock):
                            next_block_id = chunk.id
                        else:
                            yield chunk

                if pending_external:
                    raise Exception("Tool validation requests are not supported in pre-chain blocks")

                current_pre_block_id = next_block_id


    def run_post_chain(self, sc: SequenceContext, trace: SpanBuilder):
        if self.agent.config.get("postChainHeadBlockId"):
            logger.info("Post-chain head block configured, running it first")
            blocks_traversed = 0
            current_post_block_id = self.agent.config["postChainHeadBlockId"]

            while current_post_block_id:
                blocks_traversed += 1
                _MAX_CHAIN_LENGTH = 50
                if blocks_traversed > _MAX_CHAIN_LENGTH:
                    raise Exception(f"Post-chain iteration exceeded maximum block chain length of {_MAX_CHAIN_LENGTH} (possible loop)")

                block_handler = self.build_block_handler(current_post_block_id, sc)
                logger.info("Post-chain running block %s", current_post_block_id)
                next_block_id = None
                pending_external = False

                with trace.subspan("DKU_AGENT_POST_BLOCK") as block_trace:
                    block_trace.attributes["block_id"] = current_post_block_id
                    for chunk in block_handler.process_stream(block_trace):
                        if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                            yield chunk
                        elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
                            finish_reason = chunk.data.get("finishReason") if hasattr(chunk, "data") else None
                            if finish_reason == "tool_validation_requests":
                                pending_external = True
                            yield chunk
                        elif isinstance(chunk, NextBlock):
                            next_block_id = chunk.id
                        else:
                            yield chunk

                if pending_external:
                    raise Exception("Tool validation requests are not supported in post-chain blocks")

                current_post_block_id = next_block_id

    def _collect_tool_validation_responses(self, validation_responses_map=None) -> dict[str, ToolValidationResponse]:
        """
        Recursively pop and process consecutive messages with role = "toolValidationResponses" to collect all validation responses.
        Returns the collected validation responses, in a map indexed by validationRequestId.
        Optionally takes a pre-existing mapping to update, for recursive calling.
        """
        if validation_responses_map is None:
            validation_responses_map = {}

        if not (self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses"):
            raise ValueError("Tool validation response not found in the chat history")

        validation_responses_message = self.initial_messages.pop()

        validation_responses = validation_responses_message.get("toolValidationResponses")
        if not validation_responses:
            raise ValueError("Invalid tool validation response was received")
        for tc in validation_responses:
            validation_responses_map[tc["validationRequestId"]] = tc

        if self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationResponses":
            return self._collect_tool_validation_responses(validation_responses_map)
        else:
            return validation_responses_map

    def _collect_validation_requests(self) -> list[ToolValidationRequest]:
        """
        Pop and process a single, required, role = "toolValidationRequests" message to collect all validation requests.
        Returns the list of all collected validation requests.
        """
        if not (self.initial_messages and self.initial_messages[-1]["role"] == "toolValidationRequests"):
            raise ValueError("Tool validation response was received, but the original tool validation request wasn't provided in the chat history")

        validation_requests_message = self.initial_messages.pop()
        validation_requests = validation_requests_message.get("toolValidationRequests")
        if not validation_requests:
            raise ValueError("Tool validation response was received, but the tool validation requests in the chat history are empty")

        return validation_requests

    def _triage_validation_requests(self, validation_requests: list[ToolValidationRequest]) -> tuple[list[ToolValidationRequest], dict[str, list[ToolValidationRequest]]]:
        """
        Process the list of all validation requests previously collected, to separate those relative to this agent from those relative to nested tool calls.
        Returns the collected validation requests, in two separate structures:
        - own_validation_requests: a list of the validation requests that originated in this agent
        - deferred_validation_requests: a map of the validation requests that originated in nested tool calls, indexed by the tool call id.
        """
        own_validation_requests = []
        deferred_validation_requests = defaultdict(lambda: [])
        for tvr in validation_requests:
            hierarchy = tvr["hierarchy"]
            if not hierarchy:
                raise ValueError("No hierarchy found in tool validation request")

            agent_level = hierarchy.pop(0)
            if agent_level["type"] != "AGENT":
                raise ValueError("Invalid hierarchy in tool validation request")

            # sanity check: agentLoopIteration must match the restored iteration number from the memory fragment
            if "agentLoopIteration" not in agent_level or agent_level["agentLoopIteration"] is None:
                raise ValueError("Missing iteration number in tool validation request")
            if self.iteration_number == 0:
                raise ValueError("Agentic loop iteration number wasn't properly restored")
            if self.iteration_number != agent_level["agentLoopIteration"]:
                raise ValueError("Invalid iteration number in tool validation request")

            if len(hierarchy) > 0:
                tool_level = hierarchy.pop(0)
                if tool_level["type"] != "TOOL":
                    raise ValueError("Invalid hierarchy in tool validation request")
                deferred_validation_requests[tool_level["toolCallId"]].append(tvr)
            else:
                own_validation_requests.append(tvr)

        return own_validation_requests, deferred_validation_requests

    @staticmethod
    def _aggregate_own_validation_infos(raw_validation_responses_map: dict[str, ToolValidationResponse], validation_requests: list[ToolValidationRequest]) -> dict[str, ToolCallValidationInfo]:
        """
        Reconcile the validation responses and validation requests to return an aggregated map of tool call validation infos indexed by tool call id.
        """
        validation_infos_map: dict[str, ToolCallValidationInfo] = {}

        validation_request_ids = {vr["id"] for vr in validation_requests}
        for response_validation_request_id in raw_validation_responses_map.keys():
            if response_validation_request_id not in validation_request_ids:
                # validation response without a matching validation request
                raise ValueError(f"Tool validation response was received, but the corresponding tool validation request {response_validation_request_id} is missing")

        for validation_request in validation_requests:
            validation_request_id = validation_request["id"]
            tool_call_id = validation_request["toolCall"]["id"]
            if validation_request_id not in raw_validation_responses_map:
                # validation request without a matching validation response
                raise ValueError(f"No response provided for tool validation request {validation_request_id}")
            else:
                validation_response = raw_validation_responses_map[validation_request_id]
                validation_infos_map[tool_call_id] = ToolCallValidationInfo(
                    validated=validation_response["validated"],
                    allow_editing_inputs=validation_request["allowEditingInputs"],
                    edited_arguments=validation_response.get("arguments")
                )

        return validation_infos_map

    def _collect_and_triage_memory_fragments(self) -> tuple[MemoryFragment, dict[str, MemoryFragment]]:
        """
        Pop a single, required, role = "memoryFragment" message, and extract the memory fragments from its nested structure.
        Returns the collected memory fragments, in two separate structures:
        - own_memory_fragment: the (single) memory fragment encoding the past state of this agentic loop
        - deferred_memory_fragments: a map of the memory fragments that originated in nested tool calls, indexed by toolCallId
        """
        if not (self.initial_messages and self.initial_messages[-1]["role"] == "memoryFragment"):
            raise ValueError("Tool validation response was received but the memory fragment wasn't provided")

        own_memory_fragment_message = self.initial_messages.pop()
        own_memory_fragment = own_memory_fragment_message.get("memoryFragment")
        if not own_memory_fragment:
            raise ValueError("Tool validation response was received but the memory fragment wasn't provided")

        # extract nested memory fragments wrapped in messages
        own_memory_fragment_messages = own_memory_fragment.get("messages")
        if not own_memory_fragment_messages:
            raise ValueError("Tool validation response was received but the memory fragment was empty")

        deferred_memory_fragments: dict[str, MemoryFragment] = {}
        while own_memory_fragment_messages:
            if own_memory_fragment_messages[-1]["role"] != "memoryFragment":
                break
            memory_fragment_message = own_memory_fragment_messages.pop()
            memory_fragment = memory_fragment_message.get("memoryFragment")
            memory_fragment_target = memory_fragment_message.get("memoryFragmentTarget")
            if not memory_fragment:
                raise ValueError("Nested memory fragment message was empty")
            if not memory_fragment_target:
                raise ValueError("Missing target for nested memory fragment")
            if memory_fragment_target["type"] != "TOOL":
                raise ValueError("Invalid target type for nested memory fragment")
            if memory_fragment_target["toolCallId"] in deferred_memory_fragments:
                raise ValueError("There should only be one memory fragment per nested tool call")
            deferred_memory_fragments[memory_fragment_target["toolCallId"]] = memory_fragment

        return own_memory_fragment, deferred_memory_fragments

    def _restore_agentic_loop_state(self, memory_fragment: MemoryFragment, sc: SequenceContext) -> tuple[list[FunctionToolCall], dict[str, ToolOutput]]:
        """
        Restore the agentic loop state from a memory fragment, namely:
        - iteration_number
        - generated_messages
        - all_sources

        Also extract the curated list of tool calls still pending from the memory fragment, and returns this list.
        """

        # restore iteration number
        if "agentLoopIteration" not in memory_fragment or memory_fragment["agentLoopIteration"] is None:
            raise ValueError("Missing iteration number in memory fragment")
        self.iteration_number = memory_fragment["agentLoopIteration"]

        # process stashed messages
        memory_fragment_messages = memory_fragment.get("messages")
        if not memory_fragment_messages:
            raise ValueError("Tool validation response was received but the memory fragment was empty")

        # the last message should always be a partial tool outputs message
        partial_tool_outputs_message = memory_fragment_messages.pop()
        if partial_tool_outputs_message["role"] != "tool":
            raise ValueError("Tool validation response was received but the memory fragment is missing the partial tool outputs")

        # extract pending tool calls, which should be in the last message
        if not memory_fragment_messages:
            raise ValueError("Tool validation response was received but the memory fragment was incomplete")

        raw_pending_tool_calls: list[FunctionToolCall] = memory_fragment_messages[-1].get("toolCalls")
        if not raw_pending_tool_calls:
            raise ValueError("Tool validation response was received but the memory fragment did not contain any tool calls in its last message")

        # restore state of the agentic loop before the interruption
        sc.generated_messages.extend(memory_fragment_messages)

        memory_fragment_sources = memory_fragment.get("stashedSources") or []
        sc.sources.extend(memory_fragment_sources)

        # tool calls that already have an output are not pending anymore
        partial_tool_outputs = partial_tool_outputs_message.get("toolOutputs") or []
        partial_tool_outputs_map = {tool_output["callId"]: tool_output for tool_output in partial_tool_outputs}
        if partial_tool_outputs_map:
            still_pending = []
            for ptc in raw_pending_tool_calls:
                if ptc.get("id") not in partial_tool_outputs_map.keys():
                    still_pending.append(ptc)
            raw_pending_tool_calls = still_pending

        return raw_pending_tool_calls, partial_tool_outputs_map

    def _initialise_from_tool_validations(self, current_tools: dict[str, LLMTool], parent_trace: SpanBuilder, sc: SequenceContext) -> tuple[Collection[ToolCallWithPotentialValidation], dict[str, ToolCallValidationInfo], dict[str, ToolOutput]]:
        """
        Initialise the agentic loop so it can resume execution after an interruption for validating tool calls.
        - restore the agentic loop state to what it was before the interruption
        - collect the pending tool calls of the interrupted loop iteration so the agentic loop can start by running them
        - gather all the validation data relative to the pending tool calls

        Returns:
        - the list of pending tool calls
        - a map of the validation data relative to each pending tool call, indexed by tool call id
        """
        with parent_trace.subspan("DKU_AGENT_TOOL_CALLS_VALIDATIONS_CHECK") as validations_check_trace:
            # -------------------------------------------------
            # 1) collect validation responses
            raw_validation_responses_map = self._collect_tool_validation_responses()

            n_validation_responses = len(raw_validation_responses_map)
            n_accepted = sum(vr.get("validated") or 0 for vr in raw_validation_responses_map.values())
            n_rejected = n_validation_responses - n_accepted
            validations_check_trace.attributes["nbAccepted"] = n_accepted
            validations_check_trace.attributes["nbRejected"] = n_rejected
            # -------------------------------------------------

            # -------------------------------------------------
            # 2) collect validation requests
            all_validation_requests = self._collect_validation_requests()

            validations_check_trace.attributes["nbReceivedValidationRequests"] = len(all_validation_requests)
            # -------------------------------------------------

            # -------------------------------------------------
            # 3) collect memory fragments to
            # - extract the pending tool calls
            # - restore the state of the agentic loop before the interruption
            own_memory_fragment, deferred_memory_fragments = self._collect_and_triage_memory_fragments()

            raw_pending_tool_calls, partial_tool_outputs_map = self._restore_agentic_loop_state(own_memory_fragment, sc)
            # -------------------------------------------------

            # -------------------------------------------------
            # 4) validate and parse pending tool calls
            # and check that the pending tool calls target tools that are still supported by this agent
            pending_tool_calls_map: dict[str, ToolCallWithPotentialValidation] = {}
            for tool_call in raw_pending_tool_calls:
                parsed_tool_call = _validate_and_parse_tool_call(tool_call)
                if parsed_tool_call.name not in current_tools:
                    raise ValueError(f"Entering agentic loop with tool calls that are not supported by this agent: {parsed_tool_call.name}")

                pending_tool_calls_map[parsed_tool_call.id] = parsed_tool_call
            # -------------------------------------------------

            # -------------------------------------------------
            # 5) split self-owned from nested validation requests
            own_validation_requests, deferred_validation_requests = self._triage_validation_requests(all_validation_requests)
            # -------------------------------------------------

            # -------------------------------------------------
            # 6) restore memory fragments and validation requests/responses on nested tool calls
            for tool_call in pending_tool_calls_map.values():
                if memory_fragment := deferred_memory_fragments.get(tool_call.id):
                    tool_call.memory_fragment = memory_fragment
                if tool_validation_requests := deferred_validation_requests.get(tool_call.id):
                    tool_call.tool_validation_requests = tool_validation_requests
                    tool_call.tool_validation_responses = []
                    for tvr in tool_validation_requests:
                        # we pop the validation response here to be sure to leave only own validation responses in the map at the end of this block
                        if tool_validation_response := raw_validation_responses_map.pop(tvr["id"], None):
                            tool_call.tool_validation_responses.append(tool_validation_response)

            # 6.5) blocks agent - Set the block id on the pending tool calls, so that we know which block to send it to.
            for validation_request in own_validation_requests:
                tool_call_id = validation_request["toolCall"]["id"]
                if tool_call_id in pending_tool_calls_map:
                    pending_tool_calls_map[tool_call_id].block_id = validation_request["blockId"]
            for tool_call_id, deferred_requests in deferred_validation_requests.items():
                if len(deferred_requests) > 0:
                    pending_tool_calls_map[tool_call_id].block_id = deferred_requests[0]["blockId"]
            # -------------------------------------------------

            # -------------------------------------------------
            # 7) aggregate validation infos
            validation_infos_map: dict[str, ToolCallValidationInfo] = {}

            # interrupted tool calls (nested validations)
            for tool_call_id in deferred_validation_requests.keys():
                # if this tool call has a deferred validation request, then we must validate the tool call at this agent level
                # these tool calls have already been started on the previous turn, so they either did not require validation or they have already been validated on the previous turn
                validation_infos_map[tool_call_id] = ToolCallValidationInfo(validated=True, allow_editing_inputs=False)

            # not yet started tool calls (own validations)
            own_validation_infos_map = self._aggregate_own_validation_infos(raw_validation_responses_map, own_validation_requests)
            for tool_call_id, validation_info in own_validation_infos_map.items():
                if tool_call_id in validation_infos_map:
                    raise ValueError(f"Incompatible hierarchies of the tool validation requests, tool call {tool_call_id} received a deferred validation but also received a validation for itself at the same time")
                validation_infos_map[tool_call_id] = validation_info

            # all tool calls still pending at this point should have been covered by the above
            for tool_call_id in pending_tool_calls_map.keys():
                if tool_call_id not in validation_infos_map:
                    raise ValueError(f"Pending tool call {tool_call_id} is missing validation request data")
            # -------------------------------------------------

            # -------------------------------------------------
            # 8) fix tool calls with edited inputs
            for tool_call_id, validation_info in validation_infos_map.items():
                # check that validated tool calls are all about pending tool calls
                if tool_call_id not in pending_tool_calls_map:
                    # validated tool call (request + response) but no matching pending tool call
                    raise ValueError(f"Tool validation response for tool call {tool_call_id} does not match any of the pending tool calls in the memory fragment")

                # if allowEditingInputs, update the function arguments with those from the validation response
                # we should do a proper update here, not a full replacement
                if validation_info.edited_arguments is not None:
                    pending_tool_call = pending_tool_calls_map[tool_call_id]
                    dku_tool_call = pending_tool_call.dku_tool_call
                    allow_editing_inputs = False
                    if validation_info.allow_editing_inputs:
                        tool_currently_allow_editing_inputs = current_tools[pending_tool_call.name].allow_editing_inputs
                        if tool_currently_allow_editing_inputs:
                            allow_editing_inputs = current_tools[pending_tool_call.name].allow_editing_inputs
                            dku_tool_call["function"]["arguments"] = validation_info.edited_arguments
                            new_parsed_tool_call = _validate_and_parse_tool_call(dku_tool_call)
                            new_parsed_tool_call.block_id = pending_tool_call.block_id
                            new_parsed_tool_call.memory_fragment = pending_tool_call.memory_fragment
                            new_parsed_tool_call.tool_validation_requests = pending_tool_call.tool_validation_requests
                            new_parsed_tool_call.tool_validation_responses = pending_tool_call.tool_validation_responses
                            pending_tool_calls_map[tool_call_id] = new_parsed_tool_call
                    if not allow_editing_inputs:
                        original_arguments = dku_tool_call["function"]["arguments"] or "{}"
                        input_was_edited = json.loads(original_arguments, strict=False) != json.loads(validation_info.edited_arguments, strict=False)
                        if input_was_edited:
                            raise ValueError(f"Editing this tool's inputs is not allowed{' anymore' if validation_info.allow_editing_inputs else ''}")
            # -------------------------------------------------

            return pending_tool_calls_map.values(), validation_infos_map, partial_tool_outputs_map
