File tree Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -332,6 +332,7 @@ def amd_min_version(device=None, min_rdna_version=0):
332332SUPPORT_FP8_OPS = args .supports_fp8_compute
333333try :
334334 if is_amd ():
335+ torch .backends .cudnn .enabled = False # Seems to improve things a lot on AMD
335336 try :
336337 rocm_version = tuple (map (int , str (torch .version .hip ).split ("." )[:2 ]))
337338 except :
@@ -925,11 +926,7 @@ def vae_dtype(device=None, allowed_dtypes=[]):
925926 if d == torch .float16 and should_use_fp16 (device ):
926927 return d
927928
928- # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
929- # slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
930- # also a problem on RDNA4 except fp32 is also slow there.
931- # This is due to large bf16 convolutions being extremely slow.
932- if d == torch .bfloat16 and ((not is_amd ()) or amd_min_version (device , min_rdna_version = 4 )) and should_use_bf16 (device ):
929+ if d == torch .bfloat16 and should_use_bf16 (device ):
933930 return d
934931
935932 return torch .float32
You can’t perform that action at this time.
0 commit comments