-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpipeline_test_memreader.py
More file actions
2219 lines (1910 loc) · 109 KB
/
pipeline_test_memreader.py
File metadata and controls
2219 lines (1910 loc) · 109 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import time
import uuid
import json
import numpy as np
from typing import List, Dict, Optional, Any, Union
from dataclasses import dataclass
from dotenv import load_dotenv
from openai import OpenAI
from utils import (get_embedding, parse_messages, LME_JUDGE_MODEL_TEMPLATE,
LME_ANSWER_PROMPT, remove_code_blocks, extract_json)
from lme_eval import lme_grader
from datetime import datetime, timezone
import pytz
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from vector_db import VectorDBConfig, VectorDBFactory
# ==========================================
# 0. Setup & Prompts
# ==========================================
load_dotenv()
# ⚠️ 请确保环境变量中有 OPENAI_API_KEY 和 MILVUS_URI
# 如果是本地测试,确保 Docker 中 Milvus 已启动
llm_client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
base_url=os.getenv("OPENAI_BASE_URL")
)
MEMREADER_PROMPT = f"""You are a Personal Information Organizer, specialized in accurately storing facts, user memories, and preferences. Your primary role is to extract relevant pieces of information from conversations and organize them into distinct, manageable facts. This allows for easy retrieval and personalization in future interactions. Below are the types of information you need to focus on and the detailed instructions on how to handle the input data.
Types of Information to Remember:
1. Store Personal Preferences: Keep track of likes, dislikes, and specific preferences in various categories such as food, products, activities, and entertainment.
2. Maintain Important Personal Details: Remember significant personal information like names, relationships, and important dates.
3. Track Plans and Intentions: Note upcoming events, trips, goals, and any plans the user has shared.
4. Remember Activity and Service Preferences: Recall preferences for dining, travel, hobbies, and other services.
5. Monitor Health and Wellness Preferences: Keep a record of dietary restrictions, fitness routines, and other wellness-related information.
6. Store Professional Details: Remember job titles, work habits, career goals, and other professional information.
7. Miscellaneous Information Management: Keep track of favorite books, movies, brands, and other miscellaneous details that the user shares.
Here are some few shot examples:
Input: Hi.
Output: {{"facts" : []}}
Input: There are branches in trees.
Output: {{"facts" : []}}
Input: Hi, I am looking for a restaurant in San Francisco.
Output: {{"facts" : [
{{"fact": "Looking for a restaurant", "details": ["Location: San Francisco", "Intent: Dining"]}}
]}}
Input: Hi, my name is John. I am a software engineer.
Output: {{"facts" : [
{{"fact": "Name is John", "details": []}},
{{"fact": "Is a Software engineer", "details": []}}
]}}
Input: I'm at the downtown library using their free wifi. I managed to download that large dataset.
Output: {{"facts" : [
{{"fact": "Downloaded a large dataset", "details": ["Location: Downtown library", "Connection: Free wifi"]}}
]}}
Input: My favourite movies are Inception and Interstellar.
Output: {{"facts" : [
{{"fact": "Favourite movies are Inception and Interstellar", "details": []}}
]}}
Return the facts and preferences in a json format as shown above.
Remember the following:
- Today's date is {datetime.now().strftime("%Y-%m-%d")}.
- **Supplementary Details**: The `details` list must act as **METADATA** to supplement the fact (e.g., Time, Location, Price, Platform, Reason), **NOT** just splitting the fact's words. (e.g., If fact is "Bought apple", details should be ["Price: $1", "Store: Aldi"], NOT ["Action: Buy", "Object: Apple"]).
- **Context Propagation**: Ensure every extracted fact is **self-contained**. If a shared context (e.g., location, platform, activity, or timeframe) is established anywhere in the input chunk, explicitly include it in the `details` of all relevant facts, even if not repeated in every sentence.
- ALWAYS resolve relative time expressions (e.g., "yesterday", "next Friday") into absolute ISO dates (YYYY-MM-DD) based on Today's date in the details.
- Do not return anything from the custom few shot example prompts provided above.
- If you do not find anything relevant in the below conversation, you can return an empty list corresponding to the "facts" key.
- Create the facts based on the user and assistant messages only. Do not pick anything from the system messages.
- Make sure to return the response in the format mentioned in the examples. The response should be in json with a key as "facts", where each item has a "fact" string and a "details" list of strings.
Following is a conversation between the user and the assistant. You have to extract the relevant facts and preferences about the user, if any, from the conversation and return them in the json format as shown above.
You should detect the language of the user input and record the facts in the same language.
"""
MEMORY_MANAGER_PROMPT = """You are a specialized Memory Manager Agent.
Your role is to maintain the consistency and growth of a memory graph using the provided tools.
[INPUTS]
You will receive:
1. "New Facts": A list of atomic facts extracted from the latest user input.
2. "Existing Memories": A list of retrieved memory items, each with a simplified Integer ID (e.g., "0", "1", "2").
- These memories include those directly related to the new facts, as well as other related facts connected with these memories.
- They form a connected graph of information relevant to the new facts.
[MANDATORY OUTPUT FORMAT]
For every new fact you process, you MUST:
1. First generate a detailed thinking process
2. Then call the appropriate tool
[THINKING PROCESS REQUIREMENTS]
Your thinking process MUST include:
- The specific new fact you're analyzing
- Which existing memories are relevant (with their IDs)
- How memories are connected through related facts
- Your comparison and reasoning
- Which operation you've decided to perform and why
[OPERATIONS & GUIDELINES]
Compare New Facts with Existing Memories and perform the following operations using the available tools.
DO NOT output raw JSON text. You MUST use the provided function tools.
1. **ADD (create_memory)**
- **Condition**: If a fact contains completely NEW information not present in Existing Memories.
- **Action**: Call `create_memory` with a concise summary of the facts, not just a simple concatenation.
- **Important**: Memory content should be a meaningful and concise summary.
2. **UPDATE (update_memory)**
- **Condition**: If a fact adds detail, corrects, or updates a specific Existing Memory.
- **Constraint**: You MUST use the Integer ID (e.g., "0") provided in the input as the `target_memory_id`.
- **Logic**: Merge the old content and new fact into a comprehensive statement, not just a simple concatenation.
- **Example**:
- Old (ID="0"): "User likes generic pizza."
- New Fact: "User loves pepperoni pizza."
- Action: `update_memory(target_memory_id="0", new_content="User loves pepperoni pizza", ...)`
3. **DELETE (delete_memory)**
- **Condition**: If a fact explicitly contradicts an Existing Memory (and the new fact is trusted), or if the memory is no longer valid.
- **Constraint**: Use the Integer ID (e.g., "1") as `target_memory_id`.
4. **INFER (infer_memory) [CRITICAL]**
- **Condition**: Look for higher-level insights. If combining "Memory A" and "Memory B" reveals a hidden connection or causality.
- **Action**: Call `infer_memory`.
- **Example**:
- Memory A (ID="2"): "User moved to Singapore."
- Memory B (ID="3"): "User bought a Type G power adapter."
- Inference: "User is preparing electronics for Singapore power standards."
- Action: `infer_memory(source_memory_ids=["2", "3"], inference_content="...")`
5. **NOOP (no_operation)**
- **Condition**: If the fact is redundant (already exactly covered by memory), similar to existing facts associated with the retrieved memories, or trivial.
[STRICT ID RULES]
- When calling `update_memory` or `delete_memory`, **ONLY** use the string integer IDs (e.g., "0", "1", "2") found in the [EXISTING MEMORIES] list.
- **NEVER** invent a UUID or use an ID that is not in the provided list.
"""
# --- TOOLS ---
MEMORY_TOOLS = [
{
"type": "function",
"function": {
"name": "create_memory",
"description": "Create a NEW independent memory node with a concise summary of the facts.",
"parameters": {
"type": "object",
"properties": {
"content": {"type": "string", "description": "The concise summary content of the new memory, not just a list of facts."},
"evidence_facts": {"type": "array", "items": {"type": "string"}, "description": "Facts supporting this memory."}
},
"required": ["content", "evidence_facts"]
}
}
},
{
"type": "function",
"function": {
"name": "update_memory",
"description": "Update an existing memory by merging the old content and new fact into a comprehensive, concise statement.",
"parameters": {
"type": "object",
"properties": {
"target_memory_id": {"type": "string", "description": "The simplified Integer ID (e.g., '0') of the memory to update, found in the [EXISTING MEMORIES] list."},
"new_content": {"type": "string", "description": "The merged/updated comprehensive statement."},
"evidence_facts": {"type": "array", "items": {"type": "string"}, "description": "Facts supporting this update."}
},
"required": ["target_memory_id", "new_content", "evidence_facts"]
}
}
},
{
"type": "function",
"function": {
"name": "infer_memory",
"description": "Look for higher-level insights. If combining multiple existing memories reveals a hidden connection or causality, create an inferred memory.",
"parameters": {
"type": "object",
"properties": {
"source_memory_ids": {"type": "array", "items": {"type": "string"}, "description": "List of simplified Integer IDs (e.g., ['0', '1']) acting as premises, found in the [EXISTING MEMORIES] list."},
"inference_content": {"type": "string", "description": "The higher-level insight or inference derived from combining the source memories."},
"evidence_facts": {"type": "array", "items": {"type": "string"}, "description": "Facts supporting this inference."}
},
"required": ["source_memory_ids", "inference_content", "evidence_facts"]
}
}
},
{
"type": "function",
"function": {
"name": "delete_memory",
"description": "Archive/Soft-delete a memory if it explicitly contradicts a new fact (and the new fact is trusted), or if the memory is no longer valid.",
"parameters": {
"type": "object",
"properties": {
"target_memory_id": {"type": "string", "description": "The simplified Integer ID (e.g., '1') of the memory to delete, found in the [EXISTING MEMORIES] list."},
"evidence_facts": {"type": "array", "items": {"type": "string"}, "description": "Facts supporting this deletion."}
},
"required": ["target_memory_id", "evidence_facts"]
}
}
},
{
"type": "function",
"function": {
"name": "no_operation",
"description": "No action needed if the fact is redundant (already exactly covered by memory or its associated facts).",
"parameters": {
"type": "object",
"properties": {"reason": {"type": "string", "description": "The reason for no operation."}},
"required": ["reason"]
}
}
}
]
# --- UTILS ---
def get_embedding(text: str) -> List[float]:
text = text.replace("\n", " ")
return llm_client.embeddings.create(input=[text], model="text-embedding-3-small").data[0].embedding
@dataclass
class MilvusConfig:
"""Milvus配置类(兼容旧代码)"""
uri: str = os.getenv("MILVUS_URI")
user_name: str = os.getenv("MILVUS_USER_NAME")
# password: str = os.getenv("MILVUS_PASSWORD")
db_name: str = os.getenv("MILVUS_DB_NAME", "default")
dimension: int = 1536
def to_vector_db_config(self, vector_db_type: str = "milvus") -> VectorDBConfig:
"""转换为VectorDBConfig"""
# 确保vector_db_type是字符串类型
if not isinstance(vector_db_type, str):
vector_db_type = "milvus" # 默认使用milvus
# 根据vector_db_type选择不同的URL
if vector_db_type == "qdrant":
uri = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
user_name = ""
password = ""
else:
uri = self.uri
api_key = ""
user_name = self.user_name
password = os.getenv("MILVUS_PASSWORD")
return VectorDBConfig(
uri=uri,
user_name=user_name,
password=password,
api_key=api_key,
db_name=self.db_name,
dimension=self.dimension,
vector_db_type=vector_db_type
)
# ==========================================
# 1. Pipeline Class
# ==========================================
class MemoryPipeline:
def __init__(self, config=None, vector_db_type="milvus", clear_db=False, mode='eval', dataset_name=""):
"""初始化MemoryPipeline
Args:
config: MilvusConfig或VectorDBConfig实例,如果为None则使用默认配置
vector_db_type: 指定使用的向量数据库类型,支持"milvus"或"qdrant"
clear_db: 是否清空数据库,默认为False
dataset_name: 数据集名称,用于集合名称后缀,默认为空
"""
# 如果没有提供配置,创建默认配置
if config is None:
config = MilvusConfig()
self.config = config
# 转换为VectorDBConfig
if hasattr(config, 'to_vector_db_config'):
vector_db_config = config.to_vector_db_config(vector_db_type=vector_db_type)
else:
# 如果已经是VectorDBConfig实例,直接使用
vector_db_config = config
# 使用工厂类创建向量数据库客户端
self.client = VectorDBFactory.create_db(vector_db_config)
# 根据模式和数据集名称设置集合名称
base_suffix = "_test" if mode == 'test' else ""
dataset_suffix = f"_{dataset_name}" if dataset_name else ""
full_suffix = f"{base_suffix}{dataset_suffix}"
self.semantic_col = f"memories{full_suffix}_v1"
self.fact_col = f"facts{full_suffix}_v1"
self.chunk_col = f"chunks{full_suffix}_v1"
self.dim = vector_db_config.dimension # Save dimension as instance variable
# 初始化操作次数计数器
self.operation_counts = {"ADD": 0, "UPDATE": 0, "DELETE": 0, "INFER": 0, "NOOP": 0}
self._init_collections(clear_db=clear_db)
def _init_collections(self, clear_db=False):
dim = self.config.dimension
# 如果需要清空数据库,先删除所有集合
if clear_db:
print("正在清空数据库...")
# 直接删除集合,不检查存在性
self.client.drop_collection(self.semantic_col)
self.client.drop_collection(self.fact_col)
self.client.drop_collection(self.chunk_col)
print("数据库清空完成.")
# 检查并创建集合
# 处理 memories 集合
if hasattr(self.client, 'DataType'):
# 这是 Milvus 客户端
# 检查集合是否存在
if not self.client.has_collection(self.semantic_col):
# 创建完整的schema
s = self.client.create_schema(auto_id=False, enable_dynamic_field=True)
s.add_field("memory_id", self.client.DataType.VARCHAR, max_length=64, is_primary=True)
s.add_field("embedding", self.client.DataType.FLOAT_VECTOR, dim=dim)
s.add_field("content", self.client.DataType.VARCHAR, max_length=65535)
s.add_field("user_id", self.client.DataType.VARCHAR, max_length=64)
s.add_field("status", self.client.DataType.VARCHAR, max_length=16)
s.add_field("created_at", self.client.DataType.INT64)
s.add_field("updated_at", self.client.DataType.INT64)
s.add_field("relations", self.client.DataType.JSON)
# 创建集合
self.client.create_collection(self.semantic_col, schema=s)
print(f"Collection '{self.semantic_col}' created.")
# 直接创建索引,不检查索引是否存在
# Milvus的create_index方法会在索引已存在时自动跳过或返回成功
try:
print(f"为集合 '{self.semantic_col}' 创建索引...")
idx_params = self.client.prepare_index_params()
idx_params.add_index(field_name="embedding", index_type="IVF_FLAT", metric_type="COSINE", params={"nlist": 128})
self.client.create_index(self.semantic_col, index_params=idx_params)
print(f"集合 '{self.semantic_col}' 的索引创建成功或已存在")
except Exception as e:
print(f"创建索引失败: {e}")
else:
print(f"Collection '{self.semantic_col}' already exists, skipping creation.")
else:
# 非Milvus客户端,直接创建集合
self.client.create_collection(self.semantic_col)
print(f"Collection '{self.semantic_col}' created or exists.")
# 处理 facts 集合
if hasattr(self.client, 'DataType'):
# 这是 Milvus 客户端
# 检查集合是否存在
if not self.client.has_collection(self.fact_col):
s = self.client.create_schema(auto_id=False, enable_dynamic_field=True)
s.add_field("fact_id", self.client.DataType.VARCHAR, max_length=64, is_primary=True)
s.add_field("linked_memory_ids", self.client.DataType.JSON)
s.add_field("linked_chunk_id", self.client.DataType.VARCHAR, max_length=64)
s.add_field("text", self.client.DataType.VARCHAR, max_length=65535)
s.add_field("details", self.client.DataType.JSON) # 添加details字段
s.add_field("timestamp", self.client.DataType.INT64)
s.add_field("user_id", self.client.DataType.VARCHAR, max_length=64) # 添加user_id字段
s.add_field("embedding", self.client.DataType.FLOAT_VECTOR, dim=dim)
# 创建集合
self.client.create_collection(self.fact_col, schema=s)
print(f"Collection '{self.fact_col}' created.")
# 直接创建索引,不检查索引是否存在
# Milvus的create_index方法会在索引已存在时自动跳过或返回成功
try:
print(f"为集合 '{self.fact_col}' 创建索引...")
idx_params = self.client.prepare_index_params()
idx_params.add_index(field_name="embedding", index_type="IVF_FLAT", metric_type="COSINE", params={"nlist": 128})
self.client.create_index(self.fact_col, index_params=idx_params)
print(f"集合 '{self.fact_col}' 的索引创建成功或已存在")
except Exception as e:
print(f"创建索引失败: {e}")
else:
print(f"Collection '{self.fact_col}' already exists, skipping creation.")
else:
# 非Milvus客户端,直接创建集合
self.client.create_collection(self.fact_col)
print(f"Collection '{self.fact_col}' created or exists.")
# 处理 chunks 集合
if hasattr(self.client, 'DataType'):
# 这是 Milvus 客户端
# 检查集合是否存在
if not self.client.has_collection(self.chunk_col):
s = self.client.create_schema(auto_id=False, enable_dynamic_field=True)
s.add_field("chunk_id", self.client.DataType.VARCHAR, max_length=64, is_primary=True)
s.add_field("text", self.client.DataType.VARCHAR, max_length=65535)
s.add_field("timestamp", self.client.DataType.INT64)
s.add_field("embedding", self.client.DataType.FLOAT_VECTOR, dim=dim)
# 创建集合
self.client.create_collection(self.chunk_col, schema=s)
print(f"Collection '{self.chunk_col}' created.")
# 直接创建索引,不检查索引是否存在
# Milvus的create_index方法会在索引已存在时自动跳过或返回成功
try:
print(f"为集合 '{self.chunk_col}' 创建索引...")
idx_params = self.client.prepare_index_params()
idx_params.add_index(field_name="embedding", index_type="IVF_FLAT", metric_type="COSINE", params={"nlist": 128})
self.client.create_index(self.chunk_col, index_params=idx_params)
print(f"集合 '{self.chunk_col}' 的索引创建成功或已存在")
except Exception as e:
print(f"创建索引失败: {e}")
else:
print(f"Collection '{self.chunk_col}' already exists, skipping creation.")
else:
# 非Milvus客户端,直接创建集合
self.client.create_collection(self.chunk_col)
print(f"Collection '{self.chunk_col}' created or exists.")
# 直接加载所有集合,不进行复杂的错误处理
print("Loading collections into memory...")
# 加载集合(Qdrant 不需要显式加载)
if hasattr(self.client, 'load_collection'):
# 为每个集合创建索引后直接加载
print(f"加载集合 '{self.semantic_col}'...")
self.client.load_collection(self.semantic_col)
print(f"加载集合 '{self.fact_col}'...")
self.client.load_collection(self.fact_col)
print(f"加载集合 '{self.chunk_col}'...")
self.client.load_collection(self.chunk_col)
print("All collections loaded successfully.")
# --- Step 1: Extract ---
def step_extract(self, chunk_text: str, extract_mode: str = "whole", timestamp: int = None) -> Dict:
"""
从对话中提取事实
Args:
chunk_text: 对话文本
extract_mode: 提取模式,可选值:
- "whole": 对整个chunk进行提取
- "turn": 按轮次提取,每轮user-assistant对话单独提取
timestamp: 时间戳,可选,默认使用当前时间
Returns:
包含提取事实的字典
"""
# print(f"\n👀 [1. Extract] Processing: '{chunk_text}'")
# 如果没有提供timestamp,使用当前时间
if timestamp is None:
timestamp = int(time.time())
# 如果是按轮次提取,先解析对话轮次
if extract_mode == "turn":
# 尝试解析对话轮次
try:
# 简单的轮次检测:查找user:和assistant:的组合
import re
# 匹配user: ... assistant: ... 的模式
turn_pattern = r'(user: .*?)(?=assistant: |$)'
turns = re.findall(turn_pattern, chunk_text, re.DOTALL)
# 如果找到轮次,单独处理每轮
if turns:
all_facts = []
for turn in turns:
# 确保每轮都有完整的user-assistant对话
turn_text = turn.strip()
if turn_text:
# 对单轮对话提取事实
turn_facts = self._extract_single_turn(turn_text)
all_facts.extend(turn_facts)
return {"chunk_id": str(uuid.uuid4()), "chunk_text": chunk_text, "new_facts": all_facts, "timestamp": timestamp}
except Exception as e:
print(f"解析对话轮次失败,回退到whole模式: {e}")
# 默认模式:对整个chunk进行提取
facts = self._extract_single_turn(chunk_text)
return {"chunk_id": str(uuid.uuid4()), "chunk_text": chunk_text, "new_facts": facts, "timestamp": timestamp}
def _extract_single_turn(self, text: str) -> List[Dict]:
"""
对单个文本片段提取事实
Args:
text: 要提取事实的文本
Returns:
提取到的事实列表
"""
try:
response = llm_client.chat.completions.create(
model="gpt-4.1",
messages=[
{"role": "system", "content": MEMREADER_PROMPT},
{"role": "user", "content": text}],
response_format={"type": "json_object"}, temperature=0
)
fact_objects = json.loads(response.choices[0].message.content).get("facts", [])
# 保留完整的fact对象,包括details信息
facts = []
for fact_obj in fact_objects:
if fact_obj.get("fact"):
facts.append({
"text": fact_obj.get("fact", ""),
"details": fact_obj.get("details", [])
})
except Exception as e:
print(f"Extraction failed: {e}")
facts = [{"text": text, "details": []}]
return facts
# --- Step 2: Retrieve ---
def step_retrieve(self, extract_result: Dict, limit: int = 3, user_id: str = 'default', similarity_threshold: float = None) -> List[Dict]:
new_facts = extract_result['new_facts']
if not new_facts: return []
print(f"🔍 [2. Retrieve] Searching Memories for {len(new_facts)} facts...")
context_bundles = []
for fact in new_facts:
query_vec = get_embedding(fact['text'])
# 添加user_id过滤,确保只检索当前用户的记忆
res = self.client.search(
self.semantic_col, [query_vec], filter=f"status == 'active' and user_id == '{user_id}'", limit=limit,
output_fields=["content", "memory_id", "created_at"],
similarity_threshold=similarity_threshold
)
candidates = []
if res and res[0]:
for hit in res[0]:
candidates.append(hit['entity'])
# 检索这些记忆关联的事实
related_facts = []
if candidates:
# 获取所有候选记忆的ID
memory_ids = [mem['memory_id'] for mem in candidates]
# 构建查询条件,查找关联到这些记忆的事实
expr_parts = [f'array_contains(linked_memory_ids, "{mem_id}")' for mem_id in memory_ids]
filter_expr = " || ".join(expr_parts)
try:
related_facts = self.client.query(
collection_name=self.fact_col,
filter=filter_expr,
output_fields=["fact_id", "linked_memory_ids", "text", "linked_chunk_id", "timestamp", "details"]
)
except Exception as e:
print(f" ⚠️ Error retrieving related facts: {e}")
# 将related_facts添加到每个memory对象中
for mem in candidates:
mem_id = mem['memory_id']
mem['related_facts'] = [f for f in related_facts if mem_id in f.get('linked_memory_ids', [])]
context_bundles.append({
"new_fact": fact,
"candidates": candidates
})
return context_bundles
# --- Step 3: Decide (With ID Mapping) ---
def step_decide(self, extract_result: Dict, context_bundles: List[Dict], user_id: str = 'default', training_mode: bool = False) -> List[Dict]:
all_new_facts = extract_result['new_facts']
# 1. 合并去重 Candidates
temp_mem_storage = {}
for bundle in context_bundles:
for mem in bundle['candidates']:
temp_mem_storage[mem['memory_id']] = mem
unique_memories_list = list(temp_mem_storage.values())
if not training_mode:
print(f"🧠 [3. Manager] Global Decide: {len(all_new_facts)} facts vs {len(unique_memories_list)} memories.")
# 🌟 2. 构造 ID 映射 (Mapping Logic)
uuid_mapping = {} # { "0": "real-uuid", "1": "real-uuid" }
candidates_str = ""
if not unique_memories_list:
candidates_str = "(No relevant memories found. Treat as new topic.)"
else:
for idx, mem in enumerate(unique_memories_list):
simple_id = str(idx)
real_uuid = mem['memory_id']
uuid_mapping[simple_id] = real_uuid
candidates_str += f"[Memory Item ID: {simple_id}]\n- Content: {mem['content']}\n"
# 添加关联的facts
related_facts = mem.get('related_facts', [])
if related_facts:
candidates_str += "- Related Facts:\n"
for fact_idx, fact in enumerate(related_facts):
candidates_str += f" - Fact {fact_idx + 1}: {fact['text']}\n"
# 添加fact的details
details = fact.get('details', [])
if details:
if isinstance(details, list):
for detail in details:
if isinstance(detail, dict):
detail_str = ", ".join([f"{k}: {v}" for k, v in detail.items()])
candidates_str += f" Detail: {detail_str}\n"
else:
candidates_str += f" Detail: {detail}\n"
elif isinstance(details, dict):
detail_str = ", ".join([f"{k}: {v}" for k, v in details.items()])
candidates_str += f" Detail: {detail_str}\n"
candidates_str += "\n"
# 构造最终 Prompt
system_msg = MEMORY_MANAGER_PROMPT
# 只提取事实的text字段,不包含details,避免LLM将details当作独立事实
fact_texts = [fact['text'] for fact in all_new_facts]
user_content = f"""
[New Facts Stream]
{json.dumps(fact_texts, ensure_ascii=False)}
[EXISTING MEMORIES]
{candidates_str}
"""
all_decisions = []
try:
# 使用streaming模式来获取完整的响应,包括思维过程
response = llm_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": user_content}
],
tools=MEMORY_TOOLS,
tool_choice="required",
temperature=0,
stream=True
)
# 收集完整的响应
collected_messages = []
for chunk in response:
try:
if hasattr(chunk, 'choices') and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, 'delta') and hasattr(choice.delta, 'content') and choice.delta.content is not None:
collected_messages.append(choice.delta.content)
except IndexError:
continue
# 拼接完整的思考过程
thinking_process = ''.join(collected_messages)
if thinking_process and not training_mode:
print(f"\n 🧠 LLM思考过程:")
print(f" {thinking_process}")
# 重新创建非流式响应以获取工具调用
response = llm_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": user_content}
],
tools=MEMORY_TOOLS,
tool_choice="required",
temperature=0
)
# 检查响应结构是否完整
if not response.choices or len(response.choices) == 0:
if not training_mode:
print(f" ⚠️ Warning: No choices in response")
return []
choice = response.choices[0]
if not hasattr(choice, 'message') or not choice.message:
if not training_mode:
print(f" ⚠️ Warning: No message in choice")
return []
tool_calls = choice.message.tool_calls
if not tool_calls: return []
# 🌟 辅助函数: 还原 ID
def resolve_id(simple_id):
real = uuid_mapping.get(str(simple_id))
if not real and not training_mode:
print(f" ⚠️ Warning: LLM hallucinated ID '{simple_id}', ignoring.")
return real
for tool_call in tool_calls:
try:
func_name = tool_call.function.name
args = json.loads(tool_call.function.arguments)
if not training_mode:
print(f" 🤖 Raw Action: {func_name} | Args: {args}")
decision = {"action": "NOOP"}
if func_name == "create_memory":
decision.update({
"action": "ADD",
"summary": args.get("content", ""),
"facts_to_link": args.get("evidence_facts", []),
"user_id": user_id
})
elif func_name == "update_memory":
if "target_memory_id" in args:
real_tid = resolve_id(args["target_memory_id"])
if real_tid:
orig_created = temp_mem_storage.get(real_tid, {}).get('created_at', int(time.time()))
decision.update({
"action": "UPDATE",
"target_id": real_tid,
"new_content": args.get("new_content", ""),
"facts_to_link": args.get("evidence_facts", []),
"orig_created": orig_created,
"user_id": user_id
})
elif func_name == "delete_memory":
if "target_memory_id" in args:
real_tid = resolve_id(args["target_memory_id"])
if real_tid:
orig_created = temp_mem_storage.get(real_tid, {}).get('created_at', int(time.time()))
decision.update({
"action": "DELETE",
"target_id": real_tid,
"facts_to_link": args.get("evidence_facts", []),
"orig_created": orig_created,
"user_id": user_id
})
elif func_name == "infer_memory":
if "source_memory_ids" in args:
source_simples = args["source_memory_ids"]
# 确保source_simples是列表
if not isinstance(source_simples, list):
source_simples = [source_simples]
real_source_ids = [resolve_id(sid) for sid in source_simples if resolve_id(sid)]
if real_source_ids:
decision.update({
"action": "INFER",
"source_ids": real_source_ids,
"summary": args.get("inference_content", ""),
"facts_to_link": args.get("evidence_facts", []),
"user_id": user_id
})
elif func_name == "no_operation":
decision.update({"reason": args.get("reason", "No reason provided"), "user_id": user_id})
if decision["action"] != "NOOP" or "reason" in decision:
all_decisions.append(decision)
except Exception as e:
if not training_mode:
print(f" ⚠️ Error processing tool call: {e}")
continue
except Exception as e:
if not training_mode:
print(f" ⚠️ Decision Error: {e}")
return all_decisions
# --- Batch Processing for Training with GRPO Support ---
def batch_process(self, batch_data: List[Dict], user_id: str = 'default', grpo_compatible: bool = True) -> List[Dict]:
"""
Batch processing for memory management training with GRPO compatibility.
Args:
batch_data (List[Dict]): List of input data for batch processing.
user_id (str, optional): User ID for memory operations. Defaults to 'default'.
grpo_compatible (bool, optional): Whether to return GRPO-compatible format. Defaults to True.
Returns:
List[Dict]: List of results for each input in the batch.
"""
results = []
for data in batch_data:
# Extract facts from input text
extract_result = self.step_extract(data['text'], extract_mode='whole')
# Retrieve relevant memories
context_bundles = self.step_retrieve(extract_result, limit=3, user_id=user_id)
# Make decisions (memory operations) in training mode
decisions = self.step_decide(extract_result, context_bundles, user_id=user_id, training_mode=True)
# Execute decisions
self.step_execute(decisions, extract_result, user_id=user_id)
if grpo_compatible:
# Format result for GRPO training
result = {
'input': data['text'],
'extract_result': extract_result,
'decisions': decisions,
# Add GRPO-specific fields
'memory_operations': [d['action'] for d in decisions if d['action'] != 'NOOP'],
'memory_contents': [d.get('summary', '') for d in decisions if d['action'] != 'NOOP'],
# Ensure we have the expected_operation if provided in data
'expected_operation': data.get('expected_operation', '')
}
else:
# Standard format for non-GRPO training
result = {
'input': data['text'],
'extract_result': extract_result,
'decisions': decisions
}
results.append(result)
return results
# ==========================================
# Step 4: Execute (Modified for Fact Inheritance)
# ==========================================
def step_execute(self, decisions: List[Dict], extract_result: Dict, user_id: str = 'default'):
# 使用extract_result中的timestamp,而不是当前时间
ts = extract_result['timestamp']
chunk_id = extract_result['chunk_id']
all_new_facts = extract_result['new_facts']
# 1. 保存原始 Chunk
self.client.insert(self.chunk_col, [{"chunk_id": chunk_id, "text": extract_result["chunk_text"], "timestamp": ts, "embedding": get_embedding(extract_result["chunk_text"])}])
# 2. 收集所有要链接的事实文本
all_facts_to_link = set()
for decision in decisions:
action = decision.get("action")
facts_to_link = decision.get('facts_to_link', [])
for fact_text in facts_to_link:
all_facts_to_link.add(fact_text)
# 3. 对所有要处理的事实进行最终去重
# 收集所有新事实
all_facts = []
for fact in all_new_facts:
# 只处理在all_facts_to_link中的事实
if fact['text'] in all_facts_to_link:
all_facts.append(fact)
# 对所有事实进行去重
unique_all_facts = []
seen_fact_keys = set()
for fact in all_facts:
fact_key = f"{fact['text']}::{json.dumps(fact['details'], sort_keys=True)}"
# 也考虑去掉"User"前缀的情况
stripped_fact_key = f"{fact['text'].lower().replace('user ', '')}::{json.dumps(fact['details'], sort_keys=True)}"
if fact_key not in seen_fact_keys and stripped_fact_key not in seen_fact_keys:
seen_fact_keys.add(fact_key)
seen_fact_keys.add(stripped_fact_key)
unique_all_facts.append(fact)
if len(unique_all_facts) < len(all_facts):
print(f" ✅ 最终去重 {len(all_facts) - len(unique_all_facts)} 个重复事实")
# 更新all_facts_to_link为去重后的事实文本集合
all_facts_to_link = {fact['text'] for fact in unique_all_facts}
# 3. 处理每个决策
has_non_noop_action = False
# 收集所有要链接的事实,确保去重
all_matched_facts = []
seen_fact_keys = set()
for decision in decisions:
action = decision.get("action")
if action == "NOOP":
self.operation_counts["NOOP"] += 1
print(f" 🚫 No operation: {decision.get('reason', 'No reason provided')}")
continue
has_non_noop_action = True
target_mem_id = None
relations = []
# --- CASE 1: ADD ---
if action == "ADD":
self.operation_counts["ADD"] += 1
target_mem_id = str(uuid.uuid4())
self._upsert_mem(target_mem_id, decision['summary'], ts, ts, "active", [], decision.get('user_id', 'default'))
print(f" ✅ Created Mem: {target_mem_id[:8]}... | Content: {decision['summary']}")
# --- CASE 2: UPDATE ---
elif action == "UPDATE":
self.operation_counts["UPDATE"] += 1
target_mem_id = decision['target_id']
# 查询旧的memory内容
old_memories = self.client.query(
collection_name=self.semantic_col,
filter=f"memory_id == '{target_mem_id}'",
output_fields=["content", "created_at"]
)
old_content = "" if not old_memories else old_memories[0].get("content", "")
new_content = decision['new_content']
# 记录update前后的内容
print(f" 🔄 Updating Mem: {target_mem_id[:8]}...")
print(f" Before: {old_content[:]}...")
print(f" After: {new_content[:]}...")
self._upsert_mem(target_mem_id, new_content, decision['orig_created'], ts, "active", [], decision.get('user_id', 'default'))
# --- CASE 3: DELETE ---
elif action == "DELETE":
self.operation_counts["DELETE"] += 1
target_mem_id = decision['target_id']
self._upsert_mem(target_mem_id, "(Archived)", decision['orig_created'], ts, "archived", [], decision.get('user_id', 'default'))
print(f" ❌ Deleted Mem: {target_mem_id[:8]}...")
# --- CASE 4: INFER (With Fact Inheritance) ---
elif action == "INFER":
self.operation_counts["INFER"] += 1
target_mem_id = str(uuid.uuid4()) # 这是 Memory C
source_ids = decision.get('source_ids', []) # 这是 [A, B]
#############################################################
# 查询source_ids对应的memory内容,用于打印
source_mems = []
if source_ids:
quoted_source_ids = [f'"{sid}"' for sid in source_ids]
mem_filter = f"status == 'active' and memory_id in [{','.join(quoted_source_ids)}]"
try:
source_mems = self.client.query(
collection_name=self.semantic_col,
filter=mem_filter,
output_fields=["content", "memory_id", "created_at", "user_id"]
)
except Exception as e:
print(f" ⚠️ 查询source memory失败: {e}")
#############################################################
# 4.1 创建新记忆 C,并记录血缘关系 (inferred_from)
relations = [{"type": "inferred_from", "target_id": sid} for sid in source_ids]
self._upsert_mem(target_mem_id, decision['summary'], ts, ts, "active", relations, decision.get('user_id', 'default'))
####################################################################################
# 将infer前后的memory内容拼在同一个字符串里输出
infer_output = f" 💡 Inferred Mem: {target_mem_id[:8]}... | From: {[s[:8] for s in source_ids]}\n"
infer_output += f" ┌─────────────────────────────────────────────────────────────────────────────────\n"
# 拼接infer前的memory内容
if source_mems:
infer_output += f" │ 📋 Infer 前的 Memory ({len(source_mems)}个):\n"
for mem in source_mems:
mem_id = mem.get("memory_id", "unknown")
content = mem.get("content", "")
infer_output += f" │ 📌 ID: {mem_id[:8]}... | 内容: {content[:]}...\n"
# 拼接infer生成的memory内容
infer_output += f" │ 📝 Infer生成的 Memory:\n"
infer_output += f" │ 📌 ID: {target_mem_id[:8]}... | 内容: {decision['summary'][:]}...\n"
infer_output += f" └─────────────────────────────────────────────────────────────────────────────────"
# 一次性输出整个字符串
print(infer_output)
#################################################################################
# 4.2 🌟 核心修改:继承旧 Facts
# 逻辑:找出所有支持 A 或 B 的 Fact,把 C 也加到它们的支持列表里
if source_ids:
# 构建查询表达式:array_contains(linked_memory_ids, 'A') or ...
expr_parts = [f'array_contains(linked_memory_ids, "{sid}")' for sid in source_ids]
filter_expr = " || ".join(expr_parts)
try:
# 查出旧 Facts
old_related_facts = self.client.query(
collection_name=self.fact_col,
filter=filter_expr,
output_fields=["fact_id", "linked_memory_ids", "text", "linked_chunk_id", "timestamp", "details", "embedding"]
)
if old_related_facts:
updated_rows = []
for fact in old_related_facts:
current_links = fact.get("linked_memory_ids", [])