diff --git a/src/Python/Inline/Literal.hs b/src/Python/Inline/Literal.hs index 28fbbf5..8606179 100644 --- a/src/Python/Inline/Literal.hs +++ b/src/Python/Inline/Literal.hs @@ -46,6 +46,7 @@ import Foreign.Storable import Foreign.Marshal.Alloc (alloca,mallocBytes) import Foreign.Marshal.Utils (copyBytes) import GHC.Float (float2Double, double2Float) +import Data.Complex (Complex((:+))) import Language.C.Inline qualified as C import Language.C.Inline.Unsafe qualified as CU @@ -202,6 +203,24 @@ deriving via CDouble instance FromPy Double instance ToPy Float where basicToPy = basicToPy . float2Double instance FromPy Float where basicFromPy = fmap double2Float . basicFromPy +instance ToPy (Complex Float) where + basicToPy (x:+y) = basicToPy $ float2Double x :+ float2Double y +instance FromPy (Complex Float) where + basicFromPy xy_py = do + x :+ y <- basicFromPy xy_py + return $ double2Float x :+ double2Float y + +instance ToPy (Complex Double) where + basicToPy (x:+y) = Py [CU.exp| PyObject* { PyComplex_FromDoubles($(double x'), $(double y')) } |] + where x' = CDouble x + y' = CDouble y +instance FromPy (Complex Double) where + basicFromPy xy_py = do + CDouble x <- Py [CU.exp| double { PyComplex_RealAsDouble($(PyObject *xy_py)) } |] + checkThrowBadPyType + CDouble y <- Py [CU.exp| double { PyComplex_ImagAsDouble($(PyObject *xy_py)) } |] + checkThrowBadPyType + return $ x :+ y instance ToPy Int where basicToPy diff --git a/test/TST/FromPy.hs b/test/TST/FromPy.hs index 6e39005..0727640 100644 --- a/test/TST/FromPy.hs +++ b/test/TST/FromPy.hs @@ -9,6 +9,7 @@ import Test.Tasty import Test.Tasty.HUnit import Python.Inline import Python.Inline.QQ +import Data.Complex (Complex((:+))) tests :: TestTree tests = testGroup "FromPy" @@ -22,6 +23,13 @@ tests = testGroup "FromPy" , testCase "Double->Double" $ eq @Double (Just 1234.25) [pye| 1234.25 |] , testCase "None->Double" $ eq @Double Nothing [pye| None |] ] + , testGroup "Complex" + [ testCase "Int->Complex" $ eq @(Complex Double) (Just 1234) [pye| 1234 |] + , testCase "Double->Complex" $ eq @(Complex Double) (Just 1234.25) [pye| 1234.25 |] + , testCase "Complex->Complex" $ eq @(Complex Double) (Just $ 1234.5 :+ 6789) + [pye| 1234.5+6789.0j |] + , testCase "None->Complex" $ eq @(Complex Double) Nothing [pye| None |] + ] , testGroup "Char" [ testCase "0" $ eq @Char Nothing [pye| "" |] , testCase "1 1B" $ eq @Char (Just 'a') [pye| "a" |] diff --git a/test/TST/Roundtrip.hs b/test/TST/Roundtrip.hs index 232c4e7..4bdc29a 100644 --- a/test/TST/Roundtrip.hs +++ b/test/TST/Roundtrip.hs @@ -10,6 +10,7 @@ import Data.Set (Set) import Data.Map.Strict (Map) import Data.Text qualified as T import Data.Text.Lazy qualified as TL +import Data.Complex (Complex) import Foreign.C.Types import Test.Tasty @@ -61,6 +62,9 @@ tests = testGroup "Roundtrip" -- Floating point , testRoundtrip @Double , testRoundtrip @Float + -- Complex + , testRoundtrip @(Complex Double) + , testRoundtrip @(Complex Float) -- Other scalars , testRoundtrip @Char , testRoundtrip @Bool @@ -71,6 +75,7 @@ tests = testGroup "Roundtrip" , testRoundtrip @(Int,Int,Int,Char) , testRoundtrip @[Int] , testRoundtrip @[[Int]] + , testRoundtrip @[Complex Double] , testRoundtrip @(Set Int) , testRoundtrip @(Map Int Int) -- , testRoundtrip @String -- Trips on zero byte as it should diff --git a/test/TST/ToPy.hs b/test/TST/ToPy.hs index 7cd5611..52224f0 100644 --- a/test/TST/ToPy.hs +++ b/test/TST/ToPy.hs @@ -5,6 +5,7 @@ module TST.ToPy (tests) where import Data.ByteString qualified as BS import Data.Set qualified as Set import Data.Map.Strict qualified as Map +import Data.Complex (Complex((:+))) import Test.Tasty import Test.Tasty.HUnit import Python.Inline @@ -16,6 +17,9 @@ tests :: TestTree tests = testGroup "ToPy" [ testCase "Int" $ runPy $ let i = 1234 :: Int in [py_| assert i_hs == 1234 |] , testCase "Double" $ runPy $ let i = 1234.25 :: Double in [py_| assert i_hs == 1234.25 |] + , testCase "Complex" $ runPy $ + let z = 5.5 :+ 7.5 :: Complex Double + in [py_| assert (z_hs.real == 5.5); assert (z_hs.imag == 7.5)|] , testCase "Char ASCII" $ runPy $ let c = 'a' in [py_| assert c_hs == 'a' |] , testCase "Char unicode" $ runPy $ let c = 'ы' in [py_| assert c_hs == 'ы' |] , testCase "String ASCII" $ runPy $ let c = "asdf"::String in [py_| assert c_hs == 'asdf' |]