from ..utils import (
    include_sdoh,
    filter_dataset_by_site_id,
    from_datetime_to_dss_string_date,
    df_to_dict,
)
from .config import CsiConfig
import pandas as pd
from ..models import (
    StudiesSiteMechCond,
    SiteStudiesTimeline,
    SiteStudiesQuarterData,
    SiteSdoh,
    SiteScoreCardModel,
    SponsorSiteData,
)


class SiteScoreCard:
    def __init__(self):
        return

    @staticmethod
    def get_site_scorecard_json(site_id: str):
        study_feature_df = SiteScoreCard.get_site_timeline(
            datasetName=CsiConfig.STUDIES_FEATURES_W_SITE_TIMELINE, site_id=site_id
        )
        active_study_period_df = SiteScoreCard.explode_study_period(study_feature_df)

        mesh_conditions_on_site = df_to_dict(
            SiteScoreCard.get_site_conditions(site_id), keep_first=False
        )
        studies_mesh_conditions_on_site = (
            [StudiesSiteMechCond(**item) for item in mesh_conditions_on_site]
            if not mesh_conditions_on_site is None
            else []
        )

        timeline_on_site = df_to_dict(
            SiteScoreCard.get_site_timeline(
                datasetName=CsiConfig.SITE_TIMELINE_ON_ONGOING_STUDIES,
                site_id=site_id,
            ),
            keep_first=False,
        )
        ongoing_studies_timeline_on_site = (
            [SiteStudiesTimeline(**item) for item in timeline_on_site]
            if not timeline_on_site is None
            else []
        )

        timeline_historical = df_to_dict(
            SiteScoreCard.get_site_timeline(
                datasetName=CsiConfig.SITE_TIMELINE_ON_HISTORICAL_STUDIES,
                site_id=site_id,
            ),
            keep_first=False,
        )
        historical_studies_timeline_on_site = (
            [SiteStudiesTimeline(**item) for item in timeline_historical]
            if not timeline_historical is None
            else []
        )

        sdoh_on_site = df_to_dict(
            filter_dataset_by_site_id(
                datasetName=CsiConfig.STUDIES_FEATURES_W_SITE_ID_SDOH,
                site_id=site_id,
            ),
            keep_first=True,
        )
        studies_site_sdoh_on_site = None
        if not sdoh_on_site is None and include_sdoh():
            studies_site_sdoh_on_site = SiteSdoh(**sdoh_on_site)

        quarter_on_site = df_to_dict(active_study_period_df, keep_first=False)
        active_studies_by_quarter_on_site = (
            [SiteStudiesQuarterData(**item) for item in quarter_on_site]
            if not quarter_on_site is None
            else []
        )

        sponsor_site = df_to_dict(
            SiteScoreCard.get_site_used_by_sponsor(study_feature_df), keep_first=False
        )
        sponsor_data = (
            [SponsorSiteData(**item) for item in sponsor_site]
            if not sponsor_site is None
            else []
        )

        return SiteScoreCardModel(
            studies_site_sdoh_on_site=studies_site_sdoh_on_site,
            studies_mesh_conditions_on_site=studies_mesh_conditions_on_site,
            ongoing_studies_timeline_on_site=ongoing_studies_timeline_on_site,
            historical_studies_timeline_on_site=historical_studies_timeline_on_site,
            active_studies_by_quarter_on_site=active_studies_by_quarter_on_site,
            site_used_by_sponsor=sponsor_data,
        )

    @staticmethod
    def get_site_timeline(datasetName: str, site_id: str):
        selected_site_timeline_df = filter_dataset_by_site_id(
            datasetName=datasetName, site_id=site_id
        )
        selected_site_timeline_df["StartDate"] = selected_site_timeline_df[
            "StartDate"
        ].apply(from_datetime_to_dss_string_date)

        selected_site_timeline_df["CompletionDate"] = selected_site_timeline_df[
            "CompletionDate"
        ].apply(from_datetime_to_dss_string_date)
        return selected_site_timeline_df

    @staticmethod
    def explode_study_period(df: pd.DataFrame):
        df = df.assign(Period=df.apply(SiteScoreCard.explode_date_range, axis=1))
        exploded_df = df.explode("Period").reset_index(drop=True)
        exploded_df["Period"] = exploded_df["Period"].dt.strftime("%Y-%m-%d")
        return exploded_df

    @staticmethod
    def explode_date_range(row: pd.Series):
        start_date = pd.to_datetime(row["StartDate"])
        end_date = pd.to_datetime(row["CompletionDate"])
        months = pd.date_range(
            SiteScoreCard.date_quarter(start_date),
            SiteScoreCard.date_quarter(end_date),
            freq="QS",
        )  # MS stands for Month Start frequency
        return months

    @staticmethod
    def date_quarter(date: pd.Timestamp):
        return pd.Timestamp(date.year, (date.month - 1) // 3 * 3 + 1, 1)

    @staticmethod
    def get_site_conditions(site_id: str):
        selected_site_conditions_df = filter_dataset_by_site_id(
            datasetName=CsiConfig.STDUIES_MESH_CONDITIONS_JOINED, site_id=site_id
        )

        unique_studies = len(selected_site_conditions_df.NCTId.unique())
        site_mesh_conditions_prevalence_df = (
            selected_site_conditions_df.groupby("group_mesh_term")
            .agg({"NCTId": lambda x: x.nunique() / unique_studies})
            .reset_index()
        )

        site_mesh_conditions_prevalence_df = site_mesh_conditions_prevalence_df.rename(
            columns={"NCTId": "prevalence"}
        )

        return site_mesh_conditions_prevalence_df

    @staticmethod
    def get_site_used_by_sponsor(df: pd.DataFrame):
        df_grouped_sponsor = (
            df.groupby("LeadSponsorName")
            .agg({"NCTId": "count"})
            .reset_index()
            .sort_values(by="NCTId", ascending=False)
        )
        df_grouped_sponsor = df_grouped_sponsor.head(10)

        df = (
            df.groupby(["LeadSponsorName", "Recruiting"])
            .agg({"NCTId": "count"})
            .reset_index()
        )
        df = df.rename(columns={"NCTId": "count"})

        df_grouped_sponsor = df_grouped_sponsor.merge(df, on="LeadSponsorName")

        return df_grouped_sponsor
