Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/tabpfn_common_utils/telemetry/core/service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import os
import sys

from datetime import datetime
from functools import lru_cache
from posthog import Posthog
from .config import download_config
from .events import BaseTelemetryEvent
Expand Down Expand Up @@ -56,6 +58,10 @@ def telemetry_enabled(cls) -> bool:
Returns:
bool: True if telemetry is enabled, False otherwise.
"""
# Overwrite any telemetry if running in tests

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: typo, overwrite -> override

if cls._runs_in_test():
return False

# Disable telemetry by default in CI environments, but allow override
runtime = get_runtime()
default_disable = "1" if runtime.ci else "0"
Expand All @@ -73,6 +79,32 @@ def telemetry_enabled(cls) -> bool:

return True

@classmethod
@lru_cache(maxsize=1)
def _runs_in_test(cls) -> bool:
"""Auto-detect if the code is running in a test environment.

Returns:
bool: True if the code is running in a test environment, False otherwise.
"""
# Detect automatically set PyTest environment variables
default_env_vars = {"PYTEST_CURRENT_TEST", "PYTEST_XDIST_WORKER"}
for name in default_env_vars:
if os.getenv(name):
return True

# Detect widely-used testing modules
modules = {"pytest", "unittest", "nose"}
if any(name in sys.modules for name in modules):
return True

# Inspect launch args
argv0 = (sys.argv[0] if sys.argv else "").lower()
if "pytest" in argv0 or "py.test" in argv0:
return True

return False

def capture(
self,
event: BaseTelemetryEvent,
Expand Down
104 changes: 104 additions & 0 deletions tests/telemetry/core/test_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from __future__ import annotations

import os
import sys
from unittest.mock import patch

import pytest

from tabpfn_common_utils.telemetry.core.service import ProductTelemetry


class TestRunsInTest:
"""Test the _runs_in_test method for detecting test environments."""

@pytest.fixture(autouse=True)
def setup(self) -> None:
"""Set up test fixtures and clear cache before each test."""
self.original_class = ProductTelemetry() # type: ignore
self.original_class._runs_in_test.cache_clear()

def test_detects_pytest_current_test_env_var(self) -> None:
"""Test detection via PYTEST_CURRENT_TEST environment variable."""
with patch.dict(
os.environ, {"PYTEST_CURRENT_TEST": "tests/test_service.py::test_method"}
):
assert self.original_class._runs_in_test() is True

def test_detects_pytest_xdist_worker_env_var(self) -> None:
"""Test detection via PYTEST_XDIST_WORKER environment variable."""
with patch.dict(os.environ, {"PYTEST_XDIST_WORKER": "gw0"}):
assert self.original_class._runs_in_test() is True

def test_detects_pytest_in_sys_modules(self) -> None:
"""Test detection via pytest in sys.modules."""
# pytest should already be in sys.modules when running tests
assert "pytest" in sys.modules
assert self.original_class._runs_in_test() is True

def test_detects_unittest_in_sys_modules(self) -> None:
"""Test detection via unittest in sys.modules."""
with patch.dict(sys.modules, {"unittest": object()}):
assert self.original_class._runs_in_test() is True

def test_detects_nose_in_sys_modules(self) -> None:
"""Test detection via nose in sys.modules."""
with patch.dict(sys.modules, {"nose": object()}):
assert self.original_class._runs_in_test() is True

def test_detects_pytest_in_argv(self) -> None:
"""Test detection via pytest in sys.argv[0]."""
# Clear env vars and modules to ensure only sys.argv detection is tested
with (
patch.dict(os.environ, {}, clear=True),
patch("tabpfn_common_utils.telemetry.core.service.sys.modules", {}),
patch.object(sys, "argv", ["/usr/bin/pytest", "tests/"]),
):
self.original_class._runs_in_test.cache_clear()
assert self.original_class._runs_in_test() is True

def test_detects_py_test_in_argv(self) -> None:
"""Test detection via py.test in sys.argv[0]."""
# Clear env vars and modules to ensure only sys.argv detection is tested
with (
patch.dict(os.environ, {}, clear=True),
patch("tabpfn_common_utils.telemetry.core.service.sys.modules", {}),
patch.object(sys, "argv", ["/usr/local/bin/py.test", "tests/"]),
):
self.original_class._runs_in_test.cache_clear()
assert self.original_class._runs_in_test() is True

def test_detects_pytest_uppercase_in_argv(self) -> None:
"""Test detection is case-insensitive for sys.argv[0]."""
# Clear env vars and modules to ensure only sys.argv detection is tested
with (
patch.dict(os.environ, {}, clear=True),
patch("tabpfn_common_utils.telemetry.core.service.sys.modules", {}),
patch.object(sys, "argv", ["/path/to/PYTEST", "tests/"]),
):
self.original_class._runs_in_test.cache_clear()
assert self.original_class._runs_in_test() is True

def test_caching_behavior(self) -> None:
"""Test that the function caches its result."""
# First call should compute the result
result1 = self.original_class._runs_in_test()

# Check cache info
cache_info = self.original_class._runs_in_test.cache_info() # type: ignore
assert cache_info.hits == 0
assert cache_info.misses == 1

# Second call should use cache
result2 = self.original_class._runs_in_test()
assert result1 == result2

# Check cache was used
cache_info = self.original_class._runs_in_test.cache_info() # type: ignore
assert cache_info.hits == 1
assert cache_info.misses == 1

def test_returns_bool(self) -> None:
"""Test that the function always returns a boolean."""
result = self.original_class._runs_in_test()
assert isinstance(result, bool)