Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: install format lint test all check
.PHONY: install format lint test all check proto

# Define variables
PYTHON = python3
Expand Down Expand Up @@ -30,3 +30,14 @@ test:

# Run format, lint, and test
check: format lint test

# Regenerate protobuf gencode for mirix/queue/*.proto.
# Uses grpcio-tools pinned in pyproject.toml (>=1.66.0,<1.67.0) so the
# checked-in *_pb2.py / *_pb2.pyi / *_pb2_grpc.py files are reproducible.
proto:
$(POETRY) run python -m grpc_tools.protoc \
-I. \
--python_out=. \
--pyi_out=. \
--grpc_python_out=. \
mirix/queue/message.proto
388 changes: 225 additions & 163 deletions mirix/agent/agent.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mirix/helpers/message_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ def prepare_input_message_create(message: MessageCreate, agent_id: str, **kwargs
otid=message.otid,
sender_id=message.sender_id,
group_id=message.group_id,
session_id=message.session_id,
)
32 changes: 29 additions & 3 deletions mirix/orm/message.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Optional

from sqlalchemy import JSON, ForeignKey, Index
from sqlalchemy import JSON, CheckConstraint, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, declared_attr, mapped_column, relationship

from mirix.orm.custom_columns import (
Expand All @@ -10,8 +10,12 @@
)
from mirix.orm.mixins import AgentMixin, OrganizationMixin, UserMixin
from mirix.orm.sqlalchemy_base import SqlalchemyBase
from mirix.schemas.message import Message as PydanticMessage
from mirix.schemas.message import ToolReturn
from mirix.schemas.message import (
SESSION_ID_MAX_LEN,
SESSION_ID_SQL_PATTERN,
Message as PydanticMessage,
ToolReturn,
)
from mirix.schemas.mirix_message_content import MessageContent
from mirix.schemas.mirix_message_content import TextContent as PydanticTextContent
from mirix.schemas.openai.openai import ToolCall as OpenAIToolCall
Expand All @@ -32,6 +36,21 @@ class Message(SqlalchemyBase, OrganizationMixin, UserMixin, AgentMixin):
Index("ix_messages_created_at", "created_at", "id"),
Index("ix_messages_client_user", "client_id", "user_id"),
Index("ix_messages_agent_client_user", "agent_id", "client_id", "user_id"),
# Accelerates "list messages of an agent in a session, newest first".
Index(
"ix_messages_agent_session_created_at",
"agent_id",
"session_id",
"created_at",
),
# Backstop the app-level validator: DB must never store an invalid session_id.
# Uses the Postgres `~` operator, so emit the constraint only on Postgres
# (SQLite, used for some local/test setups, has no POSIX regex operator).
# Pattern derived from mirix.schemas.message so there's one source of truth.
CheckConstraint(
f"session_id IS NULL OR session_id ~ '{SESSION_ID_SQL_PATTERN}'",
name="ck_messages_session_id_format",
).ddl_if(dialect="postgresql"),
)
__pydantic_model__ = PydanticMessage

Expand Down Expand Up @@ -78,6 +97,13 @@ class Message(SqlalchemyBase, OrganizationMixin, UserMixin, AgentMixin):
nullable=True,
doc="The id of the sender of the message, can be an identity id or agent id",
)
session_id: Mapped[Optional[str]] = mapped_column(
String(SESSION_ID_MAX_LEN),
nullable=True,
doc="Top-level conversation/session identifier for grouping messages. "
"Enforced by app validator and DB CHECK constraint "
f"(pattern {SESSION_ID_SQL_PATTERN}).",
)

# Relationships
agent: Mapped["Agent"] = relationship("Agent", back_populates="messages", lazy="selectin")
Expand Down
3 changes: 3 additions & 0 deletions mirix/queue/message.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ message MessageCreate {

// Optional multi-agent group id
optional string group_id = 7;

// Optional top-level session identifier (see mirix.schemas.message.MessageCreate.session_id).
optional string session_id = 8;
}

// List of message content parts (for structured content)
Expand Down
32 changes: 16 additions & 16 deletions mirix/queue/message_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions mirix/queue/message_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class User(_message.Message):
def __init__(self, id: _Optional[str] = ..., organization_id: _Optional[str] = ..., name: _Optional[str] = ..., status: _Optional[str] = ..., timezone: _Optional[str] = ..., created_at: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., updated_at: _Optional[_Union[_timestamp_pb2.Timestamp, _Mapping]] = ..., is_deleted: bool = ...) -> None: ...

class MessageCreate(_message.Message):
__slots__ = ("role", "text_content", "structured_content", "name", "otid", "sender_id", "group_id")
__slots__ = ("role", "text_content", "structured_content", "name", "otid", "sender_id", "group_id", "session_id")
class Role(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
__slots__ = ()
ROLE_UNSPECIFIED: _ClassVar[MessageCreate.Role]
Expand All @@ -79,14 +79,16 @@ class MessageCreate(_message.Message):
OTID_FIELD_NUMBER: _ClassVar[int]
SENDER_ID_FIELD_NUMBER: _ClassVar[int]
GROUP_ID_FIELD_NUMBER: _ClassVar[int]
SESSION_ID_FIELD_NUMBER: _ClassVar[int]
role: MessageCreate.Role
text_content: str
structured_content: MessageContentList
name: str
otid: str
sender_id: str
group_id: str
def __init__(self, role: _Optional[_Union[MessageCreate.Role, str]] = ..., text_content: _Optional[str] = ..., structured_content: _Optional[_Union[MessageContentList, _Mapping]] = ..., name: _Optional[str] = ..., otid: _Optional[str] = ..., sender_id: _Optional[str] = ..., group_id: _Optional[str] = ...) -> None: ...
session_id: str
def __init__(self, role: _Optional[_Union[MessageCreate.Role, str]] = ..., text_content: _Optional[str] = ..., structured_content: _Optional[_Union[MessageContentList, _Mapping]] = ..., name: _Optional[str] = ..., otid: _Optional[str] = ..., sender_id: _Optional[str] = ..., group_id: _Optional[str] = ..., session_id: _Optional[str] = ...) -> None: ...

class MessageContentList(_message.Message):
__slots__ = ("parts",)
Expand Down
2 changes: 2 additions & 0 deletions mirix/queue/queue_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,8 @@ async def put_messages(
proto_msg.sender_id = msg.sender_id
if msg.group_id:
proto_msg.group_id = msg.group_id
if msg.session_id:
proto_msg.session_id = msg.session_id

proto_input_messages.append(proto_msg)

Expand Down
1 change: 1 addition & 0 deletions mirix/queue/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def _convert_proto_message_to_pydantic(self, proto_msg) -> "MessageCreate":
otid=proto_msg.otid if proto_msg.HasField("otid") else None,
sender_id=proto_msg.sender_id if proto_msg.HasField("sender_id") else None,
group_id=proto_msg.group_id if proto_msg.HasField("group_id") else None,
session_id=proto_msg.session_id if proto_msg.HasField("session_id") else None,
filter_tags=None,
)

Expand Down
Loading
Loading