from __future__ import annotations

import logging
import time
from typing import Any, Dict, List, Optional

from solutions.graph.graph_db_instance_manager import AbstractDbInstance, EditorWebAppDbInstance
from solutions.graph.kuzu.queries.node_adjacency import AdjacentNodeGroupInfo, get_adjacent_node_groups_info
from solutions.graph.kuzu.queries.total_counts import (
    GraphTotalCounts,
    get_edge_counts_by_group,
    get_node_counts_by_group,
)
from solutions.graph.kuzu.schema.schema_extractor import GraphSchema, extract_schema
from solutions.graph.models import (
    GraphDBDoesNotExistError,
    GraphDBWriteInProgressError,
    to_edge_group_definitions,
    to_node_group_definitions,
)
from solutions.graph.queries.params import ComputeAdjacentNodeGroupInfoParams, RunCypherParams, RunLlmCypherParams
from solutions.graph.queries.response import RunCypherFailure, RunCypherSuccess

from ..store.graph_metadata_store import GraphMetadataStore
from ..views.cypher_graph_view import CypherGraphView, CypherQueryExecutionException

logger = logging.getLogger(__name__)


def run_cypher(
    store: GraphMetadataStore, params: RunCypherParams | RunLlmCypherParams, timeout_seconds: int | None = None
) -> RunCypherSuccess | RunCypherFailure:
    """
    Raises:
        GraphDoesNotExistError
        GraphMetadataStoreError
        GraphDBWriteInProgressError
        GraphDBDoesNotExistError
    """
    perf_t0 = time.perf_counter()
    query_params: Optional[Dict[str, Any]] = None
    if isinstance(params, RunCypherParams):
        query_params = params.params
    graph_id = params.graph_id
    graph_metadata = store.get(graph_id)

    node_group_definitions = [
        definition
        for node_group_metadata in graph_metadata["nodes"].values()
        for definition in to_node_group_definitions(node_group_metadata)
    ]
    edge_group_definitions = [
        definition
        for edge_group_metadata in graph_metadata["edges"].values()
        for definition in to_edge_group_definitions(edge_group_metadata, graph_metadata)
    ]

    with EditorWebAppDbInstance(graph_id, readonly=True) as db_instance:
        graph_view = CypherGraphView(db_instance, node_group_definitions, edge_group_definitions)

        try:
            perf_t1 = time.perf_counter()
            graph_view.execute(params.query, query_params, timeout_seconds)
            perf_t2 = time.perf_counter()
        except CypherQueryExecutionException as ex:
            logger.info(f"User query raised an error on graph {graph_id}, '{params.query}'.", exc_info=True)
            return {"success": False, "errorCode": "INVALID_CYPHER_ERROR", "error": str(ex)}
        except GraphDBWriteInProgressError as ex:
            raise ex
        except GraphDBDoesNotExistError as ex:
            raise ex
        except RuntimeError as ex:
            logger.exception(f"Runtime error on graph {graph_id}, '{params.query}'.", exc_info=True)
            return {
                "success": False,
                "errorCode": "BROKEN_DB_STATE_ERROR",
                "error": "The underlying graph database seems to be in a broken state. You may consider rebuilding it.",
            }
        except Exception as ex:
            logger.exception(f"An unexpected error occured on graph {graph_id}, '{params.query}'.", exc_info=True)
            return {"success": False, "errorCode": "UNEXPECTED_ERROR", "error": "An unexpected error occured."}

        nodes = []
        for node_id in graph_metadata["nodes"]:
            nodes.extend(graph_view.get_nodes(node_id))

        edges = []
        for edge_id in graph_metadata["edges"]:
            edges.extend(graph_view.get_edges(edge_id))

        df = graph_view.get_as_df().fillna("")
        perf_t3 = time.perf_counter()

        logger.debug(f"Cypher query timing: {(perf_t3-perf_t0):.2f}s total")
        logger.debug(f".........query prep: {(perf_t1-perf_t0):.2f}s")
        logger.debug(f"....query execution: {(perf_t2-perf_t1):.2f}s")
        logger.debug(f"(configured timeout: {timeout_seconds}s)")
        logger.debug(f"....post-processing: {(perf_t3-perf_t2):.2f}s")

        return {
            "success": True,
            "nodes": nodes,
            "edges": edges,
            "table": {"columns": [{"name": col} for col in df.columns], "rows": df.to_dict("records")},
        }


def compute_adjacent_node_groups_info(
    db_instance: AbstractDbInstance, params: ComputeAdjacentNodeGroupInfoParams
) -> List[AdjacentNodeGroupInfo]:
    schema: GraphSchema = extract_schema(db_instance)
    node_group_schema = schema["node_groups"].get(params.node_group)
    if not node_group_schema:
        raise Exception(f'Node group "{params.node_group}" does not exist in this graph database.')

    with db_instance.get_new_conn() as context_manager:
        return get_adjacent_node_groups_info(context_manager.connection, node_group_schema, params.node_id)


def compute_total_counts(db_instance: AbstractDbInstance) -> GraphTotalCounts:
    if not db_instance.exists():
        logger.info("Graph database doesn't exist, ignore node count request")
        node_counts, edge_counts = {}, {}
    else:
        with db_instance.get_new_conn() as context_manager:
            conn = context_manager.connection
            node_counts = get_node_counts_by_group(conn)
            edge_counts = get_edge_counts_by_group(conn)

    return {"nodes": node_counts, "edges": edge_counts}
