"""
Centralised, opinionated logging helpers.

Every log record gets the fields:
    req_id   – UUID v4 per HTTP/WebSocket request
    user     – authenticated user (or "-")
    method   – HTTP verb (or "-")
    path     – request path (or "-")
"""

from __future__ import annotations

import json
import logging
import logging.config
import os
import uuid
from copy import deepcopy
from importlib import import_module
from typing import Any

from flask import g, has_request_context, request


def _get_cfg() -> dict[str, Any]:
    """
    Import backend.config **lazily** to dodge circular imports.
    Returns {} if import fails (very early bootstrap).
    """
    try:
        cfg_mod = import_module("backend.config")
        return cfg_mod.get_config()  # type: ignore[attr-defined]
    except Exception as e:
        print(f"Exception layzy loading module backend.config {e}")
        return {}


class _RequestContextFilter(logging.Filter):
    def filter(self, record: logging.LogRecord) -> bool:  # noqa: D401
        """Injects request-level context into `record` in-place."""

        if has_request_context():
            record.req_id = getattr(g, "request_id", "-")
            record.user = getattr(g, "authIdentifier", "-")
            record.method = request.method
            record.path = request.path
        else:
            record.req_id = record.user = record.method = record.path = "-"
        return True

# Lower case keys for case-insensitive matching
SENSITIVE_KEYS: set[str] = {
    "api_key",
    "apikey",
    "authorization",
    "credentials",
    "password",
    "secret",
    "token",
    "dkucallerticket",
}
REDACTED_VALUE = "REDACTED"


class _RedactSensitiveDataFilter(logging.Filter):
    def filter(self, record: logging.LogRecord) -> bool:
        """Recursively apply `redact_sensitive_data` to log records."""
        if isinstance(record.msg, (dict, list)):
            record.msg = redact_sensitive_data(record.msg)

        # record.args is either a tuple (for positional args) or a dict (for keyword args)
        if record.args:
            if isinstance(record.args, tuple):
                # Handle positional arguments
                record.args = tuple(redact_sensitive_data(arg) for arg in record.args)
            elif isinstance(record.args, dict):
                # Handle keyword arguments
                record.args = redact_sensitive_data(record.args)
        return True


def redact_sensitive_data(data: Any) -> Any:
    """Recursively traverses a data structure to redact values for sensitive keys.

    - Operates on dicts, lists, and JSON strings.
    - Returns a deep copy, leaving original data untouched.
    - Converts all dict keys to lower-case for case-insensitive matching
      against `SENSITIVE_KEYS`.
    """
    if isinstance(data, str):
        try:
            parsed_data = json.loads(data)
            redacted_data = redact_sensitive_data(parsed_data)
            return json.dumps(redacted_data, indent=2, sort_keys=True)
        except json.JSONDecodeError:
            return data  # Not a valid JSON string

    if not isinstance(data, (dict, list)):
        return data

    # Use deepcopy to avoid modifying the original data structure in-place,
    # which can have unintended side-effects.
    clean_data = deepcopy(data)

    if isinstance(clean_data, dict):
        for key, value in clean_data.items():
            if key.lower() in SENSITIVE_KEYS:
                clean_data[key] = REDACTED_VALUE
            else:
                clean_data[key] = redact_sensitive_data(value)
    elif isinstance(clean_data, list):
        for i, item in enumerate(clean_data):
            clean_data[i] = redact_sensitive_data(item)

    return clean_data


def configure_logging() -> None:
    cfg = _get_cfg()
    cfg_level: str = str(cfg.get("logLevel", "")).upper()
    log_level = (cfg_level or "INFO").upper()

    fmt = "[%(asctime)s.%(msecs)03d] [%(threadName)s] [%(levelname)s] [%(name)s] user=%(user)s - %(message)s"
    datefmt = "%Y/%m/%d-%H:%M:%S"
    formatter = {"format": fmt, "datefmt": datefmt}
    logging.config.dictConfig(
        {
            "version": 1,
            "disable_existing_loggers": False,
            "filters": {
                "ctx": {"()": _RequestContextFilter},
                "redact": {"()": _RedactSensitiveDataFilter},
            },
            "formatters": {"default": formatter},
            "handlers": {
                "console": {
                    "class": "logging.StreamHandler",
                    "formatter": "default",
                    "filters": ["ctx", "redact"],
                    "level": log_level,
                }
            },
            "root": {"handlers": ["console"], "level": log_level},
        }
    )


def get_logger(name: str) -> logging.Logger:
    """Uniform helper – **always** import this instead of `logging.getLogger`."""

    return logging.getLogger(name)


def set_request_id() -> None:
    """
    Ensure `g.request_id` exists (UUID4).
    Call once in *every* Flask/SocketIO entry-point.
    """

    if has_request_context() and not hasattr(g, "request_id"):
        g.request_id = str(uuid.uuid4())


def sanitize_messages_for_log(messages: list[dict] | None) -> list[dict] | None:
    """Sanitizes a list of messages for logging.

    - Redacts large fields inside `__documents__` to avoid bloating logs.
    - Redacts sensitive data like passwords, tokens, etc. via `redact_sensitive_data`.
    """
    if not messages:
        return None

    # Perform a deep copy to avoid modifying the original object.
    messages_copy = deepcopy(messages)

    try:
        for m in messages_copy:
            if isinstance(m, dict):
                docs = m.get("__documents__")
                if isinstance(docs, list):
                    for doc in docs:
                        if isinstance(doc, dict):
                            if "text" in doc:
                                doc["text"] = "[redacted]"
                            if "text_chunks" in doc:
                                doc["text_chunks"] = "[redacted]"
    except Exception:
        # Best-effort: if anything goes wrong, avoid breaking callers and proceed
        # to the next sanitization step.
        pass

    return redact_sensitive_data(messages_copy)


def extract_error_message(error_str: str) -> str:
    """Extract the meaningful message from error string, removing package names."""
    parts = [part.strip() for part in error_str.split(":") if part.strip()]
    return parts[-1] if parts else error_str