-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Text2sql tool - e2e evals and fine-tuning #967
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 18 commits
ebed0ef
edcf746
76a8caf
ab7df10
2e8b278
2c514b1
5a18b6b
0033fc9
3997357
44aa896
46d3245
b89d945
3731175
094ab01
6d76ea0
cf54eb4
e182902
79945b6
0aa42d8
4cdd5f6
5e8a7b0
71ca0ae
b17f90b
6815255
03ba7d5
11a4a64
99ead57
ee1fc97
9c294df
ef6bbb2
caf98ec
4107171
b02334a
f7c68c1
4037737
f07da72
57ffb74
a6f7d02
7a4ae9f
9ac5dd1
6b92409
c4573ba
7b508ec
3c23112
cc93b73
6269c15
2cdfbf0
4bb7faa
2bd662c
b574c6d
58ea6cb
33ac1ab
e10ddda
5baa1e3
f894d26
ad48509
e059899
f80e7bf
b630735
1ac67d9
1b802d3
df598c4
cb8b0bd
77d3544
deca42c
12a6dfa
799dee6
6501cf4
82bb008
27a23af
e38abf1
fc80546
be4817c
af3ea4f
0c7b348
54e49bc
c88e10f
57c0517
7edf3d8
8989e69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,30 +1,16 @@ | ||
| ## Text2SQL: Natural Language to SQL Interface | ||
| ## Text2SQL: Eval and Fine-tuning Tools and Quick Start Notebook | ||
|
|
||
| This project provides a set of scripts to convert natural language queries into SQL statements using Meta's Llama model. The goal is to enable users to interact with databases using natural language inputs, making it easier for non-technical users to access and analyze data. | ||
| This folder contains the `tool` subfolder, which has e2e scripts for evaluating Llama (original and fine-tuned) models on the Text2SQL task using the popular [BIRD](https://bird-bench.github.io) dataset, and e2e scripts for generating fine-tuning datasets and fine-tuning Llama 3.1 8B with the datasets. | ||
|
|
||
| For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the quickstart.ipynb notebook. | ||
| Before looking into the `tool` folder, you may start with the scripts and notebook in this folder to get familiar with how to interact with a database using natural language inputs bu asking Llama to convert natural language queries into SQL queries. | ||
|
|
||
| For detailed instructions on setting up the environment, creating a database, and executing natural language queries using the Text2SQL interface, please refer to the [quickstart.ipynb](quickstart.ipynb) notebook. | ||
|
|
||
| ### Structure: | ||
|
|
||
| - tool: A folder containing scripts for evaluating and fine-tuning Llama models on the Text2SQL task. | ||
| - quickstart.ipynb: A Quick Demo of Text2SQL Using Llama 3.3. This Jupyter Notebook includes examples of how to use the interface to execute natural language queries on the sample data. It uses Llama 3.3 to answer questions about a SQLite database using LangChain and the Llama cloud provider Together.ai. | ||
| - nba.txt: A text file containing NBA roster information, which is used as sample data for demonstration purposes. | ||
| - txt2csv.py: A script that converts text data into a CSV format. This script is used to preprocess the input data before it is fed into csv2db.py. | ||
| - csv2db.py: A script that imports data from a CSV file into a SQLite database. This script is used to populate the database with sample data. | ||
| - nba_roster.db: A SQLite database file created from the nba.txt data, used to test the Text2SQL interface. | ||
|
|
||
| ### Detailed steps on running the notebook: | ||
|
|
||
| - Before getting started, please make sure to setup Together.ai and get an API key from [here](https://www.together.ai/). | ||
|
|
||
| - First, please install the requirements from [here](https://github.com/meta-llama/llama-cookbook/blob/main/end-to-end-use-cases/coding/text2sql/requirements.txt) by running inside the folder: | ||
|
|
||
| ``` | ||
| git clone https://github.com/meta-llama/llama-cookbook.git | ||
| cd llama-cookbook/end-to-end-use-cases/coding/text2sql/ | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| ### Contributing | ||
| Contributions are welcome! If you'd like to add new features or improve existing ones, please submit a pull request. We encourage contributions in the following areas: | ||
| - Adding support for additional databases | ||
| - Developing new interfaces or applications that use the Text2SQL interface |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,223 @@ | ||
| # Text2SQL Evaluation and Fine-Tuning Tools for Llama Models | ||
|
|
||
| ## Overview | ||
|
|
||
| This folder contains scripts for evaluating Llama (original and fine-tuned) models on the Text2SQL task using the popular [BIRD](https://bird-bench.github.io) dataset, and scripts for generating fine-tuning datasets and fine-tuning Llama 3.1 8B with the datasets. | ||
|
|
||
| We have updated and significantly simplified the original eval scripts from the BIRD [repo](https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/bird) for Llama 3 & 4 models hosted via Meta's [Llama API](https://llama.developer.meta.com) or [Together.ai](https://together.ai), as well as the fine-tuned Llama 3.1 model, so you can quickly evaluate in 1-2-3 steps how well different Llama models perform on the Text2SQL task. | ||
|
|
||
| We have also provided end-to-end scripts for generating datasets (with and without reasoning steps) and fine-tuning the quantized Llama 3.1 8B model to gain a **165% (with no reasoning) and 209% (with reasoning) accuracy improvement** over the original model. | ||
|
|
||
| ## Llama Text2SQL Evaluation Results | ||
jeffxtang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| Below are the results of the Llama models we have evaluated on the BIRD DEV dataset: | ||
|
|
||
| | Model | Llama API Accuracy | Together Accuracy | | ||
| |------------------------|--------------------|-------------------| | ||
| | Llama 3.1 8b | - | 35.66% | | ||
| | Llama 3.3 70b | 54.11% | 54.63% | | ||
| | Llama-3.1-405B | - | 55.80% | | ||
| | Llama 4 Scout | 44.39% | 43.94% | | ||
| | Llama 4 Maverick | 44.00% | 41.46% | | ||
|
|
||
| - Llama 3.1 8b quantized model: 14.02% (original) | ||
| - Fine-tuned with no reasoning dataset: 37.16% | ||
| - Fine-tuned with reasoning dataset: 43.37% | ||
|
|
||
| ## Quick Start on Evaluating Llama on Text2SQL | ||
|
||
|
|
||
| First, run the commands below to create a new Conda environment and install all the required packages for Text2SQL evaluation and fine-tuning: | ||
|
|
||
| ``` | ||
| git clone https://github.com/meta-llama/llama-cookbook | ||
| cd llama-cookbook/end-to-end-use-cases/coding/text2sql/tool | ||
| conda create -n llama-text2sql python=3.10 | ||
| conda activate llama-text2sql | ||
| pip install -r requirements.txt | ||
| ``` | ||
|
|
||
| Then, follow the steps below to evaluate Llama 3 & 4 models on Text2SQL using the BIRD benchmark: | ||
|
|
||
| 1. Get the DEV dataset: | ||
| ``` | ||
| cd data | ||
| sh download_dev_unzip.sh | ||
| ``` | ||
|
|
||
| 2. Open `llama_eval.sh` and set `YOUR_API_KEY` to your [Llama API](https://llama.developer.meta.com/) key or [Together](https://api.together.ai/) API key, then uncomment a line that starts with `model=` to specify the Llama model to use for the text2sql eval. | ||
|
|
||
| 3. Run the evaluation script `sh llama_eval.sh`, which will use the BIRD DEV dataset (1534 examples in total) with external knowledge turned on to run the Llama model on each text question and compare the generated SQL with the gold SQL. | ||
|
|
||
| *Note:* If your API key or model name is incorrect, the script will exit with an authentication or model not supported error. | ||
|
|
||
| After the script completes, you'll see the accuracy of the Llama model on the BIRD DEV text2sql. For example, the total accuracy is about 54.24% with `YOUR_API_KEY` set to your Llama API key and `model='Llama-3.3-70B-Instruct'`, or about 35.07% with `YOUR_API_KEY` set to your Together API key and `model=meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo`. | ||
|
|
||
| *Note:* To compare your evaluated accuracy of your selected Llama model with other results in the BIRD Dev leaderboard, click [here](https://bird-bench.github.io/). | ||
|
|
||
| ## Evaluation Process | ||
|
|
||
| 1. **SQL Generation**: `llama_text2sql.py` sends natural language questions to the specified Llama model and collects the generated SQL queries. | ||
|
|
||
| 2. **SQL Execution**: `text2sql_eval.py` executes both the generated SQL and ground truth SQL against the corresponding databases, then continues with steps 3 and 4 below. | ||
|
|
||
| 3. **Result Comparison**: The results from executing the generated SQL are compared ([source code](text2sql_eval.py#L30)) with the results from the ground truth SQL to determine correctness. | ||
|
|
||
| 4. **Accuracy Calculation**: Accuracy scores are calculated overall and broken down by difficulty levels (simple, moderate, challenging). | ||
|
|
||
| ## Supported Models | ||
|
|
||
| ### Together AI Models | ||
| - meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo | ||
| - meta-llama/Llama-3.3-70B-Instruct-Turbo | ||
| - meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 | ||
| - meta-llama/Llama-4-Scout-17B-16E-Instruct | ||
| - other Llama models hosted on Together AI | ||
|
|
||
| ### Llama API Models | ||
| - Llama-3.3-8B-Instruct | ||
| - Llama-3.3-70B-Instruct | ||
| - Llama-4-Maverick-17B-128E-Instruct-FP8 | ||
| - Llama-4-Scout-17B-16E-Instruct-FP8 | ||
| - other Llama models hosted on Llama API | ||
|
|
||
| ## Fine-tuning with the BIRD TRAIN dataset (No Reasoning) | ||
|
|
||
| We'll first use the BIRD TRAIN dataset to prepare for supervised fine-tuning with no reasoning info in the dataset. | ||
|
|
||
| ### Using the TRAIN to prepare for supervised fine-tuning | ||
|
|
||
| 1. Get the TRAIN dataset: | ||
| ``` | ||
| cd data | ||
| sh download_train_unzip.sh | ||
| ``` | ||
|
|
||
| 2. Create the dataset | ||
|
|
||
| ``` | ||
| cd fine_tuning | ||
| python create_sft_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases | ||
| ``` | ||
|
|
||
| This will create `train_text2sql_sft_dataset.json` and `test_text2sql_sft_dataset.json` using the TRAIN set. Each line in the json files is in the conversation format ready for fine-tuning: | ||
|
|
||
| ``` | ||
| {"messages":[{"content":"You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, translate the following text question into a SQLite SQL select statement.","role":"system"},{"content":"-- DB Schema: <DB_SCHEMA>\n\n-- External Knowledge: <KNOWLEDGE_FROM_TRAIN>\n\n-- Question: <TEXT_QUESTION>","role":"user"},{"content":"<GOLD_SQL>","role":"assistant"}]} | ||
| ``` | ||
|
|
||
| ### Supervised Fine-tuning (No Reasoning) | ||
|
|
||
| First, you need to login to HuggingFace (via running `huggingface-cli login` and enter your [HF token](https://huggingface.co/settings/tokens)) and have been granted access to the [Llama 3.1 8B Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct) model. | ||
|
|
||
| Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py`. | ||
|
|
||
| After running `tensorboard --logdir ./llama31-8b-text2sql-fine_tuning` you can open `http://localhost:6006` to see the train loss chat etc: | ||
|
|
||
|  | ||
|
|
||
|
|
||
| ### Evaluating the fine-tuned model (No Reasoning) | ||
|
|
||
| First, modify `llama_eval.sh` to use the fine-tuned model: | ||
|
|
||
| ``` | ||
| YOUR_API_KEY='finetuned' | ||
| model='fine_tuning/llama31-8b-text2sql' | ||
| ``` | ||
|
|
||
| Then run `sh llama_eval.sh` to evaluate the fine-tuned model. The accuracy on the BIRD DEV dataset is about 37.16%. This is a 165% improvement over the model before fine-tuning, which has an accuracy of about 14.02% on the same dataset - you can confirm this by comparing the fine-tuned model's accuracy above with the original model's accuracy by modifying `llama_eval.sh` to use the original model: | ||
|
|
||
| ``` | ||
| YOUR_API_KEY='huggingface' | ||
| model='meta-llama/Llama-3.1-8B-Instruct' | ||
| ``` | ||
|
|
||
| Then running `sh llama_eval.sh` to evaluate the original model. | ||
|
|
||
| *Note:* We are using the 4-bit quantized Llama 3.1 8b model to reduce the memory footprint and improve the efficiency (as shown in the code nippet of llama_text2sql.py below), hence the accuracy of the quantized version (14.02%) is quite lower than the accuracy of the original Llama 3.1 8b (35.66%). | ||
|
|
||
| ``` | ||
| bnb_config = BitsAndBytesConfig( | ||
| load_in_4bit=True, | ||
| bnb_4bit_use_double_quant=True, | ||
| bnb_4bit_quant_type="nf4", | ||
| bnb_4bit_compute_dtype=torch.bfloat16, | ||
| ) | ||
| ``` | ||
|
|
||
| ## Fine-tuning with the BIRD TRAIN dataset (With Reasoning) | ||
|
|
||
| Next we'll use the BIRD TRAIN dataset to prepare for supervised fine-tuning with reasoning info in the dataset. The goal is to see if we can improve the accuracy of the fine-tuned model by adding the reasoning info in the dataset. | ||
|
|
||
| ### Creating a reasoning dataset from the TRAIN dataset | ||
|
||
|
|
||
| The script `create_reasoning_dataset.py` is used to create a reasoning dataset from the TRAIN dataset by asking Llama 3.3 70B to generate the reasoning for each text question and its corresponding gold SQL. The intent is to use the reasoning dataset to fine-tune the Llama model to improve the accuracy of the generated SQL. | ||
|
|
||
| To run the script, use the following commands: | ||
| ``` | ||
| cd fine_tuning | ||
| python create_reasoning_dataset.py --input_json ../data/train/train.json --db_root_path ../data/train/train_databases | ||
| ``` | ||
|
|
||
| This will create a `text2sql_cot_dataset` dataset and `train_text2sql_cot_dataset.json` in the conversation format ready for fine-tuning. Each example in the dataset is generated from the code snippet below: | ||
|
|
||
| ``` | ||
| prompt = f""" | ||
| -- DB Schema: {db_schema} | ||
| -- External Knowledge: {external_knowledge} | ||
| -- Text Question: {question} | ||
| """ | ||
| cot = { | ||
| "messages": [ | ||
| { | ||
| "role": "system", | ||
| "content": "You are a text to SQL query translator. Using the SQLite DB Schema and the External Knowledge, generate the step-by-step reasoning and the final SQLite SQL select statement from the text question.", | ||
| }, | ||
| {"role": "user", "content": prompt}, | ||
| {"role": "assistant", "content": reasoning}, | ||
| ] | ||
| } | ||
| ``` | ||
|
|
||
| The prompt for Llama 3.3 70B to generate the `reasoning` above is: | ||
| ``` | ||
| You are a text to SQL query translator. Based on the DB Schema and External Knowledge, given the Text Question Input and its Gold SQL Output below, generate the step-by-step reasoning to infer the Gold SQL Output from the Text Question Input. | ||
|
|
||
| -- DB Schema: {db_schema} | ||
| -- External Knowledge: {external_knowledge} | ||
| -- Text Question Input: {question} | ||
| -- Gold SQL Output: {gold_SQL} | ||
|
|
||
| Your response should be as follows:\n\n | ||
| Let me think through this step by step:\n\n1. First, I need to consider...\n2. Then...\n3. Next...\n...\n\nFinally, the SQL statement for the text question is: | ||
| ```sql ...```\n | ||
|
|
||
| """ | ||
| ``` | ||
|
|
||
| ### Supervised Fine-tuning (With Reasoning) | ||
|
|
||
| Uncomment the line `# FT_DATASET = "train_text2sql_cot_dataset.json"` in trl_sft.py to use the reasoning dataset for fine-tuning. Then run `python trl_sft.py`. After the fine-tuning completes, you'll see the fine-tuned model saved to `llama31-8b-text2sql-fine-tuned`, specified in `output_dir="llama31-8b-text2sql-fine-tuned"` of `TrainingArguments` in `trl_sft.py` - you may want to rename the `output_dir` folder to something else to avoid overwriting the previous fine-tuned model. | ||
|
|
||
| The train loss chart will look like this: | ||
|  | ||
|
|
||
| ### Evaluating the fine-tuned model (With Reasoning) | ||
|
|
||
| First, modify `llama_eval.sh` to use the fine-tuned model, which should match the `output_dir` in `TrainingArguments` in `trl_sft.py`: | ||
|
|
||
| ``` | ||
| YOUR_API_KEY='finetuned' | ||
| model='fine_tuning/llama31-8b-text2sql-fine-tuned' | ||
| ``` | ||
|
|
||
| Then uncomment the line `SYSTEM_PROMPT` [here](https://github.com/meta-llama/llama-cookbook/blob/text2sql/end-to-end-use-cases/coding/text2sql/tool/llama_text2sql.py#L31) in `llama_text2sql.py` to use it with the reasoning dataset fine-tuned model. | ||
|
|
||
| Now run `sh llama_eval.sh`, which will take longer because the reasoning is needed to generate the SQL. The accuracy this time is 43.37%, compared with 37.16% without reasoning. This is another 16% improvement over the model with fine-tuning without reasoning. | ||
|
|
||
| ## Next Steps | ||
| 1. Add a Colab notebook for fine-tuning and evaluation. | ||
| 2. Try reinforcement fine-tuning to improve the accuracy further with reasoning. | ||
| 3. Use torchtune for full and non-quantized fine-tuning of Llama 3.3 70b and Llama 4 models. | ||
| 4. Introduce agent to try to improve the accuracy further. | ||
| 5. Expand the tool to support other databases. | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| wget https://bird-bench.oss-cn-beijing.aliyuncs.com/dev.zip | ||
| unzip dev.zip | ||
| rm dev.zip | ||
| rm -rf __MACOSX | ||
| cd dev_20240627 | ||
| unzip dev_databases.zip | ||
| rm dev_databases.zip | ||
| rm -rf __MACOSX | ||
| cd .. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| wget https://bird-bench.oss-cn-beijing.aliyuncs.com/train.zip | ||
| UNZIP_DISABLE_ZIPBOMB_DETECTION=TRUE unzip train.zip | ||
| rm train.zip | ||
| rm -rf __MACOSX | ||
| cd train | ||
| unzip train_databases.zip | ||
| rm train_databases.zip | ||
| rm -rf __MACOSX | ||
| cd .. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffxtang are we positioning it as eval tool? then we need to find the narrative on delta between this tool and other leaderboards. Also, if the messaging is around a recipe on FT model for text2sql then benchmarks comes after that, we need to re-organize the narrative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's positioned as an eval tool, based on a popular and very active (latest submission in the leaderboard is May 30 2025) text2sql benchmark, on Llama 3 & 4 models including fine-tuned ones, but the end goal is to figure out how to improve the accuracy as much as possible using fine-tuning etc (in Next Steps), with the help of the eval tool.
In short, "no eval, no success" + "eval only != success".