Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions services/api-rs/crates/centaur-api-server/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ impl IntoResponse for ApiError {
Self::Runtime(SessionRuntimeError::Store(SessionStoreError::HarnessConflict {
..
})) => StatusCode::CONFLICT,
Self::Runtime(SessionRuntimeError::Store(SessionStoreError::PersonaConflict {
..
})) => StatusCode::CONFLICT,
Self::Runtime(_) | Self::Serialize(_) => StatusCode::INTERNAL_SERVER_ERROR,
};
let body = Json(json!({
Expand Down
7 changes: 6 additions & 1 deletion services/api-rs/crates/centaur-api-server/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ async fn create_or_get_session(
let thread_key = ThreadKey::try_from(raw_thread_key)?;
let session = state
.runtime
.create_or_get_session(&thread_key, &request.harness_type, request.metadata)
.create_or_get_session(
&thread_key,
&request.harness_type,
request.persona_id.as_deref(),
request.metadata,
)
.await?;
Ok(Json(session))
}
Expand Down
1 change: 1 addition & 0 deletions services/api-rs/crates/centaur-api-server/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use thiserror::Error;
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CreateSessionRequest {
pub harness_type: HarnessType,
pub persona_id: Option<String>,
pub metadata: Option<Value>,
}

Expand Down
1 change: 1 addition & 0 deletions services/api-rs/crates/centaur-session-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ async fn main() -> Result<()> {
&thread_key,
CreateSessionRequest {
harness_type: args.harness_type.into(),
persona_id: None,
metadata: Some(json!({
"source": "centaur-session-cli",
})),
Expand Down
1 change: 1 addition & 0 deletions services/api-rs/crates/centaur-session-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ pub struct Session {
pub sandbox_id: Option<String>,
pub harness_type: HarnessType,
pub harness_thread_id: Option<String>,
pub persona_id: Option<String>,
pub status: SessionStatus,
/// iron-control principal OID this session's egress proxy binds to,
/// captured at registration so a resumed session can recreate its sandbox.
Expand Down
9 changes: 8 additions & 1 deletion services/api-rs/crates/centaur-session-runtime/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ impl SessionRuntime {
&self,
thread_key: &ThreadKey,
harness_type: &HarnessType,
persona_id: Option<&str>,
metadata: Option<Value>,
) -> Result<Session, SessionRuntimeError> {
// Read slack_user_id before `metadata` is consumed below; it keys the
Expand All @@ -170,7 +171,12 @@ impl SessionRuntime {
.map(ToOwned::to_owned);
let session = self
.store
.create_or_get_session(thread_key, harness_type, default_metadata(metadata))
.create_or_get_session(
thread_key,
harness_type,
persona_id,
default_metadata(metadata),
)
.await?;
if let Some(registrar) = &self.iron_control {
// iron-control is the source of truth for the session's egress
Expand Down Expand Up @@ -2580,6 +2586,7 @@ mod tests {
sandbox_id: Some(sandbox_id.to_owned()),
harness_type: HarnessType::Codex,
harness_thread_id: None,
persona_id: None,
status: SessionStatus::Idle,
iron_control_principal: None,
created_at: now,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
alter table sessions
add column if not exists persona_id text;

alter table sessions
add constraint sessions_persona_id_len
check (persona_id is null or octet_length(persona_id) between 1 and 128);
31 changes: 25 additions & 6 deletions services/api-rs/crates/centaur-session-sqlx/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,19 @@ impl PgSessionStore {
&self,
thread_key: &ThreadKey,
harness_type: &HarnessType,
persona_id: Option<&str>,
metadata: Value,
) -> Result<Session, SessionStoreError> {
sqlx::query(
r#"
insert into sessions (thread_key, harness_type, status, metadata)
values ($1, $2, $3, $4)
insert into sessions (thread_key, harness_type, persona_id, status, metadata)
values ($1, $2, $3, $4, $5)
on conflict (thread_key) do nothing
"#,
)
.bind(thread_key.as_str())
.bind(harness_type.as_ref())
.bind(persona_id)
.bind(SessionStatus::Idle.as_ref())
.bind(metadata)
.execute(&self.pool)
Expand All @@ -87,13 +89,20 @@ impl PgSessionStore {
requested: harness_type.as_ref().to_owned(),
});
}
if session.persona_id.as_deref() != persona_id {
return Err(SessionStoreError::PersonaConflict {
thread_key: thread_key.as_str().to_owned(),
existing: session.persona_id,
requested: persona_id.map(str::to_owned),
});
}
Ok(session)
}

pub async fn get_session(&self, thread_key: &ThreadKey) -> Result<Session, SessionStoreError> {
let row = sqlx::query_as::<_, SessionRow>(
r#"
select thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at
select thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at
from sessions
where thread_key = $1
"#,
Expand Down Expand Up @@ -447,7 +456,7 @@ impl PgSessionStore {
update sessions
set sandbox_id = $2, updated_at = now()
where thread_key = $1
returning thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at
returning thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at
"#,
)
.bind(thread_key.as_str())
Expand All @@ -468,7 +477,7 @@ impl PgSessionStore {
update sessions
set iron_control_principal = $2, updated_at = now()
where thread_key = $1
returning thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at
returning thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at
"#,
)
.bind(thread_key.as_str())
Expand Down Expand Up @@ -576,7 +585,7 @@ impl PgSessionStore {
update sessions
set harness_thread_id = $2, updated_at = now()
where thread_key = $1
returning thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at
returning thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at
"#,
)
.bind(thread_key.as_str())
Expand Down Expand Up @@ -649,6 +658,14 @@ pub enum SessionStoreError {
existing: String,
requested: String,
},
#[error(
"session {thread_key} already exists with persona_id {existing:?}, requested {requested:?}"
)]
PersonaConflict {
thread_key: String,
existing: Option<String>,
requested: Option<String>,
},
#[error("invalid persisted value: {0}")]
InvalidPersistedValue(String),
#[error("invalid notification payload on {channel}: {payload}: {error}")]
Expand All @@ -669,6 +686,7 @@ struct SessionRow {
sandbox_id: Option<String>,
harness_type: String,
harness_thread_id: Option<String>,
persona_id: Option<String>,
status: String,
iron_control_principal: Option<String>,
created_at: OffsetDateTime,
Expand All @@ -684,6 +702,7 @@ impl TryFrom<SessionRow> for Session {
sandbox_id: row.sandbox_id,
harness_type: parse_persisted(row.harness_type)?,
harness_thread_id: row.harness_thread_id,
persona_id: row.persona_id,
status: parse_persisted(row.status)?,
iron_control_principal: row.iron_control_principal,
created_at: row.created_at,
Expand Down
Loading