Skip to content

feat: Persist RequestList state #1274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 23, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions src/crawlee/_utils/recoverable_state.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Generic, Literal, TypeVar

from pydantic import BaseModel

Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
*,
default_state: TStateModel,
persist_state_key: str,
persistence_enabled: bool = False,
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
persist_state_kvs_name: str | None = None,
persist_state_kvs_id: str | None = None,
logger: logging.Logger,
Expand Down Expand Up @@ -71,7 +71,7 @@ async def initialize(self) -> TStateModel:
Returns:
The loaded state model
"""
if not self._persistence_enabled:
if self._persistence_enabled is False:
self._state = self._default_state.model_copy(deep=True)
return self.current_value

Expand All @@ -81,8 +81,9 @@ async def initialize(self) -> TStateModel:

await self._load_saved_state()

event_manager = service_locator.get_event_manager()
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)
if self._persistence_enabled is True:
event_manager = service_locator.get_event_manager()
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)

return self.current_value

Expand All @@ -95,9 +96,10 @@ async def teardown(self) -> None:
if not self._persistence_enabled:
return

event_manager = service_locator.get_event_manager()
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
await self.persist_state()
if self._persistence_enabled is True:
event_manager = service_locator.get_event_manager()
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
await self.persist_state()

@property
def current_value(self) -> TStateModel:
Expand All @@ -107,6 +109,23 @@ def current_value(self) -> TStateModel:

return self._state

@property
def is_initialized(self) -> bool:
"""Check if the state has already been initialized."""
return self._state is not None

async def has_persisted_state(self) -> bool:
"""Check if there is any persisted state in the key-value store."""
if not self._persistence_enabled:
return False

if self._key_value_store is None:
raise RuntimeError('Recoverable state has not yet been initialized')

return (
await self._key_value_store.get_value(self._persist_state_key) is not None
) # TODO do not fetch the whole record

async def reset(self) -> None:
"""Reset the state to the default values and clear any persisted state.

Expand All @@ -130,17 +149,21 @@ async def persist_state(self, event_data: EventPersistStateData | None = None) -
Args:
event_data: Optional data associated with a PERSIST_STATE event
"""
self._log.debug(f'Persisting state of the Statistics (event_data={event_data}).')
self._log.debug(
f'Persisting RecoverableState (model={self._default_state.__class__.__name__}, event_data={event_data}).'
)

if self._key_value_store is None or self._state is None:
raise RuntimeError('Recoverable state has not yet been initialized')

if self._persistence_enabled:
if self._persistence_enabled is True or self._persistence_enabled == 'explicit_only':
await self._key_value_store.set_value(
self._persist_state_key,
self._state.model_dump(mode='json', by_alias=True),
'application/json',
)
else:
self._log.debug('Persistence is not enabled - not doing anything')

async def _load_saved_state(self) -> None:
if self._key_value_store is None:
Expand Down
105 changes: 91 additions & 14 deletions src/crawlee/request_loaders/_request_list.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,83 @@
from __future__ import annotations

import asyncio
import contextlib
from collections.abc import AsyncIterable, AsyncIterator, Iterable
from typing import TYPE_CHECKING
from logging import getLogger
from typing import Annotated

from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import override

from crawlee._request import Request
from crawlee._utils.docs import docs_group
from crawlee.request_loaders._request_loader import RequestLoader

if TYPE_CHECKING:
from crawlee._request import Request
logger = getLogger(__name__)


class RequestListState(BaseModel):
model_config = ConfigDict(populate_by_name=True)

next_index: Annotated[int, Field(alias='nextIndex')] = 0
next_unique_key: Annotated[str | None, Field(alias='nextUniqueKey')] = None
in_progress: Annotated[set[str], Field(alias='inProgress')] = set()
Comment on lines +22 to +24
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the camelCase aliases necessary? AFAIK I also did not use them in FS storage clients.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably not. Sessions and Statistics (other instances of recoverable state) use them too. I have no strong opinion here, if you do, say the word and I'll remove them.

Copy link
Collaborator

@vdusek vdusek Jul 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, so currently somewhere we use them, and somewhere we don't - up to you then.



class RequestListData(BaseModel):
requests: Annotated[Iterable[Request], Field()]


@docs_group('Classes')
class RequestList(RequestLoader):
"""Represents a (potentially very large) list of URLs to crawl.

Disclaimer: The `RequestList` class is in its early version and is not fully implemented. It is currently
intended mainly for testing purposes and small-scale projects. The current implementation is only in-memory
storage and is very limited. It will be (re)implemented in the future. For more details, see the GitHub issue:
https://github.com/apify/crawlee-python/issues/99. For production usage we recommend to use the `RequestQueue`.
"""
"""Represents a (potentially very large) list of URLs to crawl."""

def __init__(
self,
requests: Iterable[str | Request] | AsyncIterable[str | Request] | None = None,
name: str | None = None,
persist_state_key: str | None = None,
persist_requests_key: str | None = None,
) -> None:
"""Initialize a new instance.

Args:
requests: The request objects (or their string representations) to be added to the provider.
name: A name of the request list.
persist_state_key: A key for persisting the progress information of the RequestList.
If you do not pass a key but pass a `name`, a key will be derived using the name.
Otherwise, state will not be persisted.
persist_requests_key: A key for persisting the request data loaded from the `requests` iterator.
If specified, the request data will be stored in the KeyValueStore to make sure that they don't change
over time. This is useful if the `requests` iterator pulls the data dynamically.
"""
from crawlee._utils.recoverable_state import RecoverableState

self._name = name
self._handled_count = 0
self._assumed_total_count = 0

self._in_progress = set[str]()
self._next: Request | None = None

if persist_state_key is None and name is not None:
persist_state_key = f'SDK_REQUEST_LIST_STATE-{name}'

self._state = RecoverableState(
default_state=RequestListState(),
persistence_enabled=bool(persist_state_key),
persist_state_key=persist_state_key or '',
logger=logger,
)

self._persist_request_data = bool(persist_requests_key)

self._requests_data = RecoverableState(
default_state=RequestListData(requests=[]),
persistence_enabled='explicit_only' if self._persist_request_data else False,
persist_state_key=persist_requests_key or '',
logger=logger,
)

if isinstance(requests, AsyncIterable):
self._requests = requests.__aiter__()
elif requests is None:
Expand All @@ -50,6 +87,38 @@ def __init__(

self._requests_lock: asyncio.Lock | None = None

async def _get_state(self) -> RequestListState:
if self._state.is_initialized:
return self._state.current_value

await self._state.initialize()
await self._requests_data.initialize()

if self._persist_request_data:
if self._requests_lock is None:
self._requests_lock = asyncio.Lock()

async with self._requests_lock:
if not await self._requests_data.has_persisted_state():
self._requests_data.current_value.requests = [
request if isinstance(request, Request) else Request.from_url(request)
async for request in self._requests
]
await self._requests_data.persist_state()

self._requests = self._iterate_in_threadpool(self._requests_data.current_value.requests)

for _ in range(0, self._state.current_value.next_index):
with contextlib.suppress(StopAsyncIteration):
await self._requests.__anext__()

if (unique_key_to_check := self._state.current_value.next_unique_key) is not None:
await self._ensure_next_request()
if self._next is None or self._next.unique_key != unique_key_to_check:
raise RuntimeError()

return self._state.current_value

@property
def name(self) -> str | None:
return self._name
Expand All @@ -65,7 +134,8 @@ async def is_empty(self) -> bool:

@override
async def is_finished(self) -> bool:
return len(self._in_progress) == 0 and await self.is_empty()
state = await self._get_state()
return len(state.in_progress) == 0 and await self.is_empty()

@override
async def fetch_next_request(self) -> Request | None:
Expand All @@ -74,7 +144,8 @@ async def fetch_next_request(self) -> Request | None:
if self._next is None:
return None

self._in_progress.add(self._next.id)
state = await self._get_state()
state.in_progress.add(self._next.id)
self._assumed_total_count += 1

next_request = self._next
Expand All @@ -85,13 +156,16 @@ async def fetch_next_request(self) -> Request | None:
@override
async def mark_request_as_handled(self, request: Request) -> None:
self._handled_count += 1
self._in_progress.remove(request.id)
state = await self._get_state()
state.in_progress.remove(request.id)

@override
async def get_handled_count(self) -> int:
return self._handled_count

async def _ensure_next_request(self) -> None:
state = await self._get_state()

if self._requests_lock is None:
self._requests_lock = asyncio.Lock()

Expand All @@ -101,6 +175,9 @@ async def _ensure_next_request(self) -> None:
self._next = self._transform_request(await self._requests.__anext__())
except StopAsyncIteration:
self._next = None
else:
state.next_index += 1
state.next_unique_key = self._next.unique_key

async def _iterate_in_threadpool(self, iterable: Iterable[str | Request]) -> AsyncIterator[str | Request]:
"""Inspired by a function of the same name from encode/starlette."""
Expand Down
6 changes: 3 additions & 3 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading