from __future__ import annotations

import logging
from typing import Any, Dict, List

import kuzu
from networkx import MultiDiGraph
from pandas import DataFrame

from ..graph_builder import GROUP_DEFINITION_ID_PROP_NAME
from ..graph_db_instance_manager import AbstractDbInstance
from ..kuzu.query_result_extensions import get_as_networkx
from ..models import (
    EdgeGroupDefinition,
    EdgeGroupId,
    GraphDBDoesNotExistError,
    NodeGroupDefinition,
    NodeGroupId,
)

logger = logging.getLogger(__name__)


class CypherQueryExecutionException(Exception):
    """
    The goal of this exception is to discriminate between errors occuring because of invalid user queries and actual misbehaviors from VGE.
    """

    pass


class CypherGraphView:
    def __init__(
        self,
        db_instance: AbstractDbInstance,
        node_definitions: List[NodeGroupDefinition],
        edge_definitions: List[EdgeGroupDefinition],
    ) -> None:
        self.__db_instance = db_instance
        self.__node_definitions = node_definitions
        self.__edge_definitions = edge_definitions

        self.__graph: MultiDiGraph = MultiDiGraph()
        self.__df_result: DataFrame = DataFrame()

    def get_nodes(self, node_id: NodeGroupId) -> List:
        return [self.__graph.nodes[id] for id, attr in self.__graph.nodes(data=True) if attr.get("group") == node_id]

    def get_edges(self, edge_id: EdgeGroupId) -> List:
        return [
            {"key": key, **data}
            for s, t, key, data in self.__graph.edges(data=True, keys=True)  # type: ignore
            if data["group"] == edge_id
        ]

    def get_as_df(self) -> DataFrame:
        return self.__df_result

    def execute(self, query: str, params: Dict[str, Any] | None = None, timeout_seconds: int | None = None) -> None:
        """
        Raises:
            GraphDatabaseDoesNotExistError
            GraphDBWriteInProgressError
            GraphDBReadInProgressError
            CypherQueryExecutionException
        """
        self.__graph = MultiDiGraph()

        if not self.__db_instance.exists():
            raise GraphDBDoesNotExistError

        with self.__db_instance.get_new_conn(timeout_seconds) as conn_context_manager:
            result: kuzu.QueryResult | List[kuzu.QueryResult] | None = None
            try:
                result = conn_context_manager.connection.execute(query, params)
                graph: MultiDiGraph
                if isinstance(result, kuzu.QueryResult):
                    graph = get_as_networkx(result)

                    self.__df_result = result.get_as_df()
                else:
                    raise Exception("Unexpected kuzu result type.")
            except Exception as ex:
                logger.info(f"Failed to execute query {query}.")
                raise CypherQueryExecutionException(str(ex))
            finally:
                # Still unsure of the origin:
                # if I dont explicitely close the QueryResult manually, it fails with an error regarding semaphores.
                # Might be due to the order in which the connection and the QueryResult are closed when holding variables go out-of-scope.
                if result:
                    if isinstance(result, kuzu.QueryResult):
                        try:
                            result.close()
                        except Exception as ex:
                            logger.exception("Failed to close the query result.")
                    else:
                        for q in result:
                            try:
                                q.close()
                            except Exception as ex:
                                logger.exception("Failed to close one of the query result.")

        self.__load_node_groups_from_graph__(graph, self.__node_definitions)
        self.__load_edge_groups_from_graph__(graph, self.__edge_definitions)

    def __load_node_groups_from_graph__(
        self, graph: MultiDiGraph, group_definitions: List[NodeGroupDefinition]
    ) -> None:
        for node_instance_id, data in graph.nodes(data=True):
            node_definition_id = data[GROUP_DEFINITION_ID_PROP_NAME]
            node_definition = next(
                (definition for definition in group_definitions if definition["definition_id"] == node_definition_id),
                None,
            )
            assert node_definition

            label_column = node_definition["label_col"]
            node_group = node_definition["node_group"]

            del data["_label"]
            if not self.__graph.has_node(node_instance_id):
                self.__graph.add_node(
                    node_instance_id,
                    label=data[label_column],
                    id=node_instance_id,
                    name=node_group,
                    group=node_definition["node_id"],
                    **{"properties": data},
                )

    def __load_edge_groups_from_graph__(
        self, graph: MultiDiGraph, group_definitions: List[EdgeGroupDefinition]
    ) -> None:
        for edge_data in graph.edges(data=True):
            source_node_instance_id = edge_data[0]
            target_node_instance_id = edge_data[1]

            edge_definition_id = edge_data[2][GROUP_DEFINITION_ID_PROP_NAME]
            edge_definition = next(
                (definition for definition in group_definitions if definition["definition_id"] == edge_definition_id),
                None,
            )
            assert edge_definition

            edge_instance = {k: v for k, v in edge_data[2].items() if k in edge_definition["property_list"]}
            if self.__graph.has_node(source_node_instance_id) and self.__graph.has_node(target_node_instance_id):
                # do the following as from is a python keyword.
                from_node_instance_id = {"from": source_node_instance_id}
                self.__graph.add_edge(
                    source_node_instance_id,
                    target_node_instance_id,
                    key=None,
                    id=source_node_instance_id + target_node_instance_id,
                    group=edge_definition["edge_id"],
                    to=target_node_instance_id,
                    **from_node_instance_id,
                    **{"properties": edge_instance},
                )
