# encoding: utf-8
"""
Execute a clustering training recipe in PyRegular mode
Must be called in a Flow environment
"""
import logging
import sys

from dataiku.base.folder_context import build_folder_context
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.doctor import step_constants
from dataiku.doctor import utils
from dataiku.doctor.diagnostics import diagnostics, default_diagnostics
from dataiku.doctor.clustering_entrypoints import clustering_train_score_save
from dataiku.doctor.preprocessing_collector import ClusteringPreprocessingDataCollector
from dataiku.doctor.preprocessing_handler import ClusteringPreprocessingHandler
from dataiku.doctor.utils import unix_time_millis
from dataiku.doctor.utils.listener import ModelStatusContext
from dataiku.doctor.utils.listener import ProgressListener
from dataiku.doctor.utils.split import df_from_split_desc

logger = logging.getLogger(__name__)


def main(exec_folder):
    start = unix_time_millis()

    exec_folder_context = build_folder_context(exec_folder)
    split_folder_context = exec_folder_context.get_subfolder_context("split")
    split_desc = split_folder_context.read_json("split.json")

    listener = ProgressListener(ModelStatusContext(exec_folder_context, start))
    preprocessing_params = exec_folder_context.read_json("rpreprocessing_params.json")
    modeling_params = exec_folder_context.read_json("rmodeling_params.json")
    core_params = exec_folder_context.read_json("core_params.json")

    default_diagnostics.register_clustering_callbacks(core_params)

    with listener.push_step(step_constants.ProcessingStep.STEP_LOADING_SRC):
        train_df = df_from_split_desc(split_desc, "full", split_folder_context, preprocessing_params["per_feature"])
        diagnostics.on_load_train_dataset_end(df=train_df, target_variable=None)
        logger.info("Loaded full df: shape=(%d,%d)" % train_df.shape)

    with listener.push_step(step_constants.ProcessingStep.STEP_COLLECTING_PREPROCESSING_DATA):
        collector = ClusteringPreprocessingDataCollector(train_df, preprocessing_params)
        collector_data = collector.build()

    preproc_handler = ClusteringPreprocessingHandler({}, preprocessing_params, exec_folder_context)
    preproc_handler.collector_data = collector_data
    pipeline = preproc_handler.build_preprocessing_pipeline()

    with listener.push_step(step_constants.ProcessingStep.STEP_PREPROCESS_SRC):
        orig_index = train_df.index.copy()
        transformed_train = pipeline.fit_and_process(train_df)
        preproc_handler.save_data()
        preproc_handler.report(pipeline)
        diagnostics.on_preprocess_train_dataset_end(multiframe=transformed_train["TRAIN"])

    start_train = unix_time_millis()

    clustering_train_score_save(transformed_train,orig_index,
        preprocessing_params, modeling_params, exec_folder_context,  listener, pipeline)

    end = unix_time_millis()

    utils.write_done_traininfo(exec_folder_context, start, start_train, end, listener.to_jsonifiable())


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
    read_dku_env_and_set()
    run_folder = sys.argv[1]

    with ErrorMonitoringWrapper():
        main(run_folder)
