import logging
import os
import threading
import time

from dataiku.base.utils import is_os_windows

logger = logging.getLogger(__name__)

POLL_PARENT_PID_TIME_SECONDS = 5


def _has_parent_process_died():
    """ Check if parent process is 'init' (1), meaning it has died
        DO NOT use this method if the parent is already 'init' (1)

        Only works on Unix systems, not Windows
    """
    try:
        return os.getppid() == 1  # Parent process (kernel) has become init
    except:
        logger.exception("could not fetch parent process information")
        return True


def watch_parent_process():
    """
    Start a thread that monitors the process' parent, and exit on failure
    Because distributed training runs several processes, if one of them dies, resource need to be freed

    Failure modes:
    - We exclude the one where DSS crashes
    - Kernel dies (also pytorch process when distributed)
    - Subprocess dies: supported by the timeout parameter of torch.distributed.init_process_group
                       which defaults to 30 minutes
    - Abort in the UI == kernel dies

    The implementation polls the parent process to check if it became init (1), in which case it died.

    Has the same goal as `watch_stdin` which checks a broken pipe on the process to exit
    because it is shared by the parent process, but wasn't reliable enough for pytorch 1.9 and `torch.distributed.run`.

    In container execution, a `subprocess` is spawned by `runner.py`.
    `runner.py` is taking the place of `init` inside the container and has process id 1.

    Only works on Unix systems, not Windows
    """
    thread_name = "parent-process-monitor"

    if is_os_windows():
        logger.info("'{}' thread is not supported on windows".format(thread_name))
        return

    logger.info("watching parent ({}) process".format(os.getppid()))

    def run():
        while True:
            time.sleep(POLL_PARENT_PID_TIME_SECONDS)
            if _has_parent_process_died():
                logger.info("parent process has been killed, current process is attached to init, exiting")

                # sys.exit will cleanly finalize the intepreter and will wait for non-daemon thread to finish
                # (here the training running in the main thread).
                # In case of a dead parent, we want to exit directly.
                os._exit(0)

    thread = threading.Thread(target=run, daemon=True, name=thread_name)
    thread.start()
