@@ -237,6 +237,7 @@ def test_transpose(self, B: int, T: int, E: int) -> None:
237
237
]
238
238
),
239
239
mixed_B = st .booleans (),
240
+ bounds_check_version = st .sampled_from ((1 , 2 )),
240
241
)
241
242
@settings (verbosity = VERBOSITY , max_examples = MAX_EXAMPLES , deadline = None )
242
243
def test_bounds_check ( # noqa C901
@@ -249,6 +250,7 @@ def test_bounds_check( # noqa C901
249
250
weighted : bool ,
250
251
dtype : torch .dtype ,
251
252
mixed_B : bool ,
253
+ bounds_check_version : int ,
252
254
) -> None :
253
255
rows_per_table = torch .tensor (
254
256
np .random .randint (low = 1 , high = 1000 , size = (T ,))
@@ -348,6 +350,7 @@ def test_bounds_check( # noqa C901
348
350
bounds_check_mode ,
349
351
warning ,
350
352
weights ,
353
+ bounds_check_version = bounds_check_version ,
351
354
** vbe_args ,
352
355
)
353
356
# we don't modify when we are in-bounds.
@@ -361,6 +364,7 @@ def test_bounds_check( # noqa C901
361
364
bounds_check_mode ,
362
365
warning ,
363
366
weights ,
367
+ bounds_check_version = bounds_check_version ,
364
368
** vbe_args ,
365
369
)
366
370
torch .testing .assert_close (indices , torch .zeros_like (indices ))
@@ -376,6 +380,7 @@ def test_bounds_check( # noqa C901
376
380
bounds_check_mode ,
377
381
warning ,
378
382
weights ,
383
+ bounds_check_version = bounds_check_version ,
379
384
** vbe_args ,
380
385
)
381
386
# It would be nice to test the CUDA implementation of BoundsCheckMode==FATAL,
@@ -397,6 +402,7 @@ def test_bounds_check( # noqa C901
397
402
bounds_check_mode ,
398
403
warning ,
399
404
weights ,
405
+ bounds_check_version = bounds_check_version ,
400
406
** vbe_args ,
401
407
)
402
408
if offsets .numel () > 0 :
@@ -417,6 +423,7 @@ def test_bounds_check( # noqa C901
417
423
bounds_check_mode ,
418
424
warning ,
419
425
weights ,
426
+ bounds_check_version = bounds_check_version ,
420
427
)
421
428
422
429
# test offsets.size(0) ! = B * T + 1 case. Here we test with T >= 2 case.
@@ -444,6 +451,7 @@ def test_bounds_check( # noqa C901
444
451
bounds_check_mode ,
445
452
warning ,
446
453
weights ,
454
+ bounds_check_version = bounds_check_version ,
447
455
)
448
456
449
457
# test weights.size(0) != indices.size(0) case
@@ -459,6 +467,7 @@ def test_bounds_check( # noqa C901
459
467
warning ,
460
468
weights ,
461
469
** vbe_args ,
470
+ bounds_check_version = bounds_check_version ,
462
471
)
463
472
464
473
@given (
0 commit comments