@@ -60,6 +60,25 @@ def initialize_and_test_parameters(
60
60
local_size : Optional [int ] = None ,
61
61
) -> None :
62
62
with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
63
+ # Set seed again in each process to ensure consistency
64
+ torch .manual_seed (42 )
65
+ if torch .cuda .is_available ():
66
+ torch .cuda .manual_seed (42 )
67
+
68
+ key = (
69
+ f"embeddings.{ table_name } .weight"
70
+ if isinstance (embedding_tables , EmbeddingCollection )
71
+ else f"embedding_bags.{ table_name } .weight"
72
+ )
73
+
74
+ # Create the same fixed tensor in each process
75
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
76
+
77
+ # Load the fixed tensor into the embedding_tables to ensure consistency
78
+ embedding_tables .load_state_dict ({key : fixed_tensor })
79
+
80
+ # Store the original tensor on CPU for comparison BEFORE creating the model
81
+ original_tensor = embedding_tables .state_dict ()[key ].clone ().cpu ()
63
82
64
83
module_sharding_plan = construct_module_sharding_plan (
65
84
embedding_tables ,
@@ -81,12 +100,6 @@ def initialize_and_test_parameters(
81
100
device = ctx .device ,
82
101
)
83
102
84
- key = (
85
- f"embeddings.{ table_name } .weight"
86
- if isinstance (embedding_tables , EmbeddingCollection )
87
- else f"embedding_bags.{ table_name } .weight"
88
- )
89
-
90
103
if isinstance (model .state_dict ()[key ], DTensor ):
91
104
if ctx .rank == 0 :
92
105
gathered_tensor = torch .empty (model .state_dict ()[key ].size ())
@@ -96,28 +109,26 @@ def initialize_and_test_parameters(
96
109
gathered_tensor = model .state_dict ()[key ].full_tensor ()
97
110
if ctx .rank == 0 :
98
111
torch .testing .assert_close (
99
- gathered_tensor ,
100
- embedding_tables .state_dict ()[key ],
112
+ gathered_tensor .cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
101
113
)
102
114
elif isinstance (model .state_dict ()[key ], ShardedTensor ):
103
115
if ctx .rank == 0 :
104
- gathered_tensor = torch .empty_like (
105
- embedding_tables .state_dict ()[key ], device = ctx .device
106
- )
116
+ gathered_tensor = torch .empty_like (original_tensor , device = ctx .device )
107
117
else :
108
118
gathered_tensor = None
109
119
110
120
model .state_dict ()[key ].gather (dst = 0 , out = gathered_tensor )
111
121
112
122
if ctx .rank == 0 :
113
123
torch .testing .assert_close (
114
- none_throws (gathered_tensor ).to ("cpu" ),
115
- embedding_tables .state_dict ()[key ],
124
+ none_throws (gathered_tensor ).cpu (),
125
+ original_tensor ,
126
+ rtol = 1e-5 ,
127
+ atol = 1e-6 ,
116
128
)
117
129
elif isinstance (model .state_dict ()[key ], torch .Tensor ):
118
130
torch .testing .assert_close (
119
- embedding_tables .state_dict ()[key ].cpu (),
120
- model .state_dict ()[key ].cpu (),
131
+ model .state_dict ()[key ].cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
121
132
)
122
133
else :
123
134
raise AssertionError (
@@ -161,6 +172,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161
172
backend = "nccl"
162
173
table_name = "free_parameters"
163
174
175
+ # Set seed for deterministic tensor generation
176
+ torch .manual_seed (42 )
177
+
164
178
# Initialize embedding table on non-meta device, in this case cuda:0
165
179
embedding_tables = EmbeddingCollection (
166
180
tables = [
@@ -173,8 +187,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173
187
],
174
188
)
175
189
190
+ # Use a fixed tensor with explicit seeding for consistent testing
191
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
176
192
embedding_tables .load_state_dict (
177
- {f"embeddings.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
193
+ {f"embeddings.{ table_name } .weight" : fixed_tensor }
178
194
)
179
195
180
196
self ._run_multi_process_test (
@@ -210,6 +226,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210
226
backend = "nccl"
211
227
table_name = "free_parameters"
212
228
229
+ # Set seed for deterministic tensor generation
230
+ torch .manual_seed (42 )
231
+
213
232
# Initialize embedding bag on non-meta device, in this case cuda:0
214
233
embedding_tables = EmbeddingBagCollection (
215
234
tables = [
@@ -222,8 +241,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222
241
],
223
242
)
224
243
244
+ # Use a fixed tensor with explicit seeding for consistent testing
245
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
225
246
embedding_tables .load_state_dict (
226
- {f"embedding_bags.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
247
+ {f"embedding_bags.{ table_name } .weight" : fixed_tensor }
227
248
)
228
249
229
250
self ._run_multi_process_test (
0 commit comments