@@ -60,6 +60,7 @@ module Language.Rust.Inline (
6060 pointers ,
6161 prelude ,
6262 bytestrings ,
63+ foreignPointers ,
6364
6465 -- ** Marshalling
6566 with ,
@@ -322,6 +323,9 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
322323 let retTy = showTy haskRet
323324 in fail (" Cannot put unlifted type ‘" ++ retTy ++ " ’ in IO" )
324325 ByteString -> [t |Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|]
326+ ForeignPtr
327+ | AppT _ haskRet' <- haskRet -> [t |Ptr (Ptr $(pure haskRet'), FunPtr (Ptr $(pure haskRet') -> IO ())) -> IO ()|]
328+ | otherwise -> fail (" Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention" )
325329 pure (marshalForm, pure ret)
326330
327331 -- Convert the Haskell arguments to marshallable FFI types
@@ -348,6 +352,11 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
348352 ByteString -> do
349353 rbsT <- [t |Ptr (Ptr Word8, Word)|]
350354 pure (ByteString , rbsT)
355+ ForeignPtr
356+ | AppT _ haskArg' <- haskArg -> do
357+ ptr <- [t |Ptr $(pure haskArg')|]
358+ pure (ForeignPtr , ptr)
359+ | otherwise -> fail (" Cannot marshal " ++ showTy haskRet ++ " using the ForeignPtr calling convention" )
351360 _ -> pure (marshalForm, haskArg)
352361
353362 -- Generate the Haskell FFI import declaration and emit it
@@ -378,17 +387,28 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
378387 finalizer <- newName " finalizer"
379388 [e |
380389 alloca
381- ( \($(varP ret)) ->
382- do
383- $(appsE (varE qqName : reverse (varE ret : acc)))
384- ($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret)
385- ByteString.unsafePackCStringFinalizer
386- $(varE ptr)
387- (fromIntegral $(varE len))
388- ($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len))
390+ ( \($(varP ret)) -> do
391+ $(appsE (varE qqName : reverse (varE ret : acc)))
392+ ($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret)
393+ ByteString.unsafePackCStringFinalizer
394+ $(varE ptr)
395+ (fromIntegral $(varE len))
396+ ($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len))
389397 )
390398 |]
391- | byValue returnFfi = appsE (varE qqName : reverse acc)
399+ | returnFfi == ForeignPtr = do
400+ finalizer <- newName " finalizer"
401+ ptr <- newName " ptr"
402+ ret <- newName " ret"
403+ [e |
404+ alloca
405+ ( \($(varP ret)) -> do
406+ $(appsE (varE qqName : reverse (varE ret : acc)))
407+ ($(varP ptr), $(varP finalizer)) <- peek $(varE ret)
408+ newForeignPtr $(varE finalizer) $(varE ptr)
409+ )
410+ |]
411+ | returnByValue returnFfi = appsE (varE qqName : reverse acc)
392412 | otherwise = do
393413 ret <- newName " ret"
394414 [e |
@@ -418,7 +438,12 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
418438 with ($(varE ptr), $(varE len)) (\($(varP bsp)) -> $(goArgs (varE bsp : acc) args))
419439 )
420440 |]
421- | byValue marshalForm -> goArgs (varE argName : acc) args
441+ | marshalForm == ForeignPtr -> do
442+ ptr <- newName " ptr"
443+ [e |
444+ withForeignPtr $(varE argName) (\($(varP ptr)) -> $(goArgs (varE ptr : acc) args))
445+ |]
446+ | passByValue marshalForm -> goArgs (varE argName : acc) args
422447 | otherwise -> do
423448 x <- newName " x"
424449 [e |
@@ -444,7 +469,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
444469 mergeArgs t (Just tInter) = (fmap (const mempty ) tInter, t)
445470
446471 -- Generate the Rust function.
447- let retByVal = byValue returnFfi
472+ let retByVal = returnByValue returnFfi
448473 (retArg, retTy, ret)
449474 | retByVal =
450475 ( []
@@ -464,15 +489,15 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
464489 " , "
465490 ( [ s ++ " : " ++ marshal (renderType t)
466491 | (s, t, v) <- zip3 rustArgNames rustArgs' marshalForms
467- , let marshal x = if byValue v then x else " *const " ++ x
492+ , let marshal x = if passByValue v then x else " *const " ++ x
468493 ]
469494 ++ retArg
470495 )
471496 , " ) -> " ++ retTy ++ " {"
472497 , unlines
473498 [ " let " ++ s ++ " : " ++ renderType t ++ " = " ++ marshal s ++ " .marshal();"
474499 | (s, t, v) <- zip3 rustArgNames rustConvertedArgs marshalForms
475- , let marshal x = if byValue v then x else " unsafe { ::std::ptr::read(" ++ x ++ " ) }"
500+ , let marshal x = if passByValue v then x else " unsafe { ::std::ptr::read(" ++ x ++ " ) }"
476501 ]
477502 , " let out: " ++ renderType rustConvertedRet ++ " = (|| {" ++ renderTokens rustBody ++ " })();"
478503 , " " ++ ret
0 commit comments