@@ -117,7 +117,7 @@ function _prepare_hvp_aux(
117
117
rewrap = Rewrap (contexts... )
118
118
# Outer pushforward
119
119
new_contexts = (
120
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
120
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
121
121
)
122
122
outer_pushforward_prep = prepare_pushforward_nokwarg (
123
123
strict, shuffled_gradient, outer (backend), x, tx, new_contexts...
@@ -161,15 +161,15 @@ function _prepare_hvp_aux(
161
161
# Outer pushforward
162
162
new_contexts = (
163
163
FunctionContext (f),
164
- PrepContext (inner_gradient_prep),
165
- BackendContext (inner (backend)),
164
+ ConstantOrCache (inner_gradient_prep),
165
+ Constant (inner (backend)),
166
166
Constant (rewrap),
167
167
contexts... ,
168
168
)
169
169
new_contexts_in = (
170
170
FunctionContext (f),
171
- PrepContext (inner_gradient_in_prep),
172
- BackendContext (inner (backend)),
171
+ ConstantOrCache (inner_gradient_in_prep),
172
+ Constant (inner (backend)),
173
173
Constant (rewrap),
174
174
contexts... ,
175
175
)
@@ -228,15 +228,15 @@ function _prepare_hvp_aux(
228
228
# Outer pushforward
229
229
new_contexts = (
230
230
FunctionContext (f),
231
- PrepContext (inner_gradient_prep),
232
- BackendContext (inner (backend)),
231
+ ConstantOrCache (inner_gradient_prep),
232
+ Constant (inner (backend)),
233
233
Constant (rewrap),
234
234
contexts... ,
235
235
)
236
236
new_contexts_in = (
237
237
FunctionContext (f),
238
- PrepContext (inner_gradient_in_prep),
239
- BackendContext (inner (backend)),
238
+ ConstantOrCache (inner_gradient_in_prep),
239
+ Constant (inner (backend)),
240
240
Constant (rewrap),
241
241
contexts... ,
242
242
)
@@ -279,8 +279,8 @@ function hvp(
279
279
rewrap = Rewrap (contexts... )
280
280
new_contexts = (
281
281
FunctionContext (f),
282
- map (PrepContext , maybe_inner_gradient_prep)... ,
283
- BackendContext (inner (backend)),
282
+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
283
+ Constant (inner (backend)),
284
284
Constant (rewrap),
285
285
contexts... ,
286
286
)
@@ -318,8 +318,8 @@ function _hvp_aux!(
318
318
rewrap = Rewrap (contexts... )
319
319
new_contexts = (
320
320
FunctionContext (f),
321
- map (PrepContext , maybe_inner_gradient_in_prep)... ,
322
- BackendContext (inner (backend)),
321
+ map (ConstantOrCache , maybe_inner_gradient_in_prep)... ,
322
+ Constant (inner (backend)),
323
323
Constant (rewrap),
324
324
contexts... ,
325
325
)
@@ -349,8 +349,8 @@ function _hvp_aux!(
349
349
rewrap = Rewrap (contexts... )
350
350
new_contexts = (
351
351
FunctionContext (f),
352
- map (PrepContext , maybe_inner_gradient_prep)... ,
353
- BackendContext (inner (backend)),
352
+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
353
+ Constant (inner (backend)),
354
354
Constant (rewrap),
355
355
contexts... ,
356
356
)
@@ -378,8 +378,8 @@ function gradient_and_hvp(
378
378
rewrap = Rewrap (contexts... )
379
379
new_contexts = (
380
380
FunctionContext (f),
381
- map (PrepContext , maybe_inner_gradient_prep)... ,
382
- BackendContext (inner (backend)),
381
+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
382
+ Constant (inner (backend)),
383
383
Constant (rewrap),
384
384
contexts... ,
385
385
)
@@ -419,8 +419,8 @@ function _gradient_and_hvp_aux!(
419
419
rewrap = Rewrap (contexts... )
420
420
new_contexts = (
421
421
FunctionContext (f),
422
- map (PrepContext , maybe_inner_gradient_in_prep)... ,
423
- BackendContext (inner (backend)),
422
+ map (ConstantOrCache , maybe_inner_gradient_in_prep)... ,
423
+ Constant (inner (backend)),
424
424
Constant (rewrap),
425
425
contexts... ,
426
426
)
@@ -452,8 +452,8 @@ function _gradient_and_hvp_aux!(
452
452
rewrap = Rewrap (contexts... )
453
453
new_contexts = (
454
454
FunctionContext (f),
455
- map (PrepContext , maybe_inner_gradient_prep)... ,
456
- BackendContext (inner (backend)),
455
+ map (ConstantOrCache , maybe_inner_gradient_prep)... ,
456
+ Constant (inner (backend)),
457
457
Constant (rewrap),
458
458
contexts... ,
459
459
)
@@ -492,7 +492,7 @@ function _prepare_hvp_aux(
492
492
rewrap = Rewrap (contexts... )
493
493
new_contexts = (
494
494
FunctionContext (f),
495
- BackendContext (inner (backend)),
495
+ Constant (inner (backend)),
496
496
Constant (first (tx)),
497
497
Constant (rewrap),
498
498
contexts... ,
@@ -522,7 +522,7 @@ function hvp(
522
522
outer (backend),
523
523
x,
524
524
FunctionContext (f),
525
- BackendContext (inner (backend)),
525
+ Constant (inner (backend)),
526
526
Constant (dx),
527
527
Constant (rewrap),
528
528
contexts... ,
@@ -551,7 +551,7 @@ function hvp!(
551
551
outer (backend),
552
552
x,
553
553
FunctionContext (f),
554
- BackendContext (inner (backend)),
554
+ Constant (inner (backend)),
555
555
Constant (tx[b]),
556
556
Constant (rewrap),
557
557
contexts... ,
@@ -613,7 +613,7 @@ function _prepare_hvp_aux(
613
613
_sig = signature (f, backend, x, tx, contexts... ; strict)
614
614
rewrap = Rewrap (contexts... )
615
615
new_contexts = (
616
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
616
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
617
617
)
618
618
grad_buffer = similar (x)
619
619
outer_pullback_prep = prepare_pullback_nokwarg (
@@ -649,7 +649,7 @@ function hvp(
649
649
(; outer_pullback_prep) = prep
650
650
rewrap = Rewrap (contexts... )
651
651
new_contexts = (
652
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
652
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
653
653
)
654
654
return pullback (
655
655
shuffled_gradient, outer_pullback_prep, outer (backend), x, tx, new_contexts...
@@ -684,7 +684,7 @@ function _hvp_aux!(
684
684
(; grad_buffer, outer_pullback_in_prep) = prep
685
685
rewrap = Rewrap (contexts... )
686
686
new_contexts = (
687
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
687
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
688
688
)
689
689
return pullback! (
690
690
shuffled_gradient!,
@@ -711,7 +711,7 @@ function _hvp_aux!(
711
711
(; outer_pullback_prep) = prep
712
712
rewrap = Rewrap (contexts... )
713
713
new_contexts = (
714
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
714
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
715
715
)
716
716
return pullback! (
717
717
shuffled_gradient, tg, outer_pullback_prep, outer (backend), x, tx, new_contexts...
@@ -730,7 +730,7 @@ function gradient_and_hvp(
730
730
(; outer_pullback_prep) = prep
731
731
rewrap = Rewrap (contexts... )
732
732
new_contexts = (
733
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
733
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
734
734
)
735
735
return value_and_pullback (
736
736
shuffled_gradient, outer_pullback_prep, outer (backend), x, tx, new_contexts...
@@ -767,7 +767,7 @@ function _gradient_and_hvp_aux!(
767
767
(; outer_pullback_in_prep) = prep
768
768
rewrap = Rewrap (contexts... )
769
769
new_contexts = (
770
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
770
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
771
771
)
772
772
new_grad, _ = value_and_pullback! (
773
773
shuffled_gradient!,
@@ -796,7 +796,7 @@ function _gradient_and_hvp_aux!(
796
796
(; outer_pullback_prep) = prep
797
797
rewrap = Rewrap (contexts... )
798
798
new_contexts = (
799
- FunctionContext (f), BackendContext (inner (backend)), Constant (rewrap), contexts...
799
+ FunctionContext (f), Constant (inner (backend)), Constant (rewrap), contexts...
800
800
)
801
801
new_grad, _ = value_and_pullback! (
802
802
shuffled_gradient, tg, outer_pullback_prep, outer (backend), x, tx, new_contexts...
0 commit comments