diff --git a/pypots/nn/modules/reformer/local_attention.py b/pypots/nn/modules/reformer/local_attention.py index a617b9ba..3bcfd981 100644 --- a/pypots/nn/modules/reformer/local_attention.py +++ b/pypots/nn/modules/reformer/local_attention.py @@ -13,11 +13,22 @@ from einops import rearrange from einops import repeat, pack, unpack from torch import nn, einsum -from torch.cuda.amp import autocast TOKEN_SELF_ATTN_VALUE = -5e4 +# overwrite autocast to make it compatible with both torch >=2.4 and <2.4 +def autocast(**kwargs): + if torch.__version__ >= "2.4": + from torch.cuda.amp import autocast + + return autocast(**kwargs) + else: + from torch.amp import autocast + + return autocast("cuda", **kwargs) + + def exists(val): return val is not None