From 5777e799bcb17bd101ec5480a009a70258dff5ba Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 5 Dec 2024 20:50:04 +0100 Subject: [PATCH 01/70] Create frontend and IR for vectorised AD. --- prelude/ad.fut | 18 ++++++-- src/Futhark/IR/Parse.hs | 8 +++- src/Futhark/IR/SOACS/SOAC.hs | 72 ++++++++++++++++++-------------- src/Futhark/IR/SOACS/Simplify.hs | 12 +++--- src/Futhark/Internalise/Exps.hs | 4 +- src/Futhark/Optimise/Fusion.hs | 8 ++-- src/Futhark/Pass/AD.hs | 8 ++-- src/Language/Futhark/Prop.hs | 28 +++++++++++++ 8 files changed, 106 insertions(+), 52 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 6a512d40b9..518c545b98 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -95,18 +95,28 @@ -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. -def jvp2 'a 'b (f: a -> b) (x: a) (x': a): (b, b) = +def jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = intrinsics.jvp2 f x x' -- | Vector-Jacobian Product ("reverse mode"), producing also the -- primal result as the first element of the result tuple. -def vjp2 'a 'b (f: a -> b) (x: a) (y': b): (b, a) = +def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = intrinsics.vjp2 f x y' +-- | As `jvp2`, but accepts a vector of seed values. Semantically +-- equivalent to mapping, but may be more efficient. +def jvp2_vec 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = + intrinsics.jvp2_vec f x x' + +-- | As `vjp2`, but accepts a vector of seed values. Semantically +-- equivalent to mapping, but may be more efficient. +def vjp2_vec 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) = + intrinsics.vjp2_vec f x y' + -- | Jacobian-Vector Product ("forward mode"). -def jvp 'a 'b (f: a -> b) (x: a) (x': a): b = +def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = (jvp2 f x x').1 -- | Vector-Jacobian Product ("reverse mode"). -def vjp 'a 'b (f: a -> b) (x: a) (y': b): a = +def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a = (vjp2 f x y').1 diff --git a/src/Futhark/IR/Parse.hs b/src/Futhark/IR/Parse.hs index 298e149617..a056b29609 100644 --- a/src/Futhark/IR/Parse.hs +++ b/src/Futhark/IR/Parse.hs @@ -798,7 +798,9 @@ pSOAC pr = pVJP = parens $ SOAC.VJP - <$> braces (pSubExp `sepBy` pComma) + <$> pShape + <* pComma + <*> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma @@ -806,7 +808,9 @@ pSOAC pr = pJVP = parens $ SOAC.JVP - <$> braces (pSubExp `sepBy` pComma) + <$> pShape + <* pComma + <*> braces (pSubExp `sepBy` pComma) <* pComma <*> braces (pSubExp `sepBy` pComma) <* pComma diff --git a/src/Futhark/IR/SOACS/SOAC.hs b/src/Futhark/IR/SOACS/SOAC.hs index 265582e361..38a02191d9 100644 --- a/src/Futhark/IR/SOACS/SOAC.hs +++ b/src/Futhark/IR/SOACS/SOAC.hs @@ -125,9 +125,9 @@ data SOAC rep -- The final lambda produces indexes and values for the 'HistOp's. Hist SubExp [VName] [HistOp rep] (Lambda rep) | -- FIXME: this should not be here - JVP [SubExp] [SubExp] (Lambda rep) + JVP Shape [SubExp] [SubExp] (Lambda rep) | -- FIXME: this should not be here - VJP [SubExp] [SubExp] (Lambda rep) + VJP Shape [SubExp] [SubExp] (Lambda rep) | -- | A combination of scan, reduction, and map. The first -- t'SubExp' is the size of the input arrays. Screma SubExp [VName] (ScremaForm rep) @@ -399,14 +399,16 @@ mapSOACM :: SOACMapper frep trep m -> SOAC frep -> m (SOAC trep) -mapSOACM tv (JVP args vec lam) = +mapSOACM tv (JVP shape args vec lam) = JVP - <$> mapM (mapOnSOACSubExp tv) args + <$> mapM (mapOnSOACSubExp tv) shape + <*> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec <*> mapOnSOACLambda tv lam -mapSOACM tv (VJP args vec lam) = +mapSOACM tv (VJP shape args vec lam) = VJP - <$> mapM (mapOnSOACSubExp tv) args + <$> mapM (mapOnSOACSubExp tv) shape + <*> mapM (mapOnSOACSubExp tv) args <*> mapM (mapOnSOACSubExp tv) vec <*> mapOnSOACLambda tv lam mapSOACM tv (Stream size arrs accs lam) = @@ -514,10 +516,10 @@ instance (ASTRep rep) => Rename (SOAC rep) where -- | The type of a SOAC. soacType :: (Typed (LParamInfo rep)) => SOAC rep -> [Type] -soacType (JVP _ _ lam) = - lambdaReturnType lam ++ lambdaReturnType lam -soacType (VJP _ _ lam) = - lambdaReturnType lam ++ map paramType (lambdaParams lam) +soacType (JVP shape _ _ lam) = + lambdaReturnType lam ++ map (`arrayOfShape` shape) (lambdaReturnType lam) +soacType (VJP shape _ _ lam) = + lambdaReturnType lam ++ map ((`arrayOfShape` shape) . paramType) (lambdaParams lam) soacType (Stream outersize _ accs lam) = map (substNamesInType substs) rtp where @@ -570,10 +572,10 @@ mapHistOp f (HistOp w rf dests nes lam) = HistOp w rf dests nes $ f lam instance CanBeAliased SOAC where - addOpAliases aliases (JVP args vec lam) = - JVP args vec (Alias.analyseLambda aliases lam) - addOpAliases aliases (VJP args vec lam) = - VJP args vec (Alias.analyseLambda aliases lam) + addOpAliases aliases (JVP shape args vec lam) = + JVP shape args vec (Alias.analyseLambda aliases lam) + addOpAliases aliases (VJP shape args vec lam) = + VJP shape args vec (Alias.analyseLambda aliases lam) addOpAliases aliases (Stream size arr accs lam) = Stream size arr accs $ Alias.analyseLambda aliases lam addOpAliases aliases (Scatter len arrs dests lam) = @@ -629,12 +631,12 @@ instance IsOp SOAC where where flattenBlocks (_, arr, ivs) = oneName arr <> mconcat (map (mconcat . fst) ivs) <> mconcat (map snd ivs) - opDependencies (JVP args vec lam) = + opDependencies (JVP _ args vec lam) = mconcat $ replicate 2 $ lambdaDependencies mempty lam $ zipWith (<>) (map depsOf' args) (map depsOf' vec) - opDependencies (VJP args vec lam) = + opDependencies (VJP _ args vec lam) = lambdaDependencies mempty lam @@ -711,26 +713,32 @@ instance (RepTypes rep) => ST.IndexOp (SOAC rep) where -- | Type-check a SOAC. typeCheckSOAC :: (TC.Checkable rep) => SOAC (Aliases rep) -> TC.TypeM rep () -typeCheckSOAC (VJP args vec lam) = do +typeCheckSOAC (VJP shape args vec lam) = do + mapM_ (TC.require [Prim int64]) shape args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec - unless (vec_ts == lambdaReturnType lam) $ + unless (vec_ts == map (`arrayOfShape` shape) (lambdaReturnType lam)) $ TC.bad . TC.TypeError . docText $ "Return type" PP.indent 2 (pretty (lambdaReturnType lam)) - "does not match type of seed vector" + "inconsistent with type of seed vector" PP.indent 2 (pretty vec_ts) -typeCheckSOAC (JVP args vec lam) = do + "with shape" + PP.indent 2 (pretty shape) +typeCheckSOAC (JVP shape args vec lam) = do + mapM_ (TC.require [Prim int64]) shape args' <- mapM TC.checkArg args TC.checkLambda lam $ map TC.noArgAliases args' vec_ts <- mapM TC.checkSubExp vec - unless (vec_ts == map TC.argType args') $ + unless (vec_ts == map ((`arrayOfShape` shape) . TC.argType) args') $ TC.bad . TC.TypeError . docText $ "Parameter type" PP.indent 2 (pretty $ map TC.argType args') "does not match type of seed vector" PP.indent 2 (pretty vec_ts) + "with shape" + PP.indent 2 (pretty shape) typeCheckSOAC (Stream size arrexps accexps lam) = do TC.require [Prim int64] size accargs <- mapM TC.checkArg accexps @@ -891,10 +899,10 @@ typeCheckSOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do <> " wrong for given scan and reduction functions." instance RephraseOp SOAC where - rephraseInOp r (VJP args vec lam) = - VJP args vec <$> rephraseLambda r lam - rephraseInOp r (JVP args vec lam) = - JVP args vec <$> rephraseLambda r lam + rephraseInOp r (VJP shape args vec lam) = + VJP shape args vec <$> rephraseLambda r lam + rephraseInOp r (JVP shape args vec lam) = + JVP shape args vec <$> rephraseLambda r lam rephraseInOp r (Stream w arrs acc lam) = Stream w arrs acc <$> rephraseLambda r lam rephraseInOp r (Scatter w arrs dests lam) = @@ -916,9 +924,9 @@ instance RephraseOp SOAC where onRed (Reduce comm op nes) = Reduce comm <$> rephraseLambda r op <*> pure nes instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where - opMetrics (VJP _ _ lam) = + opMetrics (VJP _ _ _ lam) = inside "VJP" $ lambdaMetrics lam - opMetrics (JVP _ _ lam) = + opMetrics (JVP _ _ _ lam) = inside "JVP" $ lambdaMetrics lam opMetrics (Stream _ _ _ lam) = inside "Stream" $ lambdaMetrics lam @@ -933,19 +941,21 @@ instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where mapM_ (lambdaMetrics . redLambda) reds instance (PrettyRep rep) => PP.Pretty (SOAC rep) where - pretty (VJP args vec lam) = + pretty (VJP shape args vec lam) = "vjp" <> parens ( PP.align $ - PP.braces (commasep $ map pretty args) + pretty shape + <> comma PP.braces (commasep $ map pretty args) <> comma PP.braces (commasep $ map pretty vec) <> comma pretty lam ) - pretty (JVP args vec lam) = + pretty (JVP shape args vec lam) = "jvp" <> parens ( PP.align $ - PP.braces (commasep $ map pretty args) + pretty shape + <> comma PP.braces (commasep $ map pretty args) <> comma PP.braces (commasep $ map pretty vec) <> comma pretty lam ) diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index 38254646ba..04a3cb9870 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -82,16 +82,18 @@ simplifyConsts = simplifySOAC :: (Simplify.SimplifiableRep rep) => Simplify.SimplifyOp rep (SOAC (Wise rep)) -simplifySOAC (VJP arr vec lam) = do - (lam', hoisted) <- Engine.simplifyLambda mempty lam +simplifySOAC (VJP shape arr vec lam) = do + shape' <- traverse Engine.simplify shape arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec - pure (VJP arr' vec' lam', hoisted) -simplifySOAC (JVP arr vec lam) = do (lam', hoisted) <- Engine.simplifyLambda mempty lam + pure (VJP shape' arr' vec' lam', hoisted) +simplifySOAC (JVP shape arr vec lam) = do + shape' <- traverse Engine.simplify shape arr' <- mapM Engine.simplify arr vec' <- mapM Engine.simplify vec - pure (JVP arr' vec' lam', hoisted) + (lam', hoisted) <- Engine.simplifyLambda mempty lam + pure (JVP shape' arr' vec' lam', hoisted) simplifySOAC (Stream outerdim arr nes lam) = do outerdim' <- Engine.simplify outerdim nes' <- mapM Engine.simplify nes diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 4c310aa8cf..726cb3bb57 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1710,8 +1710,8 @@ isIntrinsicFunction qname args loc = do lam <- internaliseLambdaCoerce f =<< mapM subExpType x' fmap (map I.Var) . letTupExp desc . Op $ case fname of - "jvp2" -> JVP x' v' lam - _ -> VJP x' v' lam + "jvp2" -> JVP mempty x' v' lam + _ -> VJP mempty x' v' lam handleAD _ _ = Nothing handleRest [a, si, v] "scatter" = Just $ scatterF 1 a si v diff --git a/src/Futhark/Optimise/Fusion.hs b/src/Futhark/Optimise/Fusion.hs index 703ce2a047..fa5c83311b 100644 --- a/src/Futhark/Optimise/Fusion.hs +++ b/src/Futhark/Optimise/Fusion.hs @@ -541,12 +541,12 @@ runInnerFusionOnContext c@(incoming, node, nodeT, outgoing) = case nodeT of cases' <- mapM (traverse $ renameBody <=< (`doFusionWithDelayed` to_fuse)) cases defbody' <- doFusionWithDelayed defbody to_fuse pure (incoming, node, MatchNode (Let pat aux (Match cond cases' defbody' dec)) [], outgoing) - StmNode (Let pat aux (Op (Futhark.VJP args vec lam))) -> doFuseScans $ do + StmNode (Let pat aux (Op (Futhark.VJP shape args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam - pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP args vec lam'))), outgoing) - StmNode (Let pat aux (Op (Futhark.JVP args vec lam))) -> doFuseScans $ do + pure (incoming, node, StmNode (Let pat aux (Op (Futhark.VJP shape args vec lam'))), outgoing) + StmNode (Let pat aux (Op (Futhark.JVP shape args vec lam))) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam - pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP args vec lam'))), outgoing) + pure (incoming, node, StmNode (Let pat aux (Op (Futhark.JVP shape args vec lam'))), outgoing) StmNode (Let pat aux (WithAcc inputs lam)) -> doFuseScans $ do lam' <- fst <$> doFusionInLambda lam pure (incoming, node, StmNode (Let pat aux (WithAcc inputs lam')), outgoing) diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index dca9ad088c..2a7588ce21 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -36,20 +36,20 @@ bindLambda pat aux (Lambda params _ body) args = do certifying cs $ letBindNames [v] $ BasicOp $ SubExp se onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS) -onStm mode scope (Let pat aux (Op (VJP args vec lam))) = do +onStm mode scope (Let pat aux (Op (VJP shape args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope - else pure $ oneStm $ Let pat aux $ Op $ VJP args vec lam' -onStm mode scope (Let pat aux (Op (JVP args vec lam))) = do + else pure $ oneStm $ Let pat aux $ Op $ VJP shape args vec lam' +onStm mode scope (Let pat aux (Op (JVP shape args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do lam'' <- fwdJVP scope lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope - else pure $ oneStm $ Let pat aux $ Op $ JVP args vec lam' + else pure $ oneStm $ Let pat aux $ Op $ JVP shape args vec lam' onStm mode scope (Let pat aux e) = oneStm . Let pat aux <$> mapExpM mapper e where mapper = diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 070dfc733f..1dc280efd5 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -953,6 +953,34 @@ intrinsics = $ RetType [] $ Scalar $ tupleRecord [Scalar $ t_b Nonunique, Scalar $ t_a Nonunique] + ), + ( "jvp2_vec", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), + Scalar (t_a Observe), + array_a Observe $ shape [n] + ] + $ RetType [] + $ Scalar + $ tupleRecord + [ Scalar $ t_b Nonunique, + array_b Unique $ shape [n] + ] + ), + ( "vjp2_vec", + IntrinsicPolyFun + [tp_a, tp_b, sp_n] + [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), + Scalar (t_a Observe), + array_b Observe $ shape [n] + ] + $ RetType [] + $ Scalar + $ tupleRecord + [ Scalar $ t_b Nonunique, + array_a Unique $ shape [n] + ] ) ] ++ From a54fa53d1029f859a8e884650370715883ecff3f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 5 Dec 2024 21:06:07 +0100 Subject: [PATCH 02/70] Hook it up in internalisation, too. --- src/Futhark/Internalise/Exps.hs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 726cb3bb57..dd734eb11b 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1704,14 +1704,24 @@ isIntrinsicFunction qname args loc = do handleAccs _ _ = Nothing handleAD [f, x, v] fname - | fname `elem` ["jvp2", "vjp2"] = Just $ \desc -> do + | fname `elem` ["jvp2", "vjp2", "jvp2_vec", "vjp2_vec"] = Just $ \desc -> do x' <- internaliseExp "ad_x" x v' <- internaliseExp "ad_v" v + x_t <- subExpType $ head x' + v_t <- subExpType $ head v' lam <- internaliseLambdaCoerce f =<< mapM subExpType x' fmap (map I.Var) . letTupExp desc . Op $ case fname of "jvp2" -> JVP mempty x' v' lam - _ -> VJP mempty x' v' lam + "vjp2" -> VJP mempty x' v' lam + "jvp2_vec" -> + JVP (vecShape x_t v_t) x' v' lam + "vjp2_vec" -> + VJP (vecShape (head (lambdaReturnType lam)) v_t) x' v' lam + _ -> error "handleAD: not supposed to happen." + where + vecShape t1 t2 = + I.Shape $ take (I.arrayRank t2 - I.arrayRank t1) (I.arrayDims t2) handleAD _ _ = Nothing handleRest [a, si, v] "scatter" = Just $ scatterF 1 a si v From 7ecd907d6840f5d6cc80503d50c13a98c128d3ef Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 7 Dec 2024 20:50:30 +0100 Subject: [PATCH 03/70] Basic support for vectorised forward-mode AD. --- src/Futhark/AD/Fwd.hs | 172 +++++++++++----- src/Futhark/AD/Rev/Monad.hs | 45 ++++- src/Futhark/IR/Prop/Types.hs | 4 +- src/Futhark/IR/Syntax/Core.hs | 7 + src/Futhark/Pass/AD.hs | 2 +- src/Language/Futhark/Interpreter.hs | 298 ++++++++++++++-------------- 6 files changed, 332 insertions(+), 196 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index e194ec20d6..cfab15e0ce 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -3,7 +3,7 @@ module Futhark.AD.Fwd (fwdJVP) where import Control.Monad -import Control.Monad.RWS.Strict +import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor (second) import Data.List (transpose) @@ -26,11 +26,17 @@ zeroExp (Array pt shape _) = BasicOp $ Replicate shape $ Constant $ blankPrimValue pt zeroExp t = error $ "zeroExp: " ++ show t -tanType :: TypeBase s u -> ADM (TypeBase s u) +tanType :: (ArrayShape s, Monoid u) => TypeBase s u -> ADM (TypeBase s u) tanType (Acc acc ispace ts u) = do ts_tan <- mapM tanType ts pure $ Acc acc ispace (ts ++ ts_tan) u -tanType t = pure t +tanType t = do + shape <- askShape + pure $ + arrayOf + (Prim (elemType t)) + (shape `prependShape` arrayShape t) + (uniqueness t) slocal' :: ADM a -> ADM a slocal' = slocal id @@ -48,7 +54,7 @@ data RState = RState stateNameSource :: VNameSource } -newtype ADM a = ADM (BuilderT SOACS (State RState) a) +newtype ADM a = ADM (BuilderT SOACS (ReaderT Shape (State RState)) a) deriving ( Functor, Applicative, @@ -72,12 +78,15 @@ instance MonadFreshNames (State RState) where getNameSource = gets stateNameSource putNameSource src = modify (\env -> env {stateNameSource = src}) -runADM :: (MonadFreshNames m) => ADM a -> m a -runADM (ADM m) = +askShape :: ADM Shape +askShape = ADM $ lift ask + +runADM :: (MonadFreshNames m) => Shape -> ADM a -> m a +runADM shape (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState - (fst <$> runBuilderT m mempty) + (runReaderT (fst <$> runBuilderT m mempty) shape) (RState mempty vn) tanVName :: VName -> ADM VName @@ -94,7 +103,7 @@ class TanBuilder a where bundleNewList :: (TanBuilder a) => [a] -> ADM [a] bundleNewList = fmap mconcat . mapM bundleNew -instance TanBuilder (PatElem (TypeBase s u)) where +instance (ArrayShape s, Monoid u) => TanBuilder (PatElem (TypeBase s u)) where newTan (PatElem p t) | isAcc t = do insertTan p p @@ -117,7 +126,7 @@ newTanPat (Pat pes) = Pat <$> mapM newTan pes bundleNewPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t) bundleNewPat (Pat pes) = Pat <$> bundleNewList pes -instance TanBuilder (Param (TypeBase s u)) where +instance (ArrayShape s, Monoid u) => TanBuilder (Param (TypeBase s u)) where newTan (Param _ p t) = do PatElem p' t' <- newTan $ PatElem p t pure $ Param mempty p' t' @@ -129,7 +138,10 @@ instance TanBuilder (Param (TypeBase s u)) where then pure [param'] else pure [param, param'] -instance (Tangent a) => TanBuilder (Param (TypeBase s u), a) where +instance + (ArrayShape s, Monoid u, Tangent a) => + TanBuilder (Param (TypeBase s u), a) + where newTan (p, x) = (,) <$> newTan p <*> tangent x bundleNew (p, x) = do b <- bundleNew p @@ -140,7 +152,7 @@ class Tangent a where tangent :: a -> ADM a bundleTan :: a -> ADM [a] -instance Tangent (TypeBase s u) where +instance (ArrayShape s, Monoid u) => Tangent (TypeBase s u) where tangent = tanType bundleTan t | isAcc t = do @@ -181,6 +193,51 @@ instance Tangent SubExpRes where tangent (SubExpRes cs se) = SubExpRes cs <$> tangent se bundleTan (SubExpRes cs se) = map (SubExpRes cs) <$> bundleTan se +asVName :: SubExp -> ADM VName +asVName (Var v) = pure v +asVName (Constant x) = letExp "v" $ BasicOp $ SubExp $ Constant x + +withTan :: + SubExp -> + (SubExp -> ADM (Exp SOACS)) -> + ADM (Exp SOACS) +withTan x f = do + shape <- askShape + x_tan <- tangent x + if shape == mempty + then f x_tan + else do + let w = shapeSize 0 shape + x_tan_v <- asVName x_tan + x_tan_p <- newParam "x_tanp" . rowType =<< lookupType x_tan_v + lam <- mkLambda [x_tan_p] $ do + fmap (subExpsRes . pure) . letSubExp "tan" + =<< f (Var (paramName x_tan_p)) + pure $ Op $ Screma w [x_tan_v] (mapSOAC lam) + +withTans :: + PrimType -> + SubExp -> + SubExp -> + (PrimExp VName -> PrimExp VName -> PrimExp VName) -> + ADM (Exp SOACS) +withTans t x y f = do + shape <- askShape + x_tan <- asVName =<< tangent x + y_tan <- asVName =<< tangent y + if shape == mempty + then toExp $ f (LeafExp x_tan t) (LeafExp y_tan t) + else do + let w = shapeSize 0 shape + x_tan_p <- newParam "x_tanp" . rowType =<< lookupType x_tan + y_tan_p <- newParam "y_tanp" . rowType =<< lookupType y_tan + lam <- mkLambda [x_tan_p, y_tan_p] $ do + fmap (subExpsRes . pure) . letSubExp "tan" <=< toExp $ + f + (LeafExp (paramName x_tan_p) t) + (LeafExp (paramName y_tan_p) t) + pure $ Op $ Screma w [x_tan, y_tan] (mapSOAC lam) + basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM () basicFwd pat aux op = do pat_tan <- newTanPat pat @@ -198,17 +255,15 @@ basicFwd pat aux op = do let t = unOpType unop x_pe = primExpFromSubExp t x dx = pdUnOp unop x_pe - x_tan <- primExpFromSubExp t <$> tangent x - auxing aux $ letBindNames (patNames pat_tan) <=< toExp $ x_tan ~*~ dx + auxing aux $ letBindNames (patNames pat_tan) <=< withTan x $ \x_tan -> + toExp $ primExpFromSubExp t x_tan ~*~ dx BinOp bop x y -> do let t = binOpType bop - x_tan <- primExpFromSubExp t <$> tangent x - y_tan <- primExpFromSubExp t <$> tangent y - let (wrt_x, wrt_y) = - pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) - auxing aux $ - letBindNames (patNames pat_tan) <=< toExp $ - x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y + auxing aux . letBindNames (patNames pat_tan) <=< withTans t x y $ + \x_tan y_tan -> + let (wrt_x, wrt_y) = + pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) + in x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y CmpOp {} -> addStm $ Let pat_tan aux $ BasicOp op ConvOp cop x -> do @@ -217,7 +272,9 @@ basicFwd pat aux op = do Assert {} -> pure () Index arr slice -> do arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice + dims <- shapeDims <$> askShape + let slice' = Slice $ map sliceDim dims <> unSlice slice + addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice' Update safety arr slice se -> do arr_tan <- tangent arr se_tan <- tangent se @@ -225,32 +282,47 @@ basicFwd pat aux op = do Concat d (arr :| arrs) w -> do arr_tan <- tangent arr arrs_tans <- mapM tangent arrs - addStm $ Let pat_tan aux $ BasicOp $ Concat d (arr_tan :| arrs_tans) w + r <- shapeRank <$> askShape + addStm $ Let pat_tan aux $ BasicOp $ Concat (d + r) (arr_tan :| arrs_tans) w Manifest ds arr -> do arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Manifest ds arr_tan + r <- shapeRank <$> askShape + addStm . Let pat_tan aux . BasicOp $ + Manifest ([0 .. r - 1] ++ map (+ r) ds) arr_tan Iota n _ _ it -> do - addStm $ Let pat_tan aux $ BasicOp $ Replicate (Shape [n]) (intConst it 0) - Replicate n x -> do - x_tan <- tangent x - addStm $ Let pat_tan aux $ BasicOp $ Replicate n x_tan - Scratch t shape -> - addStm $ Let pat_tan aux $ BasicOp $ Scratch t shape + shape <- askShape + addStm . Let pat_tan aux . BasicOp $ + Replicate (shape <> Shape [n]) (intConst it 0) + Replicate n x -> + auxing aux $ letBind pat_tan <=< withTan x $ \x_tan -> + pure $ BasicOp $ Replicate n x_tan + Scratch t shape -> do + tan_shape <- askShape + addStm $ Let pat_tan aux $ BasicOp $ Scratch t $ shapeDims tan_shape <> shape Reshape k reshape arr -> do arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Reshape k reshape arr_tan + shape <- askShape + addStm $ Let pat_tan aux $ BasicOp $ Reshape k (shape <> reshape) arr_tan Rearrange perm arr -> do arr_tan <- tangent arr - addStm $ Let pat_tan aux $ BasicOp $ Rearrange perm arr_tan + r <- shapeRank <$> askShape + addStm . Let pat_tan aux . BasicOp $ + Rearrange ([0 .. r - 1] <> map (+ r) perm) arr_tan _ -> error $ "basicFwd: Unsupported op " ++ prettyString op fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdLambda l@(Lambda params ret body) = - Lambda <$> bundleNewList params <*> bundleTangents ret <*> inScopeOf l (fwdBody body) +fwdLambda (Lambda params ret body) = do + params' <- bundleNewList params + Lambda params' + <$> bundleTangents ret + <*> localScope (scopeOfLParams params') (fwdBody body) fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdStreamLambda l@(Lambda params ret body) = - Lambda <$> ((take 1 params ++) <$> bundleNewList (drop 1 params)) <*> bundleTangents ret <*> inScopeOf l (fwdBody body) +fwdStreamLambda (Lambda params ret body) = do + params' <- (take 1 params ++) <$> bundleNewList (drop 1 params) + Lambda params' + <$> bundleTangents ret + <*> localScope (scopeOfLParams params') (fwdBody body) interleave :: [a] -> [a] -> [a] interleave xs ys = concat $ transpose [xs, ys] @@ -272,24 +344,27 @@ fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do reds' <- mapM fwdRed reds addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds' where + zeroTans lam = + mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType lam + fwdScan :: Scan SOACS -> ADM (Scan SOACS) fwdScan sc = do op' <- fwdLambda $ scanLambda sc - neutral_tans <- mapM zeroFromSubExp $ scanNeutral sc + neutral_tans <- zeroTans $ scanLambda sc pure $ Scan - { scanNeutral = scanNeutral sc `interleave` map Var neutral_tans, + { scanNeutral = scanNeutral sc `interleave` neutral_tans, scanLambda = op' } fwdRed :: Reduce SOACS -> ADM (Reduce SOACS) fwdRed red = do op' <- fwdLambda $ redLambda red - neutral_tans <- mapM zeroFromSubExp $ redNeutral red + neutral_tans <- zeroTans $ redLambda red pure $ Reduce { redComm = redComm red, redLambda = op', - redNeutral = redNeutral red `interleave` map Var neutral_tans + redNeutral = redNeutral red `interleave` neutral_tans } fwdSOAC pat aux (Stream size xs nes lam) = do pat' <- bundleNewPat pat @@ -369,8 +444,7 @@ fwdStm (Let pat aux (BasicOp (UpdateAcc safety acc i x))) = do addStm $ Let pat' aux $ BasicOp $ UpdateAcc safety acc_tan i x' fwdStm stm@(Let pat aux (BasicOp e)) = do -- XXX: this has to be too naive. - unless (any isAcc $ patTypes pat) $ - addStm stm + unless (any isAcc $ patTypes pat) $ addStm stm basicFwd pat aux e fwdStm stm@(Let pat _ (Apply f args _ _)) | Just (ret, argts) <- M.lookup f builtInFunctions = do @@ -450,10 +524,14 @@ fwdBodyTansLast (Body _ stms res) = buildBody_ $ do mapM_ fwdStm stms (res <>) <$> mapM tangent res -fwdJVP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS) -fwdJVP scope l@(Lambda params ret body) = - runADM . localScope scope . inScopeOf l $ do +fwdJVP :: + (MonadFreshNames m) => + Scope SOACS -> + Shape -> + Lambda SOACS -> + m (Lambda SOACS) +fwdJVP scope shape (Lambda params _ body) = + runADM shape . localScope scope $ do params_tan <- mapM newTan params - body_tan <- fwdBodyTansLast body - ret_tan <- mapM tangent ret - pure $ Lambda (params ++ params_tan) (ret <> ret_tan) body_tan + mkLambda (params <> params_tan) $ + bodyBind =<< fwdBodyTansLast body diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 5c2fd4ccb2..e74b82527a 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -42,6 +42,7 @@ module Futhark.AD.Rev.Monad zeroExp, unitAdjOfType, addLambda, + vecOpExp, -- VjpOps (..), -- @@ -311,7 +312,12 @@ addBinOp (FloatType ft) = FAdd ft addBinOp Bool = LogAnd addBinOp Unit = LogAnd -tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName] +tabNest :: + (MonadBuilder m, Rep m ~ SOACS) => + Int -> + [VName] -> + ([VName] -> [VName] -> m [VName]) -> + m [VName] tabNest = tabNest' [] where tabNest' is 0 vs f = f (reverse is) vs @@ -352,7 +358,29 @@ addLambda t@Array {} = do addLambda t = error $ "addLambda: " ++ show t --- Construct an expression for adding the two variables. +-- | Construct a lambda for binop'ing two values of the given type, +-- which may be arrays. +vecOpLambda :: (PrimType -> BinOp) -> Type -> ADM (Lambda SOACS) +vecOpLambda bop (Prim pt) = binOpLambda (bop pt) pt +vecOpLambda bop t@Array {} = do + xs_p <- newParam "xs" t + ys_p <- newParam "ys" t + lam <- vecOpLambda bop $ rowType t + body <- insertStmsM $ do + res <- + letSubExp "lam_map" . Op $ + Screma (arraySize 0 t) [paramName xs_p, paramName ys_p] (mapSOAC lam) + pure $ resultBody [res] + pure + Lambda + { lambdaParams = [xs_p, ys_p], + lambdaReturnType = [t], + lambdaBody = body + } +vecOpLambda _ t = + error $ "vecOpLambda: " ++ show t + +-- | Construct an expression for adding the two variables. addExp :: VName -> VName -> ADM (Exp SOACS) addExp x y = do x_t <- lookupType x @@ -365,6 +393,19 @@ addExp x y = do _ -> error $ "addExp: unexpected type: " ++ prettyString x_t +-- | Construct an expression for performing this binary operation on two variables. +vecOpExp :: (PrimType -> BinOp) -> VName -> VName -> ADM (Exp SOACS) +vecOpExp bop x y = do + x_t <- lookupType x + case x_t of + Prim pt -> + pure $ BasicOp $ BinOp (bop pt) (Var x) (Var y) + Array {} -> do + lam <- vecOpLambda bop $ rowType x_t + pure $ Op $ Screma (arraySize 0 x_t) [x, y] (mapSOAC lam) + _ -> + error $ "vecOpExp: unexpected type: " ++ prettyString x_t + lookupAdj :: VName -> ADM Adj lookupAdj v = do maybeAdj <- gets $ M.lookup v . stateAdjs diff --git a/src/Futhark/IR/Prop/Types.hs b/src/Futhark/IR/Prop/Types.hs index 510cc6e0fc..f95ae6733f 100644 --- a/src/Futhark/IR/Prop/Types.hs +++ b/src/Futhark/IR/Prop/Types.hs @@ -133,10 +133,10 @@ existential = any ext . shapeDims . arrayShape ext (Free _) = False -- | Return the uniqueness of a type. -uniqueness :: TypeBase shape Uniqueness -> Uniqueness +uniqueness :: (Monoid u) => TypeBase shape u -> u uniqueness (Array _ _ u) = u uniqueness (Acc _ _ _ u) = u -uniqueness _ = Nonunique +uniqueness _ = mempty -- | @unique t@ is 'True' if the type of the argument is unique. unique :: TypeBase shape Uniqueness -> Bool diff --git a/src/Futhark/IR/Syntax/Core.hs b/src/Futhark/IR/Syntax/Core.hs index fb9603de50..7b0b1dec78 100644 --- a/src/Futhark/IR/Syntax/Core.hs +++ b/src/Futhark/IR/Syntax/Core.hs @@ -168,9 +168,13 @@ class (Monoid a, Eq a, Ord a) => ArrayShape a where -- | Check whether one shape if a subset of another shape. subShapeOf :: a -> a -> Bool + -- | Prepend the dimensions of a 'Shape'. + prependShape :: Shape -> a -> a + instance ArrayShape (ShapeBase SubExp) where shapeRank (Shape l) = length l subShapeOf = (==) + prependShape = (<>) instance ArrayShape (ShapeBase ExtSize) where shapeRank (Shape l) = length l @@ -193,6 +197,8 @@ instance ArrayShape (ShapeBase ExtSize) where put $ M.insert y x extmap pure True + prependShape shape = (fmap Free shape <>) + instance Semigroup Rank where Rank x <> Rank y = Rank $ x + y @@ -202,6 +208,7 @@ instance Monoid Rank where instance ArrayShape Rank where shapeRank (Rank x) = x subShapeOf = (==) + prependShape shape (Rank x) = Rank $ shapeRank shape + x -- | The memory space of a block. If 'DefaultSpace', this is the "default" -- space, whatever that is. The exact meaning of the 'SpaceId' diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index 2a7588ce21..5e53db75ec 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -47,7 +47,7 @@ onStm mode scope (Let pat aux (Op (JVP shape args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do - lam'' <- fwdJVP scope lam' + lam'' <- fwdJVP scope shape lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope else pure $ oneStm $ Let pat aux $ Op $ JVP shape args vec lam' onStm mode scope (Let pat aux e) = oneStm . Let pat aux <$> mapExpM mapper e diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 009d9ac4e4..751cbb5b60 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -422,6 +422,11 @@ fromArray :: Value -> (ValueShape, [Value]) fromArray (ValueArray shape as) = (shape, elems as) fromArray v = error $ "Expected array value, but found: " <> show v +project :: Name -> Value -> Value +project f (ValueRecord fs) + | Just v' <- M.lookup f fs = v' +project _ _ = error "Value does not have expected field." + apply :: SrcLoc -> Env -> Value -> Value -> EvalM Value apply loc env (ValueFun f) v = stacking loc env (f v) apply _ _ f _ = error $ "Cannot apply non-function: " <> show f @@ -1063,10 +1068,7 @@ eval _ (ProjectSection ks _ _) = | Just v' <- M.lookup f fs = pure v' walk _ _ = error "Value does not have expected field." eval env (Project f e _ _) = do - v <- eval env e - case v of - ValueRecord fs | Just v' <- M.lookup f fs -> pure v' - _ -> error "Value does not have expected field." + project f <$> eval env e eval env (Assert what e (Info s) loc) = do cond <- asBool <$> eval env what unless cond $ bad loc env s @@ -1249,7 +1251,140 @@ breakOnNaN inputs result breakOnNaN _ _ = pure () --- | The initial environment contains definitions of the various intrinsic functions. +getV :: PrimValue -> Maybe P.PrimValue +getV (SignedValue x) = Just $ P.IntValue x +getV (UnsignedValue x) = Just $ P.IntValue x +getV (FloatValue x) = Just $ P.FloatValue x +getV (BoolValue x) = Just $ P.BoolValue x + +putV :: P.PrimValue -> PrimValue +putV (P.IntValue x) = SignedValue x +putV (P.FloatValue x) = FloatValue x +putV (P.BoolValue x) = BoolValue x +putV P.UnitValue = BoolValue True + +getAD :: Value -> Maybe AD.ADValue +getAD (ValuePrim v) = AD.Constant <$> getV v +getAD (ValueAD d v) = Just $ AD.Variable d v +getAD _ = Nothing + +putAD :: AD.ADValue -> Value +putAD (AD.Variable d s) = ValueAD d s +putAD (AD.Constant v) = ValuePrim $ putV v + +modifyValue :: (Num t) => (t -> Value -> Value) -> Value -> Value +modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v + +modifyValueM :: + (Num t, Monad m) => + (t -> Value -> m Value) -> + Value -> + m Value +modifyValueM f v = snd <$> valueAccumLM g 0 v + where + g a b = do + b' <- f a b + pure (a + 1, b') + +-- TODO: This could be much better. Currently, it is very inefficient +-- Perhaps creating JVPValues could be abstracted into a function +-- exposed by the AD module? +doJVP2 :: Value -> Value -> Value -> EvalM Value +doJVP2 f v s = do + -- Get the depth + depth <- length <$> stacktrace + + -- Turn the seeds into a list of ADValues + let s' = + fromMaybe (error $ "jvp: invalid seeds " ++ show s) $ + mapM getAD $ + fst $ + valueAccum (\a b -> (b : a, b)) [] s + -- Augment the values + let v' = + fromMaybe (error $ "jvp: invalid values " ++ show v) $ + modifyValueM + ( \i lv -> do + lv' <- getAD lv + pure $ ValueAD depth . AD.JVP . AD.JVPValue lv' $ s' !! (length s' - 1 - i) + ) + v + + -- Run the function, and turn its outputs into a list of Values + o <- apply noLoc mempty f v' + let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o + + -- For each output.. + let m = + fromMaybe (error "jvp: differentiation failed") $ + forM o' $ \on -> case on of + -- If it is a JVP variable of the correct depth, return its primal and derivative + (ValueAD d (AD.JVP (AD.JVPValue pv dv))) | d == depth -> Just (putAD pv, putAD dv) + -- Otherwise, its partial derivatives are all 0 + _ -> (on,) . ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD on + + -- Extract the output values, and the partial derivatives + let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o + od = modifyValue (\i _ -> snd $ m !! (length m - 1 - i)) o + + -- Return a tuple of the output values, and partial derivatives + pure $ toTuple [ov, od] + +-- TODO: This could be much better. Currently, it is very inefficient +-- Perhaps creating VJPValues could be abstracted into a function +-- exposed by the AD module? +doVJP2 :: Value -> Value -> Value -> EvalM Value +doVJP2 f v s = do + -- Get the depth + depth <- length <$> stacktrace + + -- Augment the values + let v' = + fromMaybe (error $ "vjp: invalid values " ++ show v) $ + modifyValueM (\i lv -> ValueAD depth . AD.VJP . AD.VJPValue . AD.TapeID i <$> getAD lv) v + -- Turn the seeds into a list of ADValues + let s' = + fromMaybe (error $ "vjp: invalid seeds " ++ show s) $ + mapM getAD $ + fst $ + valueAccum (\a b -> (b : a, b)) [] s + + -- Run the function, and turn its outputs into a list of Values + o <- apply noLoc mempty f v' + let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o + + -- For each output.. + let m = + fromMaybe (error "vjp: differentiation failed") $ + forM (zip o' s') $ \(on, sn) -> case on of + -- If it is a VJP variable of the correct depth, run + -- deriveTape on it- and its corresponding seed + (ValueAD d (AD.VJP (AD.VJPValue t))) + | d == depth -> + (putAD $ AD.tapePrimal t,) <$> AD.deriveTape t sn + -- Otherwise, its partial derivatives are all 0 + _ -> Just (on, M.empty) + + -- Add together every derivative + let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m + + -- Extract the output values, and the partial derivatives + let ov = modifyValue (\i _ -> fst $ m !! i) o + let od = + fromMaybe (error "vjp: differentiation failed") $ + modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v + + -- Return a tuple of the output values, and partial derivatives + pure $ toTuple [ov, od] + where + -- TODO: Perhaps this could be fully abstracted by AD? + -- Making addFor private would be nice.. + add x y = + fromMaybe (error "jvp: illtyped add") $ + AD.doOp (AD.OpBin $ AD.addFor $ P.primValueType $ AD.primitive x) [x, y] + +-- | The initial environment contains definitions of the various +-- intrinsic functions. initialCtx :: Ctx initialCtx = Ctx @@ -1306,15 +1441,6 @@ initialCtx = ] boolCmp f = [(getB, Just . BoolValue, P.doCmpOp f, adBinOp $ AD.OpCmp f)] - getV (SignedValue x) = Just $ P.IntValue x - getV (UnsignedValue x) = Just $ P.IntValue x - getV (FloatValue x) = Just $ P.FloatValue x - getV (BoolValue x) = Just $ P.BoolValue x - putV (P.IntValue x) = SignedValue x - putV (P.FloatValue x) = FloatValue x - putV (P.BoolValue x) = BoolValue x - putV P.UnitValue = BoolValue True - getS (SignedValue x) = Just $ P.IntValue x getS _ = Nothing putS (P.IntValue x) = Just $ SignedValue x @@ -1335,12 +1461,6 @@ initialCtx = putB (P.BoolValue x) = Just $ BoolValue x putB _ = Nothing - getAD (ValuePrim v) = AD.Constant <$> getV v - getAD (ValueAD d v) = Just $ AD.Variable d v - getAD _ = Nothing - putAD (AD.Variable d s) = ValueAD d s - putAD (AD.Constant v) = ValuePrim $ putV v - adToPrim v = putV $ AD.primitive v adBinOp op x y = AD.doOp op [x, y] @@ -1970,130 +2090,20 @@ initialCtx = <> "]" else pure $ toArray shape $ map (toArray rowshape) $ chunk (asInt m) xs' def "manifest" = Just $ fun1 pure - def "vjp2" = Just $ - -- TODO: This could be much better. Currently, it is very inefficient - -- Perhaps creating VJPValues could be abstracted into a function - -- exposed by the AD module? - fun3 $ \f v s -> do - -- Get the depth - depth <- length <$> stacktrace - - -- Augment the values - let v' = - fromMaybe (error $ "vjp: invalid values " ++ show v) $ - modifyValueM (\i lv -> ValueAD depth . AD.VJP . AD.VJPValue . AD.TapeID i <$> getAD lv) v - -- Turn the seeds into a list of ADValues - let s' = - fromMaybe (error $ "vjp: invalid seeds " ++ show s) $ - mapM getAD $ - fst $ - valueAccum (\a b -> (b : a, b)) [] s - - -- Run the function, and turn its outputs into a list of Values - o <- apply noLoc mempty f v' - let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o - - -- For each output.. - let m = - fromMaybe (error "vjp: differentiation failed") $ - zipWithM - ( \on sn -> case on of - -- If it is a VJP variable of the correct depth, run deriveTape on it- and its corresponding seed - (ValueAD d (AD.VJP (AD.VJPValue t))) | d == depth -> (putAD $ AD.tapePrimal t,) <$> AD.deriveTape t sn - -- Otherwise, its partial derivatives are all 0 - _ -> Just (on, M.empty) - ) - o' - s' - - -- Add together every derivative - let drvs = M.map (Just . putAD) $ M.unionsWith add $ map snd m - - -- Extract the output values, and the partial derivatives - let ov = modifyValue (\i _ -> fst $ m !! i) o - let od = - fromMaybe (error "vjp: differentiation failed") $ - modifyValueM (\i vo -> M.findWithDefault (ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD vo) i drvs) v - - -- Return a tuple of the output values, and partial derivatives - pure $ toTuple [ov, od] - where - modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v - modifyValueM f v = - snd - <$> valueAccumLM - ( \a b -> do - b' <- f a b - pure (a + 1, b') - ) - 0 - v - - -- TODO: Perhaps this could be fully abstracted by AD? - -- Making addFor private would be nice.. - add x y = - fromMaybe (error "jvp: illtyped add") $ - AD.doOp (AD.OpBin $ AD.addFor $ P.primValueType $ AD.primitive x) [x, y] - def "jvp2" = Just $ - -- TODO: This could be much better. Currently, it is very inefficient - -- Perhaps creating JVPValues could be abstracted into a function - -- exposed by the AD module? - fun3 $ \f v s -> do - -- Get the depth - depth <- length <$> stacktrace - - -- Turn the seeds into a list of ADValues - let s' = - expectJust ("jvp: invalid seeds " ++ show s) $ - mapM getAD $ - fst $ - valueAccum (\a b -> (b : a, b)) [] s - -- Augment the values - let v' = - expectJust ("jvp: invalid values " ++ show v) $ - modifyValueM - ( \i lv -> do - lv' <- getAD lv - pure $ ValueAD depth . AD.JVP . AD.JVPValue lv' $ s' !! (length s' - 1 - i) - ) - v - - -- Run the function, and turn its outputs into a list of Values - o <- apply noLoc mempty f v' - let o' = fst $ valueAccum (\a b -> (b : a, b)) [] o - - -- For each output.. - let m = - expectJust "jvp: differentiation failed" $ - mapM - ( \on -> case on of - -- If it is a JVP variable of the correct depth, return its primal and derivative - (ValueAD d (AD.JVP (AD.JVPValue pv dv))) | d == depth -> Just (putAD pv, putAD dv) - -- Otherwise, its partial derivatives are all 0 - _ -> (on,) . ValuePrim . putV . P.blankPrimValue . P.primValueType . AD.primitive <$> getAD on - ) - o' - - -- Extract the output values, and the partial derivatives - let ov = modifyValue (\i _ -> fst $ m !! (length m - 1 - i)) o - od = modifyValue (\i _ -> snd $ m !! (length m - 1 - i)) o - - -- Return a tuple of the output values, and partial derivatives - pure $ toTuple [ov, od] - where - modifyValue f v = snd $ valueAccum (\a b -> (a + 1, f a b)) 0 v - modifyValueM f v = - snd - <$> valueAccumLM - ( \a b -> do - b' <- f a b - pure (a + 1, b') - ) - 0 - v - - expectJust _ (Just v) = v - expectJust s Nothing = error s + def "jvp2" = Just $ fun3 doJVP2 + def "vjp2" = Just $ fun3 doVJP2 + def "jvp2_vec" = Just $ fun3 $ \f x seeds -> do + v <- apply noLoc mempty f x + dvs <- + toArray' (valueShape v) . map (project "1") + <$> mapM (doJVP2 f x) (snd (fromArray seeds)) + pure $ toTuple [v, dvs] + def "vjp2_vec" = Just $ fun3 $ \f x seeds -> do + v <- apply noLoc mempty f x + dvs <- + toArray' (valueShape x) + <$> mapM (doVJP2 f x) (snd (fromArray seeds)) + pure $ toTuple [v, dvs] def "acc" = Nothing def s | nameFromString s `M.member` namesToPrimTypes = Nothing def s = error $ "Missing intrinsic: " ++ s From a8e616e81da0f44bc98ce73892bbc2449f0315e6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 22 Dec 2024 21:49:09 +0100 Subject: [PATCH 04/70] Forgot to add tests. --- tests/ad/vec/README.md | 4 ++++ tests/ad/vec/concat.fut | 15 +++++++++++++++ tests/ad/vec/index.fut | 15 +++++++++++++++ tests/ad/vec/reduce.fut | 15 +++++++++++++++ tests/ad/vec/replicate.fut | 15 +++++++++++++++ tests/ad/vec/reshape.fut | 15 +++++++++++++++ tests/ad/vec/transpose.fut | 15 +++++++++++++++ 7 files changed, 94 insertions(+) create mode 100644 tests/ad/vec/README.md create mode 100644 tests/ad/vec/concat.fut create mode 100644 tests/ad/vec/index.fut create mode 100644 tests/ad/vec/reduce.fut create mode 100644 tests/ad/vec/replicate.fut create mode 100644 tests/ad/vec/reshape.fut create mode 100644 tests/ad/vec/transpose.fut diff --git a/tests/ad/vec/README.md b/tests/ad/vec/README.md new file mode 100644 index 0000000000..89ef5bee8c --- /dev/null +++ b/tests/ad/vec/README.md @@ -0,0 +1,4 @@ +# Microbenchmarks for vectorised AD + +This directory contains tests for individual (core language) language +primitives. diff --git a/tests/ad/vec/concat.fut b/tests/ad/vec/concat.fut new file mode 100644 index 0000000000..69d94787df --- /dev/null +++ b/tests/ad/vec/concat.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { [1.0, 2.0, 3.0] } +-- output { [[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0]] } + +def f (xs: []f64) = xs ++ xs + +entry fwd_vec (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec f xs seeds).1 + +entry fwd_map (xs: []f64) = + map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) diff --git a/tests/ad/vec/index.fut b/tests/ad/vec/index.fut new file mode 100644 index 0000000000..5b8c89eada --- /dev/null +++ b/tests/ad/vec/index.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { 0i32 [1f32, 2f32, 3f32] } +-- output { [1f32, 0f32, 0f32] } + +def f (i: i32) (xs: []f32) = xs[i] + +entry fwd_vec l (xs: []f32) : []f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec (f l) xs seeds).1 + +entry fwd_map l (xs: []f32) : []f32 = + map (\i -> jvp (f l) xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) diff --git a/tests/ad/vec/reduce.fut b/tests/ad/vec/reduce.fut new file mode 100644 index 0000000000..a9b970c3aa --- /dev/null +++ b/tests/ad/vec/reduce.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { [1f32, 2f32, 3f32] } +-- output { [6f32, 3f32, 2f32] } + +def f (xs: []f32) = f32.product xs + +entry fwd_vec (xs: []f32) : []f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec f xs seeds).1 + +entry fwd_map (xs: []f32) : []f32 = + map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) diff --git a/tests/ad/vec/replicate.fut b/tests/ad/vec/replicate.fut new file mode 100644 index 0000000000..222fb21ed3 --- /dev/null +++ b/tests/ad/vec/replicate.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { 2i64 [1.0, 2.0] } +-- output { [[[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 1.0]]] } + +def f (n: i64) (xs: []f64) = replicate n xs + +entry fwd_vec n (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec (f n) xs seeds).1 + +entry fwd_map n (xs: []f64) = + map (\i -> jvp (f n) xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) diff --git a/tests/ad/vec/reshape.fut b/tests/ad/vec/reshape.fut new file mode 100644 index 0000000000..2094e6ff4b --- /dev/null +++ b/tests/ad/vec/reshape.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { [1.0,2.0,3.0,4.0] } +-- output { [[[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]]] } + +def f (xs: []f64) = unflatten (sized (2 * 2) xs) + +entry fwd_vec (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec f xs seeds).1 + +entry fwd_map (xs: []f64) = + map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) diff --git a/tests/ad/vec/transpose.fut b/tests/ad/vec/transpose.fut new file mode 100644 index 0000000000..f8c5bf8a32 --- /dev/null +++ b/tests/ad/vec/transpose.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { [[1.0,2.0],[3.0,4.0]] } +-- output { [[[1.0, 0.0],[0.0, 0.0]],[[0.0, 0.0],[1.0, 0.0]],[[0.0, 1.0],[0.0, 0.0]],[[0.0, 0.0],[0.0, 1.0]]] } + +def f (xs: [][]f64) = transpose xs + +entry fwd_vec [n] [m] (xs: [n][m]f64) = + let seeds = + tabulate (n * m) (\i -> tabulate (n * m) (\j -> f64.bool (i == j)) |> unflatten) + in (jvp2_vec f xs seeds).1 + +entry fwd_map [n] [m] (xs: [n][m]f64) = + tabulate (n * m) + (\i -> jvp f xs (tabulate (n * m) (\j -> f64.bool (i == j)) |> unflatten)) From 9b488e150b630b8d425f4b9d0f68d54cd86829e0 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Aug 2025 11:04:24 +0200 Subject: [PATCH 05/70] Scan test. --- tests/ad/vec/scan.fut | 15 +++++++++++++++ 1 file changed, 15 insertions(+) create mode 100644 tests/ad/vec/scan.fut diff --git a/tests/ad/vec/scan.fut b/tests/ad/vec/scan.fut new file mode 100644 index 0000000000..607c7e440b --- /dev/null +++ b/tests/ad/vec/scan.fut @@ -0,0 +1,15 @@ +-- == +-- entry: fwd_vec fwd_map +-- input { [1f32, 2f32, 3f32] } +-- output { [[1f32, 2.0, 6.0], [0f32, 1.0, 3.0], [0f32, 0.0, 2.0]] } + +def f (xs: []f32) = scan (*) 1 xs + +entry fwd_vec (xs: []f32) : [][]f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in transpose (jvp2_vec f xs seeds).1 + +entry fwd_map (xs: []f32) : [][]f32 = + map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) From 9d15ec61956ea428d668514c669c66d4773e6b63 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Aug 2025 11:10:33 +0200 Subject: [PATCH 06/70] Add jvp_vec and vjp_vec. --- prelude/ad.fut | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/prelude/ad.fut b/prelude/ad.fut index 6ec2803c3f..82cb83ddbc 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -125,3 +125,13 @@ def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = -- | Vector-Jacobian Product ("reverse mode"). def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a = (vjp2 f x y').1 + +-- | As `jvp`, but accepts a vector of seed values. Semantically +-- equivalent to mapping, but may be more efficient. +def jvp_vec 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : [n]b = + (jvp2_vec f x x').1 + +-- | As `vjp`, but accepts a vector of seed values. Semantically +-- equivalent to mapping, but may be more efficient. +def vjp_vec 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a = + (vjp2_vec f x y').1 From e75f617088f7ab35f60bcb92e355f701aeafdbf7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Aug 2025 15:57:19 +0200 Subject: [PATCH 07/70] This should not need modification. --- src/Futhark/AD/Fwd.hs | 10 ++++++---- src/Futhark/IR/Prop/Types.hs | 4 ++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 93cb4766a3..a61387e5bd 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -33,10 +33,12 @@ tanType (Acc acc ispace ts u) = do tanType t = do shape <- askShape pure $ - arrayOf - (Prim (elemType t)) - (shape `prependShape` arrayShape t) - (uniqueness t) + arrayOf (Prim (elemType t)) (shape `prependShape` arrayShape t) u + where + u = case t of + Array _ _ u -> u + Acc _ _ _ u -> u + _ -> mempty slocal' :: ADM a -> ADM a slocal' = slocal id diff --git a/src/Futhark/IR/Prop/Types.hs b/src/Futhark/IR/Prop/Types.hs index 431de474bd..5571ac5056 100644 --- a/src/Futhark/IR/Prop/Types.hs +++ b/src/Futhark/IR/Prop/Types.hs @@ -134,10 +134,10 @@ existential = any ext . shapeDims . arrayShape ext (Free _) = False -- | Return the uniqueness of a type. -uniqueness :: (Monoid u) => TypeBase shape u -> u +uniqueness :: TypeBase shape Uniqueness -> Uniqueness uniqueness (Array _ _ u) = u uniqueness (Acc _ _ _ u) = u -uniqueness _ = mempty +uniqueness _ = Nonunique -- | @unique t@ is 'True' if the type of the argument is unique. unique :: TypeBase shape Uniqueness -> Bool From cbd98fb9fe429ff2395c6be11de81675b3ea7b0c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 14 Aug 2025 17:56:01 +0200 Subject: [PATCH 08/70] Change how accumulators are handled. --- src/Futhark/AD/Fwd.hs | 151 +++++++++++++++++------------------- src/Futhark/AD/Rev/Monad.hs | 27 ------- src/Futhark/Tools.hs | 36 +++++++++ src/Futhark/Util.hs | 11 +++ tests/ad/fwd/acc0.fut | 8 +- 5 files changed, 120 insertions(+), 113 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index a61387e5bd..ebfad19ffb 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -5,15 +5,15 @@ module Futhark.AD.Fwd (fwdJVP) where import Control.Monad import Control.Monad.Reader import Control.Monad.State.Strict -import Data.Bifunctor (second) -import Data.List (transpose) +import Data.Bifunctor (bimap, second) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M import Futhark.AD.Derivatives import Futhark.Analysis.PrimExp.Convert import Futhark.Builder -import Futhark.Construct import Futhark.IR.SOACS +import Futhark.Tools +import Futhark.Util (interleave) zeroTan :: Type -> ADM SubExp zeroTan (Prim t) = pure $ constant $ blankPrimValue t @@ -28,16 +28,15 @@ zeroExp t = error $ "zeroExp: " ++ show t tanType :: (ArrayShape s, Monoid u) => TypeBase s u -> ADM (TypeBase s u) tanType (Acc acc ispace ts u) = do - ts_tan <- mapM tanType ts - pure $ Acc acc ispace (ts ++ ts_tan) u + acc_tan <- tangent acc + pure $ Acc acc_tan ispace ts u tanType t = do shape <- askShape pure $ arrayOf (Prim (elemType t)) (shape `prependShape` arrayShape t) u where u = case t of - Array _ _ u -> u - Acc _ _ _ u -> u + Array _ _ u' -> u' _ -> mempty slocal' :: ADM a -> ADM a @@ -100,27 +99,20 @@ insertTan v v' = class TanBuilder a where newTan :: a -> ADM a - bundleNew :: a -> ADM [a] + bundleNew :: a -> ADM (a, a) bundleNewList :: (TanBuilder a) => [a] -> ADM [a] -bundleNewList = fmap mconcat . mapM bundleNew +bundleNewList = fmap (uncurry interleave . unzip) . mapM bundleNew instance (ArrayShape s, Monoid u) => TanBuilder (PatElem (TypeBase s u)) where - newTan (PatElem p t) - | isAcc t = do - insertTan p p - t' <- tanType t - pure $ PatElem p t' - | otherwise = do - p' <- tanVName p - insertTan p p' - t' <- tanType t - pure $ PatElem p' t' - bundleNew pe@(PatElem _ t) = do + newTan (PatElem p t) = do + p' <- tanVName p + insertTan p p' + t' <- tanType t + pure $ PatElem p' t' + bundleNew pe = do pe' <- newTan pe - if isAcc t - then pure [pe'] - else pure [pe, pe'] + pure (pe, pe') newTanPat :: (TanBuilder (PatElem t)) => Pat t -> ADM (Pat t) newTanPat (Pat pes) = Pat <$> mapM newTan pes @@ -132,40 +124,29 @@ instance (ArrayShape s, Monoid u) => TanBuilder (Param (TypeBase s u)) where newTan (Param _ p t) = do PatElem p' t' <- newTan $ PatElem p t pure $ Param mempty p' t' - bundleNew param@(Param _ _ (Prim Unit)) = - pure [param] - bundleNew param@(Param _ _ t) = do + bundleNew param = do param' <- newTan param - if isAcc t - then pure [param'] - else pure [param, param'] + pure (param, param') -instance - (ArrayShape s, Monoid u, Tangent a) => - TanBuilder (Param (TypeBase s u), a) - where +instance (TanBuilder a, Tangent b) => TanBuilder (a, b) where newTan (p, x) = (,) <$> newTan p <*> tangent x bundleNew (p, x) = do - b <- bundleNew p + p' <- newTan p x_tan <- tangent x - pure $ zip b [x, x_tan] + pure ((p, x), (p', x_tan)) class Tangent a where tangent :: a -> ADM a - bundleTan :: a -> ADM [a] + bundleTan :: a -> ADM (a, a) instance (ArrayShape s, Monoid u) => Tangent (TypeBase s u) where tangent = tanType - bundleTan t - | isAcc t = do - t' <- tangent t - pure [t'] - | otherwise = do - t' <- tangent t - pure [t, t'] + bundleTan t = do + t' <- tangent t + pure (t, t') bundleTangents :: (Tangent a) => [a] -> ADM [a] -bundleTangents = (mconcat <$>) . mapM bundleTan +bundleTangents = fmap (uncurry interleave . unzip) . mapM bundleTan instance Tangent VName where tangent v = do @@ -174,26 +155,25 @@ instance Tangent VName where Just v_tan -> pure v_tan Nothing -> do t <- lookupType v + when (isAcc t) $ + error $ + "Missing tangent for accumulator " <> prettyString v letExp (baseString v <> "_implicit_tan") $ zeroExp t bundleTan v = do - t <- lookupType v - if isAcc t - then pure [v] - else do - v_tan <- tangent v - pure [v, v_tan] + v_tan <- tangent v + pure (v, v_tan) instance Tangent SubExp where tangent (Constant c) = zeroTan $ Prim $ primValueType c tangent (Var v) = Var <$> tangent v bundleTan c@Constant {} = do c_tan <- tangent c - pure [c, c_tan] - bundleTan (Var v) = fmap Var <$> bundleTan v + pure (c, c_tan) + bundleTan (Var v) = bimap Var Var <$> bundleTan v instance Tangent SubExpRes where tangent (SubExpRes cs se) = SubExpRes cs <$> tangent se - bundleTan (SubExpRes cs se) = map (SubExpRes cs) <$> bundleTan se + bundleTan (SubExpRes cs se) = bimap (SubExpRes cs) (SubExpRes cs) <$> bundleTan se asVName :: SubExp -> ADM VName asVName (Var v) = pure v @@ -313,21 +293,31 @@ basicFwd pat aux op = do _ -> error $ "basicFwd: Unsupported op " ++ prettyString op fwdLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdLambda (Lambda params ret body) = do +fwdLambda (Lambda params _ body) = do params' <- bundleNewList params - Lambda params' - <$> bundleTangents ret - <*> localScope (scopeOfLParams params') (fwdBody body) + mkLambda params' $ bodyBind =<< fwdBody body + +fwdWithAccLambda :: [WithAccInput SOACS] -> Lambda SOACS -> ADM (Lambda SOACS) +fwdWithAccLambda inputs (Lambda params _ body) = do + let (cert_params, acc_params) = splitAt (length inputs) params + cert_params_tan <- replicateM (length inputs) $ newParam "acc_cert_tan" $ Prim Unit + acc_params_tan <- zipWithM mkAccParam (map paramName cert_params_tan) inputs + + mkLambda (cert_params <> cert_params_tan <> acc_params <> acc_params_tan) $ do + zipWithM_ + insertTan + (map paramName (cert_params <> acc_params)) + (map paramName (cert_params_tan <> acc_params_tan)) + bodyBind =<< fwdBody body + where + mkAccParam c (shape, arrs, _) = do + ts <- map (stripArray (shapeRank shape)) <$> mapM lookupType arrs + newParam "acc_p_tan" $ Acc c shape ts NoUniqueness fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdStreamLambda (Lambda params ret body) = do +fwdStreamLambda (Lambda params _ body) = do params' <- (take 1 params ++) <$> bundleNewList (drop 1 params) - Lambda params' - <$> bundleTangents ret - <*> localScope (scopeOfLParams params') (fwdBody body) - -interleave :: [a] -> [a] -> [a] -interleave xs ys = concat $ transpose [xs, ys] + mkLambda params' $ bodyBind =<< fwdBody body zeroFromSubExp :: SubExp -> ADM VName zeroFromSubExp (Constant c) = @@ -414,10 +404,11 @@ fwdSOAC _ _ VJP {} = fwdStm :: Stm SOACS -> ADM () fwdStm (Let pat aux (BasicOp (UpdateAcc safety acc i x))) = do - pat' <- bundleNewPat pat - x' <- bundleTangents x + x_tan <- mapM tangent x acc_tan <- tangent acc - addStm $ Let pat' aux $ BasicOp $ UpdateAcc safety acc_tan i x' + addStm $ Let pat aux $ BasicOp $ UpdateAcc safety acc i x + res_tan <- letExp "tan" $ BasicOp $ UpdateAcc safety acc_tan i x_tan + insertTan (head $ patNames pat) res_tan fwdStm stm@(Let pat aux (BasicOp e)) = do -- XXX: this has to be too naive. unless (any isAcc $ patTypes pat) $ addStm stm @@ -467,25 +458,23 @@ fwdStm (Let pat aux (Loop val_pats loop@(ForLoop i it bound) body)) = do fwdBody body addStm $ Let pat' aux $ Loop val_pats' (ForLoop i it bound) body' fwdStm (Let pat aux (WithAcc inputs lam)) = do - inputs' <- forM inputs $ \(shape, arrs, op) -> do + inputs_tan <- forM inputs $ \(shape, arrs, op) -> do arrs_tan <- mapM tangent arrs op' <- case op of Nothing -> pure Nothing Just (op_lam, nes) -> do - nes_tan <- mapM (fmap Var . zeroFromSubExp) nes - op_lam' <- fwdLambda op_lam - case op_lam' of - Lambda ps ret body -> do - let op_lam'' = Lambda (removeIndexTans (shapeRank shape) ps) ret body - pure $ Just (op_lam'', interleave nes nes_tan) - pure (shape, arrs <> arrs_tan, op') + -- We assume that op_lam has unit partial derivatives (i.e., is some + -- kind of addition). This is the case for all WithAccs produced by VJP. + lams <- mapM addLambda $ lambdaReturnType op_lam + -- Horizontally fuse the lambdas to produce a single one. + idx_params <- replicateM (shapeRank shape) $ newParam "idx" $ Prim int64 + let (xs, ys) = bimap concat concat $ unzip $ map (splitAt 1 . lambdaParams) lams + op_lam' <- mkLambda (idx_params <> xs <> ys) $ mconcat <$> mapM (bodyBind . lambdaBody) lams + pure $ Just (op_lam', nes) + pure (shape, arrs_tan, op') pat' <- bundleNewPat pat - lam' <- fwdLambda lam - addStm $ Let pat' aux $ WithAcc inputs' lam' - where - removeIndexTans 0 ps = ps - removeIndexTans i (p : _ : ps) = p : removeIndexTans (i - 1) ps - removeIndexTans _ ps = ps + lam' <- fwdWithAccLambda inputs lam + addStm $ Let pat' aux $ WithAcc (interleave inputs inputs_tan) lam' fwdStm (Let pat aux (Op soac)) = fwdSOAC pat aux soac fwdStm stm = error $ "unhandled forward mode AD for Stm: " ++ prettyString stm ++ "\n" ++ show stm diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 1a036dc8db..18e95efe2c 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -314,12 +314,6 @@ noAdjsFor names m = do where names' = namesToList names -addBinOp :: PrimType -> BinOp -addBinOp (IntType it) = Add it OverflowWrap -addBinOp (FloatType ft) = FAdd ft -addBinOp Bool = LogAnd -addBinOp Unit = LogAnd - tabNest :: (MonadBuilder m, Rep m ~ SOACS) => Int -> @@ -345,27 +339,6 @@ tabNest = tabNest' [] let lam = Lambda (iparam : params) ret (Body () stms res) letTupExp "tab" $ Op $ Screma w (iota : vs) (mapSOAC lam) --- | Construct a lambda for adding two values of the given type. -addLambda :: Type -> ADM (Lambda SOACS) -addLambda (Prim pt) = binOpLambda (addBinOp pt) pt -addLambda t@Array {} = do - xs_p <- newParam "xs" t - ys_p <- newParam "ys" t - lam <- addLambda $ rowType t - body <- insertStmsM $ do - res <- - letSubExp "lam_map" . Op $ - Screma (arraySize 0 t) [paramName xs_p, paramName ys_p] (mapSOAC lam) - pure $ resultBody [res] - pure - Lambda - { lambdaParams = [xs_p, ys_p], - lambdaReturnType = [t], - lambdaBody = body - } -addLambda t = - error $ "addLambda: " ++ show t - -- | Construct a lambda for binop'ing two values of the given type, -- which may be arrays. vecOpLambda :: (PrimType -> BinOp) -> Type -> ADM (Lambda SOACS) diff --git a/src/Futhark/Tools.hs b/src/Futhark/Tools.hs index 99669ac0c4..02d24eaf6f 100644 --- a/src/Futhark/Tools.hs +++ b/src/Futhark/Tools.hs @@ -11,6 +11,8 @@ module Futhark.Tools partitionChunkedFoldParameters, withAcc, doScatter, + addBinOp, + addLambda, -- * Primitive expressions module Futhark.Analysis.PrimExp.Convert, @@ -236,3 +238,37 @@ doScatter desc rank dest arrs mk = do Screma w (map paramName acc_ps <> arrs) (mapSOAC map_lam) letTupExp desc $ WithAcc [(acc_shape, [v], Nothing) | v <- dest] withacc_lam + +-- | The most addition-like binary operator for some primitive type. +addBinOp :: PrimType -> BinOp +addBinOp (IntType it) = Add it OverflowWrap +addBinOp (FloatType ft) = FAdd ft +addBinOp Bool = LogAnd +addBinOp Unit = LogAnd + +-- | Construct a lambda for adding two values of the given type, Using SOACs to handle arrays. +addLambda :: + ( OpC (Rep m) ~ SOAC, + MonadBuilder m, + Buildable (Rep m) + ) => + TypeBase Shape NoUniqueness -> + m (Lambda (Rep m)) +addLambda (Prim pt) = binOpLambda (addBinOp pt) pt +addLambda t@Array {} = do + xs_p <- newParam "xs" t + ys_p <- newParam "ys" t + lam <- addLambda $ rowType t + body <- insertStmsM $ do + res <- + letSubExp "lam_map" . Op $ + Screma (arraySize 0 t) [paramName xs_p, paramName ys_p] (mapSOAC lam) + pure $ resultBody [res] + pure + Lambda + { lambdaParams = [xs_p, ys_p], + lambdaReturnType = [t], + lambdaBody = body + } +addLambda t = + error $ "addLambda: " ++ show t diff --git a/src/Futhark/Util.hs b/src/Futhark/Util.hs index b12ecdd9a1..4693db6ea5 100644 --- a/src/Futhark/Util.hs +++ b/src/Futhark/Util.hs @@ -54,6 +54,8 @@ module Futhark.Util topologicalSort, debugTraceM, ensureCacheDirectory, + interleave, + unterleave, ) where @@ -366,6 +368,15 @@ convFloat v | isNaN v = 0 / 0 | otherwise = fromRational $ toRational v +-- | Interleave two lists. +interleave :: [a] -> [a] -> [a] +interleave xs ys = concat $ L.transpose [xs, ys] + +-- | The inverse of interleave. +unterleave :: [a] -> ([a], [a]) +unterleave (x : y : xys) = bimap (x :) (y :) $ unterleave xys +unterleave _ = ([], []) + -- Z-encoding from https://ghc.haskell.org/trac/ghc/wiki/Commentary/Compiler/SymbolNames -- -- Slightly simplified as we do not need it to deal with tuples and diff --git a/tests/ad/fwd/acc0.fut b/tests/ad/fwd/acc0.fut index 17bbce9448..40020fece8 100644 --- a/tests/ad/fwd/acc0.fut +++ b/tests/ad/fwd/acc0.fut @@ -2,20 +2,18 @@ import "../../accs/intrinsics" def f (acc: *acc ([]i32)) i = write acc i (i32.i64 i) --- square entries - -- == -- entry: prim -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } --- output { [0, 1, 4, 9, 16, 25, 36, 49, 64, 81] } +-- output { [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] } entry prim [n] (xs: [n]i32) = let (xs': *[n]i32) = copy xs - in reduce_by_index_stream xs' (*) 1 f (map i64.i32 (xs :> [n]i32)) + in reduce_by_index_stream xs' (+) 0 f (map i64.i32 (xs :> [n]i32)) -- == -- entry: f_jvp -- input { [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] } --- output { [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] } +-- output { [2, 2, 2, 2, 2, 2, 2, 2, 2, 2] } entry f_jvp (xs: *[]i32) = jvp prim xs (replicate 10 1) From 8ae043a123c47076e0fad4b9a772ae8a4ea0070a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Aug 2025 12:41:54 +0200 Subject: [PATCH 09/70] Implement vectorised scatter. --- src/Futhark/AD/Fwd.hs | 49 +++++++++++++++++++++++++++++++-------- src/Futhark/Construct.hs | 6 +++++ tests/ad/vec/scatter0.fut | 22 ++++++++++++++++++ 3 files changed, 67 insertions(+), 10 deletions(-) create mode 100644 tests/ad/vec/scatter0.fut diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index ebfad19ffb..d380c81b1a 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -29,7 +29,8 @@ zeroExp t = error $ "zeroExp: " ++ show t tanType :: (ArrayShape s, Monoid u) => TypeBase s u -> ADM (TypeBase s u) tanType (Acc acc ispace ts u) = do acc_tan <- tangent acc - pure $ Acc acc_tan ispace ts u + tan_shape <- askShape + pure $ Acc acc_tan (tan_shape <> ispace) ts u tanType t = do shape <- askShape pure $ @@ -158,7 +159,8 @@ instance Tangent VName where when (isAcc t) $ error $ "Missing tangent for accumulator " <> prettyString v - letExp (baseString v <> "_implicit_tan") $ zeroExp t + tan_shape <- askShape + letExp (baseString v <> "_implicit_tan") $ zeroExp $ t `arrayOfShape` tan_shape bundleTan v = do v_tan <- tangent v pure (v, v_tan) @@ -197,6 +199,32 @@ withTan x f = do =<< f (Var (paramName x_tan_p)) pure $ Op $ Screma w [x_tan_v] (mapSOAC lam) +withTansI :: + VName -> + [SubExp] -> + ([SubExp] -> VName -> [SubExp] -> ADM (Exp SOACS)) -> + ADM (Exp SOACS) +withTansI x ys f = do + shape <- askShape + x_tan <- tangent x + ys_tan <- mapM tangent ys + if shape == mempty + then f [] x_tan ys_tan + else do + let w = shapeSize 0 shape + ys_tan_vs <- mapM asVName ys_tan + iota_p <- newParam "iota_p" $ Prim int64 + x_tan_p <- newParam "x_tanp" . rowType =<< lookupType x_tan + ys_tan_ps <- mapM (newParam "y_tanp" . rowType <=< lookupType) ys_tan_vs + lam <- mkLambda (iota_p : x_tan_p : ys_tan_ps) $ do + fmap (subExpsRes . pure) . letSubExp "tan" + =<< f + [Var $ paramName iota_p] + (paramName x_tan_p) + (map (Var . paramName) ys_tan_ps) + iota_v <- letExp "iota" $ iota64 w + pure $ Op $ Screma w (iota_v : x_tan : ys_tan_vs) (mapSOAC lam) + withTans :: PrimType -> SubExp -> @@ -237,11 +265,11 @@ basicFwd pat aux op = do let t = unOpType unop x_pe = primExpFromSubExp t x dx = pdUnOp unop x_pe - auxing aux $ letBindNames (patNames pat_tan) <=< withTan x $ \x_tan -> + auxing aux $ letBind pat_tan <=< withTan x $ \x_tan -> toExp $ primExpFromSubExp t x_tan ~*~ dx BinOp bop x y -> do let t = binOpType bop - auxing aux . letBindNames (patNames pat_tan) <=< withTans t x y $ + auxing aux . letBind pat_tan <=< withTans t x y $ \x_tan y_tan -> let (wrt_x, wrt_y) = pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) @@ -311,8 +339,9 @@ fwdWithAccLambda inputs (Lambda params _ body) = do bodyBind =<< fwdBody body where mkAccParam c (shape, arrs, _) = do + tan_shape <- askShape ts <- map (stripArray (shapeRank shape)) <$> mapM lookupType arrs - newParam "acc_p_tan" $ Acc c shape ts NoUniqueness + newParam "acc_p_tan" $ Acc c (tan_shape <> shape) ts NoUniqueness fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS) fwdStreamLambda (Lambda params _ body) = do @@ -404,11 +433,10 @@ fwdSOAC _ _ VJP {} = fwdStm :: Stm SOACS -> ADM () fwdStm (Let pat aux (BasicOp (UpdateAcc safety acc i x))) = do - x_tan <- mapM tangent x - acc_tan <- tangent acc + pat_tan <- newTanPat pat addStm $ Let pat aux $ BasicOp $ UpdateAcc safety acc i x - res_tan <- letExp "tan" $ BasicOp $ UpdateAcc safety acc_tan i x_tan - insertTan (head $ patNames pat) res_tan + addStm . Let pat_tan aux <=< withTansI acc x $ \is acc_tan x_tan' -> do + pure $ BasicOp $ UpdateAcc safety acc_tan (is <> i) x_tan' fwdStm stm@(Let pat aux (BasicOp e)) = do -- XXX: this has to be too naive. unless (any isAcc $ patTypes pat) $ addStm stm @@ -460,6 +488,7 @@ fwdStm (Let pat aux (Loop val_pats loop@(ForLoop i it bound) body)) = do fwdStm (Let pat aux (WithAcc inputs lam)) = do inputs_tan <- forM inputs $ \(shape, arrs, op) -> do arrs_tan <- mapM tangent arrs + tan_shape <- askShape op' <- case op of Nothing -> pure Nothing Just (op_lam, nes) -> do @@ -471,7 +500,7 @@ fwdStm (Let pat aux (WithAcc inputs lam)) = do let (xs, ys) = bimap concat concat $ unzip $ map (splitAt 1 . lambdaParams) lams op_lam' <- mkLambda (idx_params <> xs <> ys) $ mconcat <$> mapM (bodyBind . lambdaBody) lams pure $ Just (op_lam', nes) - pure (shape, arrs_tan, op') + pure (tan_shape <> shape, arrs_tan, op') pat' <- bundleNewPat pat lam' <- fwdWithAccLambda inputs lam addStm $ Let pat' aux $ WithAcc (interleave inputs inputs_tan) lam' diff --git a/src/Futhark/Construct.hs b/src/Futhark/Construct.hs index 9cc7924417..109da7d0c5 100644 --- a/src/Futhark/Construct.hs +++ b/src/Futhark/Construct.hs @@ -107,6 +107,7 @@ module Futhark.Construct fullSliceNum, isFullSlice, sliceAt, + iota64, -- * Result types instantiateShapes, @@ -567,6 +568,11 @@ sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp sliceAt t n slice = fullSlice t $ map sliceDim (take n $ arrayDims t) ++ slice +-- | Produce a straightforward `Int64` `Iota` of the given length with offset 0 +-- and stride 1. +iota64 :: SubExp -> Exp rep +iota64 n = BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 + -- | Like 'fullSlice', but the dimensions are simply numeric. fullSliceNum :: (Num d) => [d] -> [DimIndex d] -> Slice d fullSliceNum dims slice = diff --git a/tests/ad/vec/scatter0.fut b/tests/ad/vec/scatter0.fut new file mode 100644 index 0000000000..1947e12d50 --- /dev/null +++ b/tests/ad/vec/scatter0.fut @@ -0,0 +1,22 @@ +-- Simple scatter, differentiating wrt. values. +-- == +-- entry: fwd fwd_vec +-- input { [0f32, 0f32, 0f32, 0f32] [0i64, 1i64, 2i64, 3i64] [1f32, 2f32, 3f32, 0f32] } +-- output { +-- [[1.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], +-- [0.000000f32, 1.000000f32, 0.000000f32, 0.000000f32], +-- [0.000000f32, 0.000000f32, 1.000000f32, 0.000000f32], +-- [0.000000f32, 0.000000f32, 0.000000f32, 1.000000f32]] +-- } + +def f [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = + scatter (copy xs) is vs + +entry fwd [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = + let g i = jvp (\vs -> f xs is vs) vs (replicate n 0 with [i] = 1) + in tabulate n g + +entry fwd_vec [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices vs)) (indices vs) + in jvp_vec (\vs -> f xs is vs) vs seeds From 676075883eb6f5f1cfad3f422c818bdb5c6ecbfb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Aug 2025 13:51:36 +0200 Subject: [PATCH 10/70] Add map test. --- tests/ad/vec/map0.fut | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 tests/ad/vec/map0.fut diff --git a/tests/ad/vec/map0.fut b/tests/ad/vec/map0.fut new file mode 100644 index 0000000000..3764a378aa --- /dev/null +++ b/tests/ad/vec/map0.fut @@ -0,0 +1,13 @@ +-- == +-- entry: fwd fwd_vec +-- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } +-- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0], [0.0f32, 0.0, 3.0]] } + +def prim = map2 (f32.*) + +entry fwd [n] (xs: [n]f32) (ys: [n]f32) = + tabulate n (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) + in jvp_vec (prim xs) ys seeds From 240edfe1ed41e330ae1627cb7e0a8e421ffda6b2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Aug 2025 14:19:31 +0200 Subject: [PATCH 11/70] More tests, some that fail. --- tests/ad/vec/gather0.fut | 23 +++++++++++++++++++++++ tests/ad/vec/map1.fut | 17 +++++++++++++++++ tests/ad/vec/scatter1.fut | 22 ++++++++++++++++++++++ 3 files changed, 62 insertions(+) create mode 100644 tests/ad/vec/gather0.fut create mode 100644 tests/ad/vec/map1.fut create mode 100644 tests/ad/vec/scatter1.fut diff --git a/tests/ad/vec/gather0.fut b/tests/ad/vec/gather0.fut new file mode 100644 index 0000000000..29c4a29715 --- /dev/null +++ b/tests/ad/vec/gather0.fut @@ -0,0 +1,23 @@ +-- == +-- entry: fwd fwd_vec +-- input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } +-- output { [[1.0, 0.0, 0.0, 0.0], +-- [0.0, 1.0, 0.0, 0.0], +-- [0.0, 0.0, 1.0, 0.0], +-- [0.0, 0.0, 0.0, 1.0]] +-- } +-- input { [4.0,3.0,2.0,1.0] [0i64,0i64,3i64,3i64] } +-- output { [[1.0, 0.0, 0.0, 0.0], +-- [1.0, 0.0, 0.0, 0.0], +-- [0.0, 0.0, 0.0, 1.0], +-- [0.0, 0.0, 0.0, 1.0]] +-- } + +def gather xs is = map (\(i: i64) -> xs[i]) is + +entry fwd [n] [m] (xs: [n]f64) (is: [m]i64) = + transpose (tabulate n (\j -> jvp (`gather` is) xs (replicate n 0 with [j] = 1))) + +entry fwd_vec [n] [m] (xs: [n]f64) (is: [m]i64) = + let seeds = tabulate n (\j -> replicate n 0 with [j] = 1) + in transpose (jvp_vec (`gather` is) xs seeds) diff --git a/tests/ad/vec/map1.fut b/tests/ad/vec/map1.fut new file mode 100644 index 0000000000..d3ec6b17c4 --- /dev/null +++ b/tests/ad/vec/map1.fut @@ -0,0 +1,17 @@ +-- Like map0, but we do not compute the full Jacobian, so the vector size is not +-- the same as the input size. +-- == +-- entry: fwd fwd_vec +-- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } +-- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0]] } + +def prim = map2 (f32.*) + +def k = 2i64 + +entry fwd [n] (xs: [n]f32) (ys: [n]f32) = + tabulate k (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1)) + in jvp_vec (prim xs) ys seeds diff --git a/tests/ad/vec/scatter1.fut b/tests/ad/vec/scatter1.fut new file mode 100644 index 0000000000..c9e6c83a25 --- /dev/null +++ b/tests/ad/vec/scatter1.fut @@ -0,0 +1,22 @@ +-- Simple scatter, differentiating wrt. target. +-- == +-- entry: fwd fwd_vec +-- input { [0f32, 0f32, 0f32, 0f32] [0i64, 1i64] [1f32, 2f32] } +-- output { +-- [[0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], +-- [0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], +-- [0.000000f32, 0.000000f32, 1.000000f32, 0.000000f32], +-- [0.000000f32, 0.000000f32, 0.000000f32, 1.000000f32]] +-- } + +def f [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = + scatter (copy xs) is vs + +entry fwd [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = + let g i = jvp (\xs -> f xs is vs) xs (replicate k 0 with [i] = 1) + in tabulate k g + +entry fwd_vec [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in jvp_vec (\xs -> f xs is vs) xs seeds From 644b8c2ef5021eaeae82b230c57f9ee0c99cb215 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Aug 2025 14:54:05 +0200 Subject: [PATCH 12/70] Tweak the tests. --- tests/ad/vec/map1.fut | 8 ++++---- tests/ad/vec/scan.fut | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/ad/vec/map1.fut b/tests/ad/vec/map1.fut index d3ec6b17c4..12dd98fb73 100644 --- a/tests/ad/vec/map1.fut +++ b/tests/ad/vec/map1.fut @@ -3,15 +3,15 @@ -- == -- entry: fwd fwd_vec -- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } --- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0]] } +-- output { [[5.0f32, 0.0, 0.0], [0.0f32, 7.0, 0.0]] } def prim = map2 (f32.*) def k = 2i64 entry fwd [n] (xs: [n]f32) (ys: [n]f32) = - tabulate k (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) + tabulate k (\i -> jvp (uncurry prim) (xs, ys) (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = - let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1)) - in jvp_vec (prim xs) ys seeds + let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) + in jvp_vec (uncurry prim) (xs, ys) seeds diff --git a/tests/ad/vec/scan.fut b/tests/ad/vec/scan.fut index 607c7e440b..b5d7e88d07 100644 --- a/tests/ad/vec/scan.fut +++ b/tests/ad/vec/scan.fut @@ -8,7 +8,7 @@ def f (xs: []f32) = scan (*) 1 xs entry fwd_vec (xs: []f32) : [][]f32 = let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in transpose (jvp2_vec f xs seeds).1 + in (jvp2_vec f xs seeds).1 entry fwd_map (xs: []f32) : [][]f32 = map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) From e9eac0ac90e35f7674c80b421bac19beb0f43835 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Aug 2025 15:09:10 +0200 Subject: [PATCH 13/70] Another test. --- tests/ad/vec/reshape0.fut | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 tests/ad/vec/reshape0.fut diff --git a/tests/ad/vec/reshape0.fut b/tests/ad/vec/reshape0.fut new file mode 100644 index 0000000000..abbf753e7e --- /dev/null +++ b/tests/ad/vec/reshape0.fut @@ -0,0 +1,11 @@ +-- == +-- entry: fwd_map fwd_vec +-- input { 2i64 2i64 [1,2,3,4] } +-- output { [[[1, 0], [0, 0]], [[0, 1], [0, 0]]] } + +entry fwd_map n m (xs: [n * m]i32) = + tabulate 2 (\i -> jvp unflatten xs (replicate (n * m) 0 with [i] = 1)) + +entry fwd_vec n m (xs: [n * m]i32) = + let seeds = tabulate 2 (\i -> replicate (n * m) 0 with [i] = 1) + in jvp_vec unflatten xs seeds From d1438bcf9059302a68f6c4d36ea151bd8d532459 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 15 Aug 2025 16:01:27 +0200 Subject: [PATCH 14/70] Some hackyish fixes. --- src/Futhark/AD/Fwd.hs | 44 +++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index d380c81b1a..c9fbb5f234 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -356,14 +356,54 @@ zeroFromSubExp (Var v) = do t <- lookupType v letExp "zero" $ zeroExp t +vecPerm :: Shape -> Type -> [Int] +vecPerm tan_shape t = + [shapeRank tan_shape] + ++ [0 .. shapeRank tan_shape - 1] + ++ [shapeRank tan_shape + 1 .. arrayRank t - 1] + +pushTanShape :: VName -> ADM VName +pushTanShape v = do + tan_shape <- askShape + v_t <- lookupType v + if tan_shape == mempty || arrayShape v_t == tan_shape || isAcc v_t + then pure v + else do + let perm = vecPerm tan_shape v_t + letExp (baseString v <> "_tr") $ BasicOp $ Rearrange v perm + +soacInputsWithTangents :: [VName] -> ADM [VName] +soacInputsWithTangents xs = do + xs_tans <- mapM (pushTanShape <=< tangent) xs + pure $ interleave xs xs_tans + +soacResPat :: Pat Type -> ADM (Pat Type, [(Pat Type, VName)]) +soacResPat (Pat pes) = do + pes_tan <- mapM newTan pes + bimap (Pat . interleave pes) mconcat . unzip <$> mapM tweakPatElem pes_tan + where + tweakPatElem pe@(PatElem v v_t) = do + tan_shape <- askShape + if tan_shape == mempty || arrayShape v_t == tan_shape || isAcc v_t + then pure (pe, []) + else do + let perm = vecPerm tan_shape v_t + v' <- newName v + pure (PatElem v' $ rearrangeType perm v_t, [(Pat [pe], v')]) + fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do - pat' <- bundleNewPat pat - xs' <- bundleTangents xs + (pat', to_transpose) <- soacResPat pat + xs' <- soacInputsWithTangents xs f' <- fwdLambda f scs' <- mapM fwdScan scs reds' <- mapM fwdRed reds addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds' + tan_shape <- askShape + forM_ to_transpose $ \(rpat, v) -> do + v_t <- lookupType v + let perm = rearrangeInverse $ vecPerm tan_shape v_t + letBind rpat $ BasicOp $ Rearrange v perm where zeroTans lam = mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType lam From e8c0c14295d1cc0bc97a2fa7ffdaa16ff661133c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Aug 2025 22:13:59 +0200 Subject: [PATCH 15/70] More things work. --- src/Futhark/AD/Fwd.hs | 24 +++++++++++++++++------- tests/ad/vec/arr0.fut | 17 +++++++++++++++++ tests/ad/vec/arr1.fut | 15 +++++++++++++++ 3 files changed, 49 insertions(+), 7 deletions(-) create mode 100644 tests/ad/vec/arr0.fut create mode 100644 tests/ad/vec/arr1.fut diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index e727bb7dbe..9bf07ebced 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -255,8 +255,16 @@ basicFwd pat aux op = do se_tan <- tangent se addStm $ Let pat_tan aux $ BasicOp $ Opaque opaqueop se_tan ArrayLit ses t -> do + tan_shape <- askShape ses_tan <- mapM tangent ses - addStm $ Let pat_tan aux $ BasicOp $ ArrayLit ses_tan t + if tan_shape == mempty + then + addStm $ Let pat_tan aux $ BasicOp $ ArrayLit ses_tan t + else do + pat_tan_tr <- letExp "pat_tan_tr" $ BasicOp $ ArrayLit ses_tan $ t `arrayOfShape` tan_shape + pat_tan_tr_t <- lookupType pat_tan_tr + let perm = vecPerm tan_shape pat_tan_tr_t + addStm $ Let pat_tan aux $ BasicOp $ Rearrange pat_tan_tr perm UnOp unop x -> do let t = unOpType unop x_pe = primExpFromSubExp t x @@ -278,22 +286,24 @@ basicFwd pat aux op = do pure $ BasicOp $ ConvOp cop x_tan Assert {} -> pure () Index arr slice -> do - arr_tan <- tangent arr dims <- shapeDims <$> askShape + arr_tan <- tangent arr let slice' = Slice $ map sliceDim dims <> unSlice slice addStm $ Let pat_tan aux $ BasicOp $ Index arr_tan slice' Update safety arr slice se -> do + dims <- shapeDims <$> askShape arr_tan <- tangent arr se_tan <- tangent se - addStm $ Let pat_tan aux $ BasicOp $ Update safety arr_tan slice se_tan + let slice' = Slice $ map sliceDim dims <> unSlice slice + addStm $ Let pat_tan aux $ BasicOp $ Update safety arr_tan slice' se_tan Concat d (arr :| arrs) w -> do + r <- shapeRank <$> askShape arr_tan <- tangent arr arrs_tans <- mapM tangent arrs - r <- shapeRank <$> askShape addStm $ Let pat_tan aux $ BasicOp $ Concat (d + r) (arr_tan :| arrs_tans) w Manifest arr ds -> do - arr_tan <- tangent arr r <- shapeRank <$> askShape + arr_tan <- tangent arr addStm . Let pat_tan aux . BasicOp $ Manifest arr_tan ([0 .. r - 1] ++ map (+ r) ds) Iota n _ _ it -> do @@ -307,12 +317,12 @@ basicFwd pat aux op = do tan_shape <- askShape addStm $ Let pat_tan aux $ BasicOp $ Scratch t $ shapeDims tan_shape <> shape Reshape arr reshape -> do - arr_tan <- tangent arr shape <- askShape + arr_tan <- tangent arr addStm $ Let pat_tan aux $ BasicOp $ Reshape arr_tan (newshapeInner shape reshape) Rearrange arr perm -> do - arr_tan <- tangent arr r <- shapeRank <$> askShape + arr_tan <- tangent arr addStm . Let pat_tan aux . BasicOp $ Rearrange arr_tan ([0 .. r - 1] <> map (+ r) perm) _ -> error $ "basicFwd: Unsupported op " ++ prettyString op diff --git a/tests/ad/vec/arr0.fut b/tests/ad/vec/arr0.fut new file mode 100644 index 0000000000..a94cf48aee --- /dev/null +++ b/tests/ad/vec/arr0.fut @@ -0,0 +1,17 @@ +-- == +-- tags { autodiff } + +def primal (xs: [2]f64) = xs[0] * xs[1] + +-- == +-- entry: fwd fwd_vec +-- input { [5.0, 7.0] } +-- output { [7.0, 5.0] } + +entry fwd xs = + [ jvp primal xs [1, 0] + , jvp primal xs [0, 1] + ] + +entry fwd_vec xs = + jvp_vec primal xs [[1, 0], [0, 1]] diff --git a/tests/ad/vec/arr1.fut b/tests/ad/vec/arr1.fut new file mode 100644 index 0000000000..584feaee15 --- /dev/null +++ b/tests/ad/vec/arr1.fut @@ -0,0 +1,15 @@ +def primal (x, y) : [2]f64 = [x + y, x * y] + +-- == +-- tags { autodiff } +-- entry: fwd fwd_vec +-- input { 5.0 7.0 } +-- output { [[1.0,7.0], [1.0, 5.0]] } + +entry fwd x y = + [ jvp primal (x, y) (1, 0) + , jvp primal (x, y) (0, 1) + ] + +entry fwd_vec x y = + jvp_vec primal (x, y) [(1, 0), (0, 1)] From f67a2b045e5097ad7304474c378b8e92563cf798 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 17 Aug 2025 22:17:09 +0200 Subject: [PATCH 16/70] Minor fixes. --- src/Futhark/AD/Fwd.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 9bf07ebced..09c0e93bb9 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -522,14 +522,14 @@ fwdStm (Let pat aux (Loop val_pats loop@(WhileLoop v) body)) = do val_pats' <- bundleNewList val_pats pat' <- bundleNewPat pat body' <- - localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $ + localScope (scopeOfFParams (map fst val_pats') <> scopeOfLoopForm loop) . slocal' $ fwdBody body addStm $ Let pat' aux $ Loop val_pats' (WhileLoop v) body' fwdStm (Let pat aux (Loop val_pats loop@(ForLoop i it bound) body)) = do pat' <- bundleNewPat pat val_pats' <- bundleNewList val_pats body' <- - localScope (scopeOfFParams (map fst val_pats) <> scopeOfLoopForm loop) . slocal' $ + localScope (scopeOfFParams (map fst val_pats') <> scopeOfLoopForm loop) . slocal' $ fwdBody body addStm $ Let pat' aux $ Loop val_pats' (ForLoop i it bound) body' fwdStm (Let pat aux (WithAcc inputs lam)) = do From 0a10a5c1ec756c26b26632f1fb7a1c945bea3183 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 19 Aug 2025 12:32:44 +0200 Subject: [PATCH 17/70] Fix vjp2_vec in interpreter. --- src/Language/Futhark/Interpreter.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index 50bb99404b..8f77d05f0b 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -2145,7 +2145,7 @@ initialCtx = def "vjp2_vec" = Just $ fun3 $ \f x seeds -> do v <- apply noLoc mempty f x dvs <- - toArray' (valueShape x) + toArray' (valueShape x) . map (project "1") <$> mapM (doVJP2 f x) (snd (fromArray seeds)) pure $ toTuple [v, dvs] def "acc" = Nothing From b21f5f61c9436dba59434aaa4d2f086e58622ca9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 19 Aug 2025 15:54:32 +0200 Subject: [PATCH 18/70] Start work on vectorised reverse mode AD. --- futhark.cabal | 1 + src/Futhark/AD/Fwd.hs | 38 ++++++++++--------------------------- src/Futhark/AD/Rev.hs | 29 +++++++++++++++++++++------- src/Futhark/AD/Rev/Map.hs | 20 ++++++++++++++++++- src/Futhark/AD/Rev/Monad.hs | 14 ++++++++++---- src/Futhark/AD/Shared.hs | 29 ++++++++++++++++++++++++++++ src/Futhark/Pass/AD.hs | 2 +- tests/ad/vec/map0.fut | 13 ++++++++++--- 8 files changed, 102 insertions(+), 44 deletions(-) create mode 100644 src/Futhark/AD/Shared.hs diff --git a/futhark.cabal b/futhark.cabal index c54d094da6..14228e543e 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -116,6 +116,7 @@ library Futhark.Actions Futhark.AD.Derivatives Futhark.AD.Fwd + Futhark.AD.Shared Futhark.AD.Rev Futhark.AD.Rev.Loop Futhark.AD.Rev.Hist diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 09c0e93bb9..18a4140fa1 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -3,12 +3,16 @@ module Futhark.AD.Fwd (fwdJVP) where import Control.Monad +import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor (bimap, second) +import Data.Functor.Product import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M +import Data.Tuple (Solo (..), getSolo) import Futhark.AD.Derivatives +import Futhark.AD.Shared import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS @@ -173,10 +177,6 @@ instance Tangent SubExpRes where tangent (SubExpRes cs se) = SubExpRes cs <$> tangent se bundleTan (SubExpRes cs se) = bimap (SubExpRes cs) (SubExpRes cs) <$> bundleTan se -asVName :: SubExp -> ADM VName -asVName (Var v) = pure v -asVName (Constant x) = letExp "v" $ BasicOp $ SubExp $ Constant x - withTan :: SubExp -> (SubExp -> ADM (Exp SOACS)) -> @@ -184,16 +184,7 @@ withTan :: withTan x f = do shape <- askShape x_tan <- tangent x - if shape == mempty - then f x_tan - else do - let w = shapeSize 0 shape - x_tan_v <- asVName x_tan - x_tan_p <- newParam "x_tanp" . rowType =<< lookupType x_tan_v - lam <- mkLambda [x_tan_p] $ do - fmap (subExpsRes . pure) . letSubExp "tan" - =<< f (Var (paramName x_tan_p)) - pure $ Op $ Screma w [x_tan_v] (mapSOAC lam) + mapNest shape (MkSolo x_tan) (f . getSolo) withTansI :: VName -> @@ -229,20 +220,11 @@ withTans :: ADM (Exp SOACS) withTans t x y f = do shape <- askShape - x_tan <- asVName =<< tangent x - y_tan <- asVName =<< tangent y - if shape == mempty - then toExp $ f (LeafExp x_tan t) (LeafExp y_tan t) - else do - let w = shapeSize 0 shape - x_tan_p <- newParam "x_tanp" . rowType =<< lookupType x_tan - y_tan_p <- newParam "y_tanp" . rowType =<< lookupType y_tan - lam <- mkLambda [x_tan_p, y_tan_p] $ do - fmap (subExpsRes . pure) . letSubExp "tan" <=< toExp $ - f - (LeafExp (paramName x_tan_p) t) - (LeafExp (paramName y_tan_p) t) - pure $ Op $ Screma w [x_tan, y_tan] (mapSOAC lam) + x_tan <- tangent x + y_tan <- tangent y + mapNest shape (Pair (Identity x_tan) (Identity y_tan)) $ \xy -> do + Pair (Identity x_tan_v) (Identity y_tan_v) <- traverse asVName xy + toExp $ f (LeafExp x_tan_v t) (LeafExp y_tan_v t) basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM () basicFwd pat aux op = do diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index 5a908c170e..f185b20110 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -13,10 +13,12 @@ import Control.Monad.Identity import Data.List ((\\)) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M +import Data.Tuple import Futhark.AD.Derivatives import Futhark.AD.Rev.Loop import Futhark.AD.Rev.Monad import Futhark.AD.Rev.SOAC +import Futhark.AD.Shared import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS @@ -83,10 +85,20 @@ diffBasicOp pat aux e m = (wrt_x, wrt_y) = pdBinOp op (primExpFromSubExp t x) (primExpFromSubExp t y) - pat_adj' = primExpFromSubExp t $ Var pat_adj + adj_shape <- askShape + + adj_x <- letExp "binop_x_adj" + <=< mapNest adj_shape (MkSolo (Var pat_adj)) + $ \(MkSolo pat_adj') -> + let pat_adj'' = primExpFromSubExp t pat_adj' + in toExp $ pat_adj'' ~*~ wrt_x + + adj_y <- letExp "binop_y_adj" + <=< mapNest adj_shape (MkSolo (Var pat_adj)) + $ \(MkSolo pat_adj') -> + let pat_adj'' = primExpFromSubExp t pat_adj' + in toExp $ pat_adj'' ~*~ wrt_y - adj_x <- letExp "binop_x_adj" <=< toExp $ pat_adj' ~*~ wrt_x - adj_y <- letExp "binop_y_adj" <=< toExp $ pat_adj' ~*~ wrt_y updateSubExpAdj x adj_x updateSubExpAdj y adj_y -- @@ -364,11 +376,14 @@ diffLambda res_adjs get_adjs_for (Lambda params _ body) = ts' <- mapM lookupType get_adjs_for pure $ Lambda params ts' body' -revVJP :: (MonadFreshNames m) => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS) -revVJP scope (Lambda params ts body) = - runADM . localScope (scope <> scopeOfLParams params) $ do +revVJP :: (MonadFreshNames m) => Scope SOACS -> Shape -> Lambda SOACS -> m (Lambda SOACS) +revVJP scope shape (Lambda params ts body) = do + runADM shape . localScope (scope <> scopeOfLParams params) $ do + adj_shape <- askShape params_adj <- forM (zip (map resSubExp (bodyResult body)) ts) $ \(se, t) -> - Param mempty <$> maybe (newVName "const_adj") adjVName (subExpVar se) <*> pure t + Param mempty + <$> maybe (newVName "const_adj") adjVName (subExpVar se) + <*> pure (t `arrayOfShape` adj_shape) body' <- localScope (scopeOfLParams params_adj) $ diff --git a/src/Futhark/AD/Rev/Map.hs b/src/Futhark/AD/Rev/Map.hs index 71bf13c7cf..7b4831d892 100644 --- a/src/Futhark/AD/Rev/Map.hs +++ b/src/Futhark/AD/Rev/Map.hs @@ -6,6 +6,7 @@ module Futhark.AD.Rev.Map (vjpMap) where import Control.Monad import Data.Bifunctor (first) +import Debug.Trace import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder @@ -75,6 +76,23 @@ withAcc inputs m = do subAD $ mkLambda (cert_params ++ acc_params) $ m $ map paramName acc_params letTupExp "withhacc_res" $ WithAcc inputs acc_lam +vecPerm :: Shape -> Type -> [Int] +vecPerm adj_shape t = + [shapeRank adj_shape] + ++ [0 .. shapeRank adj_shape - 1] + ++ [shapeRank adj_shape + 1 .. arrayRank t - 1] + +pushAdjShape :: VName -> ADM VName +pushAdjShape v = do + adj_shape <- askShape + v_t <- lookupType v + traceM $ unlines ["pushAdjShape", prettyString v, prettyString adj_shape, prettyString v_t, show [adj_shape == mempty, arrayShape v_t == adj_shape, isAcc v_t]] + if adj_shape == mempty || arrayShape v_t == adj_shape || isAcc v_t + then pure v + else do + let perm = vecPerm adj_shape v_t + letExp (baseString v <> "_tr") $ BasicOp $ Rearrange v perm + -- | Perform VJP on a Map. The 'Adj' list is the adjoints of the -- result of the map. vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM () @@ -161,7 +179,7 @@ vjpMap ops pat_adj aux w map_lam as = returnSweepCode $ do pat_adj_vals <- forM (zip pat_adj (lambdaReturnType map_lam)) $ \(adj, t) -> case t of Acc {} -> letExp "acc_adj_rep" . BasicOp . Replicate (Shape [w]) . Var =<< adjVal adj - _ -> adjVal adj + _ -> pushAdjShape =<< adjVal adj pat_adj_params <- mapM (newParam "map_adj_p" . rowType <=< lookupType) pat_adj_vals diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 18e95efe2c..d670338c68 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -13,6 +13,7 @@ module Futhark.AD.Rev.Monad Adj (..), InBounds (..), Sparse (..), + askShape, adjFromParam, adjFromVar, lookupAdj, @@ -55,6 +56,7 @@ module Futhark.AD.Rev.Monad where import Control.Monad +import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor (second) import Data.List (foldl') @@ -199,12 +201,13 @@ data RState = RState stateNameSource :: VNameSource } -newtype ADM a = ADM (BuilderT SOACS (State RState) a) +newtype ADM a = ADM (BuilderT SOACS (ReaderT Shape (State RState)) a) deriving ( Functor, Applicative, Monad, MonadState RState, + MonadReader Shape, MonadFreshNames, HasScope SOACS, LocalScope SOACS @@ -223,12 +226,15 @@ instance MonadFreshNames (State RState) where getNameSource = gets stateNameSource putNameSource src = modify (\env -> env {stateNameSource = src}) -runADM :: (MonadFreshNames m) => ADM a -> m a -runADM (ADM m) = +askShape :: ADM Shape +askShape = ADM $ lift ask + +runADM :: (MonadFreshNames m) => Shape -> ADM a -> m a +runADM shape (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState - (fst <$> runBuilderT m mempty) + (runReaderT (fst <$> runBuilderT m mempty) shape) (RState mempty mempty mempty vn) adjVal :: Adj -> ADM VName diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs new file mode 100644 index 0000000000..0feafac833 --- /dev/null +++ b/src/Futhark/AD/Shared.hs @@ -0,0 +1,29 @@ +-- | Various definitions used for both forward and reverse mode. +module Futhark.AD.Shared (asVName, mapNest) where + +import Control.Monad +import Data.Foldable +import Futhark.Construct +import Futhark.IR.SOACS + +asVName :: (MonadBuilder m) => SubExp -> m VName +asVName (Var v) = pure v +asVName (Constant x) = letExp "v" $ BasicOp $ SubExp $ Constant x + +mapNest :: + (MonadBuilder m, Rep m ~ SOACS, Traversable f) => + Shape -> + f SubExp -> + (f SubExp -> m (Exp SOACS)) -> + m (Exp SOACS) +mapNest shape x f = do + if shape == mempty + then f x + else do + let w = shapeSize 0 shape + x_v <- traverse asVName x + x_p <- traverse (newParam "xp" . rowType <=< lookupType) x_v + lam <- mkLambda (toList x_p) $ do + fmap (subExpsRes . pure) . letSubExp "tan" + =<< f (fmap (Var . paramName) x_p) + pure $ Op $ Screma w (toList x_v) (mapSOAC lam) diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index 5e53db75ec..396e3cd6ba 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -40,7 +40,7 @@ onStm mode scope (Let pat aux (Op (VJP shape args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do - lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope lam' + lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope shape lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope else pure $ oneStm $ Let pat aux $ Op $ VJP shape args vec lam' onStm mode scope (Let pat aux (Op (JVP shape args vec lam))) = do diff --git a/tests/ad/vec/map0.fut b/tests/ad/vec/map0.fut index 3764a378aa..8ae207d68e 100644 --- a/tests/ad/vec/map0.fut +++ b/tests/ad/vec/map0.fut @@ -1,13 +1,20 @@ -- == --- entry: fwd fwd_vec +-- entry: fwd_map fwd_vec rev_map rev_vec -- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } -- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0], [0.0f32, 0.0, 3.0]] } def prim = map2 (f32.*) -entry fwd [n] (xs: [n]f32) (ys: [n]f32) = +entry fwd_map [n] (xs: [n]f32) (ys: [n]f32) = tabulate n (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = - let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) in jvp_vec (prim xs) ys seeds + +entry rev_map [n] (xs: [n]f32) (ys: [n]f32) = + transpose (tabulate n (\i -> vjp (prim xs) ys (replicate n 0 with [i] = 1))) + +entry rev_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in transpose (vjp_vec (prim xs) ys seeds) From ebfae4298c615c7ebce2488b8a8beff21da10ac5 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 20 Aug 2025 13:23:06 +0200 Subject: [PATCH 19/70] Support primitive functions properly. --- src/Futhark/AD/Fwd.hs | 36 ++++++++++++++++++++++++++++++------ src/Futhark/AD/Shared.hs | 2 +- tests/ad/vec/primfun.fut | 12 ++++++++++++ 3 files changed, 43 insertions(+), 7 deletions(-) create mode 100644 tests/ad/vec/primfun.fut diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 18a4140fa1..e8a815d735 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -7,6 +7,7 @@ import Control.Monad.Identity import Control.Monad.Reader import Control.Monad.State.Strict import Data.Bifunctor (bimap, second) +import Data.Foldable import Data.Functor.Product import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M @@ -166,7 +167,13 @@ instance Tangent VName where pure (v, v_tan) instance Tangent SubExp where - tangent (Constant c) = pure $ constant $ blankPrimValue $ primValueType c + tangent (Constant c) = do + tan_shape <- askShape + if tan_shape == mempty + then pure $ constant $ blankPrimValue pt + else letSubExp "const_implicit_tan" $ zeroExp $ Prim pt `arrayOfShape` tan_shape + where + pt = primValueType c tangent (Var v) = Var <$> tangent v bundleTan c@Constant {} = do c_tan <- tangent c @@ -226,6 +233,20 @@ withTans t x y f = do Pair (Identity x_tan_v) (Identity y_tan_v) <- traverse asVName xy toExp $ f (LeafExp x_tan_v t) (LeafExp y_tan_v t) +withAnyTans :: + (Traversable f) => + f SubExp -> + ([PrimExp VName] -> PrimExp VName) -> + ADM (Exp SOACS) +withAnyTans xs f = do + shape <- askShape + xs_tan <- traverse tangent xs + mapNest shape xs_tan $ \xs_tan' -> do + xs_tan'' <- forM xs_tan' $ \se -> do + ~(Prim t) <- subExpType se + pure $ primExpFromSubExp t se + toExp $ f $ toList xs_tan'' + basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM () basicFwd pat aux op = do pat_tan <- newTanPat pat @@ -470,11 +491,9 @@ fwdStm stm@(Let pat aux (BasicOp e)) = do -- XXX: this has to be too naive. unless (any isAcc $ patTypes pat) $ addStm stm basicFwd pat aux e -fwdStm stm@(Let pat _ (Apply f args _ _)) +fwdStm stm@(Let pat aux (Apply f args _ _)) | Just (ret, argts) <- M.lookup f builtInFunctions = do addStm stm - arg_tans <- - zipWith primExpFromSubExp argts <$> mapM (tangent . fst) args pat_tan <- newTanPat pat let arg_pes = zipWith primExpFromSubExp argts (map fst args) case pdBuiltin f arg_pes of @@ -492,8 +511,13 @@ fwdStm stm@(Let pat _ (Apply f args _ _)) _ -> error $ "fwdStm.convertTo: " ++ prettyString (f, tt, e_t) where e_t = primExpType e - zipWithM_ (letBindNames . pure) (patNames pat_tan) - =<< mapM toExp (zipWith (~*~) (map (convertTo ret) arg_tans) derivs) + + auxing aux . letBind pat_tan <=< withAnyTans (map fst args) $ + \arg_tans' -> + foldl1 (~+~) $ zipWith (~*~) (map (convertTo ret) arg_tans') derivs + +-- zipWithM_ (letBindNames . pure) (patNames pat_tan) +-- =<< mapM toExp (zipWith (~*~) (map (convertTo ret) arg_tans) derivs) fwdStm (Let pat aux (Match ses cases defbody (MatchDec ret ifsort))) = do cases' <- slocal' $ mapM (traverse fwdBody) cases defbody' <- slocal' $ fwdBody defbody diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index 0feafac833..cbdec4787b 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -8,7 +8,7 @@ import Futhark.IR.SOACS asVName :: (MonadBuilder m) => SubExp -> m VName asVName (Var v) = pure v -asVName (Constant x) = letExp "v" $ BasicOp $ SubExp $ Constant x +asVName (Constant x) = letExp "asv" $ BasicOp $ SubExp $ Constant x mapNest :: (MonadBuilder m, Rep m ~ SOACS, Traversable f) => diff --git a/tests/ad/vec/primfun.fut b/tests/ad/vec/primfun.fut new file mode 100644 index 0000000000..300f88469e --- /dev/null +++ b/tests/ad/vec/primfun.fut @@ -0,0 +1,12 @@ +-- == +-- entry: fwd_map fwd_vec +-- input { [1f32, 2f32, 3f32] } +-- output { [[0.5f32, 0.0, 0.0], [0.0f32, 0.35355338, 0.0], [0.0f32, 0.0, 0.28867513]] } + +def primal = map f32.sqrt + +entry fwd_map [n] (xs: [n]f32) = + tabulate n (\i -> jvp primal xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) = + jvp_vec primal xs (tabulate n (\i -> replicate n 0 with [i] = 1)) From 168454160c937c6bb35b64f18077d9cb6f448ceb Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 21 Aug 2025 15:24:33 +0200 Subject: [PATCH 20/70] Make unops and primfuns work. --- src/Futhark/AD/Rev.hs | 18 +++++++++++------- tests/ad/vec/primfun.fut | 8 +++++++- 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index f185b20110..c58be90f4e 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -69,11 +69,12 @@ diffBasicOp pat aux e m = returnSweepCode $ do let t = unOpType op - contrib <- do - let x_pe = primExpFromSubExp t x - pat_adj' = primExpFromSubExp t (Var pat_adj) - dx = pdUnOp op x_pe - letExp "contrib" <=< toExp $ pat_adj' ~*~ dx + + adj_shape <- askShape + + contrib <- letExp "unop_contrib" <=< mapNest adj_shape (MkSolo (Var pat_adj)) $ + \(MkSolo pat_adj') -> + toExp $ primExpFromSubExp t pat_adj' ~*~ pdUnOp op (primExpFromSubExp t x) updateSubExpAdj x contrib -- @@ -258,7 +259,6 @@ diffStm stm@(Let pat _ (Apply f args _ _)) m pat_adj <- lookupAdjVal =<< patName pat let arg_pes = zipWith primExpFromSubExp argts (map fst args) - pat_adj' = primExpFromSubExp ret (Var pat_adj) convert ft tt | ft == tt = id convert (IntType ft) (IntType tt) = ConvOpExp (SExt ft tt) @@ -267,13 +267,17 @@ diffStm stm@(Let pat _ (Apply f args _ _)) m convert (FloatType ft) Bool = ConvOpExp (FToB ft) convert ft tt = error $ "diffStm.convert: " ++ prettyString (f, ft, tt) + adj_shape <- askShape + contribs <- case pdBuiltin f arg_pes of Nothing -> error $ "No partial derivative defined for builtin function: " ++ prettyString f Just derivs -> forM (zip derivs argts) $ \(deriv, argt) -> - letExp "contrib" <=< toExp . convert ret argt $ pat_adj' ~*~ deriv + letExp "apply_contrib" <=< mapNest adj_shape (MkSolo (Var pat_adj)) $ + \(MkSolo pat_adj') -> + toExp $ convert ret argt $ primExpFromSubExp ret pat_adj' ~*~ deriv zipWithM_ updateSubExpAdj (map fst args) contribs diffStm stm@(Let pat _ (Match ses cases defbody _)) m = do diff --git a/tests/ad/vec/primfun.fut b/tests/ad/vec/primfun.fut index 300f88469e..3aecde8b94 100644 --- a/tests/ad/vec/primfun.fut +++ b/tests/ad/vec/primfun.fut @@ -1,5 +1,5 @@ -- == --- entry: fwd_map fwd_vec +-- entry: fwd_map fwd_vec rev_map rev_vec -- input { [1f32, 2f32, 3f32] } -- output { [[0.5f32, 0.0, 0.0], [0.0f32, 0.35355338, 0.0], [0.0f32, 0.0, 0.28867513]] } @@ -10,3 +10,9 @@ entry fwd_map [n] (xs: [n]f32) = entry fwd_vec [n] (xs: [n]f32) = jvp_vec primal xs (tabulate n (\i -> replicate n 0 with [i] = 1)) + +entry rev_map [n] (xs: [n]f32) = + tabulate n (\i -> vjp primal xs (replicate n 0 with [i] = 1)) + +entry rev_vec [n] (xs: [n]f32) = + vjp_vec primal xs (tabulate n (\i -> replicate n 0 with [i] = 1)) From 4819c5c1c2bf3743290bb5f83400eb222a974111 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 22 Aug 2025 16:04:59 +0200 Subject: [PATCH 21/70] Work on vectorised reductions. --- src/Futhark/AD/Fwd.hs | 21 +++---- src/Futhark/AD/Rev.hs | 14 ++--- src/Futhark/AD/Rev/Map.hs | 2 - src/Futhark/AD/Rev/Reduce.hs | 76 ++++++++++++++++-------- src/Futhark/AD/Shared.hs | 8 ++- src/Futhark/Construct.hs | 8 +++ src/Futhark/IR/TypeCheck.hs | 2 +- tests/ad/vec/{reduce.fut => reduce0.fut} | 8 ++- tests/ad/vec/reduce1.fut | 16 +++++ tests/ad/vec/reduce2.fut | 18 ++++++ tests/ad/vec/reduce3.fut | 71 ++++++++++++++++++++++ 11 files changed, 192 insertions(+), 52 deletions(-) rename tests/ad/vec/{reduce.fut => reduce0.fut} (65%) create mode 100644 tests/ad/vec/reduce1.fut create mode 100644 tests/ad/vec/reduce2.fut create mode 100644 tests/ad/vec/reduce3.fut diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index e8a815d735..7a87ff05e6 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -367,10 +367,7 @@ zeroFromSubExp (Var v) = do letExp "zero" $ zeroExp t vecPerm :: Shape -> Type -> [Int] -vecPerm tan_shape t = - [shapeRank tan_shape] - ++ [0 .. shapeRank tan_shape - 1] - ++ [shapeRank tan_shape + 1 .. arrayRank t - 1] +vecPerm = auxPerm pushTanShape :: VName -> ADM VName pushTanShape v = do @@ -387,14 +384,15 @@ soacInputsWithTangents xs = do xs_tans <- mapM (pushTanShape <=< tangent) xs pure $ interleave xs xs_tans -soacResPat :: Pat Type -> ADM (Pat Type, [(Pat Type, VName)]) -soacResPat (Pat pes) = do +soacResPat :: Int -> Int -> Pat Type -> ADM (Pat Type, [(Pat Type, VName)]) +soacResPat scan_res red_res (Pat pes) = do pes_tan <- mapM newTan pes - bimap (Pat . interleave pes) mconcat . unzip <$> mapM tweakPatElem pes_tan + bimap (Pat . interleave pes) mconcat . unzip <$> zipWithM tweakPatElem [0 ..] pes_tan where - tweakPatElem pe@(PatElem v v_t) = do + isRedRes i = i >= scan_res && i < scan_res + red_res + tweakPatElem i pe@(PatElem v v_t) = do tan_shape <- askShape - if tan_shape == mempty || arrayShape v_t == tan_shape || isAcc v_t + if isRedRes i || tan_shape == mempty || arrayShape v_t == tan_shape || isAcc v_t then pure (pe, []) else do let perm = vecPerm tan_shape v_t @@ -403,7 +401,7 @@ soacResPat (Pat pes) = do fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM () fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do - (pat', to_transpose) <- soacResPat pat + (pat', to_transpose) <- soacResPat (scanResults scs) (redResults reds) pat xs' <- soacInputsWithTangents xs f' <- fwdLambda f scs' <- mapM fwdScan scs @@ -515,9 +513,6 @@ fwdStm stm@(Let pat aux (Apply f args _ _)) auxing aux . letBind pat_tan <=< withAnyTans (map fst args) $ \arg_tans' -> foldl1 (~+~) $ zipWith (~*~) (map (convertTo ret) arg_tans') derivs - --- zipWithM_ (letBindNames . pure) (patNames pat_tan) --- =<< mapM toExp (zipWith (~*~) (map (convertTo ret) arg_tans) derivs) fwdStm (Let pat aux (Match ses cases defbody (MatchDec ret ifsort))) = do cases' <- slocal' $ mapM (traverse fwdBody) cases defbody' <- slocal' $ fwdBody defbody diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index c58be90f4e..4f9e02c828 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -141,10 +141,10 @@ diffBasicOp pat aux e m = -- Rearrange arr perm -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m + r <- shapeRank <$> askShape returnSweepCode $ - void $ - updateAdj arr <=< letExp "adj_rearrange" . BasicOp $ - Rearrange pat_adj (rearrangeInverse perm) + void . updateAdj arr <=< letExp "adj_rearrange" . BasicOp $ + Rearrange pat_adj ([0 .. r - 1] <> map (+ r) (rearrangeInverse perm)) -- Replicate (Shape []) (Var se) -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m @@ -374,11 +374,9 @@ diffBody res_adjs get_adjs_for (Body () stms res) = subAD $ diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS) diffLambda res_adjs get_adjs_for (Lambda params _ body) = - localScope (scopeOfLParams params) $ do - Body () stms res <- diffBody res_adjs get_adjs_for body - let body' = Body () stms $ takeLast (length get_adjs_for) res - ts' <- mapM lookupType get_adjs_for - pure $ Lambda params ts' body' + mkLambda params $ do + res <- bodyBind =<< diffBody res_adjs get_adjs_for body + pure $ takeLast (length get_adjs_for) res revVJP :: (MonadFreshNames m) => Scope SOACS -> Shape -> Lambda SOACS -> m (Lambda SOACS) revVJP scope shape (Lambda params ts body) = do diff --git a/src/Futhark/AD/Rev/Map.hs b/src/Futhark/AD/Rev/Map.hs index 7b4831d892..78f1db4920 100644 --- a/src/Futhark/AD/Rev/Map.hs +++ b/src/Futhark/AD/Rev/Map.hs @@ -6,7 +6,6 @@ module Futhark.AD.Rev.Map (vjpMap) where import Control.Monad import Data.Bifunctor (first) -import Debug.Trace import Futhark.AD.Rev.Monad import Futhark.Analysis.PrimExp.Convert import Futhark.Builder @@ -86,7 +85,6 @@ pushAdjShape :: VName -> ADM VName pushAdjShape v = do adj_shape <- askShape v_t <- lookupType v - traceM $ unlines ["pushAdjShape", prettyString v, prettyString adj_shape, prettyString v_t, show [adj_shape == mempty, arrayShape v_t == adj_shape, isAcc v_t]] if adj_shape == mempty || arrayShape v_t == adj_shape || isAcc v_t then pure v else do diff --git a/src/Futhark/AD/Rev/Reduce.hs b/src/Futhark/AD/Rev/Reduce.hs index fd745c728e..73e645e910 100644 --- a/src/Futhark/AD/Rev/Reduce.hs +++ b/src/Futhark/AD/Rev/Reduce.hs @@ -9,7 +9,9 @@ module Futhark.AD.Rev.Reduce where import Control.Monad +import Data.Tuple import Futhark.AD.Rev.Monad +import Futhark.AD.Shared import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS @@ -78,12 +80,19 @@ diffReduce _ops [adj] w [a] red | Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, isAdd op = do adj_rep <- - letExp (baseString adj <> "_rep") $ - BasicOp $ - Replicate (Shape [w]) $ - Var adj + transposeIfNeeded <=< letExp (baseString adj <> "_rep") $ + BasicOp (Replicate (Shape [w]) (Var adj)) void $ updateAdj a adj_rep where + transposeIfNeeded v = do + adj_shape <- askShape + if adj_shape == mempty + then pure v + else do + v_t <- lookupType v + let perm = [1 .. shapeRank adj_shape] ++ [0] ++ [shapeRank adj_shape + 1 .. arrayRank v_t - 1] + letExp (baseString v <> "_tr") $ BasicOp $ Rearrange v perm + isAdd FAdd {} = True isAdd Add {} = True isAdd _ = False @@ -117,10 +126,19 @@ diffReduce ops pat_adj w as red = do f_adj <- vjpLambda ops (map adjFromVar pat_adj) as_params f - as_adj <- letTupExp "adjs" $ Op $ Screma w (ls ++ as ++ rs) (mapSOAC f_adj) + as_adj <- + letTupExp "red_contribs" $ Op $ Screma w (ls ++ as ++ rs) (mapSOAC f_adj) - zipWithM_ updateAdj as as_adj + zipWithM_ updateAdj as =<< mapM transposeIfNeeded as_adj where + transposeIfNeeded v = do + adj_shape <- askShape + if adj_shape == mempty + then pure v + else do + v_t <- lookupType v + letExp (baseString v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) + renameRed (Reduce comm lam nes) = Reduce comm <$> renameLambda lam <*> pure nes @@ -250,7 +268,8 @@ diffMulReduce :: VjpOps -> VName -> StmAux () -> SubExp -> BinOp -> SubExp -> VName -> ADM () -> ADM () diffMulReduce _ops x aux w mul ne as m = do let t = binOpType mul - let const_zero = eSubExp $ Constant $ blankPrimValue t + let zero = Constant $ blankPrimValue t + const_zero = eSubExp zero a_param <- newParam "a" $ Prim t map_lam <- @@ -291,38 +310,43 @@ diffMulReduce _ops x aux w mul ne as m = do x_adj <- lookupAdjVal x + adj_shape <- askShape + + zero_contrib <- letExp "zero_contrib" $ BasicOp $ Replicate adj_shape zero + a_param_rev <- newParam "a" $ Prim t map_lam_rev <- mkLambda [a_param_rev] $ fmap varsRes . letTupExp "adj_res" =<< eIf (toExp $ 0 .==. le64 zr_count) - ( eBody $ - pure $ - eBinOp mul (eSubExp $ Var x_adj) $ - eBinOp (getDiv t) (eSubExp $ Var nz_prods) $ - eParam a_param_rev + ( eBody + [ mapNest adj_shape (MkSolo (Var x_adj)) $ \(MkSolo x_adj') -> + eBinOp mul (eSubExp x_adj') $ + eBinOp (getDiv t) (eVar nz_prods) $ + eParam a_param_rev + ] ) - ( eBody $ - pure $ - eIf + ( eBody + [ eIf (toExp $ 1 .==. le64 zr_count) - ( eBody $ - pure $ - eIf + ( eBody + [ eIf (eCmpOp (CmpEq t) (eParam a_param_rev) const_zero) - ( eBody $ - pure $ - eBinOp mul (eSubExp $ Var x_adj) $ - eSubExp $ - Var nz_prods + ( eBody + [ mapNest adj_shape (MkSolo (Var x_adj)) $ + \(MkSolo x_adj') -> + eBinOp mul (eSubExp x_adj') $ eVar nz_prods + ] ) - (eBody $ pure const_zero) + (eBody [eVar zero_contrib]) + ] ) - (eBody $ pure const_zero) + (eBody [eVar zero_contrib]) + ] ) - as_adjup <- letExp "adjs" $ Op $ Screma w [as] $ mapSOAC map_lam_rev + as_adjup <- letExp "prod_contrib" $ Op $ Screma w [as] $ mapSOAC map_lam_rev updateAdj as as_adjup where diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index cbdec4787b..6a34253efa 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -1,11 +1,17 @@ -- | Various definitions used for both forward and reverse mode. -module Futhark.AD.Shared (asVName, mapNest) where +module Futhark.AD.Shared (auxPerm, asVName, mapNest) where import Control.Monad import Data.Foldable import Futhark.Construct import Futhark.IR.SOACS +auxPerm :: Shape -> Type -> [Int] +auxPerm aux_shape t = + [shapeRank aux_shape] + ++ [0 .. shapeRank aux_shape - 1] + ++ [shapeRank aux_shape + 1 .. arrayRank t - 1] + asVName :: (MonadBuilder m) => SubExp -> m VName asVName (Var v) = pure v asVName (Constant x) = letExp "asv" $ BasicOp $ SubExp $ Constant x diff --git a/src/Futhark/Construct.hs b/src/Futhark/Construct.hs index 109da7d0c5..3cd2503a7a 100644 --- a/src/Futhark/Construct.hs +++ b/src/Futhark/Construct.hs @@ -68,6 +68,7 @@ module Futhark.Construct -- * Monadic expression builders eSubExp, + eVar, eParam, eMatch', eMatch, @@ -204,6 +205,13 @@ eSubExp :: m (Exp (Rep m)) eSubExp = pure . BasicOp . SubExp +-- | Turn a variable into a monad expression, through 'eSubExp'. +eVar :: + (MonadBuilder m) => + VName -> + m (Exp (Rep m)) +eVar = eSubExp . Var + -- | Treat a parameter as a monadic expression. eParam :: (MonadBuilder m) => diff --git a/src/Futhark/IR/TypeCheck.hs b/src/Futhark/IR/TypeCheck.hs index 9c7df8c541..50e3acb8bc 100644 --- a/src/Futhark/IR/TypeCheck.hs +++ b/src/Futhark/IR/TypeCheck.hs @@ -115,7 +115,7 @@ instance (Checkable rep) => Show (ErrorCase rep) where show (InvalidPatError pat t desc) = "Pat\n" ++ prettyString pat - ++ "\ncannot match value of type\n" + ++ "\ncannot match expression of type\n" ++ T.unpack (prettyTupleLines t) ++ end where diff --git a/tests/ad/vec/reduce.fut b/tests/ad/vec/reduce0.fut similarity index 65% rename from tests/ad/vec/reduce.fut rename to tests/ad/vec/reduce0.fut index a9b970c3aa..0ad76d8189 100644 --- a/tests/ad/vec/reduce.fut +++ b/tests/ad/vec/reduce0.fut @@ -1,5 +1,5 @@ -- == --- entry: fwd_vec fwd_map +-- entry: fwd_vec fwd_map rev_vec -- input { [1f32, 2f32, 3f32] } -- output { [6f32, 3f32, 2f32] } @@ -13,3 +13,9 @@ entry fwd_vec (xs: []f32) : []f32 = entry fwd_map (xs: []f32) : []f32 = map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) (indices xs) + +-- No rev_map because it would just get optimised away. The rev_vec is pointless +-- enough already. + +entry rev_vec (xs: []f32) : []f32 = + head (vjp_vec f xs [1]) diff --git a/tests/ad/vec/reduce1.fut b/tests/ad/vec/reduce1.fut new file mode 100644 index 0000000000..fa5921d42f --- /dev/null +++ b/tests/ad/vec/reduce1.fut @@ -0,0 +1,16 @@ +-- Reduce with addition. +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec +-- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } +-- output { [1.0f32, 1.0, 1.0, 1.0, 1.0] } + +entry fwd_map [n] (a: [n]f32) = + tabulate n (\i -> jvp (reduce (+) 0) a (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (reduce (+) 0) a seeds + +entry rev_vec [n] (a: [n]f32) = + head (vjp_vec (reduce (+) 0) a [1]) diff --git a/tests/ad/vec/reduce2.fut b/tests/ad/vec/reduce2.fut new file mode 100644 index 0000000000..d070b6eb1a --- /dev/null +++ b/tests/ad/vec/reduce2.fut @@ -0,0 +1,18 @@ +-- Reduce with vectorised addition. +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_vec +-- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } +-- output { [[1.0f32, 1.0], [1.0f32, 1.0], [1.0f32, 1.0], [1.0f32, 1.0], [1.0f32, 1.0]] } + +def primal [n] [k] (a: [n][k]f32) = + reduce (map2 (+)) (replicate k 0) a + +entry fwd_map [n] [k] (a: [n][k]f32) = + tabulate n (\i -> jvp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) + +entry fwd_vec [n] [k] (a: [n][k]f32) = + jvp_vec primal a (tabulate n (\i -> (replicate n (replicate k 0) with [i] = replicate k 1))) + +entry rev_vec [n] [k] (a: [n][k]f32) = + head (vjp_vec primal a [replicate k 1]) diff --git a/tests/ad/vec/reduce3.fut b/tests/ad/vec/reduce3.fut new file mode 100644 index 0000000000..25286aa9fb --- /dev/null +++ b/tests/ad/vec/reduce3.fut @@ -0,0 +1,71 @@ +-- Reduce with 2x2 matrix multiplication. +-- == +-- tags { autodiff } +-- entry: fwd_map rev_map fwd_vec rev_vec +-- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } +-- output { +-- [[[92.0f32, 36.0, 0.0, 0.0], +-- [8.0f32, 20.0, 16.0, 40.0], +-- [32.0f32, 16.0, 20.0, 10.0], +-- [23.0f32, 0.0, 36.0, 0.0]], +-- [[59.0f32, 23.0, 0.0, 0.0], +-- [5.0f32, 13.0, 10.0, 26.0], +-- [24.0f32, 8.0, 15.0, 5.0], +-- [0.0f32, 23.0, 0.0, 36.0]], +-- [[0.0f32, 0.0, 92.0, 36.0], +-- [24.0f32, 60.0, 32.0, 80.0], +-- [80.0f32, 40.0, 52.0, 26.0], +-- [59.0f32, 0.0, 92.0, 0.0]], +-- [[0.0f32, 0.0, 59.0, 23.0], +-- [15.0f32, 39.0, 20.0, 52.0], +-- [60.0f32, 20.0, 39.0, 13.0], +-- [0.0f32, 59.0, 0.0, 92.0]]] +-- } + +def mm2by2 (a1: f32, b1: f32, c1: f32, d1: f32) + (a2: f32, b2: f32, c2: f32, d2: f32) = + ( a1 * a2 + b1 * c2 + , a1 * b2 + b1 * d2 + , c1 * a2 + d1 * c2 + , c1 * b2 + d1 * d2 + ) + +def primal [n] (xs: [n](f32, f32, f32, f32)) = + reduce mm2by2 (1, 0, 0, 1) xs + +def fromarr = \(x: [4]f32) -> (x[0], x[1], x[2], x[3]) + +def fromarrs = map fromarr +def toarrs = map (\(a, b, c, d) -> [a, b, c, d]) + +def onehot_1d n x = + tabulate n (\i -> f32.bool (i == x)) + +def onehot_2d n m x y = + tabulate_2d n m (\i j -> f32.bool ((i, j) == (x, y))) + +entry fwd_map [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + in tabulate (n * 4) (\i -> jvp primal input (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) + |> toarrs + |> transpose + |> map unflatten + +entry fwd_vec [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) + in jvp_vec primal input seeds + |> toarrs + |> transpose + |> map unflatten + +entry rev_map [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + in tabulate 4 (\i -> vjp primal input (fromarr (onehot_1d 4 i))) + |> map toarrs + +entry rev_vec [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate 4 (\i -> fromarr (onehot_1d 4 i)) + in vjp_vec primal input seeds + |> map toarrs From 512c0ff0168d2aedd1bbd838e7f71c8a2faedca3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 26 Aug 2025 12:01:39 +0200 Subject: [PATCH 22/70] Start on scan. --- src/Futhark/AD/Rev.hs | 2 +- src/Futhark/AD/Rev/Monad.hs | 41 ++++++++++++++++++++++++++-- src/Futhark/AD/Rev/Scan.hs | 6 ++-- src/Futhark/AD/Shared.hs | 20 +++++++++++++- tests/ad/vec/{scan.fut => scan0.fut} | 12 +++++++- 5 files changed, 72 insertions(+), 9 deletions(-) rename tests/ad/vec/{scan.fut => scan0.fut} (54%) diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index 4f9e02c828..f3995c9db0 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -384,7 +384,7 @@ revVJP scope shape (Lambda params ts body) = do adj_shape <- askShape params_adj <- forM (zip (map resSubExp (bodyResult body)) ts) $ \(se, t) -> Param mempty - <$> maybe (newVName "const_adj") adjVName (subExpVar se) + <$> maybe (newVName "const_res_adj") adjVName (subExpVar se) <*> pure (t `arrayOfShape` adj_shape) body' <- diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index d670338c68..882190614c 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -52,6 +52,8 @@ module Futhark.AD.Rev.Monad lookupLoopTape, substLoopTape, renameLoopTape, + -- + locallyNonvectorised, ) where @@ -62,6 +64,7 @@ import Data.Bifunctor (second) import Data.List (foldl') import Data.Map qualified as M import Data.Maybe +import Futhark.AD.Shared import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.PrimExp.Convert import Futhark.Builder @@ -238,7 +241,7 @@ runADM shape (ADM m) = (RState mempty mempty mempty vn) adjVal :: Adj -> ADM VName -adjVal (AdjVal se) = letExp "const_adj" $ BasicOp $ SubExp se +adjVal (AdjVal se) = letExp "const_val_adj" $ BasicOp $ SubExp se adjVal (AdjSparse sparse) = sparseArray sparse adjVal (AdjZero shape t) = zeroArray shape $ Prim t @@ -396,6 +399,7 @@ vecOpExp bop x y = do lookupAdj :: VName -> ADM Adj lookupAdj v = do maybeAdj <- gets $ M.lookup v . stateAdjs + adj_shape <- askShape case maybeAdj of Nothing -> do v_t <- lookupType v @@ -403,7 +407,7 @@ lookupAdj v = do Acc _ shape [Prim t] _ -> pure $ AdjZero shape t Acc _ shape [t] _ -> pure $ AdjZero (shape <> arrayShape t) (elemType t) Acc {} -> error $ "lookupAdj: Non-singleton accumulator adjoint: " <> prettyString v_t - _ -> pure $ AdjZero (arrayShape v_t) (elemType v_t) + _ -> pure $ AdjZero (adj_shape <> arrayShape v_t) (elemType v_t) Just v_adj -> pure v_adj lookupAdjVal :: VName -> ADM VName @@ -562,6 +566,39 @@ substLoopTape v v' = mapM_ (setLoopTape v') =<< lookupLoopTape v renameLoopTape :: Substitutions -> ADM () renameLoopTape = mapM_ (uncurry substLoopTape) . M.toList +-- | Disable vectorised AD within the provided action. This results in a map +-- that computes each adjoint explicitly, then assembles the resulting adjoint +-- vectors. This is useful for constructs (such as scans) where vectorised AD is +-- impractical or inefficient. +locallyNonvectorised :: + (FreeIn e) => + -- | Something that represents all the free variables used in the action. + -- Usually just an expression or statement. + e -> + ADM () -> + ADM () +locallyNonvectorised e m = do + adj_shape <- askShape + if adj_shape == mempty + then m + else do + -- We map over all adjoints of free variables in 'e'. To avoid clutter, we + -- only consider those that actually have known nonzero adjoints. + e_adjs <- filterM knownAdjoint e_free + e_adjs_vals <- mapM lookupAdjVal e_adjs + e_free_adjs <- mkMap "nonvec_adj" e_adjs_vals $ \e_adjs_vals' -> do + zipWithM_ insAdj e_adjs e_adjs_vals' + local (const mempty) m + mapM lookupAdjVal e_free + zipWithM_ insAdj e_free e_free_adjs + where + e_free = namesToList $ freeIn e + knownAdjoint v = do + v_adj <- lookupAdj v + pure $ case v_adj of + AdjZero {} -> False + _ -> True + -- Note [Consumption] -- -- Parts of this transformation depends on duplicating computation. diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index d7f2279664..a7e7071397 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -382,7 +382,7 @@ finalMapPPAD ops as scan = do eLambda op_bar_2 $ toExp . Var . paramName <$> par_y_right ++ par_a diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM () -diffScan ops ys w as scan = do +diffScan ops ys w as scan = locallyNonvectorised (ys, scan, as) $ do -- ys ~ results of scan, w ~ size of input array, as ~ (unzipped) -- arrays, scan ~ scan: operator with ne scan_case <- identifyCase ops $ scanLambda scan @@ -411,14 +411,12 @@ diffScan ops ys w as scan = do map1_lam <- mkScanFusedMapLam ops w (scanLambda scan) as ys ys_adj sc d scans_lin_fun_o <- mkScanLinFunO (head as_ts) sc scan_lams <- mkScans (specialScans sc) scans_lin_fun_o - iota <- - letExp "iota" $ BasicOp $ Iota w (intConst Int64 0) (intConst Int64 1) Int64 + iota <- letExp "iota" $ iota64 w r_scan <- letTupExp "adj_ctrb_scan" . Op . Screma w [iota] $ scanomapSOAC scan_lams map1_lam mkScanFinalMap ops w (scanLambda scan) as ys (splitScanRes sc r_scan d) -- Goal: calculate as_contribs in new way - -- zipWithM_ updateAdj as as_contribs -- as_bar += new adjoint zipWithM_ updateAdj as as_contribs where mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS] diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index 6a34253efa..12673bf656 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -1,5 +1,11 @@ -- | Various definitions used for both forward and reverse mode. -module Futhark.AD.Shared (auxPerm, asVName, mapNest) where +module Futhark.AD.Shared + ( auxPerm, + asVName, + mapNest, + mkMap, + ) +where import Control.Monad import Data.Foldable @@ -33,3 +39,15 @@ mapNest shape x f = do fmap (subExpsRes . pure) . letSubExp "tan" =<< f (fmap (Var . paramName) x_p) pure $ Op $ Screma w (toList x_v) (mapSOAC lam) + +mkMap :: + (MonadBuilder m, Rep m ~ SOACS, Traversable f) => + String -> + f VName -> + (f VName -> m [VName]) -> + m [VName] +mkMap desc arrs f = do + w <- arraySize 0 <$> lookupType (head $ toList arrs) + x_p <- traverse (newParam "xp" . rowType <=< lookupType) arrs + lam <- mkLambda (toList x_p) $ varsRes <$> f (fmap paramName x_p) + letTupExp desc $ Op $ Screma w (toList arrs) (mapSOAC lam) diff --git a/tests/ad/vec/scan.fut b/tests/ad/vec/scan0.fut similarity index 54% rename from tests/ad/vec/scan.fut rename to tests/ad/vec/scan0.fut index b5d7e88d07..97d3c005b3 100644 --- a/tests/ad/vec/scan.fut +++ b/tests/ad/vec/scan0.fut @@ -1,5 +1,5 @@ -- == --- entry: fwd_vec fwd_map +-- entry: fwd_vec fwd_map rev_map rev_vec -- input { [1f32, 2f32, 3f32] } -- output { [[1f32, 2.0, 6.0], [0f32, 1.0, 3.0], [0f32, 0.0, 2.0]] } @@ -13,3 +13,13 @@ entry fwd_vec (xs: []f32) : [][]f32 = entry fwd_map (xs: []f32) : [][]f32 = map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) (indices xs) + +entry rev_map (xs: []f32) : [][]f32 = + map (\i -> vjp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) + |> transpose + +entry rev_vec (xs: []f32) : [][]f32 = + let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in vjp_vec f xs seeds + |> transpose From bee9ae3d6c048daef50184b5c7fe528553c7960b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 26 Aug 2025 14:21:40 +0200 Subject: [PATCH 23/70] More work. --- src/Futhark/AD/Rev/Scan.hs | 4 ++-- tests/ad/vec/scan0.fut | 1 + tests/ad/vec/scan1.fut | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) create mode 100644 tests/ad/vec/scan1.fut diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index 70607efa2d..151dce0be9 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -437,7 +437,7 @@ diffScanVec :: [VName] -> ADM () -> ADM () -diffScanVec ops ys aux w lam ne as m = do +diffScanVec ops ys aux w lam ne as m = locallyNonvectorised (ys, lam, as) $ do stmts <- collectStms_ $ do rank <- arrayRank <$> lookupType (head as) let rear = [1, 0] ++ drop 2 [0 .. rank - 1] @@ -468,7 +468,7 @@ diffScanVec ops ys aux w lam ne as m = do foldr (vjpStm ops) m stmts diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM () -diffScanAdd _ops ys n lam' ne as = do +diffScanAdd _ops ys n lam' ne as = locallyNonvectorised (ys, lam', as) $ do lam <- renameLambda lam' ys_bar <- lookupAdjVal ys diff --git a/tests/ad/vec/scan0.fut b/tests/ad/vec/scan0.fut index 97d3c005b3..8bf83a5e70 100644 --- a/tests/ad/vec/scan0.fut +++ b/tests/ad/vec/scan0.fut @@ -1,4 +1,5 @@ -- == +-- tags { autodiff } -- entry: fwd_vec fwd_map rev_map rev_vec -- input { [1f32, 2f32, 3f32] } -- output { [[1f32, 2.0, 6.0], [0f32, 1.0, 3.0], [0f32, 0.0, 2.0]] } diff --git a/tests/ad/vec/scan1.fut b/tests/ad/vec/scan1.fut new file mode 100644 index 0000000000..9870b92df0 --- /dev/null +++ b/tests/ad/vec/scan1.fut @@ -0,0 +1,32 @@ +-- Scan with addition. +-- == +-- tags { autodiff } +-- entry: fwd_vec fwd_map rev_map rev_vec +-- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } +-- output { [[1f32, 1f32, 1f32, 1f32, 1f32], +-- [0f32, 1f32, 1f32, 1f32, 1f32], +-- [0f32, 0f32, 1f32, 1f32, 1f32], +-- [0f32, 0f32, 0f32, 1f32, 1f32], +-- [0f32, 0f32, 0f32, 0f32, 1f32]] +-- } + +def f (xs: []f32) = scan (+) 0 xs + +entry fwd_vec (xs: []f32) : [][]f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec f xs seeds).1 + +entry fwd_map (xs: []f32) : [][]f32 = + map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) + +entry rev_map (xs: []f32) : [][]f32 = + map (\i -> vjp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) + |> transpose + +entry rev_vec (xs: []f32) : [][]f32 = + let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in vjp_vec f xs seeds + |> transpose From 7d4f77df6691b87a366358c01e5de29d4dc2b9db Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 26 Aug 2025 15:28:21 +0200 Subject: [PATCH 24/70] Some tests, some of which fail. --- src/Futhark/AD/Shared.hs | 12 +++++++----- tests/ad/vec/map2.fut | 21 +++++++++++++++++++++ tests/ad/vec/map3.fut | 18 ++++++++++++++++++ tests/ad/vec/map4.fut | 37 +++++++++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+), 5 deletions(-) create mode 100644 tests/ad/vec/map2.fut create mode 100644 tests/ad/vec/map3.fut create mode 100644 tests/ad/vec/map4.fut diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index 2d628fcff0..af5d00ce44 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -46,8 +46,10 @@ mkMap :: f VName -> (f VName -> m [VName]) -> m [VName] -mkMap desc arrs f = do - w <- arraySize 0 <$> lookupType (head $ toList arrs) - x_p <- traverse (newParam "xp" . rowType <=< lookupType) arrs - lam <- mkLambda (toList x_p) $ varsRes <$> f (fmap paramName x_p) - letTupExp desc $ Op $ Screma w (toList arrs) (mapSOAC lam) +mkMap desc arrs f + | null arrs = pure [] + | otherwise = do + w <- arraySize 0 <$> lookupType (head $ toList arrs) + x_p <- traverse (newParam "xp" . rowType <=< lookupType) arrs + lam <- mkLambda (toList x_p) $ varsRes <$> f (fmap paramName x_p) + letTupExp desc $ Op $ Screma w (toList arrs) (mapSOAC lam) diff --git a/tests/ad/vec/map2.fut b/tests/ad/vec/map2.fut new file mode 100644 index 0000000000..8c96c57cae --- /dev/null +++ b/tests/ad/vec/map2.fut @@ -0,0 +1,21 @@ +-- Map with free variable. +-- == +-- tags { autodiff } +-- entry: fwd_map rev_map rev_vec +-- input { 2.0 [1.0,2.0,3.0] } +-- output { [1.0,2.0,3.0] } + +def primal xs (c': f64) = map (* c') xs + +def onehot n i : [n]f64 = + tabulate n (\j -> f64.bool (i == j)) + +entry fwd_map [n] (c: f64) (xs: [n]f64) = + jvp (primal xs) c 1 + +entry rev_map [n] (c: f64) (xs: [n]f64) = + tabulate n (\i -> vjp (primal xs) c (onehot n i)) + +entry rev_vec [n] (c: f64) (xs: [n]f64) = + let seeds = tabulate n (\i -> onehot n i) + in vjp_vec (primal xs) c seeds diff --git a/tests/ad/vec/map3.fut b/tests/ad/vec/map3.fut new file mode 100644 index 0000000000..0a4aa89880 --- /dev/null +++ b/tests/ad/vec/map3.fut @@ -0,0 +1,18 @@ +-- == +-- tags { autodiff } +-- entry: fwd rev_map rev_vec +-- input { 1i32 [1i32,2i32,3i32] } +-- output { [1i32,2i32,3i32] } + +def primal xs (x: i32) = map (* x) xs + +entry fwd [n] (x: i32) (xs: [n]i32) = + jvp (primal xs) x 1 + +entry rev_map [n] (x: i32) (xs: [n]i32) = + tabulate n (\i -> + vjp (primal xs) x (replicate n 0 with [i] = 1)) + +entry rev_vec [n] (x: i32) (xs: [n]i32) = + let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) + in vjp_vec (primal xs) x seeds diff --git a/tests/ad/vec/map4.fut b/tests/ad/vec/map4.fut new file mode 100644 index 0000000000..e2660afada --- /dev/null +++ b/tests/ad/vec/map4.fut @@ -0,0 +1,37 @@ +-- An array is both a 'map' input and a free variable in the lambda. +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { [1,2,3] } +-- output { +-- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] +-- } + +def primal (xs: []i32) = + map (\x -> map (+ x) xs) xs + +def onehot n i : [n]i32 = + tabulate n (\j -> i32.bool (i == j)) + +def onehot_2d n m p : [n][m]i32 = + tabulate_2d n m (\i j -> i32.bool ((i, j) == p)) + +entry fwd_map [n] (xs: [n]i32) = + tabulate n (\i -> jvp primal xs (onehot n i)) + |> map transpose + |> transpose + |> map transpose + +entry fwd_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in jvp_vec primal xs seeds + |> map transpose + |> transpose + |> map transpose + +entry rev_map [n] (xs: [n]i32) = + tabulate_2d n n (\i j -> vjp primal xs (onehot_2d n n (i, j))) + +entry rev_vec [n] (xs: [n]i32) = + let seeds = tabulate_2d n n (\i j -> onehot_2d n n (i, j)) + in unflatten (vjp_vec primal xs (flatten seeds)) From 76c52be06039cff8c26a7d69a6de9bb5d52cbc30 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Tue, 26 Aug 2025 16:26:23 +0200 Subject: [PATCH 25/70] More stuff works. --- src/Futhark/AD/Rev.hs | 8 ++++++-- src/Futhark/AD/Rev/Map.hs | 15 +++++++++++++-- tests/ad/vec/map5.fut | 26 ++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 4 deletions(-) create mode 100644 tests/ad/vec/map5.fut diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index 3012416ec9..05e9a83635 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -59,8 +59,12 @@ diffBasicOp pat aux e m = ConvOp op x -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do - contrib <- - letExp "contrib" $ BasicOp $ ConvOp (flipConvOp op) $ Var pat_adj + adj_shape <- askShape + + contrib <- letExp "convop_contrib" <=< mapNest adj_shape (MkSolo (Var pat_adj)) $ + \(MkSolo pat_adj') -> + pure $ BasicOp $ ConvOp (flipConvOp op) pat_adj' + updateSubExpAdj x contrib -- UnOp op x -> do diff --git a/src/Futhark/AD/Rev/Map.hs b/src/Futhark/AD/Rev/Map.hs index d5ee42e01f..fdd4f545d8 100644 --- a/src/Futhark/AD/Rev/Map.hs +++ b/src/Futhark/AD/Rev/Map.hs @@ -91,6 +91,16 @@ pushAdjShape v = do let perm = vecPerm adj_shape v_t letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm +popAdjShape :: VName -> ADM VName +popAdjShape v = do + adj_shape <- askShape + v_t <- lookupType v + if adj_shape == mempty || arrayShape v_t == adj_shape || isAcc v_t + then pure v + else do + let perm = rearrangeInverse $ vecPerm adj_shape v_t + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm + -- | Perform VJP on a Map. The 'Adj' list is the adjoints of the -- result of the map. vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM () @@ -178,6 +188,7 @@ vjpMap ops pat_adj aux w map_lam as = returnSweepCode $ do case t of Acc {} -> letExp "acc_adj_rep" . BasicOp . Replicate (Shape [w]) . Var =<< adjVal adj _ -> pushAdjShape =<< adjVal adj + pat_adj_params <- mapM (newParam "map_adj_p" . rowType <=< lookupType) pat_adj_vals @@ -208,8 +219,8 @@ vjpMap ops pat_adj aux w map_lam as = returnSweepCode $ do let param_ts = map paramType (lambdaParams map_lam') forM_ (zip3 param_ts as param_contribs) $ \(param_t, a, param_contrib) -> case param_t of - Acc {} -> freeContrib a param_contrib - _ -> updateAdj a param_contrib + Acc {} -> freeContrib a =<< popAdjShape param_contrib -- CHECKME + _ -> updateAdj a =<< popAdjShape param_contrib where addIdxParams n lam = do idxs <- replicateM n $ newParam "idx" $ Prim int64 diff --git a/tests/ad/vec/map5.fut b/tests/ad/vec/map5.fut new file mode 100644 index 0000000000..e16aa52d8e --- /dev/null +++ b/tests/ad/vec/map5.fut @@ -0,0 +1,26 @@ +-- Map with free array variable. +-- == +-- tags { autodiff } +-- entry: fwd_map rev_map +-- input { [[1,2,3],[4,5,6]] [0,0] } +-- output { [[1, 0], [0, 1]] } + +def onehot n i : [n]i32 = + tabulate n (\j -> i32.bool (i == j)) + +def primal [n] [m] (free: [n][m]i32) (is: [n]i32) = + map (\i -> foldl (+) 0 free[i] + i) is + +entry fwd_map [n] [m] (free: [n][m]i32) (is: [n]i32) = + tabulate n (\i -> jvp (primal free) is (onehot n i)) |> transpose + +entry fwd_vec [n] [m] (free: [n][m]i32) (is: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in jvp_vec (primal free) is seeds |> transpose + +entry rev_map [n] [m] (free: [n][m]i32) (is: [n]i32) = + tabulate n (\i -> vjp (primal free) is (onehot n i)) + +entry rev_vec [n] [m] (free: [n][m]i32) (is: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in vjp_vec (primal free) is seeds From c59aefb6b1062ef7b016c27fd18c14e8ca91f421 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 28 Aug 2025 16:52:29 +0200 Subject: [PATCH 26/70] Work on histograms. --- src/Futhark/AD/Fwd.hs | 17 +- src/Futhark/AD/Rev/Hist.hs | 511 +++++++++++++++++----------------- src/Futhark/AD/Rev/SOAC.hs | 10 +- src/Futhark/AD/Shared.hs | 2 +- tests/ad/vec/hist_add.fut | 55 ++++ tests/ad/vec/hist_complex.fut | 39 +++ tests/ad/vec/hist_minmax.fut | 35 +++ tests/ad/vec/hist_mul.fut | 40 +++ 8 files changed, 444 insertions(+), 265 deletions(-) create mode 100644 tests/ad/vec/hist_add.fut create mode 100644 tests/ad/vec/hist_complex.fut create mode 100644 tests/ad/vec/hist_minmax.fut create mode 100644 tests/ad/vec/hist_mul.fut diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 93e4ece29a..48e1531e79 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -438,16 +438,23 @@ fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do fwdSOAC pat aux (Stream size xs nes lam) = do pat' <- bundleNewPat pat lam' <- fwdStreamLambda lam - xs' <- bundleTangents xs + xs' <- soacInputsWithTangents xs nes_tan <- mapM (fmap Var . zeroFromSubExp) nes let nes' = interleave nes nes_tan addStm $ Let pat' aux $ Op $ Stream size xs' nes' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do - pat' <- bundleNewPat pat + -- TODO: this is probably not very efficient in the vectorised case as we end + -- up with a dreadful update operator that involves arrays. + (pat', to_transpose) <- soacResPat 0 0 pat ops' <- mapM fwdHist ops bucket_fun' <- fwdHistBucket bucket_fun - arrs' <- bundleTangents arrs + arrs' <- soacInputsWithTangents arrs addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' + tan_shape <- askShape + forM_ to_transpose $ \(rpat, v) -> do + v_t <- lookupType v + let perm = rearrangeInverse $ vecPerm tan_shape v_t + letBind rpat $ BasicOp $ Rearrange v perm where n_indices = sum $ map (shapeRank . histShape) ops fwdBodyHist (Body _ stms res) = buildBody_ $ do @@ -463,8 +470,8 @@ fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do fwdHist :: HistOp SOACS -> ADM (HistOp SOACS) fwdHist (HistOp shape rf dest nes op) = do - dest' <- bundleTangents dest - nes_tan <- mapM (fmap Var . zeroFromSubExp) nes + dest' <- soacInputsWithTangents dest + nes_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType op op' <- fwdLambda op pure $ HistOp diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 672674f7a2..c2189b9cd0 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -240,69 +240,70 @@ diffMinMaxHist _ops x aux n minmax ne is vs w rf dst m = do m - x_bar <- lookupAdjVal x - - x_ind_dst <- newParam (baseName x <> "_ind_param") $ Prim int64 - x_bar_dst <- newParam (baseName x <> "_bar_param") $ Prim t - dst_lam_inner <- - mkLambda [x_ind_dst, x_bar_dst] $ - fmap varsRes . letTupExp "dst_bar" - =<< eIf - (toExp $ le64 (paramName x_ind_dst) .==. -1) - (eBody $ pure $ eParam x_bar_dst) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) - dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner - - dst_bar <- - letExp (baseName dst <> "_bar") . Op $ - Screma w [x_inds, x_bar] (mapSOAC dst_lam) - - updateAdj dst dst_bar - - vs_bar <- lookupAdjVal vs - - inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims [] - let inds = x_inds : inds' - - par_x_ind_vs <- replicateM nr_dims $ newParam (baseName x <> "_ind_param") $ Prim int64 - par_x_bar_vs <- newParam (baseName x <> "_bar_param") $ Prim t - vs_lam_inner <- - mkLambda (par_x_bar_vs : par_x_ind_vs) $ - fmap varsRes . letTupExp "res" - =<< eIf - (toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) - ( eBody $ - pure $ do - vs_bar_i <- - letSubExp (baseName vs_bar <> "_el") . BasicOp $ - Index vs_bar . Slice $ - fmap (DimFix . Var . paramName) par_x_ind_vs - eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i) - ) - vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner - - vs_bar_p <- - letExp (baseName vs <> "_partial") . Op $ - Screma w (x_bar : inds) (mapSOAC vs_lam) - - q <- - letSubExp "q" - =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims - - scatter_inps <- do - -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p] - -- ToDo: Cosmin asks: is the below the correct translation of the line above? - forM (inds ++ [vs_bar_p]) $ \v -> do - v_t <- lookupType v - letExp "flat" . BasicOp . Reshape v $ - reshapeAll (arrayShape v_t) (Shape [q]) - - vs_bar' <- - fmap head $ - doScatter (baseName vs <> "_bar") nr_dims [vs_bar] scatter_inps $ - pure . map (Var . paramName) - insAdj vs vs_bar' + locallyNonvectorised (x, dst, vs) $ do + x_bar <- lookupAdjVal x + + x_ind_dst <- newParam (baseName x <> "_ind_param") $ Prim int64 + x_bar_dst <- newParam (baseName x <> "_bar_param") $ Prim t + dst_lam_inner <- + mkLambda [x_ind_dst, x_bar_dst] $ + fmap varsRes . letTupExp "dst_bar" + =<< eIf + (toExp $ le64 (paramName x_ind_dst) .==. -1) + (eBody $ pure $ eParam x_bar_dst) + (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) + dst_lam <- nestedmap inner_dims [int64, vs_elm_type] dst_lam_inner + + dst_bar <- + letExp (baseName dst <> "_bar") . Op $ + Screma w [x_inds, x_bar] (mapSOAC dst_lam) + + updateAdj dst dst_bar + + vs_bar <- lookupAdjVal vs + + inds' <- traverse (letExp "inds" . BasicOp . Replicate (Shape [w]) . Var) =<< mk_indices inner_dims [] + let inds = x_inds : inds' + + par_x_ind_vs <- replicateM nr_dims $ newParam (baseName x <> "_ind_param") $ Prim int64 + par_x_bar_vs <- newParam (baseName x <> "_bar_param") $ Prim t + vs_lam_inner <- + mkLambda (par_x_bar_vs : par_x_ind_vs) $ + fmap varsRes . letTupExp "res" + =<< eIf + (toExp $ le64 (paramName $ head par_x_ind_vs) .==. -1) + (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) + ( eBody $ + pure $ do + vs_bar_i <- + letSubExp (baseName vs_bar <> "_el") . BasicOp $ + Index vs_bar . Slice $ + fmap (DimFix . Var . paramName) par_x_ind_vs + eBinOp (getBinOpPlus t) (eParam par_x_bar_vs) (eSubExp vs_bar_i) + ) + vs_lam <- nestedmap inner_dims (vs_elm_type : replicate nr_dims int64) vs_lam_inner + + vs_bar_p <- + letExp (baseName vs <> "_partial") . Op $ + Screma w (x_bar : inds) (mapSOAC vs_lam) + + q <- + letSubExp "q" + =<< foldBinOp (Mul Int64 OverflowUndef) (intConst Int64 1) dst_dims + + scatter_inps <- do + -- traverse (letExp "flat" . BasicOp . Reshape [DimNew q]) $ inds ++ [vs_bar_p] + -- ToDo: Cosmin asks: is the below the correct translation of the line above? + forM (inds ++ [vs_bar_p]) $ \v -> do + v_t <- lookupType v + letExp "flat" . BasicOp . Reshape v $ + reshapeAll (arrayShape v_t) (Shape [q]) + + vs_bar' <- + fmap head $ + doScatter (baseName vs <> "_bar") nr_dims [vs_bar] scatter_inps $ + pure . map (Var . paramName) + insAdj vs vs_bar' where mk_indices :: [SubExp] -> [SubExp] -> ADM [VName] mk_indices [] _ = pure [] @@ -402,8 +403,8 @@ diffMulHist _ops x aux n mul ne is vs w rf dst m = do fmap varsRes . letTupExp "h_part" =<< eIf (toExp $ 0 .==. le64 (paramName c_param)) - (eBody $ pure $ eParam p_param) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) + (eBody [eParam p_param]) + (eBody [eSubExp $ Constant $ blankPrimValue t]) lam_h_part <- nestedmap dst_dims [vs_elm_type, int64] lam_h_part_inner h_part_res <- eLambda lam_h_part $ map (eSubExp . Var) [nz_prods, zr_counts] h_part' <- bindSubExpRes "h_part" h_part_res @@ -417,59 +418,60 @@ diffMulHist _ops x aux n mul ne is vs w rf dst m = do m - x_bar <- lookupAdjVal x - - lam_mul'' <- renameLambda lam_mul' - dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar] - dst_bar <- bindSubExpRes (baseName dst <> "_bar") dst_bar_res - updateAdj dst $ head dst_bar - - lam_mul''' <- renameLambda lam_mul' - part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar] - part_bar' <- bindSubExpRes "part_bar" part_bar_res - let [part_bar] = part_bar' - - inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t] - let [zr_cts, pr_bar, nz_prd, a_param] = inner_params - lam_vsbar_inner <- - mkLambda inner_params $ - fmap varsRes . letTupExp "vs_bar" =<< do - eIf - (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts)) - (eBody $ pure $ eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param) - ( eBody $ - pure $ - eIf - ( eBinOp - LogAnd - (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts)) - (eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param) - ) - (eBody $ pure $ eBinOp mul (eParam nz_prd) (eParam pr_bar)) - (eBody $ pure $ eSubExp $ Constant $ blankPrimValue t) - ) + locallyNonvectorised (x, dst, vs) $ do + x_bar <- lookupAdjVal x - lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner + lam_mul'' <- renameLambda lam_mul' + dst_bar_res <- eLambda lam_mul'' $ map (eSubExp . Var) [h_part, x_bar] + dst_bar <- bindSubExpRes (baseName dst <> "_bar") dst_bar_res + updateAdj dst $ head dst_bar - i_param <- newParam "i" $ Prim int64 - a_param' <- newParam "a" $ rowType vs_type - lam_vsbar <- - mkLambda [i_param, a_param'] $ - fmap varsRes . letTupExp "vs_bar" - =<< eIf - (toExp $ withinBounds $ pure (w, paramName i_param)) - ( buildBody_ $ do - let i = fullSlice vs_type [DimFix $ Var $ paramName i_param] - names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"] - zipWithM_ (\name -> letBindNames [name] . BasicOp . flip Index i) names [zr_counts, part_bar, nz_prods] - eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param'] - ) - (eBody $ pure $ pure $ zeroExp $ rowType dst_type) + lam_mul''' <- renameLambda lam_mul' + part_bar_res <- eLambda lam_mul''' $ map (eSubExp . Var) [dst, x_bar] + part_bar' <- bindSubExpRes "part_bar" part_bar_res + let [part_bar] = part_bar' + + inner_params <- zipWithM newParam ["zr_cts", "pr_bar", "nz_prd", "a"] $ map Prim [int64, t, t, t] + let [zr_cts, pr_bar, nz_prd, a_param] = inner_params + lam_vsbar_inner <- + mkLambda inner_params $ + fmap varsRes . letTupExp "vs_bar" =<< do + eIf + (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 0) (eParam zr_cts)) + (eBody [eBinOp mul (eParam pr_bar) $ eBinOp (getBinOpDiv t) (eParam nz_prd) $ eParam a_param]) + ( eBody + [ eIf + ( eBinOp + LogAnd + (eCmpOp (CmpEq int64) (eSubExp $ intConst Int64 1) (eParam zr_cts)) + (eCmpOp (CmpEq t) (eSubExp $ Constant $ blankPrimValue t) $ eParam a_param) + ) + (eBody [eBinOp mul (eParam nz_prd) (eParam pr_bar)]) + (eBody [eSubExp $ Constant $ blankPrimValue t]) + ] + ) - vs_bar <- - letExp (baseName vs <> "_bar") $ Op $ Screma n [is, vs] $ mapSOAC lam_vsbar + lam_vsbar_middle <- nestedmap inner_dims [int64, t, t, t] lam_vsbar_inner + + i_param <- newParam "i" $ Prim int64 + a_param' <- newParam "a" $ rowType vs_type + lam_vsbar <- + mkLambda [i_param, a_param'] $ + fmap varsRes . letTupExp "vs_bar" + =<< eIf + (toExp $ withinBounds $ pure (w, paramName i_param)) + ( buildBody_ $ do + let i = fullSlice vs_type [DimFix $ Var $ paramName i_param] + names <- traverse newVName ["zr_cts", "pr_bar", "nz_prd"] + zipWithM_ (\name -> letBindNames [name] . BasicOp . flip Index i) names [zr_counts, part_bar, nz_prods] + eLambda lam_vsbar_middle $ map (eSubExp . Var) names <> [eParam a_param'] + ) + (eBody [pure $ zeroExp $ rowType dst_type]) + + vs_bar <- + letExp (baseName vs <> "_bar") $ Op $ Screma n [is, vs] $ mapSOAC lam_vsbar - updateAdj vs vs_bar + updateAdj vs vs_bar -- -- special case of histogram with add as operator. @@ -500,23 +502,23 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do m - x_bar <- lookupAdjVal x + locallyNonvectorised (x, dst, vs) $ do + x_bar <- lookupAdjVal x - updateAdj dst x_bar + updateAdj dst x_bar - x_type <- lookupType x - i_param <- newParam (baseName vs <> "_i") $ Prim int64 - let i = paramName i_param - lam_vsbar <- - mkLambda [i_param] $ - fmap varsRes . letTupExp "vs_bar" - =<< eIf - (toExp $ withinBounds $ pure (w, i)) - (eBody $ pure $ pure $ BasicOp $ Index x_bar $ fullSlice x_type [DimFix $ Var i]) - (eBody $ pure $ eSubExp ne) + i_param <- newParam (baseName vs <> "_i") $ Prim int64 + let i = paramName i_param + lam_vsbar <- + mkLambda [i_param] $ + fmap varsRes . letTupExp "vs_bar" + =<< eIf + (toExp $ withinBounds $ pure (w, i)) + (eBody [eIndex x_bar [eVar i]]) + (eBody [eSubExp ne]) - vs_bar <- letExp (baseName vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar - updateAdj vs vs_bar + vs_bar <- letExp (baseName vs <> "_bar") $ Op $ Screma n [is] $ mapSOAC lam_vsbar + updateAdj vs vs_bar -- Special case for vectorised combining operator. Rewrite -- reduce_by_index dst (map2 op) nes is vss @@ -789,144 +791,145 @@ diffHist ops xs aux n lam0 ne as w rf dst m = do m - xs_bar <- traverse lookupAdjVal xs - - (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w - f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f' - f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f' - - dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part - dst_bar <- bindSubExpRes "dst_bar" dst_bar' - zipWithM_ updateAdj dst dst_bar - - h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part - h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar' - - lam <- renameLambda lam0 - lam' <- renameLambda lam0 - - -- is' <- mapout (head as) n (head w) - -- sorted <- radixSort' (is' : tail as) n $ head w - sorted <- radixSort' as n $ head w - let siota = head sorted - let sis = head $ tail sorted - let sas = drop 2 sorted - - iota_n <- - letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 - - par_i <- newParam "i" $ Prim int64 - flag_lam <- mkFlagLam par_i sis - flag <- letExp "flag" $ Op $ Screma n [iota_n] $ mapSOAC flag_lam - - -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n) - par_i' <- newParam "i" $ Prim int64 - let i' = paramName par_i' - g_lam <- - mkLambda [par_i'] $ - fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do - im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1) - nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i') - let s1 = [DimFix im1] - let s2 = [DimFix nmi] - - -- flag array for left scan - f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i'] - - -- array for left scan - r1 <- - letTupExp' "r1" - =<< eIf - (eSubExp f1) - (eBody $ fmap eSubExp ne) - (eBody . fmap (eSubExp . Var) =<< multiIndex sas s1) - - -- array for right scan inc flag - r2 <- - letTupExp' "r2" - =<< eIf - (toExp $ le64 i' .==. 0) - (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) - ( eBody $ - pure $ do - eIf - (pure $ BasicOp $ Index flag $ Slice s2) - (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) - ( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var - =<< multiIndex sas s2 - ) - ) - - traverse eSubExp $ f1 : r1 ++ r2 - - -- scan (\(f1,v1) (f2,v2) -> - -- let f = f1 || f2 - -- let v = if f2 then v2 else g v1 v2 - -- in (f,v) ) (false,ne) (zip flags vals) - scan_lams <- - traverse - ( \l -> do - f1 <- newParam "f1" $ Prim Bool - f2 <- newParam "f2" $ Prim Bool - ps <- lambdaParams <$> renameLambda lam0 - let (p1, p2) = splitAt (length ne) ps - - mkLambda (f1 : p1 ++ f2 : p2) $ - fmap varsRes . letTupExp "scan_res" =<< do - let f = eBinOp LogOr (eParam f1) (eParam f2) - eIf - (eParam f2) - (eBody $ f : fmap eParam p2) - ( eBody . (f :) . fmap (eSubExp . Var) - =<< bindSubExpRes "gres" - =<< eLambda l (fmap eParam ps) + locallyNonvectorised (xs, dst, lam0, as) $ do + xs_bar <- traverse lookupAdjVal xs + + (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w + f'_adj_dst <- vjpLambda ops (map adjFromVar xs_bar) dst_params f' + f'_adj_hp <- vjpLambda ops (map adjFromVar xs_bar) hp_params f' + + dst_bar' <- eLambda f'_adj_dst $ map (eSubExp . Var) $ dst <> h_part + dst_bar <- bindSubExpRes "dst_bar" dst_bar' + zipWithM_ updateAdj dst dst_bar + + h_part_bar' <- eLambda f'_adj_hp $ map (eSubExp . Var) $ dst <> h_part + h_part_bar <- bindSubExpRes "h_part_bar" h_part_bar' + + lam <- renameLambda lam0 + lam' <- renameLambda lam0 + + -- is' <- mapout (head as) n (head w) + -- sorted <- radixSort' (is' : tail as) n $ head w + sorted <- radixSort' as n $ head w + let siota = head sorted + let sis = head $ tail sorted + let sas = drop 2 sorted + + iota_n <- + letExp "iota" $ BasicOp $ Iota n (intConst Int64 0) (intConst Int64 1) Int64 + + par_i <- newParam "i" $ Prim int64 + flag_lam <- mkFlagLam par_i sis + flag <- letExp "flag" $ Op $ Screma n [iota_n] $ mapSOAC flag_lam + + -- map (\i -> (if flag[i] then (true,ne) else (false,vs[i-1]), if i==0 || flag[n-i] then (true,ne) else (false,vs[n-i]))) (iota n) + par_i' <- newParam "i" $ Prim int64 + let i' = paramName par_i' + g_lam <- + mkLambda [par_i'] $ + fmap subExpsRes . mapM (letSubExp "scan_inps") =<< do + im1 <- letSubExp "i_1" =<< toExp (le64 i' - 1) + nmi <- letSubExp "n_i" =<< toExp (pe64 n - le64 i') + let s1 = [DimFix im1] + let s2 = [DimFix nmi] + + -- flag array for left scan + f1 <- letSubExp "f1" $ BasicOp $ Index flag $ Slice [DimFix $ Var i'] + + -- array for left scan + r1 <- + letTupExp' "r1" + =<< eIf + (eSubExp f1) + (eBody $ fmap eSubExp ne) + (eBody . fmap (eSubExp . Var) =<< multiIndex sas s1) + + -- array for right scan inc flag + r2 <- + letTupExp' "r2" + =<< eIf + (toExp $ le64 i' .==. 0) + (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) + ( eBody $ + pure $ do + eIf + (pure $ BasicOp $ Index flag $ Slice s2) + (eBody $ fmap eSubExp $ Constant (onePrimValue Bool) : ne) + ( eBody . fmap eSubExp . (Constant (blankPrimValue Bool) :) . fmap Var + =<< multiIndex sas s2 + ) ) - ) - [lam, lam'] - let ne' = Constant (BoolValue False) : ne + traverse eSubExp $ f1 : r1 ++ r2 + + -- scan (\(f1,v1) (f2,v2) -> + -- let f = f1 || f2 + -- let v = if f2 then v2 else g v1 v2 + -- in (f,v) ) (false,ne) (zip flags vals) + scan_lams <- + traverse + ( \l -> do + f1 <- newParam "f1" $ Prim Bool + f2 <- newParam "f2" $ Prim Bool + ps <- lambdaParams <$> renameLambda lam0 + let (p1, p2) = splitAt (length ne) ps + + mkLambda (f1 : p1 ++ f2 : p2) $ + fmap varsRes . letTupExp "scan_res" =<< do + let f = eBinOp LogOr (eParam f1) (eParam f2) + eIf + (eParam f2) + (eBody $ f : fmap eParam p2) + ( eBody . (f :) . fmap (eSubExp . Var) + =<< bindSubExpRes "gres" + =<< eLambda l (fmap eParam ps) + ) + ) + [lam, lam'] - scansres <- - letTupExp "adj_ctrb_scan" . Op $ - Screma n [iota_n] (scanomapSOAC (map (`Scan` ne') scan_lams) g_lam) + let ne' = Constant (BoolValue False) : ne - let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres + scansres <- + letTupExp "adj_ctrb_scan" . Op $ + Screma n [iota_n] (scanomapSOAC (map (`Scan` ne') scan_lams) g_lam) - -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis - par_i'' <- newParam "i" $ Prim int64 - let i'' = paramName par_i'' - map_lam <- - mkLambda [par_i''] $ - fmap varsRes . letTupExp "scan_res" - =<< eIf - (toExp $ withinBounds $ pure (head w, i'')) - (eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i'']) - ( eBody $ do - map (\t -> pure $ BasicOp $ Replicate (Shape $ tail $ arrayDims t) (Constant $ blankPrimValue $ elemType t)) as_type - ) + let (_ : ls_arr, _ : rs_arr_rev) = splitAt (length ne + 1) scansres + + -- map (\i -> if i < w && -1 < w then (xs_bar[i], dst[i]) else (0,ne)) sis + par_i'' <- newParam "i" $ Prim int64 + let i'' = paramName par_i'' + map_lam <- + mkLambda [par_i''] $ + fmap varsRes . letTupExp "scan_res" + =<< eIf + (toExp $ withinBounds $ pure (head w, i'')) + (eBody . fmap (eSubExp . Var) =<< multiIndex h_part_bar [DimFix $ Var i'']) + ( eBody $ do + map (\t -> pure $ BasicOp $ Replicate (Shape $ tail $ arrayDims t) (Constant $ blankPrimValue $ elemType t)) as_type + ) - f_bar <- letTupExp "f_bar" $ Op $ Screma n [sis] $ mapSOAC map_lam + f_bar <- letTupExp "f_bar" $ Op $ Screma n [sis] $ mapSOAC map_lam - (as_params, f) <- mkF lam0 as_type n - f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f + (as_params, f) <- mkF lam0 as_type n + f_adj <- vjpLambda ops (map adjFromVar f_bar) as_params f - -- map (\i -> rs_arr_rev[n-i-1]) (iota n) - par_i''' <- newParam "i" $ Prim int64 - let i''' = paramName par_i''' - rev_lam <- mkLambda [par_i'''] $ do - nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1) - varsRes <$> multiIndex rs_arr_rev [DimFix nmim1] + -- map (\i -> rs_arr_rev[n-i-1]) (iota n) + par_i''' <- newParam "i" $ Prim int64 + let i''' = paramName par_i''' + rev_lam <- mkLambda [par_i'''] $ do + nmim1 <- letSubExp "n_i_1" =<< toExp (pe64 n - le64 i''' - 1) + varsRes <$> multiIndex rs_arr_rev [DimFix nmim1] - rs_arr <- letTupExp "rs_arr" $ Op $ Screma n [iota_n] $ mapSOAC rev_lam + rs_arr <- letTupExp "rs_arr" $ Op $ Screma n [iota_n] $ mapSOAC rev_lam - sas_bar <- - bindSubExpRes "sas_bar" - =<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr) + sas_bar <- + bindSubExpRes "sas_bar" + =<< eLambda f_adj (map (eSubExp . Var) $ ls_arr <> sas <> rs_arr) - scatter_dst <- traverse (\t -> letExp "scatter_dst" $ BasicOp $ Scratch (elemType t) (arrayDims t)) as_type - as_bar <- multiScatter scatter_dst siota sas_bar + scatter_dst <- traverse (\t -> letExp "scatter_dst" $ BasicOp $ Scratch (elemType t) (arrayDims t)) as_type + as_bar <- multiScatter scatter_dst siota sas_bar - zipWithM_ updateAdj (tail as) as_bar + zipWithM_ updateAdj (tail as) as_bar where -- map (\i -> if i == 0 then true else is[i] != is[i-1]) (iota n) mkFlagLam :: LParam SOACS -> VName -> ADM (Lambda SOACS) diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index d3a7453582..4295d83926 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -171,14 +171,14 @@ vjpSOAC ops pat aux (Hist n [is, vs] [histop] f) m Just [(op, _, _, _)] <- lamIsBinOp lam', isAddOp op = diffAddHist ops x aux n lam ne is vs w rf dst m -vjpSOAC ops pat aux (Hist n as [histop] f) m +vjpSOAC ops pat aux (Hist w as [histop] f) m | isIdentityLambda f, - HistOp (Shape w) rf dst ne lam <- histop = do - diffHist ops (patNames pat) aux n lam ne as w rf dst m -vjpSOAC ops pat _aux (Hist n as histops f) m + HistOp (Shape n) rf dst ne lam <- histop = do + diffHist ops (patNames pat) aux w lam ne as n rf dst m +vjpSOAC ops pat _aux (Hist w as histops f) m | not (isIdentityLambda f) = do (mapstm, redstm) <- - histomapToMapAndHist pat (n, histops, f, as) + histomapToMapAndHist pat (w, histops, f, as) vjpStm ops mapstm $ vjpStm ops redstm m vjpSOAC ops pat aux (Stream w as accs lam) m = do stms <- collectStms_ $ auxing aux $ sequentialStreamWholeArray pat w accs lam as diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index af5d00ce44..8582c10422 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -36,7 +36,7 @@ mapNest shape x f = do x_v <- traverse asVName x x_p <- traverse (newParam "xp" . rowType <=< lookupType) x_v lam <- mkLambda (toList x_p) $ do - fmap (subExpsRes . pure) . letSubExp "tan" + fmap (subExpsRes . pure) . letSubExp "mapnest_res" =<< f (fmap (Var . paramName) x_p) pure $ Op $ Screma w (toList x_v) (mapSOAC lam) diff --git a/tests/ad/vec/hist_add.fut b/tests/ad/vec/hist_add.fut new file mode 100644 index 0000000000..25c5a1ee1b --- /dev/null +++ b/tests/ad/vec/hist_add.fut @@ -0,0 +1,55 @@ +-- Addition +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 2f32, 0f32, 2f32, 0f32, 0f32, 2f32, 0f32, 0f32], +-- [0f32, 0f32, 3f32, 0f32, 0f32, 3f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 4f32, 4f32, 0f32, 0f32, 0f32], +-- [5f32, 0f32, 0f32, 5f32, 0f32, 0f32, 0f32, 0f32, 5f32, 0f32, 0f32, 5f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 6f32, 0f32, 0f32, 6f32, 0f32, 0f32, 0f32, 0f32, 6f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- [[3f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 12f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 13f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 9f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 24f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 30f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- } + +def f [n] [m] (is: [n]i64) (vs: [n]f32, c: [m]f32) = + let r = hist (+) 0 m is vs + in map2 (*) r c + +entry fwd_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry fwd_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = + tabulate (n + m) (\i -> + ( tabulate n ((i ==) >-> f32.bool) + , tabulate m (((i - n) ==) >-> f32.bool) + )) + in jvp_vec (f is) (vs, c) seeds + |> transpose + |> map split + |> unzip + +entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in vjp_vec (f is) (vs, c) seeds + |> unzip diff --git a/tests/ad/vec/hist_complex.fut b/tests/ad/vec/hist_complex.fut new file mode 100644 index 0000000000..8d4af373f9 --- /dev/null +++ b/tests/ad/vec/hist_complex.fut @@ -0,0 +1,39 @@ +-- == +-- tags { autodiff } +-- entry: rev_map rev_vec +-- input { +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 2f32, 4f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 96f32, 0f32, 160f32, 0f32, 0f32, 120f32, 0f32, 0f32], +-- [0f32, 0f32, 36f32, 0f32, 0f32, 42f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 72f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 5376f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- [[4f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 240f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 84f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0.5, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0.5]] +-- } + +def f [n] [m] (is: [n]i64) (vs: [n]f32, c: [m]f32) = + let r = hist (\x y -> x * y * 2) 0.5 m is vs + in map2 (*) r c + +entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in vjp_vec (f is) (vs, c) seeds + |> unzip diff --git a/tests/ad/vec/hist_minmax.fut b/tests/ad/vec/hist_minmax.fut new file mode 100644 index 0000000000..d6792209df --- /dev/null +++ b/tests/ad/vec/hist_minmax.fut @@ -0,0 +1,35 @@ +-- Maximum +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { +-- 5i64 +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- } + +def primal [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + hist f32.max f32.lowest k is vs + +entry fwd_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + tabulate n (\i -> jvp (primal k is) vs (replicate n 0 with [i] = 1)) + |> transpose + +entry fwd_vec [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (primal k is) vs seeds + |> transpose + +entry rev_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + tabulate k (\i -> vjp (primal k is) vs (replicate k 0 with [i] = 1)) + +entry rev_vec [n] (k: i64) (is: [n]i64) (vs: [n]f32) = + let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) + in vjp_vec (primal k is) vs seeds diff --git a/tests/ad/vec/hist_mul.fut b/tests/ad/vec/hist_mul.fut new file mode 100644 index 0000000000..08e08b1b0f --- /dev/null +++ b/tests/ad/vec/hist_mul.fut @@ -0,0 +1,40 @@ +-- Multiplication +-- == +-- tags { autodiff } +-- entry: rev_map rev_vec +-- input { +-- [4i64,5i64,2i64,4i64,5i64,2i64,0i64,0i64,4i64,5i64,1i64,4i64,1i64,3i64,3i64,1i64,8i64,-1i64] +-- [11f32,16f32,7f32,0f32,14f32,6f32,2f32,1f32,13f32,0f32,5f32,0f32,3f32,0f32,9f32,4f32,17f32,18f32] +-- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32] +-- } +-- output { +-- [[0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 2f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 24f32, 0f32, 40f32, 0f32, 0f32, 30f32, 0f32, 0f32], +-- [0f32, 0f32, 18f32, 0f32, 0f32, 21f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 36f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1344f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32]] +-- [[2f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 60f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 42f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32, 0f32], +-- [0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 0f32, 1f32]] +-- } + +def f [n] [m] (is: [n]i64) (vs: [n]f32, c: [m]f32) = + let r = hist (*) 1 m is vs + in map2 (*) r c + +entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + tabulate m (\i -> vjp (f is) (vs, c) (replicate m 0 with [i] = 1)) + |> unzip + +entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in vjp_vec (f is) (vs, c) seeds + |> unzip From a91b91141271480eb1089619587bacd6fc92e78b Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 25 Sep 2025 18:06:06 +0200 Subject: [PATCH 27/70] Add failing test. --- tests/ad/vec/map6.fut | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 tests/ad/vec/map6.fut diff --git a/tests/ad/vec/map6.fut b/tests/ad/vec/map6.fut new file mode 100644 index 0000000000..8a3b010171 --- /dev/null +++ b/tests/ad/vec/map6.fut @@ -0,0 +1,37 @@ +-- #1878 +-- == +-- tags { autodiff } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } +-- output { [[0.0, 2.0, 3.0, 4.0], +-- [0.0, 0.0, 1.0, 1.0], +-- [0.0, 0.0, 0.0, 1.0], +-- [0.0, 0.0, 0.0, 0.0], +-- [-4.0, -6.0, -7.0, -8.0], +-- [0.0, 0.0, -1.0, -1.0], +-- [0.0, 0.0, 0.0, -1.0], +-- [0.0, 0.0, 0.0, 0.0]] +-- } + +def obj (x: [8]f64) = + #[unsafe] + -- For simplicity of generated code. + let col_w_pre_red = + tabulate_3d 4 2 4 (\k i j -> x[k + j] * x[i + j]) + let col_w_red = + map (map f64.sum) col_w_pre_red + let col_eq: [4]f64 = + map (\w -> w[0] - w[1]) col_w_red + in col_eq + +entry fwd_map (x: [8]f64) = + tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) + +entry fwd_vec (x: [8]f64) = + jvp_vec obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) + +entry rev_map (x: [8]f64) = + transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) + +entry rev_vec (x: [8]f64) = + transpose (vjp_vec obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) From e6949ab47547cadab40a8d4267475025a417dba1 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 8 Oct 2025 12:52:58 +0200 Subject: [PATCH 28/70] Handle vector adjoints. --- src/Futhark/AD/Rev/Monad.hs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 29c13e5952..8f2aeb0deb 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -433,11 +433,13 @@ updateAdjIndex v (check, i) se = do =<< case v_adj_t of Acc {} -> do let stms s = do + vec_shape <- askShape dims <- arrayDims <$> lookupType se_v ~[v_adj'] <- - tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> + tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> do + let (vec_is, val_is) = splitAt (shapeRank vec_shape) $ map Var is letTupExp "acc" . BasicOp $ - UpdateAcc s v_adj' (i : map Var is) [Var se_v'] + UpdateAcc s v_adj' (vec_is ++ i : val_is) [Var se_v'] pure v_adj' case check of CheckBounds _ -> stms Safe From 5425e7fe3075b8d6f65ccf00f7c5f4d413884b73 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 8 Oct 2025 15:12:29 +0200 Subject: [PATCH 29/70] Support unrolling of maps over accumulators. --- src/Futhark/IR/SOACS/Simplify.hs | 24 +++++++------- src/Futhark/Pass/AD.hs | 4 ++- src/Futhark/Transform/FirstOrderTransform.hs | 34 +++++++++++++------- 3 files changed, 39 insertions(+), 23 deletions(-) diff --git a/src/Futhark/IR/SOACS/Simplify.hs b/src/Futhark/IR/SOACS/Simplify.hs index 33e802b5d9..4a2a6ed2a0 100644 --- a/src/Futhark/IR/SOACS/Simplify.hs +++ b/src/Futhark/IR/SOACS/Simplify.hs @@ -32,6 +32,7 @@ import Data.List.NonEmpty (NonEmpty (..)) import Data.Map.Strict qualified as M import Data.Maybe import Data.Set qualified as S +import Futhark.Analysis.Alias qualified as Alias import Futhark.Analysis.DataDependencies import Futhark.Analysis.SymbolTable qualified as ST import Futhark.Analysis.UsageTable qualified as UT @@ -46,6 +47,7 @@ import Futhark.Optimise.Simplify.Rules import Futhark.Optimise.Simplify.Rules.ClosedForm import Futhark.Pass import Futhark.Tools +import Futhark.Transform.FirstOrderTransform qualified as FOT import Futhark.Transform.Rename import Futhark.Util @@ -619,7 +621,7 @@ simplifyClosedFormReduce _ _ _ _ = Skip -- For now we just remove singleton SOACs and those with unroll attributes. simplifyKnownIterationSOAC :: - (Buildable rep, BuilderOps rep, HasSOAC rep) => + (Buildable rep, BuilderOps rep, HasSOAC rep, Alias.AliasableRep rep) => TopDownRuleOp rep simplifyKnownIterationSOAC _ pat _ op | Just (Screma (Constant k) arrs (ScremaForm map_lam scans reds)) <- asSOAC op, @@ -677,16 +679,16 @@ simplifyKnownIterationSOAC _ pat _ op certifying cs $ letBindNames [v] $ BasicOp $ SubExp se -- simplifyKnownIterationSOAC _ pat aux op - | Just (Screma (Constant (IntValue (Int64Value k))) arrs (ScremaForm map_lam [] [])) <- asSOAC op, - "unroll" `inAttrs` stmAuxAttrs aux = Simplify $ do - arrs_elems <- fmap transpose . forM [0 .. k - 1] $ \i -> do - map_lam' <- renameLambda map_lam - eLambda map_lam' $ map (`eIndex` [eSubExp (constant i)]) arrs - forM_ (zip3 (patNames pat) arrs_elems (lambdaReturnType map_lam)) $ - \(v, arr_elems, t) -> - certifying (mconcat (map resCerts arr_elems)) $ - letBindNames [v] . BasicOp $ - ArrayLit (map resSubExp arr_elems) t + | Just (Screma w arrs form) <- asSOAC op, + Constant (IntValue (Int64Value k)) <- w, + "unroll" `inAttrs` stmAuxAttrs aux = + Simplify $ + auxing aux $ + FOT.transformScrema + pat + (Constant (IntValue (Int64Value k))) + arrs + form -- simplifyKnownIterationSOAC _ _ _ _ = Skip diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index 396e3cd6ba..5844fa4cc7 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -40,7 +40,9 @@ onStm mode scope (Let pat aux (Op (VJP shape args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do - lam'' <- (`runReaderT` scope) . simplifyLambda =<< revVJP scope shape lam' + lam'' <- + (`runReaderT` scope) . simplifyLambda + =<< revVJP scope shape (stmAuxAttrs aux) lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope else pure $ oneStm $ Let pat aux $ Op $ VJP shape args vec lam' onStm mode scope (Let pat aux (Op (JVP shape args vec lam))) = do diff --git a/src/Futhark/Transform/FirstOrderTransform.hs b/src/Futhark/Transform/FirstOrderTransform.hs index 656dd538a4..afd7e24017 100644 --- a/src/Futhark/Transform/FirstOrderTransform.hs +++ b/src/Futhark/Transform/FirstOrderTransform.hs @@ -13,6 +13,7 @@ module Futhark.Transform.FirstOrderTransform transformStmRecursively, transformLambda, transformSOAC, + transformScrema, ) where @@ -113,19 +114,15 @@ resultArray arrs ts = do letExp "result" =<< eBlank t mapM oneArray ts --- | Transform a single 'SOAC' into a do-loop. The body of the lambda --- is untouched, and may or may not contain further 'SOAC's depending --- on the given rep. -transformSOAC :: +-- | Sequentialise a single Screma. +transformScrema :: (Transformer m) => - Pat (LetDec (Rep m)) -> - SOAC (Rep m) -> + Pat dec -> + SubExp -> + [VName] -> + ScremaForm (Rep m) -> m () -transformSOAC _ JVP {} = - error "transformSOAC: unhandled JVP" -transformSOAC _ VJP {} = - error "transformSOAC: unhandled VJP" -transformSOAC pat (Screma w arrs form@(ScremaForm map_lam scans reds)) = do +transformScrema pat w arrs form@(ScremaForm map_lam scans reds) = do -- See Note [Translation of Screma]. -- -- Start by combining all the reduction and scan parts into a single @@ -226,6 +223,21 @@ transformSOAC pat (Screma w arrs form@(ScremaForm map_lam scans reds)) = do (++ patNames pat) <$> replicateM (length scanacc_params) (newVName "discard") letBindNames names $ Loop merge loopform loop_body + +-- | Transform a single 'SOAC' into a do-loop. The body of the lambda +-- is untouched, and may or may not contain further 'SOAC's depending +-- on the given rep. +transformSOAC :: + (Transformer m) => + Pat (LetDec (Rep m)) -> + SOAC (Rep m) -> + m () +transformSOAC _ JVP {} = + error "transformSOAC: unhandled JVP" +transformSOAC _ VJP {} = + error "transformSOAC: unhandled VJP" +transformSOAC pat (Screma w arrs form) = + transformScrema pat w arrs form transformSOAC pat (Stream w arrs nes lam) = do -- Create a loop that repeatedly applies the lambda body to a -- chunksize of 1. Hopefully this will lead to this outer loop From 35a0968a24b7eff2f63c264c33a2766b05357c79 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 8 Oct 2025 17:49:29 +0200 Subject: [PATCH 30/70] Support unrolling of vectorised AD. --- prelude/ad.fut | 10 +++--- src/Futhark/AD/Fwd.hs | 64 ++++++++++++++++++++++++++++--------- src/Futhark/AD/Rev.hs | 12 +++++-- src/Futhark/AD/Rev/Monad.hs | 32 ++++++++++++------- src/Futhark/Pass/AD.hs | 2 +- tests/ad/vec/map6.fut | 3 +- 6 files changed, 88 insertions(+), 35 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 82cb83ddbc..7acb7b2cbf 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -108,13 +108,15 @@ def jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = intrinsics.vjp2 f x y' --- | As `jvp2`, but accepts a vector of seed values. Semantically --- equivalent to mapping, but may be more efficient. +-- | As `jvp2`, but accepts a vector of seed values. Semantically equivalent to +-- mapping, but may be more efficient. If used with `#[unroll]`, tangent +-- calculations are unrolled when possible. def jvp2_vec 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = intrinsics.jvp2_vec f x x' --- | As `vjp2`, but accepts a vector of seed values. Semantically --- equivalent to mapping, but may be more efficient. +-- | As `vjp2`, but accepts a vector of seed values. Semantically equivalent to +-- mapping, but may be more efficient. If used with `#[unroll]`, adjoint +-- calculations are unrolled when possible. def vjp2_vec 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) = intrinsics.vjp2_vec f x y' diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 48e1531e79..4667a1a4c7 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -57,12 +57,18 @@ data RState = RState stateNameSource :: VNameSource } -newtype ADM a = ADM (BuilderT SOACS (ReaderT Shape (State RState)) a) +data FEnv = FEnv + { envTanShape :: Shape, + envAttrs :: Attrs + } + +newtype ADM a = ADM (BuilderT SOACS (ReaderT FEnv (State RState)) a) deriving ( Functor, Applicative, Monad, MonadState RState, + MonadReader FEnv, MonadFreshNames, HasScope SOACS, LocalScope SOACS @@ -82,14 +88,17 @@ instance MonadFreshNames (State RState) where putNameSource src = modify (\env -> env {stateNameSource = src}) askShape :: ADM Shape -askShape = ADM $ lift ask +askShape = ADM $ lift $ asks envTanShape -runADM :: (MonadFreshNames m) => Shape -> ADM a -> m a -runADM shape (ADM m) = +runADM :: (MonadFreshNames m) => Shape -> Attrs -> ADM a -> m a +runADM shape attrs (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState - (runReaderT (fst <$> runBuilderT m mempty) shape) + ( runReaderT + (fst <$> runBuilderT m mempty) + (FEnv shape attrs) + ) (RState mempty vn) tanVName :: VName -> ADM VName @@ -247,6 +256,31 @@ withAnyTans xs f = do pure $ primExpFromSubExp t se toExp $ f $ toList xs_tan'' +bindTanPat :: Pat Type -> StmAux () -> Exp SOACS -> ADM () +bindTanPat pat_tan aux e = do + attrs <- asks envAttrs + auxing aux . attributing attrs . letBind pat_tan $ e + +bindTan :: + Pat Type -> + StmAux () -> + SubExp -> + (SubExp -> ADM (Exp SOACS)) -> + ADM () +bindTan pat_tan aux x f = do + bindTanPat pat_tan aux =<< withTan x f + +bindTans :: + Pat Type -> + StmAux () -> + PrimType -> + SubExp -> + SubExp -> + (PrimExp VName -> PrimExp VName -> PrimExp VName) -> + ADM () +bindTans pat_tan aux t x y f = do + bindTanPat pat_tan aux =<< withTans t x y f + basicFwd :: Pat Type -> StmAux () -> BasicOp -> ADM () basicFwd pat aux op = do pat_tan <- newTanPat pat @@ -272,20 +306,19 @@ basicFwd pat aux op = do let t = unOpType unop x_pe = primExpFromSubExp t x dx = pdUnOp unop x_pe - auxing aux $ letBind pat_tan <=< withTan x $ \x_tan -> + bindTan pat_tan aux x $ \x_tan -> toExp $ primExpFromSubExp t x_tan ~*~ dx BinOp bop x y -> do let t = binOpType bop - auxing aux . letBind pat_tan <=< withTans t x y $ - \x_tan y_tan -> - let (wrt_x, wrt_y) = - pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) - in x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y + bindTans pat_tan aux t x y $ \x_tan y_tan -> + let (wrt_x, wrt_y) = + pdBinOp bop (primExpFromSubExp t x) (primExpFromSubExp t y) + in x_tan ~*~ wrt_x ~+~ y_tan ~*~ wrt_y CmpOp {} -> do tan_shape <- askShape addStm $ Let pat_tan aux $ zeroExp $ Prim Bool `arrayOfShape` tan_shape ConvOp cop x -> - auxing aux $ letBind pat_tan <=< withTan x $ \x_tan -> + bindTan pat_tan aux x $ \x_tan -> pure $ BasicOp $ ConvOp cop x_tan Assert {} -> pure () Index arr slice -> do @@ -314,7 +347,7 @@ basicFwd pat aux op = do addStm . Let pat_tan aux . BasicOp $ Replicate (shape <> Shape [n]) (intConst it 0) Replicate n x -> - auxing aux $ letBind pat_tan <=< withTan x $ \x_tan -> + bindTan pat_tan aux x $ \x_tan -> pure $ BasicOp $ Replicate n x_tan Scratch t shape -> do tan_shape <- askShape @@ -577,10 +610,11 @@ fwdJVP :: (MonadFreshNames m) => Scope SOACS -> Shape -> + Attrs -> Lambda SOACS -> m (Lambda SOACS) -fwdJVP scope shape (Lambda params _ body) = - runADM shape . localScope scope $ do +fwdJVP scope shape attrs (Lambda params _ body) = + runADM shape attrs . localScope scope $ do params_tan <- mapM newTan params mkLambda (params <> params_tan) $ bodyBind =<< fwdBodyTansLast body diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index 05e9a83635..f5d952ea60 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -381,9 +381,15 @@ diffLambda res_adjs get_adjs_for (Lambda params _ body) = res <- bodyBind =<< diffBody res_adjs get_adjs_for body pure $ takeLast (length get_adjs_for) res -revVJP :: (MonadFreshNames m) => Scope SOACS -> Shape -> Lambda SOACS -> m (Lambda SOACS) -revVJP scope shape (Lambda params ts body) = do - runADM shape . localScope (scope <> scopeOfLParams params) $ do +revVJP :: + (MonadFreshNames m) => + Scope SOACS -> + Shape -> + Attrs -> + Lambda SOACS -> + m (Lambda SOACS) +revVJP scope shape attrs (Lambda params ts body) = do + runADM shape attrs . localScope (scope <> scopeOfLParams params) $ do adj_shape <- askShape params_adj <- forM (zip (map resSubExp (bodyResult body)) ts) $ \(se, t) -> Param mempty diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 8f2aeb0deb..5f3dca1632 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -9,6 +9,7 @@ module Futhark.AD.Rev.Monad ( ADM, RState (..), + REnv, runADM, Adj (..), InBounds (..), @@ -204,13 +205,18 @@ data RState = RState stateNameSource :: VNameSource } -newtype ADM a = ADM (BuilderT SOACS (ReaderT Shape (State RState)) a) +data REnv = REnv + { envAdjShape :: Shape, + envAttrs :: Attrs + } + +newtype ADM a = ADM (BuilderT SOACS (ReaderT REnv (State RState)) a) deriving ( Functor, Applicative, Monad, MonadState RState, - MonadReader Shape, + MonadReader REnv, MonadFreshNames, HasScope SOACS, LocalScope SOACS @@ -230,14 +236,16 @@ instance MonadFreshNames (State RState) where putNameSource src = modify (\env -> env {stateNameSource = src}) askShape :: ADM Shape -askShape = ADM $ lift ask +askShape = ADM $ lift $ asks envAdjShape -runADM :: (MonadFreshNames m) => Shape -> ADM a -> m a -runADM shape (ADM m) = +runADM :: (MonadFreshNames m) => Shape -> Attrs -> ADM a -> m a +runADM shape attrs (ADM m) = modifyNameSource $ \vn -> second stateNameSource $ runState - (runReaderT (fst <$> runBuilderT m mempty) shape) + ( runReaderT (fst <$> runBuilderT m mempty) $ + REnv shape attrs + ) (RState mempty mempty mempty vn) adjVal :: Adj -> ADM VName @@ -434,12 +442,14 @@ updateAdjIndex v (check, i) se = do Acc {} -> do let stms s = do vec_shape <- askShape + attrs <- asks envAttrs dims <- arrayDims <$> lookupType se_v ~[v_adj'] <- - tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> do - let (vec_is, val_is) = splitAt (shapeRank vec_shape) $ map Var is - letTupExp "acc" . BasicOp $ - UpdateAcc s v_adj' (vec_is ++ i : val_is) [Var se_v'] + attributing attrs $ + tabNest (length dims) [se_v, v_adj] $ \is [se_v', v_adj'] -> do + let (vec_is, val_is) = splitAt (shapeRank vec_shape) $ map Var is + letTupExp "acc" . BasicOp $ + UpdateAcc s v_adj' (vec_is ++ i : val_is) [Var se_v'] pure v_adj' case check of CheckBounds _ -> stms Safe @@ -590,7 +600,7 @@ locallyNonvectorised e m = do e_adjs_vals <- mapM lookupAdjVal e_adjs e_free_adjs <- mkMap "nonvec_adj" e_adjs_vals $ \e_adjs_vals' -> do zipWithM_ insAdj e_adjs e_adjs_vals' - local (const mempty) m + local (\env -> env {envAdjShape = mempty}) m mapM lookupAdjVal e_free zipWithM_ insAdj e_free e_free_adjs where diff --git a/src/Futhark/Pass/AD.hs b/src/Futhark/Pass/AD.hs index 5844fa4cc7..0245a012c1 100644 --- a/src/Futhark/Pass/AD.hs +++ b/src/Futhark/Pass/AD.hs @@ -49,7 +49,7 @@ onStm mode scope (Let pat aux (Op (JVP shape args vec lam))) = do lam' <- onLambda mode scope lam if mode == All || lam == lam' then do - lam'' <- fwdJVP scope shape lam' + lam'' <- fwdJVP scope shape (stmAuxAttrs aux) lam' runBuilderT_ (bindLambda pat aux lam'' $ args ++ vec) scope else pure $ oneStm $ Let pat aux $ Op $ JVP shape args vec lam' onStm mode scope (Let pat aux e) = oneStm . Let pat aux <$> mapExpM mapper e diff --git a/tests/ad/vec/map6.fut b/tests/ad/vec/map6.fut index 8a3b010171..fb86d14dd9 100644 --- a/tests/ad/vec/map6.fut +++ b/tests/ad/vec/map6.fut @@ -28,10 +28,11 @@ entry fwd_map (x: [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) entry fwd_vec (x: [8]f64) = + #[unroll] jvp_vec obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) entry rev_map (x: [8]f64) = transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) entry rev_vec (x: [8]f64) = - transpose (vjp_vec obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) + transpose (#[unroll] vjp_vec obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) From 9679a74b2834c5a5bc370df776a5295bc3d50b4b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 29 May 2026 11:50:24 +0200 Subject: [PATCH 31/70] Extend AD vectorized entry coverage for additional `tests/ad` cases (#2472) * Add jvp_vec/vjp_vec entry points to tests/ad/ test files Add vectorised AD entry points (fwd_vec/rev_vec) to all pertinent test programs in tests/ad/ that currently compare jvp and vjp Jacobians. The new entry points use jvp_vec/vjp_vec to compute the same Jacobians without explicit loops, and are added to the same test directive so they must produce the same expected output. * Fix structure tests. * Add vectorized AD entries for additional autodiff tests --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Troels Henriksen --- tests/ad/consume0.fut | 12 ++++++++++- tests/ad/consume1.fut | 12 ++++++++++- tests/ad/for1.fut | 11 +++++++++- tests/ad/for2.fut | 9 +++++++- tests/ad/for3.fut | 11 +++++++++- tests/ad/gather0.fut | 10 ++++++++- tests/ad/gather1.fut | 15 ++++++++++++- tests/ad/gather2.fut | 13 ++++++++++- tests/ad/issue2256.fut | 6 +++++- tests/ad/map2.fut | 6 +++++- tests/ad/map3.fut | 6 +++++- tests/ad/map4.fut | 14 +++++++++++- tests/ad/map5.fut | 10 ++++++++- tests/ad/map6.fut | 10 ++++++++- tests/ad/map7.fut | 6 +++++- tests/ad/matmul.fut | 15 ++++++++++++- tests/ad/maximum.fut | 8 +++++-- tests/ad/minimum.fut | 6 +++++- tests/ad/minmax.fut | 8 +++++-- tests/ad/reduce-vec-minmax0.fut | 18 +++++++++++++++- tests/ad/reduce2.fut | 5 ++++- tests/ad/reduce_by_index0.fut | 7 +++++- tests/ad/reducebyindex3.fut | 6 +++++- tests/ad/reducebyindex4.fut | 6 +++++- tests/ad/reducebyindexminmax3.fut | 5 ++++- tests/ad/reducebyindexminmax4.fut | 5 ++++- tests/ad/reducebyindexminmax7.fut | 16 +++++++++++++- tests/ad/reducebyindexminmax8.fut | 16 +++++++++++++- tests/ad/reducemul0.fut | 6 +++++- tests/ad/reducemul4.fut | 10 ++++++++- tests/ad/reducevec0.fut | 20 ++++++++++++++++- tests/ad/scan0.fut | 10 ++++++++- tests/ad/scan1.fut | 10 ++++++++- tests/ad/scan2.fut | 10 ++++++++- tests/ad/scan3.fut | 18 +++++++++++++++- tests/ad/scan4.fut | 15 ++++++++++++- tests/ad/scan5.fut | 10 ++++++++- tests/ad/scan6.fut | 36 +++++++++++++++++++++++++++++-- tests/ad/scan7.fut | 23 +++++++++++++++++++- tests/ad/scan8.fut | 18 +++++++++++++++- tests/ad/scan9.fut | 18 +++++++++++++++- tests/ad/scatter0.fut | 10 ++++++++- tests/ad/scatter1.fut | 10 ++++++++- tests/ad/stripmine1.fut | 11 +++++++++- tests/ad/stripmine2.fut | 11 +++++++++- tests/ad/sum.fut | 6 +++++- tests/ad/truedep0.fut | 11 +++++++++- 47 files changed, 485 insertions(+), 50 deletions(-) diff --git a/tests/ad/consume0.fut b/tests/ad/consume0.fut index 44eb00a54a..4df14e3ed6 100644 --- a/tests/ad/consume0.fut +++ b/tests/ad/consume0.fut @@ -1,7 +1,7 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec rev_vec -- input { [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } @@ -15,3 +15,13 @@ entry fwd [n] (xs: *[n]f64) = entry rev [n] (xs: *[n]f64) = #[unsafe] tabulate n (\i -> vjp f xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec f xs seeds + +entry rev_vec [n] (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec f xs seeds diff --git a/tests/ad/consume1.fut b/tests/ad/consume1.fut index 66e250d0f4..f5cd718edf 100644 --- a/tests/ad/consume1.fut +++ b/tests/ad/consume1.fut @@ -1,7 +1,7 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec rev_vec -- input { true [1.0,2.0,3.0] } -- output { [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] } @@ -16,3 +16,13 @@ entry fwd [n] b (xs: *[n]f64) = entry rev [n] b (xs: *[n]f64) = #[unsafe] tabulate n (\i -> vjp (f b) xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] b (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (f b) xs seeds + +entry rev_vec [n] b (xs: *[n]f64) = + #[unsafe] + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec (f b) xs seeds diff --git a/tests/ad/for1.fut b/tests/ad/for1.fut index 799fe893ff..098079ef05 100644 --- a/tests/ad/for1.fut +++ b/tests/ad/for1.fut @@ -12,7 +12,7 @@ def pow_list [n] y (xs: [n]i32) = entry prim y xs = pow_list y xs -- == --- entry: f_vjp f_jvp +-- entry: f_vjp f_jvp f_vjp_vec f_jvp_vec -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], @@ -23,3 +23,12 @@ entry f_jvp [n] y (xs: [n]i32) = entry f_vjp [n] y (xs: [n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) + +entry f_jvp_vec [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (pow_list y) xs seeds + |> transpose + +entry f_vjp_vec [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec (pow_list y) xs seeds diff --git a/tests/ad/for2.fut b/tests/ad/for2.fut index 474fbaae88..6e56a84d8f 100644 --- a/tests/ad/for2.fut +++ b/tests/ad/for2.fut @@ -12,9 +12,16 @@ def mult_list xs = entry prim = mult_list -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec -- input { [11,5,13] } output { [0,0,26] } entry f_jvp [n] (xs: [n]i32) = tabulate n (\i -> jvp mult_list xs (replicate n 0 with [i] = 1)) entry f_vjp [n] (xs: [n]i32) = vjp mult_list xs 1 + +entry f_jvp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec mult_list xs seeds + +entry f_vjp_vec [n] (xs: [n]i32) = + (vjp_vec mult_list xs [1])[0] diff --git a/tests/ad/for3.fut b/tests/ad/for3.fut index c7d02db01c..b3755c8178 100644 --- a/tests/ad/for3.fut +++ b/tests/ad/for3.fut @@ -14,7 +14,7 @@ def square [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = square xs -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], @@ -27,3 +27,12 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) + +entry f_jvp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec square xs seeds + |> transpose + +entry f_vjp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec square xs seeds diff --git a/tests/ad/gather0.fut b/tests/ad/gather0.fut index ebc7c9905a..226622aa14 100644 --- a/tests/ad/gather0.fut +++ b/tests/ad/gather0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } -- output { [[1.0, 0.0, 0.0, 0.0], -- [0.0, 1.0, 0.0, 0.0], @@ -21,3 +21,11 @@ entry fwd_J [n] [m] (xs: [n]f64) (is: [m]i64) = entry rev_J [n] [m] (xs: [n]f64) (is: [m]i64) = tabulate m (\j -> vjp (`gather` is) xs (replicate m 0 with [j] = 1)) + +entry fwd_vec_J [n] [m] (xs: [n]f64) (is: [m]i64) = + let seeds = tabulate n (\j -> replicate n 0 with [j] = 1) + in transpose (jvp_vec (`gather` is) xs seeds) + +entry rev_vec_J [n] [m] (xs: [n]f64) (is: [m]i64) = + let seeds = tabulate m (\j -> replicate m 0 with [j] = 1) + in vjp_vec (`gather` is) xs seeds diff --git a/tests/ad/gather1.fut b/tests/ad/gather1.fut index bb9863fed8..34401c40f8 100644 --- a/tests/ad/gather1.fut +++ b/tests/ad/gather1.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [1i64, 0i64, 1i64, 1i64] @@ -43,3 +43,16 @@ entry fwd_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = entry rev_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = tabulate_2d n k (\i j -> vjp (`mapgather` is) xs (onehot_2d n k (i, j))) + +entry fwd_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = + let seeds = tabulate (n * m) (\p -> onehot_2d n m (p / m, p % m)) + in jvp_vec (`mapgather` is) xs seeds + |> unflatten + |> map transpose + |> map (map transpose) + |> map transpose + +entry rev_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = + let seeds = tabulate (n * k) (\p -> onehot_2d n k (p / k, p % k)) + in vjp_vec (`mapgather` is) xs seeds + |> unflatten diff --git a/tests/ad/gather2.fut b/tests/ad/gather2.fut index 3bc3efcc24..0dd24c82bd 100644 --- a/tests/ad/gather2.fut +++ b/tests/ad/gather2.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input -- { -- [1.0,2.0,3.0,4.0] @@ -32,3 +32,14 @@ def onehot_2d n m p : [n][m]f64 = entry rev_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = tabulate_2d n m (\i j -> vjp (`mapgather` iss) xs (onehot_2d n m (i, j))) + +entry fwd_vec_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = + let seeds = tabulate k (\i -> onehot k i) + in jvp_vec (`mapgather` iss) xs seeds + |> transpose + |> map transpose + +entry rev_vec_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = + let seeds = tabulate (n * m) (\p -> onehot_2d n m (p / m, p % m)) + in vjp_vec (`mapgather` iss) xs seeds + |> unflatten diff --git a/tests/ad/issue2256.fut b/tests/ad/issue2256.fut index 09fe505daa..512d78af2e 100644 --- a/tests/ad/issue2256.fut +++ b/tests/ad/issue2256.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [5040.0, 2521.0, 1684.0, 1278.0, 1104.0, 1440.0] } @@ -13,3 +13,7 @@ entry rev [m] (x: [m]f64) = entry fwd [m] (x: [m]f64) = tabulate m (\i -> jvp (\x' -> primal x') x (replicate m 0 with [i] = 1)) + +entry fwd_vec [m] (x: [m]f64) = + let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) + in jvp_vec (\x' -> primal x') x seeds diff --git a/tests/ad/map2.fut b/tests/ad/map2.fut index 6970cce527..2e3a5916c8 100644 --- a/tests/ad/map2.fut +++ b/tests/ad/map2.fut @@ -1,7 +1,7 @@ -- Map with free variable. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J rev_vec_J -- input { 2.0 [1.0,2.0,3.0] } -- output { [1.0,2.0,3.0] } @@ -13,3 +13,7 @@ entry fwd_J [n] (c: f64) (xs: [n]f64) = entry rev_J [n] (c: f64) (xs: [n]f64) = tabulate n (\i -> vjp (\c' -> map (* c') xs) c (onehot n i)) + +entry rev_vec_J [n] (c: f64) (xs: [n]f64) = + let seeds = tabulate n (\i -> onehot n i) + in vjp_vec (\c' -> map (* c') xs) c seeds diff --git a/tests/ad/map3.fut b/tests/ad/map3.fut index 5b817ac57d..36ae1e509a 100644 --- a/tests/ad/map3.fut +++ b/tests/ad/map3.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev rev_vec -- input { 1i32 [1i32,2i32,3i32] } -- output { [1i32,2i32,3i32] } @@ -10,3 +10,7 @@ entry fwd [n] (x: i32) (xs: [n]i32) = entry rev [n] (x: i32) (xs: [n]i32) = tabulate n (\i -> vjp (\x -> map (* x) xs) x (replicate n 0 with [i] = 1)) + +entry rev_vec [n] (x: i32) (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec (\x -> map (* x) xs) x seeds diff --git a/tests/ad/map4.fut b/tests/ad/map4.fut index aabbc4f02b..e42e5e902e 100644 --- a/tests/ad/map4.fut +++ b/tests/ad/map4.fut @@ -1,7 +1,7 @@ -- An array is both a 'map' input and a free variable in the lambda. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1,2,3] } -- output { -- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] @@ -24,3 +24,15 @@ entry fwd_J [n] (xs: [n]i32) = entry rev_J [n] (xs: [n]i32) = tabulate_2d n n (\i j -> vjp f xs (onehot_2d n n (i, j))) + +entry fwd_vec_J [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in jvp_vec f xs seeds + |> map transpose + |> transpose + |> map transpose + +entry rev_vec_J [n] (xs: [n]i32) = + let seeds = tabulate (n * n) (\k -> onehot_2d n n (k / n, k % n)) + in vjp_vec f xs seeds + |> unflatten diff --git a/tests/ad/map5.fut b/tests/ad/map5.fut index 15e6e8a28c..e90dd46328 100644 --- a/tests/ad/map5.fut +++ b/tests/ad/map5.fut @@ -1,7 +1,7 @@ -- Map with free array variable. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1,2,3],[4,5,6]] [0,0] } -- output { [[1, 0], [0, 1]] } @@ -16,3 +16,11 @@ entry fwd_J [n] [m] (free: [n][m]i32) (is: [n]i32) = entry rev_J [n] [m] (free: [n][m]i32) (is: [n]i32) = tabulate n (\i -> vjp (f free) is (onehot n i)) + +entry fwd_vec_J [n] [m] (free: [n][m]i32) (is: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in jvp_vec (f free) is seeds |> transpose + +entry rev_vec_J [n] [m] (free: [n][m]i32) (is: [n]i32) = + let seeds = tabulate n (\i -> onehot n i) + in vjp_vec (f free) is seeds diff --git a/tests/ad/map6.fut b/tests/ad/map6.fut index 8bf30b853e..f46bb5a984 100644 --- a/tests/ad/map6.fut +++ b/tests/ad/map6.fut @@ -1,7 +1,7 @@ -- #1878 -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [[0.0, 2.0, 3.0, 4.0], -- [0.0, 0.0, 1.0, 1.0], @@ -29,3 +29,11 @@ entry fwd_J (x: [8]f64) = entry rev_J (x: [8]f64) = transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) + +entry fwd_vec_J (x: [8]f64) = + let seeds = tabulate 8 (\i -> replicate 8 0 with [i] = 1) + in jvp_vec obj x seeds + +entry rev_vec_J (x: [8]f64) = + let seeds = tabulate 4 (\i -> replicate 4 0 with [i] = 1) + in transpose (vjp_vec obj x seeds) diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut index 6c9cf348ba..941d2d15b1 100644 --- a/tests/ad/map7.fut +++ b/tests/ad/map7.fut @@ -2,7 +2,7 @@ -- has active free variables. -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } -- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] } @@ -22,3 +22,7 @@ entry fwd_J (x: [8]f64) = entry rev_J (x: [8]f64) = vjp obj x 1 + +entry fwd_vec_J (x: [8]f64) = + let seeds = tabulate 8 (\i -> replicate 8 0 with [i] = 1) + in jvp_vec obj x seeds diff --git a/tests/ad/matmul.fut b/tests/ad/matmul.fut index 123eaa41b4..309096f811 100644 --- a/tests/ad/matmul.fut +++ b/tests/ad/matmul.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input -- { -- [[1.0,2.0],[3.0,4.0]] [[5.0,6.0],[7.0,8.0]] @@ -32,3 +32,16 @@ entry fwd_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = entry rev_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = tabulate_2d n p (\i j -> vjp (matmul xss) yss (onehot_2d n p (i, j))) + +entry fwd_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = + let seeds = tabulate (m * p) (\k -> onehot_2d m p (k / p, k % p)) + in jvp_vec (matmul xss) yss seeds + |> unflatten + |> transpose + |> map transpose + |> transpose + +entry rev_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = + let seeds = tabulate (n * p) (\k -> onehot_2d n p (k / p, k % p)) + in vjp_vec (matmul xss) yss seeds + |> unflatten diff --git a/tests/ad/maximum.fut b/tests/ad/maximum.fut index f605182d5b..2e6f149d9e 100644 --- a/tests/ad/maximum.fut +++ b/tests/ad/maximum.fut @@ -1,11 +1,11 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] } -- input { [1.0, 1.0] } -- output { [1.0, 0.0] } --- structure { /Screma 2 } +-- structure { /Screma 3 } def f = map f64.abs >-> f64.maximum @@ -14,3 +14,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp f xs (tabulate n ((== i) >-> f64.bool))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in jvp_vec f xs seeds diff --git a/tests/ad/minimum.fut b/tests/ad/minimum.fut index f8942cf184..a47e645895 100644 --- a/tests/ad/minimum.fut +++ b/tests/ad/minimum.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] } -- input { [1.0, 1.0] } @@ -11,3 +11,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp f64.minimum xs (tabulate n ((== i) >-> f64.bool))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in jvp_vec f64.minimum xs seeds diff --git a/tests/ad/minmax.fut b/tests/ad/minmax.fut index e71b69734e..17dabc103a 100644 --- a/tests/ad/minmax.fut +++ b/tests/ad/minmax.fut @@ -1,11 +1,11 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, -6.0, 5.0] } -- output { [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] -- [0.0, 0.0, 0.0, 0.0, 0.0, -1.0, 0.0] -- } --- structure { /Screma 2 } +-- structure { /Screma 3 } def f xs = let ys = map f64.abs xs @@ -18,3 +18,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = unzip (tabulate n (\i -> jvp f xs (tabulate n ((== i) >-> f64.bool)))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in unzip (jvp_vec f xs seeds) diff --git a/tests/ad/reduce-vec-minmax0.fut b/tests/ad/reduce-vec-minmax0.fut index d846d4dec5..ba6572a06e 100644 --- a/tests/ad/reduce-vec-minmax0.fut +++ b/tests/ad/reduce-vec-minmax0.fut @@ -14,8 +14,24 @@ def forward [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = def reverse [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = tabulate n (\i -> vjp redmap arr (replicate n 0 with [i] = 1)) +def forward_vec [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = + let seeds = tabulate (m * n) (\p -> + let i = p / n + let j = p % n + in replicate m (replicate n 0) with [i] = (replicate n 0 with [j] = 1)) + in jvp_vec redmap arr seeds + |> unflatten + |> transpose + +def reverse_vec [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec redmap arr seeds + def main [n] [m] (arr: [m][n]f32) : bool = let l = n * m * n let fs = forward arr |> flatten_3d :> [l]f32 let rs = reverse arr |> flatten_3d :> [l]f32 - in reduce (&&) true (map2 (\i j -> f32.abs (i - j) < 0.0001f32) fs rs) + let fvs = forward_vec arr |> flatten_3d :> [l]f32 + let rvs = reverse_vec arr |> flatten_3d :> [l]f32 + let close xs ys = reduce (&&) true (map2 (\i j -> f32.abs (i - j) < 0.0001f32) xs ys) + in close fs rs && close fs fvs && close rs rvs diff --git a/tests/ad/reduce2.fut b/tests/ad/reduce2.fut index 43d2c85f16..4a68f1761b 100644 --- a/tests/ad/reduce2.fut +++ b/tests/ad/reduce2.fut @@ -1,7 +1,7 @@ -- Result of one reduction is used free in a map. -- == -- tags { no_ispc autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec -- input { [3f64, 1f64, 5f64] } output { [-1.000000f64, -1.000000f64, -1.000000f64] } def sumBy 'a (f: a -> f64) (xs: []a) : f64 = map f xs |> f64.sum @@ -14,3 +14,6 @@ def f (arr: []f64) = entry fwd x = map (jvp f x) [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] entry rev x = vjp f x 1f64 + +entry fwd_vec x = + jvp_vec f x [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] diff --git a/tests/ad/reduce_by_index0.fut b/tests/ad/reduce_by_index0.fut index 2f8471250b..7663a48a69 100644 --- a/tests/ad/reduce_by_index0.fut +++ b/tests/ad/reduce_by_index0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: f_jvp +-- entry: f_jvp f_jvp_vec -- input { [0i64,1i64,2i64,3i64] [1f64,2f64,3f64,4f64] } -- output { [[1f64,0f64,0f64,0f64],[0f64,1f64,0f64,0f64],[0f64,0f64,1f64,0f64],[0f64,0f64,0f64,1f64]] } def f [n] (is: [n]i64) (vs: [n]f64) = @@ -9,3 +9,8 @@ def f [n] (is: [n]i64) (vs: [n]f64) = entry f_jvp [n] (is: [n]i64) (vs: [n]f64) = tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) |> transpose + +entry f_jvp_vec [n] (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (f is) vs seeds + |> transpose diff --git a/tests/ad/reducebyindex3.fut b/tests/ad/reducebyindex3.fut index 3cc5c18287..dff88bc4b8 100644 --- a/tests/ad/reducebyindex3.fut +++ b/tests/ad/reducebyindex3.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { -- [0i64,1i64,2i64,1i64,0i64,1i64,2i64] -- [1f64,2f64,3f64,4f64,5f64,6f64,7f64] } @@ -18,6 +18,10 @@ entry f [n] (is: [n]i64) (vs: [n]f64) = entry rev [n] (is: [n]i64) (vs: [n]f64) = tabulate 4 (\i -> vjp (f is) vs (replicate 4 0 with [i] = 1)) +entry rev_vec [n] (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate 4 (\i -> replicate 4 0 with [i] = 1) + in vjp_vec (f is) vs seeds + -- entry fwd [n] (is: [n]i64) (vs: [n]f64) = -- tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) -- |> map (.1) |> transpose diff --git a/tests/ad/reducebyindex4.fut b/tests/ad/reducebyindex4.fut index ea7262c188..ba36aac019 100644 --- a/tests/ad/reducebyindex4.fut +++ b/tests/ad/reducebyindex4.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { -- [ 0i64, 1i64, 2i64, 1i64, 0i64, 1i64, 2i64, 1i64, 0i64] -- [ 1f32, 2f32, 3f32, 4f32, 5f32, 6f32, 7f32, 8f32, 9f32] @@ -19,6 +19,10 @@ entry rev [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = vjp (f is) (zip vs0 vs1) (replicate 4 (1, 1)) |> unzip +entry rev_vec [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = + (vjp_vec (f is) (zip vs0 vs1) [replicate 4 (1, 1)])[0] + |> unzip + entry fwd [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = tabulate n (\i -> jvp (f is) (zip vs0 vs1) (replicate n (0, 0) with [i] = (1, 1))) |> transpose diff --git a/tests/ad/reducebyindexminmax3.fut b/tests/ad/reducebyindexminmax3.fut index 9ba08969dc..0dfb533ed0 100644 --- a/tests/ad/reducebyindexminmax3.fut +++ b/tests/ad/reducebyindexminmax3.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [0f32,0f32,0f32] 5f32 } -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } @@ -14,4 +14,7 @@ def red_max [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32, c: f32) = entry rev [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max dst is) (vs, c) (replicate m 0 with [0] = 1) +entry rev_vec [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = + (vjp_vec (red_max dst is) (vs, c) [replicate m 0 with [0] = 1])[0] + --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindexminmax4.fut b/tests/ad/reducebyindexminmax4.fut index efd0acb266..3954e2a5cc 100644 --- a/tests/ad/reducebyindexminmax4.fut +++ b/tests/ad/reducebyindexminmax4.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev +-- entry: rev rev_vec -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [4f32,3f32,2f32] 3f32 } -- output { [3f32,0f32,0f32] 5f32 } -- input { [5f32,1f32,2f32] [0i64,0i64,0i64] [10f32,3f32,2f32] 3f32 } @@ -14,4 +14,7 @@ def red_max [n] [m] (vs: [n]f32) (is: [n]i64) (dst: [m]f32, c: f32) = entry rev [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max vs is) (dst, c) (replicate m 0 with [0] = 1) +entry rev_vec [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = + (vjp_vec (red_max vs is) (dst, c) [replicate m 0 with [0] = 1])[0] + --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 94fd6b62d1..28495825af 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -12,8 +12,22 @@ def fwd [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate n (\i -> jvp (primal is dst) vs (replicate n (replicate k 0) with [i] = replicate k 1)) |> transpose +def rev_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = + let seeds = tabulate m (\i -> replicate m (replicate k 0) with [i] = replicate k 1) + in vjp_vec (primal is dst) vs seeds + +def fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in jvp_vec (primal is dst) vs seeds + |> transpose + def main [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let is = map (\i -> (i64.abs i) %% m) is' let r = rev is dst vs let f = fwd is dst vs - in map2 (map2 (==)) r f |> map (reduce (&&) true) |> reduce (&&) true + let rv = rev_vec is dst vs + let fv = fwd_vec is dst vs + let eq_rf = map2 (map2 (==)) r f |> map (reduce (&&) true) |> reduce (&&) true + let eq_rrv = map2 (map2 (==)) r rv |> map (reduce (&&) true) |> reduce (&&) true + let eq_ffv = map2 (map2 (==)) f fv |> map (reduce (&&) true) |> reduce (&&) true + in eq_rf && eq_rrv && eq_ffv diff --git a/tests/ad/reducebyindexminmax8.fut b/tests/ad/reducebyindexminmax8.fut index dc4273fd9b..6321dc3295 100644 --- a/tests/ad/reducebyindexminmax8.fut +++ b/tests/ad/reducebyindexminmax8.fut @@ -12,8 +12,22 @@ def fwd2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = tabulate n (\i -> jvp (primal2 is dst) vs (replicate n (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1))) |> transpose +def rev_vec2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = + let seeds = tabulate m (\i -> replicate m (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1)) + in vjp_vec (primal2 is dst) vs seeds + +def fwd_vec2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1)) + in jvp_vec (primal2 is dst) vs seeds + |> transpose + def main [n] [m] [k] [l] (is': [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = let is = map (\i -> (i64.abs i) %% m) is' let r = rev2 is dst vs let f = fwd2 is dst vs - in map2 (map2 (map2 (==))) r f |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + let rv = rev_vec2 is dst vs + let fv = fwd_vec2 is dst vs + let eq_rf = map2 (map2 (map2 (==))) r f |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + let eq_rrv = map2 (map2 (map2 (==))) r rv |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + let eq_ffv = map2 (map2 (map2 (==))) f fv |> map (map (reduce (&&) true)) |> map (reduce (&&) true) |> reduce (&&) true + in eq_rf && eq_rrv && eq_ffv diff --git a/tests/ad/reducemul0.fut b/tests/ad/reducemul0.fut index 406554a5a1..09dd2f5743 100644 --- a/tests/ad/reducemul0.fut +++ b/tests/ad/reducemul0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [0.0f32, 2.0f32, 0.0f32, 4.0f32] } output { [0.0f32, 0.0f32, 0.0f32, 0.0f32] } def red_mult [n] (xs: [n]f32) : f32 = @@ -11,3 +11,7 @@ entry rev [n] (xs: [n]f32) = entry fwd [n] (xs: [n]f32) = tabulate n (\i -> jvp red_mult xs (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec red_mult xs seeds diff --git a/tests/ad/reducemul4.fut b/tests/ad/reducemul4.fut index db067c6f45..4408becbc9 100644 --- a/tests/ad/reducemul4.fut +++ b/tests/ad/reducemul4.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [1f32, 2f32, 3f32, 4f32] } output { [[48f32, 12f32, 8f32, 6f32], [48f32, 48f32, 16f32, 12f32], [72f32, 36f32, 48f32, 18f32], [96f32, 48f32, 32f32, 48f32]] } def fun [n] (as: [n]f32) = @@ -13,3 +13,11 @@ entry fwd [n] (as: [n]f32) = entry rev [n] (as: [n]f32) = tabulate n (\i -> vjp fun as (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (as: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec fun as seeds |> transpose + +entry rev_vec [n] (as: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec fun as seeds diff --git a/tests/ad/reducevec0.fut b/tests/ad/reducevec0.fut index d781fb4cd4..fc2ad0c0f5 100644 --- a/tests/ad/reducevec0.fut +++ b/tests/ad/reducevec0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec rev_vec -- input { -- [[[0f32,1f32],[2f32,3f32]], -- [[5f32,1f32],[3f32,0f32]], @@ -17,3 +17,21 @@ entry fwd [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = tabulate_3d n m k (\i j l -> jvp f xs (replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [l] = 1)))) |> transpose |> map transpose + +entry fwd_vec [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = + let seeds = tabulate (n * m * k) (\p -> + let i = p / (m * k) + let j = (p % (m * k)) / k + let l = p % k + in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [l] = 1))) + let res = jvp_vec f xs seeds + in unflatten (sized (n * (m * k)) res) |> map unflatten + |> transpose + |> map transpose + +entry rev_vec [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = + let seeds = tabulate (m * k) (\p -> + let i = p / k + let j = p % k + in replicate m (replicate k 0) with [i] = (replicate k 0 with [j] = 1)) + in unflatten (vjp_vec f xs seeds) diff --git a/tests/ad/scan0.fut b/tests/ad/scan0.fut index b044a5c8d4..404687a08e 100644 --- a/tests/ad/scan0.fut +++ b/tests/ad/scan0.fut @@ -2,7 +2,7 @@ -- generic case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [2.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], @@ -17,3 +17,11 @@ entry fwd_J [n] (a: [n]f32) = entry rev_J [n] (a: [n]f32) = tabulate n (\i -> vjp (scan (*) 1) a (replicate n 0 with [i] = 1)) + +entry fwd_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (scan (*) 1) a seeds |> transpose + +entry rev_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec (scan (*) 1) a seeds diff --git a/tests/ad/scan1.fut b/tests/ad/scan1.fut index 00138c2eea..73ac81b592 100644 --- a/tests/ad/scan1.fut +++ b/tests/ad/scan1.fut @@ -2,7 +2,7 @@ -- addition special case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } -- output { [[1.0f32, 0.0f32, 0.0f32, 0.0f32, 0.0f32], -- [1.0f32, 1.0f32, 0.0f32, 0.0f32, 0.0f32], @@ -17,3 +17,11 @@ entry fwd_J [n] (a: [n]f32) = entry rev_J [n] (a: [n]f32) = tabulate n (\i -> vjp (scan (+) 0) a (replicate n 0 with [i] = 1)) + +entry fwd_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (scan (+) 0) a seeds |> transpose + +entry rev_vec_J [n] (a: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec (scan (+) 0) a seeds diff --git a/tests/ad/scan2.fut b/tests/ad/scan2.fut index b27466c516..1ce55e3ec9 100644 --- a/tests/ad/scan2.fut +++ b/tests/ad/scan2.fut @@ -2,7 +2,7 @@ -- special cases: vectorised and addition -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { [[[1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [0.000000f32, 0.000000f32]], [[1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32], [1.000000f32, 1.000000f32]]] } @@ -15,3 +15,11 @@ entry fwd_J [n] [k] (a: [n][k]f32) = entry rev_J [n] [k] (a: [n][k]f32) = tabulate n (\i -> vjp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) + +entry fwd_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in jvp_vec primal a seeds |> transpose + +entry rev_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in vjp_vec primal a seeds diff --git a/tests/ad/scan3.fut b/tests/ad/scan3.fut index 0de395e9d6..ac07c3595c 100644 --- a/tests/ad/scan3.fut +++ b/tests/ad/scan3.fut @@ -2,7 +2,7 @@ -- MatrixMul case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } -- output { -- [[[[1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32]], @@ -53,3 +53,19 @@ entry rev_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = in tabulate (n * 4) (\i -> vjp primal input (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) |> unflatten |> map (map toarrs) + +entry fwd_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> fromarrs (onehot_2d n 4 (i / 4) (i % 4))) + in jvp_vec primal input seeds + |> map toarrs + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> fromarrs (onehot_2d n 4 (i / 4) (i % 4))) + in vjp_vec primal input seeds + |> unflatten + |> map (map toarrs) diff --git a/tests/ad/scan4.fut b/tests/ad/scan4.fut index 71343e2d7a..8d95b8440c 100644 --- a/tests/ad/scan4.fut +++ b/tests/ad/scan4.fut @@ -2,7 +2,7 @@ -- ZeroQuadrant case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1.0f32, 2.0f32, 3.0f32], [4.0f32, 3.0f32, 5.0f32], [3.0f32, 4.0f32, 2.0f32], [4.0f32, 2.0f32, 1.0f32]] } -- output { -- [[[1f32, 1f32, 1f32], [0f32, 0f32, 0f32], @@ -31,3 +31,16 @@ entry rev_J [n] (input: [n][3]f32) = let input = fromarrs input in tabulate n (\i -> vjp primal input (replicate n (0, 0, 0) with [i] = (1, 1, 1))) |> map toarrs + +entry fwd_vec_J [n] (input: [n][3]f32) = + let input = fromarrs input + let seeds = tabulate n (\i -> replicate n (0, 0, 0) with [i] = (1, 1, 1)) + in jvp_vec primal input seeds + |> map toarrs + |> transpose + +entry rev_vec_J [n] (input: [n][3]f32) = + let input = fromarrs input + let seeds = tabulate n (\i -> replicate n (0, 0, 0) with [i] = (1, 1, 1)) + in vjp_vec primal input seeds + |> map toarrs diff --git a/tests/ad/scan5.fut b/tests/ad/scan5.fut index 6b55cd7b13..05295e13b5 100644 --- a/tests/ad/scan5.fut +++ b/tests/ad/scan5.fut @@ -2,7 +2,7 @@ -- Vectorised special case + generic case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32,1f32],[2f32,2f32],[3f32,3f32],[4f32,4f32],[5f32,5f32]] } -- output { -- [[[1f32, 1f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], @@ -21,3 +21,11 @@ entry fwd_J [n] [k] (a: [n][k]f32) = entry rev_J [n] [k] (a: [n][k]f32) = tabulate n (\i -> vjp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) + +entry fwd_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in jvp_vec primal a seeds |> transpose + +entry rev_vec_J [n] [k] (a: [n][k]f32) = + let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) + in vjp_vec primal a seeds diff --git a/tests/ad/scan6.fut b/tests/ad/scan6.fut index ead433518a..33be743c1a 100644 --- a/tests/ad/scan6.fut +++ b/tests/ad/scan6.fut @@ -2,7 +2,7 @@ -- MatrixMul case -- == -- tags { autodiff } --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[1f32, 2f32], [4f32, 3f32], [3f32, 4f32], [4f32, 2f32]] } -- output { -- [[[[1f32, 0f32], [0f32, 0f32], [0f32, 0f32], [0f32, 0f32]], @@ -39,7 +39,7 @@ entry rev_J [n] (input: [n][2]f32) = |> map (map toarrs) -- == --- entry: fwd_J2 rev_J2 +-- entry: fwd_J2 rev_J2 fwd_vec_J2 rev_vec_J2 -- no_oclgrind input { [[1f32,2f32,3f32,4f32,5f32,6f32],[6f32,5f32,4f32,3f32,2f32,1f32],[4f32,5f32,6f32,1f32,2f32,3f32],[3f32,2f32,1f32,6f32,5f32,4f32]] } -- output { [[[[1f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 1f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 1f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 0f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[4f32, 3f32, 0f32, 0f32, 0f32, 0f32], [1f32, 0f32, 1f32, 2f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[2f32, 1f32, 0f32, 0f32, 0f32, 0f32], [0f32, 1f32, 0f32, 0f32, 1f32, 2f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 4f32, 0f32, 3f32, 0f32], [0f32, 0f32, 3f32, 5f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 4f32, 0f32, 3f32], [0f32, 0f32, 4f32, 6f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 2f32, 0f32, 1f32, 0f32], [0f32, 0f32, 0f32, 0f32, 3f32, 5f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 2f32, 0f32, 1f32], [0f32, 0f32, 0f32, 0f32, 4f32, 6f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[26f32, 19f32, 0f32, 0f32, 0f32, 0f32], [6f32, 1f32, 6f32, 12f32, 1f32, 2f32], [1f32, 0f32, 16f32, 9f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[14f32, 9f32, 0f32, 0f32, 0f32, 0f32], [2f32, 3f32, 2f32, 4f32, 3f32, 6f32], [0f32, 1f32, 0f32, 0f32, 16f32, 9f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 26f32, 0f32, 19f32, 0f32], [0f32, 0f32, 18f32, 30f32, 3f32, 5f32], [0f32, 0f32, 27f32, 11f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 26f32, 0f32, 19f32], [0f32, 0f32, 24f32, 36f32, 4f32, 6f32], [0f32, 0f32, 34f32, 14f32, 0f32, 0f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 14f32, 0f32, 9f32, 0f32], [0f32, 0f32, 6f32, 10f32, 9f32, 15f32], [0f32, 0f32, 0f32, 0f32, 27f32, 11f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 14f32, 0f32, 9f32], [0f32, 0f32, 8f32, 12f32, 12f32, 18f32], [0f32, 0f32, 0f32, 0f32, 34f32, 14f32], [0f32, 0f32, 0f32, 0f32, 0f32, 0f32]]], [[[110f32, 73f32, 0f32, 0f32, 0f32, 0f32], [18f32, 19f32, 18f32, 36f32, 19f32, 38f32], [1f32, 6f32, 16f32, 9f32, 96f32, 54f32], [1f32, 0f32, 109f32, 64f32, 0f32, 0f32]], [[186f32, 131f32, 0f32, 0f32, 0f32, 0f32], [38f32, 17f32, 38f32, 76f32, 17f32, 34f32], [5f32, 4f32, 80f32, 45f32, 64f32, 36f32], [0f32, 1f32, 0f32, 0f32, 109f32, 64f32]], [[0f32, 0f32, 110f32, 0f32, 73f32, 0f32], [0f32, 0f32, 54f32, 90f32, 57f32, 95f32], [0f32, 0f32, 27f32, 11f32, 162f32, 66f32], [0f32, 0f32, 173f32, 87f32, 0f32, 0f32]], [[0f32, 0f32, 0f32, 110f32, 0f32, 73f32], [0f32, 0f32, 72f32, 108f32, 76f32, 114f32], [0f32, 0f32, 34f32, 14f32, 204f32, 84f32], [0f32, 0f32, 218f32, 110f32, 0f32, 0f32]], [[0f32, 0f32, 186f32, 0f32, 131f32, 0f32], [0f32, 0f32, 114f32, 190f32, 51f32, 85f32], [0f32, 0f32, 135f32, 55f32, 108f32, 44f32], [0f32, 0f32, 0f32, 0f32, 173f32, 87f32]], [[0f32, 0f32, 0f32, 186f32, 0f32, 131f32], [0f32, 0f32, 152f32, 228f32, 68f32, 102f32], [0f32, 0f32, 170f32, 70f32, 136f32, 56f32], [0f32, 0f32, 0f32, 0f32, 218f32, 110f32]]]] } def mm2by2 (a1, b1, c1, d1) @@ -82,3 +82,35 @@ entry rev_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = in tabulate (n * 6) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 6 (i / 6) (i % 6)))) |> unflatten |> map (map toarrs2) + +entry fwd_vec_J [n] (input: [n][2]f32) = + let input = fromarrs input + let seeds = tabulate (n * 2) (\i -> fromarrs (onehot_2d n 2 (i / 2) (i % 2))) + in jvp_vec primal input seeds + |> map toarrs + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec_J [n] (input: [n][2]f32) = + let input = fromarrs input + let seeds = tabulate (n * 2) (\i -> fromarrs (onehot_2d n 2 (i / 2) (i % 2))) + in vjp_vec primal input seeds + |> unflatten + |> map (map toarrs) + +entry fwd_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 6) (\i -> fromarrs2 (onehot_2d n 6 (i / 6) (i % 6))) + in jvp_vec primal2 input seeds + |> map toarrs2 + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 6) (\i -> fromarrs2 (onehot_2d n 6 (i / 6) (i % 6))) + in vjp_vec primal2 input seeds + |> unflatten + |> map (map toarrs2) diff --git a/tests/ad/scan7.fut b/tests/ad/scan7.fut index 00efbe4f9e..6b44ed1f14 100644 --- a/tests/ad/scan7.fut +++ b/tests/ad/scan7.fut @@ -5,7 +5,7 @@ -- tags { autodiff } -- == --- entry: fwd_J rev_J +-- entry: fwd_J rev_J fwd_vec_J rev_vec_J -- input { [[[1f32,2f32], [2f32,3f32]], [[4f32,5f32], [3f32,4f32]], -- [[3f32,4f32], [4f32,5f32]], [[4f32,5f32], [2f32,3f32]]] } -- output { @@ -107,3 +107,24 @@ entry test [n] [m] [k] (input: [n][m][k]f32) bar = let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) in a2 |> transpose |> map transpose |> (map (map transpose)) + +entry fwd_vec_J [n] [m] [k] (input: [n][m][k]f32) = + let seeds = tabulate (n * m * k) (\p -> + let i = p / (m * k) + let j = (p % (m * k)) / k + let q = p % k + in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1))) + let res = jvp_vec primal input seeds + in unflatten (sized (n * (m * k)) res) |> map unflatten + +entry rev_vec_J [n] [m] [k] (input: [n][m][k]f32) = + let seeds = tabulate (n * m * k) (\p -> + let i = p / (m * k) + let j = (p % (m * k)) / k + let q = p % k + in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1))) + let res = vjp_vec primal input seeds + |> (\x -> unflatten (sized (n * (m * k)) x)) |> map unflatten + let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) + let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) + in a2 |> transpose |> map transpose |> (map (map transpose)) diff --git a/tests/ad/scan8.fut b/tests/ad/scan8.fut index 127b564914..806a2bdf92 100644 --- a/tests/ad/scan8.fut +++ b/tests/ad/scan8.fut @@ -1,7 +1,7 @@ -- Scan with 3x3 matrix multiplication. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], -- [9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], -- [1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32], @@ -176,3 +176,19 @@ entry rev [n] (input: [n][9]f32) : [n][9][n][9]f32 = in tabulate (n * 9) (\i -> vjp primal3 input (fromarrs3 (onehot_2d n 9 (i / 9) (i % 9)))) |> unflatten |> map (map toarrs3) + +entry fwd_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = + let input = fromarrs3 input + let seeds = tabulate (n * 9) (\i -> fromarrs3 (onehot_2d n 9 (i / 9) (i % 9))) + in jvp_vec primal3 input seeds + |> map toarrs3 + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = + let input = fromarrs3 input + let seeds = tabulate (n * 9) (\i -> fromarrs3 (onehot_2d n 9 (i / 9) (i % 9))) + in vjp_vec primal3 input seeds + |> unflatten + |> map (map toarrs3) diff --git a/tests/ad/scan9.fut b/tests/ad/scan9.fut index cb1ebfb671..19de6afd69 100644 --- a/tests/ad/scan9.fut +++ b/tests/ad/scan9.fut @@ -1,7 +1,7 @@ -- Scan with 4x4 matrix multiplication. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- no_oclgrind input { -- [[1f32,2f32,3f32,4f32,5f32,6f32,7f32,8f32,9f32,10f32,11f32,12f32,13f32,14f32,15f32,16f32], -- [16f32,15f32,14f32,13f32,12f32,11f32,10f32,9f32,8f32,7f32,6f32,5f32,4f32,3f32,2f32,1f32], @@ -474,3 +474,19 @@ entry rev [n] (input: [n][16]f32) : [n][16][n][16]f32 = in tabulate (n * 16) (\i -> vjp primal2 input (fromarrs2 (onehot_2d n 16 (i / 16) (i % 16)))) |> unflatten |> map (map toarrs2) + +entry fwd_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 16) (\i -> fromarrs2 (onehot_2d n 16 (i / 16) (i % 16))) + in jvp_vec primal2 input seeds + |> map toarrs2 + |> transpose + |> map transpose + |> map (map unflatten) + +entry rev_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = + let input = fromarrs2 input + let seeds = tabulate (n * 16) (\i -> fromarrs2 (onehot_2d n 16 (i / 16) (i % 16))) + in vjp_vec primal2 input seeds + |> unflatten + |> map (map toarrs2) diff --git a/tests/ad/scatter0.fut b/tests/ad/scatter0.fut index 78eff344e3..e968b3dd63 100644 --- a/tests/ad/scatter0.fut +++ b/tests/ad/scatter0.fut @@ -1,7 +1,7 @@ -- Simple scatter, differentiating wrt. values. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64, 2i64, 3i64] [1f64, 2f64, 3f64, 0f64] } -- output { -- [[1.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], @@ -27,3 +27,11 @@ entry fwd [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = entry rev [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = vjp (\vs -> f xs is vs) vs (replicate k 0 with [i] = 1) in tabulate n g + +entry fwd_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (\vs -> f xs is vs) vs seeds + +entry rev_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate n (\i -> replicate k 0 with [i] = 1) + in vjp_vec (\vs -> f xs is vs) vs seeds diff --git a/tests/ad/scatter1.fut b/tests/ad/scatter1.fut index 0e41101762..8217639cfc 100644 --- a/tests/ad/scatter1.fut +++ b/tests/ad/scatter1.fut @@ -1,7 +1,7 @@ -- Simple scatter, differentiating wrt. target. -- == -- tags { autodiff } --- entry: fwd rev +-- entry: fwd rev fwd_vec rev_vec -- input { [0f64, 0f64, 0f64, 0f64] [0i64, 1i64] [1f64, 2f64] } -- output { -- [[0.000000f64, 0.000000f64, 0.000000f64, 0.000000f64], @@ -20,3 +20,11 @@ entry fwd [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = entry rev [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let g i = vjp (\xs -> f xs is vs) xs (replicate k 0 with [i] = 1) in tabulate k g + +entry fwd_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) + in jvp_vec (\xs -> f xs is vs) xs seeds + +entry rev_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = + let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) + in vjp_vec (\xs -> f xs is vs) xs seeds diff --git a/tests/ad/stripmine1.fut b/tests/ad/stripmine1.fut index 82e8280234..31756734d4 100644 --- a/tests/ad/stripmine1.fut +++ b/tests/ad/stripmine1.fut @@ -15,7 +15,7 @@ def square [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = square xs -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], @@ -28,3 +28,12 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) + +entry f_jvp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec square xs seeds + |> transpose + +entry f_vjp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec square xs seeds diff --git a/tests/ad/stripmine2.fut b/tests/ad/stripmine2.fut index d32cff65e0..b6516a075b 100644 --- a/tests/ad/stripmine2.fut +++ b/tests/ad/stripmine2.fut @@ -13,7 +13,7 @@ def pow_list [n] y (xs: [n]i32) = entry prim y xs = pow_list y xs -- == --- entry: f_vjp f_jvp +-- entry: f_vjp f_jvp f_vjp_vec f_jvp_vec -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], @@ -24,3 +24,12 @@ entry f_jvp [n] y (xs: [n]i32) = entry f_vjp [n] y (xs: [n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) + +entry f_jvp_vec [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (pow_list y) xs seeds + |> transpose + +entry f_vjp_vec [n] y (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec (pow_list y) xs seeds diff --git a/tests/ad/sum.fut b/tests/ad/sum.fut index 241f9b65fa..8477f66dad 100644 --- a/tests/ad/sum.fut +++ b/tests/ad/sum.fut @@ -1,7 +1,7 @@ -- Simple reduce with summation. -- == -- tags { autodiff } --- entry: rev fwd +-- entry: rev fwd fwd_vec -- input { [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] } -- output { [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] } @@ -13,3 +13,7 @@ entry rev [n] (xs: [n]f64) = entry fwd [n] (xs: [n]f64) = tabulate n (\i -> jvp sum xs (tabulate n ((== i) >-> f64.bool))) + +entry fwd_vec [n] (xs: [n]f64) = + let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) + in jvp_vec sum xs seeds diff --git a/tests/ad/truedep0.fut b/tests/ad/truedep0.fut index 518091ed19..813cf66903 100644 --- a/tests/ad/truedep0.fut +++ b/tests/ad/truedep0.fut @@ -12,7 +12,7 @@ def test [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = test xs -- == --- entry: f_jvp f_vjp +-- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec -- input { [1,2,3,4,5] } -- output { [[1,0,0,0,0], -- [2,0,0,0,0], @@ -25,3 +25,12 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp test xs (replicate n 0 with [i] = 1)) + +entry f_jvp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec test xs seeds + |> transpose + +entry f_vjp_vec [n] (xs: [n]i32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in vjp_vec test xs seeds From 4892703f3b40710423bb8493b83d1f2d98e36604 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 13:27:26 +0200 Subject: [PATCH 32/70] Need to transpose here. --- src/Futhark/AD/Rev/Reduce.hs | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/Futhark/AD/Rev/Reduce.hs b/src/Futhark/AD/Rev/Reduce.hs index 15b0fadbdf..b6804f4e5a 100644 --- a/src/Futhark/AD/Rev/Reduce.hs +++ b/src/Futhark/AD/Rev/Reduce.hs @@ -347,9 +347,17 @@ diffMulReduce _ops x aux w mul ne as m = do as_adjup <- letExp "prod_contrib" . Op . Screma w [as] =<< mapSOAC map_lam_rev - updateAdj as as_adjup + updateAdj as =<< transposeIfNeeded as_adjup where getDiv :: PrimType -> BinOp getDiv (IntType t) = SDiv t Unsafe getDiv (FloatType t) = FDiv t getDiv _ = error "In getDiv, Reduce.hs: input not supported" + + transposeIfNeeded v = do + adj_shape <- askShape + if adj_shape == mempty + then pure v + else do + v_t <- lookupType v + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) From 5410a1ad7e6def6f6be72e1fc0beb1210c745782 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 13:32:42 +0200 Subject: [PATCH 33/70] Share code. --- src/Futhark/AD/Rev/Reduce.hs | 36 +++++++++++------------------------- 1 file changed, 11 insertions(+), 25 deletions(-) diff --git a/src/Futhark/AD/Rev/Reduce.hs b/src/Futhark/AD/Rev/Reduce.hs index b6804f4e5a..e896936355 100644 --- a/src/Futhark/AD/Rev/Reduce.hs +++ b/src/Futhark/AD/Rev/Reduce.hs @@ -75,6 +75,17 @@ mkF lam = do bodyBind $ lambdaBody lam_r pure (map paramName aps, lam') +-- | If we are doing vectorised AD, then transpose the array to bring the vector +-- shape outermost. +transposeIfNeeded :: VName -> ADM VName +transposeIfNeeded v = do + adj_shape <- askShape + if adj_shape == mempty + then pure v + else do + v_t <- lookupType v + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) + diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM () diffReduce _ops [adj] w [a] red | Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, @@ -84,15 +95,6 @@ diffReduce _ops [adj] w [a] red BasicOp (Replicate (Shape [w]) (Var adj)) void $ updateAdj a adj_rep where - transposeIfNeeded v = do - adj_shape <- askShape - if adj_shape == mempty - then pure v - else do - v_t <- lookupType v - let perm = [1 .. shapeRank adj_shape] ++ [0] ++ [shapeRank adj_shape + 1 .. arrayRank v_t - 1] - letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm - isAdd FAdd {} = True isAdd Add {} = True isAdd _ = False @@ -131,14 +133,6 @@ diffReduce ops pat_adj w as red = do zipWithM_ updateAdj as =<< mapM transposeIfNeeded as_adj where - transposeIfNeeded v = do - adj_shape <- askShape - if adj_shape == mempty - then pure v - else do - v_t <- lookupType v - letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) - renameRed (Reduce comm lam nes) = Reduce comm <$> renameLambda lam <*> pure nes @@ -353,11 +347,3 @@ diffMulReduce _ops x aux w mul ne as m = do getDiv (IntType t) = SDiv t Unsafe getDiv (FloatType t) = FDiv t getDiv _ = error "In getDiv, Reduce.hs: input not supported" - - transposeIfNeeded v = do - adj_shape <- askShape - if adj_shape == mempty - then pure v - else do - v_t <- lookupType v - letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) From 76be3e379ed086ba6a7b9ed44a1b095bd6101224 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 13:47:06 +0200 Subject: [PATCH 34/70] Fix typo. --- src/Futhark/AD/Rev/Loop.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/AD/Rev/Loop.hs b/src/Futhark/AD/Rev/Loop.hs index fad1fa8004..a5b372b338 100644 --- a/src/Futhark/AD/Rev/Loop.hs +++ b/src/Futhark/AD/Rev/Loop.hs @@ -267,7 +267,7 @@ reverseIndices loop = do pure (M.singleton i i_rev, i_stms) --- | Pures a substitution which substitutes values in the reverse +-- | Returns a substitution which substitutes values in the reverse -- loop body with values from the tape. restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions restore stms_adj loop_params' i' = From 3e3aecb6e7a0afb36af80496078e40c516fb220a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 15:04:46 +0200 Subject: [PATCH 35/70] Fix some more things. --- src/Futhark/AD/Rev/Monad.hs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index c41ab8f7c1..d26e7ebbb6 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -438,11 +438,11 @@ updateAdjIndex v (check, i) se = do v_adj <- adjVal adj v_adj_t <- lookupType v_adj se_v <- letExp "se_v" $ BasicOp $ SubExp se + vec_shape <- askShape insAdj v =<< case v_adj_t of Acc {} -> do let stms s = do - vec_shape <- askShape attrs <- asks envAttrs dims <- arrayDims <$> lookupType se_v ~[v_adj'] <- @@ -458,13 +458,15 @@ updateAdjIndex v (check, i) se = do OutOfBounds -> pure v_adj _ -> do let stms s = do + let slice = + fullSlice v_adj_t $ + map sliceDim (shapeDims vec_shape) ++ [DimFix i] v_adj_i <- letExp (baseName v_adj <> "_i") . BasicOp $ - Index v_adj $ - fullSlice v_adj_t [DimFix i] + Index v_adj slice se_update <- letSubExp "updated_adj_i" =<< addExp se_v v_adj_i letExp (baseName v_adj) . BasicOp $ - Update s v_adj (fullSlice v_adj_t [DimFix i]) se_update + Update s v_adj slice se_update case check of CheckBounds _ -> stms Safe AssumeBounds -> stms Unsafe From 7046d72820cfa1903026fa44bd2787fcb4b4aff8 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 15:11:47 +0200 Subject: [PATCH 36/70] Handle vector here. --- src/Futhark/AD/Rev.hs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index ac44767ebe..f6f64ab5df 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -206,13 +206,16 @@ diffBasicOp pat aux e m = Update safety arr slice v -> do (_pat_v, pat_adj) <- commonBasicOp pat aux e m returnSweepCode $ do - v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj slice + adj_shape <- askShape + let adj_slice = Slice $ map sliceDim (shapeDims adj_shape) ++ unSlice slice + v_adj <- letExp "update_val_adj" $ BasicOp $ Index pat_adj adj_slice v_adj_copy <- copyIfArray v_adj updateSubExpAdj v v_adj_copy - zeroes <- letSubExp "update_zero" . zeroExp =<< subExpType v + v_adj_t <- lookupType v_adj + zeroes <- letSubExp "update_zero" $ zeroExp v_adj_t void $ updateAdj arr - =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj slice zeroes) + =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj adj_slice zeroes) -- See Note [Adjoints of accumulators] UpdateAcc safety _ is vs -> do addStm $ Let pat aux $ BasicOp e From 8ff4dc36dec5f5541817aaed686ac21ffc93617f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 16:13:19 +0200 Subject: [PATCH 37/70] Generate right tangents for stream. --- src/Futhark/AD/Fwd.hs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index d12e6c4d9f..f562b55c9c 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -391,14 +391,6 @@ fwdStreamLambda (Lambda params _ body) = do params' <- (take 1 params ++) <$> bundleNewList (drop 1 params) mkLambda params' $ bodyBind =<< fwdBody body -zeroFromSubExp :: SubExp -> ADM VName -zeroFromSubExp (Constant c) = - letExp "zero" . BasicOp . SubExp . Constant $ - blankPrimValue (primValueType c) -zeroFromSubExp (Var v) = do - t <- lookupType v - letExp "zero" $ zeroExp t - vecPerm :: Shape -> Type -> [Int] vecPerm = auxPerm @@ -473,7 +465,7 @@ fwdSOAC pat aux (Stream size xs nes lam) = do pat' <- bundleNewPat pat lam' <- fwdStreamLambda lam xs' <- soacInputsWithTangents xs - nes_tan <- mapM (fmap Var . zeroFromSubExp) nes + nes_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType <=< subExpType) nes let nes' = interleave nes nes_tan addStm $ Let pat' aux $ Op $ Stream size xs' nes' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do From e8c06d0de4f4fa07623e1b67522d3e142552863b Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Fri, 29 May 2026 18:08:21 +0200 Subject: [PATCH 38/70] Support Sparse adjoints in vectorised AD (#2473) --- src/Futhark/AD/Rev/Map.hs | 4 ++-- src/Futhark/AD/Rev/Monad.hs | 31 ++++++++++++++++++++----------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/Futhark/AD/Rev/Map.hs b/src/Futhark/AD/Rev/Map.hs index 22487ee7cf..b39f899b17 100644 --- a/src/Futhark/AD/Rev/Map.hs +++ b/src/Futhark/AD/Rev/Map.hs @@ -176,8 +176,8 @@ vjpMap ops res_adjs _ w map_lam as zipWithM_ forRes [0 ..] res_ivs where - isSparse (AdjSparse (Sparse shape _ ivs)) = do - guard $ shapeDims shape == [w] + isSparse (AdjSparse (Sparse shape _ vd ivs)) = do + guard $ drop vd (shapeDims shape) == [w] Just ivs isSparse _ = Nothing diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index d26e7ebbb6..49534671c4 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -111,10 +111,15 @@ data InBounds -- | A symbolic representation of an array that is all zeroes, except -- at certain indexes. data Sparse = Sparse - { -- | The shape of the array. + { -- | The full shape of the array (including any vector dimensions). sparseShape :: Shape, -- | Element type of the array. sparseType :: PrimType, + -- | Number of leading dimensions that are \"vector\" dimensions, + -- due to vectorised AD. These are not indexed by the sparse + -- index, but are present in the values. When zero, this is the + -- ordinary non-vectorised case. + sparseVecDims :: Int, -- | Locations and values of nonzero values. Indexes may be -- negative, in which case the value is ignored (unless -- 'AssumeBounds' is used). @@ -147,14 +152,15 @@ zeroArray shape t Replicate shape zero sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName -sparseArray (Sparse shape t ivs) = do +sparseArray (Sparse shape t vec_dims ivs) = do flip (foldM f) ivs =<< zeroArray shape (Prim t) where arr_t = Prim t `arrayOfShape` shape + vec_slice = map sliceDim $ take vec_dims $ shapeDims shape f arr (check, i, se) = do let stm s = letExp "sparse" . BasicOp $ - Update s arr (fullSlice arr_t [DimFix i]) se + Update s arr (fullSlice arr_t (vec_slice ++ [DimFix i])) se case check of AssumeBounds -> stm Unsafe CheckBounds _ -> stm Safe @@ -177,8 +183,8 @@ unitAdjOfType t = AdjVal <$> letSubExp "adj_unit" (oneExp t) adjRep :: Adj -> ([SubExp], [SubExp] -> Adj) adjRep (AdjVal se) = ([se], \[se'] -> AdjVal se') adjRep (AdjZero shape pt) = ([], \[] -> AdjZero shape pt) -adjRep (AdjSparse (Sparse shape pt ivs)) = - (concatMap ivRep ivs, AdjSparse . Sparse shape pt . repIvs ivs) +adjRep (AdjSparse (Sparse shape pt vd ivs)) = + (concatMap ivRep ivs, AdjSparse . Sparse shape pt vd . repIvs ivs) where ivRep (_, i, v) = [i, v] repIvs ((check, _, _) : ivs') (i : v : ses) = @@ -426,14 +432,17 @@ updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM () updateAdjIndex v (check, i) se = do maybeAdj <- gets $ M.lookup v . stateAdjs t <- lookupType v + adj_shape <- askShape let iv = (check, i, se) + vec_dims = shapeRank adj_shape + full_shape = adj_shape <> arrayShape t case maybeAdj of - Nothing -> do - setAdj v $ AdjSparse $ Sparse (arrayShape t) (elemType t) [iv] - Just AdjZero {} -> - setAdj v $ AdjSparse $ Sparse (arrayShape t) (elemType t) [iv] - Just (AdjSparse (Sparse shape pt ivs)) -> - setAdj v $ AdjSparse $ Sparse shape pt $ iv : ivs + Nothing -> + setAdj v $ AdjSparse $ Sparse full_shape (elemType t) vec_dims [iv] + Just (AdjZero {}) -> + setAdj v $ AdjSparse $ Sparse full_shape (elemType t) vec_dims [iv] + Just (AdjSparse (Sparse shape pt vd ivs)) -> + setAdj v $ AdjSparse $ Sparse shape pt vd $ iv : ivs Just adj@AdjVal {} -> do v_adj <- adjVal adj v_adj_t <- lookupType v_adj From 26801f4e9f273b0b07ec7aeef1f40a9620a26eae Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 29 May 2026 18:11:17 +0200 Subject: [PATCH 39/70] Elaborate comment. --- src/Futhark/AD/Rev/Monad.hs | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 49534671c4..08cc96f678 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -111,14 +111,16 @@ data InBounds -- | A symbolic representation of an array that is all zeroes, except -- at certain indexes. data Sparse = Sparse - { -- | The full shape of the array (including any vector dimensions). + { -- | The full shape of the array (including any vector dimensions, which are + -- stored in sparseVecDims). sparseShape :: Shape, -- | Element type of the array. sparseType :: PrimType, - -- | Number of leading dimensions that are \"vector\" dimensions, - -- due to vectorised AD. These are not indexed by the sparse - -- index, but are present in the values. When zero, this is the - -- ordinary non-vectorised case. + -- | Number of leading dimensions that are \"vector\" dimensions, due to + -- vectorised AD. These are not indexed by the sparse index, but are present + -- in the values. When zero, this is the ordinary non-vectorised case. This + -- is equivalent to the rank of `askShape`, but it is convenient to store it + -- here as well. sparseVecDims :: Int, -- | Locations and values of nonzero values. Indexes may be -- negative, in which case the value is ignored (unless From be68354e32a7a86723bcde8a712d7352d31e7d0d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 May 2026 08:42:17 +0200 Subject: [PATCH 40/70] Fix scatter. --- futhark.cabal | 1 + src/Futhark/AD/Rev.hs | 223 +------------------- src/Futhark/AD/Rev/Acc.hs | 396 +++++++++++++++++++++++++++++++++++ src/Futhark/AD/Rev/Monad.hs | 17 +- src/Futhark/AD/Rev/Reduce.hs | 17 +- 5 files changed, 424 insertions(+), 230 deletions(-) create mode 100644 src/Futhark/AD/Rev/Acc.hs diff --git a/futhark.cabal b/futhark.cabal index f1d8494e20..5f7acc8514 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -118,6 +118,7 @@ library Futhark.AD.Fwd Futhark.AD.Shared Futhark.AD.Rev + Futhark.AD.Rev.Acc Futhark.AD.Rev.Loop Futhark.AD.Rev.Hist Futhark.AD.Rev.Map diff --git a/src/Futhark/AD/Rev.hs b/src/Futhark/AD/Rev.hs index f6f64ab5df..6da9a5165e 100644 --- a/src/Futhark/AD/Rev.hs +++ b/src/Futhark/AD/Rev.hs @@ -9,12 +9,11 @@ module Futhark.AD.Rev (revVJP) where import Control.Monad -import Control.Monad.Identity -import Data.List ((\\)) import Data.List.NonEmpty (NonEmpty (..)) import Data.Map qualified as M import Data.Tuple import Futhark.AD.Derivatives +import Futhark.AD.Rev.Acc import Futhark.AD.Rev.Loop import Futhark.AD.Rev.Monad import Futhark.AD.Rev.SOAC @@ -25,7 +24,7 @@ import Futhark.IR.SOACS import Futhark.Tools import Futhark.Transform.Rename import Futhark.Transform.Substitute -import Futhark.Util (chunks, takeLast) +import Futhark.Util (takeLast) patName :: Pat Type -> ADM VName patName (Pat [pe]) = pure $ patElemName pe @@ -216,27 +215,8 @@ diffBasicOp pat aux e m = void $ updateAdj arr =<< letExp "update_src_adj" (BasicOp $ Update safety pat_adj adj_slice zeroes) - -- See Note [Adjoints of accumulators] - UpdateAcc safety _ is vs -> do - addStm $ Let pat aux $ BasicOp e - m - pat_adjs <- mapM lookupAdjVal (patNames pat) - returnSweepCode $ do - forM_ (zip pat_adjs vs) $ \(adj, v) -> do - adj_t <- lookupType adj - let index_adj = pure $ BasicOp $ Index adj $ fullSlice adj_t $ map DimFix is - adj_i <- - letExp "updateacc_val_adj" =<< case safety of - Unsafe -> - index_adj - Safe -> - -- The primal UpdateAcc may be out-of-bounds, in which case - -- indexing the adjoint is dangerous. - eIf - (eShapeInBounds (arrayShape adj_t) (map eSubExp is)) - (eBody [index_adj]) - (eBody [pure $ zeroExp $ stripArray (length is) adj_t]) - updateSubExpAdj v adj_i + UpdateAcc safety acc is vs -> + diffUpdateAcc pat aux safety acc is vs m -- UserParam {} -> void $ commonBasicOp pat aux e m @@ -245,30 +225,10 @@ vjpOps :: VjpOps vjpOps = VjpOps { vjpLambda = diffLambda, - vjpStm = diffStm + vjpStm = diffStm, + vjpBody = diffBody } --- | Transform updates on accumulators matching the given certificates into --- updates that write provided zero values. -zeroOutUpdates :: [(VName, [SubExp])] -> Lambda SOACS -> Lambda SOACS -zeroOutUpdates certs_to_zeroes lam = lam {lambdaBody = onBody $ lambdaBody lam} - where - onExp = runIdentity . mapExpM mapper - where - mapper = - (identityMapper :: (Monad m) => Mapper SOACS SOACS m) - { mapOnOp = traverseSOACStms (\_ stms -> pure $ onStms stms), - mapOnBody = \_ body -> pure $ onBody body - } - onStms = fmap onStm - onStm (Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is _))) - | Acc c _ _ _ <- patElemType pe, - Just zero <- lookup c certs_to_zeroes = - Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is zero)) - onStm (Let pat aux e) = Let pat aux $ onExp e - - onBody body = body {bodyStms = onStms $ bodyStms body} - diffStm :: Stm SOACS -> ADM () -> ADM () diffStm (Let pat aux (BasicOp e)) m = diffBasicOp pat aux e m @@ -334,35 +294,8 @@ diffStm (Let pat aux (Op soac)) m = diffStm (Let pat aux loop@Loop {}) m = diffLoop diffStms pat aux loop m -- See Note [Adjoints of accumulators] -diffStm stm@(Let pat _aux (WithAcc inputs lam)) m = do - addStm stm - m - returnSweepCode $ do - adjs <- mapM lookupAdj $ patNames pat - lam' <- renameLambda lam - free_vars <- filterM isActive $ namesToList $ freeIn lam' - free_accs <- filterM (fmap isAcc . lookupType) free_vars - let free_vars' = free_vars \\ free_accs - lam'' <- diffLambda' adjs free_vars' lam' - (inputs_zeroes, inputs') <- - unzip <$> zipWithM renameInputLambda (chunks lengths adjs) inputs - let certs = map paramName $ take (length inputs) $ lambdaParams lam'' - free_adjs <- letTupExp "with_acc_contrib" $ WithAcc inputs' $ zeroOutUpdates (zip certs inputs_zeroes) lam'' - zipWithM_ insAdj (arrs <> free_vars') free_adjs - where - lengths = map (\(_, as, _) -> length as) inputs - arrs = concatMap (\(_, as, _) -> as) inputs - renameInputLambda as_adj (shape, as, _) = do - nes_ts <- mapM (fmap (stripArray (shapeRank shape)) . lookupType) as - zeroes <- mapM (zeroArray mempty) nes_ts - as' <- mapM adjVal as_adj - pure (map Var zeroes, (shape, as', Nothing)) - diffLambda' res_adjs get_adjs_for (Lambda params ts body) = do - localScope (scopeOfLParams params) $ do - Body () stms res <- diffBody res_adjs get_adjs_for body - let body' = Body () stms $ take (length inputs) res <> takeLast (length get_adjs_for) res - ts' <- mapM lookupType get_adjs_for - pure $ Lambda params (take (length inputs) ts <> ts') body' +diffStm (Let pat aux (WithAcc inputs lam)) m = + diffWithAcc vjpOps pat aux inputs lam m diffStm stm _ = error $ "diffStm unhandled:\n" ++ prettyString stm diffStms :: Stms SOACS -> ADM () @@ -421,143 +354,3 @@ revVJP scope shape attrs (Lambda params ts body) = do body pure $ Lambda (params ++ params_adj) (ts <> map paramType params) body' - --- Note [Adjoints of accumulators] --- --- The general case of taking adjoints of WithAcc is tricky. We make --- some assumptions and lay down a basic design. --- --- First, we assume that any WithAccs that occur in the program are --- come from one of these sources: --- --- - A previous instance of VJP, which means we can rely on the operator having --- a constant adjoint (it's addition as appropriate to the type). --- --- - A scatter, meaning there is no operator. --- --- (These can actually be distinguished by the presence of an operator, although --- we do not currently bother.) --- --- Second, the adjoint of an accumulator is an array of the same type --- as the underlying array. For example, the adjoint type of the --- primal type 'acc(c, [n], {f64})' is '[n]f64'. In principle the --- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type --- '[]f64', '[]f32'. Our current design assumes that adjoints are --- single variables. This is fixable. --- --- In the return sweep, when inserting the with_acc, we still compute the --- "original" accumulator result, but modified such that its initial value is --- the adjoint of the result of the accumulator. We also modify the update_accs --- of these accumulators to be with zero values. This means that the array that --- is produced will be equal to the adjoint of the result, except for those --- places that have been updated, where it will be zero. This is intuitively --- sensible - values that have been overwritten (and so do not contribute to the --- result) should obviously have zero sensitivity. --- --- # Adjoint of UpdateAcc --- --- Consider primal code --- --- update_acc(acc, i, v) --- --- Interpreted as an imperative statement, this means --- --- acc[i] ⊕= v --- --- for some '⊕'. Normally all the compiler knows of '⊕' is that it --- is associative and commutative, but because we assume that all --- accumulators are the result of previous AD transformations, we --- can assume that '⊕' actually behaves like addition - that is, has --- unit partial derivatives. So the return sweep is --- --- v_adj += acc_adj[i] --- --- Further, we modify the primal code so that it becomes --- --- update_acc(acc, i, 0) --- --- for some appropriate notion of zero. --- --- # Adjoint of Map --- --- Suppose we have primal code --- --- let acc' = --- map (...) acc --- --- where "acc : acc(c, [n], {f64})" and the width of the Map is "w". --- Our normal transformation for Map input arrays is to similarly map --- their adjoint, but clearly this doesn't work here because the --- semantics of mapping an adjoint is an "implicit replicate". So --- when generating the return sweep we actually perform that --- replication: --- --- map (...) (replicate w acc_adj) --- --- But what about the contributions to "acc'"? Those we also have to --- take special care of. The result of the map itself is actually a --- multidimensional array: --- --- let acc_contribs = --- map (...) (replicate w acc'_adj) --- --- which we must then sum to add to the contribution. --- --- acc_adj += sum(acc_contribs) --- --- I'm slightly worried about the asymptotics of this, since my --- intuition of this is that the contributions might be rather sparse. --- (Maybe completely zero? If so it will be simplified away --- entirely.) Perhaps a better solution is to treat --- accumulator-inputs in the primal code as we do free variables, and --- create accumulators for them in the return sweep. --- --- # Consumption --- --- A minor problem is that our usual way of handling consumption (Note --- [Consumption]) is not viable, because accumulators are not --- copyable. Fortunately, while the accumulators that are consumed in --- the forward sweep will also be present in the return sweep given --- our current translation rules, they will be dead code. As long as --- we are careful to run dead code elimination after revVJP, we should --- be good. - --- Note [Array Adjoints of Match] --- --- Some unusual, but sadly not completely contrived, contain Match --- expressions that return multiple arrays, and there the arrays --- returned by one branch have overlapping aliases with another --- branch, although in different places. As an example consider this: --- --- let (X,Y) = if c --- then (A, B) --- else (B, A) --- --- Because our aliasing representation cannot express mutually --- exclusive aliases, we will consider X and Y to be aliased to each --- other. In practice, this means it is unlikely for X or Y to be --- consumed, because it would also consume the other (although it's --- possible for carefully written code). --- --- When producing adjoints for this, it will be something like --- --- let (X_adj,Y_adj) = if c --- then (A_adj, B_adj) --- else (B_adj, A_adj) --- --- which completely reflects the primal code. However, while it is --- unlikely that any consumption takes place for the original primal --- variables, it is almost guaranteed that X_adj and Y_adj will be --- consumed (that is the main way we use adjoints after all), and due --- to the conservative aliasing, when one is consumed, so is the --- other! To avoid this tragic fate, we are forced to copy any --- array-typed adjoints returned by a Match. This can be quite costly. --- However: --- --- 1) Futhark has pretty OK copy removal, so maybe it can get rid of --- these by using information not available to the AD pass. --- --- 2) In many cases, arrays will have accumulator adjoints, which are --- not subject to this problem. --- --- Issue #2228 was caused by neglecting to do this. diff --git a/src/Futhark/AD/Rev/Acc.hs b/src/Futhark/AD/Rev/Acc.hs new file mode 100644 index 0000000000..ff34bd2433 --- /dev/null +++ b/src/Futhark/AD/Rev/Acc.hs @@ -0,0 +1,396 @@ +-- | Differentiation related to accumulators in the input program. +module Futhark.AD.Rev.Acc + ( diffWithAcc, + diffUpdateAcc, + ) +where + +-- Note [Adjoints of accumulators] +-- +-- The general case of taking adjoints of WithAcc is tricky. We make +-- some assumptions and lay down a basic design. +-- +-- First, we assume that any WithAccs that occur in the program are +-- come from one of these sources: +-- +-- - A previous instance of VJP, which means we can rely on the operator having +-- a constant adjoint (it's addition as appropriate to the type). +-- +-- - A scatter, meaning there is no operator. +-- +-- (These can actually be distinguished by the presence of an operator, although +-- we do not currently bother.) +-- +-- Second, the adjoint of an accumulator is an array of the same type +-- as the underlying array. For example, the adjoint type of the +-- primal type 'acc(c, [n], {f64})' is '[n]f64'. In principle the +-- adjoint of 'acc(c, [n], {f64,f32})' should be two arrays of type +-- '[]f64', '[]f32'. Our current design assumes that adjoints are +-- single variables. This is fixable. +-- +-- In the return sweep, when inserting the with_acc, we still compute the +-- "original" accumulator result, but modified such that its initial value is +-- the adjoint of the result of the accumulator. We also modify the update_accs +-- of these accumulators to be with zero values. This means that the array that +-- is produced will be equal to the adjoint of the result, except for those +-- places that have been updated, where it will be zero. This is intuitively +-- sensible - values that have been overwritten (and so do not contribute to the +-- result) should obviously have zero sensitivity. +-- +-- # Adjoint of UpdateAcc +-- +-- Consider primal code +-- +-- update_acc(acc, i, v) +-- +-- Interpreted as an imperative statement, this means +-- +-- acc[i] ⊕= v +-- +-- for some '⊕'. Normally all the compiler knows of '⊕' is that it +-- is associative and commutative, but because we assume that all +-- accumulators are the result of previous AD transformations, we +-- can assume that '⊕' actually behaves like addition - that is, has +-- unit partial derivatives. So the return sweep is +-- +-- v_adj += acc_adj[i] +-- +-- Further, we modify the primal code so that it becomes +-- +-- update_acc(acc, i, 0) +-- +-- for some appropriate notion of zero. +-- +-- # Adjoint of Map +-- +-- Suppose we have primal code +-- +-- let acc' = +-- map (...) acc +-- +-- where "acc : acc(c, [n], {f64})" and the width of the Map is "w". +-- Our normal transformation for Map input arrays is to similarly map +-- their adjoint, but clearly this doesn't work here because the +-- semantics of mapping an adjoint is an "implicit replicate". So +-- when generating the return sweep we actually perform that +-- replication: +-- +-- map (...) (replicate w acc_adj) +-- +-- But what about the contributions to "acc'"? Those we also have to +-- take special care of. The result of the map itself is actually a +-- multidimensional array: +-- +-- let acc_contribs = +-- map (...) (replicate w acc'_adj) +-- +-- which we must then sum to add to the contribution. +-- +-- acc_adj += sum(acc_contribs) +-- +-- I'm slightly worried about the asymptotics of this, since my +-- intuition of this is that the contributions might be rather sparse. +-- (Maybe completely zero? If so it will be simplified away +-- entirely.) Perhaps a better solution is to treat +-- accumulator-inputs in the primal code as we do free variables, and +-- create accumulators for them in the return sweep. +-- +-- # Vectorised WithAcc +-- +-- When WithAcc occurs in vectorised AD, the accumulator element types gain +-- extra leading "vectorised" dimensions corresponding to the enclosing vector +-- shape. For example, if the primal type inside a map of width @w@ is @acc(c, +-- [n], {f64})@, the adjoint type is @[w][n]f64@ -- but the internal accumulator +-- layout expects shape @[n][w]f64@ (the accumulator shape comes first, then the +-- vectorised dimensions, then element dimensions). +-- +-- This means we must transpose accumulator adjoints when entering and +-- leaving the return-sweep WithAcc: +-- +-- * On entry: transpose result adjoints from @[vec...][shape...]elem@ to +-- @[shape...][vec...]elem@ so they can serve as initial values for the +-- accumulators. +-- +-- * On exit: transpose the produced arrays back from @[shape...][vec...]elem@ +-- to @[vec...][shape...]elem@ to match the expected adjoint layout. +-- +-- This is actually quite similar to how other SOACs must be handled. +-- +-- Additionally, the accumulator parameter types in the lambda (and any +-- Acc-typed pattern elements or inner lambda parameters referring to the same +-- certs) must be updated to reflect the vectorised element types *before* +-- differentiation. This ensures that 'lookupAdj' on accumulator variables +-- inside the lambda produces adjoints with the correct vectorised type. +-- +-- The UpdateAcc case is simpler under vectorisation: because the accumulator +-- adjoint already has the vectorised dimensions folded into its element type, a +-- plain index into the adjoint at the update indices directly yields the +-- correctly-shaped contribution. +-- +-- # Consumption +-- +-- A minor problem is that our usual way of handling consumption (Note +-- [Consumption]) is not viable, because accumulators are not +-- copyable. Fortunately, while the accumulators that are consumed in +-- the forward sweep will also be present in the return sweep given +-- our current translation rules, they will be dead code. As long as +-- we are careful to run dead code elimination after revVJP, we should +-- be good. + +-- Note [Array Adjoints of Match] +-- +-- Some unusual, but sadly not completely contrived, contain Match +-- expressions that return multiple arrays, and there the arrays +-- returned by one branch have overlapping aliases with another +-- branch, although in different places. As an example consider this: +-- +-- let (X,Y) = if c +-- then (A, B) +-- else (B, A) +-- +-- Because our aliasing representation cannot express mutually +-- exclusive aliases, we will consider X and Y to be aliased to each +-- other. In practice, this means it is unlikely for X or Y to be +-- consumed, because it would also consume the other (although it's +-- possible for carefully written code). +-- +-- When producing adjoints for this, it will be something like +-- +-- let (X_adj,Y_adj) = if c +-- then (A_adj, B_adj) +-- else (B_adj, A_adj) +-- +-- which completely reflects the primal code. However, while it is +-- unlikely that any consumption takes place for the original primal +-- variables, it is almost guaranteed that X_adj and Y_adj will be +-- consumed (that is the main way we use adjoints after all), and due +-- to the conservative aliasing, when one is consumed, so is the +-- other! To avoid this tragic fate, we are forced to copy any +-- array-typed adjoints returned by a Match. This can be quite costly. +-- However: +-- +-- 1) Futhark has pretty OK copy removal, so maybe it can get rid of +-- these by using information not available to the AD pass. +-- +-- 2) In many cases, arrays will have accumulator adjoints, which are +-- not subject to this problem. +-- +-- Issue #2228 was caused by neglecting to do this. + +import Control.Monad +import Control.Monad.Identity +import Data.List ((\\)) +import Futhark.AD.Rev.Monad +import Futhark.Builder +import Futhark.IR.SOACS +import Futhark.Tools +import Futhark.Transform.Rename +import Futhark.Util (chunks, takeLast) + +-- | Transform updates on accumulators matching the given certificates into +-- updates that write provided zero values. +zeroOutUpdates :: [(VName, [SubExp])] -> Lambda SOACS -> Lambda SOACS +zeroOutUpdates certs_to_zeroes lam = lam {lambdaBody = onBody $ lambdaBody lam} + where + onExp = runIdentity . mapExpM mapper + where + mapper = + (identityMapper :: (Monad m) => Mapper SOACS SOACS m) + { mapOnOp = traverseSOACStms (\_ stms -> pure $ onStms stms), + mapOnBody = \_ body -> pure $ onBody body + } + onStms = fmap onStm + onStm (Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is _))) + | Acc c _ _ _ <- patElemType pe, + Just zero <- lookup c certs_to_zeroes = + Let (Pat [pe]) aux (BasicOp (UpdateAcc safety acc is zero)) + onStm (Let pat aux e) = Let pat aux $ onExp e + + onBody body = body {bodyStms = onStms $ bodyStms body} + +-- Update accumulator parameter types in the lambda to include vectorised +-- element types. Also updates all Acc-typed pattern elements and inner +-- lambda parameters that reference the same accumulator certs. +updateAccParamTypes :: Int -> Shape -> Lambda SOACS -> Lambda SOACS +updateAccParamTypes n_inputs adj_sh lam + | adj_sh == mempty = lam + | otherwise = + let (cert_ps, rest_ps) = splitAt n_inputs (lambdaParams lam) + (acc_ps, other_ps) = splitAt n_inputs rest_ps + acc_ps' = map (updateParam cert_names) acc_ps + cert_names = map paramName cert_ps + body' = updateBody cert_names (lambdaBody lam) + ret' = map (updateAccType cert_names) (lambdaReturnType lam) + in lam + { lambdaParams = cert_ps ++ acc_ps' ++ other_ps, + lambdaReturnType = ret', + lambdaBody = body' + } + where + updateParam :: [VName] -> Param Type -> Param Type + updateParam certs p = + p {paramDec = updateAccType certs (paramDec p)} + + updateAccType :: [VName] -> Type -> Type + updateAccType certs (Acc cert acc_shape ts u) + | cert `elem` certs = + Acc cert acc_shape (map (`arrayOfShape` adj_sh) ts) u + updateAccType _ t = t + + updateBody :: [VName] -> Body SOACS -> Body SOACS + updateBody certs body = + body {bodyStms = fmap (updateStm certs) (bodyStms body)} + + updateStm :: [VName] -> Stm SOACS -> Stm SOACS + updateStm certs (Let pat aux e) = + Let (updatePat certs pat) aux (updateExp certs e) + + updatePat :: [VName] -> Pat Type -> Pat Type + updatePat certs (Pat pes) = + Pat $ map (\pe -> pe {patElemDec = updateAccType certs (patElemDec pe)}) pes + + updateExp :: [VName] -> Exp SOACS -> Exp SOACS + updateExp certs = runIdentity . mapExpM mapper + where + mapper = + (identityMapper :: (Monad m) => Mapper SOACS SOACS m) + { mapOnBody = \_ b -> pure $ updateBody certs b, + mapOnOp = pure . updateSOAC certs + } + + updateSOAC :: [VName] -> SOAC SOACS -> SOAC SOACS + updateSOAC certs = runIdentity . mapSOACM mapper + where + mapper = + identitySOACMapper + { mapOnSOACLambda = pure . updateLambda certs + } + + updateLambda :: [VName] -> Lambda SOACS -> Lambda SOACS + updateLambda certs l = + l + { lambdaParams = map (updateParam certs) (lambdaParams l), + lambdaReturnType = map (updateAccType certs) (lambdaReturnType l), + lambdaBody = updateBody certs (lambdaBody l) + } + +diffWithAcc :: + VjpOps -> + Pat Type -> + StmAux () -> + [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] -> + Lambda SOACS -> + ADM () -> + ADM () +diffWithAcc ops pat aux inputs lam m = do + addStm $ Let pat aux $ WithAcc inputs lam + m + returnSweepCode $ do + adj_shape <- askShape + adjs <- mapM lookupAdj $ patNames pat + -- Transpose the accumulator result adjoints from [vec...][shape...]elem + -- to [shape...][vec...]elem, matching the internal accumulator layout. + adjs' <- transposeAdjs adj_shape adjs + lam' <- renameLambda lam + -- Update the lambda's accumulator parameter types to reflect vectorised + -- element types BEFORE differentiation, so that lookupAdj on Acc variables + -- inside the lambda gives the correct vectorised adjoint type. + let lam'_vec = updateAccParamTypes n_inputs adj_shape lam' + free_vars <- filterM isActive $ namesToList $ freeIn lam'_vec + free_accs <- filterM (fmap isAcc . lookupType) free_vars + let free_vars' = free_vars \\ free_accs + lam'' <- diffLambda' adjs' free_vars' lam'_vec + (inputs_zeroes, inputs') <- + unzip <$> zipWithM (renameInputLambda adj_shape) (chunks lengths adjs) inputs + let certs = map paramName $ take n_inputs $ lambdaParams lam'' + raw_adjs <- + letTupExp "with_acc_contrib" . WithAcc inputs' $ + zeroOutUpdates (zip certs inputs_zeroes) lam'' + -- The accumulator results have shape [shape...][vec...]elem. Transpose + -- back to [vec...][shape...]elem for the adjoint. + let n_arrs = sum lengths + (arr_adjs, free_adjs) = splitAt n_arrs raw_adjs + arr_adjs' <- zipWithM (transposeAccResult adj_shape) (map (\(s, _, _) -> s) inputs) arr_adjs + zipWithM_ insAdj arrs arr_adjs' + zipWithM_ insAdj free_vars' free_adjs + where + n_inputs = length inputs + lengths = map (\(_, as, _) -> length as) inputs + arrs = concatMap (\(_, as, _) -> as) inputs + + -- Transpose the accumulator-related adjoints from [vec...][shape...]elem + -- to [shape...][vec...]elem. Non-accumulator adjs are left unchanged. + transposeAdjs :: Shape -> [Adj] -> ADM [Adj] + transposeAdjs adj_sh adjs + | adj_sh == mempty = pure adjs + | otherwise = do + let n_arrs = sum lengths + (acc_adjs, other_adjs) = splitAt n_arrs adjs + acc_adjs' <- mapM transposeAdj acc_adjs + pure $ acc_adjs' ++ other_adjs + + transposeAdj :: Adj -> ADM Adj + transposeAdj adj = do + v <- adjVal adj + v' <- vecToInner v + pure $ AdjVal $ Var v' + + -- Transpose [shape...][vec...][elem...] to [vec...][shape...][elem...] + transposeAccResult :: Shape -> Shape -> VName -> ADM VName + transposeAccResult adj_sh shape v + | adj_sh == mempty = pure v + | otherwise = do + v_t <- lookupType v + let r = shapeRank adj_sh + s = shapeRank shape + total = arrayRank v_t + perm = [s .. s + r - 1] ++ [0 .. s - 1] ++ [s + r .. total - 1] + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v perm + + renameInputLambda adj_sh as_adj (shape, as, _) = do + -- Compute element types with vectorised dimensions included. + orig_nes_ts <- mapM (fmap (stripArray (shapeRank shape)) . lookupType) as + let vec_nes_ts = map (`arrayOfShape` adj_sh) orig_nes_ts + zeroes <- mapM (zeroArray mempty) vec_nes_ts + -- Transpose adjoints from [vec...][shape...]elem to [shape...][vec...]elem + -- so they match the accumulator layout. + as' <- mapM adjVal as_adj + as'' <- mapM vecToInner as' + pure (map Var zeroes, (shape, as'', Nothing)) + + diffLambda' res_adjs get_adjs_for (Lambda params ts body) = do + localScope (scopeOfLParams params) $ do + Body () stms res <- vjpBody ops res_adjs get_adjs_for body + let body' = Body () stms $ take n_inputs res <> takeLast (length get_adjs_for) res + ts' <- mapM lookupType get_adjs_for + pure $ Lambda params (take n_inputs ts <> ts') body' + +diffUpdateAcc :: + Pat Type -> + StmAux () -> + Safety -> + VName -> + [SubExp] -> + [SubExp] -> + ADM () -> + ADM () +diffUpdateAcc pat aux safety acc is vs m = do + addStm $ Let pat aux $ BasicOp $ UpdateAcc safety acc is vs + m + pat_adjs <- mapM lookupAdjVal (patNames pat) + returnSweepCode $ do + forM_ (zip pat_adjs vs) $ \(adj, v) -> do + adj_t <- lookupType adj + let index_adj = pure $ BasicOp $ Index adj $ fullSlice adj_t $ map DimFix is + adj_i <- + letExp "updateacc_val_adj" =<< case safety of + Unsafe -> + index_adj + Safe -> + -- The primal UpdateAcc may be out-of-bounds, in which case + -- indexing the adjoint is dangerous. + eIf + (eShapeInBounds (arrayShape adj_t) (map eSubExp is)) + (eBody [index_adj]) + (eBody [pure $ zeroExp $ stripArray (length is) adj_t]) + updateSubExpAdj v adj_i diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 08cc96f678..41eabef7c5 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -55,6 +55,7 @@ module Futhark.AD.Rev.Monad renameLoopTape, -- locallyNonvectorised, + vecToInner, ) where @@ -567,7 +568,8 @@ subSubsts m = do data VjpOps = VjpOps { vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS), - vjpStm :: Stm SOACS -> ADM () -> ADM () + vjpStm :: Stm SOACS -> ADM () -> ADM (), + vjpBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS) } -- | @setLoopTape v vs@ establishes @vs@ as the name of the array @@ -625,6 +627,19 @@ locallyNonvectorised e m = do AdjZero {} -> False _ -> True +-- | If we are doing vectorised AD, then transpose the array to bring the vector +-- shape outermost. +-- +-- That, convers @[vec...][shape...][elem...]@ to @[shape...][vec...][elem...]@. +vecToInner :: VName -> ADM VName +vecToInner v = do + adj_shape <- askShape + if adj_shape == mempty + then pure v + else do + v_t <- lookupType v + letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) + -- Note [Consumption] -- -- Parts of this transformation depends on duplicating computation. diff --git a/src/Futhark/AD/Rev/Reduce.hs b/src/Futhark/AD/Rev/Reduce.hs index e896936355..8d8deca5d5 100644 --- a/src/Futhark/AD/Rev/Reduce.hs +++ b/src/Futhark/AD/Rev/Reduce.hs @@ -75,23 +75,12 @@ mkF lam = do bodyBind $ lambdaBody lam_r pure (map paramName aps, lam') --- | If we are doing vectorised AD, then transpose the array to bring the vector --- shape outermost. -transposeIfNeeded :: VName -> ADM VName -transposeIfNeeded v = do - adj_shape <- askShape - if adj_shape == mempty - then pure v - else do - v_t <- lookupType v - letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) - diffReduce :: VjpOps -> [VName] -> SubExp -> [VName] -> Reduce SOACS -> ADM () diffReduce _ops [adj] w [a] red | Just [(op, _, _, _)] <- lamIsBinOp $ redLambda red, isAdd op = do adj_rep <- - transposeIfNeeded <=< letExp (baseName adj <> "_rep") $ + vecToInner <=< letExp (baseName adj <> "_rep") $ BasicOp (Replicate (Shape [w]) (Var adj)) void $ updateAdj a adj_rep where @@ -131,7 +120,7 @@ diffReduce ops pat_adj w as red = do as_adj <- letTupExp "red_contribs" . Op . Screma w (ls ++ as ++ rs) =<< mapSOAC f_adj - zipWithM_ updateAdj as =<< mapM transposeIfNeeded as_adj + zipWithM_ updateAdj as =<< mapM vecToInner as_adj where renameRed (Reduce comm lam nes) = Reduce comm <$> renameLambda lam <*> pure nes @@ -341,7 +330,7 @@ diffMulReduce _ops x aux w mul ne as m = do as_adjup <- letExp "prod_contrib" . Op . Screma w [as] =<< mapSOAC map_lam_rev - updateAdj as =<< transposeIfNeeded as_adjup + updateAdj as =<< vecToInner as_adjup where getDiv :: PrimType -> BinOp getDiv (IntType t) = SDiv t Unsafe From a0d0fc2f822df3b698e9d106c453568834e77d97 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 May 2026 11:12:37 +0200 Subject: [PATCH 41/70] Fix forward-mode for Stream. --- src/Futhark/AD/Fwd.hs | 63 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 51 insertions(+), 12 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index f562b55c9c..c6b4ec70b7 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -18,7 +18,7 @@ import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools -import Futhark.Util (interleave) +import Futhark.Util (interleave, splitAt3, unterleave) zeroExp :: Type -> Exp SOACS zeroExp (Prim pt) = @@ -34,8 +34,7 @@ tanType (Acc acc ispace ts u) = do pure $ Acc acc_tan (tan_shape <> ispace) ts u tanType t = do shape <- askShape - pure $ - arrayOf (Prim (elemType t)) (shape `prependShape` arrayShape t) u + pure $ arrayOf (Prim (elemType t)) (shape `prependShape` arrayShape t) u where u = case t of Array _ _ u' -> u' @@ -386,10 +385,42 @@ fwdWithAccLambda inputs (Lambda params _ body) = do ts <- map (stripArray (shapeRank shape)) <$> mapM lookupType arrs newParam "acc_p_tan" $ Acc c (tan_shape <> shape) ts NoUniqueness -fwdStreamLambda :: Lambda SOACS -> ADM (Lambda SOACS) -fwdStreamLambda (Lambda params _ body) = do - params' <- (take 1 params ++) <$> bundleNewList (drop 1 params) - mkLambda params' $ bodyBind =<< fwdBody body +fwdStreamLambda :: Int -> Lambda SOACS -> ADM (Lambda SOACS) +fwdStreamLambda num_accs (Lambda params _ body) = do + tan_shape <- askShape + let (chunk_params, acc_params, arr_params) = splitAt3 1 num_accs params + acc_params' <- bundleNewList acc_params + (arr_params', arr_params'_tan) <- mapAndUnzipM onArrParam arr_params + let params' = + chunk_params <> acc_params' <> interleave arr_params' arr_params'_tan + mkLambda params' $ do + zipWithM_ (trArrParamTan tan_shape) arr_params' arr_params'_tan + (acc_res, map_res) <- fmap (splitAt (num_accs * 2)) . bodyBind =<< fwdBody body + let (map_res_primal, map_res_tan) = unterleave map_res + map_res_tan' <- mapM (trMapResTan tan_shape) map_res_tan + pure $ acc_res <> interleave map_res_primal map_res_tan' + where + -- Array parameters need to be treated specially as the chunk parameter + -- must always be outermost. + onArrParam p = do + shape <- askShape + (p', p_tan) <- bundleNew p + let perm = auxPerm shape $ paramType p_tan + pure (p', p_tan {paramDec = rearrangeType perm (paramType p_tan)}) + + -- Put the tangent shape back in the outermost position. + trArrParamTan tan_shape p p_tan = do + let perm = rearrangeInverse $ auxPerm tan_shape $ paramType p_tan + v <- + letExp (baseName (paramName p_tan)) . BasicOp $ + Rearrange (paramName p_tan) perm + insertTan (paramName p) v + + -- Put the chunk size back in the outermost position. + trMapResTan tan_shape (SubExpRes cs ~(Var v)) = do + v_t <- lookupType v + let perm = auxPerm tan_shape v_t + fmap varRes . certifying cs $ letExp (baseName v) . BasicOp $ Rearrange v perm vecPerm :: Shape -> Type -> [Int] vecPerm = auxPerm @@ -461,13 +492,13 @@ fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds post_lam)) = do redLambda = op', redNeutral = redNeutral red `interleave` neutral_tans } -fwdSOAC pat aux (Stream size xs nes lam) = do +fwdSOAC pat aux (Stream size xs accs lam) = do pat' <- bundleNewPat pat - lam' <- fwdStreamLambda lam + lam' <- fwdStreamLambda (length accs) lam xs' <- soacInputsWithTangents xs - nes_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType <=< subExpType) nes - let nes' = interleave nes nes_tan - addStm $ Let pat' aux $ Op $ Stream size xs' nes' lam' + accs_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType <=< subExpType) accs + let accs' = interleave accs accs_tan + addStm $ Let pat' aux $ Op $ Stream size xs' accs' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do -- TODO: this is probably not very efficient in the vectorised case as we end -- up with a dreadful update operator that involves arrays. @@ -617,3 +648,11 @@ fwdJVP scope shape attrs (Lambda params _ body) = params_tan <- mapM newTan params mkLambda (params <> params_tan) $ bodyBind =<< fwdBodyTansLast body + +-- Note [Forward-Mode vectorised AD] +-- +-- An primal variable of type 't' has a tangent of type '[tan_shape]t', where +-- 'tan_shape' is the vector shape (which may be empty in the non-vectorised +-- case). This requires some care for SOACs, which always map across the +-- outermost dimension: basically we have to transpose the inputs and the +-- outputs. From dd07eba7de50b8adfc74a04f2948dada0f8990e7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 May 2026 22:05:40 +0200 Subject: [PATCH 42/70] Fix vectorised scans. --- src/Futhark/AD/Rev/Scan.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index 2a5030fd8a..b6f8ca01b6 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -435,7 +435,7 @@ diffScanVec :: [VName] -> ADM () -> ADM () -diffScanVec ops ys aux w lam ne as m = locallyNonvectorised (ys, lam, as) $ do +diffScanVec ops ys aux w lam ne as m = do stmts <- collectStms_ $ do rank <- arrayRank <$> lookupType (head as) let rear = [1, 0] ++ drop 2 [0 .. rank - 1] From 4859c04cbd2060b06086f42156ea00ffe8d7db6c Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 30 May 2026 22:33:49 +0200 Subject: [PATCH 43/70] Fix final known vector-AD bug. --- src/Futhark/AD/Fwd.hs | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index c6b4ec70b7..39904b0e29 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -518,12 +518,9 @@ fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do mapM_ fwdStm stms let (res_is, res_vs) = splitAt n_indices res (res_is ++) <$> bundleTangents res_vs - fwdHistBucket l@(Lambda params ret body) = - let (r_is, r_vs) = splitAt n_indices ret - in Lambda - <$> bundleNewList params - <*> ((r_is ++) <$> bundleTangents r_vs) - <*> inScopeOf l (fwdBodyHist body) + fwdHistBucket (Lambda params _ body) = do + params' <- bundleNewList params + mkLambda params' $ bodyBind =<< fwdBodyHist body fwdHist :: HistOp SOACS -> ADM (HistOp SOACS) fwdHist (HistOp shape rf dest nes op) = do From 7210cf39534affc21a8ed335fd13fb92689608c2 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 31 May 2026 21:08:46 +0200 Subject: [PATCH 44/70] Unify tests. --- tests/ad/map7.fut | 29 ++++++++++++++++++++--------- tests/ad/vec/map6.fut | 38 -------------------------------------- 2 files changed, 20 insertions(+), 47 deletions(-) delete mode 100644 tests/ad/vec/map6.fut diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut index 941d2d15b1..1f0eabd15c 100644 --- a/tests/ad/map7.fut +++ b/tests/ad/map7.fut @@ -2,9 +2,17 @@ -- has active free variables. -- == -- tags { autodiff } --- entry: fwd_J rev_J fwd_vec_J +-- entry: fwd_map fwd_vec rev_map rev_vec -- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } --- output { [0.0, 0.0, 0.0, 0.0, -4.0, 0.0, 0.0, 0.0] } +-- output { [[0.0, 2.0, 3.0, 4.0], +-- [0.0, 0.0, 1.0, 1.0], +-- [0.0, 0.0, 0.0, 1.0], +-- [0.0, 0.0, 0.0, 0.0], +-- [-4.0, -6.0, -7.0, -8.0], +-- [0.0, 0.0, -1.0, -1.0], +-- [0.0, 0.0, 0.0, -1.0], +-- [0.0, 0.0, 0.0, 0.0]] +-- } def obj (x: [8]f64) = #[unsafe] @@ -15,14 +23,17 @@ def obj (x: [8]f64) = map (map f64.sum) col_w_pre_red let col_eq: [4]f64 = map (\w -> w[0] - w[1]) col_w_red - in f64.maximum col_eq + in col_eq -entry fwd_J (x: [8]f64) = +entry fwd_map (x: [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) -entry rev_J (x: [8]f64) = - vjp obj x 1 +entry fwd_vec (x: [8]f64) = + #[unroll] + jvp_vec obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) -entry fwd_vec_J (x: [8]f64) = - let seeds = tabulate 8 (\i -> replicate 8 0 with [i] = 1) - in jvp_vec obj x seeds +entry rev_map (x: [8]f64) = + transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) + +entry rev_vec (x: [8]f64) = + transpose (#[unroll] vjp_vec obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) diff --git a/tests/ad/vec/map6.fut b/tests/ad/vec/map6.fut deleted file mode 100644 index fb86d14dd9..0000000000 --- a/tests/ad/vec/map6.fut +++ /dev/null @@ -1,38 +0,0 @@ --- #1878 --- == --- tags { autodiff } --- entry: fwd_map fwd_vec rev_map rev_vec --- input { [1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0] } --- output { [[0.0, 2.0, 3.0, 4.0], --- [0.0, 0.0, 1.0, 1.0], --- [0.0, 0.0, 0.0, 1.0], --- [0.0, 0.0, 0.0, 0.0], --- [-4.0, -6.0, -7.0, -8.0], --- [0.0, 0.0, -1.0, -1.0], --- [0.0, 0.0, 0.0, -1.0], --- [0.0, 0.0, 0.0, 0.0]] --- } - -def obj (x: [8]f64) = - #[unsafe] - -- For simplicity of generated code. - let col_w_pre_red = - tabulate_3d 4 2 4 (\k i j -> x[k + j] * x[i + j]) - let col_w_red = - map (map f64.sum) col_w_pre_red - let col_eq: [4]f64 = - map (\w -> w[0] - w[1]) col_w_red - in col_eq - -entry fwd_map (x: [8]f64) = - tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) - -entry fwd_vec (x: [8]f64) = - #[unroll] - jvp_vec obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) - -entry rev_map (x: [8]f64) = - transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) - -entry rev_vec (x: [8]f64) = - transpose (#[unroll] vjp_vec obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) From fed699d0270cdc6df217dd387f3ce151b66d870f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 31 May 2026 22:15:05 +0200 Subject: [PATCH 45/70] Remove this unroll. --- tests/ad/map7.fut | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut index 1f0eabd15c..5d3da75fc7 100644 --- a/tests/ad/map7.fut +++ b/tests/ad/map7.fut @@ -29,7 +29,6 @@ entry fwd_map (x: [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) entry fwd_vec (x: [8]f64) = - #[unroll] jvp_vec obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) entry rev_map (x: [8]f64) = From e035fca9c3074f299c0ddee2d34d4b85e85e4f99 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 3 Jun 2026 17:04:13 +0200 Subject: [PATCH 46/70] Specialised handling of Hist. --- src/Futhark/AD/Fwd.hs | 83 +++++++++++++++++++++++++++++++++------- src/Futhark/AD/Shared.hs | 2 +- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 39904b0e29..92ad80bb83 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -18,7 +18,8 @@ import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools -import Futhark.Util (interleave, splitAt3, unterleave) +import Futhark.Transform.Rename +import Futhark.Util (chunks, interleave, splitAt3, unterleave) zeroExp :: Type -> Exp SOACS zeroExp (Prim pt) = @@ -500,18 +501,59 @@ fwdSOAC pat aux (Stream size xs accs lam) = do let accs' = interleave accs accs_tan addStm $ Let pat' aux $ Op $ Stream size xs' accs' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do - -- TODO: this is probably not very efficient in the vectorised case as we end - -- up with a dreadful update operator that involves arrays. - (pat', to_transpose) <- soacResPat 0 0 pat - ops' <- mapM fwdHist ops - bucket_fun' <- fwdHistBucket bucket_fun - arrs' <- soacInputsWithTangents arrs - addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' tan_shape <- askShape - forM_ to_transpose $ \(rpat, v) -> do - v_t <- lookupType v - let perm = rearrangeInverse $ vecPerm tan_shape v_t - letBind rpat $ BasicOp $ Rearrange v perm + -- See Note [Vector tangent of Hist] + if tan_shape /= mempty + then do + tmp_pat <- + fmap Pat $ forM (lambdaReturnType bucket_fun) $ \t -> + PatElem <$> newVName "hist_tmp" <*> pure (t `arrayOfRow` w) + + fwdSOAC tmp_pat aux . Screma w arrs =<< mapSOAC bucket_fun + + let num_is = sum $ map (shapeRank . histShape) ops + (tmp_pat_is, tmp_pat_vals) = splitAt num_is (patNames tmp_pat) + tmp_pat_tans <- mapM tangent tmp_pat_vals + + let dests = concatMap histDest ops + dests_tans <- mapM tangent dests + dests_copies <- forM dests $ \arr -> + letExp (baseName arr <> "_copy") =<< eCopy (eVar arr) + + letBind pat . Op . Hist w (patNames tmp_pat) ops + =<< mkIdentityLambda (lambdaReturnType bucket_fun) + + prims_and_tans <- local (\env -> env {envTanShape = mempty}) + $ letTupExp "hist_prims_and_tans" + <=< mapNest tan_shape (Pair (map Var tmp_pat_tans) (map Var dests_tans)) + $ \(Pair tmp_pat_tans' dests_tans') -> do + tmp_pat_tans'' <- mapM asVName tmp_pat_tans' + dests_tans'' <- mapM asVName dests_tans' + let ops_dests_tans = chunks (map (length . histDest) ops) dests_tans'' + ops_dests_copies = chunks (map (length . histDest) ops) dests_copies + ops' <- forM (zip3 ops ops_dests_tans ops_dests_copies) $ + \(op, op_dests_tans, op_dests_copies) -> do + op_lam <- renameLambda $ histOp op + zipWithM_ insertTan (histDest op) op_dests_tans + dest_copies <- forM op_dests_copies $ \arr -> + letExp (baseName arr <> "_copy") =<< eCopy (eVar arr) + fwdHist $ op {histOp = op_lam, histDest = dest_copies} + bucket_fun' <- mkIdentityLambda $ Prim int64 : concatMap (lambdaReturnType . histOp) ops' + let hist_arrs = tmp_pat_is ++ interleave tmp_pat_vals tmp_pat_tans'' + pure . Op $ Hist w hist_arrs ops' bucket_fun' + + let (_prims, tans) = unterleave prims_and_tans + zipWithM_ insertTan (patNames pat) tans + else do + (pat', to_transpose) <- soacResPat 0 0 pat + ops' <- mapM fwdHist ops + bucket_fun' <- fwdHistBucket bucket_fun + arrs' <- soacInputsWithTangents arrs + addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' + forM_ to_transpose $ \(rpat, v) -> do + v_t <- lookupType v + let perm = rearrangeInverse $ vecPerm tan_shape v_t + letBind rpat $ BasicOp $ Rearrange v perm where n_indices = sum $ map (shapeRank . histShape) ops fwdBodyHist (Body _ stms res) = buildBody_ $ do @@ -525,7 +567,8 @@ fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do fwdHist :: HistOp SOACS -> ADM (HistOp SOACS) fwdHist (HistOp shape rf dest nes op) = do dest' <- soacInputsWithTangents dest - nes_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType op + nes_tan <- + mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType op op' <- fwdLambda op pure $ HistOp @@ -653,3 +696,17 @@ fwdJVP scope shape attrs (Lambda params _ body) = -- case). This requires some care for SOACs, which always map across the -- outermost dimension: basically we have to transpose the inputs and the -- outputs. + +-- Note [Vector tangent of Hist] +-- +-- Naive vectorised tangents for Hist results in an operator that mixes arrays +-- and scalar code. Our code generation for this (particularly for GPU backends) +-- is terrible (partly by necessity; it's just a bad pattern). Hence, we handle +-- it by essentially locally non-vectorising. The idea is to first split apart +-- the map function, in case it does something interesting. Then we compute the +-- primal result using a normal histogram (necessary in case the tangent shape +-- is mempty), and then we map over the tangent vector of each destination and +-- input, computing scalar tangents. This requires us to copy the (primal) +-- destination, because it is repeatedly consumed - in principle this might +-- wrecks the cost model, however, vectorised AD already implicitly copies the +-- tangent, so I think we get away with it. diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index 35c69ce24e..2347e8a131 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -36,7 +36,7 @@ mapNest shape x f = do x_v <- traverse asVName x x_p <- traverse (newParam "xp" . rowType <=< lookupType) x_v lam <- mkLambda (toList x_p) $ do - fmap (subExpsRes . pure) . letSubExp "mapnest_res" + fmap (subExpsRes . map Var) . letTupExp "mapnest_res" =<< f (fmap (Var . paramName) x_p) Op . Screma w (toList x_v) <$> mapSOAC lam From 02b209e5561b8bef0c53350854f2ed1a0a8cf485 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 3 Jun 2026 23:25:04 +0200 Subject: [PATCH 47/70] Revert "Specialised handling of Hist." This reverts commit e035fca9c3074f299c0ddee2d34d4b85e85e4f99. --- src/Futhark/AD/Fwd.hs | 83 +++++++--------------------------------- src/Futhark/AD/Shared.hs | 2 +- 2 files changed, 14 insertions(+), 71 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 92ad80bb83..39904b0e29 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -18,8 +18,7 @@ import Futhark.Analysis.PrimExp.Convert import Futhark.Builder import Futhark.IR.SOACS import Futhark.Tools -import Futhark.Transform.Rename -import Futhark.Util (chunks, interleave, splitAt3, unterleave) +import Futhark.Util (interleave, splitAt3, unterleave) zeroExp :: Type -> Exp SOACS zeroExp (Prim pt) = @@ -501,59 +500,18 @@ fwdSOAC pat aux (Stream size xs accs lam) = do let accs' = interleave accs accs_tan addStm $ Let pat' aux $ Op $ Stream size xs' accs' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do + -- TODO: this is probably not very efficient in the vectorised case as we end + -- up with a dreadful update operator that involves arrays. + (pat', to_transpose) <- soacResPat 0 0 pat + ops' <- mapM fwdHist ops + bucket_fun' <- fwdHistBucket bucket_fun + arrs' <- soacInputsWithTangents arrs + addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' tan_shape <- askShape - -- See Note [Vector tangent of Hist] - if tan_shape /= mempty - then do - tmp_pat <- - fmap Pat $ forM (lambdaReturnType bucket_fun) $ \t -> - PatElem <$> newVName "hist_tmp" <*> pure (t `arrayOfRow` w) - - fwdSOAC tmp_pat aux . Screma w arrs =<< mapSOAC bucket_fun - - let num_is = sum $ map (shapeRank . histShape) ops - (tmp_pat_is, tmp_pat_vals) = splitAt num_is (patNames tmp_pat) - tmp_pat_tans <- mapM tangent tmp_pat_vals - - let dests = concatMap histDest ops - dests_tans <- mapM tangent dests - dests_copies <- forM dests $ \arr -> - letExp (baseName arr <> "_copy") =<< eCopy (eVar arr) - - letBind pat . Op . Hist w (patNames tmp_pat) ops - =<< mkIdentityLambda (lambdaReturnType bucket_fun) - - prims_and_tans <- local (\env -> env {envTanShape = mempty}) - $ letTupExp "hist_prims_and_tans" - <=< mapNest tan_shape (Pair (map Var tmp_pat_tans) (map Var dests_tans)) - $ \(Pair tmp_pat_tans' dests_tans') -> do - tmp_pat_tans'' <- mapM asVName tmp_pat_tans' - dests_tans'' <- mapM asVName dests_tans' - let ops_dests_tans = chunks (map (length . histDest) ops) dests_tans'' - ops_dests_copies = chunks (map (length . histDest) ops) dests_copies - ops' <- forM (zip3 ops ops_dests_tans ops_dests_copies) $ - \(op, op_dests_tans, op_dests_copies) -> do - op_lam <- renameLambda $ histOp op - zipWithM_ insertTan (histDest op) op_dests_tans - dest_copies <- forM op_dests_copies $ \arr -> - letExp (baseName arr <> "_copy") =<< eCopy (eVar arr) - fwdHist $ op {histOp = op_lam, histDest = dest_copies} - bucket_fun' <- mkIdentityLambda $ Prim int64 : concatMap (lambdaReturnType . histOp) ops' - let hist_arrs = tmp_pat_is ++ interleave tmp_pat_vals tmp_pat_tans'' - pure . Op $ Hist w hist_arrs ops' bucket_fun' - - let (_prims, tans) = unterleave prims_and_tans - zipWithM_ insertTan (patNames pat) tans - else do - (pat', to_transpose) <- soacResPat 0 0 pat - ops' <- mapM fwdHist ops - bucket_fun' <- fwdHistBucket bucket_fun - arrs' <- soacInputsWithTangents arrs - addStm $ Let pat' aux $ Op $ Hist w arrs' ops' bucket_fun' - forM_ to_transpose $ \(rpat, v) -> do - v_t <- lookupType v - let perm = rearrangeInverse $ vecPerm tan_shape v_t - letBind rpat $ BasicOp $ Rearrange v perm + forM_ to_transpose $ \(rpat, v) -> do + v_t <- lookupType v + let perm = rearrangeInverse $ vecPerm tan_shape v_t + letBind rpat $ BasicOp $ Rearrange v perm where n_indices = sum $ map (shapeRank . histShape) ops fwdBodyHist (Body _ stms res) = buildBody_ $ do @@ -567,8 +525,7 @@ fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do fwdHist :: HistOp SOACS -> ADM (HistOp SOACS) fwdHist (HistOp shape rf dest nes op) = do dest' <- soacInputsWithTangents dest - nes_tan <- - mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType op + nes_tan <- mapM (letSubExp "zero" . zeroExp <=< tanType) $ lambdaReturnType op op' <- fwdLambda op pure $ HistOp @@ -696,17 +653,3 @@ fwdJVP scope shape attrs (Lambda params _ body) = -- case). This requires some care for SOACs, which always map across the -- outermost dimension: basically we have to transpose the inputs and the -- outputs. - --- Note [Vector tangent of Hist] --- --- Naive vectorised tangents for Hist results in an operator that mixes arrays --- and scalar code. Our code generation for this (particularly for GPU backends) --- is terrible (partly by necessity; it's just a bad pattern). Hence, we handle --- it by essentially locally non-vectorising. The idea is to first split apart --- the map function, in case it does something interesting. Then we compute the --- primal result using a normal histogram (necessary in case the tangent shape --- is mempty), and then we map over the tangent vector of each destination and --- input, computing scalar tangents. This requires us to copy the (primal) --- destination, because it is repeatedly consumed - in principle this might --- wrecks the cost model, however, vectorised AD already implicitly copies the --- tangent, so I think we get away with it. diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index 2347e8a131..35c69ce24e 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -36,7 +36,7 @@ mapNest shape x f = do x_v <- traverse asVName x x_p <- traverse (newParam "xp" . rowType <=< lookupType) x_v lam <- mkLambda (toList x_p) $ do - fmap (subExpsRes . map Var) . letTupExp "mapnest_res" + fmap (subExpsRes . pure) . letSubExp "mapnest_res" =<< f (fmap (Var . paramName) x_p) Op . Screma w (toList x_v) <$> mapSOAC lam From e33ace413a70d66f72a2cf96f8799deb0818154a Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 5 Jun 2026 11:10:08 +0200 Subject: [PATCH 48/70] Refreshen documentation. --- prelude/ad.fut | 48 ++++++++++++++++++++++++++++-------------------- 1 file changed, 28 insertions(+), 20 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index c31e2a1999..58c4ed6d6d 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -14,8 +14,8 @@ -- -- Futhark's AD support includes the following: -- --- * Differentiation operators for forward-mode (`jvp`) and reverse-mode --- (`vjp`). +-- * Differentiation operators for forward-mode (`jvp`@term) and reverse-mode +-- (`vjp`@term). -- -- * Arbitrary control flow in differentiable code. -- @@ -23,7 +23,9 @@ -- arbitrary mixing of forward- and reverse mode (although using multiple -- rounds of reverse mode is rarely useful and often slow). -- --- * Custom derivatives (`with_vjp`). +-- * Custom derivatives (`with_vjp`@term). +-- +-- * Vectorised AD (`vjp_vec`@term, `vjp_vec`@term). -- -- * Checkpointing of sequential loops. -- @@ -62,8 +64,7 @@ -- given situation depends on whether the function has more inputs or -- outputs. -- --- You can freely nest `vjp` and `jvp` to compute higher-order --- derivatives. +-- We can freely nest `vjp` and `jvp` to compute higher-order derivatives. -- -- ## Efficiency -- @@ -91,27 +92,34 @@ -- but it can still be substantial for programs with deep sequential -- loops. -- +-- It varies on a case-by-case basis whether vectorised AD is faster or not. It +-- essentially converts propagation of (co-)tangents from scalar to array +-- operations, which can have a significant impact on memory accesses, depending +-- on how the compiler manages to optimise the resulting code. It is hard to +-- predict whether this offsets the reduction in primal work. If the vector size +-- is a constant, and the `#[unroll]` attribute is put on the AD operator, then +-- the vectors become unrolled (turned into tuples, essentially), although this +-- should only be done when the vector size is quite small, as the increase in +-- code size is substantial. +-- -- ## Differentiable functions -- --- AD only gives meaningful results for differentiable functions. The --- Futhark type system does not distinguish differentiable or --- non-differentiable operations. As a rule of thumb, a function is --- differentiable if its results are computed using a composition of --- primitive floating-point operations, without ever converting to or --- from integers. +-- AD only gives meaningful results for differentiable functions. The Futhark +-- type system does not distinguish differentiable from non-differentiable +-- operations. As a rule of thumb, a function is differentiable if its results +-- are computed using a composition of primitive floating-point operations, +-- without ever converting to or from integers. -- --- Note that a function whose input or output is a sum type with more --- than one constructor is *not* differentiable (or at least the --- sum-typed part is not). This is because the choice of constructor --- is not a continuous quantity. +-- Note that a function whose input or output is a sum type with more than one +-- constructor is *not* differentiable (or at least the sum-typed part is not). +-- This is because the choice of constructor is not a continuous quantity. -- -- ## Limitations -- --- `jvp` is expected to work in all cases. `vjp` has limitations when --- using the GPU backends similar to those for irregular flattening. --- Specifically, you should avoid structures with variant sizes, such --- as loops that carry an array that changes size through the --- execution of the loop. +-- `jvp` is expected to work in all cases. `vjp` has limitations when using the +-- GPU backends similar to those for irregular flattening. Specifically, you +-- should avoid structures with variant sizes, such as loops that carry an array +-- that changes size through the execution of the loop. -- | Jacobian-Vector Product ("forward mode"), producing also the -- primal result as the first element of the result tuple. From d1f4d318cc4599f02b956f93881c98dc94c087b6 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 10 Jun 2026 10:39:35 +0200 Subject: [PATCH 49/70] More robust equality checking. --- tests/ad/reducebyindexminmax7.fut | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 28495825af..51eb3a3752 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -19,7 +19,13 @@ def rev_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = def fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) in jvp_vec (primal is dst) vs seeds - |> transpose + |> transpose + +def approx_eql (rel_tol: f32) (a: f32) (b: f32) : bool = + let diff = f32.abs (a - b) + let scale = f32.max (f32.abs a) (f32.abs b) + let abs_tol = 100.0 * f32.epsilon * scale + in diff <= f32.max abs_tol (rel_tol * scale) def main [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let is = map (\i -> (i64.abs i) %% m) is' @@ -27,7 +33,7 @@ def main [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let f = fwd is dst vs let rv = rev_vec is dst vs let fv = fwd_vec is dst vs - let eq_rf = map2 (map2 (==)) r f |> map (reduce (&&) true) |> reduce (&&) true - let eq_rrv = map2 (map2 (==)) r rv |> map (reduce (&&) true) |> reduce (&&) true - let eq_ffv = map2 (map2 (==)) f fv |> map (reduce (&&) true) |> reduce (&&) true + let eq_rf = and (map2 (approx_eql 1e-9) (flatten_3d r) (flatten_3d f)) + let eq_rrv = and (map2 (approx_eql 1e-9) (flatten_3d r) (flatten_3d rv)) + let eq_ffv = and (map2 (approx_eql 1e-9) (flatten_3d f) (flatten_3d fv)) in eq_rf && eq_rrv && eq_ffv From eeff1c785fe51c2c967991dc53c097114df6d968 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 10 Jun 2026 11:07:58 +0200 Subject: [PATCH 50/70] Individual tests. --- tests/ad/reducebyindexminmax7.fut | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 51eb3a3752..9cee80a45e 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -1,22 +1,27 @@ -- == -- tags { autodiff } +-- entry: main -- compiled random input { [500]i64 [100][30]f32 [500][30]f32 } output { true } +-- == +-- entry: rev fwd rev_vec fwd_vec +-- compiled random input { [500]i64 [100][30]f32 [500][30]f32 } auto output + def primal [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs -def rev [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = +entry rev [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate m (\i -> vjp (primal is dst) vs (replicate m (replicate k 0) with [i] = replicate k 1)) -def fwd [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = +entry fwd [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate n (\i -> jvp (primal is dst) vs (replicate n (replicate k 0) with [i] = replicate k 1)) |> transpose -def rev_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = +entry rev_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let seeds = tabulate m (\i -> replicate m (replicate k 0) with [i] = replicate k 1) in vjp_vec (primal is dst) vs seeds -def fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = +entry fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) in jvp_vec (primal is dst) vs seeds |> transpose From 055da8415aa0cb416a6a89113521094bb7d578ae Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 10 Jun 2026 11:13:59 +0200 Subject: [PATCH 51/70] Lower tolerance. --- tests/ad/reducebyindexminmax7.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 9cee80a45e..6cd57485af 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -38,7 +38,7 @@ def main [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let f = fwd is dst vs let rv = rev_vec is dst vs let fv = fwd_vec is dst vs - let eq_rf = and (map2 (approx_eql 1e-9) (flatten_3d r) (flatten_3d f)) - let eq_rrv = and (map2 (approx_eql 1e-9) (flatten_3d r) (flatten_3d rv)) - let eq_ffv = and (map2 (approx_eql 1e-9) (flatten_3d f) (flatten_3d fv)) + let eq_rf = and (map2 (approx_eql 1e-3) (flatten_3d r) (flatten_3d f)) + let eq_rrv = and (map2 (approx_eql 1e-3) (flatten_3d r) (flatten_3d rv)) + let eq_ffv = and (map2 (approx_eql 1e-3) (flatten_3d f) (flatten_3d fv)) in eq_rf && eq_rrv && eq_ffv From 4a94616a7e23a2e8957b6e2f50261c63fe69a238 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Wed, 10 Jun 2026 13:38:01 +0200 Subject: [PATCH 52/70] Rewrite this test to be less crazy. --- tests/ad/reducebyindexminmax7.fut | 23 ++++------------------- tests/ad/reducebyindexminmax7.in | Bin 0 -> 76061 bytes tests/ad/reducebyindexminmax7.out.gz | Bin 0 -> 10414 bytes 3 files changed, 4 insertions(+), 19 deletions(-) create mode 100644 tests/ad/reducebyindexminmax7.in create mode 100644 tests/ad/reducebyindexminmax7.out.gz diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 6cd57485af..574ba3297d 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -1,14 +1,10 @@ --- == --- tags { autodiff } --- entry: main --- compiled random input { [500]i64 [100][30]f32 [500][30]f32 } output { true } - -- == -- entry: rev fwd rev_vec fwd_vec --- compiled random input { [500]i64 [100][30]f32 [500][30]f32 } auto output +-- compiled input @ reducebyindexminmax7.in output @ reducebyindexminmax7.out.gz -def primal [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = - reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs +def primal [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = + let is = map (\i -> (i64.abs i) %% m) is' + in reduce_by_index (copy dst) (map2 f32.max) (replicate k f32.lowest) is vs entry rev [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = tabulate m (\i -> vjp (primal is dst) vs (replicate m (replicate k 0) with [i] = replicate k 1)) @@ -31,14 +27,3 @@ def approx_eql (rel_tol: f32) (a: f32) (b: f32) : bool = let scale = f32.max (f32.abs a) (f32.abs b) let abs_tol = 100.0 * f32.epsilon * scale in diff <= f32.max abs_tol (rel_tol * scale) - -def main [n] [m] [k] (is': [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = - let is = map (\i -> (i64.abs i) %% m) is' - let r = rev is dst vs - let f = fwd is dst vs - let rv = rev_vec is dst vs - let fv = fwd_vec is dst vs - let eq_rf = and (map2 (approx_eql 1e-3) (flatten_3d r) (flatten_3d f)) - let eq_rrv = and (map2 (approx_eql 1e-3) (flatten_3d r) (flatten_3d rv)) - let eq_ffv = and (map2 (approx_eql 1e-3) (flatten_3d f) (flatten_3d fv)) - in eq_rf && eq_rrv && eq_ffv diff --git a/tests/ad/reducebyindexminmax7.in b/tests/ad/reducebyindexminmax7.in new file mode 100644 index 0000000000000000000000000000000000000000..b5eac437a545a3f236479d4af144765e4fa3fa4d GIT binary patch literal 76061 zcmeF3_dk_?{Qu2tAt7XE?>)}-eA?OBqpa+`va=Fpg?2(IBN-JTDU_6wdP_*dD48J% zDXQ=Jj?e$_{h{;Ab-SJ0Ik$7&uIs#>ujk|OxL;2%DoUY1b=6f$ivROkyIUQ#R66Nd zEZn;zapCC7r+e=?E~bcXhDf;*hkOnsiQYQDTK#kR^BZ2qw{4nccj>A>OG`GreP8=c zCfi%5*k&tb_2iFJOKIB5-@32%rwj0o=I%cs{Fe7#DQ)rYkfAE}affrGW2(n`&0;w$ zhfj9|eLQwsUB~K$WPG<=!Op4BQ1cV3KNP1=e7M3bcwTu^er-;1cwxND_NZo7f|0Q6 ztv_#mOn>IR`M8scPAP3we_zReIZv`5TgE0gJ?7eWwrb$r?%B?o6L~kibt^4yv2qMs z4K>FNZ5G-nZab?j!#YKnf3>$K30+p|&n+F=)A(TBV~+S-Zot zzh60YXVUCH!#FuP%Ww4=E8=lif*MR2TwIO43_pC}Y!FlxOnd0UmwJ-4rV^{ZoN!Gr zB-KSFeL$Tt`lF?BkK->n<$b=%M6wSRw&ZUX{ywkZ{8tUn#=F+7=rsH+oBF)e${8a?_i9!97LoMw?3MV?xgb7*~(Ftblb`>uJ3T?Aj>qCu{WIeZ@Rp3b5XUWz*kxOBD2}}E4R0l zZ{3!rI`=%N`X!{#>cs-w)F zF={)%02ij5Qo3EjOrKre9j?=OLwlF`o`pn2r>a<_=>)+VRaeeJ6(wJE^2Ja;^Hds# z{i5g4Ig9A;tv2C`d>wPj3Qn2cSvB=`UyYP<%&OrbpKL(o$8%rwIzDW=>OF|c8){x7 z{rF|raY<44esST4&XS^j2JP$4m!I1@hkX5|?%d0rzH&x_hP0#P1$(afOV^^Up!?q$ z2E#pOPWK3ZikYD^8Tzhac1?7XTfZog!&02{@slS~P2PHzOB=r~KIr4+JUtt?Q}Tpt zy0iD40rxlJ9+xVsR|bV$ID{V>AN|X!o7&7D@L!)$C+k=GOD_y_Zgs>&|B0UK8gZ)e zCWxe*jo(LvzKmRuwGd>FWgFX~v#`8HDe-qYD$w`W%3l#K)%d?(Q_cAp7l%uHR+tME zvRZCYT3KscxSQ{JHB#T?c+p(j${0UIN8abk`3cV_?glT#CC2*ey}30S2BH( z*LCBX^coa(TTI%l2J$>?3B4hsEK07a-$g_Ll$jWZYhiM^I zoC(D(k!r7LExG6En-cW45 zyLfj0?+YKKk+KTUCadONG_5(t@t$|8x!4j+b5EVxfM;PHXyZpFz#su^rk$ zM)q;PY#UkS&-D;DdM%1J8in-V1=72)Uin<#We^>qlT(pkIydQ>TQRBsV0T(e{tvzD zuj@~$-8;=_!5Ddmdk39)xxq6hCXQj!RKW9Ox5(3H9I3^W?k-H)CY`?&*g2$4@8f+r z)^9LzkF>k+r)xisX`20>)3M^+cwjo!7LuF#EbTIjNV-(E%M+!h-Dg+*cjc}{Y2-y5 zmwoiE!h3aWDsA3jp)#X5Lm~BA-@5V`-Nc~A?XSK%+aE4BJpc1{T2OXmDuDKGoBOOs zR=~~D@}HD>a}gd@PiD6tF5&ublh4+c7>klW{PxkNTG&i) zur7Aqe50J(&35qp^i7_P?qC|lVjG|Fa*aNNuY*y0x_W$$-QU-pf7L`Zd(VwO*rV5N zH-F-$ZqISGGiJ+O7W8>Huit+oO{?>6HuQboFpbf`0p`DMY0TTqVjZ}*7yL>tH(Ir9 z<#j!9G{{Hf+pXULYP;-;UJg+e_pZ&+b>x^*epnl6GHiDmQ~Xm}dgZ8g2|IJ~>z&`L z*BbUe+*4XL_@CyP_xB#I{qk0;7r6gGokD|#@EF0ifY(eWucg4vL zPv5@ZyfuDjFWhFFFYBvx7eCK5_ocl&ux%zqJ9CPaj=gBhzkZgDzgTDHxoo*M2=-^!s~tZrJgJb?Q20inp4Dw6TsBg{Dl+3QLP|OZGKw zUd3;gLmXYzi*x*X71ZNDEA*$LXg` zYG0dvb*>%{u!?LumHXHFrHY95p|}I9Or~7#V&+1XNJr~NOpLRSM!Jly7N3qC<=48L zb=G9Yh9Qif;_vhA;N7m)h2S<#ud)bh#LOs3_(n6Vrbx2oW5!}zh~=9xiL`Y(-F z>*P6CyiR*RXiL7YV}HU@yt8X}>if%buULe+ulNg&blV8O2g3oe8aNbB)uyFD<7QqxiSpWK}KBx#9N@^E>X&Po3|6d3;y-0sD~Wo4ZA4CUeUF zV}9dtR3%s6h_C<40&VrFh<5h;lDN2+`BSZQ;=;Z<#=AR;4!G2OD89wSMEA(4Y5XbU z;I%`nLU#Mpy0~aV8+kTYKVR9Kva7vQMy^=!*tyWmms(G}N?LtHLISm!#B5LNQnV+X zk_lsGdTj0CYuMy!`EY_Qt4Cgz+DM#HxPSZe8yDLimLD~?9Z#=F{dk#WUB>2&FZ()o z3_4~#Df03>E?U>~)cjHKD@|LgE8k2^Te*_jzBiI;69kG!r1XX(nS3;0!%1;AC8oZON?89_Z8UxEN3)dLMyW~lOLO07 z+rrNMRon`7=G%I4JX}HWSANalO;xcT*5FawE@6U8{?E{Cp7eynIrljJxC{A0ccUc6 z1bMsST+M_xHu~3yvoStLcw|@9ZIWAD=PneKsOyQIrO{~IUG>;$bu}}m_dzpH*fOi= z!2{L77DG?P%FBObdS6UQ`cdDUXBTx^evH%V!6xnX*)HC|K+5%S*GAXd3PSlDek}c) zOnI-FW=E3*hAdBCi)GuP5;%C`LgT4@+Pl+pZ4@0JY%3VE^6Z3+jXOikNKOS7ZsI|pNg{2 z|5fW00{`!5?8Jg&Y9VMS7K+2=`sk@J2kM-S;LPWMe4QM`MtK3@^C=rdqxayF>qq(+JP8XmI@E{tDv+Q zJ^l(ZLbh5N+VG3H-#qp4xAyj>)z~1fK(4VFf+BBt* zP3bJ0f5nF^fx*O%B5VAaQV#mhN%+}K2jjiI!Zp| z6I?j|oDMnP)kBBSemGhbOl((q4N1G6LC3KY_<8Li*wVa*YQJrGfX4tLMtm_$o)$07 zup)^n9ehk#v0Bs~*zUO@ci;kyI35G}-&F82^d@l{FTjxPPGD2zMVcQw@U)XGNOYP2 zCGRDIQ?(sl-E@QX;!&VZegNMZUlVse+T*U6br>lB-=7x-#e=R`q?SSiEU$rkjV_vL z?ZVK|KcLVjgsR!^K|taVjIkTy5!E$t@nS~lIc2o=i6iEJFkxh|89vh21K!X^F!eZ$ z7i%rip)?S0<(`4YhjduoAcF8^8{UhK!pIwbmH)Y_GxE?HFu!CE+7pZVbMjgy%oqBQ}cp@!gOx zMxOZv>8E=k)iVlr^;km;BMY(w?|`A&5%^7RQazT-!0zI}q;%{)+=Q_fOJbEm1Z!s_@Ri0vcxLzss+3Ei zS}-0Y%*pE_b{qK01|gID9C0V^C=l^S;EIGF#!+4&7CLnChBg;+vQ&V)coWqB@Wtsf zta$C2I-E5Z1ljOacq+03r^iyU?d&$ptBiyYj3z0*)CC8g2AFVR0;x(DFeEY|Ixibk zseDnUni{wV-oWU-YowniSAeS46rJ9Uf#*~Lk;kKr-y->N*hm|@9ms2=3vu=CS@>}+ z3I7_s2F5>>`1naGj_|s`r{WrjreO!Z6N#wry9-^_w;}TvUO2`*L3}J+fX!wTObv*~ zjABz*m!n07ib=3tJV8_rMI-OcAJA)+gcr;7QR<^Brm5MWPjxo&xRny~bh1EvMFtgR zxgjQK3@$XVqH>=buC@f?H)ds+Y;1#>r#}hyIva4;bA_&Cd)U6~4j3*<6V^xW!b?pq zOi9$m*}za-F!RR9q!y{K%_F_!DGj^poQ3m zy<@Bxb}kU-Ke%Da4GOH6rot?t2()D2!s6xy5UPkqv6XB>WxqH+rT+nJXX!EHJwRWM zGUQVRgL+XC49G{Jq>en|M<2WhmO$cZfh;>p`EF?iz^52UQCCvLN%62rxdO#j6|21p66X zq^Who&pNVLbc+Z6?7jv%X3iwdWF6!&&L+)B1)%RYQ^+dq1e#m(q>@%9f^w`5wsS<_ z@vV5wi+m0HOpD;jYsAkz$KltI8rbK?p|9N&=#RNgXlxXNfWvh-DvaK?&`#46AZ|Mk;C-rsdtclZI1=o0bs8v$f1(SQdhwD92eGJ;b;8*1`C6DDO= zC`#jubktUm=irTex;7ZvVhXGcs`zAH3Z5zm5w5-!5D@By>(+JfIa>;ojs8NHf+1nz zW(EnzbIALL6a+qdM05u`;Y0o7V0|YTwKU(rNCV=;g;L=A-A<5H^C94(A>MhRiy;At zSij(l3EoqLuXhg6r;_i@c@~U4M2{i~hp|%fGFU5a$Js?Gw7t_uL~R^GCE^L(aVdd( zk0g{_vB1HNBuL+BiBzUXAzoMl^$%Mh^)>RIv)2m6FT~^CKb^qrDhE_k^~8xGP8@Ql z!1eQoP$8X{=wP3MSp!!*Sy2ZJ`F!{%>oY{0*?~m#Bj}92Ky+yyLcXlK&^cp>A*#NZ zts($%C$dNiis*#%}WXtB^rQDK^v)f0YUkk?V``V%5Y#I28spHpNeWX}d#U>_h z^t@V2y!!2jLiLy6{#Q-hL+y_P&Pj0mI2X(mJK(EqOYFTEh702<$eb+?BE5i&^Ze+} zpG8cV%#tE#S#Vu331x3v5<<6lP~8`Cf2TLTdZr5MXQi{|imclQ~=LADSI#Er*f%Y2LK-t+AH5nz)h9wL7 znik=~10%FJISe1>g^Y;+ctSMk+e-GRe>EQZE9iE4^6Lh^N zVen=&-h9c8n|Teu7|R7Fw?eV%#3daTWm{bLu!o1rI;7WkH2`y)VAl#4Fg_5#j<5Ro z>4X~gjH|!``Fe+4ib1FEW~dqvf*;aC@%*`KaKNyhFd^>?63gEpo`u}=4XL4Sp(Lgh zpNBm5Ymj;UF(eUjc$AL@Po1K{i^byjj>Z!G1)(K_zd?BK~#N(@=1f*JAf+-pMmrX1Fu`$ojDT4AVeBP?GE$NO`;k*V@5$g(kk4vint(0>I!yq89n zsw4=IltRTmKm0AT2ub`Gflt{NeGal9Mp)v{GE0(&cOrIwPQu+>QP6c+8v;8ogJG;Y zivJgZSJYE6gj7ga(5vAQtOR?ptI&H{7&KY;qW^dx%HQXK>RJgD-*6zA&pG3V$XJr% z^?q27(Z!fZ0TiJ*04+l77{BI+YE5&5ZU1?wQ>_Hg!Uibwp~vz22jMnN0;*p2hfxg! z4B{oP(^7ss)l7#~J;lTWzn8$%>I-6bIpK7d7pe@#;f5_GYUI=4v&AuDeAXK0{dsZs zgcLq^>4Rs_hDm-L1m5t8L#gBlEGtdIvty@#by5eLS1-f10v6n@Fij}&P~j4VG^}$d zp^keje$-RIp8J67O}$X)c^uZh@M7L7Su-dE;)gL&v@8q7r@I2+VL261uJZ!gr=qdi zSslu>Me(KCJ`i3~!Eq`E=%;rALs>=a)DFXlu5HNe`4Ni8JBYfndtjhxiRmZIQRtE{ z&TX_qngkENxy=R&f@MTyqblyQZU>J#Z#;Rg78a*hpy-S--um;D)MKrIbvGQa?fn6) zJH>-r+Y5=o@2r?wVT*KEh9TFI2aU=)puf={X*L5<4GgkyLF+l|N)5+ShQvnb14Z+*k2hg0zij=+2;l^daFPKhz8k+}- zu@uy2T7z-1TIjTQC9b$lK|S%Apt&ZDi>spObJPd#)g+pLhok+pvaVuPSl89v|Jy1M~2Jia2;-HWgmfw-X0KGK0(It!vRwQ`)qnq%M zvBG_Cyl`ct7&wjwp`^41-r6OOcS84p#Or8K(qKZ)yj;lnv2wE^`c2IZUA4TobLQ|i{J~|i9=<$Q#U(>2ICw)Jn|8Y6mzcdk zp`}gMbfK8AB8d@4%ZdIdReVtP04~!t6Lw$4aF|@%V}0_8E=5`lrnbgbHbxA+ErgHH zhTzZ73E0rf0!?#~*xIjxv%VLhN-CE$%k=?TezIfTz+15O*okMN{BZQ-1f<{p4y8x+ z;OEU?6!f!4A6ahXuZAK<=LC<^gb6095=(8+Ke7#32{qvQ@ONtTgv za)s~)Z7c|>`asZWN0hQ?1j|JV>`K2!y0TCW!FanXDLy zw6E-#!EAvxrYGQgjX#=fPe3a=Z@d_lPRyrBVnsg*J@&6crx@9P7!bsK(^zyRJn(}W zD}Fd2hYkKA$jhUGoML&Tv&9Jz5|~A_6!~Bl*%RrSj=?7KKI+|^2!oq1Ku3a^SWsUl zG6WqFzAk|uts!b2Z35?Ab@1oHQ3(E$0^4erLEL{2{;Cs1^__2_N<$waGc%yIN)Ri^ z`hsne0Jh0HP)wbKhre;+)7LFPdbb5#vo6>nzX|tzWN<%~4>Byig9`;5__g>AF|u8X zyl(`+*7Z~32|J1=0K+YNK#6LGS41~M+N;;}Mj+{aWzL~dVm!L)D>`?*nvxfgV5`shauX|=-W04J)g`lM*IW> zx~QVzvKKN3zlICXJ0VDi4G&obV&R_+;@Qmw=He7}+b?mzb<^`#|f8h;2z0s?5i&<)K*7|zi?hmD*tRNmZ+mk&sy zcPRxHc^-h`z7H@HONAu|zQblM`MAXc7kYk@*q^&YemxI*epiS6^8(m7eVCwbq$NsR zKSPLu8y1ngv0l6y?(p`&)n8xXjK2i<@OHvN2Ni~8J7dY43m`V00cP#Vuua(l6DqSn z%tQ^ntp7r#7lHLV_rgnlL%e(64PINA5VD(>pxNvO9BQ8ht&F`;FPR5xyT@Vq>nXx@ zMvlB&|T&MSoYl~R;n3rEJ(5E95t7d(UfiD_j}34Auo`}r0~y)Xcz|1OeLeN%9x zUI5=whv3GWpP*|KiKGxQoPA-1`ks6cB(WDIgr-3Ck}1ajrwDVS^{~xB7tEGS;OHki zxJqcFqMsUGdlQG;CdZ+be+!s+GQpCm^pk7vpHT;H_~7!7KGv)%*U?y};B;JntdkuZo(YZI!{0GGYrfY$2?__W^@ zcYaFL>8~-sThn8p&5#U0+Jc!kozSEcg{jX;(Aw{SFDer-e%S#l%VP0q5Dj)NG!x=F z+fm}1F|qf^9*F-Kk2jyQ!xp^I)C8Z+L2aff!_s zLyNy^C=ery?6wz(Ltgy26=#4}$IgMjuMqyUSA^`V)x>=k4O}_KgSTC>$XewGP;slH zgN`X^cqd?HXg>6`I6(DnX?#_t2PfL~a6s4sTO~#zCnyMRuf)Kfg~#B>VuGs$ci@>+ zBUF4d!&SynLh{-&IMcdegN{4ySZXC!&0KIXaSZt1%VK##5F9=00IPdBG3Rd}tj7(* zgO9rCOt%Z8!vSv^^5T|*G}^ZW!J%8cIQ)VayL-doS;Icmc67!L>O^!k%m-B>1DKOO zke*Do5WflxvF0f)UjKap{Jf?i$AJ=cvuv?i~COy^%DGre}hbIEQWj)z*AO^$Q(`NOQr-^OnGH zdIXy4cf(SyCIXWRHU-9^P{4U$RX+*xrlzo*B8L@QI+*)A3n;!b`IhJzJxM{ zA`rDJBi(smjCs-OU?-LgB5jetaq$q4G@`I)*KT0<-5^Em-2-p7UxCsz6YQVi!AoQ> zI=8fo?B&;k&cZ(Ao#4ck14j7grU8EKG{LIZMp)=eiRboKLm{6NF<`w!=$zNYIWclS z?mPhTmXxS&Oowe{Cb(_79!OnW;6TRS zeG3jdBGJn>8RujrujjsrwKkT^D5Qd@uk$||A@%ZENMo{q$~<$NND;;$_4kB-W;k*k3dgcs9;9T< z1=rw25a~2U>NOHxae4s~Yn#DUjXGeOsinP2(ou@PT&{7^GHlG-~dOfDBMwa4}6*!v34s2Yo2i8I0X;1 z)(Ya&!D515aFsMe>_kUR51sl3c3ggZ5XZ;(Fiem9I~l*B{j4(S55a=Z$lC6%HWQ5X zFu*{F5|&6!f%yw|=;98+d=4)xaXb(CC&h6WZ5c`7U_2UC|M(Na^5S8J$`}XT1krZ( zGhDB5#^5z{%p#{yW%kgzPKh1|0-;(I@H+^@`rRU#^Qg0mfV z9D7f?<8~K}Y<@$=4r>$*NQ0UYBao)_C6d?W;g{z<;0<*jWd>!&(@= z6#^6UrJz&E4R!e;n4;u`x5iRQqwfH&p6MVUQ-;*iEs0I~iFol16;5iYp~VL$ym0>}LoC1$uMzdhbEfxk^piL+=Ci zGk<|%c#EKtxC~TfZy>}x1oane(E1}aPA|EU(%GbNiGdMVd-o9ZEjn2E{Q&an(qfMz z3lSnI2yVeypikq3_dNEYU)~*1-lmH|hm)~II0>t1tf5t=8%FeH;Nk)oaY{iMgLYV8 z#n%dm{`v<9=RRUyITVkmk#+7FSNuz*hTiu!;mvD)%vYScA>yqR%nn;M5??<-1nIl9*TLwhk#`4yVtB!BEJHOQ-Im?E8s3d z2Yq(&nDitH53#7C?hpxi^pjwr%MT^CC{bX94{9c}V5PwU4mt7$*pdmMuOzmni4#R~C=Hx5Iu?}cl@99Tm)21m`-NM3VG zuxI%Y^iCbXs%;5qoBkPMM2bK=H*L^-P;9W_ewJ`JhuIuOAxt zX^VVK>iNJqSP^v%jlnTr756^!MV;GGXr;)EH{}xW7()|rqa^|x-z4J=`%lomRzd`2 zrXnMwIc``j0Ue7M!PRAh#ut2{{7EZ#uv_Cy1i4RMc7|O^(Ws|CE-(W27Al7j(VxpNEo?+*~WVS7^XnhP7V`L5ZwF7>A zDuCtF$Kd>VPt+_61;60KxI37fA#FAQafb_)p18u(=NkmAo-Eih-zR5wcF?fvJ?X$L z5zNlI14rV!h{0kBEV3{MJMy!z;8FLTII|(6W38S6F{h?sL-jeF`DlbS zj>%|ryOD4jwIl|e%ZYIl17K9OLg(UWbl<^^1;;Z8imaVDwmAl4cDtb9d>2^I%Hk^D zUAU>(4N+ZzST4YiHscAH4SBGWVh9>n+|gb16y(OdhDVy(C}UIxYp3+Ed%z1ddJ~aR z%p18!YQbkS9zRN*gWb<#F`?!sxL-X55nn^_qwp5cOOZ3%qw;9~g$mzHJO!`aOxP_| zO7KdP^ITJH6uGty2iZKa@evJHa`O;G5RTf^Ku3-WMyN@nYLz_R;r$JhwD;hX z90w@>8V0!T37saQcvxQ-6@#VlaFHPxe8~Y<(O)35tb>;m4xt6t0~pO70NN7?*ecRS zc#-pMt%>I_=1Yam-0nzoOAY%crBHVMEZkWX!>7|?c-&|&@iJwE6j4g1U)qCLg{z+i+JO-j1mLSj}il(vb(EEYF3s1d~(S!q8{!rt( zvo}(8l55m~c=X%vj7f6-q;o_aM1SMOE6_Vjil=;K>b-V1vO-9uhet0LHxaTkmm zo#5)yMR2!mfg}qRjK);dnPUOh3^A1Wk%$7a(_nEU4$Chd#yJjq{O_P8UfGI;rlD70 zST2NdW^=@F@;mr?Zx)K~d~7KO&xa@r0{ zwemsS#TqT=FT>CLN_g)i4I5rG*jsd&I2ya^%3>bOyt zgr#Zs$hA8H<24hJqLG}-#rTmK77>gqJO??YcG&VWnxG#S#@;9NSZ6*BpE4f88;vmX z`SGG&%rvAIaKggCCQ-cG0k0m)frXi4a4gCmZLaPEUE6Pvkz)tjl)gdLRVjSNcNl*Z zaD%gh3L3N~;>ieG*fEiSAvU7O^L-Sijgrynf*9uQU?ddFe?fz%JdiwhqQxd9=wLs5 zSl^F5UNb=VQyS?~4ItWGm9SRe!abJ>V(X9+>W2@*qrAN+qM?S;rfv2NfN+2>`ZX>^Ba;zlRtjwBiFE+GtkDNOKht# z(D@*tgP)A=fq6#^R&t0TBU=>8Jt_x9l>n?R-GL`gQ4_Prr9idR6@Tky0+Mrw$qffM z&rJ4A$bPJCP%`25r=K{+#*NoDUxGc!5TiYN2wAcH_`p>k4~xs7llvUZpEN+pU_O)= zeG8pTqdK~gf?zEgiJxBl2PKC!;Ix=IwqgwKuU{b40>?@7S^MzXu2>9ek3&Un2IQ^c z!I)-pMn-cB!esJwGFEScdiZbX2qx={*D4@hN{tfNb0KKI5Mg`O3VWXI16E0I6ggQ& zFv*zVlYK+Pf9!H_pF#jvf9*yFdu41**27N)r{FWQ4NjjB#HfM^5Pp*f^dI^mBx(Q( zneEXc(;k^Z60!3pH%Pyb#cnfNtn?@XF@+phUzJ0@kMzh{wg)d73j#ZtL3!S{rK6SE z2jz`S1e=^N2ws!H`T#wgFJ;9w9|2I)%mN=Z#4kt4e&_{y?5g8Mi8Om;_%Dx`XUo*) z*CgkoNrHH$L;zwE!wH+NQDC|(h~>ZE5R#;IV80xOgBN%(y>O8DazOx3zJCc}?YnWr zN*G7Ff^o8~l#s65hN*L|a3sPLu3R#L;n5|6r;7@Yacw}+OCO+W_#3d;`OsZAD#!Lsi z6z+@C2PQzb{RLQu8=}mx4DvII!=vK+MB;A`+_z?fc4JZa93nxPF8~M0*%C`o6ga*2 z#9JM27k2i7|jc4LUoc{N;tvxV2+3>E`!?UaE$f+23O8Rp@W(s8XdNWaQeM? z>@6J%s#4-dN+bMRdz{4ER|0I;vxwm724auiS-7lihdOUv(9CWYK1DtS-uMtp@G8@w;-}r=a@> zHD-`|9Is7DOim(jK7jncPWHv^w=WWF5fM1MbP&V7e;|6!X+U8?GDMTPhX&p=&^P0c z>ho-2O<1v78w-ty%Dcge+=}$zWhPKS(bZfZj7QEAVC=a%gDrN!xZL@d)8lYE%5G z+(^2)!yb)$^T5Eaico2d!rN1>M1035B28))QhthJ+<_!G%e+Q1tBb;j{CpVDr@|aE zGjJs>fJoo8!?vWOAn2oyAvXo^EoUG~Kl8yxVP2THz5#>Q_9$8Eh)*fyVBMw!`V6)p zgwlv)OlkotYd)m$3PvRbKWy3dkm!xsOZG)7VJ5;AxyjjVtj14b;IJU>cr^gNkCMSa zprS5R-6xQ}Tt?LTPJ@F+oLH84 z2w%vHz`3d*D8EnygY6V(J9)ggfhNXOc;Tm)Z(&UM6IgrEqC$rPRJH12Mhg{|xvT;MS(91OnM2b%DI}TDLf5ho z9=UfIk3EZrh^Y{~d`K3Wl9@1m_7r$gmlKwbFUk6p6RK0SkT6RGeikv9tv!eZl&hrk zjs@^?k%Q{L+y?SLmSK+p z6@Id_#`O5l1l^HgxO?;-d}B#K*`j2W-~9ood-kDsqa&XGV}NfTod%bqHE?KpAACQ- zfQprY7^ca9hh1}leL4*1GnI&y2^T!4mIl9fc;F*Lf6S=$!~H+F(AT#TBqK^n_hC{ZK8l+M0BaQCgfUzqcz*hJY)^eROmi7YaFF6-{WbKOD^mm|~P()?} zNsRP30~fkEuxHOQY3VsHUUSvOlMHrv$MPiU%ZCSWgW&}XyyC#o3kT4z!vrbRHF0Nd z9&iOF<43Oj;H=C@YMU}dZ(nj2W-o^$3Y0Lgu1WZFuz_r82w^{&3pdED#XGWA^&JaE zItp^VZ&M%?x=Tpz18z8BGYm-XLq6rz!7`f)Ob`ViHg^z>3zXp<=M6Y|kQ!^)9>Eim z1bPQZz$JQPtaow5Fmq|(7+8XWULTOYWQVL5DbRUfHy(0l!WH=rkhX3H`HEjeE-ej= zPv=3T{b%65YJzmRroi*i3Y_=);8g2vAcOjN)@g<~_Io?pm^NQH2?;Pm(<^>z7aw6oL(%Q!ta(MzCJkjX|nB*gDCMmK6uVohAOc1 zc)rOZV+s=*_LyPQq!TLdJA_87ccJ97Gd?=(K%yc0a145Tp(A~UG`fA6pcAA;w|8VF zd2tuUW>kT|-5J>PMGez)^zmTyEYRXMB5BdX7Rtgs> zO9;OfYOHK{37S{caD<;8bz@vm(a;KSd^5t9mUG})`iRi`(*t*q2i`jR0}rhPqU8eo zW;6O9?EQb9QvZj&|HIz@VekL2_kY;?KkWS<_Wloh|A)Q*!`}a4@Bgs(f7ts!?EN41 z{ttWqhrR#9-v43m|FHLe*!w^1{U7%J4}1TIz5m1B|6%X{u=ju1`#M z|HIz@VekL2_kY;?KkWS<_Wloh|A)Q*!`}a4@Bgs(f7ts!?EU{I>|KrgxZ){}L87Em z1p`0S!qziGbcP<7df5!U)zUb7NC+c^;_+QrhqlR_C!R9jNyL{w2e#AX(Qp?6K$n9& z`ug!12w1d2nXe(xTE~M=ACgB;cKD&s{ZSAl;&ikYXJE3(3x4vD$9RQZh6!V993_uf zaETCvfIOKBL%6?JD>yZL?s6vv>=c7ykoAwS5 z{+~Q%1Uq7<(;htLB!-vb=fLw}0`_;2M}d<^44ILk zXe=a08X_)&`YB&0bC&8cD1bF&Y6)!xnL*~pT7%Y+per;3KmeWT{ z!8gS4>s1JNkqrF_zu|!UoFL;%I_wJ00rgu1Q2hzTGp6K`>Sw}HYo8Jd6-fo`M{aa!4JK029*3h~4dJ#MjOf zVE(EG8e7^R#vzdWD5^Q6*9{VD_T;gazbD~_bvoS1m&Z7PM!2li3!i=NlX{mw!^{6g z(Rs&X{dRF&8QFxCRAhxDROWr1kDZm36_SLMk&+Rjp{-J>RH!IPODVLZCD9=5RDPNw zB$bw)>-jscUN`r5oa>y=dA~PnV2il{o|<)$t(ZH)wRJI&p$h+_W=QvqV+)ETpfWxN zxq+p$kgrFH8#S=^ty(yC^*z(F8;fDDbJ@?-KO{6;6P^RZu*bcJ`mUQHL}m(NwiJ@L zi99v!7lCxY5JE3_pw`BnWn?H(hfNaPl7`Uh=hC!6-5a9=f>3ivhOAHcA^l-7ow3d! z8=WAEek4whoR*^PfE#ox8c1uiD*B!|%S-)GP zEvyCGvKv%mYY*|m25=p8r@%8FIAk#fGM=|+XPXC_JFD1Uj#)o9g5yxs>}iNmE)C*fIy4SXha|_M&s5k*f}hvesSCSUT#X*AUb+#P zadd6=MOL-RdqHN(|6n%-XeX(yHQfGYis4m3^P|fo^{{|axTay7w?31IO~GGj zQxe|Q%u-c;u%lvW2v`+?p^4FWS|bYkl2BY2lL&(sl5js%M^73WNKBcd9p{#_hZccM zbgLRl`I?@1Fvm!8Rl_Kod@58lN9*|k+Q>0A8Zv|lCU@BWlVP--+k`ZmhGGF%Mu==x z!C%p4_FVK5`Eqnn$v8{Y=Wzt=&kjLRtRIeeIb&_^1(vI0i^amZRL|GAOgVaRmGw}L z>`g)6Wqri$48pDrNtC!OJ(qFXYyak${zNgpsredzF6JMEr zB-rvco!*-@vNn-vFgj|E-%+3smv2!1D#o;u53_*VKe#PY$gJ9AYRc_^q*-WZrq< zl<;oa!{ks(KLpnnRZ>cyDIU7JVewv$BOd;R?0UoTGDerKD**`{3H$SsBYxc(O;ZK( zSaL!fZHL5=Hj#r9RE%d-{B8U$+NUw?=f3aZ%li=ozR!?iB+_wQ`we@IPgjl z;+_WRin&gD#${A@;wSyy=YV9zH1zBB(vnme>>A`Ks^Rs5Q`*PrK$k0~?e~SVS`h9% z^up8~1I)~yBNK*iqNJlX__H&euP-`a+FOo$4%tjGo<>Z3zX0h29(cNVHqIzbz~O1t zbo!zjbI(|g#XAyFxbZxx@7INti7R5ujnT1tfWDodiOnZC8g02T9FCZhN%=IIv{x4g z<^pIko8`!!&s$64pl?J&V)f+ zK8-m}H?$YWpQRH*FghtL8iy41?@aVA{rYt`|3!QVxyFeS2BgJrDW+6$B zH$c)VY4{%)NoC#jR1@iq^7;=nYhx@niE%{d+!L15x2>n^dy}zn_EU~g?h@QiDW>%q zqnTu;Do)8CqtP2CLHi<6&gCFjLY&<0&x6rlTWs=pLz{e5$tu5$rbd6EhgH3_IhUi1 zUpY|avSLy`-^hFfX%xBO3T5k@W49#5P`zX#Y#&ZSg;fI1l-l6bpd-Gf3Mh8ZT51*W zRm_hy{F?brzp8^_;%=x@9TtTO$wXo`X3Pg6rvxBgrjOGaJ++ziEqigiAtdwcy zi?g{;79f26Vp^-U9Dn-f2$qBvQsBl(m_73#HGDk7nn!ojUVA4zD&wp8QkBeZ`YCEo zcVSzT(vea#1Wg@I6dHe&&RWR8;=m%v2aJWVYy|e+D-eusvq1FQ`}`W~p|b)xZ0nhb zq;;t@xHXC+wWaaEaV9!PM&MoGK|1!wh0Mbn*nbB!@G0pZ`<14L1D)frA!$Bae#()f z4Lx!H3rA>H-nKMMl|#-V30#O+g8ACJndF}{Y-MgV_Ajlb?fRd|@umQIm511=x2DLn zs%5Xv%j3stM;KMc;ibW9TCMIx^~I0LXrm0gV(r=9c|&3SkgtTF<|yB7H`p14cCxIw z!ge$tq2Qx`DdD~X5;@xS#<#ch*ZeMNm42nn1ZPB^FXD*YG*}v0;ZJT3%|2?4tJjue z`>7Zts>f2zw=8z;$3FII>~6Y`A8h~UikEz*mcHcrIe z=^BXQI3zo{FgrpD`Pn7&p`e!6 zGKPk`*ubAXG}_|>B}Y77Ixc;rxT(S@9{-Az&19*w%am^BKBW)m?@^152S;Xwqj1fB(l~a50=~%MmnPu0 zcsNReg7Mb>0sU#w!GO+Bim$)Gwq%x)wB~I(C>?`8L3}0r$YJ)BtCKDVkHp)U&lI~e z5o3l}W86GN>>9(hNkfI0K`cimXWphmMp(}a@NYmU&LVHHi>r*+T>&H$C_-6W1uz8 z8;|x4!(Wa`z8GbOBX!rQ|I}u7++`f@JKLjp?P;P?CmOY-nl??I!mzn7q`HBfZ{nb&xV1UzSGmEImQ=%M_})t%#S^j>tHrflEsEpqDQxyx1OZ z-mN7+ov8?WyO};M{y=A3ebG=lKx_9NBEdd>|Jd}>^-%sE<~j|1-^E->aE2~ANFs8M z1T2#?aA!skI&JuSRQohZG=yWvKpJX|_25%~n$rJ_#iKPR+3SUa)EL9CzQzj=vV?H# z)(2+6ad{`U=TO*$*_56(4Qdls(q4`{{j=s7$v@Dc{HWn`QDkvk5am?`~6CrzSz-+wM?>VGAYfO2>oLyG%+h7!ASe1*xOL`Re~sx^O9$ zmF^yoUe|UimnjiEY&C=Kq%>Rzc7XJsR(h%Wnl{!Y<80kZ+PXfH+MlVQroa?h`)we{ zzmLPZIV>gL3mfxi!2FU2Q`IVwsn zeSq$9l!EJ*$vC%!*Q382slUdC3cd)F!^&c68!`lX(OaqHf(pJhTS2WY0PZfEX>ij{ z8cgq@#`Ue_SSz6C8a9v|+)L>kKkGfEmHz9UK>-`@&?cD=^sqPp+xnHzV$?>tOEt0D zyq4@f@_I{ZE{3^nVWVq&k=S&f9p(zSjx!^Irwkl zBeI#RLj}Da*sN`Zovj6Qvc(mLZxX_mOedWP6L$O9Pkv2HV2k4|dR&kR`F4FI?Q!Ee zp2PHOU@ZEcuV>y8O@dAGBe@Q!kd(J*V(`HLwQLo`fA6Dk=ao9rJW|N8xQ}IxQ^KK= z$IQOipX5S63T6rm$mF9sByI{*q>UlFR_FtTS`(_jtB!;7zf$&$7<8{a$6OPa;oD{x zwEl^L`K=_u=I+mAGPQ<2z4}NpJ#%m@&y2ox&A>S~P4s|Uh`f=2wHwd;Gy zw2I@FFQ(Dj2bQ?YTw#zn1Bb^7qkP~Lopc+5d0+fVO2+~+w26+j9HTcwjL;C!PJR!j zAUyE{T}^OBMz8kPMEWrAim5KTUKwT(jXCB|Kx%wzc{F5ZKD&K+2nXd65*-gFxVT&;#L`=Pd^YN zjK?zb+eSElZZFkx&4T(mRR{(o=TW&mbCibU0SB^nS3 zL!Hl5%&QtkKWBh-^K)H&M`i}pXI`k*A5e}wCgeoTSx zl3u29@c{+@HxFmu7cph&xpdC@2gN*@kH2?g@yqExYp9aNgeQKyZmpp$I&<;icPecf zp@O=V!|<@7k1HPY+29gG#DsdoHtz{(rB1*;XHkx?{z?n7onU!g9fg0*k$q@5&Ip;)$o7M2bSE;8ZMjs+>+b(}jrO5T z9W%p*A%FikT-fSPuKnXs_N;^C`u=14Z;Zj-%Qdw7Mm+w7Wg2d9Y>A#8DwE><~GVw)dsxRue9MvnF^v4tPWGs%w_!TW(n}mP0cj;BrRJ1NvL&K~;wEue&YNV>E&-F7ERhpr# zUIT0YiBj?SG$i+KXSFjKmd$>_Ru%4{CY+zm(GaL&Fxodr`T;;ariZ*O?t@=a0I!Z!xEV9;EJ^9A(V1N6ajM# zaJg|gWb;JOH+ct5F8a&j%?x11&-EKQPVu?Z6y&ZPj^({ekxd~}S5 z$-pv`I);hEFH{u=A}q1y-y|&ic7xPsU9@P^?iZ9^5yzsD9kl;>4y89};M!Mv z)U16=8He;xA6!mBvyM>DgLs5UEka;X1e!HOaOk@;T-t@8plpO-mnN1oR+`<(=UTHG z-djDmgd^j`uv#>m-Y%JqHU&|dQhtg~>3yQSS&pzCZ;Hq5ZM2}{CFPBHOU}psvh6)w z%a_#7tSzOm^Ir;#g+Efgk`Q{>of$v;nyyx}Oir#&6iqAnL8^2PimmDpM zDWlr%{j}Rpo~`%VPR-h8@GgE$pU)N2?$Z&l*fkg9Ol(~Ki;08 zh8M;dUHyO6U_AM(O@p$rF{)%zQ9kcF6Hd8EgZ-oM-CPOA9RF%Q7)^eQG_dNqC%r7< z_~3aL>7~O6NYCNVME}i{_pOtLt}vzURRe6hf)_sr$z#PcVT?`Wy0%yL$Ssq=V#g5D z*&u;C`xG%pDHBWU$3x^*4L!LZ1wEA|NN z<1iu)bS&4H7DXH5jXdx7Jo-(^{9Qit@=X@mV~6?2UE#9*4V5d3V8gH_P(89t&^@G) zy^r@|re}P>o^Wkz0#JS>kk0WM?$dxDtxpQZq$$%#x#AJ4d#XlFjV>@&6T!^XM7lO% z90Z4r1lIiv=>#35-`>`Azaa(^Vh?EQo@r1G4a8fCBEi#SM!~Z^F?PXu((Rr_121!_ zhbs<+J}6-OB~R!t6~edp8v=C$A&lD0(ZEj5sQFRB@*ZBGma)rWUm!+jwus{>$6B}g z?_wJ_grVAaCWSwA#)1=Juz3B19__qCt=mMPHz1A~hSnJ0aFl))?;*4N>By+qMpBRa zNf4hyrElVK$tjseshXqojUnEaMB~EI1We1EM1Q=T@E~TWV73c{q^b%cNdyXbw zmVkz4IJPgDPZPg=qmR#~;nxZms6&?H5d(mH{ybC~%igACV7J;o!M$^xl%H2cE+Nq* zy!Q($nk5j-J9vuTmyd(dbW!qYOM;@!8#YVK4~Iqs!S2>)+LA3Gg$u(laYHT@1id5u z`qdN|sfaTR6R6ML3;jbvar;F&ex6B!ZTnYJhz|no>|nDO?I8K~aae6MnW;<)#nQ+8 zyccnu-43jymgNIfthbirJu7Gl$NP^~nT56cxt2>i3i>`-^yCn)?WMJGZI%(9zDj|{ z(m|%LH3o|{0A$p^*ZR$`;QswHh z46f<>^PIk07{K%5WK@>8;)nA{lGC1ogA2}pcqUWz0 zgo^CxMfyh8Xy*ijQR9($b~k;Ror%gXgEX!;22qPPFlEUbEbhiX${Bf`3Wv1fTpE|@;l4SSw&b!biy%j5lu!%I@JY{gf~h}lGT1{JjF{wMl4 zY%YsS(Z)%&Xh`g75=1_;|y62CsbQvUKnnmcwr zx^wrlS3VQ)WX2AG_nV>cer<=6?+akC{xA8OmQvGo8Eg!a7I5&RC6$cqf`^{Uu+s2k}c4ww+y**zuAQX zJ<@%e459XB+D&Wd_>__;NfGzQdTF5RS|`Os%;SvDACHrOo7Uw5{7g zRXq>r*Xy0syFQo{$}8!klritOq*Lz~KF{lICHIfFS^7kMIALl_= zcOgoz^wE*obyRTwIq4pehkYeKr@lQ$ZNHw=szyT8#%LHNts(0+Q&@d%CSB(iGALY5 zGrZ(r{+jQ(znq743h&u0%lT-o98N37`oVdr0nTSvF*dZFHr%czQ}%##J}gDyyHVIB ztAf!D{WLg+W5XMJDKhXKSx+5-n3du?w*)lRarD<;p5PHK?KBKya6G;h>= z#8`c#Pd~QMF5Xv;oT86B@o;2plt9tl9h5&Lf@V+OPdOFi5is^GHGcX_NjyUkmY4!r zgCU4Y_s4~J7ptnv+nBA|4E*4|ouRLzv1e>9S8vFnuyT;lWkI2NGZ5^2gO1mFqd zOh^}juImFDrecHj!MeD7gCi*KiNM8M1_jEiY2ttFbnX5~Xk|!a_X|yA?{~!19|16u zE+tPZ4cMh?(}WLi>E(o#g1?K5AoXN5%?LO_=dFBj^TRKS+HDN!-ILLv&$Yn(8dvwQ zLV4suoGD%-m^9ajy6^vCk*|aBwb6lH+B*WX_;;rMrk$Q&SWgNu6_&OfJ8HUdC|aHt zQpUu4l#}$7JvzISo|p2z!_LE23mvNHP5d9hsm5TEVo1+Z0=Uj|C-wYSL9M!hI4fDkjCLl#X}Sb0sJK91HOFIp zPCS(iexjlc$#}_C1J>!@2!1WhX4`Nbr*0FOt`){wawX{{MlAmLY?MVP(o;SMD2(O3 zQQxn$-fS4`^TK%lA)cJ33`gu%G5S-<&p`@2U-5b|jQ{MXr77+R{WqI6-fO2BGrOrG zCY#!J%hKvSAF1w;D4h8WF67XDX6xlgfleW4csrFW25qr+tUb^0jK!VqW>$M<5zTZY z8YboorHi(>qHcqNd?#4FI?SHybNxnS37rt~hv?ZXQW_6rXi z3JD*AZ@!7x#d|odyS!mHDu+H-s>5W@Ms{R5e_zUNW&@snSiQiU4WxCDk@plFqRtoRDf`hB8f)c(>y2{ zFqj%{C;Cz)h#I2?x52lhRJ)QQrzKHa3o!QSK2{)K$<|0T)7*~dq<-xZMT#V`u<6rL zn$ERGYD&0v-W4&8-niC$l|+xll(>tha++# zr{KSHCWr_dj}7bgQ1D9!1ZoGLOPc%(_R)^UNI1JH zU_r`Vn)c-c8JVnOO6&Q#JlPvg`<&6m(XBQQ$Fs0CF7!B32Tr^OJ{_~2^!cpelh$n- zl#s#mcO|6cYY*#+P_#6*)2NY!q&Lh4Pi4X^eciLDNj{lM{u!ao$QeK6l5umFDphk0 z_OonvI=T3%;BkGf)#fT0*bi|-r%ocBT{N5B<1;e@t~4z0h{EE$HtH%LkN#&dSWf~h zdq0|hex%0QFc(*SWW|ea27Wc#@6(LmESYyPyH`M)0 z41z>eSZaQyjw$U_a=nXu-WVYKygEwftz^f9PtdeVao8_h#jcwUFrB1WO#Np`e&2?{ zarFwiJ@N{%!VVTO--uE^sgaPLCH?uTiREwRQ%CO%bhi34@7)2o=X9Hyn>wN4kRlXh z?$hHgJ&^Sjj30bWo+D0?@{oEK5wj3!=q2T!{JMxy#9EGLmsEvRE`r|4+QDkMIJPaFiuHZR*l@Q!bkbWG7qO8iiCUJaZ5+J1N5M{gaRJjr{>D}2&zcIAaR7ZjxwH0{iZ8978DhtgwLH~ z*gIlAE)|S|x19!#a-D30?@+G8nugk#xiInd!X*zsK7*~KP47+c!+!+L4t2*Pr;VgB z&KNnjYpAi%3ZKMlNU85H6=$Ae>pO$sS(}RVXmJ>aKcJXn-GTs%e)5bR19`0kgjsWC zr`$Ly{QiQHFYt`l@o_kBYe)uWs$4U_i?zFiQ01#B)M}%H%#XcvIY=E_!%FCii9ZIm zbkm9oD;WP+!q&Dtr-7OgkW93}@j)#dIAy|`+ageETTJNzgOt>{ldD3x7KPu7E!h|7 zgm$N(;iDER6JAovo)65~{t`XvUxG&qC*b1#AvkgOCsRo?MSsUO%CGXp)M~Cmk(J}> zj@vZMVG7<^aqaSGp7-K211)u5#1FN@eKQTHs20$(!$KI{9S-tHMy1Rjx?OOVtd|B+ z^?OxZGUv*;*yX%$(#9&R#8G}A6)tbWNn0tG=HoS;OW>MrPkWT}jL@&`E|>)!y!7U| zyZPs68`oRyA2Spm8qIL3a0ITc<4Sr>4Jw-dhWEOivEsuNs&*H{P19zqg~9FNLj+bKNQk|}?%fP{@2-t7Fy4ksI6ZtkTNeip^TroDpQ?N znbk%f3;gi=Wd?EzEU=v`B#T|Vu`h5v`F>tY>z*5<*YO*R*D0W)hvMk@zJqR!@`vG2 z8Qv=p;LRcYSoQ~;XzUg%Ov`8Lij*0MB%AtMM*U4b9 z5)9`Q(5&1p`ftKlnj5`<_1{xR-7-_?k24ckg}e~>FS5lF#nEKZ^nWGsNLq2Po7(s+ zuKUy?#A=Sjj9p(z+@POcxcEb8;!GUL>!nlLEp%H<7KOWA+05Iw=-ZT9w(F+=$ehdP zB^?y^ND*4)nkZa29S_{baV5fPid#&A6W%N7)N?(ywRj+0^$gy+BFFpg=H0uI6EW1rQ-m~$r(vc>G`;*I#B3j)1jpeS@!#q~Sgwu2|XL$;QCpr9>+)X>LKcRE4o|09~ zWv0>NiHrt9)7f1#>~sb^4{fIwoep~RZZ=kGerC?pM*9ER;Lmijfi z)E@|Ddz)N4G6bDJYS|=vL&O;nPHL#&KLwHiMKNN~8K;3dLT?fabO|$Sw4S)wp2% zDlx`}WJmNlZ=vd&scG&dK!E1`I=^!(nA3eE1 z5<>km(cWQ){MX`?Q)G`1%X8jPp=g0YZm@@B@2r-(fcEcXe)RP{Gy$d5;& zSu|CK#Upi9FZQ z9K^+@!n;WrqhptlSV29_)dSAQ$5N?q0m&av!Grn5WKeOEt%+zQ_17X;5xNLrmUdKH z@QRKc7!B1m$pIyI)=)*!uv%4S2@my-#c^EO&ft<%~3F# zyp=3M?$b2=Kr+<6!tx)@z&)NXZ;)P$*gt$1AkqQ%OdRpY9e6q}ixvdgAh&^M9eOU) zlip1NW33WO6%QlvO+0VlVT^=v)$}QB1rr}9!N!N(qs8S76#V!xyT`R}tnN(dC$lOngVnM1sZk=#}nCuj|jX6!dqjV7Tc_r^Dj)72<8hc>s z2Jb_~^w)L*POOPW!#O_J>*}GrYwKwgfA2@+9Ux(6o=bG)eHWq4WZU%Js=vhy@rR#K z>8mk#TxkV$!&kI@+fa6VSsJ|W9A&%N&ieoL5wS+!W$< zZs^YZLh}vZ2Bfj}IC4S^R zZK=5w`ez&ttIXkbmpdYMK49yTdA}=j1sT?x&}h5KxHMmlbpLFmr@T*?d#%)}c7*{B zWGx|!{mN9-`ji=o&c)}4To)drDrj{L$Krud)c(O8GW^uVrb$@ug|f%P4Yg{P1QCJOU=no=-N<}|HQdqbB)2kH2W zRH&bFpm8D#aA#Qtdbn~+FIf}9Ez79?ggb`ccSh^iB>c4BB)IM}nev}cgs({x9U1&U z?x&Jbuxk`DF0P|1T!j}pV;=i&g9=G#o1o8H7N#oBcqIIbHeXuK-d>ar>z`Zk;h4Hg26nINDX3&T(%Kf| zeoi38H+kSuN)}5f8ciF8-LX_+6to?kS?r@M{(ibdVjk6G#k1JT|I{&mJ6E*_3DZQ? z@9ceg5<4)4XL6Q5q7OV560IaaoWCc8*2Y3^jRfYONg(kr3Y7jL1s%07De_b<&8%8S zO&~k=y!u^zB*xEz=5uQ3#N?wiY3mF6>T-|_ zFWn_KDLH(qe9bc7BtdT5G{lUwf^?KFstxCm0x6-y$pSh?i<#fxDEM!zA=QOtlss=H zri8@s8Hyxbov6?D8~r6K(Z_6dUM3PYPb2T#smS&^&vyKpgLSUXczFI3Reaq;_YVG| zJ5!Ue_rz}!5gv}th-y0N6OAc`_WX>;bJMkolu~3tkt-+Ro}mNjY~rfw=!f)fycbUA zJ>=?l-~8VN)qAmfxGV=AfsmmrcKA z0*f7rOuA?|WP7LK5C6Bf9YZm$-UIiJEk@iI#)gNVrNzp(>CyV{#Lh)S&MykG+iVcs z6AhVB`Xnb@OZu`l5dI}g-W@7fW^2avjze&Ag9gp8S%$>oXpD>0LZd_y#J+u@Xs*B( z5fTVY*R~ROAd!aGGV03%uHQ3mL3HjrhJk86#FEA3NILAvsNsxdrk28@lwdu zj>9*bHFVKk2X&L*vAn5?w9!zM_YjZL0i$~?Jj??Mf23e!UCwv*B;Yltlco+;#H44t z=&`gaCRdqK!?i|I{-6x`@#@f7_mAmx22s{$cLdAD9tWfT0>Rf`Mj38%LJ&+{=zcTuh33D2NX#+L_#Rf(RQ4qlYGYI_cWarJo-s57xvM; z?gHk&oj+UMJXpGMGplZJpcP`Kkk`=>+$bA~kW@=S-^?>K)in$^-zUM3yZU5F!bO(sYp#$_e23P+2u5L zP9~mRY+-eLr>|>KJQnQ>MEbQPtS}M9uDesAs-(!~-CIDdd#b2F@g5zXe2|P!$6{NT zA0$MUk;z#G#SAOd^7&8SSOe_ps$mDqVo7md9CrE3p{hQN+#c+wJ6mT#>v}WUC=S8D z&HqvC%=tJV!+T&SBoMq`hOSh&QHr4jG{*R&L@NkQ$1QPXV3gHL^kk%4ebLX zv3}|`sx;WlYj#=0@Lp+qj}AtghS2V|R5HC0fvppgu=acu-FK)K81ddtI~VYZ{As5o zU3q*+p9!0l8F*dhii?*&v%-W3%;N7xUAM`&(IElFp*#$*awZ!G^iWAK+N4L^_R zrcc5$=z3;>+jDKB!5%jw4@#*3_`ZVH&pyZ=JUMUsO_j(0n z@MnaI=RC68ca{c?M?gq51fuUIqu%=hc^%7yZ%!&SR;c5lxhmIqhJy`gFo`Ya=yCE$ zc(|nU8fhf$`Q?Ihy?)p}-vg(zxmG!R2Hb`uv&Wa_(r%+ZKJ$!%Od|h0RL|xVB;oyZ zcdX%=vfRjx>`?U#4EG&ii<6=u^fLpg0Y7NR>9cIyIxXJI5JsEKLiqX4CtIa;q`5PT zYB23V(jjb2D@pz7)3ICrC;0u7~cmDi{N7x=vS zZZai4_9m`_Cz+Fm=oO%w{ls@WZ~0+a>_1A*E}>+t zuM})K8saV1I35*9j}`3ESENSQt{US9pF8C$N#M_U8+>*tqgl86&ue7-cHocGrzaQ$`$3G;hFQAd{QZI$q0Vg}+oib-TX?^jyR zz$2amd^}BUyhC(8XSFa(H@}stY~w@6ADy zTbY0<@AUC@3D<2OeoSX~o~44Ki=?+{0pja?F-$c~VCuaLVlRUE-ro*E-|IPWnK*iI`PXBhlL zlK464FopgyhO+o}n)%#>)z}8$XwYnOc<4iYJ9tfhG=up(v&MnrADR23neghKOuh+; zFnj)!a=u7l=@n0;j~1Z%*j38w3V@?%FzR#+@am@to|j2e>J@!xC-A)KWp~WKy`8cq z8e^7w8O`q;WM-SP*b%-fJv3hjGADkr`DU#Yy7N4RJez|vPliHWBM75kk4Ij}TzIK0 zLT>$X`j{h)iLzCsUlR(?bv*0lA`L4oU-r|7>nBV1)79tQl-_Zj4Bx(@YNJ#P%EZBp z?V>%iEZF7b&lFMZhEqHr+T1-2Ru}51h@Z2+1v-)`|E~WF9)tHE3}9fXfUb&M(!3aj zy{TturGg&G9d@CoTC$WdV-6|3tfNcc7ZQZx1-sossp?JyO)L09CI$R6@H$pOXB{1_ z<2A}TV+^y8gjL!!nssUtg`B-YWyVtYB&P~-C0CqLn+&zG37D{TIMutRSRU##fIpwZ z9{jGvcDOWBn!!D?tdD?sBhPMWdeFr5AN1}{GyQ6ahjtR*8;o6oLX}wD@^(PiwvkY) ztEJkoH!LQ6E~|dVaEYtyugtQBiyGnfo*zV7To-<8A$qO0G7r63*taJGc`f#MSmB6o zF8}zB!Cf|7wwByPw@~C=0coubgR&am6a0FbK994e1D%OzPVu9xP=7dW4RW!kLA z=XLXsQ*40`N_r9n0d5PCG3C3}F6jhpJsF1V2dn7frkSV@6BjIWAAz|gj<9l-ry(MI zUxDv{S3cph$ltb@^~;Q;UTDGOfjn9!dm(i4Y%*Ur1_``BqwDC5$6lpmXg3vEU9yN$ zJIhLs@q8QahqZ-F$D^k%bjjQWNj#_I&SxWEA{$AM8~4VLW#~Q+@`|^TyaN?_fLC=k=6kNd^0=2^`*RzcrXsu6od8({{Q<C#K`4rarQB)v*BSqlW(o3tX-PvBFXy~3Wr`7V3$1&Zc zZ1Ij1Y+6V$V=gYo1Lm{&eb%N5$xo9}UMhyXb$q_B*TOn;`EHalXMPMlM5uhsB)l~6 z_f8k3XCyGx4lQ5sO|yJ zPnp%y*Qb7X8WYR}U%j9vbcb$eIb*A3D$b2r05!j_EWhsyiDsQ=Ic);G&~>0QR$(-b zpa0btykJIs(`a<3BF;`fz$O%2W^XnxhDOF-wk1>w!{1+|_AAzqHd;s_L7V9GfCgFz zW3edP3of(x-^lJa>1X=j@sntNF7=~#oBpG#`qgydaUqp|H-f{?K%9Ta-}yuR1!m)| znb%ltt7BBd`fW~-$mun#)lZiWZ!2MERK1ZG-b1zJ%c1gjDc%~Uz}N2#Eq%)OVzPJ^ zyzU+QzU>qBahAyX>n6CqhRF0Q&!$x;QPJ;36fBbniP5F(@unMWy5>rHsNzX8E(Nh2 zJ{5Ft(>~G)3Bg&O5!iRmjFgO%SZ|gREG;8R>YFXw#rH|acAAj*9t*r&F&yeGwe*4C zz*lD65?qp8LTamdrl@iv#?JU3N9P@u^ZUN>3T>5WYVT;#(0K0aOe)c$y;U?-S|kl5 zA$u!HLUvYmHX%t!GP24_lB7>&zx(_9({a?{d3xUO=Y8GRd7iJ6&+Z?lQl@I&Oc7 zCJM&l^P+X6JGqZ;R@ouH%L$n#G6LD5EhIabNfw!sc-^5uZ@FJxu``~Thl?S4O9V8J z)UwX>2wdIMLM`uSV2D))>JHu};Z_xNUP{N2eT6J#HvcTO8YoxjCy4~}XXj91)VB&l zw@8Pko&=72w~$YkAJ^d`c%6FnRHZ%f~qvzsTlL5T%wUSJ`9G$lwffQw) z|2X}d{tb;{qcl3{-+dLtj?zL;@fmv1@`Tzv+_0p?3*Pf5)4+~!ysO>E+WZ(i9^D|< z)%=`3VF9J{8t&f54)Ag0vxL%A+&II%A5==y?9Wr|?}N0)MiEoq^wVjPq1fcjYnY|> z(|GqYY_WDRQ|!sZl5PB15f_e)C0Ud|f4#u??<^D;S;PFnIm($D2PHBVM1Q+LTlH;F z!?jn-$$9W~uAu*9ZD3ij0E(y6pkh3W?xj`G!q>f&nsSZ%7#}c$t=t!MSOWh1S}81= zh1%3K7FCgok#*6i89fT|7u0E6T$SAq6>StAeoRg?cQ`PG>F^5QLPz^NQEe=T#9NH2 z_uQq!m7#d3QAq>Wo5^C28oGlUXkp?Ka=x4gsp)@dtj;YaWj`8@p8wl%8$#nm24*j+ zqOk1A@C^t>N24z~&YMC^>^l{Q-y+T1x=5%9WQH3T(H8x11pPk8RJR@{y^?V_QKZZy zuf3;3gUNXH?I}r(m;{N1%V`+*BHY&InwyRy^{9zK#(oBNraRK!-A+`$O9x&CqWEUx z2ffx*D2Bgh`e$BHmNaJ&=_cZ&r4vrXgd+095XxOyMw5J;ka5o#Qlp*lAoI6_#m`{$ z7i<)a+RNDty24Z)Z$iGthIsj53M`jCVGqS;Vq;thDaPxWh9S(^SYX<`e~SuR~`uBa*c}gUnH40fu zliFuujjuc1pJRy-DQ75U6LmVDBd?KnSkVPJL^kq#zE~T59C*k=N_8n#HxzZ!=jni5 zA*rqXO+mXSfXCoa($q`y?##f-Fzy{2pGU4155^N{t9nu4>ne3!5$5C0Uj_)I#N&x+IW-_vO9`81X^ z?wV1EsTtJ7Mq<@}+344}NZWf0NGmS_w#R1TnUy_4=3I6#%$Ws+$0I4AD~Qx?Oo!Hv z(U?7+a}=J2L2ZT?YNHNt{=qpi{~UmhtA$jZV1n!7?$BAGj}u|nNVFvek^&t_AAdlG znsa!@pqL7|7Hh?`RqxharH$vOBe8orO+NO8CJV;Xo3HQbpRo>%xfYWhn#!3Ku6R>A z88^>QLE<7Ej2Z3%vGNymjB}Ajm|m2n}9x5uA`qlMK?Nyv2edG3Dvj-q_0&sF%nZYk6&Xls0Cq z3_(OBuP-l*qKlSI^!^Cfo0J1+%I{z7$F%_3dtV)sIA`KqMO7yn~GDcF?u0hnNu8xm{w;lS9aE*4pfg zq#MHprHlMfZ?BF0*KBYz#Tug?$)QMT4whFd;;Fs`3b#+?K5c8R?cOHGZQ>}~sEBg@ zbNhV#pgyma5 ztt4?gmK?umV^3EXY1hhN)FCB^{})QZcl2oIhGZ;|@`lW(U-t3i^l=w;;SF`mBU zwox$G^V%lca3+d8i{tu5Z-*IDErD(3!SwNdE@`{?K*Z4%?Jr(X;CH@rnA<^4e=~5< zOv+)O<{3(xw1%v7OtD8^jVjewv9b4CNV3ii|K26zs`LiBcVP;OU-0`&%LT<1RyaJ; z7yAP)klz+9+}$x5;pwkwh2>dR8t;Wghp#cS2x$sy26{y6>DAd$Sa2g9miy1JW6EkU zaN_T;>I+%V+Ce>szc7*?hiBQzg3AT-@%77Oczk!op6wQh@-f4v_6R}3)+2O5R)E{@ zeDPq-Zfe@!!YZdOr*AV;u%5dUONE9aVn+i0TfCOExPGs>(hWnyoMGN-&u+-t(2t$& zxFM%U4!6!x*algI*vy7@bP#&@Ytd_#hVyqfXn&T+YgG-p6fa9>lWozMA_ehDliUTwOt(&kQj|0G|=6XluKkAlh*vDn)XO)}2?)O&chV3Pk1+PX&yQ$8<*_x`zvRn?}D zsOi*M@r&FySPGtWbK`rcoGjSpY`QGe8HounIR?sSGCA{laz>7v*Gz~IzKek2b?R~6I zJepnXnhWuX@mO}~3dMhjpzffLwB*@7c4|l}?Pk-^r7(llh)9vlfhiF>}AcX z#PMC@82kA-7pO5qi|%rE)KCE(785XjsU@n%Nx=DGJ1uY@igcfJOqXe*9aHvD%6K(q z6d%gx1rsp$8gey0N!Ql;&>_(04c(7aw`HGU`4h&-~3krYd2T`vK#Za(T&+jaNpN|De(N5g@GpS@6J6(Yu9X`<9YM&-{{$BkG((< zPE%oRI0u)c715ME3!AQ&3XWF@;WPi;s+*Ru^QodvS2+sKj29HdJEKTK8~+OBaZ-9U z#VQR!5X5 z#P#G5$kATPjOn0b&J4@?r(@TvcrviHL-9;AxF3wC#npV*@^2n2_dcM9t=4pXkqL$8 zXmh@4AXwP|D~Wt7_$zmZYQ7rddA~Eh%$S3eiTo_PxR2gHILJOcxkCGkY|!DHf`B)P zC~o2VEO}=#5@bXCuoez2Uq@@xL#g7sJ?7gL#4OTZqQP(1i z`9dyK!S$_e{BQG(pOufvigRBwvBW}qxRhGaM7t)sbDMiJO%#!6szi?d4y;sXrr_A> zk)Y_yqSg_ zUC}uBWCGXjdIUoQ&Qq3{GaX1sqn>fG5SpfimBvfxB@w*eL?9(M79C@pvAu49rhXnl zr6ZFOZMBY0Wb$03%U{a383wt%esCxeU@&M4B`QmxAS{#)U;ZG7KeL)y)MwH6j~*Cm zr-8vg@wD`QGIDx<)6P3T*y$DAwa^-a10CFVdYt&PO#=6#x@gmRcbuIckCzwu-$~CB z+g9e#{!cw@*O(xr*G|E3xzn6c^o1U7kV3rpSsEOiOSVYI+gssizqf&&xdL6*pQ!W( zXJ$QCwE0!ZX1aQ8r{t0PcV80AJcWCQRq>S zCTsSDzFS@t=$yGE2%o-|Dtr&p`S35eN3-M7l-~CLZ_O0 zJO6=wS zfaDLY^!}h7)mDqZ{&g6v*K?lKyc|@Wb;e1a$$Go-6dn7i0JE~qY@^v8e*Md%%wHF} z8{d)V&l0+HX$dXrvByoGuR6KQ1E;Q^6R2_S@L`t)C_LNX@NU|AlD(G!)3(K=#B&Zh zm#V1o;}n{vTSoz@a_Hwi(wgR7f|BKhB(+Thi=8&pu=UAUa(y*jU1p7J?w2bw*=dte7x{@iy4|6-_$W;?ih^+3Dw1|qY zfW_3U#aWFbgtUZsXSYdx|qV<&yaSE zHKv~#fz0etu>CHM8u4-1=9Y!A5lJj=EBCu4RZ*r>D6(uVur~0i;D;K&9vg2nN6rCx zu{eX(1RCI5MJt`tdPs`(L$GyL09h-oq_f77xKa!JOLnIH&)zv4%VC(x*$zb~zmUG( zL=08e#09SJ3>Gm2=x4yb=O`=d*+|K)PDnoJgUYFPknX%nwV7t9;H*IQLYQ(y6A=B$ z3fCijXxG9v(h4}w>WqZ3BB_MD+CkYnt_kXTHVOt_UZi^d+&}NEi#z{vuK9~b|@`LJ$8}ajf zCN7V+r>zSg3C=&|+_da?dYYY&)XihjaZ3(i%bMB1vHSFM{46HBuZP5Y)De|rO&_gw zxu$o342>7FzZv6LM7b_BFI3WlR6ST-n}D@p$5?AsIQH_{)-E&y$(%nlj&H|G2HdG; z8N=U_;Z&uigdL}6VTyzc%3p~h!Go~>=u_tWdnQu$slaXS8nQRlp&x^L$s?+Q*?j&MZ(@D{7AUhFsxzo&ghLu`^6j;;xH`jZM@WItkYo%*IQ*A=Fu& zj4K9rsc)DRe2U_zd58d=yFGAo`4mWSRz!E1Fm#Hfuxg1rb~L9`{j6?>R()kC*j}aW zTRhQLO$c7o$nVcf(D(8+ty#$#W1U@8x_BIk?b0AKHV56qUeMKl52+5E zeZtQ#B~Og-jz2FPcdGMzNj9oJkHG@DBpg^A%Cnov&~4P?tWv(y+$au(^9gX-ZweLv za3nw5$$qT^kzZ8+Z&hq|lVuY(Wn(qx_HGG0 zK%ctPF!b9lQrpS>dKZS^_tYua(4~ZzBNrm}C)dfv=EHpES% zdi9I`n#q%p=q$+fFCf|CHB`#EZh-|p$m^a5kDpp;az`xgHE<@`9Ai2e6N%+~rtmJ1g@o8$qGxYt#>bze5MRSy zw^(6skTn+1H(@99zOmL57wDnfagzO@OexE4pwBFUE~0|A%oRhj+7hZ790}3(k;qR7 zCEMy1q#7!Ujmt~O^JE&SFJ4abLuNoZ^ERbA?<4Vx3#e#N2IKWS@Z8!T=lI_2no9va zo)L@%hq+fpJdfSCNkxo=8M$Tiv*SFTA+g}xl+8w@>}ZGSzLD_NyhK(WMr`JqY*bv6 zz!$%HTociRg$B=+hbO|hItp&jc|M_Z6)m_o$U48zLaSChi8KZvx8|_C zwUhK698q?uh@w}IBq5hvAo;AQlrF)v$j=2A*k6!8UDm1iX&KxDZdsA4ovL2ZGAhEdE|Q zn7`(G+NE!dn|=>S%#Z8H(b1fb7EbT#O=zaX7aA%$1T@+N-G6=v#G7Yg>x>A((AnWjk!y`;+mrElQ=Wt3+&=p8GZfAlUn#{{mh1BK zapp}nyuw5gaUmD;dLPs5<&Kc+mBLY%TwHg~<@{MI%oje)EWT^P-_Z+~oH-N0i_u$? zW|~?u3`Z^2(jNa%tQ>lo1Z97i>*h%`{MIU3m_3TRtW>e>j3+i;naO#3c68l-92xGG zQTA}_QXR8G}gDI`2cVkY_`!Fwj<+JFu-{zR=ewX_vw8`wJ z2MYQ|;6Yv#0x!U@4r%Zia4pN*t1-C#0lDLBw%jT?LI=;16SoRXe~n+0mP?m2_W zdoLwxss9)&*N4ZgU3B^o&%1}`lAPy!B(({#6ZxX>t@7a9+z{Ly9*ydIF1Rb&!nXBW zW7_@M=<`fwGt_Okr_q)sMT#-?a&4@t*1|ZsrOY*VG}N{_!PFoY2fm!5Grt^A6B>dq z?e}R`{5@JEsR5OX&#Ccz15xG{>fSC5vFY3o!k;sn+9xBZsh;Ln$Rgj@pC)gwr-xqR zq~|A1vD^zGVR4d9a@Jvzniq_Js6l?rMfz&3j@+0QmP~tD+1jrxW7P%9YmOvm=~vXf zbskB>59c>J&<(A4?2B@x59}l@y>OIfbDhI69Te9yAMF!o;Roj=h6eH-rSWLGlcRus z-!yCr_r=u5S=hSFl=jPSqBVh=NXx_&)$cj~H8_(b$IXVMP&|DuFQ;Covou^R7r*0# zSl7^mur|crnKVKI*J_4y=sic}j=*4rpW&2}X@Aqs~!;G@f0}jhDlGvvk;hw7~z)+BY8ek^9mnray8mTcmiK zstZqWZgd{2{W^q39C=C44_~5~zaQve%RI7|9L{+m8VD5HPu3~-nana58gf(z3MW?5 zti>bYcS8j(0!s%;4P%@(9fS1lB+hL}C!5Z3xE%YBOuJ&?#_zY9%cX4c<{%jSv%zH% zCt5kLncYi~qJW%LOfP>rs@A1DoHdohI%_q~46kDTap{=Tc#3@1j>lT7ENrkk!l;Nd zM!6Tzs)_3{m8v*CeZhv{yI1GE66$11}`gz!QzM*I(*&H?mYwIs@o|yI|Kz2p0gV6KMWK- z>rfNDhw`)U3Jk0y=<7^zDv7#FD^f~n)j$e;{x%JzpS4IWAsSJOZc?&e4z;f^NA4qA zZ0+x-&O_>u=$(P-q1Q>|rZU_ee=tQKW7v(<#4BSfvYg;a^Aat{OoaO>Bj2#sC&chZ zEDxVPJ*1S2%6K$MpL>_q(5D$~6trMDOXS{$7ccBkR5Aet)70s*(o@Qh=KP~BIri6o zIW?$slJU$;%-C>@#G2C3|Im?jg^VE~l_%^fKQ9e^`<}uExNdTE3;im(&cr_KrJcer z1?q|m$$a`y>@bai-aA+KY!qBgdG7|#6c)dgMA@Uwtna7`EHZR(X{ahX zvzIbqhXOja$%R&)(4(?qCxk0aMuJ2RHtn*-@_bV$CX6ScJRP(}^stta80e3ch6>fN zbPFq7E(<|%?+03$?n)7=UhKZf3bOKQrp=xg1@|6qBEyONtZ?3#UyI7*yn?^B1fG?D zlz};_kr*Gghz@0DVfkrg{QQ=W!bPzN7uG}a=?8R9cMig@&A@;E_;W2)4fBj2(2Ms5 z5E(g6pqe%WlcJp|cg!mCn83NhuZ~i;EAOzDkDwoAJ{U7y6pwn$AY&_qgrco1)@3qQ zy`BxlZQn_GSv~t8oQ8y3L(wfX6%sPhFt{)q%ak)Q=FE0xbcbMN;((5M0tfMv+%H!( z7WszyC=S!%8TlyapXYP#`Eb0*P2j9KNv_L`L(Pe!wEOK_dhQqsh1)(f+u|tCjoKnx z!4J(Qme?%pinDnS=fc=``^woy5>DzO# z^v@pZ=^c&Z!^E(q;2B+w{z>Eg;ev-F&o>;U104Zh3Xdw!|6#NGOda!&2u0B zuFhZvW3(|sy@HZ2iojX&4Si7BK>25cV40YTSF%5-XlEz_GK-kSZbuq_UY3GBS5rw- zCC#|~kCp^{CrEU@?SV{c{3nAja*+Vr%K-uS7YY^np2mT96#Clu1oZdk@yI8Z-^+hUK@@v-%`4t+}vnf;^)C>jJ{ zhGHjK`yCOep5%Gwxw9z$J!g{QyTcxSR(y3x9NE+CP$?&k3zFv8^j!q*bxO$M{RJ}= z>uK`QL6VlX#bg^<%<(?U9&dX?TlhVAC-e=yTi?N!&*nWfb9~_wQBD=sTPXkFbjYcu~--&-0R3Ewm9N*qzB&A7 z+bU?+uq4zKTw;&sZ72D@Zl-Bq1dIPvut<6=F4|c@`Vs%W=3e%*?Ni)KTmI;aZdSVZP?cFOtbJWGQ5%n z`O98Ze8e547Mn>U#|`@n#=x+ul;zj{WM1iKX~NPeIJi9yydd+os&8){{z!|#6 zCljHuZYLS5dEZ$PfE}JLBJyZ0zfb|NjoDORWbfcDD)qpC?hj z+CHjJ?x!7p=HSS1Gt><=z|S^oihi+`RMhX2skrh4*<>2hoa$k)$<58jCw~pb#?`YrHu( z%1!}Uy9$_-V*uTqEQBesv+(d-D|O3_LzMXeirF&(1{)GEUS&Qq=RcyGyw{>&j0TQ& z$D?WO7#wN$$Dnft-9chcPIPW@SO(l^!1vxF${m2WhoOAo%x__kQJo?Gwci^FhU{|q56rmtf!XA+;qaxz?;-#5`hg07Fgeu%l-VCj{Q zd9)_LA1g=#FLzAE*<%4TR`md_4VR{nX|~uUzLs9}%|g%xU)5I$StS&n?JD^Cr-gTa z?V`~ebCGd|Yp=ty&~Yye8eI37Svd!3Y!V917!qZZ@og2)Wt{9|awp6oJozDQUt@~b zZvNyE!+&oq_f)O*#(``;q~k`sDI4SE35A~9Bq9Bn ztxGiFtgJ$IKpdQfZix91W4~;mQ$tU% zra9d-*0i1N50l4~g;SAoN(q-&)>B3B6Nf%oVa|rz$YSsOphk0PTwD}G^-+gtNA504 z9=n0+)~b;|=c~qCE`&*75he-NTA5op8n6k9_C-!y6l? zd2xQC8k&z-kkAA#bUoL>%WFnxt?8zzvx^-z@qBV?*$Y}zXh4~NE;2_wX=L-g{w$w4 zxO*pvLV}a|=O|6nV(hTjhxgEYi$VFCIOw&w;H$$lf3F7I9&m;`mzjEGQc?KV0am;(!fB}<&%%ACBZtGtgZq`$mE~dM>CRJ|i9IjBy4|Cd{O}V-L{lqS>7NpH6RF93lT}0os1dhh}y% zT^@KRm^;IP&dpsybBc%4dv6_wx@G1_`}2xku{ug9N{oQr^C5eFdp~`VK1nA1b1~>M z7HbxIaVG6#SR8!F4n~^b>mz=yS4_l}bYm=P6DCjhtwc+W=ugT@_Im3+)@aQ$Vtnu6 zbW;q!zxd)^9-jj%I+^d7`xF*ZEx31pYcoZvFcRl_=4nRj^m5Umn?!b>Z}U#W#iXX3 zf!MEyXhV)PZfLjD*!Qbg-WmxqpX!g&KQGv=AMO~lqK`!6_Y2IDr|@o^ahO*xgux$? zQ2*tIuc!X96D6W-aP2a7d0P_BT+-k^rFoDHHHXriWn}Y!v(J2N@HsI7A|lg~U8Id` zGrK6`LKBH!vq#*{*A&}pht$R4IAm^wRlJW-bl3vO>j%?s{#i9R2OuEj5f!UAW6_W> zmit?q&)0kGx8CJkXC4FT9#SvRx^auN8->_~pzqW=@&j$ZqK^h9g1+vH)VQ^R20Y7| zrItO?CK=-0o8ioJWf5%{eVxtAK1yZodPvO+!>P7bz8@WmAIH|P_|1LvQaT+;ANkDY z^@O!OnF7E4p}4*>jLtpzL&h`7w_TQ2VTmUXPyy z$NC~l7d}ImKWZa|_xK(D#F-P$e5d(Vg%%xlrwK6=VKJ9?8?NSDo~I(vn3{^|Ym|7` z4)^d!@~&GeClZ(*ZcKV1RJU%BP>vR_deVB%xZX+lT<6Sg7w=`@-0^WZ&!PrHE?B%?N z0uK12MW} z)tePoP2=5Eyelfm84FHmVyWJ4+IlJi%ikE`eXKb72U*eiJG0Q3KMkFCZ;`73e?2^3 zJmSqg1M6MTyhjfva(@KR z?RHbbE-##rmtaqdFER!9c6!gb$fSCY7IckAk0tLt_!kbrh7@Xy3}Z1`s;Jy+jf1%= z5S8vB7d;2)2Jia+?B5VR;}1NOWL-n}tbeSIwvZ0OJCZSENFojGxJ_0+)G(qU8<9o! zbgx33HouL9q|y!2xO0pxQXPkqN7dBJJ+G^xz6mzP^6wb$v3MyNjCRhk5xw-74%c#a z)$5Nml=DTl?pLDsWvyfyR!=o~`KYaMX1|pGu&O8v2p>|z$K_+`V66`ZCaI!|&wL(k z10?rYg*qvd0xs%eQbajR_>s$*2KmrbZlLKogOvTp57*{!-s26LrV34`1+ZTGg$62r(}{Ok7~GpDaCs_(SS>}wp4CEOtUS+yPsQWdPpte~ z6l&h_JVp`skLB62x{*oP&i!N?PV+N|btmOCePwU<{3NphTPnLV9-$>CNm;dp=1T`U zZ00P%ExVNv++9y`d*(o2`WdV3nlHFlI3J&dlA%|pjy~?wTq2o*KPO^%$A=3>4lky~ zCyjA1Z7SJt&$~g96V8e%(r&ay;e`A#NQ z`<$j{1R?BHB~7->M$kYaGtc3DK6bAiw&{gIjOL5t&+2 z^rbNhR)w~dFfI+bhON}9JRH$OFVcEnXWFtlk2PG$z{9$38vb*DL^|feJC#8ris#3< zFGZkhNyUakVO64xmiPAf^D~=#e+A*yy&E*QRzu(^l7i103)vEdRC0_vMpc*BklH`4 zzi{7B*UoQrrdI_m|INX`lY0)AITP;g@yqO&sf0i)#2k*PI)df^3$m9G!>t#y+3JlC z>3+j!W)Y-<&(ij=`YnM|TQ<>%#CCF*G{u;^8nEhQSUV&Qe$L!uwf-P&RMlsDMp_}r zDiqH>Kay(ZFg)?sq|nl76wj&*x@hUWG(3|SG63Oflrwe}+QGkWX|=cL-Y4#$JaK6d%-Sz1ufJ3d035S_%i zj8Vxb6miGbq6mzN=5xs0y(H9!(@a%i-y)3P3W~p zV|8sl?{D&@O^XAuDL$P-&8DI9wH=z;AJE)V&ip9ioqSI$A@WKK0Xa!1emx8(i+|C# zCES~0@so-#SV6*a6t+B%$J{A4biXhM6=zH_BFvF9J7(i!pE}-`Z=^59!>;h88rR*bqpWc(1rVyUfw@eP;5BE1%blLc zP9%+_oWIYRk4+6V9g&0hIo>fAVu*rDQ~a2mj@h3}2z#gDqPQk}IKLw8#&+pPWAGAtgko8ZViCa%$>AE6AJ$b;i_g@Oc-LX_i>DLt8RkKV@f^KHHnnoee6+MU&c*Z(+6zMuDSWnZP0{yBJTrjA!L z3k3^PZnI%`<>^;SI?B!`!SqxpK3;Q$j0pGB^88uj-vkuwpM;hiVG8jGM%+l=5A5ZJ z+Ojhgrgw>6TFys$ksRy_@3QMF-q3u5888g@AqNi^8nCNp>5ELz@MS&aUQj~vj}WY= zvZSk~TvHa}?3KIjG~9c3 zn`t!N^IA^VMkt`S%p5gdH|d1A0=`@HP~6U6tT0Uqfpy$xz)DE8;Rh>}wMWZ{Vd&A* zg80`pw7mT%?K?Y!PH2h3DOwj-`V?_u)kr*R@e`C=@eZ{mlB{v?C^bmAQrS^J-dmRM zZHyt4%E?w;>^3XzESEx zFllrNa76YJDOV_>c2FHD9;->aYYW-)p7+nYRq@Yh8bl|G;swv5Jb#vnBdx2mIwzQt^b6X&YxDdRuh zkLfkk2*q7F2$Y+}d9r=HxiPR}^cqu3%2GkuyU z>*al(_bMj6<2+y4-AxTVyVth3gB5eG{5RJR)Oh~_$>ruz^uN)pa|dDayLj@-=iKYz z8j#!&!Jo&Qsh@j}tJ4+mXOabq9got%+1u!oej~fF;2w=Cf5-Anb>RB!8M%Z?lPk{^ zWDL5nm1gJY-WYFeDttjZYNp|6$Yf+LH3L&AA{9$3d_B2~nlHVim3wDFO}mDP&ai}a zr6IQ7$wf-&G?JS$96}#TXbInS^l;XQp`18G>UYtL_KURbc?vFF^2P?aU*x^~8Eq-{ zz*He|XqWAxwXTn-CC&y{AFGlp_ksPB6UMabxv0Crd*%Q8L*iGqQ~%Hyh`&~&ZGV+v zYgEm5s9{j)8=!IxO_aFQlINnAbaR+CvNvnN<)br(+yu$<%*>{#Mi8Ak$lS^%b9T8a zE{OPWriw4F9O0cb61Qka@rWbAUq^R0<+To&yxYLTTIwFKC zFD4;Ru!z!gGH_^d1Ukd!2ztHiseorNoc2cGqq7;(ZIzK%c$cm#ohQ+SR&=1`C)I89 zA&=d8c&zh)dr?%Wmiu)Y&hdPQbQ>K~7>%*!i^wp)gpP2nB&@qw@O_*TELPcKTfQy~ zo;%SFy<-%-F#+GC%rKy43)AFG7_5(k&`=eG>CMMCpD1jI$VSG1BF}8_IjXplgq9d$ z{i}FXU4KU_J+6}NC2KtWIvK+|H0U&+(^v5fb|&{yH%i^4E`BZa_Huu6SuoEinIX}` z8PO7dxz}q56zuJ2QHM5eow!N)O>>dopi4VEX2aQUJW9DIS}Bw>M%VJLmRbSEa8~yu z^)wuxmrrIfoISWG0nhwI+1Zt&;iI^ZEe%nD#pqY8n*WYWjRxlY?WEX24=lCHg|qn) zlKIDZ63@CQZG;sbj7Xpfxxq+njzd+@8XBd|b8KHzbI(kR|*eShm^zsJ&?J} zro9`E;N866yZM+vki z4M!l}UI)J)yJEP3DQz&>PU5G>L%<$W{Uv+o#PD7wZ(Zu%5e3a-Gw|8fALS22aCYQX zb~;;|dWE)8Rdg?%-fw_Re_>o%w2U;R0@0l_26L6TkH$C)>m5QN@oFjyR#QiWa5^f- zoAXZX$=EhZiftLINlDe*TO5=?9oD*Vxxqhs@0(=iQf2SV^{J&ENl?tufWx)X)XBY* z$MhQQ9J>lgJmLrU{CuS`$5*j)O{Zx5OLu(sJV52<$LMO19c8<<&=w6zR)z7nIJ#6& z&ppY5VcZv4Z;kg{7p?5!{!k}Hx~Zy(U9z%RoE?rMGNsgWZw78(n2+jaN%+2Y!$)0ncx%Ar(`_aNAe#O;7K3f;?S-V8s% z+3mU5FE2;srh%;HWEq!)5no$STzBUI9Gb(b1$lqHNstk zZ_Gts2x6W}p~gC@dEg(`Bw{6*MNZ|8>fkEi=RkH3gJE<~@MF)>CkD z4<%-KL(e)2WsA(%bHfID$)7P&#tv8)!!S5p1$UD=>E`w}$~-U@6)W;trm{cUo-JlA zn+>s!pILji@ZN5p8-jl!qFA-x2XV_(a7Nc1XO8;f#xBqlp*}YJq!hXjJ*0Kp67ff# zVO^92<^*uIxXT|_QD%tI<6ZC@wIpe^m9kg#Qf%=&bZF|+gXOW1=lQ%j2W-b*R%@l~{GKIiZGzg{ouy6@{e zKcC}xzwf^g-8V>HV8k?j%pA!51|9gvX)Q`JVt9GxZ=7Tj%V#B1BfBHd&hLtU=K@)# zVZtdBeqs2ymW;FurvBdLBCEp{*p|-1)uDNqu{4U^YEMCPdUrPYxE3)H#R?C}lc-#h zD)+!)IB2QE!^3~zn1;Th+`1>V<=ye-$T;cf(Ih(kk+Z)Y$es2JLyHnvzEOIG=MAEx zZUC?NO8=W|Fm=QV75T!5cLc>oc5d~d7i7Q(x`1G^L%5vVAX;Tii@dx zaNT?``ljv|?&X8zPO~@j##}(JfAU$mYcE}Y4LCSN@+904prZVr~B3aRYQsAF!%`FyvIbC}^D>9DYro%h5ntusC5hl#8xCh-0X8e2A8|wnb zAzRJ>ni-r!L<>2SlX;WCiex%`58$FB(`Y|<7;auJ7k`g#MY7yg_UYb&E1lzccWxkK zte+xa&NSvvJ&UNr8vHH4PsP_Y`0`&G2g~kPytzG6=LMs&WEYm|4Zw{*0vGq4h3#$w zLU9PA)C1A6Z9ENSH`?s(1x40Ge;R#uhjKY^udu?M^r*a}Mc} z^VIRfcIZ1?NA{;M$y8V`*~a!F)%7#(`sfqx-{BJHK(*AV4BPU*yi_yH_?-jm>Cb@e zU74291CCl6%snZe!_J``s@1_|sh*I%zNdJ1)d@F#{1QdI<*cSzQ@$#f&&0`t!uY#3 zEqhqAbLm`kPBfJn!c&-Y#gL0GwW3y1sTlG!gYVtFSSHyeSIOO?>@@dT8 zFasKoKjG@J3mDcOX!|IZc6F8bkUxb3qYT;US#xObl`equ$y723@@mu zw2Kp1qXe@fZMbu2PxO2v`K8qS`E`O}Vkmg6OzswXk$>|}ZGRKoH%E*#s- zn0ZI%VwKD&OpGs=M)UoWF^%1RprE<;l zu9#Ej$i!vK#fzmOEH82p1~by=2YEj`(2DtTpIRa3#*ebT!nmQ{d2f^(r|Wda?PqPV zu*XrpEWBpMlk!0rHNi!gXiH^YzO4qTWuKw+xclbfPzQ-|fs}8Pn0N`V!*uU!dc{SoW&- z<(29ndaaOq{Zchr$(YW%?0)-hatqClv$mTY~yQDP`}z`bNA@?XruveWXJ%!}f#(~n?M)|sJ~+R&h( zh2q>EHR|;6lm3D6F7r+pz<2*cSeUh=qux~JWu3;BMP)EjX-|*&F6>h4iO5;KBx7wD zx7D=a)uEo$^fG7XtT)K~AZO&MnXF0DMy%@)nBI|Ii?Wu?X|99NvhBEC+LrTbUtruE zePpWDi}@Cp;Us-&K82>ZxXc2t7uv`feJL*W+J)2FK`@!uihJvNa;oQMh_}hSq5B-W z>g60KxG#O|FClAF6c>$&fd69!)H;+X4yJWuPubD_Gx{{vAJu35mR5Atz9V#JN-t5z zGs5WUASN%63^dDb)EqY&RlhH@uJZU{o zaY;WP)s`CU;E;)|rfCueau`>PhjBsJ8|!Cve)oANV1e2dibTqiH|s@4aU!PCQ9M zj`U}$808@1xVH$hmw9{X%P_cp5kGVW@JU@5N32fcz=W&Fd0L6k4hsJH7tWyePvQ2^ zfQFshE4u${%}eeVaHn=6hSaVAa%_3veGz=Sx8=T1KAa`@z45d6pvZX}B7cvhhIUKF zW_4tvnY_RDyNB7zsn9A-;aJIV9a<=PwpDk5W$yeT`=gKd*m0j^vM#ff8K5f$l5y+B z&OKr<%RiXIckfeJoEgQ^+}SAfZNrLI(#vcm*_sCqz&OW|HS)fyX4jmHo?CNAbq&78 zo3Y~JT`~Q1E8(0Yd3|(XM71MJw)%^%Yqg8pYH#JXwy{*hg=P2oQq5-Sr?$2So8~e^Y zg3O#v*s~**JN4?&`$k_r-?IZrYt-oTJB~j#{}rR`t8m}%CFYw|i!hmM>$k~*1J{_eo2xgo64HDPRVEH+e_$jsGetjyko zqPW@O!q+!&&XS(ri&}i(wnU6q4v|a?$s5ueffSj6jdL zZNufwM(jRq5}Hn3r#L-3mC>>n-`>!bDGTHoPkR#U?Q31k?u_DW{tywS%GkJIo``uD z#?ZPy=zA`MKXe}>yxTXIUAbep|H2j&)#oDaAXuP(PMqnp7B^kjOU}?jD2;aI(XM*D zoir3nqjO>Q%7X7kN+xIWMocI&Vojnx^TwueLAVxjcE)n$g|QqdeJpV@>-l|TIJ4sG z(Khy;I9b$^Ms6PwZRJjE*Ol)Pcj|tOM3PA4j*W7LKhcXL23f)O_a_)h22e!SL*aAJ z1C8_Bvq1hITnG$f(ee&*##4jQgFoY*_Cw@b&4GWmFsgh8Q?R%vsDds+gtn|en>vjye zT>_sbGZX6!4KRI;vxwJLrSlRc&X(Pv)H_;S87c1;4Snh5 zysVuvu$-UjAlk;jvZeE)M=N2CPTG3N{Ji8b#dW{n*CF#6)I*$MId3!f~ zji*(9(na!8L_F=yUuV`JWrrDDWhd&?dll)>nFVi||I7N;9(%(bIQ4@&Cq`bwnxUh3 zZjf}UTgo2VK)Fk~VZ=@2KB44$xkxCNU5dByh_@NT3m4YAbPbo@ffG$&*{o8GR*%58 zR?djH=|rWya_@Y_U;b>vIqljo;rB6wyJW_pM6%2;rA%S=OHCFW(_zZf71Fu;7`LA` zpz9r1+BL*5_UA`@e&xY6k>eS+y#+_V>`cq%Uy*Gp*>3B|Pdle_e?>Ii<>zLHS}ck; z$n0QOCAP}!PFwjqE=XP@{IAaykCqPSb-g(l{LqY=8wF#XJMpIK2;Ay4TO@5*kJtB( ziP!k*S&d-8p|<(O7e6DT;0yLHSxfuD=OmX_h(4 zx{l)2`FAnv*D5@y3WQlyG|x3l!0(^RtX5WKANw65FFFnjLSp!A#%(O0{27%y7Ast? zddl7ZN_5>QXR#ZL;Qj6eS`T;ROuy3@m2nEjb+@r$=@pTtdH^L>p0ql+5QV#B1~=1) z5$nCg=V>W)+jU#~JlqX)JcsbY4sYp5b(JpZM6_u&hS3QpFtoH6{eN}lbKPLH?sHu- zeg@-p(`a1SDmzDR(!Ht_z~!BSy^3a{8l>dR=aV)NO}H#6-0w|vlWw5 zd$Lh8irKy7y@-Qzbllkc7fbeLGWMMzr{DS}8DWAJQ}kS3 z$2hY@_HnCU6o~p$Gof^&iz2gWOPHss!Y67A;_k@&k?cTKWu;1Xz<#kdG)(e>nowi@ zKhd@BIJU~?ac6T!7bA<&r z%u(j#Uq+nWC_Sx1r!#x+BKR*kgi9wR_oxf0;XR#7@dG$$=~OCm!Z4_foMT4}zg8pXs1lpV z=e*av$D(q+4xJx4@bvvC7|U+^s}<4gzi$J|4h~^%WlQF|UxLzdA6~09r}LSoc(go| zZl;%U(CIRMo^oaL)5cT}a_5ZlkJwPM5pOcv(!AX`)HIF4X_s%(8 zuoo#A@nU_Gj?~oM0w`cx$+$-KpPMfn61 zJU&{Dp#{F2lI?@q^tp<)y2^Cjr^}!#7Tl36(3zRMo~6&fn}$f%Ln5;!=zCC$1DHQHg~_JRF#2K~PfkhV zyNt8wUonVbvd@rHGDs-he1Mn!@}03&nMElrXl!u@F$1LsZ|9z7(6uacap8B;j3 zO6EhBw_{MVw`gg00+xTAdF1aRe0G$1vW?Z~vv4ST{=I?Gk6aa5CU=VDfv zE1xKZ<9^z3j{Fw{�gHot=or)l=ANax8K#N!Du%HD+$~VVsdX&r}s7&3X=UjjY*Z z?HhR89EI0cN9=rT&Xo^!sMnZ?W9nnLwtP4r^zFdoujCzZaAKK7*eVWGB34 zf=C{q4C7)A+QsVf!t8N8n-NB<#!Yx6=iL)GP2_uPKaM}uPqezd7)f>Kk^V^?)Ak+^ zUc+XI#$ier(yJCLZ!8$b;OD1pO)c3@i8p%&&8*zf4KbM9df()Vcm(*eDI(x zgB@O>Te~rw+G_}d{!PQ{Rz3M8CL8DEJ9x_&PhJm6;pNX7+-RV|-I<+f5@y3wO5+$i za4IeDOBQ0u43UuK$AQ;gqh!lnw7fP8SN=*Dn8qAY{_ij@e39qOE4>l?wiSm+-qhaE zR&4t^gWKfXd-R!(IP~!g60;;fwUhJ`KKcmnB-!`=uK{Ztj*Lh7&}>=(LWyJ^Yd%+p3#QW=QqQhbCE1GX;AF`sL#o% zy%?)-hj{io-eRFWE&nc+%tH7+QbH#y>Cj9513iTI8I2%5i{x7<7 z|5QWlx!I1lPplQa4@#$eWmB|I2&3zVSROh=s{2YVl&%H4-?{v{#m&&yL#>{GS;|J;}kT?anN6_GiCtmyuiG$l3ra8ul}iyzW7w%g%F9>+VhC z7z0{)NMCe(bF7xGEcKdicoN~wia~qC>TNkfM>0a56cpp*F3D%U*_vxIm!av;oj7b8 z$DPe0V4N(V>E$3!t$v32i{(DG>J(bbp2l<^*=KY-iE$dukg6|T2a>PbUv~GG8}5f` zVhAH`4fw2O79Nzwqr06NM?aETz81?wKtw;5T-${|F^Q}((qO}O4_=;Bh|)?MIt`l2 ztFew)^Kc*bnq0=*l6Oe%{!ql_spH9^WUkMi2VS@&j`W$%kee$+#J2z*ypV}@)%H}4 zJAf~yJrxzZLfG6tP0mAQ@7?w><|wzvww4>QcvC+fSM%rke0gX5ZO{7~#xwIv7}wv} z0jp=j*zbUJl?FSpYV8s+t&a-xZ!E^3-I5&@@=kQTFonUs6=*-zl5P{@@t@}%MSs5* zDEU|k9eE!6e&d@<;ew;MQ!t4?RXQO)q*f%X?#3~DOK`Z|SUx>4LV6rFV#*l}e(U7T z^?Apz)Oi|x!-sJq?_;WC5-uK={N1J}QTU)4%FLRu!1^kRWfsGwTL;E;ap9krB^cRk zC1z^WiKR|Ap&F{-^M&55eLM@R>ax)o-B#|nO?YMZY|Kl_hNaxCFHjxB&@P=g=e;R! zjTp&63zE3<%rxd~=*xS;3NL5Ji}pYMp>E$%%!@yS_lw72&$@b-^D_fz8)%2`tN-BK zh_N)&9l%ko&q1ZW84t+(UYgMx(H7HaDWBzCdE;pB(46|Y0o?x8l;8BT@wD-c=vO8A z!oMGgF8X66Z_$oQA!E5#^R|dF*X7F+OOCTuXX46pFj4*vj{pN|ZS<49=C9C5iK1g0Gsbrvd4HWP&shE2vh=d_ zX>QwrQ^)-=O7aOCE2U5Sm>PSJjmOK-^SHMyl&>cl@UCt>I-R&|`v**a{wPIMl?77yQP|WIP#|3*vV~SHa+pPbC*cAs5ouw*Td^wFjhfHXf z^as6rjAd3*H-3(j`A>(%ID1U~oySYi;_wjO*>^^a)$btAYpU~oajCes#{dxsVB*$H zXk|C&snjE4aN&4yVt{1i%=BR8CQn9ONa2Wm!}C~eC5HDz)Jqkr7&YV1uQyN>Aopw8$6z>N435j(WZ3eTVqzZ)CWOs{nvx}2 zcazy{!yjVHk!ZN8yo1paS8Ol;g=ie5k3(EHUqSX_99t9u7vT;~9mQI*Vy>%Ya)oe>2wVxtEL!ZHNR`;=q^oH+`;#a3&win*?ad-hA zdwmvU-xNmcyD_5KW8s_fP}IBoFzNkq{9V-+X+7Q8H}nhWg?9r?FSFMfp@Q}rIp^Wzar*Qm#WXa|@N*e6oXG)l*h8*Qfl zM$nb9qH;?#cMV^P*)RM#!2cz_WRIk6y)}#ScZfMrqgX5N8x!7*WM#%BRLZl<)dP1i zYH}aqKX07wS&7YsdI;Jq_wo?|a7zl|{A)|F&HAqJ3h?D|*RSHy{bY>sIgTQ$Zzz|H zr*`EPqIP~Xr}mNlrN(CTaCG96QxP08J`j`Dl*EXrTpaqL#d z7c!X$Bdki6`O`_3Sh`Pe*+?%Ajo2siW=SY)pTIEBW-K#m2fsxF*srC`$xRO9?{x$4 zp!0aPnm!Xdhq<${@)I&Hk3@M$1CBqNi|b~`Fh;K@#j+`IyP6>RY_ccm@j+xC&g8Y* zVH^_`iq)E?d_G#fgZIxsnBz@_ll=D_o@~Q(wQXW!;}K!7HkDns1h7S~beXRrOUkUo zqR>MysNRjW`F0Hdl+LYpw?L8d5KieWIZyHi{@MF6bFPla`ca|KmOZ!kGIKv7HG$9V z5~&%|hdIV!yc{XLN^NUkUB=EG}t*W>a*KxkQBvFXk(*yb6SYhx9T<8zIheI&CHEYwP--a&^FT;Pb|J5V)Ee>vz zzQB=##5fP>FifjN%_)Dz9`azf*RtR2DSd_a+Vf28GO_8bD|Q|;Wr6gQEvj%sfn6C~ z4O-E;#|G#ZpLLmVf3J9*mxh|6Z+O1bLhh^NePZ)gC=U!qn|=H7tO6`tna;0QEf^P_ z$q{Yx@nx5b*|L7 z8ie8Ne9V*oj{X|%MjLWD2MhP&GvG?aKHN-_}eUZ>HcCmJ2^!%Vd)zjb$pI3%GqeO zzB6T1iF5vFvy<}^=*Y~!@y}G0`I)fo*aY+%{m-|lm;OmYSyuW5LPfWdn@F7vC{x^(n zbHAX;s!(=Ke}PUFBbYknD;hUzP-C|)i*@WcHQkTfZ?~m^=~EH&?;R#RRUqa;SJrio z6+V3m74H*|DcVSQ(92vy4IVr5^lqy8!Dq@58O| zuG|q)gPA^u;TNzUBgXH?Pg7lfJ6nUAz3mxn?8khOE#{>bz;K&nTP(kfj2dP3jl3xq zX1H<2^zAUxa==kdnQKdy=8n&_K&Tfm~Y#0I6BZB8;cGhgcbyyDULU*-t6eLRT z&8N3`Um3s)&vs*deh|-8F2MS(CT8!RtC%xVh0R>-xTSZsNK}vG@L8)+TRM$r zCO$?#y&8D^m2xn-lAkcEbzwV~If@1} z<^7^92(Fxg&YAvne7+J6Yx;5Sq>Fet&G9hg}h&tDNS{1H2mGn^No_KGD`UO$F& z7jyi0Hb@b6TN70~PGg7ZH#y&b4D0RNQMb1#7Clzs<6X||k{W^#2kDL&=YjA~((f$y zV1d0YxW94=pL8G0Km#ovm{aa@S8hFTHVjdF%1(-d(gwNM*1}b9$Jl@x4cP^f+24s@rT=+{m8BUG2h=!!4r8=Ooq* zUnSbv#?#oxR&-c+3D-uXaAa{X1NRT))5uONK6MN8nk*8&gDv^~@I#k#3)?ZFMJOvj zp2y@P`O?)TGh^d^!ZK(BbQi|4RP8B_MGWVX-Oh>+i&`RKPAC)3-a~ql8VZ~>S=2Fs z77wM%w|5tAxj7!sVIu`+{@4*Ul^Tv#+?nr=Liyd9X0%)+Z&so1*oTr!N6ySM<4*@` zj-At+$w$;!7|@1W)9>NMxN-bYVaBmLPN3O+UlxyEBHzWn7`99L;1?N4uWALtC67|0 zqwN34Zc)jJzvAi1V4mL68g_-|800>LFS|yIHTD6F2$`tx6aUal&ggwN{lWV9Y{WJT zWy+L)*x7vtx}JE7d!}A&ykbeK=Y|}V@57BrUZP!A2A|z=$AeK0e4gjas-gcOMs@;H zKAb=k$>#4K--FNF%)>D014&+1DW+Wcg;=o+M#sY0;L?(XJ02ju+8xctr*eGAH>}Rw zk2z&sxI|CR{nZntLr!Lbg9fqO{RRSiYNFS#J@B%BdcL}cq|NuSf{4AaFJNNZVLNtlrsdaV>ov(18akiW1Z|TEXe%;{bwJ={o{e$TQyVhSu>R7 zMm2CMNuW&+@OdyD($4}VZ{nq&CLOfiDo#wgg0){~U{J&m_V}SoQ@c^@GGP`TD@%W; z+CDMmPl@=N?!d^?UbuK>y7X@SaEUeymy7`LgOvx$p3IUQiw_vIEtYfcE`Z_eF3?_U zD=c%YurY$0Ze@6EBo7u`NVHDQAPT-k9O_UW5XW*3)JSzK8#_Fm}lZ@K= zpN!biJ6U`gmdrm^UHf9U_{%CC#}VWwjOwWdqoP>vFMeH*2?KPiC|ZX6XZhTdX& z8!H-o)Wnr@mr>!`o{kAd{AVfqZtL!1=iqVFpJL1JaSbTAeHF_twXihvtS#TS{0+j&!+B3APEf4^4*UJs{kh2R(?RT{Wy zQrSVWM&E}}krm3Qvu!!Kz8pQB(ztwKOCF1p{>+6V;MT{P&mUY78xPgu<5vgPt&zUi zQ+CuG(1g{JxlvuK&&4gl9d)+Mk#pqplg6`J?)P?|Yr_9luEI0#p49K#n~?=>a$41s z^KSwBGpa@Gac{Q2CK>5|pJ45w&ZHYE*c0Z-?c?)B{_O+~5BPvSzb`{IX(rCfU9pOf z4vZ#vV4vYZh3+(4ZoW|hJ(adX@1pGe+MI!+lk}G7$o{KQ5L+S*Q7e06fy}Mly|qc~ zKUR*T$2!u>#+7pl7YIL_k&^3hUNT@WBHyVuh8dUO$O|>*{uzgo(XBw|ABb*VCnBs2 z8S`Tkwl6EdmBrdzG*lfs|8}5eU?kU8htND&cAy4#ruC#@V!!;pOJ0%2r^8yop(vJn z??j^YWi^-t%@yB2-4^k&@%$>!u74LMv%xEhC4Y|Mr_y-7|79iynf#Ppb935xj)O<| zSa@u*W^ps=3h$MReVyE>Cwc7l7O9-#CGWlO##8B_JG*P`hf>ql>|z!S&om=#Le2(pZQU^5;%V=|)Q%b6$`;w_)W~*r@2vc~ya&t+5rQ z#ooN9lPBltuHwaBcVM%DSdpJWt*N!>)_4ww^oo(KY9li$KM`4&1)o*h@%fB1=g#kr zUhxZ&IX9G=>Erl*(k7Uz`cP?P4WjP^NuQHBI)u(b`AKy;SH$A-m>S&L?}tv0rcy=M zo&h)CW9*G@@T&H9X|W=Kx7M_w)7T~0ueS}m2P9F)uO(}9yGobLbMZSci23sTbzodQ zoSKL7S!H(?_N|q>T2sX9gz~$8ojC2M&bE@P;b{H@`m$@i`0{m^ttWoqV5|#n{h5Xr zZsTxPcJgv>YtZ1wQ|y+zlk}60T#+g7?k=C95N0^q+lFU0^x@)pIk?m%f^Ets2|KlP z967c@yw;j7&-=TDO7BQax~z(J7f<5g16A6b`XQcwX+)Uplm7OTySAEeJ{y^Wg^!Ez zXzWy>J7T4xvz!GxkM_s%`$^Ji5X{f3R9SF3hKdD+SeBN+{8UM%5LyNEN{RGEb@vfV zhb6IPb|@A|cjVl}cAP0JsTyEP1G#G`s0iXtjZm(fr-{d_Gm-FC54RutitA~SO!I8O z3;PVFXuC5ev^O;ic4KyRI<IUB&%vH{u~U!UhC$}zCDpc`xN8c zlsX*L7$Lbk{?sbGjK^}U=)vRG-dtG+ju<_N3A?SC_(zKq$N8f~H5#6_rfAUVz>8mA zW4TSGV;A2<$o2ko z^4O+`Fs)FS?i|mMH14K6cKNLO6u?Brr>Dp_x|B|&T~J{Ip;agIcO65E};D0 z6g>@05bnYOkG#gLV&5$~H~e;#)qiiXpM4duWW`Nar!`gq-E?o78Duionn1=!)F&&Cv-Z55Z+@r)8=IH1}sYXDz zEzj?adXkCjLAea*!YiF&zrr0* zTmL2lBF-Qz-2Bsdl`c2SuW8ZI!Q;hv!O46aJuT+_1^k))A!PJcie12(1)LJ%=lw7E zXY#uD1JlURc8~f+R7IEJ(EQ-TZMEBmE*w_xAw5h0knUII^Zju7cV`7ifv3{(Fo(Y}cc5+kz0$WM4Kx&=9iKCgpCZON*lI9Zt?RZ2 z=3Fu!wLv;;brp>W{!)vk2PBO*i1wODbD5&={qFLvc?T`k!=yspaIX;T+iDewv`86d zA7g&H!07zx_4=#V*t;Cr?NtZo(WJ2A@Ex`GZ<8b^dD+zgdY2|E^@5CZ*EGaYA;q9w znZ}^yKV?RDMQ1a8{C3Q%0)gDNN@RV7@gF&qo5H=~W`&!1p|CkphZ6(oY^BA?x^`&g z=bXGLJ^&Q(h9V>{c7?LVBSzg~0Kht-{Ixh0Aj3Rx=*QKC%O89JcFNCpcSH`qsn=6& zIZ_s`X4Wa8&Eq(}>kWVu+L`w?nq3}v6@3s=_Yj}sT{=OIuu|?!P)cU=d09PITW8+T z{U%H!`7Niz7?TJwXr;V{OOBz+e@X+hIcZ!-J6E<6!fkvZ zXUUDUwMK7^uhS-KV3N(X3pdS&&Ev+amjEq)$ku=A6XF>juai?%yjWOR z6Kyiu*5Gha&XF;2}iE z<{zkB8#~Uurk=zCALhY`D6hu;x`O!Qih(gCx6$6MRh+3I-87DcQYGKB613-jk;VBg zOiG-ICf{ivENM7jHR&L`nmI;A0sJy3@3JXxZk|mSa6$y~u5xE>Xyg0fB8d)-B;Rl- zCpG6wGOL=h;Jjh-v<{=*&adOd_)K;~!V}Q-x>`0juVxA91LzEi`uXUs ztiyUmK;)75Zs`Bu^5tFVw}2ZNXmbEa6;w>oJZ7HUidi*#POliN^Jw**QB*IafC7Ew zf$1AOcP7mLwE=@%mgNwQTTc$?)d_qPi-LaGyX(s8wHp=~NERLaHL{lM{N^5V!*1-yO}UkY-V4i+R5@_n zw5!KFhD8stea+mn=hd>h>Kamm0NjF|~yqP#jLVKVai zY_PIMRjU~`h&(bE+O8F;7d1%XIf7k}P0Eqma_Z~nBFSb$?9+ADyYXOp@!_mB( z=*8Jtt*7PgN1li{jIy_`SV9r{X@Xr|%6XY+8+t5nw=c^`arjPpif2v>iJY08YFSmo z5FG4!4|o&`b838C>YJsQOvY61i?S}x2~p0fG0Cchl<6>Pbe}wRjzWvo(V*_6`)y4) z2E)PYkQ{BYQFaNj7N@QL$=!BWmUT72b~x$szsC3JiH1>v3g%m6LY>%oJfnZd@cgtqa08>mivB+g|md?2%G zqF6OYd~<2sEoof#I(MyIo$$y8MqA@o=cH|kAM7T3vT-Zg@@=(E*TboP){e}N<0H4u zYE>KhJYrb4()-#<(oXl-;@iPxYKjwYb`#ri)it8ef%Sri%}n!5Egu*gpZENc7qV8N zj$}fe$(qQy7Z9~lvk({=LE`~I;k+P8^Is|{(zxNFaFe=aUm!z<9+huWyD?%*p>_tQ zoifmlCjekr$FyvzdFLBn&x)0YA&G5x+1=u-jq_*?z?h=^VraKM$0Hrgq2O0uO&Fot~8D>wUAsaARX0OW6 zB;hZzj2E4;A73d9tsTY1)Dgj@<@Db`7rHHB;?+fp%B}M1)$;T*PC?8r0BEu}9IC1S zF%}erx&Gt~I}rPVE39@ zp9|>yAQq*bJn$;kRAiSBn64dcQr5fZzy#lL6kT($9LIuYg_y$*r`)48Q;=J9h6>KIWSJi@ERs~=GPZ|}pnCKd!ac0#dEy>SZa|k~FIZPha1WjQz#MTd^ zdu4Ufg=PCXeGBvi8L^$|83@IwgpDxF$uE5)9e1)6{fuubmYUG)S029)CFJDEi3?)G zLH@?WkHHcmm%YmDR#NHF!}_l+pakQo=&3OC!1nI5cS=maNnt)N7H{JITmx1r%8!p| zLa4zlvAsuYnFH49uCf`~Lyd3?B-?$hMl!yIkAEm_LHs9WH30Bun|@HlgB1SQT@#mi zVXOhZ=g%vHAp_F4amldRgc+z}Ou4J45fhc*dxij~x!3SH@Nt)uQ^nzd&31ZXSe{|a z@`Vl2C~@l|0N^d^!G|^x8F}{75=h&EEf`JWJo@s_iSnH(Ha`CT&<|83tO?cn9;@Rh zY}>ug$u0BjdH`P0QI|S0)I=oHxTS~b6{rX?n`xSY%&}?nhiU5c6JOVp(Y;SHE5fGs z;b}#$kW`@0xCQ1{JcbUrHOy}yz1&zo>P)ex;8_Q;6TpUyMxGYRs|C0l2Ut`+LDQ#! z%a%(-F|S}|(0+mp)1rnE@brx26@k=)BRvn^1V(H}?XUy5@>2xOc6L34X$eOUe}yr2 z&)%!g3P6H&Y*>JTaDEGN&)F;LFgK3IzD^ISK3J_8F><&fg=HOOViF z>t_#s(+O>!d2V6@<#Iz`5^MjVfVT_0rqgzfMH`LXLw%xW8y%sNI|59P2Nl{bnQXfh zwuP`prOX<&u;fR!Q4j8*2j=z#EIu_A{ zmy}m>&$RAK^WLC#K4Zt{rD+@9--{C6?!8mz`idD-Gix>LMu{^-<81X<84kBmH3Q`# z!%p^6N1Jo{9Mh5o6NQZKhnk{{fvbR_3H*xTO0qEm^P zT6?Ni=f5_h3|W<*e6Uv1R_foSIWR~`ZvE8Y=mBnOYq)4}-KVjkX-TLmKS_xQ0u@tC z*n2ZTgOj3Fvc1Ea3!#Qj)k1D@X7y*M^tRnjoh3u3x9_xM*fkSpQwFD0hc*e{5WniY z;MDTX#4CerWPOxQxRq+1+%4;rdC=99VpUi+S6B+msrJydtmzGZmeBQ-^TiG?X6@AW z+KD_xIj{amu&vFJ@#DP$Z^LlZYa<5=k&5v({laFJ=0*Wxpt$}*X-kBind(?qX^#5s zlx6LzqX{XqiPX8shmA8;K~q2+nvl$@&Vgr$os}0xjTe*W$aZm>j&6&Mv?-*O?YfDa zt{RI-xFUC%c%IbI{Cr^qy6mdnTLaY(X)`IF19ZI2t2ql2>A(M7de|)9dUcB@6N+IV zGiw@sje#>iiZfg=E|(u?c=~2}$J|lmJ~F!BxF{%dqRKBgHA6L?77*cBBZ%s+%h7uiE!tSgr&Z(HkQ$0eWt>T9!a?-yxG+m-;)frrKu+agDc?aG+$&o>tf78ngk_Xkrf6Ok=@1QoFovB-8nGhmAwb%m?^~} zUp);9lUg6CN|sZz2VNS1%+dPULBTNcWhRwviY~&la4K!G4-=O@se8x2ZpDLN7uKVe zE=Y~#x!@41O53{Nt^GrT9i<8y_lr>8H+eBZB`fkmnPS{+rq)Y=maG!wBkm#5bbNti zsM&vRCyC&Z(QJdcGIqv$(Qsfac=*%e;VzE8a}m_UU*1|ia8#4LwXHCSVZk+NkeI+T z47z9uy1B~Cv#)AlbMSrdajUsTHQgof%4ZODktnmaAs?yI)}n@60z%UZIL=#g{1$~y zMNeH-{|=@0JM>8%DT{WN7$yjSV*O5>p}*itLoDwbd9Nh-O40j-RwSC8#ti&=_dVK zQFp}iY8F)I#|zg(>r*~~8Wei8j{DLemSv@lIIDkc04PsZT*1e14MWkEN3Ks2k{!Kx z=V{^M11c5#g*~~ zs0W+YoQH30FFvQaT&grg{a#R4RKhAsQlH1^p6K9E{~>17X0Y(cz4+LKo!Et?Xr~Hn zoOcuNfJnJ$MpA6+ymseYPic678MvimQtcrBax3H-_#{ui#MZd1ZXstm(}J{!C(wvo znsyFTlRM1TvMtxH^@@XNu&a)|Uk?q!$avGig*al1x335LL3&T(}V1Jc@xLcN4##uf!tq3z1XjDmzcg%JW z%BcJCn1N*-L|C3-}UY}!%-s`G|AB3g}Z<*Td`YDk2pTb ziZvT4nSF`wgJUM19F~aRo#&7GEoyiV)94yDJtw0-iKB_XmKM_LlcN*3p6>ReLlS*e zkrwL$IAW`KIWQ88g$qoT3HgW{MZ8tS4ob45YL>zR_iLbxF=7()ei+=pWM;L9@$`8& zdLV4r2pDd#GgU|$nX(0?L3=!6z%(ECb$uOzEi?sxV`PG3A!TtY-_fqj^;qI>k$$Ie ziEq$D%)8`lHG@(!F*^JD?Ndz@AwjjuQL;u5aN1kG7D<4mmkP=Mmt=hnfA_P=6v4jobvTf!xXJ-vdOCllzhC6;4VGJEq_ba+4=c*z1{yc|{Ih-=?K)p6PV} zDOqL_<7@6?e%X`P3YRF^%k#@KAkJcujqw12wQ^w$vP+1p?yD9jwr>Wok(ujUD4$cN zO)>f`2)@0z8FE$Y+`pDaDEfEq@pps&+~&}1)xH#OR!vumHJ zJ?d|YtLc=f9|{)_hm=evWV=f%jCD#9HCJTo#XqTi=iMn)2ayWY83fGF2kup~2JmT< zKc!Zv)1#7`AZ-zp&b|jsCQoos&*8_U6}1wr2&pv!T!cl-AxK7iqxl^}BU4$oZ@a2o z)6hX9`Vyg@EBe?AXW=a0k8wGFD`pLm1#Xw1VH*`O>i4MhZAOB+7LV}H Y;)`|XfMGZzt68$5_u|_(;t|0A0aCctFaQ7m literal 0 HcmV?d00001 From 2e40acdf21d125b61737fd882080f13c5cde08bf Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jun 2026 08:06:32 +0200 Subject: [PATCH 53/70] Don't need this. --- tests/ad/reducebyindexminmax7.fut | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index 574ba3297d..c3cd646e48 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -21,9 +21,3 @@ entry fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) in jvp_vec (primal is dst) vs seeds |> transpose - -def approx_eql (rel_tol: f32) (a: f32) (b: f32) : bool = - let diff = f32.abs (a - b) - let scale = f32.max (f32.abs a) (f32.abs b) - let abs_tol = 100.0 * f32.epsilon * scale - in diff <= f32.max abs_tol (rel_tol * scale) From 27bd8a3a7c1b18afe14a1ba663e6f1d78f404fcd Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Thu, 11 Jun 2026 09:25:40 +0200 Subject: [PATCH 54/70] No ISPC for this one. --- tests/ad/reducebyindexminmax7.fut | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index c3cd646e48..ece4ac1c5c 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -1,4 +1,5 @@ -- == +-- tags { no_ispc } -- entry: rev fwd rev_vec fwd_vec -- compiled input @ reducebyindexminmax7.in output @ reducebyindexminmax7.out.gz From 1d441f9ccfe671aee2b5a0ca934ae503a8d7afae Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 13:04:27 +0200 Subject: [PATCH 55/70] Remove duplicate tests. --- tests/ad/arr0.fut | 23 ++++--- tests/ad/arr1.fut | 20 +++--- tests/ad/concat0.fut | 23 +++---- tests/ad/custom/radixsort.fut | 9 +-- tests/ad/custom/toplevel.fut | 5 +- tests/ad/{vec => }/hist_add.fut | 0 tests/ad/{vec => }/hist_complex.fut | 0 tests/ad/{vec => }/hist_minmax.fut | 0 tests/ad/{vec => }/hist_mul.fut | 0 tests/ad/{vec => }/index.fut | 1 + tests/ad/map0.fut | 22 +++++-- tests/ad/map1.fut | 21 +++++-- tests/ad/map2.fut | 16 ++--- tests/ad/map3.fut | 14 +++-- tests/ad/map4.fut | 31 +++++----- tests/ad/reduce0.fut | 25 +++++--- tests/ad/reduce1.fut | 75 ++++++++++++++++++----- tests/ad/{vec/reduce2.fut => reduce3.fut} | 0 tests/ad/reducebyindexminmax7.fut | 2 +- tests/ad/replicate0.fut | 22 ++++--- tests/ad/reshape0.fut | 12 ++-- tests/ad/{vec => }/transpose.fut | 1 + tests/ad/vec/README.md | 4 -- tests/ad/vec/arr0.fut | 17 ----- tests/ad/vec/arr1.fut | 15 ----- tests/ad/vec/concat.fut | 15 ----- tests/ad/vec/gather0.fut | 23 ------- tests/ad/vec/map0.fut | 20 ------ tests/ad/vec/map1.fut | 17 ----- tests/ad/vec/map2.fut | 21 ------- tests/ad/vec/map3.fut | 18 ------ tests/ad/vec/map4.fut | 37 ----------- tests/ad/vec/map5.fut | 26 -------- tests/ad/vec/primfun.fut | 18 ------ tests/ad/vec/reduce0.fut | 21 ------- tests/ad/vec/reduce1.fut | 16 ----- tests/ad/vec/reduce3.fut | 71 --------------------- tests/ad/vec/replicate.fut | 15 ----- tests/ad/vec/reshape.fut | 15 ----- tests/ad/vec/reshape0.fut | 11 ---- tests/ad/vec/scan0.fut | 26 -------- tests/ad/vec/scan1.fut | 32 ---------- tests/ad/vec/scatter0.fut | 22 ------- tests/ad/vec/scatter1.fut | 22 ------- 44 files changed, 212 insertions(+), 592 deletions(-) rename tests/ad/{vec => }/hist_add.fut (100%) rename tests/ad/{vec => }/hist_complex.fut (100%) rename tests/ad/{vec => }/hist_minmax.fut (100%) rename tests/ad/{vec => }/hist_mul.fut (100%) rename tests/ad/{vec => }/index.fut (95%) rename tests/ad/{vec/reduce2.fut => reduce3.fut} (100%) rename tests/ad/{vec => }/transpose.fut (96%) delete mode 100644 tests/ad/vec/README.md delete mode 100644 tests/ad/vec/arr0.fut delete mode 100644 tests/ad/vec/arr1.fut delete mode 100644 tests/ad/vec/concat.fut delete mode 100644 tests/ad/vec/gather0.fut delete mode 100644 tests/ad/vec/map0.fut delete mode 100644 tests/ad/vec/map1.fut delete mode 100644 tests/ad/vec/map2.fut delete mode 100644 tests/ad/vec/map3.fut delete mode 100644 tests/ad/vec/map4.fut delete mode 100644 tests/ad/vec/map5.fut delete mode 100644 tests/ad/vec/primfun.fut delete mode 100644 tests/ad/vec/reduce0.fut delete mode 100644 tests/ad/vec/reduce1.fut delete mode 100644 tests/ad/vec/reduce3.fut delete mode 100644 tests/ad/vec/replicate.fut delete mode 100644 tests/ad/vec/reshape.fut delete mode 100644 tests/ad/vec/reshape0.fut delete mode 100644 tests/ad/vec/scan0.fut delete mode 100644 tests/ad/vec/scan1.fut delete mode 100644 tests/ad/vec/scatter0.fut delete mode 100644 tests/ad/vec/scatter1.fut diff --git a/tests/ad/arr0.fut b/tests/ad/arr0.fut index 3ae014a68d..8db3185287 100644 --- a/tests/ad/arr0.fut +++ b/tests/ad/arr0.fut @@ -1,22 +1,25 @@ -- == -- tags { autodiff } -def f (xs: [2]f64) = xs[0] * xs[1] +def primal (xs: [2]f64) = xs[0] * xs[1] -- == --- entry: f_jvp +-- entry: fwd fwd_vec -- input { [5.0, 7.0] } --- output { 7.0 5.0 } +-- output { [7.0, 5.0] } + +entry fwd xs = + [ jvp primal xs [1, 0] + , jvp primal xs [0, 1] + ] -entry f_jvp xs = - ( jvp f xs [1, 0] - , jvp f xs [0, 1] - ) +entry fwd_vec xs = + jvp_vec primal xs [[1, 0], [0, 1]] -- == --- entry: f_vjp +-- entry: rev -- input { [5.0, 7.0] } -- output { [7.0, 5.0] } -entry f_vjp xs = - vjp f xs 1 +entry rev xs = + vjp primal xs 1 diff --git a/tests/ad/arr1.fut b/tests/ad/arr1.fut index 17f01a3b8f..584feaee15 100644 --- a/tests/ad/arr1.fut +++ b/tests/ad/arr1.fut @@ -1,17 +1,15 @@ -def f (x, y) : [2]f64 = [x + y, x * y] +def primal (x, y) : [2]f64 = [x + y, x * y] -- == -- tags { autodiff } --- entry: f_vjp f_jvp +-- entry: fwd fwd_vec -- input { 5.0 7.0 } --- output { [1.0,7.0] [1.0, 5.0] } +-- output { [[1.0,7.0], [1.0, 5.0]] } -entry f_jvp x y = - ( jvp f (x, y) (1, 0) - , jvp f (x, y) (0, 1) - ) +entry fwd x y = + [ jvp primal (x, y) (1, 0) + , jvp primal (x, y) (0, 1) + ] -entry f_vjp x y = - let (dx1, dx2) = vjp f (x, y) [1, 0] - let (dy1, dy2) = vjp f (x, y) [0, 1] - in ([dx1, dy1], [dx2, dy2]) +entry fwd_vec x y = + jvp_vec primal (x, y) [(1, 0), (0, 1)] diff --git a/tests/ad/concat0.fut b/tests/ad/concat0.fut index 6b6cfe0ba2..a3890dbefc 100644 --- a/tests/ad/concat0.fut +++ b/tests/ad/concat0.fut @@ -1,18 +1,19 @@ -- == -- tags { autodiff } +-- entry: fwd_vec fwd_map +-- input { [1.0, 2.0, 3.0] } +-- output { [[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0]] } --- == --- entry: f_jvp --- input { [1,2,3] [4,5,6] } --- output { [1,2,3,4,5,6] } +def f (xs: []f64) = xs ++ xs -entry f_jvp xs ys : []i32 = - jvp (uncurry concat) (xs, ys) (xs, ys) +entry fwd_vec (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec f xs seeds).1 --- == --- entry: f_vjp --- input { [1,2,3] [4,5,6] } --- output { [1,2,3] [4,5,6] } +entry fwd_map (xs: []f64) = + map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) -entry f_vjp xs ys : ([]i32, []i32) = +entry rev xs ys : ([]i32, []i32) = vjp (uncurry concat) (xs, ys) (concat xs ys) diff --git a/tests/ad/custom/radixsort.fut b/tests/ad/custom/radixsort.fut index f1961e9ae5..8d21661fb0 100644 --- a/tests/ad/custom/radixsort.fut +++ b/tests/ad/custom/radixsort.fut @@ -1,5 +1,6 @@ -- Custom derivative for radix sort. -- == +-- tags { autodiff } -- entry: main_standard main_custom -- input { [4f32,3f32,2f32,1f32] [0.1f32,0.2f32,0.3f32,0.4f32] } -- output { [0.4f32, 0.3f32, 0.2f32, 0.1f32 ] } @@ -20,10 +21,10 @@ def radix_sort [n] 't (f: t -> u32) (xs: [n]t) : [n]t = def differentiable_radix_sort [n] 't (f: t -> u32) (xs: [n]t) = (with_vjp (\xs -> - unzip (radix_sort (f <-< (.0)) (zip xs (iota n)))) - (\(_, perm) (xs_adj, _) -> - scatter (copy xs_adj) perm xs_adj) - xs).0 + unzip (radix_sort (f <-< (.0)) (zip xs (iota n)))) + (\(_, perm) (xs_adj, _) -> + scatter (copy xs_adj) perm xs_adj) + xs).0 entry main_standard = vjp (radix_sort f32.to_bits) entry main_custom = vjp (differentiable_radix_sort f32.to_bits) diff --git a/tests/ad/custom/toplevel.fut b/tests/ad/custom/toplevel.fut index 368e8bcb10..3efd612616 100644 --- a/tests/ad/custom/toplevel.fut +++ b/tests/ad/custom/toplevel.fut @@ -1,6 +1,7 @@ -- | If a custom derivative occurs at top level, just get rid of it. -- == +-- tags { autodiff } -- entry: do_primal -- input { 2.5 } output { 5.0 } @@ -10,8 +11,8 @@ def primal (x: f64) = with_vjp (\x -> x * 2) - (\c x_adj -> x_adj + f64.sqrt c) - x + (\c x_adj -> x_adj + f64.sqrt c) + x entry do_vjp x = vjp primal x 1 entry do_primal x = primal x diff --git a/tests/ad/vec/hist_add.fut b/tests/ad/hist_add.fut similarity index 100% rename from tests/ad/vec/hist_add.fut rename to tests/ad/hist_add.fut diff --git a/tests/ad/vec/hist_complex.fut b/tests/ad/hist_complex.fut similarity index 100% rename from tests/ad/vec/hist_complex.fut rename to tests/ad/hist_complex.fut diff --git a/tests/ad/vec/hist_minmax.fut b/tests/ad/hist_minmax.fut similarity index 100% rename from tests/ad/vec/hist_minmax.fut rename to tests/ad/hist_minmax.fut diff --git a/tests/ad/vec/hist_mul.fut b/tests/ad/hist_mul.fut similarity index 100% rename from tests/ad/vec/hist_mul.fut rename to tests/ad/hist_mul.fut diff --git a/tests/ad/vec/index.fut b/tests/ad/index.fut similarity index 95% rename from tests/ad/vec/index.fut rename to tests/ad/index.fut index 5b8c89eada..c2480985ad 100644 --- a/tests/ad/vec/index.fut +++ b/tests/ad/index.fut @@ -1,4 +1,5 @@ -- == +-- tags { autodiff } -- entry: fwd_vec fwd_map -- input { 0i32 [1f32, 2f32, 3f32] } -- output { [1f32, 0f32, 0f32] } diff --git a/tests/ad/map0.fut b/tests/ad/map0.fut index a531f6ec11..4d5a125194 100644 --- a/tests/ad/map0.fut +++ b/tests/ad/map0.fut @@ -1,7 +1,21 @@ -- == -- tags { autodiff } --- entry: rev --- input { [1,2,3] [3,2,1] } --- output { [6,4,2] } +-- entry: fwd_map fwd_vec rev_map rev_vec +-- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } +-- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0], [0.0f32, 0.0, 3.0]] } -entry rev = vjp (map (* 2i32)) +def prim = map2 (f32.*) + +entry fwd_map [n] (xs: [n]f32) (ys: [n]f32) = + tabulate n (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in jvp_vec (prim xs) ys seeds + +entry rev_map [n] (xs: [n]f32) (ys: [n]f32) = + transpose (tabulate n (\i -> vjp (prim xs) ys (replicate n 0 with [i] = 1))) + +entry rev_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) + in transpose (vjp_vec (prim xs) ys seeds) diff --git a/tests/ad/map1.fut b/tests/ad/map1.fut index efc35a9ad9..2acc256758 100644 --- a/tests/ad/map1.fut +++ b/tests/ad/map1.fut @@ -1,9 +1,18 @@ --- +-- Like map0, but we do not compute the full Jacobian, so the vector size is not +-- the same as the input size. -- == -- tags { autodiff } --- entry: rev --- input { [[1.0,2.0,3.0,4.0],[1.0,2.0,3.0,4.0]] [1.0,2.0] } --- output {[[24.0, 12.0, 8.0, 6.0], --- [48.0, 24.0, 16.0, 12.0]] } +-- entry: fwd fwd_vec +-- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } +-- output { [[5.0f32, 0.0, 0.0], [0.0f32, 7.0, 0.0]] } -entry rev = vjp (map f64.product) +def prim = map2 (f32.*) + +def k = 2i64 + +entry fwd [n] (xs: [n]f32) (ys: [n]f32) = + tabulate k (\i -> jvp (uncurry prim) (xs, ys) (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) + +entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = + let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) + in jvp_vec (uncurry prim) (xs, ys) seeds diff --git a/tests/ad/map2.fut b/tests/ad/map2.fut index 2e3a5916c8..8c96c57cae 100644 --- a/tests/ad/map2.fut +++ b/tests/ad/map2.fut @@ -1,19 +1,21 @@ -- Map with free variable. -- == -- tags { autodiff } --- entry: fwd_J rev_J rev_vec_J +-- entry: fwd_map rev_map rev_vec -- input { 2.0 [1.0,2.0,3.0] } -- output { [1.0,2.0,3.0] } +def primal xs (c': f64) = map (* c') xs + def onehot n i : [n]f64 = tabulate n (\j -> f64.bool (i == j)) -entry fwd_J [n] (c: f64) (xs: [n]f64) = - jvp (\c' -> map (* c') xs) c 1 +entry fwd_map [n] (c: f64) (xs: [n]f64) = + jvp (primal xs) c 1 -entry rev_J [n] (c: f64) (xs: [n]f64) = - tabulate n (\i -> vjp (\c' -> map (* c') xs) c (onehot n i)) +entry rev_map [n] (c: f64) (xs: [n]f64) = + tabulate n (\i -> vjp (primal xs) c (onehot n i)) -entry rev_vec_J [n] (c: f64) (xs: [n]f64) = +entry rev_vec [n] (c: f64) (xs: [n]f64) = let seeds = tabulate n (\i -> onehot n i) - in vjp_vec (\c' -> map (* c') xs) c seeds + in vjp_vec (primal xs) c seeds diff --git a/tests/ad/map3.fut b/tests/ad/map3.fut index 36ae1e509a..0a4aa89880 100644 --- a/tests/ad/map3.fut +++ b/tests/ad/map3.fut @@ -1,16 +1,18 @@ -- == -- tags { autodiff } --- entry: fwd rev rev_vec +-- entry: fwd rev_map rev_vec -- input { 1i32 [1i32,2i32,3i32] } -- output { [1i32,2i32,3i32] } +def primal xs (x: i32) = map (* x) xs + entry fwd [n] (x: i32) (xs: [n]i32) = - jvp (\x -> map (* x) xs) x 1 + jvp (primal xs) x 1 -entry rev [n] (x: i32) (xs: [n]i32) = +entry rev_map [n] (x: i32) (xs: [n]i32) = tabulate n (\i -> - vjp (\x -> map (* x) xs) x (replicate n 0 with [i] = 1)) + vjp (primal xs) x (replicate n 0 with [i] = 1)) entry rev_vec [n] (x: i32) (xs: [n]i32) = - let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec (\x -> map (* x) xs) x seeds + let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) + in vjp_vec (primal xs) x seeds diff --git a/tests/ad/map4.fut b/tests/ad/map4.fut index e42e5e902e..e2660afada 100644 --- a/tests/ad/map4.fut +++ b/tests/ad/map4.fut @@ -1,13 +1,13 @@ -- An array is both a 'map' input and a free variable in the lambda. -- == -- tags { autodiff } --- entry: fwd_J rev_J fwd_vec_J rev_vec_J +-- entry: fwd_map fwd_vec rev_map rev_vec -- input { [1,2,3] } -- output { -- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] -- } -def f (xs: []i32) = +def primal (xs: []i32) = map (\x -> map (+ x) xs) xs def onehot n i : [n]i32 = @@ -16,23 +16,22 @@ def onehot n i : [n]i32 = def onehot_2d n m p : [n][m]i32 = tabulate_2d n m (\i j -> i32.bool ((i, j) == p)) -entry fwd_J [n] (xs: [n]i32) = - tabulate n (\i -> jvp f xs (onehot n i)) +entry fwd_map [n] (xs: [n]i32) = + tabulate n (\i -> jvp primal xs (onehot n i)) |> map transpose |> transpose |> map transpose -entry rev_J [n] (xs: [n]i32) = - tabulate_2d n n (\i j -> vjp f xs (onehot_2d n n (i, j))) - -entry fwd_vec_J [n] (xs: [n]i32) = +entry fwd_vec [n] (xs: [n]i32) = let seeds = tabulate n (\i -> onehot n i) - in jvp_vec f xs seeds - |> map transpose - |> transpose - |> map transpose + in jvp_vec primal xs seeds + |> map transpose + |> transpose + |> map transpose + +entry rev_map [n] (xs: [n]i32) = + tabulate_2d n n (\i j -> vjp primal xs (onehot_2d n n (i, j))) -entry rev_vec_J [n] (xs: [n]i32) = - let seeds = tabulate (n * n) (\k -> onehot_2d n n (k / n, k % n)) - in vjp_vec f xs seeds - |> unflatten +entry rev_vec [n] (xs: [n]i32) = + let seeds = tabulate_2d n n (\i j -> onehot_2d n n (i, j)) + in unflatten (vjp_vec primal xs (flatten seeds)) diff --git a/tests/ad/reduce0.fut b/tests/ad/reduce0.fut index a0dc3b7e04..65822f3bfb 100644 --- a/tests/ad/reduce0.fut +++ b/tests/ad/reduce0.fut @@ -1,11 +1,22 @@ --- Simple reduce with multiplication -- == -- tags { autodiff } --- entry: rev --- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32] 1.0f32 } output { [24.0f32, 12.0f32, 8.0f32, 6.0f32] 24.0f32 } +-- entry: fwd_vec fwd_map rev_vec +-- input { [1f32, 2f32, 3f32] } +-- output { [6f32, 3f32, 2f32] } -def red_mult [n] (xs: [n]f32, c: f32) : f32 = - reduce (*) 1 xs * c +def f (xs: []f32) = f32.product xs -entry rev [n] (xs: [n]f32) (c: f32) = - vjp red_mult (xs, c) 1 +entry fwd_vec (xs: []f32) : []f32 = + let seeds = + map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec f xs seeds).1 + +entry fwd_map (xs: []f32) : []f32 = + map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) + (indices xs) + +-- No rev_map because it would just get optimised away. The rev_vec is pointless +-- enough already. + +entry rev_vec (xs: []f32) : []f32 = + head (vjp_vec f xs [1]) diff --git a/tests/ad/reduce1.fut b/tests/ad/reduce1.fut index 0623eb8aac..25286aa9fb 100644 --- a/tests/ad/reduce1.fut +++ b/tests/ad/reduce1.fut @@ -1,24 +1,71 @@ --- Reduce with a fancier operator. +-- Reduce with 2x2 matrix multiplication. -- == -- tags { autodiff } --- entry: rev --- input { [1.0,2.0,3.0] [2.0,3.0,4.0] [3.0,4.0,5.0] [4.0,5.0,6.0] } --- output { [47.0, 28.0, 32.0] --- [83.0, 44.0, 32.0] --- [47.0, 42.0, 42.0] --- [83.0, 66.0, 42.0] } - -def mm2by2 (a1: f64, b1: f64, c1: f64, d1: f64) - (a2: f64, b2: f64, c2: f64, d2: f64) = +-- entry: fwd_map rev_map fwd_vec rev_vec +-- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } +-- output { +-- [[[92.0f32, 36.0, 0.0, 0.0], +-- [8.0f32, 20.0, 16.0, 40.0], +-- [32.0f32, 16.0, 20.0, 10.0], +-- [23.0f32, 0.0, 36.0, 0.0]], +-- [[59.0f32, 23.0, 0.0, 0.0], +-- [5.0f32, 13.0, 10.0, 26.0], +-- [24.0f32, 8.0, 15.0, 5.0], +-- [0.0f32, 23.0, 0.0, 36.0]], +-- [[0.0f32, 0.0, 92.0, 36.0], +-- [24.0f32, 60.0, 32.0, 80.0], +-- [80.0f32, 40.0, 52.0, 26.0], +-- [59.0f32, 0.0, 92.0, 0.0]], +-- [[0.0f32, 0.0, 59.0, 23.0], +-- [15.0f32, 39.0, 20.0, 52.0], +-- [60.0f32, 20.0, 39.0, 13.0], +-- [0.0f32, 59.0, 0.0, 92.0]]] +-- } + +def mm2by2 (a1: f32, b1: f32, c1: f32, d1: f32) + (a2: f32, b2: f32, c2: f32, d2: f32) = ( a1 * a2 + b1 * c2 , a1 * b2 + b1 * d2 , c1 * a2 + d1 * c2 , c1 * b2 + d1 * d2 ) -def red_mm [n] (xs: [n](f64, f64, f64, f64)) = +def primal [n] (xs: [n](f32, f32, f32, f32)) = reduce mm2by2 (1, 0, 0, 1) xs -entry rev [n] (xs1: [n]f64) (xs2: [n]f64) (xs3: [n]f64) (xs4: [n]f64) = - vjp red_mm (zip4 xs1 xs2 xs3 xs4) (1, 1, 1, 1) - |> unzip4 +def fromarr = \(x: [4]f32) -> (x[0], x[1], x[2], x[3]) + +def fromarrs = map fromarr +def toarrs = map (\(a, b, c, d) -> [a, b, c, d]) + +def onehot_1d n x = + tabulate n (\i -> f32.bool (i == x)) + +def onehot_2d n m x y = + tabulate_2d n m (\i j -> f32.bool ((i, j) == (x, y))) + +entry fwd_map [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + in tabulate (n * 4) (\i -> jvp primal input (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) + |> toarrs + |> transpose + |> map unflatten + +entry fwd_vec [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate (n * 4) (\i -> (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) + in jvp_vec primal input seeds + |> toarrs + |> transpose + |> map unflatten + +entry rev_map [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + in tabulate 4 (\i -> vjp primal input (fromarr (onehot_1d 4 i))) + |> map toarrs + +entry rev_vec [n] (input: [n][4]f32) : [4][n][4]f32 = + let input = fromarrs input + let seeds = tabulate 4 (\i -> fromarr (onehot_1d 4 i)) + in vjp_vec primal input seeds + |> map toarrs diff --git a/tests/ad/vec/reduce2.fut b/tests/ad/reduce3.fut similarity index 100% rename from tests/ad/vec/reduce2.fut rename to tests/ad/reduce3.fut diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index ece4ac1c5c..c17991b9ba 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -1,5 +1,5 @@ -- == --- tags { no_ispc } +-- tags { autodiff no_ispc } -- entry: rev fwd rev_vec fwd_vec -- compiled input @ reducebyindexminmax7.in output @ reducebyindexminmax7.out.gz diff --git a/tests/ad/replicate0.fut b/tests/ad/replicate0.fut index 091d5f3085..d74ee052a6 100644 --- a/tests/ad/replicate0.fut +++ b/tests/ad/replicate0.fut @@ -2,17 +2,25 @@ -- tags { autodiff } -- == --- entry: f_jvp --- input { 3i64 2 } --- output { [1,1,1] } +-- entry: fwd_vec fwd_map +-- input { 2i64 [1.0, 2.0] } +-- output { [[[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 1.0]]] } -entry f_jvp n x : []i32 = - jvp (replicate n) x 1 +def f (n: i64) (xs: []f64) = replicate n xs + +entry fwd_vec n (xs: []f64) = + let seeds = + map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) + in (jvp2_vec (f n) xs seeds).1 + +entry fwd_map n (xs: []f64) = + map (\i -> jvp (f n) xs (map (\j -> f64.bool (i == j)) (indices xs))) + (indices xs) -- == --- entry: f_vjp +-- entry: rev -- input { 3i64 2i64 } -- output { 3i64 } -entry f_vjp n x = +entry rev n x = vjp (replicate n) x (iota n) diff --git a/tests/ad/reshape0.fut b/tests/ad/reshape0.fut index f54a9e3737..e26b2ca472 100644 --- a/tests/ad/reshape0.fut +++ b/tests/ad/reshape0.fut @@ -2,12 +2,16 @@ -- tags { autodiff } -- == --- entry: f_jvp +-- entry: fwd_map fwd_vec -- input { 2i64 2i64 [1,2,3,4] } --- output { [[1,2],[3,4]] } +-- output { [[[1, 0], [0, 0]], [[0, 1], [0, 0]]] } -entry f_jvp n m (xs: [n * m]i32) = - jvp unflatten xs xs +entry fwd_map n m (xs: [n * m]i32) = + tabulate 2 (\i -> jvp unflatten xs (replicate (n * m) 0 with [i] = 1)) + +entry fwd_vec n m (xs: [n * m]i32) = + let seeds = tabulate 2 (\i -> replicate (n * m) 0 with [i] = 1) + in jvp_vec unflatten xs seeds -- == -- entry: f_vjp diff --git a/tests/ad/vec/transpose.fut b/tests/ad/transpose.fut similarity index 96% rename from tests/ad/vec/transpose.fut rename to tests/ad/transpose.fut index f8c5bf8a32..f085690a69 100644 --- a/tests/ad/vec/transpose.fut +++ b/tests/ad/transpose.fut @@ -1,4 +1,5 @@ -- == +-- tags { autodiff } -- entry: fwd_vec fwd_map -- input { [[1.0,2.0],[3.0,4.0]] } -- output { [[[1.0, 0.0],[0.0, 0.0]],[[0.0, 0.0],[1.0, 0.0]],[[0.0, 1.0],[0.0, 0.0]],[[0.0, 0.0],[0.0, 1.0]]] } diff --git a/tests/ad/vec/README.md b/tests/ad/vec/README.md deleted file mode 100644 index 9a2d916021..0000000000 --- a/tests/ad/vec/README.md +++ /dev/null @@ -1,4 +0,0 @@ -# Microbenchmarks for vectorised AD - -This directory contains tests for the differentiation of individual (core -language) language primitives. diff --git a/tests/ad/vec/arr0.fut b/tests/ad/vec/arr0.fut deleted file mode 100644 index a94cf48aee..0000000000 --- a/tests/ad/vec/arr0.fut +++ /dev/null @@ -1,17 +0,0 @@ --- == --- tags { autodiff } - -def primal (xs: [2]f64) = xs[0] * xs[1] - --- == --- entry: fwd fwd_vec --- input { [5.0, 7.0] } --- output { [7.0, 5.0] } - -entry fwd xs = - [ jvp primal xs [1, 0] - , jvp primal xs [0, 1] - ] - -entry fwd_vec xs = - jvp_vec primal xs [[1, 0], [0, 1]] diff --git a/tests/ad/vec/arr1.fut b/tests/ad/vec/arr1.fut deleted file mode 100644 index 584feaee15..0000000000 --- a/tests/ad/vec/arr1.fut +++ /dev/null @@ -1,15 +0,0 @@ -def primal (x, y) : [2]f64 = [x + y, x * y] - --- == --- tags { autodiff } --- entry: fwd fwd_vec --- input { 5.0 7.0 } --- output { [[1.0,7.0], [1.0, 5.0]] } - -entry fwd x y = - [ jvp primal (x, y) (1, 0) - , jvp primal (x, y) (0, 1) - ] - -entry fwd_vec x y = - jvp_vec primal (x, y) [(1, 0), (0, 1)] diff --git a/tests/ad/vec/concat.fut b/tests/ad/vec/concat.fut deleted file mode 100644 index 69d94787df..0000000000 --- a/tests/ad/vec/concat.fut +++ /dev/null @@ -1,15 +0,0 @@ --- == --- entry: fwd_vec fwd_map --- input { [1.0, 2.0, 3.0] } --- output { [[1.0, 0.0, 0.0, 1.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 1.0, 0.0, 0.0, 1.0]] } - -def f (xs: []f64) = xs ++ xs - -entry fwd_vec (xs: []f64) = - let seeds = - map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 - -entry fwd_map (xs: []f64) = - map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) - (indices xs) diff --git a/tests/ad/vec/gather0.fut b/tests/ad/vec/gather0.fut deleted file mode 100644 index 29c4a29715..0000000000 --- a/tests/ad/vec/gather0.fut +++ /dev/null @@ -1,23 +0,0 @@ --- == --- entry: fwd fwd_vec --- input { [4.0,3.0,2.0,1.0] [0i64,1i64,2i64,3i64] } --- output { [[1.0, 0.0, 0.0, 0.0], --- [0.0, 1.0, 0.0, 0.0], --- [0.0, 0.0, 1.0, 0.0], --- [0.0, 0.0, 0.0, 1.0]] --- } --- input { [4.0,3.0,2.0,1.0] [0i64,0i64,3i64,3i64] } --- output { [[1.0, 0.0, 0.0, 0.0], --- [1.0, 0.0, 0.0, 0.0], --- [0.0, 0.0, 0.0, 1.0], --- [0.0, 0.0, 0.0, 1.0]] --- } - -def gather xs is = map (\(i: i64) -> xs[i]) is - -entry fwd [n] [m] (xs: [n]f64) (is: [m]i64) = - transpose (tabulate n (\j -> jvp (`gather` is) xs (replicate n 0 with [j] = 1))) - -entry fwd_vec [n] [m] (xs: [n]f64) (is: [m]i64) = - let seeds = tabulate n (\j -> replicate n 0 with [j] = 1) - in transpose (jvp_vec (`gather` is) xs seeds) diff --git a/tests/ad/vec/map0.fut b/tests/ad/vec/map0.fut deleted file mode 100644 index 8ae207d68e..0000000000 --- a/tests/ad/vec/map0.fut +++ /dev/null @@ -1,20 +0,0 @@ --- == --- entry: fwd_map fwd_vec rev_map rev_vec --- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } --- output { [[1.0f32, 0.0, 0.0], [0.0f32, 2.0, 0.0], [0.0f32, 0.0, 3.0]] } - -def prim = map2 (f32.*) - -entry fwd_map [n] (xs: [n]f32) (ys: [n]f32) = - tabulate n (\i -> jvp (prim xs) ys (replicate n 0 with [i] = 1)) - -entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = - let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (prim xs) ys seeds - -entry rev_map [n] (xs: [n]f32) (ys: [n]f32) = - transpose (tabulate n (\i -> vjp (prim xs) ys (replicate n 0 with [i] = 1))) - -entry rev_vec [n] (xs: [n]f32) (ys: [n]f32) = - let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in transpose (vjp_vec (prim xs) ys seeds) diff --git a/tests/ad/vec/map1.fut b/tests/ad/vec/map1.fut deleted file mode 100644 index 12dd98fb73..0000000000 --- a/tests/ad/vec/map1.fut +++ /dev/null @@ -1,17 +0,0 @@ --- Like map0, but we do not compute the full Jacobian, so the vector size is not --- the same as the input size. --- == --- entry: fwd fwd_vec --- input { [1.0f32, 2.0, 3.0] [4f32, 5, 6] } --- output { [[5.0f32, 0.0, 0.0], [0.0f32, 7.0, 0.0]] } - -def prim = map2 (f32.*) - -def k = 2i64 - -entry fwd [n] (xs: [n]f32) (ys: [n]f32) = - tabulate k (\i -> jvp (uncurry prim) (xs, ys) (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) - -entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = - let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) - in jvp_vec (uncurry prim) (xs, ys) seeds diff --git a/tests/ad/vec/map2.fut b/tests/ad/vec/map2.fut deleted file mode 100644 index 8c96c57cae..0000000000 --- a/tests/ad/vec/map2.fut +++ /dev/null @@ -1,21 +0,0 @@ --- Map with free variable. --- == --- tags { autodiff } --- entry: fwd_map rev_map rev_vec --- input { 2.0 [1.0,2.0,3.0] } --- output { [1.0,2.0,3.0] } - -def primal xs (c': f64) = map (* c') xs - -def onehot n i : [n]f64 = - tabulate n (\j -> f64.bool (i == j)) - -entry fwd_map [n] (c: f64) (xs: [n]f64) = - jvp (primal xs) c 1 - -entry rev_map [n] (c: f64) (xs: [n]f64) = - tabulate n (\i -> vjp (primal xs) c (onehot n i)) - -entry rev_vec [n] (c: f64) (xs: [n]f64) = - let seeds = tabulate n (\i -> onehot n i) - in vjp_vec (primal xs) c seeds diff --git a/tests/ad/vec/map3.fut b/tests/ad/vec/map3.fut deleted file mode 100644 index 0a4aa89880..0000000000 --- a/tests/ad/vec/map3.fut +++ /dev/null @@ -1,18 +0,0 @@ --- == --- tags { autodiff } --- entry: fwd rev_map rev_vec --- input { 1i32 [1i32,2i32,3i32] } --- output { [1i32,2i32,3i32] } - -def primal xs (x: i32) = map (* x) xs - -entry fwd [n] (x: i32) (xs: [n]i32) = - jvp (primal xs) x 1 - -entry rev_map [n] (x: i32) (xs: [n]i32) = - tabulate n (\i -> - vjp (primal xs) x (replicate n 0 with [i] = 1)) - -entry rev_vec [n] (x: i32) (xs: [n]i32) = - let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) - in vjp_vec (primal xs) x seeds diff --git a/tests/ad/vec/map4.fut b/tests/ad/vec/map4.fut deleted file mode 100644 index e2660afada..0000000000 --- a/tests/ad/vec/map4.fut +++ /dev/null @@ -1,37 +0,0 @@ --- An array is both a 'map' input and a free variable in the lambda. --- == --- tags { autodiff } --- entry: fwd_map fwd_vec rev_map rev_vec --- input { [1,2,3] } --- output { --- [[[2, 0, 0], [1, 1, 0], [1, 0, 1]], [[1, 1, 0], [0, 2, 0], [0, 1, 1]], [[1, 0, 1], [0, 1, 1], [0, 0, 2]]] --- } - -def primal (xs: []i32) = - map (\x -> map (+ x) xs) xs - -def onehot n i : [n]i32 = - tabulate n (\j -> i32.bool (i == j)) - -def onehot_2d n m p : [n][m]i32 = - tabulate_2d n m (\i j -> i32.bool ((i, j) == p)) - -entry fwd_map [n] (xs: [n]i32) = - tabulate n (\i -> jvp primal xs (onehot n i)) - |> map transpose - |> transpose - |> map transpose - -entry fwd_vec [n] (xs: [n]i32) = - let seeds = tabulate n (\i -> onehot n i) - in jvp_vec primal xs seeds - |> map transpose - |> transpose - |> map transpose - -entry rev_map [n] (xs: [n]i32) = - tabulate_2d n n (\i j -> vjp primal xs (onehot_2d n n (i, j))) - -entry rev_vec [n] (xs: [n]i32) = - let seeds = tabulate_2d n n (\i j -> onehot_2d n n (i, j)) - in unflatten (vjp_vec primal xs (flatten seeds)) diff --git a/tests/ad/vec/map5.fut b/tests/ad/vec/map5.fut deleted file mode 100644 index e16aa52d8e..0000000000 --- a/tests/ad/vec/map5.fut +++ /dev/null @@ -1,26 +0,0 @@ --- Map with free array variable. --- == --- tags { autodiff } --- entry: fwd_map rev_map --- input { [[1,2,3],[4,5,6]] [0,0] } --- output { [[1, 0], [0, 1]] } - -def onehot n i : [n]i32 = - tabulate n (\j -> i32.bool (i == j)) - -def primal [n] [m] (free: [n][m]i32) (is: [n]i32) = - map (\i -> foldl (+) 0 free[i] + i) is - -entry fwd_map [n] [m] (free: [n][m]i32) (is: [n]i32) = - tabulate n (\i -> jvp (primal free) is (onehot n i)) |> transpose - -entry fwd_vec [n] [m] (free: [n][m]i32) (is: [n]i32) = - let seeds = tabulate n (\i -> onehot n i) - in jvp_vec (primal free) is seeds |> transpose - -entry rev_map [n] [m] (free: [n][m]i32) (is: [n]i32) = - tabulate n (\i -> vjp (primal free) is (onehot n i)) - -entry rev_vec [n] [m] (free: [n][m]i32) (is: [n]i32) = - let seeds = tabulate n (\i -> onehot n i) - in vjp_vec (primal free) is seeds diff --git a/tests/ad/vec/primfun.fut b/tests/ad/vec/primfun.fut deleted file mode 100644 index 3aecde8b94..0000000000 --- a/tests/ad/vec/primfun.fut +++ /dev/null @@ -1,18 +0,0 @@ --- == --- entry: fwd_map fwd_vec rev_map rev_vec --- input { [1f32, 2f32, 3f32] } --- output { [[0.5f32, 0.0, 0.0], [0.0f32, 0.35355338, 0.0], [0.0f32, 0.0, 0.28867513]] } - -def primal = map f32.sqrt - -entry fwd_map [n] (xs: [n]f32) = - tabulate n (\i -> jvp primal xs (replicate n 0 with [i] = 1)) - -entry fwd_vec [n] (xs: [n]f32) = - jvp_vec primal xs (tabulate n (\i -> replicate n 0 with [i] = 1)) - -entry rev_map [n] (xs: [n]f32) = - tabulate n (\i -> vjp primal xs (replicate n 0 with [i] = 1)) - -entry rev_vec [n] (xs: [n]f32) = - vjp_vec primal xs (tabulate n (\i -> replicate n 0 with [i] = 1)) diff --git a/tests/ad/vec/reduce0.fut b/tests/ad/vec/reduce0.fut deleted file mode 100644 index 0ad76d8189..0000000000 --- a/tests/ad/vec/reduce0.fut +++ /dev/null @@ -1,21 +0,0 @@ --- == --- entry: fwd_vec fwd_map rev_vec --- input { [1f32, 2f32, 3f32] } --- output { [6f32, 3f32, 2f32] } - -def f (xs: []f32) = f32.product xs - -entry fwd_vec (xs: []f32) : []f32 = - let seeds = - map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 - -entry fwd_map (xs: []f32) : []f32 = - map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) - (indices xs) - --- No rev_map because it would just get optimised away. The rev_vec is pointless --- enough already. - -entry rev_vec (xs: []f32) : []f32 = - head (vjp_vec f xs [1]) diff --git a/tests/ad/vec/reduce1.fut b/tests/ad/vec/reduce1.fut deleted file mode 100644 index fa5921d42f..0000000000 --- a/tests/ad/vec/reduce1.fut +++ /dev/null @@ -1,16 +0,0 @@ --- Reduce with addition. --- == --- tags { autodiff } --- entry: fwd_map fwd_vec --- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } --- output { [1.0f32, 1.0, 1.0, 1.0, 1.0] } - -entry fwd_map [n] (a: [n]f32) = - tabulate n (\i -> jvp (reduce (+) 0) a (replicate n 0 with [i] = 1)) - -entry fwd_vec [n] (a: [n]f32) = - let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (reduce (+) 0) a seeds - -entry rev_vec [n] (a: [n]f32) = - head (vjp_vec (reduce (+) 0) a [1]) diff --git a/tests/ad/vec/reduce3.fut b/tests/ad/vec/reduce3.fut deleted file mode 100644 index 25286aa9fb..0000000000 --- a/tests/ad/vec/reduce3.fut +++ /dev/null @@ -1,71 +0,0 @@ --- Reduce with 2x2 matrix multiplication. --- == --- tags { autodiff } --- entry: fwd_map rev_map fwd_vec rev_vec --- input { [[1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32], [1f32,2f32,3f32,4f32], [4f32,3f32,2f32,1f32]] } --- output { --- [[[92.0f32, 36.0, 0.0, 0.0], --- [8.0f32, 20.0, 16.0, 40.0], --- [32.0f32, 16.0, 20.0, 10.0], --- [23.0f32, 0.0, 36.0, 0.0]], --- [[59.0f32, 23.0, 0.0, 0.0], --- [5.0f32, 13.0, 10.0, 26.0], --- [24.0f32, 8.0, 15.0, 5.0], --- [0.0f32, 23.0, 0.0, 36.0]], --- [[0.0f32, 0.0, 92.0, 36.0], --- [24.0f32, 60.0, 32.0, 80.0], --- [80.0f32, 40.0, 52.0, 26.0], --- [59.0f32, 0.0, 92.0, 0.0]], --- [[0.0f32, 0.0, 59.0, 23.0], --- [15.0f32, 39.0, 20.0, 52.0], --- [60.0f32, 20.0, 39.0, 13.0], --- [0.0f32, 59.0, 0.0, 92.0]]] --- } - -def mm2by2 (a1: f32, b1: f32, c1: f32, d1: f32) - (a2: f32, b2: f32, c2: f32, d2: f32) = - ( a1 * a2 + b1 * c2 - , a1 * b2 + b1 * d2 - , c1 * a2 + d1 * c2 - , c1 * b2 + d1 * d2 - ) - -def primal [n] (xs: [n](f32, f32, f32, f32)) = - reduce mm2by2 (1, 0, 0, 1) xs - -def fromarr = \(x: [4]f32) -> (x[0], x[1], x[2], x[3]) - -def fromarrs = map fromarr -def toarrs = map (\(a, b, c, d) -> [a, b, c, d]) - -def onehot_1d n x = - tabulate n (\i -> f32.bool (i == x)) - -def onehot_2d n m x y = - tabulate_2d n m (\i j -> f32.bool ((i, j) == (x, y))) - -entry fwd_map [n] (input: [n][4]f32) : [4][n][4]f32 = - let input = fromarrs input - in tabulate (n * 4) (\i -> jvp primal input (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) - |> toarrs - |> transpose - |> map unflatten - -entry fwd_vec [n] (input: [n][4]f32) : [4][n][4]f32 = - let input = fromarrs input - let seeds = tabulate (n * 4) (\i -> (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) - in jvp_vec primal input seeds - |> toarrs - |> transpose - |> map unflatten - -entry rev_map [n] (input: [n][4]f32) : [4][n][4]f32 = - let input = fromarrs input - in tabulate 4 (\i -> vjp primal input (fromarr (onehot_1d 4 i))) - |> map toarrs - -entry rev_vec [n] (input: [n][4]f32) : [4][n][4]f32 = - let input = fromarrs input - let seeds = tabulate 4 (\i -> fromarr (onehot_1d 4 i)) - in vjp_vec primal input seeds - |> map toarrs diff --git a/tests/ad/vec/replicate.fut b/tests/ad/vec/replicate.fut deleted file mode 100644 index 222fb21ed3..0000000000 --- a/tests/ad/vec/replicate.fut +++ /dev/null @@ -1,15 +0,0 @@ --- == --- entry: fwd_vec fwd_map --- input { 2i64 [1.0, 2.0] } --- output { [[[1.0, 0.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 1.0]]] } - -def f (n: i64) (xs: []f64) = replicate n xs - -entry fwd_vec n (xs: []f64) = - let seeds = - map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec (f n) xs seeds).1 - -entry fwd_map n (xs: []f64) = - map (\i -> jvp (f n) xs (map (\j -> f64.bool (i == j)) (indices xs))) - (indices xs) diff --git a/tests/ad/vec/reshape.fut b/tests/ad/vec/reshape.fut deleted file mode 100644 index 2094e6ff4b..0000000000 --- a/tests/ad/vec/reshape.fut +++ /dev/null @@ -1,15 +0,0 @@ --- == --- entry: fwd_vec fwd_map --- input { [1.0,2.0,3.0,4.0] } --- output { [[[1.0, 0.0], [0.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]], [[0.0, 0.0], [0.0, 1.0]]] } - -def f (xs: []f64) = unflatten (sized (2 * 2) xs) - -entry fwd_vec (xs: []f64) = - let seeds = - map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 - -entry fwd_map (xs: []f64) = - map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) - (indices xs) diff --git a/tests/ad/vec/reshape0.fut b/tests/ad/vec/reshape0.fut deleted file mode 100644 index abbf753e7e..0000000000 --- a/tests/ad/vec/reshape0.fut +++ /dev/null @@ -1,11 +0,0 @@ --- == --- entry: fwd_map fwd_vec --- input { 2i64 2i64 [1,2,3,4] } --- output { [[[1, 0], [0, 0]], [[0, 1], [0, 0]]] } - -entry fwd_map n m (xs: [n * m]i32) = - tabulate 2 (\i -> jvp unflatten xs (replicate (n * m) 0 with [i] = 1)) - -entry fwd_vec n m (xs: [n * m]i32) = - let seeds = tabulate 2 (\i -> replicate (n * m) 0 with [i] = 1) - in jvp_vec unflatten xs seeds diff --git a/tests/ad/vec/scan0.fut b/tests/ad/vec/scan0.fut deleted file mode 100644 index 8bf83a5e70..0000000000 --- a/tests/ad/vec/scan0.fut +++ /dev/null @@ -1,26 +0,0 @@ --- == --- tags { autodiff } --- entry: fwd_vec fwd_map rev_map rev_vec --- input { [1f32, 2f32, 3f32] } --- output { [[1f32, 2.0, 6.0], [0f32, 1.0, 3.0], [0f32, 0.0, 2.0]] } - -def f (xs: []f32) = scan (*) 1 xs - -entry fwd_vec (xs: []f32) : [][]f32 = - let seeds = - map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 - -entry fwd_map (xs: []f32) : [][]f32 = - map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) - (indices xs) - -entry rev_map (xs: []f32) : [][]f32 = - map (\i -> vjp f xs (map (\j -> f32.bool (i == j)) (indices xs))) - (indices xs) - |> transpose - -entry rev_vec (xs: []f32) : [][]f32 = - let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in vjp_vec f xs seeds - |> transpose diff --git a/tests/ad/vec/scan1.fut b/tests/ad/vec/scan1.fut deleted file mode 100644 index 9870b92df0..0000000000 --- a/tests/ad/vec/scan1.fut +++ /dev/null @@ -1,32 +0,0 @@ --- Scan with addition. --- == --- tags { autodiff } --- entry: fwd_vec fwd_map rev_map rev_vec --- input { [1.0f32, 2.0f32, 3.0f32, 4.0f32, 5.0f32] } --- output { [[1f32, 1f32, 1f32, 1f32, 1f32], --- [0f32, 1f32, 1f32, 1f32, 1f32], --- [0f32, 0f32, 1f32, 1f32, 1f32], --- [0f32, 0f32, 0f32, 1f32, 1f32], --- [0f32, 0f32, 0f32, 0f32, 1f32]] --- } - -def f (xs: []f32) = scan (+) 0 xs - -entry fwd_vec (xs: []f32) : [][]f32 = - let seeds = - map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 - -entry fwd_map (xs: []f32) : [][]f32 = - map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) - (indices xs) - -entry rev_map (xs: []f32) : [][]f32 = - map (\i -> vjp f xs (map (\j -> f32.bool (i == j)) (indices xs))) - (indices xs) - |> transpose - -entry rev_vec (xs: []f32) : [][]f32 = - let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in vjp_vec f xs seeds - |> transpose diff --git a/tests/ad/vec/scatter0.fut b/tests/ad/vec/scatter0.fut deleted file mode 100644 index 1947e12d50..0000000000 --- a/tests/ad/vec/scatter0.fut +++ /dev/null @@ -1,22 +0,0 @@ --- Simple scatter, differentiating wrt. values. --- == --- entry: fwd fwd_vec --- input { [0f32, 0f32, 0f32, 0f32] [0i64, 1i64, 2i64, 3i64] [1f32, 2f32, 3f32, 0f32] } --- output { --- [[1.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], --- [0.000000f32, 1.000000f32, 0.000000f32, 0.000000f32], --- [0.000000f32, 0.000000f32, 1.000000f32, 0.000000f32], --- [0.000000f32, 0.000000f32, 0.000000f32, 1.000000f32]] --- } - -def f [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = - scatter (copy xs) is vs - -entry fwd [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = - let g i = jvp (\vs -> f xs is vs) vs (replicate n 0 with [i] = 1) - in tabulate n g - -entry fwd_vec [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = - let seeds = - map (\i -> map (\j -> f32.bool (i == j)) (indices vs)) (indices vs) - in jvp_vec (\vs -> f xs is vs) vs seeds diff --git a/tests/ad/vec/scatter1.fut b/tests/ad/vec/scatter1.fut deleted file mode 100644 index c9e6c83a25..0000000000 --- a/tests/ad/vec/scatter1.fut +++ /dev/null @@ -1,22 +0,0 @@ --- Simple scatter, differentiating wrt. target. --- == --- entry: fwd fwd_vec --- input { [0f32, 0f32, 0f32, 0f32] [0i64, 1i64] [1f32, 2f32] } --- output { --- [[0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], --- [0.000000f32, 0.000000f32, 0.000000f32, 0.000000f32], --- [0.000000f32, 0.000000f32, 1.000000f32, 0.000000f32], --- [0.000000f32, 0.000000f32, 0.000000f32, 1.000000f32]] --- } - -def f [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = - scatter (copy xs) is vs - -entry fwd [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = - let g i = jvp (\xs -> f xs is vs) xs (replicate k 0 with [i] = 1) - in tabulate k g - -entry fwd_vec [n] [k] (xs: [k]f32) (is: [n]i64) (vs: [n]f32) = - let seeds = - map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in jvp_vec (\xs -> f xs is vs) xs seeds From 0afeea9df4b2d71aa038a7d7c2ef2b21a61acf57 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 14:17:31 +0200 Subject: [PATCH 56/70] Refresh terminology. --- CHANGELOG.md | 2 ++ prelude/ad.fut | 45 ++++++++++++++++--------------- src/Futhark/AD/Fwd.hs | 9 +++---- src/Futhark/AD/Rev/Hist.hs | 8 +++--- src/Futhark/AD/Rev/Monad.hs | 6 ++--- src/Futhark/AD/Rev/Scan.hs | 4 +-- src/Futhark/Internalise/Exps.hs | 6 ++--- src/Language/Futhark/Prop.hs | 4 +-- tests/ad/arr0.fut | 2 +- tests/ad/arr1.fut | 2 +- tests/ad/concat0.fut | 2 +- tests/ad/consume0.fut | 4 +-- tests/ad/consume1.fut | 4 +-- tests/ad/for1.fut | 10 +++---- tests/ad/for2.fut | 10 +++---- tests/ad/for3.fut | 10 +++---- tests/ad/gather0.fut | 4 +-- tests/ad/gather1.fut | 4 +-- tests/ad/gather2.fut | 4 +-- tests/ad/hist_add.fut | 4 +-- tests/ad/hist_complex.fut | 2 +- tests/ad/hist_minmax.fut | 4 +-- tests/ad/hist_mul.fut | 2 +- tests/ad/index.fut | 2 +- tests/ad/issue2256.fut | 2 +- tests/ad/map0.fut | 4 +-- tests/ad/map1.fut | 2 +- tests/ad/map2.fut | 2 +- tests/ad/map3.fut | 2 +- tests/ad/map4.fut | 4 +-- tests/ad/map5.fut | 4 +-- tests/ad/map6.fut | 4 +-- tests/ad/map7.fut | 4 +-- tests/ad/matmul.fut | 4 +-- tests/ad/maximum.fut | 2 +- tests/ad/minimum.fut | 2 +- tests/ad/minmax.fut | 2 +- tests/ad/reduce-vec-minmax0.fut | 4 +-- tests/ad/reduce0.fut | 4 +-- tests/ad/reduce1.fut | 4 +-- tests/ad/reduce2.fut | 2 +- tests/ad/reduce3.fut | 4 +-- tests/ad/reduce_by_index0.fut | 6 ++--- tests/ad/reducebyindex3.fut | 2 +- tests/ad/reducebyindex4.fut | 2 +- tests/ad/reducebyindexminmax3.fut | 2 +- tests/ad/reducebyindexminmax4.fut | 2 +- tests/ad/reducebyindexminmax7.fut | 4 +-- tests/ad/reducebyindexminmax8.fut | 4 +-- tests/ad/reducemul0.fut | 2 +- tests/ad/reducemul4.fut | 4 +-- tests/ad/reducevec0.fut | 4 +-- tests/ad/replicate0.fut | 2 +- tests/ad/reshape0.fut | 2 +- tests/ad/scan0.fut | 4 +-- tests/ad/scan1.fut | 4 +-- tests/ad/scan2.fut | 4 +-- tests/ad/scan3.fut | 4 +-- tests/ad/scan4.fut | 4 +-- tests/ad/scan5.fut | 4 +-- tests/ad/scan6.fut | 8 +++--- tests/ad/scan7.fut | 4 +-- tests/ad/scan8.fut | 4 +-- tests/ad/scan9.fut | 4 +-- tests/ad/scatter0.fut | 4 +-- tests/ad/scatter1.fut | 4 +-- tests/ad/stripmine1.fut | 10 +++---- tests/ad/stripmine2.fut | 10 +++---- tests/ad/sum.fut | 2 +- tests/ad/transpose.fut | 2 +- tests/ad/truedep0.fut | 10 +++---- 71 files changed, 168 insertions(+), 164 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 778e3874d8..daf32b125b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. uses the hardware support for `f16`, similarly to the CUDA backend. Implemented by Jérôme Wagner. (#2470) +* Vector AD, exposed through the functions `jmp` and `mjp`. + ### Removed ### Changed diff --git a/prelude/ad.fut b/prelude/ad.fut index 58c4ed6d6d..8810d5463a 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -25,7 +25,8 @@ -- -- * Custom derivatives (`with_vjp`@term). -- --- * Vectorised AD (`vjp_vec`@term, `vjp_vec`@term). +-- * Vector AD (`mjp`@term, `jmp`@term), sometimes also known as "batched" or +-- "multi-directional" AD. -- -- * Checkpointing of sequential loops. -- @@ -92,8 +93,8 @@ -- but it can still be substantial for programs with deep sequential -- loops. -- --- It varies on a case-by-case basis whether vectorised AD is faster or not. It --- essentially converts propagation of (co-)tangents from scalar to array +-- It varies on a case-by-case basis whether vector AD is faster or not. Vector +-- AD essentially converts propagation of (co-)tangents from scalar to array -- operations, which can have a significant impact on memory accesses, depending -- on how the compiler manages to optimise the resulting code. It is hard to -- predict whether this offsets the reduction in primal work. If the vector size @@ -131,17 +132,19 @@ def jvp2 'a 'b (f: a -> b) (x: a) (x': a) : (b, b) = def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = intrinsics.vjp2 f x y' --- | As `jvp2`, but accepts a vector of seed values. Semantically equivalent to --- mapping, but may be more efficient. If used with `#[unroll]`, tangent --- calculations are unrolled when possible. -def jvp2_vec 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = - intrinsics.jvp2_vec f x x' +-- | Jacobian-Matrix Product, returning also the primal result. As `jvp2`, but +-- accepts a vector of seed values. Semantically equivalent to mapping, but may +-- be more efficient. If used with `#[unroll]`, tangent calculations are +-- unrolled when possible. +def jmp2 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = + intrinsics.jmp2 f x x' --- | As `vjp2`, but accepts a vector of seed values. Semantically equivalent to --- mapping, but may be more efficient. If used with `#[unroll]`, adjoint --- calculations are unrolled when possible. -def vjp2_vec 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) = - intrinsics.vjp2_vec f x y' +-- | Matrix-Jacobian Product, returning also the primal result. As `vjp2`, but +-- accepts a vector of seed values. Semantically equivalent to mapping, but may +-- be more efficient. If used with `#[unroll]`, adjoint calculations are +-- unrolled when possible. +def mjp2 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) = + intrinsics.mjp2 f x y' -- | Jacobian-Vector Product ("forward mode"). def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = @@ -151,15 +154,15 @@ def jvp 'a 'b (f: a -> b) (x: a) (x': a) : b = def vjp 'a 'b (f: a -> b) (x: a) (y': b) : a = (vjp2 f x y').1 --- | As `jvp`, but accepts a vector of seed values. Semantically --- equivalent to mapping, but may be more efficient. -def jvp_vec 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : [n]b = - (jvp2_vec f x x').1 +-- | Jacobian-Matrix Product. As `jvp`, but accepts a vector of seed values. +-- Semantically equivalent to mapping, but may be more efficient. +def jmp 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : [n]b = + (jmp2 f x x').1 --- | As `vjp`, but accepts a vector of seed values. Semantically --- equivalent to mapping, but may be more efficient. -def vjp_vec 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a = - (vjp2_vec f x y').1 +-- | Matrix-Jacobian product. As `vjp`, but accepts a vector of seed values. +-- Semantically equivalent to mapping, but may be more efficient. +def mjp 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a = + (mjp2 f x y').1 -- | Provide custom reverse-mode adjoint code for a given function. This is -- useful when the adjoint synthesised by AD is not as good as one that is known diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 39904b0e29..73f500b3eb 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -646,10 +646,9 @@ fwdJVP scope shape attrs (Lambda params _ body) = mkLambda (params <> params_tan) $ bodyBind =<< fwdBodyTansLast body --- Note [Forward-Mode vectorised AD] +-- Note [Forward-Mode vector AD] -- -- An primal variable of type 't' has a tangent of type '[tan_shape]t', where --- 'tan_shape' is the vector shape (which may be empty in the non-vectorised --- case). This requires some care for SOACs, which always map across the --- outermost dimension: basically we have to transpose the inputs and the --- outputs. +-- 'tan_shape' is the vector shape (which may be empty in the non-vector case). +-- This requires some care for SOACs, which always map across the outermost +-- dimension: basically we have to transpose the inputs and the outputs. diff --git a/src/Futhark/AD/Rev/Hist.hs b/src/Futhark/AD/Rev/Hist.hs index 1acb5d6772..f988eaeca5 100644 --- a/src/Futhark/AD/Rev/Hist.hs +++ b/src/Futhark/AD/Rev/Hist.hs @@ -241,7 +241,7 @@ diffMinMaxHist _ops x aux n minmax ne is vs w rf dst m = do m - locallyNonvectorised (x, dst, vs) $ do + locallyNonvector (x, dst, vs) $ do x_bar <- lookupAdjVal x x_ind_dst <- newParam (baseName x <> "_ind_param") $ Prim int64 @@ -419,7 +419,7 @@ diffMulHist _ops x aux n mul ne is vs w rf dst m = do m - locallyNonvectorised (x, dst, vs) $ do + locallyNonvector (x, dst, vs) $ do x_bar <- lookupAdjVal x lam_mul'' <- renameLambda lam_mul' @@ -504,7 +504,7 @@ diffAddHist _ops x aux n add ne is vs w rf dst m = do m - locallyNonvectorised (x, dst, vs) $ do + locallyNonvector (x, dst, vs) $ do x_bar <- lookupAdjVal x updateAdj dst x_bar @@ -798,7 +798,7 @@ diffHist ops xs aux n lam0 ne as w rf dst m = do m - locallyNonvectorised (xs, dst, lam0, as) $ do + locallyNonvector (xs, dst, lam0, as) $ do xs_bar <- traverse lookupAdjVal xs (dst_params, hp_params, f') <- mkF' lam0 dst_type $ head w diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 41eabef7c5..173e872608 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -54,7 +54,7 @@ module Futhark.AD.Rev.Monad substLoopTape, renameLoopTape, -- - locallyNonvectorised, + locallyNonvector, vecToInner, ) where @@ -598,14 +598,14 @@ renameLoopTape = mapM_ (uncurry substLoopTape) . M.toList -- that computes each adjoint explicitly, then assembles the resulting adjoint -- vectors. This is useful for constructs (such as scans) where vectorised AD is -- impractical or inefficient. -locallyNonvectorised :: +locallyNonvector :: (FreeIn e) => -- | Something that represents all the free variables used in the action. -- Usually just an expression or statement. e -> ADM () -> ADM () -locallyNonvectorised e m = do +locallyNonvector e m = do adj_shape <- askShape if adj_shape == mempty then m diff --git a/src/Futhark/AD/Rev/Scan.hs b/src/Futhark/AD/Rev/Scan.hs index b6f8ca01b6..3ec61705d8 100644 --- a/src/Futhark/AD/Rev/Scan.hs +++ b/src/Futhark/AD/Rev/Scan.hs @@ -379,7 +379,7 @@ finalMapPPAD ops as scan = do eLambda op_bar_2 $ toExp . Var . paramName <$> par_y_right ++ par_a diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM () -diffScan ops ys w as scan = locallyNonvectorised (ys, scan, as) $ do +diffScan ops ys w as scan = locallyNonvector (ys, scan, as) $ do -- ys ~ results of scan, w ~ size of input array, as ~ (unzipped) -- arrays, scan ~ scan: operator with ne scan_case <- identifyCase ops $ scanLambda scan @@ -466,7 +466,7 @@ diffScanVec ops ys aux w lam ne as m = do foldr (vjpStm ops) m stmts diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM () -diffScanAdd _ops ys n lam' ne as = locallyNonvectorised (ys, lam', as) $ do +diffScanAdd _ops ys n lam' ne as = locallyNonvector (ys, lam', as) $ do lam <- renameLambda lam' ys_bar <- lookupAdjVal ys diff --git a/src/Futhark/Internalise/Exps.hs b/src/Futhark/Internalise/Exps.hs index 7c12d00267..7550402269 100644 --- a/src/Futhark/Internalise/Exps.hs +++ b/src/Futhark/Internalise/Exps.hs @@ -1869,7 +1869,7 @@ isIntrinsicFunction qname args = do handleAccs _ _ = Nothing handleAD [f, x, v] fname - | fname `elem` ["jvp2", "vjp2", "jvp2_vec", "vjp2_vec"] = Just $ \desc -> do + | fname `elem` ["jvp2", "vjp2", "jmp2", "mjp2"] = Just $ \desc -> do x' <- internaliseExp "ad_x" x v' <- internaliseExp "ad_v" v x_t <- subExpType $ head x' @@ -1879,9 +1879,9 @@ isIntrinsicFunction qname args = do case fname of "jvp2" -> JVP mempty x' v' lam "vjp2" -> VJP mempty x' v' lam - "jvp2_vec" -> + "jmp2" -> JVP (vecShape x_t v_t) x' v' lam - "vjp2_vec" -> + "mjp2" -> VJP (vecShape (head (lambdaReturnType lam)) v_t) x' v' lam _ -> error "handleAD: not supposed to happen." where diff --git a/src/Language/Futhark/Prop.hs b/src/Language/Futhark/Prop.hs index 7f79a005c7..eed67d707c 100644 --- a/src/Language/Futhark/Prop.hs +++ b/src/Language/Futhark/Prop.hs @@ -973,7 +973,7 @@ intrinsics = $ Scalar $ tupleRecord [Scalar $ t_b Nonunique, Scalar $ t_a Nonunique] ), - ( "jvp2_vec", + ( "jmp2", IntrinsicPolyFun [tp_a, tp_b, sp_n] [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), @@ -987,7 +987,7 @@ intrinsics = array_b Unique $ shape [n] ] ), - ( "vjp2_vec", + ( "mjp2", IntrinsicPolyFun [tp_a, tp_b, sp_n] [ Scalar (t_a mempty) `arr` Scalar (t_b Nonunique), diff --git a/tests/ad/arr0.fut b/tests/ad/arr0.fut index 8db3185287..d1d09f0cea 100644 --- a/tests/ad/arr0.fut +++ b/tests/ad/arr0.fut @@ -14,7 +14,7 @@ entry fwd xs = ] entry fwd_vec xs = - jvp_vec primal xs [[1, 0], [0, 1]] + jmp primal xs [[1, 0], [0, 1]] -- == -- entry: rev diff --git a/tests/ad/arr1.fut b/tests/ad/arr1.fut index 584feaee15..9551a9ce83 100644 --- a/tests/ad/arr1.fut +++ b/tests/ad/arr1.fut @@ -12,4 +12,4 @@ entry fwd x y = ] entry fwd_vec x y = - jvp_vec primal (x, y) [(1, 0), (0, 1)] + jmp primal (x, y) [(1, 0), (0, 1)] diff --git a/tests/ad/concat0.fut b/tests/ad/concat0.fut index a3890dbefc..f9f652edc6 100644 --- a/tests/ad/concat0.fut +++ b/tests/ad/concat0.fut @@ -9,7 +9,7 @@ def f (xs: []f64) = xs ++ xs entry fwd_vec (xs: []f64) = let seeds = map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 + in (jmp2 f xs seeds).1 entry fwd_map (xs: []f64) = map (\i -> jvp f xs (map (\j -> f64.bool (i == j)) (indices xs))) diff --git a/tests/ad/consume0.fut b/tests/ad/consume0.fut index 4df14e3ed6..589bc4c916 100644 --- a/tests/ad/consume0.fut +++ b/tests/ad/consume0.fut @@ -19,9 +19,9 @@ entry rev [n] (xs: *[n]f64) = entry fwd_vec [n] (xs: *[n]f64) = #[unsafe] let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec f xs seeds + in jmp f xs seeds entry rev_vec [n] (xs: *[n]f64) = #[unsafe] let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec f xs seeds + in mjp f xs seeds diff --git a/tests/ad/consume1.fut b/tests/ad/consume1.fut index f5cd718edf..8c88d76483 100644 --- a/tests/ad/consume1.fut +++ b/tests/ad/consume1.fut @@ -20,9 +20,9 @@ entry rev [n] b (xs: *[n]f64) = entry fwd_vec [n] b (xs: *[n]f64) = #[unsafe] let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (f b) xs seeds + in jmp (f b) xs seeds entry rev_vec [n] b (xs: *[n]f64) = #[unsafe] let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec (f b) xs seeds + in mjp (f b) xs seeds diff --git a/tests/ad/for1.fut b/tests/ad/for1.fut index 098079ef05..f356cd02c9 100644 --- a/tests/ad/for1.fut +++ b/tests/ad/for1.fut @@ -12,7 +12,7 @@ def pow_list [n] y (xs: [n]i32) = entry prim y xs = pow_list y xs -- == --- entry: f_vjp f_jvp f_vjp_vec f_jvp_vec +-- entry: f_vjp f_jvp f_mjp f_jmp -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], @@ -24,11 +24,11 @@ entry f_jvp [n] y (xs: [n]i32) = entry f_vjp [n] y (xs: [n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) -entry f_jvp_vec [n] y (xs: [n]i32) = +entry f_jmp [n] y (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (pow_list y) xs seeds + in jmp (pow_list y) xs seeds |> transpose -entry f_vjp_vec [n] y (xs: [n]i32) = +entry f_mjp [n] y (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec (pow_list y) xs seeds + in mjp (pow_list y) xs seeds diff --git a/tests/ad/for2.fut b/tests/ad/for2.fut index 6e56a84d8f..bb111cbf4b 100644 --- a/tests/ad/for2.fut +++ b/tests/ad/for2.fut @@ -12,16 +12,16 @@ def mult_list xs = entry prim = mult_list -- == --- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [11,5,13] } output { [0,0,26] } entry f_jvp [n] (xs: [n]i32) = tabulate n (\i -> jvp mult_list xs (replicate n 0 with [i] = 1)) entry f_vjp [n] (xs: [n]i32) = vjp mult_list xs 1 -entry f_jvp_vec [n] (xs: [n]i32) = +entry f_jmp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec mult_list xs seeds + in jmp mult_list xs seeds -entry f_vjp_vec [n] (xs: [n]i32) = - (vjp_vec mult_list xs [1])[0] +entry f_mjp [n] (xs: [n]i32) = + (mjp mult_list xs [1])[0] diff --git a/tests/ad/for3.fut b/tests/ad/for3.fut index b3755c8178..42350aebbd 100644 --- a/tests/ad/for3.fut +++ b/tests/ad/for3.fut @@ -14,7 +14,7 @@ def square [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = square xs -- == --- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], @@ -28,11 +28,11 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) -entry f_jvp_vec [n] (xs: [n]i32) = +entry f_jmp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec square xs seeds + in jmp square xs seeds |> transpose -entry f_vjp_vec [n] (xs: [n]i32) = +entry f_mjp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec square xs seeds + in mjp square xs seeds diff --git a/tests/ad/gather0.fut b/tests/ad/gather0.fut index 226622aa14..76d4b9727c 100644 --- a/tests/ad/gather0.fut +++ b/tests/ad/gather0.fut @@ -24,8 +24,8 @@ entry rev_J [n] [m] (xs: [n]f64) (is: [m]i64) = entry fwd_vec_J [n] [m] (xs: [n]f64) (is: [m]i64) = let seeds = tabulate n (\j -> replicate n 0 with [j] = 1) - in transpose (jvp_vec (`gather` is) xs seeds) + in transpose (jmp (`gather` is) xs seeds) entry rev_vec_J [n] [m] (xs: [n]f64) (is: [m]i64) = let seeds = tabulate m (\j -> replicate m 0 with [j] = 1) - in vjp_vec (`gather` is) xs seeds + in mjp (`gather` is) xs seeds diff --git a/tests/ad/gather1.fut b/tests/ad/gather1.fut index 34401c40f8..3ef72e17d1 100644 --- a/tests/ad/gather1.fut +++ b/tests/ad/gather1.fut @@ -46,7 +46,7 @@ entry rev_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = entry fwd_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = let seeds = tabulate (n * m) (\p -> onehot_2d n m (p / m, p % m)) - in jvp_vec (`mapgather` is) xs seeds + in jmp (`mapgather` is) xs seeds |> unflatten |> map transpose |> map (map transpose) @@ -54,5 +54,5 @@ entry fwd_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = entry rev_vec_J [n] [m] [k] (xs: [n][m]f64) (is: [k]i64) = let seeds = tabulate (n * k) (\p -> onehot_2d n k (p / k, p % k)) - in vjp_vec (`mapgather` is) xs seeds + in mjp (`mapgather` is) xs seeds |> unflatten diff --git a/tests/ad/gather2.fut b/tests/ad/gather2.fut index 0dd24c82bd..3920477bc8 100644 --- a/tests/ad/gather2.fut +++ b/tests/ad/gather2.fut @@ -35,11 +35,11 @@ entry rev_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = entry fwd_vec_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = let seeds = tabulate k (\i -> onehot k i) - in jvp_vec (`mapgather` iss) xs seeds + in jmp (`mapgather` iss) xs seeds |> transpose |> map transpose entry rev_vec_J [k] [n] [m] (xs: [k]f64) (iss: [n][m]i64) = let seeds = tabulate (n * m) (\p -> onehot_2d n m (p / m, p % m)) - in vjp_vec (`mapgather` iss) xs seeds + in mjp (`mapgather` iss) xs seeds |> unflatten diff --git a/tests/ad/hist_add.fut b/tests/ad/hist_add.fut index 25c5a1ee1b..26d275f320 100644 --- a/tests/ad/hist_add.fut +++ b/tests/ad/hist_add.fut @@ -40,7 +40,7 @@ entry fwd_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = ( tabulate n ((i ==) >-> f32.bool) , tabulate m (((i - n) ==) >-> f32.bool) )) - in jvp_vec (f is) (vs, c) seeds + in jmp (f is) (vs, c) seeds |> transpose |> map split |> unzip @@ -51,5 +51,5 @@ entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) - in vjp_vec (f is) (vs, c) seeds + in mjp (f is) (vs, c) seeds |> unzip diff --git a/tests/ad/hist_complex.fut b/tests/ad/hist_complex.fut index 8d4af373f9..0f3042c90a 100644 --- a/tests/ad/hist_complex.fut +++ b/tests/ad/hist_complex.fut @@ -35,5 +35,5 @@ entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) - in vjp_vec (f is) (vs, c) seeds + in mjp (f is) (vs, c) seeds |> unzip diff --git a/tests/ad/hist_minmax.fut b/tests/ad/hist_minmax.fut index d6792209df..a21ecff07f 100644 --- a/tests/ad/hist_minmax.fut +++ b/tests/ad/hist_minmax.fut @@ -24,7 +24,7 @@ entry fwd_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = entry fwd_vec [n] (k: i64) (is: [n]i64) (vs: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (primal k is) vs seeds + in jmp (primal k is) vs seeds |> transpose entry rev_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = @@ -32,4 +32,4 @@ entry rev_map [n] (k: i64) (is: [n]i64) (vs: [n]f32) = entry rev_vec [n] (k: i64) (is: [n]i64) (vs: [n]f32) = let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) - in vjp_vec (primal k is) vs seeds + in mjp (primal k is) vs seeds diff --git a/tests/ad/hist_mul.fut b/tests/ad/hist_mul.fut index 08e08b1b0f..878aaaab47 100644 --- a/tests/ad/hist_mul.fut +++ b/tests/ad/hist_mul.fut @@ -36,5 +36,5 @@ entry rev_map [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = entry rev_vec [n] [m] (is: [n]i64) (vs: [n]f32) (c: [m]f32) = let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) - in vjp_vec (f is) (vs, c) seeds + in mjp (f is) (vs, c) seeds |> unzip diff --git a/tests/ad/index.fut b/tests/ad/index.fut index c2480985ad..677160d33b 100644 --- a/tests/ad/index.fut +++ b/tests/ad/index.fut @@ -9,7 +9,7 @@ def f (i: i32) (xs: []f32) = xs[i] entry fwd_vec l (xs: []f32) : []f32 = let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec (f l) xs seeds).1 + in (jmp2 (f l) xs seeds).1 entry fwd_map l (xs: []f32) : []f32 = map (\i -> jvp (f l) xs (map (\j -> f32.bool (i == j)) (indices xs))) diff --git a/tests/ad/issue2256.fut b/tests/ad/issue2256.fut index 512d78af2e..f3d78219f0 100644 --- a/tests/ad/issue2256.fut +++ b/tests/ad/issue2256.fut @@ -16,4 +16,4 @@ entry fwd [m] (x: [m]f64) = entry fwd_vec [m] (x: [m]f64) = let seeds = tabulate m (\i -> replicate m 0 with [i] = 1) - in jvp_vec (\x' -> primal x') x seeds + in jmp (\x' -> primal x') x seeds diff --git a/tests/ad/map0.fut b/tests/ad/map0.fut index 4d5a125194..672c60ddf1 100644 --- a/tests/ad/map0.fut +++ b/tests/ad/map0.fut @@ -11,11 +11,11 @@ entry fwd_map [n] (xs: [n]f32) (ys: [n]f32) = entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (prim xs) ys seeds + in jmp (prim xs) ys seeds entry rev_map [n] (xs: [n]f32) (ys: [n]f32) = transpose (tabulate n (\i -> vjp (prim xs) ys (replicate n 0 with [i] = 1))) entry rev_vec [n] (xs: [n]f32) (ys: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in transpose (vjp_vec (prim xs) ys seeds) + in transpose (mjp (prim xs) ys seeds) diff --git a/tests/ad/map1.fut b/tests/ad/map1.fut index 2acc256758..56467b1696 100644 --- a/tests/ad/map1.fut +++ b/tests/ad/map1.fut @@ -15,4 +15,4 @@ entry fwd [n] (xs: [n]f32) (ys: [n]f32) = entry fwd_vec [n] (xs: [n]f32) (ys: [n]f32) = let seeds = tabulate k (\i -> (replicate n 0 with [i] = 1, replicate n 0 with [i] = 1)) - in jvp_vec (uncurry prim) (xs, ys) seeds + in jmp (uncurry prim) (xs, ys) seeds diff --git a/tests/ad/map2.fut b/tests/ad/map2.fut index 8c96c57cae..59d35c2dab 100644 --- a/tests/ad/map2.fut +++ b/tests/ad/map2.fut @@ -18,4 +18,4 @@ entry rev_map [n] (c: f64) (xs: [n]f64) = entry rev_vec [n] (c: f64) (xs: [n]f64) = let seeds = tabulate n (\i -> onehot n i) - in vjp_vec (primal xs) c seeds + in mjp (primal xs) c seeds diff --git a/tests/ad/map3.fut b/tests/ad/map3.fut index 0a4aa89880..12caf5ac51 100644 --- a/tests/ad/map3.fut +++ b/tests/ad/map3.fut @@ -15,4 +15,4 @@ entry rev_map [n] (x: i32) (xs: [n]i32) = entry rev_vec [n] (x: i32) (xs: [n]i32) = let seeds = tabulate n (\i -> (replicate n 0 with [i] = 1)) - in vjp_vec (primal xs) x seeds + in mjp (primal xs) x seeds diff --git a/tests/ad/map4.fut b/tests/ad/map4.fut index e2660afada..b7b00e250a 100644 --- a/tests/ad/map4.fut +++ b/tests/ad/map4.fut @@ -24,7 +24,7 @@ entry fwd_map [n] (xs: [n]i32) = entry fwd_vec [n] (xs: [n]i32) = let seeds = tabulate n (\i -> onehot n i) - in jvp_vec primal xs seeds + in jmp primal xs seeds |> map transpose |> transpose |> map transpose @@ -34,4 +34,4 @@ entry rev_map [n] (xs: [n]i32) = entry rev_vec [n] (xs: [n]i32) = let seeds = tabulate_2d n n (\i j -> onehot_2d n n (i, j)) - in unflatten (vjp_vec primal xs (flatten seeds)) + in unflatten (mjp primal xs (flatten seeds)) diff --git a/tests/ad/map5.fut b/tests/ad/map5.fut index e90dd46328..3e64552ee9 100644 --- a/tests/ad/map5.fut +++ b/tests/ad/map5.fut @@ -19,8 +19,8 @@ entry rev_J [n] [m] (free: [n][m]i32) (is: [n]i32) = entry fwd_vec_J [n] [m] (free: [n][m]i32) (is: [n]i32) = let seeds = tabulate n (\i -> onehot n i) - in jvp_vec (f free) is seeds |> transpose + in jmp (f free) is seeds |> transpose entry rev_vec_J [n] [m] (free: [n][m]i32) (is: [n]i32) = let seeds = tabulate n (\i -> onehot n i) - in vjp_vec (f free) is seeds + in mjp (f free) is seeds diff --git a/tests/ad/map6.fut b/tests/ad/map6.fut index f46bb5a984..2748fe972d 100644 --- a/tests/ad/map6.fut +++ b/tests/ad/map6.fut @@ -32,8 +32,8 @@ entry rev_J (x: [8]f64) = entry fwd_vec_J (x: [8]f64) = let seeds = tabulate 8 (\i -> replicate 8 0 with [i] = 1) - in jvp_vec obj x seeds + in jmp obj x seeds entry rev_vec_J (x: [8]f64) = let seeds = tabulate 4 (\i -> replicate 4 0 with [i] = 1) - in transpose (vjp_vec obj x seeds) + in transpose (mjp obj x seeds) diff --git a/tests/ad/map7.fut b/tests/ad/map7.fut index 5d3da75fc7..ed642f0ca2 100644 --- a/tests/ad/map7.fut +++ b/tests/ad/map7.fut @@ -29,10 +29,10 @@ entry fwd_map (x: [8]f64) = tabulate 8 (\i -> jvp obj x (replicate 8 0 with [i] = 1)) entry fwd_vec (x: [8]f64) = - jvp_vec obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) + jmp obj x (tabulate 8 (\i -> replicate 8 0 with [i] = 1)) entry rev_map (x: [8]f64) = transpose (tabulate 4 (\i -> vjp obj x (replicate 4 0 with [i] = 1))) entry rev_vec (x: [8]f64) = - transpose (#[unroll] vjp_vec obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) + transpose (#[unroll] mjp obj x (tabulate 4 (\i -> replicate 4 0 with [i] = 1))) diff --git a/tests/ad/matmul.fut b/tests/ad/matmul.fut index 309096f811..676fbf6dcc 100644 --- a/tests/ad/matmul.fut +++ b/tests/ad/matmul.fut @@ -35,7 +35,7 @@ entry rev_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = entry fwd_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = let seeds = tabulate (m * p) (\k -> onehot_2d m p (k / p, k % p)) - in jvp_vec (matmul xss) yss seeds + in jmp (matmul xss) yss seeds |> unflatten |> transpose |> map transpose @@ -43,5 +43,5 @@ entry fwd_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = entry rev_vec_J [n] [m] [p] (xss: [n][m]f64) (yss: [m][p]f64) = let seeds = tabulate (n * p) (\k -> onehot_2d n p (k / p, k % p)) - in vjp_vec (matmul xss) yss seeds + in mjp (matmul xss) yss seeds |> unflatten diff --git a/tests/ad/maximum.fut b/tests/ad/maximum.fut index 2e6f149d9e..1d1389f63d 100644 --- a/tests/ad/maximum.fut +++ b/tests/ad/maximum.fut @@ -17,4 +17,4 @@ entry fwd [n] (xs: [n]f64) = entry fwd_vec [n] (xs: [n]f64) = let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) - in jvp_vec f xs seeds + in jmp f xs seeds diff --git a/tests/ad/minimum.fut b/tests/ad/minimum.fut index a47e645895..a99c01b850 100644 --- a/tests/ad/minimum.fut +++ b/tests/ad/minimum.fut @@ -14,4 +14,4 @@ entry fwd [n] (xs: [n]f64) = entry fwd_vec [n] (xs: [n]f64) = let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) - in jvp_vec f64.minimum xs seeds + in jmp f64.minimum xs seeds diff --git a/tests/ad/minmax.fut b/tests/ad/minmax.fut index 17dabc103a..54ae27dbe5 100644 --- a/tests/ad/minmax.fut +++ b/tests/ad/minmax.fut @@ -21,4 +21,4 @@ entry fwd [n] (xs: [n]f64) = entry fwd_vec [n] (xs: [n]f64) = let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) - in unzip (jvp_vec f xs seeds) + in unzip (jmp f xs seeds) diff --git a/tests/ad/reduce-vec-minmax0.fut b/tests/ad/reduce-vec-minmax0.fut index ba6572a06e..336e9c7037 100644 --- a/tests/ad/reduce-vec-minmax0.fut +++ b/tests/ad/reduce-vec-minmax0.fut @@ -19,13 +19,13 @@ def forward_vec [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = let i = p / n let j = p % n in replicate m (replicate n 0) with [i] = (replicate n 0 with [j] = 1)) - in jvp_vec redmap arr seeds + in jmp redmap arr seeds |> unflatten |> transpose def reverse_vec [n] [m] (arr: [m][n]f32) : [n][m][n]f32 = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec redmap arr seeds + in mjp redmap arr seeds def main [n] [m] (arr: [m][n]f32) : bool = let l = n * m * n diff --git a/tests/ad/reduce0.fut b/tests/ad/reduce0.fut index 65822f3bfb..7f57f1f992 100644 --- a/tests/ad/reduce0.fut +++ b/tests/ad/reduce0.fut @@ -9,7 +9,7 @@ def f (xs: []f32) = f32.product xs entry fwd_vec (xs: []f32) : []f32 = let seeds = map (\i -> map (\j -> f32.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec f xs seeds).1 + in (jmp2 f xs seeds).1 entry fwd_map (xs: []f32) : []f32 = map (\i -> jvp f xs (map (\j -> f32.bool (i == j)) (indices xs))) @@ -19,4 +19,4 @@ entry fwd_map (xs: []f32) : []f32 = -- enough already. entry rev_vec (xs: []f32) : []f32 = - head (vjp_vec f xs [1]) + head (mjp f xs [1]) diff --git a/tests/ad/reduce1.fut b/tests/ad/reduce1.fut index 25286aa9fb..c43b35def0 100644 --- a/tests/ad/reduce1.fut +++ b/tests/ad/reduce1.fut @@ -54,7 +54,7 @@ entry fwd_map [n] (input: [n][4]f32) : [4][n][4]f32 = entry fwd_vec [n] (input: [n][4]f32) : [4][n][4]f32 = let input = fromarrs input let seeds = tabulate (n * 4) (\i -> (fromarrs (onehot_2d n 4 (i / 4) (i % 4)))) - in jvp_vec primal input seeds + in jmp primal input seeds |> toarrs |> transpose |> map unflatten @@ -67,5 +67,5 @@ entry rev_map [n] (input: [n][4]f32) : [4][n][4]f32 = entry rev_vec [n] (input: [n][4]f32) : [4][n][4]f32 = let input = fromarrs input let seeds = tabulate 4 (\i -> fromarr (onehot_1d 4 i)) - in vjp_vec primal input seeds + in mjp primal input seeds |> map toarrs diff --git a/tests/ad/reduce2.fut b/tests/ad/reduce2.fut index 4a68f1761b..09d6ce980d 100644 --- a/tests/ad/reduce2.fut +++ b/tests/ad/reduce2.fut @@ -16,4 +16,4 @@ entry fwd x = map (jvp f x) [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] entry rev x = vjp f x 1f64 entry fwd_vec x = - jvp_vec f x [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + jmp f x [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] diff --git a/tests/ad/reduce3.fut b/tests/ad/reduce3.fut index d070b6eb1a..4d4ad243f8 100644 --- a/tests/ad/reduce3.fut +++ b/tests/ad/reduce3.fut @@ -12,7 +12,7 @@ entry fwd_map [n] [k] (a: [n][k]f32) = tabulate n (\i -> jvp primal a (replicate n (replicate k 0) with [i] = replicate k 1)) entry fwd_vec [n] [k] (a: [n][k]f32) = - jvp_vec primal a (tabulate n (\i -> (replicate n (replicate k 0) with [i] = replicate k 1))) + jmp primal a (tabulate n (\i -> (replicate n (replicate k 0) with [i] = replicate k 1))) entry rev_vec [n] [k] (a: [n][k]f32) = - head (vjp_vec primal a [replicate k 1]) + head (mjp primal a [replicate k 1]) diff --git a/tests/ad/reduce_by_index0.fut b/tests/ad/reduce_by_index0.fut index 7663a48a69..42b617f4bb 100644 --- a/tests/ad/reduce_by_index0.fut +++ b/tests/ad/reduce_by_index0.fut @@ -1,6 +1,6 @@ -- == -- tags { autodiff } --- entry: f_jvp f_jvp_vec +-- entry: f_jvp f_jmp -- input { [0i64,1i64,2i64,3i64] [1f64,2f64,3f64,4f64] } -- output { [[1f64,0f64,0f64,0f64],[0f64,1f64,0f64,0f64],[0f64,0f64,1f64,0f64],[0f64,0f64,0f64,1f64]] } def f [n] (is: [n]i64) (vs: [n]f64) = @@ -10,7 +10,7 @@ entry f_jvp [n] (is: [n]i64) (vs: [n]f64) = tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) |> transpose -entry f_jvp_vec [n] (is: [n]i64) (vs: [n]f64) = +entry f_jmp [n] (is: [n]i64) (vs: [n]f64) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (f is) vs seeds + in jmp (f is) vs seeds |> transpose diff --git a/tests/ad/reducebyindex3.fut b/tests/ad/reducebyindex3.fut index dff88bc4b8..15ea98af85 100644 --- a/tests/ad/reducebyindex3.fut +++ b/tests/ad/reducebyindex3.fut @@ -20,7 +20,7 @@ entry rev [n] (is: [n]i64) (vs: [n]f64) = entry rev_vec [n] (is: [n]i64) (vs: [n]f64) = let seeds = tabulate 4 (\i -> replicate 4 0 with [i] = 1) - in vjp_vec (f is) vs seeds + in mjp (f is) vs seeds -- entry fwd [n] (is: [n]i64) (vs: [n]f64) = -- tabulate n (\i -> jvp (f is) vs (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindex4.fut b/tests/ad/reducebyindex4.fut index ba36aac019..b61e0273f1 100644 --- a/tests/ad/reducebyindex4.fut +++ b/tests/ad/reducebyindex4.fut @@ -20,7 +20,7 @@ entry rev [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = |> unzip entry rev_vec [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = - (vjp_vec (f is) (zip vs0 vs1) [replicate 4 (1, 1)])[0] + (mjp (f is) (zip vs0 vs1) [replicate 4 (1, 1)])[0] |> unzip entry fwd [n] (is: [n]i64) (vs0: [n]f32) (vs1: [n]f32) = diff --git a/tests/ad/reducebyindexminmax3.fut b/tests/ad/reducebyindexminmax3.fut index 0dfb533ed0..3f29347471 100644 --- a/tests/ad/reducebyindexminmax3.fut +++ b/tests/ad/reducebyindexminmax3.fut @@ -15,6 +15,6 @@ entry rev [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max dst is) (vs, c) (replicate m 0 with [0] = 1) entry rev_vec [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = - (vjp_vec (red_max dst is) (vs, c) [replicate m 0 with [0] = 1])[0] + (mjp (red_max dst is) (vs, c) [replicate m 0 with [0] = 1])[0] --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindexminmax4.fut b/tests/ad/reducebyindexminmax4.fut index 3954e2a5cc..3214886b0d 100644 --- a/tests/ad/reducebyindexminmax4.fut +++ b/tests/ad/reducebyindexminmax4.fut @@ -15,6 +15,6 @@ entry rev [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = vjp (red_max vs is) (dst, c) (replicate m 0 with [0] = 1) entry rev_vec [n] [m] (dst: [m]f32) (is: [n]i64) (vs: [n]f32) (c: f32) = - (vjp_vec (red_max vs is) (dst, c) [replicate m 0 with [0] = 1])[0] + (mjp (red_max vs is) (dst, c) [replicate m 0 with [0] = 1])[0] --tabulate n (\i -> vjp (red_max dst is) (vs, c) (replicate n 0 with [i] = 1)) diff --git a/tests/ad/reducebyindexminmax7.fut b/tests/ad/reducebyindexminmax7.fut index c17991b9ba..92a4739e33 100644 --- a/tests/ad/reducebyindexminmax7.fut +++ b/tests/ad/reducebyindexminmax7.fut @@ -16,9 +16,9 @@ entry fwd [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = entry rev_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let seeds = tabulate m (\i -> replicate m (replicate k 0) with [i] = replicate k 1) - in vjp_vec (primal is dst) vs seeds + in mjp (primal is dst) vs seeds entry fwd_vec [n] [m] [k] (is: [n]i64) (dst: [m][k]f32) (vs: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) - in jvp_vec (primal is dst) vs seeds + in jmp (primal is dst) vs seeds |> transpose diff --git a/tests/ad/reducebyindexminmax8.fut b/tests/ad/reducebyindexminmax8.fut index 6321dc3295..0c3fc324a1 100644 --- a/tests/ad/reducebyindexminmax8.fut +++ b/tests/ad/reducebyindexminmax8.fut @@ -14,11 +14,11 @@ def fwd2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = def rev_vec2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = let seeds = tabulate m (\i -> replicate m (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1)) - in vjp_vec (primal2 is dst) vs seeds + in mjp (primal2 is dst) vs seeds def fwd_vec2 [n] [m] [k] [l] (is: [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = let seeds = tabulate n (\i -> replicate n (replicate k (replicate l 0)) with [i] = replicate k (replicate l 1)) - in jvp_vec (primal2 is dst) vs seeds + in jmp (primal2 is dst) vs seeds |> transpose def main [n] [m] [k] [l] (is': [n]i64) (dst: [m][k][l]f32) (vs: [n][k][l]f32) = diff --git a/tests/ad/reducemul0.fut b/tests/ad/reducemul0.fut index 09dd2f5743..c9b4559d5a 100644 --- a/tests/ad/reducemul0.fut +++ b/tests/ad/reducemul0.fut @@ -14,4 +14,4 @@ entry fwd [n] (xs: [n]f32) = entry fwd_vec [n] (xs: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec red_mult xs seeds + in jmp red_mult xs seeds diff --git a/tests/ad/reducemul4.fut b/tests/ad/reducemul4.fut index 4408becbc9..d4a4e1664d 100644 --- a/tests/ad/reducemul4.fut +++ b/tests/ad/reducemul4.fut @@ -16,8 +16,8 @@ entry rev [n] (as: [n]f32) = entry fwd_vec [n] (as: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec fun as seeds |> transpose + in jmp fun as seeds |> transpose entry rev_vec [n] (as: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec fun as seeds + in mjp fun as seeds diff --git a/tests/ad/reducevec0.fut b/tests/ad/reducevec0.fut index fc2ad0c0f5..24a9dc1fe6 100644 --- a/tests/ad/reducevec0.fut +++ b/tests/ad/reducevec0.fut @@ -24,7 +24,7 @@ entry fwd_vec [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = let j = (p % (m * k)) / k let l = p % k in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [l] = 1))) - let res = jvp_vec f xs seeds + let res = jmp f xs seeds in unflatten (sized (n * (m * k)) res) |> map unflatten |> transpose |> map transpose @@ -34,4 +34,4 @@ entry rev_vec [n] [m] [k] (xs: [n][m][k]f32) : [m][k][n][m][k]f32 = let i = p / k let j = p % k in replicate m (replicate k 0) with [i] = (replicate k 0 with [j] = 1)) - in unflatten (vjp_vec f xs seeds) + in unflatten (mjp f xs seeds) diff --git a/tests/ad/replicate0.fut b/tests/ad/replicate0.fut index d74ee052a6..e4e7b20713 100644 --- a/tests/ad/replicate0.fut +++ b/tests/ad/replicate0.fut @@ -11,7 +11,7 @@ def f (n: i64) (xs: []f64) = replicate n xs entry fwd_vec n (xs: []f64) = let seeds = map (\i -> map (\j -> f64.bool (i == j)) (indices xs)) (indices xs) - in (jvp2_vec (f n) xs seeds).1 + in (jmp2 (f n) xs seeds).1 entry fwd_map n (xs: []f64) = map (\i -> jvp (f n) xs (map (\j -> f64.bool (i == j)) (indices xs))) diff --git a/tests/ad/reshape0.fut b/tests/ad/reshape0.fut index e26b2ca472..1b42b89c42 100644 --- a/tests/ad/reshape0.fut +++ b/tests/ad/reshape0.fut @@ -11,7 +11,7 @@ entry fwd_map n m (xs: [n * m]i32) = entry fwd_vec n m (xs: [n * m]i32) = let seeds = tabulate 2 (\i -> replicate (n * m) 0 with [i] = 1) - in jvp_vec unflatten xs seeds + in jmp unflatten xs seeds -- == -- entry: f_vjp diff --git a/tests/ad/scan0.fut b/tests/ad/scan0.fut index 404687a08e..b539a4a32e 100644 --- a/tests/ad/scan0.fut +++ b/tests/ad/scan0.fut @@ -20,8 +20,8 @@ entry rev_J [n] (a: [n]f32) = entry fwd_vec_J [n] (a: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (scan (*) 1) a seeds |> transpose + in jmp (scan (*) 1) a seeds |> transpose entry rev_vec_J [n] (a: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec (scan (*) 1) a seeds + in mjp (scan (*) 1) a seeds diff --git a/tests/ad/scan1.fut b/tests/ad/scan1.fut index 73ac81b592..5debeb1e09 100644 --- a/tests/ad/scan1.fut +++ b/tests/ad/scan1.fut @@ -20,8 +20,8 @@ entry rev_J [n] (a: [n]f32) = entry fwd_vec_J [n] (a: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (scan (+) 0) a seeds |> transpose + in jmp (scan (+) 0) a seeds |> transpose entry rev_vec_J [n] (a: [n]f32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec (scan (+) 0) a seeds + in mjp (scan (+) 0) a seeds diff --git a/tests/ad/scan2.fut b/tests/ad/scan2.fut index 1ce55e3ec9..ea2d39fc3a 100644 --- a/tests/ad/scan2.fut +++ b/tests/ad/scan2.fut @@ -18,8 +18,8 @@ entry rev_J [n] [k] (a: [n][k]f32) = entry fwd_vec_J [n] [k] (a: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) - in jvp_vec primal a seeds |> transpose + in jmp primal a seeds |> transpose entry rev_vec_J [n] [k] (a: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) - in vjp_vec primal a seeds + in mjp primal a seeds diff --git a/tests/ad/scan3.fut b/tests/ad/scan3.fut index ac07c3595c..35a8890000 100644 --- a/tests/ad/scan3.fut +++ b/tests/ad/scan3.fut @@ -57,7 +57,7 @@ entry rev_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = entry fwd_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = let input = fromarrs input let seeds = tabulate (n * 4) (\i -> fromarrs (onehot_2d n 4 (i / 4) (i % 4))) - in jvp_vec primal input seeds + in jmp primal input seeds |> map toarrs |> transpose |> map transpose @@ -66,6 +66,6 @@ entry fwd_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = entry rev_vec_J [n] (input: [n][4]f32) : [n][4][n][4]f32 = let input = fromarrs input let seeds = tabulate (n * 4) (\i -> fromarrs (onehot_2d n 4 (i / 4) (i % 4))) - in vjp_vec primal input seeds + in mjp primal input seeds |> unflatten |> map (map toarrs) diff --git a/tests/ad/scan4.fut b/tests/ad/scan4.fut index 8d95b8440c..4f414bc333 100644 --- a/tests/ad/scan4.fut +++ b/tests/ad/scan4.fut @@ -35,12 +35,12 @@ entry rev_J [n] (input: [n][3]f32) = entry fwd_vec_J [n] (input: [n][3]f32) = let input = fromarrs input let seeds = tabulate n (\i -> replicate n (0, 0, 0) with [i] = (1, 1, 1)) - in jvp_vec primal input seeds + in jmp primal input seeds |> map toarrs |> transpose entry rev_vec_J [n] (input: [n][3]f32) = let input = fromarrs input let seeds = tabulate n (\i -> replicate n (0, 0, 0) with [i] = (1, 1, 1)) - in vjp_vec primal input seeds + in mjp primal input seeds |> map toarrs diff --git a/tests/ad/scan5.fut b/tests/ad/scan5.fut index 05295e13b5..f9e4aa0130 100644 --- a/tests/ad/scan5.fut +++ b/tests/ad/scan5.fut @@ -24,8 +24,8 @@ entry rev_J [n] [k] (a: [n][k]f32) = entry fwd_vec_J [n] [k] (a: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) - in jvp_vec primal a seeds |> transpose + in jmp primal a seeds |> transpose entry rev_vec_J [n] [k] (a: [n][k]f32) = let seeds = tabulate n (\i -> replicate n (replicate k 0) with [i] = replicate k 1) - in vjp_vec primal a seeds + in mjp primal a seeds diff --git a/tests/ad/scan6.fut b/tests/ad/scan6.fut index 33be743c1a..fc0fde8be7 100644 --- a/tests/ad/scan6.fut +++ b/tests/ad/scan6.fut @@ -86,7 +86,7 @@ entry rev_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = entry fwd_vec_J [n] (input: [n][2]f32) = let input = fromarrs input let seeds = tabulate (n * 2) (\i -> fromarrs (onehot_2d n 2 (i / 2) (i % 2))) - in jvp_vec primal input seeds + in jmp primal input seeds |> map toarrs |> transpose |> map transpose @@ -95,14 +95,14 @@ entry fwd_vec_J [n] (input: [n][2]f32) = entry rev_vec_J [n] (input: [n][2]f32) = let input = fromarrs input let seeds = tabulate (n * 2) (\i -> fromarrs (onehot_2d n 2 (i / 2) (i % 2))) - in vjp_vec primal input seeds + in mjp primal input seeds |> unflatten |> map (map toarrs) entry fwd_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = let input = fromarrs2 input let seeds = tabulate (n * 6) (\i -> fromarrs2 (onehot_2d n 6 (i / 6) (i % 6))) - in jvp_vec primal2 input seeds + in jmp primal2 input seeds |> map toarrs2 |> transpose |> map transpose @@ -111,6 +111,6 @@ entry fwd_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = entry rev_vec_J2 [n] (input: [n][6]f32) : [n][6][n][6]f32 = let input = fromarrs2 input let seeds = tabulate (n * 6) (\i -> fromarrs2 (onehot_2d n 6 (i / 6) (i % 6))) - in vjp_vec primal2 input seeds + in mjp primal2 input seeds |> unflatten |> map (map toarrs2) diff --git a/tests/ad/scan7.fut b/tests/ad/scan7.fut index 6b44ed1f14..cfb967878b 100644 --- a/tests/ad/scan7.fut +++ b/tests/ad/scan7.fut @@ -114,7 +114,7 @@ entry fwd_vec_J [n] [m] [k] (input: [n][m][k]f32) = let j = (p % (m * k)) / k let q = p % k in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1))) - let res = jvp_vec primal input seeds + let res = jmp primal input seeds in unflatten (sized (n * (m * k)) res) |> map unflatten entry rev_vec_J [n] [m] [k] (input: [n][m][k]f32) = @@ -123,7 +123,7 @@ entry rev_vec_J [n] [m] [k] (input: [n][m][k]f32) = let j = (p % (m * k)) / k let q = p % k in replicate n (replicate m (replicate k 0)) with [i] = (replicate m (replicate k 0) with [j] = (replicate k 0 with [q] = 1))) - let res = vjp_vec primal input seeds + let res = mjp primal input seeds |> (\x -> unflatten (sized (n * (m * k)) x)) |> map unflatten let a = res |> map (map transpose) |> map (map (map transpose)) |> map (map (map (map transpose))) let a2 = a |> map transpose |> map (map transpose) |> map (map (map transpose)) diff --git a/tests/ad/scan8.fut b/tests/ad/scan8.fut index 806a2bdf92..8a8c14317a 100644 --- a/tests/ad/scan8.fut +++ b/tests/ad/scan8.fut @@ -180,7 +180,7 @@ entry rev [n] (input: [n][9]f32) : [n][9][n][9]f32 = entry fwd_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = let input = fromarrs3 input let seeds = tabulate (n * 9) (\i -> fromarrs3 (onehot_2d n 9 (i / 9) (i % 9))) - in jvp_vec primal3 input seeds + in jmp primal3 input seeds |> map toarrs3 |> transpose |> map transpose @@ -189,6 +189,6 @@ entry fwd_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = entry rev_vec [n] (input: [n][9]f32) : [n][9][n][9]f32 = let input = fromarrs3 input let seeds = tabulate (n * 9) (\i -> fromarrs3 (onehot_2d n 9 (i / 9) (i % 9))) - in vjp_vec primal3 input seeds + in mjp primal3 input seeds |> unflatten |> map (map toarrs3) diff --git a/tests/ad/scan9.fut b/tests/ad/scan9.fut index 19de6afd69..37c50ff9e0 100644 --- a/tests/ad/scan9.fut +++ b/tests/ad/scan9.fut @@ -478,7 +478,7 @@ entry rev [n] (input: [n][16]f32) : [n][16][n][16]f32 = entry fwd_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = let input = fromarrs2 input let seeds = tabulate (n * 16) (\i -> fromarrs2 (onehot_2d n 16 (i / 16) (i % 16))) - in jvp_vec primal2 input seeds + in jmp primal2 input seeds |> map toarrs2 |> transpose |> map transpose @@ -487,6 +487,6 @@ entry fwd_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = entry rev_vec [n] (input: [n][16]f32) : [n][16][n][16]f32 = let input = fromarrs2 input let seeds = tabulate (n * 16) (\i -> fromarrs2 (onehot_2d n 16 (i / 16) (i % 16))) - in vjp_vec primal2 input seeds + in mjp primal2 input seeds |> unflatten |> map (map toarrs2) diff --git a/tests/ad/scatter0.fut b/tests/ad/scatter0.fut index e968b3dd63..b9ed86fb0f 100644 --- a/tests/ad/scatter0.fut +++ b/tests/ad/scatter0.fut @@ -30,8 +30,8 @@ entry rev [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = entry fwd_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (\vs -> f xs is vs) vs seeds + in jmp (\vs -> f xs is vs) vs seeds entry rev_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let seeds = tabulate n (\i -> replicate k 0 with [i] = 1) - in vjp_vec (\vs -> f xs is vs) vs seeds + in mjp (\vs -> f xs is vs) vs seeds diff --git a/tests/ad/scatter1.fut b/tests/ad/scatter1.fut index 8217639cfc..54b0c05adf 100644 --- a/tests/ad/scatter1.fut +++ b/tests/ad/scatter1.fut @@ -23,8 +23,8 @@ entry rev [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = entry fwd_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) - in jvp_vec (\xs -> f xs is vs) xs seeds + in jmp (\xs -> f xs is vs) xs seeds entry rev_vec [n] [k] (xs: [k]f64) (is: [n]i64) (vs: [n]f64) = let seeds = tabulate k (\i -> replicate k 0 with [i] = 1) - in vjp_vec (\xs -> f xs is vs) xs seeds + in mjp (\xs -> f xs is vs) xs seeds diff --git a/tests/ad/stripmine1.fut b/tests/ad/stripmine1.fut index 31756734d4..09309f3b0f 100644 --- a/tests/ad/stripmine1.fut +++ b/tests/ad/stripmine1.fut @@ -15,7 +15,7 @@ def square [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = square xs -- == --- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [1,2,3,4,5] } -- output { [[2,0,0,0,0], -- [0,4,0,0,0], @@ -29,11 +29,11 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp square xs (replicate n 0 with [i] = 1)) -entry f_jvp_vec [n] (xs: [n]i32) = +entry f_jmp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec square xs seeds + in jmp square xs seeds |> transpose -entry f_vjp_vec [n] (xs: [n]i32) = +entry f_mjp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec square xs seeds + in mjp square xs seeds diff --git a/tests/ad/stripmine2.fut b/tests/ad/stripmine2.fut index b6516a075b..d256314b12 100644 --- a/tests/ad/stripmine2.fut +++ b/tests/ad/stripmine2.fut @@ -13,7 +13,7 @@ def pow_list [n] y (xs: [n]i32) = entry prim y xs = pow_list y xs -- == --- entry: f_vjp f_jvp f_vjp_vec f_jvp_vec +-- entry: f_vjp f_jvp f_mjp f_jmp -- input { 3 [1,2,3] } -- output { [[3,0,0], -- [0,12,0], @@ -25,11 +25,11 @@ entry f_jvp [n] y (xs: [n]i32) = entry f_vjp [n] y (xs: [n]i32) = tabulate n (\i -> vjp (pow_list y) xs (replicate n 0 with [i] = 1)) -entry f_jvp_vec [n] y (xs: [n]i32) = +entry f_jmp [n] y (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec (pow_list y) xs seeds + in jmp (pow_list y) xs seeds |> transpose -entry f_vjp_vec [n] y (xs: [n]i32) = +entry f_mjp [n] y (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec (pow_list y) xs seeds + in mjp (pow_list y) xs seeds diff --git a/tests/ad/sum.fut b/tests/ad/sum.fut index 8477f66dad..0cd00c1b14 100644 --- a/tests/ad/sum.fut +++ b/tests/ad/sum.fut @@ -16,4 +16,4 @@ entry fwd [n] (xs: [n]f64) = entry fwd_vec [n] (xs: [n]f64) = let seeds = tabulate n (\i -> tabulate n ((== i) >-> f64.bool)) - in jvp_vec sum xs seeds + in jmp sum xs seeds diff --git a/tests/ad/transpose.fut b/tests/ad/transpose.fut index f085690a69..fa6fd0feb8 100644 --- a/tests/ad/transpose.fut +++ b/tests/ad/transpose.fut @@ -9,7 +9,7 @@ def f (xs: [][]f64) = transpose xs entry fwd_vec [n] [m] (xs: [n][m]f64) = let seeds = tabulate (n * m) (\i -> tabulate (n * m) (\j -> f64.bool (i == j)) |> unflatten) - in (jvp2_vec f xs seeds).1 + in (jmp2 f xs seeds).1 entry fwd_map [n] [m] (xs: [n][m]f64) = tabulate (n * m) diff --git a/tests/ad/truedep0.fut b/tests/ad/truedep0.fut index 813cf66903..3680459d41 100644 --- a/tests/ad/truedep0.fut +++ b/tests/ad/truedep0.fut @@ -12,7 +12,7 @@ def test [n] (xs: [n]i32) = entry prim [n] (xs: [n]i32) = test xs -- == --- entry: f_jvp f_vjp f_jvp_vec f_vjp_vec +-- entry: f_jvp f_vjp f_jmp f_mjp -- input { [1,2,3,4,5] } -- output { [[1,0,0,0,0], -- [2,0,0,0,0], @@ -26,11 +26,11 @@ entry f_jvp [n] (xs: [n]i32) = entry f_vjp [n] (xs: [n]i32) = tabulate n (\i -> vjp test xs (replicate n 0 with [i] = 1)) -entry f_jvp_vec [n] (xs: [n]i32) = +entry f_jmp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in jvp_vec test xs seeds + in jmp test xs seeds |> transpose -entry f_vjp_vec [n] (xs: [n]i32) = +entry f_mjp [n] (xs: [n]i32) = let seeds = tabulate n (\i -> replicate n 0 with [i] = 1) - in vjp_vec test xs seeds + in mjp test xs seeds From a187ad43623e2785d094384424f2166357b6ace4 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 14:24:29 +0200 Subject: [PATCH 57/70] Also update interpreter. --- src/Language/Futhark/Interpreter.hs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index d0811c3e19..7e5fc09cf9 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -2154,13 +2154,13 @@ initialCtx = def "manifest" = Just $ fun1 pure def "jvp2" = Just $ fun3 doJVP2 def "vjp2" = Just $ fun3 doVJP2 - def "jvp2_vec" = Just $ fun3 $ \f x seeds -> do + def "jmp2" = Just $ fun3 $ \f x seeds -> do v <- apply noLoc mempty f x dvs <- toArray' (valueShape v) . map (project "1") <$> mapM (doJVP2 f x) (snd (fromArray seeds)) pure $ toTuple [v, dvs] - def "vjp2_vec" = Just $ fun3 $ \f x seeds -> do + def "mjp2" = Just $ fun3 $ \f x seeds -> do v <- apply noLoc mempty f x dvs <- toArray' (valueShape x) . map (project "1") From a21e7ff804118fec13ce22635d4f0fdbf0db2cb9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 14:31:44 +0200 Subject: [PATCH 58/70] Add failing test. --- tests/ad/custom/radixsort.fut | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/ad/custom/radixsort.fut b/tests/ad/custom/radixsort.fut index 8d21661fb0..ab10dd7727 100644 --- a/tests/ad/custom/radixsort.fut +++ b/tests/ad/custom/radixsort.fut @@ -5,6 +5,11 @@ -- input { [4f32,3f32,2f32,1f32] [0.1f32,0.2f32,0.3f32,0.4f32] } -- output { [0.4f32, 0.3f32, 0.2f32, 0.1f32 ] } +-- == +-- entry: main_custom_vec +-- input { [4f32,3f32,2f32,1f32] [[0.1f32,0.2f32,0.3f32,0.4f32],[0.5f32,0.6f32,0.7f32,0.8f32]] } +-- output { [[0.4f32, 0.3f32, 0.2f32, 0.1f32], [0.8f32, 0.7f32, 0.6f32, 0.5f32]] } + def radix_sort_step [n] 't (f: t -> u32) (xs: [n]t) (b: i32) : [n]t = let bits = map (\x -> (i32.u32 (f x >> u32.i32 b)) & 1) xs let bits_neg = map (1 -) bits @@ -28,3 +33,4 @@ def differentiable_radix_sort [n] 't (f: t -> u32) (xs: [n]t) = entry main_standard = vjp (radix_sort f32.to_bits) entry main_custom = vjp (differentiable_radix_sort f32.to_bits) +entry main_custom_vec = mjp (differentiable_radix_sort f32.to_bits) From 85ecf8a3fb9a0a90e0280014e70a376459909456 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 14:39:18 +0200 Subject: [PATCH 59/70] Fix typo in comment. --- src/Futhark/AD/Rev/Monad.hs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 173e872608..23ae80def6 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -630,7 +630,7 @@ locallyNonvector e m = do -- | If we are doing vectorised AD, then transpose the array to bring the vector -- shape outermost. -- --- That, convers @[vec...][shape...][elem...]@ to @[shape...][vec...][elem...]@. +-- That is, convers @[vec...][shape...][elem...]@ to @[shape...][vec...][elem...]@. vecToInner :: VName -> ADM VName vecToInner v = do adj_shape <- askShape From 21260c51ae9b7d626554d48196c55a4284e29eca Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 15:41:13 +0200 Subject: [PATCH 60/70] Handle with_vjp in vector mode. --- src/Futhark/AD/Rev/Monad.hs | 2 +- src/Futhark/AD/Rev/SOAC.hs | 15 ++++++++------- src/Futhark/AD/Shared.hs | 18 +++++++++++------- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 23ae80def6..0849d611c8 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -614,7 +614,7 @@ locallyNonvector e m = do -- only consider those that actually have known nonzero adjoints. e_adjs <- filterM knownAdjoint e_free e_adjs_vals <- mapM lookupAdjVal e_adjs - e_free_adjs <- mkMap "nonvec_adj" e_adjs_vals $ \e_adjs_vals' -> do + e_free_adjs <- mkMap "nonvec_adj" adj_shape e_adjs_vals $ \e_adjs_vals' -> do zipWithM_ insAdj e_adjs e_adjs_vals' local (\env -> env {envAdjShape = mempty}) m mapM lookupAdjVal e_free diff --git a/src/Futhark/AD/Rev/SOAC.hs b/src/Futhark/AD/Rev/SOAC.hs index d0e487161e..be7e07bb9a 100644 --- a/src/Futhark/AD/Rev/SOAC.hs +++ b/src/Futhark/AD/Rev/SOAC.hs @@ -194,13 +194,14 @@ vjpSOAC _ops pat aux (WithVJP args lam lam_adj) m = do forM_ (zip (patNames pat) lam_res) $ \(v, SubExpRes cs se) -> certifying cs $ letBindNames [v] $ BasicOp $ SubExp se m - pat_adj <- mapM lookupAdjVal $ patNames pat - contribs <- - eLambda lam_adj (map (eSubExp . resSubExp) lam_res ++ map (eSubExp . Var) pat_adj) - forM_ (zip args contribs) $ \(arg, contrib) -> - (updateSubExpAdj arg <=< letExp "contrib") $ - BasicOp . SubExp . resSubExp $ - contrib + locallyNonvector (patNames pat, args) $ do + pat_adj <- mapM lookupAdjVal $ patNames pat + contribs <- + eLambda lam_adj (map (eSubExp . resSubExp) lam_res ++ map (eSubExp . Var) pat_adj) + forM_ (zip args contribs) $ \(arg, contrib) -> + (updateSubExpAdj arg <=< letExp "contrib") $ + BasicOp . SubExp . resSubExp $ + contrib vjpSOAC _ _ _ soac _ = error $ "vjpSOAC unhandled:\n" ++ prettyString soac diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index 35c69ce24e..dc0c6b96ed 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -40,16 +40,20 @@ mapNest shape x f = do =<< f (fmap (Var . paramName) x_p) Op . Screma w (toList x_v) <$> mapSOAC lam +-- | Construct a map over the given arrays, which must have the provided outer +-- shape. The purpose of the 'Shape' argument is to handle the case where no +-- arrays are provided. mkMap :: (MonadBuilder m, Rep m ~ SOACS, Traversable f) => Name -> + Shape -> f VName -> + -- | Action for building the body, passed names + -- corresponding to elements of the arrays. (f VName -> m [VName]) -> m [VName] -mkMap desc arrs f - | null arrs = pure [] - | otherwise = do - w <- arraySize 0 <$> lookupType (head $ toList arrs) - x_p <- traverse (newParam "xp" . rowType <=< lookupType) arrs - lam <- mkLambda (toList x_p) $ varsRes <$> f (fmap paramName x_p) - letTupExp desc . Op . Screma w (toList arrs) =<< mapSOAC lam +mkMap desc shape arrs f = do + let w = shapeSize 0 shape + x_p <- traverse (newParam "xp" . rowType <=< lookupType) arrs + lam <- mkLambda (toList x_p) $ varsRes <$> f (fmap paramName x_p) + letTupExp desc . Op . Screma w (toList arrs) =<< mapSOAC lam From 36ca87aa2ac67c88868727ed8a0e65f3ef8f11e9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 21:40:48 +0200 Subject: [PATCH 61/70] Nomenclature fixes. --- src/Futhark/AD/Fwd.hs | 4 ++-- src/Futhark/AD/Rev/Monad.hs | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 73f500b3eb..38656cefa2 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -500,8 +500,8 @@ fwdSOAC pat aux (Stream size xs accs lam) = do let accs' = interleave accs accs_tan addStm $ Let pat' aux $ Op $ Stream size xs' accs' lam' fwdSOAC pat aux (Hist w arrs ops bucket_fun) = do - -- TODO: this is probably not very efficient in the vectorised case as we end - -- up with a dreadful update operator that involves arrays. + -- TODO: this is probably not very efficient in the vector case as we end up + -- with a dreadful update operator that involves arrays. (pat', to_transpose) <- soacResPat 0 0 pat ops' <- mapM fwdHist ops bucket_fun' <- fwdHistBucket bucket_fun diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 0849d611c8..62d27895c6 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -118,9 +118,9 @@ data Sparse = Sparse -- | Element type of the array. sparseType :: PrimType, -- | Number of leading dimensions that are \"vector\" dimensions, due to - -- vectorised AD. These are not indexed by the sparse index, but are present - -- in the values. When zero, this is the ordinary non-vectorised case. This - -- is equivalent to the rank of `askShape`, but it is convenient to store it + -- vector AD. These are not indexed by the sparse index, but are present in + -- the values. When zero, this is the ordinary non-vector case. This is + -- equivalent to the rank of `askShape`, but it is convenient to store it -- here as well. sparseVecDims :: Int, -- | Locations and values of nonzero values. Indexes may be @@ -594,9 +594,9 @@ substLoopTape v v' = mapM_ (setLoopTape v') =<< lookupLoopTape v renameLoopTape :: Substitutions -> ADM () renameLoopTape = mapM_ (uncurry substLoopTape) . M.toList --- | Disable vectorised AD within the provided action. This results in a map --- that computes each adjoint explicitly, then assembles the resulting adjoint --- vectors. This is useful for constructs (such as scans) where vectorised AD is +-- | Disable vector AD within the provided action. This results in a map that +-- computes each adjoint explicitly, then assembles the resulting adjoint +-- vectors. This is useful for constructs (such as scans) where vector AD is -- impractical or inefficient. locallyNonvector :: (FreeIn e) => @@ -627,7 +627,7 @@ locallyNonvector e m = do AdjZero {} -> False _ -> True --- | If we are doing vectorised AD, then transpose the array to bring the vector +-- | If we are doing vector AD, then transpose the array to bring the vector -- shape outermost. -- -- That is, convers @[vec...][shape...][elem...]@ to @[shape...][vec...][elem...]@. From 4c0a48469d7734201ad89c8e47bc2f90a9c62738 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 21:43:50 +0200 Subject: [PATCH 62/70] Fix markup. --- prelude/ad.fut | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 8810d5463a..2653beb45e 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -177,8 +177,7 @@ def mjp 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : [n]a = -- primal result of `with_vjp`, and some part is only used in `f'`. -- -- **Beware:** if `f` uses any free variables, these will not be taken into --- **account when computing the adjoint. Make these part of the argument --- **instead. +-- account when computing the adjoint. Make these part of the argument instead. def with_vjp 'a 'b (f: a -> b) (f': (res: b) -> (b_adj: b) -> a) (x: a) : b = intrinsics.with_vjp f f' x From eb74e141f73005fd70662458f0aaaaca7ba90b2f Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 21:45:41 +0200 Subject: [PATCH 63/70] Better reference. --- prelude/ad.fut | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 2653beb45e..97ff6e3332 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -93,8 +93,9 @@ -- but it can still be substantial for programs with deep sequential -- loops. -- --- It varies on a case-by-case basis whether vector AD is faster or not. Vector --- AD essentially converts propagation of (co-)tangents from scalar to array +-- It varies on a case-by-case basis whether vector AD (`mjp`@term/`jmp`@term) +-- is faster than using `map` on top of `vjp`@term/`jvp`@term. Vector AD +-- essentially converts propagation of (co-)tangents from scalar to array -- operations, which can have a significant impact on memory accesses, depending -- on how the compiler manages to optimise the resulting code. It is hard to -- predict whether this offsets the reduction in primal work. If the vector size From 349e5bc627284121931b97ed062a288235413c64 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 21:59:43 +0200 Subject: [PATCH 64/70] Improve documentation. --- prelude/ad.fut | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 97ff6e3332..1bc2c64cbe 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -14,10 +14,11 @@ -- -- Futhark's AD support includes the following: -- --- * Differentiation operators for forward-mode (`jvp`@term) and reverse-mode +-- * Differential operators for forward-mode (`jvp`@term) and reverse-mode -- (`vjp`@term). -- --- * Arbitrary control flow in differentiable code. +-- * Almost arbitrary control flow in differentiable code (some limitations +-- apply when using GPU backends, see below). -- -- * Higher order derivatives by nesting differentiation operators, including -- arbitrary mixing of forward- and reverse mode (although using multiple @@ -134,16 +135,16 @@ def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = intrinsics.vjp2 f x y' -- | Jacobian-Matrix Product, returning also the primal result. As `jvp2`, but --- accepts a vector of seed values. Semantically equivalent to mapping, but may --- be more efficient. If used with `#[unroll]`, tangent calculations are --- unrolled when possible. +-- accepts an array of seed vectors (hence "matrix"). Semantically equivalent to +-- mapping, but may be more efficient. If used with `#[unroll]`, tangent +-- calculations are unrolled when possible. def jmp2 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = intrinsics.jmp2 f x x' -- | Matrix-Jacobian Product, returning also the primal result. As `vjp2`, but --- accepts a vector of seed values. Semantically equivalent to mapping, but may --- be more efficient. If used with `#[unroll]`, adjoint calculations are --- unrolled when possible. +-- accepts an array of seed vectors (hence "matrix"). Semantically equivalent to +-- mapping, but may be more efficient. If used with `#[unroll]`, adjoint +-- calculations are unrolled when possible. def mjp2 'a 'b [n] (f: a -> b) (x: a) (y': [n]b) : (b, [n]a) = intrinsics.mjp2 f x y' From 788b443be1707b27f3cc6aaad94f087b6a5241d3 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 22:02:01 +0200 Subject: [PATCH 65/70] Further elaboration. --- prelude/ad.fut | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 1bc2c64cbe..4c1abca3c7 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -94,16 +94,19 @@ -- but it can still be substantial for programs with deep sequential -- loops. -- --- It varies on a case-by-case basis whether vector AD (`mjp`@term/`jmp`@term) --- is faster than using `map` on top of `vjp`@term/`jvp`@term. Vector AD --- essentially converts propagation of (co-)tangents from scalar to array --- operations, which can have a significant impact on memory accesses, depending --- on how the compiler manages to optimise the resulting code. It is hard to --- predict whether this offsets the reduction in primal work. If the vector size --- is a constant, and the `#[unroll]` attribute is put on the AD operator, then --- the vectors become unrolled (turned into tuples, essentially), although this --- should only be done when the vector size is quite small, as the increase in --- code size is substantial. +-- When using vector AD (`mjp`@term/`jmp`@term), each scalar is associated with +-- a vector of tangents or cotangents, and the space overhead for storing these +-- is therefore multiplied with the vector size. However, in the case of `vjp`, +-- the intermediate results are only stored once. It varies on a case-by-case +-- basis whether vector AD is faster than using `map` on top of +-- `vjp`@term/`jvp`@term. Vector AD essentially converts propagation of +-- (co-)tangents from scalar to array operations, which can have a significant +-- impact on memory accesses, depending on how the compiler manages to optimise +-- the resulting code. It is hard to predict whether this offsets the reduction +-- in primal work. If the vector size is a constant, and the `#[unroll]` +-- attribute is put on the AD operator, then the vectors become unrolled (turned +-- into tuples, essentially), although this should only be done when the vector +-- size is quite small, as the increase in code size is substantial. -- -- ## Differentiable functions -- From 04f3bb9f8bc8643c27ba7af85285748e503db1b9 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 22:05:22 +0200 Subject: [PATCH 66/70] Clarify. --- prelude/ad.fut | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 4c1abca3c7..6aace59f35 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -138,9 +138,9 @@ def vjp2 'a 'b (f: a -> b) (x: a) (y': b) : (b, a) = intrinsics.vjp2 f x y' -- | Jacobian-Matrix Product, returning also the primal result. As `jvp2`, but --- accepts an array of seed vectors (hence "matrix"). Semantically equivalent to --- mapping, but may be more efficient. If used with `#[unroll]`, tangent --- calculations are unrolled when possible. +-- accepts an array of seed vectors (hence "matrix", although transposed). +-- Semantically equivalent to mapping, but may be more efficient. If used with +-- `#[unroll]`, tangent calculations are unrolled when possible. def jmp2 'a 'b [n] (f: a -> b) (x: a) (x': [n]a) : (b, [n]b) = intrinsics.jmp2 f x x' From 94cb3ab1f19b6d5b7fe5bf25c7563ebc11994279 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Fri, 12 Jun 2026 22:08:26 +0200 Subject: [PATCH 67/70] More. --- prelude/ad.fut | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 6aace59f35..6e5020992f 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -73,26 +73,32 @@ -- Both `jvp` and `vjp` work by transforming the program to carry -- along extra information associated with each scalar value. -- --- In the case of `jvp`, this extra information takes the form of an --- additional scalar representing the tangent, which is then --- propagated in each scalar computation using essentially the [chain --- rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp` --- has a memory overhead of approximately *2x*, and a computational --- overhead of slightly more, but usually less than *4x*. --- --- In the case of `vjp`, since our starting point is a *cotangent*, --- the function is essentially first run forward, then backwards (the --- *return sweep*) to propagate the cotangent. During the return --- sweep, all intermediate results computed during the forward sweep --- must still be available, and must therefore be stored in memory --- during the forward sweep. This means that the memory usage of `vjp` --- is proportional to the number of sequential steps of the original --- function (essentially turning *time* into *space*). The compiler --- does a nontrivial amount of optimisation to ameliorate this --- overhead (see [AD for an Array Language with Nested --- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)), --- but it can still be substantial for programs with deep sequential --- loops. +-- In the case of `jvp` ("forward mode", or "tangent mode"), this extra +-- information takes the form of an additional scalar representing the tangent, +-- which is then propagated in each scalar computation using essentially the +-- [chain rule](https://en.wikipedia.org/wiki/Chain_rule). Therefore, `jvp` has +-- a memory overhead of approximately *2x*, and a computational overhead of +-- slightly more, but usually less than *4x*. +-- +-- In the case of `vjp` ("reverse mode" or "adjoint mode"), since our starting +-- point is a *cotangent*, the function is essentially first run forward, then +-- backwards (the *return sweep*) to propagate the cotangent. During the return +-- sweep, all intermediate results computed during the forward sweep must still +-- be available, and must therefore be stored in memory during the forward sweep +-- - this is called "the tape". This means that the memory usage of `vjp` is +-- proportional to the number of sequential steps of the original function +-- (essentially turning *time* into *space*). The compiler does a nontrivial +-- amount of optimisation to ameliorate this overhead (see [AD for an Array +-- Language with Nested +-- Parallelism](https://futhark-lang.org/publications/sc22-ad.pdf)), but it can +-- still be substantial for programs with deep sequential loops. +-- +-- Nesting `vjp`, understood as applying `vjp` to the result of `vjp`, is +-- usually a bad idea, as the code structure produced by `vjp` is fairly +-- complicated, due to the tape management. Passing the output of `jvp` to +-- `vjp`, or the other way, is however fine. As a rule of thumb, whenever you +-- stack multiple differential operators, make sure only one of them is `vjp` or +-- related ones. -- -- When using vector AD (`mjp`@term/`jmp`@term), each scalar is associated with -- a vector of tangents or cotangents, and the space overhead for storing these From 383471967d5d0bf405deafe69b8b1783d8bfda5d Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sat, 13 Jun 2026 09:55:10 +0200 Subject: [PATCH 68/70] Better naming. --- src/Futhark/AD/Fwd.hs | 9 +++------ src/Futhark/AD/Rev/Monad.hs | 8 +++----- src/Futhark/AD/Shared.hs | 15 +++++++++------ 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/Futhark/AD/Fwd.hs b/src/Futhark/AD/Fwd.hs index 38656cefa2..f374f12b76 100644 --- a/src/Futhark/AD/Fwd.hs +++ b/src/Futhark/AD/Fwd.hs @@ -405,12 +405,12 @@ fwdStreamLambda num_accs (Lambda params _ body) = do onArrParam p = do shape <- askShape (p', p_tan) <- bundleNew p - let perm = auxPerm shape $ paramType p_tan + let perm = vecPerm shape $ paramType p_tan pure (p', p_tan {paramDec = rearrangeType perm (paramType p_tan)}) -- Put the tangent shape back in the outermost position. trArrParamTan tan_shape p p_tan = do - let perm = rearrangeInverse $ auxPerm tan_shape $ paramType p_tan + let perm = rearrangeInverse $ vecPerm tan_shape $ paramType p_tan v <- letExp (baseName (paramName p_tan)) . BasicOp $ Rearrange (paramName p_tan) perm @@ -419,12 +419,9 @@ fwdStreamLambda num_accs (Lambda params _ body) = do -- Put the chunk size back in the outermost position. trMapResTan tan_shape (SubExpRes cs ~(Var v)) = do v_t <- lookupType v - let perm = auxPerm tan_shape v_t + let perm = vecPerm tan_shape v_t fmap varRes . certifying cs $ letExp (baseName v) . BasicOp $ Rearrange v perm -vecPerm :: Shape -> Type -> [Int] -vecPerm = auxPerm - pushTanShape :: VName -> ADM VName pushTanShape v = do tan_shape <- askShape diff --git a/src/Futhark/AD/Rev/Monad.hs b/src/Futhark/AD/Rev/Monad.hs index 62d27895c6..70a160afae 100644 --- a/src/Futhark/AD/Rev/Monad.hs +++ b/src/Futhark/AD/Rev/Monad.hs @@ -627,10 +627,7 @@ locallyNonvector e m = do AdjZero {} -> False _ -> True --- | If we are doing vector AD, then transpose the array to bring the vector --- shape outermost. --- --- That is, convers @[vec...][shape...][elem...]@ to @[shape...][vec...][elem...]@. +-- | If we are doing vector AD, apply 'vecPerm' to the array. vecToInner :: VName -> ADM VName vecToInner v = do adj_shape <- askShape @@ -638,7 +635,8 @@ vecToInner v = do then pure v else do v_t <- lookupType v - letExp (baseName v <> "_tr") $ BasicOp $ Rearrange v (auxPerm adj_shape v_t) + letExp (baseName v <> "_tr") . BasicOp . Rearrange v $ + vecPerm adj_shape v_t -- Note [Consumption] -- diff --git a/src/Futhark/AD/Shared.hs b/src/Futhark/AD/Shared.hs index dc0c6b96ed..9bcfa47904 100644 --- a/src/Futhark/AD/Shared.hs +++ b/src/Futhark/AD/Shared.hs @@ -1,6 +1,6 @@ -- | Various definitions used for both forward and reverse mode. module Futhark.AD.Shared - ( auxPerm, + ( vecPerm, asVName, mapNest, mkMap, @@ -12,11 +12,14 @@ import Data.Foldable import Futhark.Construct import Futhark.IR.SOACS -auxPerm :: Shape -> Type -> [Int] -auxPerm aux_shape t = - [shapeRank aux_shape] - ++ [0 .. shapeRank aux_shape - 1] - ++ [shapeRank aux_shape + 1 .. arrayRank t - 1] +-- | A permutation for transposing the vector shape past the next dimension. +-- +-- That is, converts @[vec...][d][elem...]@ to @[d][vec...][elem...]@. +vecPerm :: Shape -> Type -> [Int] +vecPerm vec_shape t = + [shapeRank vec_shape] + ++ [0 .. shapeRank vec_shape - 1] + ++ [shapeRank vec_shape + 1 .. arrayRank t - 1] asVName :: (MonadBuilder m) => SubExp -> m VName asVName (Var v) = pure v From f1df3dcdda2d0d81c238b93b92bcc870a83970b0 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Sun, 14 Jun 2026 10:59:05 +0200 Subject: [PATCH 69/70] More docs. --- prelude/ad.fut | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 6e5020992f..004f657320 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -45,26 +45,28 @@ -- such as arrays, records, sums, and so on, simply by "flattening -- out" the values and considering only their constituent scalars. -- --- Computing the full Jacobian is usually costly and sometimes not --- necessary, and it is not part of the AD facility provided by --- Futhark. Instead it is possible to parts of the Jacobian. --- --- We can take the product of an an *m* by *n* Jacobian with an --- *n*-element *tangent vector* to produce an *m*-element vector --- (*Jacobian-vector product*). Such a product can be computed in a --- single (augmented) execution of the function *f*, and by choosing --- the tangent vector appropriately we can use this to compute the --- full Jacobian. This is provided by the function `jvp`. +-- Computing the full Jacobian is usually costly and sometimes not necessary, +-- and it is not part of the AD facility provided by Futhark. Instead it is +-- possible to compute parts of the Jacobian, which semantically (but not +-- operationally) can be seen as multiplying the Jacobian with a vector, +-- producing a vector. However, it is important to understand that the full +-- Jacobian is *not* constructed as an intermediate step. +-- +-- We can take the product of an an *m* by *n* Jacobian with an *n*-element +-- *tangent vector* to produce an *m*-element vector (*Jacobian-vector +-- product*). Such a product can be computed in a single (augmented) execution +-- of the function *f*. This is provided by the function `jvp`. -- -- We can also take the product of an *m*-element vector *cotangent -- vector* with the *m* by *n* Jacobian to produce an *n*-element -- vector (*Vector-Jacobian product*). This too can be computed in a -- single execution of *f*, with `vjp`. -- --- We can use the `jvp` function to produce a *column* of the full --- Jacobian, and `vjp` to produce a *row*. Which is superior for a --- given situation depends on whether the function has more inputs or --- outputs. +-- Using an elementary vector, we can use the `jvp` function to produce a +-- *column* of the full Jacobian, and `vjp` to produce a *row*, with the nonzero +-- element of the vector identifying which column or row is extracted. Which is +-- superior for a given situation depends on whether the function has more +-- inputs or outputs. -- -- We can freely nest `vjp` and `jvp` to compute higher-order derivatives. -- From 100de3def2bdbca030e6bdde7906b558779a13b7 Mon Sep 17 00:00:00 2001 From: Troels Henriksen Date: Mon, 15 Jun 2026 10:05:22 +0200 Subject: [PATCH 70/70] Minor fices. --- prelude/ad.fut | 42 +++++++++++++++++++++++------------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/prelude/ad.fut b/prelude/ad.fut index 004f657320..c20c93e6a5 100644 --- a/prelude/ad.fut +++ b/prelude/ad.fut @@ -33,17 +33,16 @@ -- -- ## Jacobians -- --- For a differentiable function *f* whose input comprise *n* scalars --- and whose output comprises *m* scalars, the --- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) --- for a given input point is an *m* by *n* matrix of scalars that --- each represent a [partial --- derivatives](https://en.wikipedia.org/wiki/Partial_derivative). --- Intuitively, position *(i,j)* of the Jacobian describes how --- sensitive output *i* is to input *j*. The notion of Jacobian --- generalises to functions that accept or produce compound structures --- such as arrays, records, sums, and so on, simply by "flattening --- out" the values and considering only their constituent scalars. +-- For a differentiable function *f* whose input comprise *n* scalars and whose +-- output comprises *m* scalars, the +-- [Jacobian](https://en.wikipedia.org/wiki/Jacobian_matrix_and_determinant) for +-- a given input point is an *m* by *n* matrix of scalars that each represent a +-- [partial derivative](https://en.wikipedia.org/wiki/Partial_derivative). +-- Intuitively, position *(i,j)* of the Jacobian describes how sensitive output +-- *i* is to input *j*. The notion of Jacobian generalises to functions that +-- accept or produce compound structures such as arrays, records, sums, and so +-- on, simply by "flattening out" the values and considering only their +-- constituent scalars. -- -- Computing the full Jacobian is usually costly and sometimes not necessary, -- and it is not part of the AD facility provided by Futhark. Instead it is @@ -52,21 +51,25 @@ -- producing a vector. However, it is important to understand that the full -- Jacobian is *not* constructed as an intermediate step. -- --- We can take the product of an an *m* by *n* Jacobian with an *n*-element +-- We can take the product of an *m* by *n* Jacobian with an *n*-element -- *tangent vector* to produce an *m*-element vector (*Jacobian-vector -- product*). Such a product can be computed in a single (augmented) execution -- of the function *f*. This is provided by the function `jvp`. -- -- We can also take the product of an *m*-element vector *cotangent -- vector* with the *m* by *n* Jacobian to produce an *n*-element --- vector (*Vector-Jacobian product*). This too can be computed in a +-- vector (*vector-Jacobian product*). This too can be computed in a -- single execution of *f*, with `vjp`. -- --- Using an elementary vector, we can use the `jvp` function to produce a --- *column* of the full Jacobian, and `vjp` to produce a *row*, with the nonzero --- element of the vector identifying which column or row is extracted. Which is --- superior for a given situation depends on whether the function has more --- inputs or outputs. +-- A tangent has the same structure as the input and represents a direction in +-- input space. A cotangent has the same structure as the output and represents +-- sensitivities flowing backwards through the computation. +-- +-- Using an elementary (co-)tangent vector, we can use the `jvp` function to +-- produce a *column* of the full Jacobian, and `vjp` to produce a *row*, with +-- the nonzero element of the vector identifying which column or row is +-- extracted. Which is superior for a given situation depends on whether the +-- function has more inputs or outputs. -- -- We can freely nest `vjp` and `jvp` to compute higher-order derivatives. -- @@ -122,7 +125,8 @@ -- type system does not distinguish differentiable from non-differentiable -- operations. As a rule of thumb, a function is differentiable if its results -- are computed using a composition of primitive floating-point operations, --- without ever converting to or from integers. +-- without ever converting to or from integers. Most functions will also have +-- discontinuities around values that influence control flow. -- -- Note that a function whose input or output is a sum type with more than one -- constructor is *not* differentiable (or at least the sum-typed part is not).