import logging
from concurrent.futures import ThreadPoolExecutor, as_completed

from dataikuapi.dss.llm import DSSLLMStreamedCompletionFooter, DSSLLMStreamedCompletionChunk
from dataikuapi.dss.llm_tracing import SpanBuilder
from dataiku.llm.python.blocks_graph import NextBlock, BlockHandler


logger = logging.getLogger("dku.agents.blocks_graph")


class ParallelBlockHandler(BlockHandler):
    def __init__(self, turn, sequence_context, block_config):
        super().__init__(turn, sequence_context, block_config)

    def run_chain(self, trace, block_id, base_sequence_context):
        while True:
            sc = base_sequence_context.copy()
            block_handler = self.turn.build_block_handler(block_id, sc)

            logger.info("Parallel Block running subblock %s, block handler %s" % (block_id, block_handler))

            with trace.subspan("DKU_AGENT_PARALLEL_SUB_BLOCK") as block_trace:
                block_trace.attributes["block_id"] = block_id
                accumulated_text = ""
                footer = None
                next_block_id = None
                for chunk in block_handler.process_stream(block_trace):
                    if isinstance(chunk, DSSLLMStreamedCompletionChunk):
                        if chunk.text is not None:
                            accumulated_text += chunk.text
                    elif isinstance(chunk, DSSLLMStreamedCompletionFooter):
                        finish_reason = chunk.data.get("finishReason") if hasattr(chunk, "data") else None
                        if finish_reason == "tool_validation_requests":
                            # TODO: @structured-visual-agents: implement HITL in parallel blocks
                            raise Exception("Tool call requires human validation. Currently not supported in parallel blocks.")
                    elif isinstance(chunk, dict) and chunk.get("chunk", {}).get("text") is not None:
                        accumulated_text += chunk["chunk"]["text"]
                    elif isinstance(chunk, NextBlock):
                        next_block_id = chunk.id
                    else:
                        logger.warning("Unexpected chunk type in parallel block: %s" % chunk)

                if next_block_id is None:
                    logging.info("No next block from subblock %s" % block_id)
                    return sc
                else:
                    logging.info("Chaining to next block %s from subblock %s" % (next_block_id, block_id))
                    block_id = next_block_id
                    continue

    def process_stream(self, trace: SpanBuilder):
        logger.info("Parallel block starting with config %s" % self.block_config)

        block_ids = self.block_config["blockIds"]
        if not block_ids:
            yield NextBlock(id=self.block_config.get("nextBlock", None))
            return
        if self.block_config.get("id") in block_ids:
            raise Exception("Parallel block cannot execute itself")

        def run_subchain(block_id):
            logger.info("Parallel Block running subchain %s", block_id)
            with trace.subspan("DKU_AGENT_PARALLEL_SUB_CHAIN") as block_trace:
                return self.run_chain(block_trace, block_id, self.sequence_context)

        outputs = []
        max_workers = min(len(block_ids), int(self.block_config["maxThreads"]))
        with ThreadPoolExecutor(max_workers=max_workers) as executor:
            futures = [executor.submit(run_subchain, block_id) for block_id in block_ids]
            for future in as_completed(futures):
                sc = future.result()
                if sc is None:
                    continue
                self.sequence_context.generated_messages.extend(sc.generated_messages)
                if sc.last_text_output is not None:
                    outputs.append(sc.last_text_output)

        output_location = self.block_config["generatedOutputStorageLocation"]
        output_key = self.block_config["targetOutputKey"]
        if output_location == "SCRATCHPAD":
            self.sequence_context.scratchpad[output_key] = outputs
        elif output_location == "STATE":
            self.turn.state_set(output_key, outputs)

        logger.info("Parallel block is over")
        yield NextBlock(id=self.block_config.get("nextBlock", None))
