import logging
import base64

def get_snowflake_account_from_host_field(sf_account):
        if sf_account.endswith("/"):
            sf_account = sf_account[:-1]
        if sf_account.endswith(".snowflakecomputing.com"):
            return sf_account[:-len(".snowflakecomputing.com")]
        else:
            return sf_account


def _keypair_connection(connection_params, connection_parameters):
    from cryptography.hazmat.backends import default_backend
    from cryptography.hazmat.primitives import serialization

    if "user" in connection_params and connection_params["user"] is not None:
        connection_parameters["user"] = connection_params["user"]
    key = base64.b64decode(connection_params["privateKeyB64"])
    if "privateKeyPassword" in connection_params and connection_params["privateKeyPassword"]:
        p_key = serialization.load_pem_private_key(
            key,
            password=connection_params["privateKeyPassword"].encode(),
            backend=default_backend()
        )
    else:
        p_key = serialization.load_pem_private_key(
            key,
            password=None,
            backend=default_backend()
        )
    pkb = p_key.private_bytes(
        encoding=serialization.Encoding.DER,
        format=serialization.PrivateFormat.PKCS8,
        encryption_algorithm=serialization.NoEncryption())
    connection_parameters["private_key"] = pkb


def get_snowflake_connection_params(connection_name, connection_info):
        connection_params = connection_info["resolvedParams"]

        sf_host = connection_params["host"]
        if sf_host.startswith("https://"):
            sf_host = sf_host[8:]
        connection_parameters = {"account": get_snowflake_account_from_host_field(sf_host)}

        # Credentials
        if connection_info['params']['authType'] == "PASSWORD":
            sf_credentials = connection_info["resolvedBasicCredential"]
            if "user" in sf_credentials and sf_credentials["user"] is not None:
                connection_parameters["user"] = sf_credentials["user"]
            if "password" in sf_credentials and sf_credentials["password"] is not None:
                connection_parameters["password"] = sf_credentials["password"]
            if "privateKey" in connection_params and connection_params["privateKey"] is not None:
                connection_parameters["private_key"] = base64.b64decode(connection_params["privateKey"])
        elif connection_info['params']['authType'] == "OAUTH2_APP":
            sf_credentials = connection_info["resolvedOAuth2Credential"]
            if "accessToken" not in sf_credentials:
                raise ValueError("No accessToken found in %s connection. Please refer to DSS OAuth2 credentials documentation.".format(connection_name))
            connection_parameters["authenticator"] = "oauth"
            connection_parameters["token"] = sf_credentials["accessToken"]
        elif connection_info['params']['authType'] == "KEY_PAIR":
            _keypair_connection(connection_params, connection_parameters)
        else:
            raise ValueError("Unsupported authentication type '%s'.".format(connection_info['params']['authType']))

        # Connection params
        role = _get_connection_param(connection_params, "role", "role")
        if role is not None:
            connection_parameters["role"] = role

        warehouse = _get_connection_param(connection_params, "warehouse", "warehouse")
        if warehouse is not None:
            connection_parameters["warehouse"] = warehouse

        database = _get_connection_param(connection_params, "db", "db")
        if database is not None:
            connection_parameters["database"] = database

        schema = _get_connection_param(connection_params, "defaultSchema", "schema")
        if schema is not None:
            connection_parameters["schema"] = schema

        connection_parameters["application"] = connection_params["application"] if "application" in connection_params else "dataiku"

        # Make sure we don't leak sensitive information in the logs (these 2 fields must have been converted into "private_key" if they were present)
        if "private_key_file" in connection_parameters:
            del connection_parameters["private_key_file"]
        if "private_key_file_pwd" in connection_parameters:
            del connection_parameters["private_key_file_pwd"]

        to_be_redacted = ["password", "appSecret", "private_key", "token"]
        redacted_connection_parameters = {k: ("****" if k in to_be_redacted else v) for k, v in connection_parameters.items()}
        logging.info("Using Snowflake connection params %s " % redacted_connection_parameters)

        return connection_parameters

def _get_connection_param(connection_params, param_name, property_name):
        if param_name in connection_params and connection_params[param_name].strip():
            return connection_params[param_name]
        for prop in connection_params["properties"]:
            if prop["name"] is not None and prop["name"].lower() == property_name.lower():
                if "value" in prop and prop["value"].strip():
                    logging.info("Connection %s not found in the parameters but found in the JDBC properties. Using value %s" % (param_name, prop["value"]))
                    return prop["value"]
                break
        return None


def get_snowflake_connection(connection_parameters):
    import snowflake.connector

    logging.info("Connecting to Snowflake")
    con = snowflake.connector.connect(**connection_parameters)
    logging.info("Connected to Snowflake")

    # Execute post connect statements if any
    #    if "postConnectStatementsExpandedAndSplit" in connection_params and len(connection_params["postConnectStatementsExpandedAndSplit"]) > 0:
    #        for statement in connection_params["postConnectStatementsExpandedAndSplit"]:
    #            logging.info("Executing statement: %s" % statement)
    #            session.sql(statement).collect()
    #            logging.info("Statement done")


    return con


def get_quoted_columns_list(columns):
    return ",".join(['"' + c.replace('"', '""') + '"' for c in columns])

#def get_snowflake_select_query(dataset, dataset_info):
