Skip to content

Commit fd95242

Browse files
Fixing docstring linter (#2163)
1 parent 97e857f commit fd95242

17 files changed

+81
-59
lines changed

.pre-commit-config.yaml

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ default_language_version:
55

66
repos:
77
- repo: https://github.com/pre-commit/pre-commit-hooks
8-
rev: 6306a48f7dae5861702d573c9c247e4e9498e867
8+
rev: v5.0.0
99
hooks:
1010
- id: trailing-whitespace
1111
- id: check-ast
@@ -18,7 +18,7 @@ repos:
1818
exclude: '^(.*\.svg)$'
1919

2020
- repo: https://github.com/Lucas-C/pre-commit-hooks
21-
rev: v1.5.4
21+
rev: v1.5.5
2222
hooks:
2323
- id: insert-license
2424
files: \.py$|\.sh$
@@ -27,7 +27,7 @@ repos:
2727
- docs/license_header.txt
2828

2929
- repo: https://github.com/pycqa/flake8
30-
rev: 34cbf8ef3950f43d09b85e2e45c15ae5717dc37b
30+
rev: 7.1.1
3131
hooks:
3232
- id: flake8
3333
additional_dependencies:
@@ -37,15 +37,15 @@ repos:
3737
args: ['--config=.flake8']
3838

3939
- repo: https://github.com/omnilib/ufmt
40-
rev: v2.3.0
40+
rev: v2.8.0
4141
hooks:
4242
- id: ufmt
4343
additional_dependencies:
4444
- black == 22.12.0
4545
- usort == 1.0.5
4646

4747
- repo: https://github.com/jsh9/pydoclint
48-
rev: 94efc5f989adbea30f3534b476b2931a02c1af90
48+
rev: 0.5.12
4949
hooks:
5050
- id: pydoclint
5151
args: [--config=pyproject.toml]

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ target-version = ["py38"]
8787
[tool.pydoclint]
8888
style = 'google'
8989
check-return-types = 'False'
90-
exclude = 'tests/torchtune/models/(\w+)/scripts/'
90+
exclude = 'tests/torchtune/models/(\w+)/scripts/|recipes/|torchtune/modules/_export'
9191

9292
[tool.pytest.ini_options]
9393
addopts = ["--showlocals", "--import-mode=prepend", "--without-integration", "--without-slow-integration"]

torchtune/data/_collate.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,14 @@ def padded_collate(
8181
padding values.
8282
8383
Returns:
84-
torch.Tensor: The padded tensor of input ids with shape [batch_size, max_seq_len].
84+
torch.Tensor: The padded tensor of input ids with shape ``[batch_size, max_seq_len]``.
8585
8686
Raises:
87-
ValueError: if ``pad_direction`` is not one of "left" or "right".
88-
ValueError: if ``keys_to_pad`` is empty, or is not a list, or is not a subset of keys in the batch.
89-
ValueError: if ``padding_idx`` is provided as a dictionary, but the keys are not identical to
90-
``keys_to_pad``.
87+
ValueError:
88+
If ``pad_direction`` is not one of "left" or "right", **or**
89+
if ``keys_to_pad`` is empty, or is not a list, **or**
90+
if ``keys_to_pad`` is not a subset of keys in the batch, **or**
91+
if ``padding_idx`` is provided as a dictionary, but the keys are not identical to ``keys_to_pad``
9192
9293
Example:
9394
>>> a = [1, 2, 3]
@@ -149,9 +150,9 @@ def padded_collate(
149150
output_dict[k] = pad_fn(
150151
[torch.tensor(x[k]) for x in batch],
151152
batch_first=True,
152-
padding_value=padding_idx[k]
153-
if isinstance(padding_idx, dict)
154-
else padding_idx,
153+
padding_value=(
154+
padding_idx[k] if isinstance(padding_idx, dict) else padding_idx
155+
),
155156
)
156157
return output_dict
157158

@@ -274,8 +275,9 @@ def padded_collate_tiled_images_and_mask(
274275
- aspect_ratio: Tensor of shape (bsz, max_num_images, 2)
275276
276277
Raises:
277-
ValueError: if ``pad_direction`` is not one of "left" or "right".
278-
ValueError: if pad_max_tiles is set to a value less than the largest number of tiles in an image.
278+
ValueError:
279+
If ``pad_direction`` is not one of "left" or "right", **or**
280+
if pad_max_tiles is set to a value less than the largest number of tiles in an image.
279281
280282
Example:
281283
>>> image_id = 1

torchtune/data/_messages.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,10 @@ class InputOutputToMessages(Transform):
168168
on a remote url. For text-only, leave as None. Default is None.
169169
170170
Raises:
171-
ValueError: If ``column_map`` is provided and ``input`` not in ``column_map``, or
172-
``output`` not in ``column_map``.
173-
ValueError: If ``image_dir`` is provided but ``image`` not in ``column_map``.
171+
ValueError:
172+
If ``column_map`` is provided and ``input`` not in ``column_map``, or
173+
``output`` not in ``column_map``, **or**
174+
if ``image_dir`` is provided but ``image`` not in ``column_map``.
174175
"""
175176

176177
def __init__(

torchtune/data/_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def load_image(image_loc: Union[Path, str]) -> "PIL.Image.Image":
5757
to start with "http" or "https" e.g. "https://www.wikipedia.org/en/bird.jpg".
5858
5959
Raises:
60-
ValueError: If the image cannot be loaded from remote source.
61-
ValueError: If the image cannot be opened as a :class:`~PIL.Image.Image`.
60+
ValueError:
61+
If the image cannot be loaded from remote source, **or**
62+
if the image cannot be opened as a :class:`~PIL.Image.Image`.
6263
6364
Examples:
6465
>>> # Load from remote source

torchtune/models/clip/_position_embeddings.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,13 @@ def _load_state_dict_hook(
126126
**kwargs (Dict[str, Any]): Additional keyword arguments.
127127
128128
Raises:
129-
ValueError: if loaded local or global embedding n_tokens_per_tile is not derived
130-
from a squared grid.
131-
ValueError: if after interpolation, the shape of the loaded local embedding
132-
is not compatible with the current embedding.
133-
ValueError: if after interpolation, the shape of the loaded global embedding
134-
is not compatible with the current embedding.
129+
ValueError:
130+
If loaded local or global embedding n_tokens_per_tile is not derived
131+
from a squared grid, **or**
132+
if after interpolation, the shape of the loaded local embedding
133+
is not compatible with the current embedding, **or**
134+
if after interpolation, the shape of the loaded global embedding
135+
is not compatible with the current embedding.
135136
"""
136137

137138
# process local_token_positional_embedding
@@ -530,9 +531,10 @@ def _load_state_dict_hook(
530531
**kwargs (Dict[str, Any]): Additional keyword arguments.
531532
532533
Raises:
533-
ValueError: if the shape of the loaded embedding is not compatible with the current embedding.
534-
ValueError: if max_num_tiles_x, max_num_tiles_y are not equal.
535-
ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
534+
ValueError:
535+
If the shape of the loaded embedding is not compatible with the current embedding, **or**
536+
if ``max_num_tiles_x``, ``max_num_tiles_y`` are not equal, **or**
537+
if after interpolation, the shape of the loaded embedding is not compatible with the current embedding.
536538
"""
537539

538540
embedding = state_dict.get(prefix + "embedding")

torchtune/models/gemma2/_attention.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,11 @@ class Gemma2Attention(nn.Module):
5050
softcapping (Optional[float]): capping value used for soft caping, if None no capping is performed
5151
query_pre_attn_scalar (Optional[int]): value used for pre attention normalisation, if None head_dim is used instead
5252
Raises:
53-
ValueError: If ``num_heads % num_kv_heads != 0``
54-
ValueError: If ``embed_dim % num_heads != 0``
55-
ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
56-
ValueError: if q_norm is defined without k_norm or vice versa
53+
ValueError:
54+
If ``num_heads % num_kv_heads != 0``, **or**
55+
if ``embed_dim % num_heads != 0``, **or**
56+
if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or**
57+
if ``q_norm`` is defined without k_norm or vice versa
5758
"""
5859

5960
def __init__(
@@ -156,7 +157,11 @@ def setup_cache(
156157
self.cache_enabled = True
157158

158159
def reset_cache(self):
159-
"""Reset the key value caches."""
160+
"""Reset the key value caches.
161+
162+
Raises:
163+
RuntimeError: if key value caches are not already setup.
164+
"""
160165
if self.kv_cache is None:
161166
raise RuntimeError(
162167
"Key value caches are not setup. Call ``setup_caches()`` first."
@@ -196,6 +201,7 @@ def forward(
196201
If none, assume the index of the token is its position id. Default is None.
197202
198203
Raises:
204+
NotImplementedError: If ``mask`` is provided, but mask is not an instance of ``torch.Tensor``.
199205
ValueError: If no ``y`` input and ``kv_cache`` is not enabled.
200206
201207
Returns:

torchtune/models/phi3/_tokenizer.py

+1
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def tokenize_messages(
157157
158158
Raises:
159159
ValueError: If the role is not "user", "assistant", or "system".
160+
RuntimeError: If ``message["type"] != "text``.
160161
161162
Returns:
162163
Tuple[List[int], List[bool]]: The tokenized messages

torchtune/modules/attention.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,11 @@ class MultiHeadAttention(nn.Module):
7373
Default value is 0.0.
7474
7575
Raises:
76-
ValueError: If ``num_heads % num_kv_heads != 0``
77-
ValueError: If ``embed_dim % num_heads != 0``
78-
ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1``
79-
ValueError: if q_norm is defined without k_norm or vice versa
76+
ValueError:
77+
If ``num_heads % num_kv_heads != 0``, **or**
78+
if ``embed_dim % num_heads != 0``, **or**
79+
if ``attn_dropout < 0`` or ``attn_dropout > 1``, **or**
80+
if q_norm is defined without k_norm or vice versa
8081
"""
8182

8283
def __init__(

torchtune/modules/kv_cache.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,13 @@ def update(
8484
Tuple[torch.Tensor, torch.Tensor]: Updated key and value cache tensors, respectively.
8585
8686
Raises:
87-
AssertionError: if the sequence length of ``k_val`` is longer than the maximum cache sequence length.
8887
ValueError: if the batch size of the new key (or value) tensor is greater than the batch size
8988
used during cache setup.
89+
90+
Note:
91+
This function will raise an ``AssertionError`` if the sequence length of ``k_val``
92+
is longer than the maximum cache sequence length.
93+
9094
"""
9195
bsz, _, seq_len, _ = k_val.shape
9296
if bsz > self.k_cache.shape[0]:

torchtune/modules/peft/_utils.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -285,10 +285,11 @@ def validate_missing_and_unexpected_for_lora(
285285
None
286286
287287
Raises:
288-
AssertionError: if base_missing contains any base model keys.
289-
AssertionError: if base_unexpected is nonempty.
290-
AssertionError: if lora_missing contains any LoRA keys.
291-
AssertionError: if lora_unexpected is nonempty.
288+
AssertionError:
289+
If base_missing contains any base model keys, **or**
290+
if base_unexpected is nonempty, **or**
291+
if lora_missing contains any LoRA keys, **or**
292+
if lora_unexpected is nonempty.
292293
"""
293294
lora_modules = get_lora_module_names(
294295
lora_attn_modules, apply_lora_to_mlp, apply_lora_to_output

torchtune/modules/transformer.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,9 @@ class TransformerDecoder(nn.Module):
344344
output_hidden_states (Optional[List[int]]): List of layers (indices) to include in the output
345345
346346
Raises:
347-
AssertionError: num_layers is set and layer is a list
348-
AssertionError: num_layers is not set and layer is an nn.Module
347+
AssertionError:
348+
If ``num_layers`` is set and layer is a list, **or**
349+
``num_layers`` is not set and layer is an ``nn.Module``.
349350
350351
Note:
351352
Arg values are checked for correctness (eg: ``attn_dropout`` belongs to [0,1])
@@ -519,10 +520,11 @@ def _validate_inputs(
519520
input_pos (Optional[torch.Tensor]): Input tensor position IDs.
520521
521522
Raises:
522-
ValueError: if seq_len of x is bigger than max_seq_len
523-
ValueError: if the model has caches which have been setup with self-attention layers and ``mask`` is not provided.
524-
ValueError: if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided.
525-
ValueError: if the model has caches which have been setup ``input_pos`` is not provided.
523+
ValueError:
524+
If seq_len of x is bigger than max_seq_len, **or**
525+
if the model has caches which have been setup with self-attention layers and ``mask`` is not provided, **or**
526+
if the model has caches which have been setup with encoder layers and ``encoder_mask`` is not provided, **or**
527+
if the model has caches which have been setup ``input_pos`` is not provided.
526528
"""
527529

528530
if seq_len > self.max_seq_len:

torchtune/modules/vision_transformer.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,10 @@ class VisionTransformer(nn.Module):
190190
Default is False, which adds CLS token to the beginning of the sequence.
191191
192192
Raises:
193-
ValueError: If `tile_size` is not greater than 0.
194-
ValueError: If `patch_size` is not greater than 0.
195-
ValueError: If `len(out_indices)` is greater than `num_layers`.
193+
ValueError:
194+
If `tile_size` is not greater than 0, **or**
195+
if `patch_size` is not greater than 0, **or**
196+
if `len(out_indices)` is greater than `num_layers`.
196197
"""
197198

198199
def __init__(

torchtune/training/_activation_offloading.py

-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ class OffloadActivations(saved_tensors_hooks):
5555
5656
Raises:
5757
ValueError: if max_fwd_stash_size is not at least 1.
58-
RuntimeError: if use_streams but torch installation is earlier than torch-2.5.0.dev20240907
5958
6059
Example:
6160
>>> with OffloadActivations():

torchtune/training/checkpointing/_checkpointer.py

-1
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,6 @@ class FullModelMetaCheckpointer(_CheckpointerInterface):
943943
944944
Raises:
945945
ValueError: If ``checkpoint_files`` is not a list of length 1
946-
ValueError: If ``should_load_recipe_state`` is True but ``recipe_checkpoint`` is None
947946
"""
948947

949948
def __init__(

torchtune/training/checkpointing/_utils.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,9 @@ def update_state_dict_for_classifier(
284284
if ``output.weight != model.output.weight``.
285285
286286
Raises:
287-
AssertionError: if ``state_dict`` does not contain ``output.weight``.
288-
AssertionError: if ``model_named_parameters`` does not contain ``output.weight``.
287+
AssertionError:
288+
If ``state_dict`` does not contain ``output.weight``, **or**
289+
if ``model_named_parameters`` does not contain ``output.weight``.
289290
290291
"""
291292
output_weight = dict(model_named_parameters).get("output.weight", None)

torchtune/utils/_device.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,11 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
177177
178178
Args:
179179
batch (dict): dict of Tensors or more nested dicts of tensors.
180-
device (torch.device): torch device to move the tensor's too
180+
device (torch.device): torch device to move the tensors to.
181181
182182
Raises:
183-
AttributeError: if batch dict contains anything other than tensors
183+
ValueError: if batch dict contains anything other than ``torch.Tensor``.
184+
184185
"""
185186
for k, v in batch.items():
186187
if isinstance(v, dict):

0 commit comments

Comments
 (0)