from typing import Any, Dict, List, Tuple

from answers.backend.utils.db.sql_query_utils import replace_tables_in_ast
from answers.llm_assist.llm_tooling.tool_utils import get_dataset_descriptions_from_table_names, to_select_query
from common.backend.utils.dataiku_api import dataiku_api
from common.llm_assist.logging import logger
from dataiku import SQLExecutor2
from dataiku.core.dataset import Dataset
from dataiku.sql import toSQL
from langchain.tools import BaseTool


class SqlRetrieverTool(BaseTool):

    def __init__(self):
        name = "SQL retriever"
        description = "The SqlRetrieverTool is designed for retrieving data from SQL databases. It leverages filtering criteria to fetch specific records. It utilizes SQL queries to extract the first matching record based on the given filters."
        super().__init__(name=name, description=description)

    def _run(self, db_query: dict) -> Tuple[str, Any, List[List[Dict[str, str]]], List]:
        connection_name = dataiku_api.webapp_config.get("sql_retrieval_connection")
        if connection_name is None:
            raise Exception("A SQL connection selection is required to run a query.")
        if not(sql_retrieval_table_list := dataiku_api.webapp_config.get("sql_retrieval_table_list", [])):
            raise Exception("At least one SQL table is required to run a query.")
        # The first dataset is used by toSQL to get the dialect information for the connection
        first_dataset = Dataset(sql_retrieval_table_list[0])
        hard_sql_limit = int(dataiku_api.webapp_config.get("hard_sql_limit", 200))
        logger.debug(f"hard_sql_limit is set to {hard_sql_limit}")

        executor = SQLExecutor2(connection=connection_name)
        select_query = to_select_query(db_query, hard_sql_limit)

        tables_used: List[str] = replace_tables_in_ast(select_query)
        logger.debug(f"Replaced ast: {select_query}")

        sql_query = toSQL(select_query, dataset=first_dataset)
        logger.debug(f"Running SQL Query: {sql_query}")
        df = executor.query_to_df(query=sql_query)
        df.fillna("", inplace=True)
        records = df.to_dict("records")

        used_dataset_descriptions = get_dataset_descriptions_from_table_names(tables_used)
        context = f"""
        SQL query executed
        {sql_query}
        Dataset Descriptions
        {used_dataset_descriptions}
        
        >response
        {records}
        """
        return context, sql_query, records[:10], tables_used

    def _arun(self, filter_dict: dict) -> dict:  # type: ignore
        raise NotImplementedError("This tool does not support async")
