Skip to content

Commit d9ebb87

Browse files
jd7-trfacebook-github-bot
authored andcommitted
fix test (#3361)
Summary: Rollback Plan: Differential Revision: D81996413
1 parent 6e93e77 commit d9ebb87

File tree

1 file changed

+38
-17
lines changed

1 file changed

+38
-17
lines changed

torchrec/distributed/tests/test_init_parameters.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def initialize_and_test_parameters(
6060
local_size: Optional[int] = None,
6161
) -> None:
6262
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
63+
# Set seed again in each process to ensure consistency
64+
torch.manual_seed(42)
65+
if torch.cuda.is_available():
66+
torch.cuda.manual_seed(42)
67+
68+
key = (
69+
f"embeddings.{table_name}.weight"
70+
if isinstance(embedding_tables, EmbeddingCollection)
71+
else f"embedding_bags.{table_name}.weight"
72+
)
73+
74+
# Create the same fixed tensor in each process
75+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
76+
77+
# Load the fixed tensor into the embedding_tables to ensure consistency
78+
embedding_tables.load_state_dict({key: fixed_tensor})
79+
80+
# Store the original tensor on CPU for comparison BEFORE creating the model
81+
original_tensor = embedding_tables.state_dict()[key].clone().cpu()
6382

6483
module_sharding_plan = construct_module_sharding_plan(
6584
embedding_tables,
@@ -81,12 +100,6 @@ def initialize_and_test_parameters(
81100
device=ctx.device,
82101
)
83102

84-
key = (
85-
f"embeddings.{table_name}.weight"
86-
if isinstance(embedding_tables, EmbeddingCollection)
87-
else f"embedding_bags.{table_name}.weight"
88-
)
89-
90103
if isinstance(model.state_dict()[key], DTensor):
91104
if ctx.rank == 0:
92105
gathered_tensor = torch.empty(model.state_dict()[key].size())
@@ -96,28 +109,26 @@ def initialize_and_test_parameters(
96109
gathered_tensor = model.state_dict()[key].full_tensor()
97110
if ctx.rank == 0:
98111
torch.testing.assert_close(
99-
gathered_tensor,
100-
embedding_tables.state_dict()[key],
112+
gathered_tensor.cpu(), original_tensor, rtol=1e-5, atol=1e-6
101113
)
102114
elif isinstance(model.state_dict()[key], ShardedTensor):
103115
if ctx.rank == 0:
104-
gathered_tensor = torch.empty_like(
105-
embedding_tables.state_dict()[key], device=ctx.device
106-
)
116+
gathered_tensor = torch.empty_like(original_tensor, device=ctx.device)
107117
else:
108118
gathered_tensor = None
109119

110120
model.state_dict()[key].gather(dst=0, out=gathered_tensor)
111121

112122
if ctx.rank == 0:
113123
torch.testing.assert_close(
114-
none_throws(gathered_tensor).to("cpu"),
115-
embedding_tables.state_dict()[key],
124+
none_throws(gathered_tensor).cpu(),
125+
original_tensor,
126+
rtol=1e-5,
127+
atol=1e-6,
116128
)
117129
elif isinstance(model.state_dict()[key], torch.Tensor):
118130
torch.testing.assert_close(
119-
embedding_tables.state_dict()[key].cpu(),
120-
model.state_dict()[key].cpu(),
131+
model.state_dict()[key].cpu(), original_tensor, rtol=1e-5, atol=1e-6
121132
)
122133
else:
123134
raise AssertionError(
@@ -161,6 +172,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161172
backend = "nccl"
162173
table_name = "free_parameters"
163174

175+
# Set seed for deterministic tensor generation
176+
torch.manual_seed(42)
177+
164178
# Initialize embedding table on non-meta device, in this case cuda:0
165179
embedding_tables = EmbeddingCollection(
166180
tables=[
@@ -173,8 +187,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173187
],
174188
)
175189

190+
# Use a fixed tensor with explicit seeding for consistent testing
191+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
176192
embedding_tables.load_state_dict(
177-
{f"embeddings.{table_name}.weight": torch.randn(10, 64)}
193+
{f"embeddings.{table_name}.weight": fixed_tensor}
178194
)
179195

180196
self._run_multi_process_test(
@@ -210,6 +226,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210226
backend = "nccl"
211227
table_name = "free_parameters"
212228

229+
# Set seed for deterministic tensor generation
230+
torch.manual_seed(42)
231+
213232
# Initialize embedding bag on non-meta device, in this case cuda:0
214233
embedding_tables = EmbeddingBagCollection(
215234
tables=[
@@ -222,8 +241,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222241
],
223242
)
224243

244+
# Use a fixed tensor with explicit seeding for consistent testing
245+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
225246
embedding_tables.load_state_dict(
226-
{f"embedding_bags.{table_name}.weight": torch.randn(10, 64)}
247+
{f"embedding_bags.{table_name}.weight": fixed_tensor}
227248
)
228249

229250
self._run_multi_process_test(

0 commit comments

Comments
 (0)