3
3
#
4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
-
6
+ import argparse
7
7
import os .path
8
8
import runpy
9
9
import subprocess
10
- from typing import List
10
+ from typing import List , Tuple
11
11
12
12
# required env vars:
13
13
# CU_VERSION: E.g. cu112
23
23
source_root_dir = os .environ ["PWD" ]
24
24
25
25
26
- def version_constraint (version ):
26
+ def version_constraint (version ) -> str :
27
27
"""
28
28
Given version "11.3" returns " >=11.3,<11.4"
29
29
"""
@@ -32,7 +32,7 @@ def version_constraint(version):
32
32
return f" >={ version } ,<{ upper } "
33
33
34
34
35
- def get_cuda_major_minor ():
35
+ def get_cuda_major_minor () -> Tuple [ str , str ] :
36
36
if CU_VERSION == "cpu" :
37
37
raise ValueError ("fn only for cuda builds" )
38
38
if len (CU_VERSION ) != 5 or CU_VERSION [:2 ] != "cu" :
@@ -42,11 +42,17 @@ def get_cuda_major_minor():
42
42
return major , minor
43
43
44
44
45
- def setup_cuda () :
45
+ def setup_cuda (use_conda_cuda : bool ) -> None :
46
46
if CU_VERSION == "cpu" :
47
47
return
48
48
major , minor = get_cuda_major_minor ()
49
- os .environ ["CUDA_HOME" ] = f"/usr/local/cuda-{ major } .{ minor } /"
49
+ if use_conda_cuda :
50
+ os .environ ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1" ] = "- cudatoolkit"
51
+ os .environ ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2" ] = (
52
+ f"- cuda-version={ major } .{ minor } "
53
+ )
54
+ else :
55
+ os .environ ["CUDA_HOME" ] = f"/usr/local/cuda-{ major } .{ minor } /"
50
56
os .environ ["FORCE_CUDA" ] = "1"
51
57
52
58
basic_nvcc_flags = (
@@ -95,7 +101,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
95
101
return ["-c" , "pytorch" , "-c" , "nvidia" ]
96
102
97
103
98
- def setup_conda_cudatoolkit_constraint ():
104
+ def setup_conda_cudatoolkit_constraint () -> None :
99
105
if CU_VERSION == "cpu" :
100
106
os .environ ["CONDA_CPUONLY_FEATURE" ] = "- cpuonly"
101
107
os .environ ["CONDA_CUDATOOLKIT_CONSTRAINT" ] = ""
@@ -116,14 +122,25 @@ def setup_conda_cudatoolkit_constraint():
116
122
os .environ ["CONDA_CUDATOOLKIT_CONSTRAINT" ] = toolkit
117
123
118
124
119
- def do_build (start_args : List [str ]):
125
+ def do_build (start_args : List [str ]) -> None :
120
126
args = start_args .copy ()
121
127
122
128
test_flag = os .environ .get ("TEST_FLAG" )
123
129
if test_flag is not None :
124
130
args .append (test_flag )
125
131
126
- args .extend (["-c" , "bottler" , "-c" , "iopath" , "-c" , "conda-forge" ])
132
+ args .extend (
133
+ [
134
+ "-c" ,
135
+ "bottler" ,
136
+ "-c" ,
137
+ "iopath" ,
138
+ "-c" ,
139
+ "conda-forge" ,
140
+ "-c" ,
141
+ "nvidia/label/cuda-12.1.0" ,
142
+ ]
143
+ )
127
144
args .append ("--no-anaconda-upload" )
128
145
args .extend (["--python" , os .environ ["PYTHON_VERSION" ]])
129
146
args .append ("packaging/pytorch3d" )
@@ -132,8 +149,16 @@ def do_build(start_args: List[str]):
132
149
133
150
134
151
if __name__ == "__main__" :
152
+ parser = argparse .ArgumentParser (description = "Build the conda package." )
153
+ parser .add_argument (
154
+ "--use-conda-cuda" ,
155
+ action = "store_true" ,
156
+ help = "get cuda from conda ignoring local cuda" ,
157
+ )
158
+ our_args = parser .parse_args ()
159
+
135
160
args = ["conda" , "build" ]
136
- setup_cuda ()
161
+ setup_cuda (use_conda_cuda = our_args . use_conda_cuda )
137
162
138
163
init_path = source_root_dir + "/pytorch3d/__init__.py"
139
164
build_version = runpy .run_path (init_path )["__version__" ]
0 commit comments