import functools
from typing import Dict, List, Tuple

import dataiku


@functools.lru_cache(maxsize=128)
def _list_augmented_llms_by_project(selected_projects: Tuple[str, ...]):
    AUG_RAG_TYPE = "RETRIEVAL_AUGMENTED"
    aug_llm_by_project: Dict[str, List[Dict[str, str]]] = {}
    client = dataiku.api_client()
    for project in selected_projects:
        project_ob = client.get_project(project)
        aug_llms = [llm for llm in project_ob.list_llms() if llm.get("type") == AUG_RAG_TYPE]
        all_mode_models = []
        if aug_llms:
            for model in aug_llms:
                # TODO uncomment this for dss 14.1, we need to filter out the ones not in smart mode
                # rag_model_obj = project_ob.get_saved_model(model.get("savedModelSmartId"))
                # rag_model_settings = rag_model_obj.get_settings().get_raw()
                # active_version = rag_model_obj.get_active_version()
                # if active_version:
                #     active_model = next(
                #         el
                #         for el in rag_model_settings.get("inlineVersions")
                #         if el.get("versionId") == active_version.get("id")
                #     )
                # strategy = (
                #     active_model.get("ragllmSettings", {}).get("searchInputStrategySettings", {}).get("strategy")
                # )
                #     if strategy == RA_MODEL_REWRITE_MODE:
                all_mode_models.append(
                    {
                        "value": project + ":" + model.get("id"),
                        "label": model.get("friendlyName"),
                        "description": "",  # description will be filled manually by admin
                    }
                )
            if all_mode_models:
                aug_llm_by_project[project] = all_mode_models
    return aug_llm_by_project


def list_augmented_llms_by_project(selected_projects):
    return _list_augmented_llms_by_project(tuple(selected_projects))


@functools.lru_cache(maxsize=128)
def _map_aug_llms_id_name(selected_projects, with_project_key: bool = False) -> Dict[str, str]:
    llms_by_project = list_augmented_llms_by_project(selected_projects)
    aug_llms_map = {}
    for project_key, llms in llms_by_project.items():
        for llm in llms:
            aug_llms_map[llm["value"]] = llm["label"] if not with_project_key else f"[{project_key}] {llm['label']}"
    return aug_llms_map


def map_aug_llms_id_name(selected_projects, with_project_key: bool = False) -> Dict[str, str]:
    return _map_aug_llms_id_name(selected_projects=tuple(selected_projects), with_project_key=with_project_key)


def filter_agents_per_user(user: str, agents: List[str]) -> List[str]:
    accessible_agents = []
    for agent in agents:
        project_key = agent.split(":")[0]
        agent_id = agent.split(":")[2]
        with dataiku.WebappImpersonationContext() as context:
            client = dataiku.api_client()
            permissions_check = client.new_permission_check()
            permissions_check.with_item("savedmodels", project_key, agent_id)
            checks = permissions_check.execute_raw()
        if checks["items"] and checks["items"][0]["permitted"]:
            # User has permission - add the agent to accessible_agents
            accessible_agents.append(agent)

    return accessible_agents
