import unittest

from common.backend.utils.json_utils import (
    REPLACEMENT_VALUE,
    mask_keys,
)


class TestMaskKeysInJson(unittest.TestCase):
    def test_simple_dict_with_sensitive_keys(self):
        data = {"text": "Hello world", "traces": {"trace1": ["log1", "log2"], "trace2": {"log3": "log4"}}}
        keys = {"traces"}
        expected = {"text": "Hello world", "traces": REPLACEMENT_VALUE}
        self.assertEqual(mask_keys(data, keys), expected)

    def test_embedded_stringified_dicts_in_list(self):
        input_data = {
            "usedRetrieval": {
                "sources": [
                    {
                        "excerpt": "Some text",
                        "images": [
                            # 👇 This is a string, not a dict
                            "{'full_folder_id': 'PORTAL.abc', 'file_data': 'base64data', 'index': 0}"
                        ],
                        "metadata": {"source_url": "https://example.com"},
                    }
                ]
            }
        }

        keys_to_mask = ["file_data"]
        result = mask_keys(input_data, keys_to_mask)

        image_str = result["usedRetrieval"]["sources"][0]["images"][0]
        self.assertTrue(isinstance(image_str, (str, dict)))
        import ast
        try:
            parsed_image = ast.literal_eval(image_str)
        except Exception:
            self.fail("Image string was not parsable")

        self.assertEqual(parsed_image.get("file_data"), REPLACEMENT_VALUE)

    def test_nested_dict_with_sources_and_images(self):
        data = {
            "answer": {
                "content": "Here’s a summary",
                "sources": ["url1", "url2"],
                "meta": {"images": ["img1.png", "img2.png"], "author": "AI Bot"},
            }
        }
        keys = {"sources", "images"}
        expected = {
            "answer": {
                "content": "Here’s a summary",
                "sources": REPLACEMENT_VALUE,
                "meta": {"images": REPLACEMENT_VALUE, "author": "AI Bot"},
            }
        }
        self.assertEqual(mask_keys(data, keys), expected)

    def test_list_of_logs_with_traces(self):
        data = [{"message": "ok", "traces": ["stack1", "stack2"]}, {"message": "error", "traces": ["trace1"]}]
        keys = {"traces"}
        expected = [{"message": "ok", "traces": REPLACEMENT_VALUE}, {"message": "error", "traces": REPLACEMENT_VALUE}]
        self.assertEqual(mask_keys(data, keys), expected)

    def test_mask_keys_with_images(self):
        json_str = """
        {
            "result": {
                "summary": "done",
                "images": ["img1.jpg", "img2.jpg"]
            }
        }
        """
        keys_to_mask = ["images"]
        masked_str = mask_keys(json_str, keys_to_mask)
        expected = {"result": {"summary": "done", "images": REPLACEMENT_VALUE}}
        self.assertEqual(masked_str, expected)

    def test_invalid_json_string(self):
        invalid_json = '{"result": {"images": ["img1.jpg", "img2.jpg"]'  # Missing closing brace
        result = mask_keys(invalid_json, ["images"])
        self.assertEqual(result, invalid_json)

    def test_empty_and_null_data(self):
        self.assertEqual(mask_keys({}, {"traces"}), {})
        self.assertEqual(mask_keys([], {"sources"}), [])
        self.assertEqual(mask_keys(None, {"images"}), None)