diff --git a/CHANGELOG.md b/CHANGELOG.md index 2796eb37d9..6c8e246dd5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. uses the hardware support for `f16`, similarly to the CUDA backend. Implemented by Jérôme Wagner. (#2470) +* Vector AD, exposed through the functions `jmp` and `mjp`. + * All opaque values available over the C API can now be decomposed into their constituents. diff --git a/futhark.cabal b/futhark.cabal index e76bd4f53b..44fc90c85c 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -116,7 +116,9 @@ library Futhark.Actions Futhark.AD.Derivatives Futhark.AD.Fwd + Futhark.AD.Shared Futhark.AD.Rev + Futhark.AD.Rev.Acc Futhark.AD.Rev.Loop Futhark.AD.Rev.Hist Futhark.AD.Rev.Map diff --git a/prelude/ad.fut b/prelude/ad.fut index 23f128c55e..c20c93e6a5 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -14,104 +14,130 @@ -- -- Futhark's AD support includes the following: -- --- * Differentiation operators for forward-mode (`jvp`) and reverse-mode --- (`vjp`). +-- * Differential operators for forward-mode (`jvp`@term) and reverse-mode +-- (`vjp`@term). -- --- * Arbitrary control flow in differentiable code. +-- * Almost arbitrary control flow in differentiable code (some limitations +-- apply when using GPU backends, see below). -- -- * Higher order derivatives by nesting differentiation operators, including -- arbitrary mixing of forward- and reverse mode (although using multiple -- rounds of reverse mode is rarely useful and often slow). -- --- * Custom derivatives (`with_vjp`). +-- * Custom derivatives (`with_vjp`@term). +-- +-- * Vector AD (`mjp`@term, `jmp`@term), sometimes also known as "batched" or +-- "multi-directional" AD. -- -- * Checkpointing of sequential loops. -- -- ## Jacobians -- --- For a differentiable function *f* whose input comprise *n* scalars --- and whose output comprises *m* scalars, the --- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) --- for a given input point is an *m* by *n* matrix of scalars that --- each represent a [partial --- derivatives](https://en.wikipedia.org/wiki/Partial_derivative). --- Intuitively, position *(i,j)* of the Jacobian describes how --- sensitive output *i* is to input *j*. The notion of Jacobian --- generalises to functions that accept or produce compound structures --- such as arrays, records, sums, and so on, simply by "flattening --- out" the values and considering only their constituent scalars. --- --- Computing the full Jacobian is usually costly and sometimes not --- necessary, and it is not part of the AD facility provided by --- Futhark. Instead it is possible to parts of the Jacobian. --- --- We can take the product of an an *m* by *n* Jacobian with an --- *n*-element *tangent vector* to produce an *m*-element vector --- (*Jacobian-vector product*). Such a product can be computed in a --- single (augmented) execution of the function *f*, and by choosing --- the tangent vector appropriately we can use this to compute the --- full Jacobian. This is provided by the function `jvp`. +-- For a differentiable function *f* whose input comprise *n* scalars and whose +-- output comprises *m* scalars, the +-- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) for +-- a given input point is an *m* by *n* matrix of scalars that each represent a +-- [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative). +-- Intuitively, position *(i,j)* of the Jacobian describes how sensitive output +-- *i* is to input *j*. The notion of Jacobian generalises to functions that +-- accept or produce compound structures such as arrays, records, sums, and so +-- on, simply by "flattening out" the values and considering only their +-- constituent scalars. +-- +-- Computing the full Jacobian is usually costly and sometimes not necessary, +-- and it is not part of the AD facility provided by Futhark. Instead it is +-- possible to compute parts of the Jacobian, which semantically (but not +-- operationally) can be seen as multiplying the Jacobian with a vector, +-- producing a vector. However, it is important to understand that the full +-- Jacobian is *not* constructed as an intermediate step. +-- +-- We can take the product of an *m* by *n* Jacobian with an *n*-element +-- *tangent vector* to produce an *m*-element vector (*Jacobian-vector +-- product*). Such a product can be computed in a single (augmented) execution +-- of the function *f*. This is provided by the function `jvp`. -- -- We can also take the product of an *m*-element vector *cotangent -- vector* with the *m* by *n* Jacobian to produce an *n*-element --- vector (*Vector-Jacobian product*). This too can be computed in a +-- vector (*vector-Jacobian product*). This too can be computed in a -- single execution of *f*, with `vjp`. -- --- We can use the `jvp` function to produce a *column* of the full --- Jacobian, and `vjp` to produce a *row*. Which is superior for a --- given situation depends on whether the function has more inputs or --- outputs. +-- A tangent has the same structure as the input and represents a direction in +-- input space. A cotangent has the same structure as the output and represents +-- sensitivities flowing backwards through the computation. +-- +-- Using an elementary (co-)tangent vector, we can use the `jvp` function to +-- produce a *column* of the full Jacobian, and `vjp` to produce a *row*, with +-- the nonzero element of the vector identifying which column or row is +-- extracted. Which is superior for a given situation depends on whether the +-- function has more inputs or outputs. -- --- You can freely nest `vjp` and `jvp` to compute higher-order --- derivatives. +-- We can freely nest `vjp` and `jvp` to compute higher-order derivatives. -- -- ## Efficiency -- -- Both `jvp` and `vjp` work by transforming the program to carry -- along extra information associated with each scalar value. -- --- In the case of `jvp`, this extra information takes the form of an --- additional scalar representing the tangent, which is then --- propagated in each scalar computation using essentially the [chain --- rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp` --- has a memory overhead of approximately *2x*, and a computational --- overhead of slightly more, but usually less than *4x*. --- --- In the case of `vjp`, since our starting point is a *cotangent*, --- the function is essentially first run forward, then backwards (the --- *return sweep*) to propagate the cotangent. During the return --- sweep, all intermediate results computed during the forward sweep --- must still be available, and must therefore be stored in memory --- during the forward sweep. This means that the memory usage of `vjp` --- is proportional to the number of sequential steps of the original --- function (essentially turning *time* into *space*). The compiler --- does a nontrivial amount of optimisation to ameliorate this --- overhead (see [AD for an Array Language with Nested --- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)), --- but it can still be substantial for programs with deep sequential --- loops. +-- In the case of `jvp` ("forward mode", or "tangent mode"), this extra +-- information takes the form of an additional scalar representing the tangent, +-- which is then propagated in each scalar computation using essentially the +-- [chain rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp` has +-- a memory overhead of approximately *2x*, and a computational overhead of +-- slightly more, but usually less than *4x*. +-- +-- In the case of `vjp` ("reverse mode" or "adjoint mode"), since our starting +-- point is a *cotangent*, the function is essentially first run forward, then +-- backwards (the *return sweep*) to propagate the cotangent. During the return +-- sweep, all intermediate results computed during the forward sweep must still +-- be available, and must therefore be stored in memory during the forward sweep +-- - this is called "the tape". This means that the memory usage of `vjp` is +-- proportional to the number of sequential steps of the original function +-- (essentially turning *time* into *space*). The compiler does a nontrivial +-- amount of optimisation to ameliorate this overhead (see [AD for an Array +-- Language with Nested +-- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)), but it can +-- still be substantial for programs with deep sequential loops. +-- +-- Nesting `vjp`, understood as applying `vjp` to the result of `vjp`, is +-- usually a bad idea, as the code structure produced by `vjp` is fairly +-- complicated, due to the tape management. Passing the output of `jvp` to +-- `vjp`, or the other way, is however fine. As a rule of thumb, whenever you +-- stack multiple differential operators, make sure only one of them is `vjp` or +-- related ones. +-- +-- When using vector AD (`mjp`@term/`jmp`@term), each scalar is associated with +-- a vector of tangents or cotangents, and the space overhead for storing these +-- is therefore multiplied with the vector size. However, in the case of `vjp`, +-- the intermediate results are only stored once. It varies on a case-by-case +-- basis whether vector AD is faster than using `map` on top of +-- `vjp`@term/`jvp`@term. Vector AD essentially converts propagation of +-- (co-)tangents from scalar to array operations, which can have a significant +-- impact on memory accesses, depending on how the compiler manages to optimise +-- the resulting code. It is hard to predict whether this offsets the reduction +-- in primal work. If the vector size is a constant, and the `#[unroll]` +-- attribute is put on the AD operator, then the vectors become unrolled (turned +-- into tuples, essentially), although this should only be done when the vector +-- size is quite small, as the increase in code size is substantial. -- -- ## Differentiable functions -- --- AD only gives meaningful results for differentiable functions. The --- Futhark type system does not distinguish differentiable or --- non-differentiable operations. As a rule of thumb, a function is --- differentiable if its results are computed using a composition of --- primitive floating-point operations, without ever converting to or --- from integers. +-- AD only gives meaningful results for differentiable functions. The Futhark +-- type system does not distinguish differentiable from non-differentiable +-- operations. As a rule of thumb, a function is differentiable if its results +-- are computed using a composition of primitive floating-point operations, +-- without ever converting to or from integers. Most functions will also have +-- discontinuities around values that influence control flow. -- --- Note that a function whose input or output is a sum type with more --- than one constructor is *not* differentiable (or at least the --- sum-typed part is not). This is because the choice of constructor --- is not a continuous quantity. +-- Note that a function whose input or output is a sum type with more than one +-- constructor is *not* differentiable (or at least the sum-typed part is not). +-- This is because the choice of constructor is not a continuous quantity. -- -- ## Limitations -- --- `jvp` is expected to work in all cases. `vjp` has limitations when --- using the GPU backends similar to those for irregular flattening. --- Specifically, you should avoid structures with variant sizes, such --- as loops that carry an array that changes size through the --- execution of the loop. +-- `jvp` is expected to work in all cases. `vjp` has limitations when using the +-- GPU backends similar to those for irregular flattening. Specifically, you +-- should avoid structures with variant sizes, such as loops that carry an array +-- that changes size through the execution of the loop. -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. @@ -123,6 +149,20 @@ def jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = intrinsics.vjp2 f x y' +-- | Jacobian-Matrix Product, returning also the primal result. As `jvp2`, but +-- accepts an array of seed vectors (hence "matrix", although transposed). +-- Semantically equivalent to mapping, but may be more efficient. If used with +-- `#[unroll]`, tangent calculations are unrolled when possible. +def jmp2 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = + intrinsics.jmp2 f x x' + +-- | Matrix-Jacobian Product, returning also the primal result. As `vjp2`, but +-- accepts an array of seed vectors (hence "matrix"). Semantically equivalent to +-- mapping, but may be more efficient. If used with `#[unroll]`, adjoint +-- calculations are unrolled when possible. +def mjp2 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) = + intrinsics.mjp2 f x y' + -- | Jacobian-Vector Product ("forward mode"). def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = (jvp2 f x x').1 @@ -131,6 +171,16 @@ def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a = (vjp2 f x y').1 +-- | Jacobian-Matrix Product. As `jvp`, but accepts a vector of seed values. +-- Semantically equivalent to mapping, but may be more efficient. +def jmp 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : [n]b = + (jmp2 f x x').1 + +-- | Matrix-Jacobian product. As `vjp`, but accepts a vector of seed values. +-- Semantically equivalent to mapping, but may be more efficient. +def mjp 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a = + (mjp2 f x y').1 + -- | Provide custom reverse-mode adjoint code for a given function. This is -- useful when the adjoint synthesised by AD is not as good as one that is known -- analytically. @@ -144,8 +194,7 @@ def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a = -- primal result of `with_vjp`, and some part is only used in `f'`. -- -- **Beware:** if `f` uses any free variables, these will not be taken into --- **account when computing the adjoint. Make these part of the argument --- **instead. +-- account when computing the adjoint. Make these part of the argument instead. def with_vjp 'a 'b (f: a -> b) (f': (res: b) -> (b_adj: b) -> a) (x: a) : b = intrinsics.with_vjp f f' x diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index a0e611321c..f374f12b76 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -3,21 +3,22 @@ module Futhark.AD.Fwd (fwdJVP) where import Control.Monad -import Control.Monad.RWS.Strict +import Control.Monad.Identity +import Control.Monad.Reader import Control.Monad.State.Strict -import Data.Bifunctor (second) -import Data.List (transpose) +import Data.Bifunctor (bimap, second) +import Data.Foldable +import Data.Functor.Product import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M +import Data.Tuple (Solo (..), getSolo) import Futhark.AD.Derivatives +import Futhark.AD.Shared import Futhark.Analysis.PrimExp.Convert import Futhark.Builder -import Futhark.Construct import Futhark.IR.SOACS - -zeroTan :: Type -> ADM SubExp -zeroTan (Prim t) = pure $ constant $ blankPrimValue t -zeroTan t = error $ "zeroTan on non-primitive type: " ++ prettyString t +import Futhark.Tools +import Futhark.Util (interleave, splitAt3, unterleave) zeroExp :: Type -> Exp SOACS zeroExp (Prim pt) = @@ -26,11 +27,18 @@ zeroExp (Array pt shape _) = BasicOp $ Replicate shape $ Constant $ blankPrimValue pt zeroExp t = error $ "zeroExp: " ++ show t -tanType :: TypeBase s u -> ADM (TypeBase s u) +tanType :: (ArrayShape s, Monoid u) => TypeBase s u -> ADM (TypeBase s u) tanType (Acc acc ispace ts u) = do - ts_tan <- mapM tanType ts - pure $ Acc acc ispace (ts ++ ts_tan) u -tanType t = pure t + acc_tan <- tangent acc + tan_shape <- askShape + pure $ Acc acc_tan (tan_shape <> ispace) ts u +tanType t = do + shape <- askShape + pure $ arrayOf (Prim (elemType t)) (shape `prependShape` arrayShape t) u + where + u = case t of + Array _ _ u' -> u' + _ -> mempty slocal' :: ADM a -> ADM a slocal' = slocal id @@ -48,12 +56,18 @@ data RState = RState stateNameSource :: VNameSource } -newtype ADM a = ADM (BuilderT SOACS (State RState) a) +data FEnv = FEnv + { envTanShape :: Shape, + envAttrs :: Attrs + } + +newtype ADM a = ADM (BuilderT SOACS (ReaderT FEnv (State RState)) a) deriving ( Functor, Applicative, Monad, MonadState RState, + MonadReader FEnv, MonadFreshNames, HasScope SOACS, LocalScope SOACS @@ -72,12 +86,18 @@ instance MonadFreshNames (State RState) where getNameSource = gets stateNameSource putNameSource src = modify (\env -> env {stateNameSource = src}) -runADM :: (MonadFreshNames m) => ADM a -> m a -runADM (ADM m) = +askShape :: ADM Shape +askShape = ADM $ lift $ asks envTanShape + +runADM :: (MonadFreshNames m) => Shape -> Attrs -> ADM a -> m a +runADM shape attrs (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState - (fst <$> runBuilderT m mempty) + ( runReaderT + (fst <$> runBuilderT m mempty) + (FEnv shape attrs) + ) (RState mempty vn) tanVName :: VName -> ADM VName @@ -89,27 +109,20 @@ insertTan v v' = class TanBuilder a where newTan :: a -> ADM a - bundleNew :: a -> ADM [a] + bundleNew :: a -> ADM (a, a) bundleNewList :: (TanBuilder a) => [a] -> ADM [a] -bundleNewList = fmap mconcat . mapM bundleNew - -instance TanBuilder (PatElem (TypeBase s u)) where - newTan (PatElem p t) - | isAcc t = do - insertTan p p - t' <- tanType t - pure $ PatElem p t' - | otherwise = do - p' <- tanVName p - insertTan p p' - t' <- tanType t - pure $ PatElem p' t' - bundleNew pe@(PatElem _ t) = do +bundleNewList = fmap (uncurry interleave . unzip) . mapM bundleNew + +instance (ArrayShape s, Monoid u) => TanBuilder (PatElem (TypeBase s u)) where + newTan (PatElem p t) = do + p' <- tanVName p + insertTan p p' + t' <- tanType t + pure $ PatElem p' t' + bundleNew pe = do pe' <- newTan pe - if isAcc t - then pure [pe'] - else pure [pe, pe'] + pure (pe, pe') newTanPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t) newTanPat (Pat pes) = Pat <$> mapM newTan pes @@ -117,41 +130,33 @@ newTanPat (Pat pes) = Pat <$> mapM newTan pes bundleNewPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t) bundleNewPat (Pat pes) = Pat <$> bundleNewList pes -instance TanBuilder (Param (TypeBase s u)) where +instance (ArrayShape s, Monoid u) => TanBuilder (Param (TypeBase s u)) where newTan (Param _ p t) = do PatElem p' t' <- newTan $ PatElem p t pure $ Param mempty p' t' - bundleNew param@(Param _ _ (Prim Unit)) = - pure [param] - bundleNew param@(Param _ _ t) = do + bundleNew param = do param' <- newTan param - if isAcc t - then pure [param'] - else pure [param, param'] + pure (param, param') -instance (Tangent a) => TanBuilder (Param (TypeBase s u), a) where +instance (TanBuilder a, Tangent b) => TanBuilder (a, b) where newTan (p, x) = (,) <$> newTan p <*> tangent x bundleNew (p, x) = do - b <- bundleNew p + p' <- newTan p x_tan <- tangent x - pure $ zip b [x, x_tan] + pure ((p, x), (p', x_tan)) class Tangent a where tangent :: a -> ADM a - bundleTan :: a -> ADM [a] + bundleTan :: a -> ADM (a, a) -instance Tangent (TypeBase s u) where +instance (ArrayShape s, Monoid u) => Tangent (TypeBase s u) where tangent = tanType - bundleTan t - | isAcc t = do - t' <- tangent t - pure [t'] - | otherwise = do - t' <- tangent t - pure [t, t'] + bundleTan t = do + t' <- tangent t + pure (t, t') bundleTangents :: (Tangent a) => [a] -> ADM [a] -bundleTangents = (mconcat <$>) . mapM bundleTan +bundleTangents = fmap (uncurry interleave . unzip) . mapM bundleTan instance Tangent VName where tangent v = do @@ -160,26 +165,120 @@ instance Tangent VName where Just v_tan -> pure v_tan Nothing -> do t <- lookupType v - letExp (baseName v <> "_implicit_tan") $ zeroExp t + when (isAcc t) $ + error $ + "Missing tangent for accumulator " <> prettyString v + tan_shape <- askShape + letExp (baseName v <> "_implicit_tan") $ zeroExp $ t `arrayOfShape` tan_shape bundleTan v = do - t <- lookupType v - if isAcc t - then pure [v] - else do - v_tan <- tangent v - pure [v, v_tan] + v_tan <- tangent v + pure (v, v_tan) instance Tangent SubExp where - tangent (Constant c) = zeroTan $ Prim $ primValueType c + tangent (Constant c) = do + tan_shape <- askShape + if tan_shape == mempty + then pure $ constant $ blankPrimValue pt + else letSubExp "const_implicit_tan" $ zeroExp $ Prim pt `arrayOfShape` tan_shape + where + pt = primValueType c tangent (Var v) = Var <$> tangent v bundleTan c@Constant {} = do c_tan <- tangent c - pure [c, c_tan] - bundleTan (Var v) = fmap Var <$> bundleTan v + pure (c, c_tan) + bundleTan (Var v) = bimap Var Var <$> bundleTan v instance Tangent SubExpRes where tangent (SubExpRes cs se) = SubExpRes cs <$> tangent se - bundleTan (SubExpRes cs se) = map (SubExpRes cs) <$> bundleTan se + bundleTan (SubExpRes cs se) = bimap (SubExpRes cs) (SubExpRes cs) <$> bundleTan se + +withTan :: + SubExp -> + (SubExp -> ADM (Exp SOACS)) -> + ADM (Exp SOACS) +withTan x f = do + shape <- askShape + x_tan <- tangent x + mapNest shape (MkSolo x_tan) (f . getSolo) + +withTansI :: + VName -> + [SubExp] -> + ([SubExp] -> VName -> [SubExp] -> ADM (Exp SOACS)) -> + ADM (Exp SOACS) +withTansI x ys f = do + shape <- askShape + x_tan <- tangent x + ys_tan <- mapM tangent ys + if shape == mempty + then f [] x_tan ys_tan + else do + let w = shapeSize 0 shape + ys_tan_vs <- mapM asVName ys_tan + iota_p <- newParam "iota_p" $ Prim int64 + x_tan_p <- newParam "x_tanp" . rowType =<< lookupType x_tan + ys_tan_ps <- mapM (newParam "y_tanp" . rowType <=< lookupType) ys_tan_vs + lam <- mkLambda (iota_p : x_tan_p : ys_tan_ps) $ do + fmap (subExpsRes . pure) . letSubExp "tan" + =<< f + [Var $ paramName iota_p] + (paramName x_tan_p) + (map (Var . paramName) ys_tan_ps) + iota_v <- letExp "iota" $ iota64 w + Op . Screma w (iota_v : x_tan : ys_tan_vs) <$> mapSOAC lam + +withTans :: + PrimType -> + SubExp -> + SubExp -> + (PrimExp VName -> PrimExp VName -> PrimExp VName) -> + ADM (Exp SOACS) +withTans t x y f = do + shape <- askShape + x_tan <- tangent x + y_tan <- tangent y + mapNest shape (Pair (Identity x_tan) (Identity y_tan)) $ \xy -> do + Pair (Identity x_tan_v) (Identity y_tan_v) <- traverse asVName xy + toExp $ f (LeafExp x_tan_v t) (LeafExp y_tan_v t) + +withAnyTans :: + (Traversable f) => + f SubExp -> + ([PrimExp VName] -> PrimExp VName) -> + ADM (Exp SOACS) +withAnyTans xs f = do + shape <- askShape + xs_tan <- traverse tangent xs + mapNest shape xs_tan $ \xs_tan' -> do + xs_tan'' <- forM xs_tan' $ \se -> do + ~(Prim t) <- subExpType se + pure $ primExpFromSubExp t se + toExp $ f $ toList xs_tan'' + +bindTanPat :: Pat Type -> StmAux () -> Exp SOACS -> ADM () +bindTanPat pat_tan aux e = do + attrs <- asks envAttrs + auxing aux . attributing attrs . letBind pat_tan $ e + +bindTan :: + Pat Type -> + StmAux () -> + SubExp -> + (SubExp -> ADM (Exp SOACS)) -> + ADM () +bindTan pat_tan aux x f = do + bindTanPat pat_tan aux =<< withTan x f + +bindTans :: + Pat Type -> + StmAux () -> + PrimType -> + SubExp -> + SubExp -> + (PrimExp VName -> PrimExp VName -> PrimExp VName) -> + ADM () +bindTans pat_tan aux t x y f = do + bindTanPat pat_tan aux =<< withTans t x y f basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM () basicFwd pat aux op = do @@ -192,136 +291,238 @@ basicFwd pat aux op = do se_tan <- tangent se addStm $ Let pat_tan aux $ BasicOp $ Opaque opaqueop se_tan ArrayLit ses t -> do + tan_shape <- askShape ses_tan <- mapM tangent ses - addStm $ Let pat_tan aux $ BasicOp $ ArrayLit ses_tan t + if tan_shape == mempty + then + addStm $ Let pat_tan aux $ BasicOp $ ArrayLit ses_tan t + else do + pat_tan_tr <- letExp "pat_tan_tr" $ BasicOp $ ArrayLit ses_tan $ t `arrayOfShape` tan_shape + pat_tan_tr_t <- lookupType pat_tan_tr + let perm = vecPerm tan_shape pat_tan_tr_t + addStm $ Let pat_tan aux $ BasicOp $ Rearrange pat_tan_tr perm UnOp unop x -> do let t = unOpType unop x_pe = primExpFromSubExp t x dx = pdUnOp unop x_pe - x_tan <- primExpFromSubExp t <$> tangent x - auxing aux $ letBindNames (patNames pat_tan) <=< toExp $ x_tan ~*~ dx + bindTan pat_tan aux x $ \x_tan -> + toExp $ primExpFromSubExp t x_tan ~*~ dx BinOp bop x y -> do let t = binOpType bop - x_tan <- primExpFromSubExp t <$> tangent x - y_tan <- primExpFromSubExp t <$> tangent y - let (wrt_x, wrt_y) = - pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) - auxing aux $ - letBindNames (patNames pat_tan) <=< toExp $ - x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y - CmpOp {} -> - addStm $ Let pat_tan aux $ zeroExp $ Prim Bool - ConvOp cop x -> do - x_tan <- tangent x - addStm $ Let pat_tan aux $ BasicOp $ ConvOp cop x_tan + bindTans pat_tan aux t x y $ \x_tan y_tan -> + let (wrt_x, wrt_y) = + pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) + in x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y + CmpOp {} -> do + tan_shape <- askShape + addStm $ Let pat_tan aux $ zeroExp $ Prim Bool `arrayOfShape` tan_shape + ConvOp cop x -> + bindTan pat_tan aux x $ \x_tan -> + pure $ BasicOp $ ConvOp cop x_tan Assert {} -> pure () Index arr slice -> do + dims <- shapeDims <$> askShape arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice + let slice' = Slice $ map sliceDim dims <> unSlice slice + addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice' Update safety arr slice se -> do + dims <- shapeDims <$> askShape arr_tan <- tangent arr se_tan <- tangent se - addStm $ Let pat_tan aux $ BasicOp $ Update safety arr_tan slice se_tan + let slice' = Slice $ map sliceDim dims <> unSlice slice + addStm $ Let pat_tan aux $ BasicOp $ Update safety arr_tan slice' se_tan Concat d (arr :| arrs) w -> do + r <- shapeRank <$> askShape arr_tan <- tangent arr arrs_tans <- mapM tangent arrs - addStm $ Let pat_tan aux $ BasicOp $ Concat d (arr_tan :| arrs_tans) w + addStm $ Let pat_tan aux $ BasicOp $ Concat (d + r) (arr_tan :| arrs_tans) w Manifest arr ds -> do + r <- shapeRank <$> askShape arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Manifest arr_tan ds + addStm . Let pat_tan aux . BasicOp $ + Manifest arr_tan ([0 .. r - 1] ++ map (+ r) ds) Iota n _ _ it -> do - addStm $ Let pat_tan aux $ BasicOp $ Replicate (Shape [n]) (intConst it 0) - Replicate n x -> do - x_tan <- tangent x - addStm $ Let pat_tan aux $ BasicOp $ Replicate n x_tan - Scratch t shape -> - addStm $ Let pat_tan aux $ BasicOp $ Scratch t shape + shape <- askShape + addStm . Let pat_tan aux . BasicOp $ + Replicate (shape <> Shape [n]) (intConst it 0) + Replicate n x -> + bindTan pat_tan aux x $ \x_tan -> + pure $ BasicOp $ Replicate n x_tan + Scratch t shape -> do + tan_shape <- askShape + addStm $ Let pat_tan aux $ BasicOp $ Scratch t $ shapeDims tan_shape <> shape Reshape arr reshape -> do + shape <- askShape arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Reshape arr_tan reshape + addStm $ Let pat_tan aux $ BasicOp $ Reshape arr_tan (newshapeInner shape reshape) Rearrange arr perm -> do + r <- shapeRank <$> askShape arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Rearrange arr_tan perm + addStm . Let pat_tan aux . BasicOp $ + Rearrange arr_tan ([0 .. r - 1] <> map (+ r) perm) _ -> error $ "basicFwd: Unsupported op " ++ prettyString op fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdLambda l@(Lambda params ret body) = - Lambda <$> bundleNewList params <*> bundleTangents ret <*> inScopeOf l (fwdBody body) - -fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdStreamLambda l@(Lambda params ret body) = - Lambda <$> ((take 1 params ++) <$> bundleNewList (drop 1 params)) <*> bundleTangents ret <*> inScopeOf l (fwdBody body) - -interleave :: [a] -> [a] -> [a] -interleave xs ys = concat $ transpose [xs, ys] - -zeroFromSubExp :: SubExp -> ADM VName -zeroFromSubExp (Constant c) = - letExp "zero" . BasicOp . SubExp . Constant $ - blankPrimValue (primValueType c) -zeroFromSubExp (Var v) = do - t <- lookupType v - letExp "zero" $ zeroExp t +fwdLambda (Lambda params _ body) = do + params' <- bundleNewList params + mkLambda params' $ bodyBind =<< fwdBody body + +fwdWithAccLambda :: [WithAccInput SOACS] -> Lambda SOACS -> ADM (Lambda SOACS) +fwdWithAccLambda inputs (Lambda params _ body) = do + let (cert_params, acc_params) = splitAt (length inputs) params + cert_params_tan <- replicateM (length inputs) $ newParam "acc_cert_tan" $ Prim Unit + acc_params_tan <- zipWithM mkAccParam (map paramName cert_params_tan) inputs + + mkLambda (cert_params <> cert_params_tan <> acc_params <> acc_params_tan) $ do + zipWithM_ + insertTan + (map paramName (cert_params <> acc_params)) + (map paramName (cert_params_tan <> acc_params_tan)) + bodyBind =<< fwdBody body + where + mkAccParam c (shape, arrs, _) = do + tan_shape <- askShape + ts <- map (stripArray (shapeRank shape)) <$> mapM lookupType arrs + newParam "acc_p_tan" $ Acc c (tan_shape <> shape) ts NoUniqueness + +fwdStreamLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS) +fwdStreamLambda num_accs (Lambda params _ body) = do + tan_shape <- askShape + let (chunk_params, acc_params, arr_params) = splitAt3 1 num_accs params + acc_params' <- bundleNewList acc_params + (arr_params', arr_params'_tan) <- mapAndUnzipM onArrParam arr_params + let params' = + chunk_params <> acc_params' <> interleave arr_params' arr_params'_tan + mkLambda params' $ do + zipWithM_ (trArrParamTan tan_shape) arr_params' arr_params'_tan + (acc_res, map_res) <- fmap (splitAt (num_accs * 2)) . bodyBind =<< fwdBody body + let (map_res_primal, map_res_tan) = unterleave map_res + map_res_tan' <- mapM (trMapResTan tan_shape) map_res_tan + pure $ acc_res <> interleave map_res_primal map_res_tan' + where + -- Array parameters need to be treated specially as the chunk parameter + -- must always be outermost. + onArrParam p = do + shape <- askShape + (p', p_tan) <- bundleNew p + let perm = vecPerm shape $ paramType p_tan + pure (p', p_tan {paramDec = rearrangeType perm (paramType p_tan)}) + + -- Put the tangent shape back in the outermost position. + trArrParamTan tan_shape p p_tan = do + let perm = rearrangeInverse $ vecPerm tan_shape $ paramType p_tan + v <- + letExp (baseName (paramName p_tan)) . BasicOp $ + Rearrange (paramName p_tan) perm + insertTan (paramName p) v + + -- Put the chunk size back in the outermost position. + trMapResTan tan_shape (SubExpRes cs ~(Var v)) = do + v_t <- lookupType v + let perm = vecPerm tan_shape v_t + fmap varRes . certifying cs $ letExp (baseName v) . BasicOp $ Rearrange v perm + +pushTanShape :: VName -> ADM VName +pushTanShape v = do + tan_shape <- askShape + v_t <- lookupType v + if tan_shape == mempty || arrayShape v_t == tan_shape || isAcc v_t + then pure v + else do + let perm = vecPerm tan_shape v_t + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm + +soacInputsWithTangents :: [VName] -> ADM [VName] +soacInputsWithTangents xs = do + xs_tans <- mapM (pushTanShape <=< tangent) xs + pure $ interleave xs xs_tans + +soacResPat :: Int -> Int -> Pat Type -> ADM (Pat Type, [(Pat Type, VName)]) +soacResPat scan_res red_res (Pat pes) = do + pes_tan <- mapM newTan pes + bimap (Pat . interleave pes) mconcat . unzip <$> zipWithM tweakPatElem [0 ..] pes_tan + where + isRedRes i = i >= scan_res && i < scan_res + red_res + tweakPatElem i pe@(PatElem v v_t) = do + tan_shape <- askShape + if isRedRes i || tan_shape == mempty || arrayShape v_t == tan_shape || isAcc v_t + then pure (pe, []) + else do + let perm = vecPerm tan_shape v_t + v' <- newName v + pure (PatElem v' $ rearrangeType perm v_t, [(Pat [pe], v')]) fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds post_lam)) = do - pat' <- bundleNewPat pat - xs' <- bundleTangents xs + (pat', to_transpose) <- soacResPat (scanResults scs) (redResults reds) pat + xs' <- soacInputsWithTangents xs f' <- fwdLambda f scs' <- mapM fwdScan scs reds' <- mapM fwdRed reds post_lam' <- fwdLambda post_lam addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds' post_lam' + tan_shape <- askShape + forM_ to_transpose $ \(rpat, v) -> do + v_t <- lookupType v + let perm = rearrangeInverse $ vecPerm tan_shape v_t + letBind rpat $ BasicOp $ Rearrange v perm where + zeroTans lam = + mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType lam + fwdScan :: Scan SOACS -> ADM (Scan SOACS) fwdScan sc = do op' <- fwdLambda $ scanLambda sc - neutral_tans <- mapM zeroFromSubExp $ scanNeutral sc + neutral_tans <- zeroTans $ scanLambda sc pure $ Scan - { scanNeutral = scanNeutral sc `interleave` map Var neutral_tans, + { scanNeutral = scanNeutral sc `interleave` neutral_tans, scanLambda = op' } fwdRed :: Reduce SOACS -> ADM (Reduce SOACS) fwdRed red = do op' <- fwdLambda $ redLambda red - neutral_tans <- mapM zeroFromSubExp $ redNeutral red + neutral_tans <- zeroTans $ redLambda red pure $ Reduce { redComm = redComm red, redLambda = op', - redNeutral = redNeutral red `interleave` map Var neutral_tans + redNeutral = redNeutral red `interleave` neutral_tans } -fwdSOAC pat aux (Stream size xs nes lam) = do +fwdSOAC pat aux (Stream size xs accs lam) = do pat' <- bundleNewPat pat - lam' <- fwdStreamLambda lam - xs' <- bundleTangents xs - nes_tan <- mapM (fmap Var . zeroFromSubExp) nes - let nes' = interleave nes nes_tan - addStm $ Let pat' aux $ Op $ Stream size xs' nes' lam' + lam' <- fwdStreamLambda (length accs) lam + xs' <- soacInputsWithTangents xs + accs_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType <=< subExpType) accs + let accs' = interleave accs accs_tan + addStm $ Let pat' aux $ Op $ Stream size xs' accs' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do - pat' <- bundleNewPat pat + -- TODO: this is probably not very efficient in the vector case as we end up + -- with a dreadful update operator that involves arrays. + (pat', to_transpose) <- soacResPat 0 0 pat ops' <- mapM fwdHist ops bucket_fun' <- fwdHistBucket bucket_fun - arrs' <- bundleTangents arrs + arrs' <- soacInputsWithTangents arrs addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' + tan_shape <- askShape + forM_ to_transpose $ \(rpat, v) -> do + v_t <- lookupType v + let perm = rearrangeInverse $ vecPerm tan_shape v_t + letBind rpat $ BasicOp $ Rearrange v perm where n_indices = sum $ map (shapeRank . histShape) ops fwdBodyHist (Body _ stms res) = buildBody_ $ do mapM_ fwdStm stms let (res_is, res_vs) = splitAt n_indices res (res_is ++) <$> bundleTangents res_vs - fwdHistBucket l@(Lambda params ret body) = - let (r_is, r_vs) = splitAt n_indices ret - in Lambda - <$> bundleNewList params - <*> ((r_is ++) <$> bundleTangents r_vs) - <*> inScopeOf l (fwdBodyHist body) + fwdHistBucket (Lambda params _ body) = do + params' <- bundleNewList params + mkLambda params' $ bodyBind =<< fwdBodyHist body fwdHist :: HistOp SOACS -> ADM (HistOp SOACS) fwdHist (HistOp shape rf dest nes op) = do - dest' <- bundleTangents dest - nes_tan <- mapM (fmap Var . zeroFromSubExp) nes + dest' <- soacInputsWithTangents dest + nes_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType op op' <- fwdLambda op pure $ HistOp @@ -344,20 +545,17 @@ fwdSOAC _ _ VJP {} = fwdStm :: Stm SOACS -> ADM () fwdStm (Let pat aux (BasicOp (UpdateAcc safety acc i x))) = do - pat' <- bundleNewPat pat - x' <- bundleTangents x - acc_tan <- tangent acc - addStm $ Let pat' aux $ BasicOp $ UpdateAcc safety acc_tan i x' + pat_tan <- newTanPat pat + addStm $ Let pat aux $ BasicOp $ UpdateAcc safety acc i x + addStm . Let pat_tan aux <=< withTansI acc x $ \is acc_tan x_tan' -> do + pure $ BasicOp $ UpdateAcc safety acc_tan (is <> i) x_tan' fwdStm stm@(Let pat aux (BasicOp e)) = do -- XXX: this has to be too naive. - unless (any isAcc $ patTypes pat) $ - addStm stm + unless (any isAcc $ patTypes pat) $ addStm stm basicFwd pat aux e -fwdStm stm@(Let pat _ (Apply f args _ _)) +fwdStm stm@(Let pat aux (Apply f args _ _)) | Just (ret, argts) <- M.lookup f builtInFunctions = do addStm stm - arg_tans <- - zipWith primExpFromSubExp argts <$> mapM (tangent . fst) args pat_tan <- newTanPat pat let arg_pes = zipWith primExpFromSubExp argts (map fst args) case pdBuiltin f arg_pes of @@ -375,8 +573,10 @@ fwdStm stm@(Let pat _ (Apply f args _ _)) _ -> error $ "fwdStm.convertTo: " ++ prettyString (f, tt, e_t) where e_t = primExpType e - letBindNames (patNames pat_tan) - =<< toExp (foldl1 (~+~) $ zipWith (~*~) (map (convertTo ret) arg_tans) derivs) + + auxing aux . letBind pat_tan <=< withAnyTans (map fst args) $ + \arg_tans' -> + foldl1 (~+~) $ zipWith (~*~) (map (convertTo ret) arg_tans') derivs fwdStm (Let pat aux (Match ses cases defbody (MatchDec ret ifsort))) = do cases' <- slocal' $ mapM (traverse fwdBody) cases defbody' <- slocal' $ fwdBody defbody @@ -387,36 +587,35 @@ fwdStm (Let pat aux (Loop val_pats loop@(WhileLoop v) body)) = do val_pats' <- bundleNewList val_pats pat' <- bundleNewPat pat body' <- - localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $ + localScope (scopeOfFParams (map fst val_pats') <> scopeOfLoopForm loop) . slocal' $ fwdBody body addStm $ Let pat' aux $ Loop val_pats' (WhileLoop v) body' fwdStm (Let pat aux (Loop val_pats loop@(ForLoop i it bound) body)) = do pat' <- bundleNewPat pat val_pats' <- bundleNewList val_pats body' <- - localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $ + localScope (scopeOfFParams (map fst val_pats') <> scopeOfLoopForm loop) . slocal' $ fwdBody body addStm $ Let pat' aux $ Loop val_pats' (ForLoop i it bound) body' fwdStm (Let pat aux (WithAcc inputs lam)) = do - inputs' <- forM inputs $ \(shape, arrs, op) -> do + inputs_tan <- forM inputs $ \(shape, arrs, op) -> do arrs_tan <- mapM tangent arrs + tan_shape <- askShape op' <- case op of Nothing -> pure Nothing Just (op_lam, nes) -> do - nes_tan <- mapM (fmap Var . zeroFromSubExp) nes - op_lam' <- fwdLambda op_lam - case op_lam' of - Lambda ps ret body -> do - let op_lam'' = Lambda (removeIndexTans (shapeRank shape) ps) ret body - pure $ Just (op_lam'', interleave nes nes_tan) - pure (shape, arrs <> arrs_tan, op') + -- We assume that op_lam has unit partial derivatives (i.e., is some + -- kind of addition). This is the case for all WithAccs produced by VJP. + lams <- mapM addLambda $ lambdaReturnType op_lam + -- Horizontally fuse the lambdas to produce a single one. + idx_params <- replicateM (shapeRank shape) $ newParam "idx" $ Prim int64 + let (xs, ys) = bimap concat concat $ unzip $ map (splitAt 1 . lambdaParams) lams + op_lam' <- mkLambda (idx_params <> xs <> ys) $ mconcat <$> mapM (bodyBind . lambdaBody) lams + pure $ Just (op_lam', nes) + pure (tan_shape <> shape, arrs_tan, op') pat' <- bundleNewPat pat - lam' <- fwdLambda lam - addStm $ Let pat' aux $ WithAcc inputs' lam' - where - removeIndexTans 0 ps = ps - removeIndexTans i (p : _ : ps) = p : removeIndexTans (i - 1) ps - removeIndexTans _ ps = ps + lam' <- fwdWithAccLambda inputs lam + addStm $ Let pat' aux $ WithAcc (interleave inputs inputs_tan) lam' fwdStm (Let pat aux (Op soac)) = fwdSOAC pat aux soac fwdStm stm = error $ "unhandled forward mode AD for Stm: " ++ prettyString stm ++ "\n" ++ show stm @@ -431,10 +630,22 @@ fwdBodyTansLast (Body _ stms res) = buildBody_ $ do mapM_ fwdStm stms (res <>) <$> mapM tangent res -fwdJVP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS) -fwdJVP scope l@(Lambda params ret body) = - runADM . localScope scope . inScopeOf l $ do +fwdJVP :: + (MonadFreshNames m) => + Scope SOACS -> + Shape -> + Attrs -> + Lambda SOACS -> + m (Lambda SOACS) +fwdJVP scope shape attrs (Lambda params _ body) = + runADM shape attrs . localScope scope $ do params_tan <- mapM newTan params - body_tan <- fwdBodyTansLast body - ret_tan <- mapM tangent ret - pure $ Lambda (params ++ params_tan) (ret <> ret_tan) body_tan + mkLambda (params <> params_tan) $ + bodyBind =<< fwdBodyTansLast body + +-- Note [Forward-Mode vector AD] +-- +-- An primal variable of type 't' has a tangent of type '[tan_shape]t', where +-- 'tan_shape' is the vector shape (which may be empty in the non-vector case). +-- This requires some care for SOACs, which always map across the outermost +-- dimension: basically we have to transpose the inputs and the outputs. diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index b79746fc6c..6da9a5165e 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -9,21 +9,22 @@ module Futhark.AD.Rev (revVJP) where import Control.Monad -import Control.Monad.Identity -import Data.List ((\\)) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M +import Data.Tuple import Futhark.AD.Derivatives +import Futhark.AD.Rev.Acc import Futhark.AD.Rev.Loop import Futhark.AD.Rev.Monad import Futhark.AD.Rev.SOAC +import Futhark.AD.Shared import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute -import Futhark.Util (chunks, takeLast) +import Futhark.Util (takeLast) patName :: Pat Type -> ADM VName patName (Pat [pe]) = pure $ patElemName pe @@ -57,8 +58,12 @@ diffBasicOp pat aux e m = ConvOp op x -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do - contrib <- - letExp "contrib" $ BasicOp $ ConvOp (flipConvOp op) $ Var pat_adj + adj_shape <- askShape + + contrib <- letExp "convop_contrib" <=< mapNest adj_shape (MkSolo (Var pat_adj)) $ + \(MkSolo pat_adj') -> + pure $ BasicOp $ ConvOp (flipConvOp op) pat_adj' + updateSubExpAdj x contrib -- UnOp op x -> do @@ -66,11 +71,12 @@ diffBasicOp pat aux e m = returnSweepCode $ do let t = unOpType op - contrib <- do - let x_pe = primExpFromSubExp t x - pat_adj' = primExpFromSubExp t (Var pat_adj) - dx = pdUnOp op x_pe - letExp "contrib" <=< toExp $ pat_adj' ~*~ dx + + adj_shape <- askShape + + contrib <- letExp "unop_contrib" <=< mapNest adj_shape (MkSolo (Var pat_adj)) $ + \(MkSolo pat_adj') -> + toExp $ primExpFromSubExp t pat_adj' ~*~ pdUnOp op (primExpFromSubExp t x) updateSubExpAdj x contrib -- @@ -82,10 +88,20 @@ diffBasicOp pat aux e m = (wrt_x, wrt_y) = pdBinOp op (primExpFromSubExp t x) (primExpFromSubExp t y) - pat_adj' = primExpFromSubExp t $ Var pat_adj + adj_shape <- askShape + + adj_x <- letExp "binop_x_adj" + <=< mapNest adj_shape (MkSolo (Var pat_adj)) + $ \(MkSolo pat_adj') -> + let pat_adj'' = primExpFromSubExp t pat_adj' + in toExp $ pat_adj'' ~*~ wrt_x + + adj_y <- letExp "binop_y_adj" + <=< mapNest adj_shape (MkSolo (Var pat_adj)) + $ \(MkSolo pat_adj') -> + let pat_adj'' = primExpFromSubExp t pat_adj' + in toExp $ pat_adj'' ~*~ wrt_y - adj_x <- letExp "binop_x_adj" <=< toExp $ pat_adj' ~*~ wrt_x - adj_y <- letExp "binop_y_adj" <=< toExp $ pat_adj' ~*~ wrt_y updateSubExpAdj x adj_x updateSubExpAdj y adj_y -- @@ -127,10 +143,10 @@ diffBasicOp pat aux e m = -- Rearrange arr perm -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m + r <- shapeRank <$> askShape returnSweepCode $ - void $ - updateAdj arr <=< letExp "adj_rearrange" . BasicOp $ - Rearrange pat_adj (rearrangeInverse perm) + void . updateAdj arr <=< letExp "adj_rearrange" . BasicOp $ + Rearrange pat_adj ([0 .. r - 1] <> map (+ r) (rearrangeInverse perm)) -- Replicate (Shape []) (Var se) -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m @@ -189,34 +205,18 @@ diffBasicOp pat aux e m = Update safety arr slice v -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do - v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj slice + adj_shape <- askShape + let adj_slice = Slice $ map sliceDim (shapeDims adj_shape) ++ unSlice slice + v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj adj_slice v_adj_copy <- copyIfArray v_adj updateSubExpAdj v v_adj_copy - zeroes <- letSubExp "update_zero" . zeroExp =<< subExpType v + v_adj_t <- lookupType v_adj + zeroes <- letSubExp "update_zero" $ zeroExp v_adj_t void $ updateAdj arr - =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj slice zeroes) - -- See Note [Adjoints of accumulators] - UpdateAcc safety _ is vs -> do - addStm $ Let pat aux $ BasicOp e - m - pat_adjs <- mapM lookupAdjVal (patNames pat) - returnSweepCode $ do - forM_ (zip pat_adjs vs) $ \(adj, v) -> do - adj_t <- lookupType adj - let index_adj = pure $ BasicOp $ Index adj $ fullSlice adj_t $ map DimFix is - adj_i <- - letExp "updateacc_val_adj" =<< case safety of - Unsafe -> - index_adj - Safe -> - -- The primal UpdateAcc may be out-of-bounds, in which case - -- indexing the adjoint is dangerous. - eIf - (eShapeInBounds (arrayShape adj_t) (map eSubExp is)) - (eBody [index_adj]) - (eBody [pure $ zeroExp $ stripArray (length is) adj_t]) - updateSubExpAdj v adj_i + =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj adj_slice zeroes) + UpdateAcc safety acc is vs -> + diffUpdateAcc pat aux safety acc is vs m -- UserParam {} -> void $ commonBasicOp pat aux e m @@ -225,30 +225,10 @@ vjpOps :: VjpOps vjpOps = VjpOps { vjpLambda = diffLambda, - vjpStm = diffStm + vjpStm = diffStm, + vjpBody = diffBody } --- | Transform updates on accumulators matching the given certificates into --- updates that write provided zero values. -zeroOutUpdates :: [(VName, [SubExp])] -> Lambda SOACS -> Lambda SOACS -zeroOutUpdates certs_to_zeroes lam = lam {lambdaBody = onBody $ lambdaBody lam} - where - onExp = runIdentity . mapExpM mapper - where - mapper = - (identityMapper :: (Monad m) => Mapper SOACS SOACS m) - { mapOnOp = traverseSOACStms (\_ stms -> pure $ onStms stms), - mapOnBody = \_ body -> pure $ onBody body - } - onStms = fmap onStm - onStm (Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is _))) - | Acc c _ _ _ <- patElemType pe, - Just zero <- lookup c certs_to_zeroes = - Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is zero)) - onStm (Let pat aux e) = Let pat aux $ onExp e - - onBody body = body {bodyStms = onStms $ bodyStms body} - diffStm :: Stm SOACS -> ADM () -> ADM () diffStm (Let pat aux (BasicOp e)) m = diffBasicOp pat aux e m @@ -259,7 +239,6 @@ diffStm stm@(Let pat _ (Apply f args _ _)) m pat_adj <- lookupAdjVal =<< patName pat let arg_pes = zipWith primExpFromSubExp argts (map fst args) - pat_adj' = primExpFromSubExp ret (Var pat_adj) convert ft tt | ft == tt = id convert (IntType ft) (IntType tt) = ConvOpExp (SExt ft tt) @@ -268,13 +247,17 @@ diffStm stm@(Let pat _ (Apply f args _ _)) m convert (FloatType ft) Bool = ConvOpExp (FToB ft) convert ft tt = error $ "diffStm.convert: " ++ prettyString (f, ft, tt) + adj_shape <- askShape + contribs <- case pdBuiltin f arg_pes of Nothing -> error $ "No partial derivative defined for builtin function: " ++ prettyString f Just derivs -> forM (zip derivs argts) $ \(deriv, argt) -> - letExp "contrib" <=< toExp . convert ret argt $ pat_adj' ~*~ deriv + letExp "apply_contrib" <=< mapNest adj_shape (MkSolo (Var pat_adj)) $ + \(MkSolo pat_adj') -> + toExp $ convert ret argt $ primExpFromSubExp ret pat_adj' ~*~ deriv zipWithM_ updateSubExpAdj (map fst args) contribs diffStm stm@(Let pat _ (Match ses cases defbody _)) m = do @@ -311,35 +294,8 @@ diffStm (Let pat aux (Op soac)) m = diffStm (Let pat aux loop@Loop {}) m = diffLoop diffStms pat aux loop m -- See Note [Adjoints of accumulators] -diffStm stm@(Let pat _aux (WithAcc inputs lam)) m = do - addStm stm - m - returnSweepCode $ do - adjs <- mapM lookupAdj $ patNames pat - lam' <- renameLambda lam - free_vars <- filterM isActive $ namesToList $ freeIn lam' - free_accs <- filterM (fmap isAcc . lookupType) free_vars - let free_vars' = free_vars \\ free_accs - lam'' <- diffLambda' adjs free_vars' lam' - (inputs_zeroes, inputs') <- - unzip <$> zipWithM renameInputLambda (chunks lengths adjs) inputs - let certs = map paramName $ take (length inputs) $ lambdaParams lam'' - free_adjs <- letTupExp "with_acc_contrib" $ WithAcc inputs' $ zeroOutUpdates (zip certs inputs_zeroes) lam'' - zipWithM_ insAdj (arrs <> free_vars') free_adjs - where - lengths = map (\(_, as, _) -> length as) inputs - arrs = concatMap (\(_, as, _) -> as) inputs - renameInputLambda as_adj (shape, as, _) = do - nes_ts <- mapM (fmap (stripArray (shapeRank shape)) . lookupType) as - zeroes <- mapM (zeroArray mempty) nes_ts - as' <- mapM adjVal as_adj - pure (map Var zeroes, (shape, as', Nothing)) - diffLambda' res_adjs get_adjs_for (Lambda params ts body) = do - localScope (scopeOfLParams params) $ do - Body () stms res <- diffBody res_adjs get_adjs_for body - let body' = Body () stms $ take (length inputs) res <> takeLast (length get_adjs_for) res - ts' <- mapM lookupType get_adjs_for - pure $ Lambda params (take (length inputs) ts <> ts') body' +diffStm (Let pat aux (WithAcc inputs lam)) m = + diffWithAcc vjpOps pat aux inputs lam m diffStm stm _ = error $ "diffStm unhandled:\n" ++ prettyString stm diffStms :: Stms SOACS -> ADM () @@ -371,17 +327,24 @@ diffBody res_adjs get_adjs_for (Body () stms res) = subAD $ diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS) diffLambda res_adjs get_adjs_for (Lambda params _ body) = - localScope (scopeOfLParams params) $ do - Body () stms res <- diffBody res_adjs get_adjs_for body - let body' = Body () stms $ takeLast (length get_adjs_for) res - ts' <- mapM lookupType get_adjs_for - pure $ Lambda params ts' body' - -revVJP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS) -revVJP scope (Lambda params ts body) = - runADM . localScope (scope <> scopeOfLParams params) $ do + mkLambda params $ do + res <- bodyBind =<< diffBody res_adjs get_adjs_for body + pure $ takeLast (length get_adjs_for) res + +revVJP :: + (MonadFreshNames m) => + Scope SOACS -> + Shape -> + Attrs -> + Lambda SOACS -> + m (Lambda SOACS) +revVJP scope shape attrs (Lambda params ts body) = do + runADM shape attrs . localScope (scope <> scopeOfLParams params) $ do + adj_shape <- askShape params_adj <- forM (zip (map resSubExp (bodyResult body)) ts) $ \(se, t) -> - Param mempty <$> maybe (newVName "const_adj") adjVName (subExpVar se) <*> pure t + Param mempty + <$> maybe (newVName "const_res_adj") adjVName (subExpVar se) + <*> pure (t `arrayOfShape` adj_shape) body' <- localScope (scopeOfLParams params_adj) $ @@ -391,143 +354,3 @@ revVJP scope (Lambda params ts body) = body pure $ Lambda (params ++ params_adj) (ts <> map paramType params) body' - --- Note [Adjoints of accumulators] --- --- The general case of taking adjoints of WithAcc is tricky. We make --- some assumptions and lay down a basic design. --- --- First, we assume that any WithAccs that occur in the program are --- come from one of these sources: --- --- - A previous instance of VJP, which means we can rely on the operator having --- a constant adjoint (it's addition as appropriate to the type). --- --- - A scatter, meaning there is no operator. --- --- (These can actually be distinguished by the presence of an operator, although --- we do not currently bother.) --- --- Second, the adjoint of an accumulator is an array of the same type --- as the underlying array. For example, the adjoint type of the --- primal type 'acc(c, [n], {f64})' is '[n]f64'. In principle the --- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type --- '[]f64', '[]f32'. Our current design assumes that adjoints are --- single variables. This is fixable. --- --- In the return sweep, when inserting the with_acc, we still compute the --- "original" accumulator result, but modified such that its initial value is --- the adjoint of the result of the accumulator. We also modify the update_accs --- of these accumulators to be with zero values. This means that the array that --- is produced will be equal to the adjoint of the result, except for those --- places that have been updated, where it will be zero. This is intuitively --- sensible - values that have been overwritten (and so do not contribute to the --- result) should obviously have zero sensitivity. --- --- # Adjoint of UpdateAcc --- --- Consider primal code --- --- update_acc(acc, i, v) --- --- Interpreted as an imperative statement, this means --- --- acc[i] ⊕= v --- --- for some '⊕'. Normally all the compiler knows of '⊕' is that it --- is associative and commutative, but because we assume that all --- accumulators are the result of previous AD transformations, we --- can assume that '⊕' actually behaves like addition - that is, has --- unit partial derivatives. So the return sweep is --- --- v_adj += acc_adj[i] --- --- Further, we modify the primal code so that it becomes --- --- update_acc(acc, i, 0) --- --- for some appropriate notion of zero. --- --- # Adjoint of Map --- --- Suppose we have primal code --- --- let acc' = --- map (...) acc --- --- where "acc : acc(c, [n], {f64})" and the width of the Map is "w". --- Our normal transformation for Map input arrays is to similarly map --- their adjoint, but clearly this doesn't work here because the --- semantics of mapping an adjoint is an "implicit replicate". So --- when generating the return sweep we actually perform that --- replication: --- --- map (...) (replicate w acc_adj) --- --- But what about the contributions to "acc'"? Those we also have to --- take special care of. The result of the map itself is actually a --- multidimensional array: --- --- let acc_contribs = --- map (...) (replicate w acc'_adj) --- --- which we must then sum to add to the contribution. --- --- acc_adj += sum(acc_contribs) --- --- I'm slightly worried about the asymptotics of this, since my --- intuition of this is that the contributions might be rather sparse. --- (Maybe completely zero? If so it will be simplified away --- entirely.) Perhaps a better solution is to treat --- accumulator-inputs in the primal code as we do free variables, and --- create accumulators for them in the return sweep. --- --- # Consumption --- --- A minor problem is that our usual way of handling consumption (Note --- [Consumption]) is not viable, because accumulators are not --- copyable. Fortunately, while the accumulators that are consumed in --- the forward sweep will also be present in the return sweep given --- our current translation rules, they will be dead code. As long as --- we are careful to run dead code elimination after revVJP, we should --- be good. - --- Note [Array Adjoints of Match] --- --- Some unusual, but sadly not completely contrived, contain Match --- expressions that return multiple arrays, and there the arrays --- returned by one branch have overlapping aliases with another --- branch, although in different places. As an example consider this: --- --- let (X,Y) = if c --- then (A, B) --- else (B, A) --- --- Because our aliasing representation cannot express mutually --- exclusive aliases, we will consider X and Y to be aliased to each --- other. In practice, this means it is unlikely for X or Y to be --- consumed, because it would also consume the other (although it's --- possible for carefully written code). --- --- When producing adjoints for this, it will be something like --- --- let (X_adj,Y_adj) = if c --- then (A_adj, B_adj) --- else (B_adj, A_adj) --- --- which completely reflects the primal code. However, while it is --- unlikely that any consumption takes place for the original primal --- variables, it is almost guaranteed that X_adj and Y_adj will be --- consumed (that is the main way we use adjoints after all), and due --- to the conservative aliasing, when one is consumed, so is the --- other! To avoid this tragic fate, we are forced to copy any --- array-typed adjoints returned by a Match. This can be quite costly. --- However: --- --- 1) Futhark has pretty OK copy removal, so maybe it can get rid of --- these by using information not available to the AD pass. --- --- 2) In many cases, arrays will have accumulator adjoints, which are --- not subject to this problem. --- --- Issue #2228 was caused by neglecting to do this. diff --git a/src/Futhark/AD/Rev/Acc.hs b/src/Futhark/AD/Rev/Acc.hs new file mode 100644 index 0000000000..ff34bd2433 --- /dev/null +++ b/src/Futhark/AD/Rev/Acc.hs @@ -0,0 +1,396 @@ +-- | Differentiation related to accumulators in the input program. +module Futhark.AD.Rev.Acc + ( diffWithAcc, + diffUpdateAcc, + ) +where + +-- Note [Adjoints of accumulators] +-- +-- The general case of taking adjoints of WithAcc is tricky. We make +-- some assumptions and lay down a basic design. +-- +-- First, we assume that any WithAccs that occur in the program are +-- come from one of these sources: +-- +-- - A previous instance of VJP, which means we can rely on the operator having +-- a constant adjoint (it's addition as appropriate to the type). +-- +-- - A scatter, meaning there is no operator. +-- +-- (These can actually be distinguished by the presence of an operator, although +-- we do not currently bother.) +-- +-- Second, the adjoint of an accumulator is an array of the same type +-- as the underlying array. For example, the adjoint type of the +-- primal type 'acc(c, [n], {f64})' is '[n]f64'. In principle the +-- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type +-- '[]f64', '[]f32'. Our current design assumes that adjoints are +-- single variables. This is fixable. +-- +-- In the return sweep, when inserting the with_acc, we still compute the +-- "original" accumulator result, but modified such that its initial value is +-- the adjoint of the result of the accumulator. We also modify the update_accs +-- of these accumulators to be with zero values. This means that the array that +-- is produced will be equal to the adjoint of the result, except for those +-- places that have been updated, where it will be zero. This is intuitively +-- sensible - values that have been overwritten (and so do not contribute to the +-- result) should obviously have zero sensitivity. +-- +-- # Adjoint of UpdateAcc +-- +-- Consider primal code +-- +-- update_acc(acc, i, v) +-- +-- Interpreted as an imperative statement, this means +-- +-- acc[i] ⊕= v +-- +-- for some '⊕'. Normally all the compiler knows of '⊕' is that it +-- is associative and commutative, but because we assume that all +-- accumulators are the result of previous AD transformations, we +-- can assume that '⊕' actually behaves like addition - that is, has +-- unit partial derivatives. So the return sweep is +-- +-- v_adj += acc_adj[i] +-- +-- Further, we modify the primal code so that it becomes +-- +-- update_acc(acc, i, 0) +-- +-- for some appropriate notion of zero. +-- +-- # Adjoint of Map +-- +-- Suppose we have primal code +-- +-- let acc' = +-- map (...) acc +-- +-- where "acc : acc(c, [n], {f64})" and the width of the Map is "w". +-- Our normal transformation for Map input arrays is to similarly map +-- their adjoint, but clearly this doesn't work here because the +-- semantics of mapping an adjoint is an "implicit replicate". So +-- when generating the return sweep we actually perform that +-- replication: +-- +-- map (...) (replicate w acc_adj) +-- +-- But what about the contributions to "acc'"? Those we also have to +-- take special care of. The result of the map itself is actually a +-- multidimensional array: +-- +-- let acc_contribs = +-- map (...) (replicate w acc'_adj) +-- +-- which we must then sum to add to the contribution. +-- +-- acc_adj += sum(acc_contribs) +-- +-- I'm slightly worried about the asymptotics of this, since my +-- intuition of this is that the contributions might be rather sparse. +-- (Maybe completely zero? If so it will be simplified away +-- entirely.) Perhaps a better solution is to treat +-- accumulator-inputs in the primal code as we do free variables, and +-- create accumulators for them in the return sweep. +-- +-- # Vectorised WithAcc +-- +-- When WithAcc occurs in vectorised AD, the accumulator element types gain +-- extra leading "vectorised" dimensions corresponding to the enclosing vector +-- shape. For example, if the primal type inside a map of width @w@ is @acc(c, +-- [n], {f64})@, the adjoint type is @[w][n]f64@ -- but the internal accumulator +-- layout expects shape @[n][w]f64@ (the accumulator shape comes first, then the +-- vectorised dimensions, then element dimensions). +-- +-- This means we must transpose accumulator adjoints when entering and +-- leaving the return-sweep WithAcc: +-- +-- * On entry: transpose result adjoints from @[vec...][shape...]elem@ to +-- @[shape...][vec...]elem@ so they can serve as initial values for the +-- accumulators. +-- +-- * On exit: transpose the produced arrays back from @[shape...][vec...]elem@ +-- to @[vec...][shape...]elem@ to match the expected adjoint layout. +-- +-- This is actually quite similar to how other SOACs must be handled. +-- +-- Additionally, the accumulator parameter types in the lambda (and any +-- Acc-typed pattern elements or inner lambda parameters referring to the same +-- certs) must be updated to reflect the vectorised element types *before* +-- differentiation. This ensures that 'lookupAdj' on accumulator variables +-- inside the lambda produces adjoints with the correct vectorised type. +-- +-- The UpdateAcc case is simpler under vectorisation: because the accumulator +-- adjoint already has the vectorised dimensions folded into its element type, a +-- plain index into the adjoint at the update indices directly yields the +-- correctly-shaped contribution. +-- +-- # Consumption +-- +-- A minor problem is that our usual way of handling consumption (Note +-- [Consumption]) is not viable, because accumulators are not +-- copyable. Fortunately, while the accumulators that are consumed in +-- the forward sweep will also be present in the return sweep given +-- our current translation rules, they will be dead code. As long as +-- we are careful to run dead code elimination after revVJP, we should +-- be good. + +-- Note [Array Adjoints of Match] +-- +-- Some unusual, but sadly not completely contrived, contain Match +-- expressions that return multiple arrays, and there the arrays +-- returned by one branch have overlapping aliases with another +-- branch, although in different places. As an example consider this: +-- +-- let (X,Y) = if c +-- then (A, B) +-- else (B, A) +-- +-- Because our aliasing representation cannot express mutually +-- exclusive aliases, we will consider X and Y to be aliased to each +-- other. In practice, this means it is unlikely for X or Y to be +-- consumed, because it would also consume the other (although it's +-- possible for carefully written code). +-- +-- When producing adjoints for this, it will be something like +-- +-- let (X_adj,Y_adj) = if c +-- then (A_adj, B_adj) +-- else (B_adj, A_adj) +-- +-- which completely reflects the primal code. However, while it is +-- unlikely that any consumption takes place for the original primal +-- variables, it is almost guaranteed that X_adj and Y_adj will be +-- consumed (that is the main way we use adjoints after all), and due +-- to the conservative aliasing, when one is consumed, so is the +-- other! To avoid this tragic fate, we are forced to copy any +-- array-typed adjoints returned by a Match. This can be quite costly. +-- However: +-- +-- 1) Futhark has pretty OK copy removal, so maybe it can get rid of +-- these by using information not available to the AD pass. +-- +-- 2) In many cases, arrays will have accumulator adjoints, which are +-- not subject to this problem. +-- +-- Issue #2228 was caused by neglecting to do this. + +import Control.Monad +import Control.Monad.Identity +import Data.List ((\\)) +import Futhark.AD.Rev.Monad +import Futhark.Builder +import Futhark.IR.SOACS +import Futhark.Tools +import Futhark.Transform.Rename +import Futhark.Util (chunks, takeLast) + +-- | Transform updates on accumulators matching the given certificates into +-- updates that write provided zero values. +zeroOutUpdates :: [(VName, [SubExp])] -> Lambda SOACS -> Lambda SOACS +zeroOutUpdates certs_to_zeroes lam = lam {lambdaBody = onBody $ lambdaBody lam} + where + onExp = runIdentity . mapExpM mapper + where + mapper = + (identityMapper :: (Monad m) => Mapper SOACS SOACS m) + { mapOnOp = traverseSOACStms (\_ stms -> pure $ onStms stms), + mapOnBody = \_ body -> pure $ onBody body + } + onStms = fmap onStm + onStm (Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is _))) + | Acc c _ _ _ <- patElemType pe, + Just zero <- lookup c certs_to_zeroes = + Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is zero)) + onStm (Let pat aux e) = Let pat aux $ onExp e + + onBody body = body {bodyStms = onStms $ bodyStms body} + +-- Update accumulator parameter types in the lambda to include vectorised +-- element types. Also updates all Acc-typed pattern elements and inner +-- lambda parameters that reference the same accumulator certs. +updateAccParamTypes :: Int -> Shape -> Lambda SOACS -> Lambda SOACS +updateAccParamTypes n_inputs adj_sh lam + | adj_sh == mempty = lam + | otherwise = + let (cert_ps, rest_ps) = splitAt n_inputs (lambdaParams lam) + (acc_ps, other_ps) = splitAt n_inputs rest_ps + acc_ps' = map (updateParam cert_names) acc_ps + cert_names = map paramName cert_ps + body' = updateBody cert_names (lambdaBody lam) + ret' = map (updateAccType cert_names) (lambdaReturnType lam) + in lam + { lambdaParams = cert_ps ++ acc_ps' ++ other_ps, + lambdaReturnType = ret', + lambdaBody = body' + } + where + updateParam :: [VName] -> Param Type -> Param Type + updateParam certs p = + p {paramDec = updateAccType certs (paramDec p)} + + updateAccType :: [VName] -> Type -> Type + updateAccType certs (Acc cert acc_shape ts u) + | cert `elem` certs = + Acc cert acc_shape (map (`arrayOfShape` adj_sh) ts) u + updateAccType _ t = t + + updateBody :: [VName] -> Body SOACS -> Body SOACS + updateBody certs body = + body {bodyStms = fmap (updateStm certs) (bodyStms body)} + + updateStm :: [VName] -> Stm SOACS -> Stm SOACS + updateStm certs (Let pat aux e) = + Let (updatePat certs pat) aux (updateExp certs e) + + updatePat :: [VName] -> Pat Type -> Pat Type + updatePat certs (Pat pes) = + Pat $ map (\pe -> pe {patElemDec = updateAccType certs (patElemDec pe)}) pes + + updateExp :: [VName] -> Exp SOACS -> Exp SOACS + updateExp certs = runIdentity . mapExpM mapper + where + mapper = + (identityMapper :: (Monad m) => Mapper SOACS SOACS m) + { mapOnBody = \_ b -> pure $ updateBody certs b, + mapOnOp = pure . updateSOAC certs + } + + updateSOAC :: [VName] -> SOAC SOACS -> SOAC SOACS + updateSOAC certs = runIdentity . mapSOACM mapper + where + mapper = + identitySOACMapper + { mapOnSOACLambda = pure . updateLambda certs + } + + updateLambda :: [VName] -> Lambda SOACS -> Lambda SOACS + updateLambda certs l = + l + { lambdaParams = map (updateParam certs) (lambdaParams l), + lambdaReturnType = map (updateAccType certs) (lambdaReturnType l), + lambdaBody = updateBody certs (lambdaBody l) + } + +diffWithAcc :: + VjpOps -> + Pat Type -> + StmAux () -> + [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] -> + Lambda SOACS -> + ADM () -> + ADM () +diffWithAcc ops pat aux inputs lam m = do + addStm $ Let pat aux $ WithAcc inputs lam + m + returnSweepCode $ do + adj_shape <- askShape + adjs <- mapM lookupAdj $ patNames pat + -- Transpose the accumulator result adjoints from [vec...][shape...]elem + -- to [shape...][vec...]elem, matching the internal accumulator layout. + adjs' <- transposeAdjs adj_shape adjs + lam' <- renameLambda lam + -- Update the lambda's accumulator parameter types to reflect vectorised + -- element types BEFORE differentiation, so that lookupAdj on Acc variables + -- inside the lambda gives the correct vectorised adjoint type. + let lam'_vec = updateAccParamTypes n_inputs adj_shape lam' + free_vars <- filterM isActive $ namesToList $ freeIn lam'_vec + free_accs <- filterM (fmap isAcc . lookupType) free_vars + let free_vars' = free_vars \\ free_accs + lam'' <- diffLambda' adjs' free_vars' lam'_vec + (inputs_zeroes, inputs') <- + unzip <$> zipWithM (renameInputLambda adj_shape) (chunks lengths adjs) inputs + let certs = map paramName $ take n_inputs $ lambdaParams lam'' + raw_adjs <- + letTupExp "with_acc_contrib" . WithAcc inputs' $ + zeroOutUpdates (zip certs inputs_zeroes) lam'' + -- The accumulator results have shape [shape...][vec...]elem. Transpose + -- back to [vec...][shape...]elem for the adjoint. + let n_arrs = sum lengths + (arr_adjs, free_adjs) = splitAt n_arrs raw_adjs + arr_adjs' <- zipWithM (transposeAccResult adj_shape) (map (\(s, _, _) -> s) inputs) arr_adjs + zipWithM_ insAdj arrs arr_adjs' + zipWithM_ insAdj free_vars' free_adjs + where + n_inputs = length inputs + lengths = map (\(_, as, _) -> length as) inputs + arrs = concatMap (\(_, as, _) -> as) inputs + + -- Transpose the accumulator-related adjoints from [vec...][shape...]elem + -- to [shape...][vec...]elem. Non-accumulator adjs are left unchanged. + transposeAdjs :: Shape -> [Adj] -> ADM [Adj] + transposeAdjs adj_sh adjs + | adj_sh == mempty = pure adjs + | otherwise = do + let n_arrs = sum lengths + (acc_adjs, other_adjs) = splitAt n_arrs adjs + acc_adjs' <- mapM transposeAdj acc_adjs + pure $ acc_adjs' ++ other_adjs + + transposeAdj :: Adj -> ADM Adj + transposeAdj adj = do + v <- adjVal adj + v' <- vecToInner v + pure $ AdjVal $ Var v' + + -- Transpose [shape...][vec...][elem...] to [vec...][shape...][elem...] + transposeAccResult :: Shape -> Shape -> VName -> ADM VName + transposeAccResult adj_sh shape v + | adj_sh == mempty = pure v + | otherwise = do + v_t <- lookupType v + let r = shapeRank adj_sh + s = shapeRank shape + total = arrayRank v_t + perm = [s .. s + r - 1] ++ [0 .. s - 1] ++ [s + r .. total - 1] + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm + + renameInputLambda adj_sh as_adj (shape, as, _) = do + -- Compute element types with vectorised dimensions included. + orig_nes_ts <- mapM (fmap (stripArray (shapeRank shape)) . lookupType) as + let vec_nes_ts = map (`arrayOfShape` adj_sh) orig_nes_ts + zeroes <- mapM (zeroArray mempty) vec_nes_ts + -- Transpose adjoints from [vec...][shape...]elem to [shape...][vec...]elem + -- so they match the accumulator layout. + as' <- mapM adjVal as_adj + as'' <- mapM vecToInner as' + pure (map Var zeroes, (shape, as'', Nothing)) + + diffLambda' res_adjs get_adjs_for (Lambda params ts body) = do + localScope (scopeOfLParams params) $ do + Body () stms res <- vjpBody ops res_adjs get_adjs_for body + let body' = Body () stms $ take n_inputs res <> takeLast (length get_adjs_for) res + ts' <- mapM lookupType get_adjs_for + pure $ Lambda params (take n_inputs ts <> ts') body' + +diffUpdateAcc :: + Pat Type -> + StmAux () -> + Safety -> + VName -> + [SubExp] -> + [SubExp] -> + ADM () -> + ADM () +diffUpdateAcc pat aux safety acc is vs m = do + addStm $ Let pat aux $ BasicOp $ UpdateAcc safety acc is vs + m + pat_adjs <- mapM lookupAdjVal (patNames pat) + returnSweepCode $ do + forM_ (zip pat_adjs vs) $ \(adj, v) -> do + adj_t <- lookupType adj + let index_adj = pure $ BasicOp $ Index adj $ fullSlice adj_t $ map DimFix is + adj_i <- + letExp "updateacc_val_adj" =<< case safety of + Unsafe -> + index_adj + Safe -> + -- The primal UpdateAcc may be out-of-bounds, in which case + -- indexing the adjoint is dangerous. + eIf + (eShapeInBounds (arrayShape adj_t) (map eSubExp is)) + (eBody [index_adj]) + (eBody [pure $ zeroExp $ stripArray (length is) adj_t]) + updateSubExpAdj v adj_i diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 01e731df50..f988eaeca5 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -241,69 +241,70 @@ diffMinMaxHist _ops x aux n minmax ne is vs w rf dst m = do m - x_bar <- lookupAdjVal x - - x_ind_dst <- newParam (baseName x <> "_ind_param") $ Prim int64 - x_bar_dst <- newParam (baseName x <> "_bar_param") $ Prim t - dst_lam_inner <- - mkLambda [x_ind_dst, x_bar_dst] $ - fmap varsRes . letTupExp "dst_bar" - =<< eIf - (toExp $ le64 (paramName x_ind_dst) .==. -1) - (eBody $ pure $ eParam x_bar_dst) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) - dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner - - dst_bar <- - letExp (baseName dst <> "_bar") . Op . Screma w [x_inds, x_bar] - =<< mapSOAC dst_lam - - updateAdj dst dst_bar - - vs_bar <- lookupAdjVal vs - - inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims [] - let inds = x_inds : inds' - - par_x_ind_vs <- replicateM nr_dims $ newParam (baseName x <> "_ind_param") $ Prim int64 - par_x_bar_vs <- newParam (baseName x <> "_bar_param") $ Prim t - vs_lam_inner <- - mkLambda (par_x_bar_vs : par_x_ind_vs) $ - fmap varsRes . letTupExp "res" - =<< eIf - (toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) - ( eBody $ - pure $ do - vs_bar_i <- - letSubExp (baseName vs_bar <> "_el") . BasicOp $ - Index vs_bar . Slice $ - fmap (DimFix . Var . paramName) par_x_ind_vs - eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i) - ) - vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner - - vs_bar_p <- - letExp (baseName vs <> "_partial") . Op . Screma w (x_bar : inds) - =<< mapSOAC vs_lam - - q <- - letSubExp "q" - =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims - - scatter_inps <- do - -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p] - -- ToDo: Cosmin asks: is the below the correct translation of the line above? - forM (inds ++ [vs_bar_p]) $ \v -> do - v_t <- lookupType v - letExp "flat" . BasicOp . Reshape v $ - reshapeAll (arrayShape v_t) (Shape [q]) - - vs_bar' <- - fmap head $ - doScatter (baseName vs <> "_bar") nr_dims [vs_bar] scatter_inps $ - pure . map (Var . paramName) - insAdj vs vs_bar' + locallyNonvector (x, dst, vs) $ do + x_bar <- lookupAdjVal x + + x_ind_dst <- newParam (baseName x <> "_ind_param") $ Prim int64 + x_bar_dst <- newParam (baseName x <> "_bar_param") $ Prim t + dst_lam_inner <- + mkLambda [x_ind_dst, x_bar_dst] $ + fmap varsRes . letTupExp "dst_bar" + =<< eIf + (toExp $ le64 (paramName x_ind_dst) .==. -1) + (eBody $ pure $ eParam x_bar_dst) + (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) + dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner + + dst_bar <- + letExp (baseName dst <> "_bar") . Op . Screma w [x_inds, x_bar] + =<< mapSOAC dst_lam + + updateAdj dst dst_bar + + vs_bar <- lookupAdjVal vs + + inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims [] + let inds = x_inds : inds' + + par_x_ind_vs <- replicateM nr_dims $ newParam (baseName x <> "_ind_param") $ Prim int64 + par_x_bar_vs <- newParam (baseName x <> "_bar_param") $ Prim t + vs_lam_inner <- + mkLambda (par_x_bar_vs : par_x_ind_vs) $ + fmap varsRes . letTupExp "res" + =<< eIf + (toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1) + (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) + ( eBody $ + pure $ do + vs_bar_i <- + letSubExp (baseName vs_bar <> "_el") . BasicOp $ + Index vs_bar . Slice $ + fmap (DimFix . Var . paramName) par_x_ind_vs + eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i) + ) + vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner + + vs_bar_p <- + letExp (baseName vs <> "_partial") . Op . Screma w (x_bar : inds) + =<< mapSOAC vs_lam + + q <- + letSubExp "q" + =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims + + scatter_inps <- do + -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p] + -- ToDo: Cosmin asks: is the below the correct translation of the line above? + forM (inds ++ [vs_bar_p]) $ \v -> do + v_t <- lookupType v + letExp "flat" . BasicOp . Reshape v $ + reshapeAll (arrayShape v_t) (Shape [q]) + + vs_bar' <- + fmap head $ + doScatter (baseName vs <> "_bar") nr_dims [vs_bar] scatter_inps $ + pure . map (Var . paramName) + insAdj vs vs_bar' where mk_indices :: [SubExp] -> [SubExp] -> ADM [VName] mk_indices [] _ = pure [] @@ -403,8 +404,8 @@ diffMulHist _ops x aux n mul ne is vs w rf dst m = do fmap varsRes . letTupExp "h_part" =<< eIf (toExp $ 0 .==. le64 (paramName c_param)) - (eBody $ pure $ eParam p_param) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) + (eBody [eParam p_param]) + (eBody [eSubExp $ Constant $ blankPrimValue t]) lam_h_part <- nestedmap dst_dims [vs_elm_type, int64] lam_h_part_inner h_part_res <- eLambda lam_h_part $ map (eSubExp . Var) [nz_prods, zr_counts] h_part' <- bindSubExpRes "h_part" h_part_res @@ -418,60 +419,61 @@ diffMulHist _ops x aux n mul ne is vs w rf dst m = do m - x_bar <- lookupAdjVal x - - lam_mul'' <- renameLambda lam_mul' - dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar] - dst_bar <- bindSubExpRes (baseName dst <> "_bar") dst_bar_res - updateAdj dst $ head dst_bar - - lam_mul''' <- renameLambda lam_mul' - part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar] - part_bar' <- bindSubExpRes "part_bar" part_bar_res - let [part_bar] = part_bar' - - inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t] - let [zr_cts, pr_bar, nz_prd, a_param] = inner_params - lam_vsbar_inner <- - mkLambda inner_params $ - fmap varsRes . letTupExp "vs_bar" =<< do - eIf - (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts)) - (eBody $ pure $ eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param) - ( eBody $ - pure $ - eIf - ( eBinOp - LogAnd - (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts)) - (eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param) - ) - (eBody $ pure $ eBinOp mul (eParam nz_prd) (eParam pr_bar)) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) - ) + locallyNonvector (x, dst, vs) $ do + x_bar <- lookupAdjVal x - lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner + lam_mul'' <- renameLambda lam_mul' + dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar] + dst_bar <- bindSubExpRes (baseName dst <> "_bar") dst_bar_res + updateAdj dst $ head dst_bar - i_param <- newParam "i" $ Prim int64 - a_param' <- newParam "a" $ rowType vs_type - lam_vsbar <- - mkLambda [i_param, a_param'] $ - fmap varsRes . letTupExp "vs_bar" - =<< eIf - (toExp $ withinBounds $ pure (w, paramName i_param)) - ( buildBody_ $ do - let i = fullSlice vs_type [DimFix $ Var $ paramName i_param] - names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"] - zipWithM_ (\name -> letBindNames [name] . BasicOp . flip Index i) names [zr_counts, part_bar, nz_prods] - eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param'] - ) - (eBody $ pure $ pure $ zeroExp $ rowType dst_type) + lam_mul''' <- renameLambda lam_mul' + part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar] + part_bar' <- bindSubExpRes "part_bar" part_bar_res + let [part_bar] = part_bar' + + inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t] + let [zr_cts, pr_bar, nz_prd, a_param] = inner_params + lam_vsbar_inner <- + mkLambda inner_params $ + fmap varsRes . letTupExp "vs_bar" =<< do + eIf + (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts)) + (eBody [eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param]) + ( eBody + [ eIf + ( eBinOp + LogAnd + (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts)) + (eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param) + ) + (eBody [eBinOp mul (eParam nz_prd) (eParam pr_bar)]) + (eBody [eSubExp $ Constant $ blankPrimValue t]) + ] + ) - vs_bar <- - letExp (baseName vs <> "_bar") . Op . Screma n [is, vs] - =<< mapSOAC lam_vsbar + lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner + + i_param <- newParam "i" $ Prim int64 + a_param' <- newParam "a" $ rowType vs_type + lam_vsbar <- + mkLambda [i_param, a_param'] $ + fmap varsRes . letTupExp "vs_bar" + =<< eIf + (toExp $ withinBounds $ pure (w, paramName i_param)) + ( buildBody_ $ do + let i = fullSlice vs_type [DimFix $ Var $ paramName i_param] + names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"] + zipWithM_ (\name -> letBindNames [name] . BasicOp . flip Index i) names [zr_counts, part_bar, nz_prods] + eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param'] + ) + (eBody [pure $ zeroExp $ rowType dst_type]) + + vs_bar <- + letExp (baseName vs <> "_bar") . Op . Screma n [is, vs] + =<< mapSOAC lam_vsbar - updateAdj vs vs_bar + updateAdj vs vs_bar -- -- special case of histogram with add as operator. @@ -502,25 +504,25 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do m - x_bar <- lookupAdjVal x + locallyNonvector (x, dst, vs) $ do + x_bar <- lookupAdjVal x - updateAdj dst x_bar + updateAdj dst x_bar - x_type <- lookupType x - i_param <- newParam (baseName vs <> "_i") $ Prim int64 - let i = paramName i_param - lam_vsbar <- - mkLambda [i_param] $ - fmap varsRes . letTupExp "vs_bar" - =<< eIf - (toExp $ withinBounds $ pure (w, i)) - (eBody $ pure $ pure $ BasicOp $ Index x_bar $ fullSlice x_type [DimFix $ Var i]) - (eBody $ pure $ eSubExp ne) + i_param <- newParam (baseName vs <> "_i") $ Prim int64 + let i = paramName i_param + lam_vsbar <- + mkLambda [i_param] $ + fmap varsRes . letTupExp "vs_bar" + =<< eIf + (toExp $ withinBounds $ pure (w, i)) + (eBody [eIndex x_bar [eVar i]]) + (eBody [eSubExp ne]) - vs_bar <- - letExp (baseName vs <> "_bar") . Op . Screma n [is] - =<< mapSOAC lam_vsbar - updateAdj vs vs_bar + vs_bar <- + letExp (baseName vs <> "_bar") . Op . Screma n [is] + =<< mapSOAC lam_vsbar + updateAdj vs vs_bar -- Special case for vectorised combining operator. Rewrite -- reduce_by_index dst (map2 op) nes is vss @@ -796,146 +798,145 @@ diffHist ops xs aux n lam0 ne as w rf dst m = do m - xs_bar <- traverse lookupAdjVal xs - - (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w - f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f' - f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f' - - dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part - dst_bar <- bindSubExpRes "dst_bar" dst_bar' - zipWithM_ updateAdj dst dst_bar - - h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part - h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar' - - lam <- renameLambda lam0 - lam' <- renameLambda lam0 - - -- is' <- mapout (head as) n (head w) - -- sorted <- radixSort' (is' : tail as) n $ head w - sorted <- radixSort' as n $ head w - let siota = head sorted - let sis = head $ tail sorted - let sas = drop 2 sorted - - iota_n <- - letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 - - par_i <- newParam "i" $ Prim int64 - flag_lam <- mkFlagLam par_i sis - flag <- letExp "flag" . Op . Screma n [iota_n] =<< mapSOAC flag_lam - - -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n) - par_i' <- newParam "i" $ Prim int64 - let i' = paramName par_i' - g_lam <- - mkLambda [par_i'] $ - fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do - im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1) - nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i') - let s1 = [DimFix im1] - let s2 = [DimFix nmi] - - -- flag array for left scan - f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i'] - - -- array for left scan - r1 <- - letTupExp' "r1" - =<< eIf - (eSubExp f1) - (eBody $ fmap eSubExp ne) - (eBody . fmap (eSubExp . Var) =<< multiIndex sas s1) - - -- array for right scan inc flag - r2 <- - letTupExp' "r2" - =<< eIf - (toExp $ le64 i' .==. 0) - (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) - ( eBody $ - pure $ do - eIf - (pure $ BasicOp $ Index flag $ Slice s2) - (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) - ( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var - =<< multiIndex sas s2 - ) - ) - - traverse eSubExp $ f1 : r1 ++ r2 - - -- scan (\(f1,v1) (f2,v2) -> - -- let f = f1 || f2 - -- let v = if f2 then v2 else g v1 v2 - -- in (f,v) ) (false,ne) (zip flags vals) - scan_lams <- - traverse - ( \l -> do - f1 <- newParam "f1" $ Prim Bool - f2 <- newParam "f2" $ Prim Bool - ps <- lambdaParams <$> renameLambda lam0 - let (p1, p2) = splitAt (length ne) ps - - mkLambda (f1 : p1 ++ f2 : p2) $ - fmap varsRes . letTupExp "scan_res" =<< do - let f = eBinOp LogOr (eParam f1) (eParam f2) - eIf - (eParam f2) - (eBody $ f : fmap eParam p2) - ( eBody . (f :) . fmap (eSubExp . Var) - =<< bindSubExpRes "gres" - =<< eLambda l (fmap eParam ps) + locallyNonvector (xs, dst, lam0, as) $ do + xs_bar <- traverse lookupAdjVal xs + + (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w + f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f' + f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f' + + dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part + dst_bar <- bindSubExpRes "dst_bar" dst_bar' + zipWithM_ updateAdj dst dst_bar + + h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part + h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar' + + lam <- renameLambda lam0 + lam' <- renameLambda lam0 + + -- is' <- mapout (head as) n (head w) + -- sorted <- radixSort' (is' : tail as) n $ head w + sorted <- radixSort' as n $ head w + let siota = head sorted + let sis = head $ tail sorted + let sas = drop 2 sorted + + iota_n <- + letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 + + par_i <- newParam "i" $ Prim int64 + flag_lam <- mkFlagLam par_i sis + flag <- letExp "flag" . Op . Screma n [iota_n] =<< mapSOAC flag_lam + + -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n) + par_i' <- newParam "i" $ Prim int64 + let i' = paramName par_i' + g_lam <- + mkLambda [par_i'] $ + fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do + im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1) + nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i') + let s1 = [DimFix im1] + let s2 = [DimFix nmi] + + -- flag array for left scan + f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i'] + + -- array for left scan + r1 <- + letTupExp' "r1" + =<< eIf + (eSubExp f1) + (eBody $ fmap eSubExp ne) + (eBody . fmap (eSubExp . Var) =<< multiIndex sas s1) + + -- array for right scan inc flag + r2 <- + letTupExp' "r2" + =<< eIf + (toExp $ le64 i' .==. 0) + (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) + ( eBody $ + pure $ do + eIf + (pure $ BasicOp $ Index flag $ Slice s2) + (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) + ( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var + =<< multiIndex sas s2 + ) ) - ) - [lam, lam'] - let ne' = Constant (BoolValue False) : ne + traverse eSubExp $ f1 : r1 ++ r2 + + -- scan (\(f1,v1) (f2,v2) -> + -- let f = f1 || f2 + -- let v = if f2 then v2 else g v1 v2 + -- in (f,v) ) (false,ne) (zip flags vals) + scan_lams <- + traverse + ( \l -> do + f1 <- newParam "f1" $ Prim Bool + f2 <- newParam "f2" $ Prim Bool + ps <- lambdaParams <$> renameLambda lam0 + let (p1, p2) = splitAt (length ne) ps + + mkLambda (f1 : p1 ++ f2 : p2) $ + fmap varsRes . letTupExp "scan_res" =<< do + let f = eBinOp LogOr (eParam f1) (eParam f2) + eIf + (eParam f2) + (eBody $ f : fmap eParam p2) + ( eBody . (f :) . fmap (eSubExp . Var) + =<< bindSubExpRes "gres" + =<< eLambda l (fmap eParam ps) + ) + ) + [lam, lam'] - scansres <- - letTupExp "adj_ctrb_scan" . Op . Screma n [iota_n] - =<< scanomapSOAC (map (`Scan` ne') scan_lams) g_lam + let ne' = Constant (BoolValue False) : ne - let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres + scansres <- + letTupExp "adj_ctrb_scan" . Op . Screma n [iota_n] + =<< scanomapSOAC (map (`Scan` ne') scan_lams) g_lam - -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis - par_i'' <- newParam "i" $ Prim int64 - let i'' = paramName par_i'' - map_lam <- - mkLambda [par_i''] $ - fmap varsRes . letTupExp "scan_res" - =<< eIf - (toExp $ withinBounds $ pure (head w, i'')) - (eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i'']) - ( eBody $ do - map (\t -> pure $ BasicOp $ Replicate (Shape $ tail $ arrayDims t) (Constant $ blankPrimValue $ elemType t)) as_type - ) + let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres + + -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis + par_i'' <- newParam "i" $ Prim int64 + let i'' = paramName par_i'' + map_lam <- + mkLambda [par_i''] $ + fmap varsRes . letTupExp "scan_res" + =<< eIf + (toExp $ withinBounds $ pure (head w, i'')) + (eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i'']) + ( eBody $ do + map (\t -> pure $ BasicOp $ Replicate (Shape $ tail $ arrayDims t) (Constant $ blankPrimValue $ elemType t)) as_type + ) - f_bar <- letTupExp "f_bar" . Op . Screma n [sis] =<< mapSOAC map_lam + f_bar <- letTupExp "f_bar" . Op . Screma n [sis] =<< mapSOAC map_lam - (as_params, f) <- mkF lam0 as_type n - f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f + (as_params, f) <- mkF lam0 as_type n + f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f - -- map (\i -> rs_arr_rev[n-i-1]) (iota n) - par_i''' <- newParam "i" $ Prim int64 - let i''' = paramName par_i''' - rev_lam <- mkLambda [par_i'''] $ do - nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1) - varsRes <$> multiIndex rs_arr_rev [DimFix nmim1] + -- map (\i -> rs_arr_rev[n-i-1]) (iota n) + par_i''' <- newParam "i" $ Prim int64 + let i''' = paramName par_i''' + rev_lam <- mkLambda [par_i'''] $ do + nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1) + varsRes <$> multiIndex rs_arr_rev [DimFix nmim1] - rs_arr <- - letTupExp "rs_arr" . Op . Screma n [iota_n] - =<< mapSOAC rev_lam + rs_arr <- letTupExp "rs_arr" . Op . Screma n [iota_n] =<< mapSOAC rev_lam - sas_bar <- - bindSubExpRes "sas_bar" - =<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr) + sas_bar <- + bindSubExpRes "sas_bar" + =<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr) - scatter_dst <- traverse (\t -> letExp "scatter_dst" $ BasicOp $ Scratch (elemType t) (arrayDims t)) as_type - as_bar <- multiScatter scatter_dst siota sas_bar + scatter_dst <- traverse (\t -> letExp "scatter_dst" $ BasicOp $ Scratch (elemType t) (arrayDims t)) as_type + as_bar <- multiScatter scatter_dst siota sas_bar - zipWithM_ updateAdj (tail as) as_bar + zipWithM_ updateAdj (tail as) as_bar where -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n) mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS) diff --git a/src/Futhark/AD/Rev/Loop.hs b/src/Futhark/AD/Rev/Loop.hs index fad1fa8004..a5b372b338 100644 --- a/src/Futhark/AD/Rev/Loop.hs +++ b/src/Futhark/AD/Rev/Loop.hs @@ -267,7 +267,7 @@ reverseIndices loop = do pure (M.singleton i i_rev, i_stms) --- | Pures a substitution which substitutes values in the reverse +-- | Returns a substitution which substitutes values in the reverse -- loop body with values from the tape. restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions restore stms_adj loop_params' i' = diff --git a/src/Futhark/AD/Rev/Map.hs b/src/Futhark/AD/Rev/Map.hs index 851e09e710..b39f899b17 100644 --- a/src/Futhark/AD/Rev/Map.hs +++ b/src/Futhark/AD/Rev/Map.hs @@ -75,6 +75,32 @@ withAcc inputs m = do subAD $ mkLambda (cert_params ++ acc_params) $ m $ map paramName acc_params letTupExp "withhacc_res" $ WithAcc inputs acc_lam +vecPerm :: Shape -> Type -> [Int] +vecPerm adj_shape t = + [shapeRank adj_shape] + ++ [0 .. shapeRank adj_shape - 1] + ++ [shapeRank adj_shape + 1 .. arrayRank t - 1] + +pushAdjShape :: VName -> ADM VName +pushAdjShape v = do + adj_shape <- askShape + v_t <- lookupType v + if adj_shape == mempty || arrayShape v_t == adj_shape || isAcc v_t + then pure v + else do + let perm = vecPerm adj_shape v_t + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm + +popAdjShape :: VName -> ADM VName +popAdjShape v = do + adj_shape <- askShape + v_t <- lookupType v + if adj_shape == mempty || arrayShape v_t == adj_shape || isAcc v_t + then pure v + else do + let perm = rearrangeInverse $ vecPerm adj_shape v_t + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm + -- | Perform VJP on a Map. The 'Adj' list is the adjoints of the -- result of the map. vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM () @@ -150,8 +176,8 @@ vjpMap ops res_adjs _ w map_lam as zipWithM_ forRes [0 ..] res_ivs where - isSparse (AdjSparse (Sparse shape _ ivs)) = do - guard $ shapeDims shape == [w] + isSparse (AdjSparse (Sparse shape _ vd ivs)) = do + guard $ drop vd (shapeDims shape) == [w] Just ivs isSparse _ = Nothing @@ -161,7 +187,8 @@ vjpMap ops pat_adj aux w map_lam as = returnSweepCode $ do pat_adj_vals <- forM (zip pat_adj (lambdaReturnType map_lam)) $ \(adj, t) -> case t of Acc {} -> letExp "acc_adj_rep" . BasicOp . Replicate (Shape [w]) . Var =<< adjVal adj - _ -> adjVal adj + _ -> pushAdjShape =<< adjVal adj + pat_adj_params <- mapM (newParam "map_adj_p" . rowType <=< lookupType) pat_adj_vals @@ -195,8 +222,8 @@ vjpMap ops pat_adj aux w map_lam as = returnSweepCode $ do let param_ts = map paramType (lambdaParams map_lam') forM_ (zip3 param_ts as param_contribs) $ \(param_t, a, param_contrib) -> case param_t of - Acc {} -> freeContrib a param_contrib - _ -> updateAdj a param_contrib + Acc {} -> freeContrib a =<< popAdjShape param_contrib -- CHECKME + _ -> updateAdj a =<< popAdjShape param_contrib where addIdxParams n lam = do idxs <- replicateM n $ newParam "idx" $ Prim int64 diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 33de2fc397..70a160afae 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -9,10 +9,12 @@ module Futhark.AD.Rev.Monad ( ADM, RState (..), + REnv, runADM, Adj (..), InBounds (..), Sparse (..), + askShape, adjFromParam, adjFromVar, lookupAdj, @@ -43,6 +45,7 @@ module Futhark.AD.Rev.Monad zeroArray, unitAdjOfType, addLambda, + vecOpExp, -- VjpOps (..), -- @@ -50,14 +53,19 @@ module Futhark.AD.Rev.Monad lookupLoopTape, substLoopTape, renameLoopTape, + -- + locallyNonvector, + vecToInner, ) where import Control.Monad +import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor (second) import Data.Map qualified as M import Data.Maybe +import Futhark.AD.Shared import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.PrimExp.Convert import Futhark.Builder @@ -104,10 +112,17 @@ data InBounds -- | A symbolic representation of an array that is all zeroes, except -- at certain indexes. data Sparse = Sparse - { -- | The shape of the array. + { -- | The full shape of the array (including any vector dimensions, which are + -- stored in sparseVecDims). sparseShape :: Shape, -- | Element type of the array. sparseType :: PrimType, + -- | Number of leading dimensions that are \"vector\" dimensions, due to + -- vector AD. These are not indexed by the sparse index, but are present in + -- the values. When zero, this is the ordinary non-vector case. This is + -- equivalent to the rank of `askShape`, but it is convenient to store it + -- here as well. + sparseVecDims :: Int, -- | Locations and values of nonzero values. Indexes may be -- negative, in which case the value is ignored (unless -- 'AssumeBounds' is used). @@ -140,14 +155,15 @@ zeroArray shape t Replicate shape zero sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName -sparseArray (Sparse shape t ivs) = do +sparseArray (Sparse shape t vec_dims ivs) = do flip (foldM f) ivs =<< zeroArray shape (Prim t) where arr_t = Prim t `arrayOfShape` shape + vec_slice = map sliceDim $ take vec_dims $ shapeDims shape f arr (check, i, se) = do let stm s = letExp "sparse" . BasicOp $ - Update s arr (fullSlice arr_t [DimFix i]) se + Update s arr (fullSlice arr_t (vec_slice ++ [DimFix i])) se case check of AssumeBounds -> stm Unsafe CheckBounds _ -> stm Safe @@ -170,8 +186,8 @@ unitAdjOfType t = AdjVal <$> letSubExp "adj_unit" (oneExp t) adjRep :: Adj -> ([SubExp], [SubExp] -> Adj) adjRep (AdjVal se) = ([se], \[se'] -> AdjVal se') adjRep (AdjZero shape pt) = ([], \[] -> AdjZero shape pt) -adjRep (AdjSparse (Sparse shape pt ivs)) = - (concatMap ivRep ivs, AdjSparse . Sparse shape pt . repIvs ivs) +adjRep (AdjSparse (Sparse shape pt vd ivs)) = + (concatMap ivRep ivs, AdjSparse . Sparse shape pt vd . repIvs ivs) where ivRep (_, i, v) = [i, v] repIvs ((check, _, _) : ivs') (i : v : ses) = @@ -197,12 +213,18 @@ data RState = RState stateNameSource :: VNameSource } -newtype ADM a = ADM (BuilderT SOACS (State RState) a) +data REnv = REnv + { envAdjShape :: Shape, + envAttrs :: Attrs + } + +newtype ADM a = ADM (BuilderT SOACS (ReaderT REnv (State RState)) a) deriving ( Functor, Applicative, Monad, MonadState RState, + MonadReader REnv, MonadFreshNames, HasScope SOACS, LocalScope SOACS @@ -221,16 +243,21 @@ instance MonadFreshNames (State RState) where getNameSource = gets stateNameSource putNameSource src = modify (\env -> env {stateNameSource = src}) -runADM :: (MonadFreshNames m) => ADM a -> m a -runADM (ADM m) = +askShape :: ADM Shape +askShape = ADM $ lift $ asks envAdjShape + +runADM :: (MonadFreshNames m) => Shape -> Attrs -> ADM a -> m a +runADM shape attrs (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState - (fst <$> runBuilderT m mempty) + ( runReaderT (fst <$> runBuilderT m mempty) $ + REnv shape attrs + ) (RState mempty mempty mempty vn) adjVal :: Adj -> ADM VName -adjVal (AdjVal se) = letExp "const_adj" $ BasicOp $ SubExp se +adjVal (AdjVal se) = letExp "const_val_adj" $ BasicOp $ SubExp se adjVal (AdjSparse sparse) = sparseArray sparse adjVal (AdjZero shape t) = zeroArray shape $ Prim t @@ -312,13 +339,12 @@ noAdjsFor names m = do where names' = namesToList names -addBinOp :: PrimType -> BinOp -addBinOp (IntType it) = Add it OverflowWrap -addBinOp (FloatType ft) = FAdd ft -addBinOp Bool = LogAnd -addBinOp Unit = LogAnd - -tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName] +tabNest :: + (MonadBuilder m, Rep m ~ SOACS) => + Int -> + [VName] -> + ([VName] -> [VName] -> m [VName]) -> + m [VName] tabNest = tabNest' [] where tabNest' is 0 vs f = f (reverse is) vs @@ -338,13 +364,14 @@ tabNest = tabNest' [] let lam = Lambda (iparam : params) ret (Body () stms res) letTupExp "tab" . Op . Screma w (iota : vs) =<< mapSOAC lam --- | Construct a lambda for adding two values of the given type. -addLambda :: Type -> ADM (Lambda SOACS) -addLambda (Prim pt) = binOpLambda (addBinOp pt) pt -addLambda t@Array {} = do +-- | Construct a lambda for binop'ing two values of the given type, +-- which may be arrays. +vecOpLambda :: (PrimType -> BinOp) -> Type -> ADM (Lambda SOACS) +vecOpLambda bop (Prim pt) = binOpLambda (bop pt) pt +vecOpLambda bop t@Array {} = do xs_p <- newParam "xs" t ys_p <- newParam "ys" t - lam <- addLambda $ rowType t + lam <- vecOpLambda bop $ rowType t body <- insertStmsM $ do res <- letSubExp "lam_map" @@ -358,10 +385,10 @@ addLambda t@Array {} = do lambdaReturnType = [t], lambdaBody = body } -addLambda t = - error $ "addLambda: " ++ show t +vecOpLambda _ t = + error $ "vecOpLambda: " ++ show t --- Construct an expression for adding the two variables. +-- | Construct an expression for adding the two variables. addExp :: VName -> VName -> ADM (Exp SOACS) addExp x y = do x_t <- lookupType x @@ -374,9 +401,23 @@ addExp x y = do _ -> error $ "addExp: unexpected type: " ++ prettyString x_t +-- | Construct an expression for performing this binary operation on two variables. +vecOpExp :: (PrimType -> BinOp) -> VName -> VName -> ADM (Exp SOACS) +vecOpExp bop x y = do + x_t <- lookupType x + case x_t of + Prim pt -> + pure $ BasicOp $ BinOp (bop pt) (Var x) (Var y) + Array {} -> do + lam <- vecOpLambda bop $ rowType x_t + Op . Screma (arraySize 0 x_t) [x, y] <$> mapSOAC lam + _ -> + error $ "vecOpExp: unexpected type: " ++ prettyString x_t + lookupAdj :: VName -> ADM Adj lookupAdj v = do maybeAdj <- gets $ M.lookup v . stateAdjs + adj_shape <- askShape case maybeAdj of Nothing -> do v_t <- lookupType v @@ -384,7 +425,7 @@ lookupAdj v = do Acc _ shape [Prim t] _ -> pure $ AdjZero shape t Acc _ shape [t] _ -> pure $ AdjZero (shape <> arrayShape t) (elemType t) Acc {} -> error $ "lookupAdj: Non-singleton accumulator adjoint: " <> prettyString v_t - _ -> pure $ AdjZero (arrayShape v_t) (elemType v_t) + _ -> pure $ AdjZero (adj_shape <> arrayShape v_t) (elemType v_t) Just v_adj -> pure v_adj lookupAdjVal :: VName -> ADM VName @@ -394,27 +435,34 @@ updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM () updateAdjIndex v (check, i) se = do maybeAdj <- gets $ M.lookup v . stateAdjs t <- lookupType v + adj_shape <- askShape let iv = (check, i, se) + vec_dims = shapeRank adj_shape + full_shape = adj_shape <> arrayShape t case maybeAdj of - Nothing -> do - setAdj v $ AdjSparse $ Sparse (arrayShape t) (elemType t) [iv] - Just AdjZero {} -> - setAdj v $ AdjSparse $ Sparse (arrayShape t) (elemType t) [iv] - Just (AdjSparse (Sparse shape pt ivs)) -> - setAdj v $ AdjSparse $ Sparse shape pt $ iv : ivs + Nothing -> + setAdj v $ AdjSparse $ Sparse full_shape (elemType t) vec_dims [iv] + Just (AdjZero {}) -> + setAdj v $ AdjSparse $ Sparse full_shape (elemType t) vec_dims [iv] + Just (AdjSparse (Sparse shape pt vd ivs)) -> + setAdj v $ AdjSparse $ Sparse shape pt vd $ iv : ivs Just adj@AdjVal {} -> do v_adj <- adjVal adj v_adj_t <- lookupType v_adj se_v <- letExp "se_v" $ BasicOp $ SubExp se + vec_shape <- askShape insAdj v =<< case v_adj_t of Acc {} -> do let stms s = do + attrs <- asks envAttrs dims <- arrayDims <$> lookupType se_v ~[v_adj'] <- - tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> - letTupExp "acc" . BasicOp $ - UpdateAcc s v_adj' (i : map Var is) [Var se_v'] + attributing attrs $ + tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> do + let (vec_is, val_is) = splitAt (shapeRank vec_shape) $ map Var is + letTupExp "acc" . BasicOp $ + UpdateAcc s v_adj' (vec_is ++ i : val_is) [Var se_v'] pure v_adj' case check of CheckBounds _ -> stms Safe @@ -422,13 +470,15 @@ updateAdjIndex v (check, i) se = do OutOfBounds -> pure v_adj _ -> do let stms s = do + let slice = + fullSlice v_adj_t $ + map sliceDim (shapeDims vec_shape) ++ [DimFix i] v_adj_i <- letExp (baseName v_adj <> "_i") . BasicOp $ - Index v_adj $ - fullSlice v_adj_t [DimFix i] + Index v_adj slice se_update <- letSubExp "updated_adj_i" =<< addExp se_v v_adj_i letExp (baseName v_adj) . BasicOp $ - Update s v_adj (fullSlice v_adj_t [DimFix i]) se_update + Update s v_adj slice se_update case check of CheckBounds _ -> stms Safe AssumeBounds -> stms Unsafe @@ -518,7 +568,8 @@ subSubsts m = do data VjpOps = VjpOps { vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS), - vjpStm :: Stm SOACS -> ADM () -> ADM () + vjpStm :: Stm SOACS -> ADM () -> ADM (), + vjpBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS) } -- | @setLoopTape v vs@ establishes @vs@ as the name of the array @@ -543,6 +594,50 @@ substLoopTape v v' = mapM_ (setLoopTape v') =<< lookupLoopTape v renameLoopTape :: Substitutions -> ADM () renameLoopTape = mapM_ (uncurry substLoopTape) . M.toList +-- | Disable vector AD within the provided action. This results in a map that +-- computes each adjoint explicitly, then assembles the resulting adjoint +-- vectors. This is useful for constructs (such as scans) where vector AD is +-- impractical or inefficient. +locallyNonvector :: + (FreeIn e) => + -- | Something that represents all the free variables used in the action. + -- Usually just an expression or statement. + e -> + ADM () -> + ADM () +locallyNonvector e m = do + adj_shape <- askShape + if adj_shape == mempty + then m + else do + -- We map over all adjoints of free variables in 'e'. To avoid clutter, we + -- only consider those that actually have known nonzero adjoints. + e_adjs <- filterM knownAdjoint e_free + e_adjs_vals <- mapM lookupAdjVal e_adjs + e_free_adjs <- mkMap "nonvec_adj" adj_shape e_adjs_vals $ \e_adjs_vals' -> do + zipWithM_ insAdj e_adjs e_adjs_vals' + local (\env -> env {envAdjShape = mempty}) m + mapM lookupAdjVal e_free + zipWithM_ insAdj e_free e_free_adjs + where + e_free = namesToList $ freeIn e + knownAdjoint v = do + v_adj <- lookupAdj v + pure $ case v_adj of + AdjZero {} -> False + _ -> True + +-- | If we are doing vector AD, apply 'vecPerm' to the array. +vecToInner :: VName -> ADM VName +vecToInner v = do + adj_shape <- askShape + if adj_shape == mempty + then pure v + else do + v_t <- lookupType v + letExp (baseName v <> "_tr") . BasicOp . Rearrange v $ + vecPerm adj_shape v_t + -- Note [Consumption] -- -- Parts of this transformation depends on duplicating computation. diff --git a/src/Futhark/AD/Rev/Reduce.hs b/src/Futhark/AD/Rev/Reduce.hs index 67201d50e4..8d8deca5d5 100644 --- a/src/Futhark/AD/Rev/Reduce.hs +++ b/src/Futhark/AD/Rev/Reduce.hs @@ -9,7 +9,9 @@ module Futhark.AD.Rev.Reduce where import Control.Monad +import Data.Tuple import Futhark.AD.Rev.Monad +import Futhark.AD.Shared import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS @@ -78,10 +80,8 @@ diffReduce _ops [adj] w [a] red | Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, isAdd op = do adj_rep <- - letExp (baseName adj <> "_rep") $ - BasicOp $ - Replicate (Shape [w]) $ - Var adj + vecToInner <=< letExp (baseName adj <> "_rep") $ + BasicOp (Replicate (Shape [w]) (Var adj)) void $ updateAdj a adj_rep where isAdd FAdd {} = True @@ -118,10 +118,9 @@ diffReduce ops pat_adj w as red = do f_adj <- vjpLambda ops (map adjFromVar pat_adj) as_params f as_adj <- - letTupExp "adjs" . Op . Screma w (ls ++ as ++ rs) - =<< mapSOAC f_adj + letTupExp "red_contribs" . Op . Screma w (ls ++ as ++ rs) =<< mapSOAC f_adj - zipWithM_ updateAdj as as_adj + zipWithM_ updateAdj as =<< mapM vecToInner as_adj where renameRed (Reduce comm lam nes) = Reduce comm <$> renameLambda lam <*> pure nes @@ -253,7 +252,8 @@ diffMulReduce :: VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM () diffMulReduce _ops x aux w mul ne as m = do let t = binOpType mul - let const_zero = eSubExp $ Constant $ blankPrimValue t + let zero = Constant $ blankPrimValue t + const_zero = eSubExp zero a_param <- newParam "a" $ Prim t map_lam <- @@ -292,42 +292,45 @@ diffMulReduce _ops x aux w mul ne as m = do x_adj <- lookupAdjVal x + adj_shape <- askShape + + zero_contrib <- letExp "zero_contrib" $ BasicOp $ Replicate adj_shape zero + a_param_rev <- newParam "a" $ Prim t map_lam_rev <- mkLambda [a_param_rev] $ fmap varsRes . letTupExp "adj_res" =<< eIf (toExp $ 0 .==. le64 zr_count) - ( eBody $ - pure $ - eBinOp mul (eSubExp $ Var x_adj) $ - eBinOp (getDiv t) (eSubExp $ Var nz_prods) $ - eParam a_param_rev + ( eBody + [ mapNest adj_shape (MkSolo (Var x_adj)) $ \(MkSolo x_adj') -> + eBinOp mul (eSubExp x_adj') $ + eBinOp (getDiv t) (eVar nz_prods) $ + eParam a_param_rev + ] ) - ( eBody $ - pure $ - eIf + ( eBody + [ eIf (toExp $ 1 .==. le64 zr_count) - ( eBody $ - pure $ - eIf + ( eBody + [ eIf (eCmpOp (CmpEq t) (eParam a_param_rev) const_zero) - ( eBody $ - pure $ - eBinOp mul (eSubExp $ Var x_adj) $ - eSubExp $ - Var nz_prods + ( eBody + [ mapNest adj_shape (MkSolo (Var x_adj)) $ + \(MkSolo x_adj') -> + eBinOp mul (eSubExp x_adj') $ eVar nz_prods + ] ) - (eBody $ pure const_zero) + (eBody [eVar zero_contrib]) + ] ) - (eBody $ pure const_zero) + (eBody [eVar zero_contrib]) + ] ) - as_adjup <- - letExp "adjs" . Op . Screma w [as] - =<< mapSOAC map_lam_rev + as_adjup <- letExp "prod_contrib" . Op . Screma w [as] =<< mapSOAC map_lam_rev - updateAdj as as_adjup + updateAdj as =<< vecToInner as_adjup where getDiv :: PrimType -> BinOp getDiv (IntType t) = SDiv t Unsafe diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index 11ce52e8d1..be7e07bb9a 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -177,14 +177,14 @@ vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m Just [(op, _, _, _)] <- lamIsBinOp lam', isAddOp op = diffAddHist ops x aux n lam ne is vs w rf dst m -vjpSOAC ops pat aux (Hist n as [histop] f) m +vjpSOAC ops pat aux (Hist w as [histop] f) m | isIdentityLambda f, - HistOp (Shape w) rf dst ne lam <- histop = do - diffHist ops (patNames pat) aux n lam ne as w rf dst m -vjpSOAC ops pat _aux (Hist n as histops f) m + HistOp (Shape n) rf dst ne lam <- histop = do + diffHist ops (patNames pat) aux w lam ne as n rf dst m +vjpSOAC ops pat _aux (Hist w as histops f) m | not (isIdentityLambda f) = do (mapstm, redstm) <- - histomapToMapAndHist pat (n, histops, f, as) + histomapToMapAndHist pat (w, histops, f, as) vjpStm ops mapstm $ vjpStm ops redstm m vjpSOAC ops pat aux (Stream w as accs lam) m = do stms <- collectStms_ $ auxing aux $ sequentialStreamWholeArray pat w accs lam as @@ -194,13 +194,14 @@ vjpSOAC _ops pat aux (WithVJP args lam lam_adj) m = do forM_ (zip (patNames pat) lam_res) $ \(v, SubExpRes cs se) -> certifying cs $ letBindNames [v] $ BasicOp $ SubExp se m - pat_adj <- mapM lookupAdjVal $ patNames pat - contribs <- - eLambda lam_adj (map (eSubExp . resSubExp) lam_res ++ map (eSubExp . Var) pat_adj) - forM_ (zip args contribs) $ \(arg, contrib) -> - (updateSubExpAdj arg <=< letExp "contrib") $ - BasicOp . SubExp . resSubExp $ - contrib + locallyNonvector (patNames pat, args) $ do + pat_adj <- mapM lookupAdjVal $ patNames pat + contribs <- + eLambda lam_adj (map (eSubExp . resSubExp) lam_res ++ map (eSubExp . Var) pat_adj) + forM_ (zip args contribs) $ \(arg, contrib) -> + (updateSubExpAdj arg <=< letExp "contrib") $ + BasicOp . SubExp . resSubExp $ + contrib vjpSOAC _ _ _ soac _ = error $ "vjpSOAC unhandled:\n" ++ prettyString soac diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index d2ad0bc3ae..3ec61705d8 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -379,7 +379,7 @@ finalMapPPAD ops as scan = do eLambda op_bar_2 $ toExp . Var . paramName <$> par_y_right ++ par_a diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM () -diffScan ops ys w as scan = do +diffScan ops ys w as scan = locallyNonvector (ys, scan, as) $ do -- ys ~ results of scan, w ~ size of input array, as ~ (unzipped) -- arrays, scan ~ scan: operator with ne scan_case <- identifyCase ops $ scanLambda scan @@ -409,14 +409,12 @@ diffScan ops ys w as scan = do map1_lam <- mkScanFusedMapLam ops w (scanLambda scan) as ys ys_adj sc d scans_lin_fun_o <- mkScanLinFunO (head as_ts) sc scan_lams <- mkScans (specialScans sc) scans_lin_fun_o - iota <- - letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 + iota <- letExp "iota" $ iota64 w r_scan <- letTupExp "adj_ctrb_scan" . Op . Screma w [iota] =<< scanomapSOAC scan_lams map1_lam mkScanFinalMap ops w (scanLambda scan) as ys (splitScanRes sc r_scan d) -- Goal: calculate as_contribs in new way - -- zipWithM_ updateAdj as as_contribs -- as_bar += new adjoint zipWithM_ updateAdj as as_contribs where mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS] @@ -468,7 +466,7 @@ diffScanVec ops ys aux w lam ne as m = do foldr (vjpStm ops) m stmts diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM () -diffScanAdd _ops ys n lam' ne as = do +diffScanAdd _ops ys n lam' ne as = locallyNonvector (ys, lam', as) $ do lam <- renameLambda lam' ys_bar <- lookupAdjVal ys diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs new file mode 100644 index 0000000000..9bcfa47904 --- /dev/null +++ b/src/Futhark/AD/Shared.hs @@ -0,0 +1,62 @@ +-- | Various definitions used for both forward and reverse mode. +module Futhark.AD.Shared + ( vecPerm, + asVName, + mapNest, + mkMap, + ) +where + +import Control.Monad +import Data.Foldable +import Futhark.Construct +import Futhark.IR.SOACS + +-- | A permutation for transposing the vector shape past the next dimension. +-- +-- That is, converts @[vec...][d][elem...]@ to @[d][vec...][elem...]@. +vecPerm :: Shape -> Type -> [Int] +vecPerm vec_shape t = + [shapeRank vec_shape] + ++ [0 .. shapeRank vec_shape - 1] + ++ [shapeRank vec_shape + 1 .. arrayRank t - 1] + +asVName :: (MonadBuilder m) => SubExp -> m VName +asVName (Var v) = pure v +asVName (Constant x) = letExp "asv" $ BasicOp $ SubExp $ Constant x + +mapNest :: + (MonadBuilder m, Rep m ~ SOACS, Traversable f) => + Shape -> + f SubExp -> + (f SubExp -> m (Exp SOACS)) -> + m (Exp SOACS) +mapNest shape x f = do + if shape == mempty + then f x + else do + let w = shapeSize 0 shape + x_v <- traverse asVName x + x_p <- traverse (newParam "xp" . rowType <=< lookupType) x_v + lam <- mkLambda (toList x_p) $ do + fmap (subExpsRes . pure) . letSubExp "mapnest_res" + =<< f (fmap (Var . paramName) x_p) + Op . Screma w (toList x_v) <$> mapSOAC lam + +-- | Construct a map over the given arrays, which must have the provided outer +-- shape. The purpose of the 'Shape' argument is to handle the case where no +-- arrays are provided. +mkMap :: + (MonadBuilder m, Rep m ~ SOACS, Traversable f) => + Name -> + Shape -> + f VName -> + -- | Action for building the body, passed names + -- corresponding to elements of the arrays. + (f VName -> m [VName]) -> + m [VName] +mkMap desc shape arrs f = do + let w = shapeSize 0 shape + x_p <- traverse (newParam "xp" . rowType <=< lookupType) arrs + lam <- mkLambda (toList x_p) $ varsRes <$> f (fmap paramName x_p) + letTupExp desc . Op . Screma w (toList arrs) =<< mapSOAC lam diff --git a/src/Futhark/Construct.hs b/src/Futhark/Construct.hs index 055544bba1..bebc33818a 100644 --- a/src/Futhark/Construct.hs +++ b/src/Futhark/Construct.hs @@ -68,6 +68,7 @@ module Futhark.Construct -- * Monadic expression builders eSubExp, + eVar, eParam, eMatch', eMatch, @@ -108,6 +109,7 @@ module Futhark.Construct fullSliceNum, isFullSlice, sliceAt, + iota64, -- * Result types instantiateShapes, @@ -204,6 +206,13 @@ eSubExp :: m (Exp (Rep m)) eSubExp = pure . BasicOp . SubExp +-- | Turn a variable into a monad expression, through 'eSubExp'. +eVar :: + (MonadBuilder m) => + VName -> + m (Exp (Rep m)) +eVar = eSubExp . Var + -- | Treat a parameter as a monadic expression. eParam :: (MonadBuilder m) => @@ -575,6 +584,11 @@ sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp sliceAt t n slice = fullSlice t $ map sliceDim (take n $ arrayDims t) ++ slice +-- | Produce a straightforward `Int64` `Iota` of the given length with offset 0 +-- and stride 1. +iota64 :: SubExp -> Exp rep +iota64 n = BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 + -- | Like 'fullSlice', but the dimensions are simply numeric. fullSliceNum :: (Num d) => [d] -> [DimIndex d] -> Slice d fullSliceNum dims slice = diff --git a/src/Futhark/IR/Parse.hs b/src/Futhark/IR/Parse.hs index ba6104868b..d7c8aabae5 100644 --- a/src/Futhark/IR/Parse.hs +++ b/src/Futhark/IR/Parse.hs @@ -839,7 +839,9 @@ pSOAC pr = pVJP = parens $ SOAC.VJP - <$> braces (pSubExp `sepBy` pComma) + <$> pShape + <* pComma + <*> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma @@ -847,7 +849,9 @@ pSOAC pr = pJVP = parens $ SOAC.JVP - <$> braces (pSubExp `sepBy` pComma) + <$> pShape + <* pComma + <*> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma diff --git a/src/Futhark/IR/SOACS/SOAC.hs b/src/Futhark/IR/SOACS/SOAC.hs index 7d5cdbd5e0..773682c231 100644 --- a/src/Futhark/IR/SOACS/SOAC.hs +++ b/src/Futhark/IR/SOACS/SOAC.hs @@ -79,9 +79,9 @@ data SOAC rep -- The final lambda produces indexes and values for the 'HistOp's. Hist SubExp [VName] [HistOp rep] (Lambda rep) | -- FIXME: this should not be here - JVP [SubExp] [SubExp] (Lambda rep) + JVP Shape [SubExp] [SubExp] (Lambda rep) | -- FIXME: this should not be here - VJP [SubExp] [SubExp] (Lambda rep) + VJP Shape [SubExp] [SubExp] (Lambda rep) | -- FIXME: this should not be here WithVJP [SubExp] (Lambda rep) (Lambda rep) | -- | A combination of scan, reduction, and map. The first @@ -401,14 +401,16 @@ mapSOACM :: SOACMapper frep trep m -> SOAC frep -> m (SOAC trep) -mapSOACM tv (JVP args vec lam) = +mapSOACM tv (JVP shape args vec lam) = JVP - <$> mapM (mapOnSOACSubExp tv) args + <$> mapM (mapOnSOACSubExp tv) shape + <*> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec <*> mapOnSOACLambda tv lam -mapSOACM tv (VJP args vec lam) = +mapSOACM tv (VJP shape args vec lam) = VJP - <$> mapM (mapOnSOACSubExp tv) args + <$> mapM (mapOnSOACSubExp tv) shape + <*> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec <*> mapOnSOACLambda tv lam mapSOACM tv (WithVJP args lam0 lam1) = @@ -509,10 +511,10 @@ instance (ASTRep rep) => Rename (SOAC rep) where -- | The type of a SOAC. soacType :: (Typed (LParamInfo rep)) => SOAC rep -> [Type] -soacType (JVP _ _ lam) = - lambdaReturnType lam ++ lambdaReturnType lam -soacType (VJP _ _ lam) = - lambdaReturnType lam ++ map paramType (lambdaParams lam) +soacType (JVP shape _ _ lam) = + lambdaReturnType lam ++ map (`arrayOfShape` shape) (lambdaReturnType lam) +soacType (VJP shape _ _ lam) = + lambdaReturnType lam ++ map ((`arrayOfShape` shape) . paramType) (lambdaParams lam) soacType (WithVJP _ lam _) = lambdaReturnType lam soacType (Stream outersize _ accs lam) = @@ -561,10 +563,10 @@ mapHistOp f (HistOp w rf dests nes lam) = HistOp w rf dests nes $ f lam instance CanBeAliased SOAC where - addOpAliases aliases (JVP args vec lam) = - JVP args vec (Alias.analyseLambda aliases lam) - addOpAliases aliases (VJP args vec lam) = - VJP args vec (Alias.analyseLambda aliases lam) + addOpAliases aliases (JVP shape args vec lam) = + JVP shape args vec (Alias.analyseLambda aliases lam) + addOpAliases aliases (VJP shape args vec lam) = + VJP shape args vec (Alias.analyseLambda aliases lam) addOpAliases aliases (WithVJP args lam lam_adj) = WithVJP args @@ -618,12 +620,12 @@ instance IsOp SOAC where concatIndicesToEachValue is vs = let is_flat = mconcat is in map (is_flat <>) vs - opDependencies (JVP args vec lam) = + opDependencies (JVP _ args vec lam) = mconcat $ replicate 2 $ lambdaDependencies mempty lam $ zipWith (<>) (map depsOf' args) (map depsOf' vec) - opDependencies (VJP args vec lam) = + opDependencies (VJP _ args vec lam) = lambdaDependencies mempty lam @@ -708,26 +710,32 @@ instance (RepTypes rep) => ST.IndexOp (SOAC rep) where -- | Type-check a SOAC. typeCheckSOAC :: (TC.Checkable rep) => SOAC (Aliases rep) -> TC.TypeM rep () -typeCheckSOAC (VJP args vec lam) = do +typeCheckSOAC (VJP shape args vec lam) = do + mapM_ (TC.require $ Prim int64) shape args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec - unless (vec_ts == lambdaReturnType lam) $ + unless (vec_ts == map (`arrayOfShape` shape) (lambdaReturnType lam)) $ TC.bad . TC.TypeError . docText $ "Return type" PP.indent 2 (pretty (lambdaReturnType lam)) - "does not match type of seed vector" + "inconsistent with type of seed vector" PP.indent 2 (pretty vec_ts) -typeCheckSOAC (JVP args vec lam) = do + "with shape" + PP.indent 2 (pretty shape) +typeCheckSOAC (JVP shape args vec lam) = do + mapM_ (TC.require $ Prim int64) shape args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec - unless (vec_ts == map TC.argType args') $ + unless (vec_ts == map ((`arrayOfShape` shape) . TC.argType) args') $ TC.bad . TC.TypeError . docText $ "Parameter type" PP.indent 2 (pretty $ map TC.argType args') "does not match type of seed vector" PP.indent 2 (pretty vec_ts) + "with shape" + PP.indent 2 (pretty shape) typeCheckSOAC (WithVJP args lam lam_adj) = do args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' @@ -850,10 +858,10 @@ typeCheckReduce (Reduce _ red_lam red_nes) = do pure red_nes' instance RephraseOp SOAC where - rephraseInOp r (VJP args vec lam) = - VJP args vec <$> rephraseLambda r lam - rephraseInOp r (JVP args vec lam) = - JVP args vec <$> rephraseLambda r lam + rephraseInOp r (VJP shape args vec lam) = + VJP shape args vec <$> rephraseLambda r lam + rephraseInOp r (JVP shape args vec lam) = + JVP shape args vec <$> rephraseLambda r lam rephraseInOp r (WithVJP args lam lam_adj) = WithVJP args <$> rephraseLambda r lam <*> rephraseLambda r lam_adj rephraseInOp r (Stream w arrs acc lam) = @@ -881,9 +889,9 @@ rephraseScan r (Scan op nes) = Scan <$> rephraseLambda r op <*> pure nes instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where - opMetrics (VJP _ _ lam) = + opMetrics (VJP _ _ _ lam) = inside "VJP" $ lambdaMetrics lam - opMetrics (JVP _ _ lam) = + opMetrics (JVP _ _ _ lam) = inside "JVP" $ lambdaMetrics lam opMetrics (WithVJP _ lam lam_adj) = do inside "WithVJP" $ lambdaMetrics lam @@ -900,19 +908,21 @@ instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where lambdaMetrics post_lam instance (PrettyRep rep) => PP.Pretty (SOAC rep) where - pretty (VJP args vec lam) = + pretty (VJP shape args vec lam) = "vjp" <> parens ( PP.align $ - PP.braces (commasep $ map pretty args) + pretty shape + <> comma PP.braces (commasep $ map pretty args) <> comma PP.braces (commasep $ map pretty vec) <> comma pretty lam ) - pretty (JVP args vec lam) = + pretty (JVP shape args vec lam) = "jvp" <> parens ( PP.align $ - PP.braces (commasep $ map pretty args) + pretty shape + <> comma PP.braces (commasep $ map pretty args) <> comma PP.braces (commasep $ map pretty vec) <> comma pretty lam ) diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index 9fc739c512..c2c471e357 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -91,16 +91,18 @@ simplifyConsts = simplifySOAC :: (Simplify.SimplifiableRep rep) => Simplify.SimplifyOp rep (SOAC (Wise rep)) -simplifySOAC (VJP arr vec lam) = do - (lam', hoisted) <- Engine.simplifyLambda mempty lam +simplifySOAC (VJP shape arr vec lam) = do + shape' <- traverse Engine.simplify shape arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec - pure (VJP arr' vec' lam', hoisted) -simplifySOAC (JVP arr vec lam) = do (lam', hoisted) <- Engine.simplifyLambda mempty lam + pure (VJP shape' arr' vec' lam', hoisted) +simplifySOAC (JVP shape arr vec lam) = do + shape' <- traverse Engine.simplify shape arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec - pure (JVP arr' vec' lam', hoisted) + (lam', hoisted) <- Engine.simplifyLambda mempty lam + pure (JVP shape' arr' vec' lam', hoisted) simplifySOAC (WithVJP args lam lam_adj) = do args' <- mapM Engine.simplify args (lam', hoisted) <- Engine.simplifyLambda mempty lam diff --git a/src/Futhark/IR/TypeCheck.hs b/src/Futhark/IR/TypeCheck.hs index a106e386ae..bf0c94df09 100644 --- a/src/Futhark/IR/TypeCheck.hs +++ b/src/Futhark/IR/TypeCheck.hs @@ -111,7 +111,7 @@ instance (Checkable rep) => Show (ErrorCase rep) where show (InvalidPatError pat t desc) = "Pat\n" ++ prettyString pat - ++ "\ncannot match value of type\n" + ++ "\ncannot match expression of type\n" ++ T.unpack (prettyTupleLines t) ++ end where diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 4a3bd9b925..7550402269 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1869,14 +1869,24 @@ isIntrinsicFunction qname args = do handleAccs _ _ = Nothing handleAD [f, x, v] fname - | fname `elem` ["jvp2", "vjp2"] = Just $ \desc -> do + | fname `elem` ["jvp2", "vjp2", "jmp2", "mjp2"] = Just $ \desc -> do x' <- internaliseExp "ad_x" x v' <- internaliseExp "ad_v" v + x_t <- subExpType $ head x' + v_t <- subExpType $ head v' lam <- internaliseLambdaCoerce f =<< mapM subExpType x' fmap (map I.Var) . letTupExp desc . Op $ case fname of - "jvp2" -> JVP x' v' lam - _ -> VJP x' v' lam + "jvp2" -> JVP mempty x' v' lam + "vjp2" -> VJP mempty x' v' lam + "jmp2" -> + JVP (vecShape x_t v_t) x' v' lam + "mjp2" -> + VJP (vecShape (head (lambdaReturnType lam)) v_t) x' v' lam + _ -> error "handleAD: not supposed to happen." + where + vecShape t1 t2 = + I.Shape $ take (I.arrayRank t2 - I.arrayRank t1) (I.arrayDims t2) handleAD [f, f_adj, x] "with_vjp" = Just $ \desc -> do x' <- internaliseExp "ad_x" x lam <- internaliseLambdaCoerce f =<< mapM subExpType x' diff --git a/src/Futhark/Optimise/Fusion.hs b/src/Futhark/Optimise/Fusion.hs index 7a2db9bb8e..87fc6f73cf 100644 --- a/src/Futhark/Optimise/Fusion.hs +++ b/src/Futhark/Optimise/Fusion.hs @@ -667,12 +667,12 @@ runInnerFusionOnContext c@(incoming, node, nodeT, outgoing) = case nodeT of cases' <- mapM (traverse $ renameBody <=< (`doFusionWithDelayed` to_fuse)) cases defbody' <- doFusionWithDelayed defbody to_fuse pure (incoming, node, MatchNode (Let pat aux (Match cond cases' defbody' dec)) [], outgoing) - StmNode (Let pat aux (Op (Futhark.VJP args vec lam))) -> doFuseScans $ do + StmNode (Let pat aux (Op (Futhark.VJP shape args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam - pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP args vec lam'))), outgoing) - StmNode (Let pat aux (Op (Futhark.JVP args vec lam))) -> doFuseScans $ do + pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP shape args vec lam'))), outgoing) + StmNode (Let pat aux (Op (Futhark.JVP shape args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam - pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP args vec lam'))), outgoing) + pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP shape args vec lam'))), outgoing) StmNode (Let pat aux (WithAcc inputs lam)) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam pure (incoming, node, StmNode (Let pat aux (WithAcc inputs lam')), outgoing) diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index b1b2f318ad..8d3c51fb65 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -36,20 +36,22 @@ bindLambda pat aux (Lambda params _ body) args = do certifying cs $ letBindNames [v] $ BasicOp $ SubExp se onStm :: Bool -> Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS) -onStm _ mode scope (Let pat aux (Op (VJP args vec lam))) = do +onStm _ mode scope (Let pat aux (Op (VJP shape args vec lam))) = do lam' <- onLambda True mode scope lam if mode == All || lam == lam' then do - lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope lam' + lam'' <- + (`runReaderT` scope) . simplifyLambda + =<< revVJP scope shape (stmAuxAttrs aux) lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope - else pure $ oneStm $ Let pat aux $ Op $ VJP args vec lam' -onStm _ mode scope (Let pat aux (Op (JVP args vec lam))) = do + else pure $ oneStm $ Let pat aux $ Op $ VJP shape args vec lam' +onStm _ mode scope (Let pat aux (Op (JVP shape args vec lam))) = do lam' <- onLambda True mode scope lam if mode == All || lam == lam' then do - lam'' <- fwdJVP scope lam' + lam'' <- fwdJVP scope shape (stmAuxAttrs aux) lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope - else pure $ oneStm $ Let pat aux $ Op $ JVP args vec lam' + else pure $ oneStm $ Let pat aux $ Op $ JVP shape args vec lam' -- -- This corresponds to a WithVJP that is not inside of a differential operator. -- FIXME: this assumption will go bad when we don't inline so much. diff --git a/src/Futhark/Tools.hs b/src/Futhark/Tools.hs index de70129dbc..e9583e587c 100644 --- a/src/Futhark/Tools.hs +++ b/src/Futhark/Tools.hs @@ -13,6 +13,8 @@ module Futhark.Tools partitionChunkedFoldParameters, withAcc, doScatter, + addBinOp, + addLambda, -- * Primitive expressions module Futhark.Analysis.PrimExp.Convert, @@ -333,3 +335,39 @@ doScatter desc rank dest arrs mk = do =<< mapSOAC map_lam letTupExp desc $ WithAcc [(acc_shape, [v], Nothing) | v <- dest] withacc_lam + +-- | The most addition-like binary operator for some primitive type. +addBinOp :: PrimType -> BinOp +addBinOp (IntType it) = Add it OverflowWrap +addBinOp (FloatType ft) = FAdd ft +addBinOp Bool = LogAnd +addBinOp Unit = LogAnd + +-- | Construct a lambda for adding two values of the given type, Using SOACs to handle arrays. +addLambda :: + ( OpC (Rep m) ~ SOAC, + MonadBuilder m, + Buildable (Rep m) + ) => + TypeBase Shape NoUniqueness -> + m (Lambda (Rep m)) +addLambda (Prim pt) = binOpLambda (addBinOp pt) pt +addLambda t@Array {} = do + xs_p <- newParam "xs" t + ys_p <- newParam "ys" t + lam <- addLambda $ rowType t + body <- insertStmsM $ do + res <- + letSubExp "lam_map" + . Op + . Screma (arraySize 0 t) [paramName xs_p, paramName ys_p] + =<< mapSOAC lam + pure $ resultBody [res] + pure + Lambda + { lambdaParams = [xs_p, ys_p], + lambdaReturnType = [t], + lambdaBody = body + } +addLambda t = + error $ "addLambda: " ++ show t diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index 9eb9af04f6..725998a120 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -55,6 +55,8 @@ module Futhark.Util topologicalSort, debugTraceM, ensureCacheDirectory, + interleave, + unterleave, ) where @@ -376,6 +378,15 @@ convFloat v | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v +-- | Interleave two lists. +interleave :: [a] -> [a] -> [a] +interleave xs ys = concat $ L.transpose [xs, ys] + +-- | The inverse of interleave. +unterleave :: [a] -> ([a], [a]) +unterleave (x : y : xys) = bimap (x :) (y :) $ unterleave xys +unterleave _ = ([], []) + -- Z-encoding from https://ghc.haskell.org/trac/ghc/wiki/Commentary/Compiler/SymbolNames -- -- Slightly simplified as we do not need it to deal with tuples and diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 5f2ffc8bbf..a06f0bf4b3 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -2164,6 +2164,18 @@ initialCtx = def "manifest" = Just $ fun1 pure def "jvp2" = Just $ fun3 doJVP2 def "vjp2" = Just $ fun3 doVJP2 + def "jmp2" = Just $ fun3 $ \f x seeds -> do + v <- apply noLoc mempty f x + dvs <- + toArray' (valueShape v) . map (project "1") + <$> mapM (doJVP2 f x) (snd (fromArray seeds)) + pure $ toTuple [v, dvs] + def "mjp2" = Just $ fun3 $ \f x seeds -> do + v <- apply noLoc mempty f x + dvs <- + toArray' (valueShape x) . map (project "1") + <$> mapM (doVJP2 f x) (snd (fromArray seeds)) + pure $ toTuple [v, dvs] def "with_vjp" = Just $ fun3 $ \f _ arg -> -- XXX? We simply ignore the custom derivative. This is correct, but makes -- it more of a hassle to test them. diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 5829651910..eed67d707c 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -973,6 +973,34 @@ intrinsics = $ Scalar $ tupleRecord [Scalar $ t_b Nonunique, Scalar $ t_a Nonunique] ), + ( "jmp2", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), + Scalar (t_a Observe), + array_a Observe $ shape [n] + ] + $ RetType [] + $ Scalar + $ tupleRecord + [ Scalar $ t_b Nonunique, + array_b Unique $ shape [n] + ] + ), + ( "mjp2", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), + Scalar (t_a Observe), + array_b Observe $ shape [n] + ] + $ RetType [] + $ Scalar + $ tupleRecord + [ Scalar $ t_b Nonunique, + array_a Unique $ shape [n] + ] + ), ( "with_vjp", IntrinsicPolyFun [tp_a, tp_b] diff --git a/tests/ad/arr0.fut b/tests/ad/arr0.fut index 3ae014a68d..d1d09f0cea 100644 --- a/tests/ad/arr0.fut +++ b/tests/ad/arr0.fut @@ -1,22 +1,25 @@ -- == -- tags { autodiff } -def f (xs: [2]f64) = xs[0] * xs[1] +def primal (xs: [2]f64) = xs[0] * xs[1] -- == --- entry: f_jvp +-- entry: fwd fwd_vec -- input { [5.0, 7.0] } --- output { 7.0 5.0 } +-- output { [7.0, 5.0] } + +entry fwd xs = + [ jvp primal xs [1, 0] + , jvp primal xs [0, 1] + ] -entry f_jvp xs = - ( jvp f xs [1, 0] - , jvp f xs [0, 1] - ) +entry fwd_vec xs = + jmp primal xs [[1, 0], [0, 1]] -- == --- entry: f_vjp +-- entry: rev -- input { [5.0, 7.0] } -- output { [7.0, 5.0] } -entry f_vjp xs = - vjp f xs 1 +entry rev xs = + vjp primal xs 1 diff --git a/tests/ad/arr1.fut b/tests/ad/arr1.fut index 17f01a3b8f..9551a9ce83 100644 --- a/tests/ad/arr1.fut +++ b/tests/ad/arr1.fut @@ -1,17 +1,15 @@ -def f (x, y) : [2]f64 = [x + y, x * y] +def primal (x, y) : [2]f64 = [x + y, x * y] -- == -- tags { autodiff } --- entry: f_vjp f_jvp +-- entry: fwd fwd_vec -- input { 5.0 7.0 } --- output { [1.0,7.0] [1.0, 5.0] } +-- output { [[1.0,7.0], [1.0, 5.0]] } -entry f_jvp x y = - ( jvp f (x, y) (1, 0) - , jvp f (x, y) (0, 1) - ) +entry fwd x y = + [ jvp primal (x, y) (1, 0) + , jvp primal (x, y) (0, 1) + ] -entry f_vjp x y = - let (dx1, dx2) = vjp f (x, y) [1, 0] - let (dy1, dy2) = vjp f (x, y) [0, 1] - in ([dx1, dy1], [dx2, dy2]) +entry fwd_vec x y = + jmp primal (x, y) [(1, 0), (0, 1)] diff --git a/tests/ad/concat0.fut b/tests/ad/concat0.fut index 6b6cfe0ba2..f9f652edc6 100644 --- a/tests/ad/concat0.fut +++ b/tests/ad/concat0.fut @@ -1,18 +1,19 @@ -- == -- tags { autodiff } +-- entry: fwd_vec fwd_map +-- input { [1.0, 2.0, 3.0] } +-- output { [[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0]] } --- == --- entry: f_jvp --- input { [1,2,3] [4,5,6] } --- output { [1,2,3,4,5,6] } +def f (xs: []f64) = xs ++ xs -entry f_jvp xs ys : []i32 = - jvp (uncurry concat) (xs, ys) (xs, ys) +entry fwd_vec (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jmp2 f xs seeds).1 --- == --- entry: f_vjp --- input { [1,2,3] [4,5,6] } --- output { [1,2,3] [4,5,6] } +entry fwd_map (xs: []f64) = + map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) -entry f_vjp xs ys : ([]i32, []i32) = +entry rev xs ys : ([]i32, []i32) = vjp (uncurry concat) (xs, ys) (concat xs ys) diff --git a/tests/ad/consume0.fut b/tests/ad/consume0.fut index 44eb00a54a..589bc4c916 100644 --- a/tests/ad/consume0.fut +++ b/tests/ad/consume0.fut @@ -1,7 +1,7 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec rev_vec -- input { [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } @@ -15,3 +15,13 @@ entry fwd [n] (xs: *[n]f64) = entry rev [n] (xs: *[n]f64) = #[unsafe] tabulate n (\i -> vjp f xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp f xs seeds + +entry rev_vec [n] (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp f xs seeds diff --git a/tests/ad/consume1.fut b/tests/ad/consume1.fut index 66e250d0f4..8c88d76483 100644 --- a/tests/ad/consume1.fut +++ b/tests/ad/consume1.fut @@ -1,7 +1,7 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec rev_vec -- input { true [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } @@ -16,3 +16,13 @@ entry fwd [n] b (xs: *[n]f64) = entry rev [n] b (xs: *[n]f64) = #[unsafe] tabulate n (\i -> vjp (f b) xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] b (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (f b) xs seeds + +entry rev_vec [n] b (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp (f b) xs seeds diff --git a/tests/ad/custom/radixsort.fut b/tests/ad/custom/radixsort.fut index f1961e9ae5..ab10dd7727 100644 --- a/tests/ad/custom/radixsort.fut +++ b/tests/ad/custom/radixsort.fut @@ -1,9 +1,15 @@ -- Custom derivative for radix sort. -- == +-- tags { autodiff } -- entry: main_standard main_custom -- input { [4f32,3f32,2f32,1f32] [0.1f32,0.2f32,0.3f32,0.4f32] } -- output { [0.4f32, 0.3f32, 0.2f32, 0.1f32 ] } +-- == +-- entry: main_custom_vec +-- input { [4f32,3f32,2f32,1f32] [[0.1f32,0.2f32,0.3f32,0.4f32],[0.5f32,0.6f32,0.7f32,0.8f32]] } +-- output { [[0.4f32, 0.3f32, 0.2f32, 0.1f32], [0.8f32, 0.7f32, 0.6f32, 0.5f32]] } + def radix_sort_step [n] 't (f: t -> u32) (xs: [n]t) (b: i32) : [n]t = let bits = map (\x -> (i32.u32 (f x >> u32.i32 b)) & 1) xs let bits_neg = map (1 -) bits @@ -20,10 +26,11 @@ def radix_sort [n] 't (f: t -> u32) (xs: [n]t) : [n]t = def differentiable_radix_sort [n] 't (f: t -> u32) (xs: [n]t) = (with_vjp (\xs -> - unzip (radix_sort (f <-< (.0)) (zip xs (iota n)))) - (\(_, perm) (xs_adj, _) -> - scatter (copy xs_adj) perm xs_adj) - xs).0 + unzip (radix_sort (f <-< (.0)) (zip xs (iota n)))) + (\(_, perm) (xs_adj, _) -> + scatter (copy xs_adj) perm xs_adj) + xs).0 entry main_standard = vjp (radix_sort f32.to_bits) entry main_custom = vjp (differentiable_radix_sort f32.to_bits) +entry main_custom_vec = mjp (differentiable_radix_sort f32.to_bits) diff --git a/tests/ad/custom/toplevel.fut b/tests/ad/custom/toplevel.fut index 368e8bcb10..3efd612616 100644 --- a/tests/ad/custom/toplevel.fut +++ b/tests/ad/custom/toplevel.fut @@ -1,6 +1,7 @@ -- | If a custom derivative occurs at top level, just get rid of it. -- == +-- tags { autodiff } -- entry: do_primal -- input { 2.5 } output { 5.0 } @@ -10,8 +11,8 @@ def primal (x: f64) = with_vjp (\x -> x * 2) - (\c x_adj -> x_adj + f64.sqrt c) - x + (\c x_adj -> x_adj + f64.sqrt c) + x entry do_vjp x = vjp primal x 1 entry do_primal x = primal x diff --git a/tests/ad/for1.fut b/tests/ad/for1.fut index 799fe893ff..f356cd02c9 100644 --- a/tests/ad/for1.fut +++ b/tests/ad/for1.fut @@ -12,7 +12,7 @@ def pow_list [n] y (xs: [n]i32) = entry prim y xs = pow_list y xs -- == --- entry: f_vjp f_jvp +-- entry: f_vjp f_jvp f_mjp f_jmp -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], @@ -23,3 +23,12 @@ entry f_jvp [n] y (xs: [n]i32) = entry f_vjp [n] y (xs: [n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) + +entry f_jmp [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (pow_list y) xs seeds + |> transpose + +entry f_mjp [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp (pow_list y) xs seeds diff --git a/tests/ad/for2.fut b/tests/ad/for2.fut index 474fbaae88..bb111cbf4b 100644 --- a/tests/ad/for2.fut +++ b/tests/ad/for2.fut @@ -12,9 +12,16 @@ def mult_list xs = entry prim = mult_list -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [11,5,13] } output { [0,0,26] } entry f_jvp [n] (xs: [n]i32) = tabulate n (\i -> jvp mult_list xs (replicate n 0 with [i] = 1)) entry f_vjp [n] (xs: [n]i32) = vjp mult_list xs 1 + +entry f_jmp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp mult_list xs seeds + +entry f_mjp [n] (xs: [n]i32) = + (mjp mult_list xs [1])[0] diff --git a/tests/ad/for3.fut b/tests/ad/for3.fut index c7d02db01c..42350aebbd 100644 --- a/tests/ad/for3.fut +++ b/tests/ad/for3.fut @@ -14,7 +14,7 @@ def square [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = square xs -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], @@ -27,3 +27,12 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) + +entry f_jmp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp square xs seeds + |> transpose + +entry f_mjp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp square xs seeds diff --git a/tests/ad/fwd/acc0.fut b/tests/ad/fwd/acc0.fut index 1910c46e28..ee21d591a5 100644 --- a/tests/ad/fwd/acc0.fut +++ b/tests/ad/fwd/acc0.fut @@ -5,20 +5,18 @@ import "../../accs/intrinsics" def f (acc: *acc ([]i32)) i = write acc i (i32.i64 i) --- square entries - -- == -- entry: prim -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } --- output { [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] } +-- output { [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] } entry prim [n] (xs: [n]i32) = let (xs': *[n]i32) = copy xs - in reduce_by_index_stream xs' (*) 1 f (map i64.i32 (xs :> [n]i32)) + in reduce_by_index_stream xs' (+) 0 f (map i64.i32 (xs :> [n]i32)) -- == -- entry: f_jvp -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } --- output { [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] } +-- output { [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] } entry f_jvp (xs: *[]i32) = jvp prim xs (replicate 10 1) diff --git a/tests/ad/gather0.fut b/tests/ad/gather0.fut index ebc7c9905a..76d4b9727c 100644 --- a/tests/ad/gather0.fut +++ b/tests/ad/gather0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } -- output { [[1.0, 0.0, 0.0, 0.0], -- [0.0, 1.0, 0.0, 0.0], @@ -21,3 +21,11 @@ entry fwd_J [n] [m] (xs: [n]f64) (is: [m]i64) = entry rev_J [n] [m] (xs: [n]f64) (is: [m]i64) = tabulate m (\j -> vjp (`gather` is) xs (replicate m 0 with [j] = 1)) + +entry fwd_vec_J [n] [m] (xs: [n]f64) (is: [m]i64) = + let seeds = tabulate n (\j -> replicate n 0 with [j] = 1) + in transpose (jmp (`gather` is) xs seeds) + +entry rev_vec_J [n] [m] (xs: [n]f64) (is: [m]i64) = + let seeds = tabulate m (\j -> replicate m 0 with [j] = 1) + in mjp (`gather` is) xs seeds diff --git a/tests/ad/gather1.fut b/tests/ad/gather1.fut index bb9863fed8..3ef72e17d1 100644 --- a/tests/ad/gather1.fut +++ b/tests/ad/gather1.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [1i64, 0i64, 1i64, 1i64] @@ -43,3 +43,16 @@ entry fwd_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = entry rev_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = tabulate_2d n k (\i j -> vjp (`mapgather` is) xs (onehot_2d n k (i, j))) + +entry fwd_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = + let seeds = tabulate (n * m) (\p -> onehot_2d n m (p / m, p % m)) + in jmp (`mapgather` is) xs seeds + |> unflatten + |> map transpose + |> map (map transpose) + |> map transpose + +entry rev_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = + let seeds = tabulate (n * k) (\p -> onehot_2d n k (p / k, p % k)) + in mjp (`mapgather` is) xs seeds + |> unflatten diff --git a/tests/ad/gather2.fut b/tests/ad/gather2.fut index 3bc3efcc24..3920477bc8 100644 --- a/tests/ad/gather2.fut +++ b/tests/ad/gather2.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input -- { -- [1.0,2.0,3.0,4.0] @@ -32,3 +32,14 @@ def onehot_2d n m p : [n][m]f64 = entry rev_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = tabulate_2d n m (\i j -> vjp (`mapgather` iss) xs (onehot_2d n m (i, j))) + +entry fwd_vec_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = + let seeds = tabulate k (\i -> onehot k i) + in jmp (`mapgather` iss) xs seeds + |> transpose + |> map transpose + +entry rev_vec_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = + let seeds = tabulate (n * m) (\p -> onehot_2d n m (p / m, p % m)) + in mjp (`mapgather` iss) xs seeds + |> unflatten diff --git a/tests/ad/hist_add.fut b/tests/ad/hist_add.fut new file mode 100644 index 0000000000..26d275f320 --- /dev/null +++ b/tests/ad/hist_add.fut @@ -0,0 +1,55 @@ +-- Addition +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 2f32, 0f32, 2f32, 0f32, 0f32, 2f32, 0f32, 0f32], +-- [0f32, 0f32, 3f32, 0f32, 0f32, 3f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 4f32, 4f32, 0f32, 0f32, 0f32], +-- [5f32, 0f32, 0f32, 5f32, 0f32, 0f32, 0f32, 0f32, 5f32, 0f32, 0f32, 5f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 6f32, 0f32, 0f32, 6f32, 0f32, 0f32, 0f32, 0f32, 6f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- [[3f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 12f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 13f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 9f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 24f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 30f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- } + +def f [n] [m] (is: [n]i64) (vs: [n]f32, c: [m]f32) = + let r = hist (+) 0 m is vs + in map2 (*) r c + +entry fwd_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry fwd_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = + tabulate (n + m) (\i -> + ( tabulate n ((i ==) >-> f32.bool) + , tabulate m (((i - n) ==) >-> f32.bool) + )) + in jmp (f is) (vs, c) seeds + |> transpose + |> map split + |> unzip + +entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in mjp (f is) (vs, c) seeds + |> unzip diff --git a/tests/ad/hist_complex.fut b/tests/ad/hist_complex.fut new file mode 100644 index 0000000000..0f3042c90a --- /dev/null +++ b/tests/ad/hist_complex.fut @@ -0,0 +1,39 @@ +-- == +-- tags { autodiff } +-- entry: rev_map rev_vec +-- input { +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 2f32, 4f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 96f32, 0f32, 160f32, 0f32, 0f32, 120f32, 0f32, 0f32], +-- [0f32, 0f32, 36f32, 0f32, 0f32, 42f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 72f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 5376f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- [[4f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 240f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 84f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0.5, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0.5]] +-- } + +def f [n] [m] (is: [n]i64) (vs: [n]f32, c: [m]f32) = + let r = hist (\x y -> x * y * 2) 0.5 m is vs + in map2 (*) r c + +entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in mjp (f is) (vs, c) seeds + |> unzip diff --git a/tests/ad/hist_minmax.fut b/tests/ad/hist_minmax.fut new file mode 100644 index 0000000000..a21ecff07f --- /dev/null +++ b/tests/ad/hist_minmax.fut @@ -0,0 +1,35 @@ +-- Maximum +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { +-- 5i64 +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- } + +def primal [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + hist f32.max f32.lowest k is vs + +entry fwd_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + tabulate n (\i -> jvp (primal k is) vs (replicate n 0 with [i] = 1)) + |> transpose + +entry fwd_vec [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (primal k is) vs seeds + |> transpose + +entry rev_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + tabulate k (\i -> vjp (primal k is) vs (replicate k 0 with [i] = 1)) + +entry rev_vec [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) + in mjp (primal k is) vs seeds diff --git a/tests/ad/hist_mul.fut b/tests/ad/hist_mul.fut new file mode 100644 index 0000000000..878aaaab47 --- /dev/null +++ b/tests/ad/hist_mul.fut @@ -0,0 +1,40 @@ +-- Multiplication +-- == +-- tags { autodiff } +-- entry: rev_map rev_vec +-- input { +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 2f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 24f32, 0f32, 40f32, 0f32, 0f32, 30f32, 0f32, 0f32], +-- [0f32, 0f32, 18f32, 0f32, 0f32, 21f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 36f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1344f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- [[2f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 60f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 42f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32]] +-- } + +def f [n] [m] (is: [n]i64) (vs: [n]f32, c: [m]f32) = + let r = hist (*) 1 m is vs + in map2 (*) r c + +entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in mjp (f is) (vs, c) seeds + |> unzip diff --git a/tests/ad/index.fut b/tests/ad/index.fut new file mode 100644 index 0000000000..677160d33b --- /dev/null +++ b/tests/ad/index.fut @@ -0,0 +1,16 @@ +-- == +-- tags { autodiff } +-- entry: fwd_vec fwd_map +-- input { 0i32 [1f32, 2f32, 3f32] } +-- output { [1f32, 0f32, 0f32] } + +def f (i: i32) (xs: []f32) = xs[i] + +entry fwd_vec l (xs: []f32) : []f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in (jmp2 (f l) xs seeds).1 + +entry fwd_map l (xs: []f32) : []f32 = + map (\i -> jvp (f l) xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) diff --git a/tests/ad/issue2256.fut b/tests/ad/issue2256.fut index 09fe505daa..f3d78219f0 100644 --- a/tests/ad/issue2256.fut +++ b/tests/ad/issue2256.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [5040.0, 2521.0, 1684.0, 1278.0, 1104.0, 1440.0] } @@ -13,3 +13,7 @@ entry rev [m] (x: [m]f64) = entry fwd [m] (x: [m]f64) = tabulate m (\i -> jvp (\x' -> primal x') x (replicate m 0 with [i] = 1)) + +entry fwd_vec [m] (x: [m]f64) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in jmp (\x' -> primal x') x seeds diff --git a/tests/ad/map0.fut b/tests/ad/map0.fut index a531f6ec11..672c60ddf1 100644 --- a/tests/ad/map0.fut +++ b/tests/ad/map0.fut @@ -1,7 +1,21 @@ -- == -- tags { autodiff } --- entry: rev --- input { [1,2,3] [3,2,1] } --- output { [6,4,2] } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } +-- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0], [0.0f32, 0.0, 3.0]] } -entry rev = vjp (map (* 2i32)) +def prim = map2 (f32.*) + +entry fwd_map [n] (xs: [n]f32) (ys: [n]f32) = + tabulate n (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (prim xs) ys seeds + +entry rev_map [n] (xs: [n]f32) (ys: [n]f32) = + transpose (tabulate n (\i -> vjp (prim xs) ys (replicate n 0 with [i] = 1))) + +entry rev_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in transpose (mjp (prim xs) ys seeds) diff --git a/tests/ad/map1.fut b/tests/ad/map1.fut index efc35a9ad9..56467b1696 100644 --- a/tests/ad/map1.fut +++ b/tests/ad/map1.fut @@ -1,9 +1,18 @@ --- +-- Like map0, but we do not compute the full Jacobian, so the vector size is not +-- the same as the input size. -- == -- tags { autodiff } --- entry: rev --- input { [[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]] [1.0,2.0] } --- output {[[24.0, 12.0, 8.0, 6.0], --- [48.0, 24.0, 16.0, 12.0]] } +-- entry: fwd fwd_vec +-- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } +-- output { [[5.0f32, 0.0, 0.0], [0.0f32, 7.0, 0.0]] } -entry rev = vjp (map f64.product) +def prim = map2 (f32.*) + +def k = 2i64 + +entry fwd [n] (xs: [n]f32) (ys: [n]f32) = + tabulate k (\i -> jvp (uncurry prim) (xs, ys) (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) + in jmp (uncurry prim) (xs, ys) seeds diff --git a/tests/ad/map2.fut b/tests/ad/map2.fut index 6970cce527..59d35c2dab 100644 --- a/tests/ad/map2.fut +++ b/tests/ad/map2.fut @@ -1,15 +1,21 @@ -- Map with free variable. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_map rev_map rev_vec -- input { 2.0 [1.0,2.0,3.0] } -- output { [1.0,2.0,3.0] } +def primal xs (c': f64) = map (* c') xs + def onehot n i : [n]f64 = tabulate n (\j -> f64.bool (i == j)) -entry fwd_J [n] (c: f64) (xs: [n]f64) = - jvp (\c' -> map (* c') xs) c 1 +entry fwd_map [n] (c: f64) (xs: [n]f64) = + jvp (primal xs) c 1 + +entry rev_map [n] (c: f64) (xs: [n]f64) = + tabulate n (\i -> vjp (primal xs) c (onehot n i)) -entry rev_J [n] (c: f64) (xs: [n]f64) = - tabulate n (\i -> vjp (\c' -> map (* c') xs) c (onehot n i)) +entry rev_vec [n] (c: f64) (xs: [n]f64) = + let seeds = tabulate n (\i -> onehot n i) + in mjp (primal xs) c seeds diff --git a/tests/ad/map3.fut b/tests/ad/map3.fut index 5b817ac57d..12caf5ac51 100644 --- a/tests/ad/map3.fut +++ b/tests/ad/map3.fut @@ -1,12 +1,18 @@ -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev_map rev_vec -- input { 1i32 [1i32,2i32,3i32] } -- output { [1i32,2i32,3i32] } +def primal xs (x: i32) = map (* x) xs + entry fwd [n] (x: i32) (xs: [n]i32) = - jvp (\x -> map (* x) xs) x 1 + jvp (primal xs) x 1 -entry rev [n] (x: i32) (xs: [n]i32) = +entry rev_map [n] (x: i32) (xs: [n]i32) = tabulate n (\i -> - vjp (\x -> map (* x) xs) x (replicate n 0 with [i] = 1)) + vjp (primal xs) x (replicate n 0 with [i] = 1)) + +entry rev_vec [n] (x: i32) (xs: [n]i32) = + let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) + in mjp (primal xs) x seeds diff --git a/tests/ad/map4.fut b/tests/ad/map4.fut index aabbc4f02b..b7b00e250a 100644 --- a/tests/ad/map4.fut +++ b/tests/ad/map4.fut @@ -1,13 +1,13 @@ -- An array is both a 'map' input and a free variable in the lambda. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_map fwd_vec rev_map rev_vec -- input { [1,2,3] } -- output { -- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] -- } -def f (xs: []i32) = +def primal (xs: []i32) = map (\x -> map (+ x) xs) xs def onehot n i : [n]i32 = @@ -16,11 +16,22 @@ def onehot n i : [n]i32 = def onehot_2d n m p : [n][m]i32 = tabulate_2d n m (\i j -> i32.bool ((i, j) == p)) -entry fwd_J [n] (xs: [n]i32) = - tabulate n (\i -> jvp f xs (onehot n i)) +entry fwd_map [n] (xs: [n]i32) = + tabulate n (\i -> jvp primal xs (onehot n i)) |> map transpose |> transpose |> map transpose -entry rev_J [n] (xs: [n]i32) = - tabulate_2d n n (\i j -> vjp f xs (onehot_2d n n (i, j))) +entry fwd_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in jmp primal xs seeds + |> map transpose + |> transpose + |> map transpose + +entry rev_map [n] (xs: [n]i32) = + tabulate_2d n n (\i j -> vjp primal xs (onehot_2d n n (i, j))) + +entry rev_vec [n] (xs: [n]i32) = + let seeds = tabulate_2d n n (\i j -> onehot_2d n n (i, j)) + in unflatten (mjp primal xs (flatten seeds)) diff --git a/tests/ad/map5.fut b/tests/ad/map5.fut index 15e6e8a28c..3e64552ee9 100644 --- a/tests/ad/map5.fut +++ b/tests/ad/map5.fut @@ -1,7 +1,7 @@ -- Map with free array variable. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1,2,3],[4,5,6]] [0,0] } -- output { [[1, 0], [0, 1]] } @@ -16,3 +16,11 @@ entry fwd_J [n] [m] (free: [n][m]i32) (is: [n]i32) = entry rev_J [n] [m] (free: [n][m]i32) (is: [n]i32) = tabulate n (\i -> vjp (f free) is (onehot n i)) + +entry fwd_vec_J [n] [m] (free: [n][m]i32) (is: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in jmp (f free) is seeds |> transpose + +entry rev_vec_J [n] [m] (free: [n][m]i32) (is: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in mjp (f free) is seeds diff --git a/tests/ad/map6.fut b/tests/ad/map6.fut index 8bf30b853e..2748fe972d 100644 --- a/tests/ad/map6.fut +++ b/tests/ad/map6.fut @@ -1,7 +1,7 @@ -- #1878 -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [[0.0, 2.0, 3.0, 4.0], -- [0.0, 0.0, 1.0, 1.0], @@ -29,3 +29,11 @@ entry fwd_J (x: [8]f64) = entry rev_J (x: [8]f64) = transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) + +entry fwd_vec_J (x: [8]f64) = + let seeds = tabulate 8 (\i -> replicate 8 0 with [i] = 1) + in jmp obj x seeds + +entry rev_vec_J (x: [8]f64) = + let seeds = tabulate 4 (\i -> replicate 4 0 with [i] = 1) + in transpose (mjp obj x seeds) diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut index 6c9cf348ba..ed642f0ca2 100644 --- a/tests/ad/map7.fut +++ b/tests/ad/map7.fut @@ -2,9 +2,17 @@ -- has active free variables. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_map fwd_vec rev_map rev_vec -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } --- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] } +-- output { [[0.0, 2.0, 3.0, 4.0], +-- [0.0, 0.0, 1.0, 1.0], +-- [0.0, 0.0, 0.0, 1.0], +-- [0.0, 0.0, 0.0, 0.0], +-- [-4.0, -6.0, -7.0, -8.0], +-- [0.0, 0.0, -1.0, -1.0], +-- [0.0, 0.0, 0.0, -1.0], +-- [0.0, 0.0, 0.0, 0.0]] +-- } def obj (x: [8]f64) = #[unsafe] @@ -15,10 +23,16 @@ def obj (x: [8]f64) = map (map f64.sum) col_w_pre_red let col_eq: [4]f64 = map (\w -> w[0] - w[1]) col_w_red - in f64.maximum col_eq + in col_eq -entry fwd_J (x: [8]f64) = +entry fwd_map (x: [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) -entry rev_J (x: [8]f64) = - vjp obj x 1 +entry fwd_vec (x: [8]f64) = + jmp obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) + +entry rev_map (x: [8]f64) = + transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) + +entry rev_vec (x: [8]f64) = + transpose (#[unroll] mjp obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) diff --git a/tests/ad/matmul.fut b/tests/ad/matmul.fut index 123eaa41b4..676fbf6dcc 100644 --- a/tests/ad/matmul.fut +++ b/tests/ad/matmul.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] @@ -32,3 +32,16 @@ entry fwd_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = entry rev_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = tabulate_2d n p (\i j -> vjp (matmul xss) yss (onehot_2d n p (i, j))) + +entry fwd_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = + let seeds = tabulate (m * p) (\k -> onehot_2d m p (k / p, k % p)) + in jmp (matmul xss) yss seeds + |> unflatten + |> transpose + |> map transpose + |> transpose + +entry rev_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = + let seeds = tabulate (n * p) (\k -> onehot_2d n p (k / p, k % p)) + in mjp (matmul xss) yss seeds + |> unflatten diff --git a/tests/ad/maximum.fut b/tests/ad/maximum.fut index f605182d5b..1d1389f63d 100644 --- a/tests/ad/maximum.fut +++ b/tests/ad/maximum.fut @@ -1,11 +1,11 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] } -- input { [1.0, 1.0] } -- output { [1.0, 0.0] } --- structure { /Screma 2 } +-- structure { /Screma 3 } def f = map f64.abs >-> f64.maximum @@ -14,3 +14,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp f xs (tabulate n ((== i) >-> f64.bool))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in jmp f xs seeds diff --git a/tests/ad/minimum.fut b/tests/ad/minimum.fut index f8942cf184..a99c01b850 100644 --- a/tests/ad/minimum.fut +++ b/tests/ad/minimum.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] } -- input { [1.0, 1.0] } @@ -11,3 +11,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp f64.minimum xs (tabulate n ((== i) >-> f64.bool))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in jmp f64.minimum xs seeds diff --git a/tests/ad/minmax.fut b/tests/ad/minmax.fut index e71b69734e..54ae27dbe5 100644 --- a/tests/ad/minmax.fut +++ b/tests/ad/minmax.fut @@ -1,11 +1,11 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -- [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] -- } --- structure { /Screma 2 } +-- structure { /Screma 3 } def f xs = let ys = map f64.abs xs @@ -18,3 +18,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = unzip (tabulate n (\i -> jvp f xs (tabulate n ((== i) >-> f64.bool)))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in unzip (jmp f xs seeds) diff --git a/tests/ad/reduce-vec-minmax0.fut b/tests/ad/reduce-vec-minmax0.fut index d846d4dec5..336e9c7037 100644 --- a/tests/ad/reduce-vec-minmax0.fut +++ b/tests/ad/reduce-vec-minmax0.fut @@ -14,8 +14,24 @@ def forward [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = def reverse [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = tabulate n (\i -> vjp redmap arr (replicate n 0 with [i] = 1)) +def forward_vec [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = + let seeds = tabulate (m * n) (\p -> + let i = p / n + let j = p % n + in replicate m (replicate n 0) with [i] = (replicate n 0 with [j] = 1)) + in jmp redmap arr seeds + |> unflatten + |> transpose + +def reverse_vec [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp redmap arr seeds + def main [n] [m] (arr: [m][n]f32) : bool = let l = n * m * n let fs = forward arr |> flatten_3d :> [l]f32 let rs = reverse arr |> flatten_3d :> [l]f32 - in reduce (&&) true (map2 (\i j -> f32.abs (i - j) < 0.0001f32) fs rs) + let fvs = forward_vec arr |> flatten_3d :> [l]f32 + let rvs = reverse_vec arr |> flatten_3d :> [l]f32 + let close xs ys = reduce (&&) true (map2 (\i j -> f32.abs (i - j) < 0.0001f32) xs ys) + in close fs rs && close fs fvs && close rs rvs diff --git a/tests/ad/reduce0.fut b/tests/ad/reduce0.fut index a0dc3b7e04..7f57f1f992 100644 --- a/tests/ad/reduce0.fut +++ b/tests/ad/reduce0.fut @@ -1,11 +1,22 @@ --- Simple reduce with multiplication -- == -- tags { autodiff } --- entry: rev --- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32] 1.0f32 } output { [24.0f32, 12.0f32, 8.0f32, 6.0f32] 24.0f32 } +-- entry: fwd_vec fwd_map rev_vec +-- input { [1f32, 2f32, 3f32] } +-- output { [6f32, 3f32, 2f32] } -def red_mult [n] (xs: [n]f32, c: f32) : f32 = - reduce (*) 1 xs * c +def f (xs: []f32) = f32.product xs -entry rev [n] (xs: [n]f32) (c: f32) = - vjp red_mult (xs, c) 1 +entry fwd_vec (xs: []f32) : []f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in (jmp2 f xs seeds).1 + +entry fwd_map (xs: []f32) : []f32 = + map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) + +-- No rev_map because it would just get optimised away. The rev_vec is pointless +-- enough already. + +entry rev_vec (xs: []f32) : []f32 = + head (mjp f xs [1]) diff --git a/tests/ad/reduce1.fut b/tests/ad/reduce1.fut index 0623eb8aac..c43b35def0 100644 --- a/tests/ad/reduce1.fut +++ b/tests/ad/reduce1.fut @@ -1,24 +1,71 @@ --- Reduce with a fancier operator. +-- Reduce with 2x2 matrix multiplication. -- == -- tags { autodiff } --- entry: rev --- input { [1.0,2.0,3.0] [2.0,3.0,4.0] [3.0,4.0,5.0] [4.0,5.0,6.0] } --- output { [47.0, 28.0, 32.0] --- [83.0, 44.0, 32.0] --- [47.0, 42.0, 42.0] --- [83.0, 66.0, 42.0] } - -def mm2by2 (a1: f64, b1: f64, c1: f64, d1: f64) - (a2: f64, b2: f64, c2: f64, d2: f64) = +-- entry: fwd_map rev_map fwd_vec rev_vec +-- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } +-- output { +-- [[[92.0f32, 36.0, 0.0, 0.0], +-- [8.0f32, 20.0, 16.0, 40.0], +-- [32.0f32, 16.0, 20.0, 10.0], +-- [23.0f32, 0.0, 36.0, 0.0]], +-- [[59.0f32, 23.0, 0.0, 0.0], +-- [5.0f32, 13.0, 10.0, 26.0], +-- [24.0f32, 8.0, 15.0, 5.0], +-- [0.0f32, 23.0, 0.0, 36.0]], +-- [[0.0f32, 0.0, 92.0, 36.0], +-- [24.0f32, 60.0, 32.0, 80.0], +-- [80.0f32, 40.0, 52.0, 26.0], +-- [59.0f32, 0.0, 92.0, 0.0]], +-- [[0.0f32, 0.0, 59.0, 23.0], +-- [15.0f32, 39.0, 20.0, 52.0], +-- [60.0f32, 20.0, 39.0, 13.0], +-- [0.0f32, 59.0, 0.0, 92.0]]] +-- } + +def mm2by2 (a1: f32, b1: f32, c1: f32, d1: f32) + (a2: f32, b2: f32, c2: f32, d2: f32) = ( a1 * a2 + b1 * c2 , a1 * b2 + b1 * d2 , c1 * a2 + d1 * c2 , c1 * b2 + d1 * d2 ) -def red_mm [n] (xs: [n](f64, f64, f64, f64)) = +def primal [n] (xs: [n](f32, f32, f32, f32)) = reduce mm2by2 (1, 0, 0, 1) xs -entry rev [n] (xs1: [n]f64) (xs2: [n]f64) (xs3: [n]f64) (xs4: [n]f64) = - vjp red_mm (zip4 xs1 xs2 xs3 xs4) (1, 1, 1, 1) - |> unzip4 +def fromarr = \(x: [4]f32) -> (x[0], x[1], x[2], x[3]) + +def fromarrs = map fromarr +def toarrs = map (\(a, b, c, d) -> [a, b, c, d]) + +def onehot_1d n x = + tabulate n (\i -> f32.bool (i == x)) + +def onehot_2d n m x y = + tabulate_2d n m (\i j -> f32.bool ((i, j) == (x, y))) + +entry fwd_map [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + in tabulate (n * 4) (\i -> jvp primal input (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) + |> toarrs + |> transpose + |> map unflatten + +entry fwd_vec [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) + in jmp primal input seeds + |> toarrs + |> transpose + |> map unflatten + +entry rev_map [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + in tabulate 4 (\i -> vjp primal input (fromarr (onehot_1d 4 i))) + |> map toarrs + +entry rev_vec [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate 4 (\i -> fromarr (onehot_1d 4 i)) + in mjp primal input seeds + |> map toarrs diff --git a/tests/ad/reduce2.fut b/tests/ad/reduce2.fut index 43d2c85f16..09d6ce980d 100644 --- a/tests/ad/reduce2.fut +++ b/tests/ad/reduce2.fut @@ -1,7 +1,7 @@ -- Result of one reduction is used free in a map. -- == -- tags { no_ispc autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec -- input { [3f64, 1f64, 5f64] } output { [-1.000000f64, -1.000000f64, -1.000000f64] } def sumBy 'a (f: a -> f64) (xs: []a) : f64 = map f xs |> f64.sum @@ -14,3 +14,6 @@ def f (arr: []f64) = entry fwd x = map (jvp f x) [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] entry rev x = vjp f x 1f64 + +entry fwd_vec x = + jmp f x [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] diff --git a/tests/ad/reduce3.fut b/tests/ad/reduce3.fut new file mode 100644 index 0000000000..4d4ad243f8 --- /dev/null +++ b/tests/ad/reduce3.fut @@ -0,0 +1,18 @@ +-- Reduce with vectorised addition. +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_vec +-- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } +-- output { [[1.0f32, 1.0], [1.0f32, 1.0], [1.0f32, 1.0], [1.0f32, 1.0], [1.0f32, 1.0]] } + +def primal [n] [k] (a: [n][k]f32) = + reduce (map2 (+)) (replicate k 0) a + +entry fwd_map [n] [k] (a: [n][k]f32) = + tabulate n (\i -> jvp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) + +entry fwd_vec [n] [k] (a: [n][k]f32) = + jmp primal a (tabulate n (\i -> (replicate n (replicate k 0) with [i] = replicate k 1))) + +entry rev_vec [n] [k] (a: [n][k]f32) = + head (mjp primal a [replicate k 1]) diff --git a/tests/ad/reduce_by_index0.fut b/tests/ad/reduce_by_index0.fut index 2f8471250b..42b617f4bb 100644 --- a/tests/ad/reduce_by_index0.fut +++ b/tests/ad/reduce_by_index0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: f_jvp +-- entry: f_jvp f_jmp -- input { [0i64,1i64,2i64,3i64] [1f64,2f64,3f64,4f64] } -- output { [[1f64,0f64,0f64,0f64],[0f64,1f64,0f64,0f64],[0f64,0f64,1f64,0f64],[0f64,0f64,0f64,1f64]] } def f [n] (is: [n]i64) (vs: [n]f64) = @@ -9,3 +9,8 @@ def f [n] (is: [n]i64) (vs: [n]f64) = entry f_jvp [n] (is: [n]i64) (vs: [n]f64) = tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) |> transpose + +entry f_jmp [n] (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (f is) vs seeds + |> transpose diff --git a/tests/ad/reducebyindex3.fut b/tests/ad/reducebyindex3.fut index 3cc5c18287..15ea98af85 100644 --- a/tests/ad/reducebyindex3.fut +++ b/tests/ad/reducebyindex3.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { -- [0i64,1i64,2i64,1i64,0i64,1i64,2i64] -- [1f64,2f64,3f64,4f64,5f64,6f64,7f64] } @@ -18,6 +18,10 @@ entry f [n] (is: [n]i64) (vs: [n]f64) = entry rev [n] (is: [n]i64) (vs: [n]f64) = tabulate 4 (\i -> vjp (f is) vs (replicate 4 0 with [i] = 1)) +entry rev_vec [n] (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate 4 (\i -> replicate 4 0 with [i] = 1) + in mjp (f is) vs seeds + -- entry fwd [n] (is: [n]i64) (vs: [n]f64) = -- tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) -- |> map (.1) |> transpose diff --git a/tests/ad/reducebyindex4.fut b/tests/ad/reducebyindex4.fut index ea7262c188..b61e0273f1 100644 --- a/tests/ad/reducebyindex4.fut +++ b/tests/ad/reducebyindex4.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { -- [ 0i64, 1i64, 2i64, 1i64, 0i64, 1i64, 2i64, 1i64, 0i64] -- [ 1f32, 2f32, 3f32, 4f32, 5f32, 6f32, 7f32, 8f32, 9f32] @@ -19,6 +19,10 @@ entry rev [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = vjp (f is) (zip vs0 vs1) (replicate 4 (1, 1)) |> unzip +entry rev_vec [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = + (mjp (f is) (zip vs0 vs1) [replicate 4 (1, 1)])[0] + |> unzip + entry fwd [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = tabulate n (\i -> jvp (f is) (zip vs0 vs1) (replicate n (0, 0) with [i] = (1, 1))) |> transpose diff --git a/tests/ad/reducebyindexminmax3.fut b/tests/ad/reducebyindexminmax3.fut index 9ba08969dc..3f29347471 100644 --- a/tests/ad/reducebyindexminmax3.fut +++ b/tests/ad/reducebyindexminmax3.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 5f32 } -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } @@ -14,4 +14,7 @@ def red_max [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32, c: f32) = entry rev [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max dst is) (vs, c) (replicate m 0 with [0] = 1) +entry rev_vec [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = + (mjp (red_max dst is) (vs, c) [replicate m 0 with [0] = 1])[0] + --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindexminmax4.fut b/tests/ad/reducebyindexminmax4.fut index efd0acb266..3214886b0d 100644 --- a/tests/ad/reducebyindexminmax4.fut +++ b/tests/ad/reducebyindexminmax4.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 5f32 } -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } @@ -14,4 +14,7 @@ def red_max [n] [m] (vs: [n]f32) (is: [n]i64) (dst: [m]f32, c: f32) = entry rev [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max vs is) (dst, c) (replicate m 0 with [0] = 1) +entry rev_vec [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = + (mjp (red_max vs is) (dst, c) [replicate m 0 with [0] = 1])[0] + --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 94fd6b62d1..92a4739e33 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -1,19 +1,24 @@ -- == --- tags { autodiff } --- compiled random input { [500]i64 [100][30]f32 [500][30]f32 } output { true } +-- tags { autodiff no_ispc } +-- entry: rev fwd rev_vec fwd_vec +-- compiled input @ reducebyindexminmax7.in output @ reducebyindexminmax7.out.gz -def primal [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = - reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs +def primal [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = + let is = map (\i -> (i64.abs i) %% m) is' + in reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs -def rev [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = +entry rev [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate m (\i -> vjp (primal is dst) vs (replicate m (replicate k 0) with [i] = replicate k 1)) -def fwd [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = +entry fwd [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate n (\i -> jvp (primal is dst) vs (replicate n (replicate k 0) with [i] = replicate k 1)) |> transpose -def main [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = - let is = map (\i -> (i64.abs i) %% m) is' - let r = rev is dst vs - let f = fwd is dst vs - in map2 (map2 (==)) r f |> map (reduce (&&) true) |> reduce (&&) true +entry rev_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = + let seeds = tabulate m (\i -> replicate m (replicate k 0) with [i] = replicate k 1) + in mjp (primal is dst) vs seeds + +entry fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in jmp (primal is dst) vs seeds + |> transpose diff --git a/tests/ad/reducebyindexminmax7.in b/tests/ad/reducebyindexminmax7.in new file mode 100644 index 0000000000..b5eac437a5 Binary files /dev/null and b/tests/ad/reducebyindexminmax7.in differ diff --git a/tests/ad/reducebyindexminmax7.out.gz b/tests/ad/reducebyindexminmax7.out.gz new file mode 100644 index 0000000000..87f996f4e5 Binary files /dev/null and b/tests/ad/reducebyindexminmax7.out.gz differ diff --git a/tests/ad/reducebyindexminmax8.fut b/tests/ad/reducebyindexminmax8.fut index dc4273fd9b..0c3fc324a1 100644 --- a/tests/ad/reducebyindexminmax8.fut +++ b/tests/ad/reducebyindexminmax8.fut @@ -12,8 +12,22 @@ def fwd2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = tabulate n (\i -> jvp (primal2 is dst) vs (replicate n (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1))) |> transpose +def rev_vec2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = + let seeds = tabulate m (\i -> replicate m (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1)) + in mjp (primal2 is dst) vs seeds + +def fwd_vec2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1)) + in jmp (primal2 is dst) vs seeds + |> transpose + def main [n] [m] [k] [l] (is': [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = let is = map (\i -> (i64.abs i) %% m) is' let r = rev2 is dst vs let f = fwd2 is dst vs - in map2 (map2 (map2 (==))) r f |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + let rv = rev_vec2 is dst vs + let fv = fwd_vec2 is dst vs + let eq_rf = map2 (map2 (map2 (==))) r f |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + let eq_rrv = map2 (map2 (map2 (==))) r rv |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + let eq_ffv = map2 (map2 (map2 (==))) f fv |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + in eq_rf && eq_rrv && eq_ffv diff --git a/tests/ad/reducemul0.fut b/tests/ad/reducemul0.fut index 406554a5a1..c9b4559d5a 100644 --- a/tests/ad/reducemul0.fut +++ b/tests/ad/reducemul0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [0.0f32, 2.0f32, 0.0f32, 4.0f32] } output { [0.0f32, 0.0f32, 0.0f32, 0.0f32] } def red_mult [n] (xs: [n]f32) : f32 = @@ -11,3 +11,7 @@ entry rev [n] (xs: [n]f32) = entry fwd [n] (xs: [n]f32) = tabulate n (\i -> jvp red_mult xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp red_mult xs seeds diff --git a/tests/ad/reducemul4.fut b/tests/ad/reducemul4.fut index db067c6f45..d4a4e1664d 100644 --- a/tests/ad/reducemul4.fut +++ b/tests/ad/reducemul4.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [1f32, 2f32, 3f32, 4f32] } output { [[48f32, 12f32, 8f32, 6f32], [48f32, 48f32, 16f32, 12f32], [72f32, 36f32, 48f32, 18f32], [96f32, 48f32, 32f32, 48f32]] } def fun [n] (as: [n]f32) = @@ -13,3 +13,11 @@ entry fwd [n] (as: [n]f32) = entry rev [n] (as: [n]f32) = tabulate n (\i -> vjp fun as (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (as: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp fun as seeds |> transpose + +entry rev_vec [n] (as: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp fun as seeds diff --git a/tests/ad/reducevec0.fut b/tests/ad/reducevec0.fut index d781fb4cd4..24a9dc1fe6 100644 --- a/tests/ad/reducevec0.fut +++ b/tests/ad/reducevec0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec rev_vec -- input { -- [[[0f32,1f32],[2f32,3f32]], -- [[5f32,1f32],[3f32,0f32]], @@ -17,3 +17,21 @@ entry fwd [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = tabulate_3d n m k (\i j l -> jvp f xs (replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [l] = 1)))) |> transpose |> map transpose + +entry fwd_vec [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = + let seeds = tabulate (n * m * k) (\p -> + let i = p / (m * k) + let j = (p % (m * k)) / k + let l = p % k + in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [l] = 1))) + let res = jmp f xs seeds + in unflatten (sized (n * (m * k)) res) |> map unflatten + |> transpose + |> map transpose + +entry rev_vec [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = + let seeds = tabulate (m * k) (\p -> + let i = p / k + let j = p % k + in replicate m (replicate k 0) with [i] = (replicate k 0 with [j] = 1)) + in unflatten (mjp f xs seeds) diff --git a/tests/ad/replicate0.fut b/tests/ad/replicate0.fut index 091d5f3085..e4e7b20713 100644 --- a/tests/ad/replicate0.fut +++ b/tests/ad/replicate0.fut @@ -2,17 +2,25 @@ -- tags { autodiff } -- == --- entry: f_jvp --- input { 3i64 2 } --- output { [1,1,1] } +-- entry: fwd_vec fwd_map +-- input { 2i64 [1.0, 2.0] } +-- output { [[[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 1.0]]] } -entry f_jvp n x : []i32 = - jvp (replicate n) x 1 +def f (n: i64) (xs: []f64) = replicate n xs + +entry fwd_vec n (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jmp2 (f n) xs seeds).1 + +entry fwd_map n (xs: []f64) = + map (\i -> jvp (f n) xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) -- == --- entry: f_vjp +-- entry: rev -- input { 3i64 2i64 } -- output { 3i64 } -entry f_vjp n x = +entry rev n x = vjp (replicate n) x (iota n) diff --git a/tests/ad/reshape0.fut b/tests/ad/reshape0.fut index f54a9e3737..1b42b89c42 100644 --- a/tests/ad/reshape0.fut +++ b/tests/ad/reshape0.fut @@ -2,12 +2,16 @@ -- tags { autodiff } -- == --- entry: f_jvp +-- entry: fwd_map fwd_vec -- input { 2i64 2i64 [1,2,3,4] } --- output { [[1,2],[3,4]] } +-- output { [[[1, 0], [0, 0]], [[0, 1], [0, 0]]] } -entry f_jvp n m (xs: [n * m]i32) = - jvp unflatten xs xs +entry fwd_map n m (xs: [n * m]i32) = + tabulate 2 (\i -> jvp unflatten xs (replicate (n * m) 0 with [i] = 1)) + +entry fwd_vec n m (xs: [n * m]i32) = + let seeds = tabulate 2 (\i -> replicate (n * m) 0 with [i] = 1) + in jmp unflatten xs seeds -- == -- entry: f_vjp diff --git a/tests/ad/scan0.fut b/tests/ad/scan0.fut index b044a5c8d4..b539a4a32e 100644 --- a/tests/ad/scan0.fut +++ b/tests/ad/scan0.fut @@ -2,7 +2,7 @@ -- generic case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [2.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], @@ -17,3 +17,11 @@ entry fwd_J [n] (a: [n]f32) = entry rev_J [n] (a: [n]f32) = tabulate n (\i -> vjp (scan (*) 1) a (replicate n 0 with [i] = 1)) + +entry fwd_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (scan (*) 1) a seeds |> transpose + +entry rev_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp (scan (*) 1) a seeds diff --git a/tests/ad/scan1.fut b/tests/ad/scan1.fut index 00138c2eea..5debeb1e09 100644 --- a/tests/ad/scan1.fut +++ b/tests/ad/scan1.fut @@ -2,7 +2,7 @@ -- addition special case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], @@ -17,3 +17,11 @@ entry fwd_J [n] (a: [n]f32) = entry rev_J [n] (a: [n]f32) = tabulate n (\i -> vjp (scan (+) 0) a (replicate n 0 with [i] = 1)) + +entry fwd_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (scan (+) 0) a seeds |> transpose + +entry rev_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp (scan (+) 0) a seeds diff --git a/tests/ad/scan2.fut b/tests/ad/scan2.fut index b27466c516..ea2d39fc3a 100644 --- a/tests/ad/scan2.fut +++ b/tests/ad/scan2.fut @@ -2,7 +2,7 @@ -- special cases: vectorised and addition -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { [[[1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32]]] } @@ -15,3 +15,11 @@ entry fwd_J [n] [k] (a: [n][k]f32) = entry rev_J [n] [k] (a: [n][k]f32) = tabulate n (\i -> vjp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) + +entry fwd_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in jmp primal a seeds |> transpose + +entry rev_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in mjp primal a seeds diff --git a/tests/ad/scan3.fut b/tests/ad/scan3.fut index 0de395e9d6..35a8890000 100644 --- a/tests/ad/scan3.fut +++ b/tests/ad/scan3.fut @@ -2,7 +2,7 @@ -- MatrixMul case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } -- output { -- [[[[1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], @@ -53,3 +53,19 @@ entry rev_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = in tabulate (n * 4) (\i -> vjp primal input (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) |> unflatten |> map (map toarrs) + +entry fwd_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> fromarrs (onehot_2d n 4 (i / 4) (i % 4))) + in jmp primal input seeds + |> map toarrs + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> fromarrs (onehot_2d n 4 (i / 4) (i % 4))) + in mjp primal input seeds + |> unflatten + |> map (map toarrs) diff --git a/tests/ad/scan4.fut b/tests/ad/scan4.fut index 71343e2d7a..4f414bc333 100644 --- a/tests/ad/scan4.fut +++ b/tests/ad/scan4.fut @@ -2,7 +2,7 @@ -- ZeroQuadrant case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 3.0f32, 5.0f32], [3.0f32, 4.0f32, 2.0f32], [4.0f32, 2.0f32, 1.0f32]] } -- output { -- [[[1f32, 1f32, 1f32], [0f32, 0f32, 0f32], @@ -31,3 +31,16 @@ entry rev_J [n] (input: [n][3]f32) = let input = fromarrs input in tabulate n (\i -> vjp primal input (replicate n (0, 0, 0) with [i] = (1, 1, 1))) |> map toarrs + +entry fwd_vec_J [n] (input: [n][3]f32) = + let input = fromarrs input + let seeds = tabulate n (\i -> replicate n (0, 0, 0) with [i] = (1, 1, 1)) + in jmp primal input seeds + |> map toarrs + |> transpose + +entry rev_vec_J [n] (input: [n][3]f32) = + let input = fromarrs input + let seeds = tabulate n (\i -> replicate n (0, 0, 0) with [i] = (1, 1, 1)) + in mjp primal input seeds + |> map toarrs diff --git a/tests/ad/scan5.fut b/tests/ad/scan5.fut index 6b55cd7b13..f9e4aa0130 100644 --- a/tests/ad/scan5.fut +++ b/tests/ad/scan5.fut @@ -2,7 +2,7 @@ -- Vectorised special case + generic case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { -- [[[1f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], @@ -21,3 +21,11 @@ entry fwd_J [n] [k] (a: [n][k]f32) = entry rev_J [n] [k] (a: [n][k]f32) = tabulate n (\i -> vjp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) + +entry fwd_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in jmp primal a seeds |> transpose + +entry rev_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in mjp primal a seeds diff --git a/tests/ad/scan6.fut b/tests/ad/scan6.fut index ead433518a..fc0fde8be7 100644 --- a/tests/ad/scan6.fut +++ b/tests/ad/scan6.fut @@ -2,7 +2,7 @@ -- MatrixMul case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32, 2f32], [4f32, 3f32], [3f32, 4f32], [4f32, 2f32]] } -- output { -- [[[[1f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], @@ -39,7 +39,7 @@ entry rev_J [n] (input: [n][2]f32) = |> map (map toarrs) -- == --- entry: fwd_J2 rev_J2 +-- entry: fwd_J2 rev_J2 fwd_vec_J2 rev_vec_J2 -- no_oclgrind input { [[1f32,2f32,3f32,4f32,5f32,6f32],[6f32,5f32,4f32,3f32,2f32,1f32],[4f32,5f32,6f32,1f32,2f32,3f32],[3f32,2f32,1f32,6f32,5f32,4f32]] } -- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[4f32, 3f32, 0f32, 0f32, 0f32, 0f32], [1f32, 0f32, 1f32, 2f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[2f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 1f32, 2f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 4f32, 0f32, 3f32, 0f32], [0f32, 0f32, 3f32, 5f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 4f32, 0f32, 3f32], [0f32, 0f32, 4f32, 6f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 2f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 3f32, 5f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 2f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 4f32, 6f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[26f32, 19f32, 0f32, 0f32, 0f32, 0f32], [6f32, 1f32, 6f32, 12f32, 1f32, 2f32], [1f32, 0f32, 16f32, 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[14f32, 9f32, 0f32, 0f32, 0f32, 0f32], [2f32, 3f32, 2f32, 4f32, 3f32, 6f32], [0f32, 1f32, 0f32, 0f32, 16f32, 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 26f32, 0f32, 19f32, 0f32], [0f32, 0f32, 18f32, 30f32, 3f32, 5f32], [0f32, 0f32, 27f32, 11f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 26f32, 0f32, 19f32], [0f32, 0f32, 24f32, 36f32, 4f32, 6f32], [0f32, 0f32, 34f32, 14f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 14f32, 0f32, 9f32, 0f32], [0f32, 0f32, 6f32, 10f32, 9f32, 15f32], [0f32, 0f32, 0f32, 0f32, 27f32, 11f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 14f32, 0f32, 9f32], [0f32, 0f32, 8f32, 12f32, 12f32, 18f32], [0f32, 0f32, 0f32, 0f32, 34f32, 14f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[110f32, 73f32, 0f32, 0f32, 0f32, 0f32], [18f32, 19f32, 18f32, 36f32, 19f32, 38f32], [1f32, 6f32, 16f32, 9f32, 96f32, 54f32], [1f32, 0f32, 109f32, 64f32, 0f32, 0f32]], [[186f32, 131f32, 0f32, 0f32, 0f32, 0f32], [38f32, 17f32, 38f32, 76f32, 17f32, 34f32], [5f32, 4f32, 80f32, 45f32, 64f32, 36f32], [0f32, 1f32, 0f32, 0f32, 109f32, 64f32]], [[0f32, 0f32, 110f32, 0f32, 73f32, 0f32], [0f32, 0f32, 54f32, 90f32, 57f32, 95f32], [0f32, 0f32, 27f32, 11f32, 162f32, 66f32], [0f32, 0f32, 173f32, 87f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 110f32, 0f32, 73f32], [0f32, 0f32, 72f32, 108f32, 76f32, 114f32], [0f32, 0f32, 34f32, 14f32, 204f32, 84f32], [0f32, 0f32, 218f32, 110f32, 0f32, 0f32]], [[0f32, 0f32, 186f32, 0f32, 131f32, 0f32], [0f32, 0f32, 114f32, 190f32, 51f32, 85f32], [0f32, 0f32, 135f32, 55f32, 108f32, 44f32], [0f32, 0f32, 0f32, 0f32, 173f32, 87f32]], [[0f32, 0f32, 0f32, 186f32, 0f32, 131f32], [0f32, 0f32, 152f32, 228f32, 68f32, 102f32], [0f32, 0f32, 170f32, 70f32, 136f32, 56f32], [0f32, 0f32, 0f32, 0f32, 218f32, 110f32]]]] } def mm2by2 (a1, b1, c1, d1) @@ -82,3 +82,35 @@ entry rev_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = in tabulate (n * 6) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 6 (i / 6) (i % 6)))) |> unflatten |> map (map toarrs2) + +entry fwd_vec_J [n] (input: [n][2]f32) = + let input = fromarrs input + let seeds = tabulate (n * 2) (\i -> fromarrs (onehot_2d n 2 (i / 2) (i % 2))) + in jmp primal input seeds + |> map toarrs + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec_J [n] (input: [n][2]f32) = + let input = fromarrs input + let seeds = tabulate (n * 2) (\i -> fromarrs (onehot_2d n 2 (i / 2) (i % 2))) + in mjp primal input seeds + |> unflatten + |> map (map toarrs) + +entry fwd_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 6) (\i -> fromarrs2 (onehot_2d n 6 (i / 6) (i % 6))) + in jmp primal2 input seeds + |> map toarrs2 + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 6) (\i -> fromarrs2 (onehot_2d n 6 (i / 6) (i % 6))) + in mjp primal2 input seeds + |> unflatten + |> map (map toarrs2) diff --git a/tests/ad/scan7.fut b/tests/ad/scan7.fut index 00efbe4f9e..cfb967878b 100644 --- a/tests/ad/scan7.fut +++ b/tests/ad/scan7.fut @@ -5,7 +5,7 @@ -- tags { autodiff } -- == --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], -- [[3f32,4f32], [4f32,5f32]], [[4f32,5f32], [2f32,3f32]]] } -- output { @@ -107,3 +107,24 @@ entry test [n] [m] [k] (input: [n][m][k]f32) bar = let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) in a2 |> transpose |> map transpose |> (map (map transpose)) + +entry fwd_vec_J [n] [m] [k] (input: [n][m][k]f32) = + let seeds = tabulate (n * m * k) (\p -> + let i = p / (m * k) + let j = (p % (m * k)) / k + let q = p % k + in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1))) + let res = jmp primal input seeds + in unflatten (sized (n * (m * k)) res) |> map unflatten + +entry rev_vec_J [n] [m] [k] (input: [n][m][k]f32) = + let seeds = tabulate (n * m * k) (\p -> + let i = p / (m * k) + let j = (p % (m * k)) / k + let q = p % k + in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1))) + let res = mjp primal input seeds + |> (\x -> unflatten (sized (n * (m * k)) x)) |> map unflatten + let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) + let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) + in a2 |> transpose |> map transpose |> (map (map transpose)) diff --git a/tests/ad/scan8.fut b/tests/ad/scan8.fut index 127b564914..8a8c14317a 100644 --- a/tests/ad/scan8.fut +++ b/tests/ad/scan8.fut @@ -1,7 +1,7 @@ -- Scan with 3x3 matrix multiplication. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], -- [9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], @@ -176,3 +176,19 @@ entry rev [n] (input: [n][9]f32) : [n][9][n][9]f32 = in tabulate (n * 9) (\i -> vjp primal3 input (fromarrs3 (onehot_2d n 9 (i / 9) (i % 9)))) |> unflatten |> map (map toarrs3) + +entry fwd_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = + let input = fromarrs3 input + let seeds = tabulate (n * 9) (\i -> fromarrs3 (onehot_2d n 9 (i / 9) (i % 9))) + in jmp primal3 input seeds + |> map toarrs3 + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = + let input = fromarrs3 input + let seeds = tabulate (n * 9) (\i -> fromarrs3 (onehot_2d n 9 (i / 9) (i % 9))) + in mjp primal3 input seeds + |> unflatten + |> map (map toarrs3) diff --git a/tests/ad/scan9.fut b/tests/ad/scan9.fut index cb1ebfb671..37c50ff9e0 100644 --- a/tests/ad/scan9.fut +++ b/tests/ad/scan9.fut @@ -1,7 +1,7 @@ -- Scan with 4x4 matrix multiplication. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- no_oclgrind input { -- [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32,10f32,11f32,12f32,13f32,14f32,15f32,16f32], -- [16f32,15f32,14f32,13f32,12f32,11f32,10f32,9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], @@ -474,3 +474,19 @@ entry rev [n] (input: [n][16]f32) : [n][16][n][16]f32 = in tabulate (n * 16) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 16 (i / 16) (i % 16)))) |> unflatten |> map (map toarrs2) + +entry fwd_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 16) (\i -> fromarrs2 (onehot_2d n 16 (i / 16) (i % 16))) + in jmp primal2 input seeds + |> map toarrs2 + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 16) (\i -> fromarrs2 (onehot_2d n 16 (i / 16) (i % 16))) + in mjp primal2 input seeds + |> unflatten + |> map (map toarrs2) diff --git a/tests/ad/scatter0.fut b/tests/ad/scatter0.fut index 78eff344e3..b9ed86fb0f 100644 --- a/tests/ad/scatter0.fut +++ b/tests/ad/scatter0.fut @@ -1,7 +1,7 @@ -- Simple scatter, differentiating wrt. values. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64, 2i64, 3i64] [1f64, 2f64, 3f64, 0f64] } -- output { -- [[1.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], @@ -27,3 +27,11 @@ entry fwd [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = entry rev [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = vjp (\vs -> f xs is vs) vs (replicate k 0 with [i] = 1) in tabulate n g + +entry fwd_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (\vs -> f xs is vs) vs seeds + +entry rev_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate n (\i -> replicate k 0 with [i] = 1) + in mjp (\vs -> f xs is vs) vs seeds diff --git a/tests/ad/scatter1.fut b/tests/ad/scatter1.fut index 0e41101762..54b0c05adf 100644 --- a/tests/ad/scatter1.fut +++ b/tests/ad/scatter1.fut @@ -1,7 +1,7 @@ -- Simple scatter, differentiating wrt. target. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64] [1f64, 2f64] } -- output { -- [[0.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], @@ -20,3 +20,11 @@ entry fwd [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = entry rev [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = vjp (\xs -> f xs is vs) xs (replicate k 0 with [i] = 1) in tabulate k g + +entry fwd_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) + in jmp (\xs -> f xs is vs) xs seeds + +entry rev_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) + in mjp (\xs -> f xs is vs) xs seeds diff --git a/tests/ad/stripmine1.fut b/tests/ad/stripmine1.fut index 82e8280234..09309f3b0f 100644 --- a/tests/ad/stripmine1.fut +++ b/tests/ad/stripmine1.fut @@ -15,7 +15,7 @@ def square [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = square xs -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], @@ -28,3 +28,12 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) + +entry f_jmp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp square xs seeds + |> transpose + +entry f_mjp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp square xs seeds diff --git a/tests/ad/stripmine2.fut b/tests/ad/stripmine2.fut index d32cff65e0..d256314b12 100644 --- a/tests/ad/stripmine2.fut +++ b/tests/ad/stripmine2.fut @@ -13,7 +13,7 @@ def pow_list [n] y (xs: [n]i32) = entry prim y xs = pow_list y xs -- == --- entry: f_vjp f_jvp +-- entry: f_vjp f_jvp f_mjp f_jmp -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], @@ -24,3 +24,12 @@ entry f_jvp [n] y (xs: [n]i32) = entry f_vjp [n] y (xs: [n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) + +entry f_jmp [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp (pow_list y) xs seeds + |> transpose + +entry f_mjp [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp (pow_list y) xs seeds diff --git a/tests/ad/sum.fut b/tests/ad/sum.fut index 241f9b65fa..0cd00c1b14 100644 --- a/tests/ad/sum.fut +++ b/tests/ad/sum.fut @@ -1,7 +1,7 @@ -- Simple reduce with summation. -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] } @@ -13,3 +13,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp sum xs (tabulate n ((== i) >-> f64.bool))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in jmp sum xs seeds diff --git a/tests/ad/transpose.fut b/tests/ad/transpose.fut new file mode 100644 index 0000000000..fa6fd0feb8 --- /dev/null +++ b/tests/ad/transpose.fut @@ -0,0 +1,16 @@ +-- == +-- tags { autodiff } +-- entry: fwd_vec fwd_map +-- input { [[1.0,2.0],[3.0,4.0]] } +-- output { [[[1.0, 0.0],[0.0, 0.0]],[[0.0, 0.0],[1.0, 0.0]],[[0.0, 1.0],[0.0, 0.0]],[[0.0, 0.0],[0.0, 1.0]]] } + +def f (xs: [][]f64) = transpose xs + +entry fwd_vec [n] [m] (xs: [n][m]f64) = + let seeds = + tabulate (n * m) (\i -> tabulate (n * m) (\j -> f64.bool (i == j)) |> unflatten) + in (jmp2 f xs seeds).1 + +entry fwd_map [n] [m] (xs: [n][m]f64) = + tabulate (n * m) + (\i -> jvp f xs (tabulate (n * m) (\j -> f64.bool (i == j)) |> unflatten)) diff --git a/tests/ad/truedep0.fut b/tests/ad/truedep0.fut index 518091ed19..3680459d41 100644 --- a/tests/ad/truedep0.fut +++ b/tests/ad/truedep0.fut @@ -12,7 +12,7 @@ def test [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = test xs -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [1,2,3,4,5] } -- output { [[1,0,0,0,0], -- [2,0,0,0,0], @@ -25,3 +25,12 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp test xs (replicate n 0 with [i] = 1)) + +entry f_jmp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jmp test xs seeds + |> transpose + +entry f_mjp [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in mjp test xs seeds