@@ -3,7 +3,7 @@ use std::fmt;
33use std:: fmt:: Formatter ;
44use std:: panic:: UnwindSafe ;
55use std:: ptr:: { self , NonNull } ;
6- use std:: sync:: atomic:: Ordering ;
6+ use std:: sync:: atomic:: { AtomicU8 , Ordering } ;
77use std:: sync:: Arc ;
88
99use rustc_hash:: FxHashMap ;
@@ -47,21 +47,32 @@ pub struct ZalsaLocal {
4747
4848/// A cancellation token that can be used to cancel a query computation for a specific local `Database`.
4949#[ derive( Default , Clone , Debug ) ]
50- pub struct CancellationToken ( Arc < AtomicBool > ) ;
50+ pub struct CancellationToken ( Arc < AtomicU8 > ) ;
5151
5252impl CancellationToken {
53+ const CANCELLED_MASK : u8 = 0b01 ;
54+ const DISABLED_MASK : u8 = 0b10 ;
55+
5356 /// Inform the database to cancel the current query computation.
5457 pub fn cancel ( & self ) {
55- self . 0 . store ( true , Ordering :: Relaxed ) ;
58+ self . 0 . fetch_or ( Self :: CANCELLED_MASK , Ordering :: Relaxed ) ;
5659 }
5760
5861 /// Check if the query computation has been requested to be cancelled.
5962 pub fn is_cancelled ( & self ) -> bool {
60- self . 0 . load ( Ordering :: Relaxed )
63+ self . 0 . load ( Ordering :: Relaxed ) & Self :: CANCELLED_MASK != 0
6164 }
6265
63- pub ( crate ) fn uncancel ( & self ) {
64- self . 0 . store ( false , Ordering :: Relaxed ) ;
66+ fn set_cancellation_disabled ( & self , disabled : bool ) -> bool {
67+ self . 0 . fetch_or ( ( disabled as u8 ) << 1 , Ordering :: Relaxed ) & Self :: DISABLED_MASK != 0
68+ }
69+
70+ fn should_trigger_local_cancellation ( & self ) -> bool {
71+ self . 0 . load ( Ordering :: Relaxed ) == Self :: CANCELLED_MASK
72+ }
73+
74+ fn reset ( & self ) {
75+ self . 0 . store ( 0 , Ordering :: Relaxed ) ;
6576 }
6677}
6778
@@ -433,12 +444,12 @@ impl ZalsaLocal {
433444
434445 #[ inline]
435446 pub ( crate ) fn uncancel ( & self ) {
436- self . cancelled . uncancel ( ) ;
447+ self . cancelled . reset ( ) ;
437448 }
438449
439450 #[ inline]
440- pub fn is_cancelled ( & self ) -> bool {
441- self . cancelled . 0 . load ( Ordering :: Relaxed )
451+ pub fn should_trigger_local_cancellation ( & self ) -> bool {
452+ self . cancelled . should_trigger_local_cancellation ( )
442453 }
443454
444455 #[ cold]
@@ -450,6 +461,10 @@ impl ZalsaLocal {
450461 pub ( crate ) fn unwind_cancelled ( & self ) {
451462 Cancelled :: Local . throw ( ) ;
452463 }
464+
465+ pub ( crate ) fn set_cancellation_disabled ( & self , was_disabled : bool ) -> bool {
466+ self . cancelled . set_cancellation_disabled ( was_disabled)
467+ }
453468}
454469
455470// Okay to implement as `ZalsaLocal`` is !Sync
0 commit comments