import logging
from typing import Any, Generator, Set

import kuzu
import pandas as pd
from pydantic import BaseModel, Field, PositiveFloat, PositiveInt

logger = logging.getLogger(__name__)


class ProjectedGraph:
    def __init__(self, name: str, node_groups: Set[str], edge_groups: Set[str]):
        self.name = name
        self.node_groups = node_groups
        self.edge_groups = edge_groups

    def create(self, conn: kuzu.Connection) -> None:
        node_groups_str = ", ".join(f"'{group}'" for group in self.node_groups)
        edge_groups_str = ", ".join(f"'{group}'" for group in self.edge_groups)

        query = f"""
        CALL project_graph(
            '{self.name}',
            [{node_groups_str}],
            [{edge_groups_str}]
        )
        """
        logger.info(f"Creating projected graph with query: {query}.")
        conn.execute(query)


def install_algo_extension(conn: kuzu.Connection) -> None:
    conn.execute("INSTALL ALGO; LOAD ALGO;")
    logger.info("Kuzu algorithm extension installed successfully.")


class PageRankParams(BaseModel):
    damping_factor: float = Field(..., ge=0.0, le=1.0)
    max_iterations: PositiveInt
    tolerance: PositiveFloat
    normalizeInitial: bool


def __get_node_groups_info__(conn: kuzu.Connection) -> dict[str, str]:
    groups_name_df: pd.DataFrame = conn.execute("CALL show_tables() WHERE type = 'NODE' RETURN name;").get_as_df()  # type: ignore
    result = {}
    for _, row in groups_name_df.iterrows():
        group_name = row["name"]  # type: ignore
        groups_info_df = conn.execute(f'CALL TABLE_INFO("{group_name}") RETURN *;').get_as_df()  # type: ignore
        pk_name = groups_info_df[groups_info_df["primary key"] == True].iloc[0]["name"]  # noqa: E712

        result[group_name] = pk_name

    return result


def compute_page_rank(
    conn: kuzu.Connection, projected_graph: ProjectedGraph, params: PageRankParams, batch_size: int = 10000
) -> Generator[pd.DataFrame, Any, None]:
    """
    Returns:
        a generator that yields DataFrames containing the columns: "rank", "node_group", and "node_id".
    """
    install_algo_extension(conn)
    projected_graph.create(conn)

    node_groups_info = __get_node_groups_info__(conn)

    offset = 0
    stop = False
    while not stop:
        query = f"""
        CALL page_rank(
            '{projected_graph.name}',
            dampingFactor := {params.damping_factor},
            maxIterations := {params.max_iterations},
            tolerance := {params.tolerance},
            normalizeInitial := {params.normalizeInitial}
        )
        RETURN node, rank, OFFSET(ID(node)) AS internal_offset
        ORDER BY internal_offset DESC
        SKIP {offset}
        LIMIT {batch_size};
        """
        logger.debug(f"Executing query {query}.")
        result = conn.execute(query)
        try:
            if isinstance(result, kuzu.QueryResult):
                df = result.get_as_df()
                if df.empty:
                    logger.info("Batch is empty, not more results to iterate over.")
                    stop = True
                else:
                    logger.debug(f"Processing results of length {len(df)}.")

                    df["node_group"] = df["node"].apply(lambda x: x.get("_label") if isinstance(x, dict) else pd.NA)
                    df["node_id"] = df["node"].apply(
                        lambda x: x.get(node_groups_info.get(x.get("_label"))) if isinstance(x, dict) else pd.NA  # type: ignore
                    )

                    df.drop(columns=["node"], inplace=True)
                    df.drop(columns=["internal_offset"], inplace=True)

                    yield df

                    if len(df) < batch_size:
                        logger.info("Batch is below max batch size, no more results to iterate over.")
                        stop = True
                    offset += batch_size
            else:
                raise Exception("Unexpected kuzu result type. This is likely a bug.")
        finally:
            # Still unsure of the origin:
            # if I dont explicitely close the QueryResult manually, it fails with an error regarding semaphores.
            # Might be due to the order in which the connection and the QueryResult are closed when holding variables go out-of-scope.
            if result:
                if isinstance(result, kuzu.QueryResult):
                    try:
                        result.close()
                    except Exception as ex:
                        logger.exception("Failed to close the query result.")
                else:
                    for q in result:
                        try:
                            q.close()
                        except Exception as ex:
                            logger.exception("Failed to close one of the query result.")
