Skip to content

Commit

Permalink
Update: add a demo
Browse files Browse the repository at this point in the history
  • Loading branch information
ignorejjj committed Jun 15, 2024
1 parent 3d20a75 commit 267f0c4
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 6 deletions.
31 changes: 27 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ FlashRAG is still under development and there are many issues and room for impro

## :page_with_curl: Changelog

[24/06/15] We provide a [<u>demo</u>](./examples/quick_start/demo_en.py) to perform the RAG process using our toolkit.

[24/06/11] We have integrated `sentence transformers` in the retriever module. Now it's easier to use the retriever without setting pooling methods.

[24/06/05] We have provided detailed document for reproducing existing methods (see [how to reproduce](./docs/reproduce_experiment.md), [baseline details](./docs/baseline_details.md)), and [<u>configurations settings</u>](./docs/configuration.md).
Expand Down Expand Up @@ -95,20 +97,41 @@ conda install -c pytorch -c nvidia faiss-gpu=1.8.0

For beginners, we provide a [<u>an introduction to flashrag</u>](./docs/introduction_for_beginners_en.md) ([<u>中文版</u>](./docs/introduction_for_beginners_zh.md)) to help you familiarize yourself with our toolkit. Alternatively, you can directly refer to the code below.

#### Demo

We provide a toy demo to implement a simple RAG process. You can freely change the corpus and model you want to use. The English demo uses [general knowledge](https://huggingface.co/datasets/MuskumPillerum/General-Knowledge) as the corpus, `e5-base-v2` as the retriever, and `Llama3-8B-instruct` as generator. The Chinese demo uses data crawled from the official website of Remin University of China as the corpus, `bge-large-zh-v1.5` as the retriever, and qwen1.5-14B as the generator. Please fill in the corresponding path in the file.

To run the demo:

```bash
cd examples/quick_start
# run english demo
streamlit run demo_en.py

# run chinese demo
streamlit run demo_zh.py
```

<figure class="half">
<img src="./asset/demo_en.gif" >
<img src="./asset/demo_zh.gif" >
</figure>

#### Pipeline

We also provide an example to use our framework for pipeline execution.
Run the following code to implement a naive RAG pipeline using provided toy datasets.
The default retriever is `e5` and default generator is `llama2-7B-chat`. You need to fill in the corresponding model path in the following command. If you wish to use other models, please refer to the detailed instructions below.
The default retriever is `e5-base-v2` and default generator is `Llama3-8B-instruct`. You need to fill in the corresponding model path in the following command. If you wish to use other models, please refer to the detailed instructions below.

```bash
cd examples/quick_start
python simple_pipeline.py \
--model_path=<LLAMA2-7B-Chat-PATH> \
--model_path=<Llama-3-8B-instruct-PATH> \
--retriever_path=<E5-PATH>
```

After the code is completed, you can view the intermediate results of the run and the final evaluation score in the output folder under the corresponding path.

**Note:** This toy example is just to help test whether the entire process can run normally. Our toy retrieval document only contains 1000 pieces of data, so it may not yield good results.

### Using the ready-made pipeline

You can use the pipeline class we have already built (as shown in [<u>pipelines</u>](#pipelines)) to implement the RAG process inside. In this case, you just need to configure the config and load the corresponding pipeline.
Expand Down
Binary file added asset/demo_en.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asset/demo_zh.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
83 changes: 83 additions & 0 deletions examples/quick_start/demo_en.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import streamlit as st
from flashrag.config import Config
from flashrag.utils import get_retriever, get_generator
from flashrag.prompt import PromptTemplate


config_dict = {"save_note":"demo",
'model2path': {'e5': 'intfloat/e5-base-v2', 'llama3-8B-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct'},
"retrieval_method":"e5",
'generator_model': 'llama3-8B-instruct',
"corpus_path":"indexses/general_knowledge.jsonl",
"index_path":"indexses/e5_Flat.index"}

@st.cache_resource
def load_retriever(_config):
return get_retriever(_config)

@st.cache_resource
def load_generator(_config):
return get_generator(_config)

custom_theme = {
"primaryColor": "#ff6347",
"backgroundColor": "#f0f0f0",
"secondaryBackgroundColor": "#d3d3d3",
"textColor": "#121212",
"font": "sans serif"
}
st.set_page_config(page_title="FlashRAG Demo", page_icon="⚡")


st.sidebar.title("Configuration")
temperature = st.sidebar.slider("Temperature:", 0.01, 1.0, 0.5)
topk = st.sidebar.slider("Number of retrieved documents:", 1, 10, 5)
max_new_tokens = st.sidebar.slider("Max generation tokens:", 1, 2048, 256)


st.title("⚡FlashRAG Demo")
st.write("This demo retrieves documents and generates responses based on user input.")

query = st.text_area("Enter your prompt:")

config = Config('my_config.yaml', config_dict=config_dict)
retriever = load_retriever(config)
generator = load_generator(config)

system_prompt_rag = "You are a friendly AI Assistant." \
"Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable." \
"\nThe following are provided references. You can use them for answering question.\n\n{reference}"
system_prompt_no_rag = "You are a friendly AI Assistant." \
"Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable.\n"
base_user_prompt = "{question}"

prompt_template_rag = PromptTemplate(config, system_prompt=system_prompt_rag, user_prompt=base_user_prompt)
prompt_template_no_rag = PromptTemplate(config, system_prompt=system_prompt_no_rag, user_prompt=base_user_prompt)


if st.button("Generate Responses"):
with st.spinner("Retrieving and Generating..."):
retrieved_docs = retriever.search(query,num=topk)

st.subheader("References",divider='gray')
for i, doc in enumerate(retrieved_docs):
doc_title = doc.get('title','No Title')
doc_text = "\n".join(doc['contents'].split("\n")[1:])
expander = st.expander(f"**[{i+1}]: {doc_title}**", expanded=False)
with expander:
st.markdown(doc_text, unsafe_allow_html=True)

st.subheader("Generated Responses:",divider='gray')

input_prompt_with_rag = prompt_template_rag.get_string(question=query, retrieval_result=retrieved_docs)
response_with_rag = generator.generate(input_prompt_with_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
st.subheader("Response with RAG:")
st.write(response_with_rag)
input_prompt_without_rag = prompt_template_no_rag.get_string(question=query)
response_without_rag = generator.generate(input_prompt_without_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
st.subheader("Response without RAG:")
st.markdown(response_without_rag)
83 changes: 83 additions & 0 deletions examples/quick_start/demo_zh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import streamlit as st
from flashrag.config import Config
from flashrag.utils import get_retriever, get_generator
from flashrag.prompt import PromptTemplate

config_dict = {"save_note":"demo",
"generator_model":"qwen-14B",
"retrieval_method":"bge-zh",
"model2path": {"bge-zh":"BAAI/bge-large-zh-v1.5", "qwen-14B":"Qwen/Qwen1.5-14B-Chat"},
"corpus_path":"/data00/jiajie_jin/rd_corpus.jsonl",
"index_path":"/data00/jiajie_jin/flashrag_indexes/rd_corpus/bge_Flat.index"}

@st.cache_resource
def load_retriever(_config):
return get_retriever(_config)

@st.cache_resource
def load_generator(_config):
return get_generator(_config)

custom_theme = {
"primaryColor": "#ff6347",
"backgroundColor": "#f0f0f0",
"secondaryBackgroundColor": "#d3d3d3",
"textColor": "#121212",
"font": "sans serif"
}
st.set_page_config(page_title="FlashRAG Demo", page_icon="⚡")


st.sidebar.title("Configuration")
temperature = st.sidebar.slider("Temperature:", 0.01, 1.0, 0.5)
topk = st.sidebar.slider("Number of retrieved documents:", 1, 10, 5)
max_new_tokens = st.sidebar.slider("Max generation tokens:", 1, 2048, 256)


st.title("⚡FlashRAG Demo")
st.write("This demo retrieves documents and generates responses based on user input.")


query = st.text_area("Enter your prompt:")

config = Config('my_config.yaml', config_dict=config_dict)
retriever = load_retriever(config)
generator = load_generator(config)

system_prompt_rag = "你是一个友好的人工智能助手。" \
"请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。" \
"\n下面是一些可供参考的文档,你可以使用它们来回答问题。\n\n{reference}"
system_prompt_no_rag = "你是一个友好的人工智能助手。" \
"请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。\n"
base_user_prompt = "{question}"

prompt_template_rag = PromptTemplate(config, system_prompt=system_prompt_rag, user_prompt=base_user_prompt)
prompt_template_no_rag = PromptTemplate(config, system_prompt=system_prompt_no_rag, user_prompt=base_user_prompt)


if st.button("Generate Responses"):
with st.spinner("Retrieving and Generating..."):
retrieved_docs = retriever.search(query,num=topk)

st.subheader("References",divider='gray')
for i, doc in enumerate(retrieved_docs):
doc_title = doc.get('title','No Title')
doc_text = "\n".join(doc['contents'].split("\n")[1:])
expander = st.expander(f"**[{i+1}]: {doc_title}**", expanded=False)
with expander:
st.markdown(doc_text, unsafe_allow_html=True)

st.subheader("Generated Responses:",divider='gray')

input_prompt_with_rag = prompt_template_rag.get_string(question=query, retrieval_result=retrieved_docs)
response_with_rag = generator.generate(input_prompt_with_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
st.subheader("Response with RAG:")
st.write(response_with_rag)
input_prompt_without_rag = prompt_template_no_rag.get_string(question=query)
response_without_rag = generator.generate(input_prompt_without_rag,
temperature=temperature,
max_new_tokens=max_new_tokens)[0]
st.subheader("Response without RAG:")
st.markdown(response_without_rag)
4 changes: 2 additions & 2 deletions examples/quick_start/simple_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
'data_dir': 'dataset/',
'index_path': 'indexes/e5_Flat.index',
'corpus_path': 'indexes/general_knowledge.jsonl',
'model2path': {'e5': args.retriever_path, 'llama2-7B-chat': args.model_path},
'generator_model': 'llama2-7B-chat',
'model2path': {'e5': args.retriever_path, 'llama3-8B-instruct': args.model_path},
'generator_model': 'llama3-8B-instruct',
'retrieval_method': 'e5',
'metrics': ['em','f1','sub_em'],
'retrieval_topk': 1,
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ transformers>=4.40.0
vllm>=0.4.1
voyageai
sentence-transformers
streamlit

0 comments on commit 267f0c4

Please sign in to comment.