From d81c34eb834fbd7e72a3bf7f5675578f92be5c41 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 10 Nov 2025 19:00:27 +0000 Subject: [PATCH] add support for frozen dataclasses Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/cache.py | 29 +++++++++++---------- tests/llmcompressor/pipelines/test_cache.py | 11 +++++--- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/src/llmcompressor/pipelines/cache.py b/src/llmcompressor/pipelines/cache.py index ea0d5f254..84dacd294 100644 --- a/src/llmcompressor/pipelines/cache.py +++ b/src/llmcompressor/pipelines/cache.py @@ -1,7 +1,7 @@ import sys import warnings from collections import defaultdict -from dataclasses import dataclass, fields, is_dataclass +from dataclasses import dataclass, is_dataclass, fields from typing import Any, Dict, Generator, List, Optional, Union import torch @@ -166,8 +166,8 @@ def _size_helper(intermediate: IntermediateValue) -> int: for v in value.values(): _size_helper(v) case _ if is_dataclass(value): - for field in fields(value): - _size_helper(getattr(value, field.name)) + for f in fields(value): + _size_helper(getattr(value, f.name)) case _: # this handles primitive values that don't match any other cases sizes[torch.device("cpu")] += sys.getsizeof(value, 0) @@ -211,10 +211,10 @@ def _onload_value(cls, intermediate: IntermediateValue) -> Any: case dict(): return {k: cls._onload_value(v) for k, v in value.items()} case _ if is_dataclass(value): - for field in fields(value): - v = getattr(value, field.name) - setattr(value, field.name, cls._onload_value(v)) - return value + return type(value)(**{ + f.name: cls._onload_value(getattr(value, f.name)) + for f in fields(value) + }) case _: # handles primitive values that should be returned as is. # without this, a MatchError would be raised for unhandled types. @@ -255,16 +255,17 @@ def _offload_value( ) case dict(): return IntermediateValue( - value={ - k: cls._offload_value(v, **kwargs) for k, v in value.items() - }, + value={k: cls._offload_value(v, **kwargs) for k, v in value.items()}, device=None, ) case _ if is_dataclass(value): - for field in fields(value): - v = getattr(value, field.name) - setattr(value, field.name, cls._offload_value(v, **kwargs)) - return IntermediateValue(value=value, device=None) + return IntermediateValue( + value=type(value)(**{ + f.name: cls._offload_value(getattr(value, f.name), **kwargs) + for f in fields(value) + }), + device=None + ) case _: # handles primitive values and provides a warning for unsupported types. # without this, values trigger a MatchError exception. diff --git a/tests/llmcompressor/pipelines/test_cache.py b/tests/llmcompressor/pipelines/test_cache.py index 2c5315e18..92fb36466 100644 --- a/tests/llmcompressor/pipelines/test_cache.py +++ b/tests/llmcompressor/pipelines/test_cache.py @@ -1,3 +1,5 @@ +from typing import Optional + from dataclasses import dataclass, fields, is_dataclass import pytest @@ -7,10 +9,11 @@ from llmcompressor.pipelines.cache import IntermediatesCache -@dataclass +@dataclass(frozen=True) class SampleDataclass: - a: torch.Tensor - b: int + a: int + b: Optional[torch.Tensor] = None + c: Optional["SampleDataclass"] = None @pytest.fixture @@ -36,7 +39,7 @@ def sample_cache(sample_dataloader): values_to_test = [ torch.randn(2, 3).to("cpu"), - SampleDataclass(a=torch.randn(2, 3), b=42), + SampleDataclass(a=42, b=torch.randn(2, 3), c=SampleDataclass(a=64)), torch.float32, [1, 2, 3], ]