from __future__ import annotations

import json
import logging
import os
import shutil
import tempfile
import time
from pathlib import Path
from typing import Optional, Tuple

from dataiku import Folder
from dataiku.core.base import is_container_exec
from dataiku.customrecipe import get_input_names_for_role, get_output_names_for_role, get_recipe_config

from solutions.graph.graph_builder import GraphBuilder
from solutions.graph.graph_db_instance_manager import LocalDbInstance
from solutions.graph.models import GraphMetadata, GraphMetadataSnapshot, Sampling
from solutions.graph.store.graph_metadata_snapshot_store import DataikuGraphMetadataSnapshotStore

logger = logging.getLogger(__name__)


def get_local_working_dir(folder: Folder) -> Tuple[str, Optional[tempfile.TemporaryDirectory]]:
    """
    Determines the local working directory for a given folder.

    If recipe runs in a container, we do not have direct access to the file system of DSS where
    FileSystem Folder content is stored. So it creates a temporary folder to work with.
    If the output is cloud storage, we need to download the files locally to be able to work with.
    Otherwise, we can just get the path from the dataiku Folder.

    Args:
        folder (Folder): The folder object to determine the working directory for.

    Returns:
        Tuple[str, Optional[tempfile.TemporaryDirectory]]:
            - The path to the local working directory.
            - The TemporaryDirectory object if a temporary directory was created, otherwise None.
    """

    folder_type = folder.get_info()["type"]

    if is_container_exec() or folder_type != "Filesystem":
        temp_dir = tempfile.TemporaryDirectory()
        base = temp_dir.name
        return base, temp_dir

    return folder.get_path(), None


def retrieve_snapshot(snapshot_id: str) -> GraphMetadataSnapshot:
    """
    Retrieves a graph metadata snapshot by its ID.

    Args:
        snapshot_id (str): The unique identifier of the snapshot to retrieve.

    Returns:
        GraphMetadataSnapshot: The snapshot object corresponding to the given ID.

    Raises:
        Exception: If the snapshot with the specified ID is not found.
    """
    dataset = get_input_names_for_role("snapshots_ds")[0]
    store = DataikuGraphMetadataSnapshotStore(dataset)
    snapshot = store.get_by_id(snapshot_id)
    if not snapshot:
        raise Exception(f"Saved configuration '{snapshot_id}' not found.")
    return snapshot


def build_graph(graph_meta: GraphMetadata, output_path: Path):
    """
    Builds a graph based on the provided metadata and saves it to the specified output path.

    Args:
        graph_meta (GraphMetadata): Metadata describing the graph structure and contents.
        output_path (Path): The file system path where the graph database will be stored.

    """

    with LocalDbInstance(output_path, readonly=False) as db:
        with GraphBuilder(db) as builder:
            for _ in builder.insert_all(graph_meta, track_progress=False):
                pass


def write_metadata_files(snapshot: GraphMetadataSnapshot, build_dir: str):
    """
    Writes metadata files to the specified build directory.

    This function creates two JSON files:
    1. "configuration.json": Contains the provided `snapshot` object serialized as JSON.
    2. "buildInfo.json": Contains build information, specifically the current epoch time in milliseconds.

    Args:
        snapshot (GraphMetadataSnapshot): The metadata snapshot to serialize and write to "configuration.json".
        build_dir (str): The directory path where the metadata files will be written.

    Raises:
        OSError: If there is an error creating or writing to the files.
        TypeError: If `snapshot` is not serializable to JSON.
    """
    config_path = os.path.join(build_dir, "configuration.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(snapshot, f, indent=4)

    info = {"epoch_ms": int(time.time() * 1000)}
    info_path = os.path.join(build_dir, "buildInfo.json")
    with open(info_path, "w", encoding="utf-8") as f:
        json.dump(info, f, indent=4)


def upload_outputs(folder: Folder, base: str, prefix: str):
    """
    Uploads all files from a specified directory (constructed from `base` and `prefix`) to a given folder object.
    The files are uploaded with their relative paths preserved.

    Args:
        folder (Folder): The destination folder object where files will be uploaded.
        base (str): The base directory path.
        prefix (str): The subdirectory or prefix to append to the base path.
    """
    full_path = os.path.join(base, prefix)
    for root, _, files in os.walk(full_path):
        for fn in files:
            local = os.path.join(root, fn)
            rel = os.path.relpath(local, base)
            folder.upload_file(rel, local)


def run(snapshot_id: str, output_folder: Folder):
    folder_type = output_folder.get_info()["type"]

    prefix = f"built-graphs/{snapshot_id}"
    base, temp_dir = get_local_working_dir(output_folder)

    try:
        build_dir = os.path.join(base, prefix)

        if os.path.isdir(build_dir):
            logger.info(f"Graph already exists at {build_dir}. Deleting it.")
            shutil.rmtree(build_dir)

        os.makedirs(build_dir, exist_ok=True)

        snapshot = retrieve_snapshot(snapshot_id)

        graph_meta = GraphMetadata(
            id=snapshot["graph_id"],
            name=snapshot["name"],
            nodes=snapshot["nodes"],
            edges=snapshot["edges"],
            nodes_view=snapshot["nodes_view"],
            edges_view=snapshot["edges_view"],
            sampling=Sampling(sampling="all", max_rows=0),
            cypher_queries=snapshot["cypher_queries"],
        )

        db_file = Path(os.path.join(build_dir, "db.kz"))
        logger.info(f"Building graph at {db_file}...")
        build_graph(graph_meta, db_file)

        write_metadata_files(snapshot, build_dir)

        if is_container_exec() or folder_type != "Filesystem":
            logger.info("Uploading outputs to Dataiku folder...")
            upload_outputs(output_folder, base, prefix)
        else:
            logger.info("Filesystem non-container: no upload needed.")
    finally:
        if temp_dir:
            temp_dir.cleanup()


def main():
    config = get_recipe_config()
    snapshot_id = config["snapshot_id"]
    output_name = get_output_names_for_role("main")[0]
    output_folder = Folder(output_name)
    run(snapshot_id, output_folder)


main()
