Skip to content

Commit be89fef

Browse files
authored
Merge pull request #14 from DataArcTech/main
merge main
2 parents ec609ea + 23f5a0b commit be89fef

File tree

5 files changed

+27
-64
lines changed

5 files changed

+27
-64
lines changed

examples/TCL_rag/config.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,33 @@
11
llm:
22
name: openai
3-
base_url: "https://api.gptsapi.net/v1"
4-
api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
3+
base_url: "xxx"
4+
api_key: "xxx"
55
model: "gpt-4.1-mini"
66

77
embedding:
88
name: huggingface
9-
model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B"
9+
model_name: "xxx"
1010
model_kwargs:
1111
device: "cuda:0"
1212

1313

1414

1515
store:
1616
name: faiss
17-
folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store
17+
folder_path: xxx
1818

1919

2020
bm25:
2121
name: bm25
2222
k: 10
23-
data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
23+
data_path: xxx
2424

2525
retriever:
2626
name: vectorstore
2727

2828
reranker:
2929
name: qwen3
30-
model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B"
30+
model_name_or_path: "xxx"
3131
device_id: "cuda:0"
3232

3333
dataset:

examples/TCL_rag/test.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,8 @@
2424
vector_store_config=vector_store_config,
2525
bm25_retriever_config=bm25_retriever_config)
2626

27-
result = rag.invoke("毛细管设计规范按照什么标准",k=20)
27+
result = rag.invoke("模块机传感器端子不防呆的改善方案是什么?由哪个部门负责?",k=20)
2828

29-
answer = rag.answer("毛细管设计规范按照什么标准",result)
30-
31-
32-
print(answer)
29+
for i in result:
30+
print(i)
31+
print("-"*100)

rag_factory/Retrieval/Retriever/Retriever_BM25.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from concurrent.futures import ThreadPoolExecutor
66
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence
77
from dataclasses import dataclass, field
8-
8+
import uuid
99
from pydantic import ConfigDict, Field, model_validator
1010

1111
logger = logging.getLogger(__name__)
@@ -207,7 +207,7 @@ def from_texts(
207207
f"与 texts 长度 ({len(texts_list)}) 不匹配"
208208
)
209209
else:
210-
ids_list = [None for _ in texts_list]
210+
ids_list = [str(uuid.uuid4()) for _ in texts_list]
211211

212212
# 预处理文本
213213
logger.info(f"正在预处理 {len(texts_list)} 个文本...")

rag_factory/Retrieval/Retriever/Retriever_MultiPath.py

Lines changed: 10 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
5050
5151
Note:
5252
- 每个检索器的结果会被转换为RetrievalResult格式
53-
- 支持多种输入格式:Document对象、字典格式、字符串等
54-
- 融合后的结果会将score和rank信息保存在Document的metadata中
53+
- 输入只会是Document对象
54+
- 融合后的结果只返回排序好的Document对象
5555
"""
5656
top_k = kwargs.get('top_k', 10)
5757

@@ -65,43 +65,12 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
6565
# 转换为RetrievalResult格式
6666
formatted_results = []
6767
for i, doc in enumerate(documents):
68-
if isinstance(doc, Document):
69-
# 如果是Document对象
70-
retrieval_result = RetrievalResult(
71-
document=doc,
72-
score=getattr(doc, 'score', 1.0),
73-
rank=i + 1
74-
)
75-
elif isinstance(doc, dict):
76-
# 如果返回的是字典格式,需要转换为Document对象
77-
content = doc.get('content', '')
78-
metadata = doc.get('metadata', {})
79-
doc_id = doc.get('id')
80-
81-
document = Document(
82-
content=content,
83-
metadata=metadata,
84-
id=doc_id
85-
)
86-
87-
retrieval_result = RetrievalResult(
88-
document=document,
89-
score=doc.get('score', 1.0),
90-
rank=i + 1
91-
)
92-
else:
93-
# 如果是字符串或其他格式,转换为Document对象
94-
document = Document(
95-
content=str(doc),
96-
metadata={},
97-
id=None
98-
)
99-
100-
retrieval_result = RetrievalResult(
101-
document=document,
102-
score=1.0,
103-
rank=i + 1
104-
)
68+
# 输入只会是Document对象
69+
retrieval_result = RetrievalResult(
70+
document=doc,
71+
score=getattr(doc, 'score', 1.0),
72+
rank=i + 1
73+
)
10574
formatted_results.append(retrieval_result)
10675

10776
all_results.append(formatted_results)
@@ -116,16 +85,10 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
11685

11786
fused_results = self.fusion_method.fuse(all_results, top_k)
11887

119-
# 转换回Document格式
88+
# 转换回Document格式,只返回排序好的Document对象
12089
documents = []
12190
for result in fused_results:
122-
doc = result.document
123-
# 将score和rank添加到metadata中以便保留
124-
if doc.metadata is None:
125-
doc.metadata = {}
126-
doc.metadata['score'] = result.score
127-
doc.metadata['rank'] = result.rank
128-
documents.append(doc)
91+
documents.append(result.document)
12992

13093
return documents
13194

requirements.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,13 @@ llama-index
1414
llama-index-core
1515
peewee
1616

17-
mineru[core]
17+
1818
rank_bm25
1919
faiss_gpu
2020

2121

22-
23-
# streamlit
22+
# for ocr
2423
PyMuPDF
25-
openai
2624
qwen_vl_utils
2725
transformers==4.51.3
2826
huggingface_hub
@@ -31,3 +29,6 @@ flash-attn==2.8.0.post2
3129
# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2
3230
accelerate
3331
dashscope
32+
torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128
33+
34+
mineru[core]

0 commit comments

Comments
 (0)