diff --git a/models/etsformer/exponential_smoothing.py b/models/etsformer/exponential_smoothing.py index 96c167d..d48c108 100644 --- a/models/etsformer/exponential_smoothing.py +++ b/models/etsformer/exponential_smoothing.py @@ -20,7 +20,7 @@ def conv1d_fft(f, g, dim=-1): F_fg = F_f * F_g.conj() out = fft.irfft(F_fg, fast_len, dim=dim) out = out.roll((-1,), dims=(dim,)) - idx = torch.as_tensor(range(fast_len - N, fast_len)).to(out.device) + idx = torch.arange(fast_len - N, fast_len, device=out.device) out = out.index_select(dim, idx) return out