Skip to content

Commit 31cf49f

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Fix flaky Github test (#3315)
Summary: Pull Request resolved: #3315 `_maybe_compute_stride_kjt` returns `stride = 0` when `len(keys) == 0` or when `stride_per_key_per_rank` has invalid values. Adding a guard condition `if stride > 0:` before the problematic `lengths.view(-1, stride)` operation in the `dist_init` method. When `stride` is 0 (e.g. indicating empty keys), skip the permutation. Reviewed By: nipung90 Differential Revision: D80948346 fbshipit-source-id: c881ad4571dbc295ab7517bc9e351872bba3970e
1 parent a29e47a commit 31cf49f

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

torchrec/sparse/jagged_tensor.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2992,18 +2992,19 @@ def dist_init(
29922992
values.numel(),
29932993
)
29942994
elif single_batch_per_rank:
2995-
(
2996-
lengths,
2997-
values,
2998-
weights,
2999-
) = torch.ops.fbgemm.permute_2D_sparse_data(
3000-
torch.jit._unwrap_optional(recat),
3001-
lengths.view(-1, stride),
3002-
values,
3003-
weights,
3004-
values.numel(),
3005-
)
3006-
lengths = lengths.view(-1)
2995+
if stride > 0:
2996+
(
2997+
lengths,
2998+
values,
2999+
weights,
3000+
) = torch.ops.fbgemm.permute_2D_sparse_data(
3001+
torch.jit._unwrap_optional(recat),
3002+
lengths.view(-1, stride),
3003+
values,
3004+
weights,
3005+
values.numel(),
3006+
)
3007+
lengths = lengths.view(-1)
30073008
else: # variable batch size per rank
30083009
(
30093010
lengths,

0 commit comments

Comments
 (0)