import dataiku
from dataiku.llm.agent_tools import BaseAgentTool
import json, logging, requests, ast


def fetch_study_data(nctid:str, additional_fields:list[str]=[]):
    base_url = f"https://clinicaltrials.gov/api/v2/studies/{nctid}"
    base_fields = [
        "NCTId", "BriefTitle", "Organization", "OverallStatus", "Phase", "Condition",
        "InterventionType", "InterventionName", 
        "StartDateStruct", "PrimaryCompletionDateStruct", "CompletionDateStruct", "LastUpdatePostDateStruct",
        "LeadSponsorName", "CollaboratorName",
        "BriefSummary"
    ]
    advanced_fields = {
        "studyDetails": ["DetailedDescription"],
        "patientEligibility": [
            "EnrollmentInfo", "EligibilityCriteria", "HealthyVolunteers", "Sex",
            "MinimumAge", "MaximumAge"],
        "trialDesign": [
            "DesignPrimaryPurpose", "DesignAllocation", "DesignInterventionModel", 
            "DesignMasking", "ArmGroup", "NumArmGroups"
        ],
        "contactsLocations": [
            "CentralContact", "OverallOfficial", "Location"
        ]
    }
    
    if not isinstance(additional_fields, list):
        additional_fields = []
    other_fields = []
    for key in additional_fields:
        other_fields += advanced_fields[key]  
    all_fields = base_fields + other_fields
        
    fields = "|".join(all_fields)
    params = {
        "format": "json",
        "fields": fields
    }
    
    response = requests.get(base_url, params=params)
    
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Error: {response.status_code}, {response.text}")
        return None


class CTgovAPItool(BaseAgentTool):
    """An empty interface for a code-based agent tool"""

    def set_config(self, config, plugin_config):
        self.logger = logging.getLogger(__name__)
        self.config = config
        self.plugin_config = plugin_config

    def get_descriptor(self, tool):
        """
        Returns the descriptor of the tool, as a dict containing:
           - description (str)
           - inputSchema (dict, a JSON Schema representation)
        """
        return {
            "description": "Query the clinicaltrials.gov API by a study ID (NCTId) to retrieve the study overview (title, sponsors, status, phase, conditions, treatments, dates, and brief summary). Add additional fields to retreive more study information.",
            "inputSchema" : {
                "$id": "ctgov_query_schema",
                "title": "ClinicalTrials.gov API Query",
                "type": "object",
                "properties": {
                    "nctid": {
                        "type": "string",
                        "description": "A study ID (e.g., 'NCT04848928')"
                        },
                    "additional_fields": {
                        'type': "array",
                        "description": "A list of additional fields for query. Allowed fields: 'studyDetails', 'patientEligibility', 'trialDesign', 'contactsLocations'",
                        "items": {
                            "type": "string"}
                    }
                },
                "required": ["nctid"]
           }
        }
    


    def invoke(self, input, trace):
        """
        Invokes the tool.

        The arguments of the tool invocation are in input["input"], a dict
        """
        self.logger.setLevel(logging.INFO) # Use INFO or DEBUG as needed
        self.logger.info(f"Received input: {input}")
        
        args = input['input']
        nctid = args['nctid']
        additional_fields = args.get("additional_fields", [])
        
        if isinstance(additional_fields, str):
            self.logger.warning(f"Received additional_fields as a string: {additional_fields}. Attempting to parse.")
            try:
                cleaned_str = additional_fields.replace('/', '')
                additional_fields = ast.literal_eval(cleaned_str)
            except (ValueError, SyntaxError):
                self.logger.error("Failed to parse additional_fields string. Proceeding without them.")
                additional_fields = []
        
        query_results = fetch_study_data(nctid, additional_fields)
        
        self.logger.info(f"Successfully fetched data for {len(query_results)} studies.")
      
        return {
            "output": json.dumps(query_results),
            "sources": [],
        }
