import logging
import os
import sys
import pandas as pd

from dataiku.base.folder_context import build_folder_context
from dataiku.base.utils import ErrorMonitoringWrapper
from dataiku.core.managed_folder import Folder
from dataiku import Dataset
from dataiku import default_project_key
from dataiku.base.remoterun import read_dku_env_and_set
from dataiku.core import dkujson
from dataiku.doctor.deephub import builtins
from dataiku.doctor.deephub.deephub_params import DeepHubScoringParams
from dataiku.doctor.deephub.deephub_scoring import DeepHubScoringHandler
from dataiku.doctor.deephub.utils.file_utils import ManagedFolderFilesReader
from dataiku.doctor.utils.gpu_execution import log_nvidia_smi_if_use_gpu
from dataiku.doctor.utils.scoring_recipe_utils import add_output_model_metadata
from dataiku.doctor.utils.scoring_recipe_utils import dataframe_iterator
from dataiku.doctor.utils.scoring_recipe_utils import get_dataframe_dtypes_info

logger = logging.getLogger("deephub_scoring_recipe")
# We enforce a chunksize on the data (basically list of file path) because it will have low impact on the scoring
# What is exposed to the user is the batch to score the data
DATA_CHUNKSIZE = 100000


def main(model_folder, input_dataset_smartname, managed_folder_smart_id, output_dataset_smartname, recipe_desc, script,
         preparation_output_schema, fmi = None):
    builtins.load()
    logger.info("Starting deephub scoring")

    output_dataset = Dataset(output_dataset_smartname)

    model_folder_context = build_folder_context(model_folder)
    core_params = model_folder_context.read_json("core_params.json")
    preprocessing_params = model_folder_context.read_json("rpreprocessing_params.json")
    modeling_params = model_folder_context.read_json("rmodeling_params.json")
    user_meta = model_folder_context.read_json("user_meta.json")

    deephub_params = DeepHubScoringParams.build_for_scoring_recipe(modeling_params,
                                                                   preprocessing_params,
                                                                   core_params, model_folder_context,
                                                                   os.getcwd(), user_meta, recipe_desc)
    log_nvidia_smi_if_use_gpu(gpu_config=deephub_params.gpu_config)
    deephub_params.init_deephub_context()

    # No need to close the files_reader, as this one has no side effect
    files_reader = ManagedFolderFilesReader(Folder(managed_folder_smart_id))

    scoring_handler = DeepHubScoringHandler(deephub_params)

    # Obtain a streamed result of the preparation
    input_dataset = Dataset(input_dataset_smartname)
    logger.info("Will do preparation, output schema: %s" % preparation_output_schema)
    input_dataset.set_preparation_steps(script["steps"], preparation_output_schema,
                                        context_project_key=default_project_key())

    def output_generator():
        """
        :rtype: Iterator[pd.DataFrame]
        """
        logger.info("Start output generator ...")

        # TODO @deephub: this is an approximate copy-paste of what we do in other scoring recipes
        # (e.g reg_training_recipe.py), but for now we do not really enforce dtypes for deephub, to be
        # modified once we do
        names, dtypes, parse_date_columns = get_dataframe_dtypes_info(
            preparation_output_schema, preprocessing_params["per_feature"],
            prediction_type=core_params["prediction_type"]
        )
        # normalize=False because normalization is only useful for date columns
        for input_df, input_df_copy_unnormalized in dataframe_iterator(
            input_dataset, names, dtypes, parse_date_columns, preprocessing_params["per_feature"],
            batch_size=DATA_CHUNKSIZE, normalize=False,
        ):
            pred_df = scoring_handler.score(input_df, files_reader)
            scoring_handler.serialize_prediction_df(pred_df)

            if recipe_desc.get("filterInputColumns", False):
                clean_kept_columns = [c for c in recipe_desc["keptInputColumns"] if c not in pred_df.columns]
            else:
                clean_kept_columns = [c for c in input_df_copy_unnormalized.columns if c not in pred_df.columns]

            if recipe_desc.get("outputModelMetadata", False):
                add_output_model_metadata(pred_df, fmi)

            yield pd.concat([input_df_copy_unnormalized[clean_kept_columns], pred_df], axis=1)

    logger.info("Starting writer")
    with output_dataset.get_writer() as writer:
        logger.info("Starting to iterate")
        for output_df in output_generator():
            logger.info("Generator generated a df %s" % str(output_df.shape))
            writer.write_dataframe(output_df)
            logger.info("Output df written")


if __name__ == "__main__":
    logging.basicConfig(level=logging.INFO,
                        format='[%(asctime)s] [%(process)s/%(threadName)s] [%(levelname)s] [%(name)s] %(message)s')
    read_dku_env_and_set()

    args = sys.argv[1:]
    if len(args) != 8:
        raise Exception("Expected 8 arguments, received {}".format(len(args)))

    with ErrorMonitoringWrapper():
        main(args[0], args[1], args[2], args[3],
             dkujson.load_from_filepath(args[4]),
             dkujson.load_from_filepath(args[5]),
             dkujson.load_from_filepath(args[6]),
             args[7])
