Skip to content

Commit 5ee7dad

Browse files
committed
Modernize transformers module with type hints and generic types
1 parent 6fea888 commit 5ee7dad

File tree

10 files changed

+67
-71
lines changed

10 files changed

+67
-71
lines changed

src/llmcompressor/transformers/compression/compressed_tensors_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import weakref
33
from collections.abc import Generator
44
from functools import wraps
5-
from typing import Optional
65

76
import torch
87
from accelerate.accelerator import get_state_dict_offloaded_model
@@ -54,8 +53,8 @@ def save_pretrained_compressed(save_pretrained_method):
5453
@wraps(original_save_pretrained)
5554
def save_pretrained_wrapper(
5655
save_directory: str,
57-
sparsity_config: Optional[SparsityCompressionConfig] = None,
58-
quantization_format: Optional[str] = None,
56+
sparsity_config: SparsityCompressionConfig | None = None,
57+
quantization_format: str | None = None,
5958
save_compressed: bool = True,
6059
safe_serialization: bool = True,
6160
skip_sparsity_compression_stats: bool = True,
@@ -233,8 +232,8 @@ def untie_if_target_shared_embedding(
233232

234233
def get_model_compressor(
235234
model: torch.nn.Module,
236-
sparsity_config: Optional[SparsityCompressionConfig] = None,
237-
quantization_format: Optional[str] = None,
235+
sparsity_config: SparsityCompressionConfig | None = None,
236+
quantization_format: str | None = None,
238237
save_compressed: bool = True,
239238
skip_sparsity_compression_stats: bool = True,
240239
disable_sparse_compression: bool = False,

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections import defaultdict
2-
from typing import Dict, List, Optional, Tuple
2+
from typing import Tuple
33

44
import torch
55
from accelerate.accelerator import get_state_dict_offloaded_model
@@ -51,8 +51,8 @@ def tensor_follows_mask_structure(tensor: torch.Tensor, mask: str = "2:4") -> bo
5151

5252

5353
def infer_sparsity_structure_from_modifiers(
54-
modifiers: List[Modifier], # noqa E501
55-
) -> Optional[str]:
54+
modifiers: list[Modifier], # noqa E501
55+
) -> str | None:
5656
"""
5757
Determines the sparsity structure, if any exists, given the list of modifiers.
5858
@@ -65,7 +65,7 @@ def infer_sparsity_structure_from_modifiers(
6565
return None
6666

6767

68-
def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]:
68+
def infer_sparsity_structure_from_model(model: torch.nn.Module) -> str | None:
6969
"""
7070
Determines the sparsity structure, if any exists, given the model
7171
@@ -104,7 +104,7 @@ def infer_sparse_targets_and_ignores(
104104
model: torch.nn.Module,
105105
sparsity_structure: str,
106106
sparsity_threshold: float,
107-
) -> Tuple[List[str], List[str]]:
107+
) -> Tuple[list[str], list[str]]:
108108
"""
109109
Infers the target and ignore layers in the given model
110110
to be used for sparsity compression
@@ -151,7 +151,7 @@ def is_sparse_compression_target(
151151

152152
def _get_sparse_targets_ignore_dicts(
153153
module: torch.nn.Module, sparsity_structure: str, sparsity_threshold: float
154-
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
154+
) -> Tuple[dict[str, list[str]], dict[str, list[str]]]:
155155
"""
156156
Get sparse targets and ignore dictionaries
157157
@@ -176,8 +176,8 @@ def _get_sparse_targets_ignore_dicts(
176176

177177

178178
def _reduce_targets_and_ignores_into_lists(
179-
exhaustive_targets: Dict[str, List[str]], exhaustive_ignore: Dict[str, List[str]]
180-
) -> Tuple[List[str], List[str]]:
179+
exhaustive_targets: dict[str, list[str]], exhaustive_ignore: dict[str, list[str]]
180+
) -> Tuple[list[str], list[str]]:
181181
"""
182182
Reduces the targets and ignores dictionaries into lists
183183

src/llmcompressor/transformers/compression/sparsity_metadata_config.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Dict, List, Optional
2-
31
from compressed_tensors import CompressionFormat, SparsityCompressionConfig
42
from compressed_tensors.config import SparsityStructure
53
from compressed_tensors.quantization import QuantizationType
@@ -30,7 +28,7 @@ class SparsityConfigMetadata:
3028

3129
@staticmethod
3230
def infer_global_sparsity(
33-
model: Module, state_dict: Optional[Dict[str, Tensor]] = None
31+
model: Module, state_dict: dict[str, Tensor] | None = None
3432
) -> float:
3533
"""
3634
Calculates the global percentage of sparse zero weights in the model
@@ -47,12 +45,12 @@ def infer_global_sparsity(
4745

4846
@staticmethod
4947
def infer_sparsity_structure(
50-
model: Optional[Module] = None, check_only_modifiers: Optional[bool] = False
48+
model: Module | None = None, check_only_modifiers: bool | None = False
5149
) -> str:
5250
"""
5351
Determines what sparsity structure, if any, was applied.
5452
55-
First, there is an attempt to dedue the sparsity structure
53+
First, there is an attempt to deduce the sparsity structure
5654
from the currently active sparse session.
5755
5856
If that fails, the sparsity structure is inferred from the
@@ -83,12 +81,12 @@ def infer_sparsity_structure(
8381
@staticmethod
8482
def from_pretrained(
8583
model: Module,
86-
state_dict: Optional[Dict[str, Tensor]] = None,
84+
state_dict: dict[str, Tensor] | None = None,
8785
compress: bool = False,
88-
quantization_format: Optional[CompressionFormat] = None,
86+
quantization_format: CompressionFormat | None = None,
8987
disable_sparse_compression: bool = False,
90-
sparsity_structure: Optional[str] = None,
91-
) -> Optional["SparsityCompressionConfig"]:
88+
sparsity_structure: str | None = None,
89+
) -> "SparsityCompressionConfig" | None:
9290
"""
9391
Determines compression type and informational parameters for a given model
9492
@@ -155,7 +153,7 @@ def from_pretrained(
155153
def fill_config_details(
156154
config: SparsityCompressionConfig,
157155
model: Module,
158-
state_dict: Optional[Dict[str, Tensor]] = None,
156+
state_dict: dict[str, Tensor] | None = None,
159157
):
160158
"""
161159
Fills in informational sparsity parameters from a given model
@@ -173,7 +171,7 @@ def fill_config_details(
173171
@staticmethod
174172
def is_sparse24_bitmask_supported(
175173
model: Module,
176-
sparsity_structure: Optional[str] = None,
174+
sparsity_structure: str | None = None,
177175
) -> bool:
178176
"""
179177
Determines if sparse 24 bitmask sparse compressor is supported for a given model
@@ -202,7 +200,7 @@ def is_sparse24_bitmask_supported(
202200

203201
# when model is quantized, and has 2:4 sparsity
204202

205-
supported_scheme_types: List[str] = [
203+
supported_scheme_types: list[str] = [
206204
QuantizationType.INT.value,
207205
QuantizationType.FLOAT.value,
208206
]

src/llmcompressor/transformers/finetune/data/base.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import inspect
1111
from functools import cached_property
1212
from inspect import _ParameterKind as Kind
13-
from typing import Any, Callable, Dict, List, Union
13+
from typing import Any, Callable
1414

1515
from compressed_tensors.registry import RegistryMixin
1616
from datasets import Dataset, IterableDataset
@@ -202,7 +202,7 @@ def load_dataset(self):
202202
)
203203

204204
@cached_property
205-
def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
205+
def preprocess(self) -> Callable[[LazyRow], Any] | None:
206206
"""
207207
The function must return keys which correspond to processor/tokenizer kwargs,
208208
optionally including PROMPT_KEY
@@ -225,7 +225,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
225225
return self.dataset_template
226226

227227
@property
228-
def dataset_template(self) -> Union[Callable[[Any], Any], None]:
228+
def dataset_template(self) -> Callable[[Any], Any] | None:
229229
return None
230230

231231
def rename_columns(self, dataset: DatasetType) -> DatasetType:
@@ -254,7 +254,7 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
254254
list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY]))
255255
)
256256

257-
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
257+
def tokenize(self, data: LazyRow) -> dict[str, Any]:
258258
# separate prompt
259259
prompt = data.pop(self.PROMPT_KEY, None)
260260

@@ -276,7 +276,7 @@ def tokenize(self, data: LazyRow) -> Dict[str, Any]:
276276

277277
return data
278278

279-
def group_text(self, data: LazyRow) -> Dict[str, Any]:
279+
def group_text(self, data: LazyRow) -> dict[str, Any]:
280280
concatenated_data = {k: sum(data[k], []) for k in data.keys()}
281281
total_length = len(concatenated_data[list(data.keys())[0]])
282282
total_length = (total_length // self.max_seq_length) * self.max_seq_length
@@ -311,10 +311,10 @@ def add_labels(self, data: LazyRow) -> LazyRow:
311311

312312
def map(
313313
self,
314-
dataset: Union[Dataset, IterableDataset],
314+
dataset: Dataset | IterableDataset,
315315
function: Callable[[Any], Any],
316316
**kwargs,
317-
) -> Union[Dataset, IterableDataset]:
317+
) -> Dataset | IterableDataset:
318318
"""
319319
Wrapper function around Dataset.map and IterableDataset.map.
320320
@@ -336,7 +336,7 @@ def map(
336336
return dataset
337337

338338

339-
def get_columns(dataset: DatasetType) -> List[str]:
339+
def get_columns(dataset: DatasetType) -> list[str]:
340340
column_names = dataset.column_names
341341
if isinstance(column_names, dict):
342342
column_names = sum(column_names.values(), [])

src/llmcompressor/transformers/finetune/data/data_helpers.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import os
3-
from typing import Any, Dict, Optional
3+
from typing import Any
44

55
from datasets import Dataset, load_dataset
66

@@ -15,8 +15,8 @@
1515

1616
def get_raw_dataset(
1717
dataset_args,
18-
cache_dir: Optional[str] = None,
19-
streaming: Optional[bool] = False,
18+
cache_dir: str | None = None,
19+
streaming: bool | None = False,
2020
**kwargs,
2121
) -> Dataset:
2222
"""
@@ -37,7 +37,7 @@ def get_raw_dataset(
3737
return raw_datasets
3838

3939

40-
def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str]:
40+
def get_custom_datasets_from_path(path: str, ext: str = "json") -> dict[str, str]:
4141
"""
4242
Get a dictionary of custom datasets from a directory path. Support HF's load_dataset
4343
for local folder datasets https://huggingface.co/docs/datasets/loading
@@ -105,7 +105,7 @@ def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str
105105
return transform_dataset_keys(data_files)
106106

107107

108-
def transform_dataset_keys(data_files: Dict[str, Any]):
108+
def transform_dataset_keys(data_files: dict[str, Any]):
109109
"""
110110
Transform dict keys to `train`, `val` or `test` for the given input dict
111111
if matches exist with the existing keys. Note that there can only be one

src/llmcompressor/transformers/finetune/data/peoples_speech.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from copy import deepcopy
2-
from typing import TYPE_CHECKING, Any, Dict
2+
from typing import TYPE_CHECKING, Any
33

44
from datasets.formatting.formatting import LazyRow
55
from loguru import logger
@@ -68,7 +68,7 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
6868
else:
6969
return super().filter_tokenizer_args(dataset)
7070

71-
def tokenize(self, data: LazyRow) -> Dict[str, Any]:
71+
def tokenize(self, data: LazyRow) -> dict[str, Any]:
7272
if self.processor_type == "WhisperProcessor":
7373
inputs = self.processor(
7474
audio=data["audio"],

src/llmcompressor/transformers/finetune/session_mixin.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import os
44
from dataclasses import asdict
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
5+
from typing import TYPE_CHECKING, Any, Tuple
66

77
import torch
88
from loguru import logger
@@ -56,9 +56,9 @@ def __init__(
5656
self,
5757
recipe: str,
5858
model_args: "ModelArguments",
59-
dataset_args: Optional["DatasetArguments"] = None,
60-
teacher: Optional[Union[Module, str]] = None,
61-
recipe_args: Optional[Union[Dict[str, Any], str]] = None,
59+
dataset_args: "DatasetArguments" | None = None,
60+
teacher: Module | str | None = None,
61+
recipe_args: dict[str, Any] | str | None = None,
6262
**kwargs,
6363
):
6464
self.recipe = recipe
@@ -125,8 +125,8 @@ def __init__(
125125
def initialize_session(
126126
self,
127127
epoch: float,
128-
checkpoint: Optional[str] = None,
129-
stage: Optional[str] = None,
128+
checkpoint: str | None = None,
129+
stage: str | None = None,
130130
):
131131
"""
132132
Initialize the CompressionSession from the specified epoch, evaluates the recipe
@@ -251,8 +251,8 @@ def create_scheduler(
251251
def training_step(
252252
self,
253253
model: torch.nn.Module,
254-
inputs: Dict[str, Union[torch.Tensor, Any]],
255-
num_items_in_batch: Optional[int] = None,
254+
inputs: dict[str, torch.Tensor | Any],
255+
num_items_in_batch: int | None = None,
256256
) -> torch.Tensor:
257257
"""
258258
Overrides the Trainer's training step to trigger the batch_start callback to
@@ -274,10 +274,10 @@ def training_step(
274274
def compute_loss(
275275
self,
276276
model: Module,
277-
inputs: Dict[str, Any],
277+
inputs: dict[str, Any],
278278
return_outputs: bool = False,
279-
num_items_in_batch: Optional[torch.Tensor] = None,
280-
) -> Union[torch.Tensor, Tuple[torch.Tensor, Any]]:
279+
num_items_in_batch: torch.Tensor | None = None,
280+
) ->torch.Tensor | Tuple[torch.Tensor, Any]:
281281
"""
282282
Override for the compute_loss to factor trigger callbacks and filter columns
283283
@@ -326,7 +326,7 @@ def compute_loss(
326326

327327
return loss
328328

329-
def train(self, *args, stage: Optional[str] = None, **kwargs):
329+
def train(self, *args, stage: str | None = None, **kwargs):
330330
"""
331331
Run a sparsification training cycle. Runs initialization for the sparse session
332332
before calling super().train() and finalization of the session after.
@@ -370,7 +370,7 @@ def save_model(
370370
self,
371371
output_dir: str,
372372
_internal_call: bool = False,
373-
skip_sparsity_compression_stats: Optional[bool] = True,
373+
skip_sparsity_compression_stats: bool | None = True,
374374
):
375375
"""
376376
Override of the save_model function and expects it to exist in the parent.
@@ -478,10 +478,10 @@ def _prepare_model_for_fsdp(self):
478478

479479
def _extract_metadata(
480480
self,
481-
metadata_args: List[str],
482-
training_args_dict: Dict[str, Any],
483-
dataset_args_dict: Dict[str, Any],
484-
) -> Dict[str, Any]:
481+
metadata_args: list[str],
482+
training_args_dict: dict[str, Any],
483+
dataset_args_dict: dict[str, Any],
484+
) -> dict[str, Any]:
485485
metadata = {}
486486
if not training_args_dict.keys().isdisjoint(dataset_args_dict.keys()):
487487
raise ValueError(
@@ -509,7 +509,7 @@ def _check_super_defined(self, func: str):
509509
f"The super class for SessionManagerMixIn must define a {func} function"
510510
)
511511

512-
def _calculate_checkpoint_info(self, kwargs) -> Tuple[Optional[str], float]:
512+
def _calculate_checkpoint_info(self, kwargs) -> Tuple[str | None, float]:
513513
"""
514514
If resuming from checkpoint is set, get checkpoint and epoch to resume from
515515
"""

0 commit comments

Comments
 (0)