Skip to content

Commit 541d30d

Browse files
committed
Fix VectorDistance enum
1 parent dcd3b93 commit 541d30d

9 files changed

Lines changed: 38 additions & 29 deletions

File tree

src/mako/database/query/Query.php

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ public function orWhereColumn(array|string $column1, string $operator, array|str
647647
*
648648
* @return $this
649649
*/
650-
public function whereVectorDistance(string $column, array|string|Subquery $vector, float $maxDistance, VectorDistance $vectorDistance = VectorDistance::COSINE, string $separator = 'AND'): static
650+
public function whereVectorDistance(string $column, array|string|Subquery $vector, float $maxDistance, VectorDistance $vectorDistance = VectorDistance::Cosine, string $separator = 'AND'): static
651651
{
652652
$this->wheres[] = [
653653
'type' => 'whereVectorDistance',
@@ -666,7 +666,7 @@ public function whereVectorDistance(string $column, array|string|Subquery $vecto
666666
*
667667
* @return $this
668668
*/
669-
public function orWhereVectorDistance(string $column, array|string|Subquery $vector, float $maxDistance, VectorDistance $vectorDistance = VectorDistance::COSINE): static
669+
public function orWhereVectorDistance(string $column, array|string|Subquery $vector, float $maxDistance, VectorDistance $vectorDistance = VectorDistance::Cosine): static
670670
{
671671
return $this->whereVectorDistance($column, $vector, $maxDistance, $vectorDistance, 'OR');
672672
}
@@ -1130,7 +1130,7 @@ public function descendingRaw(string $raw, array $parameters = []): static
11301130
/**
11311131
* Adds a vector ORDER BY clause.
11321132
*/
1133-
public function orderByVectorDistance(string $column, array|string|Subquery $vector, VectorDistance $vectorDistance = VectorDistance::COSINE, string $order = 'ASC'): static
1133+
public function orderByVectorDistance(string $column, array|string|Subquery $vector, VectorDistance $vectorDistance = VectorDistance::Cosine, string $order = 'ASC'): static
11341134
{
11351135
$this->orderings[] = [
11361136
'type' => 'vectorDistanceOrdering',
@@ -1146,15 +1146,15 @@ public function orderByVectorDistance(string $column, array|string|Subquery $vec
11461146
/**
11471147
* Adds a ascending vector ORDER BY clause.
11481148
*/
1149-
public function ascendingVectorDistance(string $column, array|string|Subquery $vector, VectorDistance $vectorDistance = VectorDistance::COSINE): static
1149+
public function ascendingVectorDistance(string $column, array|string|Subquery $vector, VectorDistance $vectorDistance = VectorDistance::Cosine): static
11501150
{
11511151
return $this->orderByVectorDistance($column, $vector, $vectorDistance, 'ASC');
11521152
}
11531153

11541154
/**
11551155
* Adds a descending vector ORDER BY clause.
11561156
*/
1157-
public function descendingVectorDistance(string $column, array|string|Subquery $vector, VectorDistance $vectorDistance = VectorDistance::COSINE): static
1157+
public function descendingVectorDistance(string $column, array|string|Subquery $vector, VectorDistance $vectorDistance = VectorDistance::Cosine): static
11581158
{
11591159
return $this->orderByVectorDistance($column, $vector, $vectorDistance, 'DESC');
11601160
}

src/mako/database/query/VectorDistance.php

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,20 @@
77

88
namespace mako\database\query;
99

10+
use Deprecated;
11+
1012
/**
1113
* Vector distance.
1214
*/
1315
enum VectorDistance
1416
{
15-
case COSINE;
16-
case EUCLIDEAN;
17+
/* Start compatibility */
18+
#[Deprecated('use VectorDistance::Cosine instead', 'Mako 12.2.0')]
19+
public const COSINE = self::Cosine;
20+
#[Deprecated('use VectorDistance::Euclidean instead', 'Mako 12.2.0')]
21+
public const EUCLIDEAN = self::Euclidean;
22+
/* End compatibility */
23+
24+
case Cosine;
25+
case Euclidean;
1726
}

src/mako/database/query/compilers/MariaDB.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ protected function vectorDistance(array $vectorDistance): string
3838
}
3939

4040
$function = match ($vectorDistance['vectorDistance']) {
41-
VectorDistance::COSINE => 'VEC_DISTANCE_COSINE',
42-
VectorDistance::EUCLIDEAN => 'VEC_DISTANCE_EUCLIDEAN',
41+
VectorDistance::Cosine => 'VEC_DISTANCE_COSINE',
42+
VectorDistance::Euclidean => 'VEC_DISTANCE_EUCLIDEAN',
4343
};
4444

4545
return "{$function}({$this->column($vectorDistance['column'], false)}, {$vector})";

src/mako/database/query/compilers/MySQL.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ protected function vectorDistance(array $vectorDistance): string
7575
}
7676

7777
$function = match ($vectorDistance['vectorDistance']) {
78-
VectorDistance::COSINE => 'COSINE',
79-
VectorDistance::EUCLIDEAN => 'EUCLIDEAN',
78+
VectorDistance::Cosine => 'COSINE',
79+
VectorDistance::Euclidean => 'EUCLIDEAN',
8080
};
8181

8282
return "DISTANCE({$this->column($vectorDistance['column'], false)}, {$vector}, '{$function}')";

src/mako/database/query/compilers/Postgres.php

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ protected function vectorDistance(array $vectorDistance): string
7878
}
7979

8080
$function = match ($vectorDistance['vectorDistance']) {
81-
VectorDistance::COSINE => '<=>',
82-
VectorDistance::EUCLIDEAN => '<->',
81+
VectorDistance::Cosine => '<=>',
82+
VectorDistance::Euclidean => '<->',
8383
};
8484

8585
return "{$this->column($vectorDistance['column'], false)} {$function} {$vector}";

src/mako/database/query/values/out/VectorDistance.php

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class VectorDistance extends Value
3030
public function __construct(
3131
protected string $column,
3232
protected array|string $vector,
33-
protected VectorDistanceType $vectorDistance = VectorDistanceType::COSINE,
33+
protected VectorDistanceType $vectorDistance = VectorDistanceType::Cosine,
3434
protected ?string $alias = null
3535
) {
3636
}
@@ -41,8 +41,8 @@ public function __construct(
4141
protected function getMariaDbDistance(Compiler $compiler): string
4242
{
4343
$function = match ($this->vectorDistance) {
44-
VectorDistanceType::COSINE => 'VEC_DISTANCE_COSINE',
45-
VectorDistanceType::EUCLIDEAN => 'VEC_DISTANCE_EUCLIDEAN',
44+
VectorDistanceType::Cosine => 'VEC_DISTANCE_COSINE',
45+
VectorDistanceType::Euclidean => 'VEC_DISTANCE_EUCLIDEAN',
4646
};
4747

4848
return "{$function}({$compiler->escapeColumnName($this->column)}, VEC_FromText(?))";
@@ -54,8 +54,8 @@ protected function getMariaDbDistance(Compiler $compiler): string
5454
protected function getMySqlDistance(Compiler $compiler): string
5555
{
5656
$function = match ($this->vectorDistance) {
57-
VectorDistanceType::COSINE => 'COSINE',
58-
VectorDistanceType::EUCLIDEAN => 'EUCLIDEAN',
57+
VectorDistanceType::Cosine => 'COSINE',
58+
VectorDistanceType::Euclidean => 'EUCLIDEAN',
5959
};
6060

6161
return "DISTANCE({$compiler->escapeColumnName($this->column)}, STRING_TO_VECTOR(?), '{$function}')";
@@ -67,8 +67,8 @@ protected function getMySqlDistance(Compiler $compiler): string
6767
protected function getPostgresDistance(Compiler $compiler): string
6868
{
6969
$function = match ($this->vectorDistance) {
70-
VectorDistanceType::COSINE => '<=>',
71-
VectorDistanceType::EUCLIDEAN => '<->',
70+
VectorDistanceType::Cosine => '<=>',
71+
VectorDistanceType::Euclidean => '<->',
7272
};
7373

7474
return "{$compiler->columnName($this->column)} {$function} ?";

tests/unit/database/query/compilers/MariaDBCompilerTest.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public function testBasicEuclidianWhereVectorDistance(): void
7171
$query = $this->getBuilder();
7272

7373
$query = $query->table('foobar')
74-
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorDistance: VectorDistance::EUCLIDEAN)
74+
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorDistance: VectorDistance::Euclidean)
7575
->getCompiler()->select();
7676

7777
$this->assertEquals('SELECT * FROM `foobar` WHERE VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(?)) <= ?', $query['sql']);
@@ -165,7 +165,7 @@ public function testOrderByVectorDistanceEuclidean(): void
165165
$query = $this->getBuilder();
166166

167167
$query = $query->table('foobar')
168-
->orderByVectorDistance('embedding', [1, 2, 3, 4, 5], VectorDistance::EUCLIDEAN)
168+
->orderByVectorDistance('embedding', [1, 2, 3, 4, 5], VectorDistance::Euclidean)
169169
->getCompiler()->select();
170170

171171
$this->assertEquals('SELECT * FROM `foobar` ORDER BY VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(?)) ASC', $query['sql']);
@@ -255,7 +255,7 @@ public function testVectorEuclideanDistanceSelectValue(): void
255255
$query = $this->getBuilder();
256256

257257
$query = $query->table('foobar')
258-
->select([new OutVectorDistance('embedding', [1, 2, 3, 4], VectorDistance::EUCLIDEAN)])
258+
->select([new OutVectorDistance('embedding', [1, 2, 3, 4], VectorDistance::Euclidean)])
259259
->getCompiler()->select();
260260

261261
$this->assertEquals('SELECT VEC_DISTANCE_EUCLIDEAN(`embedding`, VEC_FromText(?)) FROM `foobar`', $query['sql']);

tests/unit/database/query/compilers/MySQLCompilerTest.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ public function testBasicEuclidianWhereVectorDistance(): void
260260
$query = $this->getBuilder();
261261

262262
$query = $query->table('foobar')
263-
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorDistance: VectorDistance::EUCLIDEAN)
263+
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorDistance: VectorDistance::Euclidean)
264264
->getCompiler()->select();
265265

266266
$this->assertEquals("SELECT * FROM `foobar` WHERE DISTANCE(`embedding`, STRING_TO_VECTOR(?), 'EUCLIDEAN') <= ?", $query['sql']);
@@ -354,7 +354,7 @@ public function testOrderByVectorDistanceEuclidean(): void
354354
$query = $this->getBuilder();
355355

356356
$query = $query->table('foobar')
357-
->orderByVectorDistance('embedding', [1, 2, 3, 4, 5], VectorDistance::EUCLIDEAN)
357+
->orderByVectorDistance('embedding', [1, 2, 3, 4, 5], VectorDistance::Euclidean)
358358
->getCompiler()->select();
359359

360360
$this->assertEquals("SELECT * FROM `foobar` ORDER BY DISTANCE(`embedding`, STRING_TO_VECTOR(?), 'EUCLIDEAN') ASC", $query['sql']);
@@ -444,7 +444,7 @@ public function testVectorEuclideanDistanceSelectValue(): void
444444
$query = $this->getBuilder();
445445

446446
$query = $query->table('foobar')
447-
->select([new OutVectorDistance('embedding', [1, 2, 3, 4], VectorDistance::EUCLIDEAN)])
447+
->select([new OutVectorDistance('embedding', [1, 2, 3, 4], VectorDistance::Euclidean)])
448448
->getCompiler()->select();
449449

450450
$this->assertEquals("SELECT DISTANCE(`embedding`, STRING_TO_VECTOR(?), 'EUCLIDEAN') FROM `foobar`", $query['sql']);

tests/unit/database/query/compilers/PostgresCompilerTest.php

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ public function testBasicEuclidianWhereVectorDistance(): void
232232
$query = $this->getBuilder();
233233

234234
$query = $query->table('foobar')
235-
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorDistance: VectorDistance::EUCLIDEAN)
235+
->whereVectorDistance('embedding', [1, 2, 3, 4, 5], maxDistance: 0.5, vectorDistance: VectorDistance::Euclidean)
236236
->getCompiler()->select();
237237

238238
$this->assertEquals('SELECT * FROM "foobar" WHERE "embedding" <-> ? <= ?', $query['sql']);
@@ -326,7 +326,7 @@ public function testOrderByVectorDistanceEuclidean(): void
326326
$query = $this->getBuilder();
327327

328328
$query = $query->table('foobar')
329-
->orderByVectorDistance('embedding', [1, 2, 3, 4, 5], VectorDistance::EUCLIDEAN)
329+
->orderByVectorDistance('embedding', [1, 2, 3, 4, 5], VectorDistance::Euclidean)
330330
->getCompiler()->select();
331331

332332
$this->assertEquals('SELECT * FROM "foobar" ORDER BY "embedding" <-> ? ASC', $query['sql']);
@@ -416,7 +416,7 @@ public function testVectorEuclideanDistanceSelectValue(): void
416416
$query = $this->getBuilder();
417417

418418
$query = $query->table('foobar')
419-
->select([new OutVectorDistance('embedding', [1, 2, 3, 4], VectorDistance::EUCLIDEAN)])
419+
->select([new OutVectorDistance('embedding', [1, 2, 3, 4], VectorDistance::Euclidean)])
420420
->getCompiler()->select();
421421

422422
$this->assertEquals('SELECT "embedding" <-> ? FROM "foobar"', $query['sql']);

0 commit comments

Comments
 (0)