|
| 1 | +import numpy as np |
| 2 | +import os |
| 3 | +from pgvector import HalfVector, SparseVector, Vector |
| 4 | +from pgvector.pg8000 import register_vector |
| 5 | +from pg8000.native import Connection |
| 6 | + |
| 7 | +conn = Connection(os.environ["USER"], database='pgvector_python_test') |
| 8 | + |
| 9 | +conn.run('CREATE EXTENSION IF NOT EXISTS vector') |
| 10 | +conn.run('DROP TABLE IF EXISTS pg8000_items') |
| 11 | +conn.run('CREATE TABLE pg8000_items (id bigserial PRIMARY KEY, embedding vector(3), half_embedding halfvec(3), binary_embedding bit(3), sparse_embedding sparsevec(3), embeddings vector[], half_embeddings halfvec[], sparse_embeddings sparsevec[])') |
| 12 | + |
| 13 | +register_vector(conn) |
| 14 | + |
| 15 | + |
| 16 | +class TestPg8000: |
| 17 | + def setup_method(self): |
| 18 | + conn.run('DELETE FROM pg8000_items') |
| 19 | + |
| 20 | + def test_vector(self): |
| 21 | + embedding = np.array([1.5, 2, 3]) |
| 22 | + conn.run('INSERT INTO pg8000_items (embedding) VALUES (:embedding), (NULL)', embedding=embedding) |
| 23 | + |
| 24 | + res = conn.run('SELECT embedding FROM pg8000_items ORDER BY id') |
| 25 | + assert np.array_equal(res[0][0], embedding) |
| 26 | + assert res[0][0].dtype == np.float32 |
| 27 | + assert res[1][0] is None |
| 28 | + |
| 29 | + def test_vector_class(self): |
| 30 | + embedding = Vector([1.5, 2, 3]) |
| 31 | + conn.run('INSERT INTO pg8000_items (embedding) VALUES (:embedding), (NULL)', embedding=embedding) |
| 32 | + |
| 33 | + res = conn.run('SELECT embedding FROM pg8000_items ORDER BY id') |
| 34 | + assert np.array_equal(res[0][0], embedding.to_numpy()) |
| 35 | + assert res[0][0].dtype == np.float32 |
| 36 | + assert res[1][0] is None |
| 37 | + |
| 38 | + def test_halfvec(self): |
| 39 | + embedding = HalfVector([1.5, 2, 3]) |
| 40 | + conn.run('INSERT INTO pg8000_items (half_embedding) VALUES (:embedding), (NULL)', embedding=embedding) |
| 41 | + |
| 42 | + res = conn.run('SELECT half_embedding FROM pg8000_items ORDER BY id') |
| 43 | + assert res[0][0] == embedding |
| 44 | + assert res[1][0] is None |
| 45 | + |
| 46 | + def test_bit(self): |
| 47 | + embedding = '101' |
| 48 | + conn.run('INSERT INTO pg8000_items (binary_embedding) VALUES (:embedding), (NULL)', embedding=embedding) |
| 49 | + |
| 50 | + res = conn.run('SELECT binary_embedding FROM pg8000_items ORDER BY id') |
| 51 | + assert res[0][0] == '101' |
| 52 | + assert res[1][0] is None |
| 53 | + |
| 54 | + def test_sparsevec(self): |
| 55 | + embedding = SparseVector([1.5, 2, 3]) |
| 56 | + conn.run('INSERT INTO pg8000_items (sparse_embedding) VALUES (:embedding), (NULL)', embedding=embedding) |
| 57 | + |
| 58 | + res = conn.run('SELECT sparse_embedding FROM pg8000_items ORDER BY id') |
| 59 | + assert res[0][0] == embedding |
| 60 | + assert res[1][0] is None |
0 commit comments