Skip to content

Commit 30468e9

Browse files
don't detect arch list when it has already been set & fix build-system requirments (#1802)
* fix cuda < 12.8 doesn't support 12.0 * build system doesn't need requirements
1 parent cfc1094 commit 30468e9

File tree

2 files changed

+4
-22
lines changed

2 files changed

+4
-22
lines changed

pyproject.toml

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,3 @@
11
[build-system]
2-
requires = ["setuptools >= 64",
3-
# same with requirements.txt
4-
"accelerate>=1.10.1" ,
5-
"numpy>=2.2.6" ,
6-
"torch>=2.8.0" ,
7-
"safetensors>=0.6.2" ,
8-
"transformers>=4.56.0" ,
9-
"threadpoolctl>=3.6.0" ,
10-
"packaging>=24.2" ,
11-
"device-smi==0.4.1" ,
12-
"protobuf>=6.32.0" ,
13-
"pillow>=11.3.0" ,
14-
"hf_transfer>=0.1.9" ,
15-
"huggingface_hub>=0.34.4" ,
16-
"random_word==1.0.13" ,
17-
"tokenicer>=0.0.5" ,
18-
"logbar==0.0.4" ,
19-
"wheel>=0.45.1" ,
20-
"maturin>=1.9.3" ,
21-
]
2+
requires = ["setuptools >= 64"]
223
build-backend = "setuptools.build_meta:__legacy__"

setup.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,9 @@ def get_version_tag() -> str:
222222

223223
TORCH_VERSION = _read_env("TORCH_VERSION")
224224
RELEASE_MODE = _read_env("RELEASE_MODE")
225-
CUDA_VERSION = _read_env("TORCH_VERSION")
225+
CUDA_VERSION = _read_env("CUDA_VERSION")
226226
ROCM_VERSION = _read_env("ROCM_VERSION")
227+
TORCH_CUDA_ARCH_LIST = _read_env("TORCH_CUDA_ARCH_LIST")
227228

228229

229230
# respect user env then detect
@@ -257,7 +258,7 @@ def get_version_tag() -> str:
257258
# Handle CUDA_ARCH_LIST (public) and set TORCH_CUDA_ARCH_LIST for build toolchains
258259
CUDA_ARCH_LIST = _detect_cuda_arch_list() if (BUILD_CUDA_EXT == "1" and not ROCM_VERSION) else None
259260

260-
if CUDA_ARCH_LIST:
261+
if not TORCH_CUDA_ARCH_LIST and CUDA_ARCH_LIST:
261262
archs = _parse_arch_list(CUDA_ARCH_LIST)
262263
kept = []
263264
for arch in archs:

0 commit comments

Comments
 (0)