# encoding: utf-8

"""
Execute a prediction scoring recipe in Keras mode
Must be called in a Flow environment
"""
import logging
import sys

import dataiku
from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core import dkujson
from dataiku.doctor.deep_learning.keras_support import scored_dataset_generator
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu

logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')

def main(model_folder, input_dataset_smartname, output_dataset_smartname, recipe_desc, script,
         preparation_output_schema, cond_outputs=None, fmi = None):

    log_nvidia_smi_if_use_gpu(recipe_desc=recipe_desc)
    model_folder_context = build_folder_context(model_folder)
    output_generator = scored_dataset_generator(model_folder_context, dataiku.Dataset(input_dataset_smartname),
                                                recipe_desc, script, preparation_output_schema, cond_outputs,
                                                False, False, fmi)

    output_dataset = dataiku.Dataset(output_dataset_smartname)
    logger.info("Starting writer")
    with output_dataset.get_writer() as writer:
        i = 0
        logger.info("Starting to iterate")
        for output_dict in output_generator:
            output_df = output_dict["scored"]
            logger.info("Generator generated a df {}".format(str(output_df.shape)))
            i += 1
            writer.write_dataframe(output_df)
            logger.info("Output df written")


if __name__ == "__main__":
    read_dku_env_and_set()

    with ErrorMonitoringWrapper():
        # Not using global "managedFolderId" (argv[3]) for legacy reason
        main(sys.argv[1], sys.argv[2], sys.argv[4],
             dkujson.load_from_filepath(sys.argv[5]),
             dkujson.load_from_filepath(sys.argv[6]),
             dkujson.load_from_filepath(sys.argv[7]),
             dkujson.load_from_filepath(sys.argv[8]),
             sys.argv[9])
