from datetime import datetime
import logging
import os
import sys
import base64
import re
import json
from urllib.parse import urlparse, parse_qsl

import dataiku
from dataiku.base.spark_like import SparkLike
from dataiku.base.sql_dialect import SparkLikeDialect
from dataiku.base.utils import package_version_compat

try:
    # cleanup pyspark stuff that was added by env-spark.sh
    dku_pyspark_pythonpath = os.environ.get("DKU_PYSPARK_PYTHONPATH")
    dku_spark_home = os.environ.get("DKU_SPARK_HOME")
    spark_home = os.environ.get("SPARK_HOME")
    if dku_pyspark_pythonpath is not None and len(dku_pyspark_pythonpath) > 0:
        dku_pyspark_pythonpath_chunks = dku_pyspark_pythonpath.split(":")
    elif dku_spark_home is not None and len(dku_spark_home) > 0:
        dku_pyspark_pythonpath_chunks = [dku_spark_home]
    elif spark_home is not None and len(spark_home) > 0:
        dku_pyspark_pythonpath_chunks = [spark_home]
    else:
        dku_pyspark_pythonpath_chunks = []
    def comes_from_dku_pyspark_pythonpath(p):
        for c in dku_pyspark_pythonpath_chunks:
            if c in p:
                return True
        return False
    clean_sys_path = [p for p in sys.path if not comes_from_dku_pyspark_pythonpath(p)]
    if len(sys.path) > len(clean_sys_path):
        logging.warn("Cleaned %s chunks from PYTHONPATH that came from DKU_PYSPARK_PYTHONPATH" % (len(sys.path) - len(clean_sys_path)))
    sys.path = clean_sys_path
except Exception as e:
    logging.warn("Error while restoring a pyspark-free PYTHONPATH: %s" % str(e))


try:
    from databricks.connect import DatabricksSession
    from databricks.sdk.core import Config
    from databricks.sdk.oauth import Token, Refreshable
    import databricks.connect
except ImportError as e:
    raise Exception("Unable to import DBConnectV2 libraries. Make sure you are using a code-env where databricks-connect>=13.0.* is installed. Cause: " + str(e))

def _databricks_version_check(min_version):
    detected_version = None
    if hasattr(databricks.connect, "__version__"):
        detected_version = databricks.connect.__version__
    elif hasattr(databricks.connect, "__dbconnect_version__"):
        detected_version = databricks.connect.__dbconnect_version__
    if detected_version is None:
        raise Exception("you need to install databricks-connect>=%s (no version detected)" % min_version)
    elif package_version_compat(detected_version) < package_version_compat(min_version):
        raise Exception("you need to install databricks-connect>=%s (current version: %s)" % (min_version, detected_version))

from pyspark.sql.functions import col, lit, base64, date_format, to_json, unhex

class DkuDBConnectDialect(SparkLikeDialect):

    def __init__(self):
        SparkLikeDialect.__init__(self)

    def _get_to_dss_types_map(self):
        if self._to_dss_types_map is None:
            self._to_dss_types_map = {
                            'ArrayType': 'string',
                            'BinaryType': 'string',
                            'BooleanType': 'boolean',
                            'ByteType': 'tinyint',
                            'DateType': 'dateonly',
                            'DayTimeIntervalType': 'string',
                            'DecimalType': 'double',
                            'DoubleType': 'double',
                            'FloatType': 'float',
                            'IntegerType': 'int',
                            'LongType': 'bigint',
                            'MapType': 'string',
                            'NullType': 'string',
                            'ShortType': 'smallint',
                            'StringType': 'string',
                            'StructType': 'string',
                            'TimestampNTZType': 'datetimenotz',
                            'TimestampType': 'date',
                            'UserDefinedType': 'string',
                        }
        return self._to_dss_types_map
        
    def allow_empty_schema_after_catalog(self):
        """Whether specifying a table as (catalog, table) is possible"""
        return False
        
    def identifier_quote_char(self):
        """Get the character used to quote identifiers"""
        return '`'
    
    def _column_name_to_sql_column(self, identifier):
        return col(self.quote_identifier(identifier))
    
    def _python_literal_to_sql_literal(self, value, column_type, original_type=None):
        # beware that original_type is not a pyspark type name but a SQL type
        if original_type is not None and original_type.lower() == 'binary':
            # DSS type has to be string, and since we did a getString() on the jdbc driver, it's the hexadecimal
            # of the binary data in the UI. But we can't just give the hex string to spark, that just won't match the actual values
            return unhex(lit(value))
        else:
            return lit(value)
    
    def _get_components_from_df_schema(self, df_schema):
        fields = {}
        for field in df_schema.fields:
            col_name = self.unquote_identifier(field.name)
            fields[col_name] = {"name":col_name, "datatype":field.dataType}
        return (df_schema.names, fields)
        
    def _get_datatype_name_from_df_datatype(self, datatype):
        return datatype.__class__.__name__


def get_databricks_token_from_resolved_credentials(resolved_credentials):
    return Token.from_dict({
        "access_token": resolved_credentials["accessToken"],
        "token_type": "Bearer",
        "expiry": datetime.fromtimestamp(resolved_credentials["expiry"]/1000).isoformat()
    })


# Create our own SessionCredentials that overrides refresh() to get a new access token from the backend
# See SessionCredentials in https://github.com/databricks/databricks-sdk-py/blob/main/databricks/sdk/oauth.py
# We extend Refreshable since it is more stable than SessionCredentials.
# Most recent version when DkuSessionCredentials was added was databricks-sdk==0.38.0
class DkuSessionCredentials(Refreshable):
    def __init__(self,
                 token,
                 connection_name,
                 project_key):
        self.connection_name = connection_name
        self.project_key = project_key
        super().__init__(token=token)

    # Copy the next two methods from SessionCredentials
    def auth_type(self):
        """Implementing CredentialsProvider protocol"""
        return 'oauth'

    def __call__(self, *args, **kwargs):
        """Implementing CredentialsProvider protocol"""

        def inner():
            return {'Authorization': f"Bearer {self.token().access_token}"}

        return inner

    def refresh(self):
        logging.info("Refreshing the Databricks OAuth2 access token")

        connection = dataiku.api_client().get_connection(self.connection_name)
        info = connection.get_info(self.project_key) # this will grab a new access token if needed
        return get_databricks_token_from_resolved_credentials(info["resolvedOAuth2Credential"])


# noinspection PyPep8Naming
class DkuDBConnect(SparkLike):
    """
    Handle to create Databricks Connect sessions from DSS datasets or connections
    """

    def __init__(self, serverless=False):
        """
        Create a wrapper for using Databricks Connect V2 on DSS datasets

        :param bool serverless: if True, ignore the HTTP path of the DSS connection and use serverless Databricks compute instead
        """
        SparkLike.__init__(self)
        self._dialect = DkuDBConnectDialect()
        self._connection_type = "Databricks"
        self._serverless = serverless
        if serverless:
            # check that it's possible
            _databricks_version_check("15.1.0")

    def _get_config_from_compute(self, cluster_id, **kwargs):
        if self._serverless:
            return Config(serverless_compute_id='auto', **kwargs)
        else:
            return Config(cluster_id=cluster_id, **kwargs)

    def _get_config_for_personal_access_token(self, connection_info, connection_parameters, host, cluster_id):
        user_credentials = connection_info.get("resolvedBasicCredential")
        if user_credentials is not None:
            token = user_credentials["password"]
        else:
            token = connection_parameters.get("pwd")
        if token is None or len(token) == 0:
            raise Exception("Cannot find access token in connection settings")
        return self._get_config_from_compute(
            cluster_id,
            host=host,
            token=token
        )

    def _get_config_for_oauth(self, connection_name, connection_info, project_key, host, cluster_id):
        resolved_credentials = connection_info["resolvedOAuth2Credential"]
        if "accessToken" not in resolved_credentials:
            raise ValueError("No OAuth2 access token found in %s connection. Please refer to DSS OAuth2 credentials documentation.".format(connection_name))

        logging.info("Using the resolved Databricks credential to build a Token and SessionCredentials")

        session_credentials = DkuSessionCredentials(
            token=get_databricks_token_from_resolved_credentials(resolved_credentials),
            connection_name=connection_name,
            project_key=project_key
        )
        return self._get_config_from_compute(
            cluster_id,
            host=host,
            credentials_strategy=session_credentials
        )

    def _create_session(self, connection_name, connection_info, project_key=None):
        connection_params = connection_info["resolvedParams"]
        
        # extract needed info from connection
        host = connection_params["host"]
        port = connection_params.get("port", 443)
        host_port = "%s:%s" % (host, port)
        connection_parameters = {"httppath": connection_params.get("httpPath"), "pwd":connection_params.get("password")}
        for prop in connection_params.get("properties", []):
            connection_parameters[prop["name"].lower()] = prop["value"]
                
        # deduce the cluster id
        http_path = connection_parameters.get("httppath", '')
        if self._serverless:
            cluster_id = 'not-relevant'
        else:
            sql_warehouse_pattern = "^/?sql/[^/]+/warehouses/.*$"
            if re.match(sql_warehouse_pattern, http_path) is not None:
                raise Exception("Databricks Connect doesn't support SQL warehouses")
            cluster_pattern = "^/?sql/protocol[^/]+/o/[^/]+/(.*)$"
            m = re.match(cluster_pattern, http_path)
            if m is None:
                raise Exception("Cannot find cluster id in httpPath")
            cluster_id = m.group(1)

        # get the token too
        if connection_params["authType"] == 'PERSONAL_ACCESS_TOKEN':
            config = self._get_config_for_personal_access_token(connection_info, connection_parameters, host, cluster_id)
        elif connection_params["authType"] == 'OAUTH2_APP':
            _databricks_version_check("13.3.0")
            config = self._get_config_for_oauth(connection_name, connection_info, project_key, host, cluster_id)
        else:
            raise Exception("Auth type not supported: " + connection_params["authType"])

        logging.info("Connecting to Databricks cluster %s in workspace %s " % (cluster_id, host))

        session_builder = DatabricksSession.builder.sdkConfig(config)
        # Check whether the userAgent method exists before calling. It seems like it was introduced in databricks-connect@13.1.0
        if hasattr(session_builder, 'userAgent') and callable(getattr(session_builder, 'userAgent')):
            session_builder = DatabricksSession.builder.sdkConfig(config).userAgent("dataiku/dss")

        session = session_builder.getOrCreate()

        # Execute post connect statements if any
        if "postConnectStatementsExpandedAndSplit" in connection_params and len(connection_params["postConnectStatementsExpandedAndSplit"]) > 0:
            for statement in connection_params["postConnectStatementsExpandedAndSplit"]:
                logging.info("Executing statement: %s" % statement)
                session.sql(statement).show()
                logging.info("Statement done")

        session.dss_connection_name = connection_name  # Add a dynamic attribute to the session to recognize its DSS connection later on
        return session

    def _split_jdbc_url(self, url):
        if not url.startswith("jdbc:databricks://"):
            raise ValueError("Invalid JDBC URL. It must start with jdbc:databricks://")
        m = re.match("jdbc:databricks://+([^;/]+)(/[^;]+)?((;[^=]+=[^;]+)*)$", url)
        if m is None:
            raise ValueError("Cannot parse JDBC URL")
            
        parsed = {"properties":[]}
            
        host_port = m.group(1)        
        schema = m.group(2)[1:] if m.group(2) is not None and len(m.group(2)) > 0 else None
        url_props = m.group(3) if m.group(3) is not None and len(m.group(3)) > 0 else None
        # get the host:port first
        m = re.match("^([^:]+):([0-9]+)$", host_port)
        if m is None:
            parsed["host"] = host_port
        else:
            parsed["host"] = m.group(1)
            parsed["port"] = int(m.group(2))
        if schema is not None:
            parsed["defaultSchema"] = schema
        # then the jdbc properties
        if len(url_props) > 0:
            url_props = url_props[1:] # drop leading / or ;
        m = re.match("([^:]+):([0-9]+)", host_port)
        for url_prop in url_props.split(';'):
            m = re.match("^([^=]+)=(.*)$", url_prop)
            parsed["properties"].append({"name":m.group(1), "value":m.group(2)})
            
        # check AuthMech because it says what auth is configured
        auth_mech = 3
        for prop in parsed["properties"]:
            if prop.get("name", '').lower() == "authmech":
                auth_mech = int(prop.get("value", "3").strip()) # default is 3 in driver doc
        if auth_mech == 3:
            parsed["authType"] = 'PERSONAL_ACCESS_TOKEN'
        elif auth_mech == 11:
            parsed["authType"] = 'OAUTH2_APP'
        else:
            raise Exception("Unknown auth mechanism %s" % auth_mech)
            
        return parsed
        

    def _check_dataframe_type(self, df):
        """Check if the dataframe is of the correct type"""
        if not df.__class__.__module__.startswith("pyspark."):
            raise ValueError("Dataframe is not a DatabricksConnect dataframe. Use dataset.write_dataframe() instead.")

    def _do_with_column(self, df, column_name, column_value):
        """Add or set a column in the dataframe"""
        return df.withColumn(column_name, column_value)
    
    def _cast_to_target_types(self, df, dss_schema, qualified_table_id):
        column_names, column_fields = self._dialect._get_components_from_df_schema(df.schema)
        # check the actual schema we're inserting into, and add casts as needed.
        # this is based on the assumption that DSS manages the schema, and may
        # have done some type erasure
        try:
            tdf = df.sparkSession.sql("select * from %s" % qualified_table_id)
            _, target_column_fields = self._dialect._get_components_from_df_schema(tdf.schema)
            for column_name in column_names:
                field = column_fields[column_name]
                output_field = target_column_fields.get(column_name)
                if output_field is None:
                    continue # not a good sign, we're inserting but the output column isn't there...
                datatype_name = self._dialect._get_datatype_name_from_df_datatype(field["datatype"])
                target_datatype_name = self._dialect._get_datatype_name_from_df_datatype(output_field["datatype"])
                if datatype_name == 'BinaryType' and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, base64(col(column_name)))
                if datatype_name in ['TimestampType', 'TimestampNTZType'] and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, date_format(col(column_name), lit('yyyy-MM-dd HH:mm:ss.SSS')))
                if datatype_name == 'DateType' and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, date_format(col(column_name), lit('yyyy-MM-dd')))
                if datatype_name == 'DayTimeIntervalType' and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, col(column_name).cast("string"))
                if datatype_name == 'ArrayType' and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, to_json(col(column_name)))
                if datatype_name == 'MapType' and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, to_json(col(column_name)))
                if datatype_name == 'StructType' and target_datatype_name == 'StringType':
                    df = df.withColumn(column_name, to_json(col(column_name)))
                if datatype_name == 'DecimalType' and target_datatype_name == 'DoubleType':
                    df = df.withColumn(column_name, col(column_name).cast('double'))
                # more vicious cases, caused by literals to not always be of the right type
                if datatype_name in ['ByteType', 'ShortType', 'IntegerType'] and target_datatype_name == 'LongType':
                    df = df.withColumn(column_name, col(column_name).cast('bigint'))
                if datatype_name in ['ByteType', 'ShortType'] and target_datatype_name == 'IntegerType':
                    df = df.withColumn(column_name, col(column_name).cast('int'))
                if datatype_name == 'ByteType' and target_datatype_name == 'ShortType':
                    df = df.withColumn(column_name, col(column_name).cast('short'))
                if datatype_name == 'StringType' and target_datatype_name == 'TimestampType':
                    # we passed an iso8601 timestamp string, and it needs to go in a timestamp column
                    df = df.withColumn(column_name, col(column_name).cast('timestamp'))
        except Exception as e:
            logging.warn("Unable to check output schema, inserting as is : %s" % str(e))
        return df

    def _get_table_schema(self, schema, connection_params):
        if schema and schema.strip():
            return schema
        return self._get_connection_param(connection_params, "db", "db")

    def _get_table_catalog(self, catalog, connection_params):
        if catalog and catalog.strip():
            return catalog
        return self._get_connection_param(connection_params, "defaultCatalog", "catalog")
    