Skip to content

feat: Persist RequestList state #1274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 40 additions & 18 deletions src/crawlee/_utils/recoverable_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

from pydantic import BaseModel

Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
*,
default_state: TStateModel,
persist_state_key: str,
persistence_enabled: bool = False,
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
persist_state_kvs_name: str | None = None,
persist_state_kvs_id: str | None = None,
logger: logging.Logger,
Expand All @@ -43,13 +43,14 @@ def __init__(

Args:
default_state: The default state model instance to use when no persisted state is found.
A deep copy is made each time the state is used.
A deep copy is made each time the state is used.
persist_state_key: The key under which the state is stored in the KeyValueStore
persistence_enabled: Flag to enable or disable state persistence
persistence_enabled: Flag to enable or disable state persistence. Use 'explicit_only' if you want to be able
to save the state manually, but without any automatic persistence.
persist_state_kvs_name: The name of the KeyValueStore to use for persistence.
If neither a name nor and id are supplied, the default store will be used.
If neither a name nor and id are supplied, the default store will be used.
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
If neither a name nor and id are supplied, the default store will be used.
If neither a name nor and id are supplied, the default store will be used.
logger: A logger instance for logging operations related to state persistence
"""
self._default_state = default_state
Expand All @@ -71,7 +72,7 @@ async def initialize(self) -> TStateModel:
Returns:
The loaded state model
"""
if not self._persistence_enabled:
if self._persistence_enabled is False:
self._state = self._default_state.model_copy(deep=True)
return self.current_value

Expand All @@ -84,11 +85,12 @@ async def initialize(self) -> TStateModel:

await self._load_saved_state()

# Import here to avoid circular imports.
from crawlee import service_locator # noqa: PLC0415
if self._persistence_enabled is True:
# Import here to avoid circular imports.
from crawlee import service_locator # noqa: PLC0415

event_manager = service_locator.get_event_manager()
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)
event_manager = service_locator.get_event_manager()
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)

return self.current_value

Expand All @@ -101,12 +103,13 @@ async def teardown(self) -> None:
if not self._persistence_enabled:
return

# Import here to avoid circular imports.
from crawlee import service_locator # noqa: PLC0415
if self._persistence_enabled is True:
# Import here to avoid circular imports.
from crawlee import service_locator # noqa: PLC0415

event_manager = service_locator.get_event_manager()
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
await self.persist_state()
event_manager = service_locator.get_event_manager()
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
await self.persist_state()

@property
def current_value(self) -> TStateModel:
Expand All @@ -116,6 +119,21 @@ def current_value(self) -> TStateModel:

return self._state

@property
def is_initialized(self) -> bool:
"""Check if the state has already been initialized."""
return self._state is not None

async def has_persisted_state(self) -> bool:
"""Check if there is any persisted state in the key-value store."""
if not self._persistence_enabled:
return False

if self._key_value_store is None:
raise RuntimeError('Recoverable state has not yet been initialized')

return await self._key_value_store.record_exists(self._persist_state_key)

async def reset(self) -> None:
"""Reset the state to the default values and clear any persisted state.

Expand All @@ -139,17 +157,21 @@ async def persist_state(self, event_data: EventPersistStateData | None = None) -
Args:
event_data: Optional data associated with a PERSIST_STATE event
"""
self._log.debug(f'Persisting state of the Statistics (event_data={event_data}).')
self._log.debug(
f'Persisting RecoverableState (model={self._default_state.__class__.__name__}, event_data={event_data}).'
)

if self._key_value_store is None or self._state is None:
raise RuntimeError('Recoverable state has not yet been initialized')

if self._persistence_enabled:
if self._persistence_enabled is True or self._persistence_enabled == 'explicit_only':
await self._key_value_store.set_value(
self._persist_state_key,
self._state.model_dump(mode='json', by_alias=True),
'application/json',
)
else:
self._log.debug('Persistence is not enabled - not doing anything')

async def _load_saved_state(self) -> None:
if self._key_value_store is None:
Expand Down
157 changes: 131 additions & 26 deletions src/crawlee/request_loaders/_request_list.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,83 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncIterable, AsyncIterator, Iterable
from typing import TYPE_CHECKING
import contextlib
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Iterable
from logging import getLogger
from typing import Annotated

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override

from crawlee._request import Request
from crawlee._utils.docs import docs_group
from crawlee.request_loaders._request_loader import RequestLoader

if TYPE_CHECKING:
from crawlee._request import Request
logger = getLogger(__name__)


class RequestListState(BaseModel):
model_config = ConfigDict(populate_by_name=True)

next_index: Annotated[int, Field(alias='nextIndex')] = 0
next_unique_key: Annotated[str | None, Field(alias='nextUniqueKey')] = None
in_progress: Annotated[set[str], Field(alias='inProgress')] = set()
Comment on lines +22 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the camelCase aliases necessary? AFAIK I also did not use them in FS storage clients.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. Sessions and Statistics (other instances of recoverable state) use them too. I have no strong opinion here, if you do, say the word and I'll remove them.

Copy link
Collaborator

@vdusek vdusek Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, so currently somewhere we use them, and somewhere we don't - up to you then.



class RequestListData(BaseModel):
requests: Annotated[list[Request], Field()]


@docs_group('Classes')
class RequestList(RequestLoader):
"""Represents a (potentially very large) list of URLs to crawl.

Disclaimer: The `RequestList` class is in its early version and is not fully implemented. It is currently
intended mainly for testing purposes and small-scale projects. The current implementation is only in-memory
storage and is very limited. It will be (re)implemented in the future. For more details, see the GitHub issue:
https://github.com/apify/crawlee-python/issues/99. For production usage we recommend to use the `RequestQueue`.
"""
"""Represents a (potentially very large) list of URLs to crawl."""

def __init__(
self,
requests: Iterable[str | Request] | AsyncIterable[str | Request] | None = None,
name: str | None = None,
persist_state_key: str | None = None,
persist_requests_key: str | None = None,
) -> None:
"""Initialize a new instance.

Args:
requests: The request objects (or their string representations) to be added to the provider.
name: A name of the request list.
persist_state_key: A key for persisting the progress information of the RequestList.
If you do not pass a key but pass a `name`, a key will be derived using the name.
Otherwise, state will not be persisted.
persist_requests_key: A key for persisting the request data loaded from the `requests` iterator.
If specified, the request data will be stored in the KeyValueStore to make sure that they don't change
over time. This is useful if the `requests` iterator pulls the data dynamically.
"""
from crawlee._utils.recoverable_state import RecoverableState # noqa: PLC0415

self._name = name
self._handled_count = 0
self._assumed_total_count = 0

self._in_progress = set[str]()
self._next: Request | None = None
self._next: tuple[Request | None, Request | None] = (None, None)

if persist_state_key is None and name is not None:
persist_state_key = f'SDK_REQUEST_LIST_STATE-{name}'

self._state = RecoverableState(
default_state=RequestListState(),
persistence_enabled=bool(persist_state_key),
persist_state_key=persist_state_key or '',
logger=logger,
)

self._persist_request_data = bool(persist_requests_key)

self._requests_data = RecoverableState(
default_state=RequestListData(requests=[]),
# With request data persistence enabled, a snapshot of the requests will be done on initialization
persistence_enabled='explicit_only' if self._persist_request_data else False,
persist_state_key=persist_requests_key or '',
logger=logger,
)

if isinstance(requests, AsyncIterable):
self._requests = requests.__aiter__()
Expand All @@ -50,6 +88,53 @@ def __init__(

self._requests_lock: asyncio.Lock | None = None

async def _get_state(self) -> RequestListState:
# If state is already initialized, we are done
if self._state.is_initialized:
return self._state.current_value

# Initialize recoverable state
await self._state.initialize()
await self._requests_data.initialize()

# Initialize lock if necessary
if self._requests_lock is None:
self._requests_lock = asyncio.Lock()

# If the RequestList is configured to persist request data, ensure that a copy of request data is used
if self._persist_request_data:
async with self._requests_lock:
if not await self._requests_data.has_persisted_state():
self._requests_data.current_value.requests = [
request if isinstance(request, Request) else Request.from_url(request)
async for request in self._requests
]
await self._requests_data.persist_state()

self._requests = self._iterate_in_threadpool(
self._requests_data.current_value.requests[self._state.current_value.next_index :]
)
# If not using persistent request data, advance the request iterator
else:
async with self._requests_lock:
for _ in range(self._state.current_value.next_index):
with contextlib.suppress(StopAsyncIteration):
await self._requests.__anext__()

# Check consistency of the stored state and the request iterator
if (unique_key_to_check := self._state.current_value.next_unique_key) is not None:
await self._ensure_next_request()

next_unique_key = self._next[0].unique_key if self._next[0] is not None else None
if next_unique_key != unique_key_to_check:
raise RuntimeError(
f"""Mismatch at index {
self._state.current_value.next_index
} in persisted requests - Expected unique key `{unique_key_to_check}`, got `{next_unique_key}`"""
)

return self._state.current_value

@property
def name(self) -> str | None:
return self._name
Expand All @@ -65,42 +150,62 @@ async def get_total_count(self) -> int:
@override
async def is_empty(self) -> bool:
await self._ensure_next_request()
return self._next is None
return self._next[0] is None

@override
async def is_finished(self) -> bool:
return len(self._in_progress) == 0 and await self.is_empty()
state = await self._get_state()
return len(state.in_progress) == 0 and await self.is_empty()

@override
async def fetch_next_request(self) -> Request | None:
await self._get_state()
await self._ensure_next_request()

if self._next is None:
if self._next[0] is None:
return None

self._in_progress.add(self._next.id)
state = await self._get_state()
state.in_progress.add(self._next[0].id)
self._assumed_total_count += 1

next_request = self._next
self._next = None
next_request = self._next[0]
if next_request is not None:
state.next_index += 1
state.next_unique_key = self._next[1].unique_key if self._next[1] is not None else None

self._next = (self._next[1], None)
await self._ensure_next_request()

return next_request

@override
async def mark_request_as_handled(self, request: Request) -> None:
self._handled_count += 1
self._in_progress.remove(request.id)
state = await self._get_state()
state.in_progress.remove(request.id)

async def _ensure_next_request(self) -> None:
await self._get_state()

if self._requests_lock is None:
self._requests_lock = asyncio.Lock()

try:
async with self._requests_lock:
if self._next is None:
self._next = self._transform_request(await self._requests.__anext__())
except StopAsyncIteration:
self._next = None
async with self._requests_lock:
if None in self._next:
if self._next[0] is None:
to_enqueue = [item async for item in self._dequeue_requests(2)]
self._next = (to_enqueue[0], to_enqueue[1])
else:
to_enqueue = [item async for item in self._dequeue_requests(1)]
self._next = (self._next[0], to_enqueue[0])

async def _dequeue_requests(self, count: int) -> AsyncGenerator[Request | None]:
for _ in range(count):
try:
yield self._transform_request(await self._requests.__anext__())
except StopAsyncIteration: # noqa: PERF203
yield None

async def _iterate_in_threadpool(self, iterable: Iterable[str | Request]) -> AsyncIterator[str | Request]:
"""Inspired by a function of the same name from encode/starlette."""
Expand Down
2 changes: 1 addition & 1 deletion src/crawlee/statistics/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class Statistics(Generic[TStatisticsState]):
def __init__(
self,
*,
persistence_enabled: bool = False,
persistence_enabled: bool | Literal['explicit_only'] = False,
persist_state_kvs_name: str | None = None,
persist_state_key: str | None = None,
log_message: str = 'Statistics',
Expand Down
Loading
Loading