Skip to content

Vectorised automatic differentiation#2471

Open
athas wants to merge 99 commits into
masterfrom
ad-vec
Open

Vectorised automatic differentiation#2471
athas wants to merge 99 commits into
masterfrom
ad-vec

Conversation

@athas

@athas athas commented May 29, 2026

Copy link
Copy Markdown
Member

I have been sitting on this for a while, but maybe it's time to get it finished and merged. It adds facilities for vectorised AD, exposed as the following two functions:

-- | As `jvp`, but accepts a vector of seed values. Semantically
-- equivalent to mapping, but may be more efficient.
def jvp_vec 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : [n]b =
  ...

-- | As `vjp`, but accepts a vector of seed values. Semantically
-- equivalent to mapping, but may be more efficient.
def vjp_vec 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a =
  ...

The main advantage is that it allows amortisation of the primal computation over multiple tangent/adjoint computations. I forgot what state I left it in, but all tests work. In principle it is not so difficult to support, actually, but a core trick is that the transformation can always fall back to explicit looping.

@athas

athas commented Jun 10, 2026

Copy link
Copy Markdown
Member Author

The last failing test is unrelated to AD, but must of course be fixed before this can be merged. I have not finished diagnosing or fixing the test, but it occurs for the "subhistogram" case of code generation for SegHist in the multicore backend - I believe it may be related to multi-versioning in the kernel body, but I'm not sure. It took a while to reproduce because that case is only hit with the right combination of thread count and input size.

@athas

athas commented Jun 11, 2026

Copy link
Copy Markdown
Member Author

The program from this blog post does benefit from vectorized forward-mode AD (although reverse mode is still faster):

-- The function to approximate.
def f (x: f32) =
  if x == 0 then 0 else f32.exp (-1 / (x * x))

def poly_eval [d] (P: [d + 1]f32) (x: f32) =
  f32.sum (map2 (*) P (map (x **) (map f32.i64 (indices P))))

def N : i64 = 1000

def START : f32 = (-1)

def END : f32 = 1

def riemann_integral [d] (P: [d + 1]f32) =
  let step_size = (END - START) / f32.i64 N
  let g j =
    let x = START + f32.i64 j * step_size
    let delta = poly_eval P x - f x
    in delta * delta * step_size
  in f32.sum (tabulate N g)

def poly_init (d: i64) =
  tabulate (d + 1) (\i -> f32.i64 (i + 1) * (1 / (f32.i64 d + 1)))

entry fwd d =
  tabulate_2d (d + 1) (d + 1) (\i j -> f32.bool (j == i))
  |> map (jvp riemann_integral (poly_init d))

entry fwd_vec (gradlen: i64) d =
  let num_grads = (d + 11 + gradlen - 1) / gradlen
  let seeds gradstart =
    tabulate_2d gradlen
                (d + 1)
                (\i j -> f32.bool (j == i + gradstart))
  in map (seeds <-< (gradlen *)) (iota num_grads)
     |> map (#[unroll] jmp riemann_integral (poly_init d))
     |> flatten
     |> take d

entry rev d =
  vjp riemann_integral (poly_init d) 1

-- ==
-- entry: fwd rev
-- random input { 8i64 }
-- random input { 128i64 }

-- ==
-- entry: fwd_vec
-- random input { 1i64 8i64 }
-- random input { 1i64 128i64 }
-- random input { 2i64 128i64 }
-- random input { 4i64 128i64 }
-- random input { 8i64 128i64 }
-- random input { 16i64 128i64 }
-- random input { 32i64 128i64 }
-- random input { 64i64 128i64 }
-- random input { 128i64 128i64 }

This means that all I just need to decide on the best nomenclature for this feature, and then it is ready.

@athas athas marked this pull request as ready for review June 11, 2026 18:38
@athas

athas commented Jun 12, 2026

Copy link
Copy Markdown
Member Author

I had a great idea, inspired by the Jax documentation: the surface-level functions should be mjp and jmp, for matrix-Jacobian-product and Jacobian-matrix-product, respectively. For that is exactly what it is! I still need a good term for the overall concept - currently I'm stuck on "vector AD".

@athas athas requested a review from zfnmxt June 13, 2026 07:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants