Skip to content

Commit cbee13b

Browse files
Support complex numbers. (#30)
Complex numbers are built-in in Python and ship in Haskell-base, so this packages should be able to convert between them. I have implemented the necessary `ToPy` and `FromPy` instances and tests for them.
1 parent 1f9b46d commit cbee13b

File tree

4 files changed

+36
-0
lines changed

4 files changed

+36
-0
lines changed

src/Python/Inline/Literal.hs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ import Foreign.Storable
4646
import Foreign.Marshal.Alloc (alloca,mallocBytes)
4747
import Foreign.Marshal.Utils (copyBytes)
4848
import GHC.Float (float2Double, double2Float)
49+
import Data.Complex (Complex((:+)))
4950

5051
import Language.C.Inline qualified as C
5152
import Language.C.Inline.Unsafe qualified as CU
@@ -202,6 +203,24 @@ deriving via CDouble instance FromPy Double
202203
instance ToPy Float where basicToPy = basicToPy . float2Double
203204
instance FromPy Float where basicFromPy = fmap double2Float . basicFromPy
204205

206+
instance ToPy (Complex Float) where
207+
basicToPy (x:+y) = basicToPy $ float2Double x :+ float2Double y
208+
instance FromPy (Complex Float) where
209+
basicFromPy xy_py = do
210+
x :+ y <- basicFromPy xy_py
211+
return $ double2Float x :+ double2Float y
212+
213+
instance ToPy (Complex Double) where
214+
basicToPy (x:+y) = Py [CU.exp| PyObject* { PyComplex_FromDoubles($(double x'), $(double y')) } |]
215+
where x' = CDouble x
216+
y' = CDouble y
217+
instance FromPy (Complex Double) where
218+
basicFromPy xy_py = do
219+
CDouble x <- Py [CU.exp| double { PyComplex_RealAsDouble($(PyObject *xy_py)) } |]
220+
checkThrowBadPyType
221+
CDouble y <- Py [CU.exp| double { PyComplex_ImagAsDouble($(PyObject *xy_py)) } |]
222+
checkThrowBadPyType
223+
return $ x :+ y
205224

206225
instance ToPy Int where
207226
basicToPy

test/TST/FromPy.hs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import Test.Tasty
99
import Test.Tasty.HUnit
1010
import Python.Inline
1111
import Python.Inline.QQ
12+
import Data.Complex (Complex((:+)))
1213

1314
tests :: TestTree
1415
tests = testGroup "FromPy"
@@ -22,6 +23,13 @@ tests = testGroup "FromPy"
2223
, testCase "Double->Double" $ eq @Double (Just 1234.25) [pye| 1234.25 |]
2324
, testCase "None->Double" $ eq @Double Nothing [pye| None |]
2425
]
26+
, testGroup "Complex"
27+
[ testCase "Int->Complex" $ eq @(Complex Double) (Just 1234) [pye| 1234 |]
28+
, testCase "Double->Complex" $ eq @(Complex Double) (Just 1234.25) [pye| 1234.25 |]
29+
, testCase "Complex->Complex" $ eq @(Complex Double) (Just $ 1234.5 :+ 6789)
30+
[pye| 1234.5+6789.0j |]
31+
, testCase "None->Complex" $ eq @(Complex Double) Nothing [pye| None |]
32+
]
2533
, testGroup "Char"
2634
[ testCase "0" $ eq @Char Nothing [pye| "" |]
2735
, testCase "1 1B" $ eq @Char (Just 'a') [pye| "a" |]

test/TST/Roundtrip.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Data.Set (Set)
1010
import Data.Map.Strict (Map)
1111
import Data.Text qualified as T
1212
import Data.Text.Lazy qualified as TL
13+
import Data.Complex (Complex)
1314
import Foreign.C.Types
1415

1516
import Test.Tasty
@@ -61,6 +62,9 @@ tests = testGroup "Roundtrip"
6162
-- Floating point
6263
, testRoundtrip @Double
6364
, testRoundtrip @Float
65+
-- Complex
66+
, testRoundtrip @(Complex Double)
67+
, testRoundtrip @(Complex Float)
6468
-- Other scalars
6569
, testRoundtrip @Char
6670
, testRoundtrip @Bool
@@ -71,6 +75,7 @@ tests = testGroup "Roundtrip"
7175
, testRoundtrip @(Int,Int,Int,Char)
7276
, testRoundtrip @[Int]
7377
, testRoundtrip @[[Int]]
78+
, testRoundtrip @[Complex Double]
7479
, testRoundtrip @(Set Int)
7580
, testRoundtrip @(Map Int Int)
7681
-- , testRoundtrip @String -- Trips on zero byte as it should

test/TST/ToPy.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module TST.ToPy (tests) where
55
import Data.ByteString qualified as BS
66
import Data.Set qualified as Set
77
import Data.Map.Strict qualified as Map
8+
import Data.Complex (Complex((:+)))
89
import Test.Tasty
910
import Test.Tasty.HUnit
1011
import Python.Inline
@@ -16,6 +17,9 @@ tests :: TestTree
1617
tests = testGroup "ToPy"
1718
[ testCase "Int" $ runPy $ let i = 1234 :: Int in [py_| assert i_hs == 1234 |]
1819
, testCase "Double" $ runPy $ let i = 1234.25 :: Double in [py_| assert i_hs == 1234.25 |]
20+
, testCase "Complex" $ runPy $
21+
let z = 5.5 :+ 7.5 :: Complex Double
22+
in [py_| assert (z_hs.real == 5.5); assert (z_hs.imag == 7.5)|]
1923
, testCase "Char ASCII" $ runPy $ let c = 'a' in [py_| assert c_hs == 'a' |]
2024
, testCase "Char unicode" $ runPy $ let c = 'ы' in [py_| assert c_hs == 'ы' |]
2125
, testCase "String ASCII" $ runPy $ let c = "asdf"::String in [py_| assert c_hs == 'asdf' |]

0 commit comments

Comments
 (0)