3
3
from torch import Tensor
4
4
from .utils import alphabet , create_root_partition , compute_distribution_info
5
5
6
+ import copy
6
7
import distdl
7
8
import distdl .nn as dnn
8
9
import numpy as np
@@ -24,7 +25,7 @@ def __init__(self, P_x, in_features, out_features, dim=-1, bias=True, device=tor
24
25
self .out_features = out_features
25
26
self .scale = 1 / np .sqrt (in_features * out_features )
26
27
self .bias = bias
27
-
28
+
28
29
self .b_shape = [1 ]* P_x .dim
29
30
self .b_shape [dim ] = out_features
30
31
@@ -51,7 +52,7 @@ def __init__(self, P_x, in_features, out_features, dim=-1, bias=True, device=tor
51
52
52
53
def forward (self , x : Tensor ) -> Tensor :
53
54
self .dt_comm = 0
54
-
55
+
55
56
t0 = time .time ()
56
57
W = self .W_bcast (self .W )
57
58
b = self .b_bcast (self .b )
@@ -100,6 +101,15 @@ def __init__(self, P_x, in_shape, modes, device=torch.device('cpu'), dtype=torch
100
101
self .R3 = dnn .Repartition (self .P_y , self .P_m )
101
102
self .R4 = dnn .Repartition (self .P_m , self .P_x )
102
103
104
+ # Setup FFT restrictions
105
+ self .restrict_prefixes = {}
106
+ self .restrict_suffixes = {}
107
+ for dim in [* self .dim_m , * self .dim_y ]:
108
+ mode = modes [dim - 2 ]
109
+ self .restrict_prefixes [dim ] = mode
110
+ if dim != self .dim_m [- 1 ]:
111
+ self .restrict_suffixes [dim ] = mode
112
+
103
113
# Setup weights
104
114
self .scale = 1 / (self .width * self .width )
105
115
@@ -117,11 +127,15 @@ def make_slice(bounds):
117
127
# dimensions and low modes in the time dimension
118
128
self .weights = nn .ParameterList ([])
119
129
self .slices = []
120
-
130
+
121
131
fft_shape = [* in_shape [:- 1 ], in_shape [- 1 ]// 2 ]
132
+ for dim , restriction in self .restrict_prefixes .items ():
133
+ fft_shape [dim ] = restriction
134
+ for dim , restriction in self .restrict_suffixes .items ():
135
+ fft_shape [dim ] += restriction
122
136
info = compute_distribution_info (self .P_y , fft_shape )
123
137
for i in range (2 ** (self .n - 1 )):
124
-
138
+
125
139
s = bin (i )[2 :].zfill (self .n )
126
140
bounds = []
127
141
@@ -135,7 +149,7 @@ def make_slice(bounds):
135
149
bounds .append ((max (0 , start )- start , min (mode , stop )- start ))
136
150
else :
137
151
bounds .append ((max (dim_size - mode , start )- start , min (dim_size , stop )- start ))
138
-
152
+
139
153
bounds = list (reversed (bounds ))
140
154
valid = True
141
155
for a , b in bounds :
@@ -161,34 +175,114 @@ def make_slice(bounds):
161
175
162
176
self .dt_comm = 0
163
177
178
+ def restrict (self , x : Tensor , dim : int ) -> Tensor :
179
+ '''Discard unused higher-frequency elements.'''
180
+ if dim not in self .restrict_prefixes and dim not in self .restrict_suffixes :
181
+ # nothing to restrict.
182
+ return y
183
+
184
+ pieces = []
185
+ sl = [slice (None ,None ,1 )] * len (x .shape )
186
+
187
+ if dim in self .restrict_prefixes :
188
+ # add the prefix block
189
+ sl [dim ] = slice (None , self .restrict_prefixes [dim ], 1 )
190
+ pieces .append (x [sl ])
191
+
192
+ if dim in self .restrict_suffixes :
193
+ # add the suffix block
194
+ sl [dim ] = slice (- self .restrict_suffixes [dim ], None , 1 )
195
+ pieces .append (x [sl ])
196
+
197
+ if len (pieces ) == 1 :
198
+ # only keeping a single piece
199
+ return pieces [0 ]
200
+
201
+ # multiple pieces, concatenate them
202
+ x = torch .cat (pieces , dim = dim )
203
+ return x
204
+
205
+ def zeropad (self , y : Tensor , dim : int , target_shape : list ) -> Tensor :
206
+ '''Fill in zeroes for higher-frequency elements.'''
207
+
208
+ if dim not in self .restrict_prefixes and dim not in self .restrict_suffixes :
209
+ # nothing was restricted; nothing to zero-pad.
210
+ return y
211
+
212
+ # pad up to the target shape
213
+ pad_shape = copy .copy (target_shape )
214
+ pad_shape [dim ] -= y .shape [dim ]
215
+ for i in pad_shape :
216
+ if i < 1 :
217
+ # pad is empty
218
+ return y
219
+
220
+ # build an array of pieces: the prefix if any, then zeroes, then suffix if any
221
+ pieces = []
222
+ sl = [slice (None ,None ,1 )] * len (y .shape )
223
+
224
+ if dim in self .restrict_prefixes :
225
+ # add the prefix block
226
+ sl [dim ] = slice (None , self .restrict_prefixes [dim ], 1 )
227
+ pieces .append (y [sl ])
228
+
229
+ pieces .append (torch .zeros (pad_shape , dtype = y .dtype , layout = y .layout , device = y .device ))
230
+
231
+ if dim in self .restrict_suffixes :
232
+ # add the suffix block
233
+ sl [dim ] = slice (- self .restrict_suffixes [dim ], None , 1 )
234
+ pieces .append (y [sl ])
235
+
236
+ # assemble the pieces
237
+ y = torch .cat (pieces , dim = dim )
238
+
239
+ return y
240
+
164
241
def forward (self , x : Tensor ) -> Tensor :
165
242
self .dt_comm = 0
166
243
167
244
y0 = self .linear (x )
168
-
245
+
169
246
t0 = time .time ()
170
247
x = self .R1 (x )
171
248
self .dt_comm += (time .time ()- t0 )
172
249
173
- x = torch .fft .rfftn (x , dim = tuple (self .dim_m ))
250
+ saved_shapes = {}
251
+ outermost_dim = self .dim_m [- 1 ]
252
+ x = torch .fft .rfft (x , dim = outermost_dim )
253
+ saved_shapes [outermost_dim ] = list (x .shape )
254
+ x = self .restrict (x , outermost_dim )
255
+ for dim in reversed (self .dim_m [:- 1 ]):
256
+ x = torch .fft .fft (x , dim = dim )
257
+ saved_shapes [dim ] = list (x .shape )
258
+ x = self .restrict (x , dim )
174
259
175
260
t0 = time .time ()
176
261
x = self .R2 (x )
177
262
self .dt_comm += (time .time ()- t0 )
178
263
179
- x = torch .fft .fftn (x , dim = tuple (self .dim_y ))
264
+ for dim in reversed (self .dim_y ):
265
+ x = torch .fft .fft (x , dim = dim )
266
+ saved_shapes [dim ] = list (x .shape )
267
+ x = self .restrict (x , dim )
180
268
181
269
y = 0 * x .clone ()
182
270
for w , sl in zip (self .weights , self .slices ):
183
271
y [sl ] = torch .einsum (self .eqn , x [sl ], w )
184
272
185
- y = torch .fft .ifftn (y , dim = tuple (self .dim_y ))
273
+ for dim in self .dim_y :
274
+ y = self .zeropad (y , dim , saved_shapes [dim ])
275
+ y = torch .fft .ifft (y , dim = dim )
186
276
187
277
t0 = time .time ()
188
278
y = self .R3 (y )
189
279
self .dt_comm += (time .time ()- t0 )
190
280
191
- y = torch .fft .irfftn (y , dim = tuple (self .dim_m ))
281
+ for dim in self .dim_m [:- 1 ]:
282
+ y = self .zeropad (y , dim , saved_shapes [dim ])
283
+ y = torch .fft .ifft (y , dim = dim )
284
+ y = self .zeropad (y , outermost_dim , saved_shapes [outermost_dim ])
285
+ y = torch .fft .irfft (y , dim = outermost_dim )
192
286
193
287
t0 = time .time ()
194
288
y = self .R4 (y )
@@ -222,7 +316,7 @@ def __init__(self, P_x, in_shape, out_timesteps, width, modes, num_blocks=4, dev
222
316
DistributedFNOBlock (
223
317
self .P_x ,
224
318
self .block_in_shape ,
225
- self .modes ,
319
+ self .modes ,
226
320
device = device ,
227
321
dtype = dtype
228
322
) for _ in range (num_blocks )
@@ -261,7 +355,7 @@ def forward(self, x: Tensor) -> Tensor:
261
355
if __name__ == '__main__' :
262
356
263
357
from utils import get_env , create_standard_partitions
264
-
358
+
265
359
P_shape = (1 , 1 , 2 , 2 , 1 , 1 )
266
360
P_world , P_x , P_root = create_standard_partitions (P_shape )
267
361
num_gpus = 1
@@ -279,7 +373,7 @@ def forward(self, x: Tensor) -> Tensor:
279
373
network = DistributedFNO (P_x , in_shape , nt , width , modes , num_blocks = 4 , device = x .device , dtype = x .dtype )
280
374
criterion = dnn .DistributedMSELoss (P_x )
281
375
y = network (x )
282
-
376
+
283
377
for i in range (10 ):
284
378
t0 = time .time ()
285
379
y = network (x )
@@ -288,7 +382,7 @@ def forward(self, x: Tensor) -> Tensor:
288
382
289
383
loss = criterion (y , torch .rand_like (y ))
290
384
P_x ._comm .Barrier ()
291
-
385
+
292
386
t0 = time .time ()
293
387
loss .backward ()
294
388
t1 = time .time ()
0 commit comments