@@ -528,6 +528,22 @@ def test_binary_quantize(self, engine):
528
528
items = session .query (Item ).order_by (distance ).all ()
529
529
assert [v .id for v in items ] == [2 , 3 , 1 ]
530
530
531
+ def test_binary_quantize_reranking (self , engine ):
532
+ # recreate index (could also vacuum table)
533
+ binary_quantize_index .drop (setup_engine )
534
+ binary_quantize_index .create (setup_engine )
535
+
536
+ with Session (engine ) as session :
537
+ session .add (Item (id = 1 , embedding = [- 1 , - 2 , - 3 ]))
538
+ session .add (Item (id = 2 , embedding = [1 , - 2 , 3 ]))
539
+ session .add (Item (id = 3 , embedding = [1 , 2 , 3 ]))
540
+ session .commit ()
541
+
542
+ distance = func .cast (func .binary_quantize (Item .embedding ), BIT (3 )).hamming_distance (func .binary_quantize (func .cast ([3 , - 1 , 2 ], VECTOR (3 ))))
543
+ subquery = session .query (Item ).order_by (distance ).limit (20 ).subquery ()
544
+ items = session .query (subquery ).order_by (subquery .c .embedding .cosine_distance ([3 , - 1 , 2 ])).limit (5 ).all ()
545
+ assert [v .id for v in items ] == [2 , 3 , 1 ]
546
+
531
547
532
548
@pytest .mark .parametrize ('engine' , array_engines )
533
549
class TestSqlalchemyArray :
0 commit comments