diff --git a/README.md b/README.md index 9c62924..b7ed422 100644 --- a/README.md +++ b/README.md @@ -232,10 +232,10 @@ lightmem = LightMemory.from_config(config_dict) session = { "timestamp": "2025-01-10", "turns": [ - [ - {"role": "user", "content": "My favorite ice cream flavor is pistachio, and my dog's name is Rex."}, - {"role": "assistant", "content": "Got it. Pistachio is a great choice."}], - ] + [ + {"role": "user", "content": "My favorite ice cream flavor is pistachio, and my dog's name is Rex.", "speaker_name": "John","speaker_id": "speaker_a"}, + {"role": "assistant", "content": "Got it. Pistachio is a great choice.", "speaker_name": "Assistant", "speaker_id": "speaker_b"}], + ] } @@ -377,10 +377,10 @@ We welcome contributions from the community! If you'd like to contribute, please - - Memos + + MemOS
- Memos + MemOS
diff --git a/src/lightmem/configs/memory_manager/base_config.py b/src/lightmem/configs/memory_manager/base_config.py index e8855cc..c57a8d9 100644 --- a/src/lightmem/configs/memory_manager/base_config.py +++ b/src/lightmem/configs/memory_manager/base_config.py @@ -13,6 +13,7 @@ def __init__( max_tokens: int = 2000, top_p: float = 0.1, top_k: int = 1, + include_topic_summary: bool = False, enable_vision: bool = False, vision_details: Optional[str] = "auto", # Openai specific @@ -31,6 +32,7 @@ def __init__( self.max_tokens = max_tokens self.top_p = top_p self.top_k = top_k + self.include_topic_summary = include_topic_summary self.enable_vision = enable_vision self.vision_details = vision_details # Openai specific diff --git a/src/lightmem/configs/text_embedder/base.py b/src/lightmem/configs/text_embedder/base.py index 844c8ba..b047d27 100644 --- a/src/lightmem/configs/text_embedder/base.py +++ b/src/lightmem/configs/text_embedder/base.py @@ -8,7 +8,7 @@ class TextEmbedderConfig(BaseModel): description="The embedding model or Deployment platform (e.g., 'openai', 'huggingface')" ) - _model_list: ClassVar[List[str]] = ["huggingface"] + _model_list: ClassVar[List[str]] = ["huggingface", "openai"] configs: Optional[Union[BaseTextEmbedderConfig, Dict[str, Any]]] = Field( default=None, diff --git a/src/lightmem/factory/memory_buffer/short_term_memory.py b/src/lightmem/factory/memory_buffer/short_term_memory.py index 2a75d67..2bfe201 100644 --- a/src/lightmem/factory/memory_buffer/short_term_memory.py +++ b/src/lightmem/factory/memory_buffer/short_term_memory.py @@ -7,7 +7,7 @@ def __init__(self, max_tokens: int = 2000, tokenizer: Optional[Any] = None): self.tokenizer = resolve_tokenizer(tokenizer) self.buffer: List[List[Dict[str, Any]]] = [] self.token_count: int = 0 - + print(f"ShortMemBufferManager initialized with max_tokens={self.max_tokens}") def _count_tokens(self, messages: List[Dict[str, Any]], messages_use: str) -> int: role_map = { "user_only": ["user"], diff --git a/src/lightmem/factory/memory_manager/openai.py b/src/lightmem/factory/memory_manager/openai.py index 73f4aad..a61b9cb 100644 --- a/src/lightmem/factory/memory_manager/openai.py +++ b/src/lightmem/factory/memory_manager/openai.py @@ -1,6 +1,7 @@ import concurrent +from collections import defaultdict from openai import OpenAI -from typing import List, Dict, Optional, Literal +from typing import List, Dict, Optional, Literal, Any import json, os, warnings import httpx from lightmem.configs.memory_manager.base_config import BaseMemoryManagerConfig @@ -120,28 +121,41 @@ def generate_response( params["tool_choice"] = tool_choice response = self.client.chat.completions.create(**params) - return self._parse_response(response, tools) - + usage_info = { + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens, + } + parsed_response = self._parse_response(response, tools) + + return parsed_response, usage_info + def meta_text_extract( self, system_prompt: str, extract_list: List[List[List[Dict]]], - messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only" + timestamps_list: Optional[List[List[List[str]]]] = None, + weekday_list: Optional[List[List[List[str]]]] = None, + messages_use: Literal["user_only", "assistant_only", "hybrid"] = "user_only", + topic_id_mapping: Optional[List[List[int]]] = None ) -> List[Optional[Dict]]: """ Extract metadata from text segments using parallel processing. Args: system_prompt: The system prompt for metadata generation - all_segments: List of message segments to process + extract_list: List of message segments to process + timestamps_list: Optional list of timestamps (reserved for future use) + weekday_list: Optional list of weekdays (reserved for future use) messages_use: Strategy for which messages to use + topic_id_mapping: For each API call, the global topic IDs Returns: List of extracted metadata results, None for failed segments """ if not extract_list: return [] - + def concatenate_messages(segment: List[Dict], messages_use: str) -> str: """Concatenate messages based on usage strategy""" role_filter = { @@ -161,43 +175,69 @@ def concatenate_messages(segment: List[Dict], messages_use: str) -> str: sequence_id = mes["sequence_number"] role = mes["role"] content = mes.get("content", "") - message_lines.append(f"{sequence_id}.{role}: {content}") - + speaker_name = mes.get("speaker_name", "") + time_stamp = mes.get("time_stamp", "") + weekday = mes.get("weekday", "") + + time_prefix = "" + if time_stamp and weekday: + time_prefix = f"[{time_stamp}, {weekday}] " + + if speaker_name != 'Unknown': + message_lines.append(f"{time_prefix}{sequence_id//2}.{speaker_name}: {content}") + else: + message_lines.append(f"{time_prefix}{sequence_id//2}.{role}: {content}") + return "\n".join(message_lines) - + max_workers = min(len(extract_list), 5) - def process_segment_wrapper(api_call_segments: List[List[Dict]]): - """Process one API call (multiple topic segments inside)""" + def process_segment_wrapper(args): + api_call_idx, api_call_segments = args try: - user_prompt_parts = [] - for idx, topic_segment in enumerate(api_call_segments, start=1): + user_prompt_parts: List[str] = [] + + global_topic_ids: List[int] = [] + if topic_id_mapping and api_call_idx < len(topic_id_mapping): + global_topic_ids = topic_id_mapping[api_call_idx] + + for topic_idx, topic_segment in enumerate(api_call_segments): + if topic_idx < len(global_topic_ids): + global_topic_id = global_topic_ids[topic_idx] + else: + global_topic_id = topic_idx + 1 + topic_text = concatenate_messages(topic_segment, messages_use) - user_prompt_parts.append(f"--- Topic {idx} ---\n{topic_text}") + user_prompt_parts.append(f"--- Topic {global_topic_id} ---\n{topic_text}") + print(f"User prompt for API call {api_call_idx}:\n" + "\n".join(user_prompt_parts)) user_prompt = "\n".join(user_prompt_parts) - - messages = [ + + metadata_messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": user_prompt} + {"role": "user", "content": user_prompt}, ] - raw_response = self.generate_response( - messages=messages, - response_format={"type": "json_object"} + + raw_response, usage_info = self.generate_response( + messages=metadata_messages, + response_format={"type": "json_object"}, ) - cleaned_result = clean_response(raw_response) + metadata_facts = clean_response(raw_response) + return { - "input_prompt": messages, + "input_prompt": metadata_messages, "output_prompt": raw_response, - "cleaned_result": cleaned_result + "cleaned_result": metadata_facts, + "usage": usage_info, } + except Exception as e: - print(f"Error processing API call: {e}") + print(f"Error processing API call {api_call_idx}: {e}") return None with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: try: - results = list(executor.map(process_segment_wrapper, extract_list)) + results = list(executor.map(process_segment_wrapper, enumerate(extract_list))) except Exception as e: print(f"Error in parallel processing: {e}") results = [None] * len(extract_list) @@ -218,15 +258,16 @@ def _call_update_llm(self, system_prompt, target_entry, candidate_sources): {"role": "user", "content": user_prompt} ] - response_text = self.generate_response( + response_text, usage_info = self.generate_response( messages=messages, response_format={"type": "json_object"} ) - + try: result = json.loads(response_text) if "action" not in result: - return {"action": "ignore"} + result = {"action": "ignore"} + result["usage"] = usage_info return result except Exception: - return {"action": "ignore"} \ No newline at end of file + return {"action": "ignore", "usage": usage_info if 'usage_info' in locals() else None} \ No newline at end of file diff --git a/src/lightmem/factory/text_embedder/factory.py b/src/lightmem/factory/text_embedder/factory.py index af6d2fe..e9e6106 100644 --- a/src/lightmem/factory/text_embedder/factory.py +++ b/src/lightmem/factory/text_embedder/factory.py @@ -5,6 +5,7 @@ class TextEmbedderFactory: _MODEL_MAPPING: Dict[str, str] = { "huggingface": "lightmem.factory.text_embedder.huggingface.TextEmbedderHuggingface", + "openai": "lightmem.factory.text_embedder.openai.TextEmbedderOpenAI", } @classmethod diff --git a/src/lightmem/factory/text_embedder/huggingface.py b/src/lightmem/factory/text_embedder/huggingface.py index 2166993..67257a9 100644 --- a/src/lightmem/factory/text_embedder/huggingface.py +++ b/src/lightmem/factory/text_embedder/huggingface.py @@ -7,12 +7,16 @@ class TextEmbedderHuggingface: def __init__(self, config: Optional[BaseTextEmbedderConfig] = None): self.config = config + self.total_calls = 0 + self.total_tokens = 0 if config.huggingface_base_url: self.client = OpenAI(base_url=config.huggingface_base_url) + self.use_api = True else: self.config.model = config.model or "all-MiniLM-L6-v2" self.model = SentenceTransformer(config.model, **config.model_kwargs) self.config.embedding_dims = config.embedding_dims or self.model.get_sentence_embedding_dimension() + self.use_api = False @classmethod def from_config(cls, config): @@ -39,11 +43,20 @@ def embed(self, text): Returns: list: The embedding vector. """ + self.total_calls += 1 if self.config.huggingface_base_url: - return self.client.embeddings.create(input=text, model="tei").data[0].embedding + response = self.client.embeddings.create(input=text, model="tei") + self.total_tokens += getattr(response.usage, 'total_tokens', 0) + return response.data[0].embedding else: result = self.model.encode(text, convert_to_numpy=True) if isinstance(result, np.ndarray): return result.tolist() else: - return result \ No newline at end of file + return result + + def get_stats(self): + return { + "total_calls": self.total_calls, + "total_tokens": self.total_tokens, + } \ No newline at end of file diff --git a/src/lightmem/factory/text_embedder/openai.py b/src/lightmem/factory/text_embedder/openai.py new file mode 100644 index 0000000..05be917 --- /dev/null +++ b/src/lightmem/factory/text_embedder/openai.py @@ -0,0 +1,53 @@ +from openai import OpenAI +from typing import Optional, List, Union +import os +import httpx +from lightmem.configs.text_embedder.base_config import BaseTextEmbedderConfig + + +class TextEmbedderOpenAI: + def __init__(self, config: Optional[BaseTextEmbedderConfig] = None): + self.config = config + self.model = getattr(config, "model", None) or "text-embedding-3-small" + http_client = httpx.Client(verify=False) + api_key = self.config.api_key + base_url = self.config.openai_base_url + self.client = OpenAI( + api_key=api_key, + base_url=base_url, + http_client=http_client + ) + self.total_calls = 0 + self.total_tokens = 0 + + @classmethod + def from_config(cls, config: BaseTextEmbedderConfig): + return cls(config) + + def embed(self, text: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: + def preprocess(t): + return str(t).replace("\n", " ") + + api_params = {"model": self.config.model} + api_params["dimensions"] = self.config.embedding_dims + + if isinstance(text, list): + if len(text) == 0: + return [] + inputs = [preprocess(x) for x in text] + resp = self.client.embeddings.create(input=inputs, **api_params) + self.total_calls += 1 + self.total_tokens += resp.usage.total_tokens + return [item.embedding for item in resp.data] + else: + preprocessed = preprocess(text) + resp = self.client.embeddings.create(input=[preprocessed], **api_params) + self.total_calls += 1 + self.total_tokens += resp.usage.total_tokens + return resp.data[0].embedding + + def get_stats(self): + return { + "total_calls": self.total_calls, + "total_tokens": self.total_tokens, + } \ No newline at end of file diff --git a/src/lightmem/memory/lightmem.py b/src/lightmem/memory/lightmem.py index b98297f..a82e1ef 100644 --- a/src/lightmem/memory/lightmem.py +++ b/src/lightmem/memory/lightmem.py @@ -17,10 +17,12 @@ from lightmem.factory.retriever.embeddingretriever.factory import EmbeddingRetrieverFactory from lightmem.factory.memory_buffer.sensory_memory import SenMemBufferManager from lightmem.factory.memory_buffer.short_term_memory import ShortMemBufferManager -from lightmem.memory.utils import MemoryEntry, assign_sequence_numbers_with_timestamps, save_memory_entries -from lightmem.memory.prompts import METADATA_GENERATE_PROMPT, UPDATE_PROMPT +from lightmem.memory.utils import MemoryEntry, assign_sequence_numbers_with_timestamps, save_memory_entries, convert_extraction_results_to_memory_entries +from lightmem.memory.prompts import METADATA_GENERATE_PROMPT, UPDATE_PROMPT, METADATA_GENERATE_PROMPT_locomo from lightmem.configs.logging.utils import get_logger +GLOBAL_TOPIC_IDX = 0 + class MessageNormalizer: @@ -136,6 +138,19 @@ def __init__(self, config: BaseMemoryConfigs = BaseMemoryConfigs()): self.logger = get_logger("LightMemory") self.logger.info("Initializing LightMemory with provided configuration") + self.token_stats = { + "add_memory_calls": 0, + "add_memory_prompt_tokens": 0, + "add_memory_completion_tokens": 0, + "add_memory_total_tokens": 0, + "update_calls": 0, + "update_prompt_tokens": 0, + "update_completion_tokens": 0, + "update_total_tokens": 0, + "embedding_calls": 0, + "embedding_total_tokens": 0, + } + self.logger.info("Token statistics tracking initialized") self.config = config if self.config.pre_compress: @@ -147,7 +162,7 @@ def __init__(self, config: BaseMemoryConfigs = BaseMemoryConfigs()): self.senmem_buffer_manager = SenMemBufferManager(max_tokens=self.segmenter.buffer_len, tokenizer=self.segmenter.tokenizer) self.logger.info("Initializing memory manager") self.manager = MemoryManagerFactory.from_config(self.config.memory_manager) - self.shortmem_buffer_manager = ShortMemBufferManager(max_tokens = 1024, tokenizer=getattr(self.manager, "tokenizer", self.manager.config.model)) + self.shortmem_buffer_manager = ShortMemBufferManager(max_tokens = 512, tokenizer=getattr(self.manager, "tokenizer", self.manager.config.model)) if self.config.index_strategy == 'embedding' or self.config.index_strategy == 'hybrid': self.logger.info("Initializing text embedder") self.text_embedder = TextEmbedderFactory.from_config(self.config.text_embedder) @@ -283,60 +298,75 @@ def add_memory( self.logger.debug(f"[{call_id}] Extraction not triggered, returning result") return result # TODO + global GLOBAL_TOPIC_IDX + topic_id_mapping = [] + for api_call_segments in extract_list: + api_call_topic_ids = [] + for topic_segment in api_call_segments: + api_call_topic_ids.append(GLOBAL_TOPIC_IDX) + GLOBAL_TOPIC_IDX += 1 + topic_id_mapping.append(api_call_topic_ids) + self.logger.debug(f"topic_id_mapping: {topic_id_mapping}") + self.logger.info(f"[{call_id}] Assigned global topic IDs: total={sum(len(x) for x in topic_id_mapping)}, mapping={topic_id_mapping}") self.logger.info(f"[{call_id}] Extraction triggered {extract_trigger_num} times, extract_list length: {len(extract_list)}") - extract_list, timestamps_list, weekday_list = assign_sequence_numbers_with_timestamps(extract_list) - self.logger.info(f"[{call_id}] Assigned timestamps to {len(extract_list)} items") - self.logger.debug(f"[{call_id}] Timestamps sample: {timestamps_list}") - self.logger.debug(f"[{call_id}] Weekdays sample: {weekday_list}") + extract_list, timestamps_list, weekday_list, speaker_list, topic_id_map = assign_sequence_numbers_with_timestamps(extract_list, offset_ms=500, topic_id_mapping=topic_id_mapping) self.logger.debug(f"[{call_id}] Extract list sample: {json.dumps(extract_list)}") if self.config.metadata_generate and self.config.text_summary: self.logger.info(f"[{call_id}] Starting metadata generation") - extracted_results = self.manager.meta_text_extract(METADATA_GENERATE_PROMPT, extract_list, self.config.messages_use) - for item in extracted_results: - if item is not None: - result["add_input_prompt"].append(item["input_prompt"]) - result["add_output_prompt"].append(item["output_prompt"]) - result["api_call_nums"] += 1 + extracted_results = self.manager.meta_text_extract(METADATA_GENERATE_PROMPT_locomo, extract_list, self.config.messages_use, topic_id_mapping) + + # =============API Consumption====================== + for idx, item in enumerate(extracted_results): + if item is None: + continue + + if "usage" in item: + usage = item["usage"] + self.token_stats["add_memory_calls"] += 1 + self.token_stats["add_memory_prompt_tokens"] += usage.get("prompt_tokens", 0) + self.token_stats["add_memory_completion_tokens"] += usage.get("completion_tokens", 0) + self.token_stats["add_memory_total_tokens"] += usage.get("total_tokens", 0) + + self.logger.info( + f"[{call_id}] API Call {idx} tokens - " + f"Prompt: {usage.get('prompt_tokens', 0)}, " + f"Completion: {usage.get('completion_tokens', 0)}, " + f"Total: {usage.get('total_tokens', 0)}" + ) + + self.logger.debug(f"[{call_id}] API Call {idx} raw output: {item['output_prompt']}") + self.logger.debug(f"[{call_id}] API Call {idx} cleaned result: {item['cleaned_result']}") + result["add_input_prompt"].append(item["input_prompt"]) + result["add_output_prompt"].append(item["output_prompt"]) + result["api_call_nums"] += 1 + + # ======================================= + self.logger.info(f"[{call_id}] Metadata generation completed with {result['api_call_nums']} API calls") - extracted_memory_entry = [item["cleaned_result"] for item in extracted_results if item] - self.logger.info(f"[{call_id}] Extracted {len(extracted_memory_entry)} memory entries") - self.logger.debug(f"[{call_id}] Extracted memory entry sample: {json.dumps(extracted_memory_entry)}") - memory_entries = [] - for topic_memory in extracted_memory_entry: - if not topic_memory: - continue - for entry in topic_memory: - sequence_n = entry.get("source_id") - try: - time_stamp = timestamps_list[sequence_n] - if not isinstance(time_stamp, float): - float_time_stamp = datetime.fromisoformat(time_stamp).timestamp() - weekday = weekday_list[sequence_n] - except (IndexError, TypeError) as e: - self.logger.warning(f"[{call_id}] Error getting timestamp for sequence {sequence_n}: {e}") - time_stamp = None - float_time_stamp = None - weekday = None - mem_obj = MemoryEntry( - time_stamp=time_stamp, - float_time_stamp=float_time_stamp, - weekday=weekday, - memory=entry.get("fact", ""), - # original_memory=entry.get("original_fact", ""), # TODO - # compressed_memory="" # TODO - ) - memory_entries.append(mem_obj) + memory_entries = convert_extraction_results_to_memory_entries( + extracted_results=extracted_results, + timestamps_list=timestamps_list, + weekday_list=weekday_list, + speaker_list=speaker_list, + topic_id_map=topic_id_map, + logger=self.logger + ) self.logger.info(f"[{call_id}] Created {len(memory_entries)} MemoryEntry objects") for i, mem in enumerate(memory_entries): - self.logger.debug(f"[{call_id}] MemoryEntry[{i}]: time={mem.time_stamp}, weekday={mem.weekday}, memory={mem.memory}") + self.logger.debug(f"[{call_id}] MemoryEntry[{i}]: time={mem.time_stamp}, weekday={mem.weekday}, speaker_id={mem.speaker_id}, speaker_name={mem.speaker_name}, topic_id={mem.topic_id}, memory={mem.memory}") if self.config.update == "online": self.online_update(memory_entries) elif self.config.update == "offline": self.offline_update(memory_entries) + self.logger.info( + f"[{call_id}] Cumulative token stats - " + f"Total API calls: {self.token_stats['add_memory_calls']}, " + f"Total tokens: {self.token_stats['add_memory_total_tokens']}" + ) return result def online_update(self, memory_list: List): @@ -365,12 +395,16 @@ def offline_update(self, memory_list: List, construct_update_queue_trigger: bool "time_stamp": mem_obj.time_stamp, "float_time_stamp": mem_obj.float_time_stamp, "weekday": mem_obj.weekday, + "topic_id": mem_obj.topic_id, + "topic_summary": mem_obj.topic_summary, "category": mem_obj.category, "subcategory": mem_obj.subcategory, "memory_class": mem_obj.memory_class, "memory": mem_obj.memory, "original_memory": mem_obj.original_memory, "compressed_memory": mem_obj.compressed_memory, + "speaker_id": mem_obj.speaker_id, + "speaker_name": mem_obj.speaker_name, } self.embedding_retriever.insert( vectors = [embedding_vector], @@ -499,7 +533,13 @@ def offline_update_all_entries(self, score_threshold: float = 0.5, max_workers: skipped_count = 0 lock = threading.Lock() write_lock = threading.Lock() - + update_token_stats = { + "calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0 + } + token_lock = threading.Lock() def update_entry(entry): nonlocal processed_count, updated_count, deleted_count, skipped_count @@ -526,7 +566,19 @@ def update_entry(entry): if updated_entry is None: return - + # ====== token consumption ====== + usage = updated_entry["usage"] + with token_lock: + update_token_stats["calls"] += 1 + update_token_stats["prompt_tokens"] += usage.get("prompt_tokens", 0) + update_token_stats["completion_tokens"] += usage.get("completion_tokens", 0) + update_token_stats["total_tokens"] += usage.get("total_tokens", 0) + + self.logger.debug( + f"[{call_id}] Update LLM call for {eid} - " + f"Tokens: {usage.get('total_tokens', 0)}" + ) + # ==================== token consumption ==================== action = updated_entry.get("action") if action == "delete": with write_lock: @@ -546,11 +598,20 @@ def update_entry(entry): self.logger.info(f"[{call_id}] Starting parallel offline update with {max_workers} workers") with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: executor.map(update_entry, all_entries) + with lock: + self.token_stats["update_calls"] += update_token_stats["calls"] + self.token_stats["update_prompt_tokens"] += update_token_stats["prompt_tokens"] + self.token_stats["update_completion_tokens"] += update_token_stats["completion_tokens"] + self.token_stats["update_total_tokens"] += update_token_stats["total_tokens"] self.logger.info(f"[{call_id}] Offline update completed:") self.logger.info(f"[{call_id}] - Processed: {processed_count} entries") self.logger.info(f"[{call_id}] - Updated: {updated_count} entries") self.logger.info(f"[{call_id}] - Deleted: {deleted_count} entries") self.logger.info(f"[{call_id}] - Skipped (no candidates): {skipped_count} entries") + self.logger.info( + f"[{call_id}] - Update API calls: {update_token_stats['calls']}, " + f"Total tokens: {update_token_stats['total_tokens']}" + ) self.logger.info(f"========== END {call_id} ==========") def retrieve(self, query: str, limit: int = 10, filters: dict = None) -> list[str]: @@ -595,3 +656,37 @@ def retrieve(self, query: str, limit: int = 10, filters: dict = None) -> list[st self.logger.info(f"========== END {call_id} ==========") return result_string + def get_token_statistics(self): + embedder_stats = {"total_calls": 0, "total_tokens": None} + if hasattr(self, 'text_embedder') and hasattr(self.text_embedder, 'get_stats'): + embedder_stats = self.text_embedder.get_stats() + + stats = { + "summary": { + "total_llm_calls": self.token_stats["add_memory_calls"] + self.token_stats["update_calls"], + "total_llm_tokens": self.token_stats["add_memory_total_tokens"] + self.token_stats["update_total_tokens"], + "total_embedding_calls": embedder_stats["total_calls"], + "total_embedding_tokens": embedder_stats["total_tokens"], + }, + "llm": { + "add_memory": { + "calls": self.token_stats["add_memory_calls"], + "prompt_tokens": self.token_stats["add_memory_prompt_tokens"], + "completion_tokens": self.token_stats["add_memory_completion_tokens"], + "total_tokens": self.token_stats["add_memory_total_tokens"], + }, + "update": { + "calls": self.token_stats["update_calls"], + "prompt_tokens": self.token_stats["update_prompt_tokens"], + "completion_tokens": self.token_stats["update_completion_tokens"], + "total_tokens": self.token_stats["update_total_tokens"], + }, + }, + "embedding": { + "total_calls": embedder_stats["total_calls"], + "total_tokens": embedder_stats["total_tokens"], + "note": "Includes topic segmentation + memory indexing. Local models show None for tokens." + } + } + + return stats \ No newline at end of file diff --git a/src/lightmem/memory/prompts.py b/src/lightmem/memory/prompts.py index 512d388..8852ec2 100644 --- a/src/lightmem/memory/prompts.py +++ b/src/lightmem/memory/prompts.py @@ -53,6 +53,85 @@ Reminder: Be exhaustive. Unless a message is purely meaningless, extract and output it as a fact. """ +# 73.90 +METADATA_GENERATE_PROMPT_locomo = """ +You are a Personal Information Extractor. +Your task is to extract **all possible facts or information** about the speakers from a conversation, +where the dialogue is organized into topic segments separated by markers like: + +--- Topic X --- +[timestamp, weekday] .: +... + +Important Instructions: +0. You MUST process messages **strictly in ascending source_id order** (lowest → highest). + For each message, stop and **carefully** evaluate its content before moving to the next. + Do NOT reorder, batch-skip, or skip ahead — treat messages one-by-one. +1. You MUST process every user message in order, one by one. + For each message, decide whether it contains any factual information. + - If yes → extract it and rephrase into a standalone sentence. + - Do NOT skip just because the information looks minor, trivial, or unimportant. + Extract ALL meaningful information including: + * Past events and current states + * Future plans and intentions + * Thoughts, opinions, and attitudes + * Wants, hopes, desires, and preferences +2. **CRITICAL - Preserve All Specific Details**: + When extracting facts, you MUST include ALL specific entities and details mentioned: + - **Full names with context**: "The Name of the Wind" by Patrick Rothfuss (not just "a book") + - **Complete location names**: Galway, Ireland; The Cliffs of Moher; Barcelona (not just "a city") + - **Specific event names**: benefit basketball game, study abroad program (not just "an event") + - **Product/item details**: vintage camera, brand new fire truck (not just "a camera") + - **Numbers and quantities**: 4 years ago, next month, last week + - **Company/organization names**: beverage company, fire-fighting brigade + Additionally, **infer implied information** when clearly supported: + - If multiple related items mentioned → may infer general pattern + - Keep BOTH specific facts AND inferred insights as separate entries +3. Perform light contextual completion so that each fact is a clear standalone statement. +4. **Time Handling**: + Note: Distinguish mention time (when said) vs event time (when happened). + - For events with relative time (yesterday, last week, X ago, next month): + Preserve the relative time and reference the message timestamp (YYYY-MM-DD). + Format: " ." + - For ongoing/timeless facts: No time annotation needed. +5. Output format: + Always return a JSON object with key `"data"`, which is a list of items: + { + "source_id": "", + "fact": "" + } + +Examples: +--- Topic 1 --- +[2024-01-07T17:24:00.000, Sun] 0.Tim: Hey John! Next month I'm off to Ireland for a semester in Galway +[2024-01-07T17:24:01.000, Sun] 1.John: That's awesome! Where will you stay? +[2024-01-07T17:24:02.000, Sun] 2.Tim: In Galway. I also want to visit The Cliffs of Moher +[2024-01-07T17:24:03.000, Sun] 3.John: Nice! By the way, I held a benefit basketball game last week +[2024-01-07T17:24:04.000, Sun] 4.Tim: Cool! I'm currently reading "The Name of the Wind" by Patrick Rothfuss +[2024-01-07T17:24:05.000, Sun] 5.John: That sounds interesting! +--- Topic 2 --- +[2024-01-12T13:41:00.000, Fri] 6.John: Got great news! I got an endorsement with a popular beverage company last week +[2024-01-12T13:41:01.000, Fri] 7.Tim: Congrats! That's amazing +[2024-01-12T13:41:02.000, Fri] 8.John: Thanks! By the way, Barcelona is a must-visit city +[2024-01-12T13:41:03.000, Fri] 9.Tim: I'll add it to my list! + +{"data": [ + {"source_id": 0, "fact": "Tim is going to Ireland for a semester in Galway the month after 2024-01-07."}, + {"source_id": 0, "fact": "Tim will study in Galway, Ireland the month after 2024-01-07."}, + {"source_id": 2, "fact": "Tim will stay in Galway."}, + {"source_id": 2, "fact": "Tim wants to visit The Cliffs of Moher."}, + {"source_id": 3, "fact": "John held a benefit basketball game the week before 2024-01-07."}, + {"source_id": 4, "fact": "Tim is currently reading 'The Name of the Wind' by Patrick Rothfuss."}, + {"source_id": 4, "fact": "Tim is reading a fantasy novel."}, + {"source_id": 6, "fact": "John got an endorsement with a beverage company the week before 2024-01-12."}, + {"source_id": 8, "fact": "John recommends Barcelona as a must-visit city."}, + {"source_id": 9, "fact": "Tim has a travel list and plans to add Barcelona to it."} +]} + +Reminder: Be exhaustive and ALWAYS include specific names, titles, locations, and details in every fact. +""" + + # METADATA_GENERATE_PROMPT = """ # You are a Personal Information Extractor. # Your task is to extract meaningful facts about the user from a conversation, diff --git a/src/lightmem/memory/utils.py b/src/lightmem/memory/utils.py index d098afb..0d6253f 100644 --- a/src/lightmem/memory/utils.py +++ b/src/lightmem/memory/utils.py @@ -20,6 +20,10 @@ class MemoryEntry: memory: str = "" original_memory: str = "" compressed_memory: str = "" + topic_id: Optional[int] = None + topic_summary: str = "" + speaker_id: str = "" + speaker_name: str = "" hit_time: int = 0 update_queue: List = field(default_factory=list) @@ -47,10 +51,32 @@ def clean_response(response: str) -> List[Dict[str, Any]]: return [] -def assign_sequence_numbers_with_timestamps(extract_list): +def assign_sequence_numbers_with_timestamps(extract_list, offset_ms: int = 500, topic_id_mapping: List[List[int]] = None): + from datetime import datetime, timedelta + from collections import defaultdict + current_index = 0 timestamps_list = [] weekday_list = [] + speaker_list = [] + message_refs = [] + + for segments in extract_list: + for seg in segments: + for message in seg: + session_time = message.get('session_time', '') + message_refs.append((message, session_time)) + + session_groups = defaultdict(list) + for msg, sess_time in message_refs: + session_groups[sess_time].append(msg) + + for sess_time, messages in session_groups.items(): + base_dt = datetime.strptime(sess_time, "%Y-%m-%d %H:%M:%S") + for i, msg in enumerate(messages): + offset = timedelta(milliseconds=offset_ms * i) + new_dt = base_dt + offset + msg['time_stamp'] = new_dt.isoformat(timespec='milliseconds') for segments in extract_list: for seg in segments: @@ -58,9 +84,23 @@ def assign_sequence_numbers_with_timestamps(extract_list): message["sequence_number"] = current_index timestamps_list.append(message["time_stamp"]) weekday_list.append(message["weekday"]) + speaker_info = { + 'speaker_id': message.get('speaker_id', 'Unknown'), + 'speaker_name': message.get('speaker_name', 'Unknown') + } + speaker_list.append(speaker_info) current_index += 1 - - return extract_list, timestamps_list, weekday_list + + sequence_to_topic = {} + if topic_id_mapping: + for api_idx, api_call_segments in enumerate(extract_list): + for topic_idx, topic_segment in enumerate(api_call_segments): + tid = topic_id_mapping[api_idx][topic_idx] + for msg in topic_segment: + seq = msg.get("sequence_number") + sequence_to_topic[seq] = tid + + return extract_list, timestamps_list, weekday_list, speaker_list, sequence_to_topic # TODO:merge into context retriever def save_memory_entries(memory_entries, file_path="memory_entries.json"): @@ -68,6 +108,8 @@ def entry_to_dict(entry): return { "id": entry.id, "time_stamp": entry.time_stamp, + "topic_id": entry.topic_id, + "topic_summary": entry.topic_summary, "category": entry.category, "subcategory": entry.subcategory, "memory_class": entry.memory_class, @@ -114,3 +156,127 @@ def resolve_tokenizer(tokenizer_or_name: Union[str, Any]): raise ValueError(f"Unknown model_name '{tokenizer_or_name}'") raise TypeError(f"Unsupported tokenizer type: {type(tokenizer_or_name)}") + + +def convert_extraction_results_to_memory_entries( + extracted_results: List[Optional[Dict]], + timestamps_list: List, + weekday_list: List, + speaker_list: List = None, + topic_id_map: Dict[int, int] = None, + logger = None +) -> List[MemoryEntry]: + """ + Convert extraction results to MemoryEntry objects. + + Args: + extracted_results: Results from meta_text_extract, each containing cleaned_result + timestamps_list: List of timestamps indexed by sequence_number + weekday_list: List of weekdays indexed by sequence_number + speaker_list: List of speaker information + topic_id_map: Optional mapping of sequence_number -> topic_id (preferred) + logger: Optional logger for debug info + + Returns: + List of MemoryEntry objects with assigned topic_id and timestamps + """ + memory_entries = [] + + extracted_memory_entry = [ + item["cleaned_result"] + for item in extracted_results + if item and item.get("cleaned_result") + ] + + for topic_memory in extracted_memory_entry: + if not topic_memory: + continue + + for topic_idx, fact_list in enumerate(topic_memory): + if not isinstance(fact_list, list): + fact_list = [fact_list] + + for fact_entry in fact_list: + sid = int(fact_entry.get("source_id")) + seq_candidate = sid * 2 + resolved_topic_id = topic_id_map[seq_candidate] + + mem_obj = _create_memory_entry_from_fact( + fact_entry, + timestamps_list, + weekday_list, + speaker_list, + topic_id=resolved_topic_id, + topic_summary="", + logger=logger, + ) + + if mem_obj: + memory_entries.append(mem_obj) + + return memory_entries + + +def _create_memory_entry_from_fact( + fact_entry: Dict, + timestamps_list: List, + weekday_list: List, + speaker_list: List = None, + topic_id: int = None, + topic_summary: str = "", + logger = None +) -> Optional[MemoryEntry]: + """ + Helper function to create a MemoryEntry from a fact entry. + + Args: + fact_entry: Dict containing source_id and fact + timestamps_list: List of timestamps indexed by sequence_number + weekday_list: List of weekdays indexed by sequence_number + speaker_list: List of speaker information + topic_id: Topic ID for this memory entry + topic_summary: Topic summary for this memory entry (reserved for future use) + logger: Optional logger for warnings + + Returns: + MemoryEntry object or None if creation fails + """ + sequence_n = fact_entry.get("source_id") * 2 + + try: + time_stamp = timestamps_list[sequence_n] + + if not isinstance(time_stamp, float): + from datetime import datetime + float_time_stamp = datetime.fromisoformat(time_stamp).timestamp() + else: + float_time_stamp = time_stamp + + weekday = weekday_list[sequence_n] + speaker_info = speaker_list[sequence_n] + speaker_id = speaker_info.get('speaker_id', 'unknown') + speaker_name = speaker_info.get('speaker_name', 'Unknown') + + except (IndexError, TypeError, ValueError) as e: + if logger: + logger.warning( + f"Error getting timestamp for sequence {sequence_n}: {e}" + ) + time_stamp = None + float_time_stamp = None + weekday = None + speaker_id = 'unknown' + speaker_name = 'Unknown' + + mem_obj = MemoryEntry( + time_stamp=time_stamp, + float_time_stamp=float_time_stamp, + weekday=weekday, + memory=fact_entry.get("fact", ""), + speaker_id=speaker_id, + speaker_name=speaker_name, + topic_id=topic_id, + topic_summary=topic_summary, + ) + + return mem_obj \ No newline at end of file