# encoding: utf-8
"""
Single-thread hoster for a MLflow Pyfunc predictor
"""

import sys
import time
import logging

import numpy as np
import pandas as pd

from dataiku.core import debugging
from dataiku.base.folder_context import build_noop_folder_context
from dataiku.external_ml.mlflow.predictor import MLflowPredictor
from dataiku.base.utils import watch_stdin, get_json_friendly_error
from dataiku.base.socket_block_link import JavaLink, parse_javalink_args
from dataiku.doctor.individual_explainer import DEFAULT_SHAPLEY_BACKGROUND_SIZE

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
debugging.install_handler()


def raw_mlflow_pred_with_prediction_field_to_dicts_list(pred_df, nb_records, explanations_cols=None):
    pred_df.loc[:, 'ignored'] = False
    pred_df = pd.DataFrame(index=range(0, nb_records)).merge(pred_df, how='outer', left_index=True, right_index=True)
    pred_df.ignored.fillna(True, inplace=True)
    logging.info('pre to_dict: %s', pred_df)

    results_columns = ["prediction", "ignored"]

    dicts = []
    for i, r in enumerate(pred_df[results_columns].to_dict(orient="records")):
        if r['ignored']:
            dicts.append({'ignored': True, 'ignoreReason': "IGNORED_BY_MODEL"})
        else:
            dicts.append(r)
    return dicts


def raw_mlflow_pred_without_prediction_field_to_dicts_list(pred_df, nb_records):
    pred_df.loc[:, 'ignored'] = False
    pred_df = pd.DataFrame(index=range(0, nb_records)).merge(pred_df, how='outer', left_index=True, right_index=True)
    pred_df.ignored.fillna(True, inplace=True)
    logging.info('pre to_dict: %s', pred_df)

    dicts = []
    for i, r in enumerate(pred_df.to_dict(orient="records")):
        if r['ignored']:
            dicts.append({'ignored': True, 'ignoreReason': "IGNORED_BY_MODEL"})
        else:
            dicts.append({"prediction": r})
    return dicts


def parsed_pred_to_dict(pred_df, nb_records, explanations_df=None):
    pred_df.loc[:, 'ignored'] = False
    pred_df = pd.DataFrame(index=range(0, nb_records)).merge(pred_df, how='outer', left_index=True, right_index=True)
    pred_df.ignored.fillna(True, inplace=True)
    logging.info('pre to_dict: %s', pred_df)

    explanations = None if explanations_df is None else explanations_df.to_dict(orient="records")
    results_columns = ["prediction", "ignored"]

    dicts = []
    for i, r in enumerate(pred_df[results_columns].to_dict(orient="records")):
        if r['ignored']:
            dicts.append({'ignored': True, 'ignoreReason': "IGNORED_BY_MODEL"})
        else:
            if explanations:
                r["explanations"] = {k: explanations[i][k] for k in explanations[i] if not np.isnan(explanations[i][k])}
            dicts.append(r)
    return dicts


# request looks like: {
#  "columns" : {
#    "col1" : [values],
#    "col2" : [values]
#  },
#  "mlFlowOutputStyle": "RAW" or "PARSED"
# }
def handle_predict(predictor, request):
    ret = {}
    logging.info("Input records %s" % request)

    # build the dataframe to predict
    records_df = pd.DataFrame(request["columns"])

    logging.info("Input DF: %s" % records_df)
    logging.info("Input style: %s" % request["mlFlowOutputStyle"])

    nb_records = records_df.shape[0]

    before = time.time()
    if request["mlFlowOutputStyle"] == "RAW":

        pred_df = predictor.predict_raw(records_df, force_json_tensors_output=False)

        logging.info("Pred_df is %s" % pred_df)

        after = time.time()
        ret["execTimeUS"] = int(1000000 * (after - before))
        logging.info("Done predicting, shape=%s" % str(pred_df.shape))

        if "prediction" in pred_df:
            logging.info(pred_df["prediction"])
            logging.info(pred_df["prediction"].dtype)
            logging.info(pred_df["prediction"][0])
            logging.info(type(pred_df["prediction"][0]))

            if "prediction" in pred_df and pred_df.shape[0] > 0:
                if isinstance(pred_df["prediction"][0], (list, set)):
                    ret["arbitraryArray"] = raw_mlflow_pred_with_prediction_field_to_dicts_list(pred_df, nb_records)
                elif isinstance(pred_df["prediction"][0], dict):
                    ret["arbitraryObject"] = raw_mlflow_pred_with_prediction_field_to_dicts_list(pred_df, nb_records)
                else:
                    ret["arbitraryPrimitive"] = raw_mlflow_pred_with_prediction_field_to_dicts_list(pred_df, nb_records)
        else:
            ret["arbitraryObject"] = raw_mlflow_pred_without_prediction_field_to_dicts_list(pred_df, nb_records)
        logging.info("ret=%s" % ret)

    else:
        logging.info("doing PARSED prediction")
        pred_df = predictor.predict(records_df)
        after = time.time()
        ret["execTimeUS"] = int(1000000 * (after - before))
        logging.info("Scoring_data pred_df=\n%s" % pred_df)

        explanations_df = None
        if request.get("explanations", {}).get("enabled"):
            before = time.time()
            explanations_df = predictor._compute_explanations(
                records_df,
                method=request["explanations"]["method"],
                n_explanations=request["explanations"]["nExplanations"],
                mc_steps=request["explanations"].get("mcSteps", DEFAULT_SHAPLEY_BACKGROUND_SIZE)
            ).rename(lambda x: x.replace("explanations_", ""), axis=1)

            after = time.time()
            ret["explanationsTimeUS"] = int(1000000 * (after - before))

        proba_columns = [c for c in pred_df.columns if c.startswith("proba_")]
        has_probas = len(proba_columns) > 0
        pred_idx = pred_df.index
        ret["classification"] = parsed_pred_to_dict(pred_df, nb_records, explanations_df)
        # Fairly ugly ...
        if has_probas:
            record_dicts = pred_df.to_dict(orient='records')
            for (record, i) in zip(record_dicts, pred_idx):
                entry = ret["classification"][i]
                entry["probas"] = {}
                for c in [label["label"] for label in predictor.params.model_meta["classLabels"]]:
                    entry["probas"][c] = record["proba_%s" % c]

    return ret


# socket-based connection to backend
def serve(port, secret, server_cert=None):
    link = JavaLink(port, secret, server_cert=server_cert)
    # initiate connection
    link.connect()

    # get work to do
    try:
        # retrieve the initialization info and initiate serving
        command = link.read_json()

        logging.info("Loading MLflow model")
        predictor = MLflowPredictor(build_noop_folder_context(command.get('modelFolder')))
        logging.info("MLflow model loaded")

        link.send_json({"ok": True})

        # loop and process commands
        while True:
            request = link.read_json()
            if request is None:
                break

            response = handle_predict(predictor, request)
            link.send_json(response)

        # send end of stream
        logging.info("Work done")
        link.send_string('')
    except Exception:
        logging.exception("Prediction user code failed")
        link.send_string('')  # send null to mark failure
        link.send_json(get_json_friendly_error())
    finally:
        # done
        link.close()


if __name__ == "__main__":
    watch_stdin()
    port, secret, server_cert = parse_javalink_args()
    serve(port, secret, server_cert=server_cert)
