Skip to content

Commit

Permalink
Merge pull request #447 from llmware-ai/sqlite-kv-range-lookup-fix
Browse files Browse the repository at this point in the history
sqlite key value range retrieval fix
  • Loading branch information
doberst authored Feb 21, 2024
2 parents 27a22b6 + 616b990 commit 3151e11
Showing 1 changed file with 22 additions and 57 deletions.
79 changes: 22 additions & 57 deletions llmware/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,7 @@ def unpack(self, results_cursor):
new_dict.update({key: row[counter]})
counter += 1
else:
logging.info("update: pg_retriever - outputs not matching - %s - %s", counter, row[counter])
logging.warning("update: pg_retriever - outputs not matching - %s", counter)

output.append(new_dict)

Expand Down Expand Up @@ -899,7 +899,7 @@ def unpack_search_result(self, results_cursor):
new_dict.update({key: row[counter]})
counter += 1
else:
logging.info ("update: pg_retriever - outputs not matching - %s - %s ", counter, row[counter])
logging.warning ("update: pg_retriever - outputs not matching - %s ", counter)

output.append(new_dict)

Expand Down Expand Up @@ -1020,11 +1020,7 @@ def filter_by_key(self, key, value):
results = self.conn.cursor().execute(sql_query)

if results:

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand Down Expand Up @@ -1159,17 +1155,16 @@ def filter_by_key_dict (self, key_dict):

conditions_clause = " WHERE"
for key, value in key_dict.items():
conditions_clause += f" AND {key} = {value}"
conditions_clause += f" {key} = '{value}' AND "

if conditions_clause.endswith(' AND '):
conditions_clause = conditions_clause[:-5]
if len(conditions_clause) > len(" WHERE"):
sql_query += conditions_clause + ";"

results = self.conn.cursor().execute(sql_query)

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand All @@ -1186,14 +1181,11 @@ def filter_by_key_value_range(self, key, value_range):
value_range_str = value_range_str[:-2]
value_range_str += ")"

sql_query = f"SELECT * from {self.library_name} WHERE '{key}' IN {value_range_str};"
sql_query = f"SELECT * from {self.library_name} WHERE {key} IN {value_range_str};"

results = self.conn.cursor().execute(sql_query)

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand All @@ -1203,14 +1195,11 @@ def filter_by_key_ne_value(self, key, value):

"""Filter by col (key) not equal to specified value"""

sql_query = f"SELECT * from {self.library_name} WHERE NOT '{key}' = {value};"
sql_query = f"SELECT * from {self.library_name} WHERE NOT {key} = {value};"

results = self.conn.cursor().execute(sql_query)

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand Down Expand Up @@ -1884,7 +1873,7 @@ def unpack(self, results_cursor):
counter += 1

else:
logging.info("update: sqlite_retriever - outputs not matching - %s - %s", counter, len(row))
logging.warning("update: sqlite_retriever - outputs not matching - %s ", counter)

output.append(new_dict)

Expand Down Expand Up @@ -1921,8 +1910,9 @@ def unpack_search_result(self, results_cursor):

new_dict.update({key: row[counter]})
counter += 1

else:
logging.info("update: sqlite_retriever - outputs not matching - %s - %s", counter, row[counter])
logging.warning("update: sqlite_retriever - outputs not matching - %s", counter)

output.append(new_dict)

Expand Down Expand Up @@ -2039,26 +2029,7 @@ def filter_by_key(self, key, value):
sql_query = f"SELECT rowid, * FROM {self.library_name} WHERE {key} = {value};"
results = self.conn.cursor().execute(sql_query)

# lib_card = {}

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)

"""
if self.library_name == "library":
# repackage library card
library_schema = LLMWareTableSchema.get_library_card_schema()
lib_card = {}
counter = 0
results = list(results)
for keys in library_schema:
# print("update: keys / sql - ", keys, results[0][counter])
lib_card.update({keys:results[0][counter]})
counter += 1
"""
output = self.unpack(results)

self.conn.close()

Expand Down Expand Up @@ -2187,17 +2158,17 @@ def filter_by_key_dict (self, key_dict):

conditions_clause = " WHERE"
for key, value in key_dict.items():
conditions_clause += f" AND {key} = {value}"
conditions_clause += f" {key} = {value} AND "

if conditions_clause.endswith(" AND "):
conditions_clause = conditions_clause[:-5]

if len(conditions_clause) > len(" WHERE"):
sql_query += conditions_clause + ";"

results = self.conn.cursor().execute(sql_query)

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand All @@ -2221,10 +2192,7 @@ def filter_by_key_value_range(self, key, value_range):

results = self.conn.cursor().execute(sql_query)

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand All @@ -2238,10 +2206,7 @@ def filter_by_key_ne_value(self, key, value):

results = self.conn.cursor().execute(sql_query)

if self.text_retrieval:
output = self.unpack_search_result(results)
else:
output = self.unpack(results)
output = self.unpack(results)

self.conn.close()

Expand Down

0 comments on commit 3151e11

Please sign in to comment.