# encoding: utf-8
"""
Execute a clustering scoring recipe in PyRegular mode
Must be called in a Flow environment
"""
import logging
import numpy as np
import pandas as pd
import sys

import dataiku
from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import default_project_key
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.core import debugging
from dataiku.core import dkujson
from dataiku.core import doctor_constants
from dataiku.doctor.utils.model_io import load_model_from_folder
from dataiku.doctor.clustering.clustering_fit import clustering_predict
from dataiku.doctor.preprocessing_handler import ClusteringPreprocessingHandler
from dataiku.doctor.utils.scoring_recipe_utils import add_output_model_metadata
from dataiku.doctor.utils.scoring_recipe_utils import dataframe_iterator
from dataiku.doctor.utils.scoring_recipe_utils import get_dataframe_dtypes_info

logger = logging.getLogger(__name__)


def main(model_folder, input_dataset_smartname, output_dataset_smartname, recipe_desc, script, preparation_output_schema, fmi=None):
    input_dataset = dataiku.Dataset(input_dataset_smartname)
    logger.info("Will do preparation, output schema: %s" % preparation_output_schema)
    input_dataset.set_preparation_steps(script["steps"], preparation_output_schema,
                                        context_project_key=default_project_key())

    model_folder_context = build_folder_context(model_folder)
    preprocessing_params = model_folder_context.read_json("rpreprocessing_params.json")
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    collector_data = model_folder_context.read_json("collector_data.json")

    # Name remapping
    user_meta = model_folder_context.read_json("user_meta.json")
    cluster_name_map = {}
    if "clusterMetas" in user_meta:
        logger.info("Cluster metas: %s" % user_meta["clusterMetas"])
        for (cluster_id, cluster_data) in user_meta["clusterMetas"].items():
            cluster_name_map[cluster_id] = cluster_data["name"]

    preprocessing_handler = ClusteringPreprocessingHandler({}, preprocessing_params, model_folder_context)
    preprocessing_handler.collector_data = collector_data
    # For plugins, the clf can define preprocessors, and to deserialize those, we need to load the plugin files in the path for unpickling.
    # Hence why we load the model before building the pipeline
    clf = load_model_from_folder(model_folder_context, is_prediction=False)
    pipeline = preprocessing_handler.build_preprocessing_pipeline()

    batch_size = recipe_desc.get("pythonBatchSize", 100000)
    logger.info("Scoring with batch size: {}".format(batch_size))

    try:
        logger.info("Post-processing model")
        clf.post_process(user_meta)
    except AttributeError:
        # method does not exist if model cannot be post-processed, just pass
        pass

    try:
        custom_labels = clf.get_cluster_labels()

        def map_fun_custom(i):
            name = custom_labels[i]
            return cluster_name_map.get(name, name)

        naming = map_fun_custom
    except AttributeError:
        def map_fun(i):
            name = "cluster_%i" % i
            return cluster_name_map.get(name, name)
        naming = map_fun

    def output_generator():
        logger.info("Start output generator ...")

        names, dtypes, parse_date_columns = get_dataframe_dtypes_info(
            preparation_output_schema, preprocessing_params["per_feature"]
        )

        for input_df, input_df_copy_unnormalized in dataframe_iterator(
            input_dataset, names, dtypes, parse_date_columns, preprocessing_params["per_feature"],
            batch_size=batch_size
        ):
            if recipe_desc.get("filterInputColumns", False):
                input_df_copy_unnormalized = input_df_copy_unnormalized[recipe_desc["keptInputColumns"]]

            logger.info("Processing it")
            transformed = pipeline.process(input_df)

            if transformed["TRAIN"].shape()[0] == 0:
                logger.info("Batch of size {} were dropped by preprocessing".format(input_df_copy_unnormalized.shape[0]))
                final_df = input_df_copy_unnormalized.copy()
                final_df["cluster_labels"] = np.nan
            else:
                logger.info("Applying it")
                (labels_arr, anomaly_scores) = clustering_predict(modeling_params, clf, transformed)

                cluster_labels = pd.Series(labels_arr, name="cluster_labels").map(naming)
                cluster_labels.index = transformed["TRAIN"].index
                final_df = input_df_copy_unnormalized.join(cluster_labels, how='left')
                if anomaly_scores is not None:
                    final_df = pd.concat([final_df, pd.Series(anomaly_scores, index=transformed["TRAIN"].index, name="anomaly_score")], axis=1)

            if preprocessing_params["outliers"]["method"] == "CLUSTER":
                outliers_cluter_name = cluster_name_map.get(doctor_constants.CLUSTER_OUTLIERS, doctor_constants.CLUSTER_OUTLIERS)
                final_df['cluster_labels'].fillna(outliers_cluter_name, inplace=True)

            logger.info("Done predicting it")

            if recipe_desc.get("outputModelMetadata", False):
                add_output_model_metadata(final_df, fmi)

            yield final_df

    output_dataset = dataiku.Dataset(output_dataset_smartname)
    logger.info("Starting writer")
    with output_dataset.get_writer() as writer:
        i = 0
        logger.info("Starting to iterate")
        for output_df in output_generator():
            logger.info("Generator generated a df %s" % str(output_df.shape))
            #if i == 0:
            #    output_dataset.write_schema_from_dataframe(output_df)
            i = i+1
            writer.write_dataframe(output_df)
            logger.info("Output df written")


if __name__ == "__main__":
    debugging.install_handler()
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        # folder, ism, osm, desc, script, out_schema
        main(sys.argv[1], sys.argv[2], sys.argv[3],
             dkujson.load_from_filepath(sys.argv[4]),
             dkujson.load_from_filepath(sys.argv[5]),
             dkujson.load_from_filepath(sys.argv[6]),
             sys.argv[7])
