@@ -330,15 +330,21 @@ def amd_min_version(device=None, min_rdna_version=0):
330330
331331
332332SUPPORT_FP8_OPS = args .supports_fp8_compute
333+
334+ AMD_RDNA2_AND_OLDER_ARCH = ["gfx1030" , "gfx1031" , "gfx1010" , "gfx1011" , "gfx1012" , "gfx906" , "gfx900" , "gfx803" ]
335+
333336try :
334337 if is_amd ():
335- torch .backends .cudnn .enabled = False # Seems to improve things a lot on AMD
336- logging .info ("Set: torch.backends.cudnn.enabled = False for better AMD performance." )
338+ arch = torch .cuda .get_device_properties (get_torch_device ()).gcnArchName
339+ if not (any ((a in arch ) for a in AMD_RDNA2_AND_OLDER_ARCH )):
340+ torch .backends .cudnn .enabled = False # Seems to improve things a lot on AMD
341+ logging .info ("Set: torch.backends.cudnn.enabled = False for better AMD performance." )
342+
337343 try :
338344 rocm_version = tuple (map (int , str (torch .version .hip ).split ("." )[:2 ]))
339345 except :
340346 rocm_version = (6 , - 1 )
341- arch = torch . cuda . get_device_properties ( get_torch_device ()). gcnArchName
347+
342348 logging .info ("AMD arch: {}" .format (arch ))
343349 logging .info ("ROCm version: {}" .format (rocm_version ))
344350 if args .use_split_cross_attention == False and args .use_quad_cross_attention == False :
@@ -1331,7 +1337,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
13311337
13321338 if is_amd ():
13331339 arch = torch .cuda .get_device_properties (device ).gcnArchName
1334- if any ((a in arch ) for a in [ "gfx1030" , "gfx1031" , "gfx1010" , "gfx1011" , "gfx1012" , "gfx906" , "gfx900" , "gfx803" ] ): # RDNA2 and older don't support bf16
1340+ if any ((a in arch ) for a in AMD_RDNA2_AND_OLDER_ARCH ): # RDNA2 and older don't support bf16
13351341 if manual_cast :
13361342 return True
13371343 return False
0 commit comments