Skip to content
Merged
2 changes: 1 addition & 1 deletion src/cmd/cli/command/commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ var RootCmd = &cobra.Command{
return err
}

prompt := "Welcome to Defang. I can help you deploy your project to the cloud"
prompt := "Welcome to Defang. I can help you deploy your project to the cloud."
ag, err := agent.New(ctx, getCluster(), &global.ProviderID, &global.Stack)
if err != nil {
return err
Expand Down
36 changes: 23 additions & 13 deletions src/pkg/agent/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os"
"os/signal"
"regexp"
"time"

Expand Down Expand Up @@ -97,27 +98,29 @@ func New(ctx context.Context, clusterAddr string, providerId *client.ProviderID,

func (a *Agent) StartWithUserPrompt(ctx context.Context, userPrompt string) error {
a.printer.Printf("\n%s\n", userPrompt)
a.printer.Printf("Type '/exit' to quit.\n")
return a.startSession(ctx)
// The userPrompt is for the user only. Start the session with an empty message for the agent.
return a.startSession(ctx, "")
}

func (a *Agent) StartWithMessage(ctx context.Context, msg string) error {
return a.startSession(ctx, msg)
}

func (a *Agent) startSession(ctx context.Context, initialMessage string) error {
signal.Reset(os.Interrupt) // unsubscribe the top-level signal handler

a.printer.Printf("Type '/exit' to quit.\n")

if err := a.handleUserMessage(ctx, msg); err != nil {
return fmt.Errorf("error handling initial message: %w", err)
if initialMessage != "" {
if err := a.handleUserMessage(ctx, initialMessage); err != nil {
return fmt.Errorf("error handling initial message: %w", err)
}
}

return a.startSession(ctx)
}

func (a *Agent) startSession(ctx context.Context) error {
for {
var input string
err := survey.AskOne(
&survey.Input{
Message: "",
},
&survey.Input{Message: ""},
&input,
survey.WithStdio(term.DefaultTerm.Stdio()),
survey.WithIcons(func(icons *survey.IconSet) {
Expand All @@ -141,13 +144,20 @@ func (a *Agent) startSession(ctx context.Context) error {
}

if err := a.handleUserMessage(ctx, input); err != nil {
a.printer.Println("Error handling message: %v", err)
if errors.Is(err, context.Canceled) {
continue
}
a.printer.Println("Error handling message:", err)
}
}
}

func (a *Agent) handleUserMessage(ctx context.Context, msg string) error {
maxTurns := 8
// Handle Ctrl+C during message handling / tool calls
ctx, cancel := signal.NotifyContext(ctx, os.Interrupt)
defer cancel()

const maxTurns = 8
for {
err := a.generator.HandleMessage(ctx, a.system, maxTurns, ai.NewUserMessage(ai.NewTextPart(msg)))
if err == nil {
Expand Down
12 changes: 11 additions & 1 deletion src/pkg/agent/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agent
import (
"context"
"encoding/json"
"errors"

"github.com/DefangLabs/defang/src/pkg/term"
"github.com/firebase/genkit/go/ai"
Expand Down Expand Up @@ -61,12 +62,17 @@ func (e *maxTurnsReachedError) Error() string {
}

func (g *Generator) HandleMessage(ctx context.Context, prompt string, maxTurns int, message *ai.Message) error {
g.toolManager.ClearPrevious()

if message != nil {
g.messages = append(g.messages, message)
}
for range maxTurns {
resp, err := g.generate(ctx, prompt, g.messages)
if err != nil {
if errors.Is(err, context.Canceled) {
return err
}
term.Debugf("error: %v", err)
continue
}
Expand All @@ -78,7 +84,11 @@ func (g *Generator) HandleMessage(ctx context.Context, prompt string, maxTurns i
return nil
}

toolResp := g.toolManager.HandleToolCalls(ctx, toolRequests)
toolResp, err := g.toolManager.HandleToolCalls(ctx, toolRequests)
if err != nil {
// HandleToolCalls only ever returns "fatal" errors
return err
}
g.messages = append(g.messages, toolResp)
}

Expand Down
20 changes: 14 additions & 6 deletions src/pkg/agent/toolmanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,25 @@ func (t *ToolManager) RegisterTools(tools ...ai.Tool) {
}
}

func (t *ToolManager) HandleToolCalls(ctx context.Context, requests []*ai.ToolRequest) *ai.Message {
func (t *ToolManager) HandleToolCalls(ctx context.Context, requests []*ai.ToolRequest) (*ai.Message, error) {
if t.EqualPrevious(requests) {
return ai.NewMessage(ai.RoleTool, nil, ai.NewToolResponsePart(&ai.ToolResponse{
Name: "error",
Ref: "error",
Output: "The same tool request was made in the previous turn. To prevent infinite loops, no action was taken.",
}))
})), nil
}

parts := []*ai.Part{}
for _, req := range requests {
var part *ai.Part
toolResp, err := t.handleToolRequest(ctx, req)
if err != nil {
t.printer.Printf("! %v", err)
if errors.Is(err, context.Canceled) {
return nil, err
}
// If the error is not context.Canceled, let the agent know and respond
t.printer.Println("!", err)
part = ai.NewToolResponsePart(&ai.ToolResponse{
Name: req.Name,
Ref: req.Ref,
Expand All @@ -85,7 +89,7 @@ func (t *ToolManager) HandleToolCalls(ctx context.Context, requests []*ai.ToolRe
parts = append(parts, part)
}

return ai.NewMessage(ai.RoleTool, nil, parts...)
return ai.NewMessage(ai.RoleTool, nil, parts...), nil
}

func (t *ToolManager) handleToolRequest(ctx context.Context, req *ai.ToolRequest) (*ai.ToolResponse, error) {
Expand Down Expand Up @@ -129,8 +133,8 @@ func (t *ToolManager) EqualPrevious(toolRequests []*ai.ToolRequest) bool {

isEqual := len(newToolsRequestsJSON) == len(t.prevTurnToolRequestsJSON)
if isEqual {
for key := range newToolsRequestsJSON {
if !t.prevTurnToolRequestsJSON[key] {
for prevJSON := range newToolsRequestsJSON {
if !t.prevTurnToolRequestsJSON[prevJSON] {
isEqual = false
break
}
Expand All @@ -140,3 +144,7 @@ func (t *ToolManager) EqualPrevious(toolRequests []*ai.ToolRequest) bool {
t.prevTurnToolRequestsJSON = newToolsRequestsJSON
return isEqual
}

func (t *ToolManager) ClearPrevious() {
t.prevTurnToolRequestsJSON = make(map[string]bool)
}
4 changes: 2 additions & 2 deletions src/pkg/cli/client/byoc/aws/byoc.go
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,8 @@ func (b *ByocAws) getLogGroupInputs(etag types.ETag, projectName, service, filte
if b.driver.LogGroupARN == "" {
term.Debug("CD stack LogGroupARN is not set; skipping CD logs")
} else {
cdTail := ecs.LogGroupInput{LogGroupARN: b.driver.LogGroupARN, LogEventFilterPattern: pattern} // TODO: filter by etag
// If we know the CD task ARN, only tail the logstream for that CD task
cdTail := ecs.LogGroupInput{LogGroupARN: b.driver.LogGroupARN, LogEventFilterPattern: pattern}
// If we know the CD task ARN, only tail the logstream for that CD task; FIXME: store the task ID in the project's ProjectUpdate in S3 and use that
if b.cdTaskArn != nil && b.cdEtag == etag {
cdTail.LogStreamNames = []string{ecs.GetCDLogStreamForTaskID(ecs.GetTaskID(b.cdTaskArn))}
}
Expand Down
25 changes: 17 additions & 8 deletions src/pkg/logs/slog.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package logs
import (
"context"
"log/slog"
"strings"

"github.com/DefangLabs/defang/src/pkg/term"
)
Expand All @@ -20,23 +21,31 @@ func NewTermLogger(t *term.Term) *slog.Logger {
}

func (h *termHandler) Handle(ctx context.Context, r slog.Record) error {
msg := r.Message
// Format attrs if any
var attrs string
if r.NumAttrs() > 0 {
var builder strings.Builder
builder.WriteString(msg)
opened := false
r.Attrs(func(a slog.Attr) bool {
if attrs == "" {
attrs = " {"
if !opened {
builder.WriteString(" {")
opened = true
} else {
attrs += ", "
builder.WriteString(", ")
}
attrs += a.String()
strVal := a.String()
if len(strVal) > 80 {
runes := []rune(strVal)
strVal = string(runes[:77]) + "..."
}
builder.WriteString(strVal)
return true
})
attrs += "}"
builder.WriteString("}")
msg = builder.String()
}

msg := r.Message + attrs

switch r.Level {
case slog.LevelDebug:
_, err := h.t.Debug(msg)
Expand Down
18 changes: 15 additions & 3 deletions src/pkg/migrate/heroku.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,15 @@ func (h *HerokuClient) GetPGInfo(ctx context.Context, addonID string) (PGInfo, e
return herokuGet[PGInfo](ctx, h, url)
}

// herokuGet performs an HTTP GET to the given URL using the HerokuClient's token,
// decodes the JSON response into a value of type T, and returns that value.
// The request uses the provided context and sets Heroku-specific Accept and
// Content-Type headers. If the response status is >= 400 or the body cannot
// be decoded as JSON, an error is returned describing the failure.
func herokuGet[T any](ctx context.Context, h *HerokuClient, url string) (T, error) {
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return *new(T), fmt.Errorf("failed to create request: %v", err)
return *new(T), fmt.Errorf("failed to create request: %w", err)
}

// Set headers
Expand All @@ -301,7 +306,7 @@ func herokuGet[T any](ctx context.Context, h *HerokuClient, url string) (T, erro
var data T
decoder := json.NewDecoder(resp.Body)
if err := decoder.Decode(&data); err != nil {
return *new(T), fmt.Errorf("failed to unmarshal JSON: %v", err)
return *new(T), fmt.Errorf("failed to unmarshal JSON: %w", err)
}

return data, nil
Expand All @@ -328,10 +333,17 @@ func authenticateHerokuCLI() error {
return nil
}

// getHerokuAuthTokenFromCLI obtains a short-lived Heroku API token by invoking the local Heroku CLI.
//
// It checks that the `heroku` executable is available, ensures the CLI is authenticated, runs
// `heroku authorizations:create --expires-in=300 --json`, and parses the resulting JSON for the token.
//
// The returned string is the extracted access token. An error is returned if the CLI is not installed,
// authentication fails, the command cannot be executed, or the command output cannot be parsed.
func getHerokuAuthTokenFromCLI() (string, error) {
_, err := exec.LookPath("heroku")
if err != nil {
return "", fmt.Errorf("Heroku CLI is not installed: %v", err)
return "", fmt.Errorf("Heroku CLI is not installed: %w", err)
}
term.Info("The Heroku CLI is installed, we'll use it to generate a short-lived authorization token")
err = authenticateHerokuCLI()
Expand Down
Loading