from datetime import datetime

from solutions.graph.store.graph_metadata_snapshot_store import DataikuGraphMetadataSnapshotStore


def do(payload, config, plugin_config, inputs):
    # /!\ first dataset in inputs is expected to be the snapshot dataset

    parameter_name = payload["parameterName"]
    snapshot_dataset_name = inputs[0]["fullName"]
    snapshot_store = DataikuGraphMetadataSnapshotStore(snapshot_ds_name=snapshot_dataset_name)
    snapshots = snapshot_store.get_all()
    if parameter_name == "snapshot_id":
        # Sort choices by epoch_ms in descending order
        snapshots.sort(key=lambda snapshot: snapshot["epoch_ms"], reverse=True)

        choices = [
            {
                "value": snapshot["id"],
                "label": f"{snapshot['name']} (created {datetime.fromtimestamp(snapshot['epoch_ms'] / 1000).strftime('%Y-%m-%d %H:%M')}) - {snapshot['id']}",
            }
            for snapshot in snapshots
        ]
        return {"choices": choices}
    elif parameter_name == "edge_id":
        if root_model := payload.get("rootModel", ""):
            if id := root_model.get("snapshot_id", ""):
                matching_snapshot = [s for s in snapshots if s["id"] == id][0]
                all_groups = [matching_snapshot["edges"][key] for key in matching_snapshot["edges"].keys()]
                return {"choices": [{"value": e["edge_id"], "label": e["edge_group"]} for e in all_groups]}
        pass
    elif parameter_name == "node_id":
        if root_model := payload.get("rootModel", ""):
            if id := root_model.get("snapshot_id", ""):
                matching_snapshot = [s for s in snapshots if s["id"] == id][0]
                all_groups = [matching_snapshot["nodes"][key] for key in matching_snapshot["nodes"].keys()]
                return {"choices": [{"value": e["node_id"], "label": e["node_group"]} for e in all_groups]}
        pass
    else:
        raise Exception(f"Unexpected parameter name '{parameter_name}'.")
