4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import argparse
7
8
import os .path
8
9
import runpy
9
10
import subprocess
10
- from typing import List
11
+ from typing import List , Tuple
11
12
12
13
# required env vars:
13
14
# CU_VERSION: E.g. cu112
23
24
source_root_dir = os .environ ["PWD" ]
24
25
25
26
26
- def version_constraint (version ):
27
+ def version_constraint (version ) -> str :
27
28
"""
28
29
Given version "11.3" returns " >=11.3,<11.4"
29
30
"""
@@ -32,7 +33,7 @@ def version_constraint(version):
32
33
return f" >={ version } ,<{ upper } "
33
34
34
35
35
- def get_cuda_major_minor ():
36
+ def get_cuda_major_minor () -> Tuple [ str , str ] :
36
37
if CU_VERSION == "cpu" :
37
38
raise ValueError ("fn only for cuda builds" )
38
39
if len (CU_VERSION ) != 5 or CU_VERSION [:2 ] != "cu" :
@@ -42,11 +43,10 @@ def get_cuda_major_minor():
42
43
return major , minor
43
44
44
45
45
- def setup_cuda () :
46
+ def setup_cuda (use_conda_cuda : bool ) -> List [ str ] :
46
47
if CU_VERSION == "cpu" :
47
- return
48
+ return []
48
49
major , minor = get_cuda_major_minor ()
49
- os .environ ["CUDA_HOME" ] = f"/usr/local/cuda-{ major } .{ minor } /"
50
50
os .environ ["FORCE_CUDA" ] = "1"
51
51
52
52
basic_nvcc_flags = (
@@ -75,6 +75,15 @@ def setup_cuda():
75
75
76
76
if os .environ .get ("JUST_TESTRUN" , "0" ) != "1" :
77
77
os .environ ["NVCC_FLAGS" ] = nvcc_flags
78
+ if use_conda_cuda :
79
+ os .environ ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT1" ] = "- cuda-toolkit"
80
+ os .environ ["CONDA_CUDA_TOOLKIT_BUILD_CONSTRAINT2" ] = (
81
+ f"- cuda-version={ major } .{ minor } "
82
+ )
83
+ return ["-c" , f"nvidia/label/cuda-{ major } .{ minor } .0" ]
84
+ else :
85
+ os .environ ["CUDA_HOME" ] = f"/usr/local/cuda-{ major } .{ minor } /"
86
+ return []
78
87
79
88
80
89
def setup_conda_pytorch_constraint () -> List [str ]:
@@ -95,7 +104,7 @@ def setup_conda_pytorch_constraint() -> List[str]:
95
104
return ["-c" , "pytorch" , "-c" , "nvidia" ]
96
105
97
106
98
- def setup_conda_cudatoolkit_constraint ():
107
+ def setup_conda_cudatoolkit_constraint () -> None :
99
108
if CU_VERSION == "cpu" :
100
109
os .environ ["CONDA_CPUONLY_FEATURE" ] = "- cpuonly"
101
110
os .environ ["CONDA_CUDATOOLKIT_CONSTRAINT" ] = ""
@@ -116,7 +125,7 @@ def setup_conda_cudatoolkit_constraint():
116
125
os .environ ["CONDA_CUDATOOLKIT_CONSTRAINT" ] = toolkit
117
126
118
127
119
- def do_build (start_args : List [str ]):
128
+ def do_build (start_args : List [str ]) -> None :
120
129
args = start_args .copy ()
121
130
122
131
test_flag = os .environ .get ("TEST_FLAG" )
@@ -132,8 +141,16 @@ def do_build(start_args: List[str]):
132
141
133
142
134
143
if __name__ == "__main__" :
144
+ parser = argparse .ArgumentParser (description = "Build the conda package." )
145
+ parser .add_argument (
146
+ "--use-conda-cuda" ,
147
+ action = "store_true" ,
148
+ help = "get cuda from conda ignoring local cuda" ,
149
+ )
150
+ our_args = parser .parse_args ()
151
+
135
152
args = ["conda" , "build" ]
136
- setup_cuda ()
153
+ args += setup_cuda (use_conda_cuda = our_args . use_conda_cuda )
137
154
138
155
init_path = source_root_dir + "/pytorch3d/__init__.py"
139
156
build_version = runpy .run_path (init_path )["__version__" ]
0 commit comments