import logging

import numpy as np
import pandas as pd

from dataiku import Dataset
from dataiku.core import doctor_constants
from dataiku.doctor.exception import ApiLogsException
from dataiku.doctor.utils import normalize_dataframe, logger
from dataiku.external_ml.proxy_model.common.logs_decoder import decode_sagemaker_input_logs, decode_sagemaker_output_logs

logger = logging.getLogger(__name__)

CLASSICAL_EVALUATION_DATASET_TYPE = "CLASSIC"
# API NODE LOGS
API_NODE_LOGS_FEATURE_PREFIX = "clientEvent.features."
API_NODE_LOGS_RESULT_PREFIX = "clientEvent.result."
API_NODE_LOGS_PREFIX = "clientEvent."
API_NODE_EVALUATION_DATASET_TYPE = "API_NODE_LOGS"
CLOUD_API_NODE_LOGS_FEATURE_PREFIX = "message.features."
CLOUD_API_NODE_LOGS_RESULT_PREFIX = "message.result."
CLOUD_API_NODE_EVALUATION_DATASET_TYPE = "CLOUD_API_NODE_LOGS"
CLOUD_API_NODE_LOGS_PREFIX = "message."
# SAGEMAKER
SAGEMAKER_EVALUATION_DATASET_TYPE = "SAGEMAKER_LOGS"
SAGEMAKER_INPUT_DATA_COL = "captureData.endpointInput.data"
SAGEMAKER_OUTPUT_DATA_COL = "captureData.endpointOutput.data"


def normalize_api_node_logs_dataset(input_df, feature_preproc, evaluation_dataset_type):
    if evaluation_dataset_type == API_NODE_EVALUATION_DATASET_TYPE:
        feature_prefix = API_NODE_LOGS_FEATURE_PREFIX
        result_prefix = API_NODE_LOGS_RESULT_PREFIX
    elif evaluation_dataset_type == CLOUD_API_NODE_EVALUATION_DATASET_TYPE:
        feature_prefix = CLOUD_API_NODE_LOGS_FEATURE_PREFIX
        result_prefix = CLOUD_API_NODE_LOGS_RESULT_PREFIX
    else:
        raise ApiLogsException("The evaluation dataset type %s is not a api node logs format" % evaluation_dataset_type)

    target = None
    for feature, params in feature_preproc.items():
        if params["role"] == "TARGET":
            target = feature
            break

    # Since https://app.shortcut.com/dataiku/story/205258/api-logs-er-handle-cases-where-there-is-message-feature-proba-x-and-message-result-proba-x, we first need to detect all results columns
    # In order to get rid of FEATURE_PREFIX.feature.proba_X that would colide with RESULT_PREFIX.result.probas.X once parsed
    # We only keep the one from the results

    for column in input_df.columns:
        if result_prefix in column:
            normalized_column_name = column.split(result_prefix)[1].replace("probas.", "proba_")
            if normalized_column_name not in input_df.columns:
                input_df.rename(columns={column: normalized_column_name}, inplace=True)
            else:
                input_df.drop(column, axis=1, inplace=True)

    for column in input_df.columns:
        if column == target:
            continue
        elif column == "prediction" or column.startswith("proba_"):
            continue
        elif feature_prefix in column:
            normalized_column_name = column.split(feature_prefix)[1]
            if normalized_column_name not in input_df.columns:
                input_df.rename(columns={column: normalized_column_name}, inplace=True)
            else:
                input_df.drop(column, axis=1, inplace=True)
        else:
            input_df.drop(column, axis=1, inplace=True)

    return normalize_dataframe(input_df, feature_preproc, missing_columns="CREATE")


def create_filter_on_smvid_and_deployment_step(input_dataset, saved_model_fmi, deployment_to_filter_on, evaluation_dataset_type):
    if evaluation_dataset_type == API_NODE_EVALUATION_DATASET_TYPE:
        prefix = API_NODE_LOGS_PREFIX
    elif evaluation_dataset_type == CLOUD_API_NODE_EVALUATION_DATASET_TYPE:
        prefix = CLOUD_API_NODE_LOGS_PREFIX
    else:
        raise ApiLogsException("Filter api logs is not implemented for {}".format(evaluation_dataset_type))

    if input_dataset.preparation_steps is None:
        input_dataset.preparation_steps = []

    # 1. Filter on  apinode-query
    input_dataset.preparation_steps.append(
        {
            "type": "FilterOnValue",
            "preview": False,
            "disabled": False,
            "metaType": "PROCESSOR",
            "alwaysShowComment": False,
            "params": {
                "normalizationMode": "EXACT",
                "booleanMode": "AND",
                "columns": [
                    prefix + "auditTopic"
                ],
                "values": [
                    "apinode-query"
                ],
                "matchingMode": "FULL_STRING",
                "action": "KEEP_ROW",
                "appliesTo": "SINGLE_COLUMN"
            }
        }
    )
    # 2. (Optional) Filter on deployment
    if deployment_to_filter_on is not None:
        input_dataset.preparation_steps.append(
            {
                "type": "FilterOnValue",
                "preview": False,
                "disabled": False,
                "metaType": "PROCESSOR",
                "alwaysShowComment": False,
                "params": {
                    "normalizationMode": "EXACT",
                    "booleanMode": "AND",
                    "columns": [
                        prefix + "apiDeployerDeployment.deploymentId"
                    ],
                    "values": [
                        deployment_to_filter_on
                    ],
                    "matchingMode": "FULL_STRING",
                    "action": "KEEP_ROW",
                    "appliesTo": "SINGLE_COLUMN"
                }
            }
        )

    # 3. Filter on SMV
    input_dataset.preparation_steps.append(
        {
            "type": "FilterOnValue",
            "preview": False,
            "disabled": False,
            "metaType": "PROCESSOR",
            "alwaysShowComment": False,
            "params": {
                "normalizationMode": "EXACT",
                "booleanMode": "AND",
                "columns": [
                    prefix + "savedModel.fullModelId"
                ],
                "values": [
                    saved_model_fmi
                ],
                "matchingMode": "FULL_STRING",
                "action": "KEEP_ROW",
                "appliesTo": "SINGLE_COLUMN"
            }
        }
    )

    # 4. Remove rows where the inference resulted in an error
    input_dataset.preparation_steps.append(
        {
            "type": "RemoveRowsOnEmpty",
            "params": {
                "keep": True,
                "appliesTo": "SINGLE_COLUMN",
                "columns": [prefix + "error"]
            }
        }
    )

# SAGEMAKER LOGS
def decode_and_normalize_sagemaker_logs(
    input_df_orig, mlflow_imported_model, prediction_type, target_column
):
    logger.info("Starts to decode the input dataset as sagemaker logs.")

    if mlflow_imported_model.get("proxyModelEndpointInfo") is None:
        raise ApiLogsException(
            (
                "Endpoint data capture information cannot be retrieved. Prediction logs cannot be "
                "decoded or evaluated. The proxification of the endpoint might have gone wrong."
            )
        )

    data_capture_config = mlflow_imported_model["proxyModelEndpointInfo"]["dataCaptureConfig"]
    if not data_capture_config.get("enableCapture", False):
        raise ApiLogsException("Your endpoint does not allow data capture. Prediction logs cannot be decoded nor evaluated.")

    data_capture_options = [capture_option.get("captureMode") for capture_option in data_capture_config["captureOptions"]]
    logger.info("The data capture option of your model is : %s" % str(data_capture_options))

    input_features = mlflow_imported_model["features"] + [{"name": target_column,
                                                           "type": "string" if prediction_type in {doctor_constants.BINARY_CLASSIFICATION, doctor_constants.MULTICLASS} else float}]

    _, feature_dtypes, _ = Dataset.get_dataframe_schema_st(input_features)

    if "Input" in data_capture_options:
        if SAGEMAKER_INPUT_DATA_COL not in input_df_orig:
            raise ApiLogsException("Your model has the Input data capture option enabled but %s column cannot found" % SAGEMAKER_INPUT_DATA_COL)
        input_df = decode_sagemaker_input_logs(input_df_orig[SAGEMAKER_INPUT_DATA_COL], feature_dtypes, mlflow_imported_model["inputFormat"])
    else:
        raise ApiLogsException("The data capture of input data is not enabled for this endpoint.")

    if "Output" in data_capture_options:
        if SAGEMAKER_OUTPUT_DATA_COL not in input_df_orig:
            raise ApiLogsException("Your model has the Output data capture option enabled but %s column cannot be found" % SAGEMAKER_OUTPUT_DATA_COL)
        pred_df = decode_sagemaker_output_logs(input_df_orig[SAGEMAKER_OUTPUT_DATA_COL], prediction_type, mlflow_imported_model["intToLabelMap"], mlflow_imported_model["outputFormat"])
        return pd.concat([input_df[input_df.index.isin(pred_df.index)], pred_df], axis=1)
    else:
        logger.info("No predictions found in the sagemaker logs, decoding and formatting only the input data.")
        return input_df
