diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py index e444f59c8..09a2c0375 100644 --- a/torchrec/distributed/embedding_kernel.py +++ b/torchrec/distributed/embedding_kernel.py @@ -105,7 +105,10 @@ def create_virtual_table_global_metadata( # The param size only has the information for my_rank. In order to # correctly calculate the size for other ranks, we need to use the current # rank's shard size compared to the shard size of my_rank. - curr_rank_rows = (param.size()[0] * metadata.shards_metadata[rank].shard_sizes[0]) // my_rank_shard_size # pyre-ignore[16] + curr_rank_rows = ( + param.size()[0] # pyre-ignore[16] + * metadata.shards_metadata[rank].shard_sizes[0] + ) // my_rank_shard_size else: curr_rank_rows = ( weight_count_per_rank[rank] if weight_count_per_rank is not None else 1 diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 9b06db0b6..855c8cc20 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -206,7 +206,7 @@ def are_sharded_ebc_modules_identical( val2 = getattr(module2, attr) assert type(val1) is type(val2) - if type(val1) is torch.Tensor: + if isinstance(val1, torch.Tensor): torch.testing.assert_close(val1, val2) else: assert val1 == val2