Skip to content

Commit 49deff4

Browse files
committed
Added hybrid search example [skip ci]
1 parent 06e56df commit 49deff4

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Or check out some examples:
3737

3838
- [Embeddings](examples/openai/example.cpp) with OpenAI
3939
- [Binary embeddings](examples/cohere/example.cpp) with Cohere
40+
- [Hybrid search](examples/hybrid/example.cpp) with llama.cpp (Reciprocal Rank Fusion)
4041
- [Sparse search](examples/sparse/example.cpp) with Text Embeddings Inference
4142
- [Morgan fingerprints](examples/rdkit/example.cpp) with RDKit
4243
- [Recommendations](examples/disco/example.cpp) with Disco

examples/hybrid/CMakeLists.txt

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
cmake_minimum_required(VERSION 3.18)
2+
3+
project(example)
4+
5+
set(CMAKE_CXX_STANDARD 17)
6+
7+
# for libpqxx
8+
set(CMAKE_CXX_FLAGS "-Wno-unknown-attributes")
9+
set(SKIP_BUILD_TEST ON)
10+
11+
include(FetchContent)
12+
13+
FetchContent_Declare(cpr GIT_REPOSITORY https://github.com/libcpr/cpr.git GIT_TAG 1.11.1)
14+
FetchContent_Declare(json GIT_REPOSITORY https://github.com/nlohmann/json.git GIT_TAG v3.11.3)
15+
FetchContent_Declare(libpqxx GIT_REPOSITORY https://github.com/jtv/libpqxx.git GIT_TAG 7.10.0)
16+
FetchContent_MakeAvailable(cpr json libpqxx)
17+
18+
add_executable(example example.cpp)
19+
target_include_directories(example PRIVATE ${CMAKE_SOURCE_DIR}/../../include)
20+
target_link_libraries(example PRIVATE cpr::cpr libpqxx::pqxx nlohmann_json::nlohmann_json)

examples/hybrid/example.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// run with
2+
// llama-server -hf nomic-ai/nomic-embed-text-v1.5-GGUF --embedding --pooling mean
3+
4+
#include <cstdint>
5+
#include <iostream>
6+
7+
#include <cpr/cpr.h>
8+
#include <nlohmann/json.hpp>
9+
#include <pgvector/pqxx.hpp>
10+
#include <pqxx/pqxx>
11+
12+
using json = nlohmann::json;
13+
14+
std::vector<std::vector<float>> fetch_embeddings(const std::vector<std::string>& input) {
15+
std::string url = "http://localhost:8080/v1/embeddings";
16+
json data = {
17+
{"input", input}
18+
};
19+
20+
cpr::Response r = cpr::Post(
21+
cpr::Url{url},
22+
cpr::Body{data.dump()},
23+
cpr::Header{{"Content-Type", "application/json"}}
24+
);
25+
json response = json::parse(r.text);
26+
27+
std::vector<std::vector<float>> embeddings;
28+
for (auto& v: response["data"]) {
29+
embeddings.emplace_back(v["embedding"]);
30+
}
31+
return embeddings;
32+
}
33+
34+
int main() {
35+
pqxx::connection conn("dbname=pgvector_example");
36+
37+
pqxx::work tx(conn);
38+
tx.exec("CREATE EXTENSION IF NOT EXISTS vector");
39+
tx.exec("DROP TABLE IF EXISTS documents");
40+
tx.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(768))");
41+
tx.exec("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))");
42+
tx.commit();
43+
44+
std::vector<std::string> input = {
45+
"The dog is barking",
46+
"The cat is purring",
47+
"The bear is growling"
48+
};
49+
auto embeddings = fetch_embeddings(input);
50+
51+
for (size_t i = 0; i < input.size(); i++) {
52+
tx.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2)", pqxx::params{input[i], pgvector::Vector(embeddings[i])});
53+
}
54+
tx.commit();
55+
56+
std::string sql = R"(
57+
WITH semantic_search AS (
58+
SELECT id, RANK () OVER (ORDER BY embedding <=> $2) AS rank
59+
FROM documents
60+
ORDER BY embedding <=> $2
61+
LIMIT 20
62+
),
63+
keyword_search AS (
64+
SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
65+
FROM documents, plainto_tsquery('english', $1) query
66+
WHERE to_tsvector('english', content) @@ query
67+
ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
68+
LIMIT 20
69+
)
70+
SELECT
71+
COALESCE(semantic_search.id, keyword_search.id) AS id,
72+
COALESCE(1.0 / ($3 + semantic_search.rank), 0.0) +
73+
COALESCE(1.0 / ($3 + keyword_search.rank), 0.0) AS score
74+
FROM semantic_search
75+
FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
76+
ORDER BY score DESC
77+
LIMIT 5
78+
)";
79+
std::string query = "growling bear";
80+
auto query_embedding = fetch_embeddings({query})[0];
81+
double k = 60;
82+
pqxx::result result = tx.exec(sql, pqxx::params{query, pgvector::Vector(query_embedding), k});
83+
for (const auto& row : result) {
84+
std::cout << "document: " << row[0].as<std::string>() << ", RRF score: " << row[1].as<double>() << std::endl;
85+
}
86+
87+
return 0;
88+
}

0 commit comments

Comments
 (0)