@@ -462,6 +462,20 @@ def _reduce_multiple_axes(f, x, axis, keepdims=False, **kwargs):
462462 out = paddle .unsqueeze (out , a )
463463 return out
464464
465+ _NP_2_PADDLE_DTYPE = {
466+ "BOOL" : 'bool' ,
467+ "UINT8" : 'uint8' ,
468+ "INT8" : 'int8' ,
469+ "INT16" : 'int16' ,
470+ "INT32" : 'int32' ,
471+ "INT64" : 'int64' ,
472+ "FLOAT16" : 'float16' ,
473+ "BFLOAT16" : 'bfloat16' ,
474+ "FLOAT32" : 'float32' ,
475+ "FLOAT64" : 'float64' ,
476+ "COMPLEX128" : 'complex128' ,
477+ "COMPLEX64" : 'complex64' ,
478+ }
465479
466480def prod (
467481 x : array ,
@@ -476,7 +490,36 @@ def prod(
476490 x = paddle .to_tensor (x )
477491 ndim = x .ndim
478492
479- # below because it still needs to upcast.
493+ # fix reducing on the zero dimension
494+ if x .numel () == 0 :
495+ if dtype is not None :
496+ output_dtype = _NP_2_PADDLE_DTYPE [dtype .name ]
497+ else :
498+ if x .dtype == paddle .bool :
499+ output_dtype = paddle .int64
500+ else :
501+ output_dtype = x .dtype
502+
503+ if axis is None :
504+ return paddle .to_tensor (1 , dtype = output_dtype )
505+
506+ if keepdims :
507+ output_shape = list (x .shape )
508+ if isinstance (axis , int ):
509+ axis = (axis ,)
510+ for ax in axis :
511+ output_shape [ax ] = 1
512+ else :
513+ output_shape = [dim for i , dim in enumerate (x .shape ) if i not in (axis if isinstance (axis , tuple ) else [axis ])]
514+ if not output_shape :
515+ return paddle .to_tensor (1 , dtype = output_dtype )
516+
517+ return paddle .ones (output_shape , dtype = output_dtype )
518+
519+
520+ if dtype is not None :
521+ dtype = _NP_2_PADDLE_DTYPE [dtype .name ]
522+
480523 if axis == ():
481524 if dtype is None :
482525 # We can't upcast uint8 according to the spec because there is no
@@ -492,13 +535,17 @@ def prod(
492535 return _reduce_multiple_axes (
493536 paddle .prod , x , axis , keepdim = keepdims , dtype = dtype , ** kwargs
494537 )
538+
539+
495540 if axis is None :
496541 # paddle doesn't support keepdims with axis=None
542+ if dtype is None and x .dtype == paddle .int32 :
543+ dtype = 'int64'
497544 res = paddle .prod (x , dtype = dtype , ** kwargs )
498545 res = _axis_none_keepdims (res , ndim , keepdims )
499546 return res
500-
501- return paddle .prod (x , axis , dtype = dtype , keepdim = keepdims , ** kwargs )
547+
548+ return paddle .prod (x , axis = axis , keepdims = keepdims , dtype = dtype , ** kwargs )
502549
503550
504551def sum (
@@ -747,7 +794,17 @@ def roll(
747794def nonzero (x : array , / , ** kwargs ) -> Tuple [array , ...]:
748795 if x .ndim == 0 :
749796 raise ValueError ("nonzero() does not support zero-dimensional arrays" )
750- return paddle .nonzero (x , as_tuple = True , ** kwargs )
797+
798+ if paddle .is_floating_point (x ) or paddle .is_complex (x ) :
799+ # Use paddle.isclose() to determine which elements are
800+ # "close enough" to zero.
801+ zero_tensor = paddle .zeros (shape = x .shape ,dtype = x .dtype )
802+ is_zero_mask = paddle .isclose (x , zero_tensor )
803+ is_nonzero_mask = paddle .logical_not (is_zero_mask )
804+ return paddle .nonzero (is_nonzero_mask , as_tuple = True , ** kwargs )
805+
806+ else :
807+ return paddle .nonzero (x , as_tuple = True , ** kwargs )
751808
752809
753810def where (condition : array , x1 : array , x2 : array , / ) -> array :
@@ -832,7 +889,7 @@ def eye(
832889 if n_cols is None :
833890 n_cols = n_rows
834891 z = paddle .zeros ([n_rows , n_cols ], dtype = dtype , ** kwargs ).to (device )
835- if abs (k ) <= n_rows + n_cols :
892+ if n_rows > 0 and n_cols > 0 and abs (k ) <= n_rows + n_cols :
836893 z .diagonal (k ).fill_ (1 )
837894 return z
838895
@@ -867,7 +924,11 @@ def full(
867924) -> array :
868925 if isinstance (shape , int ):
869926 shape = (shape ,)
870-
927+ if dtype is None :
928+ if isinstance (fill_value , (bool )):
929+ dtype = "bool"
930+ elif isinstance (fill_value , int ):
931+ dtype = 'int64'
871932 return paddle .full (shape , fill_value , dtype = dtype , ** kwargs ).to (device )
872933
873934
@@ -914,6 +975,8 @@ def triu(x: array, /, *, k: int = 0) -> array:
914975
915976
916977def expand_dims (x : array , / , * , axis : int = 0 ) -> array :
978+ if axis < - x .ndim - 1 or axis > x .ndim :
979+ raise IndexError (f"Axis { axis } is out of bounds for array of dimension { x .ndim } " )
917980 return paddle .unsqueeze (x , axis )
918981
919982
@@ -973,6 +1036,22 @@ def unique_values(x: array) -> array:
9731036
9741037def matmul (x1 : array , x2 : array , / , ** kwargs ) -> array :
9751038 # paddle.matmul doesn't type promote (but differently from _fix_promotion)
1039+ d1 = x1 .ndim
1040+ d2 = x2 .ndim
1041+
1042+ if d1 == 0 or d2 == 0 :
1043+ raise ValueError ("matmul does not support 0-D (scalar) inputs." )
1044+
1045+ k1 = x1 .shape [- 1 ]
1046+
1047+ if d2 == 1 :
1048+ k2 = x2 .shape [0 ]
1049+ else :
1050+ k2 = x2 .shape [- 2 ]
1051+
1052+ if k1 != k2 :
1053+ raise ValueError (f"Shapes { x1 .shape } and { x2 .shape } are not aligned for matmul: "
1054+ f"{ k1 } (dim -1 of x1) != { k2 } (dim -2 of x2)" )
9761055 x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
9771056 return paddle .matmul (x1 , x2 , ** kwargs )
9781057
@@ -988,7 +1067,36 @@ def meshgrid(*arrays: array, indexing: str = "xy") -> List[array]:
9881067
9891068
9901069def vecdot (x1 : array , x2 : array , / , * , axis : int = - 1 ) -> array :
991- x1 , x2 = _fix_promotion (x1 , x2 , only_scalar = False )
1070+ shape1 = x1 .shape
1071+ shape2 = x2 .shape
1072+ rank1 = len (shape1 )
1073+ rank2 = len (shape2 )
1074+ if rank1 == 0 or rank2 == 0 :
1075+ raise ValueError (
1076+ f"Vector dot product requires non-scalar inputs (rank > 0). "
1077+ f"Got ranks { rank1 } and { rank2 } for shapes { shape1 } and { shape2 } ."
1078+ )
1079+ try :
1080+ norm_axis1 = axis if axis >= 0 else rank1 + axis
1081+ if not (0 <= norm_axis1 < rank1 ):
1082+ raise IndexError # Axis out of bounds for x1
1083+ norm_axis2 = axis if axis >= 0 else rank2 + axis
1084+ if not (0 <= norm_axis2 < rank2 ):
1085+ raise IndexError # Axis out of bounds for x2
1086+ size1 = shape1 [norm_axis1 ]
1087+ size2 = shape2 [norm_axis2 ]
1088+ except IndexError :
1089+ raise ValueError (
1090+ f"Axis { axis } is out of bounds for input shapes { shape1 } (rank { rank1 } ) "
1091+ f"and/or { shape2 } (rank { rank2 } )."
1092+ )
1093+
1094+ if size1 != size2 :
1095+ raise ValueError (
1096+ f"Inputs must have the same dimension size along the reduction axis ({ axis } ). "
1097+ f"Got shapes { shape1 } and { shape2 } , with sizes { size1 } and { size2 } "
1098+ f"along the normalized axis { norm_axis1 } and { norm_axis2 } respectively."
1099+ )
9921100 return paddle .linalg .vecdot (x1 , x2 , axis = axis )
9931101
9941102
@@ -1063,21 +1171,39 @@ def is_complex(dtype):
10631171
10641172
10651173def take (x : array , indices : array , / , * , axis : Optional [int ] = None , ** kwargs ) -> array :
1066- if axis is None :
1174+ _axis = axis
1175+ if _axis is None :
10671176 if x .ndim != 1 :
1068- raise ValueError ("axis must be specified when ndim > 1" )
1069- axis = 0
1070- return paddle .index_select (x , axis , indices , ** kwargs )
1177+ raise ValueError ("axis must be specified when x.ndim > 1" )
1178+ _axis = 0
1179+ elif not isinstance (_axis , int ):
1180+ raise TypeError (f"axis must be an integer, but received { type (_axis )} " )
1181+
1182+ if not (- x .ndim <= _axis < x .ndim ):
1183+ raise IndexError (f"axis { _axis } is out of bounds for tensor of dimension { x .ndim } " )
1184+
1185+ if isinstance (indices , paddle .Tensor ):
1186+ indices_tensor = indices
1187+ elif isinstance (indices , int ):
1188+ indices_tensor = paddle .to_tensor ([indices ], dtype = 'int64' )
1189+ else :
1190+ # Otherwise (e.g., list, tuple), convert directly
1191+ indices_tensor = paddle .to_tensor (indices , dtype = 'int64' )
1192+ # Ensure indices is a 1D tensor
1193+ if indices_tensor .ndim == 0 :
1194+ indices_tensor = indices_tensor .reshape ([1 ])
1195+ elif indices_tensor .ndim > 1 :
1196+ raise ValueError (f"indices must be a 1D tensor, but received a { indices_tensor .ndim } D tensor" )
1197+
1198+ return paddle .index_select (x , index = indices_tensor , axis = _axis )
10711199
10721200
10731201def sign (x : array , / ) -> array :
10741202 # paddle sign() does not support complex numbers and does not propagate
10751203 # nans. See https://github.com/data-apis/array-api-compat/issues/136
1076- if paddle .is_complex (x ):
1077- out = x / paddle .abs (x )
1078- # sign(0) = 0 but the above formula would give nan
1079- out [x == 0 + 0j ] = 0 + 0j
1080- return out
1204+ if paddle .is_complex (x ) and x .ndim == 0 and x .item () == 0j :
1205+ # Handle 0-D complex zero explicitly
1206+ return paddle .zeros_like (x , dtype = x .dtype )
10811207 else :
10821208 out = paddle .sign (x )
10831209 if paddle .is_floating_point (x ):
0 commit comments