41
41
from jax .sharding import PartitionSpec as P
42
42
from utils import HasCache , TrainState
43
43
44
- from flax import linen as nn
45
44
from flax import nnx
46
45
from flax .training import checkpoints , common_utils
47
46
@@ -115,7 +114,7 @@ def compute_weighted_cross_entropy(
115
114
targets , vocab_size , on_value = confidence , off_value = low_confidence
116
115
)
117
116
118
- loss = - jnp .sum (soft_targets * nn .log_softmax (logits ), axis = - 1 )
117
+ loss = - jnp .sum (soft_targets * nnx .log_softmax (logits ), axis = - 1 )
119
118
loss = loss - normalizing_constant
120
119
121
120
normalizing_factor = np .prod (targets .shape )
@@ -191,9 +190,9 @@ def train_step(
191
190
192
191
dropout_rng = jax .random .fold_in (dropout_rng , state .step )
193
192
194
- def loss_fn (params ):
193
+ def loss_fn (params , other_variables ):
195
194
"""loss function used for training."""
196
- module = nnx .merge (state .graphdef , params )
195
+ module = nnx .merge (state .graphdef , params , other_variables )
197
196
module .set_attributes (deterministic = False , decode = False )
198
197
logits = module (
199
198
inputs ,
@@ -211,7 +210,7 @@ def loss_fn(params):
211
210
step = state .step
212
211
lr = learning_rate_fn (step )
213
212
grad_fn = jax .value_and_grad (loss_fn , has_aux = True )
214
- (_ , logits ), grads = grad_fn (state .params )
213
+ (_ , logits ), grads = grad_fn (state .params , state . other_variables )
215
214
new_state = state .apply_gradients (grads = grads )
216
215
metrics = compute_metrics (logits , inputs , weights )
217
216
metrics ['learning_rate' ] = lr
@@ -221,14 +220,15 @@ def loss_fn(params):
221
220
222
221
def eval_step (
223
222
params : nnx .State ,
223
+ other_variables : nnx .State ,
224
224
batch ,
225
225
graphdef : nnx .GraphDef [models .TransformerLM ],
226
226
label_smoothing = 0.0 ,
227
227
):
228
228
"""Calculate evaluation metrics on a batch."""
229
229
inputs = batch ['inputs' ]
230
230
weights = jnp .where (inputs > 0 , 1.0 , 0.0 )
231
- module = nnx .merge (graphdef , params )
231
+ module = nnx .merge (graphdef , params , other_variables )
232
232
module .set_attributes (deterministic = True , decode = False )
233
233
logits = module (inputs )
234
234
@@ -238,6 +238,7 @@ def eval_step(
238
238
def predict_step (
239
239
inputs ,
240
240
params : nnx .State ,
241
+ other_variables : nnx .State ,
241
242
rngkey : jax .Array ,
242
243
graphdef : nnx .GraphDef [models .TransformerLM ],
243
244
eos_id : int ,
@@ -247,20 +248,20 @@ def predict_step(
247
248
top_k : int ,
248
249
):
249
250
"""Predict language model on a batch."""
250
- module = nnx .merge (graphdef , params )
251
+ module = nnx .merge (graphdef , params , other_variables )
251
252
252
253
# TODO(cgarciae): check how pytorch does this.
253
254
for _path , m in module .iter_modules ():
254
255
if isinstance (m , HasCache ):
255
256
input_shape = (inputs .shape [0 ], max_decode_len , config .emb_dim )
256
257
m .init_cache (input_shape , dtype = config .dtype )
257
258
258
- graphdef , params , cache = nnx .split (module , nnx .Param , nnx .Cache )
259
+ graphdef , params , cache , other_variables = nnx .split (module , nnx .Param , nnx .Cache , ... )
259
260
260
261
def tokens_ids_to_logits (flat_ids , cache : nnx .State ):
261
262
"""Token slice to logits from decoder model."""
262
263
# --> [batch * beam, 1, vocab]
263
- module = nnx .merge (graphdef , params , cache )
264
+ module = nnx .merge (graphdef , params , cache , other_variables )
264
265
module .set_attributes (deterministic = True , decode = True )
265
266
logits = module (flat_ids )
266
267
cache = nnx .state (module , nnx .Cache )
@@ -313,7 +314,7 @@ def evaluate(
313
314
eval_iter = iter (eval_ds ) # pytype: disable=wrong-arg-types
314
315
for _ , eval_batch in zip (range (num_eval_steps ), eval_iter ):
315
316
eval_batch = jax .tree .map (lambda x : x ._numpy (), eval_batch ) # pylint: disable=protected-access
316
- metrics = jit_eval_step (state .params , eval_batch , state .graphdef )
317
+ metrics = jit_eval_step (state .params , state . other_variables , eval_batch , state .graphdef )
317
318
eval_metrics .append (metrics )
318
319
eval_metrics = common_utils .stack_forest (eval_metrics )
319
320
eval_metrics_sums = jax .tree .map (jnp .sum , eval_metrics )
@@ -330,6 +331,7 @@ def generate_prediction(
330
331
jit_pred_step ,
331
332
graphdef : nnx .GraphDef [models .TransformerLM ],
332
333
params : nnx .State ,
334
+ other_variables : nnx .State ,
333
335
tokenized_prompts ,
334
336
eos_id ,
335
337
inference_rng ,
@@ -359,6 +361,7 @@ def generate_prediction(
359
361
predicted = jit_pred_step (
360
362
pred_batch ,
361
363
params ,
364
+ other_variables ,
362
365
inference_rngs ,
363
366
graphdef ,
364
367
eos_id ,
@@ -389,6 +392,7 @@ def train_and_evaluate(config: default.Config, workdir: str):
389
392
workdir: Working directory for checkpoints and TF summaries. If this
390
393
contains checkpoint training will be resumed from the latest checkpoint.
391
394
"""
395
+ workdir = os .path .abspath (workdir )
392
396
tf .io .gfile .makedirs (workdir )
393
397
394
398
vocab_path = config .vocab_path
@@ -440,18 +444,15 @@ def encode_strings(strs, max_len):
440
444
max_len = max (config .max_target_length , config .max_eval_target_length ),
441
445
dropout_rate = config .dropout_rate ,
442
446
attention_dropout_rate = config .attention_dropout_rate ,
443
- kernel_init = nn .initializers .xavier_uniform (),
444
- bias_init = nn .initializers .normal (stddev = 1e-6 ),
447
+ kernel_init = nnx .initializers .xavier_uniform (),
448
+ bias_init = nnx .initializers .normal (stddev = 1e-6 ),
445
449
axis_rules = config .axis_rules ,
446
450
)
447
451
448
452
# Mesh definition
449
453
devices_array = utils .create_device_mesh (config )
450
454
mesh = Mesh (devices_array , config .mesh_axes )
451
455
452
- # print(mesh.shape)
453
- # exit()
454
-
455
456
start_step = 0
456
457
rng = jax .random .PRNGKey (config .seed )
457
458
rng , init_rng = jax .random .split (rng )
@@ -498,18 +499,19 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
498
499
None ,
499
500
), # type: ignore
500
501
out_shardings = (state_sharding , None ), # type: ignore
501
- static_argnums = ( 2 , 3 ),
502
+ static_argnames = ( "learning_rate_fn" , "label_smoothing" ),
502
503
donate_argnums = 0 ,
503
504
)
504
505
505
506
jit_eval_step = jax .jit (
506
507
eval_step ,
507
508
in_shardings = (
508
509
state_sharding .params ,
510
+ state_sharding .other_variables ,
509
511
data_sharding ,
510
512
), # type: ignore
511
513
out_shardings = None , # type: ignore
512
- static_argnums = ( 2 , 3 ),
514
+ static_argnames = ( "graphdef" , "label_smoothing" ),
513
515
)
514
516
515
517
# Since the inputs and rngkey args for predict_step will be batched,
@@ -520,6 +522,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
520
522
in_axes = (
521
523
0 ,
522
524
jax .tree .map (lambda x : None , state .params ),
525
+ jax .tree .map (lambda x : None , state .other_variables ),
523
526
0 ,
524
527
None ,
525
528
None ,
@@ -532,10 +535,11 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
532
535
in_shardings = (
533
536
data_sharding ,
534
537
state_sharding .params ,
538
+ state_sharding .other_variables ,
535
539
data_sharding ,
536
540
), # type: ignore
537
541
out_shardings = data_sharding , # type: ignore
538
- static_argnums = tuple (range (3 , 9 )),
542
+ static_argnums = tuple (range (4 , 10 )),
539
543
)
540
544
541
545
# Main Train Loop
@@ -575,7 +579,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
575
579
h (step )
576
580
577
581
# Periodic metric handling.
578
- if step % config .eval_every_steps == 0 or is_last_step :
582
+ if ( step > 0 and step % config .eval_every_steps == 0 ) or is_last_step :
579
583
with report_progress .timed ('training_metrics' ):
580
584
logging .info ('Gathering training metrics.' )
581
585
train_metrics = common_utils .stack_forest (train_metrics )
@@ -609,6 +613,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array):
609
613
jit_pred_step = jit_pred_step ,
610
614
graphdef = state .graphdef ,
611
615
params = state .params ,
616
+ other_variables = state .other_variables ,
612
617
tokenized_prompts = tokenized_prompts ,
613
618
eos_id = eos_id ,
614
619
inference_rng = inference_rng ,
0 commit comments