Skip to content

Commit b182cd3

Browse files
sayakpauldanieldk
andauthored
feat: allow get_kernel to log telemetry. (#167)
* feat: allow get_kernel to log telemetry. * Apply suggestions from code review Co-authored-by: Daniël de Kok <[email protected]> * doc --------- Co-authored-by: Daniël de Kok <[email protected]>
1 parent ce77658 commit b182cd3

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

docs/source/faq.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,13 @@ The approach of `forward`-replacement is the least invasive, because
3939
it preserves the original model graph. It is also reversible, since
4040
even though the `forward` of a layer _instance_ might be replaced,
4141
the corresponding class still has the original `forward`.
42+
43+
## Misc
44+
45+
### How can I disable kernel reporting in the user-agent?
46+
47+
By default, we collect telemetry when a call to `get_kernel()` is made.
48+
This only includes the `kernels` version, `torch` version, and the build
49+
information for the kernel being requested.
50+
51+
You can disable this by setting `export DISABLE_TELEMETRY=yes`.

src/kernels/utils.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,16 @@
1111
from importlib.metadata import Distribution
1212
from pathlib import Path
1313
from types import ModuleType
14-
from typing import Dict, List, Optional, Tuple
14+
from typing import Dict, List, Optional, Tuple, Union
1515

1616
from huggingface_hub import file_exists, snapshot_download
1717
from packaging.version import parse
1818

1919
from kernels._versions import select_revision_or_version
2020
from kernels.lockfile import KernelLock, VariantLock
2121

22+
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"}
23+
2224

2325
def _get_cache_dir() -> Optional[str]:
2426
"""Returns the kernels cache directory."""
@@ -108,6 +110,7 @@ def install_kernel(
108110
revision: str,
109111
local_files_only: bool = False,
110112
variant_locks: Optional[Dict[str, VariantLock]] = None,
113+
user_agent: Optional[Union[str, dict]] = None,
111114
) -> Tuple[str, Path]:
112115
"""
113116
Download a kernel for the current environment to the cache.
@@ -123,20 +126,24 @@ def install_kernel(
123126
Whether to only use local files and not download from the Hub.
124127
variant_locks (`Dict[str, VariantLock]`, *optional*):
125128
Optional dictionary of variant locks for validation.
129+
user_agent (`Union[str, dict]`, *optional*):
130+
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
126131
127132
Returns:
128133
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
129134
"""
130135
package_name = package_name_from_repo_id(repo_id)
131136
variant = build_variant()
132137
universal_variant = universal_build_variant()
138+
user_agent = _get_user_agent(user_agent=user_agent)
133139
repo_path = Path(
134140
snapshot_download(
135141
repo_id,
136142
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
137143
cache_dir=CACHE_DIR,
138144
revision=revision,
139145
local_files_only=local_files_only,
146+
user_agent=user_agent,
140147
)
141148
)
142149

@@ -213,7 +220,10 @@ def install_kernel_all_variants(
213220

214221

215222
def get_kernel(
216-
repo_id: str, revision: Optional[str] = None, version: Optional[str] = None
223+
repo_id: str,
224+
revision: Optional[str] = None,
225+
version: Optional[str] = None,
226+
user_agent: Optional[Union[str, dict]] = None,
217227
) -> ModuleType:
218228
"""
219229
Load a kernel from the kernel hub.
@@ -229,6 +239,8 @@ def get_kernel(
229239
version (`str`, *optional*):
230240
The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
231241
Cannot be used together with `revision`.
242+
user_agent (`Union[str, dict]`, *optional*):
243+
The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
232244
233245
Returns:
234246
`ModuleType`: The imported kernel module.
@@ -245,7 +257,9 @@ def get_kernel(
245257
```
246258
"""
247259
revision = select_revision_or_version(repo_id, revision, version)
248-
package_name, package_path = install_kernel(repo_id, revision=revision)
260+
package_name, package_path = install_kernel(
261+
repo_id, revision=revision, user_agent=user_agent
262+
)
249263
return import_from_path(package_name, package_path / package_name / "__init__.py")
250264

251265

@@ -501,3 +515,24 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
501515

502516
def package_name_from_repo_id(repo_id: str) -> str:
503517
return repo_id.split("/")[-1].replace("-", "_")
518+
519+
520+
def _get_user_agent(
521+
user_agent: Optional[Union[dict, str]] = None,
522+
) -> Union[None, dict, str]:
523+
import torch
524+
525+
from . import __version__
526+
527+
if os.getenv("DISABLE_TELEMETRY", "false").upper() in ENV_VARS_TRUE_VALUES:
528+
return None
529+
530+
if user_agent is None:
531+
user_agent = {
532+
"kernels": __version__,
533+
"torch": torch.__version__,
534+
"build_variant": build_variant(),
535+
"file_type": "kernel",
536+
}
537+
538+
return user_agent

0 commit comments

Comments
 (0)