@@ -279,11 +279,7 @@ def func(d: dict, state: State):
279
279
cast (ApplyCallbackStateful , func )
280
280
)
281
281
282
- stateful_func = _as_stateful (
283
- func = with_metadata_func ,
284
- processing_context = self ._processing_context ,
285
- stream_id = self .stream_id ,
286
- )
282
+ stateful_func = _as_stateful (with_metadata_func , self )
287
283
stream = self .stream .add_apply (stateful_func , expand = expand , metadata = True ) # type: ignore[call-overload]
288
284
else :
289
285
stream = self .stream .add_apply (
@@ -388,11 +384,7 @@ def func(values: list, state: State):
388
384
cast (UpdateCallbackStateful , func )
389
385
)
390
386
391
- stateful_func = _as_stateful (
392
- func = with_metadata_func ,
393
- processing_context = self ._processing_context ,
394
- stream_id = self .stream_id ,
395
- )
387
+ stateful_func = _as_stateful (with_metadata_func , self )
396
388
return self ._add_update (stateful_func , metadata = True )
397
389
else :
398
390
return self ._add_update (
@@ -490,11 +482,7 @@ def func(d: dict, state: State):
490
482
cast (FilterCallbackStateful , func )
491
483
)
492
484
493
- stateful_func = _as_stateful (
494
- func = with_metadata_func ,
495
- processing_context = self ._processing_context ,
496
- stream_id = self .stream_id ,
497
- )
485
+ stateful_func = _as_stateful (with_metadata_func , self )
498
486
stream = self .stream .add_filter (stateful_func , metadata = True )
499
487
else :
500
488
stream = self .stream .add_filter ( # type: ignore[call-overload]
@@ -1805,24 +1793,20 @@ def wrapper(
1805
1793
1806
1794
def _as_stateful (
1807
1795
func : Callable [[Any , Any , int , Any , State ], T ],
1808
- processing_context : ProcessingContext ,
1809
- stream_id : str ,
1796
+ sdf : StreamingDataFrame ,
1810
1797
) -> Callable [[Any , Any , int , Any ], T ]:
1811
1798
@functools .wraps (func )
1812
1799
def wrapper (value : Any , key : Any , timestamp : int , headers : Any ) -> Any :
1813
1800
# Pass a State object with an interface limited to the key updates only
1814
1801
# and prefix all the state keys by the message key
1815
- transaction = _get_transaction (processing_context , stream_id )
1816
- state = transaction .as_state (prefix = key )
1802
+ state = _get_transaction (sdf ).as_state (prefix = key )
1817
1803
return func (value , key , timestamp , headers , state )
1818
1804
1819
1805
return wrapper
1820
1806
1821
1807
1822
- def _get_transaction (
1823
- processing_context : ProcessingContext , stream_id : str
1824
- ) -> PartitionTransaction :
1825
- return processing_context .checkpoint .get_store_transaction (
1826
- stream_id = stream_id ,
1808
+ def _get_transaction (sdf : StreamingDataFrame ) -> PartitionTransaction :
1809
+ return sdf .processing_context .checkpoint .get_store_transaction (
1810
+ stream_id = sdf .stream_id ,
1827
1811
partition = message_context ().partition ,
1828
1812
)
0 commit comments