import logging
import os

from dataiku.customwebapp import get_webapp_config
from flask import g, has_request_context


class _RequestContextFilter(logging.Filter):
    """
    This class is a custom logging filter that injects user information into the log records.
    
    Purpose:
    When logging events occur during a web request, it's useful to include information about the user who made the request.
    This class checks if the current context has request information (like user identity) and adds this information to the log records.
    If there is no request context, it defaults to an empty string.
    """
    def filter(self, record: logging.LogRecord) -> bool:  # noqa: D401
        """Injects request-level context into `record` in-place."""

        if has_request_context():
            record.user = getattr(g, "authIdentifier", "")
        else:
            record.user = ""
        return True


class ConditionalUserFormatter(logging.Formatter):
    """
    This class formats log records conditionally based on the presence of user information.
    
    Purpose:
    The formatter checks if there is user information within a log record.
    - If there is, it includes that information in the log output.
    - If there isn't, it uses a default log format that excludes user details.
    
    Why we do this:
    This conditional formatting helps in maintaining clean and relevant log outputs.
    In cases where user information is not applicable (e.g.: we don't have the user information) the logs remain concise. 
    """
    def format(self, record: logging.LogRecord) -> str:
        user = getattr(record, "user", None)
        if not user:
            fmt = "%(asctime)s - %(name)s - [%(threadName)s (%(thread)d)] - %(levelname)s - %(message)s"
            temp_formatter = logging.Formatter(fmt)
            return temp_formatter.format(record)
        return super().format(record)


class LazyLogger:
    _logger = None
    _initialized = False

    @classmethod
    def _initialize_logger(cls):
        if not cls._initialized:
            try:
                webapp_config = get_webapp_config()
                log_level = webapp_config.get('log_level', 'INFO')
            except Exception as e:
                log_level = 'INFO'

            level = getattr(logging, log_level.upper(), logging.INFO)
            if not isinstance(level, int):
                raise ValueError(f'Invalid log level: {log_level}')

            if cls._logger is None:
                cls._logger = logging.getLogger(__name__)
            cls._logger.setLevel(level)

            if not cls._logger.handlers:
                # Set up a formatter that includes user information if available
                formatter = ConditionalUserFormatter("%(asctime)s - %(name)s - [%(threadName)s (%(thread)d)] - %(levelname)s - user=%(user)s - %(message)s")

                # Create an instance of the RequestContextFilter to include user details in log records
                user_filter = _RequestContextFilter()

                debug_run_folder_path = os.getenv("DEBUG_RUN_FOLDER")
                if debug_run_folder_path:
                    local_logs_path = os.path.join(debug_run_folder_path, "logs", "answers.log")
                    os.makedirs(os.path.dirname(local_logs_path), exist_ok=True)
                    file_handler = logging.FileHandler(local_logs_path)
                    file_handler.setFormatter(formatter)
                    cls._logger.addHandler(file_handler)

                handler = logging.StreamHandler()
                handler.setFormatter(formatter)
                handler.addFilter(user_filter)
                cls._logger.addHandler(handler)
                cls._logger.propagate = False

            cls._initialized = True
    
    def _log_conv_id(self, msg: str) :
        from common.backend.utils.context_utils import get_conv_id
        conv_id = get_conv_id()
        if not conv_id:
            return msg
        else:
            return f"{msg} (conv_id:'{conv_id}')"

    def debug(self, msg, log_conv_id: bool=False, *args, **kwargs):
        self._initialize_logger()
        if log_conv_id:
            msg = self._log_conv_id(msg)
        self._logger.debug(msg, *args, **kwargs) # type: ignore

    def info(self, msg, log_conv_id: bool=False, *args, **kwargs):
        self._initialize_logger()
        if log_conv_id:
            msg = self._log_conv_id(msg)
        self._logger.info(msg, *args, **kwargs) # type: ignore

    def warn(self, msg, log_conv_id: bool=False, *args, **kwargs):
        self._initialize_logger()
        if log_conv_id:
            msg = self._log_conv_id(msg)
        self._logger.warning(msg, *args, **kwargs) # type: ignore

    def error(self, msg, log_conv_id: bool=False, *args, **kwargs):
        self._initialize_logger()
        if log_conv_id:
            msg = self._log_conv_id(msg)
        self._logger.error(msg, *args, **kwargs) # type: ignore

    def critical(self, msg, log_conv_id: bool=False, *args, **kwargs):
        self._initialize_logger()
        if log_conv_id:
            msg = self._log_conv_id(msg)
        self._logger.critical(msg, *args, **kwargs) # type: ignore

    def exception(self, msg, log_conv_id: bool=False, *args, **kwargs):
        self._initialize_logger()
        if log_conv_id:
            msg = self._log_conv_id(msg)
        self._logger.exception(msg, *args, **kwargs) # type: ignore


logger = LazyLogger()
