Skip to content

Commit adbb4aa

Browse files
committed
[JOIN] Refactor _as_stateful to accept sdf
1 parent 6a6bb89 commit adbb4aa

File tree

1 file changed

+8
-24
lines changed

1 file changed

+8
-24
lines changed

quixstreams/dataframe/dataframe.py

+8-24
Original file line numberDiff line numberDiff line change
@@ -279,11 +279,7 @@ def func(d: dict, state: State):
279279
cast(ApplyCallbackStateful, func)
280280
)
281281

282-
stateful_func = _as_stateful(
283-
func=with_metadata_func,
284-
processing_context=self._processing_context,
285-
stream_id=self.stream_id,
286-
)
282+
stateful_func = _as_stateful(with_metadata_func, self)
287283
stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) # type: ignore[call-overload]
288284
else:
289285
stream = self.stream.add_apply(
@@ -388,11 +384,7 @@ def func(values: list, state: State):
388384
cast(UpdateCallbackStateful, func)
389385
)
390386

391-
stateful_func = _as_stateful(
392-
func=with_metadata_func,
393-
processing_context=self._processing_context,
394-
stream_id=self.stream_id,
395-
)
387+
stateful_func = _as_stateful(with_metadata_func, self)
396388
return self._add_update(stateful_func, metadata=True)
397389
else:
398390
return self._add_update(
@@ -490,11 +482,7 @@ def func(d: dict, state: State):
490482
cast(FilterCallbackStateful, func)
491483
)
492484

493-
stateful_func = _as_stateful(
494-
func=with_metadata_func,
495-
processing_context=self._processing_context,
496-
stream_id=self.stream_id,
497-
)
485+
stateful_func = _as_stateful(with_metadata_func, self)
498486
stream = self.stream.add_filter(stateful_func, metadata=True)
499487
else:
500488
stream = self.stream.add_filter( # type: ignore[call-overload]
@@ -1805,24 +1793,20 @@ def wrapper(
18051793

18061794
def _as_stateful(
18071795
func: Callable[[Any, Any, int, Any, State], T],
1808-
processing_context: ProcessingContext,
1809-
stream_id: str,
1796+
sdf: StreamingDataFrame,
18101797
) -> Callable[[Any, Any, int, Any], T]:
18111798
@functools.wraps(func)
18121799
def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any:
18131800
# Pass a State object with an interface limited to the key updates only
18141801
# and prefix all the state keys by the message key
1815-
transaction = _get_transaction(processing_context, stream_id)
1816-
state = transaction.as_state(prefix=key)
1802+
state = _get_transaction(sdf).as_state(prefix=key)
18171803
return func(value, key, timestamp, headers, state)
18181804

18191805
return wrapper
18201806

18211807

1822-
def _get_transaction(
1823-
processing_context: ProcessingContext, stream_id: str
1824-
) -> PartitionTransaction:
1825-
return processing_context.checkpoint.get_store_transaction(
1826-
stream_id=stream_id,
1808+
def _get_transaction(sdf: StreamingDataFrame) -> PartitionTransaction:
1809+
return sdf.processing_context.checkpoint.get_store_transaction(
1810+
stream_id=sdf.stream_id,
18271811
partition=message_context().partition,
18281812
)

0 commit comments

Comments
 (0)