Skip to content

Commit c234638

Browse files
mgoindbogunowiczSara Adkins
authored
Allow torch 2.3 and remove torch ceiling version restriction (#2259)
* Allow torch 2.3 and remove torch ceiling version restriction * Update base.py * transformers quant test fix --------- Co-authored-by: dbogunowicz <[email protected]> Co-authored-by: Sara Adkins <[email protected]>
1 parent fcf3c77 commit c234638

File tree

4 files changed

+10
-10
lines changed

4 files changed

+10
-10
lines changed

Diff for: docs/source/installation.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ limitations under the License.
1818

1919
This repository is tested on Python 3.8-3.11, and Linux/Debian systems.
2020
It is recommended to install in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep your system in order.
21-
Currently supported ML Frameworks are the following: `torch>=1.1.0,<1.14`, `tensorflow>=1.8.0,<=2.0.0`, `tensorflow.keras >= 2.2.0`.
21+
Currently supported ML Frameworks are the following: `torch>=1.7.0`, `tensorflow>=1.8.0,<=2.0.0`, `tensorflow.keras >= 2.2.0`.
2222

2323
Install with pip using:
2424

Diff for: setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
_onnxruntime_deps = ["onnxruntime>=1.0.0"]
6666
_clip_deps = ["open_clip_torch==2.20.0"]
67-
supported_torch_version = "torch>=1.7.0,<2.3"
67+
supported_torch_version = "torch>=1.7.0"
6868
_pytorch_deps = [
6969
supported_torch_version,
7070
"gputils",

Diff for: src/sparseml/pytorch/base.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import functools
16-
import os
1716
from typing import Optional
1817

1918
from sparseml.base import check_version
@@ -49,13 +48,11 @@
4948

5049

5150
_TORCH_MIN_VERSION = "1.0.0"
52-
# set max to 2.2.99 to account for bugfix versions with 2.2
53-
_TORCH_MAX_VERSION = os.environ.get("MAX_TORCH", "2.2.99")
5451

5552

5653
def check_torch_install(
5754
min_version: Optional[str] = _TORCH_MIN_VERSION,
58-
max_version: Optional[str] = _TORCH_MAX_VERSION,
55+
max_version: Optional[str] = None,
5956
raise_on_error: bool = True,
6057
) -> bool:
6158
"""
@@ -121,7 +118,7 @@ def check_torchvision_install(
121118

122119
def require_torch(
123120
min_version: Optional[str] = _TORCH_MIN_VERSION,
124-
max_version: Optional[str] = _TORCH_MAX_VERSION,
121+
max_version: Optional[str] = None,
125122
):
126123
"""
127124
Decorator function to require use of torch.

Diff for: tests/sparseml/transformers/compression/test_quantization.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@
5353
class TestQuantizationMatches(unittest.TestCase):
5454
old_recipe = None
5555
new_recipe = None
56-
model_stub = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
56+
# TODO: use "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T" for nightly
57+
# or weekly runs, but this smaller model is better for commit testing
58+
model_stub = "Xenova/llama2.c-stories15M"
5759
dataset = "open_platypus"
5860
old_output = "tiny_llama_old"
5961
new_output = "tiny_llama_new"
@@ -75,7 +77,7 @@ def setUpClass(cls):
7577
)
7678

7779
cls.model_new = SparseAutoModelForCausalLM.from_pretrained(
78-
cls.model_stub, device_map="cuda:1"
80+
cls.model_stub, device_map="cuda:0"
7981
)
8082
cls._run_oneshot(
8183
cls.model_new,
@@ -106,6 +108,7 @@ def _run_oneshot(model, recipe, dataset, output_dir):
106108
num_calibration_samples=num_calibration_samples,
107109
recipe=recipe,
108110
pad_to_max_length=pad_to_max_length,
111+
clear_sparse_session=True,
109112
)
110113

111114
def _get_quant_info_old(self, model):
@@ -219,7 +222,7 @@ def test_perplexity(self):
219222
for idx, sample in enumerate(dataloader):
220223
if idx >= self.num_comparisons:
221224
break
222-
output_new = self.model_new(**tensors_to_device(sample, "cuda:1"))
225+
output_new = self.model_new(**tensors_to_device(sample, "cuda:0"))
223226
output_old = self.model_old(**tensors_to_device(sample, "cuda:0"))
224227
if torch.isnan(output_old.loss) and torch.isnan(output_new.loss):
225228
continue

0 commit comments

Comments
 (0)