diff --git a/services/api-rs/crates/centaur-api-server/src/error.rs b/services/api-rs/crates/centaur-api-server/src/error.rs index e0e28b69..3df4754c 100644 --- a/services/api-rs/crates/centaur-api-server/src/error.rs +++ b/services/api-rs/crates/centaur-api-server/src/error.rs @@ -36,6 +36,9 @@ impl IntoResponse for ApiError { Self::Runtime(SessionRuntimeError::Store(SessionStoreError::HarnessConflict { .. })) => StatusCode::CONFLICT, + Self::Runtime(SessionRuntimeError::Store(SessionStoreError::PersonaConflict { + .. + })) => StatusCode::CONFLICT, Self::Runtime(_) | Self::Serialize(_) => StatusCode::INTERNAL_SERVER_ERROR, }; let body = Json(json!({ diff --git a/services/api-rs/crates/centaur-api-server/src/routes.rs b/services/api-rs/crates/centaur-api-server/src/routes.rs index 4f5158ec..56276657 100644 --- a/services/api-rs/crates/centaur-api-server/src/routes.rs +++ b/services/api-rs/crates/centaur-api-server/src/routes.rs @@ -55,7 +55,12 @@ async fn create_or_get_session( let thread_key = ThreadKey::try_from(raw_thread_key)?; let session = state .runtime - .create_or_get_session(&thread_key, &request.harness_type, request.metadata) + .create_or_get_session( + &thread_key, + &request.harness_type, + request.persona_id.as_deref(), + request.metadata, + ) .await?; Ok(Json(session)) } diff --git a/services/api-rs/crates/centaur-api-server/src/types.rs b/services/api-rs/crates/centaur-api-server/src/types.rs index 03dc7731..bac70f3f 100644 --- a/services/api-rs/crates/centaur-api-server/src/types.rs +++ b/services/api-rs/crates/centaur-api-server/src/types.rs @@ -8,6 +8,7 @@ use thiserror::Error; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct CreateSessionRequest { pub harness_type: HarnessType, + pub persona_id: Option, pub metadata: Option, } diff --git a/services/api-rs/crates/centaur-session-cli/src/main.rs b/services/api-rs/crates/centaur-session-cli/src/main.rs index 1642f796..f74ec34b 100644 --- a/services/api-rs/crates/centaur-session-cli/src/main.rs +++ b/services/api-rs/crates/centaur-session-cli/src/main.rs @@ -103,6 +103,7 @@ async fn main() -> Result<()> { &thread_key, CreateSessionRequest { harness_type: args.harness_type.into(), + persona_id: None, metadata: Some(json!({ "source": "centaur-session-cli", })), diff --git a/services/api-rs/crates/centaur-session-core/src/lib.rs b/services/api-rs/crates/centaur-session-core/src/lib.rs index 29a73c2a..d7033162 100644 --- a/services/api-rs/crates/centaur-session-core/src/lib.rs +++ b/services/api-rs/crates/centaur-session-core/src/lib.rs @@ -143,6 +143,7 @@ pub struct Session { pub sandbox_id: Option, pub harness_type: HarnessType, pub harness_thread_id: Option, + pub persona_id: Option, pub status: SessionStatus, /// iron-control principal OID this session's egress proxy binds to, /// captured at registration so a resumed session can recreate its sandbox. diff --git a/services/api-rs/crates/centaur-session-runtime/src/lib.rs b/services/api-rs/crates/centaur-session-runtime/src/lib.rs index e9f98e65..0a42553c 100644 --- a/services/api-rs/crates/centaur-session-runtime/src/lib.rs +++ b/services/api-rs/crates/centaur-session-runtime/src/lib.rs @@ -159,6 +159,7 @@ impl SessionRuntime { &self, thread_key: &ThreadKey, harness_type: &HarnessType, + persona_id: Option<&str>, metadata: Option, ) -> Result { // Read slack_user_id before `metadata` is consumed below; it keys the @@ -170,7 +171,12 @@ impl SessionRuntime { .map(ToOwned::to_owned); let session = self .store - .create_or_get_session(thread_key, harness_type, default_metadata(metadata)) + .create_or_get_session( + thread_key, + harness_type, + persona_id, + default_metadata(metadata), + ) .await?; if let Some(registrar) = &self.iron_control { // iron-control is the source of truth for the session's egress @@ -2580,6 +2586,7 @@ mod tests { sandbox_id: Some(sandbox_id.to_owned()), harness_type: HarnessType::Codex, harness_thread_id: None, + persona_id: None, status: SessionStatus::Idle, iron_control_principal: None, created_at: now, diff --git a/services/api-rs/crates/centaur-session-sqlx/migrations/0006_session_persona_id.sql b/services/api-rs/crates/centaur-session-sqlx/migrations/0006_session_persona_id.sql new file mode 100644 index 00000000..7b57bd8a --- /dev/null +++ b/services/api-rs/crates/centaur-session-sqlx/migrations/0006_session_persona_id.sql @@ -0,0 +1,6 @@ +alter table sessions + add column if not exists persona_id text; + +alter table sessions + add constraint sessions_persona_id_len + check (persona_id is null or octet_length(persona_id) between 1 and 128); diff --git a/services/api-rs/crates/centaur-session-sqlx/src/lib.rs b/services/api-rs/crates/centaur-session-sqlx/src/lib.rs index a46773c7..81cb2efe 100644 --- a/services/api-rs/crates/centaur-session-sqlx/src/lib.rs +++ b/services/api-rs/crates/centaur-session-sqlx/src/lib.rs @@ -63,17 +63,19 @@ impl PgSessionStore { &self, thread_key: &ThreadKey, harness_type: &HarnessType, + persona_id: Option<&str>, metadata: Value, ) -> Result { sqlx::query( r#" - insert into sessions (thread_key, harness_type, status, metadata) - values ($1, $2, $3, $4) + insert into sessions (thread_key, harness_type, persona_id, status, metadata) + values ($1, $2, $3, $4, $5) on conflict (thread_key) do nothing "#, ) .bind(thread_key.as_str()) .bind(harness_type.as_ref()) + .bind(persona_id) .bind(SessionStatus::Idle.as_ref()) .bind(metadata) .execute(&self.pool) @@ -87,13 +89,20 @@ impl PgSessionStore { requested: harness_type.as_ref().to_owned(), }); } + if session.persona_id.as_deref() != persona_id { + return Err(SessionStoreError::PersonaConflict { + thread_key: thread_key.as_str().to_owned(), + existing: session.persona_id, + requested: persona_id.map(str::to_owned), + }); + } Ok(session) } pub async fn get_session(&self, thread_key: &ThreadKey) -> Result { let row = sqlx::query_as::<_, SessionRow>( r#" - select thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at + select thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at from sessions where thread_key = $1 "#, @@ -447,7 +456,7 @@ impl PgSessionStore { update sessions set sandbox_id = $2, updated_at = now() where thread_key = $1 - returning thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at + returning thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at "#, ) .bind(thread_key.as_str()) @@ -468,7 +477,7 @@ impl PgSessionStore { update sessions set iron_control_principal = $2, updated_at = now() where thread_key = $1 - returning thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at + returning thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at "#, ) .bind(thread_key.as_str()) @@ -576,7 +585,7 @@ impl PgSessionStore { update sessions set harness_thread_id = $2, updated_at = now() where thread_key = $1 - returning thread_key, sandbox_id, harness_type, harness_thread_id, status, iron_control_principal, created_at, updated_at + returning thread_key, sandbox_id, harness_type, harness_thread_id, persona_id, status, iron_control_principal, created_at, updated_at "#, ) .bind(thread_key.as_str()) @@ -649,6 +658,14 @@ pub enum SessionStoreError { existing: String, requested: String, }, + #[error( + "session {thread_key} already exists with persona_id {existing:?}, requested {requested:?}" + )] + PersonaConflict { + thread_key: String, + existing: Option, + requested: Option, + }, #[error("invalid persisted value: {0}")] InvalidPersistedValue(String), #[error("invalid notification payload on {channel}: {payload}: {error}")] @@ -669,6 +686,7 @@ struct SessionRow { sandbox_id: Option, harness_type: String, harness_thread_id: Option, + persona_id: Option, status: String, iron_control_principal: Option, created_at: OffsetDateTime, @@ -684,6 +702,7 @@ impl TryFrom for Session { sandbox_id: row.sandbox_id, harness_type: parse_persisted(row.harness_type)?, harness_thread_id: row.harness_thread_id, + persona_id: row.persona_id, status: parse_persisted(row.status)?, iron_control_principal: row.iron_control_principal, created_at: row.created_at,