from __future__ import annotations

import logging
import tempfile
from collections import defaultdict
from typing import Dict, Generator, List, NamedTuple, Set, TypedDict

import kuzu
from dataiku import Dataset

from editor.backend.utils.webapp_config import webapp_config
from solutions.graph.graph_db_instance_manager import AbstractDbInstance, KuzuConnectionContextManager

from .dataiku.batch_generators import (
    calculate_estimated_total_batches,
    generate_edges_batches,
    generate_nodes_batches,
)
from .kuzu.batch_insert import batch_insert_edges, batch_insert_nodes
from .kuzu.schema.data_types import TYPE_MAPPING_DATAIKU_KUZU, DataikuDataType, KuzuDataType
from .kuzu.schema.table_builder import EdgeTableBuilder, NodeTableBuilder
from .models import (
    EdgeGroupDefinition,
    GraphMetadata,
    NodeGroupDefinition,
    Sampling,
    to_edge_group_definitions,
    to_node_group_definitions,
)

logger = logging.getLogger(__name__)


class PrimaryKeyParams(NamedTuple):
    primary_key: str
    primary_key_type: KuzuDataType


class PropertyParams(NamedTuple):
    property_name: str
    property_type: KuzuDataType


class NodeTableCreateParams:
    def __init__(self) -> None:
        self.possible_primary_keys = set()
        self.properties = set()

    table: str
    possible_primary_keys: set[PrimaryKeyParams]
    properties: set[PropertyParams]


class EdgeTableCreateParams:
    def __init__(self) -> None:
        self.properties = set()

    source_group: str
    target_group: str
    table: str
    properties: set[PropertyParams]


NODE_PRIMARY_KEY_PROP_NAME_PREFIX: str = "_dku_id"
GROUP_ID_PROP_NAME: str = "_dku_grp_id"
GROUP_DEFINITION_ID_PROP_NAME: str = "_dku_grp_def_id"
USER_DEFINED_PRIMARY_KEY_PROP_NAME: str = "_dku_reserved_user_defined_pk_prop"


class ProgressInfo(TypedDict):
    stage: str  # "nodes" or "edges"
    current_step: int
    total_steps: int
    current_definition: str
    percentage: float


class GraphBuilder:
    def __init__(self, db_instance: AbstractDbInstance) -> None:
        self.__db_instance = db_instance
        self.__kuzu_connection_context_manager: KuzuConnectionContextManager | None
        self.__connection: kuzu.Connection | None
        self.__temporary_folder: tempfile.TemporaryDirectory[str] | None = None

        logger.debug(f"Building the graph at {db_instance.get_db_path()}.")

    def __enter__(self) -> GraphBuilder:
        self.__kuzu_connection_context_manager = self.__db_instance.get_new_conn()
        self.__connection = self.__kuzu_connection_context_manager.open()

        # Temporary folder to store temporary CSV files before inserting them in the Kuzu database.
        self.__temporary_folder = tempfile.TemporaryDirectory()
        return self

    def __exit__(self, exception_type, exception_value, exception_traceback):
        try:
            if self.__connection:
                self.__connection.close()
        finally:
            if self.__kuzu_connection_context_manager:
                self.__kuzu_connection_context_manager.close()

            if self.__temporary_folder:
                self.__temporary_folder.cleanup()

    def insert_all(self, graph_metadata: GraphMetadata, track_progress: bool) -> Generator[ProgressInfo, None, None]:
        """
        If track_progress is True, we will compute and yield progress.
        Raises:
            RuntimeError
        """
        nodes_definitions: List[NodeGroupDefinition] = []
        for node in graph_metadata["nodes"].values():
            nodes_definitions.extend(to_node_group_definitions(node))

        edges_definitions: List[EdgeGroupDefinition] = []
        for edge in graph_metadata["edges"].values():
            edges_definitions.extend(to_edge_group_definitions(edge, graph_metadata))

        # Only calculate estimated total batches if we need to track progress
        if not track_progress:
            logger.info("Progress tracking disabled. Processing without progress updates.")
            for _ in self.insert_nodes(nodes_definitions, graph_metadata["sampling"]):
                pass

            for _ in self.insert_edges(edges_definitions, graph_metadata["sampling"]):
                pass
            return

        total_overall_batches = calculate_estimated_total_batches(
            nodes_definitions, edges_definitions, graph_metadata["sampling"]
        )

        # If total_overall_batches is -1, metrics are not available and we triggered computation
        # Don't show progress in this case, just process normally without yielding progress
        if total_overall_batches == -1:
            logger.info(
                "Dataset metrics not available, triggered metrics computation. Processing without progress updates."
            )
            for _ in self.insert_nodes(nodes_definitions, graph_metadata["sampling"]):
                pass

            for _ in self.insert_edges(edges_definitions, graph_metadata["sampling"]):
                pass
            return

        # Else, we can show progress
        current_overall_batch = 0

        # Process nodes first
        for progress in self.insert_nodes(nodes_definitions, graph_metadata["sampling"]):
            current_overall_batch += 1
            # Update progress to reflect overall progress
            overall_progress: ProgressInfo = {
                "stage": progress["stage"],
                "current_step": current_overall_batch,
                "total_steps": total_overall_batches,
                "current_definition": progress["current_definition"],
                "percentage": (current_overall_batch / total_overall_batches) * 100.0
                if total_overall_batches > 0
                else 100.0,
            }
            yield overall_progress

        # Then process edges
        for progress in self.insert_edges(edges_definitions, graph_metadata["sampling"]):
            current_overall_batch += 1
            # Update progress to reflect overall progress
            overall_progress = {
                "stage": progress["stage"],
                "current_step": current_overall_batch,
                "total_steps": total_overall_batches,
                "current_definition": progress["current_definition"],
                "percentage": (current_overall_batch / total_overall_batches) * 100.0
                if total_overall_batches > 0
                else 100.0,
            }
            yield overall_progress

    def insert_nodes(
        self, definitions: List[NodeGroupDefinition], sampling: Sampling | None
    ) -> Generator[ProgressInfo, None, None]:
        """
        Insert nodes in the Kuzu database based on each node definition passed as parameter.
        Raises:
            RuntimeError
        """
        if not self.__connection or not self.__temporary_folder:
            raise RuntimeError("GraphBuilder should be initialized in a context manager.")

        # Create kuzu tables based on node definitions passed as argument.
        tables_to_create: List[NodeTableCreateParams] = self.__infer_node_tables_to_create__(definitions)

        for params in tables_to_create:
            # Validate primary keys defined in definitions have identical data type.
            pk_unique_types = {pk.primary_key_type for pk in params.possible_primary_keys}
            assert len(pk_unique_types) == 1, f"Conflicting types for primary key {params}."

            self.__validate_consistent_property_type__(params.properties)

            (
                NodeTableBuilder(self.__connection)
                .with_table_name(params.table)
                .with_primary_key(
                    self.__get_table_id_prop_name__(params.table),
                    next(iter(params.possible_primary_keys)).primary_key_type,
                )
                .with_property(USER_DEFINED_PRIMARY_KEY_PROP_NAME, "STRING")
                .with_property(GROUP_ID_PROP_NAME, "STRING")
                .with_property(GROUP_DEFINITION_ID_PROP_NAME, "STRING")
                .with_properties(
                    [(prop.primary_key, prop.primary_key_type) for prop in list(params.possible_primary_keys)]
                )
                .with_properties([(prop.property_name, prop.property_type) for prop in list(params.properties)])
                .create()
            )

        # Insert nodes in kuzu by iterating through all relevant datasets.
        logger.info(f"Start inserting nodes...")
     
        
        current_batch = 0
        for batch in generate_nodes_batches(definitions, sampling):
            definition = batch["definition"]
            node_group = definition["node_group"]

            df = batch["df"]
            primary_col = definition["primary_col"]
            df[self.__get_table_id_prop_name__(node_group)] = df[primary_col]
            df[GROUP_ID_PROP_NAME] = definition["node_id"]
            df[GROUP_DEFINITION_ID_PROP_NAME] = definition["definition_id"]
            df[USER_DEFINED_PRIMARY_KEY_PROP_NAME] = primary_col

            batch_insert_nodes(self.__temporary_folder.name, self.__connection, node_group, df)

            current_batch += 1
            yield {
                "stage": "nodes",
                "current_step": current_batch,
                "total_steps": 0,
                "current_definition": node_group,
                "percentage": 0,
            }

        logger.info(f"Done inserting nodes.")

    def insert_edges(
        self, definitions: List[EdgeGroupDefinition], sampling: Sampling | None
    ) -> Generator[ProgressInfo, None, None]:
        """
        Insert edges in the Kuzu database based on each node definition passed as parameter.
        Raises:
            RuntimeError
        """
        if not self.__connection or not self.__temporary_folder:
            raise RuntimeError("GraphBuilder should be initialized in a context manager.")

        # Create kuzu tables based on edge definitions passed as argument.
        tables_to_create: List[EdgeTableCreateParams] = self.__infer_edge_tables_to_create__(definitions)

        for params in tables_to_create:
            self.__validate_consistent_property_type__(params.properties)

            (
                EdgeTableBuilder(self.__connection)
                .with_table_name(params.table)
                .from_node(params.source_group)
                .to_node(params.target_group)
                .with_property(GROUP_ID_PROP_NAME, "STRING")
                .with_property(GROUP_DEFINITION_ID_PROP_NAME, "STRING")
                .with_properties([(prop.property_name, prop.property_type) for prop in list(params.properties)])
                .create()
            )

        # Insert edges in kuzu by iterating through all relevant datasets.
        logger.info("Start inserting edges...")
        
        current_batch = 0
        for batch in generate_edges_batches(definitions, sampling):
            definition = batch["definition"]
            edge_group = definition["edge_group"]

            df = batch["df"]
            df[GROUP_ID_PROP_NAME] = definition["edge_id"]
            df[GROUP_DEFINITION_ID_PROP_NAME] = definition["definition_id"]

            filters_props = [f["column"] for f in definition["filters_stored"]]
            properties = (
                [GROUP_ID_PROP_NAME, GROUP_DEFINITION_ID_PROP_NAME] + definition["property_list"] + filters_props
            )

            batch_insert_edges(self.__temporary_folder.name, self.__connection, edge_group, properties, df)

            current_batch += 1
            yield {
                "stage": "edges",
                "current_step": current_batch,
                "total_steps": 0,
                "current_definition": edge_group,
                "percentage": 0,
            }

        logger.info(f"Done inserting edges.")

    def delete_group(self, group: str) -> None:
        if not self.__connection:
            raise Exception("GraphBuilder should be initialized in a context manager.")

        self.__connection.execute(f"""DROP TABLE IF EXISTS `{group}`""")

    def __validate_consistent_property_type__(self, properties: Set[PropertyParams]):
        # Validate properties with the same name have identical data types.
        property_mapping = defaultdict(set)
        for property in properties:
            property_mapping[property.property_name].add(property.property_type)
        assert all(
            len(types) == 1 for types in property_mapping.values()
        ), f"Conflicting types for property: {property_mapping}"

    def __infer_node_tables_to_create__(self, definitions: List[NodeGroupDefinition]) -> List[NodeTableCreateParams]:
        all_datasets = set([d["source_dataset"] for d in definitions])
        ds_schemas: Dict[str, Dict[str, DataikuDataType]] = {
            ds: {col["name"]: col["type"] for col in Dataset(ds, webapp_config.default_project_key).read_schema()}
            for ds in all_datasets
        }

        tables_to_create: Dict[str, NodeTableCreateParams] = {}
        for definition in definitions:
            node_group = definition["node_group"]
            dataset = definition["source_dataset"]
            ds_schema = ds_schemas[dataset]

            params = tables_to_create[node_group] if node_group in tables_to_create else NodeTableCreateParams()
            params.table = node_group

            pk_column = definition["primary_col"]
            params.possible_primary_keys.add(
                PrimaryKeyParams(pk_column, TYPE_MAPPING_DATAIKU_KUZU[ds_schema[pk_column]])
            )

            label_col = definition["label_col"]
            params.properties.add(PropertyParams(label_col, TYPE_MAPPING_DATAIKU_KUZU[ds_schema[label_col]]))

            params.properties.update(
                [PropertyParams(p, TYPE_MAPPING_DATAIKU_KUZU[ds_schema[p]]) for p in definition["property_list"]]
            )

            params.properties.update(
                [
                    PropertyParams(p["column"], TYPE_MAPPING_DATAIKU_KUZU[ds_schema[p["column"]]])
                    for p in definition["filters_stored"]
                ]
            )

            tables_to_create[node_group] = params

        return list(tables_to_create.values())

    def __infer_edge_tables_to_create__(self, definitions: List[EdgeGroupDefinition]) -> List[EdgeTableCreateParams]:
        all_datasets = set([d["edge_dataset"] for d in definitions])
        ds_schemas: Dict[str, Dict[str, DataikuDataType]] = {
            ds: {col["name"]: col["type"] for col in Dataset(ds, webapp_config.default_project_key).read_schema()}
            for ds in all_datasets
        }

        tables_to_create: Dict[str, EdgeTableCreateParams] = {}
        for definition in definitions:
            edge_group = definition["edge_group"]
            dataset = definition["edge_dataset"]
            ds_schema = ds_schemas[dataset]

            params = tables_to_create[edge_group] if edge_group in tables_to_create else EdgeTableCreateParams()
            params.table = edge_group
            params.source_group = definition["source_node_group"]
            params.target_group = definition["target_node_group"]

            params.properties.update(
                [PropertyParams(p, TYPE_MAPPING_DATAIKU_KUZU[ds_schema[p]]) for p in definition["property_list"]]
            )
            params.properties.update(
                [
                    PropertyParams(p["column"], TYPE_MAPPING_DATAIKU_KUZU[ds_schema[p["column"]]])
                    for p in definition["filters_stored"]
                ]
            )

            tables_to_create[edge_group] = params

        return list(tables_to_create.values())

    def __get_table_id_prop_name__(self, table_name: str) -> str:
        return f"{NODE_PRIMARY_KEY_PROP_NAME_PREFIX}_{table_name}"
