-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhybrid.zig
115 lines (99 loc) · 4.27 KB
/
hybrid.zig
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
const pg = @import("pg");
const std = @import("std");
const Embeddings = struct {
parsed: std.json.Parsed(ApiResponse),
const ApiResponse = struct {
embeddings: []const []const f32,
};
pub fn deinit(self: *Embeddings) void {
self.parsed.deinit();
}
pub fn get(self: *Embeddings, index: usize) ?[]const f32 {
const data = self.parsed.value.embeddings;
return if (index < data.len) data[index] else null;
}
};
fn embed(allocator: std.mem.Allocator, input: []const []const u8, _: []const u8) !Embeddings {
var client = std.http.Client{ .allocator = allocator };
defer client.deinit();
// TODO nomic-embed-text uses a task prefix
// https://huggingface.co/nomic-ai/nomic-embed-text-v1.5
const uri = try std.Uri.parse("http://localhost:11434/api/embed");
const data = .{
.input = input,
.model = "nomic-embed-text",
};
var buf: [16 * 1024]u8 = undefined;
var req = try client.open(.POST, uri, .{ .server_header_buffer = &buf });
defer req.deinit();
req.headers = .{
.content_type = .{ .override = "application/json" },
};
req.transfer_encoding = .chunked;
try req.send();
try std.json.stringify(data, .{}, req.writer());
try req.finish();
try req.wait();
std.debug.assert(req.response.status == .ok);
var rdr = std.json.reader(allocator, req.reader());
defer rdr.deinit();
const parsed = try std.json.parseFromTokenSource(Embeddings.ApiResponse, allocator, &rdr, .{ .allocate = .alloc_always, .ignore_unknown_fields = true });
return Embeddings{ .parsed = parsed };
}
pub fn main() !void {
var gpa = std.heap.GeneralPurposeAllocator(.{}).init;
defer std.debug.assert(gpa.deinit() == .ok);
const allocator = gpa.allocator();
var pool = try pg.Pool.init(allocator, .{ .auth = .{
.username = std.posix.getenv("USER").?,
.database = "pgvector_example",
} });
defer pool.deinit();
const conn = try pool.acquire();
defer pool.release(conn);
_ = try conn.exec("CREATE EXTENSION IF NOT EXISTS vector", .{});
_ = try conn.exec("DROP TABLE IF EXISTS documents", .{});
_ = try conn.exec("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(768))", .{});
_ = try conn.exec("CREATE INDEX ON documents USING GIN (to_tsvector('english', content))", .{});
const documents = [_][]const u8{ "The dog is barking", "The cat is purring", "The bear is growling" };
var documentEmbeddings = try embed(allocator, &documents, "search_document");
defer documentEmbeddings.deinit();
for (&documents, 0..) |content, i| {
const params = .{ content, documentEmbeddings.get(i) };
_ = try conn.exec("INSERT INTO documents (content, embedding) VALUES ($1, $2::float4[])", params);
}
const sql =
\\WITH semantic_search AS (
\\ SELECT id, RANK () OVER (ORDER BY embedding <=> $2::float4[]::vector) AS rank
\\ FROM documents
\\ ORDER BY embedding <=> $2::float4[]::vector
\\ LIMIT 20
\\),
\\keyword_search AS (
\\ SELECT id, RANK () OVER (ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC)
\\ FROM documents, plainto_tsquery('english', $1) query
\\ WHERE to_tsvector('english', content) @@ query
\\ ORDER BY ts_rank_cd(to_tsvector('english', content), query) DESC
\\ LIMIT 20
\\)
\\SELECT
\\ COALESCE(semantic_search.id, keyword_search.id) AS id,
\\ COALESCE(1.0 / ($3 + semantic_search.rank), 0.0) +
\\ COALESCE(1.0 / ($3 + keyword_search.rank), 0.0) AS score
\\FROM semantic_search
\\FULL OUTER JOIN keyword_search ON semantic_search.id = keyword_search.id
\\ORDER BY score DESC
\\LIMIT 5
;
const query = "growling bear";
var queryEmbeddings = try embed(allocator, &[_][]const u8{query}, "search_query");
const k = 60;
defer queryEmbeddings.deinit();
var result = try conn.query(sql, .{ query, queryEmbeddings.get(0), k });
defer result.deinit();
while (try result.next()) |row| {
const id = row.get(i64, 0);
const score = row.get(f64, 1);
std.debug.print("document: {d} | RRF score: {d}\n", .{ id, score });
}
}