Skip to content

Support complex numbers. #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions src/Python/Inline/Literal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ import Foreign.Storable
import Foreign.Marshal.Alloc (alloca,mallocBytes)
import Foreign.Marshal.Utils (copyBytes)
import GHC.Float (float2Double, double2Float)
import Data.Complex (Complex((:+)))

import Language.C.Inline qualified as C
import Language.C.Inline.Unsafe qualified as CU
Expand Down Expand Up @@ -202,6 +203,24 @@ deriving via CDouble instance FromPy Double
instance ToPy Float where basicToPy = basicToPy . float2Double
instance FromPy Float where basicFromPy = fmap double2Float . basicFromPy

instance ToPy (Complex Float) where
basicToPy (x:+y) = basicToPy $ float2Double x :+ float2Double y
instance FromPy (Complex Float) where
basicFromPy xy_py = do
x :+ y <- basicFromPy xy_py
return $ double2Float x :+ double2Float y

instance ToPy (Complex Double) where
basicToPy (x:+y) = Py [CU.exp| PyObject* { PyComplex_FromDoubles($(double x'), $(double y')) } |]
where x' = CDouble x
y' = CDouble y
instance FromPy (Complex Double) where
basicFromPy xy_py = do
CDouble x <- Py [CU.exp| double { PyComplex_RealAsDouble($(PyObject *xy_py)) } |]
checkThrowBadPyType
CDouble y <- Py [CU.exp| double { PyComplex_ImagAsDouble($(PyObject *xy_py)) } |]
checkThrowBadPyType
return $ x :+ y

instance ToPy Int where
basicToPy
Expand Down
8 changes: 8 additions & 0 deletions test/TST/FromPy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import Test.Tasty
import Test.Tasty.HUnit
import Python.Inline
import Python.Inline.QQ
import Data.Complex (Complex((:+)))

tests :: TestTree
tests = testGroup "FromPy"
Expand All @@ -22,6 +23,13 @@ tests = testGroup "FromPy"
, testCase "Double->Double" $ eq @Double (Just 1234.25) [pye| 1234.25 |]
, testCase "None->Double" $ eq @Double Nothing [pye| None |]
]
, testGroup "Complex"
[ testCase "Int->Complex" $ eq @(Complex Double) (Just 1234) [pye| 1234 |]
, testCase "Double->Complex" $ eq @(Complex Double) (Just 1234.25) [pye| 1234.25 |]
, testCase "Complex->Complex" $ eq @(Complex Double) (Just $ 1234.5 :+ 6789)
[pye| 1234.5+6789.0j |]
, testCase "None->Complex" $ eq @(Complex Double) Nothing [pye| None |]
]
, testGroup "Char"
[ testCase "0" $ eq @Char Nothing [pye| "" |]
, testCase "1 1B" $ eq @Char (Just 'a') [pye| "a" |]
Expand Down
5 changes: 5 additions & 0 deletions test/TST/Roundtrip.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import Data.Map.Strict (Map)
import Data.Text qualified as T
import Data.Text.Lazy qualified as TL
import Data.Complex (Complex)
import Foreign.C.Types

import Test.Tasty
Expand All @@ -25,7 +26,7 @@
import Data.ByteString.Short qualified as SBS
import Data.Vector qualified as V
#if MIN_VERSION_vector(0,13,2)
import Data.Vector.Strict qualified as VV

Check warning on line 29 in test/TST/Roundtrip.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.2.8

The qualified import of ‘Data.Vector.Strict’ is redundant

Check warning on line 29 in test/TST/Roundtrip.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.4.8

The qualified import of ‘Data.Vector.Strict’ is redundant

Check warning on line 29 in test/TST/Roundtrip.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.8.2

The qualified import of ‘Data.Vector.Strict’ is redundant

Check warning on line 29 in test/TST/Roundtrip.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.6.6

The qualified import of ‘Data.Vector.Strict’ is redundant

Check warning on line 29 in test/TST/Roundtrip.hs

View workflow job for this annotation

GitHub Actions / ubuntu-latest / ghc 9.10.1

The qualified import of ‘Data.Vector.Strict’ is redundant
#endif
import Data.Vector.Storable qualified as VS
import Data.Vector.Primitive qualified as VP
Expand Down Expand Up @@ -61,6 +62,9 @@
-- Floating point
, testRoundtrip @Double
, testRoundtrip @Float
-- Complex
, testRoundtrip @(Complex Double)
, testRoundtrip @(Complex Float)
-- Other scalars
, testRoundtrip @Char
, testRoundtrip @Bool
Expand All @@ -71,6 +75,7 @@
, testRoundtrip @(Int,Int,Int,Char)
, testRoundtrip @[Int]
, testRoundtrip @[[Int]]
, testRoundtrip @[Complex Double]
, testRoundtrip @(Set Int)
, testRoundtrip @(Map Int Int)
-- , testRoundtrip @String -- Trips on zero byte as it should
Expand Down
4 changes: 4 additions & 0 deletions test/TST/ToPy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module TST.ToPy (tests) where
import Data.ByteString qualified as BS
import Data.Set qualified as Set
import Data.Map.Strict qualified as Map
import Data.Complex (Complex((:+)))
import Test.Tasty
import Test.Tasty.HUnit
import Python.Inline
Expand All @@ -16,6 +17,9 @@ tests :: TestTree
tests = testGroup "ToPy"
[ testCase "Int" $ runPy $ let i = 1234 :: Int in [py_| assert i_hs == 1234 |]
, testCase "Double" $ runPy $ let i = 1234.25 :: Double in [py_| assert i_hs == 1234.25 |]
, testCase "Complex" $ runPy $
let z = 5.5 :+ 7.5 :: Complex Double
in [py_| assert (z_hs.real == 5.5); assert (z_hs.imag == 7.5)|]
, testCase "Char ASCII" $ runPy $ let c = 'a' in [py_| assert c_hs == 'a' |]
, testCase "Char unicode" $ runPy $ let c = 'ы' in [py_| assert c_hs == 'ы' |]
, testCase "String ASCII" $ runPy $ let c = "asdf"::String in [py_| assert c_hs == 'asdf' |]
Expand Down
Loading