Skip to content

Commit 64ffc7e

Browse files
authored
[amd] Fix amd build and run the benchmark (#665)
1 parent d442b85 commit 64ffc7e

File tree

10 files changed

+54
-31
lines changed

10 files changed

+54
-31
lines changed

benchmarks/nightly/autogen.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,3 +155,7 @@ vector_exp_bwd:
155155
welford_fwd:
156156
args: --op welford --baseline eager_layer_norm --metrics latency,speedup --only
157157
test_no_welford,triton_welford,eager_layer_norm
158+
bf16_flex_attention_fwd:
159+
args: --op flex_attention --metrics latency,tflops --only compiled
160+
bf16_flex_attention_bwd:
161+
args: --op flex_attention --metrics latency,tflops --only compiled

benchmarks/nightly/manual.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ extra_args:
2222
# flash_attention triton_tutorial_flash_v2 impl only supports causal in backward
2323
bf16_flash_attention_bwd:
2424
args: --op flash_attention --baseline flash_v3 --metrics latency,tflops,speedup --bwd --only triton_tutorial_flash_v2,flash_v3 --causal
25+
bf16_flex_attention_fwd:
26+
args: --op flex_attention --metrics latency,tflops --only compiled
27+
bf16_flex_attention_bwd:
28+
args: --op flex_attention --metrics latency,tflops --only compiled

benchmarks/tagging/run.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -223,13 +223,6 @@ def trace_op(op):
223223
return op_with_tags
224224

225225

226-
UNSUPPORTED_OPS = [
227-
"fp8_fused_quant_gemm_rowwise",
228-
"fp32_to_mx4",
229-
"flex_attention",
230-
"mx4_to_fp32",
231-
]
232-
233226
if __name__ == "__main__":
234227
parser = get_parser()
235228
args = parser.parse_args()
@@ -240,9 +233,6 @@ def trace_op(op):
240233
print(f"Running tagging test on ops: {ops}...")
241234
results = {}
242235
for op in ops:
243-
# deadloop on flex_attention
244-
if op in UNSUPPORTED_OPS:
245-
continue
246236
results.update(trace_op(op))
247237
if not args.output:
248238
print(results)

benchmarks/tritonparse_sweep/run.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ def setup_tritonbench_cwd():
4141
setup_tritonbench_cwd()
4242

4343
import tritonparse
44-
from tritonparse.reproducer.orchestrator import reproduce as tritonparse_reproduce
45-
from tritonparse.reproducer.types import KernelImportMode
4644
from tritonbench.operators_collection import list_operators_by_collection
4745
from tritonbench.utils.run_utils import run_in_task, setup_output_dir
46+
from tritonparse.reproducer.orchestrator import reproduce as tritonparse_reproduce
47+
from tritonparse.reproducer.types import KernelImportMode
4848

4949
NOT_WORKING_OPS = ["tritonparse_softmax_triton_softmax"]
5050

@@ -93,12 +93,12 @@ def find_ndjson_files(log_dir):
9393

9494

9595
def find_reproducer_script(output: str):
96-
output_line: list[str] = [ x for x in output.splitlines() if "repro_script" in x ]
96+
output_line: list[str] = [x for x in output.splitlines() if "repro_script" in x]
9797
if len(output_line) == 0:
9898
return None
99-
output_line = output_line[0][output_line[0].find("{"):].strip()
99+
output_line = output_line[0][output_line[0].find("{") :].strip()
100100
output_dict = eval(output_line)
101-
return output_dict['repro_script']
101+
return output_dict["repro_script"]
102102

103103

104104
def run_repro_script(repro_script):

install.py

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,28 +56,51 @@ def install_jax(cuda_version=DEFAULT_CUDA_VERSION):
5656
def install_fbgemm(genai=True):
5757
cmd = ["pip", "install", "-r", "requirements.txt"]
5858
subprocess.check_call(cmd, cwd=str(FBGEMM_PATH.resolve()))
59-
# Build target A100(8.0) or H100(9.0, 9.0a)
59+
# Build target H100(9.0, 9.0a) and blackwell (10.0, 12.0)
60+
extra_envs = os.environ.copy()
6061
if genai:
61-
cmd = [
62-
sys.executable,
63-
"setup.py",
64-
"install",
65-
"--build-target=genai",
66-
"-DTORCH_CUDA_ARCH_LIST=8.0;9.0;9.0a",
67-
]
62+
if not is_hip():
63+
cmd = [
64+
sys.executable,
65+
"setup.py",
66+
"install",
67+
"--build-target=genai",
68+
"-DTORCH_CUDA_ARCH_LIST=9.0;9.0a;10.0;12.0",
69+
]
70+
elif is_hip():
71+
# build for MI300(gfx942) and MI350(gfx950)
72+
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
73+
cmd = [
74+
"bash",
75+
"-c",
76+
f". .github/scripts/setup_env.bash; test_fbgemm_gpu_build_and_install {current_conda_env} genai/rocm",
77+
]
78+
extra_envs["BUILD_ROCM_VERSION"] = "7.0"
79+
subprocess.check_call(
80+
cmd, cwd=str(FBGEMM_PATH.parent.resolve()), env=extra_envs
81+
)
82+
return
6883
else:
6984
cmd = [
7085
sys.executable,
7186
"setup.py",
7287
"install",
7388
"--build-target=cuda",
74-
"-DTORCH_CUDA_ARCH_LIST=8.0;9.0;9.0a",
89+
"-DTORCH_CUDA_ARCH_LIST=9.0;9.0a;10.0;12.0",
7590
]
76-
subprocess.check_call(cmd, cwd=str(FBGEMM_PATH.resolve()))
91+
subprocess.check_call(cmd, cwd=str(FBGEMM_PATH.resolve()), env=extra_envs)
7792

7893

7994
def test_fbgemm():
8095
print("Checking fbgemm_gpu installation...", end="")
96+
# test triton
97+
cmd = [
98+
sys.executable,
99+
"-c",
100+
"import fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm",
101+
]
102+
subprocess.check_call(cmd)
103+
# test genai (cutlass or ck)
81104
cmd = [sys.executable, "-c", "import fbgemm_gpu.experimental.gen_ai"]
82105
subprocess.check_call(cmd)
83106
print("OK")
@@ -118,6 +141,8 @@ def setup_hip(args: argparse.Namespace):
118141
# We have to disable all third-parties that donot support hip/rocm
119142
args.all = False
120143
args.liger = True
144+
args.aiter = True
145+
args.fbgemm = True
121146

122147

123148
if __name__ == "__main__":

submodules/aiter

Submodule aiter updated 3104 files

tools/aiter/install.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,5 @@ def pip_install_requirements():
1616

1717
def install_aiter():
1818
pip_install_requirements()
19-
cmd = ["python", "setup.py", "develop"]
19+
cmd = ["pip", "install", "-e", "."]
2020
subprocess.check_call(cmd, cwd=AITER_PATH)

tritonbench/operators/fp8_gemm_rowwise/operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
100100
HAS_CUTLASS_OR_CK = is_hip() or (
101101
is_cuda() and get_nvidia_gpu_model() != "NVIDIA B200"
102102
)
103-
except (ImportError, AttributeError, FileNotFoundError):
103+
except (ImportError, AttributeError, FileNotFoundError, OSError):
104104
HAS_CUTLASS_OR_CK = False
105105

106106
try:
@@ -111,7 +111,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
111111

112112
# TODO: remove these b200 hacks.
113113
HAS_CUBLAS = is_cuda() and get_nvidia_gpu_model() != "NVIDIA B200"
114-
except (ImportError, IOError, AttributeError, FileNotFoundError):
114+
except (ImportError, IOError, AttributeError, FileNotFoundError, OSError):
115115
HAS_CUBLAS = False
116116

117117

tritonbench/operators/fp8_gemm_rowwise_grouped/operator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def parse_args(args: List[str]) -> argparse.Namespace:
180180
cutlass_or_ck_fp8_grouped_mm = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked
181181
# Set HAS_CUTLASS_OR_CK to True if import succeeds
182182
HAS_CUTLASS_OR_CK = True
183-
except (ImportError, AttributeError):
183+
except (ImportError, AttributeError, OSError):
184184
# Set HAS_CUTLASS_OR_CK to False if import fails
185185
HAS_CUTLASS_OR_CK = False
186186

tritonbench/utils/python_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@ def try_import(cond_name: str):
99
try:
1010
yield
1111
_caller_globals[cond_name] = True
12-
except (ImportError, ModuleNotFoundError) as e:
12+
except (ImportError, ModuleNotFoundError, OSError) as e:
1313
_caller_globals[cond_name] = False

0 commit comments

Comments
 (0)