import logging
import json

from dataiku.external_ml.proxy_model.common import BatchConfiguration
from dataiku.external_ml.proxy_model.common import ChunkedAndFormatGuessingProxyModel
from dataiku.external_ml.proxy_model.common import ProxyModelEndpointClient
from dataiku.external_ml.proxy_model.databricks.inputformat import RecordOrientedJSONWriter, SplitOrientedJSONWriter, \
    TFInputsJSONWriter, TFInstancesJSONWriter, DatabricksCSVWriter
from dataiku.external_ml.proxy_model.databricks.output.json_reader import DatabricksJSONResultsReader

logger = logging.getLogger(__name__)

DATABRICKS_BYTES_LIMIT = 1400000  # Copied from Vertex AI.

# Formats will be tried in the defined order. Order is important, so don't change it unless you really now what you're doing.
OutputFormats = [
    DatabricksJSONResultsReader
]
InputFormats = [
    RecordOrientedJSONWriter,
    SplitOrientedJSONWriter,
    TFInputsJSONWriter,
    TFInstancesJSONWriter,
    DatabricksCSVWriter
]


class DatabricksEndpointClient(ProxyModelEndpointClient):
    def __init__(self, url, proxies, token):
        self.url = url
        self.proxies = proxies
        self.token = token
        super(DatabricksEndpointClient, self).__init__()

    def call_endpoint(self, data, headers):
        import requests
        if self.token:
            headers["Authorization"] = "Bearer {}".format(self.token)
        # we also send json this way, not using the special json= param,
        # as the json is encoding forbiding NaN, then.
        # but DBX AutoML models typically accept NaNs.
        response = requests.post(
            self.url,
            proxies=self.proxies,
            headers=headers,
            data=data
        )
        if response.status_code != requests.codes.ok:
            try:
                parsed_response = json.loads(response.text)
            except Exception as e:
                logger.error("Databricks endpoint returned an error and we failed to read the error message from the response. Logging the exception.")
                logger.error(e)
                response.raise_for_status()
            raise Exception("Endpoint returned an error. Code: {}, Message: {}".format(parsed_response["error_code"], parsed_response["message"]))
        try:
            return response.json()
        except Exception as e:
            logger.error("Could not parse the endpoint response as json. The first 10 000 characters of the response were: {}".format(response.text[:10000]))
            # Not using raise ... from to make this file parseable in Python 2.7
            exception_with_cause = Exception("Could not parse the endpoint response as json. The first 1000 characters of the response were: {}".format(response.text[:1000]))
            exception_with_cause.__cause__ = e
            raise exception_with_cause


class DatabricksProxyModel(ChunkedAndFormatGuessingProxyModel):
    def __init__(self, endpoint_name, meta, connection, **kwargs):
        self.endpoint_name = endpoint_name
        self.connection = connection
        super(DatabricksProxyModel, self).__init__(meta.get("predictionType"), meta.get("intToLabelMap"), InputFormats,
                                                   OutputFormats, meta.get("inputFormat"), meta.get("outputFormat"),
                                                   BatchConfiguration(DATABRICKS_BYTES_LIMIT))

    def get_client(self):
        proxy = self.get_proxy()
        if proxy:
            proxies = {
                "http": "http://" + proxy,
                "https": "http://" + proxy,
            }
            logger.debug("Using proxies: {}".format(proxies))
        else:
            logger.debug("No applicative proxy configuration. Proxies may still be defined "
                         "with HTTP_PROXY and HTTPS_PROXY")
            proxies = None

        if self.connection is None:
            raise Exception("No connection configured.")
        logger.info("Using connection {} to authenticate".format(self.connection))
        dss_connection = DatabricksProxyModel.get_connection_info(self.connection, None, "DatabricksModelDeployment", "Databricks Model Deployment connection")
        params = dss_connection.get_params()
        auth_type = params.get("authType")
        if auth_type == "PERSONAL_ACCESS_TOKEN":
            logger.debug("Using keypair configured in connection")
            token = params.get("personalAccessToken")
            if not token:
                raise Exception("Personal Access Token authentication configured, but no token available")
        elif auth_type == "OAUTH2_APP":
            logger.debug("Using oauth per user authentication")
            token = dss_connection.get_oauth2_credential()["accessToken"]
        else:
            raise Exception("Unhandled auth type: {}".format(auth_type))
        host = params.get("host")
        if not host:
            raise Exception("Empty host in parameters of connection {}".format(self.connection))
        host = host.rstrip("/")
        if host.startswith('http://'):
            host = host[7:]
        elif host.startswith('https://'):
            host = host[8:]

        url = "https://" + host + ":" + str(params.get("port")) + "/serving-endpoints/" + self.endpoint_name + "/invocations"
        logger.info("Initializing Databricks client for uri '{url}'."
                    "".format(url=url))
        return DatabricksEndpointClient(url, proxies, token)
