Skip to content

Commit 0bff2f1

Browse files
committed
BUG: fix dtype of include_initial in cumulative_sum
In `concat([zeros(...), x])` zeros must have the same dtype as `x`.
1 parent 52468fb commit 0bff2f1

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Diff for: array_api_strict/_statistical_functions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def cumulative_sum(
4444
if include_initial:
4545
if axis < 0:
4646
axis += x.ndim
47-
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=dt), x], axis=axis)
47+
x = concat([zeros(x.shape[:axis] + (1,) + x.shape[axis + 1:], dtype=x.dtype), x], axis=axis)
4848
return Array._new(np.cumsum(x._array, axis=axis, dtype=dtype), device=x.device)
4949

5050

0 commit comments

Comments
 (0)