Skip to content

Commit d46842f

Browse files
jd7-trfacebook-github-bot
authored andcommitted
fix test (#3361)
Summary: Rollback Plan: Differential Revision: D81996413
1 parent 6e93e77 commit d46842f

File tree

1 file changed

+40
-17
lines changed

1 file changed

+40
-17
lines changed

torchrec/distributed/tests/test_init_parameters.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def initialize_and_test_parameters(
6060
local_size: Optional[int] = None,
6161
) -> None:
6262
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
63+
# Set seed again in each process to ensure consistency
64+
torch.manual_seed(42)
65+
if torch.cuda.is_available():
66+
torch.cuda.manual_seed(42)
67+
68+
key = (
69+
f"embeddings.{table_name}.weight"
70+
if isinstance(embedding_tables, EmbeddingCollection)
71+
else f"embedding_bags.{table_name}.weight"
72+
)
73+
74+
# Create the same fixed tensor in each process
75+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
76+
77+
# Load the fixed tensor into the embedding_tables to ensure consistency
78+
embedding_tables.load_state_dict({key: fixed_tensor})
79+
80+
# Store the original tensor on CPU for comparison BEFORE creating the model
81+
original_tensor = embedding_tables.state_dict()[key].clone().cpu()
6382

6483
module_sharding_plan = construct_module_sharding_plan(
6584
embedding_tables,
@@ -79,12 +98,8 @@ def initialize_and_test_parameters(
7998
env=ShardingEnv.from_process_group(ctx.pg),
8099
sharders=sharders,
81100
device=ctx.device,
82-
)
83-
84-
key = (
85-
f"embeddings.{table_name}.weight"
86-
if isinstance(embedding_tables, EmbeddingCollection)
87-
else f"embedding_bags.{table_name}.weight"
101+
init_data_parallel=False,
102+
init_parameters=False,
88103
)
89104

90105
if isinstance(model.state_dict()[key], DTensor):
@@ -96,28 +111,26 @@ def initialize_and_test_parameters(
96111
gathered_tensor = model.state_dict()[key].full_tensor()
97112
if ctx.rank == 0:
98113
torch.testing.assert_close(
99-
gathered_tensor,
100-
embedding_tables.state_dict()[key],
114+
gathered_tensor.cpu(), original_tensor, rtol=1e-5, atol=1e-6
101115
)
102116
elif isinstance(model.state_dict()[key], ShardedTensor):
103117
if ctx.rank == 0:
104-
gathered_tensor = torch.empty_like(
105-
embedding_tables.state_dict()[key], device=ctx.device
106-
)
118+
gathered_tensor = torch.empty_like(original_tensor, device=ctx.device)
107119
else:
108120
gathered_tensor = None
109121

110122
model.state_dict()[key].gather(dst=0, out=gathered_tensor)
111123

112124
if ctx.rank == 0:
113125
torch.testing.assert_close(
114-
none_throws(gathered_tensor).to("cpu"),
115-
embedding_tables.state_dict()[key],
126+
none_throws(gathered_tensor).cpu(),
127+
original_tensor,
128+
rtol=1e-5,
129+
atol=1e-6,
116130
)
117131
elif isinstance(model.state_dict()[key], torch.Tensor):
118132
torch.testing.assert_close(
119-
embedding_tables.state_dict()[key].cpu(),
120-
model.state_dict()[key].cpu(),
133+
model.state_dict()[key].cpu(), original_tensor, rtol=1e-5, atol=1e-6
121134
)
122135
else:
123136
raise AssertionError(
@@ -161,6 +174,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161174
backend = "nccl"
162175
table_name = "free_parameters"
163176

177+
# Set seed for deterministic tensor generation
178+
torch.manual_seed(42)
179+
164180
# Initialize embedding table on non-meta device, in this case cuda:0
165181
embedding_tables = EmbeddingCollection(
166182
tables=[
@@ -173,8 +189,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173189
],
174190
)
175191

192+
# Use a fixed tensor with explicit seeding for consistent testing
193+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
176194
embedding_tables.load_state_dict(
177-
{f"embeddings.{table_name}.weight": torch.randn(10, 64)}
195+
{f"embeddings.{table_name}.weight": fixed_tensor}
178196
)
179197

180198
self._run_multi_process_test(
@@ -210,6 +228,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210228
backend = "nccl"
211229
table_name = "free_parameters"
212230

231+
# Set seed for deterministic tensor generation
232+
torch.manual_seed(42)
233+
213234
# Initialize embedding bag on non-meta device, in this case cuda:0
214235
embedding_tables = EmbeddingBagCollection(
215236
tables=[
@@ -222,8 +243,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222243
],
223244
)
224245

246+
# Use a fixed tensor with explicit seeding for consistent testing
247+
fixed_tensor = torch.randn(10, 64, generator=torch.Generator().manual_seed(42))
225248
embedding_tables.load_state_dict(
226-
{f"embedding_bags.{table_name}.weight": torch.randn(10, 64)}
249+
{f"embedding_bags.{table_name}.weight": fixed_tensor}
227250
)
228251

229252
self._run_multi_process_test(

0 commit comments

Comments
 (0)