Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ Key variables to understand protocol behavior:
return HTTP `413`.
- `A2A_SESSION_CACHE_TTL_SECONDS` / `A2A_SESSION_CACHE_MAXSIZE`: session cache
behavior for `(identity, contextId) -> session_id`.
- `A2A_PENDING_SESSION_CLAIM_TTL_SECONDS`: lease duration for pending preferred
session claims before they expire and stop blocking other identities.
- `A2A_INTERRUPT_REQUEST_TTL_SECONDS`: active retention window for the
in-memory interrupt request binding cache used by `a2a.interrupt.*`
interrupt request binding registry used by `a2a.interrupt.*`
callback methods. Default: `10800` seconds (`180` minutes).
- `A2A_INTERRUPT_REQUEST_TOMBSTONE_TTL_SECONDS`: retention window for expired
interrupt tombstones after active TTL has elapsed. During this window,
Expand All @@ -76,6 +78,13 @@ Key variables to understand protocol behavior:
- `A2A_CLIENT_BEARER_TOKEN`: optional bearer token attached to outbound peer
calls made by the embedded A2A client and `a2a_call` tool path.
- `A2A_CLIENT_SUPPORTED_TRANSPORTS`: ordered outbound transport preference list.
- `A2A_TASK_STORE_BACKEND`: task store backend. Supported values: `memory`,
`database`. Default: `memory`.
- `A2A_TASK_STORE_DATABASE_URL`: database URL used when
`A2A_TASK_STORE_BACKEND=database`. For local persistence, prefer
`sqlite+aiosqlite:///./opencode-a2a.db`.
- `A2A_TASK_STORE_TABLE_NAME` / `A2A_TASK_STORE_CREATE_TABLE`: database task
store table name and whether to auto-create database tables on startup.
- Runtime authentication is bearer-token only via `A2A_BEARER_TOKEN`.
- The same outbound client flags are also honored by the server-side embedded
A2A client used for peer calls and `a2a_call` tool execution:
Expand Down Expand Up @@ -157,6 +166,26 @@ OPENCODE_WORKSPACE_ROOT=/abs/path/to/workspace \
opencode-a2a
```

To persist A2A task records across restarts, switch the task store backend to
SQLite:

```bash
OPENCODE_BASE_URL=http://127.0.0.1:4096 \
A2A_BEARER_TOKEN=dev-token \
A2A_TASK_STORE_BACKEND=database \
A2A_TASK_STORE_DATABASE_URL=sqlite+aiosqlite:///./opencode-a2a.db \
opencode-a2a
```

When `A2A_TASK_STORE_BACKEND=database`, the service now persists:

- task records
- session binding / ownership state
- interrupt request bindings and tombstones

In-flight asyncio locks, outbound A2A client caches, and stream-local
aggregation buffers remain process-local runtime state.

## Troubleshooting Provider Auth State

If one deployment works while another fails against the same upstream provider,
Expand Down Expand Up @@ -825,8 +854,8 @@ Notes:

- `request_id` must be a live interrupt request observed from stream metadata
(`metadata.shared.interrupt.request_id`).
- The server keeps an in-memory interrupt binding cache; callbacks with unknown
or expired `request_id` are rejected.
- The server keeps an interrupt binding registry; callbacks with unknown or
expired `request_id` are rejected.
- The cache retention windows are controlled by
`A2A_INTERRUPT_REQUEST_TTL_SECONDS` (default: `10800` seconds / `180`
minutes) and `A2A_INTERRUPT_REQUEST_TOMBSTONE_TTL_SECONDS` (default: `600`
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ classifiers = [
]
dependencies = [
"a2a-sdk==0.3.25",
"aiosqlite>=0.20,<1.0",
"fastapi>=0.110,<1.0",
"httpx>=0.27,<1.0",
"pydantic>=2.6,<3.0",
"pydantic-settings>=2.2,<3.0",
"sqlalchemy>=2.0,<3.0",
"sse-starlette>=2.1,<4.0",
"uvicorn>=0.29,<1.0",
]
Expand Down
29 changes: 29 additions & 0 deletions src/opencode_a2a/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"custom",
]
OutsideWorkspaceAccess = Literal["unknown", "allowed", "disallowed", "custom"]
TaskStoreBackend = Literal["memory", "database"]


def _parse_declared_list(value: Any) -> tuple[str, ...]:
Expand Down Expand Up @@ -135,6 +136,11 @@ class Settings(BaseSettings):
# Session cache settings
a2a_session_cache_ttl_seconds: int = Field(default=3600, alias="A2A_SESSION_CACHE_TTL_SECONDS")
a2a_session_cache_maxsize: int = Field(default=10_000, alias="A2A_SESSION_CACHE_MAXSIZE")
a2a_pending_session_claim_ttl_seconds: float = Field(
default=30.0,
gt=0.0,
alias="A2A_PENDING_SESSION_CLAIM_TTL_SECONDS",
)
a2a_interrupt_request_ttl_seconds: float = Field(
default=10_800.0,
ge=0.0,
Expand Down Expand Up @@ -176,7 +182,30 @@ class Settings(BaseSettings):
alias="A2A_CLIENT_SUPPORTED_TRANSPORTS",
)

# Task store settings
a2a_task_store_backend: TaskStoreBackend = Field(
default="memory",
alias="A2A_TASK_STORE_BACKEND",
)
a2a_task_store_database_url: str | None = Field(
default=None,
alias="A2A_TASK_STORE_DATABASE_URL",
)
a2a_task_store_table_name: str = Field(
default="tasks",
min_length=1,
alias="A2A_TASK_STORE_TABLE_NAME",
)
a2a_task_store_create_table: bool = Field(
default=True,
alias="A2A_TASK_STORE_CREATE_TABLE",
)

@model_validator(mode="after")
def _validate_sandbox_policy(self) -> Settings:
SandboxPolicy.from_settings(self).validate_configuration()
if self.a2a_task_store_backend == "database" and not self.a2a_task_store_database_url:
raise ValueError(
"A2A_TASK_STORE_DATABASE_URL is required when A2A_TASK_STORE_BACKEND=database"
)
return self
5 changes: 3 additions & 2 deletions src/opencode_a2a/contracts/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,8 +450,9 @@ def build_session_binding_extension_params(
"message to that upstream session."
),
(
"Otherwise, the server will create a new upstream session and cache "
"the (identity, contextId)->session_id mapping in memory with TTL."
"Otherwise, the server will create a new upstream session and retain "
"the (identity, contextId)->session_id mapping according to the "
"configured task/state store backend and TTL policy."
),
],
}
Expand Down
5 changes: 5 additions & 0 deletions src/opencode_a2a/execution/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

if TYPE_CHECKING:
from ..server.application import A2AClientManager
from ..server.state_store import SessionStateRepository

import httpx
from a2a.server.agent_execution import AgentExecutor, RequestContext
Expand Down Expand Up @@ -530,7 +531,9 @@ def __init__(
cancel_abort_timeout_seconds: float = 2.0,
session_cache_ttl_seconds: int = 3600,
session_cache_maxsize: int = 10_000,
pending_session_claim_ttl_seconds: float = 30.0,
a2a_client_manager: A2AClientManager | None = None,
session_state_repository: SessionStateRepository | None = None,
) -> None:
self._client = client
self._streaming_enabled = streaming_enabled
Expand All @@ -544,6 +547,8 @@ def __init__(
client=client,
session_cache_ttl_seconds=session_cache_ttl_seconds,
session_cache_maxsize=session_cache_maxsize,
pending_session_claim_ttl_seconds=pending_session_claim_ttl_seconds,
state_repository=session_state_repository,
)
self._stream_runtime = StreamRuntime(
client=client,
Expand Down
67 changes: 42 additions & 25 deletions src/opencode_a2a/execution/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncio

from .stream_state import _TTLCache
from ..server.state_store import MemorySessionStateRepository, SessionStateRepository


class SessionManager:
Expand All @@ -12,18 +12,15 @@ def __init__(
client,
session_cache_ttl_seconds: int = 3600,
session_cache_maxsize: int = 10_000,
pending_session_claim_ttl_seconds: float = 30.0,
state_repository: SessionStateRepository | None = None,
) -> None:
self._client = client
self._sessions = _TTLCache(
self._state_repository = state_repository or MemorySessionStateRepository(
ttl_seconds=session_cache_ttl_seconds,
maxsize=session_cache_maxsize,
pending_claim_ttl_seconds=pending_session_claim_ttl_seconds,
)
self._session_owners = _TTLCache(
ttl_seconds=session_cache_ttl_seconds,
maxsize=session_cache_maxsize,
refresh_on_get=True,
)
self._pending_session_claims: dict[str, str] = {}
self._lock = asyncio.Lock()
self._inflight_session_creates: dict[tuple[str, str], asyncio.Task[str]] = {}
self._session_locks: dict[str, asyncio.Lock] = {}
Expand All @@ -43,13 +40,21 @@ async def get_or_create_session(
session_id=preferred_session_id,
)
if not pending_claim:
self._sessions.set((identity, context_id), preferred_session_id)
async with self._lock:
await self._state_repository.set_session(
identity=identity,
context_id=context_id,
session_id=preferred_session_id,
)
return preferred_session_id, pending_claim

task: asyncio.Task[str] | None = None
cache_key = (identity, context_id)
async with self._lock:
existing = self._sessions.get(cache_key)
existing = await self._state_repository.get_session(
identity=cache_key[0],
context_id=cache_key[1],
)
if existing:
return existing, False
task = self._inflight_session_creates.get(cache_key)
Expand All @@ -68,14 +73,18 @@ async def get_or_create_session(
raise

async with self._lock:
owner = self._session_owners.get(session_id)
owner = await self._state_repository.get_owner(session_id=session_id)
if owner and owner != identity:
if self._inflight_session_creates.get(cache_key) is task:
self._inflight_session_creates.pop(cache_key, None)
raise PermissionError(f"Session {session_id} is not owned by you")
self._sessions.set(cache_key, session_id)
await self._state_repository.set_session(
identity=cache_key[0],
context_id=cache_key[1],
session_id=session_id,
)
if not owner:
self._session_owners.set(session_id, identity)
await self._state_repository.set_owner(session_id=session_id, identity=identity)
if self._inflight_session_creates.get(cache_key) is task:
self._inflight_session_creates.pop(cache_key, None)
return session_id, False
Expand All @@ -89,37 +98,45 @@ async def finalize_preferred_session_binding(
) -> None:
await self.finalize_session_claim(identity=identity, session_id=session_id)
async with self._lock:
self._sessions.set((identity, context_id), session_id)
await self._state_repository.set_session(
identity=identity,
context_id=context_id,
session_id=session_id,
)

async def claim_preferred_session(self, *, identity: str, session_id: str) -> bool:
async with self._lock:
owner = self._session_owners.get(session_id)
pending_owner = self._pending_session_claims.get(session_id)
owner = await self._state_repository.get_owner(session_id=session_id)
pending_owner = await self._state_repository.get_pending_claim(session_id=session_id)
if owner and owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
if pending_owner and pending_owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
if owner == identity:
return False
self._pending_session_claims[session_id] = identity
await self._state_repository.set_pending_claim(session_id=session_id, identity=identity)
return True

async def finalize_session_claim(self, *, identity: str, session_id: str) -> None:
async with self._lock:
owner = self._session_owners.get(session_id)
pending_owner = self._pending_session_claims.get(session_id)
owner = await self._state_repository.get_owner(session_id=session_id)
pending_owner = await self._state_repository.get_pending_claim(session_id=session_id)
if owner and owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
if pending_owner and pending_owner != identity:
raise PermissionError(f"Session {session_id} is not owned by you")
self._session_owners.set(session_id, identity)
if self._pending_session_claims.get(session_id) == identity:
self._pending_session_claims.pop(session_id, None)
await self._state_repository.set_owner(session_id=session_id, identity=identity)
await self._state_repository.clear_pending_claim(
session_id=session_id,
identity=identity,
)

async def release_preferred_session_claim(self, *, identity: str, session_id: str) -> None:
async with self._lock:
if self._pending_session_claims.get(session_id) == identity:
self._pending_session_claims.pop(session_id, None)
await self._state_repository.clear_pending_claim(
session_id=session_id,
identity=identity,
)

async def get_session_lock(self, session_id: str) -> asyncio.Lock:
async with self._lock:
Expand All @@ -136,5 +153,5 @@ async def pop_cached_session(
context_id: str,
) -> asyncio.Task[str] | None:
async with self._lock:
self._sessions.pop((identity, context_id))
await self._state_repository.pop_session(identity=identity, context_id=context_id)
return self._inflight_session_creates.pop((identity, context_id), None)
4 changes: 2 additions & 2 deletions src/opencode_a2a/execution/stream_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def _tool_chunks(
None,
)
if callable(remember_request):
remember_request(
await remember_request(
request_id=request_id,
session_id=session_id,
interrupt_type=asked["interrupt_type"],
Expand All @@ -456,7 +456,7 @@ def _tool_chunks(
None,
)
if callable(discard_request):
discard_request(resolved_request_id)
await discard_request(resolved_request_id)
if cleared_pending:
await _emit_interrupt_status(
state=TaskState.working,
Expand Down
8 changes: 4 additions & 4 deletions src/opencode_a2a/jsonrpc/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ async def _handle_interrupt_callback_request(
)
resolve_request = getattr(self._upstream_client, "resolve_interrupt_request", None)
if callable(resolve_request):
status, binding = resolve_request(request_id)
status, binding = await resolve_request(request_id)
if status != "active" or binding is None:
return self._generate_error_response(
base_request.id,
Expand Down Expand Up @@ -818,7 +818,7 @@ async def _handle_interrupt_callback_request(
else:
resolve_session = getattr(self._upstream_client, "resolve_interrupt_session", None)
if callable(resolve_session):
if not resolve_session(request_id):
if not await resolve_session(request_id):
return self._generate_error_response(
base_request.id,
interrupt_not_found_error(
Expand Down Expand Up @@ -869,7 +869,7 @@ async def _handle_interrupt_callback_request(
await self._upstream_client.question_reject(request_id, directory=directory)
discard_request = getattr(self._upstream_client, "discard_interrupt_request", None)
if callable(discard_request):
discard_request(request_id)
await discard_request(request_id)
except ValueError as exc:
return self._generate_error_response(
base_request.id,
Expand All @@ -880,7 +880,7 @@ async def _handle_interrupt_callback_request(
if upstream_status == 404:
discard_request = getattr(self._upstream_client, "discard_interrupt_request", None)
if callable(discard_request):
discard_request(request_id)
await discard_request(request_id)
return self._generate_error_response(
base_request.id,
interrupt_not_found_error(
Expand Down
Loading
Loading