@@ -355,7 +355,7 @@ def test_sign_shapes_invalid(invdtypes: dtype.Dtype) -> None:
355
355
)
356
356
@pytest .mark .parametrize ("dtype_name" , util .get_real_types ())
357
357
def test_trunc_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
358
- """Test truncating operation between two arrays of the same shape"""
358
+ """Test truncating operation for an array with varying shape"""
359
359
util .check_type_supported (dtype_name )
360
360
out = wrapper .randu (shape , dtype_name )
361
361
@@ -366,7 +366,7 @@ def test_trunc_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
366
366
367
367
@pytest .mark .parametrize ("invdtypes" , util .get_complex_types ())
368
368
def test_trunc_shapes_invalid (invdtypes : dtype .Dtype ) -> None :
369
- """Test trunc operation between two arrays of the same shape"""
369
+ """Test trunc operation for an array with varrying shape and invalid dtypes """
370
370
with pytest .raises (RuntimeError ):
371
371
shape = (3 , 3 )
372
372
out = wrapper .randu (shape , invdtypes )
@@ -408,4 +408,26 @@ def test_hypot_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
408
408
shape = (5 , 5 )
409
409
lhs = wrapper .randu (shape , invdtypes )
410
410
rhs = wrapper .randu (shape , invdtypes )
411
- wrapper .hypot (rhs , lhs , True )
411
+ wrapper .hypot (rhs , lhs , True )
412
+ @pytest .mark .parametrize (
413
+ "shape" ,
414
+ [
415
+ (),
416
+ (random .randint (1 , 10 ),),
417
+ (random .randint (1 , 10 ), random .randint (1 , 10 )),
418
+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
419
+ (random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 ), random .randint (1 , 10 )),
420
+ ],
421
+ )
422
+ @pytest .mark .parametrize ("dtype_name" , util .get_real_types ())
423
+ def test_clamp_shape_dtypes (shape : tuple , dtype_name : dtype .Dtype ) -> None :
424
+ """Test clamp operation between two arrays of the same shape"""
425
+ util .check_type_supported (dtype_name )
426
+ og = wrapper .randu (shape , dtype_name )
427
+ low = wrapper .randu (shape , dtype_name )
428
+ high = wrapper .randu (shape , dtype_name )
429
+ # talked to stefan about this, testing broadcasting is unnecessary
430
+ result = wrapper .clamp (og , low , high , False )
431
+ assert (
432
+ wrapper .get_dims (result )[0 : len (shape )] == shape # noqa
433
+ ), f"failed for shape: { shape } and dtype { dtype_name } "
0 commit comments