@@ -338,6 +338,7 @@ def forward(
338
338
raise NotImplementedError (f"Invalid backend: { backend } " )
339
339
340
340
input_numel = x .numel ()
341
+ ctx .input_numel = input_numel
341
342
if input_numel == 0 :
342
343
backend = "bmm"
343
344
@@ -357,9 +358,11 @@ def forward(
357
358
if input_numel == 0 :
358
359
# if inp is empty, reshape to make grad flow.
359
360
# inp shape: (0, hdim)
360
- weight = weight .view (x .shape [- 1 ], - 1 )
361
+ output = torch .matmul (x , weight .view (x .shape [- 1 ], - 1 ))
362
+ else :
363
+ output = torch .matmul (x , weight )
361
364
362
- output = torch . matmul ( x , weight )
365
+ assert len ( output . shape ) == len ( x . shape )
363
366
364
367
assert len (output .shape ) == len (x .shape )
365
368
@@ -387,14 +390,20 @@ def backward(ctx, grad_output):
387
390
if backend == "gmm" :
388
391
grad_input , grad_weight = gmm_backward_op (x , grad_output , batch_sizes , input_weight = weight )
389
392
else :
390
- grad_weight = torch .matmul (x .transpose (- 1 , - 2 ), grad_output )
393
+ if ctx .input_numel == 0 :
394
+ grad_weight = torch .zeros_like (weight )
395
+ else :
396
+ grad_weight = torch .matmul (x .transpose (- 1 , - 2 ), grad_output )
391
397
392
398
if ctx .needs_input_grad [0 ]:
393
399
if backend == "gmm" :
394
400
if grad_input is None :
395
401
grad_input , _ = gmm_backward_op (grad_output , weight , batch_sizes , is_grad_input = True )
396
402
else :
397
- grad_input = torch .matmul (grad_output , weight .transpose (- 1 , - 2 ))
403
+ if ctx .input_numel == 0 :
404
+ grad_input = torch .zeros_like (x )
405
+ else :
406
+ grad_input = torch .matmul (grad_output , weight .transpose (- 1 , - 2 ))
398
407
399
408
return grad_input , grad_weight , None , None , None , None , None
400
409
0 commit comments