from __future__ import annotations

import json
import logging
from abc import ABC, abstractmethod
from typing import Dict, List

import pandas as pd
from dataiku import Dataset
from diskcache import Index
from pandas import DataFrame, Series, notna

from solutions.graph.store.utils import hash_dict

from ..models import (
    GraphAlreadyExistsError,
    GraphDoesNotExistError,
    GraphId,
    GraphMetadata,
    VersionedGraphMetadata,
)

logger = logging.getLogger(__name__)

GRAPH_METADATA_COLUMNS = [
    "graph_id",
    "graph_name",
    "nodes",
    "edges",
    "nodes_version",
    "edges_version",
    "view_configuration",
    "cypher_queries",
]


class GraphMetadataStoreError(Exception):
    def __init__(self, *args: object) -> None:
        super().__init__(*args)


class ConcurrentUpdateCollisionError(Exception):
    graph_id: str | None

    def __init__(self, graph_id: GraphId | None) -> None:
        super().__init__(f"Attempted to update graph {graph_id} from an out-of-date configuration.")
        self.graph_id = graph_id


class AbstractGraphMetadataStore(ABC):
    @abstractmethod
    def init_dataset_schema(self) -> None:
        pass

    @abstractmethod
    def exists(self, id: GraphId) -> bool:
        """
        Raises:
            GraphMetadataStoreError
        """
        pass

    @abstractmethod
    def get_all(self) -> Dict[GraphId, VersionedGraphMetadata]:
        """
        Raises:
            GraphMetadataStoreError
        """
        pass

    @abstractmethod
    def get(self, id: GraphId) -> VersionedGraphMetadata:
        """
        Raises:
            GraphDoesNotExistError
            GraphMetadataStoreError
        """
        pass

    @abstractmethod
    def delete(self, id: GraphId) -> None:
        """
        Raises:
            GraphDoesNotExistError
            GraphMetadataStoreError
        """
        pass

    @abstractmethod
    def update(self, metadata: GraphMetadata, ref_version_token: str) -> VersionedGraphMetadata:
        """
        Raises:
            GraphDoesNotExistError
            ConcurrentUpdateCollisionError
            GraphMetadataStoreError
        """
        pass

    @abstractmethod
    def create(self, metadata: GraphMetadata) -> VersionedGraphMetadata:
        """
        Raises:
            GraphAlreadyExistsError
            GraphMetadataStoreError
        """
        pass


class GraphMetadataStore(AbstractGraphMetadataStore):
    """
    Use diskcache as a cross process cache for graph metadata.
    Also diskcache transactions serve as a synchronization mechanism across processes.
    Dataiku metadata dataset is not necessarily a SQL database (might be S3 or FS), so we need to add a synchronization mechanism.
    """

    __CACHE_INITIALIZED_KEY = "__cache_initialized__"

    def __init__(self, diskcache_path: str, dataiku_store: AbstractGraphMetadataStore) -> None:
        logging.debug(f"Initializing graph store with diskcache at {diskcache_path}.")
        self.__cache = Index(diskcache_path)
        self.__dataiku_store = dataiku_store

    def __load_cache__(self) -> None:
        graphs_metadata = self.__dataiku_store.get_all()
        self.__cache.clear()
        for key, value in graphs_metadata.items():
            self.__cache[key] = value
        self.__cache[self.__CACHE_INITIALIZED_KEY] = True

    def __ensure_cache_loaded__(self) -> None:
        with self.__cache.transact():
            if self.__CACHE_INITIALIZED_KEY not in self.__cache:
                self.init_dataset_schema()
                self.__load_cache__()

    def init_dataset_schema(self) -> None:
        self.__dataiku_store.init_dataset_schema()

    def exists(self, id: GraphId) -> bool:
        try:
            self.get(id)
            return True
        except GraphDoesNotExistError:
            return False

    def get_all(self) -> Dict[GraphId, VersionedGraphMetadata]:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            # Exclude the sentinel key from the returned dictionary
            return {k: v for k, v in self.__cache.items() if k != self.__CACHE_INITIALIZED_KEY}  # type: ignore

    def get(self, id: GraphId) -> VersionedGraphMetadata:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            if id not in self.__cache:
                raise GraphDoesNotExistError(id)
            return {**self.__cache[id]}  # type: ignore

    def delete(self, id: GraphId) -> None:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            if id not in self.__cache:
                raise GraphDoesNotExistError(id)

            self.__dataiku_store.delete(id)
            del self.__cache[id]

    def update(self, metadata: GraphMetadata, ref_version_token: str) -> VersionedGraphMetadata:
        self.__ensure_cache_loaded__()
        graph_id = metadata["id"]
        metadata = {**metadata}

        with self.__cache.transact():
            if graph_id not in self.__cache:
                raise GraphDoesNotExistError(graph_id)

            current_version_token: VersionedGraphMetadata = self.__cache[graph_id]["version_token"]  # type: ignore

            if ref_version_token != current_version_token:
                raise ConcurrentUpdateCollisionError(graph_id)

            updated_metadata = self.__dataiku_store.update(metadata, ref_version_token)
            self.__cache[updated_metadata["id"]] = updated_metadata
            return {**updated_metadata}

    def create(self, metadata: GraphMetadata) -> VersionedGraphMetadata:
        self.__ensure_cache_loaded__()
        graph_id = metadata["id"]
        metadata = {**metadata}

        with self.__cache.transact():
            if graph_id in self.__cache:
                raise GraphAlreadyExistsError(graph_id)

            created_graph_metadata = self.__dataiku_store.create(metadata)
            self.__cache[created_graph_metadata["id"]] = created_graph_metadata
            return {**created_graph_metadata}


class DataikuGraphMetadataStore(AbstractGraphMetadataStore):
    def __init__(self, config_ds_name: str, project_key: str) -> None:
        self.__config_ds = Dataset(config_ds_name, project_key)
        self.__config_ds_name = self.__config_ds.name

    def init_dataset_schema(self) -> None:
        metadata_ds = Dataset(self.__config_ds_name, self.__config_ds.project_key)
        schema = metadata_ds.read_schema(raise_if_empty=False)
        if len(schema) == 0:
            try:
                data: Dict[str, List[str]] = {col: [] for col in GRAPH_METADATA_COLUMNS}
                metadata_ds.write_with_schema(df=pd.DataFrame(data, columns=GRAPH_METADATA_COLUMNS, dtype=str))
            except Exception as ex:
                raise GraphMetadataStoreError(
                    f"Failed to init schema of graph metadata dataset {self.__config_ds_name}."
                ) from ex
        else:
            if [col["name"] for col in schema] != GRAPH_METADATA_COLUMNS:
                """
                Let the backend crash, unexpected and unresolvable programmatically without possibly losing data.
                """
                raise GraphMetadataStoreError(
                    f"Unexpected columns for graph metadata dataset {self.__config_ds_name}. Expected {GRAPH_METADATA_COLUMNS}."
                )

    def __load_graphs_metadata__(self) -> Dict[GraphId, VersionedGraphMetadata]:
        try:
            graphs_metadata = {}
            for meta_df in self.__config_ds.iter_dataframes():
                for _, row in meta_df.iterrows():
                    graph_metadata = DataikuGraphMetadataStore.from_df(row)
                    hashed_graph_metadata: VersionedGraphMetadata = graph_metadata  # type: ignore
                    hashed_graph_metadata["version_token"] = hash_dict(graph_metadata)
                    graphs_metadata[row["graph_id"]] = hashed_graph_metadata

            return graphs_metadata
        except Exception as ex:
            raise GraphMetadataStoreError(
                f"Failed to load graph metadata from dataset {self.__config_ds_name}."
            ) from ex

    def exists(self, id: GraphId) -> bool:
        try:
            self.get(id)
            return True
        except GraphDoesNotExistError as ex:
            return False

    def get_all(self) -> Dict[GraphId, VersionedGraphMetadata]:
        return {**self.__load_graphs_metadata__()}

    def get(self, id: GraphId) -> VersionedGraphMetadata:
        all_graphs = self.get_all()
        if id not in all_graphs:
            raise GraphDoesNotExistError(id)

        return {**all_graphs[id]}

    def delete(self, id: GraphId) -> None:
        try:
            meta_df: DataFrame = self.__config_ds.get_dataframe()

            output_df = meta_df.loc[meta_df["graph_id"] != id]

            self.__config_ds.write_with_schema(output_df)
        except Exception as ex:
            raise GraphMetadataStoreError(
                f"Failed to delete metadata for graph {id} in dataset {self.__config_ds_name}."
            ) from ex

    def update(self, metadata: GraphMetadata, ref_version_token: str) -> VersionedGraphMetadata:
        metadata = {**metadata}
        graph_id = metadata["id"]

        try:
            meta_df: DataFrame = self.__config_ds.get_dataframe()

            output_df: DataFrame = meta_df.loc[meta_df["graph_id"] != graph_id]
            output_df = pd.concat([output_df, DataikuGraphMetadataStore.to_df(metadata)])

            self.__config_ds.write_with_schema(output_df)

            return metadata | {"version_token": hash_dict(metadata)}  # type: ignore
        except Exception as ex:
            raise GraphMetadataStoreError(
                f"Failed to save metadata for graph {graph_id} in dataset {self.__config_ds_name}."
            ) from ex

    def create(self, metadata: GraphMetadata) -> VersionedGraphMetadata:
        metadata = {**metadata}
        graph_id = metadata["id"]

        try:
            meta_df: DataFrame = self.__config_ds.get_dataframe()

            output_df: DataFrame = meta_df.loc[meta_df["graph_id"] != graph_id]
            output_df = pd.concat([output_df, DataikuGraphMetadataStore.to_df(metadata)])

            self.__config_ds.write_with_schema(output_df)

            return metadata | {"version_token": hash_dict(metadata)}  # type: ignore
        except Exception as ex:
            raise GraphMetadataStoreError(
                f"Failed to save metadata for graph {graph_id} in dataset {self.__config_ds_name}."
            ) from ex

    @staticmethod
    def to_df(metadata: GraphMetadata) -> DataFrame:
        return DataFrame(
            {
                "graph_id": [metadata["id"]],
                "graph_name": [metadata["name"]],
                "nodes": [json.dumps(metadata["nodes"])],
                "edges": [json.dumps(metadata["edges"])],
                "nodes_version": [json.dumps({})],
                "edges_version": [json.dumps({})],
                "view_configuration": [
                    json.dumps(
                        {
                            "sampling": metadata["sampling"]["sampling"],
                            "max_rows": metadata["sampling"]["max_rows"],
                            "nodes": metadata["nodes_view"],
                            "edges": metadata["edges_view"],
                        }
                    )
                ],
                "cypher_queries": [json.dumps(metadata["cypher_queries"])],
            }
        )

    @staticmethod
    def from_df(series: Series) -> GraphMetadata:
        view_configuration = json.loads(series["view_configuration"])
        return {
            "id": series["graph_id"],
            "name": series["graph_name"],
            "nodes": json.loads(series["nodes"]),
            "edges": json.loads(series["edges"]),
            "nodes_view": view_configuration["nodes"],
            "edges_view": view_configuration["edges"],
            "sampling": {"sampling": view_configuration["sampling"], "max_rows": view_configuration["max_rows"]},
            "cypher_queries": json.loads(series["cypher_queries"])
            if "cypher_queries" in series and notna(series["cypher_queries"])
            else [],
        }
