diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 334482ba54..8ee2252299 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -177,3 +177,33 @@ docker run --rm \ --net=host \ localhost:5001/kebab:latest ``` + +### File uploads (artifacts) + +The chat UI supports uploading files to agents running on the **Go ADK runtime**. +The artifact round trip (persistence, the `save_artifact`/`load_artifacts` tools, +and surfacing agent-produced files back to the UI) is Go ADK only; the Python +runtime mirrors just the text-extraction path so uploaded documents still reach +non-multimodal models. Files travel inline (base64) over the existing A2A message +channel as `FilePart`s — there is no separate artifact HTTP API or CRD field. + +- **Storage:** Uploaded files are persisted via the ADK in-memory artifact + service (`artifact.InMemoryService()`), so artifacts live for the lifetime of + the agent process and are not durable across restarts. +- **Access from agents:** The `load_artifacts` tool is registered for every + agent, letting the LLM list/load uploaded files. Inbound uploads are saved + automatically (`SaveInputBlobsAsArtifacts`). +- **Reaching the model:** Images are forwarded inline to the model. Rich + documents (PDF, DOCX, XLSX, PPTX, HTML, EPUB) are extracted to text/markdown + so the model can read them; text-like files (txt/markdown/CSV/JSON) are + passed through as-is. The Go runtime uses + [tabula](https://github.com/tsawler/tabula) (pure Go, no CGO for native + text); the Python runtime uses + [markitdown](https://github.com/microsoft/markitdown). The supported document + set is kept common across both runtimes and the UI. Scanned/image-only PDFs + yield no text without OCR. +- **Allowlist (enforced in the UI):** images, PDF, plain text, Markdown, CSV, + JSON, XML, YAML, HTML, and DOCX, XLSX, PPTX, EPUB documents. +- **Size limit:** 10 MB per file by default, enforced both client-side and on + the server. Override the server limit with the `KAGENT_MAX_ARTIFACT_BYTES` + environment variable (value in bytes); oversized inbound files fail the task. diff --git a/design/EP-89405-file-upload-artifacts.md b/design/EP-89405-file-upload-artifacts.md new file mode 100644 index 0000000000..7a8e26b805 --- /dev/null +++ b/design/EP-89405-file-upload-artifacts.md @@ -0,0 +1,174 @@ +# EP-89405: File Upload & Artifact Support in Chat + +* Issue: [MSFN-89405] (internal) — feat: file upload feature +* Status: `implemented` (branch `feature/file-upload-artifacts`, commit `ccfada50`) +* Related: EP-2046 (Chat UI for MCP UI widgets) — shares the chat files but + explicitly excludes the file-upload backend, which this EP owns. + +## Background + +kagent's chat could only exchange text. Users frequently need to hand an agent a +file (a log, a PDF, a CSV, a screenshot) and to receive files an agent/tool +produces (a generated report, an extracted table). The Go ADK runtime ships an +artifact subsystem (`artifact.Service`, `loadartifactstool`, the +`SaveInputBlobsAsArtifacts` run option, and the per-event `ArtifactDelta` +signal), but kagent serves agents purely over A2A (`adka2a`) and never wired the +`ArtifactService` into the runner — so none of it was reachable. + +This EP enables an **end-to-end file upload / download round trip** in chat: + +1. Users attach files in the chat UI; they travel inline (base64) over the + existing A2A message/SSE channel — **no new HTTP API, no CRD field**. +2. The Go ADK executor persists inbound uploads as artifacts and surfaces + agent-produced artifacts back to the UI as downloadable A2A file parts. +3. Agents get a `save_artifact` tool (produce files) plus the built-in + `load_artifacts` tool (reference uploaded/produced files across turns). +4. Both runtimes extract text from uploaded files (notably PDF) so models that + cannot natively read a document still receive its content. + +## Motivation + +- Let users give agents real working material instead of pasting text. +- Let agents return generated files (reports, transformed data) as first-class, + downloadable chat attachments rather than dumping content into the message. +- Reuse the ADK's existing, battle-tested artifact contract rather than inventing + a parallel storage/transport. + +### Goals + +- Wire the ADK in-memory `ArtifactService` into the Go runner (process-lifetime + persistence, versioned, user/session scoped). +- Accept inbound A2A file parts, persist them via `SaveInputBlobsAsArtifacts`, + and pass them inline to the model. +- Emit agent-saved artifacts back to the UI as A2A `FilePart` events, driven by + the `ArtifactDelta` event signal (event-driven, no store diffing). +- Register `loadartifactstool` (load) and a new `save_artifact` tool (produce) + for agents. +- Extract text from uploaded documents (PDF first) in both the Go ADK + (`fileextract`) and Python (`_file_extract`) model paths so non-multimodal + models still receive document content. +- Chat UI: attach multiple files with type/size validation; render image + thumbnails and downloadable file chips for both user and agent bubbles. +- Raise the nginx/proxy request-body limit so uploads aren't rejected at the edge. + +### Non-Goals + +- Durable / shared artifact storage (GCS, DB) and cross-replica access. Artifacts + live in process memory and are lost on pod restart. +- A standalone artifact browser UI (list / delete / version history) beyond the + per-message chips. +- A dedicated artifact HTTP/storage API for the UI (everything rides A2A). +- The MCP-app / minimap chat features that share the same chat files (EP-2046). + +## Implementation Details + +### Transport & data model + +- **Upload:** `@a2a-js/sdk` `FilePart` `{ kind: "file", file: { name, mimeType, + bytes /*base64*/ } }` appended alongside the text part on `message/stream`. +- **Download:** A2A artifact-update events carrying a `FilePart`. +- **ADK artifact value:** `*genai.Part` with `InlineData {Data, MIMEType, + DisplayName}`; key `(AppName, UserID, SessionID, FileName, Version)`; versions + auto-increment; `ArtifactDelta map[filename]version` set automatically on save. + +### Go ADK runtime + +- **`go/adk/pkg/runner/adapter.go`** — set + `ArtifactService: adkartifact.InMemoryService()` on the `runner.Config` in + `CreateRunnerConfig` (single instance, reused for process-lifetime persistence). +- **`go/adk/pkg/agent/agent.go`** — register `loadartifactstool.New()` and the + new `save_artifact` tool in the agent's local tool set. +- **`go/adk/pkg/tools/save_artifact_tool.go`** — new tool letting the LLM produce + a downloadable file from chat; the executor surfaces it as an A2A file part. +- **`go/adk/pkg/a2a/executor.go`** — enable + `runConfig.SaveInputBlobsAsArtifacts = true`; on each event with + `Actions.ArtifactDelta`, `Load` each `(name, version)` from the store, set + `InlineData.DisplayName`, convert via `ToA2APart`, and emit an A2A + artifact-update `FilePart` event (`LastChunk=true`). Load/convert errors are + logged and skipped so the turn continues. +- **`go/adk/pkg/a2a/artifacts.go`** — helpers for building/emitting artifact + events from saved ADK parts. +- **`go/adk/pkg/fileextract/`** (`fileextract.go`, `pdf.go`) — extract text from + uploaded documents (PDF and other supported types) so the content is injected + for models that can't read the raw file. +- **`go/adk/pkg/models/openai_adk.go`** — inject extracted file text into the + OpenAI request path. +- New Go deps in `go/go.mod` / `go/go.sum` for PDF extraction. + +### Python runtime + +- **`python/packages/kagent-adk/src/kagent/adk/models/_file_extract.py`** — text + extraction (PDF, etc.) mirroring the Go path. +- **`python/packages/kagent-adk/src/kagent/adk/models/_openai.py`** — inject + extracted file content into the OpenAI request. +- **`python/packages/kagent-adk/pyproject.toml`** — add the extraction dependency. + +> Note: the original design scoped Python as a follow-up; the shipped +> implementation includes the Python extraction path as well. + +### UI (Next.js, `ui/src`) + +- **`lib/fileUpload.ts`** — `FILE_ACCEPT`, `MAX_FILE_BYTES` (10 MB), `isAllowedFile`, + `fileToFilePart` (read file → base64 `FilePart`). Allowlist: images, PDF, + text/markdown, CSV, JSON. +- **`components/chat/ChatInterface.tsx`** — attach button + hidden multi-file + ``; `pendingFiles` state with removable chips; type/size validation with + toasts; build `FilePart`s and append to the outgoing message; session naming + falls back to the first file name for file-only messages. +- **`components/chat/FileAttachment.tsx`** (new) — image thumbnail (object URL) + or a download chip (icon, filename, human size, download button). +- **`components/chat/ChatMessage.tsx`** — render file parts in user and agent + bubbles. +- **`lib/messageHandlers.ts`** — preserve `file` parts from messages and from + `artifact-update` events (`extractMessagesFromTasks`). + +### Edge / deployment + +- **`helm/kagent/files/nginx.conf`** — add + `client_max_body_size {{ .Values.ui.nginx.clientMaxBodySize }};`. +- **`helm/kagent/values.yaml`** — `ui.nginx.clientMaxBodySize: 50m` (default). +- **`helm/kagent/tests/ui-nginx-configmap_test.yaml`** — assert the default and + an override render into the config. + +### Acceptance criteria (AC1–AC8) + +AC1 artifact service wired · AC2 upload passed inline · AC3 upload persisted · +AC4 agent-saved artifact emitted as A2A `FilePart` · AC5 `loadartifactstool` +registered · AC6 UI multi-file attach with validation · AC7 thumbnail/chip +rendering in both bubbles · AC8 E2E upload round trip on kind. + +### Test Plan + +- **Go unit:** `adapter_test.go` (non-nil `ArtifactService`), `agent_test.go` + (tool list includes load/save tools), `executor_test.go` (ArtifactDelta → + emitted `FilePart`; oversized inbound → failed status), `artifacts_test.go`, + `save_artifact_tool_test.go`, `fileextract` tests (`fileextract_test.go`, + `fixture_test.go`), `models/openai_adk_test.go`. +- **Go e2e:** `go/core/test/e2e/file_upload_test.go` — upload an inline A2A file + part to a Go ADK agent and assert it is processed (uses the current + `a2aproject/a2a-go/v2` API). +- **Python unit:** `tests/unittests/models/test_file_extract.py`, + `test_openai.py`. +- **UI unit:** `lib/__tests__/fileUpload.test.ts`, + `lib/__tests__/messageHandlers.test.ts`, + `chat/__tests__/FileAttachment.test.tsx`, `chat/__tests__/ChatMessage.test.tsx`. +- **Helm:** `helm/kagent/tests/ui-nginx-configmap_test.yaml`. + +## Alternatives + +- **Core-server artifact endpoints / mounting `adkrest`:** more moving parts, a + second transport, and `adkrest` has no upload route — rejected in favor of + A2A-only. +- **Diff the artifact store per turn:** racy and fragile versus the idiomatic + `ArtifactDelta` event signal. +- **Inline-only download (no store surfacing):** wouldn't connect the artifact + store to the UI. +- **Send raw files to every model:** token bloat and many models can't read + PDFs — hence server-side text extraction with the 10 MB cap. + +## Open Questions + +- Durable/shared artifact storage for multi-replica deployments and restarts. +- Per-file size limit configurability beyond the current 10 MB UI cap + + `client_max_body_size`. +- Whether to expose an artifact browser (list/version/delete) in the UI. diff --git a/go/adk/pkg/a2a/artifacts.go b/go/adk/pkg/a2a/artifacts.go new file mode 100644 index 0000000000..552c0abada --- /dev/null +++ b/go/adk/pkg/a2a/artifacts.go @@ -0,0 +1,154 @@ +package a2a + +import ( + "context" + "encoding/base64" + "fmt" + "maps" + "os" + "strconv" + "strings" + + a2atype "github.com/a2aproject/a2a-go/a2a" + "github.com/a2aproject/a2a-go/a2asrv" + "github.com/a2aproject/a2a-go/a2asrv/eventqueue" + adkartifact "google.golang.org/adk/artifact" + "google.golang.org/adk/server/adka2a" //nolint:staticcheck // kagent still uses a2a-go v1; this ADK package is the compatibility adapter. +) + +const ( + // defaultMaxArtifactBytes is the default per-file size limit for inbound + // uploads (10 MB). + defaultMaxArtifactBytes = 10 * 1024 * 1024 + // envMaxArtifactBytes overrides the inbound file size limit (in bytes). + envMaxArtifactBytes = "KAGENT_MAX_ARTIFACT_BYTES" +) + +// MaxArtifactBytes returns the artifact size limit, honoring the +// KAGENT_MAX_ARTIFACT_BYTES env var and falling back to the default. It bounds +// both inbound uploads and agent-saved artifacts. +func MaxArtifactBytes() int { + if v := os.Getenv(envMaxArtifactBytes); v != "" { + if n, err := strconv.Atoi(v); err == nil && n > 0 { + return n + } + } + return defaultMaxArtifactBytes +} + +// checkInboundFileSizes returns an error if any inbound FilePart's decoded +// content exceeds the limit. Only inline base64 (FileBytes) parts are checked; +// URI-referenced files are out of scope. The decoded size is derived from the +// base64 length so a ~10 MB upload is not fully decoded onto the heap just to be +// measured; base64 validity is enforced downstream when the payload is decoded +// for persistence. +func checkInboundFileSizes(msg *a2atype.Message, limit int) error { + if msg == nil { + return nil + } + for _, part := range msg.Parts { + fp := asFilePart(part) + if fp == nil { + continue + } + fb, ok := fp.File.(a2atype.FileBytes) + if !ok { + continue + } + if n := base64DecodedLen(fb.Bytes); n > limit { + return fmt.Errorf("file %q exceeds maximum allowed size: %d bytes > %d bytes", fb.Name, n, limit) + } + } + return nil +} + +// base64DecodedLen returns the number of bytes that standard padded base64 input +// decodes to, derived from its length and trailing padding without allocating +// the payload. +// +// ponytail: assumes clean, padded StdEncoding base64 (what the UI emits via +// FileReader); embedded whitespace/newlines would inflate the estimate. The +// upgrade path is a streaming decode counter if MIME-wrapped input ever shows up. +func base64DecodedLen(s string) int { + n := base64.StdEncoding.DecodedLen(len(s)) + switch { + case strings.HasSuffix(s, "=="): + return n - 2 + case strings.HasSuffix(s, "="): + return n - 1 + } + return n +} + +// asFilePart extracts a *FilePart from an A2A Part, handling both value and +// pointer types. +func asFilePart(part a2atype.Part) *a2atype.FilePart { + switch p := part.(type) { + case *a2atype.FilePart: + return p + case a2atype.FilePart: + return &p + } + return nil +} + +// emitArtifacts loads each artifact named in delta from the artifact service +// and emits it as an A2A artifact event carrying a FilePart. Load/convert +// failures are logged and skipped so the turn continues (AC4). +func (e *KAgentExecutor) emitArtifacts( + ctx context.Context, + reqCtx *a2asrv.RequestContext, + queue eventqueue.Queue, + userID string, + sessionID string, + delta map[string]int64, + eventMeta map[string]any, +) { + svc := e.runnerConfig.ArtifactService + if svc == nil { + return + } + + for name, version := range delta { + resp, err := svc.Load(ctx, &adkartifact.LoadRequest{ + AppName: e.appName, + UserID: userID, + SessionID: sessionID, + FileName: name, + Version: version, + }) + if err != nil { + e.logger.Error(err, "failed to load saved artifact", "name", name, "version", version) + continue + } + if resp == nil || resp.Part == nil { + e.logger.V(1).Info("artifact load returned no part", "name", name, "version", version) + continue + } + + part := resp.Part + // Carry the filename so the converted FilePart has a Name. + if part.InlineData != nil && part.InlineData.DisplayName == "" { + part.InlineData.DisplayName = name + } + + a2aPart, err := adka2a.ToA2APart(part, nil) + if err != nil { + e.logger.Error(err, "failed to convert artifact to A2A part", "name", name, "version", version) + continue + } + + artifactEvent := a2atype.NewArtifactEvent(reqCtx, a2aPart) + artifactEvent.LastChunk = true + artifactEvent.Metadata = maps.Clone(eventMeta) + artifactEvent.Metadata[adka2a.ToA2AMetaKey("artifact_name")] = name + artifactEvent.Metadata[adka2a.ToA2AMetaKey("artifact_version")] = version + if part.InlineData != nil { + artifactEvent.Metadata[adka2a.ToA2AMetaKey("mime_type")] = part.InlineData.MIMEType + } + + if err := queue.Write(ctx, artifactEvent); err != nil { + e.logger.Error(err, "failed to write artifact event", "name", name, "version", version) + } + } +} diff --git a/go/adk/pkg/a2a/artifacts_test.go b/go/adk/pkg/a2a/artifacts_test.go new file mode 100644 index 0000000000..1d169e974b --- /dev/null +++ b/go/adk/pkg/a2a/artifacts_test.go @@ -0,0 +1,220 @@ +package a2a + +import ( + "context" + "encoding/base64" + "testing" + + a2atype "github.com/a2aproject/a2a-go/a2a" + "github.com/a2aproject/a2a-go/a2asrv" + "github.com/go-logr/logr" + adkartifact "google.golang.org/adk/artifact" + "google.golang.org/adk/runner" + "google.golang.org/genai" +) + +// fakeQueue is a minimal eventqueue.Queue that records written events. +type fakeQueue struct { + events []a2atype.Event +} + +func (q *fakeQueue) Write(_ context.Context, event a2atype.Event) error { + q.events = append(q.events, event) + return nil +} + +func (q *fakeQueue) WriteVersioned(_ context.Context, event a2atype.Event, _ a2atype.TaskVersion) error { + q.events = append(q.events, event) + return nil +} + +func (q *fakeQueue) Read(_ context.Context) (a2atype.Event, a2atype.TaskVersion, error) { + return nil, a2atype.TaskVersionMissing, nil +} + +func (q *fakeQueue) Close() error { return nil } + +// --------------------------------------------------------------------------- +// maxArtifactBytes +// --------------------------------------------------------------------------- + +func TestMaxArtifactBytes(t *testing.T) { + tests := []struct { + name string + env string + want int + }{ + {name: "default", env: "", want: defaultMaxArtifactBytes}, + {name: "override", env: "1024", want: 1024}, + {name: "invalid falls back", env: "not-a-number", want: defaultMaxArtifactBytes}, + {name: "zero falls back", env: "0", want: defaultMaxArtifactBytes}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.env != "" { + t.Setenv(envMaxArtifactBytes, tt.env) + } + if got := MaxArtifactBytes(); got != tt.want { + t.Errorf("MaxArtifactBytes() = %d, want %d", got, tt.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// checkInboundFileSizes (AC2) +// --------------------------------------------------------------------------- + +func fileBytesMessage(name string, data []byte) *a2atype.Message { + return a2atype.NewMessage(a2atype.MessageRoleUser, a2atype.FilePart{ + File: a2atype.FileBytes{ + FileMeta: a2atype.FileMeta{Name: name, MimeType: "text/plain"}, + Bytes: base64.StdEncoding.EncodeToString(data), + }, + }) +} + +func TestCheckInboundFileSizes(t *testing.T) { + tests := []struct { + name string + msg *a2atype.Message + limit int + wantErr bool + }{ + {name: "nil message", msg: nil, limit: 10, wantErr: false}, + {name: "under limit", msg: fileBytesMessage("a.txt", []byte("hello")), limit: 10, wantErr: false}, + {name: "at limit", msg: fileBytesMessage("a.txt", []byte("12345")), limit: 5, wantErr: false}, + {name: "over limit", msg: fileBytesMessage("a.txt", []byte("123456")), limit: 5, wantErr: true}, + {name: "text only ignored", msg: a2atype.NewMessage(a2atype.MessageRoleUser, a2atype.TextPart{Text: "hi"}), limit: 1, wantErr: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := checkInboundFileSizes(tt.msg, tt.limit) + if (err != nil) != tt.wantErr { + t.Errorf("checkInboundFileSizes() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBase64DecodedLen(t *testing.T) { + // Covers each padding case so the allocation-free size check stays exact + // for clean StdEncoding input (the format the UI emits). + for _, n := range []int{0, 1, 2, 3, 5, 10, 64, 1023, 1 << 20} { + data := make([]byte, n) + encoded := base64.StdEncoding.EncodeToString(data) + if got := base64DecodedLen(encoded); got != n { + t.Errorf("base64DecodedLen(%d-byte payload) = %d, want %d", n, got, n) + } + } +} + +func TestCheckInboundFileSizes_PointerPart(t *testing.T) { + msg := a2atype.NewMessage(a2atype.MessageRoleUser, &a2atype.FilePart{ + File: a2atype.FileBytes{ + FileMeta: a2atype.FileMeta{Name: "big.bin"}, + Bytes: base64.StdEncoding.EncodeToString([]byte("0123456789")), + }, + }) + if err := checkInboundFileSizes(msg, 5); err == nil { + t.Error("expected error for oversized pointer FilePart, got nil") + } +} + +// --------------------------------------------------------------------------- +// emitArtifacts (AC4) +// --------------------------------------------------------------------------- + +func TestEmitArtifacts_EmitsFilePart(t *testing.T) { + ctx := context.Background() + const ( + appName = "test-app" + userID = "user-1" + sessionID = "session-1" + fileName = "report.csv" + mimeType = "text/csv" + ) + wantBytes := []byte("a,b,c\n1,2,3\n") + + svc := adkartifact.InMemoryService() + saveResp, err := svc.Save(ctx, &adkartifact.SaveRequest{ + AppName: appName, + UserID: userID, + SessionID: sessionID, + FileName: fileName, + Part: &genai.Part{InlineData: &genai.Blob{Data: wantBytes, MIMEType: mimeType}}, + }) + if err != nil { + t.Fatalf("Save() error = %v", err) + } + + e := &KAgentExecutor{ + runnerConfig: runner.Config{ArtifactService: svc}, + appName: appName, + logger: logr.Discard(), + } + + reqCtx := &a2asrv.RequestContext{ + TaskID: a2atype.NewTaskID(), + ContextID: sessionID, + } + queue := &fakeQueue{} + + e.emitArtifacts(ctx, reqCtx, queue, userID, sessionID, + map[string]int64{fileName: saveResp.Version}, map[string]any{}) + + if len(queue.events) != 1 { + t.Fatalf("expected 1 artifact event, got %d", len(queue.events)) + } + artifactEvent, ok := queue.events[0].(*a2atype.TaskArtifactUpdateEvent) + if !ok { + t.Fatalf("event type = %T, want *a2atype.TaskArtifactUpdateEvent", queue.events[0]) + } + if !artifactEvent.LastChunk { + t.Error("expected LastChunk = true") + } + if len(artifactEvent.Artifact.Parts) != 1 { + t.Fatalf("expected 1 part, got %d", len(artifactEvent.Artifact.Parts)) + } + fp, ok := artifactEvent.Artifact.Parts[0].(a2atype.FilePart) + if !ok { + t.Fatalf("part type = %T, want a2atype.FilePart", artifactEvent.Artifact.Parts[0]) + } + fb, ok := fp.File.(a2atype.FileBytes) + if !ok { + t.Fatalf("file type = %T, want a2atype.FileBytes", fp.File) + } + if fb.Name != fileName { + t.Errorf("file name = %q, want %q", fb.Name, fileName) + } + if fb.MimeType != mimeType { + t.Errorf("mime type = %q, want %q", fb.MimeType, mimeType) + } + gotBytes, err := base64.StdEncoding.DecodeString(fb.Bytes) + if err != nil { + t.Fatalf("decode bytes: %v", err) + } + if string(gotBytes) != string(wantBytes) { + t.Errorf("bytes = %q, want %q", gotBytes, wantBytes) + } +} + +func TestEmitArtifacts_SkipsMissingArtifact(t *testing.T) { + ctx := context.Background() + svc := adkartifact.InMemoryService() + e := &KAgentExecutor{ + runnerConfig: runner.Config{ArtifactService: svc}, + appName: "test-app", + logger: logr.Discard(), + } + reqCtx := &a2asrv.RequestContext{TaskID: a2atype.NewTaskID(), ContextID: "s1"} + queue := &fakeQueue{} + + // Reference an artifact that was never saved → load fails → skip, no panic. + e.emitArtifacts(ctx, reqCtx, queue, "user-1", "s1", + map[string]int64{"missing.txt": 1}, map[string]any{}) + + if len(queue.events) != 0 { + t.Errorf("expected 0 events for missing artifact, got %d", len(queue.events)) + } +} diff --git a/go/adk/pkg/a2a/executor.go b/go/adk/pkg/a2a/executor.go index b9879988bb..0c4cfcee6f 100644 --- a/go/adk/pkg/a2a/executor.go +++ b/go/adk/pkg/a2a/executor.go @@ -183,7 +183,14 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont inboundMessage = resumeMessage } - // 6. Convert inbound message to *genai.Content using kagent a2aPartConverter. + // 6. Guard inbound file size (defense in depth; the UI also enforces this), + // then convert inbound message to *genai.Content using kagent a2aPartConverter. + if err := checkInboundFileSizes(inboundMessage, MaxArtifactBytes()); err != nil { + errMsg := a2atype.NewMessage(a2atype.MessageRoleAgent, a2atype.TextPart{Text: err.Error()}) + failed := a2atype.NewStatusUpdateEvent(reqCtx, a2atype.TaskStateFailed, errMsg) + failed.Final = true + return queue.Write(ctx, failed) + } content, err := messageToGenAIContent(ctx, inboundMessage) if err != nil { return fmt.Errorf("inbound message conversion failed: %w", err) @@ -234,6 +241,9 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont if e.stream { runConfig.StreamingMode = adkagent.StreamingModeSSE } + // Persist inbound user uploads as artifacts so tools/agents can reference + // them later (in addition to passing them inline to the model). + runConfig.SaveInputBlobsAsArtifacts = true // State tracked across the event loop. var ( @@ -261,6 +271,12 @@ func (e *KAgentExecutor) Execute(ctx context.Context, reqCtx *a2asrv.RequestCont // Build per-event metadata (inherits baseMeta + adds invocation_id, usage etc.). eventMeta := buildEventMeta(baseMeta, adkEvent) + // Surface artifacts the agent/tools saved during this event as A2A + // artifact events (FileParts). Errors are logged and skipped. + if len(adkEvent.Actions.ArtifactDelta) > 0 { + e.emitArtifacts(ctx, reqCtx, queue, userID, sessionID, adkEvent.Actions.ArtifactDelta, eventMeta) + } + // Convert GenAI parts → A2A parts (with kagent stamping). if adkEvent.Content == nil || len(adkEvent.Content.Parts) == 0 { // Events with no content carry metadata only; still track invocationID/usage. diff --git a/go/adk/pkg/a2a/executor_test.go b/go/adk/pkg/a2a/executor_test.go new file mode 100644 index 0000000000..d0bcf1f4b9 --- /dev/null +++ b/go/adk/pkg/a2a/executor_test.go @@ -0,0 +1,146 @@ +package a2a + +import ( + "context" + "encoding/base64" + "iter" + "testing" + + a2atype "github.com/a2aproject/a2a-go/a2a" + "github.com/a2aproject/a2a-go/a2asrv" + "github.com/go-logr/logr" + adkagent "google.golang.org/adk/agent" + adkartifact "google.golang.org/adk/artifact" + "google.golang.org/adk/runner" + adksession "google.golang.org/adk/session" +) + +// noopAgent returns an agent that emits no events (logic before agent run is +// what we exercise — e.g. SaveInputBlobsAsArtifacts). +func noopAgent(t *testing.T, name string) adkagent.Agent { + t.Helper() + a, err := adkagent.New(adkagent.Config{ + Name: name, + Run: func(_ adkagent.InvocationContext) iter.Seq2[*adksession.Event, error] { + return func(yield func(*adksession.Event, error) bool) {} + }, + }) + if err != nil { + t.Fatalf("agent.New() error = %v", err) + } + return a +} + +// TestExecute_PersistsInboundUploads verifies that an inbound file upload is +// persisted to the artifact service via SaveInputBlobsAsArtifacts (AC3). +func TestExecute_PersistsInboundUploads(t *testing.T) { + ctx := context.Background() + const ( + appName = "test-app" + contextID = "ctx-1" + ) + userID := "A2A_USER_" + contextID + sessionID := contextID + + sessionSvc := adksession.InMemoryService() + if _, err := sessionSvc.Create(ctx, &adksession.CreateRequest{ + AppName: appName, + UserID: userID, + SessionID: sessionID, + }); err != nil { + t.Fatalf("session create error = %v", err) + } + + artifactSvc := adkartifact.InMemoryService() + + e := NewKAgentExecutor(KAgentExecutorConfig{ + RunnerConfig: runner.Config{ + AppName: appName, + Agent: noopAgent(t, "test_agent"), + SessionService: sessionSvc, + ArtifactService: artifactSvc, + }, + AppName: appName, + Logger: logr.Discard(), + }) + + msg := a2atype.NewMessage(a2atype.MessageRoleUser, + a2atype.TextPart{Text: "here is a file"}, + a2atype.FilePart{File: a2atype.FileBytes{ + FileMeta: a2atype.FileMeta{Name: "note.txt", MimeType: "text/plain"}, + Bytes: base64.StdEncoding.EncodeToString([]byte("hello world")), + }}, + ) + reqCtx := &a2asrv.RequestContext{ + Message: msg, + TaskID: a2atype.NewTaskID(), + ContextID: contextID, + } + + if err := e.Execute(ctx, reqCtx, &fakeQueue{}); err != nil { + t.Fatalf("Execute() error = %v", err) + } + + listResp, err := artifactSvc.List(ctx, &adkartifact.ListRequest{ + AppName: appName, + UserID: userID, + SessionID: sessionID, + }) + if err != nil { + t.Fatalf("artifact List() error = %v", err) + } + if len(listResp.FileNames) != 1 { + t.Fatalf("expected 1 persisted artifact, got %d (%v)", len(listResp.FileNames), listResp.FileNames) + } +} + +// TestExecute_RejectsOversizedUpload verifies the server-side size guard fails +// the task for an oversized inbound file (AC2). +func TestExecute_RejectsOversizedUpload(t *testing.T) { + ctx := context.Background() + const ( + appName = "test-app" + contextID = "ctx-2" + ) + t.Setenv(envMaxArtifactBytes, "8") + + e := NewKAgentExecutor(KAgentExecutorConfig{ + RunnerConfig: runner.Config{ + AppName: appName, + Agent: noopAgent(t, "test_agent"), + SessionService: adksession.InMemoryService(), + ArtifactService: adkartifact.InMemoryService(), + }, + AppName: appName, + Logger: logr.Discard(), + }) + + msg := a2atype.NewMessage(a2atype.MessageRoleUser, + a2atype.FilePart{File: a2atype.FileBytes{ + FileMeta: a2atype.FileMeta{Name: "big.txt"}, + Bytes: base64.StdEncoding.EncodeToString([]byte("way too many bytes")), + }}, + ) + reqCtx := &a2asrv.RequestContext{ + Message: msg, + TaskID: a2atype.NewTaskID(), + ContextID: contextID, + } + + queue := &fakeQueue{} + if err := e.Execute(ctx, reqCtx, queue); err != nil { + t.Fatalf("Execute() unexpected error = %v", err) + } + + // The guard should emit a single failed status update. + if len(queue.events) != 1 { + t.Fatalf("expected 1 event, got %d", len(queue.events)) + } + statusEvent, ok := queue.events[0].(*a2atype.TaskStatusUpdateEvent) + if !ok { + t.Fatalf("event type = %T, want *a2atype.TaskStatusUpdateEvent", queue.events[0]) + } + if statusEvent.Status.State != a2atype.TaskStateFailed { + t.Errorf("state = %q, want %q", statusEvent.Status.State, a2atype.TaskStateFailed) + } +} diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index fa9d633d14..e9dc43804b 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -18,6 +18,7 @@ import ( adkmodel "google.golang.org/adk/model" adkgemini "google.golang.org/adk/model/gemini" "google.golang.org/adk/tool" + "google.golang.org/adk/tool/loadartifactstool" "google.golang.org/adk/tool/loadmemorytool" "google.golang.org/adk/tool/preloadmemorytool" "google.golang.org/genai" @@ -167,6 +168,18 @@ func buildAgentTools(agentConfig *adk.AgentConfig, remoteAgentTools, extraTools localTools = append(localTools, remoteAgentTools...) localTools = append(localTools, extraTools...) + // Register the built-in load_artifacts tool so the LLM can list/load stored + // artifacts (uploaded files and agent-produced files) across turns. + localTools = append(localTools, loadartifactstool.New()) + + // Register the save_artifact tool so the LLM can produce downloadable files + // from chat; the executor surfaces saved artifacts as A2A file parts. + saveArtifactTool, err := tools.NewSaveArtifactTool() + if err != nil { + return nil, fmt.Errorf("failed to create save_artifact tool: %w", err) + } + localTools = append(localTools, saveArtifactTool) + skillsDirectory := strings.TrimSpace(os.Getenv("KAGENT_SKILLS_FOLDER")) if skillsDirectory != "" { skillsTools, err := tools.NewSkillsTools(skillsDirectory) diff --git a/go/adk/pkg/agent/agent_test.go b/go/adk/pkg/agent/agent_test.go index 4f4c5fb45f..c01f994ff1 100644 --- a/go/adk/pkg/agent/agent_test.go +++ b/go/adk/pkg/agent/agent_test.go @@ -322,6 +322,62 @@ Use the script in scripts/convert.py. } } +// TestBuildAgentTools_RegistersLoadArtifactsTool verifies that every agent's +// tool list includes the built-in load_artifacts tool (AC5). +func TestBuildAgentTools_RegistersLoadArtifactsTool(t *testing.T) { + tests := []struct { + name string + config *adk.AgentConfig + }{ + {name: "minimal", config: &adk.AgentConfig{}}, + {name: "with memory", config: &adk.AgentConfig{Memory: &adk.MemoryConfig{TTLDays: 1}}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("KAGENT_SRT_SETTINGS_PATH", filepath.Join(t.TempDir(), "srt-settings.json")) + + tools, err := buildAgentTools(tt.config, nil, nil, logr.Discard()) + if err != nil { + t.Fatalf("buildAgentTools() error = %v", err) + } + + found := false + for _, tool := range tools { + if tool.Name() == "load_artifacts" { + found = true + break + } + } + if !found { + t.Error("expected load_artifacts tool to be registered") + } + }) + } +} + +// TestBuildAgentTools_RegistersSaveArtifactTool verifies that every agent's tool +// list includes the save_artifact tool so agents can produce files from chat. +func TestBuildAgentTools_RegistersSaveArtifactTool(t *testing.T) { + t.Setenv("KAGENT_SRT_SETTINGS_PATH", filepath.Join(t.TempDir(), "srt-settings.json")) + + tools, err := buildAgentTools(&adk.AgentConfig{}, nil, nil, logr.Discard()) + if err != nil { + t.Fatalf("buildAgentTools() error = %v", err) + } + + found := false + for _, tool := range tools { + if tool.Name() == "save_artifact" { + found = true + break + } + } + if !found { + t.Error("expected save_artifact tool to be registered") + } +} + // TestAgentConfigFieldUsage is a smoke test that ensures AgentConfig structures // used by agents exercise all relevant fields. This test acts as a canary: if a // new field is added to AgentConfig but not reflected in this test configuration, diff --git a/go/adk/pkg/agent/createllm_test.go b/go/adk/pkg/agent/createllm_test.go index 9330aa1c71..fb50194cc1 100644 --- a/go/adk/pkg/agent/createllm_test.go +++ b/go/adk/pkg/agent/createllm_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" adkagent "google.golang.org/adk/agent" + adkartifact "google.golang.org/adk/artifact" "google.golang.org/adk/runner" adksession "google.golang.org/adk/session" "google.golang.org/genai" @@ -57,9 +58,10 @@ func runAgent(t *testing.T, agentCfg *adk.AgentConfig, prompt string) string { sessionService := adksession.InMemoryService() r, err := runner.New(runner.Config{ - AppName: "test", - Agent: adkAgent, - SessionService: sessionService, + AppName: "test", + Agent: adkAgent, + SessionService: sessionService, + ArtifactService: adkartifact.InMemoryService(), }) require.NoError(t, err) diff --git a/go/adk/pkg/fileextract/fileextract.go b/go/adk/pkg/fileextract/fileextract.go new file mode 100644 index 0000000000..7b701e09d4 --- /dev/null +++ b/go/adk/pkg/fileextract/fileextract.go @@ -0,0 +1,152 @@ +// Package fileextract turns uploaded file blobs into text the model can read. +// +// Rich documents (PDF, DOCX, XLSX, PPTX, EPUB, HTML) are extracted to +// text/markdown via tabula; PDFs additionally use a ToUnicode-aware fallback +// for Type3 fonts and malformed streams (see pdf.go). Text-like files are +// returned as-is. This mirrors the Python ADK runtime behaviour so non-image +// uploads reach the model instead of being dropped. +package fileextract + +import ( + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/tsawler/tabula" + "google.golang.org/genai" +) + +// docMIMEToExt maps rich document MIME types to the file extension used for +// format detection. These are extracted to text/markdown so the model can read +// their contents. The set is kept in sync with the Python runtime +// (_file_extract.py): the common formats both tabula (Go) and markitdown +// (Python) support — PDF, DOCX, XLSX, PPTX, HTML, EPUB. +var docMIMEToExt = map[string]string{ + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + "application/epub+zip": ".epub", + "text/html": ".html", +} + +// docExtractExts is the set of filename extensions that can be extracted. +var docExtractExts = map[string]bool{ + ".pdf": true, + ".docx": true, + ".xlsx": true, + ".pptx": true, + ".epub": true, + ".html": true, + ".htm": true, +} + +// normalizeMIME strips parameters (e.g. "; charset=utf-8") and lowercases. +func normalizeMIME(mimeType string) string { + mimeType = strings.ToLower(strings.TrimSpace(mimeType)) + if i := strings.IndexByte(mimeType, ';'); i >= 0 { + mimeType = strings.TrimSpace(mimeType[:i]) + } + return mimeType +} + +// isTextLikeMIME reports whether a MIME type can be inlined as raw text. +// text/html is excluded so it is routed through tabula extraction instead. +func isTextLikeMIME(mimeType string) bool { + mimeType = normalizeMIME(mimeType) + if mimeType == "text/html" { + return false + } + if strings.HasPrefix(mimeType, "text/") { + return true + } + switch mimeType { + case "application/json", "application/xml", "application/x-ndjson", + "application/yaml", "application/x-yaml": + return true + } + return false +} + +// docExtractExt returns the tabula extension to use for a rich document, derived +// from the filename first and the MIME type as a fallback. Empty if the file is +// not a tabula-extractable document. +func docExtractExt(mimeType, name string) string { + if ext := strings.ToLower(filepath.Ext(name)); docExtractExts[ext] { + if ext == ".htm" { + return ".html" + } + return ext + } + if ext, ok := docMIMEToExt[normalizeMIME(mimeType)]; ok { + return ext + } + return "" +} + +// extractFileText turns an uploaded file's bytes into text the model can read. +// Rich documents (PDF, Office, EPUB, HTML) are extracted to markdown via +// tabula; text-like files are returned as-is. Returns an error for formats that +// cannot be represented as text (e.g. arbitrary binary). +func extractFileText(data []byte, mimeType, name string) (string, error) { + if ext := docExtractExt(mimeType, name); ext != "" { + return extractDocText(data, ext) + } + if isTextLikeMIME(mimeType) { + return string(data), nil + } + return "", fmt.Errorf("unsupported file type for text extraction: mime=%q name=%q", mimeType, name) +} + +// extractDocText writes the bytes to a temp file (tabula detects format by +// extension) and extracts markdown. PDFs are routed through extractPDF, which +// handles Type3 fonts and malformed streams that tabula's markdown path can't. +func extractDocText(data []byte, ext string) (string, error) { + tmp, err := os.CreateTemp("", "kagent-artifact-*"+ext) + if err != nil { + return "", fmt.Errorf("failed to create temp file for extraction: %w", err) + } + tmpName := tmp.Name() + defer os.Remove(tmpName) + + if _, err := tmp.Write(data); err != nil { + tmp.Close() + return "", fmt.Errorf("failed to write temp file for extraction: %w", err) + } + if err := tmp.Close(); err != nil { + return "", fmt.Errorf("failed to close temp file for extraction: %w", err) + } + + if ext == ".pdf" { + return extractPDF(tmpName) + } + + text, _, err := tabula.Open(tmpName).ToMarkdown() + if err != nil { + return "", fmt.Errorf("failed to extract text from %s document: %w", ext, err) + } + return text, nil +} + +// InlineFileToText converts a non-image inline file blob into text suitable for +// inclusion in a chat message. On extraction failure it returns a short note so +// the model can tell the user the file could not be read, instead of silently +// dropping it. Returns "" for a nil/empty blob. +func InlineFileToText(blob *genai.Blob) string { + if blob == nil { + return "" + } + name := blob.DisplayName + if name == "" { + name = "file" + } + text, err := extractFileText(blob.Data, blob.MIMEType, name) + if err != nil { + return fmt.Sprintf("[Uploaded file %q (%s) could not be read as text.]", name, blob.MIMEType) + } + if strings.TrimSpace(text) == "" { + return fmt.Sprintf("[Uploaded file %q (%s) contained no extractable text.]", name, blob.MIMEType) + } + return fmt.Sprintf("Contents of uploaded file %q:\n\n%s", name, text) +} diff --git a/go/adk/pkg/fileextract/fileextract_test.go b/go/adk/pkg/fileextract/fileextract_test.go new file mode 100644 index 0000000000..f4bc37e5f9 --- /dev/null +++ b/go/adk/pkg/fileextract/fileextract_test.go @@ -0,0 +1,143 @@ +package fileextract + +import ( + "strings" + "testing" + + "google.golang.org/genai" +) + +func TestIsTextLikeMIME(t *testing.T) { + tests := []struct { + name string + mime string + want bool + }{ + {name: "plain text", mime: "text/plain", want: true}, + {name: "markdown", mime: "text/markdown", want: true}, + {name: "csv", mime: "text/csv", want: true}, + {name: "json", mime: "application/json", want: true}, + {name: "text with charset", mime: "text/plain; charset=utf-8", want: true}, + {name: "html excluded", mime: "text/html", want: false}, + {name: "pdf not text", mime: "application/pdf", want: false}, + {name: "image not text", mime: "image/png", want: false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isTextLikeMIME(tt.mime); got != tt.want { + t.Errorf("isTextLikeMIME(%q) = %v, want %v", tt.mime, got, tt.want) + } + }) + } +} + +func TestDocExtractExt(t *testing.T) { + tests := []struct { + name string + mime string + file string + want string + }{ + {name: "pdf by mime", mime: "application/pdf", file: "report", want: ".pdf"}, + {name: "pdf by extension", mime: "application/octet-stream", file: "report.pdf", want: ".pdf"}, + {name: "docx by mime", mime: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", file: "x", want: ".docx"}, + {name: "htm normalized to html", mime: "", file: "page.htm", want: ".html"}, + {name: "text not a doc", mime: "text/plain", file: "a.txt", want: ""}, + {name: "unknown", mime: "application/zip", file: "a.zip", want: ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := docExtractExt(tt.mime, tt.file); got != tt.want { + t.Errorf("docExtractExt(%q, %q) = %q, want %q", tt.mime, tt.file, got, tt.want) + } + }) + } +} + +func TestExtractFileText(t *testing.T) { + tests := []struct { + name string + data []byte + mime string + file string + wantText string + wantErr bool + }{ + {name: "plain text returned as-is", data: []byte("hello world"), mime: "text/plain", file: "a.txt", wantText: "hello world"}, + {name: "json returned as-is", data: []byte(`{"k":"v"}`), mime: "application/json", file: "a.json", wantText: `{"k":"v"}`}, + {name: "csv returned as-is", data: []byte("a,b\n1,2"), mime: "text/csv", file: "a.csv", wantText: "a,b\n1,2"}, + {name: "unsupported binary errors", data: []byte{0x00, 0x01, 0x02}, mime: "application/zip", file: "a.zip", wantErr: true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := extractFileText(tt.data, tt.mime, tt.file) + if (err != nil) != tt.wantErr { + t.Fatalf("extractFileText() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr && got != tt.wantText { + t.Errorf("extractFileText() = %q, want %q", got, tt.wantText) + } + }) + } +} + +// TestExtractFileText_HTML exercises the real tabula extraction path with a +// lightweight HTML document (no binary fixture needed). +func TestExtractFileText_HTML(t *testing.T) { + html := []byte(`

Invoice

Total: 42 USD

`) + got, err := extractFileText(html, "text/html", "invoice.html") + if err != nil { + t.Fatalf("extractFileText() error = %v", err) + } + if !strings.Contains(got, "Invoice") || !strings.Contains(got, "42 USD") { + t.Errorf("extracted text missing expected content: %q", got) + } +} + +// TestExtractFileText_Type3PDF verifies the ToUnicode-based decoding for a PDF +// whose text is drawn with a Type3 font (which tabula's markdown path decodes to +// raw character codes). The fixture maps code 0x01->'H' and 0x02->'i' via a +// ToUnicode CMap and draws <0102>, so correct decoding yields "Hi". +func TestExtractFileText_Type3PDF(t *testing.T) { + got, err := extractFileText(type3PDFFixture(), "application/pdf", "type3.pdf") + if err != nil { + t.Fatalf("extractFileText() error = %v", err) + } + if !strings.Contains(got, "Hi") { + t.Errorf("Type3 PDF did not decode via ToUnicode; got %q", got) + } +} + +func TestInlineFileToText(t *testing.T) { + tests := []struct { + name string + blob *genai.Blob + contains string + }{ + {name: "nil blob", blob: nil, contains: ""}, + { + name: "text file labeled with name", + blob: &genai.Blob{Data: []byte("line1\nline2"), MIMEType: "text/plain", DisplayName: "notes.txt"}, + contains: `Contents of uploaded file "notes.txt"`, + }, + { + name: "unsupported binary returns note", + blob: &genai.Blob{Data: []byte{0x00, 0x01}, MIMEType: "application/zip", DisplayName: "a.zip"}, + contains: "could not be read as text", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := InlineFileToText(tt.blob) + if tt.contains == "" { + if got != "" { + t.Errorf("InlineFileToText() = %q, want empty", got) + } + return + } + if !strings.Contains(got, tt.contains) { + t.Errorf("InlineFileToText() = %q, want containing %q", got, tt.contains) + } + }) + } +} diff --git a/go/adk/pkg/fileextract/fixture_test.go b/go/adk/pkg/fileextract/fixture_test.go new file mode 100644 index 0000000000..05aa34d8b8 --- /dev/null +++ b/go/adk/pkg/fileextract/fixture_test.go @@ -0,0 +1,71 @@ +package fileextract + +import ( + "bytes" + "fmt" +) + +// type3PDFFixture builds a minimal, valid PDF whose single page draws text with +// a Type3 font. The font has no embedded glyphs that map to Unicode by encoding; +// the only correct source of text is its ToUnicode CMap (code 0x01->'H', +// 0x02->'i'). The content stream shows <0102>, so a ToUnicode-aware extractor +// yields "Hi" while a naive one yields the raw code bytes. +func type3PDFFixture() []byte { + objects := []string{ + // 1: Catalog + "<< /Type /Catalog /Pages 2 0 R >>", + // 2: Pages + "<< /Type /Pages /Kids [3 0 R] /Count 1 >>", + // 3: Page + "<< /Type /Page /Parent 2 0 R /MediaBox [0 0 300 200] " + + "/Resources << /Font << /F1 4 0 R >> >> /Contents 5 0 R >>", + // 4: Type3 font + "<< /Type /Font /Subtype /Type3 /FontBBox [0 0 750 750] " + + "/FontMatrix [0.001 0 0 0.001 0 0] /FirstChar 1 /LastChar 2 " + + "/Widths [500 500] /CharProcs << /a1 6 0 R /a2 7 0 R >> " + + "/Encoding << /Type /Encoding /Differences [1 /a1 /a2] >> " + + "/ToUnicode 8 0 R >>", + // 5: page content stream + streamObject("BT\n/F1 12 Tf\n10 100 Td\n<0102> Tj\nET\n"), + // 6: CharProc for code 1 + streamObject("500 0 0 0 500 500 d1\n"), + // 7: CharProc for code 2 + streamObject("500 0 0 0 500 500 d1\n"), + // 8: ToUnicode CMap + streamObject(toUnicodeCMap()), + } + + var buf bytes.Buffer + buf.WriteString("%PDF-1.4\n%\xe2\xe3\xcf\xd3\n") + + offsets := make([]int, len(objects)+1) + for i, body := range objects { + offsets[i+1] = buf.Len() + fmt.Fprintf(&buf, "%d 0 obj\n%s\nendobj\n", i+1, body) + } + + xrefStart := buf.Len() + fmt.Fprintf(&buf, "xref\n0 %d\n", len(objects)+1) + buf.WriteString("0000000000 65535 f \n") + for i := 1; i <= len(objects); i++ { + fmt.Fprintf(&buf, "%010d 00000 n \n", offsets[i]) + } + fmt.Fprintf(&buf, "trailer\n<< /Size %d /Root 1 0 R >>\nstartxref\n%d\n%%%%EOF\n", + len(objects)+1, xrefStart) + + return buf.Bytes() +} + +func streamObject(content string) string { + return fmt.Sprintf("<< /Length %d >>\nstream\n%s\nendstream", len(content), content) +} + +func toUnicodeCMap() string { + return "/CIDInit /ProcSet findresource begin\n" + + "12 dict begin\nbegincmap\n" + + "/CMapName /Adobe-Identity-UCS def\n" + + "/CMapType 2 def\n" + + "1 begincodespacerange\n<00> \nendcodespacerange\n" + + "2 beginbfchar\n<01> <0048>\n<02> <0069>\nendbfchar\n" + + "endcmap\nCMapName currentdict /CMap defineresource pop\nend\nend\n" +} diff --git a/go/adk/pkg/fileextract/pdf.go b/go/adk/pkg/fileextract/pdf.go new file mode 100644 index 0000000000..5be3dede3b --- /dev/null +++ b/go/adk/pkg/fileextract/pdf.go @@ -0,0 +1,196 @@ +package fileextract + +import ( + "fmt" + "strings" + + "github.com/tsawler/tabula" + "github.com/tsawler/tabula/core" + "github.com/tsawler/tabula/font" + "github.com/tsawler/tabula/reader" + "github.com/tsawler/tabula/text" +) + +// extractPDF extracts text from a PDF file. +// +// Tabula's markdown path produces the richest output for well-behaved PDFs, but +// it has two gaps: (1) it never applies a font's ToUnicode CMap for Type3 fonts, +// so text drawn with Type3 fonts decodes to raw character codes — garbled output +// like "4; HE. HKI J" instead of "Zero Trust"; and (2) it can fail outright on +// some malformed content streams. For PDFs that use Type3 fonts, and as a +// fallback when tabula's markdown extraction fails, a tolerant per-page +// extractor (extractPDFText) is used instead. +func extractPDF(path string) (string, error) { + if !pdfUsesType3Fonts(path) { + if md, _, err := tabula.Open(path).ToMarkdown(); err == nil && strings.TrimSpace(md) != "" { + return md, nil + } + // tabula failed or returned nothing — fall back to the per-page + // extractor, which tolerates streams tabula can't parse. + } + + text, err := extractPDFText(path) + if err != nil { + return "", fmt.Errorf("failed to extract text from pdf document: %w", err) + } + if strings.TrimSpace(text) == "" { + return "", fmt.Errorf("failed to extract text from pdf document: no extractable text") + } + return text, nil +} + +// extractPDFText extracts text page by page using tabula's lower-level reader +// and text extractor, additionally registering Type3 fonts with their ToUnicode +// CMaps (which tabula's high-level path skips). Pages or content streams that +// fail to parse are skipped rather than aborting the whole document. +func extractPDFText(path string) (string, error) { + r, err := reader.Open(path) + if err != nil { + return "", fmt.Errorf("failed to open PDF: %w", err) + } + defer r.Close() + + pageCount, err := r.PageCount() + if err != nil { + return "", fmt.Errorf("failed to read PDF page count: %w", err) + } + + resolver := func(ref core.IndirectRef) (core.Object, error) { + return r.ResolveReference(ref) + } + + var sb strings.Builder + for i := range pageCount { + page, err := r.GetPage(i) + if err != nil || page == nil { + continue + } + + ex := text.NewExtractor() + resources, _ := page.Resources() + if resources != nil { + // Register the font subtypes tabula handles natively, then fill + // the Type3 gap so those fonts decode via their ToUnicode CMaps. + _ = ex.RegisterFontsFromResources(resources, resolver) + registerType3Fonts(ex, resources, resolver) + ex.SetResourceContext(resources, resolver) + } + + contents, _ := page.Contents() + var data []byte + for _, c := range contents { + obj, _ := resolveObject(c, resolver) + stream, ok := obj.(*core.Stream) + if !ok { + continue + } + decoded, err := stream.Decode() + if err != nil { + continue + } + data = append(data, decoded...) + data = append(data, '\n') + } + if len(data) == 0 { + continue + } + if _, err := ex.ExtractFromBytes(data); err != nil { + continue + } + sb.WriteString(ex.GetText()) + sb.WriteString("\n\n") + } + + return sb.String(), nil +} + +// registerType3Fonts registers every Type3 font in a resources dictionary with +// its ToUnicode CMap so it decodes to the correct Unicode text. +func registerType3Fonts(ex *text.Extractor, resources core.Dict, resolver func(core.IndirectRef) (core.Object, error)) { + fontDictObj, _ := resolveObject(resources.Get("Font"), resolver) + fonts, ok := fontDictObj.(core.Dict) + if !ok { + return + } + + for name, fontObj := range fonts { + resolved, _ := resolveObject(fontObj, resolver) + fontDict, ok := resolved.(core.Dict) + if !ok { + continue + } + if subtype, _ := fontDict.GetName("Subtype"); string(subtype) != "Type3" { + continue + } + + f := font.NewFont(name, "", "Type3") + if tu := fontDict.Get("ToUnicode"); tu != nil { + if obj, _ := resolveObject(tu, resolver); obj != nil { + if stream, ok := obj.(*core.Stream); ok { + if cmap, err := font.ParseToUnicodeCMap(stream); err == nil { + f.ToUnicodeCMap = cmap + } + } + } + } + + ex.RegisterParsedFont(name, f) + if !strings.HasPrefix(name, "/") { + ex.RegisterParsedFont("/"+name, f) + } + } +} + +// pdfUsesType3Fonts reports whether any page resource references a Type3 font. +// It scans font dictionaries only (no text extraction), so it is cheap. +func pdfUsesType3Fonts(path string) bool { + r, err := reader.Open(path) + if err != nil { + return false + } + defer r.Close() + + pageCount, err := r.PageCount() + if err != nil { + return false + } + resolver := func(ref core.IndirectRef) (core.Object, error) { + return r.ResolveReference(ref) + } + + for i := range pageCount { + page, err := r.GetPage(i) + if err != nil || page == nil { + continue + } + resources, _ := page.Resources() + if resources == nil { + continue + } + fontDictObj, _ := resolveObject(resources.Get("Font"), resolver) + fonts, ok := fontDictObj.(core.Dict) + if !ok { + continue + } + for _, fontObj := range fonts { + resolved, _ := resolveObject(fontObj, resolver) + fontDict, ok := resolved.(core.Dict) + if !ok { + continue + } + if subtype, _ := fontDict.GetName("Subtype"); string(subtype) == "Type3" { + return true + } + } + } + return false +} + +// resolveObject dereferences an indirect reference, returning other objects +// unchanged. +func resolveObject(obj core.Object, resolver func(core.IndirectRef) (core.Object, error)) (core.Object, error) { + if ref, ok := obj.(core.IndirectRef); ok { + return resolver(ref) + } + return obj, nil +} diff --git a/go/adk/pkg/models/openai_adk.go b/go/adk/pkg/models/openai_adk.go index d1c0ed539c..98430629ba 100644 --- a/go/adk/pkg/models/openai_adk.go +++ b/go/adk/pkg/models/openai_adk.go @@ -11,6 +11,7 @@ import ( "slices" "strings" + "github.com/kagent-dev/kagent/go/adk/pkg/fileextract" "github.com/kagent-dev/kagent/go/adk/pkg/telemetry" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/packages/param" @@ -237,10 +238,14 @@ func genaiContentsToOpenAIMessages(contents []*genai.Content, config *genai.Gene textParts = append(textParts, part.Text) } else if part.FunctionCall != nil { functionCalls = append(functionCalls, part.FunctionCall) - } else if part.InlineData != nil && strings.HasPrefix(part.InlineData.MIMEType, "image/") { - imageParts = append(imageParts, openai.ChatCompletionContentPartImageImageURLParam{ - URL: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data)), - }) + } else if part.InlineData != nil { + if strings.HasPrefix(part.InlineData.MIMEType, "image/") { + imageParts = append(imageParts, openai.ChatCompletionContentPartImageImageURLParam{ + URL: fmt.Sprintf("data:%s;base64,%s", part.InlineData.MIMEType, base64.StdEncoding.EncodeToString(part.InlineData.Data)), + }) + } else if text := fileextract.InlineFileToText(part.InlineData); text != "" { + textParts = append(textParts, text) + } } } diff --git a/go/adk/pkg/models/openai_adk_test.go b/go/adk/pkg/models/openai_adk_test.go index 05de50f733..450766447a 100644 --- a/go/adk/pkg/models/openai_adk_test.go +++ b/go/adk/pkg/models/openai_adk_test.go @@ -3,6 +3,7 @@ package models import ( "encoding/base64" "encoding/json" + "strings" "testing" "github.com/openai/openai-go/v3" @@ -222,6 +223,35 @@ func TestGenaiContentsToOpenAIMessages(t *testing.T) { }) } +func TestGenaiContentsToOpenAIMessages_NonImageFileBecomesText(t *testing.T) { + contents := []*genai.Content{{ + Role: string(genai.RoleUser), + Parts: []*genai.Part{ + {Text: "explain attached invoice"}, + {InlineData: &genai.Blob{ + Data: []byte("vendor,amount\nAcme,42"), + MIMEType: "text/csv", + DisplayName: "invoice.csv", + }}, + }, + }} + + msgs, _ := genaiContentsToOpenAIMessages(contents, nil) + if len(msgs) != 1 { + t.Fatalf("len(messages) = %d, want 1", len(msgs)) + } + if msgs[0].OfUser == nil { + t.Fatalf("expected a user message, got %+v", msgs[0]) + } + content := msgs[0].OfUser.Content.OfString.Value + if !strings.Contains(content, "explain attached invoice") { + t.Errorf("message missing user text, got %q", content) + } + if !strings.Contains(content, "invoice.csv") || !strings.Contains(content, "Acme,42") { + t.Errorf("message missing extracted file contents, got %q", content) + } +} + func TestApplyOpenAIConfig(t *testing.T) { t.Run("nil config no panic", func(t *testing.T) { var params openai.ChatCompletionNewParams diff --git a/go/adk/pkg/runner/adapter.go b/go/adk/pkg/runner/adapter.go index 0441f778c0..11b9381a67 100644 --- a/go/adk/pkg/runner/adapter.go +++ b/go/adk/pkg/runner/adapter.go @@ -12,6 +12,7 @@ import ( "github.com/kagent-dev/kagent/go/adk/pkg/session" "github.com/kagent-dev/kagent/go/adk/pkg/sts" "github.com/kagent-dev/kagent/go/api/adk" + adkartifact "google.golang.org/adk/artifact" adkmemory "google.golang.org/adk/memory" adkplugin "google.golang.org/adk/plugin" "google.golang.org/adk/runner" @@ -84,10 +85,11 @@ func CreateRunnerConfig( } cfg := runner.Config{ - AppName: appName, - Agent: adkAgent, - SessionService: adkSessionService, - MemoryService: runnerMemory, + AppName: appName, + Agent: adkAgent, + SessionService: adkSessionService, + MemoryService: runnerMemory, + ArtifactService: adkartifact.InMemoryService(), PluginConfig: runner.PluginConfig{ Plugins: adkPlugins, }, diff --git a/go/adk/pkg/runner/adapter_test.go b/go/adk/pkg/runner/adapter_test.go new file mode 100644 index 0000000000..30e51d1a8b --- /dev/null +++ b/go/adk/pkg/runner/adapter_test.go @@ -0,0 +1,44 @@ +package runner + +import ( + "testing" + + "github.com/go-logr/logr" + "github.com/kagent-dev/kagent/go/api/adk" +) + +// TestCreateRunnerConfig_WiresArtifactService verifies that CreateRunnerConfig +// sets a non-nil ArtifactService so agents/tools get a working ctx.Artifacts() +// (AC1). +func TestCreateRunnerConfig_WiresArtifactService(t *testing.T) { + tests := []struct { + name string + appName string + }{ + {name: "named app", appName: "my-app"}, + {name: "default app", appName: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("OPENAI_API_KEY", "test-key") + ctx := logr.NewContext(t.Context(), logr.Discard()) + agentConfig := &adk.AgentConfig{ + Model: &adk.OpenAI{ + BaseModel: adk.BaseModel{Type: "openai", Model: "gpt-4o-mini"}, + BaseUrl: "https://api.openai.com/v1", + }, + Description: "test agent", + Instruction: "you are helpful", + } + + cfg, _, err := CreateRunnerConfig(ctx, agentConfig, nil, tt.appName, nil) + if err != nil { + t.Fatalf("CreateRunnerConfig() error = %v", err) + } + if cfg.ArtifactService == nil { + t.Fatal("CreateRunnerConfig() ArtifactService = nil, want non-nil") + } + }) + } +} diff --git a/go/adk/pkg/tools/save_artifact_tool.go b/go/adk/pkg/tools/save_artifact_tool.go new file mode 100644 index 0000000000..2bb2873387 --- /dev/null +++ b/go/adk/pkg/tools/save_artifact_tool.go @@ -0,0 +1,111 @@ +package tools + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + + "github.com/kagent-dev/kagent/go/adk/pkg/a2a" + "google.golang.org/adk/agent" + "google.golang.org/adk/tool" + "google.golang.org/adk/tool/functiontool" + "google.golang.org/genai" +) + +// saveArtifactInput is the LLM-facing argument schema for the save_artifact tool. +type saveArtifactInput struct { + Name string `json:"name"` + Content string `json:"content"` + MimeType string `json:"mime_type,omitempty"` + Base64 bool `json:"base64,omitempty"` +} + +// NewSaveArtifactTool creates the save_artifact tool, letting an agent persist +// content as a downloadable file artifact in the current session. The artifact +// is auto-surfaced to the client as an A2A file part by the executor. +func NewSaveArtifactTool() (tool.Tool, error) { + return functiontool.New(functiontool.Config{ + Name: "save_artifact", + Description: "Saves content as a downloadable file artifact in the current session so the " + + "user receives it as a file attachment. Provide a file name (e.g. \"report.csv\"), the " + + "file content, and optionally a MIME type (defaults to text/plain). For binary content, " + + "base64-encode it and set base64=true.", + }, func(toolCtx agent.ToolContext, in saveArtifactInput) (map[string]any, error) { + return saveArtifact(toolCtx, toolCtx.Artifacts(), in, a2a.MaxArtifactBytes()) + }) +} + +// saveArtifact holds the testable core of the save_artifact tool: it validates +// the input, decodes the content, enforces the size limit, and stores the +// artifact as inline data so it round-trips to the UI as a file part. +func saveArtifact(ctx context.Context, artifacts agent.Artifacts, in saveArtifactInput, limit int) (map[string]any, error) { + if artifacts == nil { + return nil, fmt.Errorf("artifact service is not available") + } + + name := strings.TrimSpace(in.Name) + if name == "" { + return nil, fmt.Errorf("missing required parameter: name") + } + if strings.ContainsAny(name, `/\`) { + return nil, fmt.Errorf("invalid name %q: must not contain path separators", name) + } + + var data []byte + if in.Base64 { + decoded, err := decodeFlexibleBase64(in.Content) + if err != nil { + return nil, fmt.Errorf("invalid base64 content for artifact %q: %w", name, err) + } + data = decoded + } else { + data = []byte(in.Content) + } + + if len(data) > limit { + return nil, fmt.Errorf("artifact %q exceeds maximum allowed size: %d bytes > %d bytes", name, len(data), limit) + } + + mimeType := strings.TrimSpace(in.MimeType) + if mimeType == "" { + mimeType = "text/plain" + } + + part := &genai.Part{InlineData: &genai.Blob{ + Data: data, + MIMEType: mimeType, + DisplayName: name, + }} + + resp, err := artifacts.Save(ctx, name, part) + if err != nil { + return nil, fmt.Errorf("failed to save artifact %q: %w", name, err) + } + + return map[string]any{ + "status": "saved", + "name": name, + "version": resp.Version, + "mime_type": mimeType, + "size_bytes": len(data), + }, nil +} + +// decodeFlexibleBase64 decodes base64 using the standard or URL-safe alphabet, +// with or without padding. LLMs sometimes emit url-safe (-_) or unpadded base64; +// trying the common encodings avoids failing the tool call with a confusing +// "illegal base64 data" error for content that is in fact decodable. +func decodeFlexibleBase64(s string) ([]byte, error) { + for _, enc := range []*base64.Encoding{ + base64.StdEncoding, + base64.RawStdEncoding, + base64.URLEncoding, + base64.RawURLEncoding, + } { + if b, err := enc.DecodeString(s); err == nil { + return b, nil + } + } + return nil, fmt.Errorf("content is not valid base64 (tried standard and url-safe, padded and unpadded)") +} diff --git a/go/adk/pkg/tools/save_artifact_tool_test.go b/go/adk/pkg/tools/save_artifact_tool_test.go new file mode 100644 index 0000000000..d94eb46043 --- /dev/null +++ b/go/adk/pkg/tools/save_artifact_tool_test.go @@ -0,0 +1,194 @@ +package tools + +import ( + "context" + "encoding/base64" + "fmt" + "testing" + + "google.golang.org/adk/agent" + "google.golang.org/adk/artifact" + "google.golang.org/genai" +) + +// fakeArtifacts is a minimal agent.Artifacts implementation that records saves. +type fakeArtifacts struct { + saved map[string]*genai.Part + versions map[string]int64 + saveErr error +} + +func newFakeArtifacts() *fakeArtifacts { + return &fakeArtifacts{saved: map[string]*genai.Part{}, versions: map[string]int64{}} +} + +func (f *fakeArtifacts) Save(_ context.Context, name string, data *genai.Part) (*artifact.SaveResponse, error) { + if f.saveErr != nil { + return nil, f.saveErr + } + f.versions[name]++ + f.saved[name] = data + return &artifact.SaveResponse{Version: f.versions[name]}, nil +} + +func (f *fakeArtifacts) List(context.Context) (*artifact.ListResponse, error) { + names := make([]string, 0, len(f.saved)) + for n := range f.saved { + names = append(names, n) + } + return &artifact.ListResponse{FileNames: names}, nil +} + +func (f *fakeArtifacts) Load(_ context.Context, name string) (*artifact.LoadResponse, error) { + return &artifact.LoadResponse{Part: f.saved[name]}, nil +} + +func (f *fakeArtifacts) LoadVersion(ctx context.Context, name string, _ int) (*artifact.LoadResponse, error) { + return f.Load(ctx, name) +} + +func TestSaveArtifact(t *testing.T) { + tests := []struct { + name string + artifacts agent.Artifacts + input saveArtifactInput + limit int + wantErr bool + wantBytes []byte + wantMime string + wantVersion int64 + }{ + { + name: "text content stored as inline data", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "note.txt", Content: "hello", MimeType: "text/plain"}, + limit: 1024, + wantBytes: []byte("hello"), + wantMime: "text/plain", + wantVersion: 1, + }, + { + name: "missing mime defaults to text/plain", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "a.csv", Content: "a,b\n1,2"}, + limit: 1024, + wantBytes: []byte("a,b\n1,2"), + wantMime: "text/plain", + wantVersion: 1, + }, + { + name: "base64 content decoded", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "img.bin", Content: base64.StdEncoding.EncodeToString([]byte{0x01, 0x02, 0x03}), MimeType: "application/octet-stream", Base64: true}, + limit: 1024, + wantBytes: []byte{0x01, 0x02, 0x03}, + wantMime: "application/octet-stream", + wantVersion: 1, + }, + { + // "____" is url-safe base64 for 0xff,0xff,0xff; standard decoding rejects '_'. + name: "url-safe base64 decoded", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "u.bin", Content: "____", MimeType: "application/octet-stream", Base64: true}, + limit: 1024, + wantBytes: []byte{0xff, 0xff, 0xff}, + wantMime: "application/octet-stream", + wantVersion: 1, + }, + { + // "AQI" is unpadded base64 for 0x01,0x02; standard padded decoding rejects it. + name: "unpadded base64 decoded", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "p.bin", Content: "AQI", MimeType: "application/octet-stream", Base64: true}, + limit: 1024, + wantBytes: []byte{0x01, 0x02}, + wantMime: "application/octet-stream", + wantVersion: 1, + }, + { + name: "empty name rejected", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: " ", Content: "x"}, + limit: 1024, + wantErr: true, + }, + { + name: "path separator rejected", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "dir/note.txt", Content: "x"}, + limit: 1024, + wantErr: true, + }, + { + name: "invalid base64 rejected", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "x.bin", Content: "not base64!!!", Base64: true}, + limit: 1024, + wantErr: true, + }, + { + name: "oversized content rejected", + artifacts: newFakeArtifacts(), + input: saveArtifactInput{Name: "big.txt", Content: "0123456789"}, + limit: 5, + wantErr: true, + }, + { + name: "nil artifact service rejected", + artifacts: nil, + input: saveArtifactInput{Name: "x.txt", Content: "x"}, + limit: 1024, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := saveArtifact(context.Background(), tt.artifacts, tt.input, tt.limit) + if (err != nil) != tt.wantErr { + t.Fatalf("saveArtifact() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + + if got["version"] != tt.wantVersion { + t.Errorf("version = %v, want %d", got["version"], tt.wantVersion) + } + if got["mime_type"] != tt.wantMime { + t.Errorf("mime_type = %v, want %q", got["mime_type"], tt.wantMime) + } + + fa := tt.artifacts.(*fakeArtifacts) + part := fa.saved[tt.input.Name] + if part == nil || part.InlineData == nil { + t.Fatalf("expected artifact %q saved as inline data", tt.input.Name) + } + if string(part.InlineData.Data) != string(tt.wantBytes) { + t.Errorf("stored bytes = %v, want %v", part.InlineData.Data, tt.wantBytes) + } + if part.InlineData.DisplayName != tt.input.Name { + t.Errorf("display name = %q, want %q", part.InlineData.DisplayName, tt.input.Name) + } + }) + } +} + +func TestSaveArtifact_PropagatesSaveError(t *testing.T) { + fa := newFakeArtifacts() + fa.saveErr = fmt.Errorf("store unavailable") + _, err := saveArtifact(context.Background(), fa, saveArtifactInput{Name: "x.txt", Content: "x"}, 1024) + if err == nil { + t.Fatal("expected error when underlying Save fails") + } +} + +func TestNewSaveArtifactTool(t *testing.T) { + tl, err := NewSaveArtifactTool() + if err != nil { + t.Fatalf("NewSaveArtifactTool() error = %v", err) + } + if tl.Name() != "save_artifact" { + t.Errorf("tool name = %q, want %q", tl.Name(), "save_artifact") + } +} diff --git a/go/core/test/e2e/file_upload_test.go b/go/core/test/e2e/file_upload_test.go new file mode 100644 index 0000000000..047f3f765d --- /dev/null +++ b/go/core/test/e2e/file_upload_test.go @@ -0,0 +1,64 @@ +package e2e_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + a2atype "github.com/a2aproject/a2a-go/v2/a2a" + "github.com/kagent-dev/kagent/go/api/v1alpha2" + "github.com/kagent-dev/kagent/go/core/internal/a2a" + "github.com/stretchr/testify/require" + "k8s.io/client-go/util/retry" +) + +// TestE2EFileUploadGoADKAgent verifies the end-to-end file upload round trip on +// a Go ADK agent: a user uploads a file inline (A2A file part) and the agent +// processes the request (file persisted as an artifact via +// SaveInputBlobsAsArtifacts) and responds successfully (AC8 upload path). +func TestE2EFileUploadGoADKAgent(t *testing.T) { + baseURL, stopServer := setupMockServer(t, "mocks/invoke_file_upload_agent.json") + defer stopServer() + + cli := setupK8sClient(t, false) + modelCfg := setupModelConfig(t, cli, baseURL) + + goRuntime := v1alpha2.DeclarativeRuntime_Go + agent := setupAgentWithOptions(t, cli, modelCfg.Name, nil, AgentOptions{ + Name: "file-upload-go-adk-test", + SystemMessage: "You are a helpful test agent that handles uploaded files.", + Runtime: &goRuntime, + }) + + a2aClient := setupA2AClient(t, agent) + + fileContent := []byte("hello from an uploaded file") + filePart := a2atype.NewRawPart(fileContent) + filePart.Filename = "note.txt" + filePart.MediaType = "text/plain" + textPart := a2atype.NewTextPart("Please confirm you received the uploaded file.") + + msg := a2atype.NewMessage(a2atype.MessageRoleUser, textPart, filePart) + + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + var result a2atype.SendMessageResult + err := retry.OnError(defaultRetry, func(err error) bool { return err != nil }, func() error { + reqCtx, reqCancel := context.WithTimeout(ctx, 15*time.Second) + defer reqCancel() + var sendErr error + result, sendErr = a2aClient.SendMessage(reqCtx, &a2atype.SendMessageRequest{Message: msg}) + return sendErr + }) + require.NoError(t, err) + + taskResult, ok := result.(*a2atype.Task) + require.True(t, ok) + + text := a2a.ExtractText(taskResult.History[len(taskResult.History)-1]) + jsn, marshalErr := json.Marshal(taskResult) + require.NoError(t, marshalErr) + require.Contains(t, text, "received your uploaded file", string(jsn)) +} diff --git a/go/core/test/e2e/mocks/invoke_file_upload_agent.json b/go/core/test/e2e/mocks/invoke_file_upload_agent.json new file mode 100644 index 0000000000..9bf9d0c729 --- /dev/null +++ b/go/core/test/e2e/mocks/invoke_file_upload_agent.json @@ -0,0 +1,35 @@ +{ + "openai": [ + { + "name": "file_upload_request", + "match": { + "match_type": "contains", + "message": { + "content": "confirm you received the uploaded file", + "role": "user" + } + }, + "response": { + "id": "chatcmpl-file-upload-1", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4.1-mini", + "choices": [ + { + "index": 0, + "message": { + "content": "I received your uploaded file successfully.", + "role": "assistant" + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 6, + "total_tokens": 16 + } + } + } + ] +} diff --git a/go/go.mod b/go/go.mod index f146e9bd89..d7f7e9defe 100644 --- a/go/go.mod +++ b/go/go.mod @@ -73,6 +73,7 @@ require ( github.com/pgvector/pgvector-go/pgx v0.4.0 github.com/testcontainers/testcontainers-go v0.43.0 github.com/testcontainers/testcontainers-go/modules/postgres v0.43.0 + github.com/tsawler/tabula v1.6.6 go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc v0.20.0 go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploghttp v0.20.0 go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.44.0 @@ -324,6 +325,7 @@ require ( github.com/nunnatsa/ginkgolinter v0.23.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/otiai10/gosseract/v2 v2.4.1 // indirect github.com/pb33f/ordered-map/v2 v2.3.1 // indirect github.com/pelletier/go-toml/v2 v2.3.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect @@ -415,6 +417,7 @@ require ( golang.org/x/crypto v0.53.0 // indirect golang.org/x/exp v0.0.0-20251219203646-944ab1f22d93 // indirect golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 // indirect + golang.org/x/image v0.22.0 // indirect golang.org/x/mod v0.37.0 // indirect golang.org/x/net v0.56.0 // indirect golang.org/x/oauth2 v0.36.0 // indirect diff --git a/go/go.sum b/go/go.sum index df0bc0687b..e99f143584 100644 --- a/go/go.sum +++ b/go/go.sum @@ -655,8 +655,12 @@ github.com/otiai10/copy v1.14.0 h1:dCI/t1iTdYGtkvCuBG2BgR6KZa83PTclw4U5n2wAllU= github.com/otiai10/copy v1.14.0/go.mod h1:ECfuL02W+/FkTWZWgQqXPWZgW9oeKCSQ5qVfSc4qc4w= github.com/otiai10/curr v0.0.0-20150429015615-9b4961190c95/go.mod h1:9qAhocn7zKJG+0mI8eUu6xqkFDYS2kb2saOteoSB3cE= github.com/otiai10/curr v1.0.0/go.mod h1:LskTG5wDwr8Rs+nNQ+1LlxRjAtTZZjtJW4rMXl6j4vs= +github.com/otiai10/gosseract/v2 v2.4.1 h1:G8AyBpXEeSlcq8TI85LH/pM5SXk8Djy2GEXisgyblRw= +github.com/otiai10/gosseract/v2 v2.4.1/go.mod h1:1gNWP4Hgr2o7yqWfs6r5bZxAatjOIdqWxJLWsTsembk= github.com/otiai10/mint v1.3.0/go.mod h1:F5AjcsTsWUqX+Na9fpHb52P8pcRX2CI6A3ctIT91xUo= github.com/otiai10/mint v1.3.1/go.mod h1:/yxELlJQ0ufhjUwhshSj+wFjZ78CnZ48/1wtmBH1OTc= +github.com/otiai10/mint v1.6.3 h1:87qsV/aw1F5as1eH1zS/yqHY85ANKVMgkDrf9rcxbQs= +github.com/otiai10/mint v1.6.3/go.mod h1:MJm72SBthJjz8qhefc4z1PYEieWmy8Bku7CjcAqyUSM= github.com/pb33f/ordered-map/v2 v2.3.1 h1:5319HDO0aw4DA4gzi+zv4FXU9UlSs3xGZ40wcP1nBjY= github.com/pb33f/ordered-map/v2 v2.3.1/go.mod h1:qxFQgd0PkVUtOMCkTapqotNgzRhMPL7VvaHKbd1HnmQ= github.com/pelletier/go-toml/v2 v2.3.1 h1:MYEvvGnQjeNkRF1qUuGolNtNExTDwct51yp7olPtrEc= @@ -810,6 +814,8 @@ github.com/tomarrell/wrapcheck/v2 v2.12.0 h1:H/qQ1aNWz/eeIhxKAFvkfIA+N7YDvq6TWVF github.com/tomarrell/wrapcheck/v2 v2.12.0/go.mod h1:AQhQuZd0p7b6rfW+vUwHm5OMCGgp63moQ9Qr/0BpIWo= github.com/tommy-muehle/go-mnd/v2 v2.5.1 h1:NowYhSdyE/1zwK9QCLeRb6USWdoif80Ie+v+yU8u1Zw= github.com/tommy-muehle/go-mnd/v2 v2.5.1/go.mod h1:WsUAkMJMYww6l/ufffCD3m+P7LEvr8TnZn9lwVDlgzw= +github.com/tsawler/tabula v1.6.6 h1:B2W1Iindrg58/VTLp30LPA2NCzxgxAE6Ne6RG693GHY= +github.com/tsawler/tabula v1.6.6/go.mod h1:CzvlQnJQLM2C6Cq0gRhP5z9kq9u9iM6s494QEHCmCbw= github.com/ultraware/funlen v0.2.0 h1:gCHmCn+d2/1SemTdYMiKLAHFYxTYz7z9VIDRaTGyLkI= github.com/ultraware/funlen v0.2.0/go.mod h1:ZE0q4TsJ8T1SQcjmkhN/w+MceuatI6pBFSxxyteHIJA= github.com/ultraware/whitespace v0.2.0 h1:TYowo2m9Nfj1baEQBjuHzvMRbp19i+RCcRYrSWoFa+g= @@ -931,6 +937,8 @@ golang.org/x/exp/typeparams v0.0.0-20220428152302-39d4317da171/go.mod h1:AbB0pIl golang.org/x/exp/typeparams v0.0.0-20230203172020-98cc5a0785f9/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358 h1:qWFG1Dj7TBjOjOvhEOkmyGPVoquqUKnIU0lEVLp8xyk= golang.org/x/exp/typeparams v0.0.0-20260209203927-2842357ff358/go.mod h1:4Mzdyp/6jzw9auFDJ3OMF5qksa7UvPnzKqTVGcb04ms= +golang.org/x/image v0.22.0 h1:UtK5yLUzilVrkjMAZAZ34DXGpASN8i8pj8g+O+yd10g= +golang.org/x/image v0.22.0/go.mod h1:9hPFhljd4zZ1GNSIZJ49sqbp45GKK9t6w+iXvGqZUz4= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.1/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= diff --git a/helm/kagent/files/nginx.conf b/helm/kagent/files/nginx.conf index 1af906b3eb..e346db0822 100644 --- a/helm/kagent/files/nginx.conf +++ b/helm/kagent/files/nginx.conf @@ -15,6 +15,8 @@ http { access_log /dev/stdout; + client_max_body_size {{ .Values.ui.nginx.clientMaxBodySize }}; + log_format main '[$time_local] $remote_addr - $remote_user - $request $status $body_bytes_sent $http_referer $http_user_agent $http_x_forwarded_for'; log_format upstreamlog '[$time_local] $remote_addr - $remote_user - $server_name $host to: $upstream_addr: $request $status upstream_response_time $upstream_response_time msec $msec request_time $request_time'; diff --git a/helm/kagent/tests/ui-nginx-configmap_test.yaml b/helm/kagent/tests/ui-nginx-configmap_test.yaml index bfc03cf09c..22d9d28627 100644 --- a/helm/kagent/tests/ui-nginx-configmap_test.yaml +++ b/helm/kagent/tests/ui-nginx-configmap_test.yaml @@ -69,3 +69,21 @@ tests: - matchRegex: path: data["nginx.conf"] pattern: 'server RELEASE-NAME-controller\.NAMESPACE\.svc\.cluster\.local:8083;' + + - it: should set client_max_body_size by default + template: ui-nginx-configmap.yaml + asserts: + - matchRegex: + path: data["nginx.conf"] + pattern: "client_max_body_size 50m;" + + - it: should allow overriding client_max_body_size + template: ui-nginx-configmap.yaml + set: + ui: + nginx: + clientMaxBodySize: 200m + asserts: + - matchRegex: + path: data["nginx.conf"] + pattern: "client_max_body_size 200m;" diff --git a/helm/kagent/values.yaml b/helm/kagent/values.yaml index eb17c95f35..3ddfc3a041 100644 --- a/helm/kagent/values.yaml +++ b/helm/kagent/values.yaml @@ -351,9 +351,12 @@ ui: # streaming response if no event is received within this window. Should be >= # ui.nginx.proxyReadTimeout so nginx isn't the silent limit. Default 1800 (30m). streamTimeoutSeconds: 1800 - # -- Nginx proxy timeout configuration for the UI sidecar (values are passed - # directly to the corresponding nginx directives, e.g. "1800s"). + # -- Nginx configuration for the UI sidecar (values are passed directly to the + # corresponding nginx directives, e.g. "1800s"). nginx: + # -- client_max_body_size: max allowed size of the client request body. Increase + # to allow large chat messages/attachments (set to 0 to disable the limit). + clientMaxBodySize: 50m # -- proxy_read_timeout: max time between two successive reads from the upstream. proxyReadTimeout: 1800s # -- proxy_send_timeout: max time between two successive writes to the upstream. diff --git a/python/packages/kagent-adk/pyproject.toml b/python/packages/kagent-adk/pyproject.toml index 55ef9047e1..a92d0817f9 100644 --- a/python/packages/kagent-adk/pyproject.toml +++ b/python/packages/kagent-adk/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "boto3>=1.28.57", "ollama >=0.3.6", # Ollama SDK "numpy>=2.2.6", + "markitdown[docx,pptx,xlsx,pdf]>=0.1.1", # extract text from uploaded PDF/Office docs for the model ] [tool.uv.sources] diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_file_extract.py b/python/packages/kagent-adk/src/kagent/adk/models/_file_extract.py new file mode 100644 index 0000000000..9b43d947cf --- /dev/null +++ b/python/packages/kagent-adk/src/kagent/adk/models/_file_extract.py @@ -0,0 +1,120 @@ +"""Helpers to turn uploaded file blobs into text the model can read. + +Rich documents (PDF, DOCX, XLSX, PPTX, HTML) are extracted to markdown via +``markitdown``; text-like files are returned as-is. This mirrors the Go ADK +runtime behaviour so non-image uploads reach the model instead of being dropped. +""" + +from __future__ import annotations + +import os +import tempfile +from typing import Optional + +from google.genai import types + +# Rich document MIME types mapped to the extension markitdown uses for format +# detection. Kept in sync with the Go runtime (fileextract.go): the common +# formats both markitdown (Python) and tabula (Go) support — PDF, DOCX, XLSX, +# PPTX, HTML, EPUB. +_DOC_MIME_TO_EXT: dict[str, str] = { + "application/pdf": ".pdf", + "application/vnd.openxmlformats-officedocument.wordprocessingml.document": ".docx", + "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet": ".xlsx", + "application/vnd.openxmlformats-officedocument.presentationml.presentation": ".pptx", + "application/epub+zip": ".epub", + "text/html": ".html", +} + +# Filename extensions markitdown can extract as rich documents. +_DOC_EXTS: dict[str, str] = { + ".pdf": ".pdf", + ".docx": ".docx", + ".xlsx": ".xlsx", + ".pptx": ".pptx", + ".epub": ".epub", + ".html": ".html", + ".htm": ".html", +} + + +def _normalize_mime(mime_type: Optional[str]) -> str: + if not mime_type: + return "" + mime_type = mime_type.strip().lower() + if ";" in mime_type: + mime_type = mime_type.split(";", 1)[0].strip() + return mime_type + + +def _is_text_like(mime_type: Optional[str]) -> bool: + """Whether a MIME type can be inlined as raw text (text/html excluded).""" + mime_type = _normalize_mime(mime_type) + if mime_type == "text/html": + return False + if mime_type.startswith("text/"): + return True + return mime_type in { + "application/json", + "application/xml", + "application/x-ndjson", + "application/yaml", + "application/x-yaml", + } + + +def _doc_ext(mime_type: Optional[str], name: Optional[str]) -> str: + """Extension to use for markitdown, derived from filename then MIME type.""" + if name: + _, ext = os.path.splitext(name.lower()) + if ext in _DOC_EXTS: + return _DOC_EXTS[ext] + return _DOC_MIME_TO_EXT.get(_normalize_mime(mime_type), "") + + +def _extract_doc_text(data: bytes, ext: str) -> str: + """Write bytes to a temp file (markitdown detects by extension) and extract.""" + # Lazy import so the package still imports if markitdown is unavailable. + from markitdown import MarkItDown + + fd, tmp_path = tempfile.mkstemp(suffix=ext) + try: + with os.fdopen(fd, "wb") as f: + f.write(data) + result = MarkItDown().convert(tmp_path) + return result.text_content or "" + finally: + os.remove(tmp_path) + + +def extract_file_text(data: bytes, mime_type: Optional[str], name: Optional[str]) -> str: + """Extract readable text from an uploaded file's bytes. + + Raises ValueError for formats that cannot be represented as text. + """ + ext = _doc_ext(mime_type, name) + if ext: + return _extract_doc_text(data, ext) + if _is_text_like(mime_type): + return data.decode("utf-8", errors="replace") + raise ValueError(f"unsupported file type for text extraction: mime={mime_type!r} name={name!r}") + + +def inline_file_to_text(blob: types.Blob) -> Optional[str]: + """Convert a non-image inline file blob into text for a chat message. + + Returns ``None`` for an empty blob. On extraction failure returns a short + note so the model can tell the user the file could not be read, instead of + silently dropping it. + """ + if blob is None or not blob.data: + return None + name = blob.display_name or "file" + mime_type = blob.mime_type or "" + try: + text = extract_file_text(blob.data, mime_type, name) + except Exception: + return f'[Uploaded file "{name}" ({mime_type}) could not be read as text.]' + if not text.strip(): + return f'[Uploaded file "{name}" ({mime_type}) contained no extractable text.]' + return f'Contents of uploaded file "{name}":\n\n{text}' diff --git a/python/packages/kagent-adk/src/kagent/adk/models/_openai.py b/python/packages/kagent-adk/src/kagent/adk/models/_openai.py index ae6fd6df46..64dbd208ba 100644 --- a/python/packages/kagent-adk/src/kagent/adk/models/_openai.py +++ b/python/packages/kagent-adk/src/kagent/adk/models/_openai.py @@ -32,6 +32,7 @@ from openai.types.shared_params import FunctionDefinition, FunctionParameters from pydantic import Field +from ._file_extract import inline_file_to_text from ._ssl import KAgentTLSMixin from ._token_source import GDCHTokenSource @@ -152,14 +153,20 @@ def _convert_content_to_openai_messages( function_calls.append(part.function_call) elif part.function_response: function_responses.append(part.function_response) - elif part.inline_data and part.inline_data.mime_type and part.inline_data.mime_type.startswith("image"): - if part.inline_data.data: + elif part.inline_data and part.inline_data.data: + if part.inline_data.mime_type and part.inline_data.mime_type.startswith("image"): image_data = base64.b64encode(part.inline_data.data).decode() image_part: ChatCompletionContentPartImageParam = { "type": "image_url", "image_url": {"url": f"data:{part.inline_data.mime_type};base64,{image_data}"}, } image_parts.append(image_part) + else: + # Non-image files (PDF, Office docs, text) are extracted to + # text so the model can read them instead of dropping them. + file_text = inline_file_to_text(part.inline_data) + if file_text: + text_parts.append(file_text) # Function responses are now handled together with function calls # This ensures proper pairing and prevents orphaned tool messages diff --git a/python/packages/kagent-adk/tests/unittests/models/test_file_extract.py b/python/packages/kagent-adk/tests/unittests/models/test_file_extract.py new file mode 100644 index 0000000000..cfd579e0bf --- /dev/null +++ b/python/packages/kagent-adk/tests/unittests/models/test_file_extract.py @@ -0,0 +1,90 @@ +"""Tests for uploaded-file text extraction (_file_extract).""" + +import pytest +from google.genai import types + +from kagent.adk.models._file_extract import ( + _doc_ext, + _is_text_like, + extract_file_text, + inline_file_to_text, +) + + +@pytest.mark.parametrize( + "mime,want", + [ + ("text/plain", True), + ("text/markdown", True), + ("text/csv", True), + ("application/json", True), + ("text/plain; charset=utf-8", True), + ("text/html", False), # routed through extraction instead + ("application/pdf", False), + ("image/png", False), + ], +) +def test_is_text_like(mime, want): + assert _is_text_like(mime) is want + + +@pytest.mark.parametrize( + "mime,name,want", + [ + ("application/pdf", "report", ".pdf"), + ("application/octet-stream", "report.pdf", ".pdf"), + ( + "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + "x", + ".docx", + ), + ("", "page.htm", ".html"), + ("text/plain", "a.txt", ""), + ("application/zip", "a.zip", ""), + ], +) +def test_doc_ext(mime, name, want): + assert _doc_ext(mime, name) == want + + +@pytest.mark.parametrize( + "data,mime,name,want", + [ + (b"hello world", "text/plain", "a.txt", "hello world"), + (b'{"k":"v"}', "application/json", "a.json", '{"k":"v"}'), + (b"a,b\n1,2", "text/csv", "a.csv", "a,b\n1,2"), + ], +) +def test_extract_file_text_text_like(data, mime, name, want): + assert extract_file_text(data, mime, name) == want + + +def test_extract_file_text_unsupported_raises(): + with pytest.raises(ValueError): + extract_file_text(b"\x00\x01\x02", "application/zip", "a.zip") + + +def test_inline_file_to_text_none_for_empty(): + assert inline_file_to_text(types.Blob(data=b"", mime_type="text/plain")) is None + + +def test_inline_file_to_text_labels_text_file(): + out = inline_file_to_text(types.Blob(data=b"line1\nline2", mime_type="text/plain", display_name="notes.txt")) + assert out is not None + assert 'Contents of uploaded file "notes.txt"' in out + assert "line1" in out + + +def test_inline_file_to_text_note_on_unsupported(): + out = inline_file_to_text(types.Blob(data=b"\x00\x01", mime_type="application/zip", display_name="a.zip")) + assert out is not None + assert "could not be read as text" in out + + +def test_extract_file_text_html_via_markitdown(): + """Exercises the real markitdown extraction path (skipped if not installed).""" + pytest.importorskip("markitdown") + html = b"

Invoice

Total: 42 USD

" + out = extract_file_text(html, "text/html", "invoice.html") + assert "Invoice" in out + assert "42 USD" in out diff --git a/python/packages/kagent-adk/tests/unittests/models/test_openai.py b/python/packages/kagent-adk/tests/unittests/models/test_openai.py index 5dba0e52ac..f37670e214 100644 --- a/python/packages/kagent-adk/tests/unittests/models/test_openai.py +++ b/python/packages/kagent-adk/tests/unittests/models/test_openai.py @@ -995,3 +995,56 @@ def test_round_trip_preserves_thought_signature_for_follow_up_tool_result(self): tool_messages = [m for m in messages if m["role"] == "tool"] assert len(tool_messages) == 1 assert tool_messages[0]["extra_content"] == {"google": {"thought_signature": "YWJj"}} + + +class TestConvertContentNonImageFiles: + """Non-image uploads must reach the model as extracted text, not be dropped.""" + + def test_text_file_inline_data_becomes_user_text(self): + contents = [ + Content( + role="user", + parts=[ + Part(text="explain attached invoice"), + Part( + inline_data=types.Blob( + data=b"vendor,amount\nAcme,42", + mime_type="text/csv", + display_name="invoice.csv", + ) + ), + ], + ) + ] + + messages = _convert_content_to_openai_messages(contents) + + user_messages = [m for m in messages if m["role"] == "user"] + assert len(user_messages) == 1 + content = user_messages[0]["content"] + assert isinstance(content, str) + assert "explain attached invoice" in content + assert "invoice.csv" in content + assert "Acme,42" in content + + def test_unsupported_binary_inline_data_adds_note(self): + contents = [ + Content( + role="user", + parts=[ + Part( + inline_data=types.Blob( + data=b"\x00\x01\x02", + mime_type="application/zip", + display_name="archive.zip", + ) + ), + ], + ) + ] + + messages = _convert_content_to_openai_messages(contents) + + user_messages = [m for m in messages if m["role"] == "user"] + assert len(user_messages) == 1 + assert "could not be read as text" in user_messages[0]["content"] diff --git a/ui/src/components/chat/ChatInterface.tsx b/ui/src/components/chat/ChatInterface.tsx index 98b828a54d..80fb274b72 100644 --- a/ui/src/components/chat/ChatInterface.tsx +++ b/ui/src/components/chat/ChatInterface.tsx @@ -2,7 +2,7 @@ import type React from "react"; import { useState, useRef, useEffect, useMemo } from "react"; -import { ArrowBigUp, X, Loader2, Mic, Square } from "lucide-react"; +import { ArrowBigUp, X, Loader2, Mic, Square, Paperclip } from "lucide-react"; import { Button } from "@/components/ui/button"; import { Tooltip, @@ -32,7 +32,8 @@ import { formatA2AClientError } from "@/lib/a2aErrors"; import { useChatRunInSandbox, useChatSubstrateSandbox } from "@/components/chat/ChatAgentContext"; import { v4 as uuidv4 } from "uuid"; import { getStatusPlaceholder, mapA2AStateToStatus } from "@/lib/statusUtils"; -import { Message, DataPart, Task, TaskState } from "@a2a-js/sdk"; +import { Message, DataPart, FilePart, Task, TaskState } from "@a2a-js/sdk"; +import { FILE_ACCEPT, MAX_FILE_BYTES, fileToFilePart, isAllowedFile } from "@/lib/fileUpload"; // Task states where the agent is actively processing — resubscribe to live stream. const RESUBSCRIBE_TASK_STATES: TaskState[] = ["submitted", "working"]; @@ -52,6 +53,8 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se const substrateSandbox = useChatSubstrateSandbox(); const router = useRouter(); const containerRef = useRef(null); + const fileInputRef = useRef(null); + const [pendingFiles, setPendingFiles] = useState([]); const [currentInputMessage, setCurrentInputMessage] = useState(""); const [chatStatus, setChatStatus] = useState("ready"); @@ -225,9 +228,38 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se + const handleFilesSelected = (files: FileList | null) => { + if (!files || files.length === 0) return; + const accepted: File[] = []; + for (const file of Array.from(files)) { + if (!isAllowedFile(file)) { + toast.error(`"${file.name}" is not an allowed file type`); + continue; + } + if (file.size > MAX_FILE_BYTES) { + toast.error(`"${file.name}" exceeds the 10 MB limit`); + continue; + } + accepted.push(file); + } + if (accepted.length > 0) { + setPendingFiles(prev => [...prev, ...accepted]); + } + }; + + const handleFileInputChange = (e: React.ChangeEvent) => { + handleFilesSelected(e.target.files); + // Reset so selecting the same file again re-triggers onChange. + e.target.value = ""; + }; + + const removePendingFile = (index: number) => { + setPendingFiles(prev => prev.filter((_, i) => i !== index)); + }; + const handleSendMessage = async (e: React.FormEvent) => { e.preventDefault(); - if (!currentInputMessage.trim() || !selectedAgentName || !selectedNamespace) { + if ((!currentInputMessage.trim() && pendingFiles.length === 0) || !selectedAgentName || !selectedNamespace) { return; } @@ -238,6 +270,16 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se const userMessageText = currentInputMessage; + let fileParts: FilePart[] = []; + if (pendingFiles.length > 0) { + try { + fileParts = await Promise.all(pendingFiles.map(fileToFilePart)); + } catch (err) { + toast.error(`Failed to read file: ${err instanceof Error ? err.message : "unknown error"}`); + return; + } + } + // Cross-tab guard: fetch the latest session state before mutating anything. // Two cases: (1) another tab is still streaming — reconnect instead of sending; // (2) another tab completed a turn we haven't loaded — reload so the user sees @@ -259,6 +301,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se } setCurrentInputMessage(""); + setPendingFiles([]); setChatStatus("thinking"); setStoredMessages(prev => [...prev, ...streamingMessages]); setStreamingMessages([]); @@ -275,10 +318,10 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se kind: "message", messageId, role: "user", - parts: [{ - kind: "text", - text: userMessageText - }], + parts: [ + { kind: "text", text: userMessageText }, + ...fileParts, + ], contextId: guardSessionId, metadata: { timestamp: Date.now() @@ -307,9 +350,12 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se isCreatingSessionRef.current = true; setIsFirstMessage(true); + const sessionName = userMessageText.trim() + ? deriveSessionTitle(userMessageText) + : (fileParts[0]?.file.name ?? "File upload"); const newSessionResponse = await createSession({ agent_ref: `${selectedNamespace}/${selectedAgentName}`, - name: deriveSessionTitle(userMessageText), + name: sessionName, }); if (newSessionResponse.error || !newSessionResponse.data) { @@ -382,6 +428,7 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se const a2aMessage = createMessage(userMessageText, "user", { messageId, contextId: currentSessionId, + fileParts, }); await streamA2AMessage(a2aMessage, { @@ -1021,7 +1068,37 @@ export default function ChatInterface({ selectedAgentName, selectedNamespace, se {sessionStats.total > 0 && } + {pendingFiles.length > 0 && ( +
+ {pendingFiles.map((file, index) => ( +
+ + {file.name} + +
+ ))} +
+ )} +
+