File tree Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Expand file tree Collapse file tree 2 files changed +11
-1
lines changed Original file line number Diff line number Diff line change @@ -574,7 +574,8 @@ def update_for_spec_dec(self):
574574
575575
576576@maybe_compile (dynamic = True )
577- def _scale (weights , q_scale , s ):
577+ def _scale (weights : torch .Tensor , q_scale : torch .Tensor ,
578+ s : float ) -> torch .Tensor :
578579 return weights * q_scale .squeeze (- 1 ) * s
579580
580581
Original file line number Diff line number Diff line change @@ -328,6 +328,15 @@ def get_device_uuid(device_idx: int) -> str:
328328
329329
330330def maybe_compile (func = None , ** compile_kwargs ):
331+ """
332+ Conditionally compile a function with torch.compile.
333+ If is_piecewise_running() is True, the function will be compiled with torch.compile.
334+ Args:
335+ func: The function to decorate (optional, for direct decoration).
336+ **compile_kwargs: Keyword arguments for torch.compile.
337+ Returns:
338+ The conditionally compiled function..
339+ """
331340
332341 def decorator (f ):
333342
You can’t perform that action at this time.
0 commit comments