diff --git a/pyproject.toml b/pyproject.toml index 91956ce2..7d1f73a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,7 @@ Documentation = "https://reme.agentscope.io/" Repository = "https://github.com/agentscope-ai/ReMe" [project.scripts] -reme = "reme_ai.main:main" +reme = "remecli.reme:main" reme2 = "reme.reme:main" remecli = "reme.reme_cli:main" diff --git a/reme_cli/__init__.py b/reme_cli/__init__.py new file mode 100644 index 00000000..f721a453 --- /dev/null +++ b/reme_cli/__init__.py @@ -0,0 +1,9 @@ +"""ReMe CLI package.""" + +from reme_cli.application import Application +from reme_cli.component import BaseComponent + +__all__ = [ + "BaseComponent", + "Application", +] diff --git a/reme_cli/application.py b/reme_cli/application.py new file mode 100644 index 00000000..f43f223d --- /dev/null +++ b/reme_cli/application.py @@ -0,0 +1,137 @@ +"""Application module for managing the main application lifecycle.""" + +import asyncio +from pathlib import Path +from typing import AsyncGenerator + +from enumeration import ComponentEnum +from .component import BaseComponent, ApplicationContext +from .schema import Response, StreamChunk +from .utils import execute_stream_task, print_logo, get_logger + + +class Application(BaseComponent): + """Application component for managing the main application.""" + + def __init__(self, **kwargs) -> None: + super().__init__() + self.context = ApplicationContext(**kwargs) + + working_path = Path(self.config.working_dir).absolute() + working_path.mkdir(parents=True, exist_ok=True) + memory_path = working_path / "memory" + memory_path.mkdir(parents=True, exist_ok=True) + + if self.config.enable_logo: + print_logo(self.config) + + logger = get_logger( + log_to_console=self.config.log_to_console, + log_to_file=self.config.log_to_file, + force_init=True, + ) + logger.info(f"Initializing {self.config.app_name} Application") + + from .component import R + + # Initialize the service + service_config = self.config.service + if not service_config.backend: + raise ValueError("Service configuration is missing the required 'backend' field") + service_cls = R.get(ComponentEnum.SERVICE, service_config.backend) + if not service_cls: + raise ValueError( + f"Service references an unregistered backend '{service_config.backend}' " + f"of type '{ComponentEnum.SERVICE}'", + ) + self.context.service = service_cls(**service_config.model_dump(exclude={"backend"})) + + # Initialize all components grouped by type and name + for component_type, component_configs in self.config.components.items(): + self.context.components[component_type] = {} + for name, config in component_configs.items(): + if not config.backend: + raise ValueError(f"Component '{name}' is missing the required 'backend' field") + backend_cls = R.get(component_type, config.backend) + if not backend_cls: + raise ValueError( + f"Component '{name}' references an unregistered backend '{config.backend}' " + f"of type '{component_type}'", + ) + self.context.components[component_type][name] = backend_cls(**config.model_dump(exclude={"backend"})) + + # Initialize all jobs + for job_config in self.config.jobs: + if not job_config.backend: + raise ValueError(f"Job '{job_config.name}' is missing the required 'backend' field") + + job_cls = R.get(ComponentEnum.JOB, job_config.backend) + if not job_cls: + raise ValueError( + f"Job '{job_config.name}' references an unregistered backend '{job_config.backend}' " + f"of type '{ComponentEnum.JOB}'", + ) + self.context.jobs[job_config.name] = job_cls(**job_config.model_dump(exclude={"backend"})) + + @property + def config(self): + """Get application configuration.""" + return self.context.app_config + + async def _start(self, app_context=None) -> None: + """Start the application.""" + for components in self.context.components.values(): + for component in components.values(): + try: + await component.start(self.context) + except Exception as e: + self.logger.exception(f"Failed to start component {component.__class__.__name__}: {e}") + + for name, job in self.context.jobs.items(): + try: + await job.start(self.context) + except Exception as e: + self.logger.exception(f"Failed to start job '{name}': {e}") + + async def _close(self) -> None: + """Close the application.""" + for name, job in self.context.jobs.items(): + try: + await job.close() + except Exception as e: + self.logger.exception(f"Failed to close job '{name}': {e}") + + for components in reversed(list(self.context.components.values())): + for component in reversed(list(components.values())): + try: + await component.close() + except Exception as e: + self.logger.exception(f"Failed to close component {component.__class__.__name__}: {e}") + + async def run_job(self, name: str, **kwargs) -> Response: + """Execute a registered job by name.""" + if name not in self.context.jobs: + raise KeyError(f"Job '{name}' not found") + job = self.context.jobs[name] + return await job(app_context=self.context, **kwargs) + + async def run_stream_job(self, name: str, **kwargs) -> AsyncGenerator[StreamChunk, None]: + """Execute a streaming job and yield chunks.""" + if name not in self.context.jobs: + raise KeyError(f"Job '{name}' not found") + job = self.context.jobs[name] + stream_queue = asyncio.Queue() + task = asyncio.create_task(job(stream_queue=stream_queue, app_context=self.context, **kwargs)) + async for chunk in execute_stream_task( + stream_queue=stream_queue, + task=task, + task_name=name, + output_format="chunk", + ): + assert isinstance(chunk, StreamChunk) + yield chunk + + def run_app(self): + """Run the application as a service.""" + service = self.context.service + service.run_app(app=self) diff --git a/reme_cli/component/__init__.py b/reme_cli/component/__init__.py new file mode 100644 index 00000000..dcd07fb6 --- /dev/null +++ b/reme_cli/component/__init__.py @@ -0,0 +1,36 @@ +"""Components""" + +from .application_context import ApplicationContext +from .base_component import BaseComponent +from .base_step import BaseStep +from .component_registry import ComponentRegistry, R +from .prompt_handler import PromptHandler +from .runtime_context import RuntimeContext + +from . import as_llm +from . import as_llm_formatter +from . import client +from . import embedding +from . import file_store +from . import file_watcher +from . import job +from . import service + +__all__ = [ + "ApplicationContext", + "BaseComponent", + "BaseStep", + "ComponentRegistry", + "R", + "PromptHandler", + "RuntimeContext", + # base components + "as_llm", + "as_llm_formatter", + "client", + "embedding", + "file_store", + "file_watcher", + "job", + "service", +] diff --git a/reme_cli/component/application_context.py b/reme_cli/component/application_context.py new file mode 100644 index 00000000..4b372353 --- /dev/null +++ b/reme_cli/component/application_context.py @@ -0,0 +1,33 @@ +"""Application context for initializing and managing all configured components.""" + +from ..enumeration import ComponentEnum +from ..schema import ApplicationConfig + + +class ApplicationContext: + """Application context that initializes and manages all configured components. + + This class is responsible for parsing the application configuration, + resolving backend implementations for each component via the registry, + and instantiating services, components, and jobs. + """ + + def __init__(self, **kwargs): + """Initialize the application context from configuration kwargs. + + Args: + **kwargs: Keyword arguments that form the ApplicationConfig, + including app_name, service, components, and jobs. + + Raises: + ValueError: If a required backend is missing for the service, any component, or any job, + or if a service, component, or job references an unregistered backend type. + """ + self.app_config: ApplicationConfig = ApplicationConfig(**kwargs) + + from .base_component import BaseComponent + from .job.base_job import BaseJob + + self.service = None + self.components: dict[ComponentEnum, dict[str, BaseComponent]] = {} + self.jobs: dict[str, BaseJob] = {} diff --git a/reme_cli/component/as_llm/__init__.py b/reme_cli/component/as_llm/__init__.py new file mode 100644 index 00000000..d3f8625c --- /dev/null +++ b/reme_cli/component/as_llm/__init__.py @@ -0,0 +1,57 @@ +"""AgentScope LLM model wrappers.""" + +import asyncio + +from agentscope.model import OpenAIChatModel, ChatModelBase + +from ..base_component import BaseComponent +from ..component_registry import R +from ...enumeration import ComponentEnum + + +class BaseAsLLM(BaseComponent): + """Base wrapper for AgentScope LLM models. + + Subclasses should implement _start() to initialize self.model. + """ + + component_type = ComponentEnum.AS_LLM + + def __init__(self, **kwargs) -> None: + """Initialize with model configuration kwargs.""" + super().__init__(**kwargs) + self.model: ChatModelBase | None = None + + async def _start(self, app_context=None) -> None: + """Initialize the AgentScope model. Override in subclasses.""" + + async def _close(self) -> None: + """Release model resources.""" + self.model = None + + +@R.register("openai") +class OpenAIAsLLM(BaseAsLLM): + """OpenAI chat model wrapper.""" + + async def _start(self, app_context=None) -> None: + """Initialize the OpenAI chat model.""" + self.model = OpenAIChatModel(**self.kwargs) + + async def _close(self) -> None: + """Close the HTTP client and release resources.""" + if self.model is not None: + client = getattr(self.model, "client", None) + if client is not None and hasattr(client, "close"): + close_method = client.close + if asyncio.iscoroutinefunction(close_method): + await close_method() + else: + close_method() + self.model = None + + +__all__ = [ + "BaseAsLLM", + "OpenAIAsLLM", +] diff --git a/reme_cli/component/as_llm_formatter/__init__.py b/reme_cli/component/as_llm_formatter/__init__.py new file mode 100644 index 00000000..0dc3894d --- /dev/null +++ b/reme_cli/component/as_llm_formatter/__init__.py @@ -0,0 +1,38 @@ +"""Module for AgentScope LLM formatter components.""" + +from agentscope.formatter import FormatterBase + +from .reme_openai_chat_formatter import ReMeOpenAIChatFormatter +from ..base_component import BaseComponent +from ..component_registry import R +from ...enumeration import ComponentEnum + + +class BaseAsLLMFormatter(BaseComponent): + """Base wrapper for AgentScope LLM formatters.""" + + component_type = ComponentEnum.AS_LLM_FORMATTER + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.formatter: FormatterBase | None = None + + async def _start(self, app_context=None) -> None: + """Initialize the formatter instance.""" + + async def _close(self) -> None: + self.formatter = None + + +@R.register("openai") +class AsOpenAIChatFormatter(BaseAsLLMFormatter): + """Wrapper for OpenAI chat completion formatter.""" + + async def _start(self, app_context=None) -> None: + self.formatter = ReMeOpenAIChatFormatter(**self.kwargs) + + +__all__ = [ + "BaseAsLLMFormatter", + "AsOpenAIChatFormatter", +] diff --git a/reme_cli/component/as_llm_formatter/reme_openai_chat_formatter.py b/reme_cli/component/as_llm_formatter/reme_openai_chat_formatter.py new file mode 100644 index 00000000..a501da26 --- /dev/null +++ b/reme_cli/component/as_llm_formatter/reme_openai_chat_formatter.py @@ -0,0 +1,163 @@ +"""ReMe OpenAI chat message formatter.""" + +import json +from typing import Any + +from agentscope.formatter import OpenAIChatFormatter +from agentscope.formatter._openai_formatter import ( + _format_openai_image_block, + _to_openai_audio_data, +) +from agentscope.message import Msg, TextBlock, ImageBlock, URLSource + + +def _format_openai_video_block(video_block: dict) -> dict[str, Any]: + """Format a video block for OpenAI API. + + Args: + video_block: The video block containing a URL or base64 source. + + Returns: + A dict with OpenAI-compatible video_url content. + """ + source = video_block["source"] + if source["type"] == "url": + url = source["url"] + elif source["type"] == "base64": + url = f"data:{source['media_type']};base64,{source['data']}" + else: + raise ValueError(f"Unsupported video source type: {source['type']}") + + return {"type": "video_url", "video_url": {"url": url}} + + +class ReMeOpenAIChatFormatter(OpenAIChatFormatter): + """Extends OpenAIChatFormatter with tool result image promotion and reasoning content support.""" + + async def _format( + self, + msgs: list[Msg], + ) -> list[dict[str, Any]]: + """Format messages into OpenAI API format. + + Handles text, thinking (reasoning_content), tool_use, tool_result, + image, audio, and video content blocks. + + Args: + msgs: List of Msg objects. + + Returns: + List of dicts with "role", "name", "content" and optional "tool_calls" or "reasoning_content". + """ + self.assert_list_of_msgs(msgs) + + messages: list[dict] = [] + i = 0 + while i < len(msgs): + msg = msgs[i] + content_blocks = [] + tool_calls = [] + reasoning_content_blocks = [] + + for block in msg.get_content_blocks(): + typ = block.get("type") + + if typ == "text": + content_blocks.append({**block}) + + elif typ == "thinking": + reasoning_content_blocks.append({**block}) + + elif typ == "tool_use": + tool_calls.append( + { + "id": block.get("id"), + "type": "function", + "function": { + "name": block.get("name"), + "arguments": json.dumps(block.get("input", {}), ensure_ascii=False), + }, + }, + ) + + elif typ == "tool_result": + textual_output, multimodal_data = self.convert_tool_result_to_string(block["output"]) + + messages.append( + { + "role": "tool", + "tool_call_id": block.get("id"), + "content": textual_output, + "name": block.get("name"), + }, + ) + + # Promote tool result images into a follow-up user message + promoted_blocks = [] + for url, multimodal_block in multimodal_data: + if multimodal_block["type"] == "image" and self.promote_tool_result_images: + promoted_blocks.extend( + [ + TextBlock(type="text", text=f"\n- The image from '{url}': "), + ImageBlock(type="image", source=URLSource(type="url", url=url)), + ], + ) + + if promoted_blocks: + promoted_blocks = [ + TextBlock( + type="text", + text="The following are the image contents from the tool " + f"result of '{block['name']}':", + ), + *promoted_blocks, + TextBlock(type="text", text=""), + ] + msgs.insert( + i + 1, + Msg(name="user", content=promoted_blocks, role="user"), + ) + + elif typ == "image": + content_blocks.append(_format_openai_image_block(block)) + + elif typ == "audio": + # Skip assistant audio output + if msg.role == "assistant": + continue + content_blocks.append( + { + "type": "input_audio", + "input_audio": _to_openai_audio_data(block["source"]), + }, + ) + + elif typ == "video": + # Skip assistant video output + if msg.role == "assistant": + continue + content_blocks.append(_format_openai_video_block(block)) + + else: + pass # Unsupported block type, skip + + msg_openai = { + "role": msg.role, + "name": msg.name, + "content": content_blocks or None, + } + + if tool_calls: + msg_openai["tool_calls"] = tool_calls + + if reasoning_content_blocks: + reasoning_msg = "\n".join(r.get("thinking", "") for r in reasoning_content_blocks) + if reasoning_msg: + msg_openai["reasoning_content"] = reasoning_msg + + if msg_openai["content"] or msg_openai.get("tool_calls"): + messages.append(msg_openai) + + i += 1 + + return messages diff --git a/reme_cli/component/base_component.py b/reme_cli/component/base_component.py new file mode 100644 index 00000000..61518fc2 --- /dev/null +++ b/reme_cli/component/base_component.py @@ -0,0 +1,217 @@ +"""Base class for components.""" + +from abc import ABC, abstractmethod + +from ..enumeration import ComponentEnum +from ..utils.logger_utils import get_logger + + +class BaseComponent(ABC): + """Base class for all application components. + + Provides an asynchronous lifecycle with start/close operations and + async context manager support. State tracking prevents duplicate + start or close calls. + + Subclasses must implement ``_start`` and ``_close`` to define their + specific initialization and teardown logic. + + Examples: + Direct usage:: + + comp = MyComponent() + await comp.start() + # ... use component ... + await comp.close() + + Context manager usage:: + + async with MyComponent() as comp: + # ... use component ... + + Attributes: + component_type: The type identifier for this component, used during + registry lookup. Defaults to ``ComponentEnum.BASE``. + _is_started: Internal flag indicating whether the component has been + started and not yet closed. + """ + + from .application_context import ApplicationContext + + component_type = ComponentEnum.BASE + + def __init__(self, **kwargs) -> None: + """Initialize a component instance. + + Sets up the component's internal state, binds a structured logger + with the component's class name, and stores any additional keyword + arguments for downstream use by subclasses. + + Args: + **kwargs: Arbitrary keyword arguments forwarded to the component + subclass. Typically provided by the registry when the + component is instantiated from configuration. + """ + self.kwargs: dict = dict(kwargs) + self.logger = get_logger() + if hasattr(self.logger, "bind"): + self.logger = self.logger.bind(component=self.__class__.__name__) + self._is_started: bool = False + + @abstractmethod + async def _start(self, app_context: ApplicationContext | None = None) -> None: + """Perform the actual initialization logic for this component. + + Subclasses must implement this method to set up resources such as + connections, caches, or background tasks. This method is called + internally by ``start()`` after verifying the component is not + already started. + + Args: + app_context: The shared application context that provides access + to other initialized components and the application service. + May be ``None`` if the component does not require cross-component + references. + + Raises: + ValueError: If required configuration or dependencies are missing + or invalid. + Exception: Any exception raised during resource acquisition will + propagate to the caller of ``start()``. + """ + + @abstractmethod + async def _close(self) -> None: + """Perform the actual teardown logic for this component. + + Subclasses must implement this method to release resources such as + closing connections, flushing buffers, or cancelling background tasks. + This method is called internally by ``close()`` after verifying the + component is in a started state. + + Raises: + ValueError: If the component is in an unexpected state during shutdown. + Exception: Any exception raised during resource cleanup will + propagate to the caller of ``close()``. + """ + + async def start(self, app_context: ApplicationContext | None = None) -> None: + """Start the component and transition it to an active state. + + This is the public entry point for component initialization. It guards + against duplicate starts by returning immediately if the component is + already running, then delegates to ``_start`` for the subclass-specific + setup. + + Args: + app_context: The shared application context to pass to ``_start``. + + Raises: + ValueError: If the component configuration is invalid or required + dependencies are unavailable (raised by the subclass ``_start``). + """ + if self._is_started: + return + await self._start(app_context) + self._is_started = True + + async def close(self) -> None: + """Close the component and release its resources. + + This is the public entry point for component teardown. It guards + against redundant closes by returning immediately if the component + has not been started or is already closed, then delegates to ``_close`` + for the subclass-specific cleanup. + + Raises: + ValueError: If the component is in an inconsistent state that + prevents safe shutdown (raised by the subclass ``_close``). + """ + if not self._is_started: + return + try: + await self._close() + finally: + self._is_started = False + + async def restart(self, app_context: ApplicationContext | None = None) -> None: + """Restart the component by closing and then starting it again. + + This method safely tears down the component if it is currently running + and reinitialized it. If the component is not started, it will simply + be started. + + If either the close or start operation fails, the exception propagates + immediately and the component will be left in a non-started state. + + Args: + app_context: The shared application context to pass during startup. + + Raises: + ValueError: If the component cannot be cleanly shut down or + reinitialized due to invalid state or configuration. + """ + await self.close() + await self.start(app_context) + + @property + def is_started(self) -> bool: + """Return whether the component is currently in a started state. + + Returns: + ``True`` if ``start()`` has been called and ``close()`` has not + been called since; ``False`` otherwise. + """ + return self._is_started + + async def __call__(self, **kwargs): + """Call the component instance as a function.""" + + async def __aenter__(self) -> "BaseComponent": + """Enter the async context manager by starting the component. + + Returns: + The component instance, allowing it to be bound in an ``async with`` + statement. + + Raises: + ValueError: If the component fails to start due to invalid + configuration or missing dependencies. + """ + await self.start() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb, + ) -> bool: + """Exit the async context manager by closing the component. + + Any exception raised within the context block is not suppressed. + If both the context block and ``close()`` raise exceptions, the + original exception from the context block is preserved and the + close exception is attached as its ``__cause__`` to maintain + the full error chain. + + Args: + exc_type: The exception type if an exception was raised in the + context block, otherwise ``None``. + exc_val: The exception value if an exception was raised. + exc_tb: The traceback if an exception was raised. + + Returns: + ``False`` to indicate that exceptions should not be suppressed. + """ + if self._is_started: + if exc_val is not None: + try: + await self._close() + except BaseException as close_exc: + raise close_exc from exc_val + finally: + self._is_started = False + else: + await self.close() + return False diff --git a/reme_cli/component/base_step.py b/reme_cli/component/base_step.py new file mode 100644 index 00000000..2c6a08f8 --- /dev/null +++ b/reme_cli/component/base_step.py @@ -0,0 +1,143 @@ +"""Base step class for LLM workflow execution.""" + +import copy +from abc import abstractmethod + +from .application_context import ApplicationContext +from .as_llm import BaseAsLLM +from .as_llm_formatter import BaseAsLLMFormatter +from .base_component import BaseComponent +from .embedding import BaseEmbeddingModel +from .file_store import BaseFileStore +from .prompt_handler import PromptHandler +from .runtime_context import RuntimeContext +from ..enumeration import ComponentEnum +from ..schema import ApplicationConfig +from ..utils import camel_to_snake + + +class BaseStep(BaseComponent): + """Base step for LLM workflow execution and composition.""" + + component_type = ComponentEnum.STEP + + def __new__(cls, *args, **kwargs): + """Capture init args for object cloning.""" + instance = super().__new__(cls) + instance._init_args = copy.copy(args) + instance._init_kwargs = copy.copy(kwargs) + return instance + + def __init__( + self, + name: str = "", + language: str = "", + prompt_dict: dict[str, str] | None = None, + input_mapping: dict[str, str] | None = None, + output_mapping: dict[str, str] | None = None, + **kwargs, + ): + """Initialize step configurations.""" + super().__init__(**kwargs) + self.name = name or camel_to_snake(self.__class__.__name__) + self.language = language + self.prompt = PromptHandler(language=self.language) + self.prompt.load_prompt_by_class(self.__class__).load_prompt_dict(prompt_dict) + self.input_mapping = input_mapping + self.output_mapping = output_mapping + self.context: RuntimeContext | None = None + + async def _start(self, app_context=None) -> None: + """Apply input mapping before execution.""" + if self.input_mapping and self.context: + self.context.apply_mapping(self.input_mapping) + + async def _close(self) -> None: + """Apply output mapping after execution.""" + if self.output_mapping and self.context: + self.context.apply_mapping(self.output_mapping) + + @abstractmethod + async def execute(self): + """Execute the step logic.""" + + async def __call__(self, context: RuntimeContext | None = None, **kwargs): + """Execute the step with lifecycle management.""" + self.context = RuntimeContext.from_context(context, **kwargs) + await self.start() + try: + response = await self.execute() + return response + finally: + await self.close() + + @property + def application_context(self) -> ApplicationContext: + """Get the application context from runtime context.""" + assert self.context is not None, "Runtime context not set." + return self.context.application_context + + @property + def app_config(self) -> ApplicationConfig: + """Get the application configuration.""" + return self.application_context.app_config + + @property + def as_llm(self) -> BaseAsLLM: + """Get the AsLLM instance by name.""" + name: str = self.kwargs.get("as_llm", "default") + llms = self.application_context.components[ComponentEnum.AS_LLM] + if name not in llms: + raise ValueError(f"AsLLM {name} not found.") + llm = llms[name] + if not isinstance(llm, BaseAsLLM): + raise TypeError(f"{name} is not a BaseAsLLM instance.") + return llm + + @property + def as_llm_formatter(self) -> BaseAsLLMFormatter: + """Get the AsLLMFormatter instance by name.""" + name: str = self.kwargs.get("as_llm_formatter", "default") + formatters = self.application_context.components[ComponentEnum.AS_LLM_FORMATTER] + if name not in formatters: + raise ValueError(f"AsLLMFormatter {name} not found.") + formatter = formatters[name] + if not isinstance(formatter, BaseAsLLMFormatter): + raise TypeError(f"{name} is not a BaseAsLLMFormatter instance.") + return formatter + + @property + def file_store(self) -> BaseFileStore: + """Get the FileStore instance by name.""" + name: str = self.kwargs.get("file_store", "default") + stores = self.application_context.components[ComponentEnum.FILE_STORE] + if name not in stores: + raise ValueError(f"FileStore {name} not found.") + store = stores[name] + if not isinstance(store, BaseFileStore): + raise TypeError(f"{name} is not a BaseFileStore instance.") + return store + + @property + def embedding(self) -> BaseEmbeddingModel: + """Get the EmbeddingModel instance by name.""" + name: str = self.kwargs.get("embedding", "default") + models = self.application_context.components[ComponentEnum.EMBEDDING_MODEL] + if name not in models: + raise ValueError(f"EmbeddingModel {name} not found.") + model = models[name] + if not isinstance(model, BaseEmbeddingModel): + raise TypeError(f"{name} is not a BaseEmbeddingModel instance.") + return model + + def prompt_format(self, prompt_name: str, **kwargs) -> str: + """Format a prompt template.""" + return self.prompt.prompt_format(prompt_name=prompt_name, **kwargs) + + def get_prompt(self, prompt_name: str) -> str: + """Get a prompt template by name.""" + return self.prompt.get_prompt(prompt_name=prompt_name) + + def copy(self, **kwargs) -> "BaseStep": + """Create a copy with optional parameter overrides.""" + return self.__class__(*self._init_args, **{**self._init_kwargs, **kwargs}) diff --git a/reme_cli/component/client/__init__.py b/reme_cli/component/client/__init__.py new file mode 100644 index 00000000..f549b4f1 --- /dev/null +++ b/reme_cli/component/client/__init__.py @@ -0,0 +1,9 @@ +"""Client""" + +from .base_client import BaseClient +from .http_client import HttpClient + +__all__ = [ + "BaseClient", + "HttpClient", +] diff --git a/reme_cli/component/client/base_client.py b/reme_cli/component/client/base_client.py new file mode 100644 index 00000000..d5da961b --- /dev/null +++ b/reme_cli/component/client/base_client.py @@ -0,0 +1,26 @@ +"""Abstract base class for client implementations.""" + +from abc import abstractmethod + +from ..base_component import BaseComponent +from ...enumeration import ComponentEnum + + +class BaseClient(BaseComponent): + """Abstract base class for clients that communicate with ReMe services.""" + + component_type = ComponentEnum.CLIENT + + def __init__(self, **kwargs) -> None: + super().__init__(**kwargs) + self.client = None + + async def _start(self, app_context=None) -> None: + """Initialize the client.""" + + async def _close(self) -> None: + """Close the client.""" + + @abstractmethod + async def __call__(self, action: str, **kwargs) -> dict: + """Invoke an action with the given configuration.""" diff --git a/reme_cli/component/client/http_client.py b/reme_cli/component/client/http_client.py new file mode 100644 index 00000000..c93870eb --- /dev/null +++ b/reme_cli/component/client/http_client.py @@ -0,0 +1,60 @@ +"""HTTP client for ReMe services.""" + +import json +import os + +import httpx + +from .base_client import BaseClient +from ..component_registry import R +from ...constants import REME_SERVICE_INFO, REME_DEFAULT_HOST, REME_DEFAULT_PORT + + +@R.register("http") +class HttpClient(BaseClient): + """HTTP client for ReMe service.""" + + def __init__( + self, + action: str, + host: str | None = None, + port: int | None = None, + timeout: float = 30.0, + **kwargs, + ): + super().__init__(**kwargs) + + if host and port: + pass + elif service_info := os.environ.get(REME_SERVICE_INFO): + try: + data = json.loads(service_info) + host = data.get("host", host) + port = data.get("port", port) + except Exception: + pass + else: + host = REME_DEFAULT_HOST + port = REME_DEFAULT_PORT + + self.action = action + self.base_url = f"http://{host}:{port}" + self.timeout = timeout + + async def _start(self, app_context=None) -> None: + """Initialize the HTTP client.""" + if self.client is None: + self.client = httpx.AsyncClient( + base_url=self.base_url, + timeout=self.timeout, + ) + + async def __call__(self, **_kwargs) -> dict: + response = await self.client.post(f"/{self.action}", json=self.kwargs) + response.raise_for_status() + return response.json() + + async def _close(self) -> None: + if self.client is not None: + await self.client.aclose() + self.client = None diff --git a/reme_cli/component/component_registry.py b/reme_cli/component/component_registry.py new file mode 100644 index 00000000..007a5ebe --- /dev/null +++ b/reme_cli/component/component_registry.py @@ -0,0 +1,78 @@ +""" +Component registry module. + +Provides a global registry for managing component class registration and lookup. +Supports two registration methods: +1. Direct registration: R.register(MyClass, "name") +2. Decorator registration: @R.register("name") +""" + +from typing import Callable, TypeVar, cast + +from .base_component import BaseComponent +from ..enumeration import ComponentEnum +from ..utils import get_logger + +T = TypeVar("T", bound=BaseComponent) + + +class ComponentRegistry: + """Registry for managing component class registration and lookup.""" + + def __init__(self) -> None: + self._registry: dict[ComponentEnum, dict[str, type[BaseComponent]]] = {} + self.logger = get_logger() + + def _do_register(self, cls: type[T], name: str) -> type[T]: + """Register a component class with the given name.""" + if not hasattr(cls, "component_type"): + raise TypeError(f"{cls.__name__} must have 'component_type' attribute") + if not name: + raise ValueError("Component name cannot be empty") + + component_type = cls.component_type + if name in self._registry[component_type]: + self.logger.warning(f"Component '{name}' already registered for {component_type}, overwriting") + + self._registry[component_type][name] = cls + return cls + + def register( + self, + cls_or_name: type[T] | str, + name: str | None = None, + ) -> Callable[[type[T]], type[T]] | type[T]: + """Register a component class. Supports direct and decorator modes.""" + # Direct registration: R.register(MyClass, "name") + if isinstance(cls_or_name, type): + return self._do_register(cast(type[T], cls_or_name), name or cls_or_name.__name__) + + # Decorator mode: @R.register("name") + decorator_name = cls_or_name + + def decorator(decorated_cls: type[T]) -> type[T]: + return self._do_register(decorated_cls, decorator_name or decorated_cls.__name__) + + return decorator + + def get(self, component_type: ComponentEnum, name: str) -> type[BaseComponent] | None: + """Get a registered component class by type and name.""" + return self._registry.get(component_type, {}).get(name) + + def get_all(self, component_type: ComponentEnum) -> dict[str, type[BaseComponent]]: + """Get all registered components of a given type.""" + return dict(self._registry.get(component_type, {})) + + def unregister(self, component_type: ComponentEnum, name: str) -> bool: + """Remove a component from the registry. Returns True if found.""" + if name in self._registry.get(component_type, {}): + del self._registry[component_type][name] + return True + return False + + def clear(self) -> None: + """Clear all registered components.""" + self._registry.clear() + + +R = ComponentRegistry() diff --git a/reme_cli/component/embedding/__init__.py b/reme_cli/component/embedding/__init__.py new file mode 100644 index 00000000..58a00a52 --- /dev/null +++ b/reme_cli/component/embedding/__init__.py @@ -0,0 +1,9 @@ +"""Embedding model implementations.""" + +from .base_embedding_model import BaseEmbeddingModel +from .openai_embedding_model import OpenAIEmbeddingModel + +__all__ = [ + "BaseEmbeddingModel", + "OpenAIEmbeddingModel", +] diff --git a/reme_cli/component/embedding/base_embedding_model.py b/reme_cli/component/embedding/base_embedding_model.py new file mode 100644 index 00000000..69526b40 --- /dev/null +++ b/reme_cli/component/embedding/base_embedding_model.py @@ -0,0 +1,346 @@ +"""Base embedding model with caching, batching, and retry support.""" + +import asyncio +import hashlib +import json +import time +from abc import abstractmethod +from collections import OrderedDict +from pathlib import Path + +from ..base_component import BaseComponent +from ...enumeration import ComponentEnum +from ...schema import BaseNode + + +class BaseEmbeddingModel(BaseComponent): + """Abstract base class for embedding models with LRU cache and retry logic. + + Provides: + - LRU in-memory cache with disk persistence (JSONL) + - Automatic text truncation to max_input_length + - Retry logic with exponential backoff + - Batch embedding support + """ + + component_type = ComponentEnum.EMBEDDING_MODEL + + def __init__( + self, + api_key: str | None = None, + base_url: str | None = None, + model_name: str = "", + dimensions: int = 1024, + use_dimensions: bool = False, + max_batch_size: int = 10, + max_retries: int = 3, + raise_exception: bool = True, + max_input_length: int = 8192, + cache_dir: str | Path = ".reme", + max_cache_size: int = 2000, + enable_cache: bool = True, + encoding: str = "utf-8", + **kwargs, + ): + """Initialize embedding model configuration. + + Args: + api_key: API key for the embedding service. + base_url: Base URL for the embedding service. + model_name: Name of the embedding model. + dimensions: Vector dimensions. + use_dimensions: Whether to pass dimensions parameter to API. + max_batch_size: Maximum batch size for embedding requests. + max_retries: Maximum retry attempts on failure. + raise_exception: Whether to raise exceptions on failure. + max_input_length: Maximum input text length. + cache_dir: Directory for cache storage. + max_cache_size: Maximum LRU cache size. + enable_cache: Whether to enable caching. + encoding: Text encoding for cache file operations. + """ + super().__init__(**kwargs) + self.api_key: str | None = api_key + self.base_url: str | None = base_url + self.model_name = model_name + self.dimensions = dimensions + self.use_dimensions = use_dimensions + self.max_batch_size = max_batch_size + self.max_retries = max_retries + self.raise_exception = raise_exception + self.max_input_length = max_input_length + self.cache_dir = cache_dir + self.max_cache_size = max_cache_size + self.enable_cache = enable_cache + self.encoding = encoding + + self._embedding_cache: OrderedDict[str, list[float]] = OrderedDict() + self._cache_hits = 0 + self._cache_misses = 0 + self.cache_path: Path = Path(self.cache_dir) + + def _truncate_text(self, text: str) -> str: + """Truncate text to max_input_length.""" + return text[: self.max_input_length] if len(text) > self.max_input_length else text + + def _validate_and_adjust_embedding(self, embedding: list[float]) -> list[float]: + """Adjust embedding dimensions to match expected dimensions.""" + actual_len = len(embedding) + if actual_len == self.dimensions: + return embedding + + if actual_len < self.dimensions: + self.logger.warning( + f"[ACTUAL_EMB_LENGTH] Embedding {actual_len} < expected {self.dimensions}, padding with zeros", + ) + return embedding + [0.0] * (self.dimensions - actual_len) + + self.logger.warning( + f"[ACTUAL_EMB_LENGTH] Embedding {actual_len} > expected {self.dimensions}, truncating", + ) + return embedding[: self.dimensions] + + def _get_cache_key(self, text: str) -> str: + """Generate cache key from text + model_name + dimensions.""" + cache_string = f"{text}|{self.model_name}|{self.dimensions}" + return hashlib.sha256(cache_string.encode(self.encoding)).hexdigest() + + def _get_cache_file_path(self) -> Path: + """Return path to the cache JSONL file.""" + return self.cache_path / "embedding_cache.jsonl" + + def _load_cache(self) -> None: + """Load embedding cache from disk (JSONL format).""" + if not self.enable_cache: + return + + self.cache_path.mkdir(parents=True, exist_ok=True) + cache_file = self._get_cache_file_path() + if not cache_file.exists(): + self.logger.info(f"No cache file at {cache_file}, starting empty") + return + + try: + load_start = time.time() + with open(cache_file, "r", encoding=self.encoding) as f: + lines = f.readlines() + + loaded_count = 0 + for line in reversed(lines): + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + except json.JSONDecodeError as e: + self.logger.warning(f"Failed to parse cache line: {e}") + continue + + if not data: + continue + + cache_key, embedding = next(iter(data.items())) + if cache_key and embedding and isinstance(embedding, list): + if cache_key in self._embedding_cache: + continue + if len(embedding) != self.dimensions: + self.logger.warning( + f"Cache dimension mismatch for {cache_key}: " + f"expected {self.dimensions}, got {len(embedding)}", + ) + continue + if len(self._embedding_cache) >= self.max_cache_size: + self.logger.info(f"Cache limit reached ({self.max_cache_size}), loaded {loaded_count}") + break + self._embedding_cache[cache_key] = embedding + loaded_count += 1 + + self.logger.info(f"Loaded {loaded_count} embeddings from {cache_file} in {time.time() - load_start:.2f}s") + except Exception as e: + self.logger.error(f"Failed to load cache from {cache_file}: {e}, deleting file") + try: + cache_file.unlink() + except Exception as del_e: + self.logger.error(f"Failed to delete cache file: {del_e}") + + def _save_cache(self) -> None: + """Save embedding cache to disk (JSONL format).""" + if not self.enable_cache or not self._embedding_cache: + return + + cache_file = self._get_cache_file_path() + try: + with open(cache_file, "w", encoding=self.encoding) as f: + for cache_key, embedding in self._embedding_cache.items(): + if len(embedding) != self.dimensions: + self.logger.warning(f"Cache dimension mismatch for {cache_key}") + continue + f.write(json.dumps({cache_key: embedding}, ensure_ascii=False) + "\n") + self.logger.info(f"Saved {len(self._embedding_cache)} embeddings to {cache_file}") + except Exception as e: + self.logger.error(f"Failed to save cache to {cache_file}: {e}") + + def _get_from_cache(self, text: str) -> list[float] | None: + """Retrieve embedding from cache if available.""" + if not self.enable_cache: + return None + + cache_key = self._get_cache_key(text) + if cache_key not in self._embedding_cache: + self._cache_misses += 1 + return None + + embeddings = self._embedding_cache[cache_key] + if len(embeddings) != self.dimensions: + self.logger.warning("Cached embedding dimension mismatch, removing entry") + del self._embedding_cache[cache_key] + self._cache_misses += 1 + return None + + self._embedding_cache.move_to_end(cache_key) + self._cache_hits += 1 + preview = text[:50] + "..." if len(text) > 50 else text + self.logger.info(f"Cache hit: {preview} (hits: {self._cache_hits}, misses: {self._cache_misses})") + return embeddings + + def _put_to_cache(self, text: str, embedding: list[float]) -> None: + """Store embedding in cache with LRU eviction.""" + if not self.enable_cache or self.max_cache_size <= 0: + return + + cache_key = self._get_cache_key(text) + if len(embedding) != self.dimensions: + self.logger.warning(f"[PUT_TO_CACHE] Dimension mismatch for {cache_key}") + return + + if len(self._embedding_cache) >= self.max_cache_size and cache_key not in self._embedding_cache: + self._embedding_cache.popitem(last=False) + + self._embedding_cache[cache_key] = embedding + self._embedding_cache.move_to_end(cache_key) + + def get_cache_stats(self) -> dict[str, int | float]: + """Return cache statistics: size, hits, misses, hit_rate.""" + total = self._cache_hits + self._cache_misses + hit_rate = self._cache_hits / total if total > 0 else 0.0 + return { + "cache_size": len(self._embedding_cache), + "max_cache_size": self.max_cache_size, + "cache_hits": self._cache_hits, + "cache_misses": self._cache_misses, + "hit_rate": hit_rate, + } + + def clear_cache(self) -> None: + """Clear in-memory cache and reset statistics.""" + self._embedding_cache.clear() + self._cache_hits = 0 + self._cache_misses = 0 + + @abstractmethod + async def _get_embeddings(self, input_text: list[str], **kwargs) -> list[list[float]]: + """Fetch embeddings for a batch of texts. Override in subclasses.""" + + async def get_embedding(self, input_text: str, **kwargs) -> list[float]: + """Get embedding for a single text with cache and retry.""" + truncated_text = self._truncate_text(input_text) + cached = self._get_from_cache(truncated_text) + if cached is not None: + return cached + + for retry in range(self.max_retries): + try: + result = await self._get_embeddings([truncated_text], **kwargs) + if result and len(result) == 1: + embedding = self._validate_and_adjust_embedding(result[0]) + self._put_to_cache(truncated_text, embedding) + return embedding + self.logger.warning( + f"Model {self.model_name} returned {len(result) if result else 0} results, expected 1", + ) + if retry == self.max_retries - 1: + if self.raise_exception: + raise RuntimeError("Embedding API returned empty result") + return [] + await asyncio.sleep(retry + 1) + except Exception as e: + self.logger.error(f"Model {self.model_name} failed: {e}") + if retry == self.max_retries - 1: + if self.raise_exception: + raise + return [] + await asyncio.sleep(retry + 1) + return [] + + async def get_embeddings(self, input_text: list[str], **kwargs) -> list[list[float]]: + """Get embeddings for multiple texts with cache and batching.""" + truncated_texts = [self._truncate_text(t) for t in input_text] + results: list[list[float] | None] = [None] * len(truncated_texts) + texts_to_compute: list[tuple[int, str]] = [] + + for idx, text in enumerate(truncated_texts): + cached = self._get_from_cache(text) + if cached is not None: + results[idx] = cached + else: + texts_to_compute.append((idx, text)) + + if texts_to_compute: + uncached_texts = [text for _, text in texts_to_compute] + for i in range(0, len(uncached_texts), self.max_batch_size): + batch_texts = uncached_texts[i : i + self.max_batch_size] + batch_indices = [idx for idx, _ in texts_to_compute[i : i + self.max_batch_size]] + + for retry in range(self.max_retries): + try: + batch_embeddings = await self._get_embeddings(batch_texts, **kwargs) + if batch_embeddings and len(batch_embeddings) == len(batch_texts): + for orig_idx, text, embedding in zip(batch_indices, batch_texts, batch_embeddings): + adjusted = self._validate_and_adjust_embedding(embedding) + results[orig_idx] = adjusted + self._put_to_cache(text, adjusted) + break + self.logger.warning( + f"Batch returned {len(batch_embeddings) if batch_embeddings else 0} " + f"results for {len(batch_texts)} inputs", + ) + if retry == self.max_retries - 1: + if self.raise_exception: + raise RuntimeError(f"Batch embedding failed after {self.max_retries} retries") + for orig_idx in batch_indices: + if results[orig_idx] is None: + results[orig_idx] = [] + else: + await asyncio.sleep(retry + 1) + except Exception as e: + self.logger.error(f"Model {self.model_name} batch failed: {e}") + if retry == self.max_retries - 1: + if self.raise_exception: + raise + for orig_idx in batch_indices: + if results[orig_idx] is None: + results[orig_idx] = [] + else: + await asyncio.sleep(retry + 1) + + return [r if r is not None else [] for r in results] + + async def get_node_embeddings(self, nodes: list[BaseNode], **kwargs) -> list[BaseNode]: + """Get embeddings for a list of nodes and assign to node.embedding.""" + texts = [node.text for node in nodes] + embeddings = await self.get_embeddings(texts, **kwargs) + + if len(embeddings) == len(nodes): + for node, vec in zip(nodes, embeddings): + node.embedding = vec + else: + self.logger.warning(f"Mismatch: {len(embeddings)} vectors for {len(nodes)} nodes, skipping assignment") + return nodes + + async def _start(self, app_context=None) -> None: + """Load cache on start.""" + self._load_cache() + + async def _close(self) -> None: + """Save cache on close.""" + self._save_cache() diff --git a/reme_cli/component/embedding/openai_embedding_model.py b/reme_cli/component/embedding/openai_embedding_model.py new file mode 100644 index 00000000..faa97706 --- /dev/null +++ b/reme_cli/component/embedding/openai_embedding_model.py @@ -0,0 +1,56 @@ +"""OpenAI-compatible async embedding model.""" + +from openai import AsyncOpenAI + +from .base_embedding_model import BaseEmbeddingModel +from ..component_registry import R + + +@R.register("openai") +class OpenAIEmbeddingModel(BaseEmbeddingModel): + """Async embedding model compatible with OpenAI-style APIs.""" + + def __init__(self, **kwargs): + """Initialize OpenAI embedding model.""" + super().__init__(**kwargs) + self._client: AsyncOpenAI | None = None + + async def _start(self, app_context=None) -> None: + """Initialize the AsyncOpenAI client.""" + self._client = AsyncOpenAI(api_key=self.api_key, base_url=self.base_url, **self.kwargs) + await super()._start(app_context) + + async def _close(self) -> None: + """Close the AsyncOpenAI client.""" + if self._client is not None: + await self._client.close() + self._client = None + await super()._close() + + async def _get_embeddings(self, input_text: list[str], **kwargs) -> list[list[float]]: + """Fetch embeddings for a batch of texts.""" + if self._client is None: + raise RuntimeError("Client not initialized. Call _start() first.") + + create_kwargs: dict = { + "model": self.model_name, + "input": input_text, + **kwargs, + } + if self.use_dimensions: + create_kwargs["dimensions"] = self.dimensions + + completion = await self._client.embeddings.create(**create_kwargs) + + result_emb: list[list[float] | None] = [None] * len(input_text) + for emb in completion.data: + vec = getattr(emb, "embedding", None) or getattr(emb, "dense_embedding", None) + if 0 <= emb.index < len(input_text): + if vec is not None: + result_emb[emb.index] = list(vec) + else: + self.logger.warning(f"Empty embedding for index {emb.index}") + else: + self.logger.warning(f"Invalid index {emb.index} for input length {len(input_text)}") + + return [r if r is not None else [] for r in result_emb] diff --git a/reme_cli/component/file_store/__init__.py b/reme_cli/component/file_store/__init__.py new file mode 100644 index 00000000..8fac7640 --- /dev/null +++ b/reme_cli/component/file_store/__init__.py @@ -0,0 +1,13 @@ +"""File store module for persistent memory management. + +Provides storage backends for memory chunks and file metadata with +vector and full-text search capabilities. +""" + +from .base_file_store import BaseFileStore +from .local_file_store import LocalFileStore + +__all__ = [ + "BaseFileStore", + "LocalFileStore", +] diff --git a/reme_cli/component/file_store/base_file_store.py b/reme_cli/component/file_store/base_file_store.py new file mode 100644 index 00000000..88db57a2 --- /dev/null +++ b/reme_cli/component/file_store/base_file_store.py @@ -0,0 +1,174 @@ +"""Abstract base class for file storage backends.""" + +import re +from abc import abstractmethod +from pathlib import Path + +from ..base_component import BaseComponent +from ..embedding import BaseEmbeddingModel +from ...enumeration import ComponentEnum +from ...schema import FileChunk, FileMetadata + + +class BaseFileStore(BaseComponent): + """Abstract base class for file storage backends. + + Provides embedding resolution, validation, and safe embedding retrieval + with automatic fallback on failure. + """ + + component_type = ComponentEnum.FILE_STORE + + def __init__( + self, + store_name: str, + db_path: str | Path, + embedding_model: str = "default", + fts_enabled: bool = True, + **kwargs, + ): + super().__init__(**kwargs) + self._embedding_model_name: str = embedding_model + self.embedding_model: BaseEmbeddingModel | None = None + self.store_name: str = store_name + self.db_path: Path = Path(db_path) + self.db_path.mkdir(parents=True, exist_ok=True) + self.vector_enabled: bool = bool(embedding_model) + self.fts_enabled: bool = fts_enabled + + if not re.match(r"^[a-zA-Z0-9_]+$", store_name): + raise ValueError( + f"Invalid store name '{store_name}'. Only alphanumeric characters and underscores are allowed.", + ) + if not self.vector_enabled and not self.fts_enabled: + raise ValueError("At least one of embedding_model or fts_enabled must be set.") + + async def _start(self, app_context=None): + """Resolve embedding model from app_context.""" + if not self._embedding_model_name: + return + assert app_context is not None, "app_context must be provided" + models = app_context.components.get(ComponentEnum.EMBEDDING_MODEL, {}) + if self._embedding_model_name not in models: + raise ValueError(f"Embedding model '{self._embedding_model_name}' not found.") + model = models[self._embedding_model_name] + if not isinstance(model, BaseEmbeddingModel): + raise TypeError(f"Expected BaseEmbeddingModel, got {type(model).__name__}") + self.embedding_model = model + + async def _close(self): + """Release embedding model reference.""" + self.embedding_model = None + + @property + def embedding_dim(self) -> int: + """Return the embedding dimensionality (default 1024).""" + return self.embedding_model.dimensions if self.embedding_model else 1024 + + def _disable_vector_search(self, reason: str = "embedding API error") -> None: + """Disable vector search and log a warning.""" + if self.vector_enabled: + self.logger.warning(f"[{self.store_name}] Disabling vector search: {reason}") + self.vector_enabled = False + + async def _get_embeddings_safe(self, texts: list[str], **kwargs) -> list[list[float]] | None: + """Get embeddings, returning None if vector search is disabled or an error occurs.""" + if not self.vector_enabled: + return None + try: + assert self.embedding_model is not None, "Embedding model not initialized" + return await self.embedding_model.get_embeddings(texts, **kwargs) + except Exception as e: + self._disable_vector_search(str(e)) + return None + + async def get_embedding(self, query: str, **kwargs) -> list[float] | None: + """Get embedding for a single query string.""" + result = await self._get_embeddings_safe([query], **kwargs) + return result[0] if result else None + + async def get_embeddings(self, queries: list[str], **kwargs) -> list[list[float]] | None: + """Get embeddings for a batch of query strings.""" + return await self._get_embeddings_safe(queries, **kwargs) + + async def get_chunk_embedding(self, chunk: FileChunk, **kwargs) -> FileChunk: + """Attach embedding to a single FileChunk.""" + chunk.embedding = await self.get_embedding(chunk.text, **kwargs) + return chunk + + async def get_chunk_embeddings(self, chunks: list[FileChunk], **kwargs) -> list[FileChunk]: + """Attach embeddings to a batch of FileChunk.""" + if not chunks: + return chunks + embeddings = await self.get_embeddings([c.text for c in chunks], **kwargs) + if embeddings and len(embeddings) == len(chunks): + for chunk, emb in zip(chunks, embeddings): + chunk.embedding = emb + else: + for chunk in chunks: + chunk.embedding = None + return chunks + + @abstractmethod + async def clear_all(self): + """Clear all indexed data.""" + + @abstractmethod + async def upsert_file(self, file_meta: FileMetadata, chunks: list[FileChunk]): + """Insert or update a file and its chunks.""" + + @abstractmethod + async def delete_file(self, path: str): + """Delete a file and all its chunks.""" + + @abstractmethod + async def delete_file_chunks(self, path: str, chunk_ids: list[str]): + """Delete specific chunks for a file.""" + + @abstractmethod + async def upsert_chunks(self, chunks: list[FileChunk]): + """Insert or update specific chunks without affecting others.""" + + @abstractmethod + async def list_files(self) -> list[str]: + """List all indexed file paths.""" + + @abstractmethod + async def get_file_metadata(self, path: str) -> FileMetadata | None: + """Get file metadata.""" + + @abstractmethod + async def update_file_metadata(self, file_meta: FileMetadata) -> None: + """Update file metadata without affecting chunks.""" + + @abstractmethod + async def get_file_chunks(self, path: str) -> list[FileChunk]: + """Get all chunks for a file.""" + + @abstractmethod + async def vector_search(self, query: str, limit: int) -> list[FileChunk]: + """Perform vector similarity search.""" + + @abstractmethod + async def keyword_search(self, query: str, limit: int) -> list[FileChunk]: + """Perform full-text/keyword search.""" + + @abstractmethod + async def hybrid_search( + self, + query: str, + limit: int, + vector_weight: float = 0.7, + candidate_multiplier: float = 3.0, + ) -> list[FileChunk]: + """Perform hybrid search combining vector and keyword results. + + Args: + query: Search query text. + limit: Maximum number of results. + vector_weight: Weight for vector scores (0.0-1.0). + candidate_multiplier: Multiplier for candidate pool size. + + Returns: + FileChunk list with score populated, sorted by relevance. + """ diff --git a/reme_cli/component/file_store/local_file_store.py b/reme_cli/component/file_store/local_file_store.py new file mode 100644 index 00000000..283968cd --- /dev/null +++ b/reme_cli/component/file_store/local_file_store.py @@ -0,0 +1,337 @@ +"""Pure-Python file storage with JSONL persistence.""" + +import json +from pathlib import Path + +import numpy as np + +from .base_file_store import BaseFileStore +from ..component_registry import R +from ...schema import FileChunk, FileMetadata +from ...utils import batch_cosine_similarity + + +@R.register("local") +class LocalFileStore(BaseFileStore): + """In-memory file storage with JSONL disk persistence. + + No external database required. All data lives in Python dicts; + writes are flushed to JSONL files on disk and survive restarts. + """ + + def __init__(self, encoding: str = "utf-8", **kwargs): + super().__init__(**kwargs) + self._encoding: str = encoding + self._chunks: dict[str, FileChunk] = {} + self._files: dict[str, FileMetadata] = {} + self._chunks_file: Path = self.db_path / f"{self.store_name}_chunks.jsonl" + self._metadata_file: Path = self.db_path / f"{self.store_name}_file_metadata.json" + + # -- Persistence helpers ------------------------------------------------ + + async def _load_chunks(self) -> None: + """Load chunks from JSONL file into memory.""" + if not self._chunks_file.exists(): + return + try: + data = self._chunks_file.read_text(encoding=self._encoding) + self._chunks = {} + for line in data.strip().split("\n"): + if not line: + continue + chunk = FileChunk.model_validate(json.loads(line)) + self._chunks[chunk.id] = chunk + except Exception as e: + self.logger.warning(f"Failed to load chunks: {e}") + + async def _save_chunks(self) -> None: + """Persist chunks to JSONL file with atomic write.""" + lines = [json.dumps(c.model_dump(mode="json"), ensure_ascii=False) for c in self._chunks.values()] + content = "\n".join(lines) + temp_path = self._chunks_file.with_suffix(".tmp") + try: + temp_path.write_text(content, encoding=self._encoding) + temp_path.replace(self._chunks_file) + except Exception as e: + self.logger.error(f"Failed to save chunks: {e}") + raise + finally: + if temp_path.exists(): + temp_path.unlink() + + async def _load_metadata(self) -> None: + """Load file metadata from JSON file into memory.""" + if not self._metadata_file.exists(): + return + try: + data = self._metadata_file.read_text(encoding=self._encoding) + raw: dict = json.loads(data) + self._files = {path: FileMetadata(**meta) for path, meta in raw.items()} + except Exception as e: + self.logger.warning(f"Failed to load metadata: {e}") + + async def _save_metadata(self) -> None: + """Persist file metadata to JSON file with atomic write.""" + raw = { + path: meta.model_dump(exclude={"content", "metadata"}, mode="json") for path, meta in self._files.items() + } + content = json.dumps(raw, indent=2, ensure_ascii=False) + temp_path = self._metadata_file.with_suffix(".tmp") + try: + temp_path.write_text(content, encoding=self._encoding) + temp_path.replace(self._metadata_file) + except Exception as e: + self.logger.error(f"Failed to save metadata: {e}") + raise + finally: + if temp_path.exists(): + temp_path.unlink() + + # -- Lifecycle ---------------------------------------------------------- + + async def _start(self, app_context=None) -> None: + """Load persisted data into memory.""" + await self._load_metadata() + await self._load_chunks() + self.logger.info( + f"LocalFileStore '{self.store_name}' ready: " + f"{len(self._chunks)} chunks, metadata at {self._metadata_file}", + ) + await super()._start() + + async def _close(self) -> None: + """Flush state to disk and clear memory.""" + await self._save_metadata() + await self._save_chunks() + self._chunks.clear() + self._files.clear() + await super()._close() + + # -- Write operations --------------------------------------------------- + + async def upsert_file(self, file_meta: FileMetadata, chunks: list[FileChunk]) -> None: + """Insert or update a file and its chunks.""" + if not chunks: + return + + await self.delete_file(file_meta.path) + chunks = await self.get_chunk_embeddings(chunks) + + for chunk in chunks: + self._chunks[chunk.id] = chunk + + self._files[file_meta.path] = FileMetadata( + hash=file_meta.hash, + mtime_ms=file_meta.mtime_ms, + size=file_meta.size, + path=file_meta.path, + chunk_count=len(chunks), + ) + + async def delete_file(self, path: str) -> None: + """Delete a file and all its chunks.""" + to_delete = [cid for cid, chunk in self._chunks.items() if chunk.path == path] + for cid in to_delete: + del self._chunks[cid] + self._files.pop(path, None) + + async def delete_file_chunks(self, path: str, chunk_ids: list[str]) -> None: + """Delete specific chunks for a file.""" + if not chunk_ids: + return + for cid in chunk_ids: + self._chunks.pop(cid, None) + if path in self._files: + self._files[path].chunk_count = sum(1 for chunk in self._chunks.values() if chunk.path == path) + + async def upsert_chunks(self, chunks: list[FileChunk]) -> None: + """Insert or update specific chunks without affecting others.""" + if not chunks: + return + chunks = await self.get_chunk_embeddings(chunks) + for chunk in chunks: + self._chunks[chunk.id] = chunk + + # -- Read operations ---------------------------------------------------- + + async def list_files(self) -> list[str]: + """List all indexed file paths.""" + return list(self._files.keys()) + + async def get_file_metadata(self, path: str) -> FileMetadata | None: + """Get file metadata.""" + return self._files.get(path) + + async def update_file_metadata(self, file_meta: FileMetadata) -> None: + """Update file metadata without affecting chunks.""" + self._files[file_meta.path] = FileMetadata( + hash=file_meta.hash, + mtime_ms=file_meta.mtime_ms, + size=file_meta.size, + path=file_meta.path, + chunk_count=file_meta.chunk_count, + ) + + async def get_file_chunks(self, path: str) -> list[FileChunk]: + """Get all chunks for a file, sorted by start_line.""" + chunks = [chunk for chunk in self._chunks.values() if chunk.path == path] + chunks.sort(key=lambda c: c.start_line) + return chunks + + # -- Search ------------------------------------------------------------- + + async def vector_search(self, query: str, limit: int) -> list[FileChunk]: + """Cosine-similarity vector search over in-memory embeddings.""" + if not self.vector_enabled or not query: + return [] + + query_embedding = await self.get_embedding(query) + if not query_embedding: + return [] + + candidates = [c for c in self._chunks.values() if c.embedding] + if not candidates: + return [] + + expected_dim = self.embedding_dim + + # Validate and align embedding dimensions + valid_embeddings = [] + for chunk in candidates: + emb = chunk.embedding + emb_len = len(emb) + if emb_len != expected_dim: + emb = (emb + [0.0] * (expected_dim - emb_len)) if emb_len < expected_dim else emb[:expected_dim] + valid_embeddings.append(emb) + + query_array = np.array([query_embedding]) + chunk_embeddings = np.array(valid_embeddings) + similarities = batch_cosine_similarity(query_array, chunk_embeddings)[0] + + results = [] + for chunk, sim in zip(candidates, similarities): + results.append( + FileChunk( + id=chunk.id, + path=chunk.path, + start_line=chunk.start_line, + end_line=chunk.end_line, + hash=chunk.hash, + text=chunk.text, + embedding=chunk.embedding, + scores={"vector": float(sim), "score": float(sim)}, + ), + ) + + results.sort(key=lambda r: r.score, reverse=True) + return results[:limit] + + async def keyword_search(self, query: str, limit: int) -> list[FileChunk]: + """Keyword search via substring matching.""" + if not self.fts_enabled or not query: + return [] + + words = query.split() + if not words: + return [] + + query_lower = query.lower() + words_lower = [w.lower() for w in words] + n_words = len(words) + + results = [] + for chunk in self._chunks.values(): + text_lower = chunk.text.lower() + match_count = sum(1 for w in words_lower if w in text_lower) + if match_count == 0: + continue + + base_score = match_count / n_words + phrase_bonus = 0.2 if n_words > 1 and query_lower in text_lower else 0.0 + score = min(1.0, base_score + phrase_bonus) + + results.append( + FileChunk( + id=chunk.id, + path=chunk.path, + start_line=chunk.start_line, + end_line=chunk.end_line, + hash=chunk.hash, + text=chunk.text, + scores={"keyword": score, "score": score}, + ), + ) + + results.sort(key=lambda r: r.score, reverse=True) + return results[:limit] + + async def hybrid_search( + self, + query: str, + limit: int, + vector_weight: float = 0.7, + candidate_multiplier: float = 3.0, + ) -> list[FileChunk]: + """Hybrid search combining vector and keyword results.""" + assert 0.0 <= vector_weight <= 1.0 + + candidates = min(200, max(1, int(limit * candidate_multiplier))) + text_weight = 1.0 - vector_weight + + if self.vector_enabled and self.fts_enabled: + keyword_results = await self.keyword_search(query, candidates) + vector_results = await self.vector_search(query, candidates) + + if not keyword_results: + return vector_results[:limit] + if not vector_results: + return keyword_results[:limit] + + merged = self._merge_hybrid_results( + vector=vector_results, + keyword=keyword_results, + vector_weight=vector_weight, + text_weight=text_weight, + ) + return merged[:limit] + elif self.vector_enabled: + return await self.vector_search(query, limit) + elif self.fts_enabled: + return await self.keyword_search(query, limit) + return [] + + @staticmethod + def _merge_hybrid_results( + vector: list[FileChunk], + keyword: list[FileChunk], + vector_weight: float, + text_weight: float, + ) -> list[FileChunk]: + """Merge vector and keyword results with weighted scoring.""" + merged: dict[str, FileChunk] = {} + + for result in vector: + v_score = result.scores.get("vector", 0) + result.scores["score"] = v_score * vector_weight + merged[result.merge_key] = result + + for result in keyword: + key = result.merge_key + k_score = result.scores.get("keyword", 0) + if key in merged: + merged[key].scores["score"] += k_score * text_weight + else: + result.scores["score"] = k_score * text_weight + merged[key] = result + + results = list(merged.values()) + results.sort(key=lambda r: r.score, reverse=True) + return results + + async def clear_all(self) -> None: + """Clear all indexed data from memory and disk.""" + self._chunks.clear() + self._files.clear() + await self._save_chunks() + await self._save_metadata() + self.logger.info(f"Cleared all data from LocalFileStore '{self.store_name}'") diff --git a/reme_cli/component/file_watcher/__init__.py b/reme_cli/component/file_watcher/__init__.py new file mode 100644 index 00000000..ecea0b28 --- /dev/null +++ b/reme_cli/component/file_watcher/__init__.py @@ -0,0 +1,9 @@ +"""File watcher implementations for monitoring file system changes.""" + +from .base_file_watcher import BaseFileWatcher +from .md_file_watcher import MdFileWatcher + +__all__ = [ + "BaseFileWatcher", + "MdFileWatcher", +] diff --git a/reme_cli/component/file_watcher/base_file_watcher.py b/reme_cli/component/file_watcher/base_file_watcher.py new file mode 100644 index 00000000..291a0f80 --- /dev/null +++ b/reme_cli/component/file_watcher/base_file_watcher.py @@ -0,0 +1,210 @@ +"""Base file watcher with watchfiles integration.""" + +import asyncio +from abc import abstractmethod +from pathlib import Path + +from watchfiles import Change, awatch + +from ..base_component import BaseComponent +from ..file_store import BaseFileStore +from ...enumeration import ComponentEnum + + +class BaseFileWatcher(BaseComponent): + """Abstract base class for file watchers. + + Provides file monitoring with: + - watchfiles integration for efficient change detection + - Suffix-based filtering + - Auto-restart on failure + - Optional index rebuild on start + """ + + component_type = ComponentEnum.FILE_WATCHER + + def __init__( + self, + watch_paths: list[str] | str, + suffix_filters: list[str] | None = None, + recursive: bool = False, + debounce: int = 2000, + chunk_tokens: int = 400, + chunk_overlap: int = 80, + file_store: str = "default", + rebuild_index_on_start: bool = True, + poll_delay_ms: int = 2000, + **kwargs, + ): + """Initialize file watcher configuration. + + Args: + watch_paths: Paths to watch for changes. + suffix_filters: File suffix filters (e.g., ['.py', '.txt']). + recursive: Whether to watch directories recursively. + debounce: Debounce time in milliseconds. + chunk_tokens: Token size for chunking. + chunk_overlap: Overlap size for chunks. + file_store: Name of the file store component. + rebuild_index_on_start: Clear index and rescan files on start. + poll_delay_ms: Polling delay in milliseconds. + """ + super().__init__(**kwargs) + self._file_store_name: str = file_store + self.file_store: BaseFileStore | None = None + self.watch_paths: list[str] = [watch_paths] if isinstance(watch_paths, str) else watch_paths + self.suffix_filters: list[str] = suffix_filters or [] + self.recursive: bool = recursive + self.debounce: int = debounce + self.chunk_tokens: int = chunk_tokens + self.chunk_overlap: int = chunk_overlap + self.rebuild_index_on_start: bool = rebuild_index_on_start + self.poll_delay_ms: int = poll_delay_ms + + self._stop_event = asyncio.Event() + self._watch_task: asyncio.Task | None = None + + async def _start(self, app_context=None): + """Resolve file_store and start watching task.""" + if self._file_store_name: + assert app_context is not None, "app_context must be provided" + stores = app_context.components.get(ComponentEnum.FILE_STORE, {}) + if self._file_store_name not in stores: + raise ValueError(f"File store '{self._file_store_name}' not found.") + store = stores[self._file_store_name] + if not isinstance(store, BaseFileStore): + raise TypeError(f"Expected BaseFileStore, got {type(store).__name__}") + self.file_store = store + + async def _initialize_and_watch(): + if self.rebuild_index_on_start and self.file_store: + await self.file_store.clear_all() + self.logger.info("Cleared all indexed data on start") + await self._scan_existing_files() + await self._watch_loop() + + self._stop_event.clear() + self._watch_task = asyncio.create_task(_initialize_and_watch()) + self.logger.info(f"Started watching: {self.watch_paths}") + + async def _close(self): + """Stop watching and release resources.""" + self._stop_event.set() + if self._watch_task and not self._watch_task.done(): + self._watch_task.cancel() + try: + await self._watch_task + except asyncio.CancelledError: + pass + + self._watch_task = None + self._stop_event.clear() + self.file_store = None + self.logger.info("Stopped watching") + + def watch_filter(self, _change: Change, path: str) -> bool: + """Filter files by suffix. Returns True if no filters configured.""" + if not self.suffix_filters: + return True + + for suffix in self.suffix_filters: + if path.endswith("." + suffix.strip(".")): + return True + return False + + async def _scan_existing_files(self): + """Scan existing files and add them as Change.added.""" + if not self.file_store: + return + + existing_files: set[tuple[Change, str]] = set() + + for watch_path_str in self.watch_paths: + watch_path = Path(watch_path_str) + + if not watch_path.exists(): + self.logger.warning(f"Watch path does not exist: {watch_path}") + continue + + if watch_path.is_file(): + if self.watch_filter(Change.added, str(watch_path)): + existing_files.add((Change.added, str(watch_path))) + elif watch_path.is_dir(): + if self.recursive: + for file_path in watch_path.rglob("*"): + if file_path.is_file() and self.watch_filter(Change.added, str(file_path)): + existing_files.add((Change.added, str(file_path))) + else: + for file_path in watch_path.iterdir(): + if file_path.is_file() and self.watch_filter(Change.added, str(file_path)): + existing_files.add((Change.added, str(file_path))) + + if existing_files: + self.logger.info(f"[SCAN_ON_START] Found {len(existing_files)} existing files") + await self.on_changes(existing_files) + self.logger.info(f"[SCAN_ON_START] Added {len(existing_files)} files to memory store") + else: + self.logger.info("[SCAN_ON_START] No existing files found") + + files: list[str] = await self.file_store.list_files() + for file_path in files: + chunks = await self.file_store.get_file_chunks(file_path) + self.logger.info(f"Found existing file: {file_path}, {len(chunks)} chunks") + + async def _interruptible_sleep(self, seconds: float): + """Sleep that can be interrupted by stop_event.""" + try: + await asyncio.wait_for(self._stop_event.wait(), timeout=seconds) + except asyncio.TimeoutError: + pass + + async def _watch_loop(self): + """Core monitoring loop with auto-restart on failure.""" + if not self.watch_paths: + self.logger.warning("No watch paths specified") + return + + while not self._stop_event.is_set(): + valid_paths = [p for p in self.watch_paths if Path(p).exists()] + + if not valid_paths: + self.logger.warning("No valid watch paths exist, waiting 10 seconds...") + await self._interruptible_sleep(10) + continue + + invalid_paths = set(self.watch_paths) - set(valid_paths) + if invalid_paths: + self.logger.warning(f"Skipping non-existent paths: {invalid_paths}") + + try: + self.logger.info(f"Starting watch on: {valid_paths}") + async for changes in awatch( + *valid_paths, + watch_filter=self.watch_filter, + recursive=self.recursive, + debounce=self.debounce, + poll_delay_ms=self.poll_delay_ms, + stop_event=self._stop_event, + ): + if self._stop_event.is_set(): + break + await self.on_changes(changes) + + except FileNotFoundError as e: + self.logger.error(f"Watch path no longer exists: {e}, restarting in 10 seconds...") + if not self._stop_event.is_set(): + await self._interruptible_sleep(10) + + except Exception as e: + self.logger.error(f"Error in watch loop: {e}, restarting in 10 seconds...", exc_info=True) + if not self._stop_event.is_set(): + await self._interruptible_sleep(10) + + async def on_changes(self, changes: set[tuple[Change, str]]): + """Hook method for handling file changes.""" + await self._on_changes(changes) + self.logger.info(f"[{self.__class__.__name__}] on_changes: {changes}") + + @abstractmethod + async def _on_changes(self, changes: set[tuple[Change, str]]): + """Handle file changes. Override in subclasses.""" diff --git a/reme_cli/component/file_watcher/md_file_watcher.py b/reme_cli/component/file_watcher/md_file_watcher.py new file mode 100644 index 00000000..e339d975 --- /dev/null +++ b/reme_cli/component/file_watcher/md_file_watcher.py @@ -0,0 +1,86 @@ +"""Markdown file watcher for synchronization.""" + +import asyncio +from pathlib import Path + +from watchfiles import Change + +from .base_file_watcher import BaseFileWatcher +from ..component_registry import R +from ...schema import FileMetadata +from ...utils import hash_text, chunk_markdown + + +@R.register("md") +class MdFileWatcher(BaseFileWatcher): + """Markdown file watcher that syncs .md files to memory store.""" + + def __init__(self, encoding: str = "utf-8", **kwargs): + """Initialize Markdown file watcher. + + Args: + encoding: File encoding. + """ + super().__init__(**kwargs) + self.encoding = encoding + + async def _on_changes(self, changes: set[tuple[Change, str]]): + """Handle file changes with full synchronization.""" + if not self.file_store: + self.logger.warning("File store not initialized, skipping changes") + return + + for change_type, path in changes: + try: + if change_type in [Change.added, Change.modified]: + file_meta = await self._build_file_metadata(path) + chunks = ( + chunk_markdown( + file_meta.content, + file_meta.path, + self.chunk_tokens, + self.chunk_overlap, + ) + or [] + ) + if chunks: + chunks = await self.file_store.get_chunk_embeddings(chunks) + file_meta.chunk_count = len(chunks) + + await self.file_store.delete_file(file_meta.path) + self.logger.info(f"delete_file {file_meta.path}") + + await self.file_store.upsert_file(file_meta, chunks) + self.logger.info(f"Upserted {file_meta.chunk_count} chunks for {file_meta.path}") + + elif change_type == Change.deleted: + await self.file_store.delete_file(path) + self.logger.info(f"Deleted {path}") + + else: + self.logger.warning(f"Unknown change type: {change_type}") + + self.logger.info(f"File {change_type} changed: {path}") + + except FileNotFoundError: + self.logger.warning(f"File not found: {path}, skipping") + except PermissionError: + self.logger.warning(f"Permission denied: {path}, skipping") + except Exception as e: + self.logger.error(f"Error processing {path}: {e}", exc_info=True) + + async def _build_file_metadata(self, path: str) -> FileMetadata: + """Build FileMetadata from file path.""" + file_path = Path(path) + + def _read_file_sync(): + return file_path.stat(), file_path.read_text(encoding=self.encoding) + + stat, content = await asyncio.to_thread(_read_file_sync) + return FileMetadata( + hash=hash_text(content), + mtime_ms=stat.st_mtime * 1000, + size=stat.st_size, + path=str(file_path.absolute()), + content=content, + ) diff --git a/reme_cli/component/job/__init__.py b/reme_cli/component/job/__init__.py new file mode 100644 index 00000000..8f580f03 --- /dev/null +++ b/reme_cli/component/job/__init__.py @@ -0,0 +1,9 @@ +"""Job components for executing workflows.""" + +from .base_job import BaseJob +from .stream_job import StreamJob + +__all__ = [ + "BaseJob", + "StreamJob", +] diff --git a/reme_cli/component/job/base_job.py b/reme_cli/component/job/base_job.py new file mode 100644 index 00000000..8cf2776f --- /dev/null +++ b/reme_cli/component/job/base_job.py @@ -0,0 +1,86 @@ +"""Base job component for sequential step execution.""" + +from ..base_component import BaseComponent +from ..component_registry import R +from ..runtime_context import RuntimeContext +from ...enumeration import ComponentEnum +from ...schema import Response, ComponentConfig + + +@R.register("base") +class BaseJob(BaseComponent): + """Base job that executes a sequence of steps. + + A job orchestrates multiple steps in sequence, passing a runtime context + through each step. Steps are configured via ComponentConfig and instantiated + lazily when the job starts. + """ + + component_type = ComponentEnum.JOB + + def __init__( + self, + name: str = "", + description: str = "", + parameters: dict | None = None, + steps: list[ComponentConfig] | None = None, + **kwargs, + ): + """Initialize the job. + + Args: + name: Job name identifier. + description: Human-readable description. + parameters: Default parameters passed to steps. + steps: List of step configurations to execute. + **kwargs: Additional arguments passed to BaseComponent. + """ + super().__init__(**kwargs) + + self.name: str = name + self.description: str = description + self.parameters: dict = parameters or {} + self.step_configs: list[ComponentConfig] = steps or [] + self.steps: list = [] + + async def _start(self, app_context=None) -> None: + """Instantiate all configured steps. + + Args: + app_context: Application context for dependency injection. + + Raises: + ValueError: If a step backend is not specified or not registered. + """ + for step_config in self.step_configs: + if not step_config.backend: + raise ValueError(f"{step_config.backend} backend is not specified.") + + backend_cls = R.get(ComponentEnum.STEP, step_config.backend) + if not backend_cls: + raise ValueError(f"{step_config.backend} is not registered.") + + step = backend_cls( + language=app_context.app_config.language, + **step_config.model_dump(exclude={"backend"}), + ) + self.steps.append(step) + + async def _close(self) -> None: + """Clear all instantiated steps.""" + self.steps.clear() + + async def __call__(self, **kwargs) -> Response: + """Execute all steps sequentially. + + Args: + **kwargs: Parameters passed to the runtime context. + + Returns: + The final response from the runtime context. + """ + context = RuntimeContext(**kwargs) + for step in self.steps: + await step(context) + + return context.response diff --git a/reme_cli/component/job/stream_job.py b/reme_cli/component/job/stream_job.py new file mode 100644 index 00000000..ae362310 --- /dev/null +++ b/reme_cli/component/job/stream_job.py @@ -0,0 +1,37 @@ +"""Streaming job for real-time output delivery.""" + +import asyncio + +from .base_job import BaseJob +from ..component_registry import R +from ..runtime_context import RuntimeContext +from ...enumeration import ChunkEnum + + +@R.register("stream") +class StreamJob(BaseJob): + """Job that streams execution results in real-time. + + Unlike BaseJob which returns a final response, StreamJob pushes + intermediate results to a queue as they are produced, allowing + clients to receive updates incrementally. + """ + + async def __call__(self, **kwargs) -> asyncio.Queue: + """Execute all steps with streaming enabled. + + Args: + **kwargs: Parameters passed to the runtime context. + + Returns: + An asyncio.Queue containing streamed chunks. + """ + context = RuntimeContext(stream=True, **kwargs) + try: + for step in self.steps: + await step(context) + except Exception as e: + await context.add_stream_string(str(e), ChunkEnum.ERROR) + + await context.add_stream_done() + return context.stream_queue diff --git a/reme_cli/component/prompt_handler.py b/reme_cli/component/prompt_handler.py new file mode 100644 index 00000000..6ebd727c --- /dev/null +++ b/reme_cli/component/prompt_handler.py @@ -0,0 +1,113 @@ +"""Module for managing and formatting prompt templates.""" + +import inspect +import json +from pathlib import Path +from string import Formatter + +import yaml + + +class PromptHandler: + """A handler for loading, retrieving, and formatting prompt templates.""" + + _SUPPORTED_EXTENSIONS = {".yaml", ".yml", ".json"} + + def __init__(self, language: str = "", **kwargs): + self.data: dict[str, str] = {k: v for k, v in kwargs.items() if isinstance(v, str)} + self.language: str = language.strip() + + def load_prompt_by_file( + self, + prompt_file_path: str | Path | None = None, + overwrite: bool = True, + ) -> "PromptHandler": + """Load prompts from a YAML or JSON file.""" + if prompt_file_path is None: + return self + + path = Path(prompt_file_path) + if not path.exists() or path.suffix.lower() not in self._SUPPORTED_EXTENSIONS: + return self + + try: + with path.open(encoding="utf-8") as f: + prompt_dict = yaml.safe_load(f) if path.suffix in (".yaml", ".yml") else json.load(f) + except (json.JSONDecodeError, yaml.YAMLError, OSError): + return self + + return self.load_prompt_dict(prompt_dict, overwrite) + + def load_prompt_by_class(self, cls: type, overwrite: bool = True) -> "PromptHandler": + """Load prompts from a YAML file named after the class.""" + try: + base_path = Path(inspect.getfile(cls)).with_suffix("") + except (TypeError, OSError): + return self + + for ext in (".yaml", ".yml"): + if (prompt_path := base_path.with_suffix(ext)).exists(): + return self.load_prompt_by_file(prompt_path, overwrite) + + return self + + def load_prompt_dict(self, prompt_dict: dict | None = None, overwrite: bool = True) -> "PromptHandler": + """Merge prompts from a dictionary.""" + if not prompt_dict: + return self + + for key, value in prompt_dict.items(): + if isinstance(value, str) and (overwrite or key not in self.data): + self.data[key] = value + + return self + + def get_prompt(self, prompt_name: str) -> str: + """Retrieve a prompt by name with language suffix fallback.""" + for key in (f"{prompt_name}_{self.language}", prompt_name) if self.language else (prompt_name,): + if key in self.data: + return self.data[key].strip() + + raise KeyError(f"Prompt '{prompt_name}' not found. Available: {list(self.data.keys())[:10]}") + + def has_prompt(self, prompt_name: str) -> bool: + """Check if a prompt exists.""" + return prompt_name in self.data or f"{prompt_name}_{self.language}" in self.data + + def list_prompts(self, language_filter: str | None = None) -> list[str]: + """List all available prompt names.""" + if not language_filter: + return list(self.data.keys()) + suffix = f"_{language_filter.strip()}" + return [k for k in self.data if k.endswith(suffix)] + + def prompt_format(self, prompt_name: str, validate: bool = True, **kwargs) -> str: + """Format a prompt with conditional line filtering and variable substitution.""" + prompt = self.get_prompt(prompt_name) + flags = {k: v for k, v in kwargs.items() if isinstance(v, bool)} + formats = {k: v for k, v in kwargs.items() if not isinstance(v, bool)} + + if flags: + lines = [] + for line in prompt.split("\n"): + remaining = line + should_include = False + for flag, enabled in flags.items(): + prefix = f"[{flag}]" + while remaining.startswith(prefix): + remaining = remaining[len(prefix) :] + if enabled: + should_include = True + if should_include or not any(line.startswith(f"[{f}]") for f in flags): + lines.append(remaining) + prompt = "\n".join(lines) + + if validate: + required = {f for _, f, _, _ in Formatter().parse(prompt) if f is not None} + if missing := required - set(formats.keys()): + raise ValueError(f"Missing format variables for '{prompt_name}': {sorted(missing)}") + + return prompt.format(**formats).strip() if formats else prompt.strip() + + def __repr__(self) -> str: + return f"PromptHandler(language='{self.language}', num_prompts={len(self.data)})" diff --git a/reme_cli/component/runtime_context.py b/reme_cli/component/runtime_context.py new file mode 100644 index 00000000..b56fb232 --- /dev/null +++ b/reme_cli/component/runtime_context.py @@ -0,0 +1,90 @@ +"""Runtime context for managing response states and asynchronous data streaming.""" + +import asyncio + +from .application_context import ApplicationContext +from ..enumeration import ChunkEnum +from ..schema import Response, StreamChunk + + +class RuntimeContext: + """Context for execution state, response metadata, and stream queues.""" + + def __init__(self, **kwargs): + """Initialize the context with all keyword arguments stored in data.""" + self.data: dict = kwargs + + @property + def response(self) -> Response: + """Get or create the response object.""" + return self.data.setdefault("response", Response()) + + @property + def stream_queue(self) -> asyncio.Queue: + """Get the stream queue.""" + return self.data["stream_queue"] + + @property + def application_context(self) -> ApplicationContext: + """Get the application context.""" + return self.data["application_context"] + + @classmethod + def from_context(cls, context: "RuntimeContext | None" = None, **kwargs) -> "RuntimeContext": + """Create a new context from an existing instance or keywords.""" + if context is None: + return cls(**kwargs) + context.data.update(kwargs) + return context + + async def _enqueue(self, chunk: StreamChunk) -> None: + """Internal helper to put a chunk into the queue if it exists.""" + if self.stream_queue: + await self.stream_queue.put(chunk) + + async def add_stream_string(self, chunk: str, chunk_type: ChunkEnum) -> "RuntimeContext": + """Enqueue a stream chunk from a raw string and type.""" + await self._enqueue(StreamChunk(chunk_type=chunk_type, chunk=chunk)) + return self + + async def add_stream_chunk(self, stream_chunk: StreamChunk) -> "RuntimeContext": + """Enqueue an existing stream chunk.""" + await self._enqueue(stream_chunk) + return self + + async def add_stream_done(self) -> "RuntimeContext": + """Enqueue a termination chunk to signal the end of the stream.""" + await self._enqueue(StreamChunk(chunk_type=ChunkEnum.DONE, chunk="", done=True)) + return self + + def add_response_error(self, e: Exception) -> "RuntimeContext": + """Record an exception into the response object.""" + self.response.success = False + self.response.answer = str(e) + return self + + def apply_mapping(self, mapping: dict[str, str]) -> "RuntimeContext": + """Copy internal values based on a source-to-target key map.""" + if not mapping: + return self + + for source, target in mapping.items(): + if source in self.data: + self.data[target] = self.data[source] + return self + + def validate_required_keys( + self, + required_keys: dict[str, bool], + context_name: str = "context", + ) -> "RuntimeContext": + """Ensure all required keys are present in the context. + + Args: + required_keys: Dictionary mapping key names to boolean indicating if required + context_name: Name of the context for error messages (e.g., operator name) + """ + for key, is_required in required_keys.items(): + if is_required and key not in self.data: + raise ValueError(f"{context_name}: missing required input '{key}'") + return self diff --git a/reme_cli/component/service/__init__.py b/reme_cli/component/service/__init__.py new file mode 100644 index 00000000..13715368 --- /dev/null +++ b/reme_cli/component/service/__init__.py @@ -0,0 +1,9 @@ +"""Service components for exposing jobs via different protocols.""" + +from .base_service import BaseService +from .http_service import HttpService + +__all__ = [ + "BaseService", + "HttpService", +] diff --git a/reme_cli/component/service/base_service.py b/reme_cli/component/service/base_service.py new file mode 100644 index 00000000..897efd39 --- /dev/null +++ b/reme_cli/component/service/base_service.py @@ -0,0 +1,73 @@ +"""Abstract base class for service implementations.""" + +from abc import abstractmethod + +from ..base_component import BaseComponent +from ..job.base_job import BaseJob +from ...enumeration import ComponentEnum + + +class BaseService(BaseComponent): + """Abstract base class for services that expose jobs. + + Services provide different ways to invoke jobs (HTTP, CLI, MCP, etc.). + Subclasses must implement add_job to register jobs with the service. + """ + + component_type = ComponentEnum.SERVICE + + from ...application import Application + + def __init__(self, **kwargs): + """Initialize the service. + + Args: + **kwargs: Additional service-specific configuration. + """ + super().__init__(**kwargs) + self.service = None + + async def _start(self, app_context=None) -> None: + """Default empty implementation for sync services.""" + + async def _close(self) -> None: + """Default empty implementation for sync services.""" + + @abstractmethod + def add_job(self, job: BaseJob) -> None: + """Register a job with the service. + + Args: + job: The job to register. + """ + + @abstractmethod + def build_service(self, app: "Application") -> None: + """Build the service. + + Args: + app: The application instance. + """ + + @abstractmethod + def start_service(self, app: "Application") -> None: + """Start the service.""" + + def add_jobs(self, app: "Application") -> None: + """Register all jobs from the application context.""" + for name, job in app.context.jobs.values(): + try: + self.add_job(job) + self.logger.info(f"Added job {name}") + except Exception as e: + self.logger.error(f"Failed to add job {name}: {e}") + + def run_app(self, app: "Application") -> None: + """Register all jobs from the application and start the service. + + Args: + app: The application containing jobs to register. + """ + self.build_service(app) + self.add_jobs(app) + self.start_service(app) diff --git a/reme_cli/component/service/http_service.py b/reme_cli/component/service/http_service.py new file mode 100644 index 00000000..4b8f67b1 --- /dev/null +++ b/reme_cli/component/service/http_service.py @@ -0,0 +1,138 @@ +"""HTTP service implementation using FastAPI and uvicorn.""" + +import asyncio +import json +import os +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager + +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import StreamingResponse + +from .base_service import BaseService +from ..component_registry import R +from ..job import BaseJob, StreamJob +from ...constants import REME_DEFAULT_HOST, REME_DEFAULT_PORT, REME_SERVICE_INFO +from ...schema import Request, Response +from ...utils import execute_stream_task + + +@R.register("http") +class HttpService(BaseService): + """HTTP service that exposes jobs as REST endpoints. + + Regular jobs return JSON responses, while StreamJobs return + server-sent events (SSE) for real-time streaming. + """ + + from ...application import Application + + def __init__(self, host: str = REME_DEFAULT_HOST, port: int = REME_DEFAULT_PORT, **kwargs): + """Initialize the HTTP service. + + Args: + host: Bind address for the server. + port: Port number for the server. + **kwargs: Additional arguments passed to uvicorn. + """ + super().__init__(**kwargs) + self.host: str = host + self.port: int = port + + def _add_job(self, job: BaseJob) -> None: + """Register a regular job as a POST endpoint. + + Args: + job: The job to register. + """ + + async def execute_endpoint(request: Request) -> Response: + return await job(**request.model_dump(exclude_none=True)) + + self.service.post( + path=f"/{job.name}", + response_model=Response, + description=job.description, + )(execute_endpoint) + + def _add_stream_job(self, job: StreamJob) -> None: + """Register a stream job as an SSE endpoint. + + Args: + job: The stream job to register. + """ + + async def execute_stream_endpoint(request: Request) -> StreamingResponse: + stream_queue = asyncio.Queue() + task = asyncio.create_task( + job(stream_queue=stream_queue, **request.model_dump(exclude_none=True)), + ) + + async def generate_stream() -> AsyncGenerator[bytes, None]: + async for chunk in execute_stream_task( + stream_queue=stream_queue, + task=task, + task_name=job.name, + output_format="bytes", + ): + assert isinstance(chunk, bytes) + yield chunk + + return StreamingResponse(generate_stream(), media_type="text/event-stream") + + self.service.post(f"/{job.name}")(execute_stream_endpoint) + + def add_job(self, job: BaseJob) -> None: + """Register a job with the HTTP service. + + StreamJobs are registered as SSE endpoints, regular jobs as JSON endpoints. + + Args: + job: The job to register. + """ + if isinstance(job, StreamJob): + self._add_stream_job(job) + else: + self._add_job(job) + + def build_service(self, app: Application) -> None: + """Build the FastAPI application with CORS middleware. + + Args: + app: The application instance. + """ + + @asynccontextmanager + async def lifespan(_: FastAPI): + await app.start() + service_info = json.dumps( + { + "host": self.host, + "port": self.port, + }, + ) + os.environ[REME_SERVICE_INFO] = service_info + self.logger.info(f"ReMe Service started: {REME_SERVICE_INFO}={service_info}") + yield + await app.close() + + self.service = FastAPI(title=app.config.app_name, lifespan=lifespan) + + self.service.add_middleware( + CORSMiddleware, # type: ignore[arg-type] + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + self.service.post("/health")(lambda: {"status": "healthy"}) + + def start_service(self, app: Application) -> None: + """Start the HTTP server. + + Args: + app: The application instance. + """ + uvicorn.run(self.service, host=self.host, port=self.port, **self.kwargs) diff --git a/reme_cli/config/__init__.py b/reme_cli/config/__init__.py new file mode 100644 index 00000000..2e0ce2d1 --- /dev/null +++ b/reme_cli/config/__init__.py @@ -0,0 +1,7 @@ +"""Config""" + +from .config_parser import parse_args + +__all__ = [ + "parse_args", +] diff --git a/reme_cli/config/config_parser.py b/reme_cli/config/config_parser.py new file mode 100644 index 00000000..c4a0d7d5 --- /dev/null +++ b/reme_cli/config/config_parser.py @@ -0,0 +1,146 @@ +"""Parser for YAML config with CLI argument overrides.""" + +import json +from pathlib import Path +from typing import Any + +import yaml + +# Config files are looked up relative to this module's directory +_CONFIG_DIR = Path(__file__).parent +_SUPPORTED_EXTS = (".yaml", ".yml", ".json") + + +# Pre-scan config directory: maps basename(without ext) -> Path +def _discover_configs() -> dict[str, Path]: + discovered: dict[str, Path] = {} + if _CONFIG_DIR.is_dir(): + for p in _CONFIG_DIR.iterdir(): + if p.is_file() and p.suffix in _SUPPORTED_EXTS: + discovered.setdefault(p.stem, p) + return discovered + + +_CONFIG_REGISTRY = _discover_configs() + + +def parse_dot_notation(dot_list: list[str]) -> dict: + """Parse "key.subkey=value" strings into nested dict.""" + result: dict = {} + for item in dot_list: + if "=" not in item: + raise ValueError(f"Invalid dot notation format (missing '='): {item}") + key_path, value_str = item.split("=", 1) + keys = key_path.split(".") + current = result + for key in keys[:-1]: + if key in current and not isinstance(current[key], dict): + raise ValueError(f"Cannot set nested key '{key_path}': '{key}' is already a value") + current = current.setdefault(key, {}) + current[keys[-1]] = _convert_value(value_str) + return result + + +def _convert_value(value_str: str) -> Any: + """Convert string to appropriate Python type. + + Only converts "true"/"false" (case-insensitive) to boolean. + Use JSON format (e.g., '"yes"', '"no"') to preserve these as strings. + """ + s = value_str.strip() + lower = s.lower() + + # Handle special values (null, bool) + if lower in ("none", "null"): + return None + if lower == "true": + return True + if lower == "false": + return False + + # Try numeric and JSON conversions + for converter in (int, float, json.loads): + try: + return converter(s) + except (ValueError, json.JSONDecodeError): + continue + + # Fallback to string + return s + + +def _load_yaml(name_or_path: str, encoding: str = "utf-8") -> dict: + """Load YAML or JSON file. + + First check if name_or_path matches a pre-discovered config (key in _CONFIG_REGISTRY). + If not, treat as a file path and load directly. + """ + # 1. Try pre-discovered configs first + if name_or_path in _CONFIG_REGISTRY: + return _read_config_file(_CONFIG_REGISTRY[name_or_path], encoding) + + # 2. Treat as file path + p = Path(name_or_path) + if p.suffix in _SUPPORTED_EXTS: + if not p.exists(): + raise FileNotFoundError(f"Config file not found: {p}") + return _read_config_file(p, encoding) + + known = ", ".join(sorted(_CONFIG_REGISTRY)) if _CONFIG_REGISTRY else "none" + raise FileNotFoundError(f"Config file not found: {name_or_path}. Available: {known}") + + +def _read_config_file(path: Path, encoding: str = "utf-8") -> dict: + """Read YAML or JSON file based on extension.""" + with path.open(encoding=encoding) as f: + if path.suffix == ".json": + result = json.load(f) + return result if result is not None else {} + else: + result = yaml.safe_load(f) + return result if result is not None else {} + + +def _deep_merge(base: dict, update: dict) -> dict: + """Recursively merge dicts.""" + result = base.copy() + for k, v in update.items(): + if k in result and isinstance(result[k], dict) and isinstance(v, dict): + result[k] = _deep_merge(result[k], v) + else: + result[k] = v + return result + + +def parse_args(*args, **kwargs) -> tuple[str, dict]: + """Parse CLI args: first arg is action, rest are config overrides. + + Usage: reme app config=paw.yaml service.name=test + Returns: (action, merged_config_dict) + """ + if not args: + raise ValueError("No arguments provided") + + first = args[0].lstrip("-") + if "=" in first: + raise ValueError(f"First argument must be action, got: {args[0]}") + + action = first + configs: list[dict] = [] + + for arg in args[1:]: + arg = arg.lstrip("-") + if arg.startswith("config="): + path = arg.split("=", 1)[1].strip() + if path: + configs.append(_load_yaml(path)) + elif "=" in arg: + configs.append(parse_dot_notation([arg])) + + configs.append(kwargs) + + merged: dict = {} + for cfg in configs: + merged = _deep_merge(merged, cfg) + + return action, merged diff --git a/reme_cli/config/paw.yaml b/reme_cli/config/paw.yaml new file mode 100644 index 00000000..20d5c392 --- /dev/null +++ b/reme_cli/config/paw.yaml @@ -0,0 +1,55 @@ +enable_logo: false +log_to_console: true +log_to_file: false + +service: + backend: cmd + +jobs: + - backend: base + name: test + description: "test job" + parameters: + type: object + properties: + name: + type: string + description: "name of the user" + steps: + - name: "test1" + backend: xxx + - name: "test2" + backend: xxx + +components: + as_llms: + default: + backend: openai + model_name: qwen3.6-plus + + as_llm_formatters: + default: + backend: openai + + embedding_models: + default: + backend: openai + dimensions: 1024 + use_dimensions: false + enable_cache: true + max_batch_size: 10 + max_cache_size: 2000 + max_input_length: 8192 + + file_stores: + default: + backend: chroma + embedding_model: default + store_name: "reme" + + file_watchers: + default: + backend: full + file_store: default + suffix_filters: [ ".md" ] + recursive: false diff --git a/reme_cli/constants.py b/reme_cli/constants.py new file mode 100644 index 00000000..fce66e64 --- /dev/null +++ b/reme_cli/constants.py @@ -0,0 +1,7 @@ +"""Constants""" + +REME_SERVICE_INFO = "REME_SERVICE_INFO" + +REME_DEFAULT_HOST = "127.0.0.1" + +REME_DEFAULT_PORT = 2333 diff --git a/reme_cli/enumeration/__init__.py b/reme_cli/enumeration/__init__.py new file mode 100644 index 00000000..9887ddee --- /dev/null +++ b/reme_cli/enumeration/__init__.py @@ -0,0 +1,9 @@ +"""Enumeration""" + +from .chunk_enum import ChunkEnum +from .component_enum import ComponentEnum + +__all__ = [ + "ChunkEnum", + "ComponentEnum", +] diff --git a/reme_cli/enumeration/chunk_enum.py b/reme_cli/enumeration/chunk_enum.py new file mode 100644 index 00000000..c64a5468 --- /dev/null +++ b/reme_cli/enumeration/chunk_enum.py @@ -0,0 +1,28 @@ +"""Chunk enumeration module. + +Defines the types of data chunks used in streaming responses. +""" + +from enum import Enum + + +class ChunkEnum(str, Enum): + """Enumeration of possible chunk categories for stream processing. + + This enum defines the various types of chunks that can be transmitted + during a streaming response from an LLM or agent system. + """ + + THINK = "think" + + CONTENT = "content" + + TOOL_CALL = "tool_call" + + TOOL_RESULT = "tool_result" + + USAGE = "usage" + + ERROR = "error" + + DONE = "done" diff --git a/reme_cli/enumeration/component_enum.py b/reme_cli/enumeration/component_enum.py new file mode 100644 index 00000000..bce70d7b --- /dev/null +++ b/reme_cli/enumeration/component_enum.py @@ -0,0 +1,35 @@ +"""Component enumeration module. + +Defines the types of components that can be registered and used in the application. +""" + +from enum import Enum + + +class ComponentEnum(str, Enum): + """Enumeration of component types for dependency injection and registration. + + This enum defines the various component categories that can be registered + in the application's component container. Each component type represents + a specific role or functionality within the system. + """ + + BASE = "base" + + AS_LLM = "as_llm" + + AS_LLM_FORMATTER = "as_llm_formatter" + + EMBEDDING_MODEL = "embedding_model" + + FILE_STORE = "file_store" + + FILE_WATCHER = "file_watcher" + + SERVICE = "service" + + CLIENT = "client" + + STEP = "step" + + JOB = "job" diff --git a/reme_cli/jobs/__init__.py b/reme_cli/jobs/__init__.py new file mode 100644 index 00000000..c4d57960 --- /dev/null +++ b/reme_cli/jobs/__init__.py @@ -0,0 +1 @@ +"""Jobs""" diff --git a/reme_cli/reme.py b/reme_cli/reme.py new file mode 100644 index 00000000..03fd34fd --- /dev/null +++ b/reme_cli/reme.py @@ -0,0 +1,77 @@ +"""ReMe CLI application entry point.""" + +import asyncio +import sys + +from agentscope.formatter import FormatterBase +from agentscope.message import Msg +from agentscope.model import ChatModelBase +from agentscope.token import TokenCounterBase +from agentscope.tool import Toolkit, ToolResponse + +from .application import Application +from .component import R +from .config import parse_args +from .enumeration import ComponentEnum + + +class ReMe(Application): + """ReMe memory management application.""" + + async def summary_memory( + self, + messages: list[Msg], + as_llm: str | ChatModelBase = "default", + as_llm_formatter: str | FormatterBase = "default", + as_token_counter: str | TokenCounterBase = "default", + toolkit: Toolkit | None = None, + language: str = "zh", + max_input_length: float = 128 * 1024, + compact_ratio: float = 0.7, + timezone: str | None = None, + add_thinking_block: bool = True, + ) -> str: + """Summarize and compact memory messages.""" + + async def memory_search(self, query: str, max_results: int = 5, min_score: float = 0.1) -> ToolResponse: + """Search memory for relevant entries.""" + + async def dream( + self, + as_llm: str | ChatModelBase = "default", + as_llm_formatter: str | FormatterBase = "default", + as_token_counter: str | TokenCounterBase = "default", + toolkit: Toolkit | None = None, + language: str = "zh", + timezone: str | None = None, + ) -> str: + """Process and consolidate memories in background.""" + + async def proactive( + self, + as_llm: str | ChatModelBase = "default", + as_llm_formatter: str | FormatterBase = "default", + as_token_counter: str | TokenCounterBase = "default", + toolkit: Toolkit | None = None, + language: str = "zh", + timezone: str | None = None, + ) -> str: + """Generate proactive memory insights.""" + + +def main(): + """Entry point for ReMe CLI.""" + action, config = parse_args(sys.argv[1:]) + if action == "app": + reme = ReMe(**config) + reme.run_app() + + else: + backend: str = config.pop("backend", "http") + client_cls = R.get(ComponentEnum.CLIENT, backend) + client = client_cls(action=action, **config) + asyncio.run(client()) + + +if __name__ == "__main__": + main() diff --git a/reme_cli/schema/__init__.py b/reme_cli/schema/__init__.py new file mode 100644 index 00000000..89b7fe3d --- /dev/null +++ b/reme_cli/schema/__init__.py @@ -0,0 +1,21 @@ +"""Schema""" + +from .application_config import ApplicationConfig, ComponentConfig, JobConfig +from .base_node import BaseNode +from .file_chunk import FileChunk +from .file_metadata import FileMetadata +from .request import Request +from .response import Response +from .stream_chunk import StreamChunk + +__all__ = [ + "ApplicationConfig", + "ComponentConfig", + "JobConfig", + "BaseNode", + "FileChunk", + "FileMetadata", + "Request", + "Response", + "StreamChunk", +] diff --git a/reme_cli/schema/application_config.py b/reme_cli/schema/application_config.py new file mode 100644 index 00000000..1cf36c31 --- /dev/null +++ b/reme_cli/schema/application_config.py @@ -0,0 +1,80 @@ +"""Application configuration schema module. + +This module defines the configuration models for the ReMe CLI application, +including application-level settings, service configuration, and job definitions. +""" + +import os + +from pydantic import BaseModel, ConfigDict, Field + +from ..enumeration import ComponentEnum + + +class ComponentConfig(BaseModel): + """Base configuration for a component. + + This serves as the base class for all component configurations, + allowing extra fields to be defined dynamically. + + Attributes: + backend: The backend implementation class name for this component. + """ + + model_config = ConfigDict(extra="allow") + + backend: str = Field(default="", description="Backend implementation class name") + + +class JobConfig(ComponentConfig): + """Configuration for a job definition. + + A job represents a sequence of steps that can be executed + as part of the application workflow. + + Attributes: + name: Unique identifier name for the job. + description: Human-readable description of what the job does. + parameters: Job-level parameters passed to all steps. + steps: Ordered list of step configurations to execute. + """ + + name: str = Field(default="", description="Unique job identifier") + description: str = Field(default="", description="Human-readable job description") + parameters: dict = Field(default_factory=dict, description="Job-level parameters") + steps: list[ComponentConfig] = Field(default_factory=list, description="Ordered list of step configs") + + +class ApplicationConfig(BaseModel): + """Root configuration for the ReMe CLI application. + + This model contains all configuration settings needed to initialize + and run the application, including service endpoints, job definitions, + and component registry. + + Attributes: + app_name: Display name of the application. + working_dir: Working directory for runtime files and logs. + enable_logo: Whether to display the ASCII logo on startup. + language: Default language for LLM interactions. + log_to_console: Whether to output logs to console. + log_to_file: Whether to write logs to a file. + mcp_servers: MCP server configurations indexed by name. + service: Service endpoint configuration. + jobs: List of job definitions. + components: Component registry indexed by component type and name. + """ + + app_name: str = Field(default=os.getenv("APP_NAME", "ReMe"), description="Application display name") + working_dir: str = Field(default=".reme", description="Working directory for runtime files") + enable_logo: bool = Field(default=False, description="Whether to show ASCII logo on startup") + language: str = Field(default="", description="Default language for LLM interactions") + log_to_console: bool = Field(default=True, description="Whether to log to console") + log_to_file: bool = Field(default=True, description="Whether to log to file") + mcp_servers: dict[str, dict] = Field(default_factory=dict, description="MCP server configurations") + service: ComponentConfig = Field(default_factory=ComponentConfig, description="Service endpoint config") + jobs: list[JobConfig] = Field(default_factory=list, description="Job definitions") + components: dict[ComponentEnum, dict[str, ComponentConfig]] = Field( + default_factory=dict, + description="Component registry by type", + ) diff --git a/reme_cli/schema/base_node.py b/reme_cli/schema/base_node.py new file mode 100644 index 00000000..064061a7 --- /dev/null +++ b/reme_cli/schema/base_node.py @@ -0,0 +1,29 @@ +"""Base node schema module. + +This module defines the BaseNode model, which serves as the foundational +data structure for nodes in the knowledge graph or document processing pipeline. +""" + +from uuid import uuid4 + +from pydantic import BaseModel, Field + + +class BaseNode(BaseModel): + """Base node model for graph and document structures. + + This model represents a single node in the knowledge graph or + a chunk in the document processing pipeline. It contains text content, + optional embeddings, and associated metadata. + + Attributes: + id: Unique identifier for the node, auto-generated if not provided. + text: Text content of the node. + embedding: Optional vector embedding of the text content. + metadata: Additional metadata associated with the node. + """ + + id: str = Field(default_factory=lambda: uuid4().hex, description="Unique node identifier") + text: str = Field(default="", description="Text content of the node") + embedding: list[float] | None = Field(default=None, description="Vector embedding of text") + metadata: dict = Field(default_factory=dict, description="Additional metadata") diff --git a/reme_cli/schema/file_chunk.py b/reme_cli/schema/file_chunk.py new file mode 100644 index 00000000..b08a95f3 --- /dev/null +++ b/reme_cli/schema/file_chunk.py @@ -0,0 +1,53 @@ +"""File chunk schema module. + +This module defines the FileChunk model for representing chunks of file content +in the document processing and retrieval pipeline. +""" + +from pydantic import Field + +from .base_node import BaseNode + + +class FileChunk(BaseNode): + """A chunk of file content with positional and scoring metadata. + + Represents a contiguous section of a file that has been extracted + for processing, embedding, or retrieval. Inherits text and embedding + capabilities from BaseNode. + + Attributes: + path: File path relative to workspace root. + start_line: Starting line number (1-indexed) in the source file. + end_line: Ending line number (1-indexed) in the source file. + hash: Hash of the chunk content for deduplication. + scores: Search relevance scores indexed by score type. + + Properties: + score: Final combined score for search result ranking. + merge_key: Unique key for merging duplicate search results. + """ + + path: str = Field(..., description="File path relative to workspace") + start_line: int = Field(..., description="Starting line number (1-indexed)") + end_line: int = Field(..., description="Ending line number (1-indexed)") + hash: str = Field(..., description="Hash of chunk content for deduplication") + scores: dict[str, float] = Field(default_factory=dict, description="Search scores by type") + + @property + def score(self) -> float: + """Get the final score for search result ranking. + + Returns: + The combined score, or 0.0 if not set. + """ + return self.scores.get("score", 0.0) + + @property + def merge_key(self) -> str: + """Generate a unique key for merging search results. + + Returns: + A string key in format "path:start_line:end_line". + """ + return f"{self.path}:{self.start_line}:{self.end_line}" diff --git a/reme_cli/schema/file_metadata.py b/reme_cli/schema/file_metadata.py new file mode 100644 index 00000000..d385224c --- /dev/null +++ b/reme_cli/schema/file_metadata.py @@ -0,0 +1,33 @@ +"""File metadata schema module. + +This module defines the FileMetadata model for tracking file state and +content information in the document processing pipeline. +""" + +from pydantic import BaseModel, Field + + +class FileMetadata(BaseModel): + """File metadata with optional extended fields. + + Stores essential file information for tracking changes and + managing the document processing pipeline. Optional fields allow + for different usage patterns (e.g., just tracking vs. full content). + + Attributes: + hash: Hash of the file content for change detection. + mtime_ms: Last modification time in milliseconds since epoch. + size: File size in bytes. + path: Relative path to the file within the workspace. + content: Parsed content from the file (optional, memory-intensive). + chunk_count: Number of chunks extracted from this file. + metadata: Additional file-specific metadata. + """ + + hash: str = Field(..., description="Hash of file content for change detection") + mtime_ms: float = Field(..., description="Last modification time in milliseconds") + size: int = Field(..., description="File size in bytes") + path: str | None = Field(default=None, description="Relative path within workspace") + content: str | None = Field(default=None, description="Parsed content (optional)") + chunk_count: int | None = Field(default=None, description="Number of extracted chunks") + metadata: dict = Field(default_factory=dict, description="Additional file metadata") diff --git a/reme_cli/schema/request.py b/reme_cli/schema/request.py new file mode 100644 index 00000000..ce8f3cb0 --- /dev/null +++ b/reme_cli/schema/request.py @@ -0,0 +1,25 @@ +"""Request schema module. + +This module defines the Request model for handling incoming requests +in the application service layer. +""" + +from pydantic import BaseModel, ConfigDict, Field + + +class Request(BaseModel): + """Request model for service endpoints. + + Represents an incoming request with optional metadata and + extensible fields for various request types. + + The model uses ConfigDict with extra="allow" to support + dynamically added fields while maintaining type safety. + + Attributes: + metadata: Request-level metadata for tracking and context. + """ + + model_config = ConfigDict(extra="allow") + + metadata: dict = Field(default_factory=dict, description="Request metadata for context") diff --git a/reme_cli/schema/response.py b/reme_cli/schema/response.py new file mode 100644 index 00000000..20e90027 --- /dev/null +++ b/reme_cli/schema/response.py @@ -0,0 +1,28 @@ +"""Response schema module. + +This module defines the standardized data structure for model output +responses used throughout the application. +""" + +from typing import Any + +from pydantic import BaseModel, Field, ConfigDict + + +class Response(BaseModel): + """Represents a structured response with result, status, and metadata. + + This model provides a consistent interface for returning results + from operations, LLM calls, and service endpoints. + + Attributes: + answer: The main response content, typically a string or structured data. + success: Whether the operation completed successfully. + metadata: Additional context and diagnostic information. + """ + + model_config = ConfigDict(extra="allow") + + answer: str | Any = Field(default="", description="Response content or result data") + success: bool = Field(default=True, description="Operation success status") + metadata: dict = Field(default_factory=dict, description="Additional context and diagnostics") diff --git a/reme_cli/schema/stream_chunk.py b/reme_cli/schema/stream_chunk.py new file mode 100644 index 00000000..4fc8e687 --- /dev/null +++ b/reme_cli/schema/stream_chunk.py @@ -0,0 +1,28 @@ +"""Stream chunk schema module. + +This module defines the StreamChunk model for handling streaming +responses in the application, particularly for LLM outputs. +""" + +from pydantic import BaseModel, Field + +from ..enumeration import ChunkEnum + + +class StreamChunk(BaseModel): + """A chunk of streaming response data. + + Represents a single chunk in a streaming response sequence, + commonly used for LLM outputs that are delivered incrementally. + + Attributes: + chunk_type: Type identifier for the chunk content. + chunk: The actual chunk data (string, dict, or list). + done: Whether this is the final chunk in the stream. + metadata: Additional metadata about this chunk. + """ + + chunk_type: ChunkEnum = Field(default=ChunkEnum.CONTENT, description="Type of chunk content") + chunk: str | dict | list = Field(default="", description="Chunk payload data") + done: bool = Field(default=False, description="Whether stream is complete") + metadata: dict = Field(default_factory=dict, description="Chunk metadata") diff --git a/reme_cli/utils/__init__.py b/reme_cli/utils/__init__.py new file mode 100644 index 00000000..f0031822 --- /dev/null +++ b/reme_cli/utils/__init__.py @@ -0,0 +1,22 @@ +"""Utility modules""" + +from .case_converter import camel_to_snake, snake_to_camel +from .chunking_utils import chunk_markdown +from .common_utils import hash_text, execute_stream_task +from .logger_utils import get_logger +from .logo_utils import print_logo +from .similarity_utils import cosine_similarity, batch_cosine_similarity +from .singleton import singleton + +__all__ = [ + "camel_to_snake", + "snake_to_camel", + "chunk_markdown", + "hash_text", + "execute_stream_task", + "get_logger", + "print_logo", + "cosine_similarity", + "batch_cosine_similarity", + "singleton", +] diff --git a/reme_cli/utils/case_converter.py b/reme_cli/utils/case_converter.py new file mode 100644 index 00000000..a6a1d23e --- /dev/null +++ b/reme_cli/utils/case_converter.py @@ -0,0 +1,46 @@ +"""Case conversion utility for PascalCase, camelCase, and snake_case. + +Provides bidirectional conversion between snake_case and PascalCase, +with special handling for common acronyms (LLM, API, URL, etc.). +""" + +import re + +# Acronyms that should remain uppercase in Pascal/camelCase +_ACRONYMS = frozenset({"LLM", "API", "URL", "HTTP", "JSON", "XML", "AI", "MCP"}) +_ACRONYM_MAP = {word.lower(): word for word in _ACRONYMS} + + +def camel_to_snake(content: str) -> str: + """Convert PascalCase or camelCase to snake_case. + + Handles acronyms correctly by normalizing them before conversion. + For example, "OpenAILLMClient" becomes "open_ai_llm_client". + + Args: + content: A string in PascalCase or camelCase format. + + Returns: + The converted snake_case string. + """ + # Normalize acronyms to title case (e.g., LLM -> Llm) to assist regex splitting + for word in _ACRONYMS: + content = content.replace(word, word.capitalize()) + + # Insert underscores between case transitions and convert to lowercase + return re.sub(r"(? str: + """Convert snake_case to PascalCase. + + Preserves defined acronyms in uppercase form. + For example, "open_ai_llm_client" becomes "OpenAILLMClient". + + Args: + content: A string in snake_case format. + + Returns: + The converted PascalCase string with acronyms preserved. + """ + return "".join(_ACRONYM_MAP.get(part.lower(), part.capitalize()) for part in content.split("_") if part) diff --git a/reme_cli/utils/chunking_utils.py b/reme_cli/utils/chunking_utils.py new file mode 100644 index 00000000..87ee0df7 --- /dev/null +++ b/reme_cli/utils/chunking_utils.py @@ -0,0 +1,144 @@ +"""Markdown file chunking utilities. + +Provides functionality to split Markdown documents into smaller chunks +while maintaining overlap between consecutive chunks for context preservation. +""" + +from .common_utils import hash_text +from ..schema import FileChunk + + +def chunk_markdown( + text: str, + path: str, + chunk_tokens: int, + overlap: int, +) -> list[FileChunk]: + """Split Markdown text into chunks with configurable size and overlap. + + Implements a sliding window approach to chunk Markdown content while + preserving context through overlap between consecutive chunks. Token + counts are approximated using a 1:4 ratio (1 token ≈ 4 characters). + + Args: + text: Input Markdown text to be chunked. + path: File path identifier for the source document. + chunk_tokens: Maximum number of tokens per chunk. Will be converted + to characters using the 1:4 ratio, with a minimum of 32 characters. + overlap: Number of overlapping tokens between consecutive chunks. + Helps maintain context across chunk boundaries. + + Returns: + A list of FileChunk objects, each containing: + - id: Unique identifier based on path, line numbers, and hash + - path: The source file path + - start_line: Starting line number (1-indexed) + - end_line: Ending line number (1-indexed) + - text: The chunk content + - hash: SHA-256 hash of the chunk content + + Examples: + >>> text = "# Header\\nParagraph content here.\\n\\n## Subheader" + >>> chunks = chunk_markdown(text, "doc.md", chunk_tokens=100, overlap=20) + >>> len(chunks) + 1 + >>> chunks[0].path + 'doc.md' + """ + if not text.strip(): + return [] + + lines = text.split("\n") + + # Convert tokens to characters (~1 token = 4 chars) + max_chars = max(32, chunk_tokens * 4) + overlap_chars = max(0, overlap * 4) + + chunks: list[FileChunk] = [] + + # Currently building chunk + current: list[dict] = [] # [{'line': str, 'line_no': int}] + current_chars = 0 + + def flush() -> None: + """Add current chunk to results list.""" + if not current: + return + + first_entry = current[0] + last_entry = current[-1] + + if not first_entry or not last_entry: + return + + chunk_text = "\n".join([entry["line"] for entry in current]) + start_line = first_entry["line_no"] + end_line = last_entry["line_no"] + + chunk_hash = hash_text(chunk_text) + + chunks.append( + FileChunk( + id=hash_text(f"{path}:{start_line}:{end_line}:{chunk_hash}:{len(chunks)}"), + path=path, + start_line=start_line, + end_line=end_line, + text=chunk_text, + hash=chunk_hash, + ), + ) + + def carry_overlap() -> None: + """Keep overlapping part and clear the rest.""" + nonlocal current, current_chars + + if overlap_chars <= 0 or not current: + current = [] + current_chars = 0 + return + + acc = 0 + kept = [] + + # Collect lines from the end until reaching overlap size + for j in range(len(current) - 1, -1, -1): + entry = current[j] + if not entry: + continue + + acc += len(entry["line"]) + 1 # +1 for newline + kept.insert(0, entry) # Insert at the beginning to maintain order + + if acc >= overlap_chars: + break + + current = kept + current_chars = sum(len(entry["line"]) + 1 for entry in kept) + + for i, line in enumerate(lines): + line_no = i + 1 + + # Split long lines into multiple segments + segments = [] + if not line: # Empty line + segments.append("") + else: + # If line is too long, split by maximum character count + for start in range(0, len(line), max_chars): + segments.append(line[start : start + max_chars]) + + for segment in segments: + line_size = len(segment) + 1 # +1 for newline + + # If adding current segment would exceed the limit, flush current chunk + if current_chars + line_size > max_chars and current: + flush() + carry_overlap() + + current.append({"line": segment, "line_no": line_no}) + current_chars += line_size + + # Process the final chunk + flush() + + return [c for c in chunks if c.text.strip()] diff --git a/reme_cli/utils/common_utils.py b/reme_cli/utils/common_utils.py new file mode 100644 index 00000000..6c86d487 --- /dev/null +++ b/reme_cli/utils/common_utils.py @@ -0,0 +1,152 @@ +"""Common utility functions for the application. + +Provides general-purpose utilities including text hashing and async +stream processing for task execution. +""" + +import asyncio +import hashlib +from collections.abc import AsyncGenerator, Coroutine +from typing import Any, Literal + +from .logger_utils import get_logger +from ..enumeration import ChunkEnum +from ..schema import StreamChunk + + +def run_coro_safely(coro: Coroutine[Any, Any, Any]) -> Any | asyncio.Task[Any]: + """Run a coroutine in the current event loop or a new one if none exists.""" + try: + # Attempt to retrieve the event loop associated with the current thread + loop = asyncio.get_running_loop() + + except RuntimeError: + # Start a new event loop to run the coroutine to completion + return asyncio.run(coro) + + else: + # Schedule the coroutine as a background task in the active loop + return loop.create_task(coro) + + +def hash_text(text: str, encoding: str = "utf-8") -> str: + """Generate SHA-256 hash of text content. + + Creates a cryptographic hash suitable for content identification + and deduplication purposes. + + Args: + text: Input text to hash. + encoding: Character encoding for the text. Defaults to "utf-8". + + Returns: + Hexadecimal string representation of the SHA-256 hash + (64 characters). + """ + return hashlib.sha256(text.encode(encoding)).hexdigest() + + +async def execute_stream_task( + stream_queue: asyncio.Queue, + task: asyncio.Task, + task_name: str | None = None, + output_format: Literal["str", "bytes", "chunk"] = "str", +) -> AsyncGenerator[str | bytes | StreamChunk, None]: + """Core stream flow execution logic. + + Handles streaming from a queue while monitoring the task completion. + Properly manages errors and resource cleanup. + + This async generator yields streaming data from a background task, + handling the coordination between queue-based communication and + task lifecycle management. It ensures proper cleanup even when + exceptions occur. + + Args: + stream_queue: Queue to receive StreamChunk objects from the + background task. + task: Background asyncio Task executing the flow. This task + will be monitored for completion and exceptions. + task_name: Optional flow name for logging purposes. Used in + error messages for debugging. + output_format: Output format control: + - "str": SSE-formatted string (e.g., "data:{json}\\n\\n") + - "bytes": SSE-formatted bytes for HTTP responses + - "chunk": Raw StreamChunk objects for further processing + + Yields: + - str: SSE-formatted data when output_format="str" + - bytes: SSE-formatted data when output_format="bytes" + - StreamChunk: Raw chunk objects when output_format="chunk" + + Raises: + Exception: Re-raises any exception from the background task. + """ + logger = get_logger() + try: + while True: + # Wait for next chunk or check if task failed + get_chunk = asyncio.create_task(stream_queue.get()) + done, _pending = await asyncio.wait( + {get_chunk, task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + # Priority 1: Check if main task finished (may have exception) + if task in done: + # Task finished - check for exceptions first + exc = task.exception() + if exc: + log_msg = f"Task error in {task_name}: {exc}" if task_name else f"Task error: {exc}" + logger.exception(log_msg) + raise exc + + # Task completed successfully - drain remaining chunks if any + if get_chunk in done: + chunk: StreamChunk = get_chunk.result() + if output_format == "chunk": + yield chunk + if chunk.done: + break + else: + if chunk.done: + yield b"data:[DONE]\n\n" if output_format == "bytes" else "data:[DONE]\n\n" + break + data = f"data:{chunk.model_dump_json()}\n\n" + yield data.encode() if output_format == "bytes" else data + else: + # No more chunks, task completed + get_chunk.cancel() + if output_format == "chunk": + yield StreamChunk(chunk_type=ChunkEnum.DONE, chunk="", done=True) + else: + yield b"data:[DONE]\n\n" if output_format == "bytes" else "data:[DONE]\n\n" + break + + elif get_chunk in done: + # Got a chunk from the queue (task still running) + chunk: StreamChunk = get_chunk.result() + + # Handle raw chunk mode + if output_format == "chunk": + yield chunk + if chunk.done: + break + continue + + # Handle SSE format mode (str or bytes) + if chunk.done: + yield b"data:[DONE]\n\n" if output_format == "bytes" else "data:[DONE]\n\n" + break + + data = f"data:{chunk.model_dump_json()}\n\n" + yield data.encode() if output_format == "bytes" else data + + finally: + # Ensure task is canceled if still running to avoid resource leaks + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass diff --git a/reme_cli/utils/logger_utils.py b/reme_cli/utils/logger_utils.py new file mode 100644 index 00000000..0c2eeb70 --- /dev/null +++ b/reme_cli/utils/logger_utils.py @@ -0,0 +1,86 @@ +"""Logging configuration module for application-wide tracing. + +Provides a centralized logging facility using Loguru with support for +both console and file-based output, automatic rotation, and retention. +""" + +import os +import sys +from datetime import datetime + +from loguru import logger + +_initialized = False + + +def get_logger( + log_dir: str = "logs", + level: str = "INFO", + log_to_console: bool = True, + log_to_file: bool = True, + force_init: bool = False, +): + """Get a configured logger instance. + + Automatically initializes on first call. Subsequent calls return + the same logger without re-initializing unless force_init=True. + + This function configures the global Loguru logger with: + - Colorized console output (optional) + - File output with daily rotation (optional) + - 7-day retention with ZIP compression + - Consistent timestamp and location formatting + + Args: + log_dir: Directory path for log files. Created if it doesn't exist. + Defaults to "logs" in the current working directory. + level: Logging level threshold. One of: DEBUG, INFO, WARNING, + ERROR, CRITICAL. Defaults to "INFO". + log_to_console: Whether to print logs to stdout. Defaults to True. + log_to_file: Whether to write logs to file. Defaults to True. + force_init: Force re-initialization even if already initialized. + Useful for changing log configuration mid-application. + Defaults to False. + + Returns: + The configured Loguru logger instance. + """ + global _initialized + + if _initialized and not force_init: + return logger + + # Remove default handler to avoid duplicate logs + logger.remove() + + # Configure colorized console logging if enabled + if log_to_console: + logger.add( + sink=sys.stdout, + level=level, + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {function} | {message}", + colorize=True, + ) + + # Configure file-based logging if enabled + if log_to_file: + try: + os.makedirs(log_dir, exist_ok=True) + + current_ts = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + log_filepath = os.path.join(log_dir, f"{current_ts}.log") + + logger.add( + log_filepath, + level=level, + rotation="00:00", # Rotate at midnight + retention="7 days", # Keep logs for 7 days + compression="zip", # Compress rotated logs + encoding="utf-8", + format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {file}:{line} | {function} | {message}", + ) + except Exception as e: + logger.error(f"Error configuring file logging: {e}") + + _initialized = True + return logger diff --git a/reme_cli/utils/logo_utils.py b/reme_cli/utils/logo_utils.py new file mode 100644 index 00000000..7c1475b9 --- /dev/null +++ b/reme_cli/utils/logo_utils.py @@ -0,0 +1,91 @@ +"""Terminal branding and configuration display utilities.""" + +import importlib.metadata +from typing import TYPE_CHECKING + +from rich.console import Console, Group +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +if TYPE_CHECKING: + from ..schema import ApplicationConfig + + +def get_version(package_name: str) -> str: + """Return the installed version of a package or 'unknown'.""" + try: + return importlib.metadata.version(package_name) + except importlib.metadata.PackageNotFoundError: + return "" + + +def print_logo(app_config: "ApplicationConfig"): + """Print a stylized ASCII logo and service metadata to the console.""" + ascii_art = [ + r" ██████╗ ███████╗ ███╗ ███╗ ███████╗ ", + r" ██╔══██╗ ██╔════╝ ████╗ ████║ ██╔════╝ ", + r" ██████╔╝ █████╗ ██╔████╔██║ █████╗ ", + r" ██╔══██╗ ██╔══╝ ██║╚██╔╝██║ ██╔══╝ ", + r" ██║ ██║ ███████╗ ██║ ╚═╝ ██║ ███████╗ ", + r" ╚═╝ ╚═╝ ╚══════╝ ╚═╝ ╚═╝ ╚══════╝ ", + ] + + start_color = (85, 239, 196) + end_color = (162, 155, 254) + + logo_text = Text() + for line in ascii_art: + line_len = max(1, len(line) - 1) + for i, char in enumerate(line): + # Calculate gradient shift per character + ratio = i / line_len + rgb = tuple(int(s + (e - s) * ratio) for s, e in zip(start_color, end_color)) + logo_text.append(char, style=f"bold rgb({rgb[0]},{rgb[1]},{rgb[2]})") + logo_text.append("\n") + + # Layout configuration info + info_table = Table.grid(padding=(0, 1)) + info_table.add_column(style="bold", justify="center") + info_table.add_column(style="bold cyan", justify="left") + info_table.add_column(style="white", justify="left") + + # Get service config (ComponentConfig with extra="allow") + service = app_config.service + backend = service.backend + + # Add core service info + info_table.add_row("📦", "Backend:", backend) + + match backend: + case "http": + host = service.model_extra.get("host", "localhost") if service.model_extra else "localhost" + port = service.model_extra.get("port", 8000) if service.model_extra else 8000 + info_table.add_row("🔗", "URL:", f"http://{host}:{port}") + info_table.add_row("📚", "FastAPI:", Text(get_version("fastapi"), style="dim")) + case "mcp": + transport = service.model_extra.get("transport", "stdio") if service.model_extra else "stdio" + info_table.add_row("🚌", "Transport:", transport) + if transport != "stdio": + host = service.model_extra.get("host", "localhost") if service.model_extra else "localhost" + port = service.model_extra.get("port", 8000) if service.model_extra else 8000 + url = f"http://{host}:{port}" + if transport == "sse": + url += "/sse" + info_table.add_row("🔗", "URL:", url) + info_table.add_row("📚", "FastMCP:", Text(get_version("fastmcp"), style="dim")) + + info_table.add_row("🚀", "ReMe:", Text(get_version("reme-ai"), style="dim")) + + # Render layout within a panel + panel = Panel( + Group(logo_text, info_table), + title=app_config.app_name, + title_align="left", + border_style="dim", + padding=(1, 4), + expand=False, + ) + + # use justify="center" to adjust position + Console().print(Group("\n", panel, "\n")) diff --git a/reme_cli/utils/similarity_utils.py b/reme_cli/utils/similarity_utils.py new file mode 100644 index 00000000..8c7031a0 --- /dev/null +++ b/reme_cli/utils/similarity_utils.py @@ -0,0 +1,96 @@ +"""Vector similarity computation utilities. + +Provides functions for calculating cosine similarity between vectors, +with support for both single vectors and batch operations using NumPy. +""" + +import numpy as np + + +def cosine_similarity(vec1: list[float], vec2: list[float]) -> float: + """Calculate the cosine similarity between two numeric vectors. + + Cosine similarity measures the cosine of the angle between two vectors, + returning a value between -1 (opposite) and 1 (identical direction). + + Args: + vec1: First vector as a list of floats. + vec2: Second vector as a list of floats. + + Returns: + Cosine similarity value in range [-1.0, 1.0]. + Returns 0.0 if either vector has zero magnitude. + + Raises: + ValueError: If vectors have different lengths. + + Examples: + >>> cosine_similarity([1.0, 0.0], [1.0, 0.0]) + 1.0 + >>> cosine_similarity([1.0, 0.0], [0.0, 1.0]) + 0.0 + >>> cosine_similarity([1.0, 1.0], [-1.0, -1.0]) + -1.0 + """ + if len(vec1) != len(vec2): + raise ValueError(f"Vectors must have same length: {len(vec1)} != {len(vec2)}") + + dot_product = sum(a * b for a, b in zip(vec1, vec2)) + magnitude1 = sum(a * a for a in vec1) ** 0.5 + magnitude2 = sum(b * b for b in vec2) ** 0.5 + + if magnitude1 == 0 or magnitude2 == 0: + return 0.0 + + return dot_product / (magnitude1 * magnitude2) + + +def batch_cosine_similarity(nd_array1: np.ndarray, nd_array2: np.ndarray) -> np.ndarray: + """Calculate cosine similarity matrix between two batches of vectors. + + Efficiently computes pairwise cosine similarities using matrix operations. + + Args: + nd_array1: Matrix of shape (batch_size1, emb_size) representing + the first batch of embedding vectors. + nd_array2: Matrix of shape (batch_size2, emb_size) representing + the second batch of embedding vectors. + + Returns: + Similarity matrix of shape (batch_size1, batch_size2) where + result[i, j] is the cosine similarity between nd_array1[i] and + nd_array2[j]. Values are in range [-1.0, 1.0]. + + Raises: + ValueError: If embedding dimensions don't match between arrays. + + Examples: + >>> import numpy as np + >>> arr1 = np.array([[1.0, 0.0], [0.0, 1.0]]) + >>> arr2 = np.array([[1.0, 0.0], [1.0, 1.0]]) + >>> batch_cosine_similarity(arr1, arr2) + array([[1. , 0.70710678], + [0. , 0.70710678]]) + """ + if nd_array1.shape[1] != nd_array2.shape[1]: + raise ValueError( + f"Embedding dimensions must match: {nd_array1.shape[1]} != {nd_array2.shape[1]}", + ) + + # Compute dot products: (batch_size1, emb_size) @ (emb_size, batch_size2) + # Result shape: (batch_size1, batch_size2) + dot_products = np.dot(nd_array1, nd_array2.T) + + # Compute L2 norms for each vector + norms1 = np.linalg.norm(nd_array1, axis=1) # Shape: (batch_size1,) + norms2 = np.linalg.norm(nd_array2, axis=1) # Shape: (batch_size2,) + + # Compute outer product of norms: (batch_size1, 1) @ (1, batch_size2) + # Result shape: (batch_size1, batch_size2) + norm_products = np.outer(norms1, norms2) + + # Avoid division by zero + norm_products = np.where(norm_products == 0, 1e-10, norm_products) + + # Compute cosine similarities + return dot_products / norm_products diff --git a/reme_cli/utils/singleton.py b/reme_cli/utils/singleton.py new file mode 100644 index 00000000..7503c6b3 --- /dev/null +++ b/reme_cli/utils/singleton.py @@ -0,0 +1,43 @@ +"""Singleton pattern implementation using a class decorator. + +Provides a thread-safe singleton decorator that ensures only one instance +of a decorated class exists throughout the application lifecycle. +""" + +import threading +from typing import Any, Callable, TypeVar + +T = TypeVar("T") + + +def singleton(cls: type[T]) -> Callable[..., T]: + """A class decorator that ensures only one instance of a class exists. + + Thread-safe implementation using a lock to prevent race conditions + during instance creation. + + Args: + cls: The class to decorate with singleton behavior. + + Returns: + A wrapper function that returns the single instance. + """ + _instance: dict[type[T], T] = {} + _lock = threading.Lock() + + def _singleton(*args: Any, **kwargs: Any) -> T: + """Return the existing instance or create a new one if it doesn't exist. + + Args: + *args: Positional arguments passed to the class constructor. + **kwargs: Keyword arguments passed to the class constructor. + + Returns: + The single instance of the decorated class. + """ + with _lock: + if cls not in _instance: + _instance[cls] = cls(*args, **kwargs) + return _instance[cls] + + return _singleton