We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 61bf3cf commit 0ddad13Copy full SHA for 0ddad13
torchrec/distributed/embedding_kernel.py
@@ -105,7 +105,9 @@ def create_virtual_table_global_metadata(
105
# The param size only has the information for my_rank. In order to
106
# correctly calculate the size for other ranks, we need to use the current
107
# rank's shard size compared to the shard size of my_rank.
108
- curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16]
+ curr_rank_rows = (
109
+ param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]
110
+ ) // my_rank_shard_size
111
else:
112
curr_rank_rows = (
113
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1
0 commit comments