Skip to content

Commit b5d607c

Browse files
committed
Added vector distance support to postgres compiler
1 parent 2257d60 commit b5d607c

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/mako/database/query/compilers/Postgres.php

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77

88
namespace mako\database\query\compilers;
99

10+
use mako\database\query\VectorMetric;
1011
use Override;
1112

1213
use function array_pop;
1314
use function implode;
15+
use function is_array;
1416
use function is_numeric;
17+
use function json_encode;
1518
use function str_replace;
1619

1720
/**
@@ -55,6 +58,22 @@ protected function buildJsonSet(string $column, array $segments, string $param):
5558
return $column . " = JSONB_SET({$column}, '{" . str_replace("'", "''", implode(',', $segments)) . "}', '{$param}')";
5659
}
5760

61+
/**
62+
* {@inheritDoc}
63+
*/
64+
#[Override]
65+
public function whereVectorDistance(array $where): string
66+
{
67+
$vector = is_array($where['vector']) ? json_encode($where['vector']) : $where['vector'];
68+
69+
$operator = match ($where['metric']) {
70+
VectorMetric::COSINE => '<=>',
71+
VectorMetric::EUCLIDEAN => '<->',
72+
};
73+
74+
return "{$this->column($where['column'], false)} {$operator} {$this->param($vector)} <= {$this->param($where['distance'])}";
75+
}
76+
5877
/**
5978
* {@inheritDoc}
6079
*/

tests/unit/database/query/compilers/PostgresCompilerTest.php

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use mako\database\query\compilers\Postgres as PostgresCompiler;
1212
use mako\database\query\helpers\HelperInterface;
1313
use mako\database\query\Query;
14+
use mako\database\query\VectorMetric;
1415
use mako\tests\TestCase;
1516
use Mockery;
1617
use Mockery\MockInterface;
@@ -204,6 +205,36 @@ public function testUpdateWithJSONColumn(): void
204205
$this->assertEquals([1], $query['params']);
205206
}
206207

208+
/**
209+
*
210+
*/
211+
public function testBasicCosineWhereVectorDistance(): void
212+
{
213+
$query = $this->getBuilder();
214+
215+
$query = $query->table('foobar')
216+
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5)
217+
->getCompiler()->select();
218+
219+
$this->assertEquals('SELECT * FROM "foobar" WHERE "embedding" <=> ? <= ?', $query['sql']);
220+
$this->assertEquals(['[1,2,3,4,5]', 0.5], $query['params']);
221+
}
222+
223+
/**
224+
*
225+
*/
226+
public function testBasicEuclidianWhereVectorDistance(): void
227+
{
228+
$query = $this->getBuilder();
229+
230+
$query = $query->table('foobar')
231+
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorMetric: VectorMetric::EUCLIDEAN)
232+
->getCompiler()->select();
233+
234+
$this->assertEquals('SELECT * FROM "foobar" WHERE "embedding" <-> ? <= ?', $query['sql']);
235+
$this->assertEquals(['[1,2,3,4,5]', 0.5], $query['params']);
236+
}
237+
207238
/**
208239
*
209240
*/

0 commit comments

Comments
 (0)