@@ -161,7 +161,7 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
161
161
:linenos:
162
162
:caption: Main Diagonal
163
163
164
- input = tp.iota((5, 5) ) + 1.
164
+ input = tp.iota((2, 1, 3, 3), dim=2 ) + 1.
165
165
output = tp.tril(input)
166
166
167
167
assert np.array_equal(cp.from_dlpack(output).get(), np.tril(cp.from_dlpack(input).get()))
@@ -184,8 +184,8 @@ def tril(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
184
184
185
185
assert np.array_equal(cp.from_dlpack(output).get(), np.tril(cp.from_dlpack(input).get(), -1))
186
186
"""
187
- tri_mask = (iota_like (tensor , 0 , datatype .int32 ) + full_like (tensor , diagonal , datatype .int32 )) >= iota_like (
188
- tensor , 1 , datatype .int32
187
+ tri_mask = (iota_like (tensor , - 2 , datatype .int32 ) + full_like (tensor , diagonal , datatype .int32 )) >= iota_like (
188
+ tensor , - 1 , datatype .int32
189
189
)
190
190
zeros_tensor = zeros_like (tensor )
191
191
return where (tri_mask , tensor , zeros_tensor )
@@ -213,7 +213,7 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
213
213
:linenos:
214
214
:caption: Main Diagonal
215
215
216
- input = tp.iota((5, 5) ) + 1.
216
+ input = tp.iota((2, 1, 3, 3), dim=2 ) + 1.
217
217
output = tp.triu(input)
218
218
219
219
assert np.array_equal(cp.from_dlpack(output).get(), np.triu(cp.from_dlpack(input).get()))
@@ -236,8 +236,8 @@ def triu(tensor: "tripy.Tensor", diagonal: int = 0) -> "tripy.Tensor":
236
236
237
237
assert np.array_equal(cp.from_dlpack(output).get(), np.triu(cp.from_dlpack(input).get(), -1))
238
238
"""
239
- tri_mask = (iota_like (tensor , 0 , datatype .int32 ) + full_like (tensor , diagonal , datatype .int32 )) <= iota_like (
240
- tensor , 1 , datatype .int32
239
+ tri_mask = (iota_like (tensor , - 2 , datatype .int32 ) + full_like (tensor , diagonal , datatype .int32 )) <= iota_like (
240
+ tensor , - 1 , datatype .int32
241
241
)
242
242
zeros_tensor = zeros_like (tensor )
243
243
return where (tri_mask , tensor , zeros_tensor )
0 commit comments