Conversation
|
The last failing test is unrelated to AD, but must of course be fixed before this can be merged. I have not finished diagnosing or fixing the test, but it occurs for the "subhistogram" case of code generation for SegHist in the multicore backend - I believe it may be related to multi-versioning in the kernel body, but I'm not sure. It took a while to reproduce because that case is only hit with the right combination of thread count and input size. |
|
The program from this blog post does benefit from vectorized forward-mode AD (although reverse mode is still faster): -- The function to approximate.
def f (x: f32) =
if x == 0 then 0 else f32.exp (-1 / (x * x))
def poly_eval [d] (P: [d + 1]f32) (x: f32) =
f32.sum (map2 (*) P (map (x **) (map f32.i64 (indices P))))
def N : i64 = 1000
def START : f32 = (-1)
def END : f32 = 1
def riemann_integral [d] (P: [d + 1]f32) =
let step_size = (END - START) / f32.i64 N
let g j =
let x = START + f32.i64 j * step_size
let delta = poly_eval P x - f x
in delta * delta * step_size
in f32.sum (tabulate N g)
def poly_init (d: i64) =
tabulate (d + 1) (\i -> f32.i64 (i + 1) * (1 / (f32.i64 d + 1)))
entry fwd d =
tabulate_2d (d + 1) (d + 1) (\i j -> f32.bool (j == i))
|> map (jvp riemann_integral (poly_init d))
entry fwd_vec (gradlen: i64) d =
let num_grads = (d + 11 + gradlen - 1) / gradlen
let seeds gradstart =
tabulate_2d gradlen
(d + 1)
(\i j -> f32.bool (j == i + gradstart))
in map (seeds <-< (gradlen *)) (iota num_grads)
|> map (#[unroll] jmp riemann_integral (poly_init d))
|> flatten
|> take d
entry rev d =
vjp riemann_integral (poly_init d) 1
-- ==
-- entry: fwd rev
-- random input { 8i64 }
-- random input { 128i64 }
-- ==
-- entry: fwd_vec
-- random input { 1i64 8i64 }
-- random input { 1i64 128i64 }
-- random input { 2i64 128i64 }
-- random input { 4i64 128i64 }
-- random input { 8i64 128i64 }
-- random input { 16i64 128i64 }
-- random input { 32i64 128i64 }
-- random input { 64i64 128i64 }
-- random input { 128i64 128i64 }This means that all I just need to decide on the best nomenclature for this feature, and then it is ready. |
|
I had a great idea, inspired by the Jax documentation: the surface-level functions should be |
I have been sitting on this for a while, but maybe it's time to get it finished and merged. It adds facilities for vectorised AD, exposed as the following two functions:
The main advantage is that it allows amortisation of the primal computation over multiple tangent/adjoint computations. I forgot what state I left it in, but all tests work. In principle it is not so difficult to support, actually, but a core trick is that the transformation can always fall back to explicit looping.