Skip to content

Commit 643008f

Browse files
committed
Update haddocks, cabal file, add tests
1 parent 787c316 commit 643008f

13 files changed

+124
-29
lines changed

arrayfire.cabal

-2
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@ description: High-level Haskell bindings to the ArrayFire General-purpos
1515
.
1616
<<https://user-images.githubusercontent.com/875324/59819388-9ff83f00-92f5-11e9-9ac0-51eef200c716.png>>
1717
.
18-
<https://www.youtube.com/watch?v=tI89V1Z8QHw>
19-
.
2018

2119
library
2220
exposed-modules:

src/ArrayFire/Algorithm.hs

+7-1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ anyTrue
104104
-- ^ Returns if all elements are true
105105
anyTrue x (fromIntegral -> n) = getScalar @Bool @a (x `op1` (\p a -> af_any_true p a n))
106106

107+
-- | Retrieves count of all elements in 'Array' along the specified dimension
107108
count
108109
:: forall a . AFType a
109110
=> Array a
@@ -114,7 +115,8 @@ count
114115
-- ^ Count of all elements along dimension
115116
count x (fromIntegral -> n) = getScalar @Int @a (x `op1` (\p a -> af_count p a n))
116117

117-
-- | Note: imag is always set to 0 when in is real
118+
-- | Sum all elements in 'Array'
119+
-- Note: imag is always set to 0 when in is real
118120
sumAll
119121
:: AFType a
120122
=> Array a
@@ -123,6 +125,7 @@ sumAll
123125
-- ^ imaginary and real part
124126
sumAll = (`infoFromArray2` af_sum_all)
125127

128+
-- | Sum all elements in 'Array', substituting 'NaN' values with a user specified default.
126129
sumNaNAll
127130
:: (AFType a, Fractional a)
128131
=> Array a
@@ -133,6 +136,7 @@ sumNaNAll
133136
-- ^ imaginary and real part
134137
sumNaNAll a d = infoFromArray2 a (\p g x -> af_sum_nan_all p g x d)
135138

139+
-- | Product all elements in 'Array'
136140
productAll
137141
:: AFType a
138142
=> Array a
@@ -141,6 +145,7 @@ productAll
141145
-- ^ imaginary and real part
142146
productAll = (`infoFromArray2` af_product_all)
143147

148+
-- | Product all elements in 'Array', substituting NaN values with a user specified default.
144149
productNaNAll
145150
:: (AFType a, Fractional a)
146151
=> Array a
@@ -151,6 +156,7 @@ productNaNAll
151156
-- ^ imaginary and real part
152157
productNaNAll a d = infoFromArray2 a (\p x y -> af_product_nan_all p x y d)
153158

159+
-- | Finds the minimum value of all elements in the Array
154160
minAll
155161
:: AFType a
156162
=> Array a

src/ArrayFire/BLAS.hs

+19-2
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,31 @@
33
-- |
44
-- Module : ArrayFire.BLAS
55
-- Copyright : David Johnson (c) 2019-2020
6-
-- License : BSD 3
6+
-- License : BSD3
77
-- Maintainer : David Johnson <[email protected]>
88
-- Stability : Experimental
99
-- Portability : GHC
1010
--
1111
-- Basic Linear Algebra Subprograms (BLAS) API
1212
--
13+
-- @
14+
-- main :: IO ()
15+
-- main = 'print' ('matmul' x y xProp yProp)
16+
-- where
17+
-- x,y :: 'Array' 'Double'
18+
-- x = 'matrix' (2,3) [1..]
19+
-- y = 'matrix' (3,2) [1..]
20+
--
21+
-- xProp, yProp :: 'MatProp'
22+
-- xProp = None
23+
-- yProp = None
24+
-- @
25+
-- @
26+
-- ArrayFire Array
27+
-- [2 2 1 1]
28+
-- 22.0000 28.0000
29+
-- 49.0000 64.0000
30+
-- @
1331
--------------------------------------------------------------------------------
1432
module ArrayFire.BLAS where
1533

@@ -30,7 +48,6 @@ import ArrayFire.Types
3048
-- optLhs an only be one of AF_MAT_NONE, AF_MAT_TRANS, AF_MAT_CTRANS.
3149
--
3250
-- optRhs can only be AF_MAT_NONE.
33-
--
3451
matmul
3552
:: Array a
3653
-- ^ 2D matrix of Array a, left-hand side

src/ArrayFire/Backend.hs

+15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,21 @@
77
-- Stability : Experimental
88
-- Portability : GHC
99
--
10+
-- Set and get available ArrayFire 'Backend's.
11+
--
12+
-- @
13+
-- module Main where
14+
--
15+
-- import ArrayFire
16+
--
17+
-- main :: IO ()
18+
-- main = print =<< getAvailableBackends
19+
-- @
20+
--
21+
-- @
22+
-- [nix-shell:~\/arrayfire]$ .\/main
23+
-- [CPU,OpenCL]
24+
-- @
1025
--------------------------------------------------------------------------------
1126
module ArrayFire.Backend where
1227

src/ArrayFire/Device.hs

+6-10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
--
1111
-- Information about ArrayFire API and devices
1212
--
13+
-- @
14+
-- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)
15+
-- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB
16+
-- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB
17+
-- @
18+
--
1319
--------------------------------------------------------------------------------
1420
module ArrayFire.Device where
1521

@@ -18,15 +24,6 @@ import ArrayFire.Internal.Device
1824
import ArrayFire.FFI
1925

2026
-- | Retrieve info from ArrayFire API
21-
--
22-
-- Example below:
23-
--
24-
-- @
25-
--
26-
-- ArrayFire v3.6.4 (OpenCL, 64-bit Mac OSX, build 1b8030c5)
27-
-- [0] APPLE: AMD Radeon Pro 555X Compute Engine, 4096 MB
28-
-- -1- APPLE: Intel(R) UHD Graphics 630, 1536 MB
29-
-- @
3027
info :: IO ()
3128
info = afCall af_info
3229

@@ -45,7 +42,6 @@ getDeviceCount :: IO Int
4542
getDeviceCount = fromIntegral <$> afCall1 af_get_device_count
4643

4744
-- af_err af_get_dbl_support(bool* available, const int device);
48-
4945
-- | Sets a device by 'Int'
5046
setDevice :: Int -> IO ()
5147
setDevice (fromIntegral -> x) = afCall (af_set_device x)

src/ArrayFire/FFI.hs

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import ArrayFire.Exception
1717
import ArrayFire.Types
1818
import ArrayFire.Internal.Defines
1919
import ArrayFire.Internal.Features
20+
import ArrayFire.Internal.Array
21+
2022
import Control.Exception
2123
import Control.Monad
2224
import Foreign.ForeignPtr
@@ -325,8 +327,10 @@ featuresToArray (Features fptr1) op =
325327
withForeignPtr fptr1 $ \ptr1 -> do
326328
alloca $ \ptrInput -> do
327329
throwAFError =<< op ptrInput ptr1
328-
Array <$> do
329-
newForeignPtr_ =<< peek ptrInput
330+
alloca $ \retainedArray -> do
331+
throwAFError =<< af_retain_array retainedArray =<< peek ptrInput
332+
fptr <- newForeignPtr af_release_array_finalizer =<< peek retainedArray
333+
pure (Array fptr)
330334

331335
infoFromFeatures
332336
:: Storable a

src/ArrayFire/LAPACK.hs

+3-1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ svd
3535
-- 'vt' is the output array containing V^H
3636
svd = (`op3p` af_svd)
3737

38+
3839
svdInPlace
3940
:: AFType a
4041
=> Array a
@@ -74,7 +75,7 @@ cholesky
7475
cholesky a (fromIntegral . fromEnum -> b) = do
7576
let (x',y') = op1b a (\x y z -> af_cholesky x y z b)
7677
(fromIntegral x', y')
77-
78+
7879
choleskyInplace
7980
:: AFType a
8081
=> Array a
@@ -110,6 +111,7 @@ inverse
110111
inverse a m =
111112
a `op1` (\x y -> af_inverse x y (toMatProp m))
112113

114+
-- | Not implemented in 3.6.4
113115
pinverse
114116
:: AFType a
115117
=> Array a

src/ArrayFire/Orphans.hs

+4-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import Prelude
2121

2222
import qualified ArrayFire.Arith as A
2323
import qualified ArrayFire.Array as A
24+
import qualified ArrayFire.Data as A
2425
import ArrayFire.Types
2526
import ArrayFire.Util
2627

@@ -33,7 +34,9 @@ instance (Num a, AFType a) => Num (Array a) where
3334
x * y = A.mul x y False
3435
abs = A.abs
3536
signum = A.sign
36-
negate = error "TODO: negate"
37+
negate arr = do
38+
let (w,x,y,z) = A.getDims arr
39+
A.cast (A.constant [w,x,y,z] 0.0) `A.sub` arr $ False
3740
x - y = A.sub x y False
3841
fromInteger = A.scalar . fromIntegral
3942

src/ArrayFire/Random.hs

+35-1
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,45 @@
1212
-- |
1313
-- Module : ArrayFire.Random
1414
-- Copyright : David Johnson (c) 2019-2020
15-
-- License : BSD 3
15+
-- License : BSD3
1616
-- Maintainer : David Johnson <[email protected]>
1717
-- Stability : Experimental
1818
-- Portability : GHC
1919
--
20+
-- 'RandomEngine' generation, Random 'Array' generation.
21+
--
22+
-- @
23+
-- {-\# LANGUAGE TypeApplications \#-}
24+
-- module Main where
25+
--
26+
-- import 'ArrayFire'
27+
--
28+
-- main :: IO ()
29+
-- main = do
30+
-- seed <- 'getSeed'
31+
-- -- ^ Retrieves seed
32+
-- engine <- 'createRandomEngine' 'Mersenne' seed
33+
-- -- ^ Creates engine
34+
-- array <- 'randomUniform' [3,3] engine
35+
-- -- ^ Creates random Array from engine with uniformly distributed data
36+
-- print array
37+
--
38+
-- print =<< 'randu' @'Double' [2,2]
39+
-- -- ^ Shorthand for creating random 'Array'
40+
-- @
41+
-- @
42+
-- ArrayFire 'Array'
43+
-- [3 3 1 1]
44+
-- 0.4446 0.1143 0.4283
45+
-- 0.5725 0.1456 0.9182
46+
-- 0.1915 0.1643 0.5997
47+
-- @
48+
-- @
49+
-- ArrayFire 'Array'
50+
-- [2 2 1 1]
51+
-- 0.6010 0.0278
52+
-- 0.9806 0.2126
53+
-- @
2054
--------------------------------------------------------------------------------
2155
module ArrayFire.Random
2256
( createRandomEngine

test/ArrayFire/AlgorithmSpec.hs

+15
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,18 @@ spec =
9797
A.count (A.vector @A.CBool 5 (repeat 1)) 0 `shouldBe` 5
9898
A.count (A.vector @Double 5 (repeat 1)) 0 `shouldBe` 5
9999
A.count (A.vector @Float 5 (repeat 1)) 0 `shouldBe` 5
100+
it "Should get sum all elements" $ do
101+
A.sumAll (A.vector @Int 5 (repeat 2)) `shouldBe` (10,0)
102+
A.sumAll (A.vector @Double 5 (repeat 2)) `shouldBe` (10.0,0)
103+
A.sumAll (A.vector @(A.Complex Double) 5 (repeat (2 A.:+ 0))) `shouldBe` (10.0,0)
104+
it "Should get sum all elements" $ do
105+
A.sumNaNAll (A.vector @Double 2 [10, acos 2]) 1 `shouldBe` (11.0,0)
106+
it "Should product all elements in an Array" $ do
107+
A.productAll (A.vector @Int 5 (repeat 2)) `shouldBe` (32,0)
108+
it "Should product all elements in an Array" $ do
109+
A.productNaNAll (A.vector @Double 2 [10,acos 2]) 10 `shouldBe` (100,0)
110+
it "Should find minimum value of an Array" $ do
111+
A.minAll (A.vector @Int 5 [0..]) `shouldBe` (0,0)
112+
it "Should find maximum value of an Array" $ do
113+
A.maxAll (A.vector @Int 5 [0..]) `shouldBe` (4,0)
114+

test/ArrayFire/ArithSpec.hs

+4
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@ import Foreign.C
99
spec :: Spec
1010
spec =
1111
describe "Arith tests" $ do
12+
it "Should negate scalar value" $ do
13+
negate (scalar @Int 1) `shouldBe` (-1)
14+
it "Should negate a vector" $ do
15+
negate (vector @Int 3 [2,2,2]) `shouldBe` vector @Int 3 [-2,-2,-2]
1216
it "Should add two scalar arrays" $ do
1317
scalar @Int 1 + 2 `shouldBe` 3
1418
it "Should add two scalar bool arrays" $ do

test/ArrayFire/FeaturesSpec.hs

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
{-# LANGUAGE TypeApplications #-}
22
module ArrayFire.FeaturesSpec where
33

4-
import ArrayFire
4+
import ArrayFire hiding (acos)
5+
import Prelude
56
import Test.Hspec
67

78
spec :: Spec
@@ -10,8 +11,8 @@ spec =
1011
it "Should get features number an array" $ do
1112
let feats = createFeatures 10
1213
getFeaturesNum feats `shouldBe` 10
13-
-- print (getFeaturesSize feats)
14-
-- print (getFeaturesOrientation feats)
15-
-- print (getFeaturesXPos feats)
16-
-- print (getFeaturesYPos feats)
17-
14+
-- let vec = vector @Double 10 $ repeat (acos 2)
15+
-- getFeaturesSize feats `shouldBe` vec
16+
-- getFeaturesOrientation feats `shouldBe` vec
17+
-- getFeaturesXPos feats `shouldBe` vec
18+
-- getFeaturesYPos feats `shouldBe` vec

test/ArrayFire/LAPACKSpec.hs

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,6 @@ spec =
3939
it "Should calculate inverse" $ do
4040
let x = flip A.inverse A.None $ A.matrix @Double (2,2) [4,7,2,6]
4141
x `shouldBe` A.matrix @Double (2,2) [0.6,-0.2,-0.7,0.4]
42-
-- it "Should calculate psuedo inverse" $ do
43-
-- let x = A.pinverse (A.matrix @Double (2,2) [4,7,2,6]) 1.0 A.None
44-
-- x `shouldBe` A.matrix @Double (2,2) [0.6,-0.2,-0.7,0.4]
42+
-- it "Should calculate psuedo inverse" $ do
43+
-- let x = A.pinverse (A.matrix @Double (2,2) [4,7,2,6]) 1.0 A.None
44+
-- x `shouldBe` A.matrix @Double (2,2) [0.6,-0.2,-0.7,0.4]

0 commit comments

Comments
 (0)