@@ -87,6 +87,9 @@ def initialize_and_test_parameters(
87
87
else f"embedding_bags.{ table_name } .weight"
88
88
)
89
89
90
+ # Store the original tensor on CPU for comparison
91
+ original_tensor = embedding_tables .state_dict ()[key ].clone ().cpu ()
92
+
90
93
if isinstance (model .state_dict ()[key ], DTensor ):
91
94
if ctx .rank == 0 :
92
95
gathered_tensor = torch .empty (model .state_dict ()[key ].size ())
@@ -96,28 +99,26 @@ def initialize_and_test_parameters(
96
99
gathered_tensor = model .state_dict ()[key ].full_tensor ()
97
100
if ctx .rank == 0 :
98
101
torch .testing .assert_close (
99
- gathered_tensor ,
100
- embedding_tables .state_dict ()[key ],
102
+ gathered_tensor .cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
101
103
)
102
104
elif isinstance (model .state_dict ()[key ], ShardedTensor ):
103
105
if ctx .rank == 0 :
104
- gathered_tensor = torch .empty_like (
105
- embedding_tables .state_dict ()[key ], device = ctx .device
106
- )
106
+ gathered_tensor = torch .empty_like (original_tensor , device = ctx .device )
107
107
else :
108
108
gathered_tensor = None
109
109
110
110
model .state_dict ()[key ].gather (dst = 0 , out = gathered_tensor )
111
111
112
112
if ctx .rank == 0 :
113
113
torch .testing .assert_close (
114
- none_throws (gathered_tensor ).to ("cpu" ),
115
- embedding_tables .state_dict ()[key ],
114
+ none_throws (gathered_tensor ).cpu (),
115
+ original_tensor ,
116
+ rtol = 1e-5 ,
117
+ atol = 1e-6 ,
116
118
)
117
119
elif isinstance (model .state_dict ()[key ], torch .Tensor ):
118
120
torch .testing .assert_close (
119
- embedding_tables .state_dict ()[key ].cpu (),
120
- model .state_dict ()[key ].cpu (),
121
+ model .state_dict ()[key ].cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
121
122
)
122
123
else :
123
124
raise AssertionError (
@@ -161,6 +162,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161
162
backend = "nccl"
162
163
table_name = "free_parameters"
163
164
165
+ # Set seed for deterministic tensor generation
166
+ torch .manual_seed (42 )
167
+
164
168
# Initialize embedding table on non-meta device, in this case cuda:0
165
169
embedding_tables = EmbeddingCollection (
166
170
tables = [
@@ -173,8 +177,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173
177
],
174
178
)
175
179
180
+ # Use a fixed tensor with explicit seeding for consistent testing
181
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
176
182
embedding_tables .load_state_dict (
177
- {f"embeddings.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
183
+ {f"embeddings.{ table_name } .weight" : fixed_tensor }
178
184
)
179
185
180
186
self ._run_multi_process_test (
@@ -210,6 +216,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210
216
backend = "nccl"
211
217
table_name = "free_parameters"
212
218
219
+ # Set seed for deterministic tensor generation
220
+ torch .manual_seed (42 )
221
+
213
222
# Initialize embedding bag on non-meta device, in this case cuda:0
214
223
embedding_tables = EmbeddingBagCollection (
215
224
tables = [
@@ -222,8 +231,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222
231
],
223
232
)
224
233
234
+ # Use a fixed tensor with explicit seeding for consistent testing
235
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
225
236
embedding_tables .load_state_dict (
226
- {f"embedding_bags.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
237
+ {f"embedding_bags.{ table_name } .weight" : fixed_tensor }
227
238
)
228
239
229
240
self ._run_multi_process_test (
0 commit comments