import dataiku
import pandas as pd, numpy as np
from pandas.io.json import json_normalize
import requests
import json
from collections import OrderedDict
from multiprocessing import pool, Lock
from urllib.parse import quote


def get_query_url(base_url, study_status, conditions, interventions, location, lead_sponsor, fields, essie_expression_syntax):
    if study_status:
        assert isinstance(study_status, list)
        status_query = ','.join(study_status)
        encoded_status_query = "&filter.overallStatus=" + quote(status_query)
        base_url += encoded_status_query

    if conditions:
        assert isinstance(conditions, str)
        encoded_conditions_query = "&query.cond=" + quote(conditions)
        base_url += encoded_conditions_query

    if interventions:
        assert isinstance(interventions, str)
        encoded_interventions_query = "&query.intr=" + quote(interventions)
        base_url += encoded_interventions_query

    if location:
        assert isinstance(location, str)
        encoded_location_query = "&query.locn=" + quote(location)
        base_url += encoded_location_query

    if lead_sponsor:
        assert isinstance(lead_sponsor, str)
        encoded_lead_sponsor_query = "&query.lead=" + quote(lead_sponsor)
        base_url += encoded_lead_sponsor_query
    
    if fields:
        assert isinstance(fields, list)
        fields_query = ','.join(fields)
        encoded_status_query = "&fields=" + quote(fields_query)
        base_url += encoded_status_query

    if essie_expression_syntax:
        assert isinstance(essie_expression_syntax, str)
        encoded_essie_expression_syntax_query = "&filter.advanced=" + quote(essie_expression_syntax)
        base_url += encoded_essie_expression_syntax_query

    return base_url


def get_all_queries(base_url, study_status, conditions, interventions, locations, lead_sponsors, fields, essie_expression_syntax):

    if locations:
        assert isinstance(locations, list)
    if lead_sponsors:
        assert isinstance(lead_sponsors, str)
        lead_sponsors = [sponsor.strip() for sponsor in lead_sponsors.split(',')]

    if len(locations) != 0 & len(lead_sponsors) != 0:
        queries = []
        for sponsor in lead_sponsors:
            for locn in locations:
                queries.append(get_query_url(
                    base_url=base_url, study_status=study_status, conditions=conditions,
                    interventions=interventions, location=locn, lead_sponsor=sponsor, fields=fields,
                    essie_expression_syntax=essie_expression_syntax))
        return queries
    elif locations:
        return [
            get_query_url(base_url=base_url, study_status=study_status, conditions=conditions, lead_sponsor=None,
                          interventions=interventions, location=locn, fields=fields, essie_expression_syntax=essie_expression_syntax)
            for locn in locations]

    elif lead_sponsors:
        return [
            get_query_url(base_url=base_url, study_status=study_status, conditions=conditions, location=None,
                          interventions=interventions, lead_sponsor=sponsor, fields=fields, essie_expression_syntax=essie_expression_syntax)
            for sponsor in lead_sponsors]
    else:
        return [
            get_query_url(
                base_url=base_url, study_status=study_status, conditions=conditions, interventions=interventions,
                location=None, lead_sponsor=None, fields=fields, essie_expression_syntax=essie_expression_syntax)]


def get_all_studies(col_name, study_status=None, conditions=None, interventions=None, locations=None,
                    lead_sponsors=None, fields=None, essie_expression_syntax=None):
    base_url = "https://beta.clinicaltrials.gov/api/v2/studies?pageSize=1000"
    base_urls = get_all_queries(
        base_url, study_status, conditions, interventions, locations, lead_sponsors, fields, essie_expression_syntax)

    dataset = dataiku.Dataset("clinicaltrialgov_dataset")
    dataset.write_schema(get_schema(col_name))
    with dataset.get_writer() as writer:
        for url in base_urls:
            get_studies(writer, url, col_name)
#     writer_lock = Lock()
#     with dataset.get_writer() as writer:
#         with pool.ThreadPool() as executor:
#             futures = []
#             for url in base_urls:
#                 futures.append(executor.apply_async(get_studies, args=(writer, writer_lock, url, col_name)))
#             for future in futures:
#                 print(future.get())


# def get_studies(writer, writer_lock, base_url, col_name):
def get_studies(writer, base_url, col_name):
    nextPageToken = None
    page = 1
    studies_count = 0
    
    while True:
        url = base_url
        
        if nextPageToken:
            url += f"&pageToken={nextPageToken}"

        try:
            response = requests.get(url)
            response.raise_for_status()  # Raise an HTTPError for bad responses (status codes other than 2xx)

            data = response.json()

            if 'studies' in data:
                studies = data['studies']
                studies_count += len(studies)
                write_dataset(writer, studies, col_name)
                print(f"Fetched studies: {studies_count}")

                if 'nextPageToken' in data:
                    page += 1
                    nextPageToken = data['nextPageToken']
                else:
                    break
            else:
                raise ValueError(f"Error: 'studies' key not found in the response JSON. URL: {url}")
        except requests.exceptions.RequestException as e:
            error_message = f"Error: Unable to retrieve studies. {e}. URL: {url}"
            print(error_message)
            return error_message

    print(f"Total studies fetched: {studies_count}")


def create_normalized_columns_name(json_data, parent_key='', separator='.', result=OrderedDict()):
    if isinstance(json_data, list):
        for item in json_data:
            create_normalized_columns_name(item, parent_key, separator)
    elif isinstance(json_data, dict):
        key = json_data['name']
        new_key = parent_key + separator + key if parent_key else key
        if ("children" in json_data.keys()) & ('[]' not in json_data["type"]):
            create_normalized_columns_name(json_data["children"], new_key, separator)
        else:
            result[new_key] = json_data['piece']

    return result


def get_schema(col_name):
    schema = []
    for v in col_name.values():
        schema.append({'name': v,   'type':'string'})
    return schema


def get_dataframe(json_data, col_name):
    df = json_normalize(json_data)
    cols_order = list(col_name.values())
    for old_col, new_col in col_name.items():
        if old_col in df.columns:
            # If the old column exists, rename it using the new column name
            df = df.rename(columns={old_col: new_col})
        else:
            # If the old column doesn't exist, add a new column with NaN values
            df[new_col] = np.nan
    return df[cols_order]


# Write recipe outputs
def write_dataset(writer, json_data, col_name):
    normalized_df = get_dataframe(json_data, col_name)
    writer.write_dataframe(normalized_df)