Skip to content

Commit 5abe203

Browse files
fix the testsuit
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent b9f908b commit 5abe203

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tests/test_draft_modeling_tp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def run_attention_tp_test(rank, world_size, temp_dir_name):
234234
else:
235235
sharded_param = param
236236
sharded_state_dict[name] = sharded_param
237-
attn_tp2.load_state_dict(sharded_state_dict, strict=False)
237+
attn_tp2.load_state_dict(sharded_state_dict)
238238
attn_tp2.eval()
239239

240240
input_tensor = torch.load(os.path.join(temp_dir_name, "attn_input.pth")).cuda(rank)

0 commit comments

Comments
 (0)