from typing import Any, Dict, List, Optional, TypedDict

from common.backend.constants import ALL_LONGTEXT_COLUMN_NAMES
from common.backend.utils.dataiku_api import dataiku_api
from common.llm_assist.logging import logger
from dataiku import Dataset, SQLExecutor2
from dataikuapi.dss.dataset import DSSDataset, SQLDatasetSettings

CLOB_CHUNK_SIZE = 4000


class OracleParams(TypedDict, total=False):
    dataset_name: str
    executor: SQLExecutor2
    dataset: Dataset
    columns: List[str]
    table_name: Optional[str]


class OracleSQLManager:
    def __init__(self, oracle_params: OracleParams) -> None:
        self.params = oracle_params
        self.dataset_name = self.params["dataset_name"]
        self.executor = self.params["executor"]
        self.dataset = self.params["dataset"]
        self.columns = self.params["columns"]
        self.table_name = ""
        self._get_params()

    def _get_params(self) -> None:
        table_name = self.get_oracle_table_from_dataset()
        if not table_name or table_name == "":
            raise Exception("Table name not found")
        self.table_name = table_name

    @staticmethod
    def quote_identifier(string: str) -> str:
        QUOTE = '"'
        return QUOTE + string.replace(QUOTE, QUOTE + QUOTE) + QUOTE

    def build_oracle_table_name(self, schema: Optional[str], table: str) -> str:
        quoted_table = OracleSQLManager.quote_identifier(table)
        if not schema:
            return quoted_table
        return f"{OracleSQLManager.quote_identifier(schema)}.{quoted_table}"

    def get_oracle_table_from_dataset(self) -> str:
        loc = self.dataset.get_location_info()
        if loc.get("locationInfoType") != "SQL":
            raise ValueError("Expected SQL dataset")
        table_name = loc.get("info").get("table")
        schema_name = loc.get("info").get("schema")
        return self.build_oracle_table_name(schema=schema_name, table=table_name)

    def get_col_definitions(self) -> str:
        def make_varchar_def(column_name: str) -> str:
            return f'"{column_name}" nvarchar2(2000)'

        def make_nclob_def(column_name: str) -> str:
            return f'"{column_name}" NCLOB'

        col_definitions = [
            make_nclob_def(c) if c in ALL_LONGTEXT_COLUMN_NAMES else make_varchar_def(c) for c in self.params["columns"]
        ]
        return ", ".join(col_definitions)

    def _create_oracle_table(self) -> bool:
        try:
            col_definitions = self.get_col_definitions()
            q = f"""CREATE TABLE {self.table_name} ({col_definitions})"""
            logger.debug(f"Creating table with query: {q}")
            self.params["executor"].query_to_df(q, post_queries=["COMMIT"])
            logger.info(f"""Oracle table created: {self.table_name}""")
            return True
        except:
            logger.info(f"Oracle table already exists or unable to recreate it")
            return False

    def _check_oracle_table_exists(self):
        try:
            q = f"""SELECT * FROM {self.table_name} FETCH FIRST 1 ROWS ONLY"""
            self.params["executor"].query_to_df(q, post_queries=["COMMIT"])
            logger.info(f"""Oracle table exists: {self.table_name}""")
        except:
            raise Exception(f"""Oracle table: {self.table_name} does not exist""")

    def _update_oracle_dataset_schema(self, dss_dataset: DSSDataset, ds_settings: SQLDatasetSettings) -> None:
        dataset_detected_settings = dss_dataset.test_and_detect()
        dataset_detected_schema = dataset_detected_settings["schemaDetection"]["detectedSchema"]["columns"]
        ds_settings.get_raw()["schema"]["columns"] = dataset_detected_schema
        ds_settings.save()
        logger.debug(f"""Dataset '{self.dataset_name}' schema replaced.""")

    def _refresh_oracle_schema(self) -> None:
        # Note: this is not the same dataset object created by dataiku.Dataset
        dss_dataset: DSSDataset = dataiku_api.default_project.get_dataset(self.dataset_name)
        ds_settings: SQLDatasetSettings = dss_dataset.get_settings()
        ds_settings.settings["managed"] = False
        ds_settings.save()
        try:
            self._update_oracle_dataset_schema(dss_dataset, ds_settings)
        except Exception as e:
            raise Exception(f"Unable to refresh dataset schema: {e}")
        ds_settings.settings["managed"] = True
        ds_settings.save()

    def setup_oracle_table(self) -> None:
        needs_refresh = False
        needs_refresh = self._create_oracle_table()
        self._check_oracle_table_exists()
        if needs_refresh:
            self._refresh_oracle_schema()

    @staticmethod
    def split_into_chunks(text: str, chunk_size=CLOB_CHUNK_SIZE) -> List[str]:
        return [text[i : i + chunk_size] for i in range(0, len(text), chunk_size)]

    @staticmethod
    def generate_clob_statement(long_text: str, chunk_size=CLOB_CHUNK_SIZE) -> str:
        chunks = OracleSQLManager.split_into_chunks(long_text, chunk_size)
        return " || ".join(f"TO_CLOB('{chunk}')" for chunk in chunks)

    @staticmethod
    def escape_quotes(str_value: str) -> str:
        return str_value.replace("'", "''") if str_value and isinstance(str_value, str) else str_value

    @staticmethod
    def insert_statement(cols: List[str], values: List[List[str]], chunk_size=CLOB_CHUNK_SIZE) -> str:
        set_statements = []
        nested_values = values[0]
        for col, value in zip(cols, nested_values):
            value = OracleSQLManager.escape_quotes(value)
            if not value:
                set_statements.append(f"NULL")
            elif col in ALL_LONGTEXT_COLUMN_NAMES:
                set_statements.append(f""" {OracleSQLManager.generate_clob_statement(value, chunk_size)}""")
            else:
                set_statements.append(f""" '{value}'""")
        statement = f""" ( {", ".join(set_statements)} ) """
        logger.debug(f"Oracle insert statement: {statement}")
        return statement

    @staticmethod
    def update_statement(sets: Dict[str, Any], chunk_size=CLOB_CHUNK_SIZE) -> str:
        set_statements = []
        for col, value in sets.items():
            if col in ALL_LONGTEXT_COLUMN_NAMES:
                set_statements.append(f""" "{col}" = {OracleSQLManager.generate_clob_statement(value, chunk_size)}""")
            else:
                set_statements.append(f""" "{col}" = '{value}'""")
        statement = ", ".join(set_statements)
        logger.debug(f"Oracle update statement: {statement}")
        return statement
