import logging
from time import sleep

from dataiku.external_ml.proxy_model.common import BatchConfiguration
from dataiku.external_ml.proxy_model.common import ChunkedAndFormatGuessingProxyModel
from dataiku.external_ml.proxy_model.common import ProxyModelEndpointClient
from dataiku.external_ml.proxy_model.deploy_anywhere.inputformat import DeployAnywhereWriter
from dataiku.external_ml.proxy_model.deploy_anywhere.outputformat import DeployAnywhereReader
from dataiku.external_ml.proxy_model.sagemaker.inputformat import SagemakerCSVWriter
from dataiku.external_ml.proxy_model.sagemaker.inputformat import SagemakerJSONWriter
from dataiku.external_ml.proxy_model.sagemaker.inputformat import SagemakerJSONExtendedWriter
from dataiku.external_ml.proxy_model.sagemaker.inputformat import SagemakerJSONLINESWriter
from dataiku.external_ml.proxy_model.sagemaker.outputformat import SagemakerArrayAsStringReader
from dataiku.external_ml.proxy_model.sagemaker.outputformat import SagemakerCSVReader
from dataiku.external_ml.proxy_model.sagemaker.outputformat import SagemakerJSONLINESReader
from dataiku.external_ml.proxy_model.sagemaker.outputformat import SagemakerJSONReader

logger = logging.getLogger(__name__)


SAGEMAKER_BYTES_LIMIT = 400000
SAGEMAKER_MODEL_NOT_READY_INITIAL_DELAY = 5   # 5 seconds
SAGEMAKER_MODEL_NOT_READY_MAX_DELAY = 120  # 120 seconds
SAGEMAKER_MODEL_NOT_READY_MAX_RETRIES = 5

# Formats will be tried in the defined order. Order is important, so don't change it unless you really now what you're doing.
OutputFormats = [
    SagemakerArrayAsStringReader,
    SagemakerJSONReader,
    SagemakerCSVReader,
    SagemakerJSONLINESReader,
    DeployAnywhereReader,
]
InputFormats = [
    SagemakerCSVWriter,
    SagemakerJSONWriter,
    SagemakerJSONExtendedWriter,
    SagemakerJSONLINESWriter,
    DeployAnywhereWriter,
]


class DKUSagemakerModelNotReadyException(Exception):
    pass


class DKUSagemakerEndpointClient(ProxyModelEndpointClient):
    def __init__(self, sdk_client, not_ready_retry_initial_delay, not_ready_retry_max_delay, not_ready_max_retries):
        self.sdk_client = sdk_client
        self.not_ready_retry_initial_delay = not_ready_retry_initial_delay
        self.not_ready_retry_max_delay = not_ready_retry_max_delay
        self.not_ready_max_retries = not_ready_max_retries
        super(DKUSagemakerEndpointClient, self).__init__()

    def call_endpoint(self, string_data, content_type):
        current_try = 0
        should_retry = True
        while should_retry:
            should_retry = False
            try:
                return self.sdk_client.predict(string_data, content_type)
            except Exception as e:
                if hasattr(e, "response"):
                    error = e.response.get("Error")
                    if error is not None:
                        code = error.get('Code')
                        if code == "ModelNotReadyException":
                            should_retry = current_try < self.not_ready_max_retries
                            if should_retry:
                                # Exponential backoff
                                sleep_time = min(
                                    self.not_ready_retry_max_delay,
                                    self.not_ready_retry_initial_delay * (2 ** current_try),
                                )
                                logger.info(
                                    "Model is not ready. Will try again in {} seconds.".format(sleep_time)
                                )
                                sleep(sleep_time)
                                current_try += 1
                            else:
                                logger.error(e)
                                raise DKUSagemakerModelNotReadyException(
                                    "SageMaker reported that the model is still not ready after {} retries. Please try again later.".format(
                                        current_try
                                    )
                                )
                if not should_retry:
                    msg = str(e)
                    if "Inference failed due to insufficient memory" in msg:
                        msg = (
                            "The SageMaker endpoint raised a 'insufficient memory' error, be aware that this error message is very often misleading and that"
                            " the root cause could be something totally unrelated to the endpoint memory, like an issue with the input format,"
                            " or columns that the model did not expect."
                        )
                        logger.error(msg + " Logging the exception anyway: ")
                        logger.error(e)
                        raise Exception(msg)
                    else:
                        raise


class SagemakerProxyModel(ChunkedAndFormatGuessingProxyModel):
    def __init__(self, endpoint_name, meta, region=None, connection=None, **kwargs):
        self.endpoint_name = endpoint_name
        self.region = region
        self.connection = connection

        super(SagemakerProxyModel, self).__init__(meta.get("predictionType"), meta.get("intToLabelMap"), InputFormats, OutputFormats, meta.get("inputFormat"), meta.get("outputFormat"), BatchConfiguration(SAGEMAKER_BYTES_LIMIT))

    def get_client(self):
        # TODO: check if these can be moved to the top of the file
        import boto3
        from botocore.config import Config
        from sagemaker import Predictor
        from sagemaker import Session

        connection_region = None
        access_key = None
        secret_key = None
        session_token = None
        if self.connection is not None:
            if SagemakerProxyModel.runs_on_real_api_node():
                logger.info("NOT getting connection params from connection {}. "
                            "Authentication will be performed from environment".format(self.connection))
            else:
                logger.info("Getting connection params from connection {}".format(self.connection))
                dss_connection = SagemakerProxyModel.get_connection_info(self.connection, None, "SageMaker", "SageMaker connection")
                params = dss_connection.get_params()
                connection_region = params.get("regionOrEndpoint")
                resolved_aws_credential = dss_connection.get("resolvedAWSCredential")
                if resolved_aws_credential is not None:
                    access_key = resolved_aws_credential["accessKey"]
                    secret_key = resolved_aws_credential["secretKey"]
                    session_token = resolved_aws_credential.get("sessionToken")
                    logger.info("Connection using access_key {} from resolved AWS credentials".format(access_key))
        if not self.region:
            if connection_region:
                logger.info("Using region from connection: {}".format(connection_region))
                region = connection_region
            else:
                region = None
        else:
            logger.info("Using region from saved model: {}".format(self.region))
            region = self.region

        proxy = self.get_proxy()
        if proxy:
            client_config = Config(
                proxies={
                    "http": proxy,
                    "https": proxy
                }
            )
            logger.info("Using proxy config: {}".format(client_config))
        else:
            client_config = None

        client = boto3.client('sagemaker', region_name=region, aws_access_key_id=access_key,
                              aws_secret_access_key=secret_key, aws_session_token=session_token,
                              config=client_config)
        runtime_client = boto3.client('sagemaker-runtime', region_name=region, aws_access_key_id=access_key,
                                      aws_secret_access_key=secret_key, aws_session_token=session_token,
                                      config=client_config)
        boto_session = boto3.Session(region_name=region)
        sagemaker_session = Session(boto_session=boto_session, sagemaker_client=client,
                                    sagemaker_runtime_client=runtime_client)

        sdk_client = Predictor(endpoint_name=self.endpoint_name, sagemaker_session=sagemaker_session)
        logger.info("Initialized sagemaker client with endpoint name '{endpoint_name}' "
                    "using profile '{profile_name}'' in region '{region_name}'."
                    "".format(endpoint_name=self.endpoint_name,
                              profile_name=sdk_client.sagemaker_session.boto_session.profile_name,
                              region_name=sdk_client.sagemaker_session.boto_session.region_name))
        return DKUSagemakerEndpointClient(sdk_client, SAGEMAKER_MODEL_NOT_READY_INITIAL_DELAY, SAGEMAKER_MODEL_NOT_READY_MAX_DELAY,
                                          SAGEMAKER_MODEL_NOT_READY_MAX_RETRIES)
