Skip to content

Commit

Permalink
add report validation on group narratives
Browse files Browse the repository at this point in the history
  • Loading branch information
dayesouza committed Apr 16, 2024
1 parent 86b9b5b commit 81025bd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
2 changes: 2 additions & 0 deletions app/workflows/group_narratives/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,7 @@ def __init__(self, prefix):
self.narrative_top_groups = SessionVariable(0, prefix)
self.narrative_top_attributes = SessionVariable(0, prefix)
self.narrative_report = SessionVariable('', prefix)
self.narrative_report_validation_messages = SessionVariable('', prefix)
self.narrative_report_validation = SessionVariable({}, prefix)
self.narrative_system_prompt = SessionVariable(prompts.list_prompts, prefix)
self.narrative_subject_identifier = SessionVariable('', prefix)
20 changes: 20 additions & 0 deletions app/workflows/group_narratives/workflow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) 2024 Microsoft Corporation. All rights reserved.
import json
import streamlit as st
import pandas as pd

import workflows.group_narratives.prompts as prompts
import workflows.group_narratives.config as config
import workflows.group_narratives.variables as vars
from util.session_variables import SessionVariables

import util.AI_API
import util.ui_components
Expand All @@ -13,6 +15,7 @@
def create():
workflow = 'group_narratives'
sv = vars.SessionVariables('group_narratives')
sv_home = SessionVariables('home')

intro_tab, prepare_tab, summarize_tab, generate_tab = st.tabs(['Group narratives workflow:', 'Upload data to narrate', 'Prepare data summary', 'Generate AI group reports',])

Expand Down Expand Up @@ -199,6 +202,7 @@ def create():

narrative_placeholder = st.empty()
gen_placeholder = st.empty()
get_current_time = pd.Timestamp.now().strftime('%Y%m%d%H%M%S')
if generate:
sv.narrative_selected_groups.value = selected_groups
sv.narrative_top_groups.value = top_group_ranks
Expand All @@ -208,10 +212,26 @@ def create():
prefix=''
)
sv.narrative_report.value = result

validation, messages_to_llm = util.ui_components.validate_ai_report(messages, result)
sv.narrative_report_validation.value = json.loads(validation)
sv.narrative_report_validation_messages.value = messages_to_llm
st.rerun()
else:
if sv.narrative_report.value == '':
gen_placeholder.warning('Press the Generate button to create an AI report for the selected groups.')
narrative_placeholder.markdown(sv.narrative_report.value)

util.ui_components.report_download_ui(sv.narrative_report, 'group_report')
if sv.narrative_report_validation.value != {}:
validation_status = st.status(label=f"LLM faithfulness score: {sv.narrative_report_validation.value['score']}/5", state='complete')
with validation_status:
st.write(sv.narrative_report_validation.value['explanation'])

if sv_home.mode.value == 'dev':
obj = json.dumps({
"message": sv.narrative_report_validation_messages.value,
"result": sv.narrative_report_validation.value,
"report": sv.narrative_report.value
}, indent=4)
st.download_button('Download validation prompt', use_container_width=True, data=str(obj), file_name=f'attr_pattern_{get_current_time}_messages.json', mime='text/json')

0 comments on commit 81025bd

Please sign in to comment.