Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install build deps
# lance-encoding's build.rs invokes protoc to generate prost
# bindings. `protobuf-compiler` provides the `protoc` binary;
# `libprotobuf-dev` provides the well-known types (e.g.
# google/protobuf/empty.proto) under /usr/include/google/protobuf/
# that lance-encoding's .proto files import. Both are needed.
run: |
sudo apt-get update -qq
sudo apt-get install -y --no-install-recommends \
protobuf-compiler libprotobuf-dev
- uses: dtolnay/rust-toolchain@stable
with:
components: clippy, rustfmt
Expand Down
57 changes: 28 additions & 29 deletions omem-server/src/api/handlers/sharing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,32 +512,31 @@ pub async fn batch_share(

use futures::stream::{self, StreamExt};

let results: Vec<(String, Result<Memory, OmemError>)> =
stream::iter(body.memory_ids.into_iter())
.map(|mem_id| {
let source_store = source_store.clone();
let target_store = target_store.clone();
let space_store = state.space_store.clone();
let target_space_id = target_space.id.clone();
let user_id = auth.tenant_id.clone();
let agent_id = agent_id.clone();
async move {
let result = share_single(
&source_store,
&target_store,
&space_store,
&mem_id,
&target_space_id,
&user_id,
&agent_id,
)
.await;
(mem_id, result)
}
})
.buffer_unordered(10)
.collect()
.await;
let results: Vec<(String, Result<Memory, OmemError>)> = stream::iter(body.memory_ids)
.map(|mem_id| {
let source_store = source_store.clone();
let target_store = target_store.clone();
let space_store = state.space_store.clone();
let target_space_id = target_space.id.clone();
let user_id = auth.tenant_id.clone();
let agent_id = agent_id.clone();
async move {
let result = share_single(
&source_store,
&target_store,
&space_store,
&mem_id,
&target_space_id,
&user_id,
&agent_id,
)
.await;
(mem_id, result)
}
})
.buffer_unordered(10)
.collect()
.await;

let mut succeeded = Vec::new();
let mut failed = Vec::new();
Expand Down Expand Up @@ -921,7 +920,7 @@ pub async fn share_all(

use futures::stream::{self, StreamExt};

let results: Vec<(bool, bool)> = stream::iter(filtered_ids.into_iter())
let results: Vec<(bool, bool)> = stream::iter(filtered_ids)
.map(|mem_id| {
let source_store = source_store.clone();
let target_store = target_store.clone();
Expand Down Expand Up @@ -1087,7 +1086,7 @@ pub async fn share_all_to_user(

use futures::stream::{self, StreamExt};

let results: Vec<(bool, bool)> = stream::iter(filtered_ids.into_iter())
let results: Vec<(bool, bool)> = stream::iter(filtered_ids)
.map(|mem_id| {
let source_store = source_store.clone();
let target_store = target_store.clone();
Expand Down Expand Up @@ -1260,7 +1259,7 @@ pub async fn org_publish(

use futures::stream::{self, StreamExt};

let results: Vec<(bool, bool)> = stream::iter(memory_ids.clone().into_iter())
let results: Vec<(bool, bool)> = stream::iter(memory_ids.clone())
.map(|mem_id| {
let source_store = source_store.clone();
let target_store = target_store.clone();
Expand Down
6 changes: 3 additions & 3 deletions omem-server/src/api/handlers/stats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ pub async fn get_tags(
.into_iter()
.filter(|(_, c)| *c >= params.min_count)
.collect();
tags.sort_by(|a, b| b.1.cmp(&a.1));
tags.sort_by_key(|t| std::cmp::Reverse(t.1));
let total_unique = tags.len();
tags.truncate(params.limit);

Expand Down Expand Up @@ -554,7 +554,7 @@ pub async fn get_spaces_stats(
}

let mut top_categories: Vec<(String, usize)> = category_counts.into_iter().collect();
top_categories.sort_by(|a, b| b.1.cmp(&a.1));
top_categories.sort_by_key(|c| std::cmp::Reverse(c.1));
top_categories.truncate(3);

space_stats.push(serde_json::json!({
Expand Down Expand Up @@ -758,7 +758,7 @@ pub async fn get_agents_stats(
}

let mut top_categories: Vec<(String, usize)> = category_counts.into_iter().collect();
top_categories.sort_by(|a, b| b.1.cmp(&a.1));
top_categories.sort_by_key(|c| std::cmp::Reverse(c.1));
top_categories.truncate(3);

let share_count = agent_share_counts.get(agent_id).copied().unwrap_or(0);
Expand Down
25 changes: 20 additions & 5 deletions omem-server/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1327,7 +1327,10 @@ mod tests {
.oneshot(
Request::builder()
.method("POST")
.uri(format!("/v1/spaces/{}/members", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}/members",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"user_id":"bob","role":"member"}"#))
Expand Down Expand Up @@ -1386,7 +1389,10 @@ mod tests {
.oneshot(
Request::builder()
.method("DELETE")
.uri(format!("/v1/spaces/{}/members/alice", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}/members/alice",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
Expand Down Expand Up @@ -1518,7 +1524,10 @@ mod tests {
.oneshot(
Request::builder()
.method("DELETE")
.uri(format!("/v1/spaces/{}", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
Expand All @@ -1532,7 +1541,10 @@ mod tests {
.clone()
.oneshot(
Request::builder()
.uri(format!("/v1/spaces/{}", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("x-api-key", &api_key)
.body(Body::empty())
.expect("request"),
Expand Down Expand Up @@ -1578,7 +1590,10 @@ mod tests {
.oneshot(
Request::builder()
.method("PUT")
.uri(format!("/v1/spaces/{}/members/carol", utf8_percent_encode(&space_id, NON_ALPHANUMERIC)))
.uri(format!(
"/v1/spaces/{}/members/carol",
utf8_percent_encode(&space_id, NON_ALPHANUMERIC)
))
.header("content-type", "application/json")
.header("x-api-key", &api_key)
.body(Body::from(r#"{"role":"admin"}"#))
Expand Down
5 changes: 4 additions & 1 deletion omem-server/src/retrieve/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,10 @@ mod tests {
];

let (result, _) = RetrievalPipeline::stage_rrf_normalize(entries);
let top = result.iter().find(|e| e.memory.content == "weak-top").unwrap();
let top = result
.iter()
.find(|e| e.memory.content == "weak-top")
.unwrap();
assert!(
top.rrf_score < 0.25,
"weak top result should stay below 0.25, got {}",
Expand Down
Loading