@@ -56,28 +56,51 @@ def install_jax(cuda_version=DEFAULT_CUDA_VERSION):
5656def install_fbgemm (genai = True ):
5757 cmd = ["pip" , "install" , "-r" , "requirements.txt" ]
5858 subprocess .check_call (cmd , cwd = str (FBGEMM_PATH .resolve ()))
59- # Build target A100(8.0) or H100(9.0, 9.0a)
59+ # Build target H100(9.0, 9.0a) and blackwell (10.0, 12.0)
60+ extra_envs = os .environ .copy ()
6061 if genai :
61- cmd = [
62- sys .executable ,
63- "setup.py" ,
64- "install" ,
65- "--build-target=genai" ,
66- "-DTORCH_CUDA_ARCH_LIST=8.0;9.0;9.0a" ,
67- ]
62+ if not is_hip ():
63+ cmd = [
64+ sys .executable ,
65+ "setup.py" ,
66+ "install" ,
67+ "--build-target=genai" ,
68+ "-DTORCH_CUDA_ARCH_LIST=9.0;9.0a;10.0;12.0" ,
69+ ]
70+ elif is_hip ():
71+ # build for MI300(gfx942) and MI350(gfx950)
72+ current_conda_env = os .environ .get ("CONDA_DEFAULT_ENV" )
73+ cmd = [
74+ "bash" ,
75+ "-c" ,
76+ f". .github/scripts/setup_env.bash; test_fbgemm_gpu_build_and_install { current_conda_env } genai/rocm" ,
77+ ]
78+ extra_envs ["BUILD_ROCM_VERSION" ] = "7.0"
79+ subprocess .check_call (
80+ cmd , cwd = str (FBGEMM_PATH .parent .resolve ()), env = extra_envs
81+ )
82+ return
6883 else :
6984 cmd = [
7085 sys .executable ,
7186 "setup.py" ,
7287 "install" ,
7388 "--build-target=cuda" ,
74- "-DTORCH_CUDA_ARCH_LIST=8 .0;9.0;9.0a " ,
89+ "-DTORCH_CUDA_ARCH_LIST=9 .0;9.0a;10.0;12.0 " ,
7590 ]
76- subprocess .check_call (cmd , cwd = str (FBGEMM_PATH .resolve ()))
91+ subprocess .check_call (cmd , cwd = str (FBGEMM_PATH .resolve ()), env = extra_envs )
7792
7893
7994def test_fbgemm ():
8095 print ("Checking fbgemm_gpu installation..." , end = "" )
96+ # test triton
97+ cmd = [
98+ sys .executable ,
99+ "-c" ,
100+ "import fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm" ,
101+ ]
102+ subprocess .check_call (cmd )
103+ # test genai (cutlass or ck)
81104 cmd = [sys .executable , "-c" , "import fbgemm_gpu.experimental.gen_ai" ]
82105 subprocess .check_call (cmd )
83106 print ("OK" )
@@ -118,6 +141,8 @@ def setup_hip(args: argparse.Namespace):
118141 # We have to disable all third-parties that donot support hip/rocm
119142 args .all = False
120143 args .liger = True
144+ args .aiter = True
145+ args .fbgemm = True
121146
122147
123148if __name__ == "__main__" :
0 commit comments