diff --git a/app/utils/logging_utils.py b/app/utils/logging_utils.py index ec6ff903..a2b4d205 100644 --- a/app/utils/logging_utils.py +++ b/app/utils/logging_utils.py @@ -35,9 +35,14 @@ def format(self, record: logging.LogRecord) -> str: Formatted log message including ANSI color codes. """ + original_levelname = record.levelname level_color = self.COLORS.get(record.levelno, "") - record.levelname = f"{level_color}{record.levelname}{self.RESET}" - return super().format(record) + if level_color: + record.levelname = f"{level_color}{record.levelname}{self.RESET}" + try: + return super().format(record) + finally: + record.levelname = original_levelname class RateLimitFilter(logging.Filter): diff --git a/tests/test_color_formatter.py b/tests/test_color_formatter.py index a491d291..f9bcb0f8 100644 --- a/tests/test_color_formatter.py +++ b/tests/test_color_formatter.py @@ -24,6 +24,37 @@ def test_format_includes_ansi_codes(self): self.assertIn("\033[", output) self.assertIn("boom", output) + def test_plain_handler_is_uncolored_with_colored_sibling(self): + logger = logging.getLogger("color_test_plain") + colored_stream = io.StringIO() + plain_stream = io.StringIO() + colored_handler = logging.StreamHandler(colored_stream) + colored_handler.setFormatter(ColorFormatter("%(levelname)s:%(message)s")) + plain_handler = logging.StreamHandler(plain_stream) + plain_handler.setFormatter(logging.Formatter("%(levelname)s:%(message)s")) + logger.setLevel(logging.INFO) + original_propagate = logger.propagate + logger.propagate = False + logger.addHandler(colored_handler) + logger.addHandler(plain_handler) + + try: + logger.warning("mixed output") + finally: + logger.removeHandler(colored_handler) + logger.removeHandler(plain_handler) + colored_handler.close() + plain_handler.close() + logger.propagate = original_propagate + + colored_output = colored_stream.getvalue().strip() + plain_output = plain_stream.getvalue().strip() + + self.assertIn("\033[", colored_output) + self.assertIn("mixed output", colored_output) + self.assertNotIn("\033[", plain_output) + self.assertIn("mixed output", plain_output) + if __name__ == "__main__": unittest.main()