Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def create_virtual_table_global_metadata(
# The param size only has the information for my_rank. In order to
# correctly calculate the size for other ranks, we need to use the current
# rank's shard size compared to the shard size of my_rank.
curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16]
curr_rank_rows = (
param.size()[0] # pyre-ignore[16]
* metadata.shards_metadata[rank].shard_sizes[0]
) // my_rank_shard_size
else:
curr_rank_rows = (
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1
Expand Down
2 changes: 1 addition & 1 deletion torchrec/distributed/tests/test_dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def are_sharded_ebc_modules_identical(
val2 = getattr(module2, attr)

assert type(val1) is type(val2)
if type(val1) is torch.Tensor:
if isinstance(val1, torch.Tensor):
torch.testing.assert_close(val1, val2)
else:
assert val1 == val2
Expand Down
Loading