@@ -24,26 +24,6 @@ def get_namespace(*xs, default=cu):
2424 return default # backwards compatibility
2525
2626
27- def check_cuvec (a , shape , dtype , xp = cu ):
28- """Asserts that CuVec `a` is of `shape` & `dtype`"""
29- if not isinstance (a , xp .CuVec ):
30- raise TypeError (f"must be a { xp .CuVec } " )
31- elif np .dtype (a .dtype ) != np .dtype (dtype ):
32- raise TypeError (f"dtype must be { dtype } : got { a .dtype } " )
33- elif a .shape != shape :
34- raise IndexError (f"shape must be { shape } : got { a .shape } " )
35-
36-
37- def check_similar (* arrays , allow_none = True ):
38- """Asserts that all arrays are `CuVec`s of the same `shape` & `dtype`"""
39- arrs = tuple (filter (lambda x : x is not None , arrays ))
40- if not allow_none and len (arrays ) != len (arrs ):
41- raise TypeError ("must not be None" )
42- shape , dtype , xp = arrs [0 ].shape , arrs [0 ].dtype , get_namespace (* arrs )
43- for a in arrs :
44- check_cuvec (a , shape , dtype , xp )
45-
46-
4727def div (numerator , divisor , default = FLOAT_MAX , output = None , dev_id = 0 , sync = True ):
4828 """
4929 Elementwise `output = numerator / divisor if divisor else default`
@@ -60,11 +40,9 @@ def div(numerator, divisor, default=FLOAT_MAX, output=None, dev_id=0, sync=True)
6040 res [np .isnan (res )] = default
6141 return res
6242 cu .dev_set (dev_id )
63- xp = get_namespace (numerator , divisor , output )
64- numerator = xp .asarray (numerator , 'float32' )
65- divisor = xp .asarray (divisor , 'float32' )
66- output = xp .zeros_like (numerator ) if output is None else xp .asarray (output , 'float32' )
67- check_similar (numerator , divisor , output )
43+ if output is None :
44+ output = get_namespace (numerator , divisor , output ).zeros_like (numerator )
45+ assert numerator .size == divisor .size == output .size
6846 ext .div (numerator , divisor , output , default = default )
6947 if sync : cu .dev_sync ()
7048 return output
@@ -82,11 +60,9 @@ def mul(a, b, output=None, dev_id=0, sync=True):
8260 """
8361 if dev_id is False : return np .multiply (a , b , out = output )
8462 cu .dev_set (dev_id )
85- xp = get_namespace (a , b , output )
86- a = xp .asarray (a , 'float32' )
87- b = xp .asarray (b , 'float32' )
88- output = xp .zeros_like (a ) if output is None else xp .asarray (output , 'float32' )
89- check_similar (a , b , output )
63+ if output is None :
64+ output = get_namespace (a , b , output ).zeros_like (a )
65+ assert a .size == b .size == output .size
9066 ext .mul (a , b , output )
9167 if sync : cu .dev_sync ()
9268 return output
@@ -104,11 +80,9 @@ def add(a, b, output=None, dev_id=0, sync=True):
10480 """
10581 if dev_id is False : return np .add (a , b , out = output )
10682 cu .dev_set (dev_id )
107- xp = get_namespace (a , b , output )
108- a = xp .asarray (a , 'float32' )
109- b = xp .asarray (b , 'float32' )
110- output = xp .zeros_like (a ) if output is None else xp .asarray (output , 'float32' )
111- check_similar (a , b , output )
83+ if output is None :
84+ output = get_namespace (a , b , output ).zeros_like (a )
85+ assert a .size == b .size == output .size
11286 ext .add (a , b , output )
11387 if sync : cu .dev_sync ()
11488 return output
0 commit comments