diff --git a/docs/conf.py b/docs/conf.py index 4d3b3f80d4..34b4cfdc1d 100755 --- a/docs/conf.py +++ b/docs/conf.py @@ -541,6 +541,13 @@ class FutharkLexer(RegexLexer): [], 1, ), + ( + "man/futhark-eval", + "futhark-eval", + "execute Futhark expression", + [], + 1, + ), ( "man/futhark-script", "futhark-script", diff --git a/docs/index.rst b/docs/index.rst index 4e3f6a70ec..7c3f9c4343 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,12 +51,12 @@ the `development blog `_. man/futhark-cuda.rst man/futhark-dataset.rst man/futhark-doc.rst + man/futhark-eval.rst man/futhark-fmt.rst man/futhark-hip.rst man/futhark-ispc.rst man/futhark-literate.rst man/futhark-lsp.rst - man/futhark-script.rst man/futhark-multicore.rst man/futhark-opencl.rst man/futhark-pkg.rst @@ -65,6 +65,7 @@ the `development blog `_. man/futhark-python.rst man/futhark-repl.rst man/futhark-run.rst + man/futhark-script.rst man/futhark-test.rst man/futhark-wasm-multicore.rst man/futhark-wasm.rst diff --git a/docs/man/futhark-eval.rst b/docs/man/futhark-eval.rst new file mode 100644 index 0000000000..9f1cd93cea --- /dev/null +++ b/docs/man/futhark-eval.rst @@ -0,0 +1,69 @@ +.. role:: ref(emphasis) + +.. _futhark-eval(1): + +================ +futhark-eval +================ + +SYNOPSIS +======== + +futhark eval [options...] expression + +DESCRIPTION +=========== + +This command executes Futhark expressions through the interpreter and prints +their results to stdout. You can provide a ``.fut`` file with definitions that +are then accessible in the expression. + +Further, if you pass ``--backend``, the file will be compiled with the given +compiler backend, and any references to entry points in the evaluation +expression will be handled by compiled code. This is the main purpose of +``futhark eval``: it allows convenient invocation of Futhark entry points +without having to recompile. + +The expression itself is not compiled, so it is best to put as much of the heavy +work as possible into the entry points provided in the ``.fut`` file. Also, +beware: while the expression can access any definition in the file, only the +ones explicitly declared as entry points will be compiled and run fast. All +other definitions will be executed in interpreted mode. + +OPTIONS +======= + +--backend=name + + The backend used when compiling Futhark programs (without leading ``futhark``, + e.g. just ``opencl``). Defaults to no backend, meaning purely interpreted. + +-f, --file=FILE + + Evaluate expressions in the context of this file. + +--futhark=program + + The program used to perform operations (eg. compilation). Defaults + to the binary running ``futhark eval`` itself. + +--pass-option=opt + + Pass an option to benchmark programs that are being run. For + example, we might want to run OpenCL programs on a specific device:: + + futhark eval prog.fut --backend=opencl --pass-option=-dHawaii + +--pass-compiler-option=opt + + Pass an extra option to the compiler when compiling the programs. + +--skip-compilation + + Do not run the compiler, and instead assume that the program has + already been compiled. Use with caution. + +SEE ALSO +======== + +:ref:`futhark-literate(1)`, :ref:`futhark-test(1)`, :ref:`futhark-bench(1)` diff --git a/docs/man/futhark.rst b/docs/man/futhark.rst index 37b14188e6..aedd17809e 100644 --- a/docs/man/futhark.rst +++ b/docs/man/futhark.rst @@ -73,12 +73,6 @@ A Futhark compiler development command, intentionally undocumented and intended for use in developing the Futhark compiler, not for programmers writing in Futhark. -futhark eval [-f FILE] [-w] --------------------------------------- - -Evaluates expressions given as command-line arguments. Optionally -allows a file import using ``-f``. - futhark hash PROGRAM -------------------- diff --git a/futhark.cabal b/futhark.cabal index 3fd3fb3e93..53717a8940 100644 --- a/futhark.cabal +++ b/futhark.cabal @@ -415,6 +415,17 @@ library Language.Futhark.Core Language.Futhark.Interpreter Language.Futhark.Interpreter.AD + Language.Futhark.Interpreter.FFI + Language.Futhark.Interpreter.FFI.Server + Language.Futhark.Interpreter.FFI.Server.Explorer + Language.Futhark.Interpreter.FFI.Server.Interface + Language.Futhark.Interpreter.FFI.Server.Packer + Language.Futhark.Interpreter.FFI.Server.TypeLayout + Language.Futhark.Interpreter.FFI.UIDs + Language.Futhark.Interpreter.FFI.Util.BiMap + Language.Futhark.Interpreter.FFI.Util.NDArray + Language.Futhark.Interpreter.FFI.Util.UID + Language.Futhark.Interpreter.FFI.Values Language.Futhark.Interpreter.Values Language.Futhark.FreeVars Language.Futhark.Parser @@ -477,7 +488,7 @@ library , file-embed >=0.0.14.0 , filepath >=1.4.1.1 , free >=5.1.10 - , futhark-data >= 1.1.3.0 + , futhark-data >= 1.1.4.0 , futhark-server >= 1.3.3.0 , futhark-manifest == 1.7.0.0 , githash >=0.1.6.1 diff --git a/nix/futhark-data.nix b/nix/futhark-data.nix index 3fd57a2757..121d07a740 100644 --- a/nix/futhark-data.nix +++ b/nix/futhark-data.nix @@ -5,8 +5,8 @@ }: mkDerivation { pname = "futhark-data"; - version = "1.1.3.0"; - sha256 = "a3a274bfa9f2bf9df30e8f5a23a9243726c1c82502bcaee53e2a347e0697c9be"; + version = "1.1.4.0"; + sha256 = "137733709a6e360e6f8f5b376b68bc13d026c4c9737b4efc56ed77a50e18f0fa"; libraryHaskellDepends = [ base binary bytestring bytestring-to-vector containers half megaparsec mtl scientific text vector vector-binary-instances diff --git a/src/Futhark/CLI/Eval.hs b/src/Futhark/CLI/Eval.hs index 7c63df920b..ccb723fdc4 100644 --- a/src/Futhark/CLI/Eval.hs +++ b/src/Futhark/CLI/Eval.hs @@ -7,8 +7,8 @@ import Control.Monad import Data.Map qualified as M import Data.Text qualified as T import Futhark.Eval - ( InterpreterConfig (..), - interpreterConfig, + ( EvalConfig (..), + evalConfig, newFutharkiState, runExpr, ) @@ -27,15 +27,14 @@ import System.IO (stderr) -- | Run @futhark eval@. main :: String -> [String] -> IO () -main = mainWithOptions interpreterConfig options "options... " run +main = mainWithOptions evalConfig options "options... " run where run [] _ = Nothing run exprs config = Just $ runExprs exprs config -runExprs :: [String] -> InterpreterConfig -> IO () +runExprs :: [String] -> EvalConfig -> IO () runExprs exprs cfg = do - let InterpreterConfig _ file = cfg - maybe_new_state <- newFutharkiState cfg file M.empty + maybe_new_state <- newFutharkiState cfg M.empty interpreter_state <- case maybe_new_state of Left reason -> do hPutDocLn stderr reason @@ -43,21 +42,54 @@ runExprs exprs cfg = do Right s -> pure s forM_ exprs $ \expr -> putDocLn =<< runExpr interpreter_state (T.pack expr) -options :: [FunOptDescr InterpreterConfig] +options :: [FunOptDescr EvalConfig] options = [ Option "f" ["file"] ( ReqArg - ( \entry -> Right $ \config -> - config {interpreterFile = Just entry} - ) + (\entry -> Right $ \config -> config {evalFile = Just entry}) "NAME" ) "The file to load before evaluating expressions.", Option "w" ["no-warnings"] - (NoArg $ Right $ \config -> config {interpreterPrintWarnings = False}) - "Do not print warnings." + (NoArg $ Right $ \config -> config {evalPrintWarnings = False}) + "Do not print warnings.", + Option + "p" + ["pass-option"] + ( ReqArg + ( \opt -> + Right $ \config -> + config {evalExtraOptions = opt : evalExtraOptions config} + ) + "OPT" + ) + "Pass this option to programs being run.", + Option + [] + ["pass-compiler-option"] + ( ReqArg + ( \opt -> + Right $ \config -> + config {evalCompilerOptions = opt : evalCompilerOptions config} + ) + "OPT" + ) + "Pass this option to the compiler.", + Option + "" + ["skip-compilation"] + (NoArg $ Right $ \config -> config {evalSkipCompilation = True}) + "Use already compiled server-mode program.", + Option + [] + ["backend"] + ( ReqArg + (\backend -> Right $ \config -> config {evalBackend = Just backend}) + "BACKEND" + ) + "The compiler backend used (defaults to interpreted)." ] diff --git a/src/Futhark/CLI/REPL.hs b/src/Futhark/CLI/REPL.hs index eb482a692b..8c2c493642 100644 --- a/src/Futhark/CLI/REPL.hs +++ b/src/Futhark/CLI/REPL.hs @@ -312,6 +312,8 @@ breakForReason s top _ = runInterpreter :: F I.ExtOp a -> FutharkiM (Either I.InterpreterError a) runInterpreter m = runF m (pure . Right) intOp where + intOp (I.ExtOpCall {}) = error "Unexpected ExtOpCall" + intOp (I.ExtOpRealize {}) = error "Unexpected ExtOpRealize" intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do @@ -365,6 +367,8 @@ runInterpreter m = runF m (pure . Right) intOp runInterpreterNoBreak :: (MonadIO m) => F I.ExtOp a -> m (Either I.InterpreterError a) runInterpreterNoBreak m = runF m (pure . Right) intOp where + intOp (I.ExtOpCall {}) = error "Unexpected ExtOpCall" + intOp (I.ExtOpRealize {}) = error "Unexpected ExtOpRealize" intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do liftIO $ putDocLn $ pretty w <> ":" <+> align (unAnnotate v) diff --git a/src/Futhark/CLI/Run.hs b/src/Futhark/CLI/Run.hs index d723ef55b4..ea475363a5 100644 --- a/src/Futhark/CLI/Run.hs +++ b/src/Futhark/CLI/Run.hs @@ -143,6 +143,8 @@ newFutharkiState cfg file = runExceptT $ do runInterpreter' :: (MonadIO m) => F I.ExtOp a -> m (Either I.InterpreterError a) runInterpreter' m = runF m (pure . Right) intOp where + intOp (I.ExtOpCall {}) = error "Unexpected ExtOpCall" + intOp (I.ExtOpRealize {}) = error "Unexpected ExtOpRealize" intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do liftIO $ hPutDocLn stderr $ pretty w <> ":" <+> align (unAnnotate v) diff --git a/src/Futhark/Eval.hs b/src/Futhark/Eval.hs index 2e7a5686c8..a8067eda41 100644 --- a/src/Futhark/Eval.hs +++ b/src/Futhark/Eval.hs @@ -1,7 +1,7 @@ module Futhark.Eval - ( InterpreterConfig (..), + ( EvalConfig (..), runExpr, - interpreterConfig, + evalConfig, newFutharkiState, Evaluation (..), EvalRecordRef (), @@ -9,34 +9,47 @@ module Futhark.Eval ) where +import Control.Arrow (Arrow (second)) import Control.Exception (IOException, catch) -import Control.Monad (foldM, when, (<=<)) +import Control.Monad (foldM, unless, void, when, (<=<)) import Control.Monad.Except (ExceptT, runExceptT, throwError) import Control.Monad.Free.Church (F, runF) import Control.Monad.IO.Class (MonadIO, liftIO) import Control.Monad.Trans.Class (lift) import Control.Monad.Trans.Reader (ReaderT (runReaderT), ask) -import Data.IORef (IORef, modifyIORef') +import Data.Array qualified as A +import Data.Bifunctor (first) +import Data.IORef (IORef, modifyIORef', newIORef, readIORef, writeIORef) import Data.Map qualified as M -import Data.Maybe (maybeToList) +import Data.Maybe (isJust, maybeToList) import Data.Sequence (Seq, (|>)) import Data.Text qualified as T +import Data.Text.IO qualified as T import Futhark.Compiler (prettyWarnings, readProgramFilesExceptKnown) import Futhark.Compiler.Program (VFS, fileProg, fileScope) import Futhark.Error (externalErrorS, prettyCompilerError) import Futhark.FreshNames (VNameSource) +import Futhark.Test (FutharkExe (..), compileProgram) import Futhark.Util.Pretty (commasep, hPutDoc, hPutDocLn, hardline, putDocLn) import Language.Futhark.Interpreter qualified as I +import Language.Futhark.Interpreter.FFI qualified as S +import Language.Futhark.Interpreter.FFI.Server (FutharkServer) +import Language.Futhark.Interpreter.FFI.Server qualified as S +import Language.Futhark.Interpreter.FFI.Server.Packer qualified as SP +import Language.Futhark.Interpreter.FFI.Values (Location) +import Language.Futhark.Interpreter.Values qualified as IV import Language.Futhark.Parser (parseExp) import Language.Futhark.Parser.Monad (SyntaxError (SyntaxError)) import Language.Futhark.Pretty (toName) import Language.Futhark.Prop (typeOf) import Language.Futhark.Semantic qualified as T -import Language.Futhark.Syntax (nameToText, typeParamName) +import Language.Futhark.Syntax (DecBase (ValDec), ProgBase (progDecs), VName (VName), ValBindBase (..), nameToText, typeParamName) import Language.Futhark.TypeChecker qualified as T import Prettyprinter (Doc, align, pretty, unAnnotate, vcat, (<+>)) import Prettyprinter.Render.Terminal (AnsiStyle) -import System.Exit (ExitCode (ExitFailure), exitWith) +import System.Environment (getExecutablePath) +import System.Exit (ExitCode (ExitFailure), exitFailure, exitWith) +import System.FilePath (dropExtension, ()) import System.IO (stderr) -- | The class of monads that can perform expression evaluation. @@ -78,17 +91,57 @@ runEvalRecordRef :: runEvalRecordRef msgRef (EvalRecordRef action) = flip runReaderT msgRef $ runExceptT action -newtype InterpreterState = InterpreterState (VNameSource, T.Env, I.Ctx) +newtype InterpreterState = InterpreterState (VNameSource, T.Env, I.Ctx, Maybe FutharkServer) + +-- TODO: Should NOT be IORef. This is temporary, for testing +call :: IORef (Maybe FutharkServer) -> VName -> [I.Value] -> IO I.Value +call s (VName n _) p = do + let p' = map S.fromInterpreterValue p + (Just s') <- readIORef s + (r, s'') <- first S.toInterpreterValue <$> S.runFutharkServerM (SP.call (nameToText n) p') s' + writeIORef s $ Just s'' + pure r + +-- TODO: Should NOT be IORef. This is temporary, for testing +realize :: IORef (Maybe FutharkServer) -> Location -> IO I.Value +realize s l = do + (Just s') <- readIORef s + (r, s'') <- first S.toInterpreterValue <$> S.runFutharkServerM (SP.realize' l) s' + writeIORef s $ Just s'' + pure r + +secondM :: (Monad m) => (b -> m c) -> (a, b) -> m (a, c) +secondM f (x, y) = do + y' <- f y + pure (x, y') + +fullyRealize :: IORef (Maybe FutharkServer) -> I.Value -> IO I.Value +fullyRealize _ (IV.ValuePrim p) = pure $ IV.ValuePrim p +fullyRealize s (IV.ValueArray sh a) = do + let (l, u) = A.bounds a + idxs = A.range (l, u) + bs <- mapM (fullyRealize s . (a A.!)) idxs + pure $ IV.ValueArray sh $ A.listArray (l, u) bs +fullyRealize s (IV.ValueRecord m) = IV.ValueRecord . M.fromList <$> (mapM $ secondM $ fullyRealize s) (M.toList m) +fullyRealize _ (IV.ValueFun f) = pure $ IV.ValueFun f +fullyRealize s (IV.ValueAcc sh f a) = do + let (l, u) = A.bounds a + idxs = A.range (l, u) + bs <- mapM (fullyRealize s . (a A.!)) idxs + pure $ IV.ValueAcc sh f $ A.listArray (l, u) bs +fullyRealize s (IV.ValueSum sh n vs) = IV.ValueSum sh n <$> mapM (fullyRealize s) vs +fullyRealize _ (IV.ValueAD d v) = pure $ IV.ValueAD d v +fullyRealize s (IV.ValueExt l _) = realize s l -- | Run an expression in the given interpreter state. The expression is parsed, -- type checked, and then run. Returns a prettyprinted result. Must be run in a -- monad that supports aborting and traces. runExpr :: - (Evaluation m) => + (Evaluation m, MonadIO m) => InterpreterState -> T.Text -> m (Doc AnsiStyle) -runExpr (InterpreterState (src, env, ctx)) str = do +runExpr (InterpreterState (src, env, ctx, s)) str = do uexp <- case parseExp "" str of Left (SyntaxError _ serr) -> abort $ pretty serr Right e -> pure e @@ -103,27 +156,63 @@ runExpr (InterpreterState (src, env, ctx)) str = do "The following types are ambiguous: " <> commasep (map (pretty . nameToText . toName . typeParamName) tparams) ] - pval <- runInterpreterNoBreak $ I.interpretExp ctx fexp + is <- liftIO $ newIORef s + pval <- runInterpreterNoBreak call realize is $ I.interpretExp ctx fexp case pval of Left err -> do abort $ I.prettyInterpreterError err - Right val -> pure $ I.prettyValue val <> hardline + Right val -> do + val' <- liftIO $ fullyRealize is val + pure $ I.prettyValue val' <> hardline -data InterpreterConfig = InterpreterConfig - { interpreterPrintWarnings :: Bool, - interpreterFile :: Maybe String +data EvalConfig = EvalConfig + { evalPrintWarnings :: Bool, + evalFile :: Maybe String, + -- | If @Just@, compile the file using this backend. + evalBackend :: Maybe String, + evalSkipCompilation :: Bool, + evalExtraOptions :: [String], + evalCompilerOptions :: [String], + evalFuthark :: Maybe FilePath } -interpreterConfig :: InterpreterConfig -interpreterConfig = InterpreterConfig True Nothing +evalConfig :: EvalConfig +evalConfig = + EvalConfig + { evalPrintWarnings = True, + evalFile = Nothing, + evalBackend = Nothing, + evalSkipCompilation = False, + evalExtraOptions = [], + evalCompilerOptions = [], + evalFuthark = Nothing + } + +prepareServer :: EvalConfig -> FilePath -> String -> IO FutharkServer +prepareServer cfg file backend = do + futhark <- maybe getExecutablePath pure $ evalFuthark cfg + + unless (evalSkipCompilation cfg) $ do + let compile_options = "--server" : evalCompilerOptions cfg + + let onError err = do + T.hPutStrLn stderr err + exitFailure + + void $ + either onError pure <=< runExceptT $ + compileProgram compile_options (FutharkExe futhark) backend file + + let prog = "." dropExtension file + S.startServer prog newFutharkiState :: (MonadIO m, Evaluation m) => - InterpreterConfig -> - Maybe FilePath -> + EvalConfig -> VFS -> m (Either (Doc AnsiStyle) InterpreterState) -newFutharkiState cfg maybe_file vfs = runExceptT $ do +newFutharkiState cfg vfs = runExceptT $ do + let maybe_file = evalFile cfg (ws, imports, src) <- badOnLeft prettyCompilerError =<< liftIO @@ -131,39 +220,65 @@ newFutharkiState cfg maybe_file vfs = runExceptT $ do `catch` \(err :: IOException) -> pure (externalErrorS (show err)) ) - when (interpreterPrintWarnings cfg) $ - liftIO $ - hPutDoc stderr $ - prettyWarnings ws + when (evalPrintWarnings cfg) $ + liftIO . hPutDoc stderr $ + prettyWarnings ws + + let modifyLast _ [] = [] + modifyLast f [x] = [f x] + modifyLast f (x : xs) = x : modifyLast f xs + + (imports', s) <- case (maybe_file, evalBackend cfg) of + (Just file, Just backend) -> liftIO $ do + let mdec (ValDec vb) + | isJust $ valBindEntryPoint vb = + ValDec $ vb {valBindAttrs = "$external" : valBindAttrs vb} + mdec dec = dec + (_, m) = last imports + m' = m {fileProg = (fileProg m) {progDecs = map mdec $ progDecs $ fileProg m}} + (modifyLast (second $ const m') imports,) . Just + <$> prepareServer cfg file backend + _ -> pure (imports, Nothing) + is <- liftIO $ newIORef s ictx <- let foldFile ctx = badOnLeft I.prettyInterpreterError - <=< runInterpreterNoBreak + <=< runInterpreterNoBreak call realize is . I.interpretImport ctx in foldM foldFile I.initialCtx $ - map (fmap fileProg) imports + map (fmap fileProg) imports' + s' <- liftIO $ readIORef is let (tenv, ienv) = - let (iname, fm) = last imports + let (iname, fm) = last imports' in ( fileScope fm, ictx {I.ctxEnv = I.ctxImports ictx M.! iname} ) - pure $ InterpreterState (src, tenv, ienv) + pure $ InterpreterState (src, tenv, ienv, s') where badOnLeft :: (Monad m) => (err -> err') -> Either err a -> ExceptT err' m a badOnLeft _ (Right x) = pure x badOnLeft p (Left err) = throwError $ p err runInterpreterNoBreak :: - (Evaluation m) => + (Evaluation m, MonadIO m) => + (IORef (Maybe FutharkServer) -> VName -> [I.Value] -> IO I.Value) -> + (IORef (Maybe FutharkServer) -> Location -> IO I.Value) -> + IORef (Maybe FutharkServer) -> F I.ExtOp a -> m (Either I.InterpreterError a) -runInterpreterNoBreak m = runF m (pure . Right) intOp +runInterpreterNoBreak call' realize' s m = runF m (pure . Right) intOp where intOp (I.ExtOpError err) = pure $ Left err intOp (I.ExtOpTrace w v c) = do trace $ pretty w <> ":" <+> align (unAnnotate v) c intOp (I.ExtOpBreak _ _ _ c) = c + intOp (I.ExtOpCall n p c) = do + r <- liftIO $ call' s n p + c r + intOp (I.ExtOpRealize l c) = do + r <- liftIO $ realize' s l + c r diff --git a/src/Futhark/LSP/CodeLens.hs b/src/Futhark/LSP/CodeLens.hs index cf40075d54..c902b859ec 100644 --- a/src/Futhark/LSP/CodeLens.hs +++ b/src/Futhark/LSP/CodeLens.hs @@ -26,7 +26,7 @@ import Data.Text (Text) import Data.Text qualified as T import Data.Text.Mixed.Rope qualified as R import Futhark.Compiler.Program (VFS) -import Futhark.Eval (Evaluation (abort), InterpreterConfig (InterpreterConfig), newFutharkiState, runEvalRecordRef, runExpr) +import Futhark.Eval (EvalConfig (..), Evaluation (abort), evalConfig, newFutharkiState, runEvalRecordRef, runExpr) import Futhark.LSP.CommandType qualified as CommandType import Futhark.LSP.Tool (Execute, transformVFS) import Futhark.Util (showText) @@ -286,17 +286,20 @@ executeEvalLens (EvalLensData docUri line) = do IORef (Seq (Doc AnsiStyle)) -> IO (Either (Doc AnsiStyle) (Doc AnsiStyle)) evaluationAction traceRef = interpret $ do - -- do not print warnings, no file - let interpreterConfig = InterpreterConfig False Nothing + -- do not print warnings let filePath = toNormalizedUri docUri & uriToNormalizedFilePath & fmap fromNormalizedFilePath + cfg = + evalConfig + { evalPrintWarnings = False, + evalFile = filePath + } -- load the file the expression is located in interpreterState <- - newFutharkiState interpreterConfig filePath currentVFS - >>= either abort pure + newFutharkiState cfg currentVFS >>= either abort pure liftIO setupLimits diff --git a/src/Language/Futhark/Interpreter.hs b/src/Language/Futhark/Interpreter.hs index bd4dafee5d..66530239c0 100644 --- a/src/Language/Futhark/Interpreter.hs +++ b/src/Language/Futhark/Interpreter.hs @@ -59,6 +59,7 @@ import Futhark.Util.Pretty hiding (apply) import Language.Futhark hiding (Shape, matchDims) import Language.Futhark qualified as F import Language.Futhark.Interpreter.AD qualified as AD +import Language.Futhark.Interpreter.FFI.Values (Location, indexLocation, projectLocation) import Language.Futhark.Interpreter.Values hiding (Value) import Language.Futhark.Interpreter.Values qualified import Language.Futhark.Primitive (floatValue, intValue) @@ -86,11 +87,15 @@ data ExtOp a = ExtOpTrace T.Text (Doc ()) a | ExtOpBreak Loc BreakReason (NE.NonEmpty StackFrame) a | ExtOpError InterpreterError + | ExtOpCall VName [Value] (Value -> a) + | ExtOpRealize Location (Value -> a) instance Functor ExtOp where fmap f (ExtOpTrace w s x) = ExtOpTrace w s $ f x fmap f (ExtOpBreak w why backtrace x) = ExtOpBreak w why backtrace $ f x fmap _ (ExtOpError err) = ExtOpError err + fmap f (ExtOpCall n p c) = ExtOpCall n p $ f . c + fmap f (ExtOpRealize l c) = ExtOpRealize l $ f . c type Stack = [StackFrame] @@ -135,6 +140,11 @@ stacktrace = asks $ map stackFrameLoc . fst adDepth :: EvalM AD.Depth adDepth = AD.Depth . length <$> stacktrace +realize :: Value -> EvalM Value +realize (ValueExt _ (Just v)) = pure v +realize (ValueExt l Nothing) = liftF $ ExtOpRealize l id +realize v = pure v + lookupImport :: ImportName -> EvalM (Maybe Env) lookupImport f = asks $ M.lookup f . snd @@ -467,6 +477,8 @@ fromArray v = error $ "Expected array value, but found: " <> show v project :: Name -> Value -> Value project f (ValueRecord fs) | Just v' <- M.lookup f fs = v' +project f (ValueExt _ (Just v)) = project f v +project f (ValueExt l Nothing) = ValueExt (projectLocation (nameToText f) l) Nothing project _ _ = error "Value does not have expected field." apply :: SrcLoc -> Env -> Value -> Value -> EvalM Value @@ -597,6 +609,12 @@ indexArray (IndexingFix i : is) (ValueArray _ arr) indexArray (IndexingSlice start end stride : is) (ValueArray (ShapeDim _ rowshape) arr) = do js <- indexesFor start end stride $ arrayLength arr toArray' (indexShape is rowshape) <$> mapM (indexArray is . (arr !)) js +indexArray is (ValueExt l Nothing) = Just $ ValueExt (indexLocation (map fromIntegral $ convertIs is) l) Nothing + where + convertIs (IndexingFix i : is') = i : convertIs is' + convertIs [] = [] + convertIs _ = error "TODO (89r12quiowdjl)" +indexArray is (ValueExt _ (Just v)) = indexArray is v indexArray _ v = Just v writeArray :: [Indexing] -> Value -> Value -> Maybe Value @@ -1233,7 +1251,26 @@ evalModExp env (ModApply f e (Info psubst) (Info rsubst) _) = do pure (f_env <> e_env <> res_env <> env_substs, res_mod) _ -> error "Expected ModuleFun." +extFun :: VName -> Int -> [Value] -> EvalM Value +extFun n i _ | i < 1 = liftF $ ExtOpCall n [] id -- Special case: Functions with 0 parameters - i.e. values +extFun n i vs | i == 1 = pure . ValueFun $ \v -> liftF $ ExtOpCall n (reverse $ v : vs) id +extFun n i vs = pure . ValueFun $ \v -> extFun n (i - 1) (v : vs) + evalDec :: Env -> Dec -> EvalM Env +evalDec env (ValDec vb@(ValBind (Just _) (VName vn vi) _ _ (Info ret) tparams ps fbody _ _ _)) + | "$external" `elem` valBindAttrs vb = localExts $ do + let n = VName (nameFromText $ nameToText vn) vi + binding <- evalValBinding env tparams ps ret fbody + case binding of + (TermValue (Just t) _) -> do + sizes <- extEnv + f <- extFun n (length ps) [] + pure $ mempty {envTerm = M.singleton n $ TermValue (Just t) f} <> sizes + (TermPoly (Just t) _) -> do + sizes <- extEnv + f <- extFun n (length ps) [] + pure $ mempty {envTerm = M.singleton n $ TermValue (Just t) f} <> sizes + _ -> error "TODO: Impossible? (e2huqidjnk)" evalDec env (ValDec (ValBind _ v _ _ (Info ret) tparams ps fbody _ _ _)) = localExts $ do binding <- evalValBinding env tparams ps ret fbody sizes <- extEnv @@ -1590,7 +1627,9 @@ initialCtx = bopDef fs = fun2 $ \x y -> do i <- getCounter - case (x, y) of + x''' <- realize x + y''' <- realize y + case (x''', y''') of (ValuePrim x', ValuePrim y') | Just z <- msum $ map (`bopDef'` (x', y')) fs -> do breakOnNaN [x', y'] z diff --git a/src/Language/Futhark/Interpreter/FFI.hs b/src/Language/Futhark/Interpreter/FFI.hs new file mode 100644 index 0000000000..90406f81a4 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI.hs @@ -0,0 +1,81 @@ +module Language.Futhark.Interpreter.FFI + ( ExTypeAtom, + ExType, + ExValueAtom, + ExValue, + Function (..), + Interface (..), + InFunction, + InInterface, + ExFunction, + ExInterface, + fromInterpreterValue, + toInterpreterValue, + ) +where + +import Data.Array qualified as A +import Data.Map qualified as M +import Language.Futhark.Core (Name, nameFromText, nameToText) +import Language.Futhark.Interpreter.FFI.UIDs +import Language.Futhark.Interpreter.FFI.Util.NDArray qualified as ND +import Language.Futhark.Interpreter.FFI.Values +import Language.Futhark.Interpreter.Values qualified as I +import Language.Futhark.Interpreter.Values qualified as S + +data Function a + = Function [Type a] (Type a) + deriving (Show, Eq) + +newtype Interface a + = Interface (M.Map Name (Function a)) + deriving (Show) + +type InFunction = Function PrimitiveType + +type InInterface = Interface PrimitiveType + +type ExTypeAtom = Either TypeUID PrimitiveType + +type ExValueAtom = Either Location PrimitiveValue + +type ExType = Type ExTypeAtom + +type ExValue = Value ExValueAtom + +type ExFunction = Function ExTypeAtom + +type ExInterface = Interface ExTypeAtom + +fromInterpreterValue :: I.Value m -> ExValue +fromInterpreterValue (I.ValuePrim v) = Atom $ Right $ fromPrimValue v +fromInterpreterValue iv@(I.ValueArray _ _) = Array $ fmap fromInterpreterValue $ ND.fromList (dims iv) $ flatten iv + where + flatten :: I.Value m -> [I.Value m] + flatten (I.ValueArray _ a) = concatMap flatten $ A.elems a + flatten v = [v] + + dims :: I.Value m -> [Int] + dims (I.ValueArray _ a) = let (l, u) = A.bounds a in u - l + 1 : dims (a A.! 0) + dims _ = [] +fromInterpreterValue (I.ValueRecord m) = Record $ M.map fromInterpreterValue $ M.mapKeys nameToText m +fromInterpreterValue (I.ValueSum _ n v) = Sum (nameToText n) $ map fromInterpreterValue v +fromInterpreterValue (I.ValueExt l _) = Atom $ Left l +fromInterpreterValue _ = error "TODO (qu9wdaoijlm)" + +toInterpreterValue :: ExValue -> I.Value m +toInterpreterValue (Atom (Right v)) = I.ValuePrim $ toPrimValue v +-- TODO: Add shape +toInterpreterValue (Array nd) = unflatten [] + where + unflatten :: [Int] -> I.Value m + unflatten idx = + if length idx == ND.rank nd + then toInterpreterValue $ nd ND.! reverse idx + else + let ni = ND.shape nd !! length idx + in I.ValueArray S.ShapeLeaf $ A.listArray (0, ni - 1) $ map (unflatten . (: idx)) [0 .. ni - 1] +toInterpreterValue (Record m) = I.ValueRecord $ M.map toInterpreterValue $ M.mapKeys nameFromText m +-- TODO: Add shape +toInterpreterValue (Sum n v) = I.ValueSum (I.ShapeSum M.empty) (nameFromText n) $ map toInterpreterValue v +toInterpreterValue (Atom (Left vid)) = I.ValueExt vid Nothing diff --git a/src/Language/Futhark/Interpreter/FFI/Server.hs b/src/Language/Futhark/Interpreter/FFI/Server.hs new file mode 100644 index 0000000000..56ffb686ed --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Server.hs @@ -0,0 +1,113 @@ +module Language.Futhark.Interpreter.FFI.Server + ( FutharkServer (..), + startServer, + FutharkServerM, + server, + interface, + getValueUID, + getValueUIDs, + runFutharkServerM, + typeUIDOf, + typeLayoutOf, + getChild, + putChild, + ) +where + +import Control.Arrow (Arrow (second)) +import Control.Monad.RWS (MonadState (get, put), gets) +import Control.Monad.Reader (MonadIO, MonadReader, MonadTrans (lift), ReaderT (runReaderT), asks) +import Control.Monad.State (StateT (runStateT)) +import Data.Map qualified as M +import Futhark.Server qualified as S +import Futhark.Util (isEnvVarAtLeast) +import Language.Futhark.Interpreter.FFI.Server.Explorer (exploreProgram) +import Language.Futhark.Interpreter.FFI.Server.Interface (ServerInterface (..)) +import Language.Futhark.Interpreter.FFI.Server.TypeLayout (TypeLayout (..)) +import Language.Futhark.Interpreter.FFI.UIDs (TypeUID, UIDSource, UIDSourceT, ValueUID, getUID, getUIDs, runUIDSourceT) +import Language.Futhark.Interpreter.FFI.Values (Direction) +import Prelude hiding (init) + +-- Server and function calling +data FutharkServer = FutharkServer + { fsInfo :: FutharkServerInfo, + fsUIDSource :: UIDSource, + fsState :: FutharkServerState + } + +data FutharkServerInfo = FutharkServerInfo + { fsiServer :: S.Server, + fsiInterface :: ServerInterface + } + +newtype FutharkServerState = FutharkServerState + { fssValues :: M.Map ValueUID (TypeUID, M.Map Direction ValueUID) + } + +init :: S.Server -> IO FutharkServer +init s = do + info <- FutharkServerInfo s <$> exploreProgram s + pure $ FutharkServer info mempty $ FutharkServerState mempty + +futharkServerCfg :: FilePath -> [String] -> S.ServerCfg +futharkServerCfg prog opts = + (S.newServerCfg prog opts) + { S.cfgDebug = isEnvVarAtLeast "FUTHARK_COMPILER_DEBUGGING" 1 + } + +startServer :: FilePath -> IO FutharkServer +startServer prog = S.startServer (futharkServerCfg prog []) >>= init + +newtype FutharkServerM a = FutharkServerM (ReaderT FutharkServerInfo (StateT FutharkServerState (UIDSourceT IO)) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadReader FutharkServerInfo, MonadState FutharkServerState) + +runFutharkServerM :: FutharkServerM a -> FutharkServer -> IO (a, FutharkServer) +runFutharkServerM (FutharkServerM m) s = do + ((o, state'), src) <- runUIDSourceT (runStateT (runReaderT m $ fsInfo s) $ fsState s) $ fsUIDSource s + pure + ( o, + s + { fsUIDSource = src, + fsState = state' + } + ) + +server :: FutharkServerM S.Server +server = asks fsiServer + +typeUIDOf :: ValueUID -> FutharkServerM (Maybe TypeUID) +typeUIDOf vid = fmap fst . M.lookup vid <$> gets fssValues + +typeLayout :: TypeUID -> FutharkServerM (Maybe TypeLayout) +typeLayout t = M.lookup t . siTypeLayout <$> interface + +typeLayoutOf :: ValueUID -> FutharkServerM (Maybe TypeLayout) +typeLayoutOf vid = typeUIDOf vid >>= maybe (pure Nothing) typeLayout + +-- typeOf :: ValueUID -> FutharkServerM (Maybe (Type PrimitiveType)) +-- typeOf = typeLayoutOf vid >>= maybe (pure Nothing) toType +-- where +-- toType :: TypeLayout -> FutharkServerM (Type PrimitiveType) +-- toType (TLPrimitive t) = pure $ TAtom t +-- toType (TLArray t) = fmap TArray <$> typeOf t +-- toType _ = undefined + +getChild :: ValueUID -> Direction -> FutharkServerM (Maybe ValueUID) +getChild vid d = do + s <- gets fssValues + pure $ M.lookup vid s >>= M.lookup d . snd + +putChild :: ValueUID -> Direction -> ValueUID -> FutharkServerM () +putChild pvid d cvid = do + s <- get + let children = maybe mempty snd $ M.lookup pvid (fssValues s) + put s {fssValues = M.adjust (second $ const $ M.insert d cvid children) pvid $ fssValues s} + +interface :: FutharkServerM ServerInterface +interface = asks fsiInterface + +getValueUID :: FutharkServerM ValueUID +getValueUID = FutharkServerM $ lift . lift $ getUID + +getValueUIDs :: Word -> FutharkServerM [ValueUID] +getValueUIDs n = FutharkServerM $ lift . lift $ getUIDs n diff --git a/src/Language/Futhark/Interpreter/FFI/Server/Explorer.hs b/src/Language/Futhark/Interpreter/FFI/Server/Explorer.hs new file mode 100644 index 0000000000..8a8f77d507 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Server/Explorer.hs @@ -0,0 +1,124 @@ +module Language.Futhark.Interpreter.FFI.Server.Explorer + ( exploreProgram, + ) +where + +import Control.Monad (forM) +import Control.Monad.State (MonadIO (liftIO), MonadState, StateT (runStateT), gets, modify) +import Data.Map qualified as M +import Futhark.Server qualified as S +import Language.Futhark.Interpreter.FFI.Server.Interface (Entry (..), ServerInterface (..)) +import Language.Futhark.Interpreter.FFI.Server.TypeLayout (TypeLayout (..)) +import Language.Futhark.Interpreter.FFI.UIDs +import Language.Futhark.Interpreter.FFI.Util.BiMap qualified as BM +import Language.Futhark.Interpreter.FFI.Values (PrimitiveType (..)) + +-- The explorer monad +newtype ServerExplorer a = ServerExplorer (UIDSourceT (StateT ServerInterface IO) a) + deriving (Functor, Applicative, Monad, MonadIO, MonadState ServerInterface) + +runServerExplorer :: ServerExplorer a -> UIDSource -> IO (ServerInterface, UIDSource) +runServerExplorer (ServerExplorer m) s = do + ((_, s'), o) <- runStateT (runUIDSourceT m s) mempty + pure (o, s') + +-- Utility functions +lookupTypeName :: S.TypeName -> ServerExplorer (Maybe TypeUID) +lookupTypeName n = ServerExplorer $ gets $ BM.lookupRight n . siType + +putEntryPoint :: S.EntryName -> [TypeUID] -> [TypeUID] -> ServerExplorer EntryUID +putEntryPoint n i o = ServerExplorer $ do + eid <- getUID + modify + ( \s -> + s + { siEntryPoint = BM.insert n eid $ siEntryPoint s, + siEntryPointInfo = M.insert eid (Entry i o) $ siEntryPointInfo s + } + ) + pure eid + +putType :: S.TypeName -> TypeLayout -> ServerExplorer TypeUID +putType n l = ServerExplorer $ do + tid <- getUID + modify + ( \s -> + s + { siType = BM.insert n tid $ siType s, + siTypeLayout = M.insert tid l $ siTypeLayout s + } + ) + pure tid + +-- Exploration logic +exploreType :: S.Server -> S.TypeName -> ServerExplorer TypeUID +exploreType s n = do + tid <- lookupTypeName n + case tid of + Just tid' -> pure tid' + Nothing -> do + k <- liftIO $ S.cmdKind s n + case k of + Right S.Primitive -> handlePrimitive + Right S.Array -> handleArray + Right S.Record -> handleRecord + Right S.Sum -> handleSum + Right S.Opaque -> handleOpaque + Left _ -> error "TODO (0u2qeiowjdkslm)" + where + handlePrimitive = putType n $ + TLPrimitive $ + case n of + "i8" -> TInt8 + "i16" -> TInt16 + "i32" -> TInt32 + "i64" -> TInt64 + "u8" -> TUInt8 + "u16" -> TUInt16 + "u32" -> TUInt32 + "u64" -> TUInt64 + "f16" -> TFloat16 + "f32" -> TFloat32 + "f64" -> TFloat64 + "bool" -> TBool + _ -> error "TODO (89urijqowdklmacs)" + handleArray = do + e <- liftIO $ S.cmdElemtype s n + case e of + Right e' -> exploreType s e' >>= putType n . TLArray + _ -> error "TODO (u890wqfioajscklm)" + handleRecord = do + fs <- liftIO $ S.cmdFields s n + case fs of + Right fs' -> + forM fs' (\f -> (S.fieldName f,) <$> exploreType s (S.fieldType f)) + >>= putType n . TLRecord + Left _ -> error "TODO (aq0iwpoak)" + handleSum = do + vs <- liftIO $ S.cmdVariants s n + case vs of + Right vs' -> + forM vs' (\v -> M.singleton (S.variantName v) <$> mapM (exploreType s) (S.variantTypes v)) + >>= putType n . TLSum . M.unions + Left _ -> error "TODO (r928quwfijoasckl)" + handleOpaque = putType n TLOpaque + +exploreEntryPoint :: S.Server -> S.EntryName -> ServerExplorer EntryUID +exploreEntryPoint s n = do + is <- liftIO $ S.cmdInputs s n + os <- liftIO $ S.cmdOutputs s n + case (is, os) of + (Right is', Right os') -> do + is'' <- forM is' $ exploreType s . S.inputType + os'' <- forM os' $ exploreType s . S.outputType + putEntryPoint n is'' os'' + _ -> error "TODO (98urqoijwdlansc)" + +exploreProgram :: S.Server -> IO ServerInterface +exploreProgram s = fst <$> runServerExplorer exploreProgram' mempty + where + exploreProgram' = do + es <- liftIO $ S.cmdEntryPoints s + case es of + Right es' -> mapM_ (exploreEntryPoint s) es' + _ -> error "TODO (8u2eiojqdlkm)" diff --git a/src/Language/Futhark/Interpreter/FFI/Server/Interface.hs b/src/Language/Futhark/Interpreter/FFI/Server/Interface.hs new file mode 100644 index 0000000000..ab32f374cd --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Server/Interface.hs @@ -0,0 +1,35 @@ +module Language.Futhark.Interpreter.FFI.Server.Interface + ( Entry (..), + ServerInterface (..), + ) +where + +import Data.Map qualified as M +import Futhark.Server qualified as S +import Language.Futhark.Interpreter.FFI.Server.TypeLayout (TypeLayout) +import Language.Futhark.Interpreter.FFI.UIDs (EntryUID, TypeUID) +import Language.Futhark.Interpreter.FFI.Util.BiMap qualified as BM + +data Entry = Entry [TypeUID] [TypeUID] + deriving (Eq, Ord, Show) + +data ServerInterface = ServerInterface + { siEntryPoint :: BM.BiMap S.EntryName EntryUID, + siEntryPointInfo :: M.Map EntryUID Entry, + siType :: BM.BiMap S.TypeName TypeUID, + siTypeLayout :: M.Map TypeUID TypeLayout + } + deriving (Show) + +instance Monoid ServerInterface where + mempty = ServerInterface mempty mempty mempty mempty + +instance Semigroup ServerInterface where + (<>) + (ServerInterface en1 e1 tn1 t1) + (ServerInterface en2 e2 tn2 t2) = + ServerInterface + (en1 <> en2) + (e1 <> e2) + (tn1 <> tn2) + (t1 <> t2) diff --git a/src/Language/Futhark/Interpreter/FFI/Server/Packer.hs b/src/Language/Futhark/Interpreter/FFI/Server/Packer.hs new file mode 100644 index 0000000000..d2532e01af --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Server/Packer.hs @@ -0,0 +1,302 @@ +module Language.Futhark.Interpreter.FFI.Server.Packer + ( call, + realize, + realize', + ) +where + +import Control.Arrow (Arrow (second)) +import Control.Monad (forM, forM_, replicateM, void, zipWithM, zipWithM_) +import Control.Monad.Reader (MonadIO (liftIO), MonadReader (ask), ReaderT (runReaderT)) +import Control.Monad.State (MonadTrans (lift), StateT (runStateT), gets, modify) +import Data.Binary qualified as B +import Data.Binary.Get qualified as B +import Data.ByteString.Lazy qualified as BL +import Data.Functor.Identity (Identity (runIdentity)) +import Data.Map qualified as M +import Data.Maybe (fromJust, fromMaybe) +import Data.Text qualified as T +import Futhark.Server qualified as S +import Futhark.Test.Values qualified as V +import GHC.IO.Handle (hClose) +import Language.Futhark.Interpreter.FFI (ExValue, ExValueAtom) +import Language.Futhark.Interpreter.FFI.Server (FutharkServerM) +import Language.Futhark.Interpreter.FFI.Server qualified as FS +import Language.Futhark.Interpreter.FFI.Server.Interface (Entry (Entry), ServerInterface (..)) +import Language.Futhark.Interpreter.FFI.Server.TypeLayout (TypeLayout (..)) +import Language.Futhark.Interpreter.FFI.UIDs +import Language.Futhark.Interpreter.FFI.Util.BiMap qualified as BM +import Language.Futhark.Interpreter.FFI.Util.NDArray qualified as ND +import Language.Futhark.Interpreter.FFI.Values +import System.IO.Temp (withSystemTempFile) +import Prelude hiding (init) + +varName :: ValueUID -> S.VarName +varName v = "v" <> T.show (uid v) + +getType :: ValueUID -> FutharkServerM TypeUID +getType v = do + s <- FS.server + t <- either (error "TODO (rqwy8dauisoj)") id <$> liftIO (S.cmdType s $ varName v) + si <- FS.interface + case BM.lookupRight t $ siType si of + Just tid -> pure tid + Nothing -> error "TODO (ru938wojisdlcmkzx)" + +realize' :: Location -> FutharkServerM ExValue +realize' (Location vid ds) = do + o <- realize $ Location vid $ reverse ds + s <- FS.server + _ <- liftIO $ S.cmdType s (varName o) + t <- getType o + si <- FS.interface + (ovs, oids) <- runPackerT (pack fEx t $ Atom o) si + ovs''' <- mapM (mapM (pure . ooga2)) [ovs] + head <$> unload si oids (zip [t] ovs''') + +-- | Fully unpacks a single primitive value +realize :: Location -> FutharkServerM ValueUID +realize (Location vid []) = pure vid +realize (Location vid (d : ds)) = do + c <- FS.getChild vid d + cvid <- maybe (unpack' d) pure c + realize $ Location cvid ds + where + unpack' :: Direction -> FutharkServerM ValueUID + unpack' (Index is) = do + s <- FS.server + vid' <- FS.getValueUID + void $ liftIO $ S.cmdIndex s (varName vid') (varName vid) is + FS.putChild vid d vid' + pure vid' + unpack' (Field f) = do + s <- FS.server + vid' <- FS.getValueUID + void $ liftIO $ S.cmdProject s (varName vid') (varName vid) f + FS.putChild vid d vid' + pure vid' + unpack' (VariantValue v i) = do + s <- FS.server + t <- FS.typeLayoutOf vid + let ts = case t of + Just (TLSum m) -> m M.! v + _ -> error "TODO (98uroijwdl)" + vid's <- FS.getValueUIDs $ fromIntegral $ length ts + void $ liftIO $ S.cmdDestruct s v (map varName vid's) + zipWithM_ (FS.putChild vid . VariantValue v) [0 ..] vid's + pure $ vid's !! i + +newtype PackerT v m a = PackerT (ReaderT ServerInterface (StateT [v] m) a) + deriving (Functor, Applicative, Monad, MonadIO) + +instance MonadTrans (PackerT v) where + lift = PackerT . lift . lift + +type PackerM v = PackerT v Identity + +runPackerT :: (Monad m) => PackerT v m a -> ServerInterface -> m (a, [v]) +runPackerT (PackerT m) i = second reverse <$> runStateT (runReaderT m i) mempty + +runPackerM :: PackerM v a -> ServerInterface -> (a, [v]) +runPackerM m = runIdentity . runPackerT m + +addValue :: (Monad m) => v -> PackerT v m Int +addValue v = PackerT $ do + modify (v :) + gets $ (+ (-1)) . length + +interface :: (Monad m) => PackerT v m ServerInterface +interface = PackerT ask + +pack :: (Monad m) => (TypeLayout -> a -> PackerT v m (Value b)) -> TypeUID -> Value a -> PackerT v m (Value b) +pack f tid v = do + i <- interface + case M.lookup tid $ siTypeLayout i of + Just l -> pack' l v + Nothing -> error "TODO (ru98qwojialskcm)" + where + pack' l (Atom a) = f l a + pack' (TLArray t) (Array a) = Array . ND.fromList (ND.shape a) <$> mapM (pack f t) (ND.elems a) + pack' (TLRecord fs) (Record m) = do + ms <- mapM (\(n, t) -> pack f t $ fromMaybe (error "TODO (r983uwiofhjklna,)") $ M.lookup n m) fs + let m' = M.fromList $ zip (map fst fs) ms + pure $ Record m' + pack' (TLSum m) (Sum svn svs) = do + let ts = fromMaybe (error "TODO (r893uqoijwdln)") $ M.lookup svn m + svs' <- zipWithM (pack f) ts svs + pure $ Sum svn svs' + pack' _ _ = error "TODO: (9r8uqowfijlas)" + +fIn :: TypeLayout -> ExValueAtom -> PackerM PrimitiveValue (Value (Either Location Int)) +fIn (TLPrimitive _) (Right p) = Atom . Right <$> addValue p +fIn TLOpaque (Left l) = pure $ Atom $ Left l +fIn _ _ = error "TODO (u8roqjiwlfa)" + +fEx :: TypeLayout -> ValueUID -> PackerT (ValueUID, PrimitiveType) FutharkServerM (Value (Either ValueUID Int)) +fEx (TLPrimitive t) l = Atom . Right <$> addValue (l, t) +fEx TLOpaque l = pure $ Atom $ Left l +fEx (TLArray t) vid = pure $ Atom $ Left vid -- do +-- s <- lift $ FS.server +-- shape <- either (error "TODO (98ueiwe)") id <$> liftIO (S.cmdShape s $ varName vid) +-- ids <- mapM (const <$> lift $ FS.getValueUID) [1..foldl (*) 1 shape] +-- let nd = ND.fromList shape ids +-- ND.mapMWithIndex_ (\i v -> liftIO $ S.cmdIndex s (varName v) (varName vid) i) nd +-- Array <$> mapM ((pack fEx t) . Atom) nd +fEx (TLRecord f) vid = do + s <- lift FS.server + vids <- forM f $ \(n, _) -> do + fvid <- lift FS.getValueUID + void $ liftIO $ S.cmdProject s (varName fvid) (varName vid) n + pure fvid + Record . M.fromList . zip (map fst f) <$> zipWithM (pack fEx) (map snd f) (map Atom vids) +fEx (TLSum m) vid = do + s <- lift FS.server + void $ liftIO $ S.cmdType s $ varName vid + vn <- either (error . ("TODO (uojdqlamk) " ++) . show) id <$> liftIO (S.cmdVariant s (varName vid)) + let ts = fromJust $ M.lookup vn m + vids <- forM ts $ const $ lift FS.getValueUID + void $ liftIO $ S.cmdDestruct s (varName vid) $ map varName vids + Sum vn <$> zipWithM (pack fEx) ts (map Atom vids) + +packAll :: (Monad m) => (TypeLayout -> a -> PackerT v m (Value b)) -> [(TypeUID, Value a)] -> PackerT v m [Value b] +packAll f vs = forM vs $ uncurry $ pack f + +load :: (S.Server, ServerInterface) -> [PrimitiveValue] -> [(TypeUID, Value (Either ValueUID Int))] -> FutharkServerM [ValueUID] +load (s, i) ps vs = do + vids <- replicateM (length ps) FS.getValueUID + liftIO $ withSystemTempFile "futhark-call-load" $ \tmpf tmpf_h -> do + forM_ ps $ BL.hPutStr tmpf_h . encodePrimitive + hClose tmpf_h + void $ S.cmdRestore s tmpf $ zip (map varName vids) (map (T.pack . primitiveTypeName . primitiveType) ps) + forM (map (\(tid, v) -> (tid, look tid, v)) vs) (load' vids) + where + encodePrimitive :: PrimitiveValue -> BL.ByteString + encodePrimitive (Int8 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Int16 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Int32 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Int64 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (UInt8 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (UInt16 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (UInt32 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (UInt64 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Float16 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Float32 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Float64 v) = B.encode $ fromJust $ V.putValue v + encodePrimitive (Bool v) = B.encode $ fromJust $ V.putValue v + + load' :: [ValueUID] -> (TypeUID, TypeLayout, Value (Either ValueUID Int)) -> FutharkServerM ValueUID + load' _ (_, _, Atom (Left vid)) = pure vid + load' vids (_, _, Atom (Right idx)) = pure $ vids !! idx + load' vids (tid, TLArray t, Array a) = do + values <- mapM (load' vids . (t,look t,)) $ ND.elems a + o <- FS.getValueUID + void $ liftIO $ S.cmdNewArray s (varName o) (fromJust $ BM.lookupLeft tid $ siType i) (ND.shape a) $ map varName values + pure o + load' vids (tid, TLRecord r, Record m) = do + k <- forM r $ load' vids . \(n, t) -> (t, look t, fromJust $ M.lookup n m) + o <- FS.getValueUID + void $ liftIO $ S.cmdNew s (varName o) (fromJust $ BM.lookupLeft tid $ siType i) $ map varName k + pure o + load' vids (tid, TLSum m, Sum vn vvs) = do + k <- forM (zip (fromJust $ M.lookup vn m) vvs) $ load' vids . \(t, v) -> (t, look t, v) + o <- FS.getValueUID + void $ liftIO $ S.cmdConstruct s (varName o) (fromJust $ BM.lookupLeft tid $ siType i) vn $ map varName k + pure o + load' _ _ = error "TODO (y8euiqdhjkanx)" + + look :: TypeUID -> TypeLayout + look tid = fromJust $ M.lookup tid $ siTypeLayout i + +unload :: ServerInterface -> [(ValueUID, PrimitiveType)] -> [(TypeUID, Value (Either Location Int))] -> FutharkServerM [ExValue] +unload i vs k = do + s <- FS.server + liftIO $ withSystemTempFile "futhark-call-unload" $ \tmpf tmpf_h -> do + hClose tmpf_h + void $ S.cmdStore s tmpf $ map (varName . fst) vs + bs <- BL.readFile tmpf + let vs' = case B.runGetOrFail (mapM (getPrimitive . snd) vs) bs of + Left v -> error $ "TODO (u89riqojkms) " ++ show v + Right (_, _, v) -> v + pure $ map (\(tid, v) -> unload' vs' (tid, look tid, v)) k + where + getPrimitive :: PrimitiveType -> B.Get PrimitiveValue + getPrimitive TInt8 = Int8 . fromJust . V.getValue <$> B.get + getPrimitive TInt16 = Int16 . fromJust . V.getValue <$> B.get + getPrimitive TInt32 = Int32 . fromJust . V.getValue <$> B.get + getPrimitive TInt64 = Int64 . fromJust . V.getValue <$> B.get + getPrimitive TUInt8 = UInt8 . fromJust . V.getValue <$> B.get + getPrimitive TUInt16 = UInt16 . fromJust . V.getValue <$> B.get + getPrimitive TUInt32 = UInt32 . fromJust . V.getValue <$> B.get + getPrimitive TUInt64 = UInt64 . fromJust . V.getValue <$> B.get + getPrimitive TFloat16 = Float16 . fromJust . V.getValue <$> B.get + getPrimitive TFloat32 = Float32 . fromJust . V.getValue <$> B.get + getPrimitive TFloat64 = Float64 . fromJust . V.getValue <$> B.get + getPrimitive TBool = Bool . fromJust . V.getValue <$> B.get + + look :: TypeUID -> TypeLayout + look tid = fromJust $ M.lookup tid $ siTypeLayout i + + unload' :: [PrimitiveValue] -> (TypeUID, TypeLayout, Value (Either Location Int)) -> ExValue + unload' pvs (_, TLPrimitive _, Atom (Right idx)) = Atom $ Right $ pvs !! idx + unload' _ (_, _, Atom (Left vid)) = Atom $ Left vid + unload' pvs (_, TLArray t, Array nd) = Array $ fmap (unload' pvs . (t,look t,)) nd + unload' pvs (_, TLRecord f, Record m) = + Record $ M.fromList $ zip (map fst f) $ map (\(n, t) -> unload' pvs (t, look t, fromJust $ M.lookup n m)) f + unload' pvs (_, TLSum m, Sum vn vvs) = + Sum vn $ zipWith (\t v -> unload' pvs (t, look t, v)) (fromJust $ M.lookup vn m) vvs + unload' _ _ = error "TODO (u8rqowijdalkcm)" + +ooga :: Either Location Int -> FutharkServerM (Either ValueUID Int) +ooga (Left l) = Left <$> realize l +ooga (Right i) = pure $ Right i + +ooga2 :: Either ValueUID Int -> Either Location Int +ooga2 (Left l) = Left $ Location l [] +ooga2 (Right i) = Right i + +call :: S.EntryName -> [ExValue] -> FutharkServerM ExValue +call n vs = do + s <- FS.server + si <- FS.interface + + -- Get entry info + (Entry i o) <- getEntryPointID n >>= getEntryPoint + + -- Send inputs + let (ivs, ps) = runPackerM (packAll fIn $ zip i vs) si + ivs''' <- mapM (mapM ooga) ivs + ivs' <- load (s, si) ps $ zip i ivs''' + + -- Call + o' <- replicateM (length o) FS.getValueUID + let o'' = map varName o' + void $ liftIO $ S.cmdCall s n o'' $ map varName ivs' + + -- Get outputs + (ovs, oids) <- runPackerT (packAll fEx $ zip o $ map Atom o') si + ovs''' <- mapM (mapM (pure . ooga2)) ovs + tuple' <$> unload si oids (zip o ovs''') + where + getEntryPointID n' = do + si <- FS.interface + case BM.lookupRight n' $ siEntryPoint si of + Just eid -> pure eid + Nothing -> error $ "Entry point \"" ++ T.unpack n' ++ "\" not found" + + getEntryPoint eid = do + si <- FS.interface + case M.lookup eid $ siEntryPointInfo si of + Just e -> pure e + Nothing -> error "Impossible (3urq8wfijoalskm)" -- TODO + tuple' :: [ExValue] -> ExValue + tuple' [v] = v + tuple' vs'' = toTuple vs'' + +-- TODO +-- realize :: Location -> FutharkServerM ExValue +-- realize vid = do +-- st <- FS.state +-- si <- FS.interface +-- let t = st M.! vid +-- (ov, oids) <- runPackerT (pack fEx t $ Atom vid) si +-- head <$> unload si oids [(t, ov)] diff --git a/src/Language/Futhark/Interpreter/FFI/Server/TypeLayout.hs b/src/Language/Futhark/Interpreter/FFI/Server/TypeLayout.hs new file mode 100644 index 0000000000..3e6c2751f2 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Server/TypeLayout.hs @@ -0,0 +1,18 @@ +module Language.Futhark.Interpreter.FFI.Server.TypeLayout + ( TypeLayout (..), + ) +where + +import Data.Map qualified as M +import Futhark.Server qualified as S +import Language.Futhark.Interpreter.FFI.UIDs (TypeUID) +import Language.Futhark.Interpreter.FFI.Values (PrimitiveType) +import Prelude hiding (init) + +data TypeLayout + = TLPrimitive PrimitiveType + | TLArray TypeUID + | TLRecord [(S.FieldName, TypeUID)] + | TLSum (M.Map S.VariantName [TypeUID]) + | TLOpaque + deriving (Show, Eq, Ord) diff --git a/src/Language/Futhark/Interpreter/FFI/UIDs.hs b/src/Language/Futhark/Interpreter/FFI/UIDs.hs new file mode 100644 index 0000000000..fe360fc7c2 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/UIDs.hs @@ -0,0 +1,34 @@ +module Language.Futhark.Interpreter.FFI.UIDs + ( EntryUID, + TypeUID, + ValueUID, + UID.uid, + UIDSource, + UIDSourceT, + UID.runUIDSourceT, + UIDSourceM, + UID.runUIDSourceM, + UID.getUID, + UID.getUIDs, + ) +where + +import Language.Futhark.Interpreter.FFI.Util.UID qualified as UID + +data Entry + +data Type + +data Value + +type EntryUID = UID.UID Entry Word + +type TypeUID = UID.UID Type Word + +type ValueUID = UID.UID Value Word + +type UIDSource = UID.UIDSource Word + +type UIDSourceT = UID.UIDSourceT Word + +type UIDSourceM = UID.UIDSourceM Word diff --git a/src/Language/Futhark/Interpreter/FFI/Util/BiMap.hs b/src/Language/Futhark/Interpreter/FFI/Util/BiMap.hs new file mode 100644 index 0000000000..707ea718e6 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Util/BiMap.hs @@ -0,0 +1,27 @@ +module Language.Futhark.Interpreter.FFI.Util.BiMap + ( BiMap, + insert, + lookupRight, + lookupLeft, + ) +where + +import Data.Map qualified as M + +data BiMap a b = BiMap (M.Map a b) (M.Map b a) + deriving (Eq, Ord, Show) + +instance (Ord a, Ord b) => Monoid (BiMap a b) where + mempty = BiMap mempty mempty + +instance (Ord a, Ord b) => Semigroup (BiMap a b) where + BiMap r1 l1 <> BiMap r2 l2 = BiMap (r1 <> r2) (l1 <> l2) + +insert :: (Ord a, Ord b) => a -> b -> BiMap a b -> BiMap a b +insert l r (BiMap mr ml) = BiMap (M.insert l r mr) (M.insert r l ml) + +lookupRight :: (Ord a) => a -> BiMap a b -> Maybe b +lookupRight l (BiMap mr _) = M.lookup l mr + +lookupLeft :: (Ord b) => b -> BiMap a b -> Maybe a +lookupLeft r (BiMap _ ml) = M.lookup r ml diff --git a/src/Language/Futhark/Interpreter/FFI/Util/NDArray.hs b/src/Language/Futhark/Interpreter/FFI/Util/NDArray.hs new file mode 100644 index 0000000000..a24f6f7204 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Util/NDArray.hs @@ -0,0 +1,59 @@ +module Language.Futhark.Interpreter.FFI.Util.NDArray + ( NDArray, + fromList, + shape, + size, + rank, + (!), + elems, + flatten, + mapWithIndex, + mapMWithIndex, + mapMWithIndex_, + ) +where + +import Control.Monad (zipWithM, zipWithM_) +import Data.Array qualified as A + +data NDArray a = NDArray [Int] (A.Array Int a) + deriving (Eq, Ord, Show, Functor, Foldable, Traversable) + +fromList :: [Int] -> [a] -> NDArray a +fromList s l = NDArray s $ A.array (0, length l - 1) (zip [0 ..] l) + +shape :: NDArray a -> [Int] +shape (NDArray s _) = s + +size :: NDArray a -> Int +size = product . shape + +rank :: NDArray a -> Int +rank = length . shape + +(!) :: NDArray a -> [Int] -> a +(!) (NDArray s a) idx = + let i = sum $ zipWith (*) (reverse idx) $ scanl (*) 1 s + in a A.! i + +elems :: NDArray a -> [a] +elems (NDArray _ a) = A.elems a + +flatten :: NDArray a -> A.Array Int a +flatten (NDArray _ a) = a + +indexOf :: [Int] -> Int -> [Int] +indexOf (d : ds) i = (i `mod` d) : indexOf ds (i `div` d) +indexOf [] _ = [] + +mapWithIndex :: ([Int] -> a -> b) -> NDArray a -> NDArray b +mapWithIndex f nd = + fromList (shape nd) $ zipWith f (map (indexOf $ shape nd) [0 .. size nd]) $ elems nd + +mapMWithIndex :: (Monad m) => ([Int] -> a -> m b) -> NDArray a -> m (NDArray b) +mapMWithIndex f nd = + fromList (shape nd) <$> zipWithM f (map (indexOf $ shape nd) [0 .. size nd]) (elems nd) + +mapMWithIndex_ :: (Monad m) => ([Int] -> a -> m b) -> NDArray a -> m () +mapMWithIndex_ f nd = + zipWithM_ f (map (indexOf $ shape nd) [0 .. size nd]) (elems nd) diff --git a/src/Language/Futhark/Interpreter/FFI/Util/UID.hs b/src/Language/Futhark/Interpreter/FFI/Util/UID.hs new file mode 100644 index 0000000000..61746c323e --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Util/UID.hs @@ -0,0 +1,65 @@ +{-# LANGUAGE UndecidableInstances #-} + +module Language.Futhark.Interpreter.FFI.Util.UID + ( -- Unique IDs + UID (uid), + -- Unique ID source + UIDSource, + nextUID, + -- Unique ID source monad transformer + UIDSourceT, + runUIDSourceT, + UIDSourceM, + runUIDSourceM, + getUID, + getUIDs, + ) +where + +import Control.Monad.RWS (MonadReader (ask, local)) +import Control.Monad.State (MonadIO, MonadState (get, put, state), StateT (runStateT)) +import Control.Monad.Trans.Class +import Data.Functor.Identity (Identity (runIdentity)) + +-- External IDs +newtype UID p r = UID {uid :: r} + deriving (Show, Eq, Ord, Functor) + +-- External ID source +newtype UIDSource r = UIDSource r + +instance (Ord r) => Semigroup (UIDSource r) where + UIDSource i1 <> UIDSource i2 = UIDSource $ max i1 i2 + +instance (Ord r, Bounded r) => Monoid (UIDSource r) where + mempty = UIDSource minBound + +nextUID :: (Enum r) => UIDSource r -> (UID p r, UIDSource r) +nextUID (UIDSource i) = (UID i, UIDSource $ succ i) + +-- External ID source monad transformer +newtype UIDSourceT r m a = UIDSourceT (StateT (UIDSource r) m a) + deriving (Functor, Applicative, Monad, MonadTrans, MonadIO) + +type UIDSourceM r = UIDSourceT r Identity + +instance (MonadState s m) => MonadState s (UIDSourceT r m) where + get = lift get + put = lift . put + state = lift . state + +instance (MonadReader r m) => MonadReader r (UIDSourceT r m) where + ask = lift ask + local f (UIDSourceT m) = UIDSourceT (local f m) + +runUIDSourceT :: UIDSourceT r m a -> UIDSource r -> m (a, UIDSource r) +runUIDSourceT (UIDSourceT m) = runStateT m + +runUIDSourceM :: UIDSourceT r Identity a -> UIDSource r -> (a, UIDSource r) +runUIDSourceM m = runIdentity . runUIDSourceT m + +getUID :: (Monad m, Enum r) => UIDSourceT r m (UID p r) +getUID = UIDSourceT $ state nextUID + +getUIDs :: (Monad m, Bounded r, Enum r) => r -> UIDSourceT r m [UID p r] +getUIDs n = mapM (const getUID) [minBound .. n] diff --git a/src/Language/Futhark/Interpreter/FFI/Values.hs b/src/Language/Futhark/Interpreter/FFI/Values.hs new file mode 100644 index 0000000000..179ea8cdc8 --- /dev/null +++ b/src/Language/Futhark/Interpreter/FFI/Values.hs @@ -0,0 +1,153 @@ +module Language.Futhark.Interpreter.FFI.Values + ( PrimitiveType (..), + PrimitiveValue (..), + Type (..), + Value (..), + Direction (..), + Location (..), + InType, + InValue, + primitiveType, + primitiveTypeName, + toTuple, + toTupleType, + toPrimValue, + fromPrimValue, + indexLocation, + projectLocation, + ) +where + +import Data.Map qualified as M +import Data.Text qualified as T +import Language.Futhark.Core (Half, Int16, Int32, Int64, Int8, Word16, Word32, Word64, Word8) +import Language.Futhark.Interpreter.FFI.UIDs (ValueUID) +import Language.Futhark.Interpreter.FFI.Util.NDArray (NDArray) +import Language.Futhark.Syntax qualified as I + +data PrimitiveType + = TInt8 + | TInt16 + | TInt32 + | TInt64 + | TUInt8 + | TUInt16 + | TUInt32 + | TUInt64 + | TFloat16 + | TFloat32 + | TFloat64 + | TBool + deriving (Show, Eq, Ord) + +data PrimitiveValue + = Int8 Int8 + | Int16 Int16 + | Int32 Int32 + | Int64 Int64 + | UInt8 Word8 + | UInt16 Word16 + | UInt32 Word32 + | UInt64 Word64 + | Float16 Half + | Float32 Float + | Float64 Double + | Bool Bool + deriving (Show, Eq, Ord) + +data Type a + = TAtom a + | TArray (Type a) + | TRecord (M.Map T.Text (Type a)) + | TSum (M.Map T.Text [Type a]) + deriving (Show, Eq, Ord, Functor, Foldable, Traversable) + +data Value a + = Atom a + | Array (NDArray (Value a)) + | Record (M.Map T.Text (Value a)) + | Sum T.Text [Value a] + deriving (Show, Eq, Ord, Functor, Foldable, Traversable) + +data Direction + = Index [Int] + | Field T.Text + | VariantValue T.Text Int + deriving (Show, Eq, Ord) + +type Directions = [Direction] + +data Location = Location ValueUID Directions + deriving (Show, Eq, Ord) + +indexLocation :: [Int] -> Location -> Location +indexLocation i (Location vid ds) = Location vid $ Index i : ds + +projectLocation :: T.Text -> Location -> Location +projectLocation f (Location vid ds) = Location vid $ Field f : ds + +type InType = Type PrimitiveType + +type InValue = Value PrimitiveValue + +primitiveType :: PrimitiveValue -> PrimitiveType +primitiveType (Int8 _) = TInt8 +primitiveType (Int16 _) = TInt16 +primitiveType (Int32 _) = TInt32 +primitiveType (Int64 _) = TInt64 +primitiveType (UInt8 _) = TUInt8 +primitiveType (UInt16 _) = TUInt16 +primitiveType (UInt32 _) = TUInt32 +primitiveType (UInt64 _) = TUInt64 +primitiveType (Float16 _) = TFloat16 +primitiveType (Float32 _) = TFloat32 +primitiveType (Float64 _) = TFloat64 +primitiveType (Bool _) = TBool + +primitiveTypeName :: PrimitiveType -> String +primitiveTypeName TInt8 = "i8" +primitiveTypeName TInt16 = "i16" +primitiveTypeName TInt32 = "i32" +primitiveTypeName TInt64 = "i64" +primitiveTypeName TUInt8 = "u8" +primitiveTypeName TUInt16 = "u16" +primitiveTypeName TUInt32 = "u32" +primitiveTypeName TUInt64 = "u64" +primitiveTypeName TFloat16 = "f16" +primitiveTypeName TFloat32 = "f32" +primitiveTypeName TFloat64 = "f64" +primitiveTypeName TBool = "bool" + +toTuple :: [Value a] -> Value a +toTuple vs = Record $ M.fromList $ zip (map T.show ([0 ..] :: [Int])) vs + +toTupleType :: [Type a] -> Type a +toTupleType ts = TRecord $ M.fromList $ zip (map T.show ([0 ..] :: [Int])) ts + +toPrimValue :: PrimitiveValue -> I.PrimValue +toPrimValue (Int8 v) = I.SignedValue $ I.Int8Value v +toPrimValue (Int16 v) = I.SignedValue $ I.Int16Value v +toPrimValue (Int32 v) = I.SignedValue $ I.Int32Value v +toPrimValue (Int64 v) = I.SignedValue $ I.Int64Value v +toPrimValue (UInt8 v) = I.UnsignedValue $ I.Int8Value $ fromIntegral v +toPrimValue (UInt16 v) = I.UnsignedValue $ I.Int16Value $ fromIntegral v +toPrimValue (UInt32 v) = I.UnsignedValue $ I.Int32Value $ fromIntegral v +toPrimValue (UInt64 v) = I.UnsignedValue $ I.Int64Value $ fromIntegral v +toPrimValue (Float16 v) = I.FloatValue $ I.Float16Value v +toPrimValue (Float32 v) = I.FloatValue $ I.Float32Value v +toPrimValue (Float64 v) = I.FloatValue $ I.Float64Value v +toPrimValue (Bool v) = I.BoolValue v + +fromPrimValue :: I.PrimValue -> PrimitiveValue +fromPrimValue (I.SignedValue (I.Int8Value v)) = Int8 v +fromPrimValue (I.SignedValue (I.Int16Value v)) = Int16 v +fromPrimValue (I.SignedValue (I.Int32Value v)) = Int32 v +fromPrimValue (I.SignedValue (I.Int64Value v)) = Int64 v +fromPrimValue (I.UnsignedValue (I.Int8Value v)) = UInt8 $ fromIntegral v +fromPrimValue (I.UnsignedValue (I.Int16Value v)) = UInt16 $ fromIntegral v +fromPrimValue (I.UnsignedValue (I.Int32Value v)) = UInt32 $ fromIntegral v +fromPrimValue (I.UnsignedValue (I.Int64Value v)) = UInt64 $ fromIntegral v +fromPrimValue (I.FloatValue (I.Float16Value v)) = Float16 v +fromPrimValue (I.FloatValue (I.Float32Value v)) = Float32 v +fromPrimValue (I.FloatValue (I.Float64Value v)) = Float64 v +fromPrimValue (I.BoolValue v) = Bool v diff --git a/src/Language/Futhark/Interpreter/Values.hs b/src/Language/Futhark/Interpreter/Values.hs index 4fd6fab264..e2be58916f 100644 --- a/src/Language/Futhark/Interpreter/Values.hs +++ b/src/Language/Futhark/Interpreter/Values.hs @@ -43,6 +43,7 @@ import Futhark.Util (chunk, mapAccumLM) import Futhark.Util.Pretty import Language.Futhark hiding (Shape, matchDims) import Language.Futhark.Interpreter.AD qualified as AD +import Language.Futhark.Interpreter.FFI.Values (Location) import Language.Futhark.Primitive qualified as P import Prelude hiding (break, mod) @@ -113,6 +114,8 @@ data Value m ValueAcc ValueShape (Value m -> Value m -> m (Value m)) !(Array Int (Value m)) | -- A primitive value with added information used in automatic differentiation ValueAD AD.Depth AD.ADVariable + | -- An external value + ValueExt Location (Maybe (Value m)) instance Show (Value m) where show (ValuePrim v) = "ValuePrim " <> show v <> "" @@ -122,6 +125,7 @@ instance Show (Value m) where show ValueFun {} = "ValueFun _" show ValueAcc {} = "ValueAcc _" show (ValueAD d v) = unwords ["ValueAD", show d, show v] + show (ValueExt l v) = unwords ["ValueExt", show l, show v] instance Eq (Value m) where ValuePrim (SignedValue x) == ValuePrim (SignedValue y) = @@ -154,6 +158,9 @@ prettyValueWith pprPrim = pprPrec 0 pprPrec p (ValueSum _ n vs) = parensIf (p > (0 :: Int)) $ "#" <> sep (pretty n : map (pprPrec 1) vs) pprPrec _ (ValueAD _ v) = pprPrim $ putV $ AD.varPrimal v + -- TODO: Show location? + pprPrec _ (ValueExt _ Nothing) = "ex" + pprPrec p (ValueExt _ (Just v)) = "ex" <> "(" <> pprPrec p v <> ")" pprElem v@ValueArray {} = pprPrec 0 v pprElem v = group $ pprPrec 0 v diff --git a/src/Language/Futhark/Syntax.hs b/src/Language/Futhark/Syntax.hs index 3a3cdbe0a4..f918f72cac 100644 --- a/src/Language/Futhark/Syntax.hs +++ b/src/Language/Futhark/Syntax.hs @@ -111,6 +111,7 @@ import Data.List.NonEmpty qualified as NE import Data.Map.Strict qualified as M import Data.Monoid hiding (Sum) import Data.Ord +import Data.String (IsString (..)) import Data.Text qualified as T import Data.Traversable import Futhark.Util.Loc @@ -221,6 +222,12 @@ data AttrInfo vn | AttrComp Name [AttrInfo vn] SrcLoc deriving (Eq, Ord, Show) +instance IsString (AttrAtom vn) where + fromString = AtomName . fromString + +instance IsString (AttrInfo vn) where + fromString s = AttrAtom (fromString s) mempty + -- | The elaborated size of a dimension is just an expression. type Size = ExpBase Info VName diff --git a/tests_adhoc/ffi/test.fut b/tests_adhoc/ffi/test.fut new file mode 100644 index 0000000000..1ccf1b8d3e --- /dev/null +++ b/tests_adhoc/ffi/test.fut @@ -0,0 +1,37 @@ +type r = { + x: i32, + y: i32 +} +type s = #a i32 | #b i32 | #c i32 + +entry p1: i32 = 1i32 +entry p2: i32 = 2i32 +entry p3: i32 = 3i32 + +entry r1: r = {x = 1i32, y = 2i32} +entry r2: r = {x = 3i32, y = 4i32} +entry r3: r = {x = 5i32, y = 6i32} + +entry s1: s = #a 2i32 +entry s2: s = #b 4i32 +entry s3: s = #c 6i32 + +entry pa1: [3]i32 = [p1, p2, p3] +entry pa2: [2][3]i32 = [[p1, p2, p3], [p3, p2, p1]] + +entry ra1: [3]r = [r1, r2, r3] +entry ra2: [2][3]r = [[r1, r2, r3], [r3, r2, r1]] + +entry sa1: [3]s = [s1, s2, s3] +entry sa2: [2][3]s = [[s1, s2, s3], [s3, s2, s1]] + +entry pf (x: i32): i32 = x ** 2 +entry rf (x: r): r = { x = x.x ** 2, y = x.y + 2 } +entry sf (x: s): s = + match x + case #a v -> #c (v + 1) + case #b v -> #b (v + 2) + case #c v -> #a (v + 3) + +entry pa1f (x: []i32): []i32 = map (**2) x +entry pa2f (x: [][]i32): [][]i32 = let v = map2 (**) x[0,:] x[1,:] in [v,v] diff --git a/tests_adhoc/ffi/test.sh b/tests_adhoc/ffi/test.sh new file mode 100755 index 0000000000..34a3cbd7a5 --- /dev/null +++ b/tests_adhoc/ffi/test.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash + +FILE="test.fut" +BACKEND=c + +# Compile the server executable +futhark ${BACKEND} --server $FILE + +# Expression +tests=( +# Preloaded primitives + "p1" "1" + "p2" "2" + "p3" "3" + +# Preloaded records + "r1" "{x = 1, y = 2}" + "r2" "{x = 3, y = 4}" + "r3" "{x = 5, y = 6}" + +# Preloaded sums + "s1" "#a 2" + "s2" "#b 4" + "s3" "#c 6" + +# Preloaded primitive arrays + "pa1[0] + 0" "1" + "pa1[1] + 0" "2" + "pa1[2] + 0" "3" + "pa2[1,0] + 0" "3" + "pa2[1,1] + 0" "2" + "pa2[1,2] + 0" "1" + +# Preloaded record arrays + "ra1[0].x + 0" "1" + "ra1[1].x + 0" "3" + "ra1[2].x + 0" "5" + "ra2[1,0].y + 0" "6" + "ra2[1,1].y + 0" "4" + "ra2[1,2].y + 0" "2" + +# Preloaded sum arrays + "sa1[0]" "ex" # TODO: I will add this when I auto-realize output values + "sa1[1]" "ex" # TODO: I will add this when I auto-realize output values + "sa1[2]" "ex" # TODO: I will add this when I auto-realize output values + "sa2[1,0]" "ex" # TODO: I will add this when I auto-realize output values + "sa2[1,1]" "ex" # TODO: I will add this when I auto-realize output values + "sa2[1,2]" "ex" # TODO: I will add this when I auto-realize output values + +# Primitive functions + "pf 2" "4" + "pf 3" "9" + +# Record functions + "rf {x = 1, y = 2}" "{x = 1, y = 4}" + "rf {x = 2, y = 1}" "{x = 4, y = 3}" + +# Sum functions + "sf (#a 2)" "#c 3" + "sf (#b 2)" "#b 4" + +# Primitive array functions + "(pa1f [1,2,3])[0] + 0" "1" + "(pa1f [1,2,3])[1] + 0" "4" + "(pa1f [1,2,3])[2] + 0" "9" + "(pa2f [[1,2,3], [3,2,1]])[0,0] + 0" "1" + "(pa2f [[1,2,3], [3,2,1]])[0,1] + 0" "4" + "(pa2f [[1,2,3], [3,2,1]])[0,2] + 0" "3" + +# TODO: Record and sum array functions +) + +for ((i=0; i<${#tests[@]}; i+=2)); do + exp="${tests[i]}" + expected="${tests[i+1]}" + + output=$(futhark eval --backend=${BACKEND} --skip-compilation -f "test.fut" "$exp" | tr '\n' ' ' | xargs) + + if [[ "$output" == "$expected" ]]; then + echo "PASS: $exp" + else + echo "FAIL: $exp (expected '$expected', got '$output')" + fi +done