Skip to content

Commit e3fdad9

Browse files
authored
Fix triu and tril for batched input (#14)
Signed-off-by: Akhil Goel <[email protected]>
1 parent 993a889 commit e3fdad9

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

Diff for: tripy/tripy/frontend/ops/tensor_initializers.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
161161
:linenos:
162162
:caption: Main Diagonal
163163
164-
input = tp.iota((5, 5)) + 1.
164+
input = tp.iota((2, 1, 3, 3), dim=2) + 1.
165165
output = tp.tril(input)
166166
167167
assert np.array_equal(cp.from_dlpack(output).get(), np.tril(cp.from_dlpack(input).get()))
@@ -184,8 +184,8 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
184184
185185
assert np.array_equal(cp.from_dlpack(output).get(), np.tril(cp.from_dlpack(input).get(), -1))
186186
"""
187-
tri_mask = (iota_like(tensor, 0, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) >= iota_like(
188-
tensor, 1, datatype.int32
187+
tri_mask = (iota_like(tensor, -2, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) >= iota_like(
188+
tensor, -1, datatype.int32
189189
)
190190
zeros_tensor = zeros_like(tensor)
191191
return where(tri_mask, tensor, zeros_tensor)
@@ -213,7 +213,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
213213
:linenos:
214214
:caption: Main Diagonal
215215
216-
input = tp.iota((5, 5)) + 1.
216+
input = tp.iota((2, 1, 3, 3), dim=2) + 1.
217217
output = tp.triu(input)
218218
219219
assert np.array_equal(cp.from_dlpack(output).get(), np.triu(cp.from_dlpack(input).get()))
@@ -236,8 +236,8 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
236236
237237
assert np.array_equal(cp.from_dlpack(output).get(), np.triu(cp.from_dlpack(input).get(), -1))
238238
"""
239-
tri_mask = (iota_like(tensor, 0, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) <= iota_like(
240-
tensor, 1, datatype.int32
239+
tri_mask = (iota_like(tensor, -2, datatype.int32) + full_like(tensor, diagonal, datatype.int32)) <= iota_like(
240+
tensor, -1, datatype.int32
241241
)
242242
zeros_tensor = zeros_like(tensor)
243243
return where(tri_mask, tensor, zeros_tensor)

0 commit comments

Comments
 (0)