import asyncio
import curses
import json
from datetime import datetime
from typing import Any, Dict, List, Optional, Set

import pandas as pd
from dataikuapi.dss.project import DSSAgentTool, DSSAgentToolListItem
from eval_factory import EvalFactory
from resources.mini_dev_postgresql import QUERIES
from ui.selectors import display_query_group_selector, display_tool_selector
from utils.messages import RESULT_MESSAGE, TOOL_DECLINED
from utils.queries import get_golden_query, get_queries_by_db_id, group_queries_by_db_id


async def main() -> None:
    eval_facto = EvalFactory()
    eval_facto.get_config()

    print(f"\n🔍 Test Run ID: {eval_facto.test_run_id}")
    print("\nPlease verify and validate this test run ID before proceeding.")

    confirmation = (
        input("Do you want to proceed with this test run ID? (y/N): ").strip().lower()
    )

    if confirmation not in ["y", "yes"]:
        print("Test run cancelled by user.")
        return
    print(f"✅ Proceeding with test run ID: {eval_facto.test_run_id}\n")

    tools: List[DSSAgentToolListItem] = eval_facto.project.list_agent_tools()
    sql_tools: List[DSSAgentToolListItem] = [
        tool
        for tool in tools
        if tool.type
        == "Custom_agent_tool_sql-question-answering-tool_sql-question-answering"
    ]

    selected_tools = curses.wrapper(display_tool_selector, sql_tools)

    if not selected_tools:
        print("\nNo tools selected. Exiting...")
        return

    db_ids: Set[str] = set()
    for query in QUERIES:
        db_ids.add(query["db_id"])

    available_queries: List[Dict[str, Any]] = []
    for db_id in db_ids:
        queries_for_db: List[Dict[str, Any]] = get_queries_by_db_id(QUERIES, db_id)
        available_queries.extend(queries_for_db)

    if not available_queries:
        print("\nNo queries available for the selected tools. Exiting...")
        return

    query_groups: List[Dict[str, Any]] = group_queries_by_db_id(available_queries)
    selected_groups: List[Dict[str, Any]] = curses.wrapper(
        display_query_group_selector, query_groups
    )

    if not selected_groups:
        print("\nNo query groups selected. Exiting...")
        return

    selected_queries: List[Dict[str, Any]] = []
    for group in selected_groups:
        selected_queries.extend(group["queries"])

    print("⚙️ Test will be run with the following configuration:")

    for tool in selected_tools:
        print(f"- {tool['name']} (ID: {tool['id']})")

    for group in selected_groups:
        print(f"- {group['db_id']} ({group['count']} queries)")

    semaphore = asyncio.Semaphore(10)
    results = []

    for tool in selected_tools:
        print(f"\n🚨 Running test for tool: {tool['name']}")

        currTool: DSSAgentTool = DSSAgentTool(
            eval_facto.client, eval_facto.target_project_key, tool["id"]
        )

        tasks = [
            semaphore_run(semaphore, eval_facto, currTool, query)
            for query in selected_queries
        ]

        results = await asyncio.gather(*tasks)

    if results:
        df = pd.DataFrame(results)
        print(f"\n📊 Results DataFrame created with {len(df)} queries")
        print(f"✅ Successful queries: {len(df[df['status'] == 'success'])}")
        print(f"❌ Failed queries: {len(df[df['status'] == 'error'])}")
        print(df.head())

        filename = f"query_results_{eval_facto.test_run_id}.json"
        df.to_json(filename, orient="records")
        print(f"💾 Results saved to: {filename}")
    else:
        print("\n❌ No queries to create DataFrame")


async def semaphore_run(semaphore, eval_facto, tool, query):
    async with semaphore:
        return await asyncio.to_thread(run_query_test, eval_facto, tool, query)


def run_query_test(
    eval_facto: EvalFactory,
    currTool: DSSAgentTool,
    query: Dict[str, Any],
) -> Optional[Dict[str, Any]]:
    user_query: str = query.get("question", "")
    query_id: str = query.get("question_id", 0)
    start_time = datetime.now()
    start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")

    result_dict = {
        "query_id": query_id,
        "test_run_id": eval_facto.test_run_id,
        "db_id": query.get("db_id", ""),
        "query": user_query,
        "start_time": start_time_str,
        "end_time": "",
        "info": "",
        "trace": "",
        "generated_sql": "",
        "generated_records": "",
        "tool_output": "",
        "execution_time": "",
        "status": "success",
        "error_message": "",
    }

    print(f"[🧠 START] {str(query_id)} - {user_query}")

    res: Dict[str, Any] = currTool.run(
        {"question": user_query},
        {"context": {}},
    )

    if "error" in res:
        result_dict["status"] = "error"
        result_dict["error_message"] = res["error"]
    elif "output" in res and res["output"].startswith(TOOL_DECLINED):
        result_dict["status"] = "declined"
    else:
        for item in res["sources"][0]["items"]:
            if item["type"] == "INFO":
                result_dict["info"] = item["textSnippet"]
            elif item["type"] == "GENERATED_SQL_QUERY":
                result_dict["generated_sql"] = item["performedQuery"]
            elif item["type"] == "RECORDS":
                result_dict["generated_records"] = item["records"]

    golden_query = get_golden_query(eval_facto.golden_queries, query_id)

    result_dict["golden_sql"] = golden_query["golden_query"]
    result_dict["golden_records"] = golden_query["golden_results"]

    end_time = datetime.now()
    execution_time = end_time - start_time
    execution_time_str = str(execution_time)

    result_dict["end_time"] = end_time.strftime("%Y-%m-%d %H:%M:%S")
    result_dict["execution_time"] = execution_time_str
    result_dict["tool_output"] = res.get("output", "")
    if trace := res.get("trace", None):
        result_dict["trace"] = json.dumps(trace)

    if result_dict["status"] == "error":
        print(
            RESULT_MESSAGE.format(
                status="🚨 ERROR",
                exec_time=result_dict["execution_time"],
                query_id=str(query_id),
                message=result_dict["error_message"],
            )
        )
    elif result_dict["status"] == "declined":
        print(
            RESULT_MESSAGE.format(
                status="🚩 DECLINED",
                exec_time=result_dict["execution_time"],
                query_id=str(query_id),
                message=result_dict["tool_output"],
            )
        )
    else:
        print(
            RESULT_MESSAGE.format(
                status="✅ RESULT",
                exec_time=result_dict["execution_time"],
                query_id=str(query_id),
                message=result_dict["tool_output"],
            )
        )

    return result_dict


if __name__ == "__main__":
    asyncio.run(main())
