Commit 3499116
committed
Allow per-variable optimizer, add DispatchOptimizer.
- Adds a property `variable.optimizer` that defaults to `None`
- Adds a `DispatchOptimizer` that scans the list of trainable variables during build,
collects all unique per-variable optimizers, then dispatches the apply/stateless_apply
function to the correct optimizer if applicable.
- Modifies `trainer` so that during the optimizer build stage, checks if any variables
have a custom optimizer attached, and if so inserts a `DispatchOptimizer` to properly
handle them. This allows usage to be hidden from the user.
Context: for large embedding tables, we need special optimizers to be used so that
the tables can be updated in-place, rather than returning large gradients. The layer
will handle setting of the custom optimizers, but we need the trainer to be aware
of them and dispatch the embedding tables to different optimizers appropriately.1 parent 8a6e83b commit 3499116
File tree
9 files changed
+548
-7
lines changed- keras/src
- backend
- common
- tensorflow
- optimizers
- trainers
9 files changed
+548
-7
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
154 | 154 | | |
155 | 155 | | |
156 | 156 | | |
| 157 | + | |
| 158 | + | |
157 | 159 | | |
158 | 160 | | |
159 | 161 | | |
| |||
372 | 374 | | |
373 | 375 | | |
374 | 376 | | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
| 385 | + | |
| 386 | + | |
| 387 | + | |
| 388 | + | |
| 389 | + | |
| 390 | + | |
| 391 | + | |
| 392 | + | |
| 393 | + | |
| 394 | + | |
| 395 | + | |
| 396 | + | |
375 | 397 | | |
376 | 398 | | |
377 | 399 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
| |||
419 | 420 | | |
420 | 421 | | |
421 | 422 | | |
| 423 | + | |
| 424 | + | |
| 425 | + | |
| 426 | + | |
| 427 | + | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
422 | 435 | | |
423 | 436 | | |
424 | 437 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
71 | 71 | | |
72 | 72 | | |
73 | 73 | | |
74 | | - | |
| 74 | + | |
75 | 75 | | |
76 | | - | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
77 | 93 | | |
78 | 94 | | |
79 | 95 | | |
| |||
98 | 114 | | |
99 | 115 | | |
100 | 116 | | |
101 | | - | |
102 | | - | |
| 117 | + | |
103 | 118 | | |
104 | 119 | | |
105 | 120 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
5 | 5 | | |
6 | 6 | | |
7 | 7 | | |
| 8 | + | |
8 | 9 | | |
9 | 10 | | |
10 | 11 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
204 | 204 | | |
205 | 205 | | |
206 | 206 | | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
207 | 214 | | |
208 | 215 | | |
209 | 216 | | |
210 | 217 | | |
211 | 218 | | |
212 | 219 | | |
213 | 220 | | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
214 | 228 | | |
215 | 229 | | |
216 | 230 | | |
| |||
568 | 582 | | |
569 | 583 | | |
570 | 584 | | |
571 | | - | |
| 585 | + | |
572 | 586 | | |
573 | 587 | | |
574 | 588 | | |
| |||
0 commit comments