import dataiku

from clv_forecast.constants import (
    MONTHLY_FEATURES_KEEP_COLUMNS,
    RFM_KEEP_COLUMNS,
)
from dku_utils.projects.project_commons import get_current_project_and_variables
from dku_utils.projects.recipes.join_recipe import programmaticJoinHandler
from dku_utils.projects.recipes.recipe_commons import (
    update_recipe_ouput_schema,
    adapt_recipe_engine_to_priority_and_availability,
    get_recipe_settings_and_dictionary,
    switch_visual_recipe_input,
)


def data_enrichment():

    project, variables = get_current_project_and_variables()
    recipe_name = "compute_monthly_granularity_with_rfm"

    if (
        variables["standard"]["leverage_customer_metadata_app"]
        or variables["standard"]["leverage_rfm_segmentation_app"]
    ):
        join_handler = programmaticJoinHandler(
            project=project,
            recipe_name=recipe_name,
            main_dataset_name="monthly_granularity_full",
            main_dataset_columns_to_select=MONTHLY_FEATURES_KEEP_COLUMNS,
            main_dataset_columns_to_select_alias={},
            main_dataset_computed_columns=[],
        )
        if variables["standard"]["leverage_rfm_segmentation_app"]:
            rfm_dataset = dataiku.Dataset(
                variables["standard"]["rfm_dataset_app"]
            )
            schema = rfm_dataset.read_schema()
            columns = list(
                filter(
                    lambda name: (name != "customer_id") and (name != "rfm_reference_date"),
                    map(lambda pair: pair["name"], schema),
                )
            )

            join_handler.add_one_join_on_main_dataset(
                dataset_to_join_name=variables["standard"]["rfm_dataset_app"],
                # dataset_to_join_columns_to_select=RFM_KEEP_COLUMNS,
                dataset_to_join_columns_to_select=columns,
                join_type="LEFT",
                columns_prefix="",
                left_join_key=["ref_date", "customer_id"],
                right_join_key=["rfm_reference_date", "customer_id"],
                columns_to_select_alias={},
                dataset_computed_columns=[],
            )

        if variables["standard"]["leverage_customer_metadata_app"]:
            meta_dataset = dataiku.Dataset(
                variables["standard"]["customer_metadata_dataset_app"]
            )

            schema = meta_dataset.read_schema()
            columns = list(
                filter(
                    lambda name: name != "customer_id",
                    map(lambda pair: pair["name"], schema),
                )
            )

            join_handler.add_one_join_on_main_dataset(
                dataset_to_join_name=variables["standard"][
                    "customer_metadata_dataset_app"
                ],
                dataset_to_join_columns_to_select=columns,
                join_type="LEFT",
                columns_prefix="",  # "metadata",
                left_join_key=["customer_id"],
                right_join_key=["customer_id"],
                columns_to_select_alias={},
                dataset_computed_columns=[],
            )

        update_recipe_ouput_schema(project, recipe_name)
        adapt_recipe_engine_to_priority_and_availability(project, recipe_name)

        output_dataset = "monthly_granularity_with_rfm"

        join_output_dataset = project.get_dataset(output_dataset)
        join_output_dataset.build()

    else:
        output_dataset = "monthly_granularity_full"

    prepare_recipe_name = "compute_monthly_granularity_with_rfm_prepared"
    recipe_settings, dict_set = get_recipe_settings_and_dictionary(
        project, prepare_recipe_name, True
    )
    switch_visual_recipe_input(
        project,
        recipe_name=prepare_recipe_name,
        current_input_dataset_name=dict_set["inputs"]["main"]["items"][0]["ref"],
        new_input_dataset_name=output_dataset,
    )
    update_recipe_ouput_schema(project, prepare_recipe_name)
    adapt_recipe_engine_to_priority_and_availability(project, prepare_recipe_name)
