@@ -744,6 +744,11 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) -
744
744
axis = 0
745
745
return torch .index_select (x , axis , indices , ** kwargs )
746
746
747
+
748
+ def take_along_axis (x : array , indices : array , / , * , axis : int = - 1 ) -> array :
749
+ return torch .take_along_dim (x , indices , dim = axis )
750
+
751
+
747
752
def sign (x : array , / ) -> array :
748
753
# torch sign() does not support complex numbers and does not propagate
749
754
# nans. See https://github.com/data-apis/array-api-compat/issues/136
@@ -767,14 +772,14 @@ def sign(x: array, /) -> array:
767
772
'equal' , 'floor_divide' , 'greater' , 'greater_equal' , 'hypot' ,
768
773
'less' , 'less_equal' , 'logaddexp' , 'maximum' , 'minimum' ,
769
774
'multiply' , 'not_equal' , 'pow' , 'remainder' , 'subtract' , 'max' ,
770
- 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'sort' , 'prod' , 'sum' ,
775
+ 'min' , 'clip' , 'unstack' , 'cumulative_sum' , 'cumulative_prod' , ' sort' , 'prod' , 'sum' ,
771
776
'any' , 'all' , 'mean' , 'std' , 'var' , 'concat' , 'squeeze' ,
772
777
'broadcast_to' , 'flip' , 'roll' , 'nonzero' , 'where' , 'reshape' ,
773
778
'arange' , 'eye' , 'linspace' , 'full' , 'ones' , 'zeros' , 'empty' ,
774
779
'tril' , 'triu' , 'expand_dims' , 'astype' , 'broadcast_arrays' ,
775
780
'UniqueAllResult' , 'UniqueCountsResult' , 'UniqueInverseResult' ,
776
781
'unique_all' , 'unique_counts' , 'unique_inverse' , 'unique_values' ,
777
782
'matmul' , 'matrix_transpose' , 'vecdot' , 'tensordot' , 'isdtype' ,
778
- 'take' , 'sign' ]
783
+ 'take' , 'take_along_axis' , ' sign' ]
779
784
780
785
_all_ignore = ['torch' , 'get_xp' ]
0 commit comments