Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
5777e79
Create frontend and IR for vectorised AD.
athas Dec 5, 2024
a54fa53
Hook it up in internalisation, too.
athas Dec 5, 2024
7ecd907
Basic support for vectorised forward-mode AD.
athas Dec 7, 2024
a8e616e
Forgot to add tests.
athas Dec 22, 2024
91465c3
Merge branch 'master' into ad-vec
athas Aug 14, 2025
9b488e1
Scan test.
athas Aug 14, 2025
9d15ec6
Add jvp_vec and vjp_vec.
athas Aug 14, 2025
9a1eb4b
Merge branch 'master' into ad-vec
athas Aug 14, 2025
bc01f7d
Merge branch 'master' into ad-vec
athas Aug 14, 2025
e75f617
This should not need modification.
athas Aug 14, 2025
cbd98fb
Change how accumulators are handled.
athas Aug 14, 2025
8ae043a
Implement vectorised scatter.
athas Aug 15, 2025
6760758
Add map test.
athas Aug 15, 2025
240edfe
More tests, some that fail.
athas Aug 15, 2025
644b8c2
Tweak the tests.
athas Aug 15, 2025
e9eac0a
Another test.
athas Aug 15, 2025
d1438bc
Some hackyish fixes.
athas Aug 15, 2025
5a4128c
Merge branch 'master' into ad-vec
athas Aug 15, 2025
41e2738
Fix a handful of things.
athas Aug 17, 2025
e8c0c14
More things work.
athas Aug 17, 2025
f67a2b0
Minor fixes.
athas Aug 17, 2025
0a10a5c
Fix vjp2_vec in interpreter.
athas Aug 19, 2025
b21f5f6
Start work on vectorised reverse mode AD.
athas Aug 19, 2025
ebfae42
Support primitive functions properly.
athas Aug 20, 2025
1684541
Make unops and primfuns work.
athas Aug 21, 2025
4819c5c
Work on vectorised reductions.
athas Aug 22, 2025
512c0ff
Start on scan.
athas Aug 26, 2025
47386d2
Merge branch 'master' into ad-vec
athas Aug 26, 2025
bee9ae3
More work.
athas Aug 26, 2025
7d4f77d
Some tests, some of which fail.
athas Aug 26, 2025
76c52be
More stuff works.
athas Aug 26, 2025
bc6df46
Merge branch 'master' into ad-vec
athas Aug 26, 2025
7fb9eee
Merge branch 'master' into ad-vec
athas Aug 28, 2025
c59aefb
Work on histograms.
athas Aug 28, 2025
1a9e8ca
Merge branch 'master' into ad-vec
athas Sep 3, 2025
6fa2747
Merge branch 'master' into ad-vec
athas Sep 7, 2025
4d717f1
Merge branch 'master' into ad-vec
athas Sep 8, 2025
eccd025
Merge branch 'master' into ad-vec
athas Sep 25, 2025
a91b911
Add failing test.
athas Sep 25, 2025
868d970
Merge branch 'master' into ad-vec
athas Oct 2, 2025
e6949ab
Handle vector adjoints.
athas Oct 8, 2025
5425e7f
Support unrolling of maps over accumulators.
athas Oct 8, 2025
35a0968
Support unrolling of vectorised AD.
athas Oct 8, 2025
de1e700
Merge branch 'master' into ad-vec
athas Dec 15, 2025
7f6a862
Merge branch 'master' into ad-vec
athas May 29, 2026
22d6070
Merge branch 'master' into ad-vec
athas May 29, 2026
9679a74
Extend AD vectorized entry coverage for additional `tests/ad` cases (…
Copilot May 29, 2026
4892703
Need to transpose here.
athas May 29, 2026
5410a1a
Share code.
athas May 29, 2026
76be3e3
Fix typo.
athas May 29, 2026
3e3aecb
Fix some more things.
athas May 29, 2026
7046d72
Handle vector here.
athas May 29, 2026
8ff4dc3
Generate right tangents for stream.
athas May 29, 2026
e8c06d0
Support Sparse adjoints in vectorised AD (#2473)
Copilot May 29, 2026
26801f4
Elaborate comment.
athas May 29, 2026
be68354
Fix scatter.
athas May 30, 2026
b65ff8e
Merge branch 'master' into ad-vec
athas May 30, 2026
a0d0fc2
Fix forward-mode for Stream.
athas May 30, 2026
e917e6b
Merge branch 'master' into ad-vec
athas May 30, 2026
dd07eba
Fix vectorised scans.
athas May 30, 2026
4859c04
Fix final known vector-AD bug.
athas May 30, 2026
1756d1b
Merge branch 'master' into ad-vec
athas May 31, 2026
7210cf3
Unify tests.
athas May 31, 2026
fed699d
Remove this unroll.
athas May 31, 2026
9675ba5
Merge branch 'master' into ad-vec
athas May 31, 2026
15cbb7b
Merge branch 'master' into ad-vec
athas Jun 1, 2026
b78dfc0
Merge branch 'master' into ad-vec
athas Jun 3, 2026
e035fca
Specialised handling of Hist.
athas Jun 3, 2026
02b209e
Revert "Specialised handling of Hist."
athas Jun 3, 2026
1453a54
Merge branch 'master' into ad-vec
athas Jun 3, 2026
59e92b0
Merge branch 'master' into ad-vec
athas Jun 4, 2026
e41f1e3
Merge branch 'master' into ad-vec
athas Jun 5, 2026
e33ace4
Refreshen documentation.
athas Jun 5, 2026
a3c8fb8
Merge branch 'master' into ad-vec
athas Jun 10, 2026
d1f4d31
More robust equality checking.
athas Jun 10, 2026
eeff1c7
Individual tests.
athas Jun 10, 2026
055da84
Lower tolerance.
athas Jun 10, 2026
4a94616
Rewrite this test to be less crazy.
athas Jun 10, 2026
2e40acd
Don't need this.
athas Jun 11, 2026
27bd8a3
No ISPC for this one.
athas Jun 11, 2026
7d8c5d6
Merge branch 'master' into ad-vec
athas Jun 11, 2026
3b7698f
Merge branch 'master' into ad-vec
athas Jun 12, 2026
1d441f9
Remove duplicate tests.
athas Jun 12, 2026
0afeea9
Refresh terminology.
athas Jun 12, 2026
a187ad4
Also update interpreter.
athas Jun 12, 2026
a21e7ff
Add failing test.
athas Jun 12, 2026
85ecf8a
Fix typo in comment.
athas Jun 12, 2026
21260c5
Handle with_vjp in vector mode.
athas Jun 12, 2026
36ca87a
Nomenclature fixes.
athas Jun 12, 2026
4c0a484
Fix markup.
athas Jun 12, 2026
eb74e14
Better reference.
athas Jun 12, 2026
349e5bc
Improve documentation.
athas Jun 12, 2026
788b443
Further elaboration.
athas Jun 12, 2026
04f3bb9
Clarify.
athas Jun 12, 2026
94cb3ab
More.
athas Jun 12, 2026
3834719
Better naming.
athas Jun 13, 2026
f1df3dc
More docs.
athas Jun 14, 2026
100de3d
Minor fices.
athas Jun 15, 2026
9c46df2
Merge branch 'master' into ad-vec
athas Jun 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
uses the hardware support for `f16`, similarly to the CUDA backend.
Implemented by Jérôme Wagner. (#2470)

* Vector AD, exposed through the functions `jmp` and `mjp`.

* All opaque values available over the C API can now be decomposed into their
constituents.

Expand Down
2 changes: 2 additions & 0 deletions futhark.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ library
Futhark.Actions
Futhark.AD.Derivatives
Futhark.AD.Fwd
Futhark.AD.Shared
Futhark.AD.Rev
Futhark.AD.Rev.Acc
Futhark.AD.Rev.Loop
Futhark.AD.Rev.Hist
Futhark.AD.Rev.Map
Expand Down
189 changes: 119 additions & 70 deletions prelude/ad.fut
Original file line number Diff line number Diff line change
Expand Up @@ -14,104 +14,130 @@
--
-- Futhark's AD support includes the following:
--
-- * Differentiation operators for forward-mode (`jvp`) and reverse-mode
-- (`vjp`).
-- * Differential operators for forward-mode (`jvp`@term) and reverse-mode
-- (`vjp`@term).
--
-- * Arbitrary control flow in differentiable code.
-- * Almost arbitrary control flow in differentiable code (some limitations
-- apply when using GPU backends, see below).
--
-- * Higher order derivatives by nesting differentiation operators, including
-- arbitrary mixing of forward- and reverse mode (although using multiple
-- rounds of reverse mode is rarely useful and often slow).
--
-- * Custom derivatives (`with_vjp`).
-- * Custom derivatives (`with_vjp`@term).
--
-- * Vector AD (`mjp`@term, `jmp`@term), sometimes also known as "batched" or
-- "multi-directional" AD.
--
-- * Checkpointing of sequential loops.
--
-- ## Jacobians
--
-- For a differentiable function *f* whose input comprise *n* scalars
-- and whose output comprises *m* scalars, the
-- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant)
-- for a given input point is an *m* by *n* matrix of scalars that
-- each represent a [partial
-- derivatives](https://en.wikipedia.org/wiki/Partial_derivative).
-- Intuitively, position *(i,j)* of the Jacobian describes how
-- sensitive output *i* is to input *j*. The notion of Jacobian
-- generalises to functions that accept or produce compound structures
-- such as arrays, records, sums, and so on, simply by "flattening
-- out" the values and considering only their constituent scalars.
--
-- Computing the full Jacobian is usually costly and sometimes not
-- necessary, and it is not part of the AD facility provided by
-- Futhark. Instead it is possible to parts of the Jacobian.
--
-- We can take the product of an an *m* by *n* Jacobian with an
-- *n*-element *tangent vector* to produce an *m*-element vector
-- (*Jacobian-vector product*). Such a product can be computed in a
-- single (augmented) execution of the function *f*, and by choosing
-- the tangent vector appropriately we can use this to compute the
-- full Jacobian. This is provided by the function `jvp`.
-- For a differentiable function *f* whose input comprise *n* scalars and whose
-- output comprises *m* scalars, the
-- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) for
-- a given input point is an *m* by *n* matrix of scalars that each represent a
-- [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative).
-- Intuitively, position *(i,j)* of the Jacobian describes how sensitive output
-- *i* is to input *j*. The notion of Jacobian generalises to functions that
-- accept or produce compound structures such as arrays, records, sums, and so
-- on, simply by "flattening out" the values and considering only their
-- constituent scalars.
--
-- Computing the full Jacobian is usually costly and sometimes not necessary,
-- and it is not part of the AD facility provided by Futhark. Instead it is
-- possible to compute parts of the Jacobian, which semantically (but not
-- operationally) can be seen as multiplying the Jacobian with a vector,
-- producing a vector. However, it is important to understand that the full
-- Jacobian is *not* constructed as an intermediate step.
--
-- We can take the product of an *m* by *n* Jacobian with an *n*-element
-- *tangent vector* to produce an *m*-element vector (*Jacobian-vector
-- product*). Such a product can be computed in a single (augmented) execution
-- of the function *f*. This is provided by the function `jvp`.
--
-- We can also take the product of an *m*-element vector *cotangent
-- vector* with the *m* by *n* Jacobian to produce an *n*-element
-- vector (*Vector-Jacobian product*). This too can be computed in a
-- vector (*vector-Jacobian product*). This too can be computed in a
-- single execution of *f*, with `vjp`.
--
-- We can use the `jvp` function to produce a *column* of the full
-- Jacobian, and `vjp` to produce a *row*. Which is superior for a
-- given situation depends on whether the function has more inputs or
-- outputs.
-- A tangent has the same structure as the input and represents a direction in
-- input space. A cotangent has the same structure as the output and represents
-- sensitivities flowing backwards through the computation.
--
-- Using an elementary (co-)tangent vector, we can use the `jvp` function to
-- produce a *column* of the full Jacobian, and `vjp` to produce a *row*, with
-- the nonzero element of the vector identifying which column or row is
-- extracted. Which is superior for a given situation depends on whether the
-- function has more inputs or outputs.
--
-- You can freely nest `vjp` and `jvp` to compute higher-order
-- derivatives.
-- We can freely nest `vjp` and `jvp` to compute higher-order derivatives.
--
-- ## Efficiency
--
-- Both `jvp` and `vjp` work by transforming the program to carry
-- along extra information associated with each scalar value.
--
-- In the case of `jvp`, this extra information takes the form of an
-- additional scalar representing the tangent, which is then
-- propagated in each scalar computation using essentially the [chain
-- rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp`
-- has a memory overhead of approximately *2x*, and a computational
-- overhead of slightly more, but usually less than *4x*.
--
-- In the case of `vjp`, since our starting point is a *cotangent*,
-- the function is essentially first run forward, then backwards (the
-- *return sweep*) to propagate the cotangent. During the return
-- sweep, all intermediate results computed during the forward sweep
-- must still be available, and must therefore be stored in memory
-- during the forward sweep. This means that the memory usage of `vjp`
-- is proportional to the number of sequential steps of the original
-- function (essentially turning *time* into *space*). The compiler
-- does a nontrivial amount of optimisation to ameliorate this
-- overhead (see [AD for an Array Language with Nested
-- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)),
-- but it can still be substantial for programs with deep sequential
-- loops.
-- In the case of `jvp` ("forward mode", or "tangent mode"), this extra
-- information takes the form of an additional scalar representing the tangent,
-- which is then propagated in each scalar computation using essentially the
-- [chain rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp` has
-- a memory overhead of approximately *2x*, and a computational overhead of
-- slightly more, but usually less than *4x*.
--
-- In the case of `vjp` ("reverse mode" or "adjoint mode"), since our starting
-- point is a *cotangent*, the function is essentially first run forward, then
-- backwards (the *return sweep*) to propagate the cotangent. During the return
-- sweep, all intermediate results computed during the forward sweep must still
-- be available, and must therefore be stored in memory during the forward sweep
-- - this is called "the tape". This means that the memory usage of `vjp` is
-- proportional to the number of sequential steps of the original function
-- (essentially turning *time* into *space*). The compiler does a nontrivial
-- amount of optimisation to ameliorate this overhead (see [AD for an Array
-- Language with Nested
-- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)), but it can
-- still be substantial for programs with deep sequential loops.
--
-- Nesting `vjp`, understood as applying `vjp` to the result of `vjp`, is
-- usually a bad idea, as the code structure produced by `vjp` is fairly
-- complicated, due to the tape management. Passing the output of `jvp` to
-- `vjp`, or the other way, is however fine. As a rule of thumb, whenever you
-- stack multiple differential operators, make sure only one of them is `vjp` or
-- related ones.
--
-- When using vector AD (`mjp`@term/`jmp`@term), each scalar is associated with
-- a vector of tangents or cotangents, and the space overhead for storing these
-- is therefore multiplied with the vector size. However, in the case of `vjp`,
-- the intermediate results are only stored once. It varies on a case-by-case
-- basis whether vector AD is faster than using `map` on top of
-- `vjp`@term/`jvp`@term. Vector AD essentially converts propagation of
-- (co-)tangents from scalar to array operations, which can have a significant
-- impact on memory accesses, depending on how the compiler manages to optimise
-- the resulting code. It is hard to predict whether this offsets the reduction
-- in primal work. If the vector size is a constant, and the `#[unroll]`
-- attribute is put on the AD operator, then the vectors become unrolled (turned
-- into tuples, essentially), although this should only be done when the vector
-- size is quite small, as the increase in code size is substantial.
--
-- ## Differentiable functions
--
-- AD only gives meaningful results for differentiable functions. The
-- Futhark type system does not distinguish differentiable or
-- non-differentiable operations. As a rule of thumb, a function is
-- differentiable if its results are computed using a composition of
-- primitive floating-point operations, without ever converting to or
-- from integers.
-- AD only gives meaningful results for differentiable functions. The Futhark
-- type system does not distinguish differentiable from non-differentiable
-- operations. As a rule of thumb, a function is differentiable if its results
-- are computed using a composition of primitive floating-point operations,
-- without ever converting to or from integers. Most functions will also have
-- discontinuities around values that influence control flow.
--
-- Note that a function whose input or output is a sum type with more
-- than one constructor is *not* differentiable (or at least the
-- sum-typed part is not). This is because the choice of constructor
-- is not a continuous quantity.
-- Note that a function whose input or output is a sum type with more than one
-- constructor is *not* differentiable (or at least the sum-typed part is not).
-- This is because the choice of constructor is not a continuous quantity.
--
-- ## Limitations
--
-- `jvp` is expected to work in all cases. `vjp` has limitations when
-- using the GPU backends similar to those for irregular flattening.
-- Specifically, you should avoid structures with variant sizes, such
-- as loops that carry an array that changes size through the
-- execution of the loop.
-- `jvp` is expected to work in all cases. `vjp` has limitations when using the
-- GPU backends similar to those for irregular flattening. Specifically, you
-- should avoid structures with variant sizes, such as loops that carry an array
-- that changes size through the execution of the loop.

-- | Jacobian-Vector Product ("forward mode"), producing also the
-- primal result as the first element of the result tuple.
Expand All @@ -123,6 +149,20 @@ def jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) =
def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) =
intrinsics.vjp2 f x y'

-- | Jacobian-Matrix Product, returning also the primal result. As `jvp2`, but
-- accepts an array of seed vectors (hence "matrix", although transposed).
-- Semantically equivalent to mapping, but may be more efficient. If used with
-- `#[unroll]`, tangent calculations are unrolled when possible.
def jmp2 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) =
intrinsics.jmp2 f x x'

-- | Matrix-Jacobian Product, returning also the primal result. As `vjp2`, but
-- accepts an array of seed vectors (hence "matrix"). Semantically equivalent to
-- mapping, but may be more efficient. If used with `#[unroll]`, adjoint
-- calculations are unrolled when possible.
def mjp2 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) =
intrinsics.mjp2 f x y'

-- | Jacobian-Vector Product ("forward mode").
def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b =
(jvp2 f x x').1
Expand All @@ -131,6 +171,16 @@ def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b =
def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a =
(vjp2 f x y').1

-- | Jacobian-Matrix Product. As `jvp`, but accepts a vector of seed values.
-- Semantically equivalent to mapping, but may be more efficient.
def jmp 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : [n]b =
(jmp2 f x x').1

-- | Matrix-Jacobian product. As `vjp`, but accepts a vector of seed values.
-- Semantically equivalent to mapping, but may be more efficient.
def mjp 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a =
(mjp2 f x y').1

-- | Provide custom reverse-mode adjoint code for a given function. This is
-- useful when the adjoint synthesised by AD is not as good as one that is known
-- analytically.
Expand All @@ -144,8 +194,7 @@ def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a =
-- primal result of `with_vjp`, and some part is only used in `f'`.
--
-- **Beware:** if `f` uses any free variables, these will not be taken into
-- **account when computing the adjoint. Make these part of the argument
-- **instead.
-- account when computing the adjoint. Make these part of the argument instead.
def with_vjp 'a 'b (f: a -> b) (f': (res: b) -> (b_adj: b) -> a) (x: a) : b =
intrinsics.with_vjp f f' x

Expand Down
Loading
Loading