diff --git a/cmd/gateway.go b/cmd/gateway.go index b62bca88..82b0c682 100644 --- a/cmd/gateway.go +++ b/cmd/gateway.go @@ -8,6 +8,7 @@ import ( "os/signal" "path/filepath" "syscall" + "time" "github.com/google/uuid" @@ -18,6 +19,7 @@ import ( "github.com/nextlevelbuilder/goclaw/internal/channels" "github.com/nextlevelbuilder/goclaw/internal/channels/discord" "github.com/nextlevelbuilder/goclaw/internal/channels/feishu" + "github.com/nextlevelbuilder/goclaw/internal/channels/googlechat" slackchannel "github.com/nextlevelbuilder/goclaw/internal/channels/slack" "github.com/nextlevelbuilder/goclaw/internal/channels/telegram" "github.com/nextlevelbuilder/goclaw/internal/channels/whatsapp" @@ -309,6 +311,18 @@ func runGateway() { }() } + // Sweep orphan traces left by previous crashes (running > 1h) + if pgStores.Tracing != nil { + go func() { + n, err := pgStores.Tracing.SweepOrphanTraces(context.Background(), time.Hour) + if err != nil { + slog.Warn("orphan trace sweep failed", "error", err) + } else if n > 0 { + slog.Info("orphan trace sweep complete", "swept", n) + } + }() + } + // Redis cache: compiled via build tags. Build with 'go build -tags redis' to enable. redisClient := initRedisClient(cfg) defer shutdownRedis(redisClient) @@ -689,7 +703,7 @@ func runGateway() { if mcpMgr != nil { mcpToolLister = mcpMgr } - agentsH, skillsH, tracesH, mcpH, customToolsH, channelInstancesH, providersH, delegationsH, builtinToolsH, pendingMessagesH := wireHTTP(pgStores, cfg.Gateway.Token, msgBus, toolsReg, providerRegistry, permPE.IsOwner, gatewayAddr, mcpToolLister) + agentsH, skillsH, tracesH, mcpH, customToolsH, channelInstancesH, providersH, delegationsH, builtinToolsH, pendingMessagesH, projectsH := wireHTTP(pgStores, cfg.Gateway.Token, msgBus, toolsReg, providerRegistry, permPE.IsOwner, gatewayAddr, mcpToolLister) if agentsH != nil { server.SetAgentsHandler(agentsH) } @@ -705,6 +719,9 @@ func runGateway() { if mcpH != nil { server.SetMCPHandler(mcpH) } + if projectsH != nil { + server.SetProjectHandler(projectsH) + } if customToolsH != nil { server.SetCustomToolsHandler(customToolsH) } @@ -753,8 +770,10 @@ func runGateway() { // Supports media from any agent workspace (each agent has its own workspace from DB). server.SetFilesHandler(httpapi.NewFilesHandler(cfg.Gateway.Token)) - // Storage file management — browse/delete files under ~/.goclaw/ (excluding skills dirs). - server.SetStorageHandler(httpapi.NewStorageHandler(config.ExpandHome("~/.goclaw"), cfg.Gateway.Token)) + // Storage file management — browse/delete files under the resolved workspace directory. + // Uses GOCLAW_WORKSPACE (or default ~/.goclaw/workspace) so it works correctly + // in Docker deployments where volumes are mounted outside ~/.goclaw/. + server.SetStorageHandler(httpapi.NewStorageHandler(workspace, cfg.Gateway.Token)) // Media upload endpoint — accepts multipart file uploads, returns temp path + MIME type. server.SetMediaUploadHandler(httpapi.NewMediaUploadHandler(cfg.Gateway.Token)) @@ -808,6 +827,7 @@ func runGateway() { instanceLoader.RegisterFactory(channels.TypeZaloPersonal, zalopersonal.FactoryWithPendingStore(pgStores.PendingMessages)) instanceLoader.RegisterFactory(channels.TypeWhatsApp, whatsapp.Factory) instanceLoader.RegisterFactory(channels.TypeSlack, slackchannel.FactoryWithPendingStore(pgStores.PendingMessages)) + instanceLoader.RegisterFactory(channels.TypeGoogleChat, googlechat.FactoryWithPendingStore(pgStores.PendingMessages)) if err := instanceLoader.LoadAll(context.Background()); err != nil { slog.Error("failed to load channel instances from DB", "error", err) } @@ -819,6 +839,11 @@ func runGateway() { // Register channels/instances/links/teams RPC methods wireChannelRPCMethods(server, pgStores, channelMgr, agentRouter, msgBus) + // Register party mode WS RPC methods + if pgStores.Party != nil { + methods.NewPartyMethods(pgStores.Party, pgStores.Agents, providerRegistry, msgBus).Register(server.Router()) + } + // Wire channel event subscribers (cache invalidation, pairing, cascade disable) wireChannelEventSubscribers(msgBus, server, pgStores, channelMgr, instanceLoader, pairingMethods, cfg) @@ -1026,7 +1051,7 @@ func runGateway() { channelMgr.SetContactCollector(contactCollector) // propagate to all channel handlers } - go consumeInboundMessages(ctx, msgBus, agentRouter, cfg, sched, channelMgr, consumerTeamStore, quotaChecker, delegateMgr, pgStores.Sessions, pgStores.Agents, contactCollector) + go consumeInboundMessages(ctx, msgBus, agentRouter, cfg, sched, channelMgr, consumerTeamStore, quotaChecker, delegateMgr, pgStores.Sessions, pgStores.Agents, contactCollector, pgStores.Projects) // Task recovery ticker: re-dispatches stale/pending team tasks on startup and periodically. var taskTicker *tasks.TaskTicker diff --git a/cmd/gateway_channels_setup.go b/cmd/gateway_channels_setup.go index d0825ab5..e09bb89a 100644 --- a/cmd/gateway_channels_setup.go +++ b/cmd/gateway_channels_setup.go @@ -13,6 +13,7 @@ import ( "github.com/nextlevelbuilder/goclaw/internal/channels" "github.com/nextlevelbuilder/goclaw/internal/channels/discord" "github.com/nextlevelbuilder/goclaw/internal/channels/feishu" + "github.com/nextlevelbuilder/goclaw/internal/channels/googlechat" slackchannel "github.com/nextlevelbuilder/goclaw/internal/channels/slack" "github.com/nextlevelbuilder/goclaw/internal/channels/telegram" "github.com/nextlevelbuilder/goclaw/internal/channels/whatsapp" @@ -97,6 +98,16 @@ func registerConfigChannels(cfg *config.Config, channelMgr *channels.Manager, ms slog.Info("feishu/lark channel enabled (config)") } } + + if cfg.Channels.GoogleChat.Enabled && cfg.Channels.GoogleChat.ServiceAccountFile != "" && instanceLoader == nil { + gc, err := googlechat.New(cfg.Channels.GoogleChat, msgBus, nil) + if err != nil { + slog.Error("failed to initialize google chat channel", "error", err) + } else { + channelMgr.RegisterChannel(channels.TypeGoogleChat, gc) + slog.Info("google chat channel enabled (config)") + } + } } // wireChannelRPCMethods registers WS RPC methods for channels, instances, agent links, and teams. diff --git a/cmd/gateway_consumer.go b/cmd/gateway_consumer.go index 9f6cd542..db12ac56 100644 --- a/cmd/gateway_consumer.go +++ b/cmd/gateway_consumer.go @@ -28,7 +28,7 @@ import ( // and routes them through the scheduler/agent loop, then publishes the response back. // Also handles subagent announcements: routes them through the parent agent's session // (matching TS subagent-announce.ts pattern) so the agent can reformulate for the user. -func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents *agent.Router, cfg *config.Config, sched *scheduler.Scheduler, channelMgr *channels.Manager, teamStore store.TeamStore, quotaChecker *channels.QuotaChecker, delegateMgr *tools.DelegateManager, sessStore store.SessionStore, agentStore store.AgentStore, contactCollector *store.ContactCollector) { +func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents *agent.Router, cfg *config.Config, sched *scheduler.Scheduler, channelMgr *channels.Manager, teamStore store.TeamStore, quotaChecker *channels.QuotaChecker, delegateMgr *tools.DelegateManager, sessStore store.SessionStore, agentStore store.AgentStore, contactCollector *store.ContactCollector, projectStore store.ProjectStore) { slog.Info("inbound message consumer started") // Inbound message deduplication (matching TS src/infra/dedupe.ts + inbound-dedupe.ts). @@ -64,7 +64,11 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents } } - agentLoop, err := agents.Get(agentID) + // Resolve project for this chat (nil projectStore = backward compatible) + channelType := resolveChannelType(channelMgr, msg.Channel) + projectID, projectOverrides := resolveProjectOverrides(ctx, projectStore, channelType, msg.ChatID) + + agentLoop, err := agents.GetForProject(agentID, projectID, projectOverrides) if err != nil { slog.Warn("inbound: agent not found", "agent", agentID, "channel", msg.Channel) return @@ -321,7 +325,7 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents Media: reqMedia, ForwardMedia: fwdMedia, Channel: msg.Channel, - ChannelType: resolveChannelType(channelMgr, msg.Channel), + ChannelType: channelType, ChatID: msg.ChatID, PeerKind: peerKind, LocalKey: msg.Metadata["local_key"], @@ -333,6 +337,8 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents ToolAllow: msg.ToolAllow, ExtraSystemPrompt: extraPrompt, SkillFilter: skillFilter, + ProjectID: projectID, + ProjectOverrides: projectOverrides, }, scheduler.ScheduleOpts{ MaxConcurrent: maxConcurrent, }) @@ -456,6 +462,7 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents origPeerKind := msg.Metadata["origin_peer_kind"] origLocalKey := msg.Metadata["origin_local_key"] origChannelType := resolveChannelType(channelMgr, origChannel) + saProjectID, saProjectOverrides := resolveProjectOverrides(ctx, projectStore, origChannelType, msg.ChatID) parentAgent := msg.Metadata["parent_agent"] if parentAgent == "" { parentAgent = "default" @@ -529,6 +536,8 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents Stream: false, ParentTraceID: parentTraceID, ParentRootSpanID: parentRootSpanID, + ProjectID: saProjectID, + ProjectOverrides: saProjectOverrides, } // Handle announce asynchronously with per-session serialization. // The mutex ensures concurrent announces for the same session wait for @@ -590,6 +599,7 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents origPeerKind := msg.Metadata["origin_peer_kind"] origLocalKey := msg.Metadata["origin_local_key"] origChannelType := resolveChannelType(channelMgr, origChannel) + dlgProjectID, dlgProjectOverrides := resolveProjectOverrides(ctx, projectStore, origChannelType, msg.ChatID) parentAgent := msg.Metadata["parent_agent"] if parentAgent == "" { parentAgent = "default" @@ -660,6 +670,8 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents Stream: false, ParentTraceID: parentTraceID, ParentRootSpanID: parentRootSpanID, + ProjectID: dlgProjectID, + ProjectOverrides: dlgProjectOverrides, } // Same per-session serialization as subagent announce above. @@ -714,6 +726,7 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents origPeerKind := msg.Metadata["origin_peer_kind"] origLocalKey := msg.Metadata["origin_local_key"] origChannelType := resolveChannelType(channelMgr, origChannel) + hoProjectID, hoProjectOverrides := resolveProjectOverrides(ctx, projectStore, origChannelType, msg.ChatID) targetAgent := msg.AgentID if targetAgent == "" { targetAgent = cfg.ResolveDefaultAgentID() @@ -744,16 +757,18 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents outMeta := buildAnnounceOutMeta(origLocalKey) outCh := sched.Schedule(ctx, scheduler.LaneDelegate, agent.RunRequest{ - SessionKey: sessionKey, - Message: msg.Content, - Channel: origChannel, - ChannelType: origChannelType, - ChatID: msg.ChatID, - PeerKind: origPeerKind, - LocalKey: origLocalKey, - UserID: announceUserID, - RunID: fmt.Sprintf("handoff-%s", msg.Metadata["handoff_id"]), - Stream: false, + SessionKey: sessionKey, + Message: msg.Content, + Channel: origChannel, + ChannelType: origChannelType, + ChatID: msg.ChatID, + PeerKind: origPeerKind, + LocalKey: origLocalKey, + UserID: announceUserID, + RunID: fmt.Sprintf("handoff-%s", msg.Metadata["handoff_id"]), + Stream: false, + ProjectID: hoProjectID, + ProjectOverrides: hoProjectOverrides, }) go func(origCh, chatID string, meta map[string]string) { @@ -784,6 +799,7 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents origPeerKind := msg.Metadata["origin_peer_kind"] origLocalKey := msg.Metadata["origin_local_key"] origChannelType := resolveChannelType(channelMgr, origChannel) + tmProjectID, tmProjectOverrides := resolveProjectOverrides(ctx, projectStore, origChannelType, msg.ChatID) targetAgent := msg.AgentID // team_message sets AgentID to the target agent key if targetAgent == "" { targetAgent = cfg.ResolveDefaultAgentID() @@ -820,16 +836,18 @@ func consumeInboundMessages(ctx context.Context, msgBus *bus.MessageBus, agents outMeta := buildAnnounceOutMeta(origLocalKey) outCh := sched.Schedule(ctx, scheduler.LaneDelegate, agent.RunRequest{ - SessionKey: sessionKey, - Message: msg.Content, - Channel: origChannel, - ChannelType: origChannelType, - ChatID: msg.ChatID, - PeerKind: origPeerKind, - LocalKey: origLocalKey, - UserID: announceUserID, - RunID: fmt.Sprintf("teammate-%s-%s", msg.Metadata["from_agent"], msg.Metadata["to_agent"]), - Stream: false, + SessionKey: sessionKey, + Message: msg.Content, + Channel: origChannel, + ChannelType: origChannelType, + ChatID: msg.ChatID, + PeerKind: origPeerKind, + LocalKey: origLocalKey, + UserID: announceUserID, + RunID: fmt.Sprintf("teammate-%s-%s", msg.Metadata["from_agent"], msg.Metadata["to_agent"]), + Stream: false, + ProjectID: tmProjectID, + ProjectOverrides: tmProjectOverrides, }) go func(origCh, chatID, senderID string, meta, inMeta map[string]string) { diff --git a/cmd/gateway_consumer_process.go b/cmd/gateway_consumer_process.go index d699e4bd..cc7fd9f0 100644 --- a/cmd/gateway_consumer_process.go +++ b/cmd/gateway_consumer_process.go @@ -30,7 +30,7 @@ func makeSchedulerRunFunc(agents *agent.Router, cfg *config.Config) scheduler.Ru } } - loop, err := agents.Get(agentID) + loop, err := agents.GetForProject(agentID, req.ProjectID, req.ProjectOverrides) if err != nil { return nil, fmt.Errorf("agent %s not found: %w", agentID, err) } diff --git a/cmd/gateway_consumer_project.go b/cmd/gateway_consumer_project.go new file mode 100644 index 00000000..e772c68f --- /dev/null +++ b/cmd/gateway_consumer_project.go @@ -0,0 +1,30 @@ +package cmd + +import ( + "context" + "log/slog" + + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// resolveProjectOverrides looks up the project for a chat and returns its ID + MCP env overrides. +// Returns empty values if no project is configured (backward compatible). +func resolveProjectOverrides(ctx context.Context, projectStore store.ProjectStore, channelType, chatID string) (string, map[string]map[string]string) { + if projectStore == nil || channelType == "" || chatID == "" { + return "", nil + } + project, err := projectStore.GetProjectByChatID(ctx, channelType, chatID) + if err != nil { + slog.Warn("project.resolve_failed", "channelType", channelType, "chatID", chatID, "error", err) + return "", nil + } + if project == nil { + return "", nil + } + overrides, err := projectStore.GetMCPOverridesMap(ctx, project.ID) + if err != nil { + slog.Warn("project.overrides_failed", "project", project.Slug, "error", err) + return project.ID.String(), nil + } + return project.ID.String(), overrides +} diff --git a/cmd/gateway_consumer_project_test.go b/cmd/gateway_consumer_project_test.go new file mode 100644 index 00000000..018147fe --- /dev/null +++ b/cmd/gateway_consumer_project_test.go @@ -0,0 +1,168 @@ +package cmd + +import ( + "context" + "errors" + "testing" + + "github.com/google/uuid" + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// stubProjectStore implements store.ProjectStore for testing. +type stubProjectStore struct { + store.ProjectStore // embed interface — panics on unimplemented methods + + project *store.Project + overrides map[string]map[string]string + chatErr error + overErr error +} + +func (s *stubProjectStore) GetProjectByChatID(_ context.Context, _, _ string) (*store.Project, error) { + return s.project, s.chatErr +} + +func (s *stubProjectStore) GetMCPOverridesMap(_ context.Context, _ uuid.UUID) (map[string]map[string]string, error) { + return s.overrides, s.overErr +} + +func TestResolveProjectOverrides(t *testing.T) { + testProjectID := uuid.New() + + tests := []struct { + name string + store store.ProjectStore + channelType string + chatID string + wantProjectID string + wantOverrides map[string]map[string]string + }{ + { + name: "nil store returns empty (backward compat)", + store: nil, + channelType: "telegram", + chatID: "-100123", + wantProjectID: "", + wantOverrides: nil, + }, + { + name: "empty channelType returns empty", + store: &stubProjectStore{}, + channelType: "", + chatID: "-100123", + wantProjectID: "", + wantOverrides: nil, + }, + { + name: "empty chatID returns empty", + store: &stubProjectStore{}, + channelType: "telegram", + chatID: "", + wantProjectID: "", + wantOverrides: nil, + }, + { + name: "no project found returns empty (not an error)", + store: &stubProjectStore{ + project: nil, + chatErr: nil, + }, + channelType: "telegram", + chatID: "-100999", + wantProjectID: "", + wantOverrides: nil, + }, + { + name: "project found with overrides", + store: &stubProjectStore{ + project: &store.Project{ + BaseModel: store.BaseModel{ID: testProjectID}, + Slug: "xpos", + }, + overrides: map[string]map[string]string{ + "gitlab": {"GITLAB_PROJECT_PATH": "duhd/xpos"}, + "atlassian": {"JIRA_PROJECT_KEY": "XPOS"}, + }, + }, + channelType: "telegram", + chatID: "-100123", + wantProjectID: testProjectID.String(), + wantOverrides: map[string]map[string]string{ + "gitlab": {"GITLAB_PROJECT_PATH": "duhd/xpos"}, + "atlassian": {"JIRA_PROJECT_KEY": "XPOS"}, + }, + }, + { + name: "project found but no overrides configured", + store: &stubProjectStore{ + project: &store.Project{ + BaseModel: store.BaseModel{ID: testProjectID}, + Slug: "empty-proj", + }, + overrides: map[string]map[string]string{}, + }, + channelType: "telegram", + chatID: "-100456", + wantProjectID: testProjectID.String(), + wantOverrides: map[string]map[string]string{}, + }, + { + name: "DB error on project lookup — graceful degradation", + store: &stubProjectStore{ + chatErr: errors.New("connection refused"), + }, + channelType: "telegram", + chatID: "-100123", + wantProjectID: "", + wantOverrides: nil, + }, + { + name: "project found but overrides query fails — returns projectID only", + store: &stubProjectStore{ + project: &store.Project{ + BaseModel: store.BaseModel{ID: testProjectID}, + Slug: "xpos", + }, + overErr: errors.New("timeout"), + }, + channelType: "telegram", + chatID: "-100123", + wantProjectID: testProjectID.String(), + wantOverrides: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + gotID, gotOverrides := resolveProjectOverrides(ctx, tt.store, tt.channelType, tt.chatID) + + if gotID != tt.wantProjectID { + t.Errorf("projectID: got %q, want %q", gotID, tt.wantProjectID) + } + if tt.wantOverrides == nil { + if gotOverrides != nil { + t.Errorf("overrides: got %v, want nil", gotOverrides) + } + return + } + if len(gotOverrides) != len(tt.wantOverrides) { + t.Errorf("overrides len: got %d, want %d", len(gotOverrides), len(tt.wantOverrides)) + return + } + for server, wantEnv := range tt.wantOverrides { + gotEnv, ok := gotOverrides[server] + if !ok { + t.Errorf("missing server %q in overrides", server) + continue + } + for k, wantV := range wantEnv { + if gotV := gotEnv[k]; gotV != wantV { + t.Errorf("overrides[%q][%q]: got %q, want %q", server, k, gotV, wantV) + } + } + } + }) + } +} diff --git a/cmd/gateway_http_handlers.go b/cmd/gateway_http_handlers.go index b73b0baa..c664fa9d 100644 --- a/cmd/gateway_http_handlers.go +++ b/cmd/gateway_http_handlers.go @@ -9,8 +9,8 @@ import ( "github.com/nextlevelbuilder/goclaw/internal/tools" ) -// wireHTTP creates HTTP handlers (agents + skills + traces + MCP + custom tools + channel instances + providers + delegations + builtin tools + pending messages). -func wireHTTP(stores *store.Stores, token string, msgBus *bus.MessageBus, toolsReg *tools.Registry, providerReg *providers.Registry, isOwner func(string) bool, gatewayAddr string, mcpToolLister httpapi.MCPToolLister) (*httpapi.AgentsHandler, *httpapi.SkillsHandler, *httpapi.TracesHandler, *httpapi.MCPHandler, *httpapi.CustomToolsHandler, *httpapi.ChannelInstancesHandler, *httpapi.ProvidersHandler, *httpapi.DelegationsHandler, *httpapi.BuiltinToolsHandler, *httpapi.PendingMessagesHandler) { +// wireHTTP creates HTTP handlers (agents + skills + traces + MCP + custom tools + channel instances + providers + delegations + builtin tools + pending messages + projects). +func wireHTTP(stores *store.Stores, token string, msgBus *bus.MessageBus, toolsReg *tools.Registry, providerReg *providers.Registry, isOwner func(string) bool, gatewayAddr string, mcpToolLister httpapi.MCPToolLister) (*httpapi.AgentsHandler, *httpapi.SkillsHandler, *httpapi.TracesHandler, *httpapi.MCPHandler, *httpapi.CustomToolsHandler, *httpapi.ChannelInstancesHandler, *httpapi.ProvidersHandler, *httpapi.DelegationsHandler, *httpapi.BuiltinToolsHandler, *httpapi.PendingMessagesHandler, *httpapi.ProjectHandler) { var agentsH *httpapi.AgentsHandler var skillsH *httpapi.SkillsHandler var tracesH *httpapi.TracesHandler @@ -21,6 +21,7 @@ func wireHTTP(stores *store.Stores, token string, msgBus *bus.MessageBus, toolsR var delegationsH *httpapi.DelegationsHandler var builtinToolsH *httpapi.BuiltinToolsHandler var pendingMessagesH *httpapi.PendingMessagesHandler + var projectsH *httpapi.ProjectHandler if stores != nil && stores.Agents != nil { var summoner *httpapi.AgentSummoner @@ -75,5 +76,9 @@ func wireHTTP(stores *store.Stores, token string, msgBus *bus.MessageBus, toolsR pendingMessagesH = httpapi.NewPendingMessagesHandler(stores.PendingMessages, stores.Agents, token, providerReg) } - return agentsH, skillsH, tracesH, mcpH, customToolsH, channelInstancesH, providersH, delegationsH, builtinToolsH, pendingMessagesH + if stores != nil && stores.Projects != nil { + projectsH = httpapi.NewProjectHandler(stores.Projects, token) + } + + return agentsH, skillsH, tracesH, mcpH, customToolsH, channelInstancesH, providersH, delegationsH, builtinToolsH, pendingMessagesH, projectsH } diff --git a/cmd/gateway_managed.go b/cmd/gateway_managed.go index 2481b8fd..55bc067e 100644 --- a/cmd/gateway_managed.go +++ b/cmd/gateway_managed.go @@ -385,7 +385,7 @@ func wireExtras( // avoiding import cycle between tools and agent packages. if stores.AgentLinks != nil && stores.Agents != nil { runAgentFn := func(ctx context.Context, agentKey string, req tools.DelegateRunRequest) (*tools.DelegateRunResult, error) { - loop, err := agentRouter.Get(agentKey) + loop, err := agentRouter.GetForProject(agentKey, req.ProjectID, req.ProjectOverrides) if err != nil { return nil, err } @@ -408,6 +408,8 @@ func wireExtras( ParentAgentID: req.ParentAgentID, WorkspaceChannel: req.WorkspaceChannel, WorkspaceChatID: req.WorkspaceChatID, + ProjectID: req.ProjectID, + ProjectOverrides: req.ProjectOverrides, }) if err != nil { return nil, err diff --git a/cmd/gateway_providers.go b/cmd/gateway_providers.go index 395b4259..fb69e102 100644 --- a/cmd/gateway_providers.go +++ b/cmd/gateway_providers.go @@ -319,7 +319,8 @@ func registerProvidersFromDB(registry *providers.Registry, provStore store.Provi prov.WithProviderType(p.ProviderType) registry.Register(prov) default: - prov := providers.NewOpenAIProvider(p.Name, p.APIKey, p.APIBase, "") + defaultModel := extractDefaultModel(p.Settings) + prov := providers.NewOpenAIProvider(p.Name, p.APIKey, p.APIBase, defaultModel) prov.WithProviderType(p.ProviderType) if p.ProviderType == store.ProviderMiniMax { prov.WithChatPath("/text/chatcompletion_v2") @@ -329,3 +330,17 @@ func registerProvidersFromDB(registry *providers.Registry, provStore store.Provi slog.Info("registered provider from DB", "name", p.Name) } } + +// extractDefaultModel reads default_model from a provider's settings JSONB. +func extractDefaultModel(settings json.RawMessage) string { + if len(settings) == 0 { + return "" + } + var s struct { + DefaultModel string `json:"default_model"` + } + if json.Unmarshal(settings, &s) == nil { + return s.DefaultModel + } + return "" +} diff --git a/internal/agent/loop.go b/internal/agent/loop.go index 319e0077..6ae63bf6 100644 --- a/internal/agent/loop.go +++ b/internal/agent/loop.go @@ -86,6 +86,14 @@ func (l *Loop) runLoop(ctx context.Context, req RunRequest) (*RunResult, error) ctx = tools.WithWorkspaceChatID(ctx, req.WorkspaceChatID) } + // Project scope propagation (message arrival → delegation chain). + if req.ProjectID != "" { + ctx = tools.WithToolProjectID(ctx, req.ProjectID) + } + if req.ProjectOverrides != nil { + ctx = tools.WithToolProjectOverrides(ctx, req.ProjectOverrides) + } + // Per-user workspace isolation. // Workspace path comes from user_agent_profiles (includes channel segment // for cross-channel isolation). Cached in userWorkspaces to avoid repeated DB queries. @@ -337,7 +345,7 @@ func (l *Loop) runLoop(ctx context.Context, req RunRequest) (*RunResult, error) reminder := "[System] " + strings.Join(parts, "\n\n") messages = append(messages, providers.Message{Role: "user", Content: reminder}, - providers.Message{Role: "assistant", Content: "I see the task status. Let me handle accordingly."}, + // No assistant prefill — thinking models reject it. ) } } @@ -377,7 +385,8 @@ func (l *Loop) runLoop(ctx context.Context, req RunRequest) (*RunResult, error) // If the LLM creates tasks but forgets to spawn, inject a reminder. var teamTaskCreates int // count of team_tasks action=create calls var teamTaskSpawns int // count of spawn calls with team_task_id - var teamTaskRetried bool // only retry once to prevent infinite loops + var teamTaskRetried bool // only retry once to prevent infinite loops + var interruptRetried bool // only retry interrupted streams once // Inject retry hook so channels can update placeholder on LLM retries. ctx = providers.WithRetryHook(ctx, func(attempt, maxAttempts int, err error) { @@ -442,7 +451,7 @@ func (l *Loop) runLoop(ctx context.Context, req RunRequest) (*RunResult, error) Tools: toolDefs, Model: l.model, Options: map[string]any{ - providers.OptMaxTokens: 8192, + providers.OptMaxTokens: l.maxTokens, providers.OptTemperature: 0.7, providers.OptSessionKey: req.SessionKey, providers.OptAgentID: l.agentUUID.String(), @@ -518,6 +527,50 @@ func (l *Loop) runLoop(ctx context.Context, req RunRequest) (*RunResult, error) } } + // Truncation guard: if response was cut off (max_tokens reached) and has tool calls, + // the tool call arguments are likely incomplete/malformed. Skip execution and ask + // the LLM to re-issue with complete arguments or break into smaller parts. + if resp.FinishReason == "length" && len(resp.ToolCalls) > 0 { + slog.Warn("truncated tool calls detected", + "agent", l.id, "iteration", iteration, + "tool_calls", len(resp.ToolCalls), "max_tokens", l.maxTokens) + messages = append(messages, + providers.Message{Role: "assistant", Content: resp.Content, ToolCalls: resp.ToolCalls, + Thinking: resp.Thinking, RawAssistantContent: resp.RawAssistantContent}, + providers.Message{ + Role: "user", + Content: "[System] Your response was truncated (max_tokens reached). The last tool call had incomplete arguments. Do NOT re-issue the same large tool call. Instead, break your work into smaller steps or respond with text only.", + }, + ) + pendingMsgs = append(pendingMsgs, + providers.Message{Role: "assistant", Content: resp.Content, ToolCalls: resp.ToolCalls}, + providers.Message{Role: "user", Content: "[System] Response truncated — tool call skipped."}, + ) + continue + } + + // Interrupted stream guard: if the provider detected premature SSE termination + // (connection dropped before message_stop/[DONE]), retry once with the same context. + // The partial content is discarded since it may be incomplete/cut off mid-sentence. + if resp.FinishReason == "interrupted" && !interruptRetried { + interruptRetried = true + slog.Warn("interrupted stream detected, retrying", + "agent", l.id, "iteration", iteration, + "content_len", len(resp.Content)) + // Emit retry event so channels can update placeholder. + emitRun(AgentEvent{ + Type: protocol.AgentEventRunRetrying, + AgentID: l.id, + RunID: req.RunID, + Payload: map[string]string{ + "attempt": "1", "maxAttempts": "1", + "error": "SSE stream interrupted before completion", + }, + }) + iteration-- // don't count the failed attempt + continue + } + if resp.Usage != nil { totalUsage.PromptTokens += resp.Usage.PromptTokens totalUsage.CompletionTokens += resp.Usage.CompletionTokens diff --git a/internal/agent/loop_history.go b/internal/agent/loop_history.go index 763ff820..e3de4645 100644 --- a/internal/agent/loop_history.go +++ b/internal/agent/loop_history.go @@ -285,7 +285,7 @@ func limitHistoryTurns(msgs []providers.Message, limit int) []providers.Message // - Orphaned tool messages at start of history (after truncation) // - tool_result without matching tool_use in preceding assistant message // - assistant with tool_calls but missing tool_results -// sanitizeHistory repairs tool_use/tool_result pairing in session history. +// // Returns the cleaned messages and the number of messages that were dropped or synthesized. func sanitizeHistory(msgs []providers.Message) ([]providers.Message, int) { if len(msgs) == 0 { @@ -301,6 +301,7 @@ func sanitizeHistory(msgs []providers.Message) ([]providers.Message, int) { "tool_call_id", msgs[start].ToolCallID) dropped++ start++ + dropped++ } if start >= len(msgs) { diff --git a/internal/agent/loop_types.go b/internal/agent/loop_types.go index 79fbdae0..61204718 100644 --- a/internal/agent/loop_types.go +++ b/internal/agent/loop_types.go @@ -46,6 +46,7 @@ type Loop struct { contextWindow int maxIterations int maxToolCalls int + maxTokens int workspace string workspaceSharing *store.WorkspaceSharingConfig @@ -154,6 +155,7 @@ type LoopConfig struct { ContextWindow int MaxIterations int MaxToolCalls int + MaxTokens int Workspace string WorkspaceSharing *store.WorkspaceSharingConfig @@ -263,6 +265,7 @@ func NewLoop(cfg LoopConfig) *Loop { contextWindow: cfg.ContextWindow, maxIterations: cfg.MaxIterations, maxToolCalls: cfg.MaxToolCalls, + maxTokens: cfg.MaxTokens, workspace: cfg.Workspace, workspaceSharing: cfg.WorkspaceSharing, restrictToWs: cfg.RestrictToWs, @@ -348,6 +351,10 @@ type RunRequest struct { // Workspace scope propagation (set by delegation, read by workspace tools) WorkspaceChannel string WorkspaceChatID string + + // Project-scoped MCP env overrides (resolved at message arrival) + ProjectID string // resolved project UUID (empty = no project) + ProjectOverrides map[string]map[string]string // {serverName: {envKey: envVal}} } // RunResult is the output of a completed agent run. diff --git a/internal/agent/resolver.go b/internal/agent/resolver.go index f10f820c..6a599aaa 100644 --- a/internal/agent/resolver.go +++ b/internal/agent/resolver.go @@ -88,7 +88,7 @@ type ResolverDeps struct { // NewManagedResolver creates a ResolverFunc that builds Loops from DB agent data. // Agents are defined in Postgres, not config.json. func NewManagedResolver(deps ResolverDeps) ResolverFunc { - return func(agentKey string) (Agent, error) { + return func(agentKey string, opts ResolveOpts) (Agent, error) { ctx := context.Background() // Support lookup by UUID (e.g. from cron jobs that store agent_id as UUID) @@ -231,6 +231,10 @@ func NewManagedResolver(deps ResolverDeps) ResolverFunc { if maxIter <= 0 { maxIter = 20 } + maxTokens := ag.ParseMaxTokens() + if maxTokens <= 0 { + maxTokens = 8192 + } // Per-agent config overrides (fallback to global defaults from config.json) compactionCfg := deps.CompactionCfg @@ -292,7 +296,7 @@ func NewManagedResolver(deps ResolverDeps) ResolverFunc { mcpOpts = append(mcpOpts, mcpbridge.WithPool(deps.MCPPool)) } mcpMgr := mcpbridge.NewManager(toolsReg, mcpOpts...) - if err := mcpMgr.LoadForAgent(ctx, ag.ID, ""); err != nil { + if err := mcpMgr.LoadForAgent(ctx, ag.ID, "", opts.ProjectID, opts.ProjectOverrides); err != nil { slog.Warn("failed to load MCP servers for agent", "agent", agentKey, "error", err) } else if mcpMgr.IsSearchMode() { // Search mode: too many tools — register mcp_tool_search meta-tool @@ -357,6 +361,7 @@ func NewManagedResolver(deps ResolverDeps) ResolverFunc { Model: ag.Model, ContextWindow: contextWindow, MaxIterations: maxIter, + MaxTokens: maxTokens, Workspace: workspace, RestrictToWs: &restrictVal, SubagentsCfg: ag.ParseSubagentsConfig(), diff --git a/internal/agent/router.go b/internal/agent/router.go index f3e251b7..14336642 100644 --- a/internal/agent/router.go +++ b/internal/agent/router.go @@ -7,9 +7,18 @@ import ( "time" ) +// ResolveOpts carries per-request context for the resolver (e.g., project scoping). +// When ProjectID is non-empty the resolver creates a Loop whose MCP connections +// are scoped to that project, preventing cross-project contamination in shared +// agents like sdlc-assistant. +type ResolveOpts struct { + ProjectID string + ProjectOverrides map[string]map[string]string +} + // ResolverFunc is called when an agent isn't found in the cache. // Used to lazy-create agents from DB. -type ResolverFunc func(agentKey string) (Agent, error) +type ResolverFunc func(agentKey string, opts ResolveOpts) (Agent, error) const defaultRouterTTL = 10 * time.Minute @@ -83,7 +92,7 @@ func (r *Router) Get(agentID string) (Agent, error) { // Try resolver (create from DB) if resolver != nil { - ag, err := resolver(agentID) + ag, err := resolver(agentID, ResolveOpts{}) if err != nil { return nil, err } @@ -101,6 +110,50 @@ func (r *Router) Get(agentID string) (Agent, error) { return nil, fmt.Errorf("agent not found: %s", agentID) } +// GetForProject returns an agent Loop scoped to a specific project. +// Cache key: "agentID:projectID". If projectID is empty, falls back to Get(agentID). +func (r *Router) GetForProject(agentID, projectID string, projectOverrides map[string]map[string]string) (Agent, error) { + if projectID == "" { + return r.Get(agentID) + } + cacheKey := agentID + ":" + projectID + + r.mu.RLock() + entry, ok := r.agents[cacheKey] + resolver := r.resolver + r.mu.RUnlock() + + if ok && (r.ttl == 0 || time.Since(entry.cachedAt) < r.ttl) { + return entry.agent, nil + } + if ok { + r.mu.Lock() + delete(r.agents, cacheKey) + r.mu.Unlock() + } + if resolver == nil { + return nil, fmt.Errorf("agent not found: %s", agentID) + } + + ag, err := resolver(agentID, ResolveOpts{ + ProjectID: projectID, + ProjectOverrides: projectOverrides, + }) + if err != nil { + return nil, err + } + + r.mu.Lock() + // Double-check: another goroutine might have created it + if existing, ok := r.agents[cacheKey]; ok { + r.mu.Unlock() + return existing.agent, nil + } + r.agents[cacheKey] = &agentEntry{agent: ag, cachedAt: time.Now()} + r.mu.Unlock() + return ag, nil +} + // Remove removes an agent from the router. func (r *Router) Remove(agentID string) { r.mu.Lock() diff --git a/internal/agent/router_project_test.go b/internal/agent/router_project_test.go new file mode 100644 index 00000000..581016ab --- /dev/null +++ b/internal/agent/router_project_test.go @@ -0,0 +1,100 @@ +package agent + +import ( + "context" + "testing" +) + +// mockAgent implements the Agent interface for testing. +type mockAgent struct { + id string +} + +func (m *mockAgent) ID() string { return m.id } +func (m *mockAgent) Run(_ context.Context, _ RunRequest) (*RunResult, error) { return nil, nil } +func (m *mockAgent) IsRunning() bool { return false } +func (m *mockAgent) Model() string { return "test-model" } +func (m *mockAgent) ProviderName() string { return "test-provider" } + +func TestGetForProject_EmptyProjectFallsBackToGet(t *testing.T) { + r := NewRouter() + r.SetResolver(func(agentKey string, opts ResolveOpts) (Agent, error) { + return &mockAgent{id: agentKey}, nil + }) + + ag, err := r.GetForProject("sdlc-assistant", "", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if ag.ID() != "sdlc-assistant" { + t.Errorf("got agent ID %q, want 'sdlc-assistant'", ag.ID()) + } +} + +func TestGetForProject_DifferentProjectsSeparateCache(t *testing.T) { + callCount := 0 + + r := NewRouter() + r.SetResolver(func(agentKey string, opts ResolveOpts) (Agent, error) { + callCount++ + return &mockAgent{id: agentKey + ":" + opts.ProjectID}, nil + }) + + ag1, err := r.GetForProject("sdlc-assistant", "uuid-xpos", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + ag2, err := r.GetForProject("sdlc-assistant", "uuid-payment", nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if ag1.ID() == ag2.ID() { + t.Error("different projects should get different cached agents") + } + if callCount != 2 { + t.Errorf("resolver should be called twice (once per project), got %d", callCount) + } +} + +func TestGetForProject_SameProjectUsesCache(t *testing.T) { + callCount := 0 + + r := NewRouter() + r.SetResolver(func(agentKey string, opts ResolveOpts) (Agent, error) { + callCount++ + return &mockAgent{id: agentKey}, nil + }) + + _, _ = r.GetForProject("sdlc-assistant", "uuid-xpos", nil) + _, _ = r.GetForProject("sdlc-assistant", "uuid-xpos", nil) + + if callCount != 1 { + t.Errorf("resolver should be called once (cached), got %d", callCount) + } +} + +func TestGetForProject_NoProjectAndWithProject_SeparateCache(t *testing.T) { + callCount := 0 + + r := NewRouter() + r.SetResolver(func(agentKey string, opts ResolveOpts) (Agent, error) { + callCount++ + suffix := "" + if opts.ProjectID != "" { + suffix = ":" + opts.ProjectID + } + return &mockAgent{id: agentKey + suffix}, nil + }) + + ag1, _ := r.GetForProject("sdlc-assistant", "", nil) + ag2, _ := r.GetForProject("sdlc-assistant", "uuid-xpos", nil) + + if ag1.ID() == ag2.ID() { + t.Error("no-project and with-project should get different agents") + } + if callCount != 2 { + t.Errorf("expected 2 resolver calls, got %d", callCount) + } +} diff --git a/internal/channels/channel.go b/internal/channels/channel.go index def6a188..cd16ec56 100644 --- a/internal/channels/channel.go +++ b/internal/channels/channel.go @@ -58,6 +58,7 @@ const ( TypeWhatsApp = "whatsapp" TypeZaloOA = "zalo_oa" TypeZaloPersonal = "zalo_personal" + TypeGoogleChat = "google_chat" ) // Channel defines the interface that all channel implementations must satisfy. diff --git a/internal/channels/googlechat/auth.go b/internal/channels/googlechat/auth.go new file mode 100644 index 00000000..a3bc50d0 --- /dev/null +++ b/internal/channels/googlechat/auth.go @@ -0,0 +1,170 @@ +package googlechat + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "sync" + "time" +) + +type ServiceAccountAuth struct { + email string + privateKey *rsa.PrivateKey + scopes []string + token string + expiresAt time.Time + mu sync.Mutex + tokenEndpoint string + httpClient *http.Client +} + +type serviceAccountFile struct { + Type string `json:"type"` + ClientEmail string `json:"client_email"` + PrivateKey string `json:"private_key"` + TokenURI string `json:"token_uri"` +} + +func NewServiceAccountAuth(saFilePath string, scopes []string) (*ServiceAccountAuth, error) { + data, err := os.ReadFile(saFilePath) + if err != nil { + return nil, fmt.Errorf("read service account file: %w", err) + } + + var sa serviceAccountFile + if err := json.Unmarshal(data, &sa); err != nil { + return nil, fmt.Errorf("parse service account file: %w", err) + } + if sa.ClientEmail == "" { + return nil, fmt.Errorf("service account file missing client_email") + } + if sa.PrivateKey == "" { + return nil, fmt.Errorf("service account file missing private_key") + } + + block, _ := pem.Decode([]byte(sa.PrivateKey)) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block from private_key") + } + + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + rsaKey, err2 := x509.ParsePKCS1PrivateKey(block.Bytes) + if err2 != nil { + return nil, fmt.Errorf("parse private key: %w (pkcs1: %w)", err, err2) + } + key = rsaKey + } + + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private key is not RSA") + } + + ep := sa.TokenURI + if ep == "" { + ep = tokenEndpoint + } + + return &ServiceAccountAuth{ + email: sa.ClientEmail, + privateKey: rsaKey, + scopes: scopes, + tokenEndpoint: ep, + httpClient: &http.Client{Timeout: 10 * time.Second}, + }, nil +} + +func (a *ServiceAccountAuth) Token(ctx context.Context) (string, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if a.token != "" && time.Now().Add(60*time.Second).Before(a.expiresAt) { + return a.token, nil + } + + now := time.Now() + claims := map[string]any{ + "iss": a.email, + "scope": strings.Join(a.scopes, " "), + "aud": tokenEndpoint, + "iat": now.Unix(), + "exp": now.Add(time.Hour).Unix(), + } + + signedJWT, err := signJWT(a.privateKey, claims) + if err != nil { + return "", fmt.Errorf("sign JWT: %w", err) + } + + form := url.Values{ + "grant_type": {"urn:ietf:params:oauth:grant-type:jwt-bearer"}, + "assertion": {signedJWT}, + } + + req, err := http.NewRequestWithContext(ctx, "POST", a.tokenEndpoint, strings.NewReader(form.Encode())) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + resp, err := a.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("token exchange request: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("token exchange failed (%d): %s", resp.StatusCode, string(body)) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + TokenType string `json:"token_type"` + } + if err := json.Unmarshal(body, &tokenResp); err != nil { + return "", fmt.Errorf("parse token response: %w", err) + } + + a.token = tokenResp.AccessToken + a.expiresAt = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + + return a.token, nil +} + +func signJWT(key *rsa.PrivateKey, claims map[string]any) (string, error) { + header := base64URLEncode([]byte(`{"alg":"RS256","typ":"JWT"}`)) + payload, err := json.Marshal(claims) + if err != nil { + return "", err + } + payloadEnc := base64URLEncode(payload) + signingInput := header + "." + payloadEnc + + hash := sha256.Sum256([]byte(signingInput)) + sig, err := rsa.SignPKCS1v15(rand.Reader, key, crypto.SHA256, hash[:]) + if err != nil { + return "", err + } + + return signingInput + "." + base64URLEncode(sig), nil +} + +func base64URLEncode(data []byte) string { + return base64.RawURLEncoding.EncodeToString(data) +} diff --git a/internal/channels/googlechat/auth_test.go b/internal/channels/googlechat/auth_test.go new file mode 100644 index 00000000..5eb555f6 --- /dev/null +++ b/internal/channels/googlechat/auth_test.go @@ -0,0 +1,145 @@ +package googlechat + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "testing" + "time" +) + +func testServiceAccountJSON(t *testing.T, dir string) (string, *rsa.PrivateKey) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatal(err) + } + pkcs8, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + t.Fatal(err) + } + pemBlock := pem.EncodeToMemory(&pem.Block{Type: "PRIVATE KEY", Bytes: pkcs8}) + + sa := map[string]string{ + "type": "service_account", + "client_email": "test@test.iam.gserviceaccount.com", + "private_key": string(pemBlock), + "token_uri": "https://oauth2.googleapis.com/token", + } + data, _ := json.Marshal(sa) + path := filepath.Join(dir, "sa.json") + if err := os.WriteFile(path, data, 0600); err != nil { + t.Fatal(err) + } + return path, key +} + +func TestNewServiceAccountAuth_ValidFile(t *testing.T) { + dir := t.TempDir() + path, _ := testServiceAccountJSON(t, dir) + auth, err := NewServiceAccountAuth(path, []string{scopeChat}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if auth.email != "test@test.iam.gserviceaccount.com" { + t.Errorf("email = %q, want test@test.iam.gserviceaccount.com", auth.email) + } +} + +func TestNewServiceAccountAuth_InvalidFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "bad.json") + os.WriteFile(path, []byte("{bad json"), 0600) + _, err := NewServiceAccountAuth(path, []string{scopeChat}) + if err == nil { + t.Fatal("expected error for invalid JSON") + } +} + +func TestNewServiceAccountAuth_MissingFile(t *testing.T) { + _, err := NewServiceAccountAuth("/nonexistent/sa.json", []string{scopeChat}) + if err == nil { + t.Fatal("expected error for missing file") + } +} + +func TestServiceAccountAuth_Token_CachesWithinTTL(t *testing.T) { + dir := t.TempDir() + path, _ := testServiceAccountJSON(t, dir) + callCount := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok-123", + "expires_in": 3600, + "token_type": "Bearer", + }) + })) + defer ts.Close() + + auth, err := NewServiceAccountAuth(path, []string{scopeChat}) + if err != nil { + t.Fatal(err) + } + auth.tokenEndpoint = ts.URL + + ctx := context.Background() + tok1, _ := auth.Token(ctx) + tok2, _ := auth.Token(ctx) + if tok1 != tok2 { + t.Errorf("tokens differ") + } + if callCount != 1 { + t.Errorf("callCount = %d, want 1", callCount) + } +} + +func TestServiceAccountAuth_Token_RefreshesExpired(t *testing.T) { + dir := t.TempDir() + path, _ := testServiceAccountJSON(t, dir) + callCount := 0 + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + callCount++ + json.NewEncoder(w).Encode(map[string]any{ + "access_token": "tok", + "expires_in": 1, + "token_type": "Bearer", + }) + })) + defer ts.Close() + + auth, err := NewServiceAccountAuth(path, []string{scopeChat}) + if err != nil { + t.Fatal(err) + } + auth.tokenEndpoint = ts.URL + auth.Token(context.Background()) + auth.expiresAt = time.Now().Add(-1 * time.Minute) + auth.Token(context.Background()) + if callCount != 2 { + t.Errorf("callCount = %d, want 2", callCount) + } +} + +func TestServiceAccountAuth_Token_RefreshFailure(t *testing.T) { + dir := t.TempDir() + path, _ := testServiceAccountJSON(t, dir) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(500) + })) + defer ts.Close() + + auth, _ := NewServiceAccountAuth(path, []string{scopeChat}) + auth.tokenEndpoint = ts.URL + _, err := auth.Token(context.Background()) + if err == nil { + t.Fatal("expected error on 500") + } +} diff --git a/internal/channels/googlechat/constants.go b/internal/channels/googlechat/constants.go new file mode 100644 index 00000000..f2e5f7db --- /dev/null +++ b/internal/channels/googlechat/constants.go @@ -0,0 +1,26 @@ +package googlechat + +import "time" + +const ( + typeGoogleChat = "google_chat" + chatAPIBase = "https://chat.googleapis.com/v1" + pubsubAPIBase = "https://pubsub.googleapis.com/v1" + driveUploadBase = "https://www.googleapis.com/upload/drive/v3" + driveAPIBase = "https://www.googleapis.com/drive/v3" + tokenEndpoint = "https://oauth2.googleapis.com/token" + googleChatMaxMessageBytes = 3900 + longFormThresholdDefault = 6000 + dedupTTL = 5 * time.Minute + defaultPullInterval = 1 * time.Second + defaultPullMaxMessages = 10 + defaultMediaMaxMB = 20 + defaultFileRetentionDays = 7 + shutdownDrainTimeout = 5 * time.Second + scopeChat = "https://www.googleapis.com/auth/chat.bot" + scopePubSub = "https://www.googleapis.com/auth/pubsub" + scopeDrive = "https://www.googleapis.com/auth/drive.file" + retrySendMaxAttempts = 5 + retrySendBaseDelay = 1 * time.Second + retrySendMaxDelay = 30 * time.Second +) diff --git a/internal/channels/googlechat/factory.go b/internal/channels/googlechat/factory.go new file mode 100644 index 00000000..5afbff7d --- /dev/null +++ b/internal/channels/googlechat/factory.go @@ -0,0 +1,118 @@ +package googlechat + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/google/uuid" + "github.com/nextlevelbuilder/goclaw/internal/bus" + "github.com/nextlevelbuilder/goclaw/internal/channels" + "github.com/nextlevelbuilder/goclaw/internal/config" + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// googleChatCreds maps the credentials JSON from the channel_instances table. +type googleChatCreds struct { + ServiceAccountJSON json.RawMessage `json:"service_account_json"` // embedded SA key JSON +} + +// googleChatInstanceConfig maps the non-secret config JSONB from the channel_instances table. +type googleChatInstanceConfig struct { + Mode string `json:"mode,omitempty"` + ProjectID string `json:"project_id,omitempty"` + SubscriptionID string `json:"subscription_id,omitempty"` + PullIntervalMs int `json:"pull_interval_ms,omitempty"` + BotUser string `json:"bot_user,omitempty"` + DMPolicy string `json:"dm_policy,omitempty"` + GroupPolicy string `json:"group_policy,omitempty"` + AllowFrom []string `json:"allow_from,omitempty"` + LongFormThreshold int `json:"long_form_threshold,omitempty"` + LongFormFormat string `json:"long_form_format,omitempty"` + MediaMaxMB int `json:"media_max_mb,omitempty"` + DrivePermission string `json:"drive_permission,omitempty"` + BlockReply *bool `json:"block_reply,omitempty"` +} + +// FactoryWithPendingStore returns a ChannelFactory that includes the pending message store. +func FactoryWithPendingStore(pendingStore store.PendingMessageStore) channels.ChannelFactory { + return func(name string, creds json.RawMessage, cfg json.RawMessage, + msgBus *bus.MessageBus, pairingSvc store.PairingStore) (channels.Channel, error) { + return buildChannel(name, creds, cfg, msgBus, pendingStore) + } +} + +// Factory creates a Google Chat channel from DB instance data (no pending store). +func Factory(name string, creds json.RawMessage, cfg json.RawMessage, + msgBus *bus.MessageBus, _ store.PairingStore) (channels.Channel, error) { + return buildChannel(name, creds, cfg, msgBus, nil) +} + +func buildChannel(name string, creds json.RawMessage, cfg json.RawMessage, + msgBus *bus.MessageBus, pendingStore store.PendingMessageStore) (channels.Channel, error) { + + var c googleChatCreds + if len(creds) > 0 { + if err := json.Unmarshal(creds, &c); err != nil { + return nil, fmt.Errorf("decode googlechat credentials: %w", err) + } + } + + var ic googleChatInstanceConfig + if len(cfg) > 0 { + if err := json.Unmarshal(cfg, &ic); err != nil { + return nil, fmt.Errorf("decode googlechat config: %w", err) + } + } + + if len(c.ServiceAccountJSON) == 0 { + return nil, fmt.Errorf("googlechat: service_account_json is required in credentials") + } + + // Write SA JSON to a temp file for NewServiceAccountAuth (it reads from file path). + saFile, err := writeTempSAFile(c.ServiceAccountJSON) + if err != nil { + return nil, fmt.Errorf("googlechat: write SA temp file: %w", err) + } + + gcCfg := config.GoogleChatConfig{ + Enabled: true, + ServiceAccountFile: saFile, + Mode: ic.Mode, + ProjectID: ic.ProjectID, + SubscriptionID: ic.SubscriptionID, + PullIntervalMs: ic.PullIntervalMs, + BotUser: ic.BotUser, + DMPolicy: ic.DMPolicy, + GroupPolicy: ic.GroupPolicy, + AllowFrom: ic.AllowFrom, + LongFormThreshold: ic.LongFormThreshold, + LongFormFormat: ic.LongFormFormat, + MediaMaxMB: ic.MediaMaxMB, + DrivePermission: ic.DrivePermission, + BlockReply: ic.BlockReply, + } + + // DB instances default to "allowlist" for groups. + if gcCfg.GroupPolicy == "" { + gcCfg.GroupPolicy = "allowlist" + } + + ch, err := New(gcCfg, msgBus, pendingStore) + if err != nil { + return nil, err + } + + ch.SetName(name) + return ch, nil +} + +// writeTempSAFile writes the SA JSON to a temp file and returns the path. +func writeTempSAFile(saJSON json.RawMessage) (string, error) { + tmpPath := filepath.Join(os.TempDir(), "goclaw-sa-"+uuid.New().String()+".json") + if err := os.WriteFile(tmpPath, saJSON, 0600); err != nil { + return "", err + } + return tmpPath, nil +} diff --git a/internal/channels/googlechat/format.go b/internal/channels/googlechat/format.go new file mode 100644 index 00000000..ca1ffd8a --- /dev/null +++ b/internal/channels/googlechat/format.go @@ -0,0 +1,290 @@ +package googlechat + +import ( + "regexp" + "strings" + "unicode/utf8" +) + +const ( + codePlaceholder = "\x00" + boldMarker = "\x01" // temp marker for bold * to avoid italic conversion +) + +var ( + reCodeBlock = regexp.MustCompile("(?s)(```[\\s\\S]*?```)") + reCodeInline = regexp.MustCompile("(`[^`]+`)") + reBoldItalic = regexp.MustCompile(`\*\*\*(.+?)\*\*\*`) + reBold = regexp.MustCompile(`\*\*(.+?)\*\*`) + reStrike = regexp.MustCompile(`~~(.+?)~~`) + reLink = regexp.MustCompile(`\[([^\]]+)\]\(([^)]+)\)`) + reTable = regexp.MustCompile(`(?m)^\|.+\|$\n^\|[-| :]+\|$`) + reLongCodeBlock = regexp.MustCompile("(?s)```[\\w]*\\n(.{500,}?)\\n```") +) + +func markdownToGoogleChat(text string) string { + if text == "" { + return "" + } + + var codeBlocks []string + protected := reCodeBlock.ReplaceAllStringFunc(text, func(match string) string { + codeBlocks = append(codeBlocks, match) + return codePlaceholder + }) + var inlineCodes []string + protected = reCodeInline.ReplaceAllStringFunc(protected, func(match string) string { + inlineCodes = append(inlineCodes, match) + return codePlaceholder + }) + + // Bold+italic: ***text*** → *_text_* (use boldMarker to protect from italic converter) + protected = reBoldItalic.ReplaceAllString(protected, boldMarker+"_${1}_"+boldMarker) + // Bold: **text** → *text* (use boldMarker to protect from italic converter) + protected = reBold.ReplaceAllString(protected, boldMarker+"${1}"+boldMarker) + // Italic: *text* → _text_ (only matches unprotected single *) + protected = convertItalic(protected) + // Restore boldMarker → * + protected = strings.ReplaceAll(protected, boldMarker, "*") + protected = reStrike.ReplaceAllString(protected, "~${1}~") + protected = reLink.ReplaceAllString(protected, "<${2}|${1}>") + + codeIdx := 0 + inlineIdx := 0 + var result strings.Builder + for _, r := range protected { + if string(r) == codePlaceholder { + if codeIdx < len(codeBlocks) { + result.WriteString(codeBlocks[codeIdx]) + codeIdx++ + } else if inlineIdx < len(inlineCodes) { + result.WriteString(inlineCodes[inlineIdx]) + inlineIdx++ + } + } else { + result.WriteRune(r) + } + } + + return result.String() +} + +func convertItalic(s string) string { + var result strings.Builder + runes := []rune(s) + i := 0 + for i < len(runes) { + if runes[i] == '*' { + prevStar := i > 0 && runes[i-1] == '*' + nextStar := i+1 < len(runes) && runes[i+1] == '*' + if !prevStar && !nextStar { + end := -1 + for j := i + 1; j < len(runes); j++ { + if runes[j] == '*' { + nextJ := j+1 < len(runes) && runes[j+1] == '*' + prevJ := j > 0 && runes[j-1] == '*' + if !nextJ && !prevJ { + end = j + break + } + } + } + if end > 0 { + result.WriteRune('_') + result.WriteString(string(runes[i+1 : end])) + result.WriteRune('_') + i = end + 1 + continue + } + } + } + result.WriteRune(runes[i]) + i++ + } + return result.String() +} + +func detectStructuredContent(text string) bool { + return reTable.MatchString(text) || reLongCodeBlock.MatchString(text) +} + +func chunkByBytes(text string, maxBytes int) []string { + if text == "" { + return nil + } + if len([]byte(text)) <= maxBytes { + return []string{text} + } + + var chunks []string + paragraphs := strings.Split(text, "\n\n") + if len(paragraphs) > 1 { + var current strings.Builder + for i, p := range paragraphs { + sep := "" + if i > 0 { + sep = "\n\n" + } + candidate := current.String() + sep + p + if len([]byte(candidate)) > maxBytes && current.Len() > 0 { + chunks = append(chunks, current.String()) + current.Reset() + current.WriteString(p) + } else { + if current.Len() > 0 { + current.WriteString(sep) + } + current.WriteString(p) + } + } + if current.Len() > 0 { + remaining := current.String() + if len([]byte(remaining)) > maxBytes { + chunks = append(chunks, chunkByLines(remaining, maxBytes)...) + } else { + chunks = append(chunks, remaining) + } + } + return chunks + } + + return chunkByLines(text, maxBytes) +} + +func chunkByLines(text string, maxBytes int) []string { + lines := strings.Split(text, "\n") + if len(lines) <= 1 { + return chunkByWords(text, maxBytes) + } + + var chunks []string + var current strings.Builder + for i, line := range lines { + sep := "" + if i > 0 { + sep = "\n" + } + candidate := current.String() + sep + line + if len([]byte(candidate)) > maxBytes && current.Len() > 0 { + chunks = append(chunks, current.String()) + current.Reset() + if len([]byte(line)) > maxBytes { + chunks = append(chunks, chunkByWords(line, maxBytes)...) + } else { + current.WriteString(line) + } + } else { + if current.Len() > 0 { + current.WriteString(sep) + } + current.WriteString(line) + } + } + if current.Len() > 0 { + chunks = append(chunks, current.String()) + } + return chunks +} + +func chunkByWords(text string, maxBytes int) []string { + words := strings.Fields(text) + if len(words) == 0 { + return []string{text} + } + + var chunks []string + var current strings.Builder + for _, word := range words { + sep := "" + if current.Len() > 0 { + sep = " " + } + candidate := current.String() + sep + word + if len([]byte(candidate)) > maxBytes && current.Len() > 0 { + chunks = append(chunks, current.String()) + current.Reset() + if len([]byte(word)) > maxBytes { + chunks = append(chunks, splitAtUTF8Boundary(word, maxBytes)...) + } else { + current.WriteString(word) + } + } else { + if current.Len() > 0 { + current.WriteString(sep) + } + current.WriteString(word) + } + } + if current.Len() > 0 { + chunks = append(chunks, current.String()) + } + return chunks +} + +func splitAtUTF8Boundary(word string, maxBytes int) []string { + var chunks []string + b := []byte(word) + for len(b) > 0 { + end := maxBytes + if end > len(b) { + end = len(b) + } + for end > 0 && !utf8.Valid(b[:end]) { + end-- + } + if end == 0 { + end = 1 + } + chunks = append(chunks, string(b[:end])) + b = b[end:] + } + return chunks +} + +func extractSummary(content string) string { + lines := strings.Split(content, "\n") + if len(lines) == 0 { + return content + } + + var heading string + var bullets []string + var textLines []string + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if strings.HasPrefix(trimmed, "# ") && heading == "" { + heading = strings.TrimPrefix(trimmed, "# ") + continue + } + if strings.HasPrefix(trimmed, "## ") && heading != "" { + break + } + if strings.HasPrefix(trimmed, "- ") || strings.HasPrefix(trimmed, "* ") { + if len(bullets) < 3 { + bullets = append(bullets, trimmed) + } + continue + } + if trimmed != "" && heading != "" && len(bullets) == 0 { + textLines = append(textLines, trimmed) + } + } + + var parts []string + if heading != "" { + parts = append(parts, heading) + } + if len(textLines) > 0 && len(bullets) == 0 { + parts = append(parts, strings.Join(textLines, "\n")) + } + if len(bullets) > 0 { + parts = append(parts, strings.Join(bullets, "\n")) + } + + result := strings.Join(parts, "\n\n") + if result == "" { + return content + } + return result +} diff --git a/internal/channels/googlechat/format_test.go b/internal/channels/googlechat/format_test.go new file mode 100644 index 00000000..a7497034 --- /dev/null +++ b/internal/channels/googlechat/format_test.go @@ -0,0 +1,118 @@ +package googlechat + +import ( + "strings" + "testing" +) + +func TestMarkdownToGoogleChat(t *testing.T) { + tests := []struct { + name, input, want string + }{ + {"bold", "**hello**", "*hello*"}, + {"italic", "*hello*", "_hello_"}, + {"strikethrough", "~~deleted~~", "~deleted~"}, + {"code inline", "`code`", "`code`"}, + {"code block", "```go\nfunc(){}\n```", "```go\nfunc(){}\n```"}, + {"mixed", "**bold** and *italic*", "*bold* and _italic_"}, + {"link", "[text](https://example.com)", ""}, + {"nested bold+italic", "***both***", "*_both_*"}, + {"empty", "", ""}, + {"plain text", "plain text", "plain text"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := markdownToGoogleChat(tt.input) + if got != tt.want { + t.Errorf("markdownToGoogleChat(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestDetectStructuredContent(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"has table", "| col1 | col2 |\n|---|---|\n| a | b |", true}, + {"has long code block", "```\n" + string(make([]byte, 600)) + "\n```", true}, + {"short code block", "```\nshort\n```", false}, + {"plain text", "Hello world", false}, + {"inline code only", "`code` in text", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := detectStructuredContent(tt.input); got != tt.want { + t.Errorf("detectStructuredContent() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestChunkByBytes(t *testing.T) { + tests := []struct { + name string + input string + maxBytes int + wantCount int + }{ + {"under limit", "hello", googleChatMaxMessageBytes, 1}, + {"empty", "", googleChatMaxMessageBytes, 0}, + {"over limit paragraph split", "para one\n\npara two\n\npara three", 20, 2}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + chunks := chunkByBytes(tt.input, tt.maxBytes) + if len(chunks) != tt.wantCount { + t.Errorf("chunkByBytes() returned %d chunks, want %d", len(chunks), tt.wantCount) + } + for i, c := range chunks { + if len([]byte(c)) > tt.maxBytes { + t.Errorf("chunk[%d] = %d bytes, exceeds max %d", i, len([]byte(c)), tt.maxBytes) + } + } + }) + } +} + +func TestChunkByBytes_Unicode(t *testing.T) { + vn := "Đây là một đoạn văn bản tiếng Việt dài để kiểm tra việc chia chunk theo byte" + chunks := chunkByBytes(vn, 50) + if len(chunks) < 2 { + t.Fatalf("expected multiple chunks, got %d", len(chunks)) + } + for i, c := range chunks { + if len([]byte(c)) > 50 { + t.Errorf("chunk[%d] = %d bytes, exceeds 50", i, len([]byte(c))) + } + if c == "" { + t.Errorf("chunk[%d] is empty", i) + } + } + // Verify all words are preserved across chunks. + reassembled := strings.Join(chunks, " ") + if reassembled != vn { + t.Errorf("reassembled text doesn't match original:\ngot: %q\nwant: %q", reassembled, vn) + } +} + +func TestExtractSummary(t *testing.T) { + tests := []struct { + name, input, want string + }{ + {"heading + bullets", "# Title\n- A\n- B\n- C\n- D\n- E", "Title\n\n- A\n- B\n- C"}, + {"no heading", "- A\n- B\n- C\n- D", "- A\n- B\n- C"}, + {"very short", "Hello", "Hello"}, + {"only heading", "# Title", "Title"}, + {"multiple headings", "# H1\ntext here\n## H2\nmore text", "H1\n\ntext here"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := extractSummary(tt.input); got != tt.want { + t.Errorf("extractSummary(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} diff --git a/internal/channels/googlechat/googlechat.go b/internal/channels/googlechat/googlechat.go new file mode 100644 index 00000000..806698a4 --- /dev/null +++ b/internal/channels/googlechat/googlechat.go @@ -0,0 +1,222 @@ +package googlechat + +import ( + "context" + "log/slog" + "net/http" + "sync" + "time" + + "github.com/nextlevelbuilder/goclaw/internal/bus" + "github.com/nextlevelbuilder/goclaw/internal/channels" + "github.com/nextlevelbuilder/goclaw/internal/config" + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// Channel implements channels.Channel, channels.BlockReplyChannel, and +// channels.PendingCompactable for Google Chat via Pub/Sub pull (phase 1). +type Channel struct { + *channels.BaseChannel + + // Auth + auth *ServiceAccountAuth + + // Pub/Sub config + projectID string + subscriptionID string + pullInterval time.Duration + + // Identity + botUser string // bot's own user ID to filter self-messages + + // Policies + dmPolicy string + groupPolicy string + requireMention bool // require @bot mention in groups + + // Outbound + apiBase string // overridable Chat API base (for testing) + longFormThreshold int + longFormFormat string // "md" or "txt" + drivePermission string // "domain" or "anyone" + driveDomain string // domain for "domain" permission + blockReply *bool + + // Media + mediaMaxBytes int64 + fileRetentionDays int + + // HTTP client (shared for all API calls) + httpClient *http.Client + + // State + dedup *dedupCache + threadIDs sync.Map // spaceID:senderID → threadName + placeholders sync.Map // chatID → messageName (placeholder for edit) + groupHistory *channels.PendingHistory + historyLimit int + driveFiles []driveFileRecord + driveFilesMu sync.Mutex + + // Lifecycle + pullCancel context.CancelFunc + pullDone chan struct{} + cleanupCancel context.CancelFunc +} + +// New creates a new Google Chat channel from config. +func New(cfg config.GoogleChatConfig, msgBus *bus.MessageBus, pendingStore store.PendingMessageStore) (*Channel, error) { + auth, err := NewServiceAccountAuth(cfg.ServiceAccountFile, []string{scopeChat, scopePubSub, scopeDrive}) + if err != nil { + return nil, err + } + + pullInterval := defaultPullInterval + if cfg.PullIntervalMs > 0 { + pullInterval = time.Duration(cfg.PullIntervalMs) * time.Millisecond + } + + longFormThreshold := longFormThresholdDefault + if cfg.LongFormThreshold > 0 { + longFormThreshold = cfg.LongFormThreshold + } else if cfg.LongFormThreshold < 0 { + longFormThreshold = 0 // disabled + } + + longFormFormat := "md" + if cfg.LongFormFormat == "txt" { + longFormFormat = "txt" + } + + mediaMaxBytes := int64(defaultMediaMaxMB) * 1024 * 1024 + if cfg.MediaMaxMB > 0 { + mediaMaxBytes = int64(cfg.MediaMaxMB) * 1024 * 1024 + } + + drivePermission := "domain" + if cfg.DrivePermission == "anyone" { + drivePermission = "anyone" + } + driveDomain := "vnpay.vn" + if cfg.DriveDomain != "" { + driveDomain = cfg.DriveDomain + } + + dmPolicy := cfg.DMPolicy + if dmPolicy == "" { + dmPolicy = "open" + } + groupPolicy := cfg.GroupPolicy + if groupPolicy == "" { + groupPolicy = "open" + } + + requireMention := true + if cfg.RequireMention != nil { + requireMention = *cfg.RequireMention + } + + historyLimit := 50 + if cfg.HistoryLimit > 0 { + historyLimit = cfg.HistoryLimit + } + + ch := &Channel{ + BaseChannel: channels.NewBaseChannel(channels.TypeGoogleChat, msgBus, cfg.AllowFrom), + auth: auth, + projectID: cfg.ProjectID, + subscriptionID: cfg.SubscriptionID, + pullInterval: pullInterval, + botUser: cfg.BotUser, + dmPolicy: dmPolicy, + groupPolicy: groupPolicy, + requireMention: requireMention, + apiBase: chatAPIBase, + longFormThreshold: longFormThreshold, + longFormFormat: longFormFormat, + drivePermission: drivePermission, + driveDomain: driveDomain, + blockReply: cfg.BlockReply, + mediaMaxBytes: mediaMaxBytes, + fileRetentionDays: cfg.FileRetentionDays, + httpClient: &http.Client{Timeout: 30 * time.Second}, + dedup: newDedupCache(dedupTTL), + historyLimit: historyLimit, + groupHistory: channels.MakeHistory("google_chat", pendingStore), + } + + ch.BaseChannel.SetType(typeGoogleChat) + ch.BaseChannel.ValidatePolicy(dmPolicy, groupPolicy) + + return ch, nil +} + +// Start begins the Pub/Sub pull loop and optional Drive cleanup goroutine. +func (c *Channel) Start(ctx context.Context) error { + if c.IsRunning() { + return nil + } + + pullCtx, cancel := context.WithCancel(ctx) + c.pullCancel = cancel + c.pullDone = make(chan struct{}) + + go func() { + defer close(c.pullDone) + c.startPullLoop(pullCtx) + }() + + // Start Drive file cleanup goroutine if retention is configured. + if c.fileRetentionDays > 0 { + cleanupCtx, cleanupCancel := context.WithCancel(ctx) + c.cleanupCancel = cleanupCancel + go c.startDriveCleanupLoop(cleanupCtx) + } + + c.SetRunning(true) + slog.Info("googlechat: channel started", + "name", c.Name(), + "project", c.projectID, + "subscription", c.subscriptionID) + return nil +} + +// Stop gracefully shuts down the channel. +func (c *Channel) Stop(ctx context.Context) error { + if !c.IsRunning() { + return nil + } + + c.SetRunning(false) + + if c.cleanupCancel != nil { + c.cleanupCancel() + } + if c.pullCancel != nil { + c.pullCancel() + } + + // Wait for pull loop to drain (with timeout). + if c.pullDone != nil { + select { + case <-c.pullDone: + case <-time.After(shutdownDrainTimeout): + slog.Warn("googlechat: shutdown drain timeout exceeded") + } + } + + slog.Info("googlechat: channel stopped", "name", c.Name()) + return nil +} + +// SetPendingCompaction implements channels.PendingCompactable. +func (c *Channel) SetPendingCompaction(cfg *channels.CompactionConfig) { + if c.groupHistory != nil { + c.groupHistory.SetCompactionConfig(cfg) + } +} + +// BlockReplyEnabled implements channels.BlockReplyChannel. +func (c *Channel) BlockReplyEnabled() *bool { + return c.blockReply +} diff --git a/internal/channels/googlechat/media.go b/internal/channels/googlechat/media.go new file mode 100644 index 00000000..92ccd98c --- /dev/null +++ b/internal/channels/googlechat/media.go @@ -0,0 +1,290 @@ +package googlechat + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "mime/multipart" + "net/http" + "net/textproto" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" +) + +// downloadAttachment downloads a Chat API attachment to a temp file. +func (c *Channel) downloadAttachment(ctx context.Context, att chatAttachment) (string, error) { + if att.ResourceName == "" { + return "", fmt.Errorf("attachment has no resourceName") + } + + token, err := c.auth.Token(ctx) + if err != nil { + return "", err + } + + url := fmt.Sprintf("%s/media/%s?alt=media", chatAPIBase, att.ResourceName) + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", fmt.Errorf("download attachment: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("download attachment %d: %s", resp.StatusCode, string(body)) + } + + // Check size limit. + if c.mediaMaxBytes > 0 && resp.ContentLength > c.mediaMaxBytes { + return "", fmt.Errorf("attachment too large: %d bytes (max %d)", resp.ContentLength, c.mediaMaxBytes) + } + + // Determine extension from content type. + ext := extensionFromMIME(att.ContentType) + tmpPath := filepath.Join(os.TempDir(), uuid.New().String()+ext) + + f, err := os.Create(tmpPath) + if err != nil { + return "", err + } + defer f.Close() + + // Limit read to mediaMaxBytes. + reader := io.Reader(resp.Body) + if c.mediaMaxBytes > 0 { + reader = io.LimitReader(resp.Body, c.mediaMaxBytes+1) + } + n, err := io.Copy(f, reader) + if err != nil { + os.Remove(tmpPath) + return "", err + } + if c.mediaMaxBytes > 0 && n > c.mediaMaxBytes { + os.Remove(tmpPath) + return "", fmt.Errorf("attachment exceeded max size during download") + } + + slog.Debug("googlechat: attachment downloaded", "path", tmpPath, "size", n, "type", att.ContentType) + return tmpPath, nil +} + +// driveFileRecord tracks uploaded Drive files for retention cleanup. +type driveFileRecord struct { + FileID string + CreatedAt time.Time +} + +// startDriveCleanupLoop periodically deletes expired Drive files. +func (c *Channel) startDriveCleanupLoop(ctx context.Context) { + ticker := time.NewTicker(1 * time.Hour) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + c.cleanupExpiredDriveFiles(ctx) + } + } +} + +// cleanupExpiredDriveFiles deletes Drive files older than fileRetentionDays. +func (c *Channel) cleanupExpiredDriveFiles(ctx context.Context) { + c.driveFilesMu.Lock() + defer c.driveFilesMu.Unlock() + + cutoff := time.Now().AddDate(0, 0, -c.fileRetentionDays) + var remaining []driveFileRecord + for _, f := range c.driveFiles { + if f.CreatedAt.Before(cutoff) { + if err := c.deleteDriveFile(ctx, f.FileID); err != nil { + slog.Warn("googlechat: failed to delete expired drive file", "file_id", f.FileID, "error", err) + remaining = append(remaining, f) // retry next cycle + } else { + slog.Debug("googlechat: deleted expired drive file", "file_id", f.FileID) + } + } else { + remaining = append(remaining, f) + } + } + c.driveFiles = remaining +} + +// deleteDriveFile deletes a file from Google Drive. +func (c *Channel) deleteDriveFile(ctx context.Context, fileID string) error { + token, err := c.auth.Token(ctx) + if err != nil { + return err + } + + url := fmt.Sprintf("%s/files/%s", driveAPIBase, fileID) + req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("delete drive file %d: %s", resp.StatusCode, string(body)) + } + return nil +} + +// extensionFromMIME returns a file extension for common MIME types. +func extensionFromMIME(mime string) string { + switch { + case strings.HasPrefix(mime, "image/png"): + return ".png" + case strings.HasPrefix(mime, "image/jpeg"): + return ".jpg" + case strings.HasPrefix(mime, "image/gif"): + return ".gif" + case strings.HasPrefix(mime, "image/webp"): + return ".webp" + case strings.HasPrefix(mime, "application/pdf"): + return ".pdf" + case strings.HasPrefix(mime, "text/plain"): + return ".txt" + case strings.HasPrefix(mime, "text/markdown"): + return ".md" + default: + return "" + } +} + +// uploadToDrive uploads a file to Google Drive and returns the file ID and web link. +func (c *Channel) uploadToDrive(ctx context.Context, localPath string, fileName string, mimeType string) (fileID string, webLink string, err error) { + token, err := c.auth.Token(ctx) + if err != nil { + return "", "", err + } + + f, err := os.Open(localPath) + if err != nil { + return "", "", err + } + defer f.Close() + + // Build multipart upload body. + pr, pw := io.Pipe() + writer := multipart.NewWriter(pw) + + go func() { + defer pw.Close() + defer writer.Close() + + // Part 1: metadata + metaHeader := make(textproto.MIMEHeader) + metaHeader.Set("Content-Type", "application/json; charset=UTF-8") + metaPart, _ := writer.CreatePart(metaHeader) + json.NewEncoder(metaPart).Encode(map[string]string{ + "name": fileName, + "mimeType": mimeType, + }) + + // Part 2: file content + fileHeader := make(textproto.MIMEHeader) + fileHeader.Set("Content-Type", mimeType) + filePart, _ := writer.CreatePart(fileHeader) + io.Copy(filePart, f) + }() + + url := driveUploadBase + "/files?uploadType=multipart&fields=id,webViewLink" + req, err := http.NewRequestWithContext(ctx, "POST", url, pr) + if err != nil { + return "", "", err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "multipart/related; boundary="+writer.Boundary()) + + resp, err := c.httpClient.Do(req) + if err != nil { + return "", "", fmt.Errorf("drive upload: %w", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("drive upload %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + ID string `json:"id"` + WebViewLink string `json:"webViewLink"` + } + if err := json.Unmarshal(body, &result); err != nil { + return "", "", fmt.Errorf("parse drive response: %w", err) + } + + // Set permissions. + if err := c.setDrivePermission(ctx, result.ID); err != nil { + slog.Warn("googlechat: failed to set drive permission", "file_id", result.ID, "error", err) + } + + // Track for retention cleanup. + if c.fileRetentionDays > 0 { + c.driveFilesMu.Lock() + c.driveFiles = append(c.driveFiles, driveFileRecord{FileID: result.ID, CreatedAt: time.Now()}) + c.driveFilesMu.Unlock() + } + + return result.ID, result.WebViewLink, nil +} + +// setDrivePermission sets the sharing permission on a Drive file. +func (c *Channel) setDrivePermission(ctx context.Context, fileID string) error { + token, err := c.auth.Token(ctx) + if err != nil { + return err + } + + var perm map[string]string + switch c.drivePermission { + case "anyone": + perm = map[string]string{"type": "anyone", "role": "reader"} + default: // "domain" + perm = map[string]string{"type": "domain", "role": "reader", "domain": c.driveDomain} + } + + body, _ := json.Marshal(perm) + url := fmt.Sprintf("%s/files/%s/permissions", driveAPIBase, fileID) + + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(body))) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return fmt.Errorf("set permission %d: %s", resp.StatusCode, string(b)) + } + return nil +} diff --git a/internal/channels/googlechat/pubsub.go b/internal/channels/googlechat/pubsub.go new file mode 100644 index 00000000..d45f623c --- /dev/null +++ b/internal/channels/googlechat/pubsub.go @@ -0,0 +1,376 @@ +package googlechat + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + "sync" + "time" +) + +// chatEvent is the parsed representation of a Google Chat event from Pub/Sub. +type chatEvent struct { + Type string // MESSAGE, ADDED_TO_SPACE, REMOVED_FROM_SPACE, etc. + SenderID string // users/{userId} + SenderName string // display name + SpaceID string // spaces/{spaceId} + SpaceType string // DM, SPACE, ROOM + PeerKind string // "direct" or "group" + Text string // message text + MessageName string // spaces/{spaceId}/messages/{messageId} + ThreadName string // spaces/{spaceId}/threads/{threadId} + Attachments []chatAttachment // file attachments +} + +// chatAttachment represents a file attachment in a Google Chat message. +type chatAttachment struct { + Name string // attachment resource name + ContentType string + ResourceName string // for download via media API +} + +// parseEvent parses a base64-encoded Pub/Sub message data into a chatEvent. +func parseEvent(encodedData string) (*chatEvent, error) { + data, err := base64.StdEncoding.DecodeString(encodedData) + if err != nil { + return nil, fmt.Errorf("base64 decode: %w", err) + } + if len(data) == 0 { + return nil, fmt.Errorf("empty event data") + } + + var raw struct { + Type string `json:"type"` + Message struct { + Name string `json:"name"` + Text string `json:"text"` + Sender struct { + Name string `json:"name"` + DisplayName string `json:"displayName"` + } `json:"sender"` + Thread struct { + Name string `json:"name"` + } `json:"thread"` + Attachment []struct { + Name string `json:"name"` + ContentType string `json:"contentType"` + AttachmentDataRef struct { + ResourceName string `json:"resourceName"` + } `json:"attachmentDataRef"` + } `json:"attachment"` + } `json:"message"` + Space struct { + Name string `json:"name"` + Type string `json:"type"` + } `json:"space"` + User struct { + Name string `json:"name"` + } `json:"user"` + } + + if err := json.Unmarshal(data, &raw); err != nil { + return nil, fmt.Errorf("parse chat event: %w", err) + } + + evt := &chatEvent{ + Type: raw.Type, + SpaceID: raw.Space.Name, + SpaceType: raw.Space.Type, + } + + // Determine peer kind + switch raw.Space.Type { + case "DM": + evt.PeerKind = "direct" + default: // SPACE, ROOM + evt.PeerKind = "group" + } + + // Extract sender + if raw.Type == "MESSAGE" { + if raw.Message.Sender.Name == "" { + return nil, fmt.Errorf("MESSAGE event missing sender") + } + evt.SenderID = raw.Message.Sender.Name + evt.SenderName = raw.Message.Sender.DisplayName + evt.Text = raw.Message.Text + evt.MessageName = raw.Message.Name + evt.ThreadName = raw.Message.Thread.Name + + // Parse attachments + for _, att := range raw.Message.Attachment { + evt.Attachments = append(evt.Attachments, chatAttachment{ + Name: att.Name, + ContentType: att.ContentType, + ResourceName: att.AttachmentDataRef.ResourceName, + }) + } + } else if raw.Type == "ADDED_TO_SPACE" || raw.Type == "REMOVED_FROM_SPACE" { + evt.SenderID = raw.User.Name + } + + return evt, nil +} + +// dedupCache is a thread-safe cache for Pub/Sub message deduplication. +type dedupCache struct { + mu sync.Mutex + entries map[string]time.Time + ttl time.Duration +} + +func newDedupCache(ttl time.Duration) *dedupCache { + return &dedupCache{ + entries: make(map[string]time.Time), + ttl: ttl, + } +} + +// seen returns true if the messageID was already processed. +func (d *dedupCache) seen(messageID string) bool { + d.mu.Lock() + defer d.mu.Unlock() + + if t, ok := d.entries[messageID]; ok { + if time.Since(t) < d.ttl { + return true + } + delete(d.entries, messageID) + } + return false +} + +// add marks a messageID as processed. +func (d *dedupCache) add(messageID string) { + d.mu.Lock() + defer d.mu.Unlock() + d.entries[messageID] = time.Now() + + // Periodic cleanup of expired entries (every 100 adds). + if len(d.entries)%100 == 0 { + now := time.Now() + for k, t := range d.entries { + if now.Sub(t) > d.ttl { + delete(d.entries, k) + } + } + } +} + +// pubsubPullResponse is the response from Pub/Sub pull API. +type pubsubPullResponse struct { + ReceivedMessages []struct { + AckID string `json:"ackId"` + Message struct { + Data string `json:"data"` + MessageID string `json:"messageId"` + } `json:"message"` + } `json:"receivedMessages"` +} + +// pullMessages performs a single Pub/Sub pull request and returns received messages. +func pullMessages(ctx context.Context, auth *ServiceAccountAuth, httpClient *http.Client, projectID, subscriptionID string, maxMessages int) (*pubsubPullResponse, error) { + token, err := auth.Token(ctx) + if err != nil { + return nil, fmt.Errorf("get token: %w", err) + } + + url := fmt.Sprintf("%s/projects/%s/subscriptions/%s:pull", pubsubAPIBase, projectID, subscriptionID) + body := fmt.Sprintf(`{"maxMessages":%d}`, maxMessages) + + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(body)) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("pubsub pull: %w", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + + if resp.StatusCode == http.StatusOK && len(respBody) <= 2 { + // Empty response "{}" — no messages + return &pubsubPullResponse{}, nil + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("pubsub pull %d: %s", resp.StatusCode, string(respBody)) + } + + var result pubsubPullResponse + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("parse pull response: %w", err) + } + return &result, nil +} + +// ackMessages acknowledges received Pub/Sub messages. +func ackMessages(ctx context.Context, auth *ServiceAccountAuth, httpClient *http.Client, projectID, subscriptionID string, ackIDs []string) error { + if len(ackIDs) == 0 { + return nil + } + + token, err := auth.Token(ctx) + if err != nil { + return fmt.Errorf("get token: %w", err) + } + + url := fmt.Sprintf("%s/projects/%s/subscriptions/%s:acknowledge", pubsubAPIBase, projectID, subscriptionID) + + ackBody := struct { + AckIDs []string `json:"ackIds"` + }{AckIDs: ackIDs} + bodyBytes, _ := json.Marshal(ackBody) + + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := httpClient.Do(req) + if err != nil { + return fmt.Errorf("pubsub ack: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("pubsub ack %d: %s", resp.StatusCode, string(body)) + } + return nil +} + +// startPullLoop runs the Pub/Sub pull loop. Blocks until ctx is cancelled. +func (c *Channel) startPullLoop(ctx context.Context) { + ticker := time.NewTicker(c.pullInterval) + defer ticker.Stop() + + slog.Info("googlechat: pubsub pull loop started", + "project", c.projectID, "subscription", c.subscriptionID, + "interval", c.pullInterval) + + for { + select { + case <-ctx.Done(): + slog.Info("googlechat: pubsub pull loop stopped") + return + case <-ticker.C: + c.doPull(ctx) + } + } +} + +// doPull performs a single pull cycle. +func (c *Channel) doPull(ctx context.Context) { + resp, err := pullMessages(ctx, c.auth, c.httpClient, c.projectID, c.subscriptionID, defaultPullMaxMessages) + if err != nil { + if ctx.Err() != nil { + return // context cancelled, normal shutdown + } + slog.Warn("googlechat: pubsub pull failed", "error", err) + return + } + + if len(resp.ReceivedMessages) == 0 { + return + } + + var ackIDs []string + for _, rm := range resp.ReceivedMessages { + ackIDs = append(ackIDs, rm.AckID) + + // Dedup check + if c.dedup.seen(rm.Message.MessageID) { + slog.Debug("googlechat: duplicate pubsub message, skipping", "message_id", rm.Message.MessageID) + continue + } + c.dedup.add(rm.Message.MessageID) + + evt, err := parseEvent(rm.Message.Data) + if err != nil { + slog.Warn("googlechat: malformed event, acking anyway", "error", err, "message_id", rm.Message.MessageID) + continue + } + + c.handleEvent(ctx, evt) + } + + // Ack all messages (including malformed ones to prevent infinite re-delivery). + if err := ackMessages(ctx, c.auth, c.httpClient, c.projectID, c.subscriptionID, ackIDs); err != nil { + slog.Warn("googlechat: ack failed", "error", err) + } +} + +// handleEvent dispatches a parsed chat event. +func (c *Channel) handleEvent(ctx context.Context, evt *chatEvent) { + // Filter bot self-messages. + if c.botUser != "" && evt.SenderID == c.botUser { + return + } + + switch evt.Type { + case "MESSAGE": + c.handleMessage(ctx, evt) + case "ADDED_TO_SPACE": + slog.Info("googlechat: added to space", "space", evt.SpaceID, "by", evt.SenderID) + case "REMOVED_FROM_SPACE": + slog.Info("googlechat: removed from space", "space", evt.SpaceID) + default: + slog.Debug("googlechat: ignoring event", "type", evt.Type, "space", evt.SpaceID) + } +} + +// handleMessage processes an inbound MESSAGE event. +func (c *Channel) handleMessage(ctx context.Context, evt *chatEvent) { + // Skip whitespace-only messages. + text := strings.TrimSpace(evt.Text) + if text == "" && len(evt.Attachments) == 0 { + return + } + + // Check DM/Group policy. + if !c.BaseChannel.CheckPolicy(evt.PeerKind, c.dmPolicy, c.groupPolicy, evt.SenderID) { + slog.Debug("googlechat: message rejected by policy", + "sender", evt.SenderID, "peer_kind", evt.PeerKind) + return + } + + // Store thread name for outbound routing (groups). + if evt.ThreadName != "" && evt.PeerKind == "group" { + threadKey := evt.SpaceID + ":" + evt.SenderID + c.threadIDs.Store(threadKey, evt.ThreadName) + } + + // Download attachments. + var mediaPaths []string + for _, att := range evt.Attachments { + path, err := c.downloadAttachment(ctx, att) + if err != nil { + slog.Warn("googlechat: attachment download failed", "error", err) + continue + } + mediaPaths = append(mediaPaths, path) + } + + metadata := map[string]string{ + "sender_name": evt.SenderName, + "message_name": evt.MessageName, + } + if evt.ThreadName != "" { + metadata["thread_name"] = evt.ThreadName + } + + c.BaseChannel.HandleMessage(evt.SenderID, evt.SpaceID, text, mediaPaths, metadata, evt.PeerKind) +} diff --git a/internal/channels/googlechat/pubsub_test.go b/internal/channels/googlechat/pubsub_test.go new file mode 100644 index 00000000..ea4dfcd0 --- /dev/null +++ b/internal/channels/googlechat/pubsub_test.go @@ -0,0 +1,205 @@ +package googlechat + +import ( + "encoding/base64" + "encoding/json" + "testing" +) + +func TestParseEvent_Message(t *testing.T) { + chatEvent := map[string]any{ + "type": "MESSAGE", + "message": map[string]any{ + "name": "spaces/AAA/messages/BBB", + "text": "hello bot", + "sender": map[string]any{ + "name": "users/12345", + "displayName": "Test User", + }, + "thread": map[string]any{ + "name": "spaces/AAA/threads/CCC", + }, + }, + "space": map[string]any{ + "name": "spaces/AAA", + "type": "DM", + }, + } + data, _ := json.Marshal(chatEvent) + encoded := base64.StdEncoding.EncodeToString(data) + + evt, err := parseEvent(encoded) + if err != nil { + t.Fatal(err) + } + if evt.Type != "MESSAGE" { + t.Errorf("type = %q, want MESSAGE", evt.Type) + } + if evt.SenderID != "users/12345" { + t.Errorf("senderID = %q, want users/12345", evt.SenderID) + } + if evt.SpaceID != "spaces/AAA" { + t.Errorf("spaceID = %q, want spaces/AAA", evt.SpaceID) + } + if evt.Text != "hello bot" { + t.Errorf("text = %q, want 'hello bot'", evt.Text) + } + if evt.PeerKind != "direct" { + t.Errorf("peerKind = %q, want direct", evt.PeerKind) + } + if evt.ThreadName != "spaces/AAA/threads/CCC" { + t.Errorf("threadName = %q, want spaces/AAA/threads/CCC", evt.ThreadName) + } +} + +func TestParseEvent_GroupSpace(t *testing.T) { + chatEvent := map[string]any{ + "type": "MESSAGE", + "message": map[string]any{ + "text": "hey", + "sender": map[string]any{ + "name": "users/999", + }, + }, + "space": map[string]any{ + "name": "spaces/GGG", + "type": "SPACE", + }, + } + data, _ := json.Marshal(chatEvent) + encoded := base64.StdEncoding.EncodeToString(data) + + evt, err := parseEvent(encoded) + if err != nil { + t.Fatal(err) + } + if evt.PeerKind != "group" { + t.Errorf("peerKind = %q, want group", evt.PeerKind) + } +} + +func TestParseEvent_AddedToSpace(t *testing.T) { + chatEvent := map[string]any{ + "type": "ADDED_TO_SPACE", + "space": map[string]any{ + "name": "spaces/AAA", + "type": "DM", + }, + "user": map[string]any{ + "name": "users/12345", + }, + } + data, _ := json.Marshal(chatEvent) + encoded := base64.StdEncoding.EncodeToString(data) + + evt, err := parseEvent(encoded) + if err != nil { + t.Fatal(err) + } + if evt.Type != "ADDED_TO_SPACE" { + t.Errorf("type = %q, want ADDED_TO_SPACE", evt.Type) + } +} + +func TestParseEvent_MalformedJSON(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("{bad json")) + _, err := parseEvent(encoded) + if err == nil { + t.Fatal("expected error for malformed JSON") + } +} + +func TestParseEvent_EmptyData(t *testing.T) { + encoded := base64.StdEncoding.EncodeToString([]byte("")) + _, err := parseEvent(encoded) + if err == nil { + t.Fatal("expected error for empty data") + } +} + +func TestParseEvent_MissingSender(t *testing.T) { + chatEvent := map[string]any{ + "type": "MESSAGE", + "message": map[string]any{"text": "hello"}, + "space": map[string]any{"name": "spaces/AAA", "type": "DM"}, + } + data, _ := json.Marshal(chatEvent) + encoded := base64.StdEncoding.EncodeToString(data) + + _, err := parseEvent(encoded) + if err == nil { + t.Fatal("expected error for missing sender") + } +} + +func TestParseEvent_BotSelfFilter(t *testing.T) { + chatEvent := map[string]any{ + "type": "MESSAGE", + "message": map[string]any{ + "text": "bot reply", + "sender": map[string]any{"name": "users/BOT123"}, + }, + "space": map[string]any{"name": "spaces/AAA", "type": "DM"}, + } + data, _ := json.Marshal(chatEvent) + encoded := base64.StdEncoding.EncodeToString(data) + + evt, err := parseEvent(encoded) + if err != nil { + t.Fatal(err) + } + if evt.SenderID != "users/BOT123" { + t.Errorf("senderID = %q", evt.SenderID) + } +} + +func TestParseEvent_WithAttachment(t *testing.T) { + chatEvent := map[string]any{ + "type": "MESSAGE", + "message": map[string]any{ + "text": "", + "sender": map[string]any{"name": "users/12345"}, + "attachment": []any{ + map[string]any{ + "name": "spaces/AAA/messages/BBB/attachments/CCC", + "contentType": "image/png", + "attachmentDataRef": map[string]any{ + "resourceName": "spaces/AAA/attachments/CCC", + }, + }, + }, + }, + "space": map[string]any{"name": "spaces/AAA", "type": "DM"}, + } + data, _ := json.Marshal(chatEvent) + encoded := base64.StdEncoding.EncodeToString(data) + + evt, err := parseEvent(encoded) + if err != nil { + t.Fatal(err) + } + if len(evt.Attachments) != 1 { + t.Fatalf("attachments = %d, want 1", len(evt.Attachments)) + } + if evt.Attachments[0].ResourceName != "spaces/AAA/attachments/CCC" { + t.Errorf("resourceName = %q", evt.Attachments[0].ResourceName) + } +} + +func TestDedupCache(t *testing.T) { + cache := newDedupCache(dedupTTL) + + if cache.seen("msg1") { + t.Error("msg1 should not be seen yet") + } + cache.add("msg1") + if !cache.seen("msg1") { + t.Error("msg1 should be seen after add") + } + if !cache.seen("msg1") { + t.Error("msg1 should still be seen") + } + if cache.seen("msg2") { + t.Error("msg2 should not be seen") + } +} diff --git a/internal/channels/googlechat/send.go b/internal/channels/googlechat/send.go new file mode 100644 index 00000000..9f5b2af6 --- /dev/null +++ b/internal/channels/googlechat/send.go @@ -0,0 +1,427 @@ +package googlechat + +import ( + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "math" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/google/uuid" + "github.com/nextlevelbuilder/goclaw/internal/bus" +) + +// sendTextMessage sends a plain text message to a Google Chat space. +func sendTextMessage(ctx context.Context, apiBase, token string, httpClient *http.Client, msg bus.OutboundMessage, threadName, replyOption string) error { + _, err := sendTextMessageWithResponse(ctx, apiBase, token, httpClient, msg, threadName, replyOption) + return err +} + +// sendTextMessageWithResponse sends a text message and returns the API response (for thread chaining). +func sendTextMessageWithResponse(ctx context.Context, apiBase, token string, httpClient *http.Client, msg bus.OutboundMessage, threadName, replyOption string) (*chatMessageResponse, error) { + text := strings.TrimSpace(msg.Content) + if text == "" { + return nil, nil + } + + body := map[string]any{ + "text": markdownToGoogleChat(text), + } + if threadName != "" { + body["thread"] = map[string]string{"name": threadName} + } + + return postChatMessage(ctx, apiBase, token, httpClient, msg.ChatID, body, replyOption) +} + +// sendCardMessage sends a Card V2 message. +func sendCardMessage(ctx context.Context, apiBase, token string, httpClient *http.Client, chatID string, card map[string]any, threadName, replyOption string) error { + if threadName != "" { + card["thread"] = map[string]string{"name": threadName} + } + _, err := postChatMessage(ctx, apiBase, token, httpClient, chatID, card, replyOption) + return err +} + +// chatMessageResponse is the response from Chat API message operations. +type chatMessageResponse struct { + Name string `json:"name"` // spaces/{space}/messages/{message} + Thread struct { + Name string `json:"name"` // spaces/{space}/threads/{thread} + } `json:"thread"` +} + +// postChatMessage sends a message to the Chat API with retry logic. +func postChatMessage(ctx context.Context, apiBase, token string, httpClient *http.Client, spaceID string, body map[string]any, replyOption string) (*chatMessageResponse, error) { + bodyBytes, err := json.Marshal(body) + if err != nil { + return nil, err + } + + url := fmt.Sprintf("%s/%s/messages", apiBase, spaceID) + if replyOption != "" { + url += "?messageReplyOption=" + replyOption + } + + var result chatMessageResponse + err = retrySend(ctx, httpClient, func() (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + return httpClient.Do(req) + }, &result) + if err != nil { + return nil, err + } + return &result, nil +} + +// editMessage edits an existing message. +func editMessage(ctx context.Context, apiBase, token string, httpClient *http.Client, messageName string, text string) error { + body := map[string]any{ + "text": text, + } + bodyBytes, _ := json.Marshal(body) + + url := fmt.Sprintf("%s/%s?updateMask=text", apiBase, messageName) + + return retrySend(ctx, httpClient, func() (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, "PATCH", url, strings.NewReader(string(bodyBytes))) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + return httpClient.Do(req) + }) +} + +// deleteMessage deletes a message. +func deleteMessage(ctx context.Context, apiBase, token string, httpClient *http.Client, messageName string) error { + url := fmt.Sprintf("%s/%s", apiBase, messageName) + + req, err := http.NewRequestWithContext(ctx, "DELETE", url, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := httpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusNoContent { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("delete message %d: %s", resp.StatusCode, string(body)) + } + return nil +} + +// retrySend retries an HTTP request with exponential backoff on 429/5xx. +func retrySend(ctx context.Context, httpClient *http.Client, doReq func() (*http.Response, error), result ...any) error { + delay := retrySendBaseDelay + for attempt := 0; attempt < retrySendMaxAttempts; attempt++ { + resp, err := doReq() + if err != nil { + if ctx.Err() != nil { + return ctx.Err() + } + return err + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + if len(result) > 0 && result[0] != nil { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + json.Unmarshal(body, result[0]) + } else { + resp.Body.Close() + } + return nil + } + + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + if resp.StatusCode == 429 || resp.StatusCode >= 500 { + if attempt < retrySendMaxAttempts-1 { + slog.Debug("googlechat: retrying send", "status", resp.StatusCode, "attempt", attempt+1, "delay", delay) + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(delay): + } + delay = time.Duration(math.Min(float64(delay*2), float64(retrySendMaxDelay))) + continue + } + } + + return fmt.Errorf("chat API %d: %s", resp.StatusCode, string(body)) + } + return fmt.Errorf("chat API: max retries exceeded") +} + +// buildCardMessage creates a Cards V2 message from content with tables/code. +func buildCardMessage(content string) map[string]any { + if !detectStructuredContent(content) { + return nil + } + + var sections []map[string]any + lines := strings.Split(content, "\n") + var currentWidgets []map[string]any + var inTable bool + var tableRows []string + + flushTable := func() { + if len(tableRows) > 0 { + tableText := strings.Join(tableRows, "\n") + currentWidgets = append(currentWidgets, map[string]any{ + "textParagraph": map[string]string{ + "text": "
" + tableText + "
", + }, + }) + tableRows = nil + } + } + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + + // Table row detection. + if strings.HasPrefix(trimmed, "|") && strings.HasSuffix(trimmed, "|") { + inTable = true + if isSeparatorRow(trimmed) { + continue + } + tableRows = append(tableRows, trimmed) + continue + } + + if inTable { + flushTable() + inTable = false + } + + if strings.HasPrefix(trimmed, "# ") { + if len(currentWidgets) > 0 { + sections = append(sections, map[string]any{"widgets": currentWidgets}) + currentWidgets = nil + } + continue + } + + if trimmed != "" { + currentWidgets = append(currentWidgets, map[string]any{ + "textParagraph": map[string]string{ + "text": markdownToGoogleChat(trimmed), + }, + }) + } + } + + flushTable() + if len(currentWidgets) > 0 { + sections = append(sections, map[string]any{"widgets": currentWidgets}) + } + + if len(sections) == 0 { + return nil + } + + title := "Response" + for _, line := range lines { + if strings.HasPrefix(strings.TrimSpace(line), "# ") { + title = strings.TrimPrefix(strings.TrimSpace(line), "# ") + break + } + } + + return map[string]any{ + "cardsV2": []map[string]any{{ + "card": map[string]any{ + "header": map[string]string{"title": title}, + "sections": sections, + }, + }}, + } +} + +// isSeparatorRow checks if a table row is a separator (e.g. |---|---|). +func isSeparatorRow(row string) bool { + inner := strings.Trim(row, "|") + for _, ch := range inner { + if ch != '-' && ch != ':' && ch != ' ' && ch != '|' { + return false + } + } + return true +} + +// Send implements the Channel interface for outbound messages. +func (c *Channel) Send(ctx context.Context, msg bus.OutboundMessage) error { + content := strings.TrimSpace(msg.Content) + if content == "" && len(msg.Media) == 0 { + return nil + } + + token, err := c.auth.Token(ctx) + if err != nil { + return err + } + + // Determine thread context. + peerKind := msg.Metadata["peer_kind"] + threadName := "" + replyOption := "" + if peerKind == "group" { + if tn, ok := msg.Metadata["thread_name"]; ok { + threadName = tn + } else { + senderID := msg.Metadata["sender_id"] + threadKey := msg.ChatID + ":" + senderID + if v, ok := c.threadIDs.Load(threadKey); ok { + threadName = v.(string) + } + } + replyOption = "REPLY_MESSAGE_FALLBACK_TO_NEW_THREAD" + } + + // Check for placeholder edit (Thinking... → final response). + if placeholderName, ok := c.placeholders.Load(msg.ChatID); ok { + c.placeholders.Delete(msg.ChatID) + pName := placeholderName.(string) + + if len([]byte(content)) <= googleChatMaxMessageBytes && !detectStructuredContent(content) { + if err := editMessage(ctx, c.apiBase, token, c.httpClient, pName, markdownToGoogleChat(content)); err != nil { + slog.Warn("googlechat: placeholder edit failed, sending new", "error", err) + } else { + return nil + } + } + deleteMessage(ctx, c.apiBase, token, c.httpClient, pName) + } + + // Long-form content → file attachment. + if c.longFormThreshold > 0 && len(content) > c.longFormThreshold { + if err := c.sendLongForm(ctx, token, msg, content, threadName, replyOption); err != nil { + slog.Warn("googlechat: long-form send failed, falling back to chunks", "error", err) + } else { + return nil + } + } + + // Card message for structured content. + if card := buildCardMessage(content); card != nil { + return sendCardMessage(ctx, c.apiBase, token, c.httpClient, msg.ChatID, card, threadName, replyOption) + } + + // Chunked plain text. + chunks := chunkByBytes(content, googleChatMaxMessageBytes) + currentThread := threadName + for i, chunk := range chunks { + chunkMsg := msg + chunkMsg.Content = chunk + resp, err := sendTextMessageWithResponse(ctx, c.apiBase, token, c.httpClient, chunkMsg, currentThread, replyOption) + if err != nil { + return fmt.Errorf("send chunk %d/%d: %w", i+1, len(chunks), err) + } + if resp != nil && resp.Thread.Name != "" { + currentThread = resp.Thread.Name + } + } + + return nil +} + +// sendLongForm uploads content as a file and sends a summary message. +func (c *Channel) sendLongForm(ctx context.Context, token string, msg bus.OutboundMessage, content, threadName, replyOption string) error { + summary := extractSummary(content) + + ext := ".md" + if c.longFormFormat == "txt" { + ext = ".txt" + } + tmpPath := filepath.Join(os.TempDir(), uuid.New().String()+ext) + if err := os.WriteFile(tmpPath, []byte(content), 0644); err != nil { + return err + } + defer os.Remove(tmpPath) + + mimeType := "text/markdown" + if c.longFormFormat == "txt" { + mimeType = "text/plain" + } + _, webLink, err := c.uploadToDrive(ctx, tmpPath, "response"+ext, mimeType) + if err != nil { + return err + } + + summaryText := markdownToGoogleChat(summary) + "\n\n📎 " + webLink + body := map[string]any{ + "text": summaryText, + } + if threadName != "" { + body["thread"] = map[string]string{"name": threadName} + } + + _, err = postChatMessage(ctx, c.apiBase, token, c.httpClient, msg.ChatID, body, replyOption) + return err +} + +// sendPlaceholder sends a "Thinking..." placeholder message and stores its name. +func (c *Channel) sendPlaceholder(ctx context.Context, chatID, threadName, replyOption string) { + token, err := c.auth.Token(ctx) + if err != nil { + slog.Warn("googlechat: placeholder auth failed", "error", err) + return + } + + body := map[string]any{ + "text": "🤔 Thinking...", + } + if threadName != "" { + body["thread"] = map[string]string{"name": threadName} + } + + bodyBytes, _ := json.Marshal(body) + url := fmt.Sprintf("%s/%s/messages", c.apiBase, chatID) + if replyOption != "" { + url += "?messageReplyOption=" + replyOption + } + + req, err := http.NewRequestWithContext(ctx, "POST", url, strings.NewReader(string(bodyBytes))) + if err != nil { + return + } + req.Header.Set("Authorization", "Bearer "+token) + req.Header.Set("Content-Type", "application/json") + + resp, err := c.httpClient.Do(req) + if err != nil { + return + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusOK { + var result struct { + Name string `json:"name"` + } + respBody, _ := io.ReadAll(resp.Body) + if json.Unmarshal(respBody, &result) == nil && result.Name != "" { + c.placeholders.Store(chatID, result.Name) + slog.Debug("googlechat: placeholder sent", "chat_id", chatID, "name", result.Name) + } + } +} diff --git a/internal/channels/googlechat/send_test.go b/internal/channels/googlechat/send_test.go new file mode 100644 index 00000000..097fe160 --- /dev/null +++ b/internal/channels/googlechat/send_test.go @@ -0,0 +1,71 @@ +package googlechat + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/nextlevelbuilder/goclaw/internal/bus" +) + +// mockChatAPI creates an httptest server mimicking Google Chat API. +func mockChatAPI(t *testing.T, handler func(w http.ResponseWriter, r *http.Request)) (*httptest.Server, string) { + t.Helper() + ts := httptest.NewServer(http.HandlerFunc(handler)) + return ts, ts.URL +} + +func TestSendMessage_ShortDM(t *testing.T) { + var sentBody map[string]any + ts, baseURL := mockChatAPI(t, func(w http.ResponseWriter, r *http.Request) { + json.NewDecoder(r.Body).Decode(&sentBody) + json.NewEncoder(w).Encode(map[string]any{ + "name": "spaces/DM1/messages/123", + }) + }) + defer ts.Close() + + msg := bus.OutboundMessage{ + ChatID: "spaces/DM1", + Content: "Hello world", + Metadata: map[string]string{ + "peer_kind": "direct", + }, + } + + err := sendTextMessage(context.Background(), baseURL, "fake-token", &http.Client{}, msg, "", "") + if err != nil { + t.Fatal(err) + } + if sentBody["text"] == nil { + t.Error("expected text in sent body") + } +} + +func TestSendMessage_EmptyContent(t *testing.T) { + msg := bus.OutboundMessage{ + ChatID: "spaces/DM1", + Content: "", + } + + err := sendTextMessage(context.Background(), "http://unused", "fake", &http.Client{}, msg, "", "") + if err != nil { + t.Fatal("empty content should not error") + } +} + +func TestBuildCardMessage_Table(t *testing.T) { + content := "# Results\n\n| Name | Score |\n|---|---|\n| Alice | 95 |\n| Bob | 87 |" + card := buildCardMessage(content) + if card == nil { + t.Fatal("expected card for table content") + } + cardJSON, _ := json.Marshal(card) + s := string(cardJSON) + if !strings.Contains(s, "cardsV2") { + t.Error("card JSON should contain cardsV2") + } +} diff --git a/internal/channels/googlechat/stream.go b/internal/channels/googlechat/stream.go new file mode 100644 index 00000000..46e7010d --- /dev/null +++ b/internal/channels/googlechat/stream.go @@ -0,0 +1,7 @@ +package googlechat + +// Phase 2: StreamingChannel implementation for Google Chat. +// Uses PATCH /v1/{message} to edit messages progressively as LLM chunks arrive, +// similar to the Telegram DraftStream. +// +// Implementation deferred to phase 2. diff --git a/internal/channels/googlechat/webhook.go b/internal/channels/googlechat/webhook.go new file mode 100644 index 00000000..03fb6fe5 --- /dev/null +++ b/internal/channels/googlechat/webhook.go @@ -0,0 +1,7 @@ +package googlechat + +// Phase 2: HTTP webhook handler for Google Chat push events. +// When mode="webhook", the channel will register an HTTP handler via WebhookChannel +// interface instead of using Pub/Sub pull. +// +// Implementation deferred to phase 2. diff --git a/internal/channels/whatsapp/media.go b/internal/channels/whatsapp/media.go new file mode 100644 index 00000000..0f489422 --- /dev/null +++ b/internal/channels/whatsapp/media.go @@ -0,0 +1,237 @@ +package whatsapp + +import ( + "fmt" + "io" + "log/slog" + "net/http" + "os" + "path/filepath" + "strings" + "time" + + "github.com/nextlevelbuilder/goclaw/internal/channels/media" +) + +const ( + // maxMediaBytes is the max download size for WhatsApp media (20 MB). + maxMediaBytes int64 = 20 * 1024 * 1024 + + // downloadTimeout is the HTTP timeout for downloading media from URLs. + downloadTimeout = 30 * time.Second +) + +// resolveMedia processes raw media entries from the bridge. +// Each entry can be a URL (http/https) or a local file path. +// Returns MediaInfo list with local file paths and detected MIME types. +func (c *Channel) resolveMedia(rawMedia []any) []media.MediaInfo { + var results []media.MediaInfo + + for _, m := range rawMedia { + switch v := m.(type) { + case string: + info := c.resolveMediaEntry(v, "") + if info != nil { + results = append(results, *info) + } + + case map[string]any: + // Bridge may send structured media: {"url":"...","filename":"...","mimetype":"..."} + url, _ := v["url"].(string) + path, _ := v["path"].(string) + fileName, _ := v["filename"].(string) + mimeType, _ := v["mimetype"].(string) + + target := url + if target == "" { + target = path + } + if target == "" { + continue + } + + info := c.resolveMediaEntry(target, fileName) + if info != nil { + if mimeType != "" { + info.ContentType = mimeType + info.Type = media.MediaKindFromMime(mimeType) + } + if fileName != "" { + info.FileName = fileName + } + results = append(results, *info) + } + } + } + + return results +} + +// resolveMediaEntry handles a single media entry (URL or local path). +func (c *Channel) resolveMediaEntry(entry, fileName string) *media.MediaInfo { + if strings.HasPrefix(entry, "http://") || strings.HasPrefix(entry, "https://") { + return c.downloadMediaURL(entry, fileName) + } + return c.resolveLocalFile(entry, fileName) +} + +// downloadMediaURL downloads media from a URL and saves to a temp file. +func (c *Channel) downloadMediaURL(url, fileName string) *media.MediaInfo { + client := &http.Client{Timeout: downloadTimeout} + + resp, err := client.Get(url) + if err != nil { + slog.Warn("whatsapp media download failed", "url", url, "error", err) + return nil + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + slog.Warn("whatsapp media download non-200", "url", url, "status", resp.StatusCode) + return nil + } + + // Detect extension from Content-Type or URL + ext := extensionFromContentType(resp.Header.Get("Content-Type")) + if ext == "" && fileName != "" { + ext = filepath.Ext(fileName) + } + if ext == "" { + ext = extensionFromURL(url) + } + + f, err := os.CreateTemp("", "goclaw_wa_*"+ext) + if err != nil { + slog.Warn("whatsapp media temp file failed", "error", err) + return nil + } + defer f.Close() + + n, err := io.Copy(f, io.LimitReader(resp.Body, maxMediaBytes)) + if err != nil { + os.Remove(f.Name()) + slog.Warn("whatsapp media write failed", "error", err) + return nil + } + if n == 0 { + os.Remove(f.Name()) + slog.Warn("whatsapp media empty response", "url", url) + return nil + } + + ct := resp.Header.Get("Content-Type") + if ct == "" { + ct = media.DetectMIMEType(f.Name()) + } + kind := media.MediaKindFromMime(ct) + + if fileName == "" { + fileName = filepath.Base(f.Name()) + } + + slog.Debug("whatsapp media downloaded", "path", f.Name(), "size", n, "type", kind) + + return &media.MediaInfo{ + Type: kind, + FilePath: f.Name(), + ContentType: ct, + FileName: fileName, + FileSize: n, + } +} + +// resolveLocalFile validates a local file path from the bridge. +func (c *Channel) resolveLocalFile(path, fileName string) *media.MediaInfo { + info, err := os.Stat(path) + if err != nil { + slog.Warn("whatsapp media local file not found", "path", path, "error", err) + return nil + } + + if info.IsDir() { + return nil + } + + if info.Size() > maxMediaBytes { + slog.Warn("whatsapp media file too large", "path", path, "size", info.Size()) + return nil + } + + ct := media.DetectMIMEType(path) + kind := media.MediaKindFromMime(ct) + + if fileName == "" { + fileName = filepath.Base(path) + } + + slog.Debug("whatsapp media local file resolved", "path", path, "size", info.Size(), "type", kind) + + return &media.MediaInfo{ + Type: kind, + FilePath: path, + ContentType: ct, + FileName: fileName, + FileSize: info.Size(), + } +} + +// extensionFromContentType maps common Content-Type headers to file extensions. +func extensionFromContentType(ct string) string { + ct = strings.ToLower(ct) + switch { + case strings.Contains(ct, "image/jpeg"): + return ".jpg" + case strings.Contains(ct, "image/png"): + return ".png" + case strings.Contains(ct, "image/gif"): + return ".gif" + case strings.Contains(ct, "image/webp"): + return ".webp" + case strings.Contains(ct, "video/mp4"): + return ".mp4" + case strings.Contains(ct, "audio/ogg"): + return ".ogg" + case strings.Contains(ct, "audio/mpeg"): + return ".mp3" + case strings.Contains(ct, "application/pdf"): + return ".pdf" + case strings.Contains(ct, "application/vnd.openxmlformats"): + return ".docx" + default: + return "" + } +} + +// extensionFromURL extracts the file extension from a URL path. +func extensionFromURL(url string) string { + // Strip query string + if idx := strings.IndexByte(url, '?'); idx > 0 { + url = url[:idx] + } + ext := filepath.Ext(url) + if len(ext) > 6 { // sanity check + return "" + } + return ext +} + +// mediaInfoToPaths converts MediaInfo slice to string paths for HandleMessage. +func mediaInfoToPaths(infos []media.MediaInfo) []string { + paths := make([]string, 0, len(infos)) + for _, info := range infos { + paths = append(paths, info.FilePath) + } + return paths +} + +// mediaInfoToLogAttrs returns a summary string for logging. +func mediaInfoToLogAttrs(infos []media.MediaInfo) string { + if len(infos) == 0 { + return "none" + } + parts := make([]string, 0, len(infos)) + for _, info := range infos { + parts = append(parts, fmt.Sprintf("%s(%s)", info.Type, info.FileName)) + } + return strings.Join(parts, ", ") +} diff --git a/internal/channels/whatsapp/whatsapp.go b/internal/channels/whatsapp/whatsapp.go index 981298e3..14625008 100644 --- a/internal/channels/whatsapp/whatsapp.go +++ b/internal/channels/whatsapp/whatsapp.go @@ -239,20 +239,31 @@ func (c *Channel) handleIncomingMessage(msg map[string]any) { } content, _ := msg["content"].(string) - if content == "" { - content = "[empty message]" - } - var media []string - if mediaData, ok := msg["media"].([]any); ok { - media = make([]string, 0, len(mediaData)) - for _, m := range mediaData { - if path, ok := m.(string); ok { - media = append(media, path) - } + // Resolve media (download URLs, verify local files, detect MIME) + var mediaPaths []string + if mediaData, ok := msg["media"].([]any); ok && len(mediaData) > 0 { + resolved := c.resolveMedia(mediaData) + mediaPaths = mediaInfoToPaths(resolved) + + if len(resolved) > 0 { + slog.Info("whatsapp media resolved", + "sender_id", senderID, + "count", len(resolved), + "items", mediaInfoToLogAttrs(resolved), + ) + } + + // If message has media but no text, use a placeholder + if content == "" && len(resolved) > 0 { + content = "[media]" } } + if content == "" { + content = "[empty message]" + } + metadata := make(map[string]string) if messageID, ok := msg["id"].(string); ok { metadata["message_id"] = messageID @@ -265,9 +276,10 @@ func (c *Channel) handleIncomingMessage(msg map[string]any) { "sender_id", senderID, "chat_id", chatID, "preview", channels.Truncate(content, 50), + "media_count", len(mediaPaths), ) - c.HandleMessage(senderID, chatID, content, media, metadata, peerKind) + c.HandleMessage(senderID, chatID, content, mediaPaths, metadata, peerKind) } // checkGroupPolicy evaluates the group policy for a sender, with pairing support. diff --git a/internal/channels/zalo/zalo.go b/internal/channels/zalo/zalo.go index 2c5ce04a..642c50d0 100644 --- a/internal/channels/zalo/zalo.go +++ b/internal/channels/zalo/zalo.go @@ -89,7 +89,7 @@ func (c *Channel) Start(ctx context.Context) error { if err != nil { return fmt.Errorf("zalo getMe failed: %w", err) } - slog.Info("zalo bot connected", "bot_id", info.ID, "bot_name", info.Name) + slog.Info("zalo bot connected", "bot_id", info.ID, "bot_name", info.Label()) c.SetRunning(true) @@ -425,29 +425,41 @@ type zaloAPIResponse struct { } type zaloBotInfo struct { - ID string `json:"id"` - Name string `json:"name"` + ID string `json:"id"` + Name string `json:"account_name"` + DisplayName string `json:"display_name"` +} + +func (b *zaloBotInfo) Label() string { + if b.DisplayName != "" { + return b.DisplayName + } + return b.Name } type zaloMessage struct { - MessageID string `json:"message_id"` - Text string `json:"text"` - Photo string `json:"photo"` - PhotoURL string `json:"photo_url"` - Caption string `json:"caption"` - From zaloFrom `json:"from"` - Chat zaloChat `json:"chat"` - Date int64 `json:"date"` + MessageID string `json:"message_id"` + MessageType string `json:"message_type"` + Text string `json:"text"` + Photo string `json:"photo"` + PhotoURL string `json:"photo_url"` + Caption string `json:"caption"` + From zaloFrom `json:"from"` + Chat zaloChat `json:"chat"` + Date int64 `json:"date"` } type zaloFrom struct { - ID string `json:"id"` - Username string `json:"username"` + ID string `json:"id"` + Username string `json:"username"` + DisplayName string `json:"display_name"` + IsBot bool `json:"is_bot"` } type zaloChat struct { - ID string `json:"id"` - Type string `json:"type"` + ID string `json:"id"` + Type string `json:"type"` + ChatType string `json:"chat_type"` } type zaloUpdate struct { @@ -521,11 +533,28 @@ func (c *Channel) getUpdates(timeout int) ([]zaloUpdate, error) { return nil, err } + // Try array first var updates []zaloUpdate - if err := json.Unmarshal(result, &updates); err != nil { + if err := json.Unmarshal(result, &updates); err == nil { + return updates, nil + } + + // Try single object (Zalo Bot Platform returns one update at a time) + var single zaloUpdate + if err := json.Unmarshal(result, &single); err == nil && single.EventName != "" { + slog.Info("zalo update received", "event", single.EventName) + return []zaloUpdate{single}, nil + } + + // Try wrapped {"updates": [...]} + var wrapped struct { + Updates []zaloUpdate `json:"updates"` + } + if err := json.Unmarshal(result, &wrapped); err != nil { + slog.Warn("zalo getUpdates unknown format", "raw", string(result[:min(len(result), 500)])) return nil, fmt.Errorf("unmarshal updates: %w", err) } - return updates, nil + return wrapped.Updates, nil } func (c *Channel) sendMessage(chatID, text string) error { diff --git a/internal/config/config_channels.go b/internal/config/config_channels.go index 0d759a42..279f6d1b 100644 --- a/internal/config/config_channels.go +++ b/internal/config/config_channels.go @@ -20,6 +20,7 @@ type ChannelsConfig struct { Zalo ZaloConfig `json:"zalo"` ZaloPersonal ZaloPersonalConfig `json:"zalo_personal"` Feishu FeishuConfig `json:"feishu"` + GoogleChat GoogleChatConfig `json:"google_chat"` PendingCompaction *PendingCompactionConfig `json:"pending_compaction,omitempty"` // global pending message compaction settings } @@ -180,6 +181,28 @@ type FeishuConfig struct { VoiceAgentID string `json:"voice_agent_id,omitempty"` } +type GoogleChatConfig struct { + Enabled bool `json:"enabled"` + ServiceAccountFile string `json:"serviceAccountFile"` + Mode string `json:"mode"` // "pubsub" (phase 1) | "webhook" (phase 2) + ProjectID string `json:"projectId"` + SubscriptionID string `json:"subscriptionId"` + PullIntervalMs int `json:"pullIntervalMs,omitempty"` + BotUser string `json:"botUser,omitempty"` + DMPolicy string `json:"dm_policy,omitempty"` // "open" (default), "allowlist", "disabled" + GroupPolicy string `json:"group_policy,omitempty"` // "open" (default), "allowlist", "disabled" + RequireMention *bool `json:"require_mention,omitempty"` // require @bot mention in groups (default true) + AllowFrom FlexibleStringSlice `json:"allow_from,omitempty"` + HistoryLimit int `json:"history_limit,omitempty"` // max pending group messages (default 50, 0=disabled) + LongFormThreshold int `json:"longFormThreshold,omitempty"` + LongFormFormat string `json:"longFormFormat,omitempty"` // "md" (default) | "txt" + MediaMaxMB int `json:"mediaMaxMb,omitempty"` + FileRetentionDays int `json:"fileRetentionDays,omitempty"` // auto-delete Drive files (0 = keep forever) + DrivePermission string `json:"drivePermission,omitempty"` // "domain" (default) | "anyone" + DriveDomain string `json:"driveDomain,omitempty"` // domain for "domain" permission (default "vnpay.vn") + BlockReply *bool `json:"block_reply,omitempty"` +} + // ProvidersConfig maps provider name to its config. type ProvidersConfig struct { Anthropic ProviderConfig `json:"anthropic"` diff --git a/internal/gateway/methods/party.go b/internal/gateway/methods/party.go new file mode 100644 index 00000000..bac3a497 --- /dev/null +++ b/internal/gateway/methods/party.go @@ -0,0 +1,564 @@ +package methods + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "sort" + + "github.com/google/uuid" + + "github.com/nextlevelbuilder/goclaw/internal/bus" + "github.com/nextlevelbuilder/goclaw/internal/gateway" + "github.com/nextlevelbuilder/goclaw/internal/party" + "github.com/nextlevelbuilder/goclaw/internal/providers" + "github.com/nextlevelbuilder/goclaw/internal/store" + "github.com/nextlevelbuilder/goclaw/pkg/protocol" +) + +// PartyMethods handles party.* WebSocket RPC methods. +type PartyMethods struct { + partyStore store.PartyStore + agentStore store.AgentStore + providerReg *providers.Registry + msgBus *bus.MessageBus +} + +// NewPartyMethods creates a new PartyMethods handler. +func NewPartyMethods(partyStore store.PartyStore, agentStore store.AgentStore, providerReg *providers.Registry, msgBus *bus.MessageBus) *PartyMethods { + return &PartyMethods{ + partyStore: partyStore, + agentStore: agentStore, + providerReg: providerReg, + msgBus: msgBus, + } +} + +// Register registers all party.* methods on the router. +func (m *PartyMethods) Register(router *gateway.MethodRouter) { + router.Register(protocol.MethodPartyStart, m.handleStart) + router.Register(protocol.MethodPartyRound, m.handleRound) + router.Register(protocol.MethodPartyQuestion, m.handleQuestion) + router.Register(protocol.MethodPartyAddContext, m.handleAddContext) + router.Register(protocol.MethodPartySummary, m.handleSummary) + router.Register(protocol.MethodPartyExit, m.handleExit) + router.Register(protocol.MethodPartyList, m.handleList) +} + +// getEngine returns a party engine using the best available provider. +// Prefers providers with a non-empty DefaultModel (DB providers with settings), +// falling back to the first name alphabetically for deterministic selection. +func (m *PartyMethods) getEngine() (*party.Engine, error) { + names := m.providerReg.List() + if len(names) == 0 { + return nil, fmt.Errorf("no LLM providers available") + } + + // Prefer a provider with DefaultModel set (typically DB providers with settings.default_model) + var bestName string + for _, name := range names { + p, err := m.providerReg.Get(name) + if err != nil { + continue + } + if p.DefaultModel() != "" { + if bestName == "" || name < bestName { + bestName = name + } + } + } + // Fallback: pick first alphabetically for determinism + if bestName == "" { + sort.Strings(names) + bestName = names[0] + } + + provider, err := m.providerReg.Get(bestName) + if err != nil { + return nil, fmt.Errorf("provider %s: %w", bestName, err) + } + return party.NewEngine(m.partyStore, m.agentStore, provider), nil +} + +// emitterForClient creates an EventEmitter that broadcasts to all connected WS clients. +func (m *PartyMethods) emitterForClient(client *gateway.Client) party.EventEmitter { + return func(event protocol.EventFrame) { + client.SendEvent(event) + } +} + +type partyStartParams struct { + Topic string `json:"topic"` + TeamPreset string `json:"team_preset,omitempty"` + Personas []string `json:"personas,omitempty"` + ContextURLs []string `json:"context_urls,omitempty"` +} + +func (m *PartyMethods) handleStart(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params partyStartParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + if params.Topic == "" { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "topic is required")) + return + } + + // Resolve personas from preset or custom list + personaKeys := params.Personas + if params.TeamPreset != "" { + for _, preset := range party.PresetTeams() { + if preset.Key == params.TeamPreset { + personaKeys = preset.Personas + break + } + } + } + if len(personaKeys) == 0 { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "no personas selected")) + return + } + + engine, err := m.getEngine() + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + // Load persona info from DB + personas, err := engine.LoadPersonas(ctx, personaKeys) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + // Marshal persona keys for storage + personasJSON, _ := json.Marshal(personaKeys) + + // Create session + sess := &store.PartySessionData{ + Topic: params.Topic, + TeamPreset: params.TeamPreset, + Status: store.PartyStatusDiscussing, + Mode: store.PartyModeStandard, + MaxRounds: 10, + UserID: client.UserID(), + Personas: personasJSON, + } + if err := m.partyStore.CreateSession(ctx, sess); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + slog.Info("party: session started", "session_id", sess.ID, "topic", params.Topic, "personas", len(personas)) + + // Emit started event + emit := m.emitterForClient(client) + personaInfos := make([]map[string]string, len(personas)) + for i, p := range personas { + personaInfos[i] = map[string]string{ + "agent_key": p.AgentKey, + "display_name": p.DisplayName, + "emoji": p.Emoji, + "movie_ref": p.MovieRef, + } + } + emit(*protocol.NewEvent(protocol.EventPartyStarted, map[string]any{ + "session_id": sess.ID, + "topic": params.Topic, + "personas": personaInfos, + })) + + // Generate introductions for each persona + for _, p := range personas { + intro := fmt.Sprintf("%s %s reporting in. Ready to discuss: %s", p.Emoji, p.DisplayName, params.Topic) + emit(*protocol.NewEvent(protocol.EventPartyPersonaIntro, map[string]any{ + "session_id": sess.ID, + "persona": p.AgentKey, + "emoji": p.Emoji, + "content": intro, + })) + } + + client.SendResponse(protocol.NewOKResponse(req.ID, map[string]any{ + "session_id": sess.ID, + "personas": personaInfos, + "status": sess.Status, + })) +} + +type partyRoundParams struct { + SessionID string `json:"session_id"` + Mode string `json:"mode,omitempty"` +} + +func (m *PartyMethods) handleRound(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params partyRoundParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + sessID, err := uuid.Parse(params.SessionID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid session_id")) + return + } + + sess, err := m.partyStore.GetSession(ctx, sessID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, "session not found")) + return + } + + if sess.Status != store.PartyStatusDiscussing { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "session is not in discussing state")) + return + } + + // Increment round + sess.Round++ + mode := params.Mode + if mode == "" { + mode = sess.Mode + } + + engine, err := m.getEngine() + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + // Load personas + var personaKeys []string + json.Unmarshal(sess.Personas, &personaKeys) + personas, err := engine.LoadPersonas(ctx, personaKeys) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + emit := m.emitterForClient(client) + emit(*protocol.NewEvent(protocol.EventPartyRoundStarted, map[string]any{ + "session_id": sess.ID, + "round": sess.Round, + "mode": mode, + })) + + // Run the round + var result *party.RoundResult + switch mode { + case store.PartyModeDeep: + result, err = engine.RunDeepRound(ctx, sess, personas, emit) + case store.PartyModeTokenRing: + result, err = engine.RunTokenRingRound(ctx, sess, personas, emit) + default: + result, err = engine.RunStandardRound(ctx, sess, personas, emit) + } + if err != nil { + slog.Error("party: round failed", "session", sess.ID, "round", sess.Round, "mode", mode, "error", err) + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + // Append to history + var history []party.RoundResult + json.Unmarshal(sess.History, &history) + history = append(history, *result) + historyJSON, _ := json.Marshal(history) + + // Update session + if err := m.partyStore.UpdateSession(ctx, sess.ID, map[string]any{ + "round": sess.Round, + "mode": mode, + "history": historyJSON, + }); err != nil { + slog.Warn("party: failed to update session", "error", err) + } + + emit(*protocol.NewEvent(protocol.EventPartyRoundComplete, map[string]any{ + "session_id": sess.ID, + "round": sess.Round, + "mode": mode, + })) + + client.SendResponse(protocol.NewOKResponse(req.ID, map[string]any{ + "round": sess.Round, + "mode": mode, + "messages": result.Messages, + })) +} + +type partyQuestionParams struct { + SessionID string `json:"session_id"` + Text string `json:"text"` +} + +func (m *PartyMethods) handleQuestion(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params partyQuestionParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + sessID, err := uuid.Parse(params.SessionID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid session_id")) + return + } + + sess, err := m.partyStore.GetSession(ctx, sessID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, "session not found")) + return + } + + // Temporarily set topic to the question for this round + originalTopic := sess.Topic + sess.Topic = fmt.Sprintf("%s\n\nUser question: %s", originalTopic, params.Text) + sess.Round++ + + engine, err := m.getEngine() + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + var personaKeys []string + json.Unmarshal(sess.Personas, &personaKeys) + personas, err := engine.LoadPersonas(ctx, personaKeys) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + emit := m.emitterForClient(client) + result, err := engine.RunStandardRound(ctx, sess, personas, emit) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + // Restore topic and update session + sess.Topic = originalTopic + var history []party.RoundResult + json.Unmarshal(sess.History, &history) + history = append(history, *result) + historyJSON, _ := json.Marshal(history) + + if err := m.partyStore.UpdateSession(ctx, sess.ID, map[string]any{ + "round": sess.Round, + "history": historyJSON, + }); err != nil { + slog.Warn("party: failed to update session", "error", err) + } + + client.SendResponse(protocol.NewOKResponse(req.ID, map[string]any{ + "round": sess.Round, + "messages": result.Messages, + })) +} + +type partyAddContextParams struct { + SessionID string `json:"session_id"` + Type string `json:"type"` + Name string `json:"name,omitempty"` + Content string `json:"content,omitempty"` + URL string `json:"url,omitempty"` +} + +func (m *PartyMethods) handleAddContext(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params partyAddContextParams + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + sessID, err := uuid.Parse(params.SessionID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid session_id")) + return + } + + sess, err := m.partyStore.GetSession(ctx, sessID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, "session not found")) + return + } + + // Parse existing context + var sessionCtx map[string]any + if json.Unmarshal(sess.Context, &sessionCtx) != nil { + sessionCtx = make(map[string]any) + } + + // Add new context based on type + switch params.Type { + case "document": + docs, _ := sessionCtx["documents"].([]any) + docs = append(docs, map[string]string{"name": params.Name, "content": params.Content, "source": "upload"}) + sessionCtx["documents"] = docs + case "meeting_notes": + sessionCtx["meeting_notes"] = params.Content + case "custom": + sessionCtx["custom"] = params.Content + default: + sessionCtx[params.Type] = params.Content + } + + contextJSON, _ := json.Marshal(sessionCtx) + if err := m.partyStore.UpdateSession(ctx, sess.ID, map[string]any{ + "context": contextJSON, + }); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + emit := m.emitterForClient(client) + emit(*protocol.NewEvent(protocol.EventPartyContextAdded, map[string]any{ + "session_id": sess.ID, + "name": params.Name, + "type": params.Type, + })) + + client.SendResponse(protocol.NewOKResponse(req.ID, map[string]any{ + "ok": true, + "context_count": len(sessionCtx), + })) +} + +func (m *PartyMethods) handleSummary(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params struct { + SessionID string `json:"session_id"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + sessID, err := uuid.Parse(params.SessionID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid session_id")) + return + } + + sess, err := m.partyStore.GetSession(ctx, sessID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, "session not found")) + return + } + + engine, err := m.getEngine() + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + summary, err := engine.GenerateSummary(ctx, sess) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + // Store summary in session + summaryJSON, _ := json.Marshal(summary) + m.partyStore.UpdateSession(ctx, sess.ID, map[string]any{ + "summary": summaryJSON, + "status": store.PartyStatusSummarizing, + }) + + emit := m.emitterForClient(client) + emit(*protocol.NewEvent(protocol.EventPartySummaryReady, map[string]any{ + "session_id": sess.ID, + "summary": summary, + })) + + client.SendResponse(protocol.NewOKResponse(req.ID, summary)) +} + +func (m *PartyMethods) handleExit(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params struct { + SessionID string `json:"session_id"` + FollowUp string `json:"follow_up,omitempty"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + sessID, err := uuid.Parse(params.SessionID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid session_id")) + return + } + + sess, err := m.partyStore.GetSession(ctx, sessID) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, "session not found")) + return + } + + // Generate summary if not yet done + var summary *party.SummaryResult + if sess.Summary == nil || string(sess.Summary) == "null" { + engine, err := m.getEngine() + if err == nil { + summary, _ = engine.GenerateSummary(ctx, sess) + } + } else { + json.Unmarshal(sess.Summary, &summary) + } + + // Close session + updates := map[string]any{"status": store.PartyStatusClosed} + if summary != nil { + summaryJSON, _ := json.Marshal(summary) + updates["summary"] = summaryJSON + } + m.partyStore.UpdateSession(ctx, sess.ID, updates) + + emit := m.emitterForClient(client) + emit(*protocol.NewEvent(protocol.EventPartyClosed, map[string]any{ + "session_id": sess.ID, + })) + + response := map[string]any{ + "session_id": sess.ID, + "status": store.PartyStatusClosed, + } + if summary != nil { + response["summary"] = summary + } + + client.SendResponse(protocol.NewOKResponse(req.ID, response)) +} + +func (m *PartyMethods) handleList(ctx context.Context, client *gateway.Client, req *protocol.RequestFrame) { + var params struct { + Status string `json:"status,omitempty"` + Limit int `json:"limit,omitempty"` + } + if err := json.Unmarshal(req.Params, ¶ms); err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInvalidRequest, "invalid params")) + return + } + + limit := params.Limit + if limit <= 0 { + limit = 20 + } + + sessions, err := m.partyStore.ListSessions(ctx, client.UserID(), params.Status, limit) + if err != nil { + client.SendResponse(protocol.NewErrorResponse(req.ID, protocol.ErrInternal, err.Error())) + return + } + + client.SendResponse(protocol.NewOKResponse(req.ID, map[string]any{ + "sessions": sessions, + "count": len(sessions), + })) +} diff --git a/internal/gateway/server.go b/internal/gateway/server.go index b0f74fd1..af06fcfa 100644 --- a/internal/gateway/server.go +++ b/internal/gateway/server.go @@ -57,6 +57,7 @@ type Server struct { mediaServeHandler *httpapi.MediaServeHandler // media serve endpoint activityHandler *httpapi.ActivityHandler // activity audit log API usageHandler *httpapi.UsageHandler // usage analytics API + projectHandler *httpapi.ProjectHandler // project CRUD + MCP overrides API agentStore store.AgentStore // for context injection in tools_invoke msgBus *bus.MessageBus // for MCP bridge media delivery @@ -188,6 +189,11 @@ func (s *Server) BuildMux() *http.ServeMux { s.mcpHandler.RegisterRoutes(mux) } + // Project CRUD + MCP overrides API + if s.projectHandler != nil { + s.projectHandler.RegisterRoutes(mux) + } + // Custom tool CRUD API if s.customToolsHandler != nil { s.customToolsHandler.RegisterRoutes(mux) @@ -489,6 +495,9 @@ func (s *Server) SetActivityHandler(h *httpapi.ActivityHandler) { s.activityHand // SetUsageHandler sets the usage analytics handler. func (s *Server) SetUsageHandler(h *httpapi.UsageHandler) { s.usageHandler = h } +// SetProjectHandler sets the project CRUD + MCP overrides handler. +func (s *Server) SetProjectHandler(h *httpapi.ProjectHandler) { s.projectHandler = h } + // SetAgentStore sets the agent store for context injection in tools_invoke. func (s *Server) SetAgentStore(as store.AgentStore) { s.agentStore = as } diff --git a/internal/http/projects.go b/internal/http/projects.go new file mode 100644 index 00000000..a0f55fa7 --- /dev/null +++ b/internal/http/projects.go @@ -0,0 +1,277 @@ +package http + +import ( + "encoding/json" + "log/slog" + "net/http" + + "github.com/google/uuid" + + "github.com/nextlevelbuilder/goclaw/internal/i18n" + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// ProjectHandler handles project CRUD and MCP override HTTP endpoints. +type ProjectHandler struct { + store store.ProjectStore + token string +} + +// NewProjectHandler creates a handler for project management endpoints. +func NewProjectHandler(store store.ProjectStore, token string) *ProjectHandler { + return &ProjectHandler{store: store, token: token} +} + +// RegisterRoutes registers all project routes on the given mux. +func (h *ProjectHandler) RegisterRoutes(mux *http.ServeMux) { + // Project CRUD + mux.HandleFunc("GET /v1/projects", h.auth(h.handleListProjects)) + mux.HandleFunc("POST /v1/projects", h.auth(h.handleCreateProject)) + mux.HandleFunc("GET /v1/projects/by-chat", h.auth(h.handleGetByChat)) + mux.HandleFunc("GET /v1/projects/{id}", h.auth(h.handleGetProject)) + mux.HandleFunc("PUT /v1/projects/{id}", h.auth(h.handleUpdateProject)) + mux.HandleFunc("DELETE /v1/projects/{id}", h.auth(h.handleDeleteProject)) + + // MCP overrides + mux.HandleFunc("GET /v1/projects/{id}/mcp", h.auth(h.handleListMCPOverrides)) + mux.HandleFunc("PUT /v1/projects/{id}/mcp/{serverName}", h.auth(h.handleSetMCPOverride)) + mux.HandleFunc("DELETE /v1/projects/{id}/mcp/{serverName}", h.auth(h.handleRemoveMCPOverride)) +} + +func (h *ProjectHandler) auth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if h.token != "" { + if extractBearerToken(r) != h.token { + locale := extractLocale(r) + writeJSON(w, http.StatusUnauthorized, map[string]string{"error": i18n.T(locale, i18n.MsgUnauthorized)}) + return + } + } + userID := extractUserID(r) + ctx := store.WithLocale(r.Context(), extractLocale(r)) + if userID != "" { + ctx = store.WithUserID(ctx, userID) + } + r = r.WithContext(ctx) + next(w, r) + } +} + +// --- Project CRUD --- + +func (h *ProjectHandler) handleListProjects(w http.ResponseWriter, r *http.Request) { + projects, err := h.store.ListProjects(r.Context()) + if err != nil { + slog.Error("projects.list", "error", err) + locale := store.LocaleFromContext(r.Context()) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": i18n.T(locale, i18n.MsgFailedToList, "projects")}) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{"projects": projects}) +} + +func (h *ProjectHandler) handleCreateProject(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + + var payload struct { + Name string `json:"name"` + Slug string `json:"slug"` + ChannelType *string `json:"channel_type"` + ChatID *string `json:"chat_id"` + TeamID *uuid.UUID `json:"team_id"` + Description *string `json:"description"` + Status string `json:"status"` + } + if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20)).Decode(&payload); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidJSON)}) + return + } + + if payload.Name == "" || payload.Slug == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgRequired, "name and slug")}) + return + } + if !isValidSlug(payload.Slug) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidSlug, "slug")}) + return + } + + if payload.Status == "" { + payload.Status = "active" + } + + project := store.Project{ + Name: payload.Name, + Slug: payload.Slug, + ChannelType: payload.ChannelType, + ChatID: payload.ChatID, + TeamID: payload.TeamID, + Description: payload.Description, + Status: payload.Status, + CreatedBy: store.UserIDFromContext(r.Context()), + } + + if err := h.store.CreateProject(r.Context(), &project); err != nil { + slog.Error("projects.create", "error", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusCreated, project) +} + +func (h *ProjectHandler) handleGetProject(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidID, "project")}) + return + } + + project, err := h.store.GetProject(r.Context(), id) + if err != nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": i18n.T(locale, i18n.MsgNotFound, "project", id.String())}) + return + } + + writeJSON(w, http.StatusOK, project) +} + +func (h *ProjectHandler) handleGetByChat(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + channelType := r.URL.Query().Get("channel_type") + chatID := r.URL.Query().Get("chat_id") + + if channelType == "" || chatID == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgRequired, "channel_type and chat_id")}) + return + } + + project, err := h.store.GetProjectByChatID(r.Context(), channelType, chatID) + if err != nil || project == nil { + writeJSON(w, http.StatusNotFound, map[string]string{"error": "project not found for this chat"}) + return + } + + writeJSON(w, http.StatusOK, project) +} + +func (h *ProjectHandler) handleUpdateProject(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidID, "project")}) + return + } + + var updates map[string]any + if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20)).Decode(&updates); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidJSON)}) + return + } + + if slug, ok := updates["slug"]; ok { + if s, _ := slug.(string); !isValidSlug(s) { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidSlug, "slug")}) + return + } + } + + if err := h.store.UpdateProject(r.Context(), id, updates); err != nil { + slog.Error("projects.update", "error", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "updated"}) +} + +func (h *ProjectHandler) handleDeleteProject(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidID, "project")}) + return + } + + if err := h.store.DeleteProject(r.Context(), id); err != nil { + slog.Error("projects.delete", "error", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + w.WriteHeader(http.StatusNoContent) +} + +// --- MCP Overrides --- + +func (h *ProjectHandler) handleListMCPOverrides(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidID, "project")}) + return + } + + overrides, err := h.store.GetMCPOverrides(r.Context(), id) + if err != nil { + slog.Error("projects.list_mcp_overrides", "error", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": i18n.T(locale, i18n.MsgFailedToList, "MCP overrides")}) + return + } + + writeJSON(w, http.StatusOK, map[string]interface{}{"overrides": overrides}) +} + +func (h *ProjectHandler) handleSetMCPOverride(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidID, "project")}) + return + } + + serverName := r.PathValue("serverName") + if serverName == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgRequired, "serverName")}) + return + } + + var envOverrides map[string]string + if err := json.NewDecoder(http.MaxBytesReader(w, r.Body, 1<<20)).Decode(&envOverrides); err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidJSON)}) + return + } + + if err := h.store.SetMCPOverride(r.Context(), id, serverName, envOverrides); err != nil { + slog.Error("projects.set_mcp_override", "error", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + writeJSON(w, http.StatusOK, map[string]string{"status": "updated"}) +} + +func (h *ProjectHandler) handleRemoveMCPOverride(w http.ResponseWriter, r *http.Request) { + locale := store.LocaleFromContext(r.Context()) + id, err := uuid.Parse(r.PathValue("id")) + if err != nil { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgInvalidID, "project")}) + return + } + + serverName := r.PathValue("serverName") + if serverName == "" { + writeJSON(w, http.StatusBadRequest, map[string]string{"error": i18n.T(locale, i18n.MsgRequired, "serverName")}) + return + } + + if err := h.store.RemoveMCPOverride(r.Context(), id, serverName); err != nil { + slog.Error("projects.remove_mcp_override", "error", err) + writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/mcp/manager.go b/internal/mcp/manager.go index c4d64665..f8bcf1f7 100644 --- a/internal/mcp/manager.go +++ b/internal/mcp/manager.go @@ -148,7 +148,9 @@ func (m *Manager) Start(ctx context.Context) error { // LoadForAgent connects MCP servers accessible by a specific agent+user. // Previously registered MCP tools for this manager are cleared and reloaded. -func (m *Manager) LoadForAgent(ctx context.Context, agentID uuid.UUID, userID string) error { +// projectID and projectOverrides enable per-project MCP process isolation: +// each project gets its own pool entry with merged env vars. +func (m *Manager) LoadForAgent(ctx context.Context, agentID uuid.UUID, userID string, projectID string, projectOverrides map[string]map[string]string) error { if m.store == nil { return nil } @@ -173,8 +175,13 @@ func (m *Manager) LoadForAgent(ctx context.Context, agentID uuid.UUID, userID st if m.pool != nil { // Pool mode: acquire shared connection, create per-agent BridgeTools + var serverOverrides map[string]string + if projectOverrides != nil { + serverOverrides = projectOverrides[srv.Name] + } if err := m.connectViaPool(ctx, srv.Name, srv.Transport, srv.Command, - args, env, srv.URL, headers, srv.ToolPrefix, srv.TimeoutSec); err != nil { + args, env, srv.URL, headers, srv.ToolPrefix, srv.TimeoutSec, + projectID, serverOverrides); err != nil { slog.Warn("mcp.server.connect_failed", "server", srv.Name, "error", err) continue } @@ -189,8 +196,13 @@ func (m *Manager) LoadForAgent(ctx context.Context, agentID uuid.UUID, userID st } // Apply tool filtering from grants + // Use poolKey for pool-backed servers so filterTools finds the right entries + poolKey := srv.Name + if m.pool != nil && projectID != "" { + poolKey = srv.Name + ":" + projectID + } if len(info.ToolAllow) > 0 || len(info.ToolDeny) > 0 { - m.filterTools(srv.Name, info.ToolAllow, info.ToolDeny) + m.filterTools(poolKey, info.ToolAllow, info.ToolDeny) } } diff --git a/internal/mcp/manager_connect.go b/internal/mcp/manager_connect.go index 9d4816cf..9efa7968 100644 --- a/internal/mcp/manager_connect.go +++ b/internal/mcp/manager_connect.go @@ -121,10 +121,35 @@ func (m *Manager) registerBridgeTools(ss *serverState, mcpTools []mcpgo.Tool, se return registeredNames } +// mergeEnv merges base env with project overrides. +// Project overrides take priority (add/replace only, never remove base keys). +func mergeEnv(base, overrides map[string]string) map[string]string { + if len(overrides) == 0 { + return base + } + merged := make(map[string]string, len(base)+len(overrides)) + for k, v := range base { + merged[k] = v + } + for k, v := range overrides { + merged[k] = v + } + return merged +} + // connectViaPool acquires a shared connection from the pool and creates // per-agent BridgeTools pointing to the shared client/connected pointers. -func (m *Manager) connectViaPool(ctx context.Context, name, transportType, command string, args []string, env map[string]string, url string, headers map[string]string, toolPrefix string, timeoutSec int) error { - entry, err := m.pool.Acquire(ctx, name, transportType, command, args, env, url, headers, timeoutSec) +func (m *Manager) connectViaPool(ctx context.Context, name, transportType, command string, + args []string, env map[string]string, url string, headers map[string]string, + toolPrefix string, timeoutSec int, projectID string, projectEnvOverrides map[string]string) error { + + mergedEnv := mergeEnv(env, projectEnvOverrides) + poolKey := name + if projectID != "" { + poolKey = name + ":" + projectID + } + + entry, err := m.pool.Acquire(ctx, poolKey, name, transportType, command, args, mergedEnv, url, headers, timeoutSec) if err != nil { return err } @@ -134,24 +159,25 @@ func (m *Manager) connectViaPool(ctx context.Context, name, transportType, comma // Track server state and per-agent tool names m.mu.Lock() - m.servers[name] = entry.state + m.servers[poolKey] = entry.state if m.poolServers == nil { m.poolServers = make(map[string]struct{}) } - m.poolServers[name] = struct{}{} + m.poolServers[poolKey] = struct{}{} if m.poolToolNames == nil { m.poolToolNames = make(map[string][]string) } - m.poolToolNames[name] = registeredNames + m.poolToolNames[poolKey] = registeredNames m.mu.Unlock() if len(registeredNames) > 0 { - tools.RegisterToolGroup("mcp:"+name, registeredNames) + tools.RegisterToolGroup("mcp:"+poolKey, registeredNames) m.updateMCPGroup() } slog.Info("mcp.server.connected_via_pool", "server", name, + "poolKey", poolKey, "transport", transportType, "tools", len(registeredNames), ) diff --git a/internal/mcp/manager_tools.go b/internal/mcp/manager_tools.go index f469abf3..49184c21 100644 --- a/internal/mcp/manager_tools.go +++ b/internal/mcp/manager_tools.go @@ -142,16 +142,17 @@ func DiscoverTools(ctx context.Context, transportType, command string, args []st } // filterTools removes tools from the registry that don't match the allow/deny lists. -func (m *Manager) filterTools(serverName string, allow, deny []string) { +// poolKey is the composite key used in poolServers/poolToolNames (e.g. "name" or "name:projectID"). +func (m *Manager) filterTools(poolKey string, allow, deny []string) { m.mu.Lock() defer m.mu.Unlock() // Get the tool names list (pool-backed or standalone) var toolNames []string - _, isPool := m.poolServers[serverName] + _, isPool := m.poolServers[poolKey] if isPool { - toolNames = m.poolToolNames[serverName] - } else if ss, ok := m.servers[serverName]; ok { + toolNames = m.poolToolNames[poolKey] + } else if ss, ok := m.servers[poolKey]; ok { toolNames = ss.toolNames } else { return @@ -192,8 +193,8 @@ func (m *Manager) filterTools(serverName string, allow, deny []string) { // Update the correct tool names list if isPool { - m.poolToolNames[serverName] = kept + m.poolToolNames[poolKey] = kept } else { - m.servers[serverName].toolNames = kept + m.servers[poolKey].toolNames = kept } } diff --git a/internal/mcp/merge_env_test.go b/internal/mcp/merge_env_test.go new file mode 100644 index 00000000..f2e09697 --- /dev/null +++ b/internal/mcp/merge_env_test.go @@ -0,0 +1,150 @@ +package mcp + +import "testing" + +func TestMergeEnv(t *testing.T) { + tests := []struct { + name string + base map[string]string + overrides map[string]string + want map[string]string + }{ + { + name: "nil overrides returns base unchanged", + base: map[string]string{"A": "1", "B": "2"}, + overrides: nil, + want: map[string]string{"A": "1", "B": "2"}, + }, + { + name: "empty overrides returns base unchanged", + base: map[string]string{"A": "1"}, + overrides: map[string]string{}, + want: map[string]string{"A": "1"}, + }, + { + name: "override replaces base key", + base: map[string]string{"GITLAB_URL": "https://git.example.com", "GITLAB_PROJECT_ID": "1"}, + overrides: map[string]string{"GITLAB_PROJECT_ID": "42"}, + want: map[string]string{"GITLAB_URL": "https://git.example.com", "GITLAB_PROJECT_ID": "42"}, + }, + { + name: "override adds new key without removing base", + base: map[string]string{"GITLAB_URL": "https://git.example.com"}, + overrides: map[string]string{"GITLAB_PROJECT_PATH": "duhd/xpos"}, + want: map[string]string{"GITLAB_URL": "https://git.example.com", "GITLAB_PROJECT_PATH": "duhd/xpos"}, + }, + { + name: "nil base with overrides", + base: nil, + overrides: map[string]string{"KEY": "val"}, + want: map[string]string{"KEY": "val"}, + }, + { + name: "both nil returns nil-like base", + base: nil, + overrides: nil, + want: nil, + }, + { + name: "multiple overrides replace multiple base keys", + base: map[string]string{"A": "1", "B": "2", "C": "3"}, + overrides: map[string]string{"A": "10", "C": "30", "D": "40"}, + want: map[string]string{"A": "10", "B": "2", "C": "30", "D": "40"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := mergeEnv(tt.base, tt.overrides) + if tt.want == nil { + if tt.base == nil && got != nil && len(got) > 0 { + t.Errorf("expected nil-like result, got %v", got) + } + return + } + if len(got) != len(tt.want) { + t.Errorf("len mismatch: got %d, want %d\ngot: %v\nwant: %v", len(got), len(tt.want), got, tt.want) + return + } + for k, wantV := range tt.want { + if gotV, ok := got[k]; !ok || gotV != wantV { + t.Errorf("key %q: got %q, want %q", k, gotV, wantV) + } + } + }) + } +} + +func TestMergeEnv_BaseNotMutated(t *testing.T) { + base := map[string]string{"A": "1", "B": "2"} + overrides := map[string]string{"A": "override", "C": "new"} + + _ = mergeEnv(base, overrides) + + if base["A"] != "1" { + t.Errorf("base was mutated: A=%q, want '1'", base["A"]) + } + if _, ok := base["C"]; ok { + t.Error("base was mutated: unexpected key 'C'") + } +} + +func TestPoolKeyComputation(t *testing.T) { + tests := []struct { + name string + server string + projectID string + wantKey string + }{ + { + name: "no project — backward compat key", + server: "gitlab", + projectID: "", + wantKey: "gitlab", + }, + { + name: "with project — composite key", + server: "gitlab", + projectID: "uuid-xpos", + wantKey: "gitlab:uuid-xpos", + }, + { + name: "different projects same server — different keys", + server: "atlassian", + projectID: "uuid-payment", + wantKey: "atlassian:uuid-payment", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + poolKey := tt.server + if tt.projectID != "" { + poolKey = tt.server + ":" + tt.projectID + } + if poolKey != tt.wantKey { + t.Errorf("poolKey: got %q, want %q", poolKey, tt.wantKey) + } + }) + } +} + +func TestPoolKeyIsolation(t *testing.T) { + server := "gitlab" + projectA := "uuid-xpos" + projectB := "uuid-payment" + + keyA := server + ":" + projectA + keyB := server + ":" + projectB + keyNone := server + + if keyA == keyB { + t.Error("project A and B should have different pool keys") + } + if keyA == keyNone { + t.Error("project A and no-project should have different pool keys") + } + if keyB == keyNone { + t.Error("project B and no-project should have different pool keys") + } +} diff --git a/internal/mcp/pool.go b/internal/mcp/pool.go index 4786051f..fd8a08e6 100644 --- a/internal/mcp/pool.go +++ b/internal/mcp/pool.go @@ -34,25 +34,27 @@ func NewPool() *Pool { // Acquire returns a shared connection for the named server. // If no connection exists, it connects using the provided config. // Increments the reference count. -func (p *Pool) Acquire(ctx context.Context, name, transportType, command string, args []string, env map[string]string, url string, headers map[string]string, timeoutSec int) (*poolEntry, error) { +// poolKey is the composite key (e.g. "name" or "name:projectID") used for +// process isolation; name is the server name used for connectAndDiscover. +func (p *Pool) Acquire(ctx context.Context, poolKey, name, transportType, command string, args []string, env map[string]string, url string, headers map[string]string, timeoutSec int) (*poolEntry, error) { p.mu.Lock() - if entry, ok := p.servers[name]; ok && entry.state.connected.Load() { + if entry, ok := p.servers[poolKey]; ok && entry.state.connected.Load() { entry.refCount++ p.mu.Unlock() - slog.Debug("mcp.pool.reuse", "server", name, "refCount", entry.refCount) + slog.Debug("mcp.pool.reuse", "server", name, "poolKey", poolKey, "refCount", entry.refCount) return entry, nil } // If entry exists but disconnected, close old connection first - if old, ok := p.servers[name]; ok { + if old, ok := p.servers[poolKey]; ok { if old.state.cancel != nil { old.state.cancel() } if old.state.client != nil { _ = old.state.client.Close() } - delete(p.servers, name) + delete(p.servers, poolKey) } p.mu.Unlock() @@ -76,7 +78,7 @@ func (p *Pool) Acquire(ctx context.Context, name, transportType, command string, p.mu.Lock() // Check if another goroutine connected while we were connecting - if existing, ok := p.servers[name]; ok && existing.state.connected.Load() { + if existing, ok := p.servers[poolKey]; ok && existing.state.connected.Load() { // Use existing, close ours p.mu.Unlock() hcancel() @@ -86,26 +88,26 @@ func (p *Pool) Acquire(ctx context.Context, name, transportType, command string, p.mu.Unlock() return existing, nil } - p.servers[name] = entry + p.servers[poolKey] = entry p.mu.Unlock() - slog.Info("mcp.pool.connected", "server", name, "tools", len(mcpTools)) + slog.Info("mcp.pool.connected", "server", name, "poolKey", poolKey, "tools", len(mcpTools)) return entry, nil } // Release decrements the reference count for a server. // The connection is NOT closed when refCount reaches 0 — it stays // alive for future agents. Use Stop() to close all connections. -func (p *Pool) Release(name string) { +func (p *Pool) Release(poolKey string) { p.mu.Lock() defer p.mu.Unlock() - if entry, ok := p.servers[name]; ok { + if entry, ok := p.servers[poolKey]; ok { entry.refCount-- if entry.refCount < 0 { entry.refCount = 0 } - slog.Debug("mcp.pool.release", "server", name, "refCount", entry.refCount) + slog.Debug("mcp.pool.release", "poolKey", poolKey, "refCount", entry.refCount) } } diff --git a/internal/party/engine.go b/internal/party/engine.go new file mode 100644 index 00000000..3362524e --- /dev/null +++ b/internal/party/engine.go @@ -0,0 +1,307 @@ +package party + +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + "strings" + "sync" + + "github.com/nextlevelbuilder/goclaw/internal/providers" + "github.com/nextlevelbuilder/goclaw/internal/store" + "github.com/nextlevelbuilder/goclaw/pkg/protocol" +) + +// EventEmitter sends party events to connected clients. +type EventEmitter func(event protocol.EventFrame) + +// Engine orchestrates party mode discussions. +type Engine struct { + partyStore store.PartyStore + agentStore store.AgentStore + provider providers.Provider +} + +// NewEngine creates a new party engine. +func NewEngine(partyStore store.PartyStore, agentStore store.AgentStore, provider providers.Provider) *Engine { + return &Engine{ + partyStore: partyStore, + agentStore: agentStore, + provider: provider, + } +} + +// LoadPersonas loads persona info from agent DB for the given keys. +func (e *Engine) LoadPersonas(ctx context.Context, keys []string) ([]PersonaInfo, error) { + var personas []PersonaInfo + for _, key := range keys { + agent, err := e.agentStore.GetByKey(ctx, key) + if err != nil { + return nil, fmt.Errorf("persona %s not found: %w", key, err) + } + pi := PersonaInfo{ + AgentKey: key, + DisplayName: agent.DisplayName, + } + // Extract persona metadata from other_config + var cfg map[string]json.RawMessage + if json.Unmarshal(agent.OtherConfig, &cfg) == nil { + if personaJSON, ok := cfg["persona"]; ok { + var pm struct { + Emoji string `json:"emoji"` + MovieRef string `json:"movie_ref"` + SpeakingStyle string `json:"speaking_style"` + ExpertiseWeight map[string]float64 `json:"expertise_weight"` + } + if json.Unmarshal(personaJSON, &pm) == nil { + pi.Emoji = pm.Emoji + pi.MovieRef = pm.MovieRef + pi.SpeakingStyle = pm.SpeakingStyle + pi.ExpertiseWeight = pm.ExpertiseWeight + } + } + } + if pi.Emoji == "" { + pi.Emoji = "🤖" + } + personas = append(personas, pi) + } + return personas, nil +} + +// RunStandardRound executes a standard mode round (1 LLM call, all personas). +func (e *Engine) RunStandardRound(ctx context.Context, session *store.PartySessionData, personas []PersonaInfo, emit EventEmitter) (*RoundResult, error) { + slog.Info("party: standard round", "session", session.ID, "round", session.Round) + + systemPrompt := "You are a party mode facilitator. Generate responses for each persona in character." + userPrompt := BuildStandardRoundPrompt(session, personas) + + resp, err := e.llmCall(ctx, systemPrompt, userPrompt) + if err != nil { + return nil, fmt.Errorf("standard round LLM call: %w", err) + } + + messages := parsePersonaMessages(resp, personas) + for _, m := range messages { + emit(*protocol.NewEvent(protocol.EventPartyPersonaSpoke, map[string]any{ + "session_id": session.ID, "persona": m.PersonaKey, + "emoji": m.Emoji, "content": m.Content, + })) + } + + return &RoundResult{Round: session.Round, Mode: store.PartyModeStandard, Messages: messages}, nil +} + +// RunDeepRound executes Deep Mode: parallel thinking → cross-talk. +func (e *Engine) RunDeepRound(ctx context.Context, session *store.PartySessionData, personas []PersonaInfo, emit EventEmitter) (*RoundResult, error) { + slog.Info("party: deep round (parallel)", "session", session.ID, "round", session.Round, "personas", len(personas)) + + // Step 1: Parallel thinking + thoughts, err := e.runParallelThinking(ctx, session, personas, emit) + if err != nil { + return nil, fmt.Errorf("parallel thinking: %w", err) + } + + // Step 2: Cross-talk (1 LLM call) + systemPrompt := "You are a party mode facilitator generating cross-talk between personas." + userPrompt := BuildCrossTalkPrompt(session, personas, thoughts) + + resp, err := e.llmCall(ctx, systemPrompt, userPrompt) + if err != nil { + return nil, fmt.Errorf("cross-talk LLM call: %w", err) + } + + messages := parsePersonaMessages(resp, personas) + for i := range messages { + // Attach thinking from Step 1 + for _, t := range thoughts { + if t.PersonaKey == messages[i].PersonaKey { + messages[i].Thinking = t.Content + break + } + } + emit(*protocol.NewEvent(protocol.EventPartyPersonaSpoke, map[string]any{ + "session_id": session.ID, "persona": messages[i].PersonaKey, + "emoji": messages[i].Emoji, "content": messages[i].Content, + })) + } + + return &RoundResult{Round: session.Round, Mode: store.PartyModeDeep, Messages: messages}, nil +} + +// RunTokenRingRound executes Token-Ring: parallel thinking → sequential turns. +func (e *Engine) RunTokenRingRound(ctx context.Context, session *store.PartySessionData, personas []PersonaInfo, emit EventEmitter) (*RoundResult, error) { + slog.Info("party: token-ring round", "session", session.ID, "round", session.Round, "personas", len(personas)) + + // Step 1: Parallel thinking + thoughts, err := e.runParallelThinking(ctx, session, personas, emit) + if err != nil { + return nil, fmt.Errorf("parallel thinking: %w", err) + } + + // Step 2: Sequential turns + var messages []PersonaMessage + var priorTurns []PersonaMessage + + for i, persona := range personas { + isLast := i == len(personas)-1 + + soulMD := e.loadPersonaSoulMD(ctx, persona.AgentKey) + systemPrompt := BuildPersonaSystemPrompt(persona, session, soulMD) + userPrompt := BuildTokenRingTurnPrompt(session, persona, thoughts, priorTurns, isLast) + + resp, err := e.llmCall(ctx, systemPrompt, userPrompt) + if err != nil { + slog.Warn("party: token-ring turn failed", "persona", persona.AgentKey, "error", err) + continue + } + + msg := PersonaMessage{ + PersonaKey: persona.AgentKey, + DisplayName: persona.DisplayName, + Emoji: persona.Emoji, + Content: strings.TrimSpace(resp), + } + messages = append(messages, msg) + priorTurns = append(priorTurns, msg) + + // Emit immediately — user sees each persona respond in real-time + emit(*protocol.NewEvent(protocol.EventPartyPersonaSpoke, map[string]any{ + "session_id": session.ID, "persona": msg.PersonaKey, + "emoji": msg.Emoji, "content": msg.Content, + })) + } + + return &RoundResult{Round: session.Round, Mode: store.PartyModeTokenRing, Messages: messages}, nil +} + +// runParallelThinking executes independent thinking for all personas in parallel. +func (e *Engine) runParallelThinking(ctx context.Context, session *store.PartySessionData, personas []PersonaInfo, emit EventEmitter) ([]PersonaThought, error) { + thoughts := make([]PersonaThought, len(personas)) + errs := make([]error, len(personas)) + var wg sync.WaitGroup + + for i, persona := range personas { + wg.Add(1) + go func(idx int, p PersonaInfo) { + defer wg.Done() + + emit(*protocol.NewEvent(protocol.EventPartyPersonaThinking, map[string]any{ + "session_id": session.ID, "persona": p.AgentKey, "emoji": p.Emoji, + })) + + soulMD := e.loadPersonaSoulMD(ctx, p.AgentKey) + systemPrompt := BuildPersonaSystemPrompt(p, session, soulMD) + userPrompt := BuildThinkingPrompt(session, p) + + resp, err := e.llmCall(ctx, systemPrompt, userPrompt) + if err != nil { + errs[idx] = err + return + } + thoughts[idx] = PersonaThought{PersonaKey: p.AgentKey, Emoji: p.Emoji, Content: strings.TrimSpace(resp)} + }(i, persona) + } + wg.Wait() + + for i, err := range errs { + if err != nil { + return nil, fmt.Errorf("persona %s thinking failed: %w", personas[i].AgentKey, err) + } + } + + return thoughts, nil +} + +// GenerateSummary generates a discussion summary. +func (e *Engine) GenerateSummary(ctx context.Context, session *store.PartySessionData) (*SummaryResult, error) { + prompt := BuildSummaryPrompt(session) + resp, err := e.llmCall(ctx, "You are a discussion summarizer. Generate structured markdown summaries.", prompt) + if err != nil { + return nil, fmt.Errorf("summary LLM call: %w", err) + } + + var personaKeys []string + json.Unmarshal(session.Personas, &personaKeys) + + return &SummaryResult{ + Topic: session.Topic, + Rounds: session.Round, + Personas: personaKeys, + Markdown: resp, + }, nil +} + +func (e *Engine) llmCall(ctx context.Context, systemPrompt, userPrompt string) (string, error) { + req := providers.ChatRequest{ + Messages: []providers.Message{ + {Role: "system", Content: systemPrompt}, + {Role: "user", Content: userPrompt}, + }, + Options: map[string]any{ + "max_tokens": 4096, + }, + } + resp, err := e.provider.Chat(ctx, req) + if err != nil { + return "", err + } + return resp.Content, nil +} + +func (e *Engine) loadPersonaSoulMD(ctx context.Context, agentKey string) string { + agent, err := e.agentStore.GetByKey(ctx, agentKey) + if err != nil { + return "" + } + files, _ := e.agentStore.GetAgentContextFiles(ctx, agent.ID) + for _, f := range files { + if f.FileName == "SOUL.md" { + return f.Content + } + } + return "" +} + +// parsePersonaMessages parses LLM output into individual persona messages. +func parsePersonaMessages(resp string, personas []PersonaInfo) []PersonaMessage { + var messages []PersonaMessage + lines := strings.Split(resp, "\n") + + var current *PersonaMessage + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + matched := false + for _, p := range personas { + if strings.HasPrefix(line, p.Emoji) { + if current != nil { + messages = append(messages, *current) + } + content := line + prefix := p.Emoji + " " + p.DisplayName + ":" + if strings.HasPrefix(line, prefix) { + content = strings.TrimSpace(line[len(prefix):]) + } + current = &PersonaMessage{ + PersonaKey: p.AgentKey, + DisplayName: p.DisplayName, + Emoji: p.Emoji, + Content: content, + } + matched = true + break + } + } + if !matched && current != nil { + current.Content += "\n" + line + } + } + if current != nil { + messages = append(messages, *current) + } + return messages +} diff --git a/internal/party/prompt.go b/internal/party/prompt.go new file mode 100644 index 00000000..7217e92b --- /dev/null +++ b/internal/party/prompt.go @@ -0,0 +1,161 @@ +package party + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// BuildPersonaSystemPrompt builds the system prompt for a persona in a party round. +func BuildPersonaSystemPrompt(persona PersonaInfo, session *store.PartySessionData, soulMD string) string { + var sb strings.Builder + + // Persona identity + sb.WriteString(soulMD) + sb.WriteString("\n\n") + + // Party context + sb.WriteString("\n") + sb.WriteString(fmt.Sprintf("%s\n", session.Topic)) + + var ctx map[string]json.RawMessage + if json.Unmarshal(session.Context, &ctx) == nil { + if docs, ok := ctx["documents"]; ok { + sb.WriteString("\n") + sb.Write(docs) + sb.WriteString("\n\n") + } + if code, ok := ctx["codebase"]; ok { + sb.WriteString("\n") + sb.Write(code) + sb.WriteString("\n\n") + } + if notes, ok := ctx["meeting_notes"]; ok { + sb.WriteString("\n") + sb.Write(notes) + sb.WriteString("\n\n") + } + if custom, ok := ctx["custom"]; ok { + sb.WriteString("\n") + sb.Write(custom) + sb.WriteString("\n\n") + } + } + sb.WriteString("\n\n") + + // Round history (sliding window — last 3 rounds) + var history []RoundResult + if json.Unmarshal(session.History, &history) == nil && len(history) > 0 { + start := 0 + if len(history) > 3 { + start = len(history) - 3 + } + sb.WriteString("\n") + for _, r := range history[start:] { + sb.WriteString(fmt.Sprintf("Round %d [%s]:\n", r.Round, r.Mode)) + for _, m := range r.Messages { + sb.WriteString(fmt.Sprintf(" %s %s: %s\n", m.Emoji, m.DisplayName, truncate(m.Content, 500))) + } + } + sb.WriteString("\n") + } + + return sb.String() +} + +// BuildStandardRoundPrompt builds the user message for a standard round. +func BuildStandardRoundPrompt(session *store.PartySessionData, personas []PersonaInfo) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Round %d discussion about: %s\n\n", session.Round, session.Topic)) + sb.WriteString("Respond as each of these personas IN CHARACTER. Each persona gives their expert analysis.\n") + sb.WriteString("Format each response as: {emoji} {name}: {response}\n") + sb.WriteString("Encourage genuine disagreement where expertise conflicts.\n\n") + sb.WriteString("Personas:\n") + for _, p := range personas { + sb.WriteString(fmt.Sprintf("- %s %s (%s)\n", p.Emoji, p.DisplayName, p.SpeakingStyle)) + } + return sb.String() +} + +// BuildThinkingPrompt builds the user message for Deep Mode Step 1 (independent thinking). +func BuildThinkingPrompt(session *store.PartySessionData, persona PersonaInfo) string { + return fmt.Sprintf( + "Round %d: Think independently about \"%s\".\n"+ + "Share your analysis from your %s expertise.\n"+ + "Be specific, cite relevant standards/principles.\n"+ + "Identify risks, opportunities, and trade-offs.\n"+ + "Stay completely in character as %s.", + session.Round, session.Topic, persona.DisplayName, persona.DisplayName) +} + +// BuildCrossTalkPrompt builds the prompt for Deep Mode Step 2 (cross-talk from collected thoughts). +func BuildCrossTalkPrompt(session *store.PartySessionData, personas []PersonaInfo, thoughts []PersonaThought) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Round %d cross-talk about: %s\n\n", session.Round, session.Topic)) + sb.WriteString("Each persona has shared their independent thinking below.\n") + sb.WriteString("Now generate cross-talk: personas respond to EACH OTHER.\n") + sb.WriteString("Challenge disagreements explicitly. Build on agreements.\n") + sb.WriteString("Stay in character. Format: {emoji} {name}: {response}\n\n") + + sb.WriteString("\n") + for _, t := range thoughts { + sb.WriteString(fmt.Sprintf("<%s>\n%s\n\n", t.PersonaKey, t.Content, t.PersonaKey)) + } + sb.WriteString("\n") + return sb.String() +} + +// BuildTokenRingTurnPrompt builds the prompt for one persona's turn in Token-Ring mode. +func BuildTokenRingTurnPrompt(session *store.PartySessionData, persona PersonaInfo, thoughts []PersonaThought, priorTurns []PersonaMessage, isLast bool) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("Round %d, your turn in the discussion about: %s\n\n", session.Round, session.Topic)) + + sb.WriteString("Independent thoughts from all personas:\n") + for _, t := range thoughts { + sb.WriteString(fmt.Sprintf("- %s: %s\n", t.PersonaKey, truncate(t.Content, 300))) + } + + if len(priorTurns) > 0 { + sb.WriteString("\nPrior responses this round:\n") + for _, m := range priorTurns { + sb.WriteString(fmt.Sprintf(" %s %s: %s\n", m.Emoji, m.DisplayName, m.Content)) + } + sb.WriteString("\nRespond to what others have said. Challenge or build on their points.\n") + } + + if isLast { + sb.WriteString("\nYou are the LAST speaker. Synthesize: what does the team agree on? What remains unresolved?\n") + } + + sb.WriteString("\nStay completely in character. Be direct and specific.") + return sb.String() +} + +// BuildSummaryPrompt builds the prompt for generating the discussion summary. +func BuildSummaryPrompt(session *store.PartySessionData) string { + return fmt.Sprintf(`Summarize this party mode discussion. + +Topic: %s +Rounds: %d + +Discussion history: +%s + +Generate a structured summary with: +1. Points of Agreement (unanimous decisions) +2. Points of Disagreement (who disagrees, why) +3. Decisions Made +4. Action Items (action, assignee persona, deadline suggestion, checkpoint link) +5. Compliance Notes (if any security/PCI-DSS/SBV items) + +Format as clean markdown.`, session.Topic, session.Round, string(session.History)) +} + +func truncate(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + "..." +} diff --git a/internal/party/types.go b/internal/party/types.go new file mode 100644 index 00000000..933d0f97 --- /dev/null +++ b/internal/party/types.go @@ -0,0 +1,77 @@ +package party + +// PersonaInfo holds runtime persona metadata loaded from agent DB. +type PersonaInfo struct { + AgentKey string `json:"agent_key"` + DisplayName string `json:"display_name"` + Emoji string `json:"emoji"` + MovieRef string `json:"movie_ref"` + SpeakingStyle string `json:"speaking_style"` + ExpertiseWeight map[string]float64 `json:"expertise_weight,omitempty"` +} + +// PersonaThought is one persona's independent thinking output (Deep Mode Step 1). +type PersonaThought struct { + PersonaKey string `json:"persona_key"` + Emoji string `json:"emoji"` + Content string `json:"content"` +} + +// PersonaMessage is a persona's spoken message in a round. +type PersonaMessage struct { + PersonaKey string `json:"persona_key"` + DisplayName string `json:"display_name"` + Emoji string `json:"emoji"` + Content string `json:"content"` + Thinking string `json:"thinking,omitempty"` +} + +// RoundResult contains all persona messages for one round. +type RoundResult struct { + Round int `json:"round"` + Mode string `json:"mode"` + Messages []PersonaMessage `json:"messages"` +} + +// SummaryResult contains the party discussion summary. +type SummaryResult struct { + Topic string `json:"topic"` + Rounds int `json:"rounds"` + Personas []string `json:"personas"` + Agreements []string `json:"agreements"` + Disagreements []string `json:"disagreements"` + Decisions []string `json:"decisions"` + ActionItems []ActionItem `json:"action_items"` + Compliance []string `json:"compliance_notes,omitempty"` + Markdown string `json:"markdown"` +} + +// ActionItem is a follow-up task from the discussion. +type ActionItem struct { + Action string `json:"action"` + Assignee string `json:"assignee"` + Deadline string `json:"deadline,omitempty"` + CPLink string `json:"cp_link,omitempty"` +} + +// PresetTeam defines a preset team composition. +type PresetTeam struct { + Key string `json:"key"` + Name string `json:"name"` + Personas []string `json:"personas"` + UseCase string `json:"use_case"` + Facilitator string `json:"facilitator"` + Mandatory []string `json:"mandatory,omitempty"` +} + +// PresetTeams returns the 6 preset team compositions. +func PresetTeams() []PresetTeam { + return []PresetTeam{ + {Key: "payment_feature", Name: "Payment Feature", Personas: []string{"tony-stark-persona", "neo-persona", "batman-persona", "judge-dredd-persona", "columbo-persona"}, UseCase: "Payment flows, settlement", Facilitator: "gandalf-persona"}, + {Key: "security_review", Name: "Security Review", Personas: []string{"batman-persona", "judge-dredd-persona", "neo-persona", "scotty-persona"}, UseCase: "Threat modeling, pre-CP3", Facilitator: "batman-persona"}, + {Key: "sprint_planning", Name: "Sprint Planning", Personas: []string{"tony-stark-persona", "sherlock-persona", "neo-persona", "gandalf-persona", "columbo-persona"}, UseCase: "Sprint kickoff, PRD review", Facilitator: "gandalf-persona"}, + {Key: "architecture_decision", Name: "Architecture Decision", Personas: []string{"neo-persona", "spock-persona", "scotty-persona", "batman-persona"}, UseCase: "ADR, tech stack eval", Facilitator: "morpheus-persona"}, + {Key: "ux_review", Name: "UX Review", Personas: []string{"edna-mode-persona", "tony-stark-persona", "spider-man-persona", "ethan-hunt-persona", "columbo-persona"}, UseCase: "Design review", Facilitator: "edna-mode-persona"}, + {Key: "incident_response", Name: "Incident Response", Personas: []string{"scotty-persona", "neo-persona", "batman-persona", "nick-fury-persona"}, UseCase: "Production incidents", Facilitator: "nick-fury-persona"}, + } +} diff --git a/internal/providers/anthropic_stream.go b/internal/providers/anthropic_stream.go index 27c6eb8d..fda9175a 100644 --- a/internal/providers/anthropic_stream.go +++ b/internal/providers/anthropic_stream.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "strings" ) @@ -27,6 +28,7 @@ func (p *AnthropicProvider) ChatStream(ctx context.Context, req ChatRequest, onC defer respBody.Close() result := &ChatResponse{FinishReason: "stop"} + var receivedStop bool // tracks whether message_stop event was received // Accumulate raw JSON fragments for each tool call by index toolCallJSON := make(map[int]string) @@ -148,10 +150,27 @@ func (p *AnthropicProvider) ChatStream(ctx context.Context, req ChatRequest, onC } case "message_stop": - // Stream complete + receivedStop = true } } + // Check for scanner errors (timeout, connection reset, etc.) + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("anthropic stream read error: %w", err) + } + + // Detect premature stream termination: SSE connection closed before + // the message_stop event. This can happen when a proxy (e.g. CLIProxyAPI) + // drops the connection or the network is interrupted. The default + // FinishReason "stop" is misleading in this case. + if !receivedStop && (result.Content != "" || len(result.ToolCalls) > 0) { + slog.Warn("anthropic stream interrupted: no message_stop event", + "content_len", len(result.Content), + "tool_calls", len(result.ToolCalls), + "has_usage", result.Usage != nil) + result.FinishReason = "interrupted" + } + // Parse accumulated tool call JSON arguments for i, rawJSON := range toolCallJSON { if rawJSON != "" { diff --git a/internal/providers/codex.go b/internal/providers/codex.go index 9df7e43e..fca2791d 100644 --- a/internal/providers/codex.go +++ b/internal/providers/codex.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" "strings" "time" @@ -98,7 +99,10 @@ func (p *CodexProvider) ChatStream(ctx context.Context, req ChatRequest, onChunk continue } args := make(map[string]any) - _ = json.Unmarshal([]byte(acc.rawArgs), &args) + if err := json.Unmarshal([]byte(acc.rawArgs), &args); err != nil && acc.rawArgs != "" { + slog.Warn("truncated tool call arguments (codex stream)", + "tool", acc.name, "error", err, "raw_len", len(acc.rawArgs)) + } result.ToolCalls = append(result.ToolCalls, ToolCall{ ID: acc.callID, Name: acc.name, diff --git a/internal/providers/openai.go b/internal/providers/openai.go index f9ad5f8f..64f03441 100644 --- a/internal/providers/openai.go +++ b/internal/providers/openai.go @@ -7,7 +7,9 @@ import ( "encoding/json" "fmt" "io" + "log/slog" "net/http" + "sort" "strings" "time" ) @@ -108,6 +110,7 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ChatRequest, onChun defer respBody.Close() result := &ChatResponse{FinishReason: "stop"} + var receivedDone bool // tracks whether [DONE] marker was received accumulators := make(map[int]*toolCallAccumulator) scanner := bufio.NewScanner(respBody) @@ -122,6 +125,7 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ChatRequest, onChun data := strings.TrimPrefix(line, "data:") data = strings.TrimPrefix(data, " ") if data == "[DONE]" { + receivedDone = true break } @@ -194,11 +198,30 @@ func (p *OpenAIProvider) ChatStream(ctx context.Context, req ChatRequest, onChun return nil, fmt.Errorf("%s: stream read error: %w", p.name, err) } - // Parse accumulated tool call arguments - for i := 0; i < len(accumulators); i++ { - acc := accumulators[i] + // Detect premature stream termination: connection closed before [DONE]. + if !receivedDone && (result.Content != "" || len(accumulators) > 0) { + slog.Warn("openai stream interrupted: no [DONE] marker", + "provider", p.name, + "content_len", len(result.Content), + "tool_calls", len(accumulators)) + result.FinishReason = "interrupted" + } + + // Parse accumulated tool call arguments. + // Keys are SSE tool_call indices which may be non-contiguous (e.g. {0, 2}), + // so we sort keys instead of assuming sequential 0..len-1. + indices := make([]int, 0, len(accumulators)) + for idx := range accumulators { + indices = append(indices, idx) + } + sort.Ints(indices) + for _, idx := range indices { + acc := accumulators[idx] args := make(map[string]any) - _ = json.Unmarshal([]byte(acc.rawArgs), &args) + if err := json.Unmarshal([]byte(acc.rawArgs), &args); err != nil && acc.rawArgs != "" { + slog.Warn("truncated tool call arguments (stream)", + "tool", acc.Name, "error", err, "raw_len", len(acc.rawArgs)) + } acc.Arguments = args if acc.thoughtSig != "" { acc.Metadata = map[string]string{"thought_signature": acc.thoughtSig} @@ -379,7 +402,10 @@ func (p *OpenAIProvider) parseResponse(resp *openAIResponse) *ChatResponse { for _, tc := range msg.ToolCalls { args := make(map[string]any) - _ = json.Unmarshal([]byte(tc.Function.Arguments), &args) + if err := json.Unmarshal([]byte(tc.Function.Arguments), &args); err != nil && tc.Function.Arguments != "" { + slog.Warn("truncated tool call arguments (non-stream)", + "tool", tc.Function.Name, "error", err, "raw_len", len(tc.Function.Arguments)) + } call := ToolCall{ ID: tc.ID, Name: strings.TrimSpace(tc.Function.Name), diff --git a/internal/scheduler/lanes.go b/internal/scheduler/lanes.go index 1af97d57..eb2488c0 100644 --- a/internal/scheduler/lanes.go +++ b/internal/scheduler/lanes.go @@ -92,6 +92,9 @@ func (l *Lane) Submit(ctx context.Context, fn func()) error { go func() { defer func() { + if r := recover(); r != nil { + slog.Error("panic in lane worker", "lane", l.name, "panic", r) + } l.active.Add(-1) l.wg.Done() l.sem <- token // return token diff --git a/internal/sessions/manager.go b/internal/sessions/manager.go index 14ef8f19..069e0f38 100644 --- a/internal/sessions/manager.go +++ b/internal/sessions/manager.go @@ -260,6 +260,19 @@ func (m *Manager) GetLastPromptTokens(key string) (int, int) { return 0, 0 } +// SetHistory replaces the full message history for a session. +func (m *Manager) SetHistory(key string, msgs []providers.Message) { + m.mu.Lock() + defer m.mu.Unlock() + + s, ok := m.sessions[key] + if !ok { + return + } + s.Messages = msgs + s.Updated = time.Now() +} + // TruncateHistory keeps only the last N messages. func (m *Manager) TruncateHistory(key string, keepLast int) { m.mu.Lock() @@ -278,16 +291,6 @@ func (m *Manager) TruncateHistory(key string, keepLast int) { s.Updated = time.Now() } -// SetHistory replaces a session's message history with the given slice. -func (m *Manager) SetHistory(key string, msgs []providers.Message) { - m.mu.Lock() - defer m.mu.Unlock() - - if s, ok := m.sessions[key]; ok { - s.Messages = msgs - s.Updated = time.Now() - } -} // Reset clears a session's history and summary. func (m *Manager) Reset(key string) { diff --git a/internal/store/agent_store.go b/internal/store/agent_store.go index 60827c7c..e45e8e37 100644 --- a/internal/store/agent_store.go +++ b/internal/store/agent_store.go @@ -124,6 +124,21 @@ func (a *AgentData) ParseMemoryConfig() *config.MemoryConfig { return &c } +// ParseMaxTokens extracts max_tokens from other_config JSONB. +// Returns 0 if not configured (caller should apply default). +func (a *AgentData) ParseMaxTokens() int { + if len(a.OtherConfig) == 0 { + return 0 + } + var cfg struct { + MaxTokens int `json:"max_tokens"` + } + if json.Unmarshal(a.OtherConfig, &cfg) != nil { + return 0 + } + return cfg.MaxTokens +} + // ParseThinkingLevel extracts thinking_level from other_config JSONB. // Returns "" if not configured (meaning "off"). func (a *AgentData) ParseThinkingLevel() string { diff --git a/internal/store/party_store.go b/internal/store/party_store.go new file mode 100644 index 00000000..31e0b80e --- /dev/null +++ b/internal/store/party_store.go @@ -0,0 +1,56 @@ +package store + +import ( + "context" + "encoding/json" + "time" + + "github.com/google/uuid" +) + +// Party session statuses. +const ( + PartyStatusAssembling = "assembling" + PartyStatusDiscussing = "discussing" + PartyStatusSummarizing = "summarizing" + PartyStatusClosed = "closed" +) + +// Party discussion modes. +const ( + PartyModeStandard = "standard" + PartyModeDeep = "deep" + PartyModeTokenRing = "token_ring" +) + +// PartySessionData represents a party mode session. +type PartySessionData struct { + ID uuid.UUID `json:"id"` + Topic string `json:"topic"` + TeamPreset string `json:"team_preset,omitempty"` + Status string `json:"status"` + Mode string `json:"mode"` + Round int `json:"round"` + MaxRounds int `json:"max_rounds"` + UserID string `json:"user_id"` + Channel string `json:"channel,omitempty"` + ChatID string `json:"chat_id,omitempty"` + Personas json.RawMessage `json:"personas"` + Context json.RawMessage `json:"context"` + History json.RawMessage `json:"history"` + Summary json.RawMessage `json:"summary,omitempty"` + Artifacts json.RawMessage `json:"artifacts"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// PartyStore manages party mode sessions. +type PartyStore interface { + CreateSession(ctx context.Context, session *PartySessionData) error + GetSession(ctx context.Context, id uuid.UUID) (*PartySessionData, error) + UpdateSession(ctx context.Context, id uuid.UUID, updates map[string]any) error + ListSessions(ctx context.Context, userID string, status string, limit int) ([]*PartySessionData, error) + // GetActiveSession returns the active (assembling/discussing) session for a user+channel+chat. + GetActiveSession(ctx context.Context, userID, channel, chatID string) (*PartySessionData, error) + DeleteSession(ctx context.Context, id uuid.UUID) error +} diff --git a/internal/store/pg/factory.go b/internal/store/pg/factory.go index 58b1ea0a..cff0c98a 100644 --- a/internal/store/pg/factory.go +++ b/internal/store/pg/factory.go @@ -44,5 +44,7 @@ func NewPGStores(cfg store.StoreConfig) (*store.Stores, error) { Contacts: NewPGContactStore(db), Activity: NewPGActivityStore(db), Snapshots: NewPGSnapshotStore(db), + Party: NewPGPartyStore(db), + Projects: NewPGProjectStore(db), }, nil } diff --git a/internal/store/pg/helpers.go b/internal/store/pg/helpers.go index 71f88374..1c1a9747 100644 --- a/internal/store/pg/helpers.go +++ b/internal/store/pg/helpers.go @@ -138,6 +138,7 @@ var tablesWithUpdatedAt = map[string]bool{ "agent_context_files": true, "user_context_files": true, "user_agent_overrides": true, "config_secrets": true, "memory_documents": true, "memory_chunks": true, "embedding_cache": true, + "projects": true, "project_mcp_overrides": true, } func tableHasUpdatedAt(table string) bool { diff --git a/internal/store/pg/party.go b/internal/store/pg/party.go new file mode 100644 index 00000000..ecb597db --- /dev/null +++ b/internal/store/pg/party.go @@ -0,0 +1,140 @@ +package pg + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +const partySelectCols = `id, topic, team_preset, status, mode, round, max_rounds, + user_id, channel, chat_id, personas, context, history, + COALESCE(summary, 'null'), artifacts, created_at, updated_at` + +// PGPartyStore implements PartyStore backed by PostgreSQL. +type PGPartyStore struct { + db *sql.DB +} + +// NewPGPartyStore creates a new PGPartyStore. +func NewPGPartyStore(db *sql.DB) *PGPartyStore { + return &PGPartyStore{db: db} +} + +func (s *PGPartyStore) CreateSession(ctx context.Context, sess *store.PartySessionData) error { + if sess.ID == uuid.Nil { + sess.ID = store.GenNewID() + } + now := time.Now() + sess.CreatedAt = now + sess.UpdatedAt = now + if len(sess.Personas) == 0 { + sess.Personas = json.RawMessage("[]") + } + if len(sess.Context) == 0 { + sess.Context = json.RawMessage("{}") + } + if len(sess.History) == 0 { + sess.History = json.RawMessage("[]") + } + if len(sess.Artifacts) == 0 { + sess.Artifacts = json.RawMessage("[]") + } + + _, err := s.db.ExecContext(ctx, + `INSERT INTO party_sessions + (id, topic, team_preset, status, mode, round, max_rounds, + user_id, channel, chat_id, personas, context, history, artifacts, created_at, updated_at) + VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16)`, + sess.ID, sess.Topic, sess.TeamPreset, sess.Status, sess.Mode, + sess.Round, sess.MaxRounds, sess.UserID, sess.Channel, sess.ChatID, + sess.Personas, sess.Context, sess.History, sess.Artifacts, + sess.CreatedAt, sess.UpdatedAt) + return err +} + +func (s *PGPartyStore) GetSession(ctx context.Context, id uuid.UUID) (*store.PartySessionData, error) { + row := s.db.QueryRowContext(ctx, + `SELECT `+partySelectCols+` FROM party_sessions WHERE id = $1`, id) + return scanPartyRow(row) +} + +func (s *PGPartyStore) GetActiveSession(ctx context.Context, userID, channel, chatID string) (*store.PartySessionData, error) { + row := s.db.QueryRowContext(ctx, + `SELECT `+partySelectCols+` FROM party_sessions + WHERE user_id = $1 AND channel = $2 AND chat_id = $3 + AND status IN ('assembling', 'discussing') + ORDER BY created_at DESC LIMIT 1`, userID, channel, chatID) + sess, err := scanPartyRow(row) + if err == sql.ErrNoRows { + return nil, nil + } + return sess, err +} + +func (s *PGPartyStore) UpdateSession(ctx context.Context, id uuid.UUID, updates map[string]any) error { + updates["updated_at"] = time.Now() + return execMapUpdate(ctx, s.db, "party_sessions", id, updates) +} + +func (s *PGPartyStore) ListSessions(ctx context.Context, userID string, status string, limit int) ([]*store.PartySessionData, error) { + query := `SELECT ` + partySelectCols + ` FROM party_sessions WHERE user_id = $1` + args := []any{userID} + if status != "" { + query += ` AND status = $2` + args = append(args, status) + } + query += ` ORDER BY created_at DESC` + if limit > 0 { + query += fmt.Sprintf(` LIMIT %d`, limit) + } + + rows, err := s.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var sessions []*store.PartySessionData + for rows.Next() { + sess, err := scanPartyRows(rows) + if err != nil { + return nil, err + } + sessions = append(sessions, sess) + } + return sessions, rows.Err() +} + +func (s *PGPartyStore) DeleteSession(ctx context.Context, id uuid.UUID) error { + _, err := s.db.ExecContext(ctx, `DELETE FROM party_sessions WHERE id = $1`, id) + return err +} + +func scanPartyRow(row *sql.Row) (*store.PartySessionData, error) { + var s store.PartySessionData + err := row.Scan(&s.ID, &s.Topic, &s.TeamPreset, &s.Status, &s.Mode, + &s.Round, &s.MaxRounds, &s.UserID, &s.Channel, &s.ChatID, + &s.Personas, &s.Context, &s.History, &s.Summary, &s.Artifacts, + &s.CreatedAt, &s.UpdatedAt) + if err != nil { + return nil, err + } + return &s, nil +} + +func scanPartyRows(rows *sql.Rows) (*store.PartySessionData, error) { + var s store.PartySessionData + err := rows.Scan(&s.ID, &s.Topic, &s.TeamPreset, &s.Status, &s.Mode, + &s.Round, &s.MaxRounds, &s.UserID, &s.Channel, &s.ChatID, + &s.Personas, &s.Context, &s.History, &s.Summary, &s.Artifacts, + &s.CreatedAt, &s.UpdatedAt) + if err != nil { + return nil, err + } + return &s, nil +} diff --git a/internal/store/pg/projects.go b/internal/store/pg/projects.go new file mode 100644 index 00000000..eedbade7 --- /dev/null +++ b/internal/store/pg/projects.go @@ -0,0 +1,180 @@ +package pg + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "regexp" + + "github.com/google/uuid" + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +var secretKeyPattern = regexp.MustCompile(`(?i)(^|_)(TOKEN|SECRET|PASSWORD|API_KEY)($|_)`) + +// PGProjectStore implements store.ProjectStore backed by Postgres. +type PGProjectStore struct { + db *sql.DB +} + +func NewPGProjectStore(db *sql.DB) *PGProjectStore { + return &PGProjectStore{db: db} +} + +// --- Project CRUD --- + +func (s *PGProjectStore) CreateProject(ctx context.Context, p *store.Project) error { + query := `INSERT INTO projects (name, slug, channel_type, chat_id, team_id, description, status, created_by) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at` + return s.db.QueryRowContext(ctx, query, + p.Name, p.Slug, p.ChannelType, p.ChatID, nilUUID(p.TeamID), + p.Description, p.Status, p.CreatedBy, + ).Scan(&p.ID, &p.CreatedAt, &p.UpdatedAt) +} + +func (s *PGProjectStore) GetProject(ctx context.Context, id uuid.UUID) (*store.Project, error) { + query := `SELECT id, name, slug, channel_type, chat_id, team_id, description, status, created_by, created_at, updated_at + FROM projects WHERE id = $1` + return s.scanProject(s.db.QueryRowContext(ctx, query, id)) +} + +func (s *PGProjectStore) GetProjectBySlug(ctx context.Context, slug string) (*store.Project, error) { + query := `SELECT id, name, slug, channel_type, chat_id, team_id, description, status, created_by, created_at, updated_at + FROM projects WHERE slug = $1` + return s.scanProject(s.db.QueryRowContext(ctx, query, slug)) +} + +func (s *PGProjectStore) GetProjectByChatID(ctx context.Context, channelType, chatID string) (*store.Project, error) { + query := `SELECT id, name, slug, channel_type, chat_id, team_id, description, status, created_by, created_at, updated_at + FROM projects WHERE channel_type = $1 AND chat_id = $2 AND status = 'active'` + p, err := s.scanProject(s.db.QueryRowContext(ctx, query, channelType, chatID)) + if err == sql.ErrNoRows { + return nil, nil // no project — not an error + } + return p, err +} + +func (s *PGProjectStore) scanProject(row *sql.Row) (*store.Project, error) { + p := &store.Project{} + err := row.Scan( + &p.ID, &p.Name, &p.Slug, &p.ChannelType, &p.ChatID, &p.TeamID, + &p.Description, &p.Status, &p.CreatedBy, &p.CreatedAt, &p.UpdatedAt, + ) + if err != nil { + return nil, err + } + return p, nil +} + +func (s *PGProjectStore) ListProjects(ctx context.Context) ([]store.Project, error) { + query := `SELECT id, name, slug, channel_type, chat_id, team_id, description, status, created_by, created_at, updated_at + FROM projects ORDER BY name` + rows, err := s.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make([]store.Project, 0) + for rows.Next() { + var p store.Project + if err := rows.Scan( + &p.ID, &p.Name, &p.Slug, &p.ChannelType, &p.ChatID, &p.TeamID, + &p.Description, &p.Status, &p.CreatedBy, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + continue + } + result = append(result, p) + } + return result, rows.Err() +} + +func (s *PGProjectStore) UpdateProject(ctx context.Context, id uuid.UUID, updates map[string]any) error { + return execMapUpdate(ctx, s.db, "projects", id, updates) +} + +func (s *PGProjectStore) DeleteProject(ctx context.Context, id uuid.UUID) error { + _, err := s.db.ExecContext(ctx, "DELETE FROM projects WHERE id = $1", id) + return err +} + +// --- MCP overrides --- + +// SetMCPOverride upserts env overrides for a project+server. +// Rejects keys that look like secrets (TOKEN, SECRET, PASSWORD, API_KEY). +func (s *PGProjectStore) SetMCPOverride(ctx context.Context, projectID uuid.UUID, serverName string, envOverrides map[string]string) error { + for key := range envOverrides { + if secretKeyPattern.MatchString(key) { + return fmt.Errorf("env key %q contains secret pattern (TOKEN/SECRET/PASSWORD/API_KEY) — use mcp_servers.env for secrets", key) + } + } + envJSON, err := json.Marshal(envOverrides) + if err != nil { + return err + } + query := `INSERT INTO project_mcp_overrides (project_id, server_name, env_overrides) + VALUES ($1, $2, $3) + ON CONFLICT (project_id, server_name) DO UPDATE SET env_overrides = $3, updated_at = NOW()` + _, err = s.db.ExecContext(ctx, query, projectID, serverName, envJSON) + return err +} + +func (s *PGProjectStore) RemoveMCPOverride(ctx context.Context, projectID uuid.UUID, serverName string) error { + _, err := s.db.ExecContext(ctx, + "DELETE FROM project_mcp_overrides WHERE project_id = $1 AND server_name = $2", + projectID, serverName) + return err +} + +func (s *PGProjectStore) GetMCPOverrides(ctx context.Context, projectID uuid.UUID) ([]store.ProjectMCPOverride, error) { + query := `SELECT id, project_id, server_name, env_overrides, enabled + FROM project_mcp_overrides WHERE project_id = $1 ORDER BY server_name` + rows, err := s.db.QueryContext(ctx, query, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + + result := make([]store.ProjectMCPOverride, 0) + for rows.Next() { + var o store.ProjectMCPOverride + var envJSON []byte + if err := rows.Scan(&o.ID, &o.ProjectID, &o.ServerName, &envJSON, &o.Enabled); err != nil { + continue + } + o.EnvOverrides = make(map[string]string) + if len(envJSON) > 0 { + if err := json.Unmarshal(envJSON, &o.EnvOverrides); err != nil { + continue + } + } + result = append(result, o) + } + return result, rows.Err() +} + +// GetMCPOverridesMap returns {serverName: {envKey: envVal}} for runtime env injection. +func (s *PGProjectStore) GetMCPOverridesMap(ctx context.Context, projectID uuid.UUID) (map[string]map[string]string, error) { + query := `SELECT server_name, env_overrides FROM project_mcp_overrides + WHERE project_id = $1 AND enabled = true` + rows, err := s.db.QueryContext(ctx, query, projectID) + if err != nil { + return nil, err + } + defer rows.Close() + result := make(map[string]map[string]string) + for rows.Next() { + var serverName string + var envJSON []byte + if err := rows.Scan(&serverName, &envJSON); err != nil { + return nil, err + } + env := make(map[string]string) + if err := json.Unmarshal(envJSON, &env); err != nil { + return nil, err + } + result[serverName] = env + } + return result, rows.Err() +} diff --git a/internal/store/pg/projects_test.go b/internal/store/pg/projects_test.go new file mode 100644 index 00000000..1fefe79c --- /dev/null +++ b/internal/store/pg/projects_test.go @@ -0,0 +1,40 @@ +package pg + +import "testing" + +func TestSecretKeyPattern(t *testing.T) { + tests := []struct { + key string + isSecret bool + }{ + {"GITLAB_TOKEN", true}, + {"gitlab_token", true}, + {"GitLab_Token", true}, + {"API_KEY", true}, + {"api_key", true}, + {"MY_SECRET", true}, + {"DB_PASSWORD", true}, + {"password", true}, + {"AUTH_TOKEN_V2", true}, + {"X_API_KEY_EXTRA", true}, + {"SECRET_STUFF", true}, + + {"GITLAB_PROJECT_ID", false}, + {"GITLAB_PROJECT_PATH", false}, + {"JIRA_PROJECT_KEY", false}, + {"CONFLUENCE_SPACE_KEY", false}, + {"PROJECT_PATH", false}, + {"BOARD_ID", false}, + {"WORKSPACE_DIR", false}, + {"TOKENIZER_TYPE", false}, // contains "TOKEN" substring but not at word boundary + } + + for _, tt := range tests { + t.Run(tt.key, func(t *testing.T) { + got := secretKeyPattern.MatchString(tt.key) + if got != tt.isSecret { + t.Errorf("secretKeyPattern.MatchString(%q) = %v, want %v", tt.key, got, tt.isSecret) + } + }) + } +} diff --git a/internal/store/pg/sessions_ops.go b/internal/store/pg/sessions_ops.go index 5033b6de..29e466a4 100644 --- a/internal/store/pg/sessions_ops.go +++ b/internal/store/pg/sessions_ops.go @@ -6,24 +6,24 @@ import ( "github.com/nextlevelbuilder/goclaw/internal/providers" ) -func (s *PGSessionStore) TruncateHistory(key string, keepLast int) { +func (s *PGSessionStore) SetHistory(key string, msgs []providers.Message) { s.mu.Lock() defer s.mu.Unlock() if data, ok := s.cache[key]; ok { - if keepLast <= 0 { - data.Messages = []providers.Message{} - } else if len(data.Messages) > keepLast { - data.Messages = data.Messages[len(data.Messages)-keepLast:] - } + data.Messages = msgs data.Updated = time.Now() } } -func (s *PGSessionStore) SetHistory(key string, msgs []providers.Message) { +func (s *PGSessionStore) TruncateHistory(key string, keepLast int) { s.mu.Lock() defer s.mu.Unlock() if data, ok := s.cache[key]; ok { - data.Messages = msgs + if keepLast <= 0 { + data.Messages = []providers.Message{} + } else if len(data.Messages) > keepLast { + data.Messages = data.Messages[len(data.Messages)-keepLast:] + } data.Updated = time.Now() } } diff --git a/internal/store/pg/tracing.go b/internal/store/pg/tracing.go index 2fa06a3f..34c9bcef 100644 --- a/internal/store/pg/tracing.go +++ b/internal/store/pg/tracing.go @@ -428,6 +428,18 @@ func (s *PGTracingStore) GetMonthlyAgentCost(ctx context.Context, agentID uuid.U return cost, err } +func (s *PGTracingStore) SweepOrphanTraces(ctx context.Context, maxAge time.Duration) (int, error) { + cutoff := time.Now().Add(-maxAge) + res, err := s.db.ExecContext(ctx, + `UPDATE traces SET status = 'error', error = 'orphan: process crashed', end_time = NOW() + WHERE status = 'running' AND created_at < $1`, cutoff) + if err != nil { + return 0, err + } + n, _ := res.RowsAffected() + return int(n), nil +} + func (s *PGTracingStore) GetCostSummary(ctx context.Context, opts store.CostSummaryOpts) ([]store.CostSummaryRow, error) { var conditions []string var args []any diff --git a/internal/store/project_store.go b/internal/store/project_store.go new file mode 100644 index 00000000..6e0904cd --- /dev/null +++ b/internal/store/project_store.go @@ -0,0 +1,47 @@ +package store + +import ( + "context" + + "github.com/google/uuid" +) + +// Project represents a workspace bound to a group chat. +type Project struct { + BaseModel + Name string `json:"name"` + Slug string `json:"slug"` + ChannelType *string `json:"channel_type,omitempty"` + ChatID *string `json:"chat_id,omitempty"` + TeamID *uuid.UUID `json:"team_id,omitempty"` + Description *string `json:"description,omitempty"` + Status string `json:"status"` + CreatedBy string `json:"created_by"` +} + +// ProjectMCPOverride holds per-project env overrides for an MCP server. +type ProjectMCPOverride struct { + ID uuid.UUID `json:"id"` + ProjectID uuid.UUID `json:"project_id"` + ServerName string `json:"server_name"` + EnvOverrides map[string]string `json:"env_overrides"` + Enabled bool `json:"enabled"` +} + +// ProjectStore manages project entities and their MCP overrides. +type ProjectStore interface { + CreateProject(ctx context.Context, p *Project) error + GetProject(ctx context.Context, id uuid.UUID) (*Project, error) + GetProjectBySlug(ctx context.Context, slug string) (*Project, error) + GetProjectByChatID(ctx context.Context, channelType, chatID string) (*Project, error) + ListProjects(ctx context.Context) ([]Project, error) + UpdateProject(ctx context.Context, id uuid.UUID, updates map[string]any) error + DeleteProject(ctx context.Context, id uuid.UUID) error + + // MCP overrides + SetMCPOverride(ctx context.Context, projectID uuid.UUID, serverName string, envOverrides map[string]string) error + RemoveMCPOverride(ctx context.Context, projectID uuid.UUID, serverName string) error + GetMCPOverrides(ctx context.Context, projectID uuid.UUID) ([]ProjectMCPOverride, error) + // GetMCPOverridesMap returns {serverName: {envKey: envVal}} for runtime injection. + GetMCPOverridesMap(ctx context.Context, projectID uuid.UUID) (map[string]map[string]string, error) +} diff --git a/internal/store/session_store.go b/internal/store/session_store.go index 5f4d0435..6d9d4433 100644 --- a/internal/store/session_store.go +++ b/internal/store/session_store.go @@ -102,8 +102,8 @@ type SessionStore interface { GetContextWindow(key string) int SetLastPromptTokens(key string, tokens, msgCount int) GetLastPromptTokens(key string) (tokens, msgCount int) - TruncateHistory(key string, keepLast int) SetHistory(key string, msgs []providers.Message) + TruncateHistory(key string, keepLast int) Reset(key string) Delete(key string) error List(agentID string) []SessionInfo diff --git a/internal/store/stores.go b/internal/store/stores.go index f89abbe9..ad151de8 100644 --- a/internal/store/stores.go +++ b/internal/store/stores.go @@ -25,4 +25,6 @@ type Stores struct { Contacts ContactStore Activity ActivityStore Snapshots SnapshotStore + Party PartyStore + Projects ProjectStore } diff --git a/internal/store/tracing_store.go b/internal/store/tracing_store.go index 1dd800f6..a5903fe2 100644 --- a/internal/store/tracing_store.go +++ b/internal/store/tracing_store.go @@ -141,4 +141,7 @@ type TracingStore interface { // Cost aggregation GetMonthlyAgentCost(ctx context.Context, agentID uuid.UUID, year int, month time.Month) (float64, error) GetCostSummary(ctx context.Context, opts CostSummaryOpts) ([]CostSummaryRow, error) + + // Maintenance + SweepOrphanTraces(ctx context.Context, maxAge time.Duration) (int, error) } diff --git a/internal/tools/context_keys.go b/internal/tools/context_keys.go index 2b9f6e3d..2bd28cac 100644 --- a/internal/tools/context_keys.go +++ b/internal/tools/context_keys.go @@ -221,6 +221,34 @@ func WorkspaceChatIDFromCtx(ctx context.Context) string { return v } +// --- Project context propagation (message arrival → delegation chain) --- + +const ( + ctxProjectID toolContextKey = "tool_project_id" + ctxProjectOverrides toolContextKey = "tool_project_overrides" +) + +// WithToolProjectID injects the resolved project UUID into context. +// Used by delegation tools to propagate project scope through the delegation chain. +func WithToolProjectID(ctx context.Context, projectID string) context.Context { + return context.WithValue(ctx, ctxProjectID, projectID) +} + +func ToolProjectIDFromCtx(ctx context.Context) string { + v, _ := ctx.Value(ctxProjectID).(string) + return v +} + +// WithToolProjectOverrides injects project MCP env overrides into context. +func WithToolProjectOverrides(ctx context.Context, overrides map[string]map[string]string) context.Context { + return context.WithValue(ctx, ctxProjectOverrides, overrides) +} + +func ToolProjectOverridesFromCtx(ctx context.Context) map[string]map[string]string { + v, _ := ctx.Value(ctxProjectOverrides).(map[string]map[string]string) + return v +} + // --- Per-agent sandbox config override --- const ctxSandboxCfg toolContextKey = "tool_sandbox_config" diff --git a/internal/tools/delegate.go b/internal/tools/delegate.go index 273d585f..e99851bf 100644 --- a/internal/tools/delegate.go +++ b/internal/tools/delegate.go @@ -47,6 +47,10 @@ type DelegationTask struct { TeamID uuid.UUID `json:"-"` // from link.TeamID (for delegation history) TeamTaskID uuid.UUID `json:"-"` + // Project scope propagation (message arrival → delegation chain) + ProjectID string `json:"-"` + ProjectOverrides map[string]map[string]string `json:"-"` + // Activity tracking (updated via UpdateActivity on agent.activity events) LastActivity string `json:"-"` // "thinking", "tool_exec", "compacting" LastTool string `json:"-"` // current tool name (when LastActivity == "tool_exec") @@ -117,6 +121,9 @@ type DelegateRunRequest struct { // Workspace scope propagation (set by delegation, read by workspace tools) WorkspaceChannel string WorkspaceChatID string + + ProjectID string `json:"project_id,omitempty"` + ProjectOverrides map[string]map[string]string `json:"project_overrides,omitempty"` } // DelegateRunResult is the result from AgentRunFunc. diff --git a/internal/tools/delegate_prep.go b/internal/tools/delegate_prep.go index 1540e231..bd6e93e0 100644 --- a/internal/tools/delegate_prep.go +++ b/internal/tools/delegate_prep.go @@ -187,6 +187,8 @@ func (dm *DelegateManager) prepareDelegation(ctx context.Context, opts DelegateO OriginTraceID: tracing.TraceIDFromContext(ctx), OriginRootSpanID: tracing.ParentSpanIDFromContext(ctx), TeamTaskID: opts.TeamTaskID, + ProjectID: ToolProjectIDFromCtx(ctx), + ProjectOverrides: ToolProjectOverridesFromCtx(ctx), } // Carry team_id from the link (for delegation history filtering by team) @@ -399,6 +401,10 @@ func (dm *DelegateManager) buildRunRequest(task *DelegationTask, message string) ParentAgentID: task.SourceAgentKey, } + // Propagate project scope so delegate agents use the same MCP env overrides. + req.ProjectID = task.ProjectID + req.ProjectOverrides = task.ProjectOverrides + // Propagate workspace scope to delegate so workspace tools write to the // origin user's workspace, not the "delegate" channel. Scope = userID. req.WorkspaceChannel = "" diff --git a/internal/tools/team_tasks_tool.go b/internal/tools/team_tasks_tool.go index 4807a858..fe7c578c 100644 --- a/internal/tools/team_tasks_tool.go +++ b/internal/tools/team_tasks_tool.go @@ -341,7 +341,16 @@ func (t *TeamTasksTool) executeCreate(ctx context.Context, args map[string]any) status = store.TeamTaskStatusBlocked } + channel := ToolChannelFromCtx(ctx) chatID := ToolChatIDFromCtx(ctx) + meta, _ := args["metadata"].(map[string]any) + if meta == nil { + meta = make(map[string]any) + } + if senderID := store.SenderIDFromContext(ctx); senderID != "" { + meta["sender_id"] = senderID + } + meta["channel"] = channel task := &store.TeamTaskData{ TeamID: team.ID, @@ -350,8 +359,9 @@ func (t *TeamTasksTool) executeCreate(ctx context.Context, args map[string]any) Status: status, BlockedBy: blockedBy, Priority: priority, + Metadata: meta, UserID: store.UserIDFromContext(ctx), - Channel: ToolChannelFromCtx(ctx), + Channel: channel, TaskType: "general", CreatedByAgentID: &agentID, ChatID: chatID, diff --git a/internal/tools/team_tasks_tool_test.go b/internal/tools/team_tasks_tool_test.go new file mode 100644 index 00000000..ff0f2de2 --- /dev/null +++ b/internal/tools/team_tasks_tool_test.go @@ -0,0 +1,300 @@ +package tools + +import ( + "context" + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/google/uuid" + + "github.com/nextlevelbuilder/goclaw/internal/store" +) + +// ── Mock stores ──────────────────────────────────────────────────────────────── + +type mockTeamStore struct { + store.TeamStore // embed to satisfy full interface + team *store.TeamData + createdTask *store.TeamTaskData +} + +func (m *mockTeamStore) GetTeamForAgent(_ context.Context, _ uuid.UUID) (*store.TeamData, error) { + if m.team == nil { + return nil, nil + } + return m.team, nil +} + +func (m *mockTeamStore) CreateTask(_ context.Context, task *store.TeamTaskData) error { + task.ID = uuid.New() + task.CreatedAt = time.Now() + m.createdTask = task + return nil +} + +type mockAgentStore struct { + store.AgentStore +} + +func (m *mockAgentStore) GetByKey(_ context.Context, _ string) (*store.AgentData, error) { + return nil, fmt.Errorf("not found") +} +func (m *mockAgentStore) GetByID(_ context.Context, _ uuid.UUID) (*store.AgentData, error) { + return nil, fmt.Errorf("not found") +} + +// ── Helper ───────────────────────────────────────────────────────────────────── + +func makeTestTeam(leadID uuid.UUID, settings json.RawMessage) *store.TeamData { + return &store.TeamData{ + BaseModel: store.BaseModel{ID: uuid.New()}, + Name: "test-team", + LeadAgentID: leadID, + Status: "active", + Settings: settings, + } +} + +func makeCtx(agentID uuid.UUID, userID, senderID, channel string) context.Context { + ctx := context.Background() + ctx = store.WithAgentID(ctx, agentID) + ctx = store.WithUserID(ctx, userID) + if senderID != "" { + ctx = store.WithSenderID(ctx, senderID) + } + ctx = WithToolChannel(ctx, channel) + return ctx +} + +// ── Tests: sender_id tracking ────────────────────────────────────────────────── + +func TestExecuteCreate_SenderIDTracking(t *testing.T) { + leadID := uuid.New() + team := makeTestTeam(leadID, nil) + + ts := &mockTeamStore{team: team} + mgr := NewTeamToolManager(ts, &mockAgentStore{}, nil) + tool := NewTeamTasksTool(mgr) + + ctx := makeCtx(leadID, "group:telegram:chat123", "user-456", "telegram") + args := map[string]any{ + "action": "create", + "subject": "Test task with sender_id", + } + + result := tool.executeCreate(ctx, args) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + + if ts.createdTask == nil { + t.Fatal("expected task to be created") + } + + meta := ts.createdTask.Metadata + if meta == nil { + t.Fatal("expected metadata to be non-nil") + } + + if sid, ok := meta["sender_id"].(string); !ok || sid != "user-456" { + t.Errorf("expected sender_id=user-456, got %v", meta["sender_id"]) + } + if ch, ok := meta["channel"].(string); !ok || ch != "telegram" { + t.Errorf("expected channel=telegram, got %v", meta["channel"]) + } +} + +func TestExecuteCreate_NoSenderID(t *testing.T) { + leadID := uuid.New() + team := makeTestTeam(leadID, nil) + + ts := &mockTeamStore{team: team} + mgr := NewTeamToolManager(ts, &mockAgentStore{}, nil) + tool := NewTeamTasksTool(mgr) + + // No sender ID in context (delegate channel, internal agent-to-agent) + ctx := makeCtx(leadID, "delegate:system", "", "delegate") + args := map[string]any{ + "action": "create", + "subject": "Internal task", + } + + result := tool.executeCreate(ctx, args) + if result.IsError { + t.Fatalf("expected success, got error: %s", result.ForLLM) + } + + meta := ts.createdTask.Metadata + if meta == nil { + t.Fatal("expected metadata to be non-nil") + } + + // sender_id should NOT be present (empty sender) + if _, ok := meta["sender_id"]; ok { + t.Error("expected no sender_id for delegate channel") + } + // channel should still be present + if ch, ok := meta["channel"].(string); !ok || ch != "delegate" { + t.Errorf("expected channel=delegate, got %v", meta["channel"]) + } +} + +// ── Tests: requireLead ───────────────────────────────────────────────────────── + +func TestExecuteCreate_RequireLead_Rejected(t *testing.T) { + leadID := uuid.New() + nonLeadID := uuid.New() + team := makeTestTeam(leadID, nil) + + ts := &mockTeamStore{team: team} + mgr := NewTeamToolManager(ts, &mockAgentStore{}, nil) + tool := NewTeamTasksTool(mgr) + + // Non-lead agent trying to create task via telegram + ctx := makeCtx(nonLeadID, "group:telegram:chat123", "user-789", "telegram") + args := map[string]any{ + "action": "create", + "subject": "Unauthorized task", + } + + result := tool.executeCreate(ctx, args) + if !result.IsError { + t.Fatal("expected error for non-lead agent") + } + if !strings.Contains(result.ForLLM, "only the team lead") { + t.Errorf("expected 'only the team lead' error, got: %s", result.ForLLM) + } +} + +func TestExecuteCreate_RequireLead_DelegateBypass(t *testing.T) { + leadID := uuid.New() + nonLeadID := uuid.New() + team := makeTestTeam(leadID, nil) + + ts := &mockTeamStore{team: team} + mgr := NewTeamToolManager(ts, &mockAgentStore{}, nil) + tool := NewTeamTasksTool(mgr) + + // Non-lead agent via delegate channel (internal agent-to-agent) should bypass + ctx := makeCtx(nonLeadID, "delegate:system", "", "delegate") + args := map[string]any{ + "action": "create", + "subject": "Delegated task", + } + + result := tool.executeCreate(ctx, args) + if result.IsError { + t.Fatalf("delegate channel should bypass requireLead, got: %s", result.ForLLM) + } +} + +// ── Tests: checkTeamAccess ───────────────────────────────────────────────────── + +func TestCheckTeamAccess_AllowChannels(t *testing.T) { + settings := json.RawMessage(`{"allow_channels":["telegram","delegate","system"]}`) + + // Allowed channel + if err := checkTeamAccess(settings, "user1", "telegram"); err != nil { + t.Errorf("telegram should be allowed: %v", err) + } + + // Blocked channel + if err := checkTeamAccess(settings, "user1", "slack"); err == nil { + t.Error("slack should be denied") + } + + // delegate always passes + if err := checkTeamAccess(settings, "user1", "delegate"); err != nil { + t.Errorf("delegate should always pass: %v", err) + } + + // system always passes + if err := checkTeamAccess(settings, "user1", "system"); err != nil { + t.Errorf("system should always pass: %v", err) + } +} + +func TestCheckTeamAccess_DenyOverAllow(t *testing.T) { + settings := json.RawMessage(`{ + "allow_user_ids": ["user-A", "user-B"], + "deny_user_ids": ["user-B"] + }`) + + // user-A allowed + if err := checkTeamAccess(settings, "user-A", "telegram"); err != nil { + t.Errorf("user-A should be allowed: %v", err) + } + + // user-B denied (deny > allow) + if err := checkTeamAccess(settings, "user-B", "telegram"); err == nil { + t.Error("user-B should be denied (deny overrides allow)") + } + + // user-C not in allow list + if err := checkTeamAccess(settings, "user-C", "telegram"); err == nil { + t.Error("user-C should be denied (not in allow list)") + } +} + +func TestCheckTeamAccess_EmptySettings(t *testing.T) { + // Empty settings = open access + if err := checkTeamAccess(nil, "anyone", "any-channel"); err != nil { + t.Errorf("empty settings should allow all: %v", err) + } + if err := checkTeamAccess(json.RawMessage(`{}`), "anyone", "any-channel"); err != nil { + t.Errorf("empty JSON settings should allow all: %v", err) + } +} + +func TestCheckTeamAccess_DenyChannels(t *testing.T) { + settings := json.RawMessage(`{"deny_channels":["whatsapp"]}`) + + if err := checkTeamAccess(settings, "user1", "telegram"); err != nil { + t.Errorf("telegram should be allowed: %v", err) + } + if err := checkTeamAccess(settings, "user1", "whatsapp"); err == nil { + t.Error("whatsapp should be denied") + } +} + +// ── Tests: requireLead unit ──────────────────────────────────────────────────── + +func TestRequireLead_LeadAllowed(t *testing.T) { + leadID := uuid.New() + team := makeTestTeam(leadID, nil) + mgr := NewTeamToolManager(&mockTeamStore{}, &mockAgentStore{}, nil) + + ctx := makeCtx(leadID, "user1", "", "telegram") + if err := mgr.requireLead(ctx, team, leadID); err != nil { + t.Errorf("lead should be allowed: %v", err) + } +} + +func TestRequireLead_NonLeadRejected(t *testing.T) { + leadID := uuid.New() + otherID := uuid.New() + team := makeTestTeam(leadID, nil) + mgr := NewTeamToolManager(&mockTeamStore{}, &mockAgentStore{}, nil) + + ctx := makeCtx(otherID, "user1", "", "telegram") + if err := mgr.requireLead(ctx, team, otherID); err == nil { + t.Error("non-lead should be rejected") + } +} + +func TestRequireLead_SystemBypass(t *testing.T) { + leadID := uuid.New() + otherID := uuid.New() + team := makeTestTeam(leadID, nil) + mgr := NewTeamToolManager(&mockTeamStore{}, &mockAgentStore{}, nil) + + for _, ch := range []string{"delegate", "system"} { + ctx := makeCtx(otherID, "user1", "", ch) + if err := mgr.requireLead(ctx, team, otherID); err != nil { + t.Errorf("channel %q should bypass requireLead: %v", ch, err) + } + } +} diff --git a/internal/upgrade/version.go b/internal/upgrade/version.go index b4e625e5..0e6b69fc 100644 --- a/internal/upgrade/version.go +++ b/internal/upgrade/version.go @@ -2,4 +2,4 @@ package upgrade // RequiredSchemaVersion is the schema migration version this binary requires. // Bump this whenever adding a new SQL migration file. -const RequiredSchemaVersion uint = 18 +const RequiredSchemaVersion uint = 20 diff --git a/migrations/000018_party_sessions.down.sql b/migrations/000018_party_sessions.down.sql new file mode 100644 index 00000000..4adc042c --- /dev/null +++ b/migrations/000018_party_sessions.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS party_sessions; diff --git a/migrations/000018_party_sessions.up.sql b/migrations/000018_party_sessions.up.sql new file mode 100644 index 00000000..17093083 --- /dev/null +++ b/migrations/000018_party_sessions.up.sql @@ -0,0 +1,23 @@ +-- 000014_party_sessions.up.sql +CREATE TABLE IF NOT EXISTS party_sessions ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), + topic TEXT NOT NULL, + team_preset VARCHAR(50), + status VARCHAR(20) NOT NULL DEFAULT 'assembling', + mode VARCHAR(10) NOT NULL DEFAULT 'standard', + round INT NOT NULL DEFAULT 0, + max_rounds INT NOT NULL DEFAULT 10, + user_id VARCHAR(200) NOT NULL, + channel VARCHAR(255), + chat_id VARCHAR(200), + personas JSONB NOT NULL DEFAULT '[]', + context JSONB NOT NULL DEFAULT '{}', + history JSONB NOT NULL DEFAULT '[]', + summary JSONB, + artifacts JSONB NOT NULL DEFAULT '[]', + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now() +); + +CREATE INDEX IF NOT EXISTS idx_party_sessions_user ON party_sessions(user_id, status); +CREATE INDEX IF NOT EXISTS idx_party_sessions_channel ON party_sessions(channel, chat_id, status); diff --git a/migrations/000018_team_tasks_workspace_followup.down.sql b/migrations/000019_team_tasks_workspace_followup.down.sql similarity index 100% rename from migrations/000018_team_tasks_workspace_followup.down.sql rename to migrations/000019_team_tasks_workspace_followup.down.sql diff --git a/migrations/000018_team_tasks_workspace_followup.up.sql b/migrations/000019_team_tasks_workspace_followup.up.sql similarity index 100% rename from migrations/000018_team_tasks_workspace_followup.up.sql rename to migrations/000019_team_tasks_workspace_followup.up.sql diff --git a/migrations/000020_projects.down.sql b/migrations/000020_projects.down.sql new file mode 100644 index 00000000..6b08b89c --- /dev/null +++ b/migrations/000020_projects.down.sql @@ -0,0 +1,6 @@ +-- 000020_projects.down.sql +DROP TRIGGER IF EXISTS set_project_mcp_overrides_updated_at ON project_mcp_overrides; +DROP TRIGGER IF EXISTS set_projects_updated_at ON projects; +DROP FUNCTION IF EXISTS update_projects_updated_at(); +DROP TABLE IF EXISTS project_mcp_overrides; +DROP TABLE IF EXISTS projects; diff --git a/migrations/000020_projects.up.sql b/migrations/000020_projects.up.sql new file mode 100644 index 00000000..eecd8498 --- /dev/null +++ b/migrations/000020_projects.up.sql @@ -0,0 +1,45 @@ +-- 000020_projects.up.sql +-- Project entity for per-group MCP env overrides (Project-as-a-Channel) + +CREATE TABLE projects ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), + name VARCHAR(255) NOT NULL, + slug VARCHAR(100) NOT NULL UNIQUE, + channel_type VARCHAR(50), + chat_id VARCHAR(255), + team_id UUID REFERENCES agent_teams(id) ON DELETE SET NULL, + description TEXT, + status VARCHAR(20) NOT NULL DEFAULT 'active', + created_by VARCHAR(255) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(channel_type, chat_id) +); + +CREATE TABLE project_mcp_overrides ( + id UUID PRIMARY KEY DEFAULT uuid_generate_v7(), + project_id UUID NOT NULL REFERENCES projects(id) ON DELETE CASCADE, + server_name VARCHAR(255) NOT NULL, + env_overrides JSONB NOT NULL DEFAULT '{}', + enabled BOOLEAN NOT NULL DEFAULT true, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + UNIQUE(project_id, server_name) +); + +-- Auto-update updated_at on row modification +CREATE OR REPLACE FUNCTION update_projects_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER set_projects_updated_at + BEFORE UPDATE ON projects + FOR EACH ROW EXECUTE FUNCTION update_projects_updated_at(); + +CREATE TRIGGER set_project_mcp_overrides_updated_at + BEFORE UPDATE ON project_mcp_overrides + FOR EACH ROW EXECUTE FUNCTION update_projects_updated_at(); diff --git a/pkg/protocol/party.go b/pkg/protocol/party.go new file mode 100644 index 00000000..b724e35f --- /dev/null +++ b/pkg/protocol/party.go @@ -0,0 +1,26 @@ +package protocol + +// Party Mode methods. +const ( + MethodPartyStart = "party.start" + MethodPartyRound = "party.round" + MethodPartyQuestion = "party.question" + MethodPartyAddContext = "party.add_context" + MethodPartySummary = "party.summary" + MethodPartyExit = "party.exit" + MethodPartyList = "party.list" +) + +// Party Mode events. +const ( + EventPartyStarted = "party.started" + EventPartyPersonaIntro = "party.persona.intro" + EventPartyRoundStarted = "party.round.started" + EventPartyPersonaThinking = "party.persona.thinking" + EventPartyPersonaSpoke = "party.persona.spoke" + EventPartyRoundComplete = "party.round.complete" + EventPartyContextAdded = "party.context.added" + EventPartySummaryReady = "party.summary.ready" + EventPartyArtifact = "party.artifact" + EventPartyClosed = "party.closed" +) diff --git a/ui/web/src/api/protocol.ts b/ui/web/src/api/protocol.ts index 679963dd..d9674158 100644 --- a/ui/web/src/api/protocol.ts +++ b/ui/web/src/api/protocol.ts @@ -153,6 +153,15 @@ export const Methods = { DELEGATIONS_LIST: "delegations.list", DELEGATIONS_GET: "delegations.get", + // Party mode + PARTY_START: "party.start", + PARTY_ROUND: "party.round", + PARTY_QUESTION: "party.question", + PARTY_ADD_CONTEXT: "party.add_context", + PARTY_SUMMARY: "party.summary", + PARTY_EXIT: "party.exit", + PARTY_LIST: "party.list", + // Phase 3+ - NICE TO HAVE LOGS_TAIL: "logs.tail", } as const; @@ -221,6 +230,18 @@ export const Events = { // Trace lifecycle TRACE_UPDATED: "trace.updated", + // Party mode + PARTY_STARTED: "party.started", + PARTY_PERSONA_INTRO: "party.persona.intro", + PARTY_ROUND_STARTED: "party.round.started", + PARTY_PERSONA_THINKING: "party.persona.thinking", + PARTY_PERSONA_SPOKE: "party.persona.spoke", + PARTY_ROUND_COMPLETE: "party.round.complete", + PARTY_CONTEXT_ADDED: "party.context.added", + PARTY_SUMMARY_READY: "party.summary.ready", + PARTY_ARTIFACT: "party.artifact", + PARTY_CLOSED: "party.closed", + // Skill dependency check (realtime progress during startup/rescan) SKILL_DEPS_CHECKED: "skill.deps.checked", SKILL_DEPS_COMPLETE: "skill.deps.complete", diff --git a/ui/web/src/components/layout/connection-status.tsx b/ui/web/src/components/layout/connection-status.tsx index d5447014..bad04ced 100644 --- a/ui/web/src/components/layout/connection-status.tsx +++ b/ui/web/src/components/layout/connection-status.tsx @@ -7,7 +7,7 @@ export function ConnectionStatus() { const connected = useAuthStore((s) => s.connected); return ( -
+
+ @@ -84,6 +87,7 @@ export function Sidebar({ collapsed, onNavItemClick }: SidebarProps) { + diff --git a/ui/web/src/i18n/index.ts b/ui/web/src/i18n/index.ts index 09c342d0..418f9dc4 100644 --- a/ui/web/src/i18n/index.ts +++ b/ui/web/src/i18n/index.ts @@ -30,8 +30,10 @@ import enSetup from "./locales/en/setup.json"; import enMemory from "./locales/en/memory.json"; import enStorage from "./locales/en/storage.json"; import enPendingMessages from "./locales/en/pending-messages.json"; +import enParty from "./locales/en/party.json"; import enContacts from "./locales/en/contacts.json"; import enActivity from "./locales/en/activity.json"; +import enProjects from "./locales/en/projects.json"; // --- VI namespaces --- import viCommon from "./locales/vi/common.json"; @@ -62,8 +64,10 @@ import viSetup from "./locales/vi/setup.json"; import viMemory from "./locales/vi/memory.json"; import viStorage from "./locales/vi/storage.json"; import viPendingMessages from "./locales/vi/pending-messages.json"; +import viParty from "./locales/vi/party.json"; import viContacts from "./locales/vi/contacts.json"; import viActivity from "./locales/vi/activity.json"; +import viProjects from "./locales/vi/projects.json"; // --- ZH namespaces --- import zhCommon from "./locales/zh/common.json"; @@ -94,8 +98,10 @@ import zhSetup from "./locales/zh/setup.json"; import zhMemory from "./locales/zh/memory.json"; import zhStorage from "./locales/zh/storage.json"; import zhPendingMessages from "./locales/zh/pending-messages.json"; +import zhParty from "./locales/zh/party.json"; import zhContacts from "./locales/zh/contacts.json"; import zhActivity from "./locales/zh/activity.json"; +import zhProjects from "./locales/zh/projects.json"; const STORAGE_KEY = "goclaw:language"; @@ -113,7 +119,7 @@ const ns = [ "agents", "teams", "sessions", "skills", "cron", "config", "channels", "providers", "traces", "events", "delegations", "usage", "approvals", "nodes", "logs", "tools", "mcp", "tts", - "setup", "memory", "storage", "pending-messages", "contacts", "activity", + "setup", "memory", "storage", "pending-messages", "party", "contacts", "activity", "projects", ] as const; i18n.use(initReactI18next).init({ @@ -127,7 +133,7 @@ i18n.use(initReactI18next).init({ approvals: enApprovals, nodes: enNodes, logs: enLogs, tools: enTools, mcp: enMcp, tts: enTts, setup: enSetup, memory: enMemory, storage: enStorage, "pending-messages": enPendingMessages, - contacts: enContacts, activity: enActivity, + party: enParty, contacts: enContacts, activity: enActivity, projects: enProjects, }, vi: { common: viCommon, sidebar: viSidebar, topbar: viTopbar, login: viLogin, @@ -138,7 +144,7 @@ i18n.use(initReactI18next).init({ approvals: viApprovals, nodes: viNodes, logs: viLogs, tools: viTools, mcp: viMcp, tts: viTts, setup: viSetup, memory: viMemory, storage: viStorage, "pending-messages": viPendingMessages, - contacts: viContacts, activity: viActivity, + party: viParty, contacts: viContacts, activity: viActivity, projects: viProjects, }, zh: { common: zhCommon, sidebar: zhSidebar, topbar: zhTopbar, login: zhLogin, @@ -149,7 +155,7 @@ i18n.use(initReactI18next).init({ approvals: zhApprovals, nodes: zhNodes, logs: zhLogs, tools: zhTools, mcp: zhMcp, tts: zhTts, setup: zhSetup, memory: zhMemory, storage: zhStorage, "pending-messages": zhPendingMessages, - contacts: zhContacts, activity: zhActivity, + party: zhParty, contacts: zhContacts, activity: zhActivity, projects: zhProjects, }, }, ns: [...ns], diff --git a/ui/web/src/i18n/locales/en/party.json b/ui/web/src/i18n/locales/en/party.json new file mode 100644 index 00000000..b4c6e0ab --- /dev/null +++ b/ui/web/src/i18n/locales/en/party.json @@ -0,0 +1,46 @@ +{ + "title": "Party Mode", + "newParty": "New Party", + "topic": "Discussion Topic", + "selectTeam": "Select Team", + "customTeam": "Custom Team", + "start": "Start Discussion", + "presets": { + "payment_feature": "Payment Feature", + "security_review": "Security Review", + "sprint_planning": "Sprint Planning", + "architecture_decision": "Architecture Decision", + "ux_review": "UX Review", + "incident_response": "Incident Response" + }, + "controls": { + "continue": "Continue", + "deepMode": "Deep Mode [P]", + "tokenRing": "Token Ring [R]", + "question": "Question [Q]", + "summary": "Summary [D]", + "exit": "Exit [E]" + }, + "status": { + "thinking": "Thinking...", + "speaking": "Speaking", + "idle": "Idle" + }, + "round": "Round {{n}}", + "mode": { + "standard": "Standard", + "deep": "Deep", + "token_ring": "Token Ring" + }, + "noSessions": "No party sessions yet", + "description": "Multi-persona AI discussions with structured rounds", + "exitConfirm": "Exit this party session?", + "summary": { + "title": "Discussion Summary", + "agreements": "Points of Agreement", + "disagreements": "Points of Disagreement", + "decisions": "Decisions Made", + "actionItems": "Action Items", + "compliance": "Compliance Notes" + } +} diff --git a/ui/web/src/i18n/locales/en/projects.json b/ui/web/src/i18n/locales/en/projects.json new file mode 100644 index 00000000..53bf6335 --- /dev/null +++ b/ui/web/src/i18n/locales/en/projects.json @@ -0,0 +1,79 @@ +{ + "title": "Projects", + "description": "Manage project workspaces and their MCP server overrides", + "addProject": "New Project", + "searchPlaceholder": "Search projects...", + "emptyTitle": "No projects", + "emptyDescription": "Create your first project to get started.", + "noMatchTitle": "No matching projects", + "noMatchDescription": "Try a different search term.", + "columns": { + "name": "Name", + "slug": "Slug", + "channel": "Channel", + "status": "Status", + "overrides": "MCP Overrides", + "createdBy": "Created By", + "actions": "Actions" + }, + "manageOverrides": "Manage MCP overrides", + "delete": { + "title": "Delete Project", + "description": "Are you sure you want to delete \"{{name}}\"? This will also remove all MCP overrides.", + "confirmLabel": "Delete" + }, + "form": { + "createTitle": "New Project", + "editTitle": "Edit Project", + "name": "Name *", + "slug": "Slug *", + "slugHint": "Lowercase letters, numbers, and hyphens only", + "channelType": "Channel Type", + "channelTypePlaceholder": "Select channel type...", + "chatId": "Chat ID", + "chatIdPlaceholder": "e.g. -1001234567890", + "description": "Description", + "descriptionPlaceholder": "Brief description of this project", + "status": "Status", + "cancel": "Cancel", + "create": "Create", + "update": "Update", + "saving": "Saving...", + "errors": { + "nameRequired": "Name and slug are required", + "slugInvalid": "Slug must be lowercase letters, numbers, and hyphens only" + } + }, + "overrides": { + "title": "MCP Overrides — {{name}}", + "description": "Per-project environment variable overrides for MCP servers. Secret keys (TOKEN, SECRET, PASSWORD, API_KEY) must be set in the global MCP server config.", + "serverName": "MCP Server *", + "serverNamePlaceholder": "e.g. gitlab, atlassian", + "envOverrides": "Environment Variables", + "envKeyPlaceholder": "Variable name", + "envValuePlaceholder": "Value", + "addVariable": "Add Variable", + "addOverride": "Add Override", + "save": "Save", + "saving": "Saving...", + "noOverrides": "No MCP overrides configured", + "noOverridesDescription": "Add per-project environment variables for MCP servers.", + "failedLoad": "Failed to load overrides", + "failedSave": "Failed to save override", + "failedDelete": "Failed to delete override", + "saved": "Override saved", + "deleted": "Override deleted" + }, + "status": { + "active": "Active", + "archived": "Archived" + }, + "toast": { + "created": "Project created", + "updated": "Project updated", + "deleted": "Project deleted", + "failedCreate": "Failed to create project", + "failedUpdate": "Failed to update project", + "failedDelete": "Failed to delete project" + } +} diff --git a/ui/web/src/i18n/locales/en/sidebar.json b/ui/web/src/i18n/locales/en/sidebar.json index 4d4da4f4..ea98c4ea 100644 --- a/ui/web/src/i18n/locales/en/sidebar.json +++ b/ui/web/src/i18n/locales/en/sidebar.json @@ -35,6 +35,8 @@ "approvals": "Approvals", "nodes": "Nodes", "tts": "TTS", - "activity": "Activity" + "party": "Party Mode", + "activity": "Activity", + "projects": "Projects" } } diff --git a/ui/web/src/i18n/locales/vi/party.json b/ui/web/src/i18n/locales/vi/party.json new file mode 100644 index 00000000..29de1f63 --- /dev/null +++ b/ui/web/src/i18n/locales/vi/party.json @@ -0,0 +1,46 @@ +{ + "title": "Ch\u1ebf \u0111\u1ed9 Party", + "newParty": "T\u1ea1o Party m\u1edbi", + "topic": "Ch\u1ee7 \u0111\u1ec1 th\u1ea3o lu\u1eadn", + "selectTeam": "Ch\u1ecdn \u0111\u1ed9i", + "customTeam": "\u0110\u1ed9i tu\u1ef3 ch\u1ec9nh", + "start": "B\u1eaft \u0111\u1ea7u th\u1ea3o lu\u1eadn", + "presets": { + "payment_feature": "T\u00ednh n\u0103ng Thanh to\u00e1n", + "security_review": "\u0110\u00e1nh gi\u00e1 B\u1ea3o m\u1eadt", + "sprint_planning": "L\u1eadp k\u1ebf ho\u1ea1ch Sprint", + "architecture_decision": "Quy\u1ebft \u0111\u1ecbnh Ki\u1ebfn tr\u00fac", + "ux_review": "\u0110\u00e1nh gi\u00e1 UX", + "incident_response": "X\u1eed l\u00fd s\u1ef1 c\u1ed1" + }, + "controls": { + "continue": "Ti\u1ebfp t\u1ee5c", + "deepMode": "Ch\u1ebf \u0111\u1ed9 Deep [P]", + "tokenRing": "Token Ring [R]", + "question": "C\u00e2u h\u1ecfi [Q]", + "summary": "T\u00f3m t\u1eaft [D]", + "exit": "Tho\u00e1t [E]" + }, + "status": { + "thinking": "\u0110ang suy ngh\u0129...", + "speaking": "\u0110ang n\u00f3i", + "idle": "Ch\u1edd" + }, + "round": "V\u00f2ng {{n}}", + "mode": { + "standard": "Ti\u00eau chu\u1ea9n", + "deep": "Deep", + "token_ring": "Token Ring" + }, + "noSessions": "Ch\u01b0a c\u00f3 phi\u00ean party n\u00e0o", + "description": "Th\u1ea3o lu\u1eadn AI \u0111a nh\u00e2n v\u1eadt v\u1edbi c\u00e1c v\u00f2ng c\u00f3 c\u1ea5u tr\u00fac", + "exitConfirm": "Tho\u00e1t phi\u00ean party n\u00e0y?", + "summary": { + "title": "T\u00f3m t\u1eaft th\u1ea3o lu\u1eadn", + "agreements": "\u0110i\u1ec3m \u0111\u1ed3ng thu\u1eadn", + "disagreements": "\u0110i\u1ec3m b\u1ea5t \u0111\u1ed3ng", + "decisions": "Quy\u1ebft \u0111\u1ecbnh", + "actionItems": "H\u1ea1ng m\u1ee5c h\u00e0nh \u0111\u1ed9ng", + "compliance": "Ghi ch\u00fa tu\u00e2n th\u1ee7" + } +} diff --git a/ui/web/src/i18n/locales/vi/projects.json b/ui/web/src/i18n/locales/vi/projects.json new file mode 100644 index 00000000..ca1b212e --- /dev/null +++ b/ui/web/src/i18n/locales/vi/projects.json @@ -0,0 +1,79 @@ +{ + "title": "Dự án", + "description": "Quản lý workspace dự án và cấu hình MCP server riêng", + "addProject": "Dự án mới", + "searchPlaceholder": "Tìm dự án...", + "emptyTitle": "Chưa có dự án", + "emptyDescription": "Tạo dự án đầu tiên để bắt đầu.", + "noMatchTitle": "Không tìm thấy dự án", + "noMatchDescription": "Thử từ khóa khác.", + "columns": { + "name": "Tên", + "slug": "Slug", + "channel": "Kênh", + "status": "Trạng thái", + "overrides": "MCP Override", + "createdBy": "Người tạo", + "actions": "Thao tác" + }, + "manageOverrides": "Quản lý MCP override", + "delete": { + "title": "Xóa dự án", + "description": "Bạn có chắc muốn xóa \"{{name}}\"? Tất cả MCP override cũng sẽ bị xóa.", + "confirmLabel": "Xóa" + }, + "form": { + "createTitle": "Dự án mới", + "editTitle": "Sửa dự án", + "name": "Tên *", + "slug": "Slug *", + "slugHint": "Chỉ chữ thường, số và dấu gạch ngang", + "channelType": "Loại kênh", + "channelTypePlaceholder": "Chọn loại kênh...", + "chatId": "Chat ID", + "chatIdPlaceholder": "VD: -1001234567890", + "description": "Mô tả", + "descriptionPlaceholder": "Mô tả ngắn về dự án", + "status": "Trạng thái", + "cancel": "Hủy", + "create": "Tạo", + "update": "Cập nhật", + "saving": "Đang lưu...", + "errors": { + "nameRequired": "Tên và slug là bắt buộc", + "slugInvalid": "Slug chỉ được chứa chữ thường, số và dấu gạch ngang" + } + }, + "overrides": { + "title": "MCP Override — {{name}}", + "description": "Biến môi trường riêng cho từng dự án. Các key bí mật (TOKEN, SECRET, PASSWORD, API_KEY) phải đặt trong cấu hình MCP server chung.", + "serverName": "MCP Server *", + "serverNamePlaceholder": "VD: gitlab, atlassian", + "envOverrides": "Biến môi trường", + "envKeyPlaceholder": "Tên biến", + "envValuePlaceholder": "Giá trị", + "addVariable": "Thêm biến", + "addOverride": "Thêm Override", + "save": "Lưu", + "saving": "Đang lưu...", + "noOverrides": "Chưa có MCP override", + "noOverridesDescription": "Thêm biến môi trường riêng cho MCP server.", + "failedLoad": "Không thể tải override", + "failedSave": "Không thể lưu override", + "failedDelete": "Không thể xóa override", + "saved": "Đã lưu override", + "deleted": "Đã xóa override" + }, + "status": { + "active": "Hoạt động", + "archived": "Lưu trữ" + }, + "toast": { + "created": "Đã tạo dự án", + "updated": "Đã cập nhật dự án", + "deleted": "Đã xóa dự án", + "failedCreate": "Không thể tạo dự án", + "failedUpdate": "Không thể cập nhật dự án", + "failedDelete": "Không thể xóa dự án" + } +} diff --git a/ui/web/src/i18n/locales/vi/sidebar.json b/ui/web/src/i18n/locales/vi/sidebar.json index 67640d1f..dc82888c 100644 --- a/ui/web/src/i18n/locales/vi/sidebar.json +++ b/ui/web/src/i18n/locales/vi/sidebar.json @@ -35,6 +35,8 @@ "approvals": "Phê duyệt", "nodes": "Nodes", "tts": "TTS", - "activity": "Hoạt động" + "party": "Chế độ Party", + "activity": "Hoạt động", + "projects": "Dự án" } } diff --git a/ui/web/src/i18n/locales/zh/party.json b/ui/web/src/i18n/locales/zh/party.json new file mode 100644 index 00000000..92ff4e51 --- /dev/null +++ b/ui/web/src/i18n/locales/zh/party.json @@ -0,0 +1,46 @@ +{ + "title": "Party 模式", + "newParty": "新建 Party", + "topic": "讨论主题", + "selectTeam": "选择团队", + "customTeam": "自定义团队", + "start": "开始讨论", + "presets": { + "payment_feature": "支付功能", + "security_review": "安全审查", + "sprint_planning": "Sprint 规划", + "architecture_decision": "架构决策", + "ux_review": "UX 审查", + "incident_response": "事件响应" + }, + "controls": { + "continue": "继续", + "deepMode": "深度模式 [P]", + "tokenRing": "令牌环 [R]", + "question": "提问 [Q]", + "summary": "总结 [D]", + "exit": "退出 [E]" + }, + "status": { + "thinking": "思考中...", + "speaking": "发言中", + "idle": "空闲" + }, + "round": "第 {{n}} 轮", + "mode": { + "standard": "标准", + "deep": "深度", + "token_ring": "令牌环" + }, + "noSessions": "暂无 Party 会话", + "description": "多角色 AI 讨论,结构化轮次", + "exitConfirm": "退出此 Party 会话?", + "summary": { + "title": "讨论总结", + "agreements": "共识要点", + "disagreements": "分歧要点", + "decisions": "已做决策", + "actionItems": "行动项", + "compliance": "合规备注" + } +} diff --git a/ui/web/src/i18n/locales/zh/projects.json b/ui/web/src/i18n/locales/zh/projects.json new file mode 100644 index 00000000..42e89384 --- /dev/null +++ b/ui/web/src/i18n/locales/zh/projects.json @@ -0,0 +1,79 @@ +{ + "title": "项目", + "description": "管理项目工作空间及其 MCP 服务器配置", + "addProject": "新建项目", + "searchPlaceholder": "搜索项目...", + "emptyTitle": "暂无项目", + "emptyDescription": "创建第一个项目以开始使用。", + "noMatchTitle": "未找到匹配项目", + "noMatchDescription": "请尝试其他搜索词。", + "columns": { + "name": "名称", + "slug": "Slug", + "channel": "渠道", + "status": "状态", + "overrides": "MCP 覆盖", + "createdBy": "创建者", + "actions": "操作" + }, + "manageOverrides": "管理 MCP 覆盖", + "delete": { + "title": "删除项目", + "description": "确定要删除 \"{{name}}\" 吗?所有 MCP 覆盖也将被删除。", + "confirmLabel": "删除" + }, + "form": { + "createTitle": "新建项目", + "editTitle": "编辑项目", + "name": "名称 *", + "slug": "Slug *", + "slugHint": "仅限小写字母、数字和连字符", + "channelType": "渠道类型", + "channelTypePlaceholder": "选择渠道类型...", + "chatId": "Chat ID", + "chatIdPlaceholder": "例如 -1001234567890", + "description": "描述", + "descriptionPlaceholder": "项目简要描述", + "status": "状态", + "cancel": "取消", + "create": "创建", + "update": "更新", + "saving": "保存中...", + "errors": { + "nameRequired": "名称和 Slug 为必填项", + "slugInvalid": "Slug 仅允许小写字母、数字和连字符" + } + }, + "overrides": { + "title": "MCP 覆盖 — {{name}}", + "description": "每个项目的 MCP 服务器环境变量覆盖。敏感键(TOKEN、SECRET、PASSWORD、API_KEY)须在全局 MCP 服务器配置中设置。", + "serverName": "MCP 服务器 *", + "serverNamePlaceholder": "例如 gitlab、atlassian", + "envOverrides": "环境变量", + "envKeyPlaceholder": "变量名", + "envValuePlaceholder": "值", + "addVariable": "添加变量", + "addOverride": "添加覆盖", + "save": "保存", + "saving": "保存中...", + "noOverrides": "暂无 MCP 覆盖", + "noOverridesDescription": "为 MCP 服务器添加项目专属环境变量。", + "failedLoad": "加载覆盖失败", + "failedSave": "保存覆盖失败", + "failedDelete": "删除覆盖失败", + "saved": "覆盖已保存", + "deleted": "覆盖已删除" + }, + "status": { + "active": "活跃", + "archived": "已归档" + }, + "toast": { + "created": "项目已创建", + "updated": "项目已更新", + "deleted": "项目已删除", + "failedCreate": "创建项目失败", + "failedUpdate": "更新项目失败", + "failedDelete": "删除项目失败" + } +} diff --git a/ui/web/src/i18n/locales/zh/sidebar.json b/ui/web/src/i18n/locales/zh/sidebar.json index f9224930..d754faaa 100644 --- a/ui/web/src/i18n/locales/zh/sidebar.json +++ b/ui/web/src/i18n/locales/zh/sidebar.json @@ -35,6 +35,8 @@ "traces": "追踪", "tts": "TTS", "usage": "用量", - "activity": "活动日志" + "party": "Party 模式", + "activity": "活动日志", + "projects": "项目" } } diff --git a/ui/web/src/lib/constants.ts b/ui/web/src/lib/constants.ts index 001b48ad..05ddc426 100644 --- a/ui/web/src/lib/constants.ts +++ b/ui/web/src/lib/constants.ts @@ -26,6 +26,7 @@ export const ROUTES = { PROVIDERS: "/providers", TEAMS: "/teams", TEAM_DETAIL: "/teams/:id", + PARTY: "/party", CUSTOM_TOOLS: "/custom-tools", BUILTIN_TOOLS: "/builtin-tools", MCP: "/mcp", @@ -35,6 +36,7 @@ export const ROUTES = { MEMORY: "/memory", KNOWLEDGE_GRAPH: "/knowledge-graph", ACTIVITY: "/activity", + PROJECTS: "/projects", SETUP: "/setup", } as const; diff --git a/ui/web/src/lib/query-keys.ts b/ui/web/src/lib/query-keys.ts index bc6aec12..ed78107f 100644 --- a/ui/web/src/lib/query-keys.ts +++ b/ui/web/src/lib/query-keys.ts @@ -64,6 +64,11 @@ export const queryKeys = { all: ["teams"] as const, detail: (id: string) => ["teams", id] as const, }, + projects: { + all: ["projects"] as const, + detail: (id: string) => ["projects", id] as const, + overrides: (id: string) => ["projects", id, "overrides"] as const, + }, memory: { all: ["memory"] as const, list: (params: Record) => ["memory", params] as const, diff --git a/ui/web/src/pages/party/hooks/use-party.ts b/ui/web/src/pages/party/hooks/use-party.ts new file mode 100644 index 00000000..200f72cc --- /dev/null +++ b/ui/web/src/pages/party/hooks/use-party.ts @@ -0,0 +1,453 @@ +import { useState, useCallback, useRef } from "react"; +import { useWs } from "@/hooks/use-ws"; +import { useWsEvent } from "@/hooks/use-ws-event"; +import { useAuthStore } from "@/stores/use-auth-store"; +import { Methods, Events } from "@/api/protocol"; + +// --- Types --- + +export type PartyMode = "standard" | "deep" | "token_ring"; + +export interface PersonaInfo { + key: string; + emoji: string; + name: string; + role: string; + color: string; +} + +export interface PartyMessage { + id: string; + type: "intro" | "spoke" | "thinking" | "round_header" | "context" | "summary" | "artifact"; + personaKey?: string; + personaEmoji?: string; + personaName?: string; + content: string; + round?: number; + mode?: PartyMode; + timestamp: number; +} + +export interface PartySession { + id: string; + topic: string; + status: "active" | "closed"; + personas: PersonaInfo[]; + round: number; + mode: PartyMode; + createdAt: string; + // Preserved from backend for session restore + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _history?: any[]; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + _summary?: any; +} + +export interface PartySummary { + agreements?: string[]; + disagreements?: string[]; + decisions?: string[]; + actionItems?: string[]; + compliance?: string[]; + markdown?: string; +} + +// Persona color palette for left-border styling +const PERSONA_COLORS = [ + "#3b82f6", "#ef4444", "#10b981", "#f59e0b", "#8b5cf6", + "#ec4899", "#06b6d4", "#f97316", "#6366f1", "#14b8a6", + "#e11d48", "#84cc16", "#a855f7", "#0ea5e9", +]; + +// Map backend status to frontend display status +function mapStatus(backendStatus: string): "active" | "closed" { + return backendStatus === "closed" ? "closed" : "active"; +} + +// Transform backend session (snake_case) to frontend PartySession +// eslint-disable-next-line @typescript-eslint/no-explicit-any +function transformSession(raw: any): PartySession { + const personaKeys: string[] = Array.isArray(raw.personas) + ? raw.personas + : []; + return { + id: raw.id, + topic: raw.topic ?? "", + status: mapStatus(raw.status ?? ""), + personas: personaKeys.map((k: string) => ({ + key: k, + emoji: "", + name: k, + role: "", + color: "", + })), + round: raw.round ?? 0, + mode: raw.mode ?? "standard", + createdAt: raw.created_at ?? raw.createdAt ?? "", + _history: Array.isArray(raw.history) ? raw.history : undefined, + _summary: raw.summary ?? undefined, + }; +} + +// Hydrate PartyMessage[] from backend history (RoundResult[]) and summary +function hydrateMessages( + history: any[] | undefined, // eslint-disable-line @typescript-eslint/no-explicit-any + summary: any | undefined, // eslint-disable-line @typescript-eslint/no-explicit-any + startId: number, +): { msgs: PartyMessage[]; nextId: number } { + let id = startId; + const msgs: PartyMessage[] = []; + if (!history) return { msgs, nextId: id }; + + for (const round of history) { + // Round header + msgs.push({ + id: `pm-${++id}`, + type: "round_header", + content: "", + round: round.round, + mode: round.mode as PartyMode, + timestamp: Date.now(), + }); + // Persona messages + for (const m of round.messages ?? []) { + msgs.push({ + id: `pm-${++id}`, + type: "spoke", + personaKey: m.persona_key, + personaEmoji: m.emoji ?? "", + personaName: m.display_name ?? m.persona_key, + content: m.content ?? "", + round: round.round, + timestamp: Date.now(), + }); + } + } + + // Summary (if exists and has markdown) + if (summary && typeof summary === "object" && summary.markdown) { + msgs.push({ + id: `pm-${++id}`, + type: "summary", + content: summary.markdown, + timestamp: Date.now(), + }); + } + + return { msgs, nextId: id }; +} + +export function useParty() { + const ws = useWs(); + const connected = useAuthStore((s) => s.connected); + + const [sessions, setSessions] = useState([]); + const [activeSessionId, setActiveSessionId] = useState(null); + const [messages, setMessages] = useState([]); + const [personas, setPersonas] = useState([]); + const [thinkingPersonas, setThinkingPersonas] = useState>(new Set()); + const [round, setRound] = useState(0); + const [mode, setMode] = useState("standard"); + const [status, setStatus] = useState<"idle" | "active" | "closed">("idle"); + const [summary, setSummary] = useState(null); + const [loading, setLoading] = useState(false); + + const msgIdCounter = useRef(0); + const personaColorMap = useRef>(new Map()); + + const getPersonaColor = useCallback((key: string): string => { + if (personaColorMap.current.has(key)) { + return personaColorMap.current.get(key)!; + } + const idx = personaColorMap.current.size % PERSONA_COLORS.length; + const color = PERSONA_COLORS[idx] ?? "#3b82f6"; + personaColorMap.current.set(key, color); + return color; + }, []); + + const addMessage = useCallback((msg: Omit) => { + const id = `pm-${++msgIdCounter.current}`; + setMessages((prev) => [...prev, { ...msg, id, timestamp: Date.now() }]); + }, []); + + // --- Event handlers --- + // Backend sends snake_case field names (Go json tags) + + const handlePartyStarted = useCallback((payload: unknown) => { + // Backend: { session_id, topic, personas: [{ agent_key, display_name, emoji, movie_ref }] } + const p = payload as { + session_id: string; + topic: string; + personas: Array<{ agent_key: string; display_name: string; emoji: string; movie_ref: string }>; + }; + const sessionId = p.session_id; + setActiveSessionId(sessionId); + const enriched: PersonaInfo[] = (p.personas ?? []).map((pe) => ({ + key: pe.agent_key, + emoji: pe.emoji ?? "", + name: pe.display_name ?? pe.agent_key, + role: pe.movie_ref ?? "", + color: getPersonaColor(pe.agent_key), + })); + setPersonas(enriched); + setRound(0); + setMode("standard"); + setStatus("active"); + setSummary(null); + setMessages([]); + personaColorMap.current.clear(); + enriched.forEach((pe) => personaColorMap.current.set(pe.key, pe.color)); + + // Add new session to sessions list for sidebar + setSessions((prev) => [ + { + id: sessionId, + topic: p.topic ?? "", + status: "active", + personas: enriched, + round: 0, + mode: "standard", + createdAt: new Date().toISOString(), + }, + ...prev, + ]); + }, [getPersonaColor]); + + const handlePersonaIntro = useCallback((payload: unknown) => { + // Backend: { session_id, persona, emoji, content } + const p = payload as { persona: string; emoji: string; content: string }; + addMessage({ + type: "intro", + personaKey: p.persona, + personaEmoji: p.emoji ?? "", + personaName: p.persona, + content: p.content ?? "", + }); + }, [addMessage]); + + const handleRoundStarted = useCallback((payload: unknown) => { + const p = payload as { round: number; mode: string }; + setRound(p.round); + setMode(p.mode as PartyMode); + addMessage({ + type: "round_header", + content: "", + round: p.round, + mode: p.mode as PartyMode, + }); + }, [addMessage]); + + const handlePersonaThinking = useCallback((payload: unknown) => { + // Backend: { session_id, persona, emoji } + const p = payload as { persona: string }; + setThinkingPersonas((prev) => new Set(prev).add(p.persona)); + }, []); + + const handlePersonaSpoke = useCallback((payload: unknown) => { + // Backend: { session_id, persona, emoji, content } + const p = payload as { persona: string; emoji: string; content: string }; + setThinkingPersonas((prev) => { + const next = new Set(prev); + next.delete(p.persona); + return next; + }); + addMessage({ + type: "spoke", + personaKey: p.persona, + personaEmoji: p.emoji ?? "", + personaName: p.persona, + content: p.content ?? "", + }); + }, [addMessage]); + + const handleRoundComplete = useCallback((payload: unknown) => { + const p = payload as { round: number }; + setThinkingPersonas(new Set()); + void p; + }, []); + + const handleContextAdded = useCallback((payload: unknown) => { + const p = payload as { type: string; name?: string }; + addMessage({ + type: "context", + content: `Context added: ${p.type}${p.name ? ` (${p.name})` : ""}`, + }); + }, [addMessage]); + + const handleSummaryReady = useCallback((payload: unknown) => { + // Backend: { session_id, summary: { markdown, ... } } + const raw = payload as { summary?: PartySummary; markdown?: string }; + const s: PartySummary = raw.summary ?? raw as PartySummary; + setSummary(s); + addMessage({ + type: "summary", + content: s.markdown ?? "", + }); + }, [addMessage]); + + const handleArtifact = useCallback((payload: unknown) => { + const p = payload as { name: string; content: string }; + addMessage({ + type: "artifact", + content: `**${p.name}**\n\n${p.content}`, + }); + }, [addMessage]); + + const handlePartyClosed = useCallback((payload: unknown) => { + const p = payload as { session_id?: string }; + setStatus("closed"); + setThinkingPersonas(new Set()); + // Update session status in the list + if (p.session_id) { + setSessions((prev) => + prev.map((s) => (s.id === p.session_id ? { ...s, status: "closed" as const } : s)), + ); + } + }, []); + + // --- Subscribe to events --- + useWsEvent(Events.PARTY_STARTED, handlePartyStarted); + useWsEvent(Events.PARTY_PERSONA_INTRO, handlePersonaIntro); + useWsEvent(Events.PARTY_ROUND_STARTED, handleRoundStarted); + useWsEvent(Events.PARTY_PERSONA_THINKING, handlePersonaThinking); + useWsEvent(Events.PARTY_PERSONA_SPOKE, handlePersonaSpoke); + useWsEvent(Events.PARTY_ROUND_COMPLETE, handleRoundComplete); + useWsEvent(Events.PARTY_CONTEXT_ADDED, handleContextAdded); + useWsEvent(Events.PARTY_SUMMARY_READY, handleSummaryReady); + useWsEvent(Events.PARTY_ARTIFACT, handleArtifact); + useWsEvent(Events.PARTY_CLOSED, handlePartyClosed); + + // --- RPC calls --- + // Backend expects snake_case field names (Go json tags) + + const listSessions = useCallback(async () => { + if (!connected) return; + setLoading(true); + try { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + const res = await ws.call<{ sessions: any[] }>(Methods.PARTY_LIST, {}); + setSessions((res.sessions ?? []).map(transformSession)); + } catch { + // ignore + } finally { + setLoading(false); + } + }, [ws, connected]); + + const startParty = useCallback( + async (topic: string, teamPreset?: string, personaKeys?: string[]) => { + if (!connected) return; + setLoading(true); + try { + await ws.call(Methods.PARTY_START, { + topic, + team_preset: teamPreset ?? undefined, + personas: personaKeys ?? undefined, + }); + } catch { + // ignore + } finally { + setLoading(false); + } + }, + [ws, connected], + ); + + const runRound = useCallback( + async (sessionId: string, roundMode?: PartyMode) => { + await ws.call(Methods.PARTY_ROUND, { + session_id: sessionId, + mode: roundMode ?? undefined, + }); + }, + [ws], + ); + + const askQuestion = useCallback( + async (sessionId: string, text: string) => { + await ws.call(Methods.PARTY_QUESTION, { session_id: sessionId, text }); + }, + [ws], + ); + + const addContext = useCallback( + async (sessionId: string, type: string, name?: string, content?: string) => { + await ws.call(Methods.PARTY_ADD_CONTEXT, { session_id: sessionId, type, name, content }); + }, + [ws], + ); + + const getSummary = useCallback( + async (sessionId: string) => { + await ws.call(Methods.PARTY_SUMMARY, { session_id: sessionId }); + }, + [ws], + ); + + const exitParty = useCallback( + async (sessionId: string) => { + await ws.call(Methods.PARTY_EXIT, { session_id: sessionId }); + }, + [ws], + ); + + // Activate an existing session from the list (hydrates all state including history) + const selectSession = useCallback( + (session: PartySession) => { + setActiveSessionId(session.id); + const enriched = session.personas.map((pe) => ({ + ...pe, + color: getPersonaColor(pe.key), + })); + setPersonas(enriched); + setRound(session.round); + setMode(session.mode); + setStatus(session.status === "closed" ? "closed" : "active"); + personaColorMap.current.clear(); + enriched.forEach((pe) => personaColorMap.current.set(pe.key, pe.color)); + + // Hydrate messages from stored history + const { msgs: restored, nextId } = hydrateMessages(session._history, session._summary, msgIdCounter.current); + if (restored.length > 0) { + msgIdCounter.current = nextId; + setMessages(restored); + // Restore summary state + if (session._summary && typeof session._summary === "object" && session._summary.markdown) { + setSummary(session._summary as PartySummary); + } else { + setSummary(null); + } + } else { + setMessages([]); + setSummary(null); + } + }, + [getPersonaColor], + ); + + return { + // state + sessions, + activeSessionId, + messages, + personas, + thinkingPersonas, + round, + mode, + status, + summary, + loading, + + // actions + listSessions, + startParty, + runRound, + askQuestion, + addContext, + getSummary, + exitParty, + selectSession, + setActiveSessionId, + getPersonaColor, + }; +} diff --git a/ui/web/src/pages/party/party-controls.tsx b/ui/web/src/pages/party/party-controls.tsx new file mode 100644 index 00000000..c1c3fcd2 --- /dev/null +++ b/ui/web/src/pages/party/party-controls.tsx @@ -0,0 +1,197 @@ +import { useState } from "react"; +import { useTranslation } from "react-i18next"; +import { Play, MessageCircleQuestion, FileText, LogOut } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { ConfirmDialog } from "@/components/shared/confirm-dialog"; +import { cn } from "@/lib/utils"; +import type { PartyMode } from "./hooks/use-party"; + +interface PartyControlsProps { + sessionId: string; + currentMode: PartyMode; + status: "idle" | "active" | "closed"; + onRunRound: (sessionId: string, mode?: PartyMode) => Promise; + onAskQuestion: (sessionId: string, text: string) => Promise; + onGetSummary: (sessionId: string) => Promise; + onExit: (sessionId: string) => Promise; +} + +const MODE_OPTIONS: { value: PartyMode; labelKey: string }[] = [ + { value: "standard", labelKey: "mode.standard" }, + { value: "deep", labelKey: "mode.deep" }, + { value: "token_ring", labelKey: "mode.token_ring" }, +]; + +export function PartyControls({ + sessionId, + currentMode, + status, + onRunRound, + onAskQuestion, + onGetSummary, + onExit, +}: PartyControlsProps) { + const { t } = useTranslation("party"); + const [selectedMode, setSelectedMode] = useState(currentMode); + const [questionOpen, setQuestionOpen] = useState(false); + const [questionText, setQuestionText] = useState(""); + const [exitOpen, setExitOpen] = useState(false); + const [runningAction, setRunningAction] = useState(null); + + const isClosed = status === "closed"; + + const handleRunRound = async () => { + setRunningAction("round"); + try { + await onRunRound(sessionId, selectedMode); + } finally { + setRunningAction(null); + } + }; + + const handleQuestion = async () => { + if (!questionText.trim()) return; + setRunningAction("question"); + try { + await onAskQuestion(sessionId, questionText.trim()); + setQuestionText(""); + setQuestionOpen(false); + } finally { + setRunningAction(null); + } + }; + + const handleSummary = async () => { + setRunningAction("summary"); + try { + await onGetSummary(sessionId); + } finally { + setRunningAction(null); + } + }; + + const handleExit = async () => { + setRunningAction("exit"); + try { + await onExit(sessionId); + setExitOpen(false); + } finally { + setRunningAction(null); + } + }; + + return ( +
+
+ {/* Mode toggle */} +
+ {MODE_OPTIONS.map((opt) => ( + + ))} +
+ +
+ + {/* Continue [C] */} + + + {/* Question [Q] */} + {questionOpen ? ( +
+ setQuestionText(e.target.value)} + placeholder="Ask a question..." + className="h-8 w-48 text-xs" + onKeyDown={(e) => { + if (e.key === "Enter") handleQuestion(); + if (e.key === "Escape") setQuestionOpen(false); + }} + autoFocus + /> + +
+ ) : ( + + )} + + {/* Summary [D] */} + + + {/* Spacer */} +
+ + {/* Exit [E] */} + +
+ + +
+ ); +} diff --git a/ui/web/src/pages/party/party-page.tsx b/ui/web/src/pages/party/party-page.tsx new file mode 100644 index 00000000..e51c6266 --- /dev/null +++ b/ui/web/src/pages/party/party-page.tsx @@ -0,0 +1,171 @@ +import { useState, useEffect } from "react"; +import { useTranslation } from "react-i18next"; +import { PartyPopper, Plus } from "lucide-react"; +import { PageHeader } from "@/components/shared/page-header"; +import { EmptyState } from "@/components/shared/empty-state"; +import { CardSkeleton } from "@/components/shared/loading-skeleton"; +import { Button } from "@/components/ui/button"; +import { Badge } from "@/components/ui/badge"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { useDeferredLoading } from "@/hooks/use-deferred-loading"; +import { cn } from "@/lib/utils"; +import { useParty } from "./hooks/use-party"; +import { PartyStartDialog } from "./party-start-dialog"; +import { PartySession } from "./party-session"; +import { PersonaSidebar } from "./persona-sidebar"; +import { PartyControls } from "./party-controls"; + +export function PartyPage() { + const { t } = useTranslation("party"); + const { + sessions, + activeSessionId, + messages, + personas, + thinkingPersonas, + mode, + status, + loading, + listSessions, + startParty, + runRound, + askQuestion, + getSummary, + exitParty, + selectSession, + getPersonaColor, + } = useParty(); + + const [createOpen, setCreateOpen] = useState(false); + const showSkeleton = useDeferredLoading(loading && sessions.length === 0); + + useEffect(() => { + listSessions(); + }, [listSessions]); + + const hasActiveSession = activeSessionId !== null && status !== "idle"; + + return ( +
+ {/* Header area */} +
+ setCreateOpen(true)} className="gap-1"> + {t("newParty")} + + } + /> +
+ + {/* Main content */} +
+ {/* Session list panel */} +
+
+

+ Sessions +

+
+ +
+ {showSkeleton ? ( + Array.from({ length: 3 }).map((_, i) => ( + + )) + ) : sessions.length === 0 ? ( +

+ {t("noSessions")} +

+ ) : ( + sessions.map((session) => ( + + )) + )} +
+
+
+ + {/* Active session area */} + {hasActiveSession ? ( +
+ {/* Chat + persona sidebar */} +
+ {/* Chat messages */} + + + {/* Right sidebar with persona list */} + +
+ + {/* Bottom controls */} + +
+ ) : ( +
+ setCreateOpen(true)} className="gap-1"> + {t("newParty")} + + } + /> +
+ )} +
+ + {/* Start dialog */} + +
+ ); +} diff --git a/ui/web/src/pages/party/party-session.tsx b/ui/web/src/pages/party/party-session.tsx new file mode 100644 index 00000000..fd8b04f9 --- /dev/null +++ b/ui/web/src/pages/party/party-session.tsx @@ -0,0 +1,129 @@ +import { useEffect, useRef } from "react"; +import { useTranslation } from "react-i18next"; +import { ScrollArea } from "@/components/ui/scroll-area"; +import { MarkdownRenderer } from "@/components/shared/markdown-renderer"; +import { cn } from "@/lib/utils"; +import type { PartyMessage, PartyMode } from "./hooks/use-party"; + +interface PartySessionProps { + messages: PartyMessage[]; + getPersonaColor: (key: string) => string; +} + +function RoundHeader({ round, mode, t }: { round: number; mode?: PartyMode; t: (k: string, opts?: Record) => string }) { + const modeLabel = mode ? t(`mode.${mode}`) : ""; + return ( +
+
+ + {t("round", { n: round })} {modeLabel && `[${modeLabel}]`} + +
+
+ ); +} + +function PersonaMessage({ + message, + borderColor, +}: { + message: PartyMessage; + borderColor: string; +}) { + const isIntro = message.type === "intro"; + + return ( +
+
+ {message.personaEmoji} + {message.personaName} + {isIntro && ( + intro + )} +
+
+ +
+
+ ); +} + +function ContextMessage({ message }: { message: PartyMessage }) { + return ( +
+ + {message.content} + +
+ ); +} + +function SummaryMessage({ message }: { message: PartyMessage }) { + return ( +
+ +
+ ); +} + +function ArtifactMessage({ message }: { message: PartyMessage }) { + return ( +
+ +
+ ); +} + +export function PartySession({ messages, getPersonaColor }: PartySessionProps) { + const { t } = useTranslation("party"); + const bottomRef = useRef(null); + + useEffect(() => { + bottomRef.current?.scrollIntoView({ behavior: "smooth" }); + }, [messages.length]); + + return ( + +
+ {messages.map((msg) => { + switch (msg.type) { + case "round_header": + return ( + + ); + case "intro": + case "spoke": + return ( + + ); + case "context": + return ; + case "summary": + return ; + case "artifact": + return ; + default: + return null; + } + })} +
+
+ + ); +} diff --git a/ui/web/src/pages/party/party-start-dialog.tsx b/ui/web/src/pages/party/party-start-dialog.tsx new file mode 100644 index 00000000..a0164291 --- /dev/null +++ b/ui/web/src/pages/party/party-start-dialog.tsx @@ -0,0 +1,191 @@ +import { useState } from "react"; +import { useTranslation } from "react-i18next"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogFooter, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { cn } from "@/lib/utils"; + +const TEAM_PRESETS = [ + "payment_feature", + "security_review", + "sprint_planning", + "architecture_decision", + "ux_review", + "incident_response", +] as const; + +const ALL_PERSONAS = [ + { key: "tony-stark-persona", emoji: "\ud83d\udcca", label: "Product Manager" }, + { key: "morpheus-persona", emoji: "\ud83d\udd27", label: "Tech Lead" }, + { key: "batman-persona", emoji: "\ud83d\udd12", label: "Security Analyst" }, + { key: "columbo-persona", emoji: "\ud83e\uddea", label: "QA Engineer" }, + { key: "scotty-persona", emoji: "\u2699\ufe0f", label: "DevOps Engineer" }, + { key: "edna-mode-persona", emoji: "\ud83c\udfa8", label: "UX Designer" }, + { key: "spider-man-persona", emoji: "\ud83c\udf10", label: "Frontend Dev" }, + { key: "ethan-hunt-persona", emoji: "\ud83d\udcf1", label: "Mobile Dev" }, + { key: "sherlock-persona", emoji: "\ud83d\udcbc", label: "Business Analyst" }, + { key: "judge-dredd-persona", emoji: "\ud83d\udccb", label: "Compliance Officer" }, + { key: "gandalf-persona", emoji: "\ud83c\udfc3", label: "Scrum Master" }, + { key: "neo-persona", emoji: "\ud83c\udfd7\ufe0f", label: "Architect" }, + { key: "spock-persona", emoji: "\ud83d\uddc4\ufe0f", label: "DBA" }, + { key: "nick-fury-persona", emoji: "\ud83d\udc54", label: "Executive" }, +] as const; + +interface PartyStartDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + onStart: (topic: string, teamPreset?: string, personaKeys?: string[]) => Promise; +} + +export function PartyStartDialog({ open, onOpenChange, onStart }: PartyStartDialogProps) { + const { t } = useTranslation("party"); + const [topic, setTopic] = useState(""); + const [selectedPreset, setSelectedPreset] = useState(null); + const [isCustom, setIsCustom] = useState(false); + const [selectedPersonas, setSelectedPersonas] = useState>(new Set()); + const [loading, setLoading] = useState(false); + + const togglePersona = (key: string) => { + setSelectedPersonas((prev) => { + const next = new Set(prev); + if (next.has(key)) { + next.delete(key); + } else { + next.add(key); + } + return next; + }); + }; + + const handleSelectPreset = (preset: string) => { + setSelectedPreset(preset); + setIsCustom(false); + setSelectedPersonas(new Set()); + }; + + const handleSelectCustom = () => { + setSelectedPreset(null); + setIsCustom(true); + }; + + const canStart = topic.trim().length > 0 && (selectedPreset || (isCustom && selectedPersonas.size >= 2)); + + const handleStart = async () => { + if (!canStart) return; + setLoading(true); + try { + await onStart( + topic.trim(), + selectedPreset ?? undefined, + isCustom ? Array.from(selectedPersonas) : undefined, + ); + onOpenChange(false); + setTopic(""); + setSelectedPreset(null); + setIsCustom(false); + setSelectedPersonas(new Set()); + } catch { + // error handled upstream + } finally { + setLoading(false); + } + }; + + return ( + + + + {t("newParty")} + + +
+ {/* Topic input */} +
+ + setTopic(e.target.value)} + placeholder="e.g., Design payment reconciliation service..." + /> +
+ + {/* Team presets */} +
+ +
+ {TEAM_PRESETS.map((preset) => ( + + ))} +
+
+ + {/* Custom team option */} +
+ + + {isCustom && ( +
+ {ALL_PERSONAS.map((persona) => ( + + ))} +
+ )} +
+
+ + + + + +
+
+ ); +} diff --git a/ui/web/src/pages/party/persona-sidebar.tsx b/ui/web/src/pages/party/persona-sidebar.tsx new file mode 100644 index 00000000..c3feac4d --- /dev/null +++ b/ui/web/src/pages/party/persona-sidebar.tsx @@ -0,0 +1,64 @@ +import { useTranslation } from "react-i18next"; +import { cn } from "@/lib/utils"; +import type { PersonaInfo } from "./hooks/use-party"; + +interface PersonaSidebarProps { + personas: PersonaInfo[]; + thinkingPersonas: Set; +} + +export function PersonaSidebar({ personas, thinkingPersonas }: PersonaSidebarProps) { + const { t } = useTranslation("party"); + + if (personas.length === 0) return null; + + return ( +
+
+

+ Personas +

+
+
+ {personas.map((persona) => { + const isThinking = thinkingPersonas.has(persona.key); + return ( +
+ {/* Status indicator */} + + {/* Persona info */} +
+
+ {persona.emoji} + + {persona.name} + +
+

+ {persona.role} +

+
+ {/* Status label */} + {isThinking && ( + + {t("status.thinking")} + + )} +
+ ); + })} +
+
+ ); +} diff --git a/ui/web/src/pages/projects/hooks/use-projects.ts b/ui/web/src/pages/projects/hooks/use-projects.ts new file mode 100644 index 00000000..76d09953 --- /dev/null +++ b/ui/web/src/pages/projects/hooks/use-projects.ts @@ -0,0 +1,144 @@ +import { useCallback } from "react"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import i18next from "i18next"; +import { useHttp } from "@/hooks/use-ws"; +import { queryKeys } from "@/lib/query-keys"; +import { toast } from "@/stores/use-toast-store"; + +export interface ProjectData { + id: string; + name: string; + slug: string; + channel_type?: string | null; + chat_id?: string | null; + team_id?: string | null; + description?: string | null; + status: string; + created_by: string; + created_at: string; + updated_at: string; +} + +export interface ProjectInput { + name: string; + slug: string; + channel_type?: string | null; + chat_id?: string | null; + description?: string | null; + status?: string; +} + +export interface ProjectMCPOverride { + id: string; + project_id: string; + server_name: string; + env_overrides: Record; + enabled: boolean; +} + +export function useProjects() { + const http = useHttp(); + const queryClient = useQueryClient(); + + const { data: projects = [], isLoading: loading } = useQuery({ + queryKey: queryKeys.projects.all, + queryFn: async () => { + const res = await http.get<{ projects: ProjectData[] }>("/v1/projects"); + return res.projects ?? []; + }, + }); + + const invalidate = useCallback( + () => queryClient.invalidateQueries({ queryKey: queryKeys.projects.all }), + [queryClient], + ); + + const createProject = useCallback( + async (data: ProjectInput) => { + try { + const res = await http.post("/v1/projects", data); + await invalidate(); + toast.success(i18next.t("projects:toast.created")); + return res; + } catch (err) { + toast.error(i18next.t("projects:toast.failedCreate"), err instanceof Error ? err.message : ""); + throw err; + } + }, + [http, invalidate], + ); + + const updateProject = useCallback( + async (id: string, data: Partial) => { + try { + await http.put(`/v1/projects/${id}`, data); + await invalidate(); + toast.success(i18next.t("projects:toast.updated")); + } catch (err) { + toast.error(i18next.t("projects:toast.failedUpdate"), err instanceof Error ? err.message : ""); + throw err; + } + }, + [http, invalidate], + ); + + const deleteProject = useCallback( + async (id: string) => { + try { + await http.delete(`/v1/projects/${id}`); + await invalidate(); + toast.success(i18next.t("projects:toast.deleted")); + } catch (err) { + toast.error(i18next.t("projects:toast.failedDelete"), err instanceof Error ? err.message : ""); + throw err; + } + }, + [http, invalidate], + ); + + const listOverrides = useCallback( + async (projectId: string) => { + const res = await http.get<{ overrides: ProjectMCPOverride[] }>(`/v1/projects/${projectId}/mcp`); + return res.overrides ?? []; + }, + [http], + ); + + const setOverride = useCallback( + async (projectId: string, serverName: string, envOverrides: Record) => { + try { + await http.put(`/v1/projects/${projectId}/mcp/${serverName}`, envOverrides); + toast.success(i18next.t("projects:overrides.saved")); + } catch (err) { + toast.error(i18next.t("projects:overrides.failedSave"), err instanceof Error ? err.message : ""); + throw err; + } + }, + [http], + ); + + const removeOverride = useCallback( + async (projectId: string, serverName: string) => { + try { + await http.delete(`/v1/projects/${projectId}/mcp/${serverName}`); + toast.success(i18next.t("projects:overrides.deleted")); + } catch (err) { + toast.error(i18next.t("projects:overrides.failedDelete"), err instanceof Error ? err.message : ""); + throw err; + } + }, + [http], + ); + + return { + projects, + loading, + refresh: invalidate, + createProject, + updateProject, + deleteProject, + listOverrides, + setOverride, + removeOverride, + }; +} diff --git a/ui/web/src/pages/projects/project-form-dialog.tsx b/ui/web/src/pages/projects/project-form-dialog.tsx new file mode 100644 index 00000000..f7e73124 --- /dev/null +++ b/ui/web/src/pages/projects/project-form-dialog.tsx @@ -0,0 +1,182 @@ +import { useState, useEffect } from "react"; +import { useTranslation } from "react-i18next"; +import { + Dialog, + DialogContent, + DialogFooter, + DialogHeader, + DialogTitle, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Textarea } from "@/components/ui/textarea"; +import { slugify, isValidSlug } from "@/lib/slug"; +import type { ProjectData, ProjectInput } from "./hooks/use-projects"; + +const CHANNEL_TYPES = [ + { value: "", label: "—" }, + { value: "telegram", label: "Telegram" }, + { value: "zalo_oa", label: "Zalo OA" }, + { value: "discord", label: "Discord" }, + { value: "slack", label: "Slack" }, + { value: "feishu", label: "Feishu/Lark" }, + { value: "whatsapp", label: "WhatsApp" }, + { value: "google_chat", label: "Google Chat" }, +]; + +const STATUS_OPTIONS = [ + { value: "active", label: "Active" }, + { value: "archived", label: "Archived" }, +]; + +interface ProjectFormDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + project?: ProjectData | null; + onSubmit: (data: ProjectInput) => Promise; +} + +export function ProjectFormDialog({ open, onOpenChange, project, onSubmit }: ProjectFormDialogProps) { + const { t } = useTranslation("projects"); + const [name, setName] = useState(""); + const [slug, setSlug] = useState(""); + const [channelType, setChannelType] = useState(""); + const [chatId, setChatId] = useState(""); + const [description, setDescription] = useState(""); + const [status, setStatus] = useState("active"); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(""); + const [autoSlug, setAutoSlug] = useState(true); + + useEffect(() => { + if (open) { + setName(project?.name ?? ""); + setSlug(project?.slug ?? ""); + setChannelType(project?.channel_type ?? ""); + setChatId(project?.chat_id ?? ""); + setDescription(project?.description ?? ""); + setStatus(project?.status ?? "active"); + setError(""); + setAutoSlug(!project); + } + }, [open, project]); + + const handleNameChange = (value: string) => { + setName(value); + if (autoSlug) { + setSlug(slugify(value)); + } + }; + + const handleSlugChange = (value: string) => { + setAutoSlug(false); + setSlug(slugify(value)); + }; + + const handleSubmit = async () => { + if (!name.trim() || !slug.trim()) { + setError(t("form.errors.nameRequired")); + return; + } + if (!isValidSlug(slug)) { + setError(t("form.errors.slugInvalid")); + return; + } + + setLoading(true); + setError(""); + try { + await onSubmit({ + name: name.trim(), + slug: slug.trim(), + channel_type: channelType || null, + chat_id: chatId.trim() || null, + description: description.trim() || null, + status, + }); + onOpenChange(false); + } catch (err: unknown) { + setError(err instanceof Error ? err.message : t("form.saving")); + } finally { + setLoading(false); + } + }; + + return ( + !loading && onOpenChange(v)}> + + + {project ? t("form.editTitle") : t("form.createTitle")} + + +
+
+ + handleNameChange(e.target.value)} placeholder="XPOS" className="text-base md:text-sm" /> +
+ +
+ + handleSlugChange(e.target.value)} placeholder="xpos" className="font-mono text-base md:text-sm" /> +

{t("form.slugHint")}

+
+ +
+ + +
+ +
+ + setChatId(e.target.value)} placeholder={t("form.chatIdPlaceholder")} className="font-mono text-base md:text-sm" /> +
+ +
+ +