Skip to content

Commit 752c01b

Browse files
committed
Refactor query logic and databackend
1 parent bc9ea39 commit 752c01b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+1943
-3569
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
#### New Features & Functionality
1818

1919
- Add type annotations as a way to declare schema
20+
- Deprecate "free" queries
21+
- Create simpler developer contract for databackend
2022

2123
#### Bug Fixes
2224

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
DIRECTORIES ?= superduper test
1+
DIRECTORIES ?= superduper test plugins
22
SUPERDUPER_CONFIG ?= test/configs/default.yaml
33
PYTEST_ARGUMENTS ?=
44
PLUGIN_NAME ?=

plugins/ibis/plugin_test/test_databackend.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,18 @@
22

33
import pytest
44
from superduper import CFG
5+
from superduper.misc.importing import load_plugin
56

67
from superduper_ibis.data_backend import IbisDataBackend
78

89

910
@pytest.fixture
1011
def databackend():
11-
backend = IbisDataBackend(CFG.data_backend)
12+
plugin = load_plugin('ibis')
13+
backend = IbisDataBackend(CFG.data_backend, plugin=plugin)
1214
yield backend
1315
backend.drop(True)
1416

1517

16-
def test_output_dest(databackend):
17-
db_utils.test_output_dest(databackend)
18-
19-
20-
def test_query_builder(databackend):
21-
db_utils.test_query_builder(databackend)
22-
23-
2418
def test_list_tables_or_collections(databackend):
2519
db_utils.test_list_tables_or_collections(databackend)

plugins/ibis/plugin_test/test_end_2_end.py

-39
Original file line numberDiff line numberDiff line change
@@ -141,42 +141,3 @@ def postprocess(x):
141141
# Get the results
142142
result = list(db.execute(q))
143143
assert listener2.outputs in result[0].unpack()
144-
145-
146-
def test_nested_query(db):
147-
memory_table = False
148-
if CFG.data_backend.endswith("csv"):
149-
memory_table = True
150-
schema = Schema(
151-
identifier="my_table",
152-
fields={
153-
"id": FieldType(identifier="int64"),
154-
"health": FieldType(identifier="int32"),
155-
"age": FieldType(identifier="int32"),
156-
},
157-
)
158-
159-
from superduper.components.table import Table
160-
161-
t = Table(identifier="my_table", schema=schema)
162-
163-
db.apply(t)
164-
165-
t = db["my_table"]
166-
q = t.filter(t.age >= 10)
167-
168-
expr_ = q.compile(db)
169-
170-
if not memory_table:
171-
assert 'WHERE "t0"."age" >=' in str(expr_)
172-
else:
173-
pass
174-
# TODO this doesn't test anything useful and
175-
# is sensitive to version changes
176-
# TODO refactor/ remove
177-
# assert 'Selection[r0]\n predicates:\n r0.age >= 10' in str(expr_)
178-
# assert (
179-
# 'my_table\n _fold string\n id '
180-
# 'int64\n health int32\n age '
181-
# 'int32\n image binary' in str(expr_)
182-
# )

plugins/ibis/plugin_test/test_query.py

+15-52
Original file line numberDiff line numberDiff line change
@@ -51,32 +51,16 @@ def test_renamings(db):
5151
add_listeners(db)
5252
t = db["documents"]
5353
listener_uuid = [db.load('listener', k).outputs for k in db.show("listener")][0]
54-
q = t.select("id", "x", "y").outputs(listener_uuid)
55-
data = list(db.execute(q))
54+
q = t.select("id", "x", "y").outputs(listener_uuid.split('__', 1)[-1])
55+
data = q.execute()
5656
assert isinstance(data[0].unpack()[listener_uuid], np.ndarray)
5757

5858

5959
def test_serialize_query(db):
60-
from superduper_ibis.query import IbisQuery
60+
t = db['documents']
61+
q = t.filter(t['id'] == 1).select('id', 'x')
6162

62-
t = IbisQuery(db=db, table="documents", parts=[("select", ("id",), {})])
63-
64-
q = t.filter(t.id == 1).select(t.id, t.x)
65-
66-
print(Document.decode(q.encode()).unpack())
67-
68-
69-
def test_add_fold(db):
70-
add_random_data(db, n=10)
71-
table = db["documents"]
72-
select_train = table.select("id", "x", "_fold").add_fold("train")
73-
result_train = db.execute(select_train)
74-
75-
select_valid = table.select("id", "x", "_fold").add_fold("valid")
76-
result_valid = db.execute(select_valid)
77-
result_train = list(result_train)
78-
result_valid = list(result_valid)
79-
assert len(result_train) + len(result_valid) == 10
63+
print(Document.decode(q.encode(), db=db).unpack())
8064

8165

8266
def test_get_data(db):
@@ -88,7 +72,7 @@ def test_get_data(db):
8872
def test_insert_select(db):
8973
add_random_data(db, n=5)
9074
q = db["documents"].select("id", "x", "y").limit(2)
91-
r = list(db.execute(q))
75+
r = q.execute()
9276

9377
assert len(r) == 2
9478
assert all(all([k in ["id", "x", "y"] for k in x.unpack().keys()]) for x in r)
@@ -98,43 +82,25 @@ def test_filter(db):
9882
add_random_data(db, n=5)
9983
t = db["documents"]
10084
q = t.select("id", "y")
101-
r = list(db.execute(q))
85+
r = q.execute()
10286
ys = [x["y"] for x in r]
10387
uq = np.unique(ys, return_counts=True)
10488

105-
q = t.select("id", "y").filter(t.y == uq[0][0])
106-
r = list(db.execute(q))
89+
q = t.select("id", "y").filter(t['y'] == uq[0][0])
90+
r = q.execute()
10791
assert len(r) == uq[1][0]
10892

10993

110-
def test_execute_complex_query_sqldb_auto_schema(db):
111-
import ibis
112-
113-
db.cfg.auto_schema = True
114-
115-
table = db["documents"]
116-
table.insert(
117-
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(100)]
118-
).execute()
119-
120-
cur = table.select("this").order_by(ibis.desc("this")).limit(10).execute(db)
121-
expected = [f"is a test {i}" for i in range(99, 89, -1)]
122-
cur_this = [r["this"] for r in cur]
123-
assert sorted(cur_this) == sorted(expected)
124-
125-
12694
def test_select_using_ids(db):
12795
db.cfg.auto_schema = True
12896

12997
table = db["documents"]
130-
table.insert(
131-
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)]
132-
).execute()
98+
table.insert([{"this": f"is a test {i}", "id": str(i)} for i in range(4)])
13399

134100
basic_select = db['documents'].select()
135101

136-
assert len(basic_select.tolist()) == 4
137-
assert len(basic_select.select_using_ids(['1', '2']).tolist()) == 2
102+
assert len(basic_select.execute()) == 4
103+
assert len(basic_select.subset(['1', '2'])) == 2
138104

139105

140106
def test_select_using_ids_of_outputs(db):
@@ -147,21 +113,18 @@ def my_func(x):
147113
db.cfg.auto_schema = True
148114

149115
table = db["documents"]
150-
table.insert(
151-
[Document({"this": f"is a test {i}", "id": str(i)}) for i in range(4)]
152-
).execute()
116+
table.insert([{"this": f"is a test {i}", "id": str(i)} for i in range(4)])
153117

154118
listener = my_func.to_listener(key='this', select=db['documents'].select())
155119
db.apply(listener)
156120

157121
q1 = db[listener.outputs].select()
158-
r1 = q1.tolist()
122+
r1 = q1.execute()
159123

160124
assert len(r1) == 4
161125

162126
ids = [x['id'] for x in r1]
163127

164-
q2 = q1.select_using_ids(ids[:2])
165-
r2 = q2.tolist()
128+
r2 = q1.subset(ids[:2])
166129

167130
assert len(r2) == 2
+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .data_backend import IbisDataBackend as DataBackend
2-
from .query import IbisQuery
32

43
__version__ = "0.5.1"
54

6-
__all__ = ["IbisQuery", "DataBackend"]
5+
__all__ = ["DataBackend"]

0 commit comments

Comments
 (0)