diff --git a/README.md b/README.md
index 9c62924..b7ed422 100644
--- a/README.md
+++ b/README.md
@@ -232,10 +232,10 @@ lightmem = LightMemory.from_config(config_dict)
session = {
"timestamp": "2025-01-10",
"turns": [
- [
- {"role": "user", "content": "My favorite ice cream flavor is pistachio, and my dog's name is Rex."},
- {"role": "assistant", "content": "Got it. Pistachio is a great choice."}],
- ]
+ [
+ {"role": "user", "content": "My favorite ice cream flavor is pistachio, and my dog's name is Rex.", "speaker_name": "John","speaker_id": "speaker_a"},
+ {"role": "assistant", "content": "Got it. Pistachio is a great choice.", "speaker_name": "Assistant", "speaker_id": "speaker_b"}],
+ ]
}
@@ -377,10 +377,10 @@ We welcome contributions from the community! If you'd like to contribute, please
-
-
+
+
- Memos
+ MemOS
|
diff --git a/src/lightmem/configs/memory_manager/base_config.py b/src/lightmem/configs/memory_manager/base_config.py
index e8855cc..c57a8d9 100644
--- a/src/lightmem/configs/memory_manager/base_config.py
+++ b/src/lightmem/configs/memory_manager/base_config.py
@@ -13,6 +13,7 @@ def __init__(
max_tokens: int = 2000,
top_p: float = 0.1,
top_k: int = 1,
+ include_topic_summary: bool = False,
enable_vision: bool = False,
vision_details: Optional[str] = "auto",
# Openai specific
@@ -31,6 +32,7 @@ def __init__(
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
+ self.include_topic_summary = include_topic_summary
self.enable_vision = enable_vision
self.vision_details = vision_details
# Openai specific
diff --git a/src/lightmem/configs/text_embedder/base.py b/src/lightmem/configs/text_embedder/base.py
index 844c8ba..b047d27 100644
--- a/src/lightmem/configs/text_embedder/base.py
+++ b/src/lightmem/configs/text_embedder/base.py
@@ -8,7 +8,7 @@ class TextEmbedderConfig(BaseModel):
description="The embedding model or Deployment platform (e.g., 'openai', 'huggingface')"
)
- _model_list: ClassVar[List[str]] = ["huggingface"]
+ _model_list: ClassVar[List[str]] = ["huggingface", "openai"]
configs: Optional[Union[BaseTextEmbedderConfig, Dict[str, Any]]] = Field(
default=None,
diff --git a/src/lightmem/factory/memory_buffer/short_term_memory.py b/src/lightmem/factory/memory_buffer/short_term_memory.py
index 2a75d67..2bfe201 100644
--- a/src/lightmem/factory/memory_buffer/short_term_memory.py
+++ b/src/lightmem/factory/memory_buffer/short_term_memory.py
@@ -7,7 +7,7 @@ def __init__(self, max_tokens: int = 2000, tokenizer: Optional[Any] = None):
self.tokenizer = resolve_tokenizer(tokenizer)
self.buffer: List[List[Dict[str, Any]]] = []
self.token_count: int = 0
-
+ print(f"ShortMemBufferManager initialized with max_tokens={self.max_tokens}")
def _count_tokens(self, messages: List[Dict[str, Any]], messages_use: str) -> int:
role_map = {
"user_only": ["user"],
diff --git a/src/lightmem/factory/memory_manager/openai.py b/src/lightmem/factory/memory_manager/openai.py
index 73f4aad..a61b9cb 100644
--- a/src/lightmem/factory/memory_manager/openai.py
+++ b/src/lightmem/factory/memory_manager/openai.py
@@ -1,6 +1,7 @@
import concurrent
+from collections import defaultdict
from openai import OpenAI
-from typing import List, Dict, Optional, Literal
+from typing import List, Dict, Optional, Literal, Any
import json, os, warnings
import httpx
from lightmem.configs.memory_manager.base_config import BaseMemoryManagerConfig
@@ -120,28 +121,41 @@ def generate_response(
params["tool_choice"] = tool_choice
response = self.client.chat.completions.create(**params)
- return self._parse_response(response, tools)
-
+ usage_info = {
+ "prompt_tokens": response.usage.prompt_tokens,
+ "completion_tokens": response.usage.completion_tokens,
+ "total_tokens": response.usage.total_tokens,
+ }
+ parsed_response = self._parse_response(response, tools)
+
+ return parsed_response, usage_info
+
def meta_text_extract(
self,
system_prompt: str,
extract_list: List[List[List[Dict]]],
- messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only"
+ timestamps_list: Optional[List[List[List[str]]]] = None,
+ weekday_list: Optional[List[List[List[str]]]] = None,
+ messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only",
+ topic_id_mapping: Optional[List[List[int]]] = None
) -> List[Optional[Dict]]:
"""
Extract metadata from text segments using parallel processing.
Args:
system_prompt: The system prompt for metadata generation
- all_segments: List of message segments to process
+ extract_list: List of message segments to process
+ timestamps_list: Optional list of timestamps (reserved for future use)
+ weekday_list: Optional list of weekdays (reserved for future use)
messages_use: Strategy for which messages to use
+ topic_id_mapping: For each API call, the global topic IDs
Returns:
List of extracted metadata results, None for failed segments
"""
if not extract_list:
return []
-
+
def concatenate_messages(segment: List[Dict], messages_use: str) -> str:
"""Concatenate messages based on usage strategy"""
role_filter = {
@@ -161,43 +175,69 @@ def concatenate_messages(segment: List[Dict], messages_use: str) -> str:
sequence_id = mes["sequence_number"]
role = mes["role"]
content = mes.get("content", "")
- message_lines.append(f"{sequence_id}.{role}: {content}")
-
+ speaker_name = mes.get("speaker_name", "")
+ time_stamp = mes.get("time_stamp", "")
+ weekday = mes.get("weekday", "")
+
+ time_prefix = ""
+ if time_stamp and weekday:
+ time_prefix = f"[{time_stamp}, {weekday}] "
+
+ if speaker_name != 'Unknown':
+ message_lines.append(f"{time_prefix}{sequence_id//2}.{speaker_name}: {content}")
+ else:
+ message_lines.append(f"{time_prefix}{sequence_id//2}.{role}: {content}")
+
return "\n".join(message_lines)
-
+
max_workers = min(len(extract_list), 5)
- def process_segment_wrapper(api_call_segments: List[List[Dict]]):
- """Process one API call (multiple topic segments inside)"""
+ def process_segment_wrapper(args):
+ api_call_idx, api_call_segments = args
try:
- user_prompt_parts = []
- for idx, topic_segment in enumerate(api_call_segments, start=1):
+ user_prompt_parts: List[str] = []
+
+ global_topic_ids: List[int] = []
+ if topic_id_mapping and api_call_idx < len(topic_id_mapping):
+ global_topic_ids = topic_id_mapping[api_call_idx]
+
+ for topic_idx, topic_segment in enumerate(api_call_segments):
+ if topic_idx < len(global_topic_ids):
+ global_topic_id = global_topic_ids[topic_idx]
+ else:
+ global_topic_id = topic_idx + 1
+
topic_text = concatenate_messages(topic_segment, messages_use)
- user_prompt_parts.append(f"--- Topic {idx} ---\n{topic_text}")
+ user_prompt_parts.append(f"--- Topic {global_topic_id} ---\n{topic_text}")
+ print(f"User prompt for API call {api_call_idx}:\n" + "\n".join(user_prompt_parts))
user_prompt = "\n".join(user_prompt_parts)
-
- messages = [
+
+ metadata_messages = [
{"role": "system", "content": system_prompt},
- {"role": "user", "content": user_prompt}
+ {"role": "user", "content": user_prompt},
]
- raw_response = self.generate_response(
- messages=messages,
- response_format={"type": "json_object"}
+
+ raw_response, usage_info = self.generate_response(
+ messages=metadata_messages,
+ response_format={"type": "json_object"},
)
- cleaned_result = clean_response(raw_response)
+ metadata_facts = clean_response(raw_response)
+
return {
- "input_prompt": messages,
+ "input_prompt": metadata_messages,
"output_prompt": raw_response,
- "cleaned_result": cleaned_result
+ "cleaned_result": metadata_facts,
+ "usage": usage_info,
}
+
except Exception as e:
- print(f"Error processing API call: {e}")
+ print(f"Error processing API call {api_call_idx}: {e}")
return None
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
try:
- results = list(executor.map(process_segment_wrapper, extract_list))
+ results = list(executor.map(process_segment_wrapper, enumerate(extract_list)))
except Exception as e:
print(f"Error in parallel processing: {e}")
results = [None] * len(extract_list)
@@ -218,15 +258,16 @@ def _call_update_llm(self, system_prompt, target_entry, candidate_sources):
{"role": "user", "content": user_prompt}
]
- response_text = self.generate_response(
+ response_text, usage_info = self.generate_response(
messages=messages,
response_format={"type": "json_object"}
)
-
+
try:
result = json.loads(response_text)
if "action" not in result:
- return {"action": "ignore"}
+ result = {"action": "ignore"}
+ result["usage"] = usage_info
return result
except Exception:
- return {"action": "ignore"}
\ No newline at end of file
+ return {"action": "ignore", "usage": usage_info if 'usage_info' in locals() else None}
\ No newline at end of file
diff --git a/src/lightmem/factory/text_embedder/factory.py b/src/lightmem/factory/text_embedder/factory.py
index af6d2fe..e9e6106 100644
--- a/src/lightmem/factory/text_embedder/factory.py
+++ b/src/lightmem/factory/text_embedder/factory.py
@@ -5,6 +5,7 @@
class TextEmbedderFactory:
_MODEL_MAPPING: Dict[str, str] = {
"huggingface": "lightmem.factory.text_embedder.huggingface.TextEmbedderHuggingface",
+ "openai": "lightmem.factory.text_embedder.openai.TextEmbedderOpenAI",
}
@classmethod
diff --git a/src/lightmem/factory/text_embedder/huggingface.py b/src/lightmem/factory/text_embedder/huggingface.py
index 2166993..67257a9 100644
--- a/src/lightmem/factory/text_embedder/huggingface.py
+++ b/src/lightmem/factory/text_embedder/huggingface.py
@@ -7,12 +7,16 @@
class TextEmbedderHuggingface:
def __init__(self, config: Optional[BaseTextEmbedderConfig] = None):
self.config = config
+ self.total_calls = 0
+ self.total_tokens = 0
if config.huggingface_base_url:
self.client = OpenAI(base_url=config.huggingface_base_url)
+ self.use_api = True
else:
self.config.model = config.model or "all-MiniLM-L6-v2"
self.model = SentenceTransformer(config.model, **config.model_kwargs)
self.config.embedding_dims = config.embedding_dims or self.model.get_sentence_embedding_dimension()
+ self.use_api = False
@classmethod
def from_config(cls, config):
@@ -39,11 +43,20 @@ def embed(self, text):
Returns:
list: The embedding vector.
"""
+ self.total_calls += 1
if self.config.huggingface_base_url:
- return self.client.embeddings.create(input=text, model="tei").data[0].embedding
+ response = self.client.embeddings.create(input=text, model="tei")
+ self.total_tokens += getattr(response.usage, 'total_tokens', 0)
+ return response.data[0].embedding
else:
result = self.model.encode(text, convert_to_numpy=True)
if isinstance(result, np.ndarray):
return result.tolist()
else:
- return result
\ No newline at end of file
+ return result
+
+ def get_stats(self):
+ return {
+ "total_calls": self.total_calls,
+ "total_tokens": self.total_tokens,
+ }
\ No newline at end of file
diff --git a/src/lightmem/factory/text_embedder/openai.py b/src/lightmem/factory/text_embedder/openai.py
new file mode 100644
index 0000000..05be917
--- /dev/null
+++ b/src/lightmem/factory/text_embedder/openai.py
@@ -0,0 +1,53 @@
+from openai import OpenAI
+from typing import Optional, List, Union
+import os
+import httpx
+from lightmem.configs.text_embedder.base_config import BaseTextEmbedderConfig
+
+
+class TextEmbedderOpenAI:
+ def __init__(self, config: Optional[BaseTextEmbedderConfig] = None):
+ self.config = config
+ self.model = getattr(config, "model", None) or "text-embedding-3-small"
+ http_client = httpx.Client(verify=False)
+ api_key = self.config.api_key
+ base_url = self.config.openai_base_url
+ self.client = OpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ http_client=http_client
+ )
+ self.total_calls = 0
+ self.total_tokens = 0
+
+ @classmethod
+ def from_config(cls, config: BaseTextEmbedderConfig):
+ return cls(config)
+
+ def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
+ def preprocess(t):
+ return str(t).replace("\n", " ")
+
+ api_params = {"model": self.config.model}
+ api_params["dimensions"] = self.config.embedding_dims
+
+ if isinstance(text, list):
+ if len(text) == 0:
+ return []
+ inputs = [preprocess(x) for x in text]
+ resp = self.client.embeddings.create(input=inputs, **api_params)
+ self.total_calls += 1
+ self.total_tokens += resp.usage.total_tokens
+ return [item.embedding for item in resp.data]
+ else:
+ preprocessed = preprocess(text)
+ resp = self.client.embeddings.create(input=[preprocessed], **api_params)
+ self.total_calls += 1
+ self.total_tokens += resp.usage.total_tokens
+ return resp.data[0].embedding
+
+ def get_stats(self):
+ return {
+ "total_calls": self.total_calls,
+ "total_tokens": self.total_tokens,
+ }
\ No newline at end of file
diff --git a/src/lightmem/memory/lightmem.py b/src/lightmem/memory/lightmem.py
index b98297f..a82e1ef 100644
--- a/src/lightmem/memory/lightmem.py
+++ b/src/lightmem/memory/lightmem.py
@@ -17,10 +17,12 @@
from lightmem.factory.retriever.embeddingretriever.factory import EmbeddingRetrieverFactory
from lightmem.factory.memory_buffer.sensory_memory import SenMemBufferManager
from lightmem.factory.memory_buffer.short_term_memory import ShortMemBufferManager
-from lightmem.memory.utils import MemoryEntry, assign_sequence_numbers_with_timestamps, save_memory_entries
-from lightmem.memory.prompts import METADATA_GENERATE_PROMPT, UPDATE_PROMPT
+from lightmem.memory.utils import MemoryEntry, assign_sequence_numbers_with_timestamps, save_memory_entries, convert_extraction_results_to_memory_entries
+from lightmem.memory.prompts import METADATA_GENERATE_PROMPT, UPDATE_PROMPT, METADATA_GENERATE_PROMPT_locomo
from lightmem.configs.logging.utils import get_logger
+GLOBAL_TOPIC_IDX = 0
+
class MessageNormalizer:
@@ -136,6 +138,19 @@ def __init__(self, config: BaseMemoryConfigs = BaseMemoryConfigs()):
self.logger = get_logger("LightMemory")
self.logger.info("Initializing LightMemory with provided configuration")
+ self.token_stats = {
+ "add_memory_calls": 0,
+ "add_memory_prompt_tokens": 0,
+ "add_memory_completion_tokens": 0,
+ "add_memory_total_tokens": 0,
+ "update_calls": 0,
+ "update_prompt_tokens": 0,
+ "update_completion_tokens": 0,
+ "update_total_tokens": 0,
+ "embedding_calls": 0,
+ "embedding_total_tokens": 0,
+ }
+ self.logger.info("Token statistics tracking initialized")
self.config = config
if self.config.pre_compress:
@@ -147,7 +162,7 @@ def __init__(self, config: BaseMemoryConfigs = BaseMemoryConfigs()):
self.senmem_buffer_manager = SenMemBufferManager(max_tokens=self.segmenter.buffer_len, tokenizer=self.segmenter.tokenizer)
self.logger.info("Initializing memory manager")
self.manager = MemoryManagerFactory.from_config(self.config.memory_manager)
- self.shortmem_buffer_manager = ShortMemBufferManager(max_tokens = 1024, tokenizer=getattr(self.manager, "tokenizer", self.manager.config.model))
+ self.shortmem_buffer_manager = ShortMemBufferManager(max_tokens = 512, tokenizer=getattr(self.manager, "tokenizer", self.manager.config.model))
if self.config.index_strategy == 'embedding' or self.config.index_strategy == 'hybrid':
self.logger.info("Initializing text embedder")
self.text_embedder = TextEmbedderFactory.from_config(self.config.text_embedder)
@@ -283,60 +298,75 @@ def add_memory(
self.logger.debug(f"[{call_id}] Extraction not triggered, returning result")
return result # TODO
+ global GLOBAL_TOPIC_IDX
+ topic_id_mapping = []
+ for api_call_segments in extract_list:
+ api_call_topic_ids = []
+ for topic_segment in api_call_segments:
+ api_call_topic_ids.append(GLOBAL_TOPIC_IDX)
+ GLOBAL_TOPIC_IDX += 1
+ topic_id_mapping.append(api_call_topic_ids)
+ self.logger.debug(f"topic_id_mapping: {topic_id_mapping}")
+ self.logger.info(f"[{call_id}] Assigned global topic IDs: total={sum(len(x) for x in topic_id_mapping)}, mapping={topic_id_mapping}")
self.logger.info(f"[{call_id}] Extraction triggered {extract_trigger_num} times, extract_list length: {len(extract_list)}")
- extract_list, timestamps_list, weekday_list = assign_sequence_numbers_with_timestamps(extract_list)
- self.logger.info(f"[{call_id}] Assigned timestamps to {len(extract_list)} items")
- self.logger.debug(f"[{call_id}] Timestamps sample: {timestamps_list}")
- self.logger.debug(f"[{call_id}] Weekdays sample: {weekday_list}")
+ extract_list, timestamps_list, weekday_list, speaker_list, topic_id_map = assign_sequence_numbers_with_timestamps(extract_list, offset_ms=500, topic_id_mapping=topic_id_mapping)
self.logger.debug(f"[{call_id}] Extract list sample: {json.dumps(extract_list)}")
if self.config.metadata_generate and self.config.text_summary:
self.logger.info(f"[{call_id}] Starting metadata generation")
- extracted_results = self.manager.meta_text_extract(METADATA_GENERATE_PROMPT, extract_list, self.config.messages_use)
- for item in extracted_results:
- if item is not None:
- result["add_input_prompt"].append(item["input_prompt"])
- result["add_output_prompt"].append(item["output_prompt"])
- result["api_call_nums"] += 1
+ extracted_results = self.manager.meta_text_extract(METADATA_GENERATE_PROMPT_locomo, extract_list, self.config.messages_use, topic_id_mapping)
+
+ # =============API Consumption======================
+ for idx, item in enumerate(extracted_results):
+ if item is None:
+ continue
+
+ if "usage" in item:
+ usage = item["usage"]
+ self.token_stats["add_memory_calls"] += 1
+ self.token_stats["add_memory_prompt_tokens"] += usage.get("prompt_tokens", 0)
+ self.token_stats["add_memory_completion_tokens"] += usage.get("completion_tokens", 0)
+ self.token_stats["add_memory_total_tokens"] += usage.get("total_tokens", 0)
+
+ self.logger.info(
+ f"[{call_id}] API Call {idx} tokens - "
+ f"Prompt: {usage.get('prompt_tokens', 0)}, "
+ f"Completion: {usage.get('completion_tokens', 0)}, "
+ f"Total: {usage.get('total_tokens', 0)}"
+ )
+
+ self.logger.debug(f"[{call_id}] API Call {idx} raw output: {item['output_prompt']}")
+ self.logger.debug(f"[{call_id}] API Call {idx} cleaned result: {item['cleaned_result']}")
+ result["add_input_prompt"].append(item["input_prompt"])
+ result["add_output_prompt"].append(item["output_prompt"])
+ result["api_call_nums"] += 1
+
+ # =======================================
+
self.logger.info(f"[{call_id}] Metadata generation completed with {result['api_call_nums']} API calls")
- extracted_memory_entry = [item["cleaned_result"] for item in extracted_results if item]
- self.logger.info(f"[{call_id}] Extracted {len(extracted_memory_entry)} memory entries")
- self.logger.debug(f"[{call_id}] Extracted memory entry sample: {json.dumps(extracted_memory_entry)}")
- memory_entries = []
- for topic_memory in extracted_memory_entry:
- if not topic_memory:
- continue
- for entry in topic_memory:
- sequence_n = entry.get("source_id")
- try:
- time_stamp = timestamps_list[sequence_n]
- if not isinstance(time_stamp, float):
- float_time_stamp = datetime.fromisoformat(time_stamp).timestamp()
- weekday = weekday_list[sequence_n]
- except (IndexError, TypeError) as e:
- self.logger.warning(f"[{call_id}] Error getting timestamp for sequence {sequence_n}: {e}")
- time_stamp = None
- float_time_stamp = None
- weekday = None
- mem_obj = MemoryEntry(
- time_stamp=time_stamp,
- float_time_stamp=float_time_stamp,
- weekday=weekday,
- memory=entry.get("fact", ""),
- # original_memory=entry.get("original_fact", ""), # TODO
- # compressed_memory="" # TODO
- )
- memory_entries.append(mem_obj)
+ memory_entries = convert_extraction_results_to_memory_entries(
+ extracted_results=extracted_results,
+ timestamps_list=timestamps_list,
+ weekday_list=weekday_list,
+ speaker_list=speaker_list,
+ topic_id_map=topic_id_map,
+ logger=self.logger
+ )
self.logger.info(f"[{call_id}] Created {len(memory_entries)} MemoryEntry objects")
for i, mem in enumerate(memory_entries):
- self.logger.debug(f"[{call_id}] MemoryEntry[{i}]: time={mem.time_stamp}, weekday={mem.weekday}, memory={mem.memory}")
+ self.logger.debug(f"[{call_id}] MemoryEntry[{i}]: time={mem.time_stamp}, weekday={mem.weekday}, speaker_id={mem.speaker_id}, speaker_name={mem.speaker_name}, topic_id={mem.topic_id}, memory={mem.memory}")
if self.config.update == "online":
self.online_update(memory_entries)
elif self.config.update == "offline":
self.offline_update(memory_entries)
+ self.logger.info(
+ f"[{call_id}] Cumulative token stats - "
+ f"Total API calls: {self.token_stats['add_memory_calls']}, "
+ f"Total tokens: {self.token_stats['add_memory_total_tokens']}"
+ )
return result
def online_update(self, memory_list: List):
@@ -365,12 +395,16 @@ def offline_update(self, memory_list: List, construct_update_queue_trigger: bool
"time_stamp": mem_obj.time_stamp,
"float_time_stamp": mem_obj.float_time_stamp,
"weekday": mem_obj.weekday,
+ "topic_id": mem_obj.topic_id,
+ "topic_summary": mem_obj.topic_summary,
"category": mem_obj.category,
"subcategory": mem_obj.subcategory,
"memory_class": mem_obj.memory_class,
"memory": mem_obj.memory,
"original_memory": mem_obj.original_memory,
"compressed_memory": mem_obj.compressed_memory,
+ "speaker_id": mem_obj.speaker_id,
+ "speaker_name": mem_obj.speaker_name,
}
self.embedding_retriever.insert(
vectors = [embedding_vector],
@@ -499,7 +533,13 @@ def offline_update_all_entries(self, score_threshold: float = 0.5, max_workers:
skipped_count = 0
lock = threading.Lock()
write_lock = threading.Lock()
-
+ update_token_stats = {
+ "calls": 0,
+ "prompt_tokens": 0,
+ "completion_tokens": 0,
+ "total_tokens": 0
+ }
+ token_lock = threading.Lock()
def update_entry(entry):
nonlocal processed_count, updated_count, deleted_count, skipped_count
@@ -526,7 +566,19 @@ def update_entry(entry):
if updated_entry is None:
return
-
+ # ====== token consumption ======
+ usage = updated_entry["usage"]
+ with token_lock:
+ update_token_stats["calls"] += 1
+ update_token_stats["prompt_tokens"] += usage.get("prompt_tokens", 0)
+ update_token_stats["completion_tokens"] += usage.get("completion_tokens", 0)
+ update_token_stats["total_tokens"] += usage.get("total_tokens", 0)
+
+ self.logger.debug(
+ f"[{call_id}] Update LLM call for {eid} - "
+ f"Tokens: {usage.get('total_tokens', 0)}"
+ )
+ # ==================== token consumption ====================
action = updated_entry.get("action")
if action == "delete":
with write_lock:
@@ -546,11 +598,20 @@ def update_entry(entry):
self.logger.info(f"[{call_id}] Starting parallel offline update with {max_workers} workers")
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(update_entry, all_entries)
+ with lock:
+ self.token_stats["update_calls"] += update_token_stats["calls"]
+ self.token_stats["update_prompt_tokens"] += update_token_stats["prompt_tokens"]
+ self.token_stats["update_completion_tokens"] += update_token_stats["completion_tokens"]
+ self.token_stats["update_total_tokens"] += update_token_stats["total_tokens"]
self.logger.info(f"[{call_id}] Offline update completed:")
self.logger.info(f"[{call_id}] - Processed: {processed_count} entries")
self.logger.info(f"[{call_id}] - Updated: {updated_count} entries")
self.logger.info(f"[{call_id}] - Deleted: {deleted_count} entries")
self.logger.info(f"[{call_id}] - Skipped (no candidates): {skipped_count} entries")
+ self.logger.info(
+ f"[{call_id}] - Update API calls: {update_token_stats['calls']}, "
+ f"Total tokens: {update_token_stats['total_tokens']}"
+ )
self.logger.info(f"========== END {call_id} ==========")
def retrieve(self, query: str, limit: int = 10, filters: dict = None) -> list[str]:
@@ -595,3 +656,37 @@ def retrieve(self, query: str, limit: int = 10, filters: dict = None) -> list[st
self.logger.info(f"========== END {call_id} ==========")
return result_string
+ def get_token_statistics(self):
+ embedder_stats = {"total_calls": 0, "total_tokens": None}
+ if hasattr(self, 'text_embedder') and hasattr(self.text_embedder, 'get_stats'):
+ embedder_stats = self.text_embedder.get_stats()
+
+ stats = {
+ "summary": {
+ "total_llm_calls": self.token_stats["add_memory_calls"] + self.token_stats["update_calls"],
+ "total_llm_tokens": self.token_stats["add_memory_total_tokens"] + self.token_stats["update_total_tokens"],
+ "total_embedding_calls": embedder_stats["total_calls"],
+ "total_embedding_tokens": embedder_stats["total_tokens"],
+ },
+ "llm": {
+ "add_memory": {
+ "calls": self.token_stats["add_memory_calls"],
+ "prompt_tokens": self.token_stats["add_memory_prompt_tokens"],
+ "completion_tokens": self.token_stats["add_memory_completion_tokens"],
+ "total_tokens": self.token_stats["add_memory_total_tokens"],
+ },
+ "update": {
+ "calls": self.token_stats["update_calls"],
+ "prompt_tokens": self.token_stats["update_prompt_tokens"],
+ "completion_tokens": self.token_stats["update_completion_tokens"],
+ "total_tokens": self.token_stats["update_total_tokens"],
+ },
+ },
+ "embedding": {
+ "total_calls": embedder_stats["total_calls"],
+ "total_tokens": embedder_stats["total_tokens"],
+ "note": "Includes topic segmentation + memory indexing. Local models show None for tokens."
+ }
+ }
+
+ return stats
\ No newline at end of file
diff --git a/src/lightmem/memory/prompts.py b/src/lightmem/memory/prompts.py
index 512d388..8852ec2 100644
--- a/src/lightmem/memory/prompts.py
+++ b/src/lightmem/memory/prompts.py
@@ -53,6 +53,85 @@
Reminder: Be exhaustive. Unless a message is purely meaningless, extract and output it as a fact.
"""
+# 73.90
+METADATA_GENERATE_PROMPT_locomo = """
+You are a Personal Information Extractor.
+Your task is to extract **all possible facts or information** about the speakers from a conversation,
+where the dialogue is organized into topic segments separated by markers like:
+
+--- Topic X ---
+[timestamp, weekday] .:
+...
+
+Important Instructions:
+0. You MUST process messages **strictly in ascending source_id order** (lowest → highest).
+ For each message, stop and **carefully** evaluate its content before moving to the next.
+ Do NOT reorder, batch-skip, or skip ahead — treat messages one-by-one.
+1. You MUST process every user message in order, one by one.
+ For each message, decide whether it contains any factual information.
+ - If yes → extract it and rephrase into a standalone sentence.
+ - Do NOT skip just because the information looks minor, trivial, or unimportant.
+ Extract ALL meaningful information including:
+ * Past events and current states
+ * Future plans and intentions
+ * Thoughts, opinions, and attitudes
+ * Wants, hopes, desires, and preferences
+2. **CRITICAL - Preserve All Specific Details**:
+ When extracting facts, you MUST include ALL specific entities and details mentioned:
+ - **Full names with context**: "The Name of the Wind" by Patrick Rothfuss (not just "a book")
+ - **Complete location names**: Galway, Ireland; The Cliffs of Moher; Barcelona (not just "a city")
+ - **Specific event names**: benefit basketball game, study abroad program (not just "an event")
+ - **Product/item details**: vintage camera, brand new fire truck (not just "a camera")
+ - **Numbers and quantities**: 4 years ago, next month, last week
+ - **Company/organization names**: beverage company, fire-fighting brigade
+ Additionally, **infer implied information** when clearly supported:
+ - If multiple related items mentioned → may infer general pattern
+ - Keep BOTH specific facts AND inferred insights as separate entries
+3. Perform light contextual completion so that each fact is a clear standalone statement.
+4. **Time Handling**:
+ Note: Distinguish mention time (when said) vs event time (when happened).
+ - For events with relative time (yesterday, last week, X ago, next month):
+ Preserve the relative time and reference the message timestamp (YYYY-MM-DD).
+ Format: " ."
+ - For ongoing/timeless facts: No time annotation needed.
+5. Output format:
+ Always return a JSON object with key `"data"`, which is a list of items:
+ {
+ "source_id": "",
+ "fact": ""
+ }
+
+Examples:
+--- Topic 1 ---
+[2024-01-07T17:24:00.000, Sun] 0.Tim: Hey John! Next month I'm off to Ireland for a semester in Galway
+[2024-01-07T17:24:01.000, Sun] 1.John: That's awesome! Where will you stay?
+[2024-01-07T17:24:02.000, Sun] 2.Tim: In Galway. I also want to visit The Cliffs of Moher
+[2024-01-07T17:24:03.000, Sun] 3.John: Nice! By the way, I held a benefit basketball game last week
+[2024-01-07T17:24:04.000, Sun] 4.Tim: Cool! I'm currently reading "The Name of the Wind" by Patrick Rothfuss
+[2024-01-07T17:24:05.000, Sun] 5.John: That sounds interesting!
+--- Topic 2 ---
+[2024-01-12T13:41:00.000, Fri] 6.John: Got great news! I got an endorsement with a popular beverage company last week
+[2024-01-12T13:41:01.000, Fri] 7.Tim: Congrats! That's amazing
+[2024-01-12T13:41:02.000, Fri] 8.John: Thanks! By the way, Barcelona is a must-visit city
+[2024-01-12T13:41:03.000, Fri] 9.Tim: I'll add it to my list!
+
+{"data": [
+ {"source_id": 0, "fact": "Tim is going to Ireland for a semester in Galway the month after 2024-01-07."},
+ {"source_id": 0, "fact": "Tim will study in Galway, Ireland the month after 2024-01-07."},
+ {"source_id": 2, "fact": "Tim will stay in Galway."},
+ {"source_id": 2, "fact": "Tim wants to visit The Cliffs of Moher."},
+ {"source_id": 3, "fact": "John held a benefit basketball game the week before 2024-01-07."},
+ {"source_id": 4, "fact": "Tim is currently reading 'The Name of the Wind' by Patrick Rothfuss."},
+ {"source_id": 4, "fact": "Tim is reading a fantasy novel."},
+ {"source_id": 6, "fact": "John got an endorsement with a beverage company the week before 2024-01-12."},
+ {"source_id": 8, "fact": "John recommends Barcelona as a must-visit city."},
+ {"source_id": 9, "fact": "Tim has a travel list and plans to add Barcelona to it."}
+]}
+
+Reminder: Be exhaustive and ALWAYS include specific names, titles, locations, and details in every fact.
+"""
+
+
# METADATA_GENERATE_PROMPT = """
# You are a Personal Information Extractor.
# Your task is to extract meaningful facts about the user from a conversation,
diff --git a/src/lightmem/memory/utils.py b/src/lightmem/memory/utils.py
index d098afb..0d6253f 100644
--- a/src/lightmem/memory/utils.py
+++ b/src/lightmem/memory/utils.py
@@ -20,6 +20,10 @@ class MemoryEntry:
memory: str = ""
original_memory: str = ""
compressed_memory: str = ""
+ topic_id: Optional[int] = None
+ topic_summary: str = ""
+ speaker_id: str = ""
+ speaker_name: str = ""
hit_time: int = 0
update_queue: List = field(default_factory=list)
@@ -47,10 +51,32 @@ def clean_response(response: str) -> List[Dict[str, Any]]:
return []
-def assign_sequence_numbers_with_timestamps(extract_list):
+def assign_sequence_numbers_with_timestamps(extract_list, offset_ms: int = 500, topic_id_mapping: List[List[int]] = None):
+ from datetime import datetime, timedelta
+ from collections import defaultdict
+
current_index = 0
timestamps_list = []
weekday_list = []
+ speaker_list = []
+ message_refs = []
+
+ for segments in extract_list:
+ for seg in segments:
+ for message in seg:
+ session_time = message.get('session_time', '')
+ message_refs.append((message, session_time))
+
+ session_groups = defaultdict(list)
+ for msg, sess_time in message_refs:
+ session_groups[sess_time].append(msg)
+
+ for sess_time, messages in session_groups.items():
+ base_dt = datetime.strptime(sess_time, "%Y-%m-%d %H:%M:%S")
+ for i, msg in enumerate(messages):
+ offset = timedelta(milliseconds=offset_ms * i)
+ new_dt = base_dt + offset
+ msg['time_stamp'] = new_dt.isoformat(timespec='milliseconds')
for segments in extract_list:
for seg in segments:
@@ -58,9 +84,23 @@ def assign_sequence_numbers_with_timestamps(extract_list):
message["sequence_number"] = current_index
timestamps_list.append(message["time_stamp"])
weekday_list.append(message["weekday"])
+ speaker_info = {
+ 'speaker_id': message.get('speaker_id', 'Unknown'),
+ 'speaker_name': message.get('speaker_name', 'Unknown')
+ }
+ speaker_list.append(speaker_info)
current_index += 1
-
- return extract_list, timestamps_list, weekday_list
+
+ sequence_to_topic = {}
+ if topic_id_mapping:
+ for api_idx, api_call_segments in enumerate(extract_list):
+ for topic_idx, topic_segment in enumerate(api_call_segments):
+ tid = topic_id_mapping[api_idx][topic_idx]
+ for msg in topic_segment:
+ seq = msg.get("sequence_number")
+ sequence_to_topic[seq] = tid
+
+ return extract_list, timestamps_list, weekday_list, speaker_list, sequence_to_topic
# TODO:merge into context retriever
def save_memory_entries(memory_entries, file_path="memory_entries.json"):
@@ -68,6 +108,8 @@ def entry_to_dict(entry):
return {
"id": entry.id,
"time_stamp": entry.time_stamp,
+ "topic_id": entry.topic_id,
+ "topic_summary": entry.topic_summary,
"category": entry.category,
"subcategory": entry.subcategory,
"memory_class": entry.memory_class,
@@ -114,3 +156,127 @@ def resolve_tokenizer(tokenizer_or_name: Union[str, Any]):
raise ValueError(f"Unknown model_name '{tokenizer_or_name}'")
raise TypeError(f"Unsupported tokenizer type: {type(tokenizer_or_name)}")
+
+
+def convert_extraction_results_to_memory_entries(
+ extracted_results: List[Optional[Dict]],
+ timestamps_list: List,
+ weekday_list: List,
+ speaker_list: List = None,
+ topic_id_map: Dict[int, int] = None,
+ logger = None
+) -> List[MemoryEntry]:
+ """
+ Convert extraction results to MemoryEntry objects.
+
+ Args:
+ extracted_results: Results from meta_text_extract, each containing cleaned_result
+ timestamps_list: List of timestamps indexed by sequence_number
+ weekday_list: List of weekdays indexed by sequence_number
+ speaker_list: List of speaker information
+ topic_id_map: Optional mapping of sequence_number -> topic_id (preferred)
+ logger: Optional logger for debug info
+
+ Returns:
+ List of MemoryEntry objects with assigned topic_id and timestamps
+ """
+ memory_entries = []
+
+ extracted_memory_entry = [
+ item["cleaned_result"]
+ for item in extracted_results
+ if item and item.get("cleaned_result")
+ ]
+
+ for topic_memory in extracted_memory_entry:
+ if not topic_memory:
+ continue
+
+ for topic_idx, fact_list in enumerate(topic_memory):
+ if not isinstance(fact_list, list):
+ fact_list = [fact_list]
+
+ for fact_entry in fact_list:
+ sid = int(fact_entry.get("source_id"))
+ seq_candidate = sid * 2
+ resolved_topic_id = topic_id_map[seq_candidate]
+
+ mem_obj = _create_memory_entry_from_fact(
+ fact_entry,
+ timestamps_list,
+ weekday_list,
+ speaker_list,
+ topic_id=resolved_topic_id,
+ topic_summary="",
+ logger=logger,
+ )
+
+ if mem_obj:
+ memory_entries.append(mem_obj)
+
+ return memory_entries
+
+
+def _create_memory_entry_from_fact(
+ fact_entry: Dict,
+ timestamps_list: List,
+ weekday_list: List,
+ speaker_list: List = None,
+ topic_id: int = None,
+ topic_summary: str = "",
+ logger = None
+) -> Optional[MemoryEntry]:
+ """
+ Helper function to create a MemoryEntry from a fact entry.
+
+ Args:
+ fact_entry: Dict containing source_id and fact
+ timestamps_list: List of timestamps indexed by sequence_number
+ weekday_list: List of weekdays indexed by sequence_number
+ speaker_list: List of speaker information
+ topic_id: Topic ID for this memory entry
+ topic_summary: Topic summary for this memory entry (reserved for future use)
+ logger: Optional logger for warnings
+
+ Returns:
+ MemoryEntry object or None if creation fails
+ """
+ sequence_n = fact_entry.get("source_id") * 2
+
+ try:
+ time_stamp = timestamps_list[sequence_n]
+
+ if not isinstance(time_stamp, float):
+ from datetime import datetime
+ float_time_stamp = datetime.fromisoformat(time_stamp).timestamp()
+ else:
+ float_time_stamp = time_stamp
+
+ weekday = weekday_list[sequence_n]
+ speaker_info = speaker_list[sequence_n]
+ speaker_id = speaker_info.get('speaker_id', 'unknown')
+ speaker_name = speaker_info.get('speaker_name', 'Unknown')
+
+ except (IndexError, TypeError, ValueError) as e:
+ if logger:
+ logger.warning(
+ f"Error getting timestamp for sequence {sequence_n}: {e}"
+ )
+ time_stamp = None
+ float_time_stamp = None
+ weekday = None
+ speaker_id = 'unknown'
+ speaker_name = 'Unknown'
+
+ mem_obj = MemoryEntry(
+ time_stamp=time_stamp,
+ float_time_stamp=float_time_stamp,
+ weekday=weekday,
+ memory=fact_entry.get("fact", ""),
+ speaker_id=speaker_id,
+ speaker_name=speaker_name,
+ topic_id=topic_id,
+ topic_summary=topic_summary,
+ )
+
+ return mem_obj
\ No newline at end of file
|