from __future__ import annotations

import json
import logging
from abc import ABC, abstractmethod
from typing import Dict, List

import pandas as pd
from dataiku import Dataset
from diskcache import Index
from pandas import DataFrame, Series

from ..models import GraphId, GraphMetadataSnapshot, SnapshotId

logger = logging.getLogger(__name__)

GRAPH_METADATA_SNAPSHOT_COLUMNS = [
    "id",
    "graph_id",
    "epoch_ms",
    "snapshot",
]


class GraphMetadataSnapshotStoreException(Exception):
    def __init__(self, *args: object) -> None:
        super().__init__(*args)


class AbstractGraphMetadataSnapshotStore(ABC):
    @abstractmethod
    def init_dataset_schema(self) -> None:
        pass

    @abstractmethod
    def get_all(self) -> List[GraphMetadataSnapshot]:
        pass

    @abstractmethod
    def get_all_by_graph_id(self, id: GraphId) -> List[GraphMetadataSnapshot]:
        pass

    @abstractmethod
    def get_by_id(self, id: SnapshotId) -> GraphMetadataSnapshot | None:
        pass

    @abstractmethod
    def delete(self, id: SnapshotId) -> None:
        pass

    @abstractmethod
    def delete_all(self, graph_id: GraphId) -> None:
        pass

    @abstractmethod
    def save(self, snapshot: GraphMetadataSnapshot) -> None:
        pass


class GraphMetadataSnapshotStore(AbstractGraphMetadataSnapshotStore):
    __CACHE_INITIALIZED_KEY = "__cache_initialized__"

    def __init__(self, diskcache_path: str, dataiku_store: DataikuGraphMetadataSnapshotStore) -> None:
        logger.debug(f"Initializing saved configuration store with diskcache at {diskcache_path}.")
        self.__cache = Index(diskcache_path)
        self.__dataiku_store = dataiku_store

    def __load_cache__(self) -> None:
        snapshots = self.__dataiku_store.get_all()
        self.__cache.clear()
        for s in snapshots:
            self.__cache[s["id"]] = s
        self.__cache[self.__CACHE_INITIALIZED_KEY] = True

    def __ensure_cache_loaded__(self) -> None:
        with self.__cache.transact():
            if self.__CACHE_INITIALIZED_KEY not in self.__cache:
                self.init_dataset_schema()
                self.__load_cache__()

    def init_dataset_schema(self) -> None:
        self.__dataiku_store.init_dataset_schema()

    def get_all(self) -> List[GraphMetadataSnapshot]:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            return [{**s} for k, s in self.__cache.items() if k != self.__CACHE_INITIALIZED_KEY]  # type: ignore

    def get_all_by_graph_id(self, id: GraphId) -> List[GraphMetadataSnapshot]:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            return [{**s} for k, s in self.__cache.items() if k != self.__CACHE_INITIALIZED_KEY and s["graph_id"] == id]  # type: ignore

    def get_by_id(self, id: SnapshotId) -> GraphMetadataSnapshot | None:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            snapshot = self.__cache.get(id)
            return {**snapshot} if snapshot else None  # type: ignore

    def delete(self, id: SnapshotId) -> None:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            self.__dataiku_store.delete(id)
            self.__cache.pop(id)

    def delete_all(self, graph_id: GraphId) -> None:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            self.__dataiku_store.delete_all(graph_id)
            to_delete = [
                k for k, v in self.__cache.items() if k != self.__CACHE_INITIALIZED_KEY and v["graph_id"] == graph_id
            ]
            for k in to_delete:
                self.__cache.pop(k)

    def save(self, snapshot: GraphMetadataSnapshot) -> None:
        self.__ensure_cache_loaded__()
        with self.__cache.transact():
            if snapshot["id"] in self.__cache:
                raise GraphMetadataSnapshotStoreException(f"Graph metadata snapshot with id {snapshot['id']} already exists.")
            self.__dataiku_store.save(snapshot)
            self.__cache[snapshot["id"]] = snapshot


class DataikuGraphMetadataSnapshotStore(AbstractGraphMetadataSnapshotStore):
    def __init__(self, snapshot_ds_name: str, project_key: str | None = None) -> None:
        self.__snapshot_ds = Dataset(snapshot_ds_name, project_key)
        self.__snapshot_ds_name = self.__snapshot_ds.name

    def init_dataset_schema(self) -> None:
        snapshot_ds = Dataset(self.__snapshot_ds_name, self.__snapshot_ds.project_key)
        schema = snapshot_ds.read_schema(raise_if_empty=False)
        if len(schema) == 0:
            try:
                data: Dict[str, List[str]] = {col: [] for col in GRAPH_METADATA_SNAPSHOT_COLUMNS}
                df = pd.DataFrame(data, columns=GRAPH_METADATA_SNAPSHOT_COLUMNS, dtype=str)
                df = df.astype(
                    {
                        "id": str,
                        "graph_id": str,
                        "epoch_ms": int,
                        "snapshot": str,
                    }
                )
                snapshot_ds.write_with_schema(df)
            except Exception as ex:
                raise GraphMetadataSnapshotStoreException(
                    f"Failed to init schema of graph metadata saved configurations dataset {self.__snapshot_ds_name}."
                ) from ex
        elif [col["name"] for col in schema] != GRAPH_METADATA_SNAPSHOT_COLUMNS:
            # Let the backend crash, unexpected and unresolvable programmatically without possibly losing data.
            raise GraphMetadataSnapshotStoreException(
                f"Unexpected columns for graph metadata saved configurations dataset {self.__snapshot_ds_name}. Expected {GRAPH_METADATA_SNAPSHOT_COLUMNS}."
            )

    def __load_graph_metadata_snapshots__(self) -> Dict[SnapshotId, GraphMetadataSnapshot]:
        """
        Raises:
            GraphMetadataSnapshotStoreException
        """
        try:
            snapshots = {}
            for meta_df in self.__snapshot_ds.iter_dataframes():
                for _, row in meta_df.iterrows():
                    snapshots[row["id"]] = DataikuGraphMetadataSnapshotStore.from_df(row)

            return snapshots
        except Exception as ex:
            raise GraphMetadataSnapshotStoreException(
                f"Failed to load graph metadata saved configurations from dataset {self.__snapshot_ds_name}."
            ) from ex

    def get_all(self) -> List[GraphMetadataSnapshot]:
        """
        Raises:
            GraphMetadataSnapshotStoreException
        """
        return [{**s} for s in self.__load_graph_metadata_snapshots__().values()]

    def get_all_by_graph_id(self, id: GraphId) -> List[GraphMetadataSnapshot]:
        """
        Raises:
            GraphMetadataSnapshotStoreException
        """
        all_snapshots = self.get_all()

        return [{**s} for s in all_snapshots if s["graph_id"] == id]

    def get_by_id(self, id: SnapshotId) -> GraphMetadataSnapshot | None:
        """
        Returns None if id does not exist.
        Raises:
            GraphMetadataSnapshotStoreException
        """
        all_snapshots = self.get_all()

        matching_snapshots: List[GraphMetadataSnapshot] = [{**s} for s in all_snapshots if s["id"] == id]
        assert len(matching_snapshots) <= 1, f"Several saved configurations have the same id '{id}'."

        return matching_snapshots[0] if matching_snapshots else None

    def delete(self, id: SnapshotId) -> None:
        """
        Raises:
            GraphMetadataSnapshotStoreException
        """
        try:
            snapshot_df: DataFrame = self.__snapshot_ds.get_dataframe()

            self.__snapshot_ds.write_with_schema(snapshot_df.loc[snapshot_df["id"] != id])
        except Exception as ex:
            raise GraphMetadataSnapshotStoreException(
                f"Failed to delete metadata saved configuration for graph {id} in dataset {self.__snapshot_ds_name}."
            ) from ex

    def delete_all(self, graph_id: GraphId) -> None:
        """
        Raises:
            GraphMetadataSnapshotStoreException
        """
        try:
            snapshot_df: DataFrame = self.__snapshot_ds.get_dataframe()

            self.__snapshot_ds.write_with_schema(snapshot_df.loc[snapshot_df["graph_id"] != graph_id])
        except Exception as ex:
            raise GraphMetadataSnapshotStoreException(
                f"Failed to delete metadata saved configuration for graph {id} in dataset {self.__snapshot_ds_name}."
            ) from ex

    def save(self, snapshot: GraphMetadataSnapshot) -> None:
        """
        Raises:
            GraphMetadataSnapshotStoreException
        """
        snapshot = {**snapshot}
        snapshot_id = snapshot["id"]

        try:
            original_df: DataFrame = self.__snapshot_ds.get_dataframe()

            updated_df = pd.concat([original_df, DataikuGraphMetadataSnapshotStore.to_df(snapshot)])

            self.__snapshot_ds.write_dataframe(updated_df)
        except Exception as ex:
            raise GraphMetadataSnapshotStoreException(
                f"Failed to save metadata saved configuration {snapshot_id} for graph {snapshot['graph_id']} in dataset {self.__snapshot_ds_name}."
            ) from ex

    @staticmethod
    def to_df(snapshot: GraphMetadataSnapshot) -> DataFrame:
        return DataFrame(
            {
                "id": snapshot["id"],
                "graph_id": snapshot["graph_id"],
                "epoch_ms": [snapshot["epoch_ms"]],
                "snapshot": json.dumps(snapshot),
            }
        )

    @staticmethod
    def from_df(series: Series) -> GraphMetadataSnapshot:
        return json.loads(series["snapshot"])  # type: ignore
