# -*- coding: utf-8 -*-
import logging

from dataiku import Dataset
from dataiku.customrecipe import get_input_names_for_role, get_output_names_for_role, get_recipe_config

from solutions.graph.dataiku.batch_generators import generate_edges_batches
from solutions.graph.models import GraphMetadata, Sampling, to_edge_group_definitions
from solutions.graph.store.graph_metadata_snapshot_store import DataikuGraphMetadataSnapshotStore

logger = logging.getLogger(__name__)

# get recipe config
snapshot_id = get_recipe_config()["snapshot_id"]
edge_group_id = get_recipe_config()["edge_id"]

logger.info(f"Collecting edges {edge_group_id} of saved configuration {snapshot_id}.")

snapshots_store = DataikuGraphMetadataSnapshotStore(get_input_names_for_role("snapshots_ds")[0])

snapshot = snapshots_store.get_by_id(snapshot_id)
if not snapshot:
    raise ValueError(f"Cannot load saved configuration, no saved configuration {snapshot_id} is available.")

edges = snapshot["edges"]

if not edges.get(edge_group_id):
    raise ValueError("Cannot load edge group, no edge id: " + edge_group_id + " is available")

edges_group_meta = edges[edge_group_id]

edge_group = edges_group_meta["edge_id"]

edges_dataset_output = Dataset(get_output_names_for_role("main")[0])

graph_metadata = GraphMetadata(
    id=snapshot["graph_id"],
    name=snapshot["name"],
    nodes={**snapshot["nodes"]},
    edges={**snapshot["edges"]},
    nodes_view={**snapshot["nodes_view"]},
    edges_view={**snapshot["edges_view"]},
    sampling=Sampling(sampling="all", max_rows=0),
    cypher_queries=[{**q} for q in snapshot["cypher_queries"]],
)
group_definitions = to_edge_group_definitions(edges_group_meta, graph_metadata)

logger.info("Starting to process edges...")

first_batch = True
with edges_dataset_output.get_writer() as output_writer:
    for batch in generate_edges_batches(group_definitions, sampling=None):
        definition = batch["definition"]
        edge_group = definition["edge_group"]

        df = batch["df"]
        # id and label of the source node
        df["_dku_src_id"] = df[definition["source_column"]]
        df["_dku_src_label"] = definition["source_node_group"]
        # id and label of the target node
        df["_dku_tgt_id"] = df[definition["target_column"]]
        df["_dku_tgt_label"] = definition["target_node_group"]
        # _dku_label, label of the edge group.
        df["_dku_label"] = definition["edge_group"]

        if first_batch:
            edges_dataset_output.write_schema_from_dataframe(df, drop_and_create=True)
            first_batch = False

        output_writer.write_dataframe(df)

logger.info("Done processing edges.")
