from __future__ import annotations

import logging
import os
import tempfile
from typing import Dict, List

import dataiku
from dataiku.customwebapp import get_webapp_config
from dataikuapi.dss.llm import DSSLLM

logger = logging.getLogger(__name__)


class WebAppConfig:
    def __init__(self) -> None:
        self.__is_setup = False
        self.__webapp_config: Dict = {}
        self.__default_project_key = ""
        self.__db_local_folder_path = ""
        self.__db_query_timeout_seconds = 60
        self.__nodes_datasets: List[str] = []
        self.__edges_datasets: List[str] = []
        self.__metadata_ds: str = ""
        self.__snapshots_ds: str | None = None
        self.__build_graph_recipe_output_connection: str | None = None
        self.__llm_id: str | None = None
        self.__logging_ds: str | None = None

    def setup(self, webapp_config: Dict | None = None, default_project_key: str | None = None) -> None:
        if self.__is_setup:
            return

        self.__webapp_config = webapp_config if webapp_config else get_webapp_config()
        assert self.__webapp_config

        self.__default_project_key = (
            default_project_key if default_project_key else dataiku.get_custom_variables()["projectKey"]
        )
        assert self.__default_project_key

        logger.info(f"Webapp config is {self.__webapp_config}.")

        # When debugging locally, use a path passed as an environment variable to store files instead of the default /tmp.
        debug_run_folder_path = os.getenv("DEBUG_RUN_FOLDER")

        self.__db_local_folder_path = (
            self.__create_db_folder__(debug_run_folder_path)
            if debug_run_folder_path
            else self.__create_db_folder__(tempfile.TemporaryDirectory().name)
        )

        logger.info(f"Kuzu files are persisted locally at '{self.__db_local_folder_path}'.")

        if self.__webapp_config.get("db_query_timeout_seconds"):
            self.__db_query_timeout_seconds = int(self.__webapp_config.get("db_query_timeout_seconds"))  # type: ignore

        if self.__webapp_config.get("nodes_datasets"):
            self.__nodes_datasets = self.__webapp_config.get("nodes_datasets")  # type: ignore

        if self.__webapp_config.get("edges_datasets"):
            self.__edges_datasets = self.__webapp_config.get("edges_datasets")  # type: ignore

        if self.__webapp_config.get("metadata_ds"):
            self.__metadata_ds = self.__webapp_config.get("metadata_ds")  # type: ignore
            assert self.__metadata_ds, "It is mandatory to configure the metadata dataset."

        if self.__webapp_config.get("snapshots_ds"):
            self.__snapshots_ds = self.__webapp_config.get("snapshots_ds")

        if self.__webapp_config.get("build_graph_recipe_output_connection"):
            self.__build_graph_recipe_output_connection = self.__webapp_config.get(
                "build_graph_recipe_output_connection"
            )

        if self.__webapp_config.get("llm_id"):
            self.__llm_id = self.__webapp_config.get("llm_id")

        if self.__webapp_config.get("logging_ds"):
            self.__logging_ds = self.__webapp_config.get("logging_ds")

        self.__is_setup = True

    @property
    def default_project_key(self):
        return self.__default_project_key

    @property
    def db_folder_path(self) -> str:
        return self.__db_local_folder_path

    @property
    def db_query_timeout_seconds(self) -> int:
        return self.__db_query_timeout_seconds

    @property
    def nodes_datasets(self) -> List[str]:
        return self.__nodes_datasets

    @property
    def edges_datasets(self) -> List[str]:
        return self.__edges_datasets

    @property
    def metadata_ds(self) -> str:
        return self.__metadata_ds

    @property
    def snapshots_ds(self) -> str | None:
        return self.__snapshots_ds

    @property
    def build_graph_recipe_output_connection(self) -> str | None:
        return self.__build_graph_recipe_output_connection

    @property
    def llm_id(self) -> str | None:
        return self.__llm_id

    @property
    def logging_ds(self) -> str | None:
        return self.__logging_ds

    def get_columns(self, dataset_name: str):
        dataset = dataiku.Dataset(name=dataset_name, project_key=self.default_project_key)
        return dataset.get_config().get("schema").get("columns")

    def __create_db_folder__(self, parent_path: str):
        p = os.path.join(parent_path, "graph-instances")

        try:
            os.makedirs("graph-instances")
        except FileExistsError:
            pass

        return p

    def get_llm(self) -> DSSLLM | None:
        if not self.llm_id:
            return None

        client = dataiku.api_client()
        project = client.get_project(self.default_project_key)
        return project.get_llm(self.llm_id)


webapp_config = WebAppConfig()
