# encoding: utf-8
"""
Execute a clustering training recipe in PyRegular mode
Must be called in a Flow environment
"""
import json
import logging
import numpy as np
import pandas as pd
import sys


import dataiku
from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.doctor import step_constants
from dataiku.doctor import utils
from dataiku.doctor.clustering.clustering_fit import clustering_fit
from dataiku.doctor.preprocessing_collector import ClusteringPreprocessingDataCollector
from dataiku.doctor.preprocessing_handler import ClusteringPreprocessingHandler
from dataiku.doctor.utils import doctor_constants
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.split import df_from_split_desc_no_normalization

logger = logging.getLogger(__name__)


def main(exec_folder, output_dataset, keptInputColumns):
    start = unix_time_millis()
    listener = ProgressListener()

    exec_folder_context = build_folder_context(exec_folder)
    split_folder_context = exec_folder_context.get_subfolder_context("split")
    split_desc = split_folder_context.read_json("split.json")

    preprocessing_params = exec_folder_context.read_json("rpreprocessing_params.json")
    modeling_params = exec_folder_context.read_json("rmodeling_params.json")

    with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_SRC):
        input_df = df_from_split_desc_no_normalization(split_desc, "full", split_folder_context, preprocessing_params["per_feature"])
        logger.info("Loaded full df: shape=(%d,%d)" % input_df.shape)
        input_df_copy_unnormalized = input_df.copy()
        input_df = utils.normalize_dataframe(input_df, preprocessing_params["per_feature"])        

    with listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
        collector = ClusteringPreprocessingDataCollector(input_df, preprocessing_params)
        collector_data = collector.build()

    preproc_handler = ClusteringPreprocessingHandler({}, preprocessing_params, exec_folder_context)
    preproc_handler.collector_data = collector_data
    pipeline = preproc_handler.build_preprocessing_pipeline()

    with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_SRC):
        transformed_train = pipeline.fit_and_process(input_df)

    start_train = unix_time_millis()

    (clf, actual_params, cluster_labels, anomaly_scores, _) = clustering_fit(modeling_params, transformed_train)

    # if model has custom labels, use them
    try:
        cluster_names = clf.get_cluster_labels()
    except AttributeError:
        cluster_names = ["cluster_%s" % i for i in range(len(np.unique(cluster_labels)))]
    cl = pd.Series(data=cluster_labels, name="cluster_labels").map(lambda i: cluster_names[i])
    cl.index = transformed_train["TRAIN"].index

    final_df = input_df_copy_unnormalized.join(cl, how='left')

    if anomaly_scores is not None:
        final_df = pd.concat([final_df, pd.Series(anomaly_scores, name="anomaly_score", index=transformed_train["TRAIN"].index)], axis=1)

    if keptInputColumns is not None:
        final_df = final_df[keptInputColumns + ['cluster_labels']]

    if preprocessing_params["outliers"]["method"] == "CLUSTER":
        final_df['cluster_labels'].fillna(doctor_constants.CLUSTER_OUTLIERS, inplace=True)

    dataiku.Dataset(output_dataset).write_from_dataframe(final_df)

    end = unix_time_millis()

    utils.write_done_traininfo(exec_folder_context, start, start_train, end, listener.to_jsonifiable())


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')

    read_dku_env_and_set()
    
    keptInputColumns = None
    if len(sys.argv) > 3 and len(sys.argv[3]) > 0:
        try:
            logger.info("Kept input columns: "+sys.argv[3])
            keptInputColumns = json.loads(sys.argv[3])
        except Exception as e:
            logger.error(e)
            raise Exception("Failed to parse columns to keep, check the logs")

    with ErrorMonitoringWrapper():
        main(sys.argv[1], sys.argv[2], keptInputColumns)
