Skip to content

Commit 02f44d9

Browse files
jd7-trfacebook-github-bot
authored andcommitted
fix test
Differential Revision: D81996413
1 parent 6e93e77 commit 02f44d9

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

torchrec/distributed/tests/test_init_parameters.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def initialize_and_test_parameters(
8787
else f"embedding_bags.{table_name}.weight"
8888
)
8989

90+
# Store the original tensor on CPU for comparison
91+
original_tensor = embedding_tables.state_dict()[key].clone().cpu()
92+
9093
if isinstance(model.state_dict()[key], DTensor):
9194
if ctx.rank == 0:
9295
gathered_tensor = torch.empty(model.state_dict()[key].size())
@@ -96,28 +99,26 @@ def initialize_and_test_parameters(
9699
gathered_tensor = model.state_dict()[key].full_tensor()
97100
if ctx.rank == 0:
98101
torch.testing.assert_close(
99-
gathered_tensor,
100-
embedding_tables.state_dict()[key],
102+
gathered_tensor.cpu(), original_tensor, rtol=1e-5, atol=1e-6
101103
)
102104
elif isinstance(model.state_dict()[key], ShardedTensor):
103105
if ctx.rank == 0:
104-
gathered_tensor = torch.empty_like(
105-
embedding_tables.state_dict()[key], device=ctx.device
106-
)
106+
gathered_tensor = torch.empty_like(original_tensor, device=ctx.device)
107107
else:
108108
gathered_tensor = None
109109

110110
model.state_dict()[key].gather(dst=0, out=gathered_tensor)
111111

112112
if ctx.rank == 0:
113113
torch.testing.assert_close(
114-
none_throws(gathered_tensor).to("cpu"),
115-
embedding_tables.state_dict()[key],
114+
none_throws(gathered_tensor).cpu(),
115+
original_tensor,
116+
rtol=1e-5,
117+
atol=1e-6,
116118
)
117119
elif isinstance(model.state_dict()[key], torch.Tensor):
118120
torch.testing.assert_close(
119-
embedding_tables.state_dict()[key].cpu(),
120-
model.state_dict()[key].cpu(),
121+
model.state_dict()[key].cpu(), original_tensor, rtol=1e-5, atol=1e-6
121122
)
122123
else:
123124
raise AssertionError(
@@ -161,6 +162,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161162
backend = "nccl"
162163
table_name = "free_parameters"
163164

165+
# Set seed for deterministic tensor generation
166+
torch.manual_seed(42)
167+
164168
# Initialize embedding table on non-meta device, in this case cuda:0
165169
embedding_tables = EmbeddingCollection(
166170
tables=[
@@ -173,8 +177,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173177
],
174178
)
175179

180+
# Use a fixed tensor with explicit seeding for consistent testing
181+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
176182
embedding_tables.load_state_dict(
177-
{f"embeddings.{table_name}.weight": torch.randn(10, 64)}
183+
{f"embeddings.{table_name}.weight": fixed_tensor}
178184
)
179185

180186
self._run_multi_process_test(
@@ -210,6 +216,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210216
backend = "nccl"
211217
table_name = "free_parameters"
212218

219+
# Set seed for deterministic tensor generation
220+
torch.manual_seed(42)
221+
213222
# Initialize embedding bag on non-meta device, in this case cuda:0
214223
embedding_tables = EmbeddingBagCollection(
215224
tables=[
@@ -222,8 +231,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222231
],
223232
)
224233

234+
# Use a fixed tensor with explicit seeding for consistent testing
235+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
225236
embedding_tables.load_state_dict(
226-
{f"embedding_bags.{table_name}.weight": torch.randn(10, 64)}
237+
{f"embedding_bags.{table_name}.weight": fixed_tensor}
227238
)
228239

229240
self._run_multi_process_test(

0 commit comments

Comments
 (0)