Skip to content

Commit 9cdc649

Browse files
Only disable cudnn on newer AMD GPUs. (#10437)
1 parent 560b1bd commit 9cdc649

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

comfy/model_management.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,15 +330,21 @@ def amd_min_version(device=None, min_rdna_version=0):
330330

331331

332332
SUPPORT_FP8_OPS = args.supports_fp8_compute
333+
334+
AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]
335+
333336
try:
334337
if is_amd():
335-
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
336-
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
338+
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
339+
if not (any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH)):
340+
torch.backends.cudnn.enabled = False # Seems to improve things a lot on AMD
341+
logging.info("Set: torch.backends.cudnn.enabled = False for better AMD performance.")
342+
337343
try:
338344
rocm_version = tuple(map(int, str(torch.version.hip).split(".")[:2]))
339345
except:
340346
rocm_version = (6, -1)
341-
arch = torch.cuda.get_device_properties(get_torch_device()).gcnArchName
347+
342348
logging.info("AMD arch: {}".format(arch))
343349
logging.info("ROCm version: {}".format(rocm_version))
344350
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
@@ -1331,7 +1337,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
13311337

13321338
if is_amd():
13331339
arch = torch.cuda.get_device_properties(device).gcnArchName
1334-
if any((a in arch) for a in ["gfx1030", "gfx1031", "gfx1010", "gfx1011", "gfx1012", "gfx906", "gfx900", "gfx803"]): # RDNA2 and older don't support bf16
1340+
if any((a in arch) for a in AMD_RDNA2_AND_OLDER_ARCH): # RDNA2 and older don't support bf16
13351341
if manual_cast:
13361342
return True
13371343
return False

0 commit comments

Comments
 (0)