Skip to content

Commit 0012b62

Browse files
committed
Added from_f32_slice function to HalfVector
1 parent 84103ca commit 0012b62

File tree

6 files changed

+35
-87
lines changed

6 files changed

+35
-87
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
## 0.4.1 (unreleased)
2+
3+
- Added `from_f32_slice` function to `HalfVector`
4+
15
## 0.4.0 (2024-07-28)
26

37
- Added support for SQLx 0.8

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ let slice = vec.as_slice();
287287

288288
Note: Use the `halfvec` feature to enable half vectors
289289

290-
Create a half vector
290+
Create a half vector from a `Vec<f16>`
291291

292292
```rust
293293
use half::f16;
@@ -296,6 +296,12 @@ use pgvector::HalfVector;
296296
let vec = HalfVector::from(vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]);
297297
```
298298

299+
or a `f32` slice [unreleased]
300+
301+
```rust
302+
let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
303+
```
304+
299305
Convert to a `Vec<f16>`
300306

301307
```rust

src/diesel_ext/halfvec.rs

Lines changed: 11 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ impl FromSql<HalfVectorType, Pg> for HalfVector {
3636
mod tests {
3737
use crate::{HalfVector, VectorExpressionMethods};
3838
use diesel::prelude::*;
39-
use half::f16;
4039

4140
table! {
4241
use diesel::sql_types::*;
@@ -74,25 +73,13 @@ mod tests {
7473

7574
let new_items = vec![
7675
NewItem {
77-
embedding: Some(HalfVector::from(vec![
78-
f16::from_f32(1.0),
79-
f16::from_f32(1.0),
80-
f16::from_f32(1.0),
81-
])),
76+
embedding: Some(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])),
8277
},
8378
NewItem {
84-
embedding: Some(HalfVector::from(vec![
85-
f16::from_f32(2.0),
86-
f16::from_f32(2.0),
87-
f16::from_f32(2.0),
88-
])),
79+
embedding: Some(HalfVector::from_f32_slice(&[2.0, 2.0, 2.0])),
8980
},
9081
NewItem {
91-
embedding: Some(HalfVector::from(vec![
92-
f16::from_f32(1.0),
93-
f16::from_f32(1.0),
94-
f16::from_f32(2.0),
95-
])),
82+
embedding: Some(HalfVector::from_f32_slice(&[1.0, 1.0, 2.0])),
9683
},
9784
NewItem { embedding: None },
9885
];
@@ -105,32 +92,20 @@ mod tests {
10592
assert_eq!(4, all.len());
10693

10794
let neighbors = items::table
108-
.order(items::embedding.l2_distance(HalfVector::from(vec![
109-
f16::from_f32(1.0),
110-
f16::from_f32(1.0),
111-
f16::from_f32(1.0),
112-
])))
95+
.order(items::embedding.l2_distance(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])))
11396
.limit(5)
11497
.load::<Item>(&mut conn)?;
11598
assert_eq!(
11699
vec![1, 3, 2, 4],
117100
neighbors.iter().map(|v| v.id).collect::<Vec<i32>>()
118101
);
119102
assert_eq!(
120-
Some(HalfVector::from(vec![
121-
f16::from_f32(1.0),
122-
f16::from_f32(1.0),
123-
f16::from_f32(1.0)
124-
])),
103+
Some(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])),
125104
neighbors.first().unwrap().embedding
126105
);
127106

128107
let neighbors = items::table
129-
.order(items::embedding.max_inner_product(HalfVector::from(vec![
130-
f16::from_f32(1.0),
131-
f16::from_f32(1.0),
132-
f16::from_f32(1.0),
133-
])))
108+
.order(items::embedding.max_inner_product(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])))
134109
.limit(5)
135110
.load::<Item>(&mut conn)?;
136111
assert_eq!(
@@ -139,11 +114,7 @@ mod tests {
139114
);
140115

141116
let neighbors = items::table
142-
.order(items::embedding.cosine_distance(HalfVector::from(vec![
143-
f16::from_f32(1.0),
144-
f16::from_f32(1.0),
145-
f16::from_f32(1.0),
146-
])))
117+
.order(items::embedding.cosine_distance(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])))
147118
.limit(5)
148119
.load::<Item>(&mut conn)?;
149120
assert_eq!(
@@ -152,11 +123,7 @@ mod tests {
152123
);
153124

154125
let neighbors = items::table
155-
.order(items::embedding.l1_distance(HalfVector::from(vec![
156-
f16::from_f32(1.0),
157-
f16::from_f32(1.0),
158-
f16::from_f32(1.0),
159-
])))
126+
.order(items::embedding.l1_distance(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])))
160127
.limit(5)
161128
.load::<Item>(&mut conn)?;
162129
assert_eq!(
@@ -165,11 +132,9 @@ mod tests {
165132
);
166133

167134
let distances = items::table
168-
.select(items::embedding.max_inner_product(HalfVector::from(vec![
169-
f16::from_f32(1.0),
170-
f16::from_f32(1.0),
171-
f16::from_f32(1.0),
172-
])))
135+
.select(
136+
items::embedding.max_inner_product(HalfVector::from_f32_slice(&[1.0, 1.0, 1.0])),
137+
)
173138
.order(items::id)
174139
.load::<Option<f64>>(&mut conn)?;
175140
assert_eq!(vec![Some(-3.0), Some(-6.0), Some(-4.0), None], distances);

src/halfvec.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ impl From<HalfVector> for Vec<f16> {
2525
}
2626

2727
impl HalfVector {
28+
/// Creates a half vector from a `f32` slice.
29+
pub fn from_f32_slice(slice: &[f32]) -> HalfVector {
30+
HalfVector(slice.iter().map(|v| f16::from_f32(*v)).collect())
31+
}
32+
2833
/// Returns a copy of the half vector as a `Vec<f16>`.
2934
pub fn to_vec(&self) -> Vec<f16> {
3035
self.0.clone()
@@ -76,11 +81,7 @@ mod tests {
7681

7782
#[test]
7883
fn test_to_vec() {
79-
let vec = HalfVector::from(vec![
80-
f16::from_f32(1.0),
81-
f16::from_f32(2.0),
82-
f16::from_f32(3.0),
83-
]);
84+
let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
8485
assert_eq!(
8586
vec.to_vec(),
8687
vec![f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]
@@ -89,11 +90,7 @@ mod tests {
8990

9091
#[test]
9192
fn test_as_slice() {
92-
let vec = HalfVector::from(vec![
93-
f16::from_f32(1.0),
94-
f16::from_f32(2.0),
95-
f16::from_f32(3.0),
96-
]);
93+
let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
9794
assert_eq!(
9895
vec.as_slice(),
9996
&[f16::from_f32(1.0), f16::from_f32(2.0), f16::from_f32(3.0)]

src/postgres_ext/halfvec.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,14 @@ mod tests {
5959
&[],
6060
)?;
6161

62-
let vec = HalfVector::from(vec![
63-
f16::from_f32(1.0),
64-
f16::from_f32(2.0),
65-
f16::from_f32(3.0),
66-
]);
67-
let vec2 = HalfVector::from(vec![
68-
f16::from_f32(4.0),
69-
f16::from_f32(5.0),
70-
f16::from_f32(6.0),
71-
]);
62+
let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
63+
let vec2 = HalfVector::from_f32_slice(&[4.0, 5.0, 6.0]);
7264
client.execute(
7365
"INSERT INTO postgres_half_items (embedding) VALUES ($1), ($2), (NULL)",
7466
&[&vec, &vec2],
7567
)?;
7668

77-
let query_vec = HalfVector::from(vec![
78-
f16::from_f32(3.0),
79-
f16::from_f32(1.0),
80-
f16::from_f32(2.0),
81-
]);
69+
let query_vec = HalfVector::from_f32_slice(&[3.0, 1.0, 2.0]);
8270
let row = client.query_one(
8371
"SELECT embedding FROM postgres_half_items ORDER BY embedding <-> $1 LIMIT 1",
8472
&[&query_vec],

src/sqlx_ext/halfvec.rs

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -65,27 +65,15 @@ mod tests {
6565
.execute(&pool)
6666
.await?;
6767

68-
let vec = HalfVector::from(vec![
69-
f16::from_f32(1.0),
70-
f16::from_f32(2.0),
71-
f16::from_f32(3.0),
72-
]);
73-
let vec2 = HalfVector::from(vec![
74-
f16::from_f32(4.0),
75-
f16::from_f32(5.0),
76-
f16::from_f32(6.0),
77-
]);
68+
let vec = HalfVector::from_f32_slice(&[1.0, 2.0, 3.0]);
69+
let vec2 = HalfVector::from_f32_slice(&[4.0, 5.0, 6.0]);
7870
sqlx::query("INSERT INTO sqlx_half_items (embedding) VALUES ($1), ($2), (NULL)")
7971
.bind(&vec)
8072
.bind(&vec2)
8173
.execute(&pool)
8274
.await?;
8375

84-
let query_vec = HalfVector::from(vec![
85-
f16::from_f32(3.0),
86-
f16::from_f32(1.0),
87-
f16::from_f32(2.0),
88-
]);
76+
let query_vec = HalfVector::from_f32_slice(&[3.0, 1.0, 2.0]);
8977
let row =
9078
sqlx::query("SELECT embedding FROM sqlx_half_items ORDER BY embedding <-> $1 LIMIT 1")
9179
.bind(query_vec)

0 commit comments

Comments
 (0)