Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions cmd/entire/cli/agent/factoryaidroid/lifecycle.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package factoryaidroid

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"os"
"strings"
"time"

"github.com/entireio/cli/cmd/entire/cli/agent"
Expand Down Expand Up @@ -50,7 +52,7 @@ func (f *FactoryAIDroidAgent) HookNames() []string {

// ParseHookEvent translates a Factory AI Droid hook into a normalized lifecycle Event.
// Returns nil if the hook has no lifecycle significance.
func (f *FactoryAIDroidAgent) ParseHookEvent(_ context.Context, hookName string, stdin io.Reader) (*agent.Event, error) {
func (f *FactoryAIDroidAgent) ParseHookEvent(ctx context.Context, hookName string, stdin io.Reader) (*agent.Event, error) {
switch hookName {
case HookNameSessionStart:
return f.parseSessionStart(stdin)
Expand All @@ -61,9 +63,9 @@ func (f *FactoryAIDroidAgent) ParseHookEvent(_ context.Context, hookName string,
case HookNameSessionEnd:
return f.parseSessionEnd(stdin)
case HookNamePreToolUse:
return f.parseSubagentStart(stdin)
return f.parseSubagentStart(ctx, stdin)
case HookNamePostToolUse:
return f.parseSubagentEnd(stdin)
return f.parseSubagentEnd(ctx, stdin)
case HookNamePreCompact:
return f.parseCompaction(stdin)
case HookNameSubagentStop, HookNameNotification:
Expand Down Expand Up @@ -226,36 +228,50 @@ func (f *FactoryAIDroidAgent) parseSessionEnd(stdin io.Reader) (*agent.Event, er
}, nil
}

func (f *FactoryAIDroidAgent) parseSubagentStart(stdin io.Reader) (*agent.Event, error) {
func (f *FactoryAIDroidAgent) parseSubagentStart(ctx context.Context, stdin io.Reader) (*agent.Event, error) {
raw, err := agent.ReadAndParseHookInput[taskHookInputRaw](stdin)
if err != nil {
return nil, err
}
toolUseID := raw.ToolUseID
if toolUseID == "" {
toolUseID, err = registerFallbackToolUseID(ctx, raw.SessionID, raw.ToolName, raw.ToolInput)
if err != nil {
return nil, err
}
}
return &agent.Event{
Type: agent.SubagentStart,
SessionID: raw.SessionID,
SessionRef: raw.TranscriptPath,
ToolUseID: raw.ToolUseID,
ToolUseID: toolUseID,
ToolInput: raw.ToolInput,
Timestamp: time.Now(),
}, nil
}

func (f *FactoryAIDroidAgent) parseSubagentEnd(stdin io.Reader) (*agent.Event, error) {
func (f *FactoryAIDroidAgent) parseSubagentEnd(ctx context.Context, stdin io.Reader) (*agent.Event, error) {
raw, err := agent.ReadAndParseHookInput[postToolHookInputRaw](stdin)
if err != nil {
return nil, err
}
toolUseID := raw.ToolUseID
if toolUseID == "" {
toolUseID, err = resolveFallbackToolUseID(ctx, raw.SessionID, raw.ToolName, raw.ToolInput)
if err != nil {
return nil, err
}
}
event := &agent.Event{
Type: agent.SubagentEnd,
SessionID: raw.SessionID,
SessionRef: raw.TranscriptPath,
ToolUseID: raw.ToolUseID,
ToolUseID: toolUseID,
ToolInput: raw.ToolInput,
Timestamp: time.Now(),
}
if raw.ToolResponse.AgentID != "" {
event.SubagentID = raw.ToolResponse.AgentID
if agentID := parseHookToolResponseAgentID(raw.ToolResponse); agentID != "" {
event.SubagentID = agentID
}
return event, nil
}
Expand All @@ -272,3 +288,44 @@ func (f *FactoryAIDroidAgent) parseCompaction(stdin io.Reader) (*agent.Event, er
Timestamp: time.Now(),
}, nil
}

func parseHookToolResponseAgentID(raw json.RawMessage) string {
if trimmed := bytes.TrimSpace(raw); len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) {
return ""
}

var obj struct {
AgentID string `json:"agentId"`
}
if err := json.Unmarshal(raw, &obj); err == nil && obj.AgentID != "" {
return obj.AgentID
}

var text string
if err := json.Unmarshal(raw, &text); err == nil {
return extractHookToolResponseAgentID(text)
}

return ""
}

func extractHookToolResponseAgentID(text string) string {
const prefix = "agentId: "
_, after, found := strings.Cut(text, prefix)
if !found {
return ""
}

after = strings.TrimSpace(after)
end := 0
for end < len(after) {
ch := after[end]
if ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch >= '0' && ch <= '9' || ch == '-' || ch == '_' {
end++
continue
}
break
}

return after[:end]
}
88 changes: 88 additions & 0 deletions cmd/entire/cli/agent/factoryaidroid/lifecycle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@ package factoryaidroid

import (
"context"
"path/filepath"
"strings"
"testing"

"github.com/entireio/cli/cmd/entire/cli/agent"
"github.com/entireio/cli/cmd/entire/cli/paths"
git "github.com/go-git/go-git/v6"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -176,6 +179,91 @@ func TestParseHookEvent_SubagentEnd(t *testing.T) {
}
}

func TestParseHookEvent_SubagentStart_MissingToolUseID(t *testing.T) {
t.Parallel()

ag := &FactoryAIDroidAgent{}
input := `{"session_id": "sess-4", "transcript_path": "/tmp/t.jsonl", "tool_name": "Task", "tool_input": {"prompt": "do something"}}`

event, err := ag.ParseHookEvent(context.Background(), HookNamePreToolUse, strings.NewReader(input))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if event.ToolUseID == "" {
t.Fatal("expected fallback tool_use_id, got empty string")
}
}

func TestParseHookEvent_SubagentEnd_StringToolResponse(t *testing.T) {
t.Parallel()

ag := &FactoryAIDroidAgent{}
input := `{"session_id": "sess-5", "transcript_path": "/tmp/t.jsonl", "tool_name": "Task", "tool_input": {}, "tool_response": "agentId: agent-789"}`

event, err := ag.ParseHookEvent(context.Background(), HookNamePostToolUse, strings.NewReader(input))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if event.Type != agent.SubagentEnd {
t.Errorf("expected SubagentEnd, got %v", event.Type)
}
if event.SubagentID != "agent-789" {
t.Errorf("expected SubagentID 'agent-789', got %q", event.SubagentID)
}
if event.ToolUseID == "" {
t.Fatal("expected fallback tool_use_id, got empty string")
}
}

func TestParseHookEvent_MissingToolUseID_RepeatedInputsStayUniqueAndCorrelate(t *testing.T) {
repoDir := t.TempDir()
if _, err := git.PlainInit(repoDir, false); err != nil {
t.Fatalf("git init: %v", err)
}
t.Chdir(repoDir)
paths.ClearWorktreeRootCache()
t.Cleanup(paths.ClearWorktreeRootCache)

ag := &FactoryAIDroidAgent{}
input := `{"session_id": "sess-repeat", "transcript_path": "/tmp/t.jsonl", "tool_name": "Task", "tool_input": {"prompt": "do something"}}`

startOne, err := ag.ParseHookEvent(context.Background(), HookNamePreToolUse, strings.NewReader(input))
if err != nil {
t.Fatalf("unexpected error on first start: %v", err)
}
startTwo, err := ag.ParseHookEvent(context.Background(), HookNamePreToolUse, strings.NewReader(input))
if err != nil {
t.Fatalf("unexpected error on second start: %v", err)
}
if startOne.ToolUseID == startTwo.ToolUseID {
t.Fatalf("expected unique fallback tool_use_id values, got %q", startOne.ToolUseID)
}

endTwo, err := ag.ParseHookEvent(context.Background(), HookNamePostToolUse, strings.NewReader(input))
if err != nil {
t.Fatalf("unexpected error on second end: %v", err)
}
if endTwo.ToolUseID != startTwo.ToolUseID {
t.Fatalf("expected most recent fallback tool_use_id %q, got %q", startTwo.ToolUseID, endTwo.ToolUseID)
}

endOne, err := ag.ParseHookEvent(context.Background(), HookNamePostToolUse, strings.NewReader(input))
if err != nil {
t.Fatalf("unexpected error on first end: %v", err)
}
if endOne.ToolUseID != startOne.ToolUseID {
t.Fatalf("expected earlier fallback tool_use_id %q, got %q", startOne.ToolUseID, endOne.ToolUseID)
}

matches, err := filepath.Glob(filepath.Join(repoDir, paths.EntireTmpDir, fallbackToolUseStatePrefix+"*.json"))
if err != nil {
t.Fatalf("glob fallback state files: %v", err)
}
if len(matches) != 0 {
t.Fatalf("expected fallback state cleanup, found %v", matches)
}
}

func TestParseHookEvent_Compaction(t *testing.T) {
t.Parallel()

Expand Down
147 changes: 147 additions & 0 deletions cmd/entire/cli/agent/factoryaidroid/tool_use_fallback.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package factoryaidroid

import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"

"github.com/entireio/cli/cmd/entire/cli/paths"
)

const fallbackToolUseStatePrefix = "factory-task-tool-use-"

type fallbackToolUseState struct {
Entries []fallbackToolUseEntry `json:"entries"`
}

type fallbackToolUseEntry struct {
Fingerprint string `json:"fingerprint"`
ToolUseID string `json:"tool_use_id"`
}

func registerFallbackToolUseID(
ctx context.Context,
sessionID, toolName string,
toolInput json.RawMessage,
) (string, error) {
statePath, err := fallbackToolUseStatePath(ctx, sessionID)
if err != nil {
return "", err
}

state, err := loadFallbackToolUseState(statePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
state = &fallbackToolUseState{}
} else {
return "", err
}
}

toolUseID, err := newFallbackToolUseID()
if err != nil {
return "", err
}

state.Entries = append(state.Entries, fallbackToolUseEntry{
Fingerprint: fallbackToolFingerprint(toolName, toolInput),
ToolUseID: toolUseID,
})
if err := saveFallbackToolUseState(statePath, state); err != nil {
return "", err
}

return toolUseID, nil
}

func resolveFallbackToolUseID(
ctx context.Context,
sessionID, toolName string,
toolInput json.RawMessage,
) (string, error) {
statePath, err := fallbackToolUseStatePath(ctx, sessionID)
if err != nil {
return "", err
}

state, err := loadFallbackToolUseState(statePath)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return fallbackToolUseID(sessionID, toolName, toolInput), nil
}
return "", err
}

fingerprint := fallbackToolFingerprint(toolName, toolInput)
for i := len(state.Entries) - 1; i >= 0; i-- {
if state.Entries[i].Fingerprint != fingerprint {
continue
}

toolUseID := state.Entries[i].ToolUseID
state.Entries = append(state.Entries[:i], state.Entries[i+1:]...)
if err := saveFallbackToolUseState(statePath, state); err != nil {
return "", err
}
return toolUseID, nil
}

return fallbackToolUseID(sessionID, toolName, toolInput), nil
}

func newFallbackToolUseID() (string, error) {
var suffix [8]byte
if _, err := rand.Read(suffix[:]); err != nil {
return "", fmt.Errorf("generate fallback tool_use_id: %w", err)
}
return "factorytask_" + hex.EncodeToString(suffix[:]), nil
}

func fallbackToolUseStatePath(ctx context.Context, sessionID string) (string, error) {
tmpDir, err := paths.AbsPath(ctx, paths.EntireTmpDir)
if err != nil {
return "", fmt.Errorf("resolve fallback tool_use_id tmp dir: %w", err)
}
if err := os.MkdirAll(tmpDir, 0o750); err != nil {
return "", fmt.Errorf("create fallback tool_use_id tmp dir: %w", err)
}

sessionHash := fallbackToolUseID(sessionID, "", nil)
return filepath.Join(tmpDir, fallbackToolUseStatePrefix+sessionHash+".json"), nil
}

func loadFallbackToolUseState(path string) (*fallbackToolUseState, error) {
data, err := os.ReadFile(filepath.Clean(path))
if err != nil {
return nil, fmt.Errorf("read fallback tool_use_id state: %w", err)
}

var state fallbackToolUseState
if err := json.Unmarshal(data, &state); err != nil {
return nil, fmt.Errorf("unmarshal fallback tool_use_id state: %w", err)
}
return &state, nil
}

func saveFallbackToolUseState(path string, state *fallbackToolUseState) error {
if len(state.Entries) == 0 {
if err := os.Remove(path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("remove empty fallback tool_use_id state: %w", err)
}
return nil
}

data, err := json.Marshal(state)
if err != nil {
return fmt.Errorf("marshal fallback tool_use_id state: %w", err)
}
if err := os.WriteFile(path, data, 0o600); err != nil {
return fmt.Errorf("write fallback tool_use_id state: %w", err)
}
return nil
}
Loading
Loading