@@ -60,6 +60,25 @@ def initialize_and_test_parameters(
60
60
local_size : Optional [int ] = None ,
61
61
) -> None :
62
62
with MultiProcessContext (rank , world_size , backend , local_size ) as ctx :
63
+ # Set seed again in each process to ensure consistency
64
+ torch .manual_seed (42 )
65
+ if torch .cuda .is_available ():
66
+ torch .cuda .manual_seed (42 )
67
+
68
+ key = (
69
+ f"embeddings.{ table_name } .weight"
70
+ if isinstance (embedding_tables , EmbeddingCollection )
71
+ else f"embedding_bags.{ table_name } .weight"
72
+ )
73
+
74
+ # Create the same fixed tensor in each process
75
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
76
+
77
+ # Load the fixed tensor into the embedding_tables to ensure consistency
78
+ embedding_tables .load_state_dict ({key : fixed_tensor })
79
+
80
+ # Store the original tensor on CPU for comparison BEFORE creating the model
81
+ original_tensor = embedding_tables .state_dict ()[key ].clone ().cpu ()
63
82
64
83
module_sharding_plan = construct_module_sharding_plan (
65
84
embedding_tables ,
@@ -79,12 +98,8 @@ def initialize_and_test_parameters(
79
98
env = ShardingEnv .from_process_group (ctx .pg ),
80
99
sharders = sharders ,
81
100
device = ctx .device ,
82
- )
83
-
84
- key = (
85
- f"embeddings.{ table_name } .weight"
86
- if isinstance (embedding_tables , EmbeddingCollection )
87
- else f"embedding_bags.{ table_name } .weight"
101
+ init_data_parallel = False ,
102
+ init_parameters = False ,
88
103
)
89
104
90
105
if isinstance (model .state_dict ()[key ], DTensor ):
@@ -96,28 +111,26 @@ def initialize_and_test_parameters(
96
111
gathered_tensor = model .state_dict ()[key ].full_tensor ()
97
112
if ctx .rank == 0 :
98
113
torch .testing .assert_close (
99
- gathered_tensor ,
100
- embedding_tables .state_dict ()[key ],
114
+ gathered_tensor .cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
101
115
)
102
116
elif isinstance (model .state_dict ()[key ], ShardedTensor ):
103
117
if ctx .rank == 0 :
104
- gathered_tensor = torch .empty_like (
105
- embedding_tables .state_dict ()[key ], device = ctx .device
106
- )
118
+ gathered_tensor = torch .empty_like (original_tensor , device = ctx .device )
107
119
else :
108
120
gathered_tensor = None
109
121
110
122
model .state_dict ()[key ].gather (dst = 0 , out = gathered_tensor )
111
123
112
124
if ctx .rank == 0 :
113
125
torch .testing .assert_close (
114
- none_throws (gathered_tensor ).to ("cpu" ),
115
- embedding_tables .state_dict ()[key ],
126
+ none_throws (gathered_tensor ).cpu (),
127
+ original_tensor ,
128
+ rtol = 1e-5 ,
129
+ atol = 1e-6 ,
116
130
)
117
131
elif isinstance (model .state_dict ()[key ], torch .Tensor ):
118
132
torch .testing .assert_close (
119
- embedding_tables .state_dict ()[key ].cpu (),
120
- model .state_dict ()[key ].cpu (),
133
+ model .state_dict ()[key ].cpu (), original_tensor , rtol = 1e-5 , atol = 1e-6
121
134
)
122
135
else :
123
136
raise AssertionError (
@@ -161,6 +174,9 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
161
174
backend = "nccl"
162
175
table_name = "free_parameters"
163
176
177
+ # Set seed for deterministic tensor generation
178
+ torch .manual_seed (42 )
179
+
164
180
# Initialize embedding table on non-meta device, in this case cuda:0
165
181
embedding_tables = EmbeddingCollection (
166
182
tables = [
@@ -173,8 +189,10 @@ def test_initialize_parameters_ec(self, sharding_type: str) -> None:
173
189
],
174
190
)
175
191
192
+ # Use a fixed tensor with explicit seeding for consistent testing
193
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
176
194
embedding_tables .load_state_dict (
177
- {f"embeddings.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
195
+ {f"embeddings.{ table_name } .weight" : fixed_tensor }
178
196
)
179
197
180
198
self ._run_multi_process_test (
@@ -210,6 +228,9 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
210
228
backend = "nccl"
211
229
table_name = "free_parameters"
212
230
231
+ # Set seed for deterministic tensor generation
232
+ torch .manual_seed (42 )
233
+
213
234
# Initialize embedding bag on non-meta device, in this case cuda:0
214
235
embedding_tables = EmbeddingBagCollection (
215
236
tables = [
@@ -222,8 +243,10 @@ def test_initialize_parameters_ebc(self, sharding_type: str) -> None:
222
243
],
223
244
)
224
245
246
+ # Use a fixed tensor with explicit seeding for consistent testing
247
+ fixed_tensor = torch .randn (10 , 64 , generator = torch .Generator ().manual_seed (42 ))
225
248
embedding_tables .load_state_dict (
226
- {f"embedding_bags.{ table_name } .weight" : torch . randn ( 10 , 64 ) }
249
+ {f"embedding_bags.{ table_name } .weight" : fixed_tensor }
227
250
)
228
251
229
252
self ._run_multi_process_test (
0 commit comments