Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion omem-server/src/api/handlers/imports.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,9 @@ pub async fn cross_reconcile(
};

// limit=6 to account for self appearing in results
let similar = store.vector_search(query_vec, 6, 0.85, None, None).await?;
let similar = store
.vector_search(query_vec, 6, 0.85, None, None, false)
.await?;

for (candidate, score) in &similar {
if candidate.id == memory.id {
Expand Down
24 changes: 23 additions & 1 deletion omem-server/src/api/handlers/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ pub struct CreateMemoryBody {
pub tags: Option<Vec<String>>,
pub source: Option<String>,
pub memory_type: Option<String>,
/// IDs of memories to mark superseded by this one. Used when consolidating
/// fragmented memories (e.g., chunked old-embedder content) into a single
/// new memory in one atomic call.
#[serde(default)]
pub replaces: Option<Vec<String>>,
}

#[derive(Deserialize)]
Expand All @@ -61,6 +66,8 @@ pub struct SearchQuery {
pub agent_id: Option<String>,
#[serde(default)]
pub check_stale: bool,
#[serde(default)]
pub include_superseded: bool,
}

fn default_limit() -> usize {
Expand All @@ -82,6 +89,8 @@ pub struct ListQuery {
pub sort: String,
#[serde(default = "default_order")]
pub order: String,
#[serde(default)]
pub include_superseded: bool,
}

fn default_sort() -> String {
Expand Down Expand Up @@ -228,7 +237,16 @@ pub async fn create_memory(
.map_err(|e| OmemError::Embedding(format!("failed to embed content: {e}")))?;
let vector = vectors.into_iter().next();

store.create(&memory, vector.as_deref()).await?;
match body.replaces.as_deref() {
Some(ids) if !ids.is_empty() => {
store
.supersede_batch(&memory, vector.as_deref(), ids)
.await?;
}
_ => {
store.create(&memory, vector.as_deref()).await?;
}
}

// Fire-and-forget: check auto-share rules for the newly created memory
{
Expand Down Expand Up @@ -303,6 +321,7 @@ pub async fn search_memories(
.map(|t| t.split(',').map(|s| s.trim().to_string()).collect()),
source_filter: params.source.clone(),
agent_id_filter: params.agent_id.clone(),
include_superseded: params.include_superseded,
};

let retrieval_pipeline = RetrievalPipeline::new(store);
Expand Down Expand Up @@ -364,6 +383,7 @@ pub async fn search_memories(
});
let source_filter = params.source.clone();
let agent_id_filter = params.agent_id.clone();
let include_superseded = params.include_superseded;
let store = acc.store.clone();
let space_id = acc.space_id.clone();
let weight = acc.weight;
Expand All @@ -380,6 +400,7 @@ pub async fn search_memories(
tags_filter,
source_filter,
agent_id_filter,
include_superseded,
};
let pipeline = RetrievalPipeline::new(store);
let result = pipeline.search(&request).await;
Expand Down Expand Up @@ -623,6 +644,7 @@ pub async fn list_memories(
.map(|t| t.split(',').map(|s| s.trim().to_string()).collect()),
memory_type: params.memory_type,
state: params.state,
include_superseded: params.include_superseded,
sort: params.sort,
order: params.order,
};
Expand Down
6 changes: 3 additions & 3 deletions omem-server/src/api/handlers/sharing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1591,7 +1591,7 @@ mod tests {
target_store.create(&copy, None).await.expect("batch share");
}

let team_list = target_store.list(100, 0).await.expect("list");
let team_list = target_store.list(100, 0, false).await.expect("list");
assert_eq!(team_list.len(), 3);
}

Expand Down Expand Up @@ -1638,7 +1638,7 @@ mod tests {
.get_store("team:backend")
.await
.expect("team store");
let team_list = team_store.list(100, 0).await.expect("list");
let team_list = team_store.list(100, 0, false).await.expect("list");
assert_eq!(team_list.len(), 1);
assert_eq!(team_list[0].content, "prefers vim keybindings");
}
Expand Down Expand Up @@ -1907,7 +1907,7 @@ mod tests {
.expect("share single");
}

let team_list = target_store.list(100, 0).await.expect("list");
let team_list = target_store.list(100, 0, false).await.expect("list");
assert_eq!(team_list.len(), 1);
assert!(team_list[0].content.contains("dark mode"));
}
Expand Down
230 changes: 225 additions & 5 deletions omem-server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,211 @@ mod tests {
assert_eq!(update_resp.status(), StatusCode::BAD_REQUEST);
}

async fn create_test_memory(app: &axum::Router, api_key: &str, content: &str) -> String {
let body = format!(r#"{{"content":"{content}"}}"#);
let resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", api_key)
.body(Body::from(body))
.expect("request"),
)
.await
.expect("response");
let bytes = resp.into_body().collect().await.expect("body").to_bytes();
let v: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
v["id"].as_str().expect("id").to_string()
}

#[tokio::test]
async fn test_create_with_replaces_supersedes_old() {
let (app, _dir) = setup_app().await;
let api_key = create_test_tenant(&app).await;

let old1 = create_test_memory(&app, &api_key, "fragment one").await;
let old2 = create_test_memory(&app, &api_key, "fragment two").await;

let body = format!(r#"{{"content":"consolidated","replaces":["{old1}","{old2}"]}}"#);
let create_resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(body))
.expect("request"),
)
.await
.expect("response");
assert_eq!(create_resp.status(), StatusCode::CREATED);

let bytes = create_resp
.into_body()
.collect()
.await
.expect("body")
.to_bytes();
let new_mem: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
let new_id = new_mem["id"].as_str().expect("id");

// Old memories should now be in superseded state via direct fetch.
for old_id in [&old1, &old2] {
let get_resp = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/v1/memories/{old_id}"))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
)
.await
.expect("response");
assert_eq!(get_resp.status(), StatusCode::OK);
let bytes = get_resp
.into_body()
.collect()
.await
.expect("body")
.to_bytes();
let old: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
assert_eq!(old["state"], "superseded");
assert_eq!(old["superseded_by"], new_id);
}
}

#[tokio::test]
async fn test_create_with_missing_replaces_returns_400() {
let (app, _dir) = setup_app().await;
let api_key = create_test_tenant(&app).await;

let real_id = create_test_memory(&app, &api_key, "real one").await;

// One real id, one ghost id.
let body =
format!(r#"{{"content":"consolidated","replaces":["{real_id}","ghost-id-nope"]}}"#);
let resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(body))
.expect("request"),
)
.await
.expect("response");
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);

let bytes = resp.into_body().collect().await.expect("body").to_bytes();
let err: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
let msg = err["error"]["message"].as_str().unwrap_or_default();
assert!(
msg.contains("ghost-id-nope"),
"error should name the missing id, got: {msg}"
);

// The real memory must NOT have been touched.
let get_resp = app
.clone()
.oneshot(
Request::builder()
.uri(format!("/v1/memories/{real_id}"))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
)
.await
.expect("response");
let bytes = get_resp
.into_body()
.collect()
.await
.expect("body")
.to_bytes();
let v: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
assert_eq!(v["state"], "active");
}

#[tokio::test]
async fn test_search_excludes_superseded_by_default() {
let (app, _dir) = setup_app().await;
let api_key = create_test_tenant(&app).await;

let old = create_test_memory(&app, &api_key, "fragment about rust programming").await;
let body =
format!(r#"{{"content":"unified rust programming notes","replaces":["{old}"]}}"#);
app.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(body))
.expect("request"),
)
.await
.expect("response");

// Default search should NOT surface the superseded old fragment.
let resp = app
.clone()
.oneshot(
Request::builder()
.uri("/v1/memories/search?q=rust&limit=20")
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
)
.await
.expect("response");
let bytes = resp.into_body().collect().await.expect("body").to_bytes();
let r: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
let results = r["results"].as_array().expect("results");
for res in results {
let id = res["memory"]["id"].as_str().unwrap_or_default();
assert_ne!(id, old.as_str(), "superseded memory should not appear");
}

// With include_superseded=true, the old should resurface.
let resp_inc = app
.clone()
.oneshot(
Request::builder()
.uri("/v1/memories/search?q=rust&limit=20&include_superseded=true")
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
)
.await
.expect("response");
let bytes = resp_inc
.into_body()
.collect()
.await
.expect("body")
.to_bytes();
let r: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
let results = r["results"].as_array().expect("results");
let ids: Vec<&str> = results
.iter()
.map(|res| res["memory"]["id"].as_str().unwrap_or_default())
.collect();
assert!(
ids.contains(&old.as_str()),
"old should resurface with include_superseded=true, got ids: {ids:?}"
);
}

#[tokio::test]
async fn test_search_memories() {
let (app, _dir) = setup_app().await;
Expand Down Expand Up @@ -1327,7 +1532,10 @@ mod tests {
.oneshot(
Request::builder()
.method("POST")
.uri(format!("/v1/spaces/{}/members", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}/members",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"user_id":"bob","role":"member"}"#))
Expand Down Expand Up @@ -1386,7 +1594,10 @@ mod tests {
.oneshot(
Request::builder()
.method("DELETE")
.uri(format!("/v1/spaces/{}/members/alice", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}/members/alice",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
Expand Down Expand Up @@ -1518,7 +1729,10 @@ mod tests {
.oneshot(
Request::builder()
.method("DELETE")
.uri(format!("/v1/spaces/{}", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
Expand All @@ -1532,7 +1746,10 @@ mod tests {
.clone()
.oneshot(
Request::builder()
.uri(format!("/v1/spaces/{}", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
Expand Down Expand Up @@ -1578,7 +1795,10 @@ mod tests {
.oneshot(
Request::builder()
.method("PUT")
.uri(format!("/v1/spaces/{}/members/carol", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}/members/carol",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"role":"admin"}"#))
Expand Down
Loading
Loading