# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# -*- coding: utf-8 -*-
import dataiku
import pandas as pd
from clv_forecast.utils.generic import reformat_group
from clv_forecast.constants import inference_column_mapping
from sklearn.metrics import classification_report

from dku_utils.projects.project_commons import get_current_project_and_variables

# Read recipe inputs
test_output = dataiku.Dataset("test_output")
test_output_df = test_output.get_dataframe()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
project, variables = get_current_project_and_variables()
n_cluster = variables["standard"]["number_value_cluster_app"]
rename_dict = dict([(reformat_group(cluster),f"Metrics for CLV-group-{cluster+1}") for cluster in range(n_cluster)])
rename_dict.update(dict([("micro avg","micro_avg"),("macro avg","macro_avg"),("weighted avg","weighted_avg")]))

true_label = test_output_df["future_clv_cluster"].tolist()
ml_regression_cluster = test_output_df["ml_regression_future_clv_cluster"].tolist()
ml_classification_cluster = test_output_df["prediction_future_clv_cluster"].tolist()
lifetime_regression_cluster = test_output_df["lifetime_regression_future_clv_cluster"].tolist()
naive_cluster = test_output_df["current_clv_cluster"].tolist()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
#print(classification_report(y_true=true_label, y_pred=ml_regression_cluster))
dict_regression = pd.DataFrame(classification_report(y_true=true_label, y_pred=ml_regression_cluster,output_dict=True))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
#print(classification_report(y_true=true_label, y_pred=ml_classification_cluster))
dict_classification = pd.DataFrame(classification_report(y_true=true_label, y_pred=ml_classification_cluster,output_dict=True))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
#print(classification_report(y_true=true_label, y_pred=lifetime_regression_cluster))
dict_lifetime_regression = pd.DataFrame(classification_report(y_true=true_label, y_pred=lifetime_regression_cluster,output_dict=True))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dict_naive = pd.DataFrame(classification_report(y_true=true_label, y_pred=naive_cluster,output_dict=True))

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
dict_regression["model"]="regression"
dict_classification["model"]="classification"
dict_lifetime_regression["model"]="lifetime"
dict_naive["model"]="naive"

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
manual_df = pd.concat([dict_classification, dict_regression, dict_lifetime_regression,dict_naive]).reset_index().rename(columns={"index":"metric"})

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
manual_df = manual_df.rename(columns=rename_dict)

max_metric = manual_df[manual_df["metric"]==variables["standard"]["classification_metric_to_optimize_app"]]["weighted_avg"].max()

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
best_model = manual_df[(manual_df["metric"]==variables["standard"]["classification_metric_to_optimize_app"]) & (manual_df["weighted_avg"]==max_metric)]["model"].values[0]

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
inference_column = inference_column_mapping[best_model]
variables["standard"]["inference_result_column"] = inference_column
variables["standard"]["best_model"] = best_model
project.set_variables(variables)

# -------------------------------------------------------------------------------- NOTEBOOK-CELL: CODE
# Write recipe outputs
classification_manual_metrics = dataiku.Dataset("classification_manual_metrics")
classification_manual_metrics.write_with_schema(dict_classification.rename(columns=rename_dict))

regression_classification_manual_metrics = dataiku.Dataset("regression_classification_manual_metrics")
regression_classification_manual_metrics.write_with_schema(dict_regression.rename(columns=rename_dict))

lifetime_classification_manual_metrics = dataiku.Dataset("lifetime_classification_manual_metrics")
lifetime_classification_manual_metrics.write_with_schema(dict_lifetime_regression.rename(columns=rename_dict))

manual_metrics = dataiku.Dataset("manual_metrics")
manual_metrics.write_with_schema(manual_df)
