Skip to content

Commit afec712

Browse files
authored
Fix docstring for value_and_pullback_function (#125)
* Fix docstring for value_and_pullback_function * Remove comma * Change names to pff and pbf
1 parent 19e7d88 commit afec712

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "AbstractDifferentiation"
22
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
33
authors = ["Mohamed Tarek <[email protected]> and contributors"]
4-
version = "0.6.0"
4+
version = "0.6.1"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/AbstractDifferentiation.jl

+38-31
Original file line numberDiff line numberDiff line change
@@ -188,50 +188,54 @@ end
188188
"""
189189
AD.pushforward_function(ab::AD.AbstractBackend, f, xs...)
190190
191-
Return the pushforward function `pf` of the function `f` at the inputs `xs` using backend `ab`.
191+
Return the pushforward function `pff` of the function `f` at the inputs `xs` using backend `ab`.
192192
193-
The pushfoward function `pf` accepts as input a `Tuple` of tangents, one for each element in `xs`.
194-
If `xs` consists of a single element, `pf` can also accept a single tangent instead of a 1-tuple.
193+
The pushfoward function `pff` accepts as input a `Tuple` of tangents, one for each element in `xs`.
194+
If `xs` consists of a single element, `pff` can also accept a single tangent instead of a 1-tuple.
195195
"""
196196
function pushforward_function(ab::AbstractBackend, f, xs...)
197-
return (ds) -> begin
198-
return jacobian(
199-
lowest(ab),
200-
(xds...,) -> begin
201-
if ds isa Tuple
202-
@assert length(xs) == length(ds)
203-
newxs = xs .+ ds .* xds
204-
return f(newxs...)
205-
else
206-
newx = only(xs) + ds * only(xds)
207-
return f(newx)
208-
end
209-
end,
210-
_zero.(xs, ds)...,
211-
)
197+
function pff(ds)
198+
function pff_aux(xds...)
199+
if ds isa Tuple
200+
@assert length(xs) == length(ds)
201+
newxs = xs .+ ds .* xds
202+
return f(newxs...)
203+
else
204+
newx = only(xs) + ds * only(xds)
205+
return f(newx)
206+
end
207+
end
208+
return jacobian(lowest(ab), pff_aux, _zero.(xs, ds)...)
212209
end
210+
return pff
213211
end
214212

215213
"""
216214
AD.value_and_pushforward_function(ab::AD.AbstractBackend, f, xs...)
217215
218-
Return a function that, given tangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pushforward function `AD.pushforward_function(ab, f, xs...)` applied to `ts`.
216+
Return a single function `vpff` which, given tangents `ts`, computes the tuple `(v, p) = vpff(ts)` composed of
217+
218+
- the function value `v = f(xs...)`
219+
- the pushforward value `p = pff(ts)` given by the pushforward function `pff = AD.pushforward_function(ab, f, xs...)` applied to `ts`.
219220
220221
See also [`AbstractDifferentiation.pushforward_function`](@ref).
222+
223+
!!! warning
224+
This name should be understood as "(value and pushforward) function", and thus is not aligned with the reverse mode counterpart [`AbstractDifferentiation.value_and_pullback_function`](@ref).
221225
"""
222226
function value_and_pushforward_function(ab::AbstractBackend, f, xs...)
223227
n = length(xs)
224228
value = f(xs...)
225-
pf_function = pushforward_function(lowest(ab), f, xs...)
229+
pff = pushforward_function(lowest(ab), f, xs...)
226230

227-
return ds -> begin
231+
function vpff(ds)
228232
if !(ds isa Tuple)
229233
ds = (ds,)
230234
end
231235
@assert length(ds) == n
232-
pf = pf_function(ds)
233-
return value, pf
236+
return value, pff(ds)
234237
end
238+
return vpff
235239
end
236240

237241
_zero(::Number, d::Number) = zero(d)
@@ -253,10 +257,10 @@ end
253257
"""
254258
AD.pullback_function(ab::AD.AbstractBackend, f, xs...)
255259
256-
Return the pullback function `pb` of the function `f` at the inputs `xs` using backend `ab`.
260+
Return the pullback function `pbf` of the function `f` at the inputs `xs` using backend `ab`.
257261
258-
The pullback function `pb` accepts as input a `Tuple` of cotangents, one for each output of `f`.
259-
If `f` has a single output, `pb` can also accept a single input instead of a 1-tuple.
262+
The pullback function `pbf` accepts as input a `Tuple` of cotangents, one for each output of `f`.
263+
If `f` has a single output, `pbf` can also accept a single input instead of a 1-tuple.
260264
"""
261265
function pullback_function(ab::AbstractBackend, f, xs...)
262266
_, pbf = value_and_pullback_function(ab, f, xs...)
@@ -266,14 +270,17 @@ end
266270
"""
267271
AD.value_and_pullback_function(ab::AD.AbstractBackend, f, xs...)
268272
269-
Return a function that, given cotangents `ts`, computes the tuple `(v, p)` of the function value `v = f(xs...)` and the output `p` of the pullback function `AD.pullback_function(ab, f, xs...)` applied to `ts`.
273+
Return a tuple `(v, pbf)` of the function value `v = f(xs...)` and the pullback function `pbf = AD.pullback_function(ab, f, xs...)`.
270274
271275
See also [`AbstractDifferentiation.pullback_function`](@ref).
276+
277+
!!! warning
278+
This name should be understood as "value and (pullback function)", and thus is not aligned with the forward mode counterpart [`AbstractDifferentiation.value_and_pushforward_function`](@ref).
272279
"""
273280
function value_and_pullback_function(ab::AbstractBackend, f, xs...)
274281
value = f(xs...)
275-
function pullback_function(ws)
276-
function pullback_gradient_function(_xs...)
282+
function pbf(ws)
283+
function pbf_aux(_xs...)
277284
vs = f(_xs...)
278285
if ws isa Tuple
279286
@assert length(vs) == length(ws)
@@ -282,9 +289,9 @@ function value_and_pullback_function(ab::AbstractBackend, f, xs...)
282289
return _dot(vs, ws)
283290
end
284291
end
285-
return gradient(lowest(ab), pullback_gradient_function, xs...)
292+
return gradient(lowest(ab), pbf_aux, xs...)
286293
end
287-
return value, pullback_function
294+
return value, pbf
288295
end
289296

290297
struct LazyDerivative{B,F,X}

0 commit comments

Comments
 (0)