Skip to content

Commit 4993e06

Browse files
committed
Update tuple typehint
1 parent 5ee7dad commit 4993e06

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

src/llmcompressor/transformers/compression/helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections import defaultdict
2-
from typing import Tuple
32

43
import torch
54
from accelerate.accelerator import get_state_dict_offloaded_model
@@ -104,7 +103,7 @@ def infer_sparse_targets_and_ignores(
104103
model: torch.nn.Module,
105104
sparsity_structure: str,
106105
sparsity_threshold: float,
107-
) -> Tuple[list[str], list[str]]:
106+
) -> tuple[list[str], list[str]]:
108107
"""
109108
Infers the target and ignore layers in the given model
110109
to be used for sparsity compression
@@ -151,7 +150,7 @@ def is_sparse_compression_target(
151150

152151
def _get_sparse_targets_ignore_dicts(
153152
module: torch.nn.Module, sparsity_structure: str, sparsity_threshold: float
154-
) -> Tuple[dict[str, list[str]], dict[str, list[str]]]:
153+
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
155154
"""
156155
Get sparse targets and ignore dictionaries
157156
@@ -177,7 +176,7 @@ def _get_sparse_targets_ignore_dicts(
177176

178177
def _reduce_targets_and_ignores_into_lists(
179178
exhaustive_targets: dict[str, list[str]], exhaustive_ignore: dict[str, list[str]]
180-
) -> Tuple[list[str], list[str]]:
179+
) -> tuple[list[str], list[str]]:
181180
"""
182181
Reduces the targets and ignores dictionaries into lists
183182

src/llmcompressor/transformers/finetune/session_mixin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import os
44
from dataclasses import asdict
5-
from typing import TYPE_CHECKING, Any, Tuple
5+
from typing import TYPE_CHECKING, Any
66

77
import torch
88
from loguru import logger
@@ -277,7 +277,7 @@ def compute_loss(
277277
inputs: dict[str, Any],
278278
return_outputs: bool = False,
279279
num_items_in_batch: torch.Tensor | None = None,
280-
) ->torch.Tensor | Tuple[torch.Tensor, Any]:
280+
) ->torch.Tensor | tuple[torch.Tensor, Any]:
281281
"""
282282
Override for the compute_loss to factor trigger callbacks and filter columns
283283
@@ -509,7 +509,7 @@ def _check_super_defined(self, func: str):
509509
f"The super class for SessionManagerMixIn must define a {func} function"
510510
)
511511

512-
def _calculate_checkpoint_info(self, kwargs) -> Tuple[str | None, float]:
512+
def _calculate_checkpoint_info(self, kwargs) -> tuple[str | None, float]:
513513
"""
514514
If resuming from checkpoint is set, get checkpoint and epoch to resume from
515515
"""

0 commit comments

Comments
 (0)