Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion omem-server/src/api/handlers/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ pub struct CreateMemoryBody {
#[serde(default)]
pub tags: Option<Vec<String>>,
pub source: Option<String>,
pub memory_type: Option<String>,
}

#[derive(Deserialize)]
Expand Down Expand Up @@ -95,6 +96,7 @@ pub struct UpdateMemoryBody {
pub content: Option<String>,
pub tags: Option<Vec<String>>,
pub state: Option<String>,
pub memory_type: Option<String>,
}

#[derive(Serialize)]
Expand Down Expand Up @@ -201,10 +203,15 @@ pub async fn create_memory(
return Err(OmemError::Validation("content cannot be empty".to_string()));
}

let memory_type = match body.memory_type {
Some(s) => s.parse().map_err(OmemError::Validation)?,
None => MemoryType::Pinned,
};

let mut memory = Memory::new(
&content,
Category::Preferences,
MemoryType::Pinned,
memory_type,
&auth.tenant_id,
);
memory.tags = body.tags.unwrap_or_default();
Expand Down Expand Up @@ -525,6 +532,12 @@ pub async fn update_memory(
.map_err(|e: String| OmemError::Validation(e))?;
}

if let Some(memory_type_str) = body.memory_type {
memory.memory_type = memory_type_str
.parse()
.map_err(|e: String| OmemError::Validation(e))?;
}

memory.updated_at = chrono::Utc::now().to_rfc3339();

let vector = if need_reembed {
Expand Down
139 changes: 139 additions & 0 deletions omem-server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,145 @@ mod tests {
assert_eq!(json["tags"][0], "new-tag");
}

#[tokio::test]
async fn test_create_memory_with_type() {
let (app, _dir) = setup_app().await;
let api_key = create_test_tenant(&app).await;

// Create with explicit memory_type=insight.
let create_resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"content":"an insight","memory_type":"insight"}"#))
.expect("request"),
)
.await
.expect("response");
let bytes = create_resp.into_body().collect().await.expect("body").to_bytes();
let created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
assert_eq!(created["memory_type"], "insight");

// Default (no memory_type) still becomes pinned for backwards compat.
let default_resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"content":"default"}"#))
.expect("request"),
)
.await
.expect("response");
let bytes = default_resp.into_body().collect().await.expect("body").to_bytes();
let default_created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
assert_eq!(default_created["memory_type"], "pinned");
}

#[tokio::test]
async fn test_update_memory_type() {
let (app, _dir) = setup_app().await;
let api_key = create_test_tenant(&app).await;

// Create a pinned memory (default).
let create_resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"content":"originally pinned"}"#))
.expect("request"),
)
.await
.expect("response");

let bytes = create_resp
.into_body()
.collect()
.await
.expect("body")
.to_bytes();
let created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
let memory_id = created["id"].as_str().expect("id");
assert_eq!(created["memory_type"], "pinned");

// Demote to insight.
let update_resp = app
.clone()
.oneshot(
Request::builder()
.method("PUT")
.uri(format!("/v1/memories/{memory_id}"))
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"memory_type":"insight"}"#))
.expect("request"),
)
.await
.expect("response");

assert_eq!(update_resp.status(), StatusCode::OK);

let bytes = update_resp
.into_body()
.collect()
.await
.expect("body")
.to_bytes();
let json: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
assert_eq!(json["memory_type"], "insight");
}

#[tokio::test]
async fn test_update_memory_type_invalid() {
let (app, _dir) = setup_app().await;
let api_key = create_test_tenant(&app).await;

let create_resp = app
.clone()
.oneshot(
Request::builder()
.method("POST")
.uri("/v1/memories")
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"content":"test"}"#))
.expect("request"),
)
.await
.expect("response");
let bytes = create_resp.into_body().collect().await.expect("body").to_bytes();
let created: serde_json::Value = serde_json::from_slice(&bytes).expect("json");
let memory_id = created["id"].as_str().expect("id");

// Unknown type should fail validation, not silently accept.
let update_resp = app
.clone()
.oneshot(
Request::builder()
.method("PUT")
.uri(format!("/v1/memories/{memory_id}"))
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"memory_type":"bogus"}"#))
.expect("request"),
)
.await
.expect("response");

assert_eq!(update_resp.status(), StatusCode::BAD_REQUEST);
}

#[tokio::test]
async fn test_search_memories() {
let (app, _dir) = setup_app().await;
Expand Down
Loading