Skip to content

Commit

Permalink
bleh
Browse files Browse the repository at this point in the history
  • Loading branch information
dlwh committed Jan 22, 2025
1 parent 8b88696 commit bc32f43
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_eval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def test_eval_lm():
Vocab = haliax.Axis("vocab", len(tok))
model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0))

state = TrainerState(0, model, model, jax.random.PRNGKey(0), True, None, None)
state = TrainerState(0, model, model, jax.random.PRNGKey(0), None, True, None, None)

save_checkpoint(state, 0, f"{f}/ckpt")

Expand Down Expand Up @@ -79,7 +79,7 @@ def test_eval_lm_from_hf():
Vocab = haliax.Axis("vocab", len(tok))
model = Gpt2LMHeadModel.init(Vocab, model_config, key=jax.random.PRNGKey(0))

state = TrainerState(0, model, model, jax.random.PRNGKey(0), True, None, None)
state = TrainerState(0, model, model, jax.random.PRNGKey(0), None, True, None, None)

save_checkpoint(state, 0, f"{f}/ckpt")

Expand Down

0 comments on commit bc32f43

Please sign in to comment.