Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import weakref
from collections.abc import Generator
from functools import wraps
from typing import Optional

import torch
from accelerate.accelerator import get_state_dict_offloaded_model
Expand Down Expand Up @@ -54,8 +53,8 @@ def save_pretrained_compressed(save_pretrained_method):
@wraps(original_save_pretrained)
def save_pretrained_wrapper(
save_directory: str,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_format: Optional[str] = None,
sparsity_config: SparsityCompressionConfig | None = None,
quantization_format: str | None = None,
save_compressed: bool = True,
safe_serialization: bool = True,
skip_sparsity_compression_stats: bool = True,
Expand Down Expand Up @@ -233,8 +232,8 @@ def untie_if_target_shared_embedding(

def get_model_compressor(
model: torch.nn.Module,
sparsity_config: Optional[SparsityCompressionConfig] = None,
quantization_format: Optional[str] = None,
sparsity_config: SparsityCompressionConfig | None = None,
quantization_format: str | None = None,
save_compressed: bool = True,
skip_sparsity_compression_stats: bool = True,
disable_sparse_compression: bool = False,
Expand Down
15 changes: 7 additions & 8 deletions src/llmcompressor/transformers/compression/helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections import defaultdict
from typing import Dict, List, Optional, Tuple

import torch
from accelerate.accelerator import get_state_dict_offloaded_model
Expand Down Expand Up @@ -51,8 +50,8 @@ def tensor_follows_mask_structure(tensor: torch.Tensor, mask: str = "2:4") -> bo


def infer_sparsity_structure_from_modifiers(
modifiers: List[Modifier], # noqa E501
) -> Optional[str]:
modifiers: list[Modifier], # noqa E501
) -> str | None:
"""
Determines the sparsity structure, if any exists, given the list of modifiers.

Expand All @@ -65,7 +64,7 @@ def infer_sparsity_structure_from_modifiers(
return None


def infer_sparsity_structure_from_model(model: torch.nn.Module) -> Optional[str]:
def infer_sparsity_structure_from_model(model: torch.nn.Module) -> str | None:
"""
Determines the sparsity structure, if any exists, given the model

Expand Down Expand Up @@ -104,7 +103,7 @@ def infer_sparse_targets_and_ignores(
model: torch.nn.Module,
sparsity_structure: str,
sparsity_threshold: float,
) -> Tuple[List[str], List[str]]:
) -> tuple[list[str], list[str]]:
"""
Infers the target and ignore layers in the given model
to be used for sparsity compression
Expand Down Expand Up @@ -151,7 +150,7 @@ def is_sparse_compression_target(

def _get_sparse_targets_ignore_dicts(
module: torch.nn.Module, sparsity_structure: str, sparsity_threshold: float
) -> Tuple[Dict[str, List[str]], Dict[str, List[str]]]:
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
"""
Get sparse targets and ignore dictionaries

Expand All @@ -176,8 +175,8 @@ def _get_sparse_targets_ignore_dicts(


def _reduce_targets_and_ignores_into_lists(
exhaustive_targets: Dict[str, List[str]], exhaustive_ignore: Dict[str, List[str]]
) -> Tuple[List[str], List[str]]:
exhaustive_targets: dict[str, list[str]], exhaustive_ignore: dict[str, list[str]]
) -> tuple[list[str], list[str]]:
"""
Reduces the targets and ignores dictionaries into lists

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict, List, Optional

from compressed_tensors import CompressionFormat, SparsityCompressionConfig
from compressed_tensors.config import SparsityStructure
from compressed_tensors.quantization import QuantizationType
Expand Down Expand Up @@ -30,7 +28,7 @@ class SparsityConfigMetadata:

@staticmethod
def infer_global_sparsity(
model: Module, state_dict: Optional[Dict[str, Tensor]] = None
model: Module, state_dict: dict[str, Tensor] | None = None
) -> float:
"""
Calculates the global percentage of sparse zero weights in the model
Expand All @@ -47,12 +45,12 @@ def infer_global_sparsity(

@staticmethod
def infer_sparsity_structure(
model: Optional[Module] = None, check_only_modifiers: Optional[bool] = False
model: Module | None = None, check_only_modifiers: bool | None = False
) -> str:
"""
Determines what sparsity structure, if any, was applied.

First, there is an attempt to dedue the sparsity structure
First, there is an attempt to deduce the sparsity structure
from the currently active sparse session.

If that fails, the sparsity structure is inferred from the
Expand Down Expand Up @@ -83,12 +81,12 @@ def infer_sparsity_structure(
@staticmethod
def from_pretrained(
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
state_dict: dict[str, Tensor] | None = None,
compress: bool = False,
quantization_format: Optional[CompressionFormat] = None,
quantization_format: CompressionFormat | None = None,
disable_sparse_compression: bool = False,
sparsity_structure: Optional[str] = None,
) -> Optional["SparsityCompressionConfig"]:
sparsity_structure: str | None = None,
) -> "SparsityCompressionConfig" | None:
"""
Determines compression type and informational parameters for a given model

Expand Down Expand Up @@ -155,7 +153,7 @@ def from_pretrained(
def fill_config_details(
config: SparsityCompressionConfig,
model: Module,
state_dict: Optional[Dict[str, Tensor]] = None,
state_dict: dict[str, Tensor] | None = None,
):
"""
Fills in informational sparsity parameters from a given model
Expand All @@ -173,7 +171,7 @@ def fill_config_details(
@staticmethod
def is_sparse24_bitmask_supported(
model: Module,
sparsity_structure: Optional[str] = None,
sparsity_structure: str | None = None,
) -> bool:
"""
Determines if sparse 24 bitmask sparse compressor is supported for a given model
Expand Down Expand Up @@ -202,7 +200,7 @@ def is_sparse24_bitmask_supported(

# when model is quantized, and has 2:4 sparsity

supported_scheme_types: List[str] = [
supported_scheme_types: list[str] = [
QuantizationType.INT.value,
QuantizationType.FLOAT.value,
]
Expand Down
16 changes: 8 additions & 8 deletions src/llmcompressor/transformers/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import inspect
from functools import cached_property
from inspect import _ParameterKind as Kind
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable

from compressed_tensors.registry import RegistryMixin
from datasets import Dataset, IterableDataset
Expand Down Expand Up @@ -202,7 +202,7 @@ def load_dataset(self):
)

@cached_property
def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
def preprocess(self) -> Callable[[LazyRow], Any] | None:
"""
The function must return keys which correspond to processor/tokenizer kwargs,
optionally including PROMPT_KEY
Expand All @@ -225,7 +225,7 @@ def preprocess(self) -> Union[Callable[[LazyRow], Any], None]:
return self.dataset_template

@property
def dataset_template(self) -> Union[Callable[[Any], Any], None]:
def dataset_template(self) -> Callable[[Any], Any] | None:
return None

def rename_columns(self, dataset: DatasetType) -> DatasetType:
Expand Down Expand Up @@ -254,7 +254,7 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
list(set(column_names) - set(tokenizer_args) - set([self.PROMPT_KEY]))
)

def tokenize(self, data: LazyRow) -> Dict[str, Any]:
def tokenize(self, data: LazyRow) -> dict[str, Any]:
# separate prompt
prompt = data.pop(self.PROMPT_KEY, None)

Expand All @@ -276,7 +276,7 @@ def tokenize(self, data: LazyRow) -> Dict[str, Any]:

return data

def group_text(self, data: LazyRow) -> Dict[str, Any]:
def group_text(self, data: LazyRow) -> dict[str, Any]:
concatenated_data = {k: sum(data[k], []) for k in data.keys()}
total_length = len(concatenated_data[list(data.keys())[0]])
total_length = (total_length // self.max_seq_length) * self.max_seq_length
Expand Down Expand Up @@ -311,10 +311,10 @@ def add_labels(self, data: LazyRow) -> LazyRow:

def map(
self,
dataset: Union[Dataset, IterableDataset],
dataset: Dataset | IterableDataset,
function: Callable[[Any], Any],
**kwargs,
) -> Union[Dataset, IterableDataset]:
) -> Dataset | IterableDataset:
"""
Wrapper function around Dataset.map and IterableDataset.map.

Expand All @@ -336,7 +336,7 @@ def map(
return dataset


def get_columns(dataset: DatasetType) -> List[str]:
def get_columns(dataset: DatasetType) -> list[str]:
column_names = dataset.column_names
if isinstance(column_names, dict):
column_names = sum(column_names.values(), [])
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/transformers/data/data_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import os
from typing import Any, Dict, Optional
from typing import Any

from datasets import Dataset, load_dataset

Expand All @@ -15,8 +15,8 @@

def get_raw_dataset(
dataset_args,
cache_dir: Optional[str] = None,
streaming: Optional[bool] = False,
cache_dir: str | None = None,
streaming: bool | None = False,
**kwargs,
) -> Dataset:
"""
Expand All @@ -37,7 +37,7 @@ def get_raw_dataset(
return raw_datasets


def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str]:
def get_custom_datasets_from_path(path: str, ext: str = "json") -> dict[str, str]:
"""
Get a dictionary of custom datasets from a directory path. Support HF's load_dataset
for local folder datasets https://huggingface.co/docs/datasets/loading
Expand Down Expand Up @@ -105,7 +105,7 @@ def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str
return transform_dataset_keys(data_files)


def transform_dataset_keys(data_files: Dict[str, Any]):
def transform_dataset_keys(data_files: dict[str, Any]):
"""
Transform dict keys to `train`, `val` or `test` for the given input dict
if matches exist with the existing keys. Note that there can only be one
Expand Down
4 changes: 2 additions & 2 deletions src/llmcompressor/transformers/data/peoples_speech.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict
from typing import TYPE_CHECKING, Any

from datasets.formatting.formatting import LazyRow
from loguru import logger
Expand Down Expand Up @@ -68,7 +68,7 @@ def filter_tokenizer_args(self, dataset: DatasetType) -> DatasetType:
else:
return super().filter_tokenizer_args(dataset)

def tokenize(self, data: LazyRow) -> Dict[str, Any]:
def tokenize(self, data: LazyRow) -> dict[str, Any]:
if self.processor_type == "WhisperProcessor":
inputs = self.processor(
audio=data["audio"],
Expand Down
14 changes: 7 additions & 7 deletions src/llmcompressor/transformers/tracing/debug.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Type, Union, Optional, Dict, Tuple, Any
from typing import Type, Tuple, Any

import argparse
from contextlib import nullcontext
Expand Down Expand Up @@ -33,13 +33,13 @@ def parse_args():
def trace(
model_id: str,
model_class: Type[PreTrainedModel],
sequential_targets: Optional[Union[List[str], str]] = None,
ignore: Union[List[str], str] = DatasetArguments().tracing_ignore,
sequential_targets: list[str] | str | None = None,
ignore: list[str] | str = DatasetArguments().tracing_ignore,
modality: str = "text",
trust_remote_code: bool = True,
skip_weights: bool = True,
device_map: Union[str, Dict] = "cpu",
) -> Tuple[PreTrainedModel, List[Subgraph], Dict[str, torch.Tensor]]:
device_map: str | dict = "cpu",
) -> Tuple[PreTrainedModel, list[Subgraph], dict[str, torch.Tensor]]:
"""
Debug traceability by tracing a pre-trained model into subgraphs

Expand Down Expand Up @@ -110,7 +110,7 @@ def trace(
return model, subgraphs, sample


def get_dataset_kwargs(modality: str, ignore: List[str]) -> Dict[str, str]:
def get_dataset_kwargs(modality: str, ignore: list[str]) -> dict[str, str]:
dataset_kwargs = {
"text": {
"dataset": "ultrachat-200k",
Expand Down Expand Up @@ -139,7 +139,7 @@ def get_dataset_kwargs(modality: str, ignore: List[str]) -> Dict[str, str]:
return dataset_kwargs[modality] | common_kwargs


def collate_sample(sample: Dict[str, Any], device: str) -> Dict[str, torch.Tensor]:
def collate_sample(sample: dict[str, Any], device: str) -> dict[str, torch.Tensor]:
for name, value in sample.items():
if name in ("input_ids", "attention_mask") and torch.tensor(value).ndim == 1:
sample[name] = torch.tensor([value], device=device)
Expand Down
6 changes: 3 additions & 3 deletions src/llmcompressor/transformers/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import os
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING

import requests
from huggingface_hub import (
Expand Down Expand Up @@ -47,7 +47,7 @@ def is_model_ct_quantized_from_path(path: str) -> bool:
return False


def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:
def infer_recipe_from_model_path(model_path: str | Path) -> str | None:
"""
Infer the recipe from the model_path.

Expand Down Expand Up @@ -95,7 +95,7 @@ def infer_recipe_from_model_path(model_path: Union[str, Path]) -> Optional[str]:

def recipe_from_huggingface_model_id(
hf_stub: str, recipe_file_name: str = RECIPE_FILE_NAME
) -> Optional[str]:
) -> str | None:
"""
Attempts to download the recipe from the Hugging Face model ID.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
popular training datasets.
"""

from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING

from compressed_tensors.registry import RegistryMixin

Expand All @@ -20,7 +20,7 @@ class PreprocessingFunctionRegistry(RegistryMixin):


@PreprocessingFunctionRegistry.register()
def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: Dict):
def custom_evolved_codealpaca_dataset(self: "TextGenerationDataset", data: dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
Expand Down