-
Notifications
You must be signed in to change notification settings - Fork 409
feat: Persist RequestList state #1274
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
65ea2fc
52caeec
fa701f8
8a1c928
f601ce3
c825b18
3a7312d
68c4c31
7d1c36a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,83 @@ | ||
from __future__ import annotations | ||
|
||
import asyncio | ||
import contextlib | ||
from collections.abc import AsyncIterable, AsyncIterator, Iterable | ||
from typing import TYPE_CHECKING | ||
from logging import getLogger | ||
from typing import Annotated | ||
|
||
from pydantic import BaseModel, ConfigDict, Field | ||
from typing_extensions import override | ||
|
||
from crawlee._request import Request | ||
from crawlee._utils.docs import docs_group | ||
from crawlee.request_loaders._request_loader import RequestLoader | ||
|
||
if TYPE_CHECKING: | ||
from crawlee._request import Request | ||
logger = getLogger(__name__) | ||
|
||
|
||
class RequestListState(BaseModel): | ||
model_config = ConfigDict(populate_by_name=True) | ||
|
||
next_index: Annotated[int, Field(alias='nextIndex')] = 0 | ||
Pijukatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
next_unique_key: Annotated[str | None, Field(alias='nextUniqueKey')] = None | ||
in_progress: Annotated[set[str], Field(alias='inProgress')] = set() | ||
Comment on lines
+22
to
+24
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably not. Sessions and Statistics (other instances of recoverable state) use them too. I have no strong opinion here, if you do, say the word and I'll remove them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, so currently somewhere we use them, and somewhere we don't - up to you then. |
||
|
||
|
||
class RequestListData(BaseModel): | ||
requests: Annotated[Iterable[Request], Field()] | ||
|
||
|
||
@docs_group('Classes') | ||
class RequestList(RequestLoader): | ||
"""Represents a (potentially very large) list of URLs to crawl. | ||
|
||
Disclaimer: The `RequestList` class is in its early version and is not fully implemented. It is currently | ||
intended mainly for testing purposes and small-scale projects. The current implementation is only in-memory | ||
storage and is very limited. It will be (re)implemented in the future. For more details, see the GitHub issue: | ||
https://github.com/apify/crawlee-python/issues/99. For production usage we recommend to use the `RequestQueue`. | ||
""" | ||
"""Represents a (potentially very large) list of URLs to crawl.""" | ||
|
||
def __init__( | ||
self, | ||
requests: Iterable[str | Request] | AsyncIterable[str | Request] | None = None, | ||
name: str | None = None, | ||
persist_state_key: str | None = None, | ||
persist_requests_key: str | None = None, | ||
) -> None: | ||
"""Initialize a new instance. | ||
|
||
Args: | ||
requests: The request objects (or their string representations) to be added to the provider. | ||
name: A name of the request list. | ||
persist_state_key: A key for persisting the progress information of the RequestList. | ||
If you do not pass a key but pass a `name`, a key will be derived using the name. | ||
Otherwise, state will not be persisted. | ||
persist_requests_key: A key for persisting the request data loaded from the `requests` iterator. | ||
If specified, the request data will be stored in the KeyValueStore to make sure that they don't change | ||
over time. This is useful if the `requests` iterator pulls the data dynamically. | ||
""" | ||
from crawlee._utils.recoverable_state import RecoverableState | ||
|
||
self._name = name | ||
self._handled_count = 0 | ||
self._assumed_total_count = 0 | ||
|
||
self._in_progress = set[str]() | ||
self._next: Request | None = None | ||
|
||
if persist_state_key is None and name is not None: | ||
persist_state_key = f'SDK_REQUEST_LIST_STATE-{name}' | ||
|
||
self._state = RecoverableState( | ||
default_state=RequestListState(), | ||
persistence_enabled=bool(persist_state_key), | ||
persist_state_key=persist_state_key or '', | ||
logger=logger, | ||
) | ||
|
||
self._persist_request_data = bool(persist_requests_key) | ||
|
||
self._requests_data = RecoverableState( | ||
default_state=RequestListData(requests=[]), | ||
persistence_enabled='explicit_only' if self._persist_request_data else False, | ||
Pijukatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
persist_state_key=persist_requests_key or '', | ||
logger=logger, | ||
) | ||
|
||
if isinstance(requests, AsyncIterable): | ||
self._requests = requests.__aiter__() | ||
elif requests is None: | ||
|
@@ -50,6 +87,38 @@ def __init__( | |
|
||
self._requests_lock: asyncio.Lock | None = None | ||
|
||
async def _get_state(self) -> RequestListState: | ||
Pijukatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if self._state.is_initialized: | ||
return self._state.current_value | ||
|
||
await self._state.initialize() | ||
await self._requests_data.initialize() | ||
|
||
if self._persist_request_data: | ||
if self._requests_lock is None: | ||
self._requests_lock = asyncio.Lock() | ||
|
||
async with self._requests_lock: | ||
if not await self._requests_data.has_persisted_state(): | ||
self._requests_data.current_value.requests = [ | ||
request if isinstance(request, Request) else Request.from_url(request) | ||
async for request in self._requests | ||
] | ||
await self._requests_data.persist_state() | ||
|
||
self._requests = self._iterate_in_threadpool(self._requests_data.current_value.requests) | ||
|
||
for _ in range(0, self._state.current_value.next_index): | ||
with contextlib.suppress(StopAsyncIteration): | ||
await self._requests.__anext__() | ||
Pijukatel marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if (unique_key_to_check := self._state.current_value.next_unique_key) is not None: | ||
await self._ensure_next_request() | ||
if self._next is None or self._next.unique_key != unique_key_to_check: | ||
raise RuntimeError() | ||
|
||
return self._state.current_value | ||
|
||
@property | ||
def name(self) -> str | None: | ||
return self._name | ||
|
@@ -65,7 +134,8 @@ async def is_empty(self) -> bool: | |
|
||
@override | ||
async def is_finished(self) -> bool: | ||
return len(self._in_progress) == 0 and await self.is_empty() | ||
state = await self._get_state() | ||
return len(state.in_progress) == 0 and await self.is_empty() | ||
|
||
@override | ||
async def fetch_next_request(self) -> Request | None: | ||
|
@@ -74,7 +144,8 @@ async def fetch_next_request(self) -> Request | None: | |
if self._next is None: | ||
return None | ||
|
||
self._in_progress.add(self._next.id) | ||
state = await self._get_state() | ||
state.in_progress.add(self._next.id) | ||
self._assumed_total_count += 1 | ||
|
||
next_request = self._next | ||
|
@@ -85,13 +156,16 @@ async def fetch_next_request(self) -> Request | None: | |
@override | ||
async def mark_request_as_handled(self, request: Request) -> None: | ||
self._handled_count += 1 | ||
self._in_progress.remove(request.id) | ||
state = await self._get_state() | ||
state.in_progress.remove(request.id) | ||
|
||
@override | ||
async def get_handled_count(self) -> int: | ||
return self._handled_count | ||
|
||
async def _ensure_next_request(self) -> None: | ||
state = await self._get_state() | ||
|
||
if self._requests_lock is None: | ||
self._requests_lock = asyncio.Lock() | ||
|
||
|
@@ -101,6 +175,9 @@ async def _ensure_next_request(self) -> None: | |
self._next = self._transform_request(await self._requests.__anext__()) | ||
except StopAsyncIteration: | ||
self._next = None | ||
else: | ||
state.next_index += 1 | ||
state.next_unique_key = self._next.unique_key | ||
|
||
async def _iterate_in_threadpool(self, iterable: Iterable[str | Request]) -> AsyncIterator[str | Request]: | ||
"""Inspired by a function of the same name from encode/starlette.""" | ||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Uh oh!
There was an error while loading. Please reload this page.