import logging
import os.path as osp
from os import environ
import tempfile

from dataiku.base.utils import TmpFolder
from dataiku.external_ml.proxy_model.azure_ml.inputformat import AzureMLJSONWriterInputsShaped, AzureMLJSONWriterInputDataShaped
from dataiku.external_ml.proxy_model.azure_ml.inputformat.json_writer_inputdata_columns_data import AzureMLJSONWriterInputDataShapedWithColumnsAndData
from dataiku.external_ml.proxy_model.azure_ml.outputformat import AzureMLJSONArrayReader, AzureMLJSONObjectResultsReader
from dataiku.external_ml.proxy_model.common import ProxyModelEndpointClient, BatchConfiguration, ChunkedAndFormatGuessingProxyModel
from dataiku.external_ml.proxy_model.deploy_anywhere.inputformat import DeployAnywhereWriter
from dataiku.external_ml.proxy_model.deploy_anywhere.outputformat import DeployAnywhereReader

logger = logging.getLogger(__name__)


AZUREML_BYTES_LIMIT = 10000000

# Formats will be tried in the defined order. Order is important, so don't change it unless you really now what you're doing.
OutputFormats = [
    AzureMLJSONArrayReader,
    AzureMLJSONObjectResultsReader,
    DeployAnywhereReader,
]
InputFormats = [
    AzureMLJSONWriterInputDataShapedWithColumnsAndData,
    AzureMLJSONWriterInputDataShaped,
    AzureMLJSONWriterInputsShaped,
    DeployAnywhereWriter,
]


class DKUAzureMLEndpointClient(ProxyModelEndpointClient):
    def __init__(self, sdk_client, endpoint_name):
        self.sdk_client = sdk_client
        self.endpoint_name = endpoint_name
        super(DKUAzureMLEndpointClient, self).__init__()

    def call_endpoint(self, string_data, content_type):
        azure_logger = logging.getLogger('azure')
        if azure_logger:
            old_level = logging.getLogger('azure').getEffectiveLevel()
        try:
            with TmpFolder(tempfile.gettempdir()) as temp_folder:
                filename = osp.join(temp_folder, "azureml_request_body.tmp")
                with open(filename, 'w', encoding='utf-8') as f:
                    f.write(string_data)
                azure_logger.setLevel(logging.ERROR)
                return self.sdk_client.online_endpoints.invoke(
                        endpoint_name=self.endpoint_name,
                        request_file=filename,
                    )
        finally:
            if azure_logger:
                azure_logger.setLevel(old_level)


class AzureMLProxyModel(ChunkedAndFormatGuessingProxyModel):
    def __init__(self, subscription_id, resource_group, workspace, endpoint_name, meta, connection=None, **kwargs):
        self.subscription_id = subscription_id
        self.resource_group = resource_group
        self.workspace = workspace
        self.endpoint_name = endpoint_name
        self.connection = connection
        super(AzureMLProxyModel, self).__init__(meta.get("predictionType"), meta.get("intToLabelMap"), InputFormats, OutputFormats, meta.get("inputFormat"), meta.get("outputFormat"), BatchConfiguration(AZUREML_BYTES_LIMIT))

    def get_client(self):
        from azure.ai.ml import MLClient
        from azure.identity import DefaultAzureCredential
        from azure.core.credentials import AccessToken, TokenCredential
        import time

        class DSSCredential(TokenCredential):
            def __init__(self, access_token):
                self.az_access_token = AccessToken(access_token, int(time.time()) + 3600)

            def get_token(self, *args, **kwargs):
                return self.az_access_token

        credential = None
        if self.connection is not None:
            if AzureMLProxyModel.runs_on_real_api_node():
                logger.info("NOT getting connection params from connection {}. Authentication will be performed from environment".format(self.connection))
            else:
                dss_connection = AzureMLProxyModel.get_connection_info(self.connection, None, "AzureML", "Azure Machine Learning connection")
                params = dss_connection.get_params()
                auth_type = params.get("authType")
                if auth_type == "OAUTH2_APP":
                    tenant_id = dss_connection.get_params().get("tenantId")
                    app_id = dss_connection.get_params().get("appId")
                    credential = DSSCredential(dss_connection.get_oauth2_credential()["accessToken"])
                    logger.debug("Using oauth authentication with tenantId: {} app_id: {}".format(tenant_id, app_id))
                elif auth_type == "ENVIRONMENT":
                    logger.debug("Using authentication from environment")
                else:
                    raise Exception("Unhandled auth type: {}".format(auth_type))
        if credential is None:
            logger.info("Using DefaultAzureCredential()")
            # Apart from env var DKU_AZURE_CLIENT_ID_ENV_KEY that is used by Fleet Manager, the library also supports
            # AZURE_CLIENT_ID variable to set up the client id for managed identity authentications.
            azure_client_id = environ.get("DKU_AZURE_CLIENT_ID_ENV_KEY")

            if azure_client_id:
                credential = DefaultAzureCredential(managed_identity_client_id=azure_client_id)
            else:
                credential = DefaultAzureCredential()

        proxy = self.get_proxy()
        if proxy:
            proxies = {
                "http": "http://" + proxy,
                "https": "http://" + proxy,
            }
            logger.info("Using proxy: {}".format(proxies))
        else:
            proxies = None

        sdk_client = MLClient(
            credential,
            self.subscription_id,
            self.resource_group,
            self.workspace,
            proxies=proxies
        )
        logger.info("Initialized AzureML client with endpoint name '{endpoint_name}' "
                    "using credentials associated to subscription id '{subscription_id}'"
                    "resource group '{resource_group}' and workspace '{workspace}."
                    "".format(endpoint_name=self.endpoint_name,
                              subscription_id=self.subscription_id,
                              resource_group=self.resource_group,
                              workspace=self.workspace))
        return DKUAzureMLEndpointClient(sdk_client, self.endpoint_name)
