@@ -371,6 +371,26 @@ def _emb_module_forward(
371
371
lengths_or_offsets : torch .Tensor ,
372
372
weights : Optional [torch .Tensor ],
373
373
) -> torch .Tensor :
374
+ # Check if total embedding dimension is 0 (can happen in column-wise sharding)
375
+ total_D = sum (table .local_cols for table in self ._config .embedding_tables )
376
+
377
+ if total_D == 0 :
378
+ # For empty shards, return tensor with correct batch size but 0 embedding dimension
379
+ # Use tensor operations that are FX symbolic tracing compatible
380
+ if self .lengths_to_tbe :
381
+ # For lengths format, batch size equals lengths tensor size
382
+ # Create [B, 0] tensor using zeros_like and slicing
383
+ dummy = torch .zeros_like (lengths_or_offsets , dtype = torch .float )
384
+ return dummy .unsqueeze (- 1 )[:, :0 ] # [B, 0] tensor
385
+ else :
386
+ # For offsets format, batch size is one less than offset size
387
+ # Use tensor slicing to create batch dimension
388
+ batch_tensor = lengths_or_offsets [
389
+ :- 1
390
+ ] # Remove last element to get batch size
391
+ dummy = torch .zeros_like (batch_tensor , dtype = torch .float )
392
+ return dummy .unsqueeze (- 1 )[:, :0 ] # [B, 0] tensor
393
+
374
394
kwargs = {"indices" : indices }
375
395
376
396
if self .lengths_to_tbe :
@@ -600,6 +620,18 @@ def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
600
620
else :
601
621
values , offsets , _ = _unwrap_kjt (features )
602
622
623
+ # Check if total embedding dimension is 0
624
+ total_D = sum (table .local_cols for table in self ._config .embedding_tables )
625
+
626
+ if total_D == 0 :
627
+ # For empty shards, return tensor with correct batch size but 0 embedding dimension
628
+ # Use tensor operations that are FX symbolic tracing compatible
629
+ # For offsets format, batch size is one less than offset size
630
+ # Use tensor slicing to create batch dimension
631
+ batch_tensor = offsets [:- 1 ] # Remove last element to get batch size
632
+ dummy = torch .zeros_like (batch_tensor , dtype = torch .float )
633
+ return dummy .unsqueeze (- 1 )[:, :0 ] # [B, 0] tensor
634
+
603
635
if self ._emb_module_registered :
604
636
return self .emb_module (
605
637
indices = values ,
0 commit comments