1111from importlib .metadata import Distribution
1212from pathlib import Path
1313from types import ModuleType
14- from typing import Dict , List , Optional , Tuple
14+ from typing import Dict , List , Optional , Tuple , Union
1515
1616from huggingface_hub import file_exists , snapshot_download
1717from packaging .version import parse
1818
1919from kernels ._versions import select_revision_or_version
2020from kernels .lockfile import KernelLock , VariantLock
2121
22+ ENV_VARS_TRUE_VALUES = {"1" , "ON" , "YES" , "TRUE" }
23+
2224
2325def _get_cache_dir () -> Optional [str ]:
2426 """Returns the kernels cache directory."""
@@ -108,6 +110,7 @@ def install_kernel(
108110 revision : str ,
109111 local_files_only : bool = False ,
110112 variant_locks : Optional [Dict [str , VariantLock ]] = None ,
113+ user_agent : Optional [Union [str , dict ]] = None ,
111114) -> Tuple [str , Path ]:
112115 """
113116 Download a kernel for the current environment to the cache.
@@ -123,20 +126,24 @@ def install_kernel(
123126 Whether to only use local files and not download from the Hub.
124127 variant_locks (`Dict[str, VariantLock]`, *optional*):
125128 Optional dictionary of variant locks for validation.
129+ user_agent (`Union[str, dict]`, *optional*):
130+ The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
126131
127132 Returns:
128133 `Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
129134 """
130135 package_name = package_name_from_repo_id (repo_id )
131136 variant = build_variant ()
132137 universal_variant = universal_build_variant ()
138+ user_agent = _get_user_agent (user_agent = user_agent )
133139 repo_path = Path (
134140 snapshot_download (
135141 repo_id ,
136142 allow_patterns = [f"build/{ variant } /*" , f"build/{ universal_variant } /*" ],
137143 cache_dir = CACHE_DIR ,
138144 revision = revision ,
139145 local_files_only = local_files_only ,
146+ user_agent = user_agent ,
140147 )
141148 )
142149
@@ -213,7 +220,10 @@ def install_kernel_all_variants(
213220
214221
215222def get_kernel (
216- repo_id : str , revision : Optional [str ] = None , version : Optional [str ] = None
223+ repo_id : str ,
224+ revision : Optional [str ] = None ,
225+ version : Optional [str ] = None ,
226+ user_agent : Optional [Union [str , dict ]] = None ,
217227) -> ModuleType :
218228 """
219229 Load a kernel from the kernel hub.
@@ -229,6 +239,8 @@ def get_kernel(
229239 version (`str`, *optional*):
230240 The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`.
231241 Cannot be used together with `revision`.
242+ user_agent (`Union[str, dict]`, *optional*):
243+ The `user_agent` info to pass to `snapshot_download()` for internal telemetry.
232244
233245 Returns:
234246 `ModuleType`: The imported kernel module.
@@ -245,7 +257,9 @@ def get_kernel(
245257 ```
246258 """
247259 revision = select_revision_or_version (repo_id , revision , version )
248- package_name , package_path = install_kernel (repo_id , revision = revision )
260+ package_name , package_path = install_kernel (
261+ repo_id , revision = revision , user_agent = user_agent
262+ )
249263 return import_from_path (package_name , package_path / package_name / "__init__.py" )
250264
251265
@@ -501,3 +515,24 @@ def git_hash_object(data: bytes, object_type: str = "blob"):
501515
502516def package_name_from_repo_id (repo_id : str ) -> str :
503517 return repo_id .split ("/" )[- 1 ].replace ("-" , "_" )
518+
519+
520+ def _get_user_agent (
521+ user_agent : Optional [Union [dict , str ]] = None ,
522+ ) -> Union [None , dict , str ]:
523+ import torch
524+
525+ from . import __version__
526+
527+ if os .getenv ("DISABLE_TELEMETRY" , "false" ).upper () in ENV_VARS_TRUE_VALUES :
528+ return None
529+
530+ if user_agent is None :
531+ user_agent = {
532+ "kernels" : __version__ ,
533+ "torch" : torch .__version__ ,
534+ "build_variant" : build_variant (),
535+ "file_type" : "kernel" ,
536+ }
537+
538+ return user_agent
0 commit comments