from ..models import StudyType, NovelStudy, CohortAge, SimilarStudy, CandidateSite
from typing import Union, Tuple, TypedDict
from datetime import datetime
from ..utils import (
    k_neighbors,
    include_sdoh,
    filter_dataset,
    df_to_dict,
    filter_dataset_by_nctid,
    transform_string,
)
import pandas as pd
import dataiku
from .config import CsiConfig, SELF_DEFINED_NTCID
from webaiku.apis.dataiku.api import dataiku_api
import numpy as np
import numpy.typing as npt
from typing import List, Optional
import pickle
from .study_similarity import study_similarity
from webaiku.apis.dataiku.formula import DataikuFormula
import json
import torch
import torch.nn as nn
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel


class AgeVariables(TypedDict):
    child: str
    adult: str
    older_adult: str


class StudyNeighbors:
    def __init__(self):
        # pretrained LLM 
        self.tokenizer = AutoTokenizer.from_pretrained(CsiConfig.LLM_DIRECTORY)
        self.model = AutoModel.from_pretrained(CsiConfig.LLM_DIRECTORY)
        
        # Similarity Index
        self.index = None
        # self.index = StudyNeighbors.load_similarity_index()

        ##
        self.studies_w_scores_dataset_name = (
            CsiConfig.STUDIES_W_SCORES_SDOH_DATASET_NAME
            if include_sdoh()
            else CsiConfig.STUDIES_W_SCORES_DATASET_NAME
        )

    @staticmethod
    def load_similarity_index():
        index_folder = dataiku.Folder(
            CsiConfig.SIMILARITY_INDEX_FOLDER_ID, project_key=dataiku_api.project_key
        )
        import faiss
        with index_folder.get_download_stream(CsiConfig.SIMILARITY_INDEX_PKL_FILE) as f:
            data = f.read()
            index = faiss.deserialize_index(pickle.loads(data))
            return index

    def get_index(self):
        if not self.index:
            self.index = StudyNeighbors.load_similarity_index()
        return self.index
    
    def get_study_neighbors_ids_df(self, data: Union[NovelStudy, str]):
        return (
            self.get_study_neighbors_ids_from_study(data)
            if isinstance(data, str)
            else self.get_study_neighbors_ids_from_protocol(data)
        )

    def get_nearest_studies_and_sites_w_scores(self, data: Union[NovelStudy, str]):
        neighbors_ids_df: pd.DataFrame = self.get_study_neighbors_ids_df(data)
        neighbors_ids_df = neighbors_ids_df.rename(
            columns={"neighbor_id": "Study_id"}
        ).iloc[:, :-1]
        neighbors_ids = list(pd.unique(neighbors_ids_df["Study_id"]))
        filter_formula = DataikuFormula()
        filter_formula.filter_column_by_values("NCTId", neighbors_ids)
        filter_expression = filter_formula.execute()
        filtered_scores_dataset = filter_dataset(
            self.studies_w_scores_dataset_name, filters=filter_expression
        )
        filtered_sites_dataset = filter_dataset(
            CsiConfig.STUDIES_W_SITES_JOINED_DATASET_NAME, filters=filter_expression
        )

        filtered_sites_dataset = filtered_sites_dataset.rename(
            columns={
                "NCTId": "Study_id",
                "Location_facility_normalized": "Site_name",
                "Site_City": "City",
                "Site_Region": "Region/State",
                "Site_Country": "Country",
                "Site_zip": "Zip",
            }
        )

        filtered_scores_dataset = filtered_scores_dataset.rename(
            columns={"NCTId": "Study_id"}
        )

        # Merged studies
        merged_studies = neighbors_ids_df.merge(
            filtered_scores_dataset, on="Study_id"
        ).sort_values(by="Similarity", ascending=False)

        merged_studies["Rank"] = (
            merged_studies["Similarity"]
            .rank(ascending=False, na_option="bottom")
            .astype(int)
        )

        merged_studies["Rank"] = merged_studies["Rank"].apply(
            lambda x: f"RANK {x}" if not pd.isna(x) else "RANK NA"
        )

        merged_studies = merged_studies.fillna("NA")

        # Merged Sites
        sites_results = df_to_dict(
            self.get_candidate_sites_data(
                neighbors_ids_df=neighbors_ids_df,
                nctid=data if isinstance(data, str) else None,
            ),
            keep_first=False,
        )

        studies_results = df_to_dict(merged_studies, keep_first=False)

        return {
            "studies": [SimilarStudy(**result) for result in studies_results],
            "sites": [CandidateSite(**result) for result in sites_results],
        }

    def get_study_neighbors_ids_from_protocol(self, novel_study: NovelStudy):
        input_study_protocol_df = StudyNeighbors.get_input_protocol_df(
            novel_study=novel_study
        )
        input_embeded_array = self.get_normalized_protocol_array(
            input_study_protocol_df
        )
        protocol_nearest_neighbor_df = self.search_nearest_studies(
            input_embeded_array, SELF_DEFINED_NTCID
        )
        return protocol_nearest_neighbor_df

    def get_study_neighbors_ids_from_study(self, ntcid: str):
        query_array = self.get_existing_study_embedded_array(ntcid=ntcid)
        if query_array is None:
            return None

        study_nearest_neighbor_df = self.search_nearest_studies(
            query_array, ntcid=ntcid
        )
        return study_nearest_neighbor_df

    def get_existing_study_embedded_array(self, ntcid: str):
        study_index = (
            study_similarity.nctids.index(ntcid)
            if ntcid in study_similarity.nctids
            else -1
        )
        if study_index >= 0:
            return np.expand_dims(self.get_index().reconstruct(study_index), axis=0)
        else:
            return None

    def search_nearest_studies(self, query_array: npt.NDArray, ntcid: str):
        distance, knearest_index = self.get_index().search(
            query_array, k_neighbors())
        nearest_study_ids = [study_similarity.nctids[i]
                             for i in knearest_index[0]]

        nearest_studies = pd.DataFrame(
            {"neighbor_id": nearest_study_ids, "Similarity": distance[0]}
        )
        nearest_studies["input_id"] = ntcid
        return nearest_studies.query("neighbor_id!=input_id")

    @staticmethod
    def get_input_protocol_df(novel_study: NovelStudy):
        log_time = datetime.now()
        age_vars = StudyNeighbors.get_age_vars(novel_study["cohortAge"])
        age_label = "-".join([age_vars["child"], age_vars["adult"], age_vars["older_adult"]])
        inclusion = novel_study["inclusionCriteria"].strip()
        exclusion = novel_study["exclusionCriteria"].strip()
        mesh_conditions = ". ".join(novel_study["meshConditions"])
        input_study_df = pd.DataFrame(
            {
                "log_time": [log_time],
                "title": [novel_study["title"]],
                "brief_summary": [novel_study["briefSummary"]],
                "age_group_label": [age_label],
                "healthy_volunteers": [novel_study["healthyVolunteers"]],
                "sex": [novel_study["cohortSex"]],
                "inclusion_criteria1": [inclusion],
                "exclusion_criteria1": [exclusion],
                "mesh_conditions": [mesh_conditions]
            }
        )
        return input_study_df

    def get_normalized_protocol_array(self, input_protocol_df: pd.DataFrame):

        embeded_features_tensor = self.embed_features(input_protocol_df)
        # Convert PyTorch tensor to NumPy array
        numpy_array = embeded_features_tensor.cpu().detach().numpy()
        # Convert NumPy array to contiguous array with data type np.float32
        contiguous_array = np.ascontiguousarray(numpy_array, dtype=np.float32)
        return contiguous_array
    
    def embed_features(self, input_protocol_df: pd.DataFrame):
        embeddings = []
        unstructured_cols = [
            'brief_summary', 'inclusion_criteria1', 'exclusion_criteria1', 'mesh_conditions']
        categorical_cols = [
            'age_group_label', 'sex', 'healthy_volunteers']

        for column in unstructured_cols:
            embedded_vector = self.embed_text(input_protocol_df, column)
            if column == "mesh_conditions":
                embedded_vector = embedded_vector*2
            embeddings.append(embedded_vector)

        for column in categorical_cols:
            embedded_vector = self.embed_categorical_variable(input_protocol_df, column)
            embeddings.append(embedded_vector)

        concat_embedding = torch.cat(embeddings, dim=1)
        normalized_embedding = self.normalize_vector(concat_embedding)
        return normalized_embedding

    def embed_text(self, input_protocol_df: pd.DataFrame, column: str):
        batch_size = 32

        # Tokenize the text in batches
        tokenized_texts = self.tokenizer(
            list(input_protocol_df[column].fillna('NA')),
            padding=True,
            truncation=True,
            return_tensors='pt',
            max_length=128)

        # Forward pass through the BERT model in batches
        column_embeddings = []
        num_batches = (len(tokenized_texts['input_ids']) + batch_size - 1) // batch_size
        for i in tqdm(range(num_batches), desc=f'Embedding {column}'):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, len(tokenized_texts['input_ids']))
            batch_tokenized_texts = {
                key: value[start_idx:end_idx] for key, value in tokenized_texts.items()}
            with torch.no_grad():
                outputs = self.model(**batch_tokenized_texts)
                # Use mean pooling to get sentence embeddings
                batch_embeddings = outputs.last_hidden_state.mean(dim=1)

            # Normalize the batch embeddings
            normalized_batch_embeddings = self.normalize_vector(batch_embeddings)

            column_embeddings.append(normalized_batch_embeddings)

        # Concatenate the normalized embeddings from all batches
        embedded_vector = torch.cat(column_embeddings, dim=0)

        return embedded_vector

    def load_label_encoder(self, column: str):
        encoder_dict = {
            "age_group_label": CsiConfig.AGE_LABEL_ENCODER_PK_FILE,
            "healthy_volunteers": CsiConfig.HEALTHY_VOLUNTEERS_LABEL_ENCODER_PK_FILE,
            "sex": CsiConfig.SEX_LABEL_ENCODER_PK_FILE
        }
        encoder_dir = encoder_dict[column]
        index_folder = dataiku.Folder(
            CsiConfig.SIMILARITY_INDEX_FOLDER_ID, project_key=dataiku_api.project_key
        )
        with index_folder.get_download_stream(encoder_dir) as f:
            data = f.read()
            label_encoder = pickle.loads(data)
            return label_encoder

    def embed_categorical_variable(self, input_protocol_df: pd.DataFrame, column: str):
        # Set the seed for random number generator
        torch.manual_seed(42)

        # Extracting categorical variable from DataFrame
        categorical_variable = input_protocol_df[column].astype(str).values

        # Label Encoding
        label_encoder = self.load_label_encoder(column)
        label_encoded = label_encoder.transform(categorical_variable)

        # Convert to a tensor
        tensor_encoded = torch.tensor(label_encoded)

        # Define the embedding layer
        num_categories = len(label_encoder.classes_)  # Number of unique categories
        embedding_dim = num_categories  # Dimensionality of the embedding vectors
        embedding_layer = nn.Embedding(num_categories, embedding_dim)

        # Embedding the categorical variable

        embedded_data = embedding_layer(tensor_encoded)
        normalized_vector = nn.functional.normalize(embedded_data, p=2, dim=-1)

        return normalized_vector

    def normalize_vector(self, vector: torch.Tensor):
        vector_norm = torch.norm(vector, p=2, dim=-1, keepdim=True)
        normalized_vector = vector / vector_norm
        return normalized_vector

    def normalise_feature_array(self, data: npt.NDArray, cols: List[str]):
        index = [self.features.index(item) for item in cols]
        if cols == self.meshterm_conditions_cols:
            normalized_array = (
                data[:, index]
                / np.linalg.norm(data[:, index], axis=1, keepdims=True)
                * 2
            )
        else:
            normalized_array = data[:, index] / np.linalg.norm(
                data[:, index], axis=1, keepdims=True
            )
        return normalized_array

    @staticmethod
    def get_age_vars(cohort_ages: List[str]) -> AgeVariables:
        cohort_ages = [cohort_age.lower().strip()
                       for cohort_age in cohort_ages]
        return AgeVariables(
            child="1" if CohortAge.CHILD.value.lower() in cohort_ages else "0",
            adult="1" if CohortAge.ADULT.value.lower() in cohort_ages else "0",
            older_adult="1"
            if CohortAge.OLDER_ADULT.value.lower() in cohort_ages
            else "0")


    def get_candidate_sites_data(
        self, neighbors_ids_df: pd.DataFrame, nctid: Optional[str]
    ):
        neighbors_ids = list(pd.unique(neighbors_ids_df["Study_id"]))
        filter_formula = DataikuFormula()
        filter_formula.filter_column_by_values("NCTId", neighbors_ids)
        filter_expression = filter_formula.execute()
        studies_w_sites_df = filter_dataset(
            CsiConfig.STUDIES_W_SITES_JOINED_DATASET_NAME, filters=filter_expression
        )
        studies_w_sites_df = studies_w_sites_df.rename(
            columns={
                "NCTId": "Study_id",
                "Location_facility_normalized": "Site_name",
                "Site_City": "City",
                "Site_Region": "Region/State",
                "Site_Country": "Country",
                "Site_zip": "Zip",
            }
        )
        nearest_studies_w_sites_df = neighbors_ids_df.merge(
            studies_w_sites_df, on="Study_id"
        ).fillna("NA")

        if nctid:
            study_facilities = StudyNeighbors.get_study_facilities(nctid=nctid)
            nearest_studies_w_sites_df = nearest_studies_w_sites_df[
                ~nearest_studies_w_sites_df.Facility_ID.isin(study_facilities)
            ]

        group_cols = ["Facility_ID", "Facility_preferred_name"]

        # Create the 'Study_info' column

        nearest_studies_w_sites_df = nearest_studies_w_sites_df.drop_duplicates(
            subset=[
                "Facility_ID",
                "Facility_preferred_name",
                "Study_id",
                "Study_status",
                "Site_status",
            ]
        )

        nearest_studies_w_sites_df["Study_info"] = nearest_studies_w_sites_df.apply(
            lambda row: [row.Study_id, row.Study_status, row.Site_status], axis=1
        )

        nearest_studies_w_sites_df = (
            nearest_studies_w_sites_df.groupby(group_cols)
            .agg(
                Similarity=("Similarity", lambda x: list(x)),
                City=("City", "first"),
                Region_State=("Region/State", "first"),
                Country=("Country", "first"),
                Zip=("Zip", "first"),
                GeoZip=("GeoZip", "first"),
                Study_info=("Study_info", lambda x: list(x)),
            )
            .reset_index()
        )

        nearest_studies_w_sites_df["sort_val"] = nearest_studies_w_sites_df[
            "Similarity"
        ].apply(lambda x: sum(x))

        nearest_studies_w_sites_df[
            "Has_competing_studies"
        ] = nearest_studies_w_sites_df["Study_info"].apply(
            lambda study_info_list: StudyNeighbors.has_intersection(
                [status for _, status, _ in study_info_list]
            )
        )
        nearest_studies_w_sites_df = nearest_studies_w_sites_df.sort_values(
            by="sort_val", ascending=False
        )

        return nearest_studies_w_sites_df

    @staticmethod
    def get_study_facilities(nctid: str):
        facilities_df = filter_dataset_by_nctid(
            CsiConfig.FACILITIES_W_PREFERRED_NAME_DATASET_NAME, nctid=nctid
        )
        return list(pd.unique(facilities_df["Facility_ID"]))

    @staticmethod
    def has_intersection(status_list: List[str]):
        taget_status = [
            transform_string("Recruiting"),
            transform_string("Not Yet Recruiting"),
        ]
        return any([transform_string(status) in taget_status for status in status_list])


study_neighbors = StudyNeighbors()
