Skip to content

Commit 7802049

Browse files
Add support for DuckDB (#27)
1 parent b68ec48 commit 7802049

File tree

6 files changed

+529
-98
lines changed

6 files changed

+529
-98
lines changed

setup.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from setuptools import setup
22
import os
3-
4-
VERSION = "0.1.0a10"
3+
from tsellm import __version__
54

65

76
def get_long_description():
@@ -14,7 +13,7 @@ def get_long_description():
1413

1514
setup(
1615
name="tsellm",
17-
description="Interactive SQLite shell with LLM support",
16+
description=__version__.__description__,
1817
long_description=get_long_description(),
1918
long_description_content_type="text/markdown",
2019
author="Florents Tselai",
@@ -29,9 +28,9 @@ def get_long_description():
2928
"Changelog": "https://github.com/Florents-Tselai/tsellm/releases",
3029
},
3130
license="BSD License",
32-
version=VERSION,
31+
version=__version__.__version__,
3332
packages=["tsellm"],
34-
install_requires=["llm", "setuptools", "pip"],
33+
install_requires=["llm", "setuptools", "pip", "duckdb"],
3534
extras_require={
3635
"test": [
3736
"pytest",

tests/test_tsellm.py

+218-28
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,46 @@
1-
import llm.cli
2-
from sqlite_utils import Database
3-
from tsellm.cli import cli
1+
import sqlite3
2+
import tempfile
43
import unittest
5-
from test.support import captured_stdout, captured_stderr, captured_stdin, os_helper
4+
from pathlib import Path
5+
from test.support import captured_stdout, captured_stderr, captured_stdin
66
from test.support.os_helper import TESTFN, unlink
7-
from llm import models
8-
import sqlite3
7+
8+
import duckdb
9+
import llm.cli
910
from llm import cli as llm_cli
1011

12+
from tsellm.__version__ import __version__
13+
from tsellm.cli import (
14+
cli,
15+
TsellmConsole,
16+
SQLiteConsole,
17+
TsellmConsoleMixin,
18+
)
19+
20+
21+
def new_tempfile():
22+
return Path(tempfile.mkdtemp()) / "test"
23+
24+
25+
def new_sqlite_file():
26+
f = new_tempfile()
27+
with sqlite3.connect(f) as db:
28+
db.execute("SELECT 1")
29+
return f
30+
1131

12-
class CommandLineInterface(unittest.TestCase):
32+
def new_duckdb_file():
33+
f = new_tempfile()
34+
con = duckdb.connect(f.__str__())
35+
con.sql("SELECT 1")
36+
return f
37+
38+
39+
class TsellmConsoleTest(unittest.TestCase):
40+
def setUp(self):
41+
super().setUp()
42+
llm_cli.set_default_model("markov")
43+
llm_cli.set_default_embedding_model("hazo")
1344

1445
def _do_test(self, *args, expect_success=True):
1546
with (
@@ -38,25 +69,132 @@ def expect_failure(self, *args):
3869
self.assertEqual(out, "")
3970
return err
4071

72+
def test_sniff_sqlite(self):
73+
self.assertTrue(TsellmConsoleMixin().is_sqlite(new_sqlite_file()))
74+
75+
def test_sniff_duckdb(self):
76+
self.assertTrue(TsellmConsoleMixin().is_duckdb(new_duckdb_file()))
77+
78+
def test_console_factory_sqlite(self):
79+
s = new_sqlite_file()
80+
self.assertTrue(TsellmConsoleMixin().is_sqlite(s))
81+
obj = TsellmConsole.create_console(s)
82+
self.assertIsInstance(obj, SQLiteConsole)
83+
84+
# def test_console_factory_duckdb(self):
85+
# s = new_duckdb_file()
86+
# self.assertTrue(TsellmConsole.is_duckdb(s))
87+
# obj = TsellmConsole.create_console(s)
88+
# self.assertIsInstance(obj, DuckDBConsole)
89+
4190
def test_cli_help(self):
4291
out = self.expect_success("-h")
4392
self.assertIn("usage: python -m tsellm", out)
4493

4594
def test_cli_version(self):
4695
out = self.expect_success("-v")
96+
self.assertIn(__version__, out)
97+
98+
def test_choose_db(self):
99+
self.expect_failure("--sqlite", "--duckdb")
100+
101+
def test_deault_sqlite(self):
102+
f = new_tempfile()
103+
self.expect_success(str(f), "select 1")
104+
self.assertTrue(TsellmConsoleMixin().is_sqlite(f))
105+
106+
MEMORY_DB_MSG = "Connected to :memory:"
107+
PS1 = "tsellm> "
108+
PS2 = "... "
109+
110+
def run_cli(self, *args, commands=()):
111+
with (
112+
captured_stdin() as stdin,
113+
captured_stdout() as stdout,
114+
captured_stderr() as stderr,
115+
self.assertRaises(SystemExit) as cm
116+
):
117+
for cmd in commands:
118+
stdin.write(cmd + "\n")
119+
stdin.seek(0)
120+
cli(args)
121+
122+
out = stdout.getvalue()
123+
err = stderr.getvalue()
124+
self.assertEqual(cm.exception.code, 0,
125+
f"Unexpected failure: {args=}\n{out}\n{err}")
126+
return out, err
127+
128+
def test_interact(self):
129+
out, err = self.run_cli()
130+
self.assertIn(self.MEMORY_DB_MSG, err)
131+
self.assertIn(self.MEMORY_DB_MSG, err)
132+
self.assertTrue(out.endswith(self.PS1))
133+
self.assertEqual(out.count(self.PS1), 1)
134+
self.assertEqual(out.count(self.PS2), 0)
135+
136+
def test_interact_quit(self):
137+
out, err = self.run_cli(commands=(".quit",))
138+
self.assertIn(self.MEMORY_DB_MSG, err)
139+
self.assertTrue(out.endswith(self.PS1))
140+
self.assertEqual(out.count(self.PS1), 1)
141+
self.assertEqual(out.count(self.PS2), 0)
142+
143+
def test_interact_version(self):
144+
out, err = self.run_cli(commands=(".version",))
145+
self.assertIn(self.MEMORY_DB_MSG, err)
146+
self.assertIn(sqlite3.sqlite_version + "\n", out)
147+
self.assertTrue(out.endswith(self.PS1))
148+
self.assertEqual(out.count(self.PS1), 2)
149+
self.assertEqual(out.count(self.PS2), 0)
47150
self.assertIn(sqlite3.sqlite_version, out)
48151

152+
def test_interact_valid_sql(self):
153+
out, err = self.run_cli(commands=("SELECT 1;",))
154+
self.assertIn(self.MEMORY_DB_MSG, err)
155+
self.assertIn("(1,)\n", out)
156+
self.assertTrue(out.endswith(self.PS1))
157+
self.assertEqual(out.count(self.PS1), 2)
158+
self.assertEqual(out.count(self.PS2), 0)
159+
160+
def test_interact_incomplete_multiline_sql(self):
161+
out, err = self.run_cli(commands=("SELECT 1",))
162+
self.assertIn(self.MEMORY_DB_MSG, err)
163+
self.assertTrue(out.endswith(self.PS2))
164+
self.assertEqual(out.count(self.PS1), 1)
165+
self.assertEqual(out.count(self.PS2), 1)
166+
167+
def test_interact_valid_multiline_sql(self):
168+
out, err = self.run_cli(commands=("SELECT 1\n;",))
169+
self.assertIn(self.MEMORY_DB_MSG, err)
170+
self.assertIn(self.PS2, out)
171+
self.assertIn("(1,)\n", out)
172+
self.assertTrue(out.endswith(self.PS1))
173+
self.assertEqual(out.count(self.PS1), 2)
174+
self.assertEqual(out.count(self.PS2), 1)
175+
176+
177+
class InMemorySQLiteTest(TsellmConsoleTest):
178+
path_args = None
179+
180+
def setUp(self):
181+
super().setUp()
182+
self.path_args = (
183+
"--sqlite",
184+
":memory:",
185+
)
186+
49187
def test_cli_execute_sql(self):
50-
out = self.expect_success(":memory:", "select 1")
188+
out = self.expect_success(*self.path_args, "select 1")
51189
self.assertIn("(1,)", out)
52190

53191
def test_cli_execute_too_much_sql(self):
54-
stderr = self.expect_failure(":memory:", "select 1; select 2")
192+
stderr = self.expect_failure(*self.path_args, "select 1; select 2")
55193
err = "ProgrammingError: You can only execute one statement at a time"
56194
self.assertIn(err, stderr)
57195

58196
def test_cli_execute_incomplete_sql(self):
59-
stderr = self.expect_failure(":memory:", "sel")
197+
stderr = self.expect_failure(*self.path_args, "sel")
60198
self.assertIn("OperationalError (SQLITE_ERROR)", stderr)
61199

62200
def test_cli_on_disk_db(self):
@@ -66,47 +204,99 @@ def test_cli_on_disk_db(self):
66204
out = self.expect_success(TESTFN, "select count(t) from t")
67205
self.assertIn("(0,)", out)
68206

69-
70-
class SQLiteLLMFunction(CommandLineInterface):
71-
72-
def setUp(self):
73-
super().setUp()
74-
llm_cli.set_default_model("markov")
75-
llm_cli.set_default_embedding_model("hazo")
76-
77207
def assertMarkovResult(self, prompt, generated):
78208
# Every word should be one of the original prompt (see https://github.com/simonw/llm-markov/blob/657ca504bcf9f0bfc1c6ee5fe838cde9a8976381/tests/test_llm_markov.py#L20)
79209
for w in prompt.split(" "):
80210
self.assertIn(w, generated)
81211

82212
def test_prompt_markov(self):
83-
out = self.expect_success(":memory:", "select prompt('hello world', 'markov')")
213+
out = self.expect_success(
214+
*self.path_args, "select prompt('hello world', 'markov')"
215+
)
84216
self.assertMarkovResult("hello world", out)
85217

86218
def test_prompt_default_markov(self):
87219
self.assertEqual(llm_cli.get_default_model(), "markov")
88-
out = self.expect_success(":memory:", "select prompt('hello world')")
220+
out = self.expect_success(*self.path_args, "select prompt('hello world')")
89221
self.assertMarkovResult("hello world", out)
90222

91223
def test_embed_hazo(self):
92-
out = self.expect_success(":memory:", "select embed('hello world', 'hazo')")
224+
out = self.expect_success(
225+
*self.path_args, "select embed('hello world', 'hazo')"
226+
)
93227
self.assertEqual(
94228
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
95229
out,
96230
)
97231

98232
def test_embed_hazo_binary(self):
99233
self.assertTrue(llm.get_embedding_model("hazo").supports_binary)
100-
self.expect_success(":memory:", "select embed(randomblob(16), 'hazo')")
234+
self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")
235+
236+
def test_embed_default_hazo(self):
237+
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
238+
out = self.expect_success(*self.path_args, "select embed('hello world')")
239+
self.assertEqual(
240+
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
241+
out,
242+
)
243+
244+
245+
class DefaultInMemorySQLiteTest(InMemorySQLiteTest):
246+
"""--sqlite is omitted and should be the default, so all test cases remain the same"""
247+
248+
def setUp(self):
249+
super().setUp()
250+
self.path_args = (":memory:",)
101251

102252

253+
class DiskSQLiteTest(InMemorySQLiteTest):
254+
db_fp = None
255+
path_args = ()
256+
257+
def setUp(self):
258+
super().setUp()
259+
self.db_fp = str(new_tempfile())
260+
self.path_args = (
261+
"--sqlite",
262+
self.db_fp,
263+
)
264+
265+
def test_embed_default_hazo_leaves_valid_db_behind(self):
266+
# This should probably be called for all test cases
267+
super().test_embed_default_hazo()
268+
self.assertTrue(TsellmConsoleMixin().is_sqlite(self.db_fp))
269+
270+
271+
class InMemoryDuckDBTest(InMemorySQLiteTest):
272+
def setUp(self):
273+
super().setUp()
274+
self.path_args = (
275+
"--duckdb",
276+
":memory:",
277+
)
278+
279+
def test_duckdb_execute(self):
280+
out = self.expect_success(*self.path_args, "select 'Hello World!'")
281+
self.assertIn("('Hello World!',)", out)
282+
283+
def test_cli_execute_incomplete_sql(self):
284+
pass
285+
286+
def test_cli_execute_too_much_sql(self):
287+
pass
288+
103289
def test_embed_default_hazo(self):
104-
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
105-
out = self.expect_success(":memory:", "select embed('hello world')")
106-
self.assertEqual(
107-
"('[5.0, 5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]',)\n",
108-
out,
109-
)
290+
# See https://github.com/Florents-Tselai/tsellm/issues/24
291+
pass
292+
293+
def test_prompt_default_markov(self):
294+
# See https://github.com/Florents-Tselai/tsellm/issues/24
295+
pass
296+
297+
def test_embed_hazo_binary(self):
298+
# See https://github.com/Florents-Tselai/tsellm/issues/25
299+
pass
110300

111301

112302
if __name__ == "__main__":

tsellm/__main__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
23
from .cli import cli
34

45
if __name__ == "__main__":

tsellm/__version__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
__title__ = "tsellm"
2+
__description__ = "Use LLMs in SQLite and DuckDB"
3+
__version__ = "0.1.0a10"

0 commit comments

Comments
 (0)