Skip to content

Commit 9964483

Browse files
committed
Added dpctl.tensor.stack feature and tests
stack() function joins a sequence of arrays along a new axis and follows array API spec. https://data-apis.org/array-api/latest/API_specification/generated/signatures.manipulation_functions.stack.html#signatures.manipulation_functions.stack
1 parent c09ac88 commit 9964483

File tree

3 files changed

+175
-15
lines changed

3 files changed

+175
-15
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
permute_dims,
4646
roll,
4747
squeeze,
48+
stack,
4849
)
4950
from dpctl.tensor._reshape import reshape
5051
from dpctl.tensor._usmarray import usm_ndarray
@@ -68,6 +69,7 @@
6869
"reshape",
6970
"roll",
7071
"concat",
72+
"stack",
7173
"broadcast_arrays",
7274
"broadcast_to",
7375
"expand_dims",

dpctl/tensor/_manipulation_functions.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,7 @@ def roll(X, shift, axes=None):
288288
return res
289289

290290

291-
def concat(arrays, axis=0):
292-
"""
293-
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
294-
295-
Joins a sequence of arrays along an existing axis.
296-
"""
291+
def arrays_validation(arrays):
297292
n = len(arrays)
298293
if n == 0:
299294
raise TypeError("Missing 1 required positional argument: 'arrays'")
@@ -324,11 +319,23 @@ def concat(arrays, axis=0):
324319
for i in range(1, n):
325320
if X0.ndim != arrays[i].ndim:
326321
raise ValueError(
327-
"All the input arrays must have same number of "
328-
"dimensions, but the array at index 0 has "
329-
f"{X0.ndim} dimension(s) and the array at index "
330-
f"{i} has {arrays[i].ndim} dimension(s)"
322+
"All the input arrays must have same number of dimensions, "
323+
f"but the array at index 0 has {X0.ndim} dimension(s) and the "
324+
f"array at index {i} has {arrays[i].ndim} dimension(s)"
331325
)
326+
return res_dtype, res_usm_type, exec_q
327+
328+
329+
def concat(arrays, axis=0):
330+
"""
331+
concat(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
332+
333+
Joins a sequence of arrays along an existing axis.
334+
"""
335+
res_dtype, res_usm_type, exec_q = arrays_validation(arrays)
336+
337+
n = len(arrays)
338+
X0 = arrays[0]
332339

333340
axis = normalize_axis_index(axis, X0.ndim)
334341
X0_shape = X0.shape
@@ -337,11 +344,10 @@ def concat(arrays, axis=0):
337344
for j in range(X0.ndim):
338345
if X0_shape[j] != Xi_shape[j] and j != axis:
339346
raise ValueError(
340-
"All the input array dimensions for the "
341-
"concatenation axis must match exactly, but "
342-
f"along dimension {j}, the array at index 0 "
343-
f"has size {X0_shape[j]} and the array at "
344-
f"index {i} has size {Xi_shape[j]}"
347+
"All the input array dimensions for the concatenation "
348+
f"axis must match exactly, but along dimension {j}, the "
349+
f"array at index 0 has size {X0_shape[j]} and the array "
350+
f"at index {i} has size {Xi_shape[j]}"
345351
)
346352

347353
res_shape_axis = 0
@@ -373,3 +379,45 @@ def concat(arrays, axis=0):
373379
dpctl.SyclEvent.wait_for(hev_list)
374380

375381
return res
382+
383+
384+
def stack(arrays, axis=0):
385+
"""
386+
stack(arrays: tuple or list of usm_ndarrays, axis: int) -> usm_ndarray
387+
388+
Joins a sequence of arrays along a new axis.
389+
"""
390+
res_dtype, res_usm_type, exec_q = arrays_validation(arrays)
391+
392+
n = len(arrays)
393+
X0 = arrays[0]
394+
res_ndim = X0.ndim + 1
395+
axis = normalize_axis_index(axis, res_ndim)
396+
X0_shape = X0.shape
397+
398+
for i in range(1, n):
399+
if X0_shape != arrays[i].shape:
400+
raise ValueError("All input arrays must have the same shape")
401+
402+
res_shape = tuple(
403+
X0_shape[i - 1 * (i >= axis)] if i != axis else n
404+
for i in range(res_ndim)
405+
)
406+
407+
res = dpt.empty(
408+
res_shape, dtype=res_dtype, usm_type=res_usm_type, sycl_queue=exec_q
409+
)
410+
411+
hev_list = []
412+
for i in range(n):
413+
c_shapes_copy = tuple(
414+
i if j == axis else np.s_[:] for j in range(res_ndim)
415+
)
416+
hev, _ = ti._copy_usm_ndarray_into_usm_ndarray(
417+
src=arrays[i], dst=res[c_shapes_copy], sycl_queue=exec_q
418+
)
419+
hev_list.append(hev)
420+
421+
dpctl.SyclEvent.wait_for(hev_list)
422+
423+
return res

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,3 +889,113 @@ def test_concat_3arrays(data):
889889
R = dpt.concat([X, Y, Z], axis=axis)
890890

891891
assert_array_equal(Rnp, dpt.asnumpy(R))
892+
893+
894+
def test_stack_incorrect_shape():
895+
try:
896+
q = dpctl.SyclQueue()
897+
except dpctl.SyclQueueCreationError:
898+
pytest.skip("Queue could not be created")
899+
900+
X = dpt.ones((1,), sycl_queue=q)
901+
Y = dpt.ones((2,), sycl_queue=q)
902+
903+
pytest.raises(ValueError, dpt.stack, [X, Y], 0)
904+
905+
906+
@pytest.mark.parametrize(
907+
"data",
908+
[
909+
[(6,), 0],
910+
[(2, 3), 1],
911+
[(3, 2), -1],
912+
[(1, 6), 2],
913+
[(2, 1, 3), 2],
914+
],
915+
)
916+
def test_stack_1array(data):
917+
try:
918+
q = dpctl.SyclQueue()
919+
except dpctl.SyclQueueCreationError:
920+
pytest.skip("Queue could not be created")
921+
922+
shape, axis = data
923+
924+
Xnp = np.arange(6).reshape(shape)
925+
X = dpt.asarray(Xnp, sycl_queue=q)
926+
927+
Ynp = np.stack([Xnp], axis=axis)
928+
Y = dpt.stack([X], axis=axis)
929+
930+
assert_array_equal(Ynp, dpt.asnumpy(Y))
931+
932+
Ynp = np.stack((Xnp,), axis=axis)
933+
Y = dpt.stack((X,), axis=axis)
934+
935+
assert_array_equal(Ynp, dpt.asnumpy(Y))
936+
937+
938+
@pytest.mark.parametrize(
939+
"data",
940+
[
941+
[(1,), 0],
942+
[(0, 2), 0],
943+
[(2, 0), 0],
944+
[(2, 3), 0],
945+
[(2, 3), 1],
946+
[(2, 3), 2],
947+
[(2, 3), -1],
948+
[(2, 3), -2],
949+
[(2, 2, 2), 1],
950+
],
951+
)
952+
def test_stack_2arrays(data):
953+
try:
954+
q = dpctl.SyclQueue()
955+
except dpctl.SyclQueueCreationError:
956+
pytest.skip("Queue could not be created")
957+
958+
shape, axis = data
959+
960+
Xnp = np.ones(shape)
961+
X = dpt.asarray(Xnp, sycl_queue=q)
962+
963+
Ynp = np.zeros(shape)
964+
Y = dpt.asarray(Ynp, sycl_queue=q)
965+
966+
Znp = np.stack([Xnp, Ynp], axis=axis)
967+
print(Znp.shape)
968+
Z = dpt.stack([X, Y], axis=axis)
969+
970+
assert_array_equal(Znp, dpt.asnumpy(Z))
971+
972+
973+
@pytest.mark.parametrize(
974+
"data",
975+
[
976+
[(1,), 0],
977+
[(0, 2), 0],
978+
[(2, 1, 2), 1],
979+
],
980+
)
981+
def test_stack_3arrays(data):
982+
try:
983+
q = dpctl.SyclQueue()
984+
except dpctl.SyclQueueCreationError:
985+
pytest.skip("Queue could not be created")
986+
987+
shape, axis = data
988+
989+
Xnp = np.ones(shape)
990+
X = dpt.asarray(Xnp, sycl_queue=q)
991+
992+
Ynp = np.zeros(shape)
993+
Y = dpt.asarray(Ynp, sycl_queue=q)
994+
995+
Znp = np.full(shape, 2.0)
996+
Z = dpt.asarray(Znp, sycl_queue=q)
997+
998+
Rnp = np.stack([Xnp, Ynp, Znp], axis=axis)
999+
R = dpt.stack([X, Y, Z], axis=axis)
1000+
1001+
assert_array_equal(Rnp, dpt.asnumpy(R))

0 commit comments

Comments
 (0)