-
Notifications
You must be signed in to change notification settings - Fork 231
feat: Add AgentCore Memory Checkpointer #628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
17e8d4f
2a8a920
63c54c8
ce6bd78
4b12da5
abec69d
86bdc1e
36edb27
3c24436
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| """ | ||
| Constants and exceptions for AgentCore Memory Checkpoint Saver. | ||
| """ | ||
|
|
||
| EMPTY_CHANNEL_VALUE = "_empty" | ||
|
|
||
|
|
||
| class AgentCoreMemoryError(Exception): | ||
| """Base exception for AgentCore Memory errors.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class EventDecodingError(AgentCoreMemoryError): | ||
| """Raised when event decoding fails.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class InvalidConfigError(AgentCoreMemoryError): | ||
| """Raised when configuration is invalid.""" | ||
|
|
||
| pass | ||
|
|
||
|
|
||
| class EventNotFoundError(AgentCoreMemoryError): | ||
| """Raised when expected event is not found.""" | ||
|
|
||
| pass |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,307 @@ | ||
| """ | ||
| Helper classes for AgentCore Memory Checkpoint Saver. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import base64 | ||
| import json | ||
| import logging | ||
| from collections import defaultdict | ||
| from datetime import UTC, datetime | ||
| from typing import Any, Dict, List, Union | ||
|
|
||
| import boto3 | ||
| from langgraph.checkpoint.base import CheckpointTuple, SerializerProtocol | ||
|
|
||
| from langgraph_checkpoint_aws.checkpoint.agentcore_memory.constants import ( | ||
| EMPTY_CHANNEL_VALUE, | ||
| EventDecodingError, | ||
| ) | ||
| from langgraph_checkpoint_aws.checkpoint.agentcore_memory.models import ( | ||
| ChannelDataEvent, | ||
| CheckpointerConfig, | ||
| CheckpointEvent, | ||
| WriteItem, | ||
| WritesEvent, | ||
| ) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
| # Union type for all events | ||
| EventType = Union[CheckpointEvent, ChannelDataEvent, WritesEvent] | ||
|
|
||
|
|
||
| class EventSerializer: | ||
| """Handles serialization and deserialization of events to store in AgentCore Memory.""" | ||
|
|
||
| def __init__(self, serde: SerializerProtocol): | ||
| self.serde = serde | ||
|
|
||
| def serialize_value(self, value: Any) -> Dict[str, Any]: | ||
| """Serialize a value using the serde protocol.""" | ||
| type_tag, binary_data = self.serde.dumps_typed(value) | ||
| return {"type": type_tag, "data": base64.b64encode(binary_data).decode("utf-8")} | ||
3coins marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def deserialize_value(self, serialized: Dict[str, Any]) -> Any: | ||
| """Deserialize a value using the serde protocol.""" | ||
| try: | ||
| type_tag = serialized["type"] | ||
| binary_data = base64.b64decode(serialized["data"]) | ||
| return self.serde.loads_typed((type_tag, binary_data)) | ||
| except Exception as e: | ||
| raise EventDecodingError(f"Failed to deserialize value: {e}") | ||
|
|
||
| def serialize_event(self, event: EventType) -> str: | ||
| """Serialize an event to JSON string.""" | ||
|
|
||
| # Create a custom serializer for Pydantic models | ||
| def custom_serializer(obj): | ||
| if hasattr(obj, "model_dump"): | ||
| return obj.model_dump() | ||
| raise TypeError(f"Object of type {type(obj)} is not JSON serializable") | ||
|
|
||
| # Get the base dictionary | ||
| event_dict = event.model_dump(exclude_none=True) | ||
|
|
||
| # Handle special serialization for specific fields | ||
| if isinstance(event, CheckpointEvent): | ||
| event_dict["checkpoint_data"] = self.serialize_value(event.checkpoint_data) | ||
| event_dict["metadata"] = self.serialize_value(event.metadata) | ||
|
|
||
| elif isinstance(event, ChannelDataEvent): | ||
| if event.value != EMPTY_CHANNEL_VALUE: | ||
| event_dict["value"] = self.serialize_value(event.value) | ||
|
|
||
| elif isinstance(event, WritesEvent): | ||
| # The writes field is already properly serialized by model_dump() | ||
| # We just need to serialize the value field in each write | ||
| for write in event_dict["writes"]: | ||
| val = write.get("value", EMPTY_CHANNEL_VALUE) | ||
| write["value"] = self.serialize_value(val) | ||
|
|
||
| return json.dumps(event_dict, default=custom_serializer) | ||
|
|
||
| def deserialize_event(self, data: str) -> EventType: | ||
| """Deserialize JSON string to event.""" | ||
| try: | ||
| event_dict = json.loads(data) | ||
| event_type = event_dict.get("event_type") | ||
|
|
||
| if event_type == "checkpoint": | ||
| # Deserialize checkpoint data and metadata | ||
| event_dict["checkpoint_data"] = self.deserialize_value( | ||
| event_dict["checkpoint_data"] | ||
| ) | ||
| event_dict["metadata"] = self.deserialize_value(event_dict["metadata"]) | ||
| return CheckpointEvent(**event_dict) | ||
|
|
||
| elif event_type == "channel_data": | ||
| # Deserialize channel value if not empty | ||
| if "value" in event_dict and isinstance(event_dict["value"], dict): | ||
| event_dict["value"] = self.deserialize_value(event_dict["value"]) | ||
| return ChannelDataEvent(**event_dict) | ||
|
|
||
| elif event_type == "writes": | ||
| # Deserialize write values | ||
| for write in event_dict["writes"]: | ||
| if isinstance(write["value"], dict): | ||
| write["value"] = self.deserialize_value(write["value"]) | ||
| return WritesEvent(**event_dict) | ||
|
|
||
| else: | ||
| raise EventDecodingError(f"Unknown event type: {event_type}") | ||
|
|
||
| except json.JSONDecodeError as e: | ||
| raise EventDecodingError(f"Failed to parse JSON: {e}") | ||
| except Exception as e: | ||
| raise EventDecodingError(f"Failed to deserialize event: {e}") | ||
|
|
||
|
|
||
| class CheckpointEventClient: | ||
| """Handles low-level event storage and retrieval from AgentCore Memory for checkpoints.""" | ||
|
|
||
| def __init__(self, memory_id: str, serializer: EventSerializer, **boto3_kwargs): | ||
| self.memory_id = memory_id | ||
| self.serializer = serializer | ||
| self.client = boto3.client("bedrock-agentcore", **boto3_kwargs) | ||
|
|
||
| def store_event(self, event: EventType, session_id: str, actor_id: str) -> None: | ||
| """Store an event in AgentCore Memory.""" | ||
| serialized = self.serializer.serialize_event(event) | ||
|
|
||
| self.client.create_event( | ||
| memoryId=self.memory_id, | ||
| actorId=actor_id, | ||
| sessionId=session_id, | ||
| eventTimestamp=datetime.now(UTC), | ||
| payload=[{"blob": serialized}], | ||
| ) | ||
|
|
||
| def store_events_batch( | ||
|
||
| self, events: List[EventType], session_id: str, actor_id: str | ||
| ) -> None: | ||
| """Store multiple events in a single API call to AgentCore Memory.""" | ||
| # Serialize all events into payload blobs | ||
| payload = [] | ||
| timestamp = datetime.now(UTC) | ||
|
|
||
| for event in events: | ||
| serialized = self.serializer.serialize_event(event) | ||
| payload.append({"blob": serialized}) | ||
|
|
||
| # Store all events in a single create_event call | ||
| self.client.create_event( | ||
| memoryId=self.memory_id, | ||
| actorId=actor_id, | ||
| sessionId=session_id, | ||
| eventTimestamp=timestamp, | ||
| payload=payload, | ||
| ) | ||
|
|
||
| def get_events( | ||
| self, session_id: str, actor_id: str, max_results: int = 100 | ||
| ) -> List[EventType]: | ||
| """Retrieve events from AgentCore Memory.""" | ||
| all_events = [] | ||
| next_token = None | ||
|
|
||
| while len(all_events) < max_results: | ||
| params = { | ||
| "memoryId": self.memory_id, | ||
| "actorId": actor_id, | ||
| "sessionId": session_id, | ||
| "maxResults": min(100, max_results - len(all_events)), | ||
|
||
| "includePayloads": True, | ||
| } | ||
|
|
||
| if next_token: | ||
| params["nextToken"] = next_token | ||
|
|
||
| response = self.client.list_events(**params) | ||
|
|
||
| for event in response.get("events", []): | ||
| for payload_item in event.get("payload", []): | ||
| blob = payload_item.get("blob") | ||
| if blob: | ||
| try: | ||
| parsed_event = self.serializer.deserialize_event(blob) | ||
| all_events.append(parsed_event) | ||
| except EventDecodingError as e: | ||
| logger.warning(f"Failed to decode event: {e}") | ||
|
|
||
| next_token = response.get("nextToken") | ||
| if not next_token or len(all_events) >= max_results: | ||
| break | ||
|
|
||
| return all_events[:max_results] | ||
|
||
|
|
||
| def delete_events(self, session_id: str, actor_id: str) -> None: | ||
3coins marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Delete all events for a session.""" | ||
| params = { | ||
| "memoryId": self.memory_id, | ||
| "actorId": actor_id, | ||
| "sessionId": session_id, | ||
| "maxResults": 100, | ||
| "includePayloads": False, | ||
| } | ||
|
|
||
| while True: | ||
| response = self.client.list_events(**params) | ||
| events = response.get("events", []) | ||
|
|
||
| if not events: | ||
| break | ||
|
|
||
| for event in events: | ||
| self.client.delete_event( | ||
| memoryId=self.memory_id, | ||
| sessionId=session_id, | ||
| eventId=event["eventId"], | ||
| actorId=actor_id, | ||
| ) | ||
|
|
||
| next_token = response.get("nextToken") | ||
| if not next_token: | ||
| break | ||
| params["nextToken"] = next_token | ||
|
|
||
|
|
||
| class EventProcessor: | ||
| """Processes events into checkpoint data structures.""" | ||
|
|
||
| @staticmethod | ||
| def process_events( | ||
| events: List[EventType], | ||
| ) -> tuple[ | ||
| Dict[str, CheckpointEvent], | ||
| Dict[str, List[WriteItem]], | ||
| Dict[tuple[str, str], Any], | ||
| ]: | ||
| """Process events into organized data structures.""" | ||
| checkpoints = {} | ||
| writes_by_checkpoint = defaultdict(list) | ||
| channel_data_by_version = {} | ||
|
|
||
| for event in events: | ||
| if isinstance(event, CheckpointEvent): | ||
| checkpoints[event.checkpoint_id] = event | ||
|
|
||
| elif isinstance(event, WritesEvent): | ||
| writes_by_checkpoint[event.checkpoint_id].extend(event.writes) | ||
|
|
||
| elif isinstance(event, ChannelDataEvent): | ||
| if event.value != EMPTY_CHANNEL_VALUE: | ||
| channel_data_by_version[(event.channel, event.version)] = ( | ||
| event.value | ||
| ) | ||
|
|
||
| return checkpoints, writes_by_checkpoint, channel_data_by_version | ||
|
|
||
| @staticmethod | ||
| def build_checkpoint_tuple( | ||
| checkpoint_event: CheckpointEvent, | ||
| writes: List[WriteItem], | ||
| channel_data: Dict[tuple[str, str], Any], | ||
| config: CheckpointerConfig, | ||
| ) -> CheckpointTuple: | ||
| """Build a CheckpointTuple from processed data.""" | ||
| # Build pending writes | ||
| pending_writes = [ | ||
| (write.task_id, write.channel, write.value) for write in writes | ||
| ] | ||
|
|
||
| # Build parent config | ||
| parent_config = None | ||
| if checkpoint_event.parent_checkpoint_id: | ||
| parent_config = { | ||
| "configurable": { | ||
| "thread_id": config.thread_id, | ||
| "checkpoint_ns": config.checkpoint_ns, | ||
| "checkpoint_id": checkpoint_event.parent_checkpoint_id, | ||
| } | ||
| } | ||
|
|
||
| # Build checkpoint with channel values | ||
| checkpoint = checkpoint_event.checkpoint_data.copy() | ||
| channel_values = {} | ||
|
|
||
| for channel, version in checkpoint.get("channel_versions", {}).items(): | ||
| if (channel, version) in channel_data: | ||
| channel_values[channel] = channel_data[(channel, version)] | ||
|
|
||
| checkpoint["channel_values"] = channel_values | ||
|
|
||
| return CheckpointTuple( | ||
| config={ | ||
| "configurable": { | ||
| "thread_id": config.thread_id, | ||
| "checkpoint_ns": config.checkpoint_ns, | ||
| "checkpoint_id": checkpoint_event.checkpoint_id, | ||
| } | ||
| }, | ||
| checkpoint=checkpoint, | ||
| metadata=checkpoint_event.metadata, | ||
| parent_config=parent_config, | ||
| pending_writes=pending_writes, | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.