Skip to content

Commit 1b64ace

Browse files
committed
Should zero out af_array before using.
Adds zeroing out function to cbits/wrapper.c for zeroing out bytes before use.
1 parent 643008f commit 1b64ace

File tree

5 files changed

+25
-5
lines changed

5 files changed

+25
-5
lines changed

cbits/wrapper.c

+4
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,7 @@ void test_window () {
3535
af_create_window(&window, 100, 100, "foo");
3636
af_show(window);
3737
}
38+
39+
void zeroOutArray (af_array * arr) {
40+
(*arr) = 0;
41+
}

src/ArrayFire/Array.hs

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ mkArray dims xs =
105105
dataPtr <- castPtr <$> newArray (Prelude.take size xs)
106106
let ndims = fromIntegral (Prelude.length dims)
107107
alloca $ \arrayPtr -> do
108+
zeroOutArray arrayPtr
108109
dimsPtr <- newArray (DimT . fromIntegral <$> dims)
109110
throwAFError =<< af_create_array arrayPtr dataPtr ndims dimsPtr dType
110111
free dataPtr >> free dimsPtr

src/ArrayFire/Data.hs

+14-5
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ constant
5151
-> Array Double
5252
constant dims val =
5353
unsafePerformIO . mask_ $ do
54-
ptr <- alloca $ \ptrPtr ->
54+
ptr <- alloca $ \ptrPtr -> do
55+
zeroOutArray ptrPtr
5556
withArray (fromIntegral <$> dims) $ \dimArray -> do
5657
throwAFError =<< af_constant ptrPtr val n dimArray typ
5758
peek ptrPtr
@@ -76,7 +77,8 @@ constantComplex
7677
-- ^ Scalar value
7778
-> Array (Complex Double)
7879
constantComplex dims val = unsafePerformIO . mask_ $ do
79-
ptr <- alloca $ \ptrPtr -> mask_ $ do
80+
ptr <- alloca $ \ptrPtr -> do
81+
zeroOutArray ptrPtr
8082
withArray (fromIntegral <$> dims) $ \dimArray -> do
8183
throwAFError =<< af_constant_complex ptrPtr (realPart val) (imagPart val) n dimArray typ
8284
peek ptrPtr
@@ -101,7 +103,8 @@ constantLong
101103
-- ^ Scalar value
102104
-> Array Int
103105
constantLong dims val = unsafePerformIO . mask_ $ do
104-
ptr <- alloca $ \ptrPtr ->
106+
ptr <- alloca $ \ptrPtr -> do
107+
zeroOutArray ptrPtr
105108
withArray (fromIntegral <$> dims) $ \dimArray -> do
106109
throwAFError =<< af_constant_long ptrPtr (fromIntegral val) n dimArray
107110
peek ptrPtr
@@ -123,7 +126,8 @@ constantULong
123126
-> Word64
124127
-> Array Word64
125128
constantULong dims val = unsafePerformIO . mask_ $ do
126-
ptr <- alloca $ \ptrPtr -> mask_ $ do
129+
ptr <- alloca $ \ptrPtr -> do
130+
zeroOutArray ptrPtr
127131
withArray (fromIntegral <$> dims) $ \dimArray -> do
128132
throwAFError =<< af_constant_ulong ptrPtr (fromIntegral val) n dimArray
129133
peek ptrPtr
@@ -142,6 +146,7 @@ range
142146
-> IO (Array a)
143147
range dims (fromIntegral -> k) = do
144148
ptr <- alloca $ \ptrPtr -> mask_ $ do
149+
zeroOutArray ptrPtr
145150
withArray (fromIntegral <$> dims) $ \dimArray -> do
146151
throwAFError =<< af_range ptrPtr n dimArray k typ
147152
peek ptrPtr
@@ -157,7 +162,8 @@ iota
157162
:: forall a . AFType a
158163
=> [Int] -> [Int] -> IO (Array a)
159164
iota dims tdims = do
160-
ptr <- alloca $ \ptrPtr -> mask_ $
165+
ptr <- alloca $ \ptrPtr -> mask_ $ do
166+
zeroOutArray ptrPtr
161167
withArray (fromIntegral <$> dims) $ \dimArray ->
162168
withArray (fromIntegral <$> tdims) $ \tdimArray -> do
163169
throwAFError =<< af_iota ptrPtr n dimArray tn tdimArray typ
@@ -184,6 +190,7 @@ identity
184190
-> Array a
185191
identity dims = unsafePerformIO . mask_ $ do
186192
ptr <- alloca $ \ptrPtr -> mask_ $ do
193+
zeroOutArray ptrPtr
187194
withArray (fromIntegral <$> dims) $ \dimArray -> do
188195
throwAFError =<< af_identity ptrPtr n dimArray typ
189196
peek ptrPtr
@@ -229,6 +236,7 @@ joinMany (fromIntegral -> n) arrays = unsafePerformIO . mask_ $ do
229236
forM_ fptrs $ \fptr ->
230237
withForeignPtr fptr (poke fPtrsPtr)
231238
alloca $ \aPtr -> do
239+
zeroOutArray aPtr
232240
throwAFError =<< af_join_many aPtr n nArrays fPtrsPtr
233241
peek aPtr
234242
Array <$>
@@ -274,6 +282,7 @@ moddims
274282
moddims dims (Array fptr) =
275283
unsafePerformIO . mask_ . withForeignPtr fptr $ \ptr -> do
276284
newPtr <- alloca $ \aPtr -> do
285+
zeroOutArray aPtr
277286
withArray (fromIntegral <$> dims) $ \dimsPtr -> do
278287
throwAFError =<< af_moddims aPtr ptr n dimsPtr
279288
peek aPtr

src/ArrayFire/FFI.hs

+5
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ createArray' op =
148148
mask_ $ do
149149
ptr <-
150150
alloca $ \ptrInput -> do
151+
zeroOutArray ptrInput
151152
throwAFError =<< op ptrInput
152153
peek ptrInput
153154
fptr <- newForeignPtr af_release_array_finalizer ptr
@@ -161,6 +162,7 @@ createArray op =
161162
unsafePerformIO . mask_ $ do
162163
ptr <-
163164
alloca $ \ptrInput -> do
165+
zeroOutArray ptrInput
164166
throwAFError =<< op ptrInput
165167
peek ptrInput
166168
fptr <- newForeignPtr af_release_array_finalizer ptr
@@ -445,3 +447,6 @@ infoFromArray4 (Array fptr1) op =
445447
<*> peek ptrInput2
446448
<*> peek ptrInput3
447449
<*> peek ptrInput4
450+
451+
foreign import ccall unsafe "zeroOutArray"
452+
zeroOutArray :: Ptr AFArray -> IO ()

src/ArrayFire/Random.hs

+1
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ rand
211211
-> IO (Array a)
212212
rand dims f = mask_ $ do
213213
ptr <- alloca $ \ptrPtr -> do
214+
zeroOutArray ptrPtr
214215
withArray (fromIntegral <$> dims) $ \dimArray -> do
215216
throwAFError =<< f ptrPtr n dimArray typ
216217
peek ptrPtr

0 commit comments

Comments
 (0)