Skip to content

Commit 1ab88b9

Browse files
committed
Fix release wheel CUDA index calculation
1 parent 3fffa55 commit 1ab88b9

1 file changed

Lines changed: 2 additions & 6 deletions

File tree

.github/actions/build-pytorch-wheel/Dockerfile

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,5 @@ RUN CUDA_MAJOR_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1'}) && \
4141
# Install PyTorch
4242
RUN export MATRIX_CUDA_VERSION=$(echo $CUDA_VERSION | awk -F \. {'print $1 $2'}) && \
4343
export MATRIX_TORCH_VERSION=$(echo $TORCH_VERSION | awk -F \. {'print $1 "." $2'}) && \
44-
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
45-
minv = {'2.5': 118, '2.6': 118, '2.7': 118, '2.8': 126, '2.9': 126}[env['MATRIX_TORCH_VERSION']]; \
46-
maxv = {'2.5': 124, '2.6': 126, '2.7': 128, '2.8': 129, '2.9': 130}[env['MATRIX_TORCH_VERSION']]; \
47-
print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)" \
48-
) && \
49-
pip install --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
44+
export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; versions = {'2.5': (118, 124), '2.6': (118, 126), '2.7': (118, 128), '2.8': (126, 129), '2.9': (126, 130)}; minv, maxv = versions[env['MATRIX_TORCH_VERSION']]; print(minv if int(env['MATRIX_CUDA_VERSION']) < 120 else maxv)") && \
45+
pip install --no-cache-dir torch==${TORCH_VERSION} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}

0 commit comments

Comments
 (0)