diff --git a/internal/cmd/calendar_build.go b/internal/cmd/calendar_build.go index eb3537481..149ff55c6 100644 --- a/internal/cmd/calendar_build.go +++ b/internal/cmd/calendar_build.go @@ -57,6 +57,16 @@ func extractTimezone(value string) string { return candidate } } + + // Fallback for fixed whole-hour offsets when no regional timezone match is found. + // NOTE: IANA "Etc/GMT" names use reversed signs (e.g. +02:00 => Etc/GMT-2). + if offset%3600 == 0 { + hours := offset / 3600 + if hours > 0 { + return fmt.Sprintf("Etc/GMT-%d", hours) + } + return fmt.Sprintf("Etc/GMT+%d", -hours) + } return "" } diff --git a/internal/cmd/calendar_build_test.go b/internal/cmd/calendar_build_test.go index 58277ef2e..74a7dfec9 100644 --- a/internal/cmd/calendar_build_test.go +++ b/internal/cmd/calendar_build_test.go @@ -17,7 +17,8 @@ func TestExtractTimezone(t *testing.T) { {"2026-01-08T16:00:00Z", "UTC"}, {"2026-01-08T11:00:00+00:00", "UTC"}, {"invalid", ""}, - {"2026-01-08T11:00:00-04:00", ""}, // not a common US offset on this date + {"2026-01-08T11:00:00-04:00", "Etc/GMT+4"}, + {"2026-01-08T11:00:00+02:00", "Etc/GMT-2"}, {"2026-01-08T11:00:00+05:30", ""}, // India - not mapped } diff --git a/internal/cmd/calendar_create_update_test.go b/internal/cmd/calendar_create_update_test.go index 0439f2301..308b30afc 100644 --- a/internal/cmd/calendar_create_update_test.go +++ b/internal/cmd/calendar_create_update_test.go @@ -120,6 +120,63 @@ func TestCalendarCreateCmd_WithMeetAndAttachments(t *testing.T) { } } +func TestCalendarCreateCmd_RecurringOffsetTimezoneFallback(t *testing.T) { + origNew := newCalendarService + t.Cleanup(func() { newCalendarService = origNew }) + + var gotEvent calendar.Event + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + path := strings.TrimPrefix(r.URL.Path, "/calendar/v3") + if r.Method == http.MethodPost && path == "/calendars/cal/events" { + _ = json.NewDecoder(r.Body).Decode(&gotEvent) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "ev3", + }) + return + } + http.NotFound(w, r) + })) + defer srv.Close() + + svc, err := calendar.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newCalendarService = func(context.Context, string) (*calendar.Service, error) { return svc, nil } + + u, err := ui.New(ui.Options{Stdout: os.Stdout, Stderr: os.Stderr, Color: "never"}) + if err != nil { + t.Fatalf("ui.New: %v", err) + } + ctx := outfmt.WithMode(ui.WithUI(context.Background(), u), outfmt.Mode{JSON: true}) + + cmd := &CalendarCreateCmd{} + if err := runKong(t, cmd, []string{ + "cal", + "--summary", "Recurring Test", + "--from", "2026-02-13T08:00:00+02:00", + "--to", "2026-02-13T09:00:00+02:00", + "--rrule", "FREQ=WEEKLY;BYDAY=TU,TH", + }, ctx, &RootFlags{Account: "a@b.com"}); err != nil { + t.Fatalf("runKong: %v", err) + } + + if gotEvent.Start == nil || gotEvent.Start.TimeZone != "Etc/GMT-2" { + t.Fatalf("expected start timezone fallback Etc/GMT-2, got %#v", gotEvent.Start) + } + if gotEvent.End == nil || gotEvent.End.TimeZone != "Etc/GMT-2" { + t.Fatalf("expected end timezone fallback Etc/GMT-2, got %#v", gotEvent.End) + } + if len(gotEvent.Recurrence) == 0 { + t.Fatalf("expected recurrence to be set") + } +} + func TestCalendarUpdateCmd_RunJSON(t *testing.T) { origNew := newCalendarService t.Cleanup(func() { newCalendarService = origNew }) diff --git a/internal/cmd/execute_drive_more_commands_test.go b/internal/cmd/execute_drive_more_commands_test.go index 6d9fadb8d..09b7e9d9f 100644 --- a/internal/cmd/execute_drive_more_commands_test.go +++ b/internal/cmd/execute_drive_more_commands_test.go @@ -59,6 +59,9 @@ func TestExecute_DriveMoreCommands_JSON(t *testing.T) { }) return case strings.Contains(path, "/files/id1") && r.Method == http.MethodDelete: + if got := r.URL.Query().Get("supportsAllDrives"); got != "true" { + t.Fatalf("expected supportsAllDrives=true, got: %q (raw=%q)", got, r.URL.RawQuery) + } w.WriteHeader(http.StatusNoContent) return case strings.Contains(path, "/files/id1") && (r.Method == http.MethodPatch || r.Method == http.MethodPut): @@ -224,6 +227,9 @@ func TestExecute_DriveMoreCommands_Text(t *testing.T) { }) return case strings.Contains(path, "/files/id1") && r.Method == http.MethodDelete: + if got := r.URL.Query().Get("supportsAllDrives"); got != "true" { + t.Fatalf("expected supportsAllDrives=true, got: %q (raw=%q)", got, r.URL.RawQuery) + } w.WriteHeader(http.StatusNoContent) return case strings.Contains(path, "/files/id1") && (r.Method == http.MethodPatch || r.Method == http.MethodPut): diff --git a/internal/cmd/execute_gmail_attachment_test.go b/internal/cmd/execute_gmail_attachment_test.go index 7d025249a..d193f6a7a 100644 --- a/internal/cmd/execute_gmail_attachment_test.go +++ b/internal/cmd/execute_gmail_attachment_test.go @@ -21,6 +21,7 @@ func TestExecute_GmailAttachment_OutPath_JSON(t *testing.T) { t.Cleanup(func() { newGmailService = origNew }) var attachmentCalls int32 + var messageCalls int32 // 2 bytes => base64 has padding; exercises padded-base64 fallback decode path. attachmentData := []byte("ab") attachmentEncoded := base64.URLEncoding.EncodeToString(attachmentData) @@ -32,10 +33,25 @@ func TestExecute_GmailAttachment_OutPath_JSON(t *testing.T) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{"data": attachmentEncoded}) return + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/"): + atomic.AddInt32(&messageCalls, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m1", + "payload": map[string]any{ + "parts": []map[string]any{ + { + "filename": "a.bin", + "body": map[string]any{ + "attachmentId": "a1", + "size": len(attachmentData), + }, + }, + }, + }, + }) + return default: - if strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/") { - t.Fatalf("unexpected messages.get call: %s", r.URL.Path) - } http.NotFound(w, r) return } @@ -75,6 +91,9 @@ func TestExecute_GmailAttachment_OutPath_JSON(t *testing.T) { } parsed1 := run() + if atomic.LoadInt32(&messageCalls) != 1 { + t.Fatalf("messageCalls=%d", messageCalls) + } if atomic.LoadInt32(&attachmentCalls) != 1 { t.Fatalf("attachmentCalls=%d", attachmentCalls) } @@ -97,6 +116,9 @@ func TestExecute_GmailAttachment_OutPath_JSON(t *testing.T) { } parsed2 := run() + if atomic.LoadInt32(&messageCalls) != 2 { + t.Fatalf("messageCalls=%d", messageCalls) + } if atomic.LoadInt32(&attachmentCalls) != 1 { t.Fatalf("attachmentCalls=%d", attachmentCalls) } @@ -119,10 +141,24 @@ func TestExecute_GmailAttachment_NameOverride_ConfigDir_JSON(t *testing.T) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{"data": attachmentEncoded}) return + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/"): + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m1", + "payload": map[string]any{ + "parts": []map[string]any{ + { + "filename": "override.bin", + "body": map[string]any{ + "attachmentId": "a1", + "size": len(attachmentData), + }, + }, + }, + }, + }) + return default: - if strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/") { - t.Fatalf("unexpected messages.get call: %s", r.URL.Path) - } http.NotFound(w, r) return } @@ -174,14 +210,31 @@ func TestExecute_GmailAttachment_NotFound(t *testing.T) { t.Cleanup(func() { newGmailService = origNew }) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1/attachments/") { + switch { + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1/attachments/"): + http.NotFound(w, r) + return + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/"): + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m1", + "payload": map[string]any{ + "parts": []map[string]any{ + { + "filename": "a.bin", + "body": map[string]any{ + "attachmentId": "a1", + "size": 2, + }, + }, + }, + }, + }) + return + default: http.NotFound(w, r) return } - if strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") { - t.Fatalf("unexpected messages.get call: %s", r.URL.Path) - } - http.NotFound(w, r) })) defer srv.Close() @@ -210,3 +263,183 @@ func TestExecute_GmailAttachment_NotFound(t *testing.T) { t.Fatalf("expected no file written, stat=%v", statErr) } } + +func TestExecute_GmailAttachment_OutDirWithName_JSON(t *testing.T) { + origNew := newGmailService + t.Cleanup(func() { newGmailService = origNew }) + + var attachmentCalls int32 + attachmentData := []byte("hello") + attachmentEncoded := base64.URLEncoding.EncodeToString(attachmentData) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1/attachments/a1"): + atomic.AddInt32(&attachmentCalls, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"data": attachmentEncoded}) + return + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/"): + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m1", + "payload": map[string]any{ + "parts": []map[string]any{ + { + "filename": "ignored.bin", + "body": map[string]any{ + "attachmentId": "a1", + "size": len(attachmentData), + }, + }, + }, + }, + }) + return + default: + http.NotFound(w, r) + return + } + })) + defer srv.Close() + + svc, err := gmail.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newGmailService = func(context.Context, string) (*gmail.Service, error) { return svc, nil } + + outDir := t.TempDir() + wantPath := filepath.Join(outDir, "invoice.pdf") + + run := func() map[string]any { + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if execErr := Execute([]string{ + "--json", + "--account", "a@b.com", + "gmail", "attachment", "m1", "a1", + "--out", outDir, + "--name", "invoice.pdf", + }); execErr != nil { + t.Fatalf("Execute: %v", execErr) + } + }) + }) + + var parsed map[string]any + if unmarshalErr := json.Unmarshal([]byte(out), &parsed); unmarshalErr != nil { + t.Fatalf("json parse: %v\nout=%q", unmarshalErr, out) + } + return parsed + } + + parsed1 := run() + if parsed1["path"] != wantPath { + t.Fatalf("path=%v want=%s", parsed1["path"], wantPath) + } + if parsed1["cached"] != false { + t.Fatalf("cached=%v", parsed1["cached"]) + } + if atomic.LoadInt32(&attachmentCalls) != 1 { + t.Fatalf("attachmentCalls=%d", attachmentCalls) + } + + parsed2 := run() + if parsed2["cached"] != true { + t.Fatalf("cached=%v", parsed2["cached"]) + } + if atomic.LoadInt32(&attachmentCalls) != 1 { + t.Fatalf("attachmentCalls=%d", attachmentCalls) + } +} + +func TestExecute_GmailAttachment_StaleFileIsRedownloaded(t *testing.T) { + origNew := newGmailService + t.Cleanup(func() { newGmailService = origNew }) + + var attachmentCalls int32 + attachmentData := []byte("fresh-bytes") + attachmentEncoded := base64.URLEncoding.EncodeToString(attachmentData) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1/attachments/a1"): + atomic.AddInt32(&attachmentCalls, 1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{"data": attachmentEncoded}) + return + case strings.Contains(r.URL.Path, "/gmail/v1/users/me/messages/m1") && !strings.Contains(r.URL.Path, "/attachments/"): + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "m1", + "payload": map[string]any{ + "parts": []map[string]any{ + { + "filename": "a.bin", + "body": map[string]any{ + "attachmentId": "a1", + "size": len(attachmentData), + }, + }, + }, + }, + }) + return + default: + http.NotFound(w, r) + return + } + })) + defer srv.Close() + + svc, err := gmail.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + newGmailService = func(context.Context, string) (*gmail.Service, error) { return svc, nil } + + outPath := filepath.Join(t.TempDir(), "invoice.pdf") + if writeErr := os.WriteFile(outPath, []byte("stale"), 0o600); writeErr != nil { + t.Fatalf("WriteFile: %v", err) + } + + out := captureStdout(t, func() { + _ = captureStderr(t, func() { + if execErr := Execute([]string{ + "--json", + "--account", "a@b.com", + "gmail", "attachment", "m1", "a1", + "--out", outPath, + }); execErr != nil { + t.Fatalf("Execute: %v", execErr) + } + }) + }) + + var parsed map[string]any + if unmarshalErr := json.Unmarshal([]byte(out), &parsed); unmarshalErr != nil { + t.Fatalf("json parse: %v\nout=%q", unmarshalErr, out) + } + if parsed["cached"] != false { + t.Fatalf("cached=%v", parsed["cached"]) + } + if atomic.LoadInt32(&attachmentCalls) != 1 { + t.Fatalf("attachmentCalls=%d", attachmentCalls) + } + b, err := os.ReadFile(outPath) + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(b) != string(attachmentData) { + t.Fatalf("content=%q", string(b)) + } +} diff --git a/internal/cmd/gmail_attachment.go b/internal/cmd/gmail_attachment.go index 4e07f6dc9..b1eeb6ece 100644 --- a/internal/cmd/gmail_attachment.go +++ b/internal/cmd/gmail_attachment.go @@ -20,7 +20,7 @@ type GmailAttachmentCmd struct { MessageID string `arg:"" name:"messageId" help:"Message ID"` AttachmentID string `arg:"" name:"attachmentId" help:"Attachment ID"` Output OutputPathFlag `embed:""` - Name string `name:"name" help:"Filename (only used when --out is empty)"` + Name string `name:"name" help:"Filename (used when --out is empty or points to a directory)"` } func (c *GmailAttachmentCmd) Run(ctx context.Context, flags *RootFlags) error { @@ -40,42 +40,12 @@ func (c *GmailAttachmentCmd) Run(ctx context.Context, flags *RootFlags) error { return err } - if strings.TrimSpace(c.Output.Path) == "" { - dir, dirErr := config.EnsureGmailAttachmentsDir() - if dirErr != nil { - return dirErr - } - filename := strings.TrimSpace(c.Name) - if filename == "" { - filename = "attachment.bin" - } - safeFilename := filepath.Base(filename) - if safeFilename == "" || safeFilename == "." || safeFilename == ".." { - safeFilename = "attachment.bin" - } - shortID := attachmentID - if len(shortID) > 8 { - shortID = shortID[:8] - } - destPath := filepath.Join(dir, fmt.Sprintf("%s_%s_%s", messageID, shortID, safeFilename)) - path, cached, bytes, dlErr := downloadAttachmentToPath(ctx, svc, messageID, attachmentID, destPath, -1) - if dlErr != nil { - return dlErr - } - if outfmt.IsJSON(ctx) { - return outfmt.WriteJSON(os.Stdout, map[string]any{"path": path, "cached": cached, "bytes": bytes}) - } - u.Out().Printf("path\t%s", path) - u.Out().Printf("cached\t%t", cached) - u.Out().Printf("bytes\t%d", bytes) - return nil - } - - outPath, err := config.ExpandPath(c.Output.Path) + destPath, err := resolveAttachmentOutputPath(messageID, attachmentID, c.Output.Path, c.Name) if err != nil { return err } - path, cached, bytes, err := downloadAttachmentToPath(ctx, svc, messageID, attachmentID, outPath, -1) + expectedSize := lookupAttachmentSizeEstimate(ctx, svc, messageID, attachmentID) + path, cached, bytes, err := downloadAttachmentToPath(ctx, svc, messageID, attachmentID, destPath, expectedSize) if err != nil { return err } @@ -88,6 +58,66 @@ func (c *GmailAttachmentCmd) Run(ctx context.Context, flags *RootFlags) error { return nil } +func resolveAttachmentOutputPath(messageID, attachmentID, outPathFlag, name string) (string, error) { + shortID := attachmentID + if len(shortID) > 8 { + shortID = shortID[:8] + } + safeFilename := sanitizeAttachmentFilename(name, "attachment.bin") + + if strings.TrimSpace(outPathFlag) == "" { + dir, err := config.EnsureGmailAttachmentsDir() + if err != nil { + return "", err + } + return filepath.Join(dir, fmt.Sprintf("%s_%s_%s", messageID, shortID, safeFilename)), nil + } + + outPath, err := config.ExpandPath(outPathFlag) + if err != nil { + return "", err + } + + if st, statErr := os.Stat(outPath); statErr == nil && st.IsDir() { + filename := safeFilename + if strings.TrimSpace(name) == "" { + filename = fmt.Sprintf("%s_%s_attachment.bin", messageID, shortID) + } + return filepath.Join(outPath, filename), nil + } + + // Treat paths ending with a separator as directory targets even if they don't exist yet. + if outPath != "" && os.IsPathSeparator(outPath[len(outPath)-1]) { + return filepath.Join(outPath, safeFilename), nil + } + + return outPath, nil +} + +func sanitizeAttachmentFilename(name, fallback string) string { + safeFilename := filepath.Base(strings.TrimSpace(name)) + if safeFilename == "" || safeFilename == "." || safeFilename == ".." { + return fallback + } + return safeFilename +} + +func lookupAttachmentSizeEstimate(ctx context.Context, svc *gmail.Service, messageID, attachmentID string) int64 { + if svc == nil { + return -1 + } + msg, err := svc.Users.Messages.Get("me", messageID).Format("full").Fields("payload").Context(ctx).Do() + if err != nil || msg == nil { + return -1 + } + for _, a := range collectAttachments(msg.Payload) { + if a.AttachmentID == attachmentID && a.Size > 0 { + return a.Size + } + } + return -1 +} + func downloadAttachmentToPath( ctx context.Context, svc *gmail.Service, @@ -100,16 +130,16 @@ func downloadAttachmentToPath( return "", false, 0, errors.New("missing outPath") } - if expectedSize > 0 { - if st, err := os.Stat(outPath); err == nil && st.Size() == expectedSize { - return outPath, true, st.Size(), nil - } - } else if expectedSize == -1 { - if st, err := os.Stat(outPath); err == nil && st.Size() > 0 { + if st, err := os.Stat(outPath); err == nil && st.Mode().IsRegular() { + if expectedSize > 0 && st.Size() == expectedSize { return outPath, true, st.Size(), nil } } + if svc == nil { + return "", false, 0, errors.New("missing gmail service") + } + body, err := svc.Users.Messages.Attachments.Get("me", messageID, attachmentID).Context(ctx).Do() if err != nil { return "", false, 0, err diff --git a/internal/cmd/gmail_attachment_more_test.go b/internal/cmd/gmail_attachment_more_test.go index 7c000af79..f6de4a58c 100644 --- a/internal/cmd/gmail_attachment_more_test.go +++ b/internal/cmd/gmail_attachment_more_test.go @@ -37,16 +37,31 @@ func TestDownloadAttachmentToPath_CachedBySize(t *testing.T) { func TestDownloadAttachmentToPath_CachedByAnySize(t *testing.T) { path := filepath.Join(t.TempDir(), "b.bin") - if err := os.WriteFile(path, []byte("abcd"), 0o600); err != nil { + if err := os.WriteFile(path, []byte("stale"), 0o600); err != nil { t.Fatalf("WriteFile: %v", err) } - gotPath, cached, bytes, err := downloadAttachmentToPath(context.Background(), nil, "m1", "a1", path, -1) + srv := httptestServerForAttachment(t, base64.RawURLEncoding.EncodeToString([]byte("fresh"))) + + gsvc, err := gmail.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + gotPath, cached, bytes, err := downloadAttachmentToPath(context.Background(), gsvc, "m1", "a1", path, -1) if err != nil { t.Fatalf("downloadAttachmentToPath: %v", err) } - if gotPath != path || !cached || bytes != 4 { + if gotPath != path || cached || bytes != 5 { t.Fatalf("unexpected result: path=%q cached=%v bytes=%d", gotPath, cached, bytes) } + if data, err := os.ReadFile(path); err != nil { + t.Fatalf("ReadFile: %v", err) + } else if string(data) != "fresh" { + t.Fatalf("unexpected data: %q", string(data)) + } } func TestDownloadAttachmentToPath_Base64Fallback(t *testing.T) { @@ -94,6 +109,108 @@ func TestDownloadAttachmentToPath_EmptyData(t *testing.T) { } } +func TestDownloadAttachmentToPath_DirectoryNotCacheHit(t *testing.T) { + dir := t.TempDir() + srv := httptestServerForAttachment(t, base64.RawURLEncoding.EncodeToString([]byte("x"))) + + gsvc, err := gmail.NewService(context.Background(), + option.WithoutAuthentication(), + option.WithHTTPClient(srv.Client()), + option.WithEndpoint(srv.URL+"/"), + ) + if err != nil { + t.Fatalf("NewService: %v", err) + } + + if _, _, _, err := downloadAttachmentToPath(context.Background(), gsvc, "m1", "a1", dir, -1); err == nil { + t.Fatalf("expected error for directory output path") + } +} + +func TestSanitizeAttachmentFilename(t *testing.T) { + tests := []struct { + name string + fallback string + want string + }{ + {"report.pdf", "attachment.bin", "report.pdf"}, + {"", "attachment.bin", "attachment.bin"}, + {" ", "attachment.bin", "attachment.bin"}, + {".", "attachment.bin", "attachment.bin"}, + {"..", "attachment.bin", "attachment.bin"}, + {"../../etc/passwd", "attachment.bin", "passwd"}, + {"../../../secret.txt", "attachment.bin", "secret.txt"}, + {"/absolute/path/file.txt", "attachment.bin", "file.txt"}, + {"dir/subdir/file.txt", "attachment.bin", "file.txt"}, + {"normal.txt", "fallback.dat", "normal.txt"}, + } + for _, tt := range tests { + got := sanitizeAttachmentFilename(tt.name, tt.fallback) + if got != tt.want { + t.Errorf("sanitizeAttachmentFilename(%q, %q) = %q, want %q", tt.name, tt.fallback, got, tt.want) + } + } +} + +func TestResolveAttachmentOutputPath(t *testing.T) { + t.Run("explicit file path", func(t *testing.T) { + path, err := resolveAttachmentOutputPath("m1", "a1", "/tmp/out.bin", "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if path != "/tmp/out.bin" { + t.Fatalf("got %q, want /tmp/out.bin", path) + } + }) + + t.Run("directory target appends filename", func(t *testing.T) { + dir := t.TempDir() + path, err := resolveAttachmentOutputPath("m1", "abcdefghij", dir, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join(dir, "m1_abcdefgh_attachment.bin") + if path != want { + t.Fatalf("got %q, want %q", path, want) + } + }) + + t.Run("directory target with custom name", func(t *testing.T) { + dir := t.TempDir() + path, err := resolveAttachmentOutputPath("m1", "abcdefghij", dir, "report.pdf") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join(dir, "report.pdf") + if path != want { + t.Fatalf("got %q, want %q", path, want) + } + }) + + t.Run("traversal in name is stripped", func(t *testing.T) { + dir := t.TempDir() + path, err := resolveAttachmentOutputPath("m1", "abcdefghij", dir, "../../etc/passwd") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join(dir, "passwd") + if path != want { + t.Fatalf("got %q, want %q", path, want) + } + }) + + t.Run("trailing separator treated as directory", func(t *testing.T) { + path, err := resolveAttachmentOutputPath("m1", "abcdefghij", "/tmp/newdir/", "report.pdf") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + want := filepath.Join("/tmp/newdir", "report.pdf") + if path != want { + t.Fatalf("got %q, want %q", path, want) + } + }) +} + func httptestServerForAttachment(t *testing.T, data string) *httptest.Server { t.Helper() return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 28012debe..6313929c4 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -6,6 +6,7 @@ import ( "fmt" "log/slog" "os" + "strconv" "github.com/alecthomas/kong" @@ -162,10 +163,7 @@ func envOr(key, fallback string) string { } func boolString(v bool) string { - if v { - return "true" - } - return "false" + return strconv.FormatBool(v) } func newParser(description string) (*kong.Kong, *CLI, error) { diff --git a/internal/googleapi/client.go b/internal/googleapi/client.go index b5ac74ea6..40a77d736 100644 --- a/internal/googleapi/client.go +++ b/internal/googleapi/client.go @@ -122,11 +122,7 @@ func optionsForAccountScopes(ctx context.Context, serviceLabel string, email str ts = tokenSource } } - baseTransport := &http.Transport{ - TLSClientConfig: &tls.Config{ - MinVersion: tls.VersionTLS12, - }, - } + baseTransport := newBaseTransport() // Wrap with retry logic for 429 and 5xx errors retryTransport := NewRetryTransport(&oauth2.Transport{ Source: ts, @@ -141,3 +137,28 @@ func optionsForAccountScopes(ctx context.Context, serviceLabel string, email str return []option.ClientOption{option.WithHTTPClient(c)}, nil } + +func newBaseTransport() *http.Transport { + defaultTransport, ok := http.DefaultTransport.(*http.Transport) + if !ok || defaultTransport == nil { + return &http.Transport{ + Proxy: http.ProxyFromEnvironment, + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + } + } + + // Clone() deep-copies TLSClientConfig, so no additional clone needed. + transport := defaultTransport.Clone() + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{MinVersion: tls.VersionTLS12} + return transport + } + + if transport.TLSClientConfig.MinVersion < tls.VersionTLS12 { + transport.TLSClientConfig.MinVersion = tls.VersionTLS12 + } + + return transport +} diff --git a/internal/googleapi/client_more_test.go b/internal/googleapi/client_more_test.go index 06ecce89c..d520b86d0 100644 --- a/internal/googleapi/client_more_test.go +++ b/internal/googleapi/client_more_test.go @@ -2,9 +2,12 @@ package googleapi import ( "context" + "crypto/tls" "errors" + "net/http" "os" "path/filepath" + "strings" "testing" "github.com/99designs/keyring" @@ -257,3 +260,38 @@ func TestOptionsForAccountScopes_ServiceAccountPreferred(t *testing.T) { t.Fatalf("expected client options") } } + +func TestNewBaseTransport_RespectsProxyAndTLSMinimum(t *testing.T) { + t.Setenv("HTTPS_PROXY", "http://127.0.0.1:8888") + + transport := newBaseTransport() + if transport == nil { + t.Fatalf("expected transport") + } + + if transport.Proxy == nil { + t.Fatalf("expected proxy func") + } + + if transport.TLSClientConfig == nil { + t.Fatalf("expected TLS config") + } + + if transport.TLSClientConfig.MinVersion < tls.VersionTLS12 { + t.Fatalf("expected TLS min version >= 1.2, got %d", transport.TLSClientConfig.MinVersion) + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "https://www.googleapis.com", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + + proxyURL, err := transport.Proxy(req) + if err != nil { + t.Fatalf("proxy lookup: %v", err) + } + + if proxyURL == nil || !strings.Contains(proxyURL.String(), "127.0.0.1:8888") { + t.Fatalf("expected HTTPS proxy to be honored, got: %v", proxyURL) + } +}