diff --git a/README.md b/README.md index 34c1d3b..7ba7ab1 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ FlashRAG is still under development and there are many issues and room for impro ## :page_with_curl: Changelog +[24/06/15] We provide a [demo](./examples/quick_start/demo_en.py) to perform the RAG process using our toolkit. + [24/06/11] We have integrated `sentence transformers` in the retriever module. Now it's easier to use the retriever without setting pooling methods. [24/06/05] We have provided detailed document for reproducing existing methods (see [how to reproduce](./docs/reproduce_experiment.md), [baseline details](./docs/baseline_details.md)), and [configurations settings](./docs/configuration.md). @@ -95,20 +97,41 @@ conda install -c pytorch -c nvidia faiss-gpu=1.8.0 For beginners, we provide a [an introduction to flashrag](./docs/introduction_for_beginners_en.md) ([中文版](./docs/introduction_for_beginners_zh.md)) to help you familiarize yourself with our toolkit. Alternatively, you can directly refer to the code below. +#### Demo + +We provide a toy demo to implement a simple RAG process. You can freely change the corpus and model you want to use. The English demo uses [general knowledge](https://huggingface.co/datasets/MuskumPillerum/General-Knowledge) as the corpus, `e5-base-v2` as the retriever, and `Llama3-8B-instruct` as generator. The Chinese demo uses data crawled from the official website of Remin University of China as the corpus, `bge-large-zh-v1.5` as the retriever, and qwen1.5-14B as the generator. Please fill in the corresponding path in the file. + +To run the demo: + +```bash +cd examples/quick_start +# run english demo +streamlit run demo_en.py + +# run chinese demo +streamlit run demo_zh.py +``` + +
+ + +
+ +#### Pipeline + +We also provide an example to use our framework for pipeline execution. Run the following code to implement a naive RAG pipeline using provided toy datasets. -The default retriever is `e5` and default generator is `llama2-7B-chat`. You need to fill in the corresponding model path in the following command. If you wish to use other models, please refer to the detailed instructions below. +The default retriever is `e5-base-v2` and default generator is `Llama3-8B-instruct`. You need to fill in the corresponding model path in the following command. If you wish to use other models, please refer to the detailed instructions below. ```bash cd examples/quick_start python simple_pipeline.py \ - --model_path= \ + --model_path= \ --retriever_path= ``` After the code is completed, you can view the intermediate results of the run and the final evaluation score in the output folder under the corresponding path. -**Note:** This toy example is just to help test whether the entire process can run normally. Our toy retrieval document only contains 1000 pieces of data, so it may not yield good results. - ### Using the ready-made pipeline You can use the pipeline class we have already built (as shown in [pipelines](#pipelines)) to implement the RAG process inside. In this case, you just need to configure the config and load the corresponding pipeline. diff --git a/asset/demo_en.gif b/asset/demo_en.gif new file mode 100644 index 0000000..252aaaf Binary files /dev/null and b/asset/demo_en.gif differ diff --git a/asset/demo_zh.gif b/asset/demo_zh.gif new file mode 100644 index 0000000..977489f Binary files /dev/null and b/asset/demo_zh.gif differ diff --git a/examples/quick_start/demo_en.py b/examples/quick_start/demo_en.py new file mode 100644 index 0000000..65b3bf6 --- /dev/null +++ b/examples/quick_start/demo_en.py @@ -0,0 +1,83 @@ +import streamlit as st +from flashrag.config import Config +from flashrag.utils import get_retriever, get_generator +from flashrag.prompt import PromptTemplate + + +config_dict = {"save_note":"demo", + 'model2path': {'e5': 'intfloat/e5-base-v2', 'llama3-8B-instruct': 'meta-llama/Meta-Llama-3-8B-Instruct'}, + "retrieval_method":"e5", + 'generator_model': 'llama3-8B-instruct', + "corpus_path":"indexses/general_knowledge.jsonl", + "index_path":"indexses/e5_Flat.index"} + +@st.cache_resource +def load_retriever(_config): + return get_retriever(_config) + +@st.cache_resource +def load_generator(_config): + return get_generator(_config) + +custom_theme = { + "primaryColor": "#ff6347", + "backgroundColor": "#f0f0f0", + "secondaryBackgroundColor": "#d3d3d3", + "textColor": "#121212", + "font": "sans serif" +} +st.set_page_config(page_title="FlashRAG Demo", page_icon="⚡") + + +st.sidebar.title("Configuration") +temperature = st.sidebar.slider("Temperature:", 0.01, 1.0, 0.5) +topk = st.sidebar.slider("Number of retrieved documents:", 1, 10, 5) +max_new_tokens = st.sidebar.slider("Max generation tokens:", 1, 2048, 256) + + +st.title("⚡FlashRAG Demo") +st.write("This demo retrieves documents and generates responses based on user input.") + +query = st.text_area("Enter your prompt:") + +config = Config('my_config.yaml', config_dict=config_dict) +retriever = load_retriever(config) +generator = load_generator(config) + +system_prompt_rag = "You are a friendly AI Assistant." \ + "Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable." \ + "\nThe following are provided references. You can use them for answering question.\n\n{reference}" +system_prompt_no_rag = "You are a friendly AI Assistant." \ + "Respond to the input as a friendly AI assistant, generating human-like text, and follow the instructions in the input if applicable.\n" +base_user_prompt = "{question}" + +prompt_template_rag = PromptTemplate(config, system_prompt=system_prompt_rag, user_prompt=base_user_prompt) +prompt_template_no_rag = PromptTemplate(config, system_prompt=system_prompt_no_rag, user_prompt=base_user_prompt) + + +if st.button("Generate Responses"): + with st.spinner("Retrieving and Generating..."): + retrieved_docs = retriever.search(query,num=topk) + + st.subheader("References",divider='gray') + for i, doc in enumerate(retrieved_docs): + doc_title = doc.get('title','No Title') + doc_text = "\n".join(doc['contents'].split("\n")[1:]) + expander = st.expander(f"**[{i+1}]: {doc_title}**", expanded=False) + with expander: + st.markdown(doc_text, unsafe_allow_html=True) + + st.subheader("Generated Responses:",divider='gray') + + input_prompt_with_rag = prompt_template_rag.get_string(question=query, retrieval_result=retrieved_docs) + response_with_rag = generator.generate(input_prompt_with_rag, + temperature=temperature, + max_new_tokens=max_new_tokens)[0] + st.subheader("Response with RAG:") + st.write(response_with_rag) + input_prompt_without_rag = prompt_template_no_rag.get_string(question=query) + response_without_rag = generator.generate(input_prompt_without_rag, + temperature=temperature, + max_new_tokens=max_new_tokens)[0] + st.subheader("Response without RAG:") + st.markdown(response_without_rag) diff --git a/examples/quick_start/demo_zh.py b/examples/quick_start/demo_zh.py new file mode 100644 index 0000000..387d169 --- /dev/null +++ b/examples/quick_start/demo_zh.py @@ -0,0 +1,83 @@ +import streamlit as st +from flashrag.config import Config +from flashrag.utils import get_retriever, get_generator +from flashrag.prompt import PromptTemplate + +config_dict = {"save_note":"demo", + "generator_model":"qwen-14B", + "retrieval_method":"bge-zh", + "model2path": {"bge-zh":"BAAI/bge-large-zh-v1.5", "qwen-14B":"Qwen/Qwen1.5-14B-Chat"}, + "corpus_path":"/data00/jiajie_jin/rd_corpus.jsonl", + "index_path":"/data00/jiajie_jin/flashrag_indexes/rd_corpus/bge_Flat.index"} + +@st.cache_resource +def load_retriever(_config): + return get_retriever(_config) + +@st.cache_resource +def load_generator(_config): + return get_generator(_config) + +custom_theme = { + "primaryColor": "#ff6347", + "backgroundColor": "#f0f0f0", + "secondaryBackgroundColor": "#d3d3d3", + "textColor": "#121212", + "font": "sans serif" +} +st.set_page_config(page_title="FlashRAG Demo", page_icon="⚡") + + +st.sidebar.title("Configuration") +temperature = st.sidebar.slider("Temperature:", 0.01, 1.0, 0.5) +topk = st.sidebar.slider("Number of retrieved documents:", 1, 10, 5) +max_new_tokens = st.sidebar.slider("Max generation tokens:", 1, 2048, 256) + + +st.title("⚡FlashRAG Demo") +st.write("This demo retrieves documents and generates responses based on user input.") + + +query = st.text_area("Enter your prompt:") + +config = Config('my_config.yaml', config_dict=config_dict) +retriever = load_retriever(config) +generator = load_generator(config) + +system_prompt_rag = "你是一个友好的人工智能助手。" \ + "请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。" \ + "\n下面是一些可供参考的文档,你可以使用它们来回答问题。\n\n{reference}" +system_prompt_no_rag = "你是一个友好的人工智能助手。" \ + "请对用户的输出做出高质量的响应,生成类似于人类的内容,并尽量遵循输入中的指令。\n" +base_user_prompt = "{question}" + +prompt_template_rag = PromptTemplate(config, system_prompt=system_prompt_rag, user_prompt=base_user_prompt) +prompt_template_no_rag = PromptTemplate(config, system_prompt=system_prompt_no_rag, user_prompt=base_user_prompt) + + +if st.button("Generate Responses"): + with st.spinner("Retrieving and Generating..."): + retrieved_docs = retriever.search(query,num=topk) + + st.subheader("References",divider='gray') + for i, doc in enumerate(retrieved_docs): + doc_title = doc.get('title','No Title') + doc_text = "\n".join(doc['contents'].split("\n")[1:]) + expander = st.expander(f"**[{i+1}]: {doc_title}**", expanded=False) + with expander: + st.markdown(doc_text, unsafe_allow_html=True) + + st.subheader("Generated Responses:",divider='gray') + + input_prompt_with_rag = prompt_template_rag.get_string(question=query, retrieval_result=retrieved_docs) + response_with_rag = generator.generate(input_prompt_with_rag, + temperature=temperature, + max_new_tokens=max_new_tokens)[0] + st.subheader("Response with RAG:") + st.write(response_with_rag) + input_prompt_without_rag = prompt_template_no_rag.get_string(question=query) + response_without_rag = generator.generate(input_prompt_without_rag, + temperature=temperature, + max_new_tokens=max_new_tokens)[0] + st.subheader("Response without RAG:") + st.markdown(response_without_rag) diff --git a/examples/quick_start/simple_pipeline.py b/examples/quick_start/simple_pipeline.py index 3624d1f..5a0ee97 100644 --- a/examples/quick_start/simple_pipeline.py +++ b/examples/quick_start/simple_pipeline.py @@ -13,8 +13,8 @@ 'data_dir': 'dataset/', 'index_path': 'indexes/e5_Flat.index', 'corpus_path': 'indexes/general_knowledge.jsonl', - 'model2path': {'e5': args.retriever_path, 'llama2-7B-chat': args.model_path}, - 'generator_model': 'llama2-7B-chat', + 'model2path': {'e5': args.retriever_path, 'llama3-8B-instruct': args.model_path}, + 'generator_model': 'llama3-8B-instruct', 'retrieval_method': 'e5', 'metrics': ['em','f1','sub_em'], 'retrieval_topk': 1, diff --git a/requirements.txt b/requirements.txt index 92ec766..3315a59 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ transformers>=4.40.0 vllm>=0.4.1 voyageai sentence-transformers +streamlit