"""Dataiku utilities"""
import multiprocessing
import math
import dateutil.relativedelta
import inspect
import json
import logging
import os
import os.path as osp
import select
import shutil
import tempfile
import threading
import traceback
import random
import string
import sys
import tarfile

from contextlib import contextmanager
from six import reraise

from dataiku.base.compat import ImpCompat

logger = logging.getLogger(__name__)


def get_clazz_in_code(code, parent_clazz, strict_module=False):
    """
        Gets a class inherinting from parent_clazz by parsing code (as a string)
    """
    with TmpFolder(tempfile.gettempdir()) as temp_folder:
        code_file = osp.join(temp_folder, "dku_code.py")
        with open(code_file, "wb") as f:
            f.write(encode_utf8(code))
        module = ImpCompat.load_source(
            "dku_nomatter_module_name", code_file
        )
        return get_clazz_in_module(module, parent_clazz, strict_module)


def get_clazz_in_module(module, parent_clazz, strict_module=False):
    """Gets a class inherinting from parent_clazz by reading loaded python module properties"""
    clazz = None
    for k in dir(module):
        v = getattr(module, k)
        if inspect.isclass(v):
            if strict_module and v.__module__ != "dku_nomatter_module_name":
                continue
            if issubclass(v, parent_clazz) and v is not parent_clazz:
                if clazz is not None:
                    raise safe_exception(Exception, u"Multiple classes inheriting {} defined, already had {} and found {}".format(parent_clazz, clazz, v))
                clazz = v
    if clazz is None:
        raise safe_exception(Exception, u"No class inherits {}".format(parent_clazz))
    return clazz

def get_argspec(f):
    """
    Get argspec of a function
    """
    if sys.version_info > (3,0):
        return inspect.getfullargspec(f)
    else:
        return inspect.getargspec(f)

class ErrorMonitoringWrapper:
    """
        Allows to monitor the execution of arbitrary code in order to catch potential errors, format them and dump them
        on a file on the disk, for the backend to retrieve them and display them in the UI.

        To be used in the context of a with statement.

        Can be used when executing (with exec statement) free code from user that needs to run at top level of a script
        (See https://analytics.dataiku.com/projects/RDWIKI/wiki/About%20exec%20in%20python for more info on exec)

        :param exit_if_fail: whether the program should exit if the wrapped code fails. Must be the code number to use
                             when exiting, None if don't want to exit (default is 1)
        :param final_callback: callback to execute after wrapped code ends or fails. Will always be executed except if
                               catch_sysexit is False and there is a system exit in the wrapped code (default None)
        :param error_file: where to dump error information (default is "error.json")
        :param catch_sysexit: whether to catch system exit from wrapped code and finalize execution or exit immediately.
                              Often set to False, because DSS has another mechanism to catch those kinds of errors
                              (default is False)
    """

    def __init__(self, exit_if_fail=1, final_callback=None, error_file="error.json",
                 catch_sysexit=False):
        self.exit_if_fail = exit_if_fail
        self.final_callback = final_callback
        self.error_file = error_file
        self.catch_sysexit = catch_sysexit

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):

        has_failed = exc_type is not None

        if has_failed:

            if not self.catch_sysexit and exc_type is SystemExit:
                return False

            sys.stderr.write("*************** Recipe code failed **************\n")
            sys.stderr.write("Begin Python stack\n")  # Smart log marker
            traceback.print_exc()
            sys.stderr.write("End Python stack\n")  # Smart log marker

            additional_prefix = u""
            while exc_tb is not None:
                if exc_tb.tb_frame is not None and exc_tb.tb_frame.f_code is not None:
                    if exc_tb.tb_frame.f_code.co_filename == "<string>" and exc_tb.tb_frame.f_code.co_name == "<module>":
                        additional_prefix = u"At line {}: ".format(exc_tb.tb_lineno)
                        break
                exc_tb = exc_tb.tb_next

            with open(self.error_file, "w") as f:
                err = {
                    "detailedMessage": u"{}{}: {}".format(additional_prefix, safe_unicode_str(exc_type), safe_unicode_str(exc_val)),
                    "errorType": safe_unicode_str(exc_type),
                    "message": safe_unicode_str(exc_val)
                }
                f.write(json.dumps(err))

        if callable(self.final_callback):
            self.final_callback()

        if (self.exit_if_fail is not None) and has_failed:
            sys.exit(self.exit_if_fail)

        return True

class RaiseWithTraceback:
    """
        A context manager to chain new exception to stack in case of executed code within the with statement
        raises an exception.

        :param fail_message: message that the new Exception will contain (default '')
        :param add_err_in_message: whether to add the Exception message at the end of the new Exception message
                                  (default True)

        works both for python 2 and 3

        Example (run in python 2):

          with RaiseExceptionWithTracebackIfFail("Bad error"):
            1 / 0

          => raises an: Exception: Bad error, Error: integer division or modulo by zero
             and displays the traceback
    """

    def __init__(self, fail_message='', add_err_in_message=True):
        self.fail_message = fail_message
        self.add_prev_err_in_message = add_err_in_message

    def __enter__(self):
        pass

    def __exit__(self, exc_type, exc_val, exc_tb):
        has_failed = exc_type is not None

        if has_failed:
            if self.add_prev_err_in_message:
                new_error_message = self.fail_message + ", " if self.fail_message else ''
                new_error_message += u"Error: {}".format(safe_unicode_str(exc_val))
            else:
                new_error_message = self.fail_message

            reraise(Exception,
                    safe_exception(Exception, new_error_message),
                    exc_tb)

def watch_stdin():
    """
    Starts a thread which watches stdin and exits the process when it closes
    so as not to survive the parent backend
    """
    # Not compatible on Windows, so we disable this
    # It doesn't seems to break anything, but further testing might be able to find flawed behavior
    if not is_os_windows():
        def read_stdin():
            try:
                while True:
                    # Block in select instead of read so as not to hang sys.exit() on Suse
                    (r, w, x) = select.select([sys.stdin], [], [])
                    if sys.stdin not in r:
                        # Should not happen
                        continue
                    if not sys.stdin.read(1):
                        logging.warning("Standard input closed, exiting")
                        os._exit(0)
            except IOError as e:
                logging.warning("Error reading standard input, exiting", exc_info=True)
                os._exit(1)

        stdin_thread = threading.Thread(name="stdin-watcher", target=read_stdin)
        stdin_thread.daemon = True
        stdin_thread.start()

def package_version_compat(package_version_str):
    """ Compatibility function to use packaging instead of distutils when possible, this avoids warnings and is future-proof
    """
    try:
        from packaging import version
        return version.parse(package_version_str)
    except ImportError:
        from distutils.version import LooseVersion
        return LooseVersion(package_version_str)

def check_base_package_version(p, name, min_version, max_version, error_details):
    import warnings
    from distutils.version import LooseVersion

    from dataiku.base import remoterun
    is_in_dss = remoterun.has_env_var("DKU_API_TICKET")

    if max_version is not None and LooseVersion(p.__version__) > LooseVersion(max_version):
        if is_in_dss:
            raise safe_exception(Exception, u"Base package {} is too recent: version {} was found. {}. You should not install overriding versions of DSS base packages.".format(name, p.__version__, error_details))
        else:
            warnings.warn(u"Package {} is too recent: version {} was found. {}. Some features may malfunction.".format(name, p.__version__, error_details), Warning)
    if min_version is not None and LooseVersion(p.__version__) < LooseVersion(min_version):
        if is_in_dss:
            raise safe_exception(Exception, u"Base package {} is too old: version {} was found. {}. You should not install overriding versions of DSS base packages.".format(name, p.__version__, error_details))
        else:
            warnings.warn(u"Package {} is too old: version {} was found. {}. Some features may malfunction.".format(name, p.__version__, error_details), Warning)

def package_is_at_least(p, min_version):
    return package_version_compat(p.__version__) >= package_version_compat(min_version)


def package_is_exactly(p, exact_version):
    return package_version_compat(p.__version__) == package_version_compat(exact_version)


def package_is_at_least_no_import(package_name, min_version):
    """
    :type package_name: str
    :type min_version: str
    """
    package_versions = build_package_versions([package_name])
    version = package_versions.get(package_name)
    if not version:
        return False
    return package_version_compat(version) >= package_version_compat(min_version)


def build_package_versions(package_names):
    """
    Builds a dictionary containing the versions of the requested packages. Starting Python 3.8 importlib is considered as a replacement of pkg_resources (which
    is slow and will be deprecated any time soon).

    :param package_names: list of packages as str
    :return: dict
    """
    package_versions = {}
    try:
        from importlib.metadata import PackageNotFoundError
        from importlib.metadata import version
        for package_name in package_names:
            try:
                package_versions[package_name] = version(package_name)
            except PackageNotFoundError:
                pass
    except ImportError:
        try:
            import pkg_resources
            from pkg_resources import DistributionNotFound
            for package_name in package_names:
                try:
                    package_versions[package_name] = pkg_resources.get_distribution(package_name).version
                except NameError:
                    pass
                except DistributionNotFound:
                    pass
        except ImportError:
            pass
    return package_versions


def get_json_friendly_error(additional_fields=dict()):
    ex_type, ex, tb = sys.exc_info()
    frames = traceback.extract_tb(tb)

    def friendlify(f):
        if isinstance(f, tuple):
            return f
        else:
            # damn you Python3
            return (f.filename, f.lineno, f.name, f.line)

    json_friendly_frames = [friendlify(f) for f in frames]

    result = {
        'errorType': safe_unicode_str(ex_type),
        'message': safe_unicode_str(ex),
        'traceback': json_friendly_frames
    }
    result.update(additional_fields)
    return result


def safe_unicode_str(o):
    if (isinstance(o, Exception)):
        if (hasattr(o, "desc")):
            # Special case for Spark's AnalysisException which has a "desc" field
            # (but its __str__ is badly formatted so we dont want it)
            return safe_unicode_str(o.desc)
        elif isinstance(o, EnvironmentError) and hasattr(o, "errno") and hasattr(o, "strerror") and not (o.errno is None and o.strerror is None):
            # Special handling for EnvironmentError because has multiple attributes ('errno', 'strerror')
            # Most common is IOError that has an additional 'filename' attribute
            error_message = u"[Errno {}] {}".format(o.errno, safe_unicode_str(o.strerror))
            if hasattr(o, "filename"):
                error_message += u": '{}'".format(safe_unicode_str(o.filename))
            return error_message
        elif (o.args is None) or (len(o.args) == 0):
            # Exception has no args, try to convert directly the exception to Unicode
            try:
                if sys.version_info > (3,0):
                    return str(o)
                else:
                    return unicode(o)
            except Exception as e:
                return safe_unicode_str('<No details>')
        else:
            return safe_unicode_str(o.args[0])
    else:
        if sys.version_info > (3, 0):
            # Python 3 special handling
            if (isinstance(o, str)):
                return o
            elif (isinstance(o, bytes)):
                try:
                    return smart_decode_str(o)
                except UnicodeDecodeError:
                    return str(o)
            else:
                return str(o)
        else:
            # Python 2 special handling
            if (isinstance(o, unicode)):
                return o
            elif (isinstance(o, str)):
                return smart_decode_str(o)
            else:
                try:
                    return unicode(o)
                except UnicodeDecodeError:
                    # There will be no infinite loop as Python guarantees that 'str(o)' will produce a string
                    return safe_unicode_str(str(o))


def smart_decode_str(o):
    try:
        # Try to decode the string as utf-8 (most common encoding)
        return o.decode('utf-8', 'strict')
    except UnicodeDecodeError:
        try:
            # Try to decode the string as latin1 (second most common encoding)
            return o.decode('iso-8859-1', 'strict')
        except UnicodeDecodeError:
            # We have run out of options. Skip characters that we cannot decode. This call will (in theory) never fail.
            return o.decode('utf-8', 'ignore')


def random_string(length):
    return ''.join(random.choice(string.ascii_letters) for _ in range(length))


class TmpFolder:
    """
        Helper to create temporary folder inside another folder.

        To be used as a with statement:
          - The __enter__ function returns the path of the new folder
          - The temporary folder is deleted when exiting the with statement

        Example:
            import os
            with TmpFolder("/path/to/parent/folder") as tmp_folder_path:
                file_in_folder_path = os.path.join(tmp_folder_path, "new-file.txt")
                with open(file_in_folder_path, 'w') as f:
                    f.write("this is a new file")
                os.rename(file_in_folder_path, "/new/path")

        Args:
            parent_folder (str): path of the folder in which the temporary folder will created. MUST exists
    """

    def __init__(self, parent_folder):
        unique_folder_name = "tmp_folder_{}".format(random_string(8))
        self._folder_path = osp.join(parent_folder, unique_folder_name)
        os.makedirs(self._folder_path)

    def __enter__(self):
        return self._folder_path

    def __exit__(self, exc_type, exc_val, exc_tb):

        if osp.isdir(self._folder_path):
            shutil.rmtree(self._folder_path)


def safe_exception(cls, msg):
    """
    Returns an exception with correct type for message: utf-8 encoded for python2, unicode (str) for python3
    so that is displayed correctly
    """
    major_version = sys.version_info[0]
    if major_version == 2 and isinstance(msg, unicode):
        msg = msg.encode("utf-8")
    return cls(msg)


def encode_utf8(s):
    major_version = sys.version_info[0]
    if major_version == 2 and isinstance(s, unicode):
        return s.encode("utf-8")
    elif major_version > 2 and isinstance(s, str):
        return s.encode("utf-8")
    return s


def safe_convert_to_string(series):
    first_val = series.iloc[0]
    if sys.version_info > (3, 0):
        return series.astype(object)
    if isinstance(first_val, unicode):
        return series.str.encode("utf-8")
    return series.astype(object)


@contextmanager
def contextualized_thread_name(suffix):
    current_thread = threading.current_thread()
    previous_name = current_thread.name
    current_thread.name = "%s:%s" % (current_thread.name, suffix)
    try:
        yield
    finally:
        current_thread.name = previous_name


def duration_HHMMSS(total_seconds):
    """Convert seconds to a `HH MM SS` string"""
    rd = dateutil.relativedelta.relativedelta(seconds=total_seconds)
    strings = []
    if rd.hours:
        strings.append(str(int(rd.hours)) + "h")
    if rd.minutes:
       strings.append(str(int(rd.minutes)) + "m")
    if int(rd.seconds):
        strings.append(str(int(rd.seconds)) + "s")
    return " ".join(strings)


def is_os_windows():
    return sys.platform.startswith("win32")


def detect_usable_cpu_count():
    """
    Estimate the nb. of usable CPUs after taking cgroups quotas into account
    (from http://blog.tabanpour.info/projects/2018/09/07/tf-docker-kube.html)
    """
    total_count = multiprocessing.cpu_count()
    usable_count = None

    try:
        # Cgroups v1
        cfs_period_path = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"
        cfs_quota_path = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"

        # Cgroups v2
        cgroup_v2_quota = "/sys/fs/cgroup/cpu.max"

        if osp.exists(cfs_quota_path) and osp.exists(cfs_period_path):
            # we are in a linux container with cpu quotas!
            with open(cfs_period_path, 'rb') as p, open(cfs_quota_path, 'rb') as q:
                p, q = float(p.read()), float(q.read())
                # get the cores allocated by dividing the quota
                # in microseconds by the period in microseconds
                if q > 0 and p > 0:
                    usable_count = int(min(total_count, math.ceil(q / p)))

        elif osp.exists(cgroup_v2_quota):
            with open(cgroup_v2_quota) as f:
                quota_str, period_str = f.read().split()
                if quota_str.isnumeric() and period_str.isnumeric():
                    quota = float(quota_str)
                    period = float(period_str)
                    if period > 0 and quota > 0:
                        usable_count = int(min(total_count, float(quota_str) / float(period_str)))
    except:
        logger.exception("Error occurred while reading cgroup quotas")

    logger.info("Detected %s/%s usable CPUs" % ("?" if usable_count is None else usable_count, total_count))
    return total_count if usable_count is None else usable_count

def tar_extractall(tar, *args, **kwargs):
    # https://app.shortcut.com/dataiku/story/171363/hide-scary-warning-occurring-when-extracting-a-tar-archive-on-a-rhel-alma-python
    tar.extraction_filter = getattr(tarfile, "tar_filter", None)
    tar.extractall(*args, **kwargs)
