# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import dataiku.insights
import pandas as pd
import numpy as np
import math

from sklearn.cluster import KMeans
from bs_commons.dku_utils.projects.project_commons import get_current_project_and_variables
from bs_commons.dku_utils.projects.datasets.dataset_commons import change_dataset_column_datatype
from bs_commons.dates_handling.type_conversions import from_datetime_to_dss_string_date
from solution.functions import (load_rfm_parameters, compute_segmentation_quantiles_boundaries,
                                score_rfm_with_quantiles, train_kmeans_models_for_rfm,
                                score_rfm_with_k_means, enrich_dataframe_with_rfm_global_scores_and_segments,
                                remove_dataframe_outliers_based_on_quantiles, generate_rfm_box_plots,
                                load_mapping_between_original_and_rfm_columns, load_business_rules,
                                from_value_to_business_score, score_rfm_with_business_rules)
from solution.constants import COLUMN_FOR_RECENCY_COMPUTATION, LOWER_OUTLIERS_QUANTILE_TRESHOLD, HIGHER_OUTLIERS_QUANTILE_TRESHOLD

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
project, variables = get_current_project_and_variables()
app_variables = variables["standard"]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Read recipe inputs
customer_rfm_inputs = dataiku.Dataset("customer_rfm_inputs")
customer_rfm_inputs_df = customer_rfm_inputs.get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
main_connection_type = app_variables["connection_type_app"]

segment_labeling_components = app_variables["segment_labeling_components_app"]
if segment_labeling_components == "recency_and_frequency":
    segments_identification_dataset = "rf_segments_identication_synced"
    segments_identification_join_key = ["recency", "frequency"]
else:
    segments_identification_dataset = "rfm_segments_identication_synced"
    segments_identification_join_key = ["recency", "frequency", "monetary_value"]

if main_connection_type in ["Redshift", "Synapse", "BigQuery"]:
    segments_identification = dataiku.Dataset(f"{segments_identification_dataset}_synced")
    segments_identification_df = segments_identification.get_dataframe()
else:
    segments_identification = dataiku.Dataset(f"{segments_identification_dataset}")
    segments_identification_df = segments_identification.get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
scores_computation_technique = app_variables["scores_computation_technique_app"]
n_segments_per_axis = app_variables["n_segments_per_axis_app"]
recency_policy = app_variables["recency_policy_app"]
monetary_value_policy = app_variables["monetary_value_policy_app"]

rfm_original_columns, reverse_scores_in_rfm_columns, original_columns_to_rfm_labels_mapping, rfm_columns =\
load_rfm_parameters(COLUMN_FOR_RECENCY_COMPUTATION, recency_policy, monetary_value_policy)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
customer_rfm_inputs_without_outliers_df = remove_dataframe_outliers_based_on_quantiles(customer_rfm_inputs_df,
                                                                                       LOWER_OUTLIERS_QUANTILE_TRESHOLD,
                                                                                       HIGHER_OUTLIERS_QUANTILE_TRESHOLD,
                                                                                       rfm_original_columns)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
not_outlier_customers =  list(customer_rfm_inputs_without_outliers_df["customer_id"])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
axes_data_original = {}
axes_data_without_outliers = {}
for axis_label in rfm_original_columns:
    axes_data_original[axis_label] = list(customer_rfm_inputs_df[axis_label])
    axes_data_without_outliers[axis_label] = list(customer_rfm_inputs_without_outliers_df[axis_label])

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
if "quantiles" in scores_computation_technique:
    print("Computing RFM axes scores with quantiles ...")
    segmentation_quantiles_boundaries = compute_segmentation_quantiles_boundaries(n_segments_per_axis)
    axes_quantiles = {}
    for axis_label in rfm_original_columns:
        axis_data = axes_data_without_outliers[axis_label]
        axis_quantiles = np.quantile(axis_data, q=segmentation_quantiles_boundaries)
        axes_quantiles[axis_label] = axis_quantiles

    axes_rfm_scores = score_rfm_with_quantiles(rfm_original_columns, axes_data_original, axes_quantiles, n_segments_per_axis, reverse_scores_in_rfm_columns)

elif "kmeans_clustering" in scores_computation_technique:
    print("Computing RFM axes scores with KMeans clustering ...")
    axes_kmeans_clustering_models = train_kmeans_models_for_rfm(rfm_original_columns, axes_data_without_outliers, n_segments_per_axis)
    axes_rfm_scores = score_rfm_with_k_means(rfm_original_columns, axes_data_original, axes_kmeans_clustering_models, reverse_scores_in_rfm_columns)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
for axis_label in rfm_original_columns:
    rfm_label = original_columns_to_rfm_labels_mapping[axis_label]
    rfm_scores = axes_rfm_scores[axis_label]
    customer_rfm_inputs_df[rfm_label] = rfm_scores

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Computing the RFM global scores and segments:
print("Computing the RFM global scores and segments ...")
customer_rfm_inputs_df = enrich_dataframe_with_rfm_global_scores_and_segments(customer_rfm_inputs_df, n_segments_per_axis, ["recency", "frequency", "monetary_value"], "rfm")
customer_rfm_inputs_df = enrich_dataframe_with_rfm_global_scores_and_segments(customer_rfm_inputs_df, n_segments_per_axis, ["recency", "frequency", "monetary_value", "density"], "rfmd")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
if "quantiles" in scores_computation_technique:
    for axis_label in rfm_original_columns:
        rfm_label = original_columns_to_rfm_labels_mapping[axis_label]
        axis_quantiles = axes_quantiles[axis_label]
        reverse_scores = reverse_scores_in_rfm_columns[axis_label]
        if reverse_scores:
            axis_quantiles = np.flip(axis_quantiles)
        for quantile_index, quantile_value in enumerate(axis_quantiles):
            rfm_score = quantile_index + 1
            variables["standard"]["{}_{}_quantile_app".format(rfm_label, rfm_score)] = quantile_value


elif "kmeans_clustering" in scores_computation_technique:
    for axis_label in rfm_original_columns:
        rfm_label = original_columns_to_rfm_labels_mapping[axis_label]
        clustering_model = axes_kmeans_clustering_models[axis_label]
        cluster_centroids = clustering_model.cluster_centers_.reshape(-1)
        cluster_centroids.sort()
        reverse_scores = reverse_scores_in_rfm_columns[axis_label]
        if reverse_scores:
            cluster_centroids = np.flip(cluster_centroids)

        for centroid_index, centroid_value in enumerate(cluster_centroids):
            rfm_score = centroid_index + 1
            variables["standard"]["{}_{}_centroid_app".format(rfm_label, rfm_score)] = centroid_value
project.set_variables(variables)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
if "business_rules" in scores_computation_technique:
    rfm_business_rules_scope = app_variables["rfm_business_rules_scope_app"]
    print(f"Appliying business to assess '{rfm_business_rules_scope}' ...")

    axes_business_thresholds = load_business_rules(project, variables,
                                                   COLUMN_FOR_RECENCY_COMPUTATION)
    axes_business_rfm_scores = score_rfm_with_business_rules(rfm_original_columns,
                                                         axes_data_original,
                                                         axes_business_thresholds,
                                                         reverse_scores_in_rfm_columns)

    for axis_label in axes_business_rfm_scores.keys():
        rfm_label = original_columns_to_rfm_labels_mapping[axis_label]
        rfm_scores = axes_business_rfm_scores[axis_label]
        customer_rfm_inputs_df[rfm_label] = rfm_scores

    print(f"Business rules successfully to assess '{rfm_business_rules_scope}'!")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Formating the date columns in DSS non-ambiguous date format :
DATE_COLUMNS = ["rfm_reference_month_start", "rfm_reference_date",
                "first_transaction_date", "last_transaction_date"]
for date_column in DATE_COLUMNS:
    date_column_values = list(customer_rfm_inputs_df[date_column])
    date_column_values = [from_datetime_to_dss_string_date(date) for date in date_column_values]
    customer_rfm_inputs_df[date_column] = date_column_values

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
customer_rfm_inputs_df = customer_rfm_inputs_df.merge(segments_identification_df, how="left",
                                                      on=segments_identification_join_key)
customer_rfm_inputs_df.drop("segment_color", axis=1, inplace=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write recipe outputs:
customer_rfm_segments = dataiku.Dataset("customer_rfm_segments")
customer_rfm_segments.write_dataframe(customer_rfm_inputs_df, infer_schema=False, dropAndCreate=True)
#customer_rfm_segments.write_with_schema(customer_rfm_inputs_df)

no_outlier_customer_rfm_inputs_df = customer_rfm_inputs_df[customer_rfm_inputs_df["customer_id"].isin(not_outlier_customers)]

no_outliers_customer_rfm_segments = dataiku.Dataset("no_outliers_customer_rfm_segments")
#no_outliers_customer_rfm_segments.write_with_schema(no_outlier_customer_rfm_inputs_df)
no_outliers_customer_rfm_segments.write_dataframe(no_outlier_customer_rfm_inputs_df, infer_schema=False, dropAndCreate=True)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Casting the customer_id column datatype to "string" :
change_dataset_column_datatype(project, "customer_rfm_segments", "customer_id", "string")

# Casting the date columns datatype to "date" :
for date_column in DATE_COLUMNS:
    change_dataset_column_datatype(project, "customer_rfm_segments", date_column, "date")

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Ploting RFM boxplots :
#rfm_box_plots = generate_rfm_box_plots(app_variables, rfm_original_columns, original_columns_to_rfm_labels_mapping,
#                                       customer_rfm_inputs_df, not_outlier_customers, n_segments_per_axis)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Saving custom charts as static insights :
#project_key = dataiku.get_custom_variables()["projectKey"]
#for static_insight_id in rfm_box_plots.keys():
#    plotly_figure = rfm_box_plots[static_insight_id]
#    dataiku.insights.save_plotly(id=static_insight_id, figure=plotly_figure, project_key=project_key)