Skip to content

Commit

Permalink
[Update] add example & evaluation code & date of ANAH-v2 (#7)
Browse files Browse the repository at this point in the history
* add anahv2 example & evaluation & update readme

* Add ANAH-v2 Data & Track large file with Git LFS
  • Loading branch information
Liqu1d-G authored Dec 11, 2024
1 parent 9691b55 commit 7d2bf01
Show file tree
Hide file tree
Showing 12 changed files with 750 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
eval/anah_v2/question_document.jsonl filter=lfs diff=lfs merge=lfs -text
43 changes: 34 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The repo contains:

+ The [data](#huggingface-dataset) for training and evaluating the LLM which consists of sentence-level hallucination annotations.
+ The [model](#huggingface-model) for annotating the hallucination.
+ The [code](#evaluation) for evaluating the LLMs' ability to annotate hallucination.
+ The [code](#evaluation) for evaluating the hallucinations level of LLM-generated content and the LLMs' ability to annotate hallucination.


## 🚀 What's New
Expand Down Expand Up @@ -60,7 +60,8 @@ The ANAH dataset is available on Huggingface dataset hub.
| Dataset | Huggingface Repo |
|---------|------------------|
| ANAH | [Dataset Link](https://huggingface.co/datasets/opencompass/anah) |
| ANAH-v2 | [Will Open Source] |

We also release the topics, questions and reference documents of ANAH-v2, you can find it [here](https://github.com/open-compass/ANAH/blob/main/eval/anah_v2/question_document.jsonl).

<a name="huggingface-model"></a>
### Model
Expand All @@ -82,14 +83,38 @@ You have to follow the prompt in our paper to annotate the hallucination. Note t

We recommand you to use the more advanced annotator ANAH-v2 and its prompt can be found [here](https://github.com/open-compass/ANAH/blob/main/prompt_v2.py).

The models follow the conversation format of InternLM2-chat, with the template protocol as:
We also provide some [examples](https://github.com/open-compass/ANAH/blob/main/example) of using the ANAH-v2 annotator, which you can refer to for annotating your content.

```python
dict(role='user', begin='<|im_start|>user\n', end='<|im_end|>\n'),
dict(role='assistant', begin='<|im_start|>assistant\n', end='<|im_end|>\n'),
```
<a name="evaluation"></a>
## 🏗️ ️Evaluation

## 🏗️ ️Evaluation for the hallucinations level of LLM-generated content.

ANAH-v2 is a nice hallucination annotator that can be used to assess the level of hallucinations in LLM-generated content.

### 1. Responses Generation

For the models you want to evaluate, collect their responses under some questions. We recommend that you use the questions from the [ANAH-v2 dataset](https://github.com/open-compass/ANAH/blob/main/eval/anah_v2/question_document.jsonl), but you can also use your custom questions. Then, construct your model response file in the following format:

```json
{"question": "...", "response": "..."}
{"question": "...", "response": "..."}
```


### 2. Hallucination Score Evaluation

Put the path to the model response file you just got into `{your_model_response_path}`. Then run the following command. You can get the hallucination annotation result in `{your_annotation_result_path}` and the factuality score (higher score means lower level of hallucination) in `{your_evaluation_result_path}`.

```bash
python -u ./eval/anah_v2/eval.py \
--json_path {your_model_response_path} \
--annotation_path {your_annotation_result_path} \
--eval_path {your_evaluation_result_path} \
```

Note that if you are using the customized questions, you will need to prepare a `question_document` file to be entered as `--document_path`. You can refer to the format of [this file](https://github.com/open-compass/ANAH/blob/main/eval/anah_v2/question_document.jsonl) to organize your file.

## 🏗️ ️Evaluation for LLMs' ability to generate fine-grained hallucination annotation.

ANAH can be used for evaluating the current open-source and close-source LLMs' ability to generate fine-grained hallucination annotation.

Expand All @@ -114,7 +139,7 @@ We recommend you download the huggingface model to your local path and replace t
Our evaluations are conducted on NVIDIA A100 GPUs, and OOM may occur on other types of machines.

```bash
python -u ./eval/eval.py \
python -u ./eval/anah_v1/eval.py \
--model_type {your_model_type} \
--server_addr {your_hf_model_path} \
--json_path {test_set_path} \
Expand Down
File renamed without changes.
File renamed without changes.
161 changes: 161 additions & 0 deletions eval/anah_v2/anahv2_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
def fact_check_prompt(question, annotation, language):
cn_user_prompt = f"""
你将作为一个事实判断器,我会给你提供一个问题和一个针对该问题的部分回答,你的任务是判断回答中的内容是否存在可以判断的事实。
## 判断标准:
- **可以判断的事实:** 具体的、客观的信息点,这些信息可以通过数据、研究结果或其他可靠来源进行验证。例如,统计数据、历史事件、科学定律、具体案例等。
- **非事实描述:** 个人意见、主观判断或无法验证的声明。
## 任务流程:
### 1. **仔细阅读问题,问题如下:**
{question}
### 2. **仔细阅读回答,部分回答如下:**
{annotation}
### 3. **进行分析:** 根据上述判断标准,判断回答中是否包含可以判断的事实。
- 如果回答中不存在可以判断的事实,则输出“<无事实>”。
- 如果回答中存在可以判断的事实,则输出“<有事实>”。
"""

en_user_prompt = f"""
You will act as a fact checker, and I will provide you with a question and a corresponding partial answer. Your task is to determine whether the content of the answer contains verifiable facts.
## Judgment Criteria:
- **Verifiable Facts:** Specific, objective points of information that can be verified through data, research results, or other reliable sources. Examples include statistical data, historical events, scientific laws, and specific case studies.
- **Non-factual Descriptions:** Personal opinions, subjective judgments, or unverifiable statements.
## Task Process:
### 1. **Carefully read the question, which is as follows:**
{question}
### 2. **Carefully read the partial answer, which is as follows:**
{annotation}
### 3. **Conduct the Analysis:** Based on the above judgment criteria, determine if the answer contains verifiable facts.
- If there are no verifiable facts in the answer, output “<No Facts>”.
- If there are verifiable facts in the answer, output “<Facts Present>”.
"""
return cn_user_prompt if language == "zh" else en_user_prompt


def reference_check_prompt(question, reference, annotation, language):
cn_user_prompt = f"""
你将作为一个信息提取器,我将给你提供一个问题、一份相关的参考文档,以及一个针对该问题的部分回答,你的任务是从参考文档中提炼出与问题和回答相关的信息。
## 操作步骤:
### 1. **仔细阅读问题,问题如下:**
{question}
### 2. **仔细阅读回答,部分回答如下:**
{annotation}
### 3. **分析参考文档:** 找出与问题和回答最相关的信息,这些信息可能与回答内容完全相同、部分相同,或存在冲突。
**参考文档如下:**
{reference}
### 4. **列出相关信息:** 按顺序列出所有发现的相关信息,如果有多条信息的话以 <SEP> 作为分隔。
### 5. **无相关信息时输出:** 如果没有找到相关信息,请输出<无参考信息>。
"""

en_user_prompt = f"""
You will act as an information extractor. I will provide you with a question, a related reference document, and a partial answer to that question. Your task is to extract information from the reference document that is relevant to the question and answer.
## Operational Steps:
### 1. **Carefully read the question, which is as follows:**
{question}
### 2. **Carefully read the partial answer, which is as follows:**
{annotation}
### 3. **Analyze the Reference Document:** Identify information most relevant to the question and answer. This information may be completely the same, partially similar, or conflicting with the content of the answer.
**The reference document is as follows:**
{reference}
### 4. **List the Relevant Information:** List all the relevant information found in order, separated by <SEP> if there are multiple pieces of information.
### 5. **Output When No Information Is Found:** If no relevant information is found, output <No Reference Information>.
"""

return cn_user_prompt if language == "zh" else en_user_prompt


def hallucination_check_prompt(question, reference, annotation, language):
cn_user_prompt = f"""
你将作为一个‘幻觉’标注器,我将会给你提供一个一个问题,一个针对该问题的部分回答和相关的参考要点。你需要判断提供的回答中是否含有幻觉性内容,并标注幻觉类型。
‘幻觉’指的是与参考要点相矛盾或在参考要点中没有依据的内容。
## 判断准则:
1. **无幻觉:** 如果回答与参考要点完全一致,且没有引入与参考要点相矛盾的信息,请输出:<无幻觉>。
2. **矛盾:** 如果回答内容与参考要点存在明显矛盾,请输出:<矛盾>。
3. **无法验证:** 如果回答包含的信息在参考要点中没有提及,且无法从参考要点中得到支持或验证,请输出:<无法验证>。
## 任务流程:
### 1. **仔细阅读问题,问题如下:**
{question}
### 2. **仔细阅读回答,部分回答如下:**
{annotation}
### 3. **仔细阅读参考要点,参考要点如下:**
{reference}
### 4. **进行分析:** 根据上述判断标准,判断回答中是否包含幻觉,并输出幻觉类型。
"""

en_user_prompt = f"""
You will act as a 'Hallucination' annotator. I will provide you with a question, a partial answer to that question, and related reference points. You need to determine whether the provided answer contains any hallucinatory content and annotate the type of hallucination.
'Hallucination' refers to content that contradicts the reference points or is unsupported by them.
## Judgment Criteria:
1. **No Hallucination:** If the answer is completely consistent with the reference points and does not introduce any contradictory information, output: <No Hallucination>.
2. **Contradiction:** If the answer clearly contradicts the reference points, output: <Contradictory>.
3. **Unverifiable:** If the answer contains information not mentioned in the reference points and cannot be supported or verified by them, output: <Unverifiable>.
## Task Process:
### 1. **Carefully read the question, which is as follows:**
{question}
### 2. **Carefully read the partial answer, which is as follows:**
{annotation}
### 3. **Carefully read the reference points, which are as follows:**
{reference}
### 4. **Conduct the analysis:** Based on the above judgment criteria, determine if the answer contains hallucinations and output the type of hallucination.
"""

return cn_user_prompt if language == "zh" else en_user_prompt
122 changes: 122 additions & 0 deletions eval/anah_v2/eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
import subprocess
import argparse
import csv
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

from utils import get_lines_from_path, write_lines_to_path, sentence_tokenize
from anahv2_prompt import fact_check_prompt, reference_check_prompt, hallucination_check_prompt

def process_question(model, tokenizer, question: str, sentence: str, document: str, language: str) -> str:
"""Process the question through multiple steps: fact-checking, reference-checking, hallucination-checking."""
# Step 1: Fact-checking
fact_check = fact_check_prompt(question, sentence, language)
messages = [{"role": "user", "content": fact_check}]
response, _ = model.chat(tokenizer, messages)

if response == "<No Facts>" or response == "<无事实>":
return "nofact"

messages.append({"role": "assistant", "content": response})

# Step 2: Reference-checking
reference_check = reference_check_prompt(question, document, sentence, language)
messages.append({"role": "user", "content": reference_check})
response, _ = model.chat(tokenizer, messages)

response_tmp = response.strip().replace(" ", "").lower()
if "noreferenceinformation" in response_tmp or "无参考信息" in response_tmp:
return "unverifiable"

reference = response
messages.append({"role": "assistant", "content": reference})

# Step 3: Hallucination-checking
hallucination_check = hallucination_check_prompt(question, reference, sentence, language)
messages.append({"role": "user", "content": hallucination_check})
response, _ = model.chat(tokenizer, messages)

hallucination_type = response.strip().replace(" ", "").lower()
if "nohallucination" in hallucination_type or "无幻觉" in hallucination_type:
return "ok"
elif "contradictory" in hallucination_type or "矛盾" in hallucination_type:
return "contradictory"
elif "unverifiable" in hallucination_type or "无法验证" in hallucination_type:
return "unverifiable"


def evaluate_output(args):
lines = get_lines_from_path(args.annotation_path)

total, contradictory, unverifiable, ok, nofact = 0, 0, 0, 0, 0
for line in lines:
total += 1
annotation = line["annotation"]
if annotation == "contradictory":
contradictory += 1
if annotation == "unverifiable":
unverifiable += 1
if annotation == "ok":
ok += 1
if annotation == "nofact":
nofact += 1

hallu = contradictory + unverifiable

with open(args.eval_path, "w") as f:
writer = csv.writer(f)
writer.writerow(["ok", "nofact", "contradictory", "unverifiable", "hallucination", "total", "score"])
writer.writerow([
f"{(ok/total)*100:.2f}%",
f"{(nofact/total)*100:.2f}%",
f"{(contradictory/total)*100:.2f}%",
f"{(unverifiable/total)*100:.2f}%",
hallu,
total,
f"{(1 - hallu/total)*100:.2f}",
])


def run(args):
path = 'opencompass/anah-v2'
tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.float16, trust_remote_code=True).cuda()
model.eval()

old_lines = get_lines_from_path("eval/anah_v2/question_document.jsonl")
ref_lines = dict()
for line in old_lines:
ref_lines[line["question"]] = dict(document=line["document"], language=line["language"])

lines = get_lines_from_path(args.json_path)
new_lines = []
for line in lines:
response = line["response"]
question = line["question"]
language = ref_lines[question]["language"]
document = ref_lines[question]["document"]
sentences = sentence_tokenize(response, language, keep_end=False)
for sent in sentences:
annotation = process_question(model, tokenizer, question, sent, document, language)
new_lines.append({"question": question, "response": response, "sentence": sent, "annotation": annotation, "language": language})

write_lines_to_path(args.annotation_path, new_lines)

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--json_path", type=str)
parser.add_argument("--document_path", default="eval/anah_v2/question_document.jsonl", type=str)
parser.add_argument("--annotation_path", type=str)
parser.add_argument("--eval_path", type=str)
args = parser.parse_args()

if os.path.exists(args.eval_path):
print("Already evaluated")
exit(0)

if not os.path.exists(args.annotation_path):
run(args)

evaluate_output(args)

3 changes: 3 additions & 0 deletions eval/anah_v2/question_document.jsonl
Git LFS file not shown
Loading

0 comments on commit 7d2bf01

Please sign in to comment.