from langchain_core.tools import tool
from langchain_core.messages import HumanMessage, ToolMessage
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain import hub
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder

# Dataiku Agent Imports
from dataiku.langchain.dku_llm import DKUChatLLM
from dataiku.llm.python import BaseLLM, GenericLangChainAgentWrapper
import dataiku

import pandas as pd
import requests, pickle
from urllib import parse
from typing import Union

from dku_utils.projects.project_commons import get_current_project_and_variables

project, variables = get_current_project_and_variables()
llm_connection = variables['standard']["light_llm"]


research_dataset24 = '9d522605-3f74-54d2-8f08-a8c07e8045cf'
research_dataset23 = '67285c54-b340-552f-9a53-b26a8c00b4d3'
research_dataset22 = "38413cb6-ae5e-5c7e-8684-38243340bb6c"
research_dataset21 = '28db09b7-30fa-5b7c-b3b3-8b6f563de10e'
research_dataset20 = 'b6cadb27-a130-5b5f-a71f-a895a98fc6c0'
research_dataset19 = '2a708ddc-cc05-546e-89e4-2953cb905103'
research_dataset18 = 'e947198c-7106-5af5-b932-1a037bfffdea'
OPEN_PAYMENT_RESEARCH_DATASETS = [research_dataset18, research_dataset19, research_dataset20, research_dataset21, research_dataset22, research_dataset23, research_dataset24]
OP_COLUMNS = [
     'Change_Type',
     'Covered_Recipient_Type',
     'Noncovered_Recipient_Entity_Name',
     'Teaching_Hospital_CCN',
     'Teaching_Hospital_ID',
     'Teaching_Hospital_Name',
     'Recipient_Primary_Business_Street_Address_Line1',
     'Recipient_Primary_Business_Street_Address_Line2',
     'Recipient_City',
     'Recipient_State',
     'Recipient_Zip_Code',
     'Recipient_Country',
     'Recipient_Province',
     'Recipient_Postal_Code',
     'Principal_Investigator_1_Covered_Recipient_Type',
     'Principal_Investigator_1_Profile_ID',
     'Principal_Investigator_1_NPI',
     'Principal_Investigator_1_First_Name',
     'Principal_Investigator_1_Middle_Name',
     'Principal_Investigator_1_Last_Name',
     'Principal_Investigator_1_Name_Suffix',
     'Principal_Investigator_1_Business_Street_Address_Line1',
     'Principal_Investigator_1_Business_Street_Address_Line2',
     'Principal_Investigator_1_City',
     'Principal_Investigator_1_State',
     'Principal_Investigator_1_Zip_Code',
     'Principal_Investigator_1_Country',
     'Principal_Investigator_1_Province',
     'Principal_Investigator_1_Postal_Code',
     'Principal_Investigator_1_Primary_Type_1',
     'Principal_Investigator_1_Specialty_1',
     'Submitting_Applicable_Manufacturer_or_Applicable_GPO_Name',
     'Applicable_Manufacturer_or_Applicable_GPO_Making_Payment_ID',
     'Applicable_Manufacturer_or_Applicable_GPO_Making_Payment_Name',
     'Applicable_Manufacturer_or_Applicable_GPO_Making_Payment_State',
     'Applicable_Manufacturer_or_Applicable_GPO_Making_Payment_Country',
     'Related_Product_Indicator',
     'Covered_or_Noncovered_Indicator_1',
     'Indicate_Drug_or_Biological_or_Device_or_Medical_Supply_1',
     'Product_Category_or_Therapeutic_Area_1',
     'Name_of_Drug_or_Biological_or_Device_or_Medical_Supply_1',
     'Associated_Drug_or_Biological_NDC_1',
     'Associated_Device_or_Medical_Supply_PDI_1',
     'Total_Amount_of_Payment_USDollars',
     'Date_of_Payment',
     'Form_of_Payment_or_Transfer_of_Value',
     'Record_ID',
     'Program_Year',
     'Payment_Publication_Date',
     'ClinicalTrials_Gov_Identifier'
]


NCTID_CACHE_FILE = "nctid_cache.pkl"
NPI_CACHE_FILE = "npi_cache.pkl"
cache_folder = dataiku.Folder("F4z6F5y2")

# load cache from folder
def load_pickle_from_folder(folder, filename):
    if filename in folder.list_paths_in_partition():
        with folder.get_download_stream(filename) as f:
            data = f.read()
            return pickle.loads(data)
    return {}

nctid_cache = load_pickle_from_folder(cache_folder, NCTID_CACHE_FILE)
npi_cache = load_pickle_from_folder(cache_folder, NPI_CACHE_FILE)



def query_payments_by_nctid(nctid: str) -> str:
    """Return the research payments for a given study made by sponsors to healthcare providers between 2018 and 2024.
    
    Args:
    nctid: Unique Protocol Identification Number. An unique identifier assigned to the study protocol.
    """
    if nctid in nctid_cache:
        print(f"Fetch query result from cache for study {nctid}")
        return nctid_cache[nctid]
    
    total_dfs = []
    for dataset in OPEN_PAYMENT_RESEARCH_DATASETS:
        query_by_study = payment_by_nctid_syntax(nctid, dataset)
        data = fetch_data(query_by_study)
        print(f'Table {dataset} queried')
        if data:
            try:
                df = pd.DataFrame(data)[OP_COLUMNS]
                df['Total_Amount_of_Payment_USDollars'] = df['Total_Amount_of_Payment_USDollars'].astype('float')
                df['Date_of_Payment'] = pd.to_datetime(df['Date_of_Payment'], format='%m/%d/%Y')
                total_dfs.append(df)
                print(f'Query saved from {dataset}')
            except Exception as e:
                print(f"Data parsing error in dataset {dataset}: {e}")

    if total_dfs:
        tdf = pd.concat(total_dfs)
        result= tdf.to_json(orient="records")
        nctid_cache[nctid] = result
    else:
        result = f"No payment records for study {nctid} between 2018 to 2024."
        nctid_cache[nctid] = result
    
    # Update cache file
    with cache_folder.get_writer(NCTID_CACHE_FILE) as w:
        cache_pk = pickle.dumps(nctid_cache)
        w.write(cache_pk)
    return result
 
    
@tool
def aggregate_payments_by_nctid(
    nctid: str,
    groupby_columns: list[str],
    aggregation_dict: dict[str, Union[str, tuple[str, str], list[str]]],
    sort_by: str = None,
    ascending: bool = False
) -> str:
    """Dynamically aggregate payment records for a given study ID based on grouping and aggregation specs.
    
    
    Args:
        nctid: Unique Protocol Identification Number. An unique identifier assigned to the study protocol.
        groupby_columns: List of columns to group by
        aggregation_dict: Dictionary mapping columns to aggregation functions
        sort_by: Optional column to sort by after aggregation
        ascending: Sort order

    Returns:
        Aggregated data in JSON format or error message.
        
    Examples:
    groupby_columns = [
        'Principal_Investigator_1_Profile_ID',
        'Principal_Investigator_1_NPI', 
        'Form_of_Payment_or_Transfer_of_Value'
    ]
    aggregation_dict = {
        'PI_First_Name': ['Principal_Investigator_1_First_Name', 'first'],
        'PI_Middle_Name': ['Principal_Investigator_1_Middle_Name', 'first'],
        'PI_Last_Name': ['Principal_Investigator_1_Last_Name', 'first'],
        'PI_Specialty': ['Principal_Investigator_1_Specialty_1', 'last'],
        'PI_Street_1': ['Principal_Investigator_1_Business_Street_Address_Line1', 'last'],
        'PI_Street_2': ['Principal_Investigator_1_Business_Street_Address_Line2', 'last'],
        'PI_City': ['Principal_Investigator_1_City', 'last'],
        'PI_State': ['Principal_Investigator_1_State', 'last'],
        'PI_Zip': ['Principal_Investigator_1_Zip_Code', 'last'],
        'PI_Country': ['Principal_Investigator_1_Country', 'last'],
        'PI_Primary_Type': ['Principal_Investigator_1_Primary_Type_1', 'last'],
        'Total_Payment_Sum': ['Total_Amount_of_Payment_USDollars', 'sum'],
        'Record_Count': ['Record_ID', 'count'],
        'Start_Date': ['Date_of_Payment', 'first'],
        'End_Date': ['Date_of_Payment', 'last'],
        'Recipient_Name': ['Noncovered_Recipient_Entity_Name', 'last'],
        'Teaching_Hospital': ['Teaching_Hospital_Name', 'last'],
        'Recipient_Street_1': ['Recipient_Primary_Business_Street_Address_Line1', 'last'],
        'Recipient_Street_2': ['Recipient_Primary_Business_Street_Address_Line2', 'last'],
        'Recipient_City': ['Recipient_City', 'last'],
        'Recipient_State': ['Recipient_State', 'last'],
        'Recipient_Zip': ['Recipient_Zip_Code', 'last'],
        'Recipient_Country': ['Recipient_Country', 'last']
    }
    *custom aggregation functions:
    - 'nunique': returm distinct value count
    - 'unique_list': return struct of unique values 
    """
    
    result = query_payments_by_nctid(nctid)
    if result.startswith("No payment records"):
        return result
    
    try:
        df = pd.read_json(result, orient="records")
    except ValueError as e:
        return f"Failed to parse payment data for study {nctid}: {e}"
    
    try:
        # Convert list values to tuples if needed
        cleaned_aggregation_dict = build_agg_dict(aggregation_dict)
        agg_df = (
            df
            .sort_values(by=["Date_of_Payment"])
            .groupby(groupby_columns)
            .agg(**cleaned_aggregation_dict)
            .reset_index()
        )

        if sort_by:
            agg_df = agg_df.sort_values(by=sort_by, ascending=ascending)

        return agg_df.to_json(orient="records")
    
    except Exception as e:
        return f"Aggregation error: {e}"
    


def query_payments_by_npi(npi: str) -> str:
    """Return the research payments made by sponsors to a given principal investigator between 2018 and 2024.
    
    Args:
    npi: National Provider Identifier, a unique identification number for covered health care providers registered in the US.
    """
    if npi in npi_cache:
        print(f"Fetch query result from cache for NPI {npi}")
        return npi_cache[npi]
    
    total_dfs = []
    for dataset in OPEN_PAYMENT_RESEARCH_DATASETS:
        query_by_study = payment_by_npi_syntax(npi, dataset)
        data = fetch_data(query_by_study)
        print(f'Table {dataset} queried')
        if data:
            try:
                df = pd.DataFrame(data)[OP_COLUMNS]
                df['Total_Amount_of_Payment_USDollars'] = df['Total_Amount_of_Payment_USDollars'].astype('float')
                df['Date_of_Payment'] = pd.to_datetime(df['Date_of_Payment'], format='%m/%d/%Y')
                total_dfs.append(df)
                print(f'Query saved from {dataset}')
            except Exception as e:
                print(f"Data parsing error in dataset {dataset}: {e}")

    if total_dfs:
        tdf = pd.concat(total_dfs)
        result= tdf.to_json(orient="records")
        npi_cache[npi] = result
    else:
        result = f"No payment records for covered health provider {npi} between 2018 to 2024."
        npi_cache[npi] = result
    
    # Update cache file
    with cache_folder.get_writer(NPI_CACHE_FILE) as w:
        cache_pk = pickle.dumps(npi_cache)
        w.write(cache_pk)
    return result


@tool
def aggregate_payments_by_npi(
    npi: str,
    groupby_columns: list[str],
    aggregation_dict: dict[str, Union[str, tuple[str, str], list[str]]],
    sort_by: str = None,
    ascending: bool = False
) -> str:
    """Dynamically aggregate payment records for a given principal investigator NPI based on grouping and aggregation specs.
    
    
    Args:
        npi: National Provider Identifier, a unique identification number for covered health care providers registered in the US.
        groupby_columns: List of columns to group by
        aggregation_dict: Dictionary mapping columns to aggregation functions
        sort_by: Optional column to sort by after aggregation
        ascending: Sort order

    Returns:
        Aggregated data in JSON format or error message.
    
    Example:
    groupby_columns = [
        'Applicable_Manufacturer_or_Applicable_GPO_Making_Payment_ID',
        'Applicable_Manufacturer_or_Applicable_GPO_Making_Payment_Name', 
        'Form_of_Payment_or_Transfer_of_Value'
        ]
    aggregation_dict = {
        "PI_First_Name": ["Principal_Investigator_1_First_Name", "first"],
        "PI_Middle_Name": ["Principal_Investigator_1_Middle_Name", "first"],
        "PI_Last_Name": ["Principal_Investigator_1_Last_Name", "first"],
        "PI_Specialty": ["Principal_Investigator_1_Specialty_1", "last"],
        "PI_Street_1": ["Principal_Investigator_1_Business_Street_Address_Line1", "last"],
        "PI_Street_2": ["Principal_Investigator_1_Business_Street_Address_Line2", "last"],
        "PI_City": ["Principal_Investigator_1_City", "last"],
        "PI_State": ["Principal_Investigator_1_State", "last"],
        "PI_Zip": ["Principal_Investigator_1_Zip_Code", "last"],
        "PI_Country": ["Principal_Investigator_1_Country", "last"],
        "PI_Primary_Type": ["Principal_Investigator_1_Primary_Type_1", "last"],
        "Total_Payment_Sum": ["Total_Amount_of_Payment_USDollars", "sum"],
        "Record_Count": ["Record_ID", "count"],
        "Study_Count": ["ClinicalTrials_Gov_Identifier", "nunique"],  # Custom aggregation
        "Study_Ids": ["ClinicalTrials_Gov_Identifier", "unique_list"],  # Custom aggregation
        "Start_Date": ["Date_of_Payment", "first"],
        "End_Date": ["Date_of_Payment", "last"],
        "Recipient_Name": ["Noncovered_Recipient_Entity_Name", "last"],
        "Teaching_Hospital": ["Teaching_Hospital_Name", "last"],
        "Recipient_Street_1": ["Recipient_Primary_Business_Street_Address_Line1", "last"],
        "Recipient_Street_2": ["Recipient_Primary_Business_Street_Address_Line2", "last"],
        "Recipient_City": ["Recipient_City", "last"],
        "Recipient_State": ["Recipient_State", "last"],
        "Recipient_Zip": ["Recipient_Zip_Code", "last"],
        "Recipient_Country": ["Recipient_Country", "last"]
     *custom aggregation functions:
    - 'nunique': returm distinct value count
    - 'unique_list': return struct of unique values 
    """
    
    result = query_payments_by_npi(npi)
    if result.startswith("No payment records"):
        return result
    
    try:
        df = pd.read_json(result, orient="records")
    except ValueError as e:
        return f"Failed to parse payment data for investigator {npi}: {e}"
    
    try:
        # Convert list values to tuples if needed
        cleaned_aggregation_dict = build_agg_dict(aggregation_dict)
        agg_df = (
            df
            .sort_values(by=["Date_of_Payment"])
            .groupby(groupby_columns)
            .agg(**cleaned_aggregation_dict)
            .reset_index()
        )

        if sort_by:
            agg_df = agg_df.sort_values(by=sort_by, ascending=ascending)

        return agg_df.to_json(orient="records")
    
    except Exception as e:
        return f"Aggregation error: {e}"
    
    
    
tools = [aggregate_payments_by_nctid, aggregate_payments_by_npi]
chat_model = DKUChatLLM(llm_id=llm_connection)

prompt = ChatPromptTemplate.from_messages(
    [
        ("system","""
        
        #Role
        As a clinical operation analyst, you are tasked with generating competitiveness analysis on principal investigators involved in recent clinical trial operations. 

        # About the available source
        The Open Payments collects and publishes information about financial relationships between drug and medical device companies (referred to as "reporting entities") and certain health care providers (referred to as "covered recipients"). You have access to the research payments from 2018 to 2024.

        #Guideline:
        - You should break down a complex question into separate tasks for individual tools and then execute these tasks in sequence, if necessary, to get the final answer. For example, if you search for detailed participating investigator information in a given study but don't have NPI numbers, you should break down the task into two steps. First, search for the investigators and their NPI in the study, then serach for more detailed information on individual investigators. 
        - When listing principal investigators involved in a study, always display their NPI number, specialty, location, and affiliations.
        - When analyzing individual principal investigators, group the grant amount by sponsor and list all associated clinical trials (NCTId).
        - NCTId and NPI are critical foreign keys. Always include them in your answers.        
        """),
        MessagesPlaceholder("chat_history", optional=True),
        ("human", "{input}"),
        MessagesPlaceholder("agent_scratchpad"),
    ]
)

agent = create_openai_tools_agent(
    chat_model, 
    tools,  
    prompt= prompt
)
agent_executor = AgentExecutor(agent=agent, tools=tools)


class MyLLM(GenericLangChainAgentWrapper):
    def __init__(self):
        super(MyLLM, self).__init__(agent_executor)


def payment_by_nctid_syntax(nctid, dataset):
    query_by_study = f'[SELECT * FROM {dataset}][WHERE ClinicalTrials_Gov_Identifier = "{nctid}"]'
    return query_by_study

def payment_by_npi_syntax(npi, dataset):
    query_by_npi = f'[SELECT * FROM {dataset}][WHERE Principal_Investigator_1_NPI = "{npi}"]'
    return query_by_npi


def fetch_data(sql_query):
    base_url = "https://openpaymentsdata.cms.gov/api/1/datastore/sql"
    encoded_query = parse.quote(sql_query)
    full_url = f"{base_url}?query={encoded_query}"
    headers = {"accept": "application/json"}
    
    response = requests.get(full_url, headers=headers)
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return None
    
def build_agg_dict(aggregation_dict):
    cleaned_aggregation_dict = {}
    for output_col, spec in aggregation_dict.items():
        if not (isinstance(spec, list) or isinstance(spec, tuple)) or len(spec) != 2:
            raise ValueError(f"Invalid aggregation format for {output_col}: {spec}")
        source_col, agg_func = spec

        if agg_func == "nunique":
            cleaned_aggregation_dict[output_col] = (source_col, pd.Series.nunique)
        elif agg_func == "unique_list":
            cleaned_aggregation_dict[output_col] = (source_col, lambda x: list(set(x.dropna())))
        else:
            cleaned_aggregation_dict[output_col] = (source_col, agg_func)
    return cleaned_aggregation_dict
        