Skip to content

Commit b5d248b

Browse files
sreyakumardbirman
andauthored
feat: Embed chat for generating complex queries (#31) (#39)
In chat_query.py: - Calls LLM from metadata chatbot asynchronously (requires credentials AWS Bedrock access to Claude Sonnet 3-7) - Generates loading icon when query is loading - Queries are kept track of in self.queries In query.py: - Added complex_query_builder = ComplexQueryBuilder() Ran linters --------- Co-authored-by: Dan Birman <[email protected]>
1 parent 5ee6352 commit b5d248b

File tree

15 files changed

+4715
-452
lines changed

15 files changed

+4715
-452
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,6 @@ dmypy.json
137137

138138
# MacOs
139139
**/.DS_Store
140+
141+
# Local UI testing
142+
panel_chat_query.py

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ dependencies = [
2525
'aind-data-access-api[rds]',
2626
'aind-metadata-validator>=0.8.3',
2727
'flask',
28+
'langchain',
29+
'langchain_aws',
2830
]
2931

3032
[project.optional-dependencies]

src/aind_metadata_viz/app.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,21 @@
22
import panel as pn
33
from aind_data_schema import __version__ as ads_version
44
from aind_metadata_viz import database
5-
from aind_metadata_viz.utils import AIND_COLORS, COLOR_OPTIONS, hd_style, outer_style
5+
from aind_metadata_viz.utils import (
6+
AIND_COLORS,
7+
COLOR_OPTIONS,
8+
hd_style,
9+
outer_style,
10+
)
611
from aind_metadata_viz.charts import file_present_chart, modality_present_chart
712

813
pn.extension("vega", design="material")
914
alt.themes.enable("ggplot2")
1015

1116
# Define CSS to set the background color and add to panel
12-
background_param = pn.state.location.query_params.get("background", "dark_blue")
17+
background_param = pn.state.location.query_params.get(
18+
"background", "dark_blue"
19+
)
1320
background_color = AIND_COLORS.get(background_param, AIND_COLORS["dark_blue"])
1421

1522
css = f"""
@@ -22,8 +29,8 @@
2229
pn.config.raw_css.append(css)
2330

2431
# Get the active color list
25-
colors = (
26-
COLOR_OPTIONS.get(pn.state.location.query_params.get("colors"), COLOR_OPTIONS["default"])
32+
colors = COLOR_OPTIONS.get(
33+
pn.state.location.query_params.get("colors"), COLOR_OPTIONS["default"]
2734
)
2835
color_list = list(colors.values())
2936

@@ -39,7 +46,9 @@
3946
name="Filter by core file:", options=database.ALL_FILES
4047
)
4148

42-
field_selector = pn.widgets.Select(name="Filter download by field:", options=[])
49+
field_selector = pn.widgets.Select(
50+
name="Filter download by field:", options=[]
51+
)
4352

4453
missing_selector = pn.widgets.Select(
4554
name="Filter download by state", options=["Missing", "Valid/Present"]
@@ -118,7 +127,12 @@ def field_present_chart(selected_file, derived_filter, **args):
118127

119128
sum_longform_df = db.get_file_field_presence()
120129

121-
field_selection = alt.selection_point(fields=['field'], empty='none', name='field', value=field_selector.value)
130+
field_selection = alt.selection_point(
131+
fields=["field"],
132+
empty="none",
133+
name="field",
134+
value=field_selector.value,
135+
)
122136

123137
chart = (
124138
alt.Chart(sum_longform_df)
@@ -151,12 +165,11 @@ def field_present_chart(selected_file, derived_filter, **args):
151165

152166
def update_selection(event):
153167
if len(event.new) > 0:
154-
field_selector.value = event.new[0]['field']
155-
pane.selection.param.watch(update_selection, 'field')
156-
157-
return pane
168+
field_selector.value = event.new[0]["field"]
158169

170+
pane.selection.param.watch(update_selection, "field")
159171

172+
return pane
160173

161174

162175
header = (
@@ -175,7 +188,6 @@ def update_selection(event):
175188
"""
176189

177190

178-
179191
header_pane = pn.pane.Markdown(header, styles=outer_style, width=420)
180192

181193
total_md = f"<p style=\"text-align:center\"><b>{db.get_overall_valid():1.2f}%</b> of all metadata records are fully {hd_style('valid', colors)}</p>"
@@ -209,7 +221,10 @@ def build_row(selected_modality, derived_filter):
209221
db.modality_filter = selected_modality
210222
db.derived_filter = derived_filter
211223

212-
return pn.Row(file_present_chart(db, colors, top_selector), modality_present_chart(db, colors, color_list, modality_selector))
224+
return pn.Row(
225+
file_present_chart(db, colors, top_selector),
226+
modality_present_chart(db, colors, color_list, modality_selector),
227+
)
213228

214229

215230
top_row = pn.bind(
@@ -228,24 +243,42 @@ def build_row(selected_modality, derived_filter):
228243
# Put everything in a column and buffer it
229244
main_col = pn.Column(top_row, mid_plot, styles=outer_style, width=515)
230245

231-
main_row = pn.Row(pn.HSpacer(), left_col, pn.Spacer(width=20), main_col, pn.HSpacer(), margin=20)
246+
main_row = pn.Row(
247+
pn.HSpacer(),
248+
left_col,
249+
pn.Spacer(width=20),
250+
main_col,
251+
pn.HSpacer(),
252+
margin=20,
253+
)
232254

233255
# Add the validator search section
234-
validator_name_selector = pn.widgets.TextInput(name="Enter asset name to validate:", value="", placeholder="Asset name", width=800)
256+
validator_name_selector = pn.widgets.TextInput(
257+
name="Enter asset name to validate:",
258+
value="",
259+
placeholder="Asset name",
260+
width=800,
261+
)
235262
pn.state.location.sync(validator_name_selector, {"value": "validator_name"})
236263

237264
validator = database.RecordValidator(validator_name_selector.value, colors)
238265

239266

240267
def build_validator(validator_name):
241268
validator.update(validator_name)
242-
col = pn.Column(validator_name_selector, validator.panel(), width=(515+20+420), styles=outer_style)
269+
col = pn.Column(
270+
validator_name_selector,
271+
validator.panel(),
272+
width=(515 + 20 + 420),
273+
styles=outer_style,
274+
)
243275
row = pn.Row(pn.HSpacer(), col, pn.HSpacer())
244276
return row
245277

246278

247-
validator_row = pn.bind(build_validator,
248-
validator_name=validator_name_selector)
279+
validator_row = pn.bind(
280+
build_validator, validator_name=validator_name_selector
281+
)
249282

250283
pn.Column(main_row, validator_row).servable(
251284
title="Metadata Portal",

src/aind_metadata_viz/charts.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ def file_present_chart(db, colors, selector):
1111
local_states = sum_longform_df["state"].unique()
1212
local_color_list = [colors[state] for state in local_states]
1313

14-
file_selection = alt.selection_point(fields=['file'], empty='none', name='file', value=selector.value)
14+
file_selection = alt.selection_point(
15+
fields=["file"], empty="none", name="file", value=selector.value
16+
)
1517

1618
chart = (
1719
alt.Chart(sum_longform_df)
@@ -40,8 +42,9 @@ def file_present_chart(db, colors, selector):
4042

4143
def update_selection(event):
4244
if len(event.new) > 0:
43-
selector.value = event.new[0]['file']
44-
pane.selection.param.watch(update_selection, 'file')
45+
selector.value = event.new[0]["file"]
46+
47+
pane.selection.param.watch(update_selection, "file")
4548

4649
return pane
4750

@@ -55,7 +58,12 @@ def modality_present_chart(db, colors, color_list, selector):
5558
df_list.append(sum_longform_df)
5659
df = pd.concat(df_list)
5760

58-
modality_selection = alt.selection_point(fields=['modality'], empty='all', name='modality', value=(selector.value if selector.value != "all" else None))
61+
modality_selection = alt.selection_point(
62+
fields=["modality"],
63+
empty="all",
64+
name="modality",
65+
value=(selector.value if selector.value != "all" else None),
66+
)
5967

6068
chart = (
6169
alt.Chart(df)
@@ -93,9 +101,10 @@ def modality_present_chart(db, colors, color_list, selector):
93101

94102
def update_selection(event):
95103
if len(event.new) > 0:
96-
selector.value = event.new[0]['modality']
104+
selector.value = event.new[0]["modality"]
97105
else:
98106
selector.value = "all"
99-
pane.selection.param.watch(update_selection, 'modality')
100107

101-
return pane
108+
pane.selection.param.watch(update_selection, "modality")
109+
110+
return pane

0 commit comments

Comments
 (0)