Skip to content

Commit af20f47

Browse files
committed
drop unneeded type checking
1 parent 6201ec6 commit af20f47

File tree

1 file changed

+9
-35
lines changed

1 file changed

+9
-35
lines changed

numcu/lib.py

Lines changed: 9 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -24,26 +24,6 @@ def get_namespace(*xs, default=cu):
2424
return default # backwards compatibility
2525

2626

27-
def check_cuvec(a, shape, dtype, xp=cu):
28-
"""Asserts that CuVec `a` is of `shape` & `dtype`"""
29-
if not isinstance(a, xp.CuVec):
30-
raise TypeError(f"must be a {xp.CuVec}")
31-
elif np.dtype(a.dtype) != np.dtype(dtype):
32-
raise TypeError(f"dtype must be {dtype}: got {a.dtype}")
33-
elif a.shape != shape:
34-
raise IndexError(f"shape must be {shape}: got {a.shape}")
35-
36-
37-
def check_similar(*arrays, allow_none=True):
38-
"""Asserts that all arrays are `CuVec`s of the same `shape` & `dtype`"""
39-
arrs = tuple(filter(lambda x: x is not None, arrays))
40-
if not allow_none and len(arrays) != len(arrs):
41-
raise TypeError("must not be None")
42-
shape, dtype, xp = arrs[0].shape, arrs[0].dtype, get_namespace(*arrs)
43-
for a in arrs:
44-
check_cuvec(a, shape, dtype, xp)
45-
46-
4727
def div(numerator, divisor, default=FLOAT_MAX, output=None, dev_id=0, sync=True):
4828
"""
4929
Elementwise `output = numerator / divisor if divisor else default`
@@ -60,11 +40,9 @@ def div(numerator, divisor, default=FLOAT_MAX, output=None, dev_id=0, sync=True)
6040
res[np.isnan(res)] = default
6141
return res
6242
cu.dev_set(dev_id)
63-
xp = get_namespace(numerator, divisor, output)
64-
numerator = xp.asarray(numerator, 'float32')
65-
divisor = xp.asarray(divisor, 'float32')
66-
output = xp.zeros_like(numerator) if output is None else xp.asarray(output, 'float32')
67-
check_similar(numerator, divisor, output)
43+
if output is None:
44+
output = get_namespace(numerator, divisor, output).zeros_like(numerator)
45+
assert numerator.size == divisor.size == output.size
6846
ext.div(numerator, divisor, output, default=default)
6947
if sync: cu.dev_sync()
7048
return output
@@ -82,11 +60,9 @@ def mul(a, b, output=None, dev_id=0, sync=True):
8260
"""
8361
if dev_id is False: return np.multiply(a, b, out=output)
8462
cu.dev_set(dev_id)
85-
xp = get_namespace(a, b, output)
86-
a = xp.asarray(a, 'float32')
87-
b = xp.asarray(b, 'float32')
88-
output = xp.zeros_like(a) if output is None else xp.asarray(output, 'float32')
89-
check_similar(a, b, output)
63+
if output is None:
64+
output = get_namespace(a, b, output).zeros_like(a)
65+
assert a.size == b.size == output.size
9066
ext.mul(a, b, output)
9167
if sync: cu.dev_sync()
9268
return output
@@ -104,11 +80,9 @@ def add(a, b, output=None, dev_id=0, sync=True):
10480
"""
10581
if dev_id is False: return np.add(a, b, out=output)
10682
cu.dev_set(dev_id)
107-
xp = get_namespace(a, b, output)
108-
a = xp.asarray(a, 'float32')
109-
b = xp.asarray(b, 'float32')
110-
output = xp.zeros_like(a) if output is None else xp.asarray(output, 'float32')
111-
check_similar(a, b, output)
83+
if output is None:
84+
output = get_namespace(a, b, output).zeros_like(a)
85+
assert a.size == b.size == output.size
11286
ext.add(a, b, output)
11387
if sync: cu.dev_sync()
11488
return output

0 commit comments

Comments
 (0)