# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import dataiku.insights
import pandas as pd
import numpy as np
# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE

from solution.constants import N_SIMULATIONS

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
last_rfm_segment_counts_df = dataiku.Dataset("last_customer_rfm_segments_counts").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
segment_transitions_df = dataiku.Dataset("segment_label_transitions").get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
historical_segment_aggregations_df = dataiku.Dataset("historical_segment_aggregations").get_dataframe()
historical_segment_aggregations_df["previous_avg_total_basket"]\
[historical_segment_aggregations_df["previous_segment_label"]=="INACTIVE"] = 0.0
historical_segment_aggregations_df

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
historical_segment_aggregations_df =\
historical_segment_aggregations_df.sort_values(by="previous_avg_total_basket").reset_index(drop=True)
historical_segment_aggregations_df["avg_total_basket_rank"] = historical_segment_aggregations_df.index

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
original_segments = list(historical_segment_aggregations_df["previous_segment_label"])
historical_segment_aggregations_df["label_correction"] =\
historical_segment_aggregations_df["avg_total_basket_rank"] - 1
historical_segment_aggregations_df["label_correction"] =\
historical_segment_aggregations_df["label_correction"].astype(str)
historical_segment_aggregations_df["previous_segment_label_corrected"] =\
historical_segment_aggregations_df["label_correction"] + ". " +\
historical_segment_aggregations_df["previous_segment_label"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
segment_labels_mapping_df =\
historical_segment_aggregations_df[["previous_segment_label", "previous_segment_label_corrected"]]\
.drop_duplicates()
old_to_new_segment_labels_mapping = {}
for new_label, old_label in zip(
        segment_labels_mapping_df["previous_segment_label_corrected"],
        segment_labels_mapping_df["previous_segment_label"]):
    old_to_new_segment_labels_mapping[old_label] = new_label

historical_segment_aggregations_df["previous_segment_label"] =\
historical_segment_aggregations_df["previous_segment_label"].apply(
    lambda x: old_to_new_segment_labels_mapping[x])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
last_rfm_segment_counts_df["segment_label"] = last_rfm_segment_counts_df["segment_label"]\
.apply(lambda x: old_to_new_segment_labels_mapping[x])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
last_segment_occurences_and_baskets_df =\
historical_segment_aggregations_df[["previous_segment_label",
                                    "previous_avg_total_basket"]].merge(last_rfm_segment_counts_df,
                                                                        how="left",
                                                                        left_on="previous_segment_label",
                                                                        right_on="segment_label"
                                                                       )\
.drop("segment_label", axis=1)\
.rename({
    "previous_segment_label": "segment_label",
    "previous_avg_total_basket": "avg_total_basket"}, axis=1)
last_segment_occurences_and_baskets_df["currently_in_segment_label"] =\
last_segment_occurences_and_baskets_df["currently_in_segment_label"].fillna(0)
last_segment_occurences_and_baskets_df["currently_in_segment_label"] =\
last_segment_occurences_and_baskets_df["currently_in_segment_label"].astype(int)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
all_segment_baskets = {}
all_segment_occurences = {}
for segment, avg_total_basket, occurences in zip(
    last_segment_occurences_and_baskets_df["segment_label"],
    last_segment_occurences_and_baskets_df["avg_total_basket"],
    last_segment_occurences_and_baskets_df["currently_in_segment_label"]
):
    all_segment_baskets[segment] = avg_total_basket
    all_segment_occurences[segment] = occurences

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# ## All transitions dataframe computation
all_segments = list(historical_segment_aggregations_df["previous_segment_label"].unique())

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
segment_transitions_df["next_segment_label"]\
[segment_transitions_df["next_segment_label"]=="same_segment_label"]=\
segment_transitions_df["previous_segment_label"]
for column_name in ["previous_segment_label", "next_segment_label"]:
    segment_transitions_df[column_name] =\
    segment_transitions_df[column_name].apply(lambda x: old_to_new_segment_labels_mapping[x])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
transitions_df = pd.DataFrame({"previous_segment_label": all_segments})
transitions_df["cross_join_key"] = 1
tmp_df = transitions_df.copy()
tmp_df.rename({"previous_segment_label": "next_segment_label"}, axis=1, inplace=True)
transitions_df = transitions_df.merge(tmp_df, how="left", on="cross_join_key")
transitions_df = transitions_df.merge(segment_transitions_df,
                                      how="left",
                                      on=["previous_segment_label", "next_segment_label"])
transitions_df.drop("cross_join_key", axis=1, inplace=True)
transitions_df.fillna(0, inplace=True)
transitions_df.reset_index(drop=True, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
transitions_df = transitions_df\
.merge(historical_segment_aggregations_df[["previous_segment_label",
                                           "previous_segment_label_occurences",
                                           "previous_segment_label_corrected"]],
       how="left",
       on=["previous_segment_label"])
transitions_df.rename({"previous_segment_label_corrected": "previous_segment",
                       "previous_segment_label_occurences": "previous_segment_occurences"},
                      axis=1, inplace=True)
transitions_df.drop("previous_segment_label", axis=1, inplace=True)
transitions_df = transitions_df\
.merge(historical_segment_aggregations_df[["previous_segment_label",
                                           "previous_segment_label_corrected"]],
       how="left",
       left_on=["next_segment_label"],
       right_on=["previous_segment_label"])
transitions_df.drop(["next_segment_label", "previous_segment_label"], axis=1, inplace=True)
transitions_df.rename({"previous_segment_label_corrected": "next_segment"}, axis=1, inplace=True)
transitions_df.sort_values(by=["previous_segment", "next_segment"],
                           ascending=[True, True], inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
transitions_df = transitions_df[["previous_segment", "previous_segment_occurences",
                                 "n_transitions", "next_segment"]]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for column_name in ["n_transitions", "previous_segment_occurences"]:
    transitions_df[column_name].fillna(0, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
def get_probable_segment_deltas_kpis_and_labels(probable_segment_deltas_df):
    probable_segment_deltas_df["all_simulated_segments_evolution_revenue"] =\
    probable_segment_deltas_df["simulated_segment_occurences"] *\
    probable_segment_deltas_df["current_segment_avg_total_basket"]

    probable_segment_deltas_df["revenue_delta_with_all_simulated_segments_evolution"] =\
        probable_segment_deltas_df["all_simulated_segments_evolution_revenue"] -\
        probable_segment_deltas_df["without_transitions_revenue"]

    probable_segment_deltas_df["occurences_delta_type"] = "Positive delta"
    probable_segment_deltas_df["occurences_delta_type"][
        probable_segment_deltas_df["occurences_delta_with_simulated_transitions"] <= 0
    ] = "Negative delta"

    probable_segment_deltas_df["current_segment_evolution_revenue_delta_type"] = "Positive delta"
    probable_segment_deltas_df["current_segment_evolution_revenue_delta_type"][
        probable_segment_deltas_df["revenue_delta_with_simulated_current_segment_evolution"] <= 0
    ] = "Negative delta"

    probable_segment_deltas_df["all_segments_evolution_revenue_delta_type"] = "Positive delta"
    probable_segment_deltas_df["all_segments_evolution_revenue_delta_type"][
        probable_segment_deltas_df["revenue_delta_with_all_simulated_segments_evolution"] <= 0
    ] = "Negative delta"

    probable_segment_deltas_df["transitioning_segment_label"] =\
    probable_segment_deltas_df["current_segment"].apply(lambda x: f"'{x}'")
    probable_segment_deltas_df["transitioning_segment_label"] =\
    probable_segment_deltas_df["transitioning_segment_label"] + " transitions"
    return probable_segment_deltas_df

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
transitions_df["transition_probability"] =\
transitions_df["n_transitions"] / transitions_df["previous_segment_occurences"]

def get_simulated_dataframes(segments_possible_transitions, segments_transition_probabilities):
    # ## Simulation of the transtion to the next segment:

    simulated_segment_transitions_data = []
    for segment in all_segments:
        segment_occurences = all_segment_occurences[segment]
        for __ in range(segment_occurences):
            possible_transitions = segments_possible_transitions[segment]
            transition_probabilities = segments_transition_probabilities[segment]
            transition_indexes = np.array(range(len(transition_probabilities)))
            simulated_transition_index = np.random.choice(transition_indexes, p=transition_probabilities)
            simulated_transition = possible_transitions[simulated_transition_index]
            transition_information = {
                "previous_segment": segment,
                "simulated_segment": simulated_transition,
                "simulated_segment_occurences": 1
            }
            simulated_segment_transitions_data.append(transition_information)

    simulated_segment_transitions_df = pd.DataFrame(simulated_segment_transitions_data)

    # ## Simulated revenue delta computation
    AGGREGATION_KEY = ["previous_segment", "simulated_segment"]
    simulated_segment_revenue_transitions_df = \
    simulated_segment_transitions_df[AGGREGATION_KEY+["simulated_segment_occurences"]]\
    .groupby(AGGREGATION_KEY).sum().reset_index()

    simulated_segment_revenue_transitions_df =\
    transitions_df[["previous_segment", "next_segment"]].merge(simulated_segment_revenue_transitions_df,
                                                            how="left",
                                                            left_on=["previous_segment", "next_segment"],
                                                            right_on=["previous_segment", "simulated_segment"])
    simulated_segment_revenue_transitions_df["simulated_segment"]\
    [simulated_segment_revenue_transitions_df["simulated_segment"].isnull()] =\
    simulated_segment_revenue_transitions_df["next_segment"]
    simulated_segment_revenue_transitions_df.drop("next_segment", axis=1, inplace=True)

    simulated_segment_revenue_transitions_df["simulated_segment_occurences"] =\
    simulated_segment_revenue_transitions_df["simulated_segment_occurences"].fillna(0)
    simulated_segment_revenue_transitions_df

    simulated_segment_revenue_transitions_df["previous_segment_avg_total_basket"] =\
    simulated_segment_revenue_transitions_df["previous_segment"].apply(lambda x: all_segment_baskets[x])
    simulated_segment_revenue_transitions_df["simulated_segment_avg_total_basket"] =\
    simulated_segment_revenue_transitions_df["simulated_segment"].apply(lambda x: all_segment_occurences[x])

    simulated_segment_revenue_transitions_df["simulated_transitions_revenue"] =\
    simulated_segment_revenue_transitions_df["simulated_segment_occurences"] *\
    simulated_segment_revenue_transitions_df["simulated_segment_avg_total_basket"]

    simulated_segment_revenue_df = \
    simulated_segment_revenue_transitions_df[["previous_segment", "simulated_transitions_revenue"]]\
    .groupby("previous_segment").sum().reset_index().rename({"simulated_transitions_revenue": "simulated_current_segment_evolution_revenue"}, axis=1)

    last_segment_occurences_and_baskets_df["without_transitions_revenue"] =\
    last_segment_occurences_and_baskets_df["currently_in_segment_label"] *\
    last_segment_occurences_and_baskets_df["avg_total_basket"]

    probable_segment_revenue_deltas_df =\
    last_segment_occurences_and_baskets_df[["segment_label", "currently_in_segment_label",
                                        "avg_total_basket", "without_transitions_revenue"]].rename(
        {"segment_label": "current_segment",
        "avg_total_basket": "current_segment_avg_total_basket"
        }, axis=1)\
    .merge(simulated_segment_revenue_df,
        how="left",
        left_on=["current_segment"],
        right_on=["previous_segment"])\
    .drop("previous_segment", axis=1)
    for column_name in ["without_transitions_revenue", "simulated_current_segment_evolution_revenue"]:
        probable_segment_revenue_deltas_df[column_name].fillna(0, inplace=True)

    probable_segment_revenue_deltas_df["revenue_delta_with_simulated_current_segment_evolution"] =\
        probable_segment_revenue_deltas_df["simulated_current_segment_evolution_revenue"] -\
        probable_segment_revenue_deltas_df["without_transitions_revenue"]

    # ## Simulated occurences delta computation
    simulated_segment_occurences_df = \
    simulated_segment_transitions_df[["simulated_segment", "simulated_segment_occurences"]]\
    .groupby(["simulated_segment"]).sum().reset_index()

    probable_segment_occurences_deltas_df =\
    last_segment_occurences_and_baskets_df[["segment_label", "currently_in_segment_label"]].rename(
        {"segment_label": "current_segment",
        "currently_in_segment_label": "current_segment_occurences",
        "previous_segment_label_occurences": "current_segment_occurences",
        }, axis=1)\
    .merge(simulated_segment_occurences_df,
        how="left",
        left_on=["current_segment"],
        right_on=["simulated_segment"])

    for column_name in ["current_segment_occurences", "simulated_segment_occurences"]:
        probable_segment_occurences_deltas_df[column_name].fillna(0, inplace=True)
        probable_segment_occurences_deltas_df[column_name] = probable_segment_occurences_deltas_df[column_name].astype(int)

    probable_segment_occurences_deltas_df["simulated_segment"][probable_segment_occurences_deltas_df["simulated_segment"].isnull()] =\
    probable_segment_occurences_deltas_df["current_segment"]

    probable_segment_occurences_deltas_df["occurences_delta_with_simulated_transitions"] =\
    probable_segment_occurences_deltas_df["simulated_segment_occurences"] - probable_segment_occurences_deltas_df["current_segment_occurences"]

    probable_segment_revenue_deltas_df.drop(["currently_in_segment_label"], axis=1, inplace=True)
    probable_segment_deltas_df = probable_segment_occurences_deltas_df.merge(
        probable_segment_revenue_deltas_df,
        how="left",
        on="current_segment"
    )\
    .drop(["simulated_segment"], axis=1)
    probable_segment_deltas_df = get_probable_segment_deltas_kpis_and_labels(probable_segment_deltas_df)
    return probable_segment_deltas_df

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
segments_possible_transitions = {}
segments_transition_probabilities = {}
for segment in all_segments:
    filtered_transitions_df = transitions_df[transitions_df["previous_segment"]==segment]
    segments_possible_transitions[segment] = filtered_transitions_df["next_segment"].to_numpy()
    segments_transition_probabilities[segment] =\
        filtered_transitions_df["transition_probability"].to_numpy()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE


# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
simulated_segment_delta_dataframes = []
for loop_index in range(N_SIMULATIONS):
    print(f"Simulating segment deltas ({loop_index+1}/{N_SIMULATIONS}).")
    probable_segment_deltas_df =\
    get_simulated_dataframes(segments_possible_transitions, segments_transition_probabilities)
    probable_segment_deltas_df["simulation_id"] = f"simulation_{loop_index+1}"
    simulated_segment_delta_dataframes.append(probable_segment_deltas_df)

all_simulation_dataframes_df = pd.concat(simulated_segment_delta_dataframes)
segment_deltas_scenarios_distribution_df =\
all_simulation_dataframes_df[["simulation_id",
                              "revenue_delta_with_all_simulated_segments_evolution"]]\
.groupby(["simulation_id"]).sum().reset_index()
segment_deltas_scenarios_distribution_df


# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
focus_columns = ["current_segment", "without_transitions_revenue", "simulated_segment_occurences", "current_segment_avg_total_basket",
                 "revenue_delta_with_simulated_current_segment_evolution",
                 "revenue_delta_with_all_simulated_segments_evolution",
                 "occurences_delta_with_simulated_transitions"
                ]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
probable_segment_deltas_df = all_simulation_dataframes_df[focus_columns]\
.groupby("current_segment").median().reset_index()
probable_segment_deltas_df = get_probable_segment_deltas_kpis_and_labels(probable_segment_deltas_df)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
probable_segment_deltas_df

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
transitions_df.rename({"previous_segment": "current_segment"}, axis=1, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
transitions_df["next_segment_index"] = transitions_df["next_segment"]
transitions_df["next_segment_index"] =\
transitions_df["next_segment_index"].apply(lambda x: x.split('.')[0])
transitions_df["next_segment_index"] =\
transitions_df["next_segment_index"] + "."
transitions_df

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write recipe outputs
transition_probabilities = dataiku.Dataset("segment_transition_probabilities")
transition_probabilities.write_with_schema(transitions_df)
probable_segment_revenue_deltas = dataiku.Dataset("probable_segment_deltas")
probable_segment_revenue_deltas.write_with_schema(probable_segment_deltas_df)