from typing import Any, Dict

from kuzu import QueryResult, Type
from networkx import MultiDiGraph

from ..graph_builder import GROUP_ID_PROP_NAME


def get_as_networkx(query_result: QueryResult) -> MultiDiGraph:
    """
    This code is heavily inspired from kuzu.QueryResult.get_as_networkx method and changed to fit our own usage.
    """
    query_result.check_for_query_result_close()
    properties_to_extract = query_result._get_properties_to_extract()

    query_result.reset_iterator()

    nodes = {}
    rels = {}
    table_primary_key_dict: Dict[str, str] = {}

    def encode_node_id(node: dict[str, Any]) -> str:
        node_label = node["_label"]
        group_id = node[GROUP_ID_PROP_NAME]
        return f"{group_id}~{node[table_primary_key_dict[node_label]]!s}"

    def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
        return rel["_id"]["table"], rel["_id"]["offset"]

    # De-duplicate nodes and rels
    while query_result.has_next():
        row = query_result.get_next()
        if not isinstance(row, list):
            raise Exception("Unexpected type when turning result in a networkx graph.")

        for i in properties_to_extract:
            # Skip empty nodes and rels, which may be returned by
            # OPTIONAL MATCH
            if row[i] is None or row[i] == {}:
                continue
            column_type, _ = properties_to_extract[i]
            if column_type == Type.NODE.value:
                _id = row[i]["_id"]
                nodes[(_id["table"], _id["offset"])] = row[i]

            elif column_type == Type.REL.value:
                _src = row[i]["_src"]
                _dst = row[i]["_dst"]
                rels[encode_rel_id(row[i])] = row[i]

            elif column_type == Type.RECURSIVE_REL.value:
                for node in row[i]["_nodes"]:
                    _id = node["_id"]
                    nodes[(_id["table"], _id["offset"])] = node
                for rel in row[i]["_rels"]:
                    for key in list(rel.keys()):
                        if rel[key] is None:
                            del rel[key]
                    _src = rel["_src"]
                    _dst = rel["_dst"]
                    rels[encode_rel_id(rel)] = rel

    import networkx as nx

    nx_graph: nx.MultiDiGraph = nx.MultiDiGraph()

    # Add nodes
    for node in nodes.values():
        node_label = node["_label"]
        if node_label not in table_primary_key_dict:
            props = query_result.connection._get_node_property_names(node_label)  # type: ignore
            for prop_name in props:
                if props[prop_name]["is_primary_key"]:
                    table_primary_key_dict[node_label] = prop_name
                    break
        node_id = encode_node_id(node)
        del node["_id"]
        nx_graph.add_node(node_id, **node)

    # Add rels
    for rel in rels.values():
        _src = rel["_src"]
        _dst = rel["_dst"]
        src_node = nodes[(_src["table"], _src["offset"])]
        dst_node = nodes[(_dst["table"], _dst["offset"])]
        src_id = encode_node_id(src_node)
        dst_id = encode_node_id(dst_node)
        nx_graph.add_edge(src_id, dst_id, **rel)
    return nx_graph
