# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Read recipe inputs
facilities_us_county_demographic_joined = dataiku.Dataset("facilities_us_county_demographic_joined")
facilities_us_county_demographic_joined_df = facilities_us_county_demographic_joined.get_dataframe()

index_cols = [col for col in facilities_us_county_demographic_joined_df.columns if "Percent" in col]
index_cols.append('Social_Vulnerability_Index_county')

selected_df = (
    facilities_us_county_demographic_joined_df
    .loc[:, ['NCTId', 'FIPS', 'Population_county', *index_cols]]
    .drop_duplicates())

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Calculate total population from all counties where the sites reside for each study
study_sites_region_population_df = (
    facilities_us_county_demographic_joined_df
    .loc[:, ['NCTId', 'Population_county', 'Area_sqmi_county', 'County_name', 'State_name']]
    .dropna()
    .drop_duplicates()
    .groupby('NCTId').agg({'Population_county': sum, 'Area_sqmi_county': sum})
    .rename(columns={
        'Population_county': 'US_counties_population_sum',
        'Area_sqmi_county': 'US_counties_area_sum'})
    .reset_index())

study_sites_region_population_df=(
    study_sites_region_population_df
    .assign(US_population_density=study_sites_region_population_df['US_counties_population_sum']/study_sites_region_population_df['US_counties_area_sum']))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
study_facilities_sdoh_scored_df = (
    selected_df
    .merge(study_sites_region_population_df, on='NCTId', how='left'))

# multiply county weight to sdoh metrics
study_facilities_sdoh_scored_df = (
    study_facilities_sdoh_scored_df
    .assign(County_weight=study_facilities_sdoh_scored_df.Population_county/study_facilities_sdoh_scored_df.US_counties_population_sum))

for col in index_cols:
    study_facilities_sdoh_scored_df[col] = (
        study_facilities_sdoh_scored_df[col] * study_facilities_sdoh_scored_df['County_weight'])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
study_facilities_sdoh_scored_df

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# aggregate SDOH metrics by study
agg_functions = dict([(col, lambda x: x.sum(skipna=False)) for col in index_cols])
agg_functions['US_counties_area_sum'] = 'first'
agg_functions['US_counties_population_sum'] = 'first'
agg_functions['US_population_density'] = 'first'
# check if study recruits any non-us sites
agg_functions['FIPS'] = lambda x: x.isna().any()
study_weighted_sdoh_score_df = (
    study_facilities_sdoh_scored_df.groupby(['NCTId']).agg(agg_functions))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# remove SDOH metrics if study involes non-US sites
study_weighted_sdoh_score_df.loc[study_weighted_sdoh_score_df['FIPS'], :] = np.nan
study_weighted_sdoh_score_df = (
    study_weighted_sdoh_score_df
    .assign(has_nonus_sites=study_weighted_sdoh_score_df['FIPS'].isna().replace({False: np.nan}))
    .reset_index())

# rename column
rename_cols = dict([(col, col+'_weighted_sum') for col in index_cols])
study_weighted_sdoh_score_df = study_weighted_sdoh_score_df.rename(columns=rename_cols)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write recipe outputs
study_facilities_sdoh_scored = dataiku.Dataset("study_facilities_sdoh_scored")
study_facilities_sdoh_scored.write_with_schema(study_weighted_sdoh_score_df)