from __future__ import annotations

import logging

from flask import Blueprint, request

from explorer.backend.utils.webapp_config import webapp_config
from solutions.backend.utils import return_ok
from solutions.graph.explorer_builder import ExplorerMetadataManager
from solutions.graph.queries.graph_data import compute_adjacent_node_groups_info, compute_total_counts
from solutions.graph.queries.params import ComputeAdjacentNodeGroupInfoParams

logger = logging.getLogger(__name__)

graph_queries = Blueprint("graph_queries", __name__, url_prefix="/queries")


@graph_queries.route("/<snapshot_id>/computeAdjacentNodeGroupsInfo", methods=["POST"])
def get_adjacent_node_groups(snapshot_id: str):
    params = ComputeAdjacentNodeGroupInfoParams(**request.get_json())
    graph_manager = ExplorerMetadataManager(webapp_config.db_explorer_folders)

    with graph_manager.db_instance_factory.get_db_instance_from_snapshot_id(snapshot_id) as db_instance:
        return return_ok(compute_adjacent_node_groups_info(db_instance, params))

@graph_queries.route("/<snapshot_id>/getTotalCounts", methods=["GET"])
def get_total_counts(snapshot_id: str):
    graph_manager = ExplorerMetadataManager(webapp_config.db_explorer_folders)
    with graph_manager.db_instance_factory.get_db_instance_from_snapshot_id(snapshot_id) as db_instance:
        return return_ok(compute_total_counts(db_instance))
