Skip to content

Commit 275b8af

Browse files
committed
Implement ForeignPtr marshalling
1 parent 29bad26 commit 275b8af

File tree

7 files changed

+156
-29
lines changed

7 files changed

+156
-29
lines changed

inline-rust.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ test-suite spec
8585
, Submodule
8686
, Submodule.Submodule
8787
, ByteStrings
88+
, ForeignPtr
8889
build-depends: base
8990
, inline-rust
9091
, language-rust

src/Language/Rust/Inline.hs

Lines changed: 38 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ module Language.Rust.Inline (
6060
pointers,
6161
prelude,
6262
bytestrings,
63+
foreignPointers,
6364

6465
-- ** Marshalling
6566
with,
@@ -322,6 +323,9 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
322323
let retTy = showTy haskRet
323324
in fail ("Cannot put unlifted type ‘" ++ retTy ++ "’ in IO")
324325
ByteString -> [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|]
326+
ForeignPtr
327+
| AppT _ haskRet' <- haskRet -> [t|Ptr (Ptr $(pure haskRet'), FunPtr (Ptr $(pure haskRet') -> IO ())) -> IO ()|]
328+
| otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention")
325329
pure (marshalForm, pure ret)
326330

327331
-- Convert the Haskell arguments to marshallable FFI types
@@ -348,6 +352,11 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
348352
ByteString -> do
349353
rbsT <- [t|Ptr (Ptr Word8, Word)|]
350354
pure (ByteString, rbsT)
355+
ForeignPtr
356+
| AppT _ haskArg' <- haskArg -> do
357+
ptr <- [t|Ptr $(pure haskArg')|]
358+
pure (ForeignPtr, ptr)
359+
| otherwise -> fail ("Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention")
351360
_ -> pure (marshalForm, haskArg)
352361

353362
-- Generate the Haskell FFI import declaration and emit it
@@ -378,17 +387,28 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
378387
finalizer <- newName "finalizer"
379388
[e|
380389
alloca
381-
( \($(varP ret)) ->
382-
do
383-
$(appsE (varE qqName : reverse (varE ret : acc)))
384-
($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret)
385-
ByteString.unsafePackCStringFinalizer
386-
$(varE ptr)
387-
(fromIntegral $(varE len))
388-
($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len))
390+
( \($(varP ret)) -> do
391+
$(appsE (varE qqName : reverse (varE ret : acc)))
392+
($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret)
393+
ByteString.unsafePackCStringFinalizer
394+
$(varE ptr)
395+
(fromIntegral $(varE len))
396+
($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len))
389397
)
390398
|]
391-
| byValue returnFfi = appsE (varE qqName : reverse acc)
399+
| returnFfi == ForeignPtr = do
400+
finalizer <- newName "finalizer"
401+
ptr <- newName "ptr"
402+
ret <- newName "ret"
403+
[e|
404+
alloca
405+
( \($(varP ret)) -> do
406+
$(appsE (varE qqName : reverse (varE ret : acc)))
407+
($(varP ptr), $(varP finalizer)) <- peek $(varE ret)
408+
newForeignPtr $(varE finalizer) $(varE ptr)
409+
)
410+
|]
411+
| returnByValue returnFfi = appsE (varE qqName : reverse acc)
392412
| otherwise = do
393413
ret <- newName "ret"
394414
[e|
@@ -418,7 +438,12 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
418438
with ($(varE ptr), $(varE len)) (\($(varP bsp)) -> $(goArgs (varE bsp : acc) args))
419439
)
420440
|]
421-
| byValue marshalForm -> goArgs (varE argName : acc) args
441+
| marshalForm == ForeignPtr -> do
442+
ptr <- newName "ptr"
443+
[e|
444+
withForeignPtr $(varE argName) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args))
445+
|]
446+
| passByValue marshalForm -> goArgs (varE argName : acc) args
422447
| otherwise -> do
423448
x <- newName "x"
424449
[e|
@@ -444,7 +469,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
444469
mergeArgs t (Just tInter) = (fmap (const mempty) tInter, t)
445470

446471
-- Generate the Rust function.
447-
let retByVal = byValue returnFfi
472+
let retByVal = returnByValue returnFfi
448473
(retArg, retTy, ret)
449474
| retByVal =
450475
( []
@@ -464,15 +489,15 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
464489
", "
465490
( [ s ++ ": " ++ marshal (renderType t)
466491
| (s, t, v) <- zip3 rustArgNames rustArgs' marshalForms
467-
, let marshal x = if byValue v then x else "*const " ++ x
492+
, let marshal x = if passByValue v then x else "*const " ++ x
468493
]
469494
++ retArg
470495
)
471496
, ") -> " ++ retTy ++ " {"
472497
, unlines
473498
[ " let " ++ s ++ ": " ++ renderType t ++ " = " ++ marshal s ++ ".marshal();"
474499
| (s, t, v) <- zip3 rustArgNames rustConvertedArgs marshalForms
475-
, let marshal x = if byValue v then x else "unsafe { ::std::ptr::read(" ++ x ++ ") }"
500+
, let marshal x = if passByValue v then x else "unsafe { ::std::ptr::read(" ++ x ++ ") }"
476501
]
477502
, " let out: " ++ renderType rustConvertedRet ++ " = (|| {" ++ renderTokens rustBody ++ "})();"
478503
, " " ++ ret

src/Language/Rust/Inline/Context.hs

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
{-# LANGUAGE FlexibleInstances #-}
33
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
44
{-# LANGUAGE MagicHash #-}
5+
{-# LANGUAGE OverloadedStrings #-}
56
{-# LANGUAGE QuasiQuotes #-}
67
{-# LANGUAGE TemplateHaskell #-}
78

@@ -15,7 +16,10 @@ import Language.Rust.Syntax (
1516
Arg (..),
1617
FnDecl (..),
1718
Mutability (..),
18-
Ty (BareFn, Ptr),
19+
Path (..),
20+
PathParameters (..),
21+
PathSegment (..),
22+
Ty (..),
1923
Unsafety (..),
2024
)
2125

@@ -32,6 +36,7 @@ import Data.Word (Word16, Word32, Word64, Word8)
3236
import Foreign.C.Types -- pretty much every type here is used
3337
import Foreign.Ptr (FunPtr, Ptr)
3438

39+
import Foreign.ForeignPtr (ForeignPtr)
3540
import GHC.Exts (
3641
ByteArray#,
3742
Char#,
@@ -308,6 +313,58 @@ pointers = do
308313
, "}"
309314
]
310315

316+
foreignPointers :: Q Context
317+
foreignPointers =
318+
pure $ Context ([rule], [], [foreignPtr, constPtr, mutPtr])
319+
where
320+
rule (Ptr _ t _) context
321+
| First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing)
322+
rule (PathTy Nothing (Path False [PathSegment "ForeignPtr" (Just (AngleBracketed [] [t] [] _)) _] _) _) context
323+
| First (Just (t', Nothing)) <- lookupRTypeInContext t context = pure ([t|ForeignPtr $t'|], Nothing)
324+
rule _ _ = mempty
325+
326+
foreignPtr =
327+
unlines
328+
[ "#[repr(C)]"
329+
, "pub struct ForeignPtr<T>(*mut T, extern \"C\" fn (*mut T));"
330+
]
331+
332+
constPtr =
333+
unlines
334+
[ "impl<T> MarshalInto<*const T> for *const T {"
335+
, " fn marshal(self) -> *const T { self }"
336+
, "}"
337+
]
338+
339+
mutPtr =
340+
unlines
341+
[ "impl<T> MarshalInto<*mut T> for ForeignPtr<T> {"
342+
, " fn marshal(self) -> *mut T {"
343+
, " let ForeignPtr(ptr, _) = self;"
344+
, " ptr"
345+
, " }"
346+
, "}"
347+
, "impl<T> MarshalInto<ForeignPtr<T>> for ForeignPtr<T> {"
348+
, " fn marshal(self) -> Self {"
349+
, " self"
350+
, " }"
351+
, "}"
352+
, ""
353+
, "impl<T> From<Box<T>> for ForeignPtr<T> {"
354+
, " fn from(p: Box<T>) -> ForeignPtr<T> {"
355+
, " extern fn free<T> (ptr: *mut T) {"
356+
, " let t = unsafe { Box::from_raw(ptr) };"
357+
, " drop(t);"
358+
, " }"
359+
, " ForeignPtr(std::ptr::from_mut(Box::leak(p)), free)"
360+
, " }"
361+
, "}"
362+
, ""
363+
, "impl<T> MarshalInto<*mut T> for *mut T {"
364+
, " fn marshal(self) -> *mut T { self }"
365+
, "}"
366+
]
367+
311368
{- | This maps a Rust function type into the corresponding 'FunPtr' wrapped
312369
Haskell function type.
313370

src/Language/Rust/Inline/Context/ByteString.hs

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,14 @@ import Data.Word (Word8)
3030
import Foreign.Ptr (Ptr)
3131

3232
bytestrings :: Q Context
33-
bytestrings = do
34-
bytestringT <- [t|ByteString|]
35-
pure $ Context ([rule], [rev bytestringT], [rustByteString, impl])
33+
bytestrings =
34+
pure $ Context ([rule], [], [rustByteString, impl])
3635
where
3736
rule rty _
3837
| rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |])
39-
| rty == void [ty| Vec<u8> |] = pure ([t|ByteString|], pure . pure $ void [ty| RustMutByteString |])
38+
| rty == void [ty| Vec<u8> |] = pure ([t|ByteString|], pure . pure $ void [ty| RustOwnedByteString |])
4039
rule _ _ = mempty
4140

42-
rev _ _ _ = mempty
43-
4441
rustByteString =
4542
unlines
4643
[ "#[repr(C)]"
@@ -49,7 +46,7 @@ bytestrings = do
4946
, ""
5047
, "#[repr(C)]"
5148
, "#[derive(Copy, Clone)]"
52-
, "pub struct RustMutByteString(*mut u8, usize, extern \"C\" fn (*mut u8, usize) -> ());"
49+
, "pub struct RustOwnedByteString(*mut u8, usize, extern \"C\" fn (*mut u8, usize));"
5350
]
5451

5552
impl =
@@ -61,16 +58,16 @@ bytestrings = do
6158
, " }"
6259
, "}"
6360
, ""
64-
, "impl MarshalInto<RustMutByteString> for Vec<u8> {"
65-
, " fn marshal(self) -> RustMutByteString {"
61+
, "impl MarshalInto<RustOwnedByteString> for Vec<u8> {"
62+
, " fn marshal(self) -> RustOwnedByteString {"
6663
, " let bytes = Box::leak(self.into_boxed_slice());"
6764
, " let len = bytes.len();"
6865
, ""
69-
, " extern fn free_bytestring(ptr: *mut u8, len: usize) -> () {"
66+
, " extern fn free(ptr: *mut u8, len: usize) {"
7067
, " let bytes = unsafe { Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr, len) ) };"
71-
, " drop(bytes)"
68+
, " drop(bytes);"
7269
, " }"
73-
, " RustMutByteString(bytes.as_mut_ptr(), len, free_bytestring)"
70+
, " RustOwnedByteString(bytes.as_mut_ptr(), len, free)"
7471
, " }"
7572
, "}"
7673
]

src/Language/Rust/Inline/Marshal.hs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import Data.Word
2222
import Data.Int
2323

2424
import Foreign.Ptr ( plusPtr )
25-
import Foreign.ForeignPtr ( withForeignPtr )
25+
import Foreign.ForeignPtr ( withForeignPtr, ForeignPtr )
2626
import Foreign.StablePtr ( StablePtr )
2727
import Foreign.Storable ( Storable )
2828

@@ -37,10 +37,14 @@ data MarshalForm
3737
| BoxedDirect -- ^ value is marshallable and can be passed directly to the FFI
3838
| BoxedIndirect -- ^ value isn't marshallable directly but may be passed indirectly via a 'Ptr'
3939
| ByteString
40+
| ForeignPtr
4041
deriving (Eq)
4142

42-
byValue :: MarshalForm -> Bool
43-
byValue = (`elem` [UnboxedDirect, BoxedDirect])
43+
passByValue :: MarshalForm -> Bool
44+
passByValue = (`elem` [UnboxedDirect, BoxedDirect, ForeignPtr])
45+
46+
returnByValue :: MarshalForm -> Bool
47+
returnByValue = (`elem` [UnboxedDirect, BoxedDirect])
4448

4549
-- | Identify which types can be marshalled by the GHC FFI and which types are
4650
-- unlifted. A negative response to the first of these questions doesn't mean
@@ -56,13 +60,15 @@ ghcMarshallable ty = do
5660
tyconsU <- sequence qTyconsUnboxed
5761
tyconsB <- sequence qTyconsBoxed
5862
bytestring <- [t| ByteString |]
63+
fptrCons <- [t| ForeignPtr |]
5964

6065
case ty of
6166
_ | ty `elem` simpleU -> pure UnboxedDirect
6267
| ty `elem` simpleB -> pure BoxedDirect
6368
| ty == bytestring -> pure ByteString
6469
AppT con _ | con `elem` tyconsU -> pure UnboxedDirect
6570
| con `elem` tyconsB -> pure BoxedDirect
71+
| con == fptrCons -> pure ForeignPtr
6672
_ -> pure BoxedIndirect
6773
where
6874
qSimpleUnboxed = [ [t| Char# |]

tests/ForeignPtr.hs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{-# LANGUAGE QuasiQuotes #-}
2+
{-# LANGUAGE ScopedTypeVariables #-}
3+
{-# LANGUAGE TemplateHaskell #-}
4+
5+
module ForeignPtr where
6+
7+
import Language.Rust.Inline
8+
9+
import Data.Word (Word64)
10+
import Foreign (Storable (..))
11+
import Foreign.ForeignPtr
12+
import Foreign.Ptr
13+
import Test.Hspec
14+
15+
extendContext foreignPointers
16+
extendContext basic
17+
setCrateModule
18+
19+
foreignPtrTypes :: Spec
20+
foreignPtrTypes = describe "ForeignPtr types" $ do
21+
it "Can marshal ForeignPtr arguments" $ do
22+
p <- mallocForeignPtr
23+
withForeignPtr p (`poke` 42)
24+
let read = [rust| u64 { unsafe { *$(p: *const u64) } } |]
25+
42 `shouldBe` read
26+
27+
it "Can mutate ForeignPtr arguments" $ do
28+
p <- mallocForeignPtr
29+
[rustIO| () {
30+
unsafe { *$(p: *mut u64) = 42; }
31+
} |]
32+
val <- withForeignPtr p peek
33+
val `shouldBe` 42
34+
35+
it "Can marshal ForeignPtr returns" $ do
36+
let p = [rust| ForeignPtr<u64> { Box::new(42).into() }|]
37+
val <- withForeignPtr p peek
38+
val `shouldBe` 42

tests/Main.hs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ import AlgebraicDataTypes
1919
import ByteStrings
2020
import Submodule
2121
import Submodule.Submodule
22+
import ForeignPtr
2223
import Data.Word
2324
import Test.Hspec
2425
import Foreign.Storable
2526
import Foreign.Ptr
2627
import Foreign.Marshal.Array
28+
2729
extendContext basic
2830
setCrateRoot []
2931

@@ -39,3 +41,4 @@ main = hspec $
3941
submoduleTest
4042
subsubmoduleTest
4143
bytestringSpec
44+
foreignPtrTypes

0 commit comments

Comments
 (0)