Skip to content

Commit c033313

Browse files
committed
Add sqlite as a destination
1 parent f515063 commit c033313

File tree

6 files changed

+246
-5
lines changed

6 files changed

+246
-5
lines changed

llmstack/data/apis.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def destinations(self, request):
116116
PandasStore,
117117
Pinecone,
118118
SingleStore,
119+
SqliteDatabase,
119120
Weaviate,
120121
)
121122

@@ -145,6 +146,12 @@ def destinations(self, request):
145146
"schema": PandasStore.get_schema(),
146147
"ui_schema": PandasStore.get_ui_schema(),
147148
},
149+
{
150+
"slug": SqliteDatabase.slug(),
151+
"provider_slug": SqliteDatabase.provider_slug(),
152+
"schema": SqliteDatabase.get_schema(),
153+
"ui_schema": SqliteDatabase.get_ui_schema(),
154+
},
148155
]
149156
)
150157

llmstack/data/destinations/__init__.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from llmstack.data.destinations.stores.pandas import PandasStore
44
from llmstack.data.destinations.stores.postgres import PostgresDatabase
55
from llmstack.data.destinations.stores.singlestore import SingleStore
6+
from llmstack.data.destinations.stores.sqlite import SqliteDatabase
67
from llmstack.data.destinations.vector_stores.chromadb import ChromaDB
78
from llmstack.data.destinations.vector_stores.pinecone import Pinecone
89
from llmstack.data.destinations.vector_stores.qdrant import Qdrant
@@ -12,10 +13,20 @@
1213

1314
@cache
1415
def get_destination_cls(slug, provider_slug):
15-
for cls in [ChromaDB, Weaviate, SingleStore, Pinecone, Qdrant, PromptlyVectorStore, PandasStore, PostgresDatabase]:
16+
for cls in [
17+
ChromaDB,
18+
Weaviate,
19+
SingleStore,
20+
Pinecone,
21+
Qdrant,
22+
PromptlyVectorStore,
23+
PandasStore,
24+
PostgresDatabase,
25+
SqliteDatabase,
26+
]:
1627
if cls.slug() == slug and cls.provider_slug() == provider_slug:
1728
return cls
1829
return None
1930

2031

21-
__all__ = ["SingleStore", "Pinecone", "Weaviate", "PandasStore", "PostgresDatabase"]
32+
__all__ = ["SingleStore", "Pinecone", "Weaviate", "PandasStore", "PostgresDatabase", "SqliteDatabase"]

llmstack/data/destinations/stores/sql.py

Whitespace-only changes.
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
import base64
2+
import io
3+
import json
4+
import logging
5+
import os
6+
import sqlite3
7+
import uuid
8+
from typing import List, Literal, Optional, Union
9+
10+
from llama_index.core.schema import TextNode
11+
from llama_index.core.vector_stores.types import VectorStoreQueryResult
12+
from pydantic import BaseModel, Field, PrivateAttr
13+
14+
from llmstack.data.destinations.base import BaseDestination
15+
from llmstack.data.sources.base import DataDocument
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
def create_empty_sqlite_db():
21+
filename = f"sqlite_{str(uuid.uuid4())[:4]}.db"
22+
return f"data:application/octet-stream;name={filename};base64,{base64.b64encode(b'').decode('utf-8')}"
23+
24+
25+
def create_destination_document_asset(file, document_id, datasource_uuid):
26+
from llmstack.data.models import DataSourceEntryFiles
27+
28+
if not file:
29+
return None
30+
31+
file_obj = DataSourceEntryFiles.create_from_data_uri(
32+
file, ref_id=document_id, metadata={"datasource_uuid": datasource_uuid}
33+
)
34+
return file_obj
35+
36+
37+
def get_destination_document_asset_by_document_id(document_id):
38+
from llmstack.data.models import DataSourceEntryFiles
39+
40+
file = DataSourceEntryFiles.objects.filter(ref_id=document_id).first()
41+
return file
42+
43+
44+
def create_temp_file_from_asset(asset):
45+
import tempfile
46+
47+
temp_file = tempfile.NamedTemporaryFile(delete=False)
48+
temp_file.write(asset.file.read())
49+
temp_file.flush()
50+
temp_file.seek(0)
51+
return temp_file.name
52+
53+
54+
def get_sqlite_data_type(_type: str):
55+
if _type == "string":
56+
return "TEXT"
57+
elif _type == "number":
58+
return "REAL"
59+
elif _type == "boolean":
60+
return "BOOLEAN"
61+
return "TEXT"
62+
63+
64+
class SchemaEntry(BaseModel):
65+
name: str
66+
type: Union[Literal["string"], Literal["number"], Literal["boolean"]] = "string"
67+
68+
69+
class MappingEntry(BaseModel):
70+
source: str
71+
target: str
72+
73+
74+
class FullTextSearchPlugin(BaseModel):
75+
type: Literal["fts5"] = "fts5"
76+
77+
78+
def load_database_from_asset(asset):
79+
local_db = create_temp_file_from_asset(asset)
80+
conn = sqlite3.connect(local_db)
81+
return conn, local_db
82+
83+
84+
def update_asset_from_database(asset, database):
85+
# Read the database content
86+
buffer = io.BytesIO()
87+
with open(database, "rb") as f:
88+
buffer.write(f.read())
89+
buffer.seek(0)
90+
asset.update_file(buffer.getvalue(), asset.metadata.get("file_name"))
91+
# Delete the temporary file
92+
os.remove(database)
93+
94+
95+
class SqliteDatabase(BaseDestination):
96+
schema: List[SchemaEntry] = Field(
97+
description="Schema of the table",
98+
default=[
99+
SchemaEntry(name="id", type="string"),
100+
SchemaEntry(name="text", type="string"),
101+
SchemaEntry(name="metadata_json", type="string"),
102+
],
103+
)
104+
table_name: str = Field(description="Name of the table", default="data")
105+
search_plugin: Optional[Union[FullTextSearchPlugin]] = Field(
106+
description="Search plugin to use",
107+
default=None,
108+
)
109+
110+
_asset = PrivateAttr(default=None)
111+
_name = PrivateAttr(default="sqlite")
112+
113+
@classmethod
114+
def slug(cls):
115+
return "sqlite"
116+
117+
@classmethod
118+
def provider_slug(cls):
119+
return "promptly"
120+
121+
def initialize_client(self, *args, **kwargs):
122+
datasource = kwargs.get("datasource")
123+
self._name = datasource.name
124+
document_id = str(datasource.uuid)
125+
asset = get_destination_document_asset_by_document_id(document_id)
126+
127+
if asset is None:
128+
file = create_empty_sqlite_db()
129+
self._asset = create_destination_document_asset(file, document_id, str(datasource.uuid))
130+
else:
131+
self._asset = asset
132+
133+
def add(self, document):
134+
conn, local_db = load_database_from_asset(self._asset)
135+
c = conn.cursor()
136+
137+
create_table_query = f"CREATE TABLE IF NOT EXISTS {self.table_name} ({','.join([f'{item.name} {get_sqlite_data_type(item.type)}' for item in self.schema])})"
138+
if self.search_plugin:
139+
if self.search_plugin.type == "fts5":
140+
create_table_query = f"CREATE VIRTUAL TABLE IF NOT EXISTS {self.table_name} USING fts5({','.join([f'{item.name}' for item in self.schema])})"
141+
elif self.search_plugin.type == "semantic":
142+
import sqlite_vec
143+
144+
conn.enable_load_extension(True)
145+
sqlite_vec.load(conn)
146+
conn.enable_load_extension(False)
147+
148+
create_table_query = f"CREATE VIRTUAL TABLE IF NOT EXISTS {self.table_name} USING vec0({','.join([f'{item.name} {get_sqlite_data_type(item.type)}' for item in self.schema])}, embedding float[1536])"
149+
150+
c.execute(create_table_query)
151+
152+
try:
153+
for node in document.nodes:
154+
document_dict = {"text": node.text, "metadata_json": json.dumps(node.metadata)}
155+
for schema_entry in self.schema:
156+
if schema_entry.name == "id" or schema_entry.name == "text" or schema_entry.name == "metadata_json":
157+
continue
158+
if schema_entry.name in node.metadata:
159+
document_dict[schema_entry.name] = node.metadata[schema_entry.name]
160+
if self.search_plugin and self.search_plugin.type == "semantic":
161+
document_dict["embedding"] = node.embedding
162+
163+
entry_dict = {"id": node.id_, **document_dict}
164+
c.execute(
165+
f"INSERT INTO {self.table_name} ({','.join(entry_dict.keys())}) VALUES ({','.join(['?'] * len(entry_dict))})",
166+
list(entry_dict.values()),
167+
)
168+
conn.commit()
169+
conn.close()
170+
except Exception as e:
171+
logger.exception(f"Error adding nodes to sqlite store {e}")
172+
raise e
173+
174+
update_asset_from_database(self._asset, local_db)
175+
ids = [r.node_id for r in document.nodes]
176+
return ids
177+
178+
def delete(self, document: DataDocument):
179+
conn, local_db = load_database_from_asset(self._asset)
180+
c = conn.cursor()
181+
for node_id in document.node_ids:
182+
c.execute(f"DELETE FROM {self.table_name} WHERE id = ?", (node_id,))
183+
conn.commit()
184+
conn.close()
185+
update_asset_from_database(self._asset, local_db)
186+
187+
def search(self, query: str, **kwargs):
188+
conn, _ = load_database_from_asset(self._asset)
189+
c = conn.cursor()
190+
result = c.execute(query).fetchall()
191+
conn.close()
192+
nodes = list(
193+
map(lambda x: TextNode(text=json.dumps(x), metadata={"query": query, "source": self._name}), result)
194+
)
195+
node_ids = list(map(lambda x: x, enumerate(result)))
196+
return VectorStoreQueryResult(nodes=nodes, ids=node_ids, similarities=[])
197+
198+
def create_collection(self):
199+
pass
200+
201+
def delete_collection(self):
202+
if self._asset:
203+
self._asset.file.delete()
204+
self._asset.delete()
205+
206+
def get_nodes(self, node_ids=None, filters=None):
207+
conn, _ = load_database_from_asset(self._asset)
208+
column_names = [schema_entry.name for schema_entry in self.schema]
209+
c = conn.cursor()
210+
if node_ids:
211+
query = f"SELECT {','.join(column_names)} FROM {self.table_name} WHERE id IN ({','.join(['?'] * len(node_ids))})"
212+
rows = c.execute(query, node_ids).fetchall()
213+
else:
214+
rows = c.execute(f"SELECT * FROM {self.table_name}").fetchall()
215+
conn.close()
216+
if rows:
217+
return list(
218+
map(
219+
lambda x: TextNode(id_=x[0], text=json.dumps(x), metadata={"source": self._name}),
220+
rows,
221+
)
222+
)
223+
return []

llmstack/data/transformations/splitters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def get_default_data(cls):
3939
class CSVTextSplitter(TextSplitter, PromptlyTransformers):
4040
exclude_columns: Optional[List[str]] = Field(
4141
default=None,
42-
description="Columns to exclude from the text",
42+
description="Columns to drop from the csv row",
4343
)
4444
text_columns: Optional[List[str]] = Field(
4545
default=None,
@@ -92,6 +92,7 @@ def _parse_nodes(self, nodes, show_progress: bool = False, **kwargs):
9292
row_text = json.dumps(text_parts)
9393
all_nodes.extend(build_nodes_from_splits([row_text], node, id_func=self.id_func))
9494
for column_name, value in content.items():
95-
all_nodes[-1].metadata[f"{self.metadata_prefix}{column_name}"] = value
95+
metadata_key = f"{self.metadata_prefix}{column_name}".replace(" ", "_")
96+
all_nodes[-1].metadata[metadata_key] = value
9697

9798
return all_nodes

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ django-picklefield = "^3.2"
4242
django-redis = "^5.4.0"
4343
djangorestframework = "^3.15.2"
4444
django-flags = "^5.0.13"
45-
django-jsonform = {version = "^2.17.4"}
4645
django-ratelimit = {version = "^4.1.0"}
4746
croniter = {version ="^2.0.1"}
4847
pykka = "^4.0.2"

0 commit comments

Comments
 (0)