Skip to content

Commit b532554

Browse files
authored
Merge pull request #3 from Infinoid/restrict
Restrict FFTs in DistributedFNOBlock
2 parents 3a480c0 + 4f0f149 commit b532554

File tree

1 file changed

+108
-14
lines changed

1 file changed

+108
-14
lines changed

dfno/dfno.py

+108-14
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch import Tensor
44
from .utils import alphabet, create_root_partition, compute_distribution_info
55

6+
import copy
67
import distdl
78
import distdl.nn as dnn
89
import numpy as np
@@ -24,7 +25,7 @@ def __init__(self, P_x, in_features, out_features, dim=-1, bias=True, device=tor
2425
self.out_features = out_features
2526
self.scale = 1/np.sqrt(in_features*out_features)
2627
self.bias = bias
27-
28+
2829
self.b_shape = [1]*P_x.dim
2930
self.b_shape[dim] = out_features
3031

@@ -51,7 +52,7 @@ def __init__(self, P_x, in_features, out_features, dim=-1, bias=True, device=tor
5152

5253
def forward(self, x: Tensor) -> Tensor:
5354
self.dt_comm = 0
54-
55+
5556
t0 = time.time()
5657
W = self.W_bcast(self.W)
5758
b = self.b_bcast(self.b)
@@ -100,6 +101,15 @@ def __init__(self, P_x, in_shape, modes, device=torch.device('cpu'), dtype=torch
100101
self.R3 = dnn.Repartition(self.P_y, self.P_m)
101102
self.R4 = dnn.Repartition(self.P_m, self.P_x)
102103

104+
# Setup FFT restrictions
105+
self.restrict_prefixes = {}
106+
self.restrict_suffixes = {}
107+
for dim in [*self.dim_m, *self.dim_y]:
108+
mode = modes[dim-2]
109+
self.restrict_prefixes[dim] = mode
110+
if dim != self.dim_m[-1]:
111+
self.restrict_suffixes[dim] = mode
112+
103113
# Setup weights
104114
self.scale = 1/(self.width*self.width)
105115

@@ -117,11 +127,15 @@ def make_slice(bounds):
117127
# dimensions and low modes in the time dimension
118128
self.weights = nn.ParameterList([])
119129
self.slices = []
120-
130+
121131
fft_shape = [*in_shape[:-1], in_shape[-1]//2]
132+
for dim, restriction in self.restrict_prefixes.items():
133+
fft_shape[dim] = restriction
134+
for dim, restriction in self.restrict_suffixes.items():
135+
fft_shape[dim] += restriction
122136
info = compute_distribution_info(self.P_y, fft_shape)
123137
for i in range(2**(self.n-1)):
124-
138+
125139
s = bin(i)[2:].zfill(self.n)
126140
bounds = []
127141

@@ -135,7 +149,7 @@ def make_slice(bounds):
135149
bounds.append((max(0, start)-start, min(mode, stop)-start))
136150
else:
137151
bounds.append((max(dim_size-mode, start)-start, min(dim_size, stop)-start))
138-
152+
139153
bounds = list(reversed(bounds))
140154
valid = True
141155
for a, b in bounds:
@@ -161,34 +175,114 @@ def make_slice(bounds):
161175

162176
self.dt_comm = 0
163177

178+
def restrict(self, x: Tensor, dim: int) -> Tensor:
179+
'''Discard unused higher-frequency elements.'''
180+
if dim not in self.restrict_prefixes and dim not in self.restrict_suffixes:
181+
# nothing to restrict.
182+
return y
183+
184+
pieces = []
185+
sl = [slice(None,None,1)] * len(x.shape)
186+
187+
if dim in self.restrict_prefixes:
188+
# add the prefix block
189+
sl[dim] = slice(None, self.restrict_prefixes[dim], 1)
190+
pieces.append(x[sl])
191+
192+
if dim in self.restrict_suffixes:
193+
# add the suffix block
194+
sl[dim] = slice(-self.restrict_suffixes[dim], None, 1)
195+
pieces.append(x[sl])
196+
197+
if len(pieces) == 1:
198+
# only keeping a single piece
199+
return pieces[0]
200+
201+
# multiple pieces, concatenate them
202+
x = torch.cat(pieces, dim=dim)
203+
return x
204+
205+
def zeropad(self, y: Tensor, dim: int, target_shape: list) -> Tensor:
206+
'''Fill in zeroes for higher-frequency elements.'''
207+
208+
if dim not in self.restrict_prefixes and dim not in self.restrict_suffixes:
209+
# nothing was restricted; nothing to zero-pad.
210+
return y
211+
212+
# pad up to the target shape
213+
pad_shape = copy.copy(target_shape)
214+
pad_shape[dim] -= y.shape[dim]
215+
for i in pad_shape:
216+
if i < 1:
217+
# pad is empty
218+
return y
219+
220+
# build an array of pieces: the prefix if any, then zeroes, then suffix if any
221+
pieces = []
222+
sl = [slice(None,None,1)] * len(y.shape)
223+
224+
if dim in self.restrict_prefixes:
225+
# add the prefix block
226+
sl[dim] = slice(None, self.restrict_prefixes[dim], 1)
227+
pieces.append(y[sl])
228+
229+
pieces.append(torch.zeros(pad_shape, dtype=y.dtype, layout=y.layout, device=y.device))
230+
231+
if dim in self.restrict_suffixes:
232+
# add the suffix block
233+
sl[dim] = slice(-self.restrict_suffixes[dim], None, 1)
234+
pieces.append(y[sl])
235+
236+
# assemble the pieces
237+
y = torch.cat(pieces, dim=dim)
238+
239+
return y
240+
164241
def forward(self, x: Tensor) -> Tensor:
165242
self.dt_comm = 0
166243

167244
y0 = self.linear(x)
168-
245+
169246
t0 = time.time()
170247
x = self.R1(x)
171248
self.dt_comm += (time.time()-t0)
172249

173-
x = torch.fft.rfftn(x, dim=tuple(self.dim_m))
250+
saved_shapes = {}
251+
outermost_dim = self.dim_m[-1]
252+
x = torch.fft.rfft(x, dim=outermost_dim)
253+
saved_shapes[outermost_dim] = list(x.shape)
254+
x = self.restrict(x, outermost_dim)
255+
for dim in reversed(self.dim_m[:-1]):
256+
x = torch.fft.fft(x, dim=dim)
257+
saved_shapes[dim] = list(x.shape)
258+
x = self.restrict(x, dim)
174259

175260
t0 = time.time()
176261
x = self.R2(x)
177262
self.dt_comm += (time.time()-t0)
178263

179-
x = torch.fft.fftn(x, dim=tuple(self.dim_y))
264+
for dim in reversed(self.dim_y):
265+
x = torch.fft.fft(x, dim=dim)
266+
saved_shapes[dim] = list(x.shape)
267+
x = self.restrict(x, dim)
180268

181269
y = 0*x.clone()
182270
for w, sl in zip(self.weights, self.slices):
183271
y[sl] = torch.einsum(self.eqn, x[sl], w)
184272

185-
y = torch.fft.ifftn(y, dim=tuple(self.dim_y))
273+
for dim in self.dim_y:
274+
y = self.zeropad(y, dim, saved_shapes[dim])
275+
y = torch.fft.ifft(y, dim=dim)
186276

187277
t0 = time.time()
188278
y = self.R3(y)
189279
self.dt_comm += (time.time()-t0)
190280

191-
y = torch.fft.irfftn(y, dim=tuple(self.dim_m))
281+
for dim in self.dim_m[:-1]:
282+
y = self.zeropad(y, dim, saved_shapes[dim])
283+
y = torch.fft.ifft(y, dim=dim)
284+
y = self.zeropad(y, outermost_dim, saved_shapes[outermost_dim])
285+
y = torch.fft.irfft(y, dim=outermost_dim)
192286

193287
t0 = time.time()
194288
y = self.R4(y)
@@ -222,7 +316,7 @@ def __init__(self, P_x, in_shape, out_timesteps, width, modes, num_blocks=4, dev
222316
DistributedFNOBlock(
223317
self.P_x,
224318
self.block_in_shape,
225-
self.modes,
319+
self.modes,
226320
device=device,
227321
dtype=dtype
228322
) for _ in range(num_blocks)
@@ -261,7 +355,7 @@ def forward(self, x: Tensor) -> Tensor:
261355
if __name__ == '__main__':
262356

263357
from utils import get_env, create_standard_partitions
264-
358+
265359
P_shape = (1, 1, 2, 2, 1, 1)
266360
P_world, P_x, P_root = create_standard_partitions(P_shape)
267361
num_gpus = 1
@@ -279,7 +373,7 @@ def forward(self, x: Tensor) -> Tensor:
279373
network = DistributedFNO(P_x, in_shape, nt, width, modes, num_blocks=4, device=x.device, dtype=x.dtype)
280374
criterion = dnn.DistributedMSELoss(P_x)
281375
y = network(x)
282-
376+
283377
for i in range(10):
284378
t0 = time.time()
285379
y = network(x)
@@ -288,7 +382,7 @@ def forward(self, x: Tensor) -> Tensor:
288382

289383
loss = criterion(y, torch.rand_like(y))
290384
P_x._comm.Barrier()
291-
385+
292386
t0 = time.time()
293387
loss.backward()
294388
t1 = time.time()

0 commit comments

Comments
 (0)