diff --git a/src/cmd/cli/command/commands.go b/src/cmd/cli/command/commands.go index 1f8999a1d..385aa93c8 100644 --- a/src/cmd/cli/command/commands.go +++ b/src/cmd/cli/command/commands.go @@ -26,6 +26,7 @@ import ( "github.com/DefangLabs/defang/src/pkg/cluster" "github.com/DefangLabs/defang/src/pkg/debug" "github.com/DefangLabs/defang/src/pkg/dryrun" + "github.com/DefangLabs/defang/src/pkg/elicitations" "github.com/DefangLabs/defang/src/pkg/github" "github.com/DefangLabs/defang/src/pkg/login" "github.com/DefangLabs/defang/src/pkg/logs" @@ -195,6 +196,7 @@ func SetupCommands(ctx context.Context, version string) { return completions, cobra.ShellCompDirectiveNoFileComp }) // RootCmd.Flag("provider").NoOptDefVal = "auto" NO this will break the "--provider aws" + RootCmd.Flags().MarkDeprecated("provider", "please use --stack instead") RootCmd.PersistentFlags().BoolVarP(&global.Verbose, "verbose", "v", global.Verbose, "verbose logging") // backwards compat: only used by tail RootCmd.PersistentFlags().BoolVar(&global.Debug, "debug", global.Debug, "debug logging for troubleshooting the CLI") RootCmd.PersistentFlags().BoolVar(&dryrun.DoDryRun, "dry-run", false, "dry run (don't actually change anything)") @@ -552,7 +554,22 @@ var whoamiCmd = &cobra.Command{ loader := configureLoader(cmd) global.NonInteractive = true // don't show provider prompt - provider, err := newProvider(cmd.Context(), loader) + ctx := cmd.Context() + projectName, err := loader.LoadProjectName(ctx) + if err != nil { + term.Warnf("Unable to load project: %v", err) + } + elicitationsClient := elicitations.NewSurveyClient(os.Stdin, os.Stdout, os.Stderr) + ec := elicitations.NewController(elicitationsClient) + wd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + sm, err := stacks.NewManager(global.Client, wd, projectName) + if err != nil { + return fmt.Errorf("failed to create stack manager: %w", err) + } + provider, err := newProvider(cmd.Context(), ec, sm) if err != nil { term.Debug("unable to get provider:", err) } @@ -1281,41 +1298,149 @@ var providerDescription = map[cliClient.ProviderID]string{ cliClient.ProviderGCP: "Deploy to Google Cloud Platform using gcloud Application Default Credentials.", } -func updateProviderID(ctx context.Context, loader cliClient.Loader) error { - extraMsg := "" - whence := "default project" +func getStack(ctx context.Context, ec elicitations.Controller, sm stacks.Manager) (*stacks.StackParameters, string, error) { + stackSelector := stacks.NewSelector(ec, sm) + + var whence string + stack := &stacks.StackParameters{ + Name: "", + Provider: cliClient.ProviderAuto, + Mode: modes.ModeUnspecified, + } - // Command line flag takes precedence over environment variable + // This code unfortunately replicates the provider precedence rules in the + // RoomCmd's PersistentPreRunE func, I think we should avoid reading the + // stack file during startup, and only read it here instead. + if os.Getenv("DEFANG_STACK") != "" || RootCmd.PersistentFlags().Changed("stack") { + whence = "stack file" + stackName := os.Getenv("DEFANG_STACK") + if stackName == "" { + stackName = RootCmd.Flags().Lookup("stack").Value.String() + } + stackParams, err := sm.Load(stackName) + if err != nil { + return nil, "", fmt.Errorf("unable to load stack %q: %w", stackName, err) + } + stack = stackParams + + if stack.Provider == cliClient.ProviderAuto { + return nil, "", fmt.Errorf("stack %q has an invalid provider %q", stack.Name, stack.Provider) + } + return stack, whence, nil + } + + knownStacks, err := sm.List(ctx) + if err != nil { + return nil, "", fmt.Errorf("unable to list stacks: %w", err) + } + stackNames := make([]string, 0, len(knownStacks)) + for _, s := range knownStacks { + stackNames = append(stackNames, s.Name) + } if RootCmd.PersistentFlags().Changed("provider") { - whence = "command line flag" - } else if val, ok := os.LookupEnv("DEFANG_PROVIDER"); ok { - // Sanitize the provider value from the environment variable - if err := global.Stack.Provider.Set(val); err != nil { - return fmt.Errorf("invalid provider '%v' in environment variable DEFANG_PROVIDER, supported providers are: %v", val, cliClient.AllProviders()) + term.Warn("Warning: --provider flag is deprecated. Please use --stack instead. To learn about stacks, visit https://docs.defang.io/docs/concepts/stacks") + providerIDString := RootCmd.Flags().Lookup("provider").Value.String() + err := stack.Provider.Set(providerIDString) + if err != nil { + return nil, "", fmt.Errorf("invalid provider %q: %w", providerIDString, err) + } + } else if _, ok := os.LookupEnv("DEFANG_PROVIDER"); ok { + term.Warn("Warning: DEFANG_PROVIDER environment variable is deprecated. Please use --stack instead. To learn about stacks, visit https://docs.defang.io/docs/concepts/stacks") + providerIDString := os.Getenv("DEFANG_PROVIDER") + err := stack.Provider.Set(providerIDString) + if err != nil { + return nil, "", fmt.Errorf("invalid provider %q: %w", providerIDString, err) } - whence = "environment variable" + } + if global.NonInteractive && stack.Provider == cliClient.ProviderAuto { + whence = "non-interactive default" + stack.Name = "beta" + stack.Provider = cliClient.ProviderDefang + return stack, whence, nil } - switch global.Stack.Provider { - case cliClient.ProviderAuto: - if global.NonInteractive { - // Defaults to defang provider in non-interactive mode - if awsInEnv() { - term.Warn("Using Defang playground, but AWS environment variables were detected; did you forget --provider=aws or DEFANG_PROVIDER=aws?") - } - if doInEnv() { - term.Warn("Using Defang playground, but DIGITALOCEAN_TOKEN environment variable was detected; did you forget --provider=digitalocean or DEFANG_PROVIDER=digitalocean?") - } - if gcpInEnv() { - term.Warn("Using Defang playground, but GCP_PROJECT_ID/CLOUDSDK_CORE_PROJECT environment variable was detected; did you forget --provider=gcp or DEFANG_PROVIDER=gcp?") + // if there is exactly one stack with that provider, use it + if len(knownStacks) == 1 && (stack.Provider == cliClient.ProviderAuto || knownStacks[0].Provider == stack.Provider.String()) { + knownStack := knownStacks[0] + // try to read the stackfile + stack, loadErr := sm.Load(knownStack.Name) + if loadErr != nil { + var outsideErr *stacks.OutsideError + if errors.Is(loadErr, os.ErrNotExist) || errors.As(loadErr, &outsideErr) { + var importErr error + term.Warn("unable to load stack from file, attempting to import from previous deployments", loadErr) + stack, importErr = importStack(sm, knownStack) + if importErr != nil { + return nil, "", fmt.Errorf("unable to load or import stack: %w", errors.Join(loadErr, importErr)) + } } - global.Stack.Provider = cliClient.ProviderDefang + } + + whence = "only stack" + return stack, whence, nil + } + + // if there are zero known stacks or more than one known stack, prompt the user to create or select a stack + if global.NonInteractive { + if len(stackNames) > 0 { + return nil, "", fmt.Errorf("please specify a stack using --stack. The following stacks are available: %v", stackNames) } else { - var err error - if whence, err = determineProviderID(ctx, loader); err != nil { - return err - } + return nil, "", fmt.Errorf("no stacks are configured; please create a stack using 'defang stack create --provider=%s'", stack.Provider) + } + } + + stackParameters, err := stackSelector.SelectStack(ctx) + if err != nil { + return nil, "", fmt.Errorf("failed to select stack: %w", err) + } + stack = stackParameters + whence = "interactive selection" + return stack, whence, nil +} + +func importStack(sm stacks.Manager, stack stacks.StackListItem) (*stacks.StackParameters, error) { + var providerID cliClient.ProviderID + err := providerID.Set(stack.Provider) + if err != nil { + return nil, fmt.Errorf("invalid provider %q in stack %q: %w", stack.Provider, stack.Name, err) + } + mode := modes.ModeUnspecified + if stack.Mode != "" { + err = mode.Set(stack.Mode) + if err != nil { + return nil, fmt.Errorf("invalid mode %q in stack %q: %w", stack.Mode, stack.Name, err) } + } + params := &stacks.StackParameters{ + Name: stack.Name, + Provider: providerID, + Region: stack.Region, + Mode: mode, + } + err = sm.LoadParameters(params.ToMap(), false) + if err != nil { + return nil, fmt.Errorf("unable to load parameters for stack %q: %w", stack.Name, err) + } + + return params, nil +} + +func printProviderMismatchWarnings(ctx context.Context, provider cliClient.ProviderID) { + if provider == cliClient.ProviderDefang { + // Ignore any env vars when explicitly using the Defang playground provider + // Defaults to defang provider in non-interactive mode + if awsInEnv() { + term.Warn("AWS environment variables were detected; did you forget --provider=aws or DEFANG_PROVIDER=aws?") + } + if doInEnv() { + term.Warn("DIGITALOCEAN_TOKEN environment variable was detected; did you forget --provider=digitalocean or DEFANG_PROVIDER=digitalocean?") + } + if gcpInEnv() { + term.Warn("GCP_PROJECT_ID/CLOUDSDK_CORE_PROJECT environment variable was detected; did you forget --provider=gcp or DEFANG_PROVIDER=gcp?") + } + } + + switch provider { case cliClient.ProviderAWS: if !awsInConfig(ctx) { term.Warn("AWS provider was selected, but AWS environment is not set") @@ -1328,105 +1453,66 @@ func updateProviderID(ctx context.Context, loader cliClient.Loader) error { if !gcpInEnv() { term.Warn("GCP provider was selected, but GCP_PROJECT_ID environment variable is not set") } - case cliClient.ProviderDefang: - // Ignore any env vars when explicitly using the Defang playground provider - extraMsg = "; consider using BYOC (https://s.defang.io/byoc)" } - - term.Infof("Using %s provider from %s%s", global.Stack.Provider.Name(), whence, extraMsg) - return nil } -func newProvider(ctx context.Context, loader cliClient.Loader) (cliClient.Provider, error) { - if err := updateProviderID(ctx, loader); err != nil { +func newProvider(ctx context.Context, ec elicitations.Controller, sm stacks.Manager) (cliClient.Provider, error) { + stack, whence, err := getStack(ctx, ec, sm) + if err != nil { return nil, err } - provider := cli.NewProvider(ctx, global.Stack.Provider, global.Client, global.Stack.Name) - return provider, nil -} + // TODO: avoid writing to this global variable once all readers are removed + global.Stack = *stack -func newProviderChecked(ctx context.Context, loader cliClient.Loader) (cliClient.Provider, error) { - provider, err := newProvider(ctx, loader) - if err != nil { - return nil, err + extraMsg := "" + if stack.Provider == cliClient.ProviderDefang { + extraMsg = "; consider using BYOC (https://s.defang.io/byoc)" } - _, err = provider.AccountInfo(ctx) - return provider, err -} + term.Infof("Using the %q stack on %s from %s%s", stack.Name, stack.Provider, whence, extraMsg) -func canIUseProvider(ctx context.Context, provider cliClient.Provider, projectName string, serviceCount int) error { - return cliClient.CanIUseProvider(ctx, global.Client, provider, projectName, global.Stack.Name, serviceCount) + printProviderMismatchWarnings(ctx, stack.Provider) + provider := cli.NewProvider(ctx, stack.Provider, global.Client, stack.Name) + return provider, nil } -func determineProviderID(ctx context.Context, loader cliClient.Loader) (string, error) { - var projectName string +func newProviderChecked(ctx context.Context, loader cliClient.Loader) (cliClient.Provider, error) { + var err error + projectName := "" + outside := true if loader != nil { - var err error projectName, err = loader.LoadProjectName(ctx) if err != nil { term.Warnf("Unable to load project: %v", err) } - - if projectName != "" && !RootCmd.PersistentFlags().Changed("provider") { // If user manually selected auto provider, do not load from remote - resp, err := global.Client.GetSelectedProvider(ctx, &defangv1.GetSelectedProviderRequest{Project: projectName}) - if err != nil { - term.Debugf("Unable to get selected provider: %v", err) - } else if resp.Provider != defangv1.Provider_PROVIDER_UNSPECIFIED { - global.Stack.Provider.SetValue(resp.Provider) - return "stored preference", nil - } - } + outside = loader.OutsideWorkingDirectory() } - - whence, err := interactiveSelectProvider(cliClient.AllProviders()) - - // Save the selected provider to the fabric - if projectName != "" { - if err := global.Client.SetSelectedProvider(ctx, &defangv1.SetSelectedProviderRequest{Project: projectName, Provider: global.Stack.Provider.Value()}); err != nil { - term.Debugf("Unable to save selected provider to defang server: %v", err) - } else { - term.Printf("%v is now the default provider for project %v and will auto-select next time if no other provider is specified. Use --provider=auto to reselect.", global.Stack.Provider, projectName) + elicitationsClient := elicitations.NewSurveyClient(os.Stdin, os.Stdout, os.Stderr) + ec := elicitations.NewController(elicitationsClient) + var sm stacks.Manager + if outside { + sm, err = stacks.NewManager(global.Client, "", projectName) + if err != nil { + return nil, fmt.Errorf("failed to create stack manager: %w", err) + } + } else { + wd, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get working directory: %w", err) + } + sm, err = stacks.NewManager(global.Client, wd, projectName) + if err != nil { + return nil, fmt.Errorf("failed to create stack manager: %w", err) } } - - return whence, err -} - -func interactiveSelectProvider(providers []cliClient.ProviderID) (string, error) { - if len(providers) < 2 { - panic("interactiveSelectProvider called with less than 2 providers") - } - // Prompt the user to choose a provider if in interactive mode - options := []string{} - for _, p := range providers { - options = append(options, p.String()) - } - // Default to the provider in the environment if available - var defaultOption any // not string! - if awsInEnv() { - defaultOption = cliClient.ProviderAWS.String() - } else if doInEnv() { - defaultOption = cliClient.ProviderDO.String() - } else if gcpInEnv() { - defaultOption = cliClient.ProviderGCP.String() - } - var optionValue string - if err := survey.AskOne(&survey.Select{ - Default: defaultOption, - Message: "Choose a cloud provider:", - Options: options, - Help: "The provider you choose will be used for deploying services.", - Description: func(value string, i int) string { - return providerDescription[cliClient.ProviderID(value)] - }, - }, &optionValue, survey.WithStdio(term.DefaultTerm.Stdio())); err != nil { - return "", fmt.Errorf("failed to select provider: %w", err) - } - track.Evt("ProviderSelected", P("provider", optionValue)) - if err := global.Stack.Provider.Set(optionValue); err != nil { - panic(err) + provider, err := newProvider(ctx, ec, sm) + if err != nil { + return nil, err } + _, err = provider.AccountInfo(ctx) + return provider, err +} - return "interactive prompt", nil +func canIUseProvider(ctx context.Context, provider cliClient.Provider, projectName string, serviceCount int) error { + return cliClient.CanIUseProvider(ctx, global.Client, provider, projectName, serviceCount) } diff --git a/src/cmd/cli/command/commands_test.go b/src/cmd/cli/command/commands_test.go index a328015fd..ab02a45cf 100644 --- a/src/cmd/cli/command/commands_test.go +++ b/src/cmd/cli/command/commands_test.go @@ -13,11 +13,9 @@ import ( "github.com/DefangLabs/defang/src/pkg/auth" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/aws" - "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/gcp" - "github.com/DefangLabs/defang/src/pkg/cli/compose" pkg "github.com/DefangLabs/defang/src/pkg/clouds/aws" - gcpdriver "github.com/DefangLabs/defang/src/pkg/clouds/gcp" - "github.com/DefangLabs/defang/src/pkg/term" + "github.com/DefangLabs/defang/src/pkg/modes" + "github.com/DefangLabs/defang/src/pkg/stacks" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/DefangLabs/defang/src/protos/io/defang/v1/defangv1connect" "github.com/aws/aws-sdk-go-v2/service/ssm" @@ -25,7 +23,6 @@ import ( "github.com/aws/smithy-go/ptr" "github.com/bufbuild/connect-go" "github.com/spf13/cobra" - "golang.org/x/oauth2/google" "google.golang.org/protobuf/types/known/emptypb" ) @@ -83,6 +80,18 @@ func (m *mockFabricService) SetSelectedProvider(context.Context, *connect.Reques return connect.NewResponse(&emptypb.Empty{}), nil } +func (m *mockFabricService) ListDeployments(context.Context, *connect.Request[defangv1.ListDeploymentsRequest]) (*connect.Response[defangv1.ListDeploymentsResponse], error) { + return connect.NewResponse(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Stack: "beta", + Provider: defangv1.Provider_AWS, + Mode: defangv1.DeploymentMode_DEVELOPMENT, + }, + }, + }), nil +} + func init() { SetupCommands(context.Background(), "0.0.0-test") } @@ -141,8 +150,6 @@ func TestCommandGates(t *testing.T) { _, handler := defangv1connect.NewFabricControllerHandler(mockService) t.Chdir("../../../../src/testdata/sanity") - t.Setenv("AWS_REGION", "us-west-2") - userinfoServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.URL.Path != "/userinfo" { t.Fatalf("unexpected path %q", r.URL.Path) @@ -239,10 +246,19 @@ func (m *MockFabricControllerClient) GetSelectedProvider(ctx context.Context, re } func (m *MockFabricControllerClient) SetSelectedProvider(ctx context.Context, req *connect.Request[defangv1.SetSelectedProviderRequest]) (*connect.Response[emptypb.Empty], error) { + if m.savedProvider == nil { + m.savedProvider = make(map[string]defangv1.Provider) + } m.savedProvider[req.Msg.Project] = req.Msg.Provider return connect.NewResponse(&emptypb.Empty{}), nil } +func (m *MockFabricControllerClient) ListDeployments(ctx context.Context, req *connect.Request[defangv1.ListDeploymentsRequest]) (*connect.Response[defangv1.ListDeploymentsResponse], error) { + return connect.NewResponse(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{}, + }), nil +} + type FakeStdin struct { *bytes.Reader } @@ -259,17 +275,92 @@ func (f *FakeStdout) Fd() uintptr { return os.Stdout.Fd() } -func TestGetProvider(t *testing.T) { +type mockStackManager struct { + t *testing.T + expectedProvider cliClient.ProviderID + expectedRegion string + listResult []stacks.StackListItem + listError error + loadResults map[string]*stacks.StackParameters + loadResult *stacks.StackParameters + loadError error + createError error + createResult *stacks.StackParameters +} + +func NewMockStackManager(t *testing.T, expectedProvider cliClient.ProviderID, expectedRegion string) *mockStackManager { + return &mockStackManager{ + t: t, + expectedProvider: expectedProvider, + expectedRegion: expectedRegion, + listResult: []stacks.StackListItem{}, + } +} + +func (m *mockStackManager) List(ctx context.Context) ([]stacks.StackListItem, error) { + if m.listError != nil { + return nil, m.listError + } + return m.listResult, nil +} + +func (m *mockStackManager) Load(name string) (*stacks.StackParameters, error) { + if m.loadError != nil { + return nil, m.loadError + } + + // Check for specific stack name first + if m.loadResults != nil { + if result, exists := m.loadResults[name]; exists { + return result, nil + } + } + + // If we have an explicit loadResult, return it + if m.loadResult != nil { + return m.loadResult, nil + } + + // If we have expected provider/region (from old NewMockStackManager usage), create default params + if m.expectedProvider != "" && m.expectedRegion != "" { + params := stacks.StackParameters{ + Name: name, + Provider: m.expectedProvider, + Region: m.expectedRegion, + Mode: modes.ModeAffordable, + } + stacks.LoadParameters(params.ToMap(), true) + return ¶ms, nil + } + + return nil, os.ErrNotExist +} + +func (m *mockStackManager) Create(params stacks.StackParameters) (string, error) { + if m.createError != nil { + return "", m.createError + } + if m.createResult != nil { + m.loadResult = m.createResult + } + return params.Name, nil +} + +func (m *mockStackManager) LoadParameters(params map[string]string, overload bool) error { + return stacks.LoadParameters(params, overload) +} + +func TestNewProvider(t *testing.T) { mockClient := cliClient.GrpcClient{} mockCtrl := &MockFabricControllerClient{ canIUseResponse: defangv1.CanIUseResponse{}, } mockClient.SetClient(mockCtrl) global.Client = &mockClient - loader := cliClient.MockLoader{Project: compose.Project{Name: "empty"}} oldRootCmd := RootCmd t.Cleanup(func() { RootCmd = oldRootCmd + global.Stack = stacks.StackParameters{} }) FakeRootWithProviderParam := func(provider string) *cobra.Command { cmd := &cobra.Command{} @@ -287,7 +378,11 @@ func TestGetProvider(t *testing.T) { os.Unsetenv("DEFANG_PROVIDER") RootCmd = FakeRootWithProviderParam("") - p, err := newProvider(ctx, nil) + // Create a mock stacks manager that returns empty stack list + mockEC := &mockElicitationsController{} + mockSM := NewMockStackManager(t, cliClient.ProviderAWS, "us-west-2") + + p, err := newProvider(ctx, mockEC, mockSM) if err != nil { t.Fatalf("getProvider() failed: %v", err) } @@ -296,202 +391,20 @@ func TestGetProvider(t *testing.T) { } }) - t.Run("Auto provider should get provider from client", func(t *testing.T) { - global.Stack.Provider = "auto" - os.Unsetenv("DEFANG_PROVIDER") - t.Setenv("AWS_REGION", "us-west-2") - RootCmd = FakeRootWithProviderParam("") - - mockCtrl.savedProvider = map[string]defangv1.Provider{"empty": defangv1.Provider_AWS} - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - }) - - p, err := newProvider(ctx, loader) - if err != nil { - t.Fatalf("getProvider() failed: %v", err) - } - if _, ok := p.(*aws.ByocAws); !ok { - t.Errorf("Expected provider to be of type *aws.ByocAws, got %T", p) - } - }) - - t.Run("Auto provider from param with saved provider should go interactive and save", func(t *testing.T) { - global.Stack.Provider = "auto" - os.Unsetenv("DEFANG_PROVIDER") - t.Setenv("AWS_REGION", "us-west-2") - mockCtrl.savedProvider = map[string]defangv1.Provider{"someotherproj": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("") - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - oldTerm := term.DefaultTerm - term.DefaultTerm = term.NewTerm( - &FakeStdin{bytes.NewReader([]byte("aws\n"))}, - &FakeStdout{new(bytes.Buffer)}, - new(bytes.Buffer), - ) - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - term.DefaultTerm = oldTerm - }) - - p, err := newProvider(ctx, loader) - if err != nil { - t.Fatalf("getProvider() failed: %v", err) - } - if _, ok := p.(*aws.ByocAws); !ok { - t.Errorf("Expected provider to be of type *aws.ByocAws, got %T", p) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_AWS { - t.Errorf("Expected provider to be saved as AWS, got %v", mockCtrl.savedProvider["empty"]) - } - }) - - t.Run("Interactive provider prompt infer default provider from environment variable", func(t *testing.T) { - if testing.Short() { - t.Skip("Skip digitalocean test") - } - global.Stack.Provider = "auto" - os.Unsetenv("DEFANG_PROVIDER") - os.Unsetenv("AWS_PROFILE") - t.Setenv("AWS_REGION", "us-west-2") - t.Setenv("DIGITALOCEAN_TOKEN", "test-token") - mockCtrl.savedProvider = map[string]defangv1.Provider{"someotherproj": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("") - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - oldTerm := term.DefaultTerm - term.DefaultTerm = term.NewTerm( - &FakeStdin{bytes.NewReader([]byte("\n"))}, // Use default option, which should be DO from env var - &FakeStdout{new(bytes.Buffer)}, - new(bytes.Buffer), - ) - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - term.DefaultTerm = oldTerm - }) - - _, err := newProvider(ctx, loader) - if err != nil && !strings.HasPrefix(err.Error(), "GET https://api.digitalocean.com/v2/account: 401") { - t.Fatalf("getProvider() failed: %v", err) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_DIGITALOCEAN { - t.Errorf("Expected provider to be saved as DIGITALOCEAN, got %v", mockCtrl.savedProvider["empty"]) - } - }) - - t.Run("Auto provider from param with saved provider should go interactive and save", func(t *testing.T) { - os.Unsetenv("GCP_PROJECT_ID") // To trigger error - os.Unsetenv("DEFANG_PROVIDER") - global.Stack.Provider = "auto" - mockCtrl.savedProvider = map[string]defangv1.Provider{"empty": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("auto") - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - oldTerm := term.DefaultTerm - term.DefaultTerm = term.NewTerm( - &FakeStdin{bytes.NewReader([]byte("gcp\n"))}, - &FakeStdout{new(bytes.Buffer)}, - new(bytes.Buffer), - ) - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - term.DefaultTerm = oldTerm - }) - - _, err := newProvider(ctx, loader) - if err != nil && err.Error() != "GCP_PROJECT_ID or CLOUDSDK_CORE_PROJECT must be set for GCP projects" { - t.Fatalf("getProvider() failed: %v", err) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_GCP { - t.Errorf("Expected provider to be saved as GCP, got %v", mockCtrl.savedProvider["empty"]) - } - }) - - t.Run("Should take provider from param without updating saved provider", func(t *testing.T) { - os.Unsetenv("DIGITALOCEAN_TOKEN") - os.Unsetenv("DEFANG_PROVIDER") - mockCtrl.savedProvider = map[string]defangv1.Provider{"empty": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("digitalocean") - ni := global.NonInteractive - global.NonInteractive = false - t.Cleanup(func() { - global.NonInteractive = ni - mockCtrl.savedProvider = nil - }) - - _, err := newProvider(ctx, loader) - if err != nil && !strings.HasPrefix(err.Error(), "DIGITALOCEAN_TOKEN must be set") { - t.Fatalf("getProvider() failed: %v", err) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_AWS { - t.Errorf("Expected provider to stay as AWS, but got %v", mockCtrl.savedProvider["empty"]) - } - }) - - t.Run("Should take provider from env aws", func(t *testing.T) { - t.Setenv("DEFANG_PROVIDER", "aws") - t.Setenv("AWS_REGION", "us-west-2") - RootCmd = FakeRootWithProviderParam("") - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - t.Cleanup(func() { - aws.StsClient = sts - }) - - p, err := newProvider(ctx, loader) - if err != nil { - t.Errorf("getProvider() failed: %v", err) - } - if _, ok := p.(*aws.ByocAws); !ok { - t.Errorf("Expected provider to be of type *aws.ByocAws, got %T", p) - } - }) + t.Run("Should set cd image from canIUse response", func(t *testing.T) { + t.Chdir("../../../../src/testdata/sanity") + t.Setenv("DEFANG_STACK", "beta") - t.Run("Should take provider from env gcp", func(t *testing.T) { - t.Setenv("DEFANG_PROVIDER", "gcp") - t.Setenv("GCP_PROJECT_ID", "test_proj_id") - RootCmd = FakeRootWithProviderParam("") - gcpdriver.FindGoogleDefaultCredentials = func(ctx context.Context, scopes ...string) (*google.Credentials, error) { - return &google.Credentials{ - JSON: []byte(`{"client_email":"test@email.com"}`), - }, nil - } + // Set up RootCmd with required flags for getStack function + RootCmd = &cobra.Command{Use: "defang"} + RootCmd.PersistentFlags().StringVarP(&global.Stack.Name, "stack", "s", global.Stack.Name, "stack name") + RootCmd.PersistentFlags().VarP(&global.Stack.Provider, "provider", "P", "provider") + RootCmd.PersistentFlags().StringP("project-name", "p", "", "project name") + RootCmd.PersistentFlags().StringArrayP("file", "f", []string{}, "compose file path(s)") - p, err := newProvider(ctx, loader) - if err != nil { - t.Errorf("getProvider() failed: %v", err) - } - if _, ok := p.(*gcp.ByocGcp); !ok { - t.Errorf("Expected provider to be of type *aws.ByocGcp, got %T", p) - } - }) + // Parse the flags to initialize the flag system + RootCmd.ParseFlags([]string{}) - t.Run("Should set cd image from canIUse response", func(t *testing.T) { - t.Setenv("DEFANG_PROVIDER", "aws") - t.Setenv("AWS_REGION", "us-west-2") sts := aws.StsClient aws.StsClient = &mockStsProviderAPI{} const cdImageTag = "site/registry/repo:tag@sha256:digest" @@ -499,9 +412,12 @@ func TestGetProvider(t *testing.T) { t.Cleanup(func() { aws.StsClient = sts mockCtrl.canIUseResponse.CdImage = "" + global.Stack = stacks.StackParameters{} }) - p, err := newProvider(ctx, loader) + mockEC := &mockElicitationsController{} + mockSM := NewMockStackManager(t, cliClient.ProviderAWS, "us-west-2") + p, err := newProvider(ctx, mockEC, mockSM) if err != nil { t.Errorf("getProvider() failed: %v", err) } @@ -521,8 +437,8 @@ func TestGetProvider(t *testing.T) { }) t.Run("Can override cd image from environment variable", func(t *testing.T) { - t.Setenv("DEFANG_PROVIDER", "aws") - t.Setenv("AWS_REGION", "us-west-2") + t.Chdir("../../../../src/testdata/sanity") + t.Setenv("DEFANG_STACK", "beta") sts := aws.StsClient aws.StsClient = &mockStsProviderAPI{} const cdImageTag = "site/registry/repo:tag@sha256:digest" @@ -532,9 +448,12 @@ func TestGetProvider(t *testing.T) { t.Cleanup(func() { aws.StsClient = sts mockCtrl.canIUseResponse.CdImage = "" + global.Stack = stacks.StackParameters{} }) - p, err := newProvider(ctx, loader) + mockEC := &mockElicitationsController{} + mockSM := NewMockStackManager(t, cliClient.ProviderAWS, "us-west-2") + p, err := newProvider(ctx, mockEC, mockSM) if err != nil { t.Errorf("getProvider() failed: %v", err) } @@ -553,3 +472,340 @@ func TestGetProvider(t *testing.T) { } }) } + +type mockElicitationsController struct { + isSupported bool + enumChoice string +} + +func (m *mockElicitationsController) RequestString(ctx context.Context, message, field string) (string, error) { + return "", nil +} + +func (m *mockElicitationsController) RequestStringWithDefault(ctx context.Context, message, field, defaultValue string) (string, error) { + return defaultValue, nil +} + +func (m *mockElicitationsController) RequestEnum(ctx context.Context, message, field string, options []string) (string, error) { + if m.enumChoice != "" { + return m.enumChoice, nil + } + if len(options) > 0 { + return options[0], nil + } + return "", nil +} + +func (m *mockElicitationsController) SetSupported(supported bool) { + m.isSupported = supported +} + +func (m *mockElicitationsController) IsSupported() bool { + return m.isSupported +} + +func TestGetStack(t *testing.T) { + ctx := context.Background() + + // Save original state + origRootCmd := RootCmd + origGlobalNonInteractive := global.NonInteractive + defer func() { + RootCmd = origRootCmd + global.NonInteractive = origGlobalNonInteractive + global.Stack = stacks.StackParameters{} + }() + + testCases := []struct { + name string + setup func(t *testing.T) (*mockElicitationsController, *mockStackManager) + stackFlag string + providerFlag string + envProvider string + nonInteractive bool + expectedStack *stacks.StackParameters + expectedWhence string + expectedError string + expectWarning bool + }{ + { + name: "stack flag provided with valid stack", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{} + sm := &mockStackManager{ + loadResult: &stacks.StackParameters{ + Name: "test-stack", + Provider: cliClient.ProviderAWS, + Region: "us-west-2", + }, + } + return ec, sm + }, + stackFlag: "test-stack", + expectedStack: &stacks.StackParameters{ + Name: "test-stack", + Provider: cliClient.ProviderAWS, + Region: "us-west-2", + }, + expectedWhence: "stack file", + }, + { + name: "stack flag provided with invalid stack", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{} + sm := &mockStackManager{ + loadError: errors.New("stack not found"), + } + return ec, sm + }, + stackFlag: "nonexistent-stack", + expectedError: "unable to load stack \"nonexistent-stack\": stack not found", + }, + { + name: "stack flag with auto provider should error", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{} + sm := &mockStackManager{ + loadResult: &stacks.StackParameters{ + Name: "auto-stack", + Provider: cliClient.ProviderAuto, + Region: "us-west-2", + }, + } + return ec, sm + }, + stackFlag: "auto-stack", + expectedError: "stack \"auto-stack\" has an invalid provider \"auto\"", + }, + { + name: "provider flag provided with warning and existing stacks", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{ + isSupported: true, + enumChoice: "existing-stack", + } + sm := &mockStackManager{ + listResult: []stacks.StackListItem{ + {Name: "existing-stack", Provider: "aws"}, + }, + loadResult: &stacks.StackParameters{ + Name: "existing-stack", + Provider: cliClient.ProviderAWS, + }, + } + return ec, sm + }, + providerFlag: "aws", + expectWarning: true, + expectedStack: &stacks.StackParameters{ + Name: "existing-stack", + Provider: cliClient.ProviderAWS, + }, + expectedWhence: "only stack", + }, + { + name: "env provider with warning and existing stacks", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{ + isSupported: true, + enumChoice: "existing-stack", + } + sm := &mockStackManager{ + listResult: []stacks.StackListItem{ + {Name: "existing-stack", Provider: "aws"}, // Different provider to avoid "only stack" path + {Name: "other-stack", Provider: "gcp"}, + }, + loadResult: &stacks.StackParameters{ + Name: "existing-stack", + Provider: cliClient.ProviderAWS, + }, + } + return ec, sm + }, + envProvider: "gcp", + expectWarning: true, + expectedStack: &stacks.StackParameters{ + Name: "existing-stack", + Provider: cliClient.ProviderAWS, + }, + expectedWhence: "interactive selection", + }, + { + name: "non-interactive with auto provider returns default", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{} + sm := &mockStackManager{ + listResult: []stacks.StackListItem{}, + } + return ec, sm + }, + nonInteractive: true, + expectedStack: &stacks.StackParameters{ + Name: "beta", + Provider: cliClient.ProviderDefang, + Mode: modes.ModeUnspecified, + }, + expectedWhence: "non-interactive default", + }, + { + name: "single stack matches provider", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{} + sm := &mockStackManager{ + listResult: []stacks.StackListItem{ + {Name: "only-stack", Provider: "aws", Mode: "affordable"}, + }, + } + return ec, sm + }, + expectedStack: &stacks.StackParameters{ + Name: "only-stack", + Provider: cliClient.ProviderAWS, + Mode: modes.ModeAffordable, + }, + expectedWhence: "only stack", + }, + { + name: "interactive selection succeeds", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{ + isSupported: true, + enumChoice: "stack1", + } + sm := &mockStackManager{ + listResult: []stacks.StackListItem{ + {Name: "stack1", Provider: "aws"}, + {Name: "stack2", Provider: "gcp"}, + }, + loadResult: &stacks.StackParameters{ + Name: "stack1", + Provider: cliClient.ProviderAWS, + }, + } + return ec, sm + }, + expectedStack: &stacks.StackParameters{ + Name: "stack1", + Provider: cliClient.ProviderAWS, + }, + expectedWhence: "interactive selection", + }, + { + name: "sm.List error should propagate", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{} + sm := &mockStackManager{ + listError: errors.New("failed to list stacks"), + } + return ec, sm + }, + expectedError: "unable to list stacks: failed to list stacks", + }, + { + name: "stackSelector.SelectStack error should propagate", + setup: func(t *testing.T) (*mockElicitationsController, *mockStackManager) { + ec := &mockElicitationsController{isSupported: false} // Will cause SelectStack to fail + sm := &mockStackManager{ + listResult: []stacks.StackListItem{ + {Name: "stack1", Provider: "aws"}, + {Name: "stack2", Provider: "aws"}, + }, + } + return ec, sm + }, + expectedError: "failed to select stack:", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Setup mocks + ec, sm := tc.setup(t) + + // Create a new root command for this test + testRootCmd := &cobra.Command{Use: "defang"} + testRootCmd.PersistentFlags().String("stack", "", "stack name") + testRootCmd.PersistentFlags().VarP(&global.Stack.Provider, "provider", "P", "provider") + + // Set flags if provided + var args []string + if tc.stackFlag != "" { + args = append(args, "--stack", tc.stackFlag) + } + if tc.providerFlag != "" { + args = append(args, "--provider", tc.providerFlag) + } + + if len(args) > 0 { + testRootCmd.ParseFlags(args) + } + + // Set environment variable if provided + if tc.envProvider != "" { + t.Setenv("DEFANG_PROVIDER", tc.envProvider) + } else { + os.Unsetenv("DEFANG_PROVIDER") + } + + // Set global state + RootCmd = testRootCmd + global.NonInteractive = tc.nonInteractive + + // Reset global stack state + global.Stack.Provider = cliClient.ProviderAuto + + // Capture output to check for warnings + var output bytes.Buffer + + // Call the function under test + stack, whence, err := getStack(ctx, ec, sm) + + // Check error expectations + if tc.expectedError != "" { + if err == nil { + t.Fatalf("expected error %q, got nil", tc.expectedError) + } + if !strings.Contains(err.Error(), tc.expectedError) { + t.Fatalf("expected error to contain %q, got %q", tc.expectedError, err.Error()) + } + return + } + + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Check stack expectations + if tc.expectedStack != nil { + if stack == nil { + t.Fatal("expected stack to be non-nil") + } + if stack.Name != tc.expectedStack.Name { + t.Errorf("expected stack name %q, got %q", tc.expectedStack.Name, stack.Name) + } + if stack.Provider != tc.expectedStack.Provider { + t.Errorf("expected stack provider %q, got %q", tc.expectedStack.Provider, stack.Provider) + } + if tc.expectedStack.Region != "" && stack.Region != tc.expectedStack.Region { + t.Errorf("expected stack region %q, got %q", tc.expectedStack.Region, stack.Region) + } + } + + // Check whence expectations + if tc.expectedWhence != "" && whence != tc.expectedWhence { + t.Errorf("expected whence %q, got %q", tc.expectedWhence, whence) + } + + // Check warning expectations + if tc.expectWarning { + // Since we can't easily capture term.Warn output in tests, we just verify + // that the code path that would produce warnings was taken + if tc.providerFlag != "" && !testRootCmd.PersistentFlags().Changed("provider") { + t.Error("expected provider flag to be marked as changed for warning path") + } + } + + _ = output // Suppress unused variable warning for now + }) + } +} diff --git a/src/cmd/cli/command/compose.go b/src/cmd/cli/command/compose.go index 98e916dee..7edfd341b 100644 --- a/src/cmd/cli/command/compose.go +++ b/src/cmd/cli/command/compose.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "os" "slices" "strings" "time" @@ -18,6 +19,7 @@ import ( pcluster "github.com/DefangLabs/defang/src/pkg/cluster" "github.com/DefangLabs/defang/src/pkg/debug" "github.com/DefangLabs/defang/src/pkg/dryrun" + "github.com/DefangLabs/defang/src/pkg/elicitations" "github.com/DefangLabs/defang/src/pkg/logs" "github.com/DefangLabs/defang/src/pkg/modes" "github.com/DefangLabs/defang/src/pkg/stacks" @@ -112,7 +114,7 @@ func makeComposeUpCmd() *cobra.Command { } else if accountInfo, err := provider.AccountInfo(ctx); err != nil { term.Debugf("AccountInfo failed: %v", err) } else if len(resp.Deployments) > 0 { - handleExistingDeployments(resp.Deployments, accountInfo, project.Name) + handleExistingDeployments(resp.Deployments, accountInfo, project.Name, provider.GetStackName()) } else if global.Stack.Name == "" { err = promptToCreateStack(ctx, stacks.StackParameters{ Name: stacks.MakeDefaultName(accountInfo.Provider, accountInfo.Region), @@ -218,10 +220,13 @@ func makeComposeUpCmd() *cobra.Command { return composeUpCmd } -func handleExistingDeployments(existingDeployments []*defangv1.Deployment, accountInfo *cliClient.AccountInfo, projectName string) error { +func handleExistingDeployments(existingDeployments []*defangv1.Deployment, accountInfo *cliClient.AccountInfo, projectName string, stackName string) error { samePlace := slices.ContainsFunc(existingDeployments, func(dep *defangv1.Deployment) bool { + if dep.Provider != accountInfo.Provider.Value() { + return false + } // Old deployments may not have a region or account ID, so we check for empty values too - return dep.Provider == global.Stack.Provider.Value() && (dep.ProviderAccountId == accountInfo.AccountID || dep.ProviderAccountId == "") && (dep.Region == accountInfo.Region || dep.Region == "") + return (dep.ProviderAccountId == accountInfo.AccountID || dep.ProviderAccountId == "") && (dep.Region == accountInfo.Region || dep.Region == "") }) if samePlace { return nil @@ -229,8 +234,8 @@ func handleExistingDeployments(existingDeployments []*defangv1.Deployment, accou if err := confirmDeploymentToNewLocation(projectName, existingDeployments); err != nil { return err } - if global.Stack.Name == "" { - stackName := "beta" + if stackName == "" { + stackName = "beta" _, err := stacks.Create(stacks.StackParameters{ Name: stackName, Provider: accountInfo.Provider, @@ -558,7 +563,18 @@ func makeComposeConfigCmd() *cobra.Command { }, loadErr) } - provider, err := newProvider(ctx, loader) + elicitationsClient := elicitations.NewSurveyClient(os.Stdin, os.Stdout, os.Stderr) + ec := elicitations.NewController(elicitationsClient) + wd, err := os.Getwd() + if err != nil { + return fmt.Errorf("failed to get working directory: %w", err) + } + sm, err := stacks.NewManager(global.Client, wd, project.Name) + if err != nil { + return fmt.Errorf("failed to create stack manager: %w", err) + } + + provider, err := newProvider(ctx, ec, sm) if err != nil { return err } diff --git a/src/cmd/cli/command/estimate.go b/src/cmd/cli/command/estimate.go index 5a18ea5bd..178d29ecb 100644 --- a/src/cmd/cli/command/estimate.go +++ b/src/cmd/cli/command/estimate.go @@ -3,10 +3,12 @@ package command import ( "fmt" + "github.com/AlecAivazis/survey/v2" "github.com/DefangLabs/defang/src/pkg/cli" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/modes" "github.com/DefangLabs/defang/src/pkg/term" + "github.com/DefangLabs/defang/src/pkg/track" "github.com/spf13/cobra" ) @@ -27,13 +29,14 @@ func makeEstimateCmd() *cobra.Command { } if global.Stack.Provider == cliClient.ProviderAuto { - _, err = interactiveSelectProvider([]cliClient.ProviderID{ + providerID, err := interactiveSelectProvider([]cliClient.ProviderID{ cliClient.ProviderAWS, cliClient.ProviderGCP, }) if err != nil { return fmt.Errorf("failed to select provider: %w", err) } + global.Stack.Provider = providerID } var previewProvider cliClient.Provider = &cliClient.PlaygroundProvider{FabricClient: global.Client} @@ -62,3 +65,40 @@ func makeEstimateCmd() *cobra.Command { estimateCmd.Flags().StringVarP(&global.Stack.Region, "region", "r", "", "which cloud region to estimate") return estimateCmd } + +func interactiveSelectProvider(providers []cliClient.ProviderID) (cliClient.ProviderID, error) { + if len(providers) < 2 { + panic("interactiveSelectProvider called with less than 2 providers") + } + // Prompt the user to choose a provider if in interactive mode + options := []string{} + for _, p := range providers { + options = append(options, p.String()) + } + // Default to the provider in the environment if available + var defaultOption any // not string! + if awsInEnv() { + defaultOption = cliClient.ProviderAWS.String() + } else if gcpInEnv() { + defaultOption = cliClient.ProviderGCP.String() + } + var optionValue string + if err := survey.AskOne(&survey.Select{ + Default: defaultOption, + Message: "Choose a cloud provider:", + Options: options, + Help: "The provider you choose will be used for deploying services.", + Description: func(value string, i int) string { + return providerDescription[cliClient.ProviderID(value)] + }, + }, &optionValue, survey.WithStdio(term.DefaultTerm.Stdio())); err != nil { + return "", fmt.Errorf("failed to select provider: %w", err) + } + track.Evt("ProviderSelected", P("provider", optionValue)) + var providerID cliClient.ProviderID + err := providerID.Set(optionValue) + if err != nil { + return "", err + } + return providerID, nil +} diff --git a/src/cmd/cli/command/globals.go b/src/cmd/cli/command/globals.go index 7e5da8feb..98ddf5601 100644 --- a/src/cmd/cli/command/globals.go +++ b/src/cmd/cli/command/globals.go @@ -95,7 +95,7 @@ var global GlobalConfig = GlobalConfig{ HasTty: term.IsTerminal(), HideUpdate: false, NonInteractive: !term.IsTerminal(), - Stack: stacks.StackParameters{Provider: cliClient.ProviderAuto, Mode: modes.ModeUnspecified}, + Stack: stacks.StackParameters{Name: "", Provider: cliClient.ProviderAuto, Mode: modes.ModeUnspecified}, SourcePlatform: migrate.SourcePlatformUnspecified, // default to auto-detecting the source platform Verbose: false, } diff --git a/src/cmd/cli/command/stack.go b/src/cmd/cli/command/stack.go index 6ee787634..13270b927 100644 --- a/src/cmd/cli/command/stack.go +++ b/src/cmd/cli/command/stack.go @@ -87,7 +87,24 @@ func makeStackListCmd() *cobra.Command { RunE: func(cmd *cobra.Command, args []string) error { jsonMode, _ := cmd.Flags().GetBool("json") - stacks, err := stacks.List() + wd, err := os.Getwd() + if err != nil { + return err + } + + ctx := cmd.Context() + loader := configureLoader(cmd) + projectName, err := loader.LoadProjectName(ctx) + if err != nil { + return err + } + + sm, err := stacks.NewManager(global.Client, wd, projectName) + if err != nil { + return err + } + + stacks, err := sm.List(ctx) if err != nil { return err } @@ -110,7 +127,8 @@ func makeStackListCmd() *cobra.Command { return err } - return term.Table(stacks, "Name", "Provider", "Region", "Mode") + columns := []string{"Name", "Provider", "Region", "Mode", "DeployedAt"} + return term.Table(stacks, columns...) }, } stackListCmd.Flags().Bool("json", false, "Output in JSON format") diff --git a/src/cmd/cli/command/stack_test.go b/src/cmd/cli/command/stack_test.go index e1dceca4d..571ab26ba 100644 --- a/src/cmd/cli/command/stack_test.go +++ b/src/cmd/cli/command/stack_test.go @@ -2,12 +2,15 @@ package command import ( "bytes" + "os" "testing" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/modes" "github.com/DefangLabs/defang/src/pkg/stacks" "github.com/DefangLabs/defang/src/pkg/term" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" + "github.com/spf13/cobra" "github.com/stretchr/testify/assert" ) @@ -25,8 +28,35 @@ func MockTerm(t *testing.T, stdout *bytes.Buffer, stdin *bytes.Reader) { } func TestStackListCmd(t *testing.T) { + // Save original RootCmd and restore after test + origRootCmd := RootCmd + origClient := global.Client + defer func() { + RootCmd = origRootCmd + global.Client = origClient + }() + + // Set up a mock client + mockClient := cliClient.GrpcClient{} + mockCtrl := &MockFabricControllerClient{ + canIUseResponse: defangv1.CanIUseResponse{}, + } + mockClient.SetClient(mockCtrl) + global.Client = &mockClient + + // Set up a fake RootCmd with required flags + RootCmd = &cobra.Command{Use: "defang"} + RootCmd.PersistentFlags().StringVarP(&global.Stack.Name, "stack", "s", global.Stack.Name, "stack name") + RootCmd.PersistentFlags().VarP(&global.Stack.Provider, "provider", "P", "provider") + RootCmd.PersistentFlags().StringP("project-name", "p", "", "project name") + RootCmd.PersistentFlags().StringArrayP("file", "f", []string{}, "compose file path(s)") + + // Create stackListCmd with manual RunE to avoid configureLoader call during test var stackListCmd = makeStackListCmd() + // Add stackListCmd as a child of RootCmd + RootCmd.AddCommand(stackListCmd) + tests := []struct { name string stacks []stacks.StackParameters @@ -53,15 +83,23 @@ func TestStackListCmd(t *testing.T) { Mode: modes.ModeBalanced, }, }, - expectOutput: "NAME PROVIDER REGION MODE\n" + - "teststack1 aws us-west-2 AFFORDABLE \n" + - "teststack2 gcp us-central1 BALANCED \n", + expectOutput: "NAME PROVIDER REGION MODE DEPLOYEDAT\n" + + "teststack1 aws us-west-2 AFFORDABLE 0001-01-01 00:00:00 +0000 UTC \n" + + "teststack2 gcp us-central1 BALANCED 0001-01-01 00:00:00 +0000 UTC \n", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Setup stacks t.Chdir(t.TempDir()) + // create a compose file so stackListCmd doesn't error out + os.WriteFile( + "compose.yaml", + []byte(`services: + web: + image: nginx`), + os.FileMode(0644), + ) for _, stack := range tt.stacks { stacks.Create(stack) } @@ -70,7 +108,8 @@ func TestStackListCmd(t *testing.T) { mockStdin := bytes.NewReader([]byte{}) MockTerm(t, buffer, mockStdin) - err := stackListCmd.RunE(stackListCmd, []string{}) + RootCmd.SetArgs([]string{"list"}) + err := RootCmd.Execute() assert.NoError(t, err) assert.Equal(t, tt.expectOutput, buffer.String()) }) @@ -126,3 +165,36 @@ func TestNonInteractiveStackNewCmd(t *testing.T) { }) } } + +func TestLoadParameters(t *testing.T) { + params := map[string]string{ + "DEFANG_PROVIDER": "aws", + "AWS_REGION": "us-west-2", + "AWS_PROFILE": "default", + "DEFANG_MODE": "AFFORDABLE", + } + + // Clear any existing env vars that might interfere with the test + os.Unsetenv("DEFANG_PROVIDER") + os.Unsetenv("AWS_REGION") + os.Unsetenv("AWS_PROFILE") + os.Unsetenv("DEFANG_MODE") + + defer func() { + // Clean up environment variables after test + os.Unsetenv("DEFANG_PROVIDER") + os.Unsetenv("AWS_REGION") + os.Unsetenv("AWS_PROFILE") + os.Unsetenv("DEFANG_MODE") + }() + + err := stacks.LoadParameters(params, true) + if err != nil { + t.Fatalf("LoadParameters() error = %v", err) + } + + assert.Equal(t, "aws", os.Getenv("DEFANG_PROVIDER")) + assert.Equal(t, "us-west-2", os.Getenv("AWS_REGION")) + assert.Equal(t, "default", os.Getenv("AWS_PROFILE")) + assert.Equal(t, "AFFORDABLE", os.Getenv("DEFANG_MODE")) +} diff --git a/src/cmd/cli/command/workspace_test.go b/src/cmd/cli/command/workspace_test.go index f83cabd7c..e79bb1625 100644 --- a/src/cmd/cli/command/workspace_test.go +++ b/src/cmd/cli/command/workspace_test.go @@ -57,6 +57,9 @@ func TestWorkspaceListJSON(t *testing.T) { oldGlobal := global t.Cleanup(func() { global = oldGlobal }) + // Reset stack name to prevent loading stack files + global.Stack.Name = "" + if err := testCommand([]string{"workspace", "ls", "--json", "--non-interactive"}, clusterURL); err != nil { t.Fatalf("workspace ls --json failed: %v", err) } @@ -89,6 +92,9 @@ func TestWorkspaceListVerboseTable(t *testing.T) { oldGlobal := global t.Cleanup(func() { global = oldGlobal }) + // Reset stack name to prevent loading stack files + global.Stack.Name = "" + if err := testCommand([]string{"workspace", "ls", "--verbose", "--json=false", "--non-interactive"}, clusterURL); err != nil { t.Fatalf("workspace ls --verbose failed: %v", err) } diff --git a/src/pkg/agent/tools/default_tool_cli.go b/src/pkg/agent/tools/default_tool_cli.go index 0e13e6ba4..22a40c97f 100644 --- a/src/pkg/agent/tools/default_tool_cli.go +++ b/src/pkg/agent/tools/default_tool_cli.go @@ -27,8 +27,8 @@ type StackConfig struct { type DefaultToolCLI struct{} -func (DefaultToolCLI) CanIUseProvider(ctx context.Context, client *cliClient.GrpcClient, projectName, stackName string, provider cliClient.Provider, serviceCount int) error { - return cliClient.CanIUseProvider(ctx, client, provider, projectName, stackName, serviceCount) +func (DefaultToolCLI) CanIUseProvider(ctx context.Context, client *cliClient.GrpcClient, provider cliClient.Provider, projectName string, serviceCount int) error { + return cliClient.CanIUseProvider(ctx, client, provider, projectName, serviceCount) } func (DefaultToolCLI) ConfigSet(ctx context.Context, projectName string, provider cliClient.Provider, name, value string) error { diff --git a/src/pkg/agent/tools/deploy.go b/src/pkg/agent/tools/deploy.go index 5518af254..12527fc27 100644 --- a/src/pkg/agent/tools/deploy.go +++ b/src/pkg/agent/tools/deploy.go @@ -22,7 +22,7 @@ type DeployParams struct { common.LoaderParams } -func HandleDeployTool(ctx context.Context, loader cliClient.ProjectLoader, params DeployParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { +func HandleDeployTool(ctx context.Context, loader cliClient.Loader, params DeployParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { term.Debug("Function invoked: loader.LoadProject") project, err := cli.LoadProject(ctx, loader) if err != nil { @@ -44,14 +44,17 @@ func HandleDeployTool(ctx context.Context, loader cliClient.ProjectLoader, param } } - sm := stacks.NewManager(params.WorkingDirectory) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, config.Stack) if err != nil { return "", fmt.Errorf("failed to setup provider: %w", err) } - err = cli.CanIUseProvider(ctx, client, project.Name, config.Stack.Name, provider, len(project.Services)) + err = cli.CanIUseProvider(ctx, client, provider, project.Name, len(project.Services)) if err != nil { return "", fmt.Errorf("failed to use provider: %w", err) } diff --git a/src/pkg/agent/tools/deploy_test.go b/src/pkg/agent/tools/deploy_test.go index d81fab71d..946e11d51 100644 --- a/src/pkg/agent/tools/deploy_test.go +++ b/src/pkg/agent/tools/deploy_test.go @@ -77,7 +77,7 @@ func (m *MockDeployCLI) LoadProject(ctx context.Context, loader client.Loader) ( return m.Project, nil } -func (m *MockDeployCLI) CanIUseProvider(ctx context.Context, client *client.GrpcClient, projectName, stackName string, provider client.Provider, serviceCount int) error { +func (m *MockDeployCLI) CanIUseProvider(ctx context.Context, client *client.GrpcClient, provider client.Provider, projectName string, serviceCount int) error { m.CallLog = append(m.CallLog, "CanIUseProvider") return nil } diff --git a/src/pkg/agent/tools/destroy.go b/src/pkg/agent/tools/destroy.go index b0af31a84..4aaa99bbc 100644 --- a/src/pkg/agent/tools/destroy.go +++ b/src/pkg/agent/tools/destroy.go @@ -17,14 +17,17 @@ type DestroyParams struct { common.LoaderParams } -func HandleDestroyTool(ctx context.Context, loader cliClient.ProjectLoader, params DestroyParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { +func HandleDestroyTool(ctx context.Context, loader cliClient.Loader, params DestroyParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, config.Cluster) if err != nil { return "", fmt.Errorf("could not connect: %w", err) } - sm := stacks.NewManager(params.WorkingDirectory) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, config.Stack) if err != nil { @@ -36,7 +39,7 @@ func HandleDestroyTool(ctx context.Context, loader cliClient.ProjectLoader, para return "", fmt.Errorf("failed to load project name: %w", err) } - err = cli.CanIUseProvider(ctx, client, projectName, config.Stack.Name, provider, 0) + err = cli.CanIUseProvider(ctx, client, provider, projectName, 0) if err != nil { return "", fmt.Errorf("failed to use provider: %w", err) } diff --git a/src/pkg/agent/tools/destroy_test.go b/src/pkg/agent/tools/destroy_test.go index 3e5fc610d..5c850eed1 100644 --- a/src/pkg/agent/tools/destroy_test.go +++ b/src/pkg/agent/tools/destroy_test.go @@ -57,7 +57,7 @@ func (m *MockDestroyCLI) LoadProjectNameWithFallback(ctx context.Context, loader return m.ProjectName, nil } -func (m *MockDestroyCLI) CanIUseProvider(ctx context.Context, grpcClient *client.GrpcClient, projectName, stackName string, provider client.Provider, serviceCount int) error { +func (m *MockDestroyCLI) CanIUseProvider(ctx context.Context, grpcClient *client.GrpcClient, provider client.Provider, projectName string, serviceCount int) error { m.CallLog = append(m.CallLog, fmt.Sprintf("CanIUseProvider(%s)", projectName)) if m.CanIUseProviderError != nil { return m.CanIUseProviderError diff --git a/src/pkg/agent/tools/estimate.go b/src/pkg/agent/tools/estimate.go index 17bbad552..688302988 100644 --- a/src/pkg/agent/tools/estimate.go +++ b/src/pkg/agent/tools/estimate.go @@ -17,7 +17,7 @@ type EstimateParams struct { Region string `json:"region,omitempty" jsonschema:"description=The region in which to estimate costs."` } -func HandleEstimateTool(ctx context.Context, loader cliClient.ProjectLoader, params EstimateParams, cli CLIInterface, sc StackConfig) (string, error) { +func HandleEstimateTool(ctx context.Context, loader cliClient.Loader, params EstimateParams, cli CLIInterface, sc StackConfig) (string, error) { term.Debug("Function invoked: loader.LoadProject") project, err := cli.LoadProject(ctx, loader) if err != nil { diff --git a/src/pkg/agent/tools/interfaces.go b/src/pkg/agent/tools/interfaces.go index 03ffa6bc5..5afa3a998 100644 --- a/src/pkg/agent/tools/interfaces.go +++ b/src/pkg/agent/tools/interfaces.go @@ -12,7 +12,7 @@ import ( ) type CLIInterface interface { - CanIUseProvider(ctx context.Context, client *cliClient.GrpcClient, projectName, stackName string, provider cliClient.Provider, serviceCount int) error + CanIUseProvider(ctx context.Context, client *cliClient.GrpcClient, provider cliClient.Provider, projectName string, serviceCount int) error ComposeDown(ctx context.Context, projectName string, client *cliClient.GrpcClient, provider cliClient.Provider) (string, error) ComposeUp(ctx context.Context, client *cliClient.GrpcClient, provider cliClient.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) ConfigDelete(ctx context.Context, projectName string, provider cliClient.Provider, name string) error diff --git a/src/pkg/agent/tools/listConfig.go b/src/pkg/agent/tools/listConfig.go index cd2688975..7f05be5dc 100644 --- a/src/pkg/agent/tools/listConfig.go +++ b/src/pkg/agent/tools/listConfig.go @@ -17,14 +17,18 @@ type ListConfigParams struct { } // HandleListConfigTool handles the list config tool logic -func HandleListConfigTool(ctx context.Context, loader cliClient.ProjectLoader, params ListConfigParams, cli CLIInterface, ec elicitations.Controller, sc StackConfig) (string, error) { +func HandleListConfigTool(ctx context.Context, loader cliClient.Loader, params ListConfigParams, cli CLIInterface, ec elicitations.Controller, sc StackConfig) (string, error) { term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, sc.Cluster) if err != nil { return "", fmt.Errorf("Could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client, stacks.NewManager(params.WorkingDirectory)) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } + pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, sc.Stack) if err != nil { return "", fmt.Errorf("failed to setup provider: %w", err) diff --git a/src/pkg/agent/tools/logs.go b/src/pkg/agent/tools/logs.go index 3640f5e72..6fc7c99ee 100644 --- a/src/pkg/agent/tools/logs.go +++ b/src/pkg/agent/tools/logs.go @@ -22,7 +22,7 @@ type LogsParams struct { Until string `json:"until,omitempty" jsonschema:"description=Optional: Retrieve logs written before this time. Format as RFC3339 or duration (e.g., '2023-10-01T15:04:05Z' or '1h')."` } -func HandleLogsTool(ctx context.Context, loader cliClient.ProjectLoader, params LogsParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { +func HandleLogsTool(ctx context.Context, loader cliClient.Loader, params LogsParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { var sinceTime, untilTime time.Time var err error now := time.Now() @@ -45,7 +45,10 @@ func HandleLogsTool(ctx context.Context, loader cliClient.ProjectLoader, params return "", fmt.Errorf("could not connect: %w", err) } - sm := stacks.NewManager(params.WorkingDirectory) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, config.Stack) if err != nil { @@ -59,7 +62,7 @@ func HandleLogsTool(ctx context.Context, loader cliClient.ProjectLoader, params } term.Debug("Project name loaded:", projectName) - err = cli.CanIUseProvider(ctx, client, projectName, config.Stack.Name, provider, 0) + err = cli.CanIUseProvider(ctx, client, provider, projectName, 0) if err != nil { return "", fmt.Errorf("failed to use provider: %w", err) } diff --git a/src/pkg/agent/tools/provider.go b/src/pkg/agent/tools/provider.go index c9f089649..ed22bb8e9 100644 --- a/src/pkg/agent/tools/provider.go +++ b/src/pkg/agent/tools/provider.go @@ -2,7 +2,6 @@ package tools import ( "context" - "errors" "fmt" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" @@ -37,7 +36,8 @@ func (pp *providerPreparer) SetupProvider(ctx context.Context, stack *stacks.Sta var providerID cliClient.ProviderID var err error if stack.Name == "" { - newStack, err := pp.setupStack(ctx) + selector := stacks.NewSelector(pp.ec, pp.sm) + newStack, err := selector.SelectStack(ctx) if err != nil { return nil, nil, fmt.Errorf("failed to setup stack: %w", err) } @@ -53,53 +53,3 @@ func (pp *providerPreparer) SetupProvider(ctx context.Context, stack *stacks.Sta provider := pp.pc.NewProvider(ctx, providerID, pp.fc, stack.Name) return &providerID, provider, nil } - -func (pp *providerPreparer) selectStack(ctx context.Context, ec elicitations.Controller) (string, error) { - stackList, err := pp.sm.List() - if err != nil { - return "", fmt.Errorf("failed to list stacks: %w", err) - } - - if len(stackList) == 0 { - return CreateNewStack, nil - } - - stackNames := make([]string, 0, len(stackList)+1) - for _, s := range stackList { - stackNames = append(stackNames, s.Name) - } - stackNames = append(stackNames, CreateNewStack) - - selectedStackName, err := ec.RequestEnum(ctx, "Select a stack", "stack", stackNames) - if err != nil { - return "", fmt.Errorf("failed to elicit stack choice: %w", err) - } - - return selectedStackName, nil -} - -func (pp *providerPreparer) setupStack(ctx context.Context) (*stacks.StackParameters, error) { - if !pp.ec.IsSupported() { - return nil, errors.New("your mcp client does not support elicitations, use the 'select_stack' tool to choose a stack") - } - selectedStackName, err := pp.selectStack(ctx, pp.ec) - if err != nil { - return nil, fmt.Errorf("failed to select stack: %w", err) - } - - if selectedStackName == CreateNewStack { - wizard := stacks.NewWizard(pp.ec) - params, err := wizard.CollectParameters(ctx) - if err != nil { - return nil, fmt.Errorf("failed to collect stack parameters: %w", err) - } - _, err = pp.sm.Create(*params) - if err != nil { - return nil, fmt.Errorf("failed to create stack: %w", err) - } - - selectedStackName = params.Name - } - - return pp.sm.Load(selectedStackName) -} diff --git a/src/pkg/agent/tools/removeConfig.go b/src/pkg/agent/tools/removeConfig.go index 0d5e279b9..16f6d726d 100644 --- a/src/pkg/agent/tools/removeConfig.go +++ b/src/pkg/agent/tools/removeConfig.go @@ -18,14 +18,17 @@ type RemoveConfigParams struct { } // HandleRemoveConfigTool handles the remove config tool logic -func HandleRemoveConfigTool(ctx context.Context, loader cliClient.ProjectLoader, params RemoveConfigParams, cli CLIInterface, ec elicitations.Controller, sc StackConfig) (string, error) { +func HandleRemoveConfigTool(ctx context.Context, loader cliClient.Loader, params RemoveConfigParams, cli CLIInterface, ec elicitations.Controller, sc StackConfig) (string, error) { term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, sc.Cluster) if err != nil { return "", fmt.Errorf("Could not connect: %w", err) } - sm := stacks.NewManager(params.WorkingDirectory) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, sc.Stack) if err != nil { diff --git a/src/pkg/agent/tools/services.go b/src/pkg/agent/tools/services.go index 6c589db03..90508ef4f 100644 --- a/src/pkg/agent/tools/services.go +++ b/src/pkg/agent/tools/services.go @@ -20,14 +20,18 @@ type ServicesParams struct { common.LoaderParams } -func HandleServicesTool(ctx context.Context, loader cliClient.ProjectLoader, params ServicesParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { +func HandleServicesTool(ctx context.Context, loader cliClient.Loader, params ServicesParams, cli CLIInterface, ec elicitations.Controller, config StackConfig) (string, error) { term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, config.Cluster) if err != nil { return "", fmt.Errorf("could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client, stacks.NewManager(params.WorkingDirectory)) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } + pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, config.Stack) if err != nil { return "", fmt.Errorf("failed to setup provider: %w", err) diff --git a/src/pkg/agent/tools/setConfig.go b/src/pkg/agent/tools/setConfig.go index 0ecc11764..393b5c63e 100644 --- a/src/pkg/agent/tools/setConfig.go +++ b/src/pkg/agent/tools/setConfig.go @@ -18,14 +18,17 @@ type SetConfigParams struct { Value string `json:"value" jsonschema:"required"` } -func HandleSetConfig(ctx context.Context, loader cliClient.ProjectLoader, params SetConfigParams, cli CLIInterface, ec elicitations.Controller, sc StackConfig) (string, error) { +func HandleSetConfig(ctx context.Context, loader cliClient.Loader, params SetConfigParams, cli CLIInterface, ec elicitations.Controller, sc StackConfig) (string, error) { term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, sc.Cluster) if err != nil { return "", fmt.Errorf("Could not connect: %w", err) } - sm := stacks.NewManager(params.WorkingDirectory) + sm, err := stacks.NewManager(client, params.WorkingDirectory, params.ProjectName) + if err != nil { + return "", fmt.Errorf("failed to create stack manager: %w", err) + } pp := NewProviderPreparer(cli, ec, client, sm) _, provider, err := pp.SetupProvider(ctx, sc.Stack) if err != nil { diff --git a/src/pkg/cli/client/byoc/baseclient.go b/src/pkg/cli/client/byoc/baseclient.go index b45431a1d..16013ea16 100644 --- a/src/pkg/cli/client/byoc/baseclient.go +++ b/src/pkg/cli/client/byoc/baseclient.go @@ -71,6 +71,10 @@ func NewByocBaseClient(tenantName types.TenantNameOrID, backend ProjectBackend, return b } +func (b *ByocBaseClient) GetStackName() string { + return b.PulumiStack +} + func (b *ByocBaseClient) Debug(context.Context, *defangv1.DebugRequest) (*defangv1.DebugResponse, error) { return nil, client.ErrNotImplemented("AI debugging is not yet supported for BYOC") } @@ -114,9 +118,6 @@ func (b *ByocBaseClient) GetProjectDomain(projectName, zone string) string { return "" // no project name => no custom domain } domain := dns.Normalize(zone) - if hasStack, ok := b.projectBackend.(HasStackSupport); ok { - domain = hasStack.GetStackName() + "." + domain - } return domain } diff --git a/src/pkg/cli/client/byoc/gcp/byoc.go b/src/pkg/cli/client/byoc/gcp/byoc.go index d62383c1a..40de9c98d 100644 --- a/src/pkg/cli/client/byoc/gcp/byoc.go +++ b/src/pkg/cli/client/byoc/gcp/byoc.go @@ -303,7 +303,7 @@ func (b *ByocGcp) BootstrapList(ctx context.Context, _allRegions bool) (iter.Seq func (b *ByocGcp) AccountInfo(ctx context.Context) (*client.AccountInfo, error) { projectId := getGcpProjectID() if projectId == "" { - return nil, errors.New("GCP_PROJECT_ID or CLOUDSDK_CORE_PROJECT must be set for GCP projects") + return nil, errors.New("GCP_PROJECT_ID or CLOUDSDK_CORE_PROJECT must be set for GCP projects; use 'gcloud projects list' to see available project ids") } // check whether the ADC is logged in by trying to get the current account email diff --git a/src/pkg/cli/client/caniuse.go b/src/pkg/cli/client/caniuse.go index de6e48b32..6f9c2a6d2 100644 --- a/src/pkg/cli/client/caniuse.go +++ b/src/pkg/cli/client/caniuse.go @@ -7,7 +7,7 @@ import ( defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" ) -func CanIUseProvider(ctx context.Context, client FabricClient, provider Provider, projectName, stack string, serviceCount int) error { +func CanIUseProvider(ctx context.Context, client FabricClient, provider Provider, projectName string, serviceCount int) error { info, err := provider.AccountInfo(ctx) if err != nil { return err @@ -19,7 +19,7 @@ func CanIUseProvider(ctx context.Context, client FabricClient, provider Provider ProviderAccountId: info.AccountID, Region: info.Region, ServiceCount: int32(serviceCount), // #nosec G115 - service count will not overflow int32 - Stack: stack, + Stack: provider.GetStackName(), } resp, err := client.CanIUse(ctx, &canUseReq) diff --git a/src/pkg/cli/client/client.go b/src/pkg/cli/client/client.go index 2290c2bf0..a547cad99 100644 --- a/src/pkg/cli/client/client.go +++ b/src/pkg/cli/client/client.go @@ -6,14 +6,8 @@ import ( "github.com/DefangLabs/defang/src/pkg/types" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/DefangLabs/defang/src/protos/io/defang/v1/defangv1connect" - composeTypes "github.com/compose-spec/compose-go/v2/types" ) -type ProjectLoader interface { - LoadProjectName(context.Context) (string, error) - LoadProject(context.Context) (*composeTypes.Project, error) -} - type FabricClient interface { AgreeToS(context.Context) error CanIUse(context.Context, *defangv1.CanIUseRequest) (*defangv1.CanIUseResponse, error) diff --git a/src/pkg/cli/client/mock.go b/src/pkg/cli/client/mock.go index c7a3c412b..601e69a9a 100644 --- a/src/pkg/cli/client/mock.go +++ b/src/pkg/cli/client/mock.go @@ -171,3 +171,7 @@ func (m MockLoader) LoadProject(ctx context.Context) (*composeTypes.Project, err func (m MockLoader) LoadProjectName(ctx context.Context) (string, error) { return m.Project.Name, m.Error } + +func (m MockLoader) OutsideWorkingDirectory() bool { + return false +} diff --git a/src/pkg/cli/client/playground.go b/src/pkg/cli/client/playground.go index e52d71cee..1fe3ecb84 100644 --- a/src/pkg/cli/client/playground.go +++ b/src/pkg/cli/client/playground.go @@ -18,10 +18,18 @@ type PlaygroundProvider struct { FabricClient RetryDelayer shardDomain string + PulumiStack string } var _ Provider = (*PlaygroundProvider)(nil) +func (g *PlaygroundProvider) GetStackName() string { + if g.PulumiStack == "" { + return "beta" + } + return g.PulumiStack +} + func (g *PlaygroundProvider) Deploy(ctx context.Context, req *defangv1.DeployRequest) (*defangv1.DeployResponse, error) { if os.Getenv("DEFANG_PULUMI_DIR") != "" { return nil, errors.New("DEFANG_PULUMI_DIR is set, but not supported by the Playground provider") diff --git a/src/pkg/cli/client/provider.go b/src/pkg/cli/client/provider.go index 5a8827b48..27798baed 100644 --- a/src/pkg/cli/client/provider.go +++ b/src/pkg/cli/client/provider.go @@ -66,11 +66,13 @@ type Provider interface { SetUpCD(context.Context) error Subscribe(context.Context, *defangv1.SubscribeRequest) (ServerStream[defangv1.SubscribeResponse], error) TearDownCD(context.Context) error + GetStackName() string } type Loader interface { LoadProject(context.Context) (*composeTypes.Project, error) LoadProjectName(context.Context) (string, error) + OutsideWorkingDirectory() bool } type RetryDelayer struct { diff --git a/src/pkg/cli/compose/loader.go b/src/pkg/cli/compose/loader.go index c17d5cd03..7a5f5abd3 100644 --- a/src/pkg/cli/compose/loader.go +++ b/src/pkg/cli/compose/loader.go @@ -91,6 +91,11 @@ func (l *Loader) LoadProject(ctx context.Context) (*Project, error) { return l.loadProject(ctx, false) } +func (l *Loader) OutsideWorkingDirectory() bool { + // if --project-name is provider, we assume we are outside the project's working directory + return l.options.ProjectName != "" +} + func (l *Loader) loadProject(ctx context.Context, suppressWarn bool) (*Project, error) { if l.cached != nil { return l.cached, nil diff --git a/src/pkg/stacks/manager.go b/src/pkg/stacks/manager.go index d3f89c9aa..be1de685f 100644 --- a/src/pkg/stacks/manager.go +++ b/src/pkg/stacks/manager.go @@ -1,37 +1,206 @@ package stacks +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "slices" + "time" + + cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" + "github.com/DefangLabs/defang/src/pkg/modes" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" +) + type Manager interface { - List() ([]StackListItem, error) + List(ctx context.Context) ([]StackListItem, error) Load(name string) (*StackParameters, error) + LoadParameters(params map[string]string, overload bool) error Create(params StackParameters) (string, error) } +type DeploymentLister interface { + ListDeployments(ctx context.Context, req *defangv1.ListDeploymentsRequest) (*defangv1.ListDeploymentsResponse, error) +} + type manager struct { + fabric DeploymentLister + targetDirectory string + projectName string + outside bool workingDirectory string } -func NewManager(workingDirectory string) *manager { +func NewManager(fabric DeploymentLister, targetDirectory string, projectName string) (*manager, error) { + workingDirectory, err := os.Getwd() + if err != nil { + return nil, fmt.Errorf("failed to get working directory: %w", err) + } + var outside bool + var absTargetDirectory string + if targetDirectory == "" { + outside = true + absTargetDirectory = "" + } else { + // abs path for targetDirectory + var err error + absTargetDirectory, err = filepath.Abs(targetDirectory) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for target directory: %w", err) + } + outside = workingDirectory != absTargetDirectory + } return &manager{ + fabric: fabric, + targetDirectory: absTargetDirectory, + projectName: projectName, + outside: outside, workingDirectory: workingDirectory, + }, nil +} + +func (sm *manager) List(ctx context.Context) ([]StackListItem, error) { + remoteStacks, err := sm.ListRemote(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list remote stacks: %w", err) + } + localStacks, err := sm.ListLocal() + if err != nil { + var outsideErr *OutsideError + if !errors.As(err, &outsideErr) { + return nil, fmt.Errorf("failed to list local stacks: %w", err) + } + } + // Merge remote and local stacks into a single list of type StackOption, + // prefer remote if both exist, so we can show last deployed time + stackMap := make(map[string]StackListItem) + for _, local := range localStacks { + stackMap[local.Name] = StackListItem{ + Name: local.Name, + Provider: local.Provider, + Region: local.Region, + Mode: local.Mode, + AWSProfile: local.AWSProfile, + GCPProjectID: local.GCPProjectID, + DeployedAt: time.Time{}, // No deployed time for local + } + } + for _, remote := range remoteStacks { + stackMap[remote.StackParameters.Name] = StackListItem{ + Name: remote.StackParameters.Name, + Provider: remote.StackParameters.Provider.String(), + Region: remote.StackParameters.Region, + Mode: remote.StackParameters.Mode.String(), + AWSProfile: remote.StackParameters.AWSProfile, + GCPProjectID: remote.StackParameters.GCPProjectID, + DeployedAt: remote.DeployedAt.Local(), + } + } + + stackList := make([]StackListItem, 0, len(stackMap)) + for _, stack := range stackMap { + stackList = append(stackList, stack) + } + // sort stacks by name asc + // sort stacks by name asc + slices.SortFunc(stackList, func(a, b StackListItem) int { + if a.Name < b.Name { + return -1 + } + if a.Name > b.Name { + return 1 + } + return 0 + }) + + return stackList, nil +} + +func (sm *manager) ListLocal() ([]StackListItem, error) { + if sm.outside { + return nil, &OutsideError{TargetDirectory: sm.targetDirectory, WorkingDirectory: sm.workingDirectory} } + return ListInDirectory(sm.targetDirectory) +} + +type RemoteStack struct { + StackParameters + DeployedAt time.Time } -func (sm *manager) List() ([]StackListItem, error) { - return ListInDirectory(sm.workingDirectory) +func (sm *manager) ListRemote(ctx context.Context) ([]RemoteStack, error) { + resp, err := sm.fabric.ListDeployments(ctx, &defangv1.ListDeploymentsRequest{ + Project: sm.projectName, + }) + if err != nil { + return nil, fmt.Errorf("failed to list deployments: %w", err) + } + deployments := resp.GetDeployments() + stackMap := make(map[string]RemoteStack) + for _, deployment := range deployments { + stackName := deployment.GetStack() + if stackName == "" { + stackName = "beta" + } + var providerID cliClient.ProviderID + providerID.SetValue(deployment.GetProvider()) + // avoid overwriting existing entries, deployments are already sorted by deployed_at desc + if _, exists := stackMap[stackName]; !exists { + var deployedAt time.Time + if ts := deployment.GetTimestamp(); ts != nil { + deployedAt = ts.AsTime() + } + stackMap[stackName] = RemoteStack{ + StackParameters: StackParameters{ + Name: stackName, + Provider: providerID, + Region: deployment.GetRegion(), + Mode: modes.Mode(deployment.GetMode()), + }, + DeployedAt: deployedAt, + } + } + } + stackParams := make([]RemoteStack, 0, len(stackMap)) + for _, params := range stackMap { + stackParams = append(stackParams, params) + } + return stackParams, nil +} + +type OutsideError struct { + TargetDirectory string + WorkingDirectory string +} + +func (e *OutsideError) Error() string { + return fmt.Sprintf("operation not allowed: target directory (%s) is different from working directory (%s)", e.TargetDirectory, e.WorkingDirectory) } func (sm *manager) Load(name string) (*StackParameters, error) { - params, err := ReadInDirectory(sm.workingDirectory, name) + if sm.outside { + return nil, &OutsideError{TargetDirectory: sm.targetDirectory, WorkingDirectory: sm.workingDirectory} + } + params, err := ReadInDirectory(sm.targetDirectory, name) if err != nil { return nil, err } - err = LoadInDirectory(sm.workingDirectory, name) + err = LoadInDirectory(sm.targetDirectory, name) if err != nil { return nil, err } return params, nil } +func (sm *manager) LoadParameters(params map[string]string, overload bool) error { + return LoadParameters(params, overload) +} + func (sm *manager) Create(params StackParameters) (string, error) { - return CreateInDirectory(sm.workingDirectory, params) + if sm.outside { + return "", &OutsideError{TargetDirectory: sm.targetDirectory, WorkingDirectory: sm.workingDirectory} + } + return CreateInDirectory(sm.targetDirectory, params) } diff --git a/src/pkg/stacks/manager_test.go b/src/pkg/stacks/manager_test.go index afcc83627..486299d86 100644 --- a/src/pkg/stacks/manager_test.go +++ b/src/pkg/stacks/manager_test.go @@ -1,17 +1,42 @@ package stacks import ( + "context" + "errors" "os" "path/filepath" + "strings" "testing" + "time" "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/modes" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" + "google.golang.org/protobuf/types/known/timestamppb" ) +// mockFabricClient implements FabricClient interface for testing +type mockFabricClient struct { + deployments []*defangv1.Deployment + listErr error +} + +func (m *mockFabricClient) ListDeployments(ctx context.Context, req *defangv1.ListDeploymentsRequest) (*defangv1.ListDeploymentsResponse, error) { + if m.listErr != nil { + return nil, m.listErr + } + return &defangv1.ListDeploymentsResponse{ + Deployments: m.deployments, + }, nil +} + func TestNewManager(t *testing.T) { workingDir := "/tmp/test-dir" - manager := NewManager(workingDir) + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, workingDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } if manager == nil { t.Error("NewManager should not return nil") @@ -22,10 +47,17 @@ func TestManager_CreateListLoad(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() - manager := NewManager(tmpDir) + // Change to temp directory so working directory matches target directory + t.Chdir(tmpDir) + + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } // Test that listing returns empty when no stacks exist - stacks, err := manager.List() + stacks, err := manager.List(context.Background()) if err != nil { t.Fatalf("List() should not error on empty directory: %v", err) } @@ -58,7 +90,7 @@ func TestManager_CreateListLoad(t *testing.T) { } // Test listing after creating a stack - stacks, err = manager.List() + stacks, err = manager.List(context.Background()) if err != nil { t.Fatalf("List() failed: %v", err) } @@ -104,7 +136,14 @@ func TestManager_CreateGCPStack(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() - manager := NewManager(tmpDir) + // Change to temp directory so working directory matches target directory + t.Chdir(tmpDir) + + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } // Test creating a GCP stack params := StackParameters{ @@ -145,7 +184,14 @@ func TestManager_CreateMultipleStacks(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() - manager := NewManager(tmpDir) + // Change to temp directory so working directory matches target directory + t.Chdir(tmpDir) + + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } // Create multiple stacks stacks := []StackParameters{ @@ -173,14 +219,14 @@ func TestManager_CreateMultipleStacks(t *testing.T) { // Create all stacks for _, params := range stacks { - _, err := manager.Create(params) + _, err = manager.Create(params) if err != nil { t.Fatalf("Create() failed for stack %s: %v", params.Name, err) } } // List all stacks - listedStacks, err := manager.List() + listedStacks, err := manager.List(context.Background()) if err != nil { t.Fatalf("List() failed: %v", err) } @@ -208,10 +254,14 @@ func TestManager_LoadNonexistentStack(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() - manager := NewManager(tmpDir) + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } // Try to load a stack that doesn't exist - _, err := manager.Load("nonexistent") + _, err = manager.Load("nonexistent") if err == nil { t.Error("Load() should return error for nonexistent stack") } @@ -221,7 +271,11 @@ func TestManager_CreateInvalidStackName(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() - manager := NewManager(tmpDir) + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } // Test with empty name params := StackParameters{ @@ -230,7 +284,7 @@ func TestManager_CreateInvalidStackName(t *testing.T) { Region: "us-east-1", } - _, err := manager.Create(params) + _, err = manager.Create(params) if err == nil { t.Error("Create() should return error for empty stack name") } @@ -254,7 +308,14 @@ func TestManager_CreateDuplicateStack(t *testing.T) { // Create a temporary directory for testing tmpDir := t.TempDir() - manager := NewManager(tmpDir) + // Change to temp directory so working directory matches target directory + t.Chdir(tmpDir) + + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } params := StackParameters{ Name: "duplicatestack", @@ -264,7 +325,7 @@ func TestManager_CreateDuplicateStack(t *testing.T) { } // Create the first stack - _, err := manager.Create(params) + _, err = manager.Create(params) if err != nil { t.Fatalf("First Create() failed: %v", err) } @@ -275,3 +336,417 @@ func TestManager_CreateDuplicateStack(t *testing.T) { t.Error("Create() should return error for duplicate stack name") } } + +func TestManager_ListRemote(t *testing.T) { + tmpDir := t.TempDir() + + deployedAt := time.Now() + mockClient := &mockFabricClient{ + deployments: []*defangv1.Deployment{ + { + Stack: "remotestack1", + Provider: defangv1.Provider_AWS, + Region: "us-east-1", + Timestamp: timestamppb.New(deployedAt), + }, + { + Stack: "remotestack2", + Provider: defangv1.Provider_GCP, + Region: "us-central1", + Timestamp: timestamppb.New(deployedAt.Add(-time.Hour)), + }, + }, + } + + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + remoteStacks, err := manager.ListRemote(context.Background()) + if err != nil { + t.Fatalf("ListRemote() failed: %v", err) + } + + if len(remoteStacks) != 2 { + t.Errorf("Expected 2 remote stacks, got %d", len(remoteStacks)) + } + + // Check first remote stack + if remoteStacks[0].Name != "remotestack1" && remoteStacks[1].Name != "remotestack1" { + t.Error("Expected to find remotestack1") + } + + // Check second remote stack + if remoteStacks[0].Name != "remotestack2" && remoteStacks[1].Name != "remotestack2" { + t.Error("Expected to find remotestack2") + } + + // Verify deployed time is set + for _, stack := range remoteStacks { + if stack.DeployedAt.IsZero() { + t.Errorf("Expected DeployedAt to be set for stack %s", stack.Name) + } + } +} + +func TestManager_ListRemoteError(t *testing.T) { + tmpDir := t.TempDir() + + mockClient := &mockFabricClient{ + listErr: errors.New("network error"), + } + + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + _, err = manager.ListRemote(context.Background()) + if err == nil { + t.Error("ListRemote() should return error when fabric client fails") + } +} + +func TestManager_ListMerged(t *testing.T) { + tmpDir := t.TempDir() + + // Change to temp directory so working directory matches target directory + t.Chdir(tmpDir) + + deployedAt := time.Now() + mockClient := &mockFabricClient{ + deployments: []*defangv1.Deployment{ + { + Stack: "sharedstack", + Provider: defangv1.Provider_AWS, + Region: "us-east-1", + Timestamp: timestamppb.New(deployedAt), + }, + { + Stack: "remoteonlystack", + Provider: defangv1.Provider_GCP, + Region: "us-central1", + Timestamp: timestamppb.New(deployedAt), + }, + }, + } + + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + // Create a local stack that exists remotely too + localParams := StackParameters{ + Name: "sharedstack", + Provider: client.ProviderAWS, + Region: "us-west-2", // Different region locally + AWSProfile: "default", + Mode: modes.ModeAffordable, + } + _, err = manager.Create(localParams) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + // Create a local-only stack + localOnlyParams := StackParameters{ + Name: "localonlystack", + Provider: client.ProviderAWS, + Region: "us-west-1", + AWSProfile: "default", + Mode: modes.ModeAffordable, + } + _, err = manager.Create(localOnlyParams) + if err != nil { + t.Fatalf("Create() failed: %v", err) + } + + // List merged stacks + stacks, err := manager.List(context.Background()) + if err != nil { + t.Fatalf("List() failed: %v", err) + } + + if len(stacks) != 3 { + t.Errorf("Expected 3 merged stacks, got %d", len(stacks)) + } + + stackMap := make(map[string]StackListItem) + for _, stack := range stacks { + stackMap[stack.Name] = stack + } + + // Check shared stack prefers remote (should have deployed time and remote region) + sharedStack, exists := stackMap["sharedstack"] + if !exists { + t.Error("Expected to find sharedstack") + } else { + if sharedStack.Region != "us-east-1" { + t.Errorf("Expected shared stack to use remote region us-east-1, got %s", sharedStack.Region) + } + if sharedStack.DeployedAt.IsZero() { + t.Error("Expected shared stack to have deployment time from remote") + } + } + + // Check remote-only stack exists + _, exists = stackMap["remoteonlystack"] + if !exists { + t.Error("Expected to find remoteonlystack") + } + + // Check local-only stack exists and has no deployed time + localOnlyStack, exists := stackMap["localonlystack"] + if !exists { + t.Error("Expected to find localonlystack") + } else { + if !localOnlyStack.DeployedAt.IsZero() { + t.Error("Expected local-only stack to have zero deployed time") + } + } +} + +func TestManager_ListRemoteWithBetaStack(t *testing.T) { + tmpDir := t.TempDir() + + deployedAt := time.Now() + mockClient := &mockFabricClient{ + deployments: []*defangv1.Deployment{ + { + Stack: "", // Empty stack name should default to "beta" + Provider: defangv1.Provider_AWS, + Region: "us-east-1", + Timestamp: timestamppb.New(deployedAt), + }, + }, + } + + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + remoteStacks, err := manager.ListRemote(context.Background()) + if err != nil { + t.Fatalf("ListRemote() failed: %v", err) + } + + if len(remoteStacks) != 1 { + t.Errorf("Expected 1 remote stack, got %d", len(remoteStacks)) + } + + if remoteStacks[0].Name != "beta" { + t.Errorf("Expected stack name to be 'beta', got '%s'", remoteStacks[0].Name) + } +} + +func TestManager_ListRemoteDuplicateDeployments(t *testing.T) { + tmpDir := t.TempDir() + + deployedAt := time.Now() + mockClient := &mockFabricClient{ + deployments: []*defangv1.Deployment{ + { + Stack: "duplicatestack", + Provider: defangv1.Provider_AWS, + Region: "us-east-1", + Timestamp: timestamppb.New(deployedAt), // Most recent + }, + { + Stack: "duplicatestack", + Provider: defangv1.Provider_AWS, + Region: "us-west-2", + Timestamp: timestamppb.New(deployedAt.Add(-time.Hour)), // Older + }, + }, + } + + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + remoteStacks, err := manager.ListRemote(context.Background()) + if err != nil { + t.Fatalf("ListRemote() failed: %v", err) + } + + if len(remoteStacks) != 1 { + t.Errorf("Expected 1 remote stack (duplicates should be merged), got %d", len(remoteStacks)) + } + + // Should use the first deployment (most recent) since they're already sorted + if remoteStacks[0].Region != "us-east-1" { + t.Errorf("Expected region from first deployment (us-east-1), got %s", remoteStacks[0].Region) + } +} + +func TestManager_WorkingDirectoryMatches(t *testing.T) { + // Create a temporary directory for testing + tmpDir := t.TempDir() + + // Change to temp directory so working directory matches target directory + t.Chdir(tmpDir) + + mockClient := &mockFabricClient{} + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + // Test that local operations work when working directory matches target directory + params := StackParameters{ + Name: "teststack", + Provider: client.ProviderAWS, + Region: "us-east-1", + AWSProfile: "default", + Mode: modes.ModeAffordable, + } + + // Create should work + filename, err := manager.Create(params) + if err != nil { + t.Fatalf("Create() failed when directories match: %v", err) + } + + expectedPath := filepath.Join(tmpDir, Directory, "teststack") + if filename != expectedPath { + t.Errorf("Expected filename %s, got %s", expectedPath, filename) + } + + // List should work + stacks, err := manager.List(context.Background()) + if err != nil { + t.Fatalf("List() failed when directories match: %v", err) + } + + if len(stacks) != 1 { + t.Errorf("Expected 1 stack, got %d", len(stacks)) + } + + // Load should work + loadedParams, err := manager.Load("teststack") + if err != nil { + t.Fatalf("Load() failed when directories match: %v", err) + } + + if loadedParams.Name != "teststack" { + t.Errorf("Expected loaded stack name 'teststack', got '%s'", loadedParams.Name) + } +} + +func TestManager_WorkingDirectoryDifferent(t *testing.T) { + // Create a temporary directory for testing but don't change to it + tmpDir := t.TempDir() + + deployedAt := time.Now() + mockClient := &mockFabricClient{ + deployments: []*defangv1.Deployment{ + { + Stack: "remotestack1", + Provider: defangv1.Provider_AWS, + Region: "us-east-1", + Timestamp: timestamppb.New(deployedAt), + }, + { + Stack: "remotestack2", + Provider: defangv1.Provider_GCP, + Region: "us-central1", + Timestamp: timestamppb.New(deployedAt), + }, + }, + } + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + // Test that local operations are blocked when working directory differs from target directory + params := StackParameters{ + Name: "teststack", + Provider: client.ProviderAWS, + Region: "us-east-1", + AWSProfile: "default", + Mode: modes.ModeAffordable, + } + + // Create should fail + _, err = manager.Create(params) + if err == nil { + t.Error("Create() should fail when target directory differs from working directory") + } + if !strings.Contains(err.Error(), "operation not allowed: target directory") { + t.Errorf("Expected specific error message about operation not allowed, got: %v", err) + } + + // List should return only remote stacks (no error) + stacks, err := manager.List(context.Background()) + if err != nil { + t.Fatalf("List() should not fail when target directory differs from working directory: %v", err) + } + if len(stacks) != 2 { + t.Errorf("Expected 2 remote stacks, got %d", len(stacks)) + } + + // Verify the returned stacks are remote stacks + stackNames := make(map[string]bool) + for _, stack := range stacks { + stackNames[stack.Name] = true + if stack.DeployedAt.IsZero() { + t.Errorf("Expected remote stack %s to have deployment time", stack.Name) + } + } + if !stackNames["remotestack1"] { + t.Error("Expected to find remotestack1") + } + if !stackNames["remotestack2"] { + t.Error("Expected to find remotestack2") + } + + // Load should fail + _, err = manager.Load("teststack") + if err == nil { + t.Error("Load() should fail when target directory differs from working directory") + } + if !strings.Contains(err.Error(), "operation not allowed: target directory") { + t.Errorf("Expected specific error message about operation not allowed, got: %v", err) + } +} + +func TestManager_RemoteOperationsWorkRegardlessOfDirectory(t *testing.T) { + // Create a temporary directory for testing but don't change to it + tmpDir := t.TempDir() + + deployedAt := time.Now() + mockClient := &mockFabricClient{ + deployments: []*defangv1.Deployment{ + { + Stack: "remotestack", + Provider: defangv1.Provider_AWS, + Region: "us-east-1", + Timestamp: timestamppb.New(deployedAt), + }, + }, + } + + manager, err := NewManager(mockClient, tmpDir, "test-project") + if err != nil { + t.Fatalf("NewManager failed: %v", err) + } + + // Remote operations should work even when directories don't match + remoteStacks, err := manager.ListRemote(context.Background()) + if err != nil { + t.Fatalf("ListRemote() should work even when directories don't match: %v", err) + } + + if len(remoteStacks) != 1 { + t.Errorf("Expected 1 remote stack, got %d", len(remoteStacks)) + } + + if remoteStacks[0].Name != "remotestack" { + t.Errorf("Expected stack name 'remotestack', got '%s'", remoteStacks[0].Name) + } +} diff --git a/src/pkg/stacks/selector.go b/src/pkg/stacks/selector.go new file mode 100644 index 000000000..73e8589c2 --- /dev/null +++ b/src/pkg/stacks/selector.go @@ -0,0 +1,115 @@ +package stacks + +import ( + "context" + "errors" + "fmt" + "os" + "path/filepath" + "slices" + + "github.com/DefangLabs/defang/src/pkg/elicitations" + "github.com/DefangLabs/defang/src/pkg/term" +) + +const CreateNewStack = "Create new stack" + +type stackSelector struct { + ec elicitations.Controller + sm Manager +} + +func NewSelector(ec elicitations.Controller, sm Manager) *stackSelector { + return &stackSelector{ + ec: ec, + sm: sm, + } +} + +func (ss *stackSelector) SelectStack(ctx context.Context) (*StackParameters, error) { + if !ss.ec.IsSupported() { + return nil, errors.New("your mcp client does not support elicitations, use the 'select_stack' tool to choose a stack") + } + selectedStackName, err := ss.elicitStackSelection(ctx, ss.ec) + if err != nil { + return nil, fmt.Errorf("failed to select stack: %w", err) + } + + if selectedStackName == CreateNewStack { + wizard := NewWizard(ss.ec) + params, err := wizard.CollectParameters(ctx) + if err != nil { + return nil, fmt.Errorf("failed to collect stack parameters: %w", err) + } + _, err = ss.sm.Create(*params) + if err != nil { + return nil, fmt.Errorf("failed to create stack: %w", err) + } + + selectedStackName = params.Name + } + + return ss.sm.Load(selectedStackName) +} + +func (ss *stackSelector) elicitStackSelection(ctx context.Context, ec elicitations.Controller) (string, error) { + stackList, err := ss.sm.List(ctx) + if err != nil { + return "", fmt.Errorf("failed to list stacks: %w", err) + } + + if len(stackList) == 0 { + return CreateNewStack, nil + } + + stackLabels := make([]string, 0, len(stackList)+1) + stackNames := make([]string, 0, len(stackList)) + labelMap := make(map[string]string) + for _, s := range stackList { + var label string + if s.DeployedAt.IsZero() { + label = s.Name + } else { + label = fmt.Sprintf("%s (deployed %s)", s.Name, s.DeployedAt.Format("Jan 2")) + } + stackLabels = append(stackLabels, label) + stackNames = append(stackNames, s.Name) + labelMap[label] = s.Name + } + stackLabels = append(stackLabels, CreateNewStack) + + printStacksInfoMessage(stackNames) + selectedLabel, err := ec.RequestEnum(ctx, "Select a stack", "stack", stackLabels) + if err != nil { + return "", fmt.Errorf("failed to elicit stack choice: %w", err) + } + + // If "Create new stack" was selected, return as-is + if selectedLabel == CreateNewStack { + return CreateNewStack, nil + } + + // Otherwise, map back to the actual stack name + selectedName, exists := labelMap[selectedLabel] + if !exists { + return "", fmt.Errorf("invalid stack selection: %s", selectedLabel) + } + + return selectedName, nil +} + +func printStacksInfoMessage(stacks []string) { + // If there is a stack named "beta", print an informational message about it + betaExists := slices.Contains(stacks, "beta") + if betaExists { + infoLine := "This project was deployed with an implicit Stack called 'beta' before Stacks were introduced." + if len(stacks) == 1 { + infoLine += "\n To update your existing deployment, select the 'beta' Stack.\n" + + "Creating a new Stack will result in a separate deployment instance." + } + infoLine += "\n To learn more about Stacks, visit: https://docs.defang.io/docs/concepts/stacks" + term.Info(infoLine + "\n") + } + executable, _ := os.Executable() + term.Infof("To skip this prompt, run %s up --stack=%s", filepath.Base(executable), "") +} diff --git a/src/pkg/stacks/selector_test.go b/src/pkg/stacks/selector_test.go new file mode 100644 index 000000000..012e5045e --- /dev/null +++ b/src/pkg/stacks/selector_test.go @@ -0,0 +1,541 @@ +package stacks + +import ( + "context" + "errors" + "fmt" + "testing" + + cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" + "github.com/DefangLabs/defang/src/pkg/elicitations" + "github.com/DefangLabs/defang/src/pkg/modes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +// MockElicitationsController mocks the elicitations.Controller interface +type MockElicitationsController struct { + mock.Mock +} + +func (m *MockElicitationsController) RequestString(ctx context.Context, message, field string) (string, error) { + args := m.Called(ctx, message, field) + return args.String(0), args.Error(1) +} + +func (m *MockElicitationsController) RequestStringWithDefault(ctx context.Context, message, field, defaultValue string) (string, error) { + args := m.Called(ctx, message, field, defaultValue) + return args.String(0), args.Error(1) +} + +func (m *MockElicitationsController) RequestEnum(ctx context.Context, message, field string, options []string) (string, error) { + args := m.Called(ctx, message, field, options) + return args.String(0), args.Error(1) +} + +func (m *MockElicitationsController) SetSupported(supported bool) { + m.Called(supported) +} + +func (m *MockElicitationsController) IsSupported() bool { + args := m.Called() + return args.Bool(0) +} + +// MockStacksManager mocks the stacks.Manager interface +type MockStacksManager struct { + mock.Mock +} + +func (m *MockStacksManager) List(ctx context.Context) ([]StackListItem, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + result, ok := args.Get(0).([]StackListItem) + if !ok { + return nil, args.Error(1) + } + return result, args.Error(1) +} + +func (m *MockStacksManager) Load(name string) (*StackParameters, error) { + args := m.Called(name) + if args.Get(0) == nil { + return nil, args.Error(1) + } + result, ok := args.Get(0).(*StackParameters) + if !ok { + return nil, args.Error(1) + } + return result, args.Error(1) +} + +func (m *MockStacksManager) LoadParameters(params map[string]string, overload bool) error { + args := m.Called(params, overload) + return args.Error(0) +} + +func (m *MockStacksManager) Create(params StackParameters) (string, error) { + args := m.Called(params) + return args.String(0), args.Error(1) +} + +func TestStackSelector_SelectStack_ExistingStack(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock existing stacks list + existingStacks := []StackListItem{ + {Name: "production", Provider: "aws", Region: "us-west-2"}, + {Name: "development", Provider: "aws", Region: "us-east-1"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil) + + // Mock user selecting existing stack + expectedOptions := []string{"production", "development", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return("production", nil) + + // Mock loading the selected stack + expectedParams := &StackParameters{ + Name: "production", + Provider: cliClient.ProviderAWS, + Region: "us-west-2", + AWSProfile: "default", + Mode: modes.ModeBalanced, + } + mockSM.On("Load", "production").Return(expectedParams, nil) + + selector := NewSelector(mockEC, mockSM) + + result, err := selector.SelectStack(ctx) + + assert.NoError(t, err) + assert.Equal(t, expectedParams, result) + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) +} + +// WizardInterface defines the interface for collecting stack parameters +type WizardInterface interface { + CollectParameters(ctx context.Context) (*StackParameters, error) +} + +// MockWizardInterface mocks the WizardInterface +type MockWizardInterface struct { + mock.Mock +} + +func (m *MockWizardInterface) CollectParameters(ctx context.Context) (*StackParameters, error) { + args := m.Called(ctx) + if args.Get(0) == nil { + return nil, args.Error(1) + } + result, ok := args.Get(0).(*StackParameters) + if !ok { + return nil, args.Error(1) + } + return result, args.Error(1) +} + +// testableStackSelector extends stackSelector to allow wizard injection for testing +type testableStackSelector struct { + ec elicitations.Controller + sm Manager + wizard WizardInterface +} + +func (tss *testableStackSelector) SelectStack(ctx context.Context) (*StackParameters, error) { + if !tss.ec.IsSupported() { + return nil, errors.New("your mcp client does not support elicitations, use the 'select_stack' tool to choose a stack") + } + selectedStackName, err := tss.elicitStackSelection(ctx, tss.ec) + if err != nil { + return nil, fmt.Errorf("failed to select stack: %w", err) + } + + if selectedStackName == CreateNewStack { + params, err := tss.wizard.CollectParameters(ctx) + if err != nil { + return nil, fmt.Errorf("failed to collect stack parameters: %w", err) + } + _, err = tss.sm.Create(*params) + if err != nil { + return nil, fmt.Errorf("failed to create stack: %w", err) + } + + selectedStackName = params.Name + } + + return tss.sm.Load(selectedStackName) +} + +func (tss *testableStackSelector) elicitStackSelection(ctx context.Context, ec elicitations.Controller) (string, error) { + stackList, err := tss.sm.List(ctx) + if err != nil { + return "", fmt.Errorf("failed to list stacks: %w", err) + } + + if len(stackList) == 0 { + return CreateNewStack, nil + } + + stackNames := make([]string, 0, len(stackList)+1) + for _, s := range stackList { + stackNames = append(stackNames, s.Name) + } + stackNames = append(stackNames, CreateNewStack) + + selectedStackName, err := ec.RequestEnum(ctx, "Select a stack", "stack", stackNames) + if err != nil { + return "", fmt.Errorf("failed to elicit stack choice: %w", err) + } + + return selectedStackName, nil +} + +func TestStackSelector_SelectStack_CreateNewStack(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + mockWizard := &MockWizardInterface{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock existing stacks list + existingStacks := []StackListItem{ + {Name: "production", Provider: "aws", Region: "us-west-2"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil) + + // Mock user selecting to create new stack + expectedOptions := []string{"production", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return(CreateNewStack, nil) + + // Mock wizard parameter collection + newStackParams := &StackParameters{ + Name: "staging", + Provider: cliClient.ProviderAWS, + Region: "us-east-1", + AWSProfile: "staging", + Mode: modes.ModeAffordable, + } + mockWizard.On("CollectParameters", ctx).Return(newStackParams, nil) + + // Mock stack creation + mockSM.On("Create", *newStackParams).Return("staging", nil) + + // Mock loading the created stack + mockSM.On("Load", "staging").Return(newStackParams, nil) + + selector := &testableStackSelector{ + ec: mockEC, + sm: mockSM, + wizard: mockWizard, + } + + result, err := selector.SelectStack(ctx) + + assert.NoError(t, err) + assert.Equal(t, newStackParams, result) + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) + mockWizard.AssertExpectations(t) +} + +func TestStackSelector_SelectStack_NoExistingStacks(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + mockWizard := &MockWizardInterface{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock empty stacks list - when no stacks exist, it should automatically proceed to create new + mockSM.On("List", ctx).Return([]StackListItem{}, nil) + + // Mock wizard parameter collection + newStackParams := &StackParameters{ + Name: "firststack", + Provider: cliClient.ProviderAWS, + Region: "us-west-2", + AWSProfile: "default", + Mode: modes.ModeBalanced, + } + mockWizard.On("CollectParameters", ctx).Return(newStackParams, nil) + + // Mock stack creation + mockSM.On("Create", *newStackParams).Return("firststack", nil) + + // Mock loading the created stack + mockSM.On("Load", "firststack").Return(newStackParams, nil) + + selector := &testableStackSelector{ + ec: mockEC, + sm: mockSM, + wizard: mockWizard, + } + + result, err := selector.SelectStack(ctx) + + assert.NoError(t, err) + assert.Equal(t, newStackParams, result) + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) + mockWizard.AssertExpectations(t) +} + +func TestStackSelector_SelectStack_ElicitationsNotSupported(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + + // Mock that elicitations are not supported + mockEC.On("IsSupported").Return(false) + + selector := NewSelector(mockEC, mockSM) + + result, err := selector.SelectStack(ctx) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "your mcp client does not support elicitations") + + mockEC.AssertExpectations(t) + mockSM.AssertNotCalled(t, "List") +} + +func TestStackSelector_SelectStack_ListStacksError(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock error when listing stacks + mockSM.On("List", ctx).Return([]StackListItem{}, errors.New("failed to access stack storage")) + + selector := NewSelector(mockEC, mockSM) + + result, err := selector.SelectStack(ctx) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to select stack") + assert.Contains(t, err.Error(), "failed to list stacks") + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) +} + +func TestStackSelector_SelectStack_ElicitationError(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock existing stacks list + existingStacks := []StackListItem{ + {Name: "production", Provider: "aws", Region: "us-west-2"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil) + + // Mock error during elicitation + expectedOptions := []string{"production", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return("", errors.New("user cancelled selection")) + + selector := NewSelector(mockEC, mockSM) + + result, err := selector.SelectStack(ctx) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to select stack") + assert.Contains(t, err.Error(), "failed to elicit stack choice") + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) +} + +func TestStackSelector_SelectStack_LoadStackError(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock existing stacks list + existingStacks := []StackListItem{ + {Name: "production", Provider: "aws", Region: "us-west-2"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil) + + // Mock user selecting existing stack + expectedOptions := []string{"production", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return("production", nil) + + // Mock error when loading the selected stack + mockSM.On("Load", "production").Return((*StackParameters)(nil), errors.New("stack file corrupted")) + + selector := NewSelector(mockEC, mockSM) + + result, err := selector.SelectStack(ctx) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "stack file corrupted") + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) +} + +func TestStackSelector_SelectStack_WizardError(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + mockWizard := &MockWizardInterface{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock existing stacks list + existingStacks := []StackListItem{ + {Name: "production", Provider: "aws", Region: "us-west-2"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil) + + // Mock user selecting to create new stack + expectedOptions := []string{"production", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return(CreateNewStack, nil) + + // Mock wizard parameter collection error + mockWizard.On("CollectParameters", ctx).Return((*StackParameters)(nil), errors.New("user cancelled wizard")) + + selector := &testableStackSelector{ + ec: mockEC, + sm: mockSM, + wizard: mockWizard, + } + + result, err := selector.SelectStack(ctx) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to collect stack parameters") + assert.Contains(t, err.Error(), "user cancelled wizard") + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) + mockWizard.AssertExpectations(t) +} + +func TestStackSelector_SelectStack_CreateStackError(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + mockWizard := &MockWizardInterface{} + + // Mock that elicitations are supported + mockEC.On("IsSupported").Return(true) + + // Mock existing stacks list + existingStacks := []StackListItem{ + {Name: "production", Provider: "aws", Region: "us-west-2"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil) + + // Mock user selecting to create new stack + expectedOptions := []string{"production", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return(CreateNewStack, nil) + + // Mock wizard parameter collection + newStackParams := &StackParameters{ + Name: "staging", + Provider: cliClient.ProviderAWS, + Region: "us-east-1", + AWSProfile: "staging", + Mode: modes.ModeAffordable, + } + mockWizard.On("CollectParameters", ctx).Return(newStackParams, nil) + + // Mock stack creation error + mockSM.On("Create", *newStackParams).Return("", errors.New("invalid stack configuration")) + + selector := &testableStackSelector{ + ec: mockEC, + sm: mockSM, + wizard: mockWizard, + } + + result, err := selector.SelectStack(ctx) + + assert.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "failed to create stack") + assert.Contains(t, err.Error(), "invalid stack configuration") + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) + mockWizard.AssertExpectations(t) +} + +func TestStackSelector_ElicitStackSelection(t *testing.T) { + ctx := context.Background() + + mockEC := &MockElicitationsController{} + mockSM := &MockStacksManager{} + + // Test case: multiple stacks available + t.Run("multiple stacks", func(t *testing.T) { + existingStacks := []StackListItem{ + {Name: "prod", Provider: "aws", Region: "us-west-2"}, + {Name: "dev", Provider: "gcp", Region: "us-central1"}, + } + mockSM.On("List", ctx).Return(existingStacks, nil).Once() + + expectedOptions := []string{"prod", "dev", CreateNewStack} + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", expectedOptions).Return("dev", nil).Once() + + selector := NewSelector(mockEC, mockSM) + result, err := selector.elicitStackSelection(ctx, mockEC) + + assert.NoError(t, err) + assert.Equal(t, "dev", result) + }) + + // Test case: no stacks available + t.Run("no stacks", func(t *testing.T) { + mockSM.On("List", ctx).Return([]StackListItem{}, nil).Once() + + selector := NewSelector(mockEC, mockSM) + result, err := selector.elicitStackSelection(ctx, mockEC) + + assert.NoError(t, err) + assert.Equal(t, CreateNewStack, result) + }) + + mockEC.AssertExpectations(t) + mockSM.AssertExpectations(t) +} diff --git a/src/pkg/stacks/stacks.go b/src/pkg/stacks/stacks.go index 085f20663..b5ade4fbe 100644 --- a/src/pkg/stacks/stacks.go +++ b/src/pkg/stacks/stacks.go @@ -7,6 +7,7 @@ import ( "path/filepath" "regexp" "strings" + "time" "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/modes" @@ -23,6 +24,61 @@ type StackParameters struct { Mode modes.Mode } +func (params StackParameters) ToMap() map[string]string { + var properties map[string]string = make(map[string]string) + properties["DEFANG_PROVIDER"] = strings.ToLower(params.Provider.String()) + if params.Region != "" { + var regionVarName string + switch params.Provider { + case client.ProviderAWS: + regionVarName = "AWS_REGION" + case client.ProviderGCP: + regionVarName = "GCP_LOCATION" + } + if regionVarName != "" { + properties[regionVarName] = strings.ToLower(params.Region) + } + } + if params.Mode != modes.ModeUnspecified { + properties["DEFANG_MODE"] = strings.ToLower(params.Mode.String()) + } + + if params.Provider == client.ProviderAWS && params.AWSProfile != "" { + properties["AWS_PROFILE"] = params.AWSProfile + } + if params.Provider == client.ProviderGCP && params.GCPProjectID != "" { + properties["GCP_PROJECT_ID"] = params.GCPProjectID + } + return properties +} + +func ParamsFromMap(properties map[string]string) (StackParameters, error) { + var params StackParameters + for key, value := range properties { + switch key { + case "DEFANG_PROVIDER": + if err := params.Provider.Set(value); err != nil { + return params, err + } + case "AWS_REGION": + params.Region = value + case "GCP_LOCATION": + params.Region = value + case "AWS_PROFILE": + params.AWSProfile = value + case "GCP_PROJECT_ID": + params.GCPProjectID = value + case "DEFANG_MODE": + mode, err := modes.Parse(value) + if err != nil { + return params, err + } + params.Mode = mode + } + } + return params, nil +} + var validStackName = regexp.MustCompile(`^[a-z][a-z0-9]*$`) const Directory = ".defang" @@ -90,6 +146,7 @@ type StackListItem struct { Provider string Region string Mode string + DeployedAt time.Time } func List() ([]StackListItem, error) { @@ -137,58 +194,12 @@ func Parse(content string) (StackParameters, error) { if err != nil { return StackParameters{}, err } - var params StackParameters - for key, value := range properties { - switch key { - case "DEFANG_PROVIDER": - if err := params.Provider.Set(value); err != nil { - return params, err - } - case "AWS_REGION": - params.Region = value - case "GCP_LOCATION": - params.Region = value - case "AWS_PROFILE": - params.AWSProfile = value - case "GCP_PROJECT_ID": - params.GCPProjectID = value - case "DEFANG_MODE": - mode, err := modes.Parse(value) - if err != nil { - return params, err - } - params.Mode = mode - } - } - return params, nil + + return ParamsFromMap(properties) } func Marshal(params *StackParameters) (string, error) { - var properties map[string]string = make(map[string]string) - properties["DEFANG_PROVIDER"] = strings.ToLower(params.Provider.String()) - if params.Region != "" { - var regionVarName string - switch params.Provider { - case client.ProviderAWS: - regionVarName = "AWS_REGION" - case client.ProviderGCP: - regionVarName = "GCP_LOCATION" - } - if regionVarName != "" { - properties[regionVarName] = strings.ToLower(params.Region) - } - } - if params.Mode != modes.ModeUnspecified { - properties["DEFANG_MODE"] = strings.ToLower(params.Mode.String()) - } - - if params.Provider == client.ProviderAWS && params.AWSProfile != "" { - properties["AWS_PROFILE"] = params.AWSProfile - } - if params.Provider == client.ProviderGCP && params.GCPProjectID != "" { - properties["GCP_PROJECT_ID"] = params.GCPProjectID - } - return godotenv.Marshal(properties) + return godotenv.Marshal(params.ToMap()) } func Remove(name string) error { @@ -250,6 +261,28 @@ func OverloadInDirectory(workingDirectory, name string) error { return nil } +// This was basically ripped out of godotenv.Overload/Load. Unfortunately, they don't export +// a function that loads a map[string]string, so we have to reimplement it here. +func LoadParameters(params map[string]string, overload bool) error { + currentEnv := map[string]bool{} + rawEnv := os.Environ() + for _, rawEnvLine := range rawEnv { + key := strings.Split(rawEnvLine, "=")[0] + currentEnv[key] = true + } + + for key, value := range params { + if !currentEnv[key] || overload { + err := os.Setenv(key, value) + if err != nil { + return fmt.Errorf("could not set env var %q: %w", key, err) + } + } + } + + return nil +} + func filename(workingDirectory, stackname string) string { return filepath.Join(workingDirectory, Directory, stackname) } diff --git a/src/pkg/stacks/stacks_test.go b/src/pkg/stacks/stacks_test.go index caf2c60ab..5361f79fc 100644 --- a/src/pkg/stacks/stacks_test.go +++ b/src/pkg/stacks/stacks_test.go @@ -379,3 +379,113 @@ func TestLoad(t *testing.T) { assert.Equal(t, os.Getenv("GCP_LOCATION"), stackParams.Region) }) } + +func TestParamsToMap(t *testing.T) { + tests := []struct { + name string + params StackParameters + expectedMap map[string]string + }{ + { + name: "AWS params", + params: StackParameters{ + Name: "teststack", + Provider: cliClient.ProviderAWS, + Region: "us-west-2", + AWSProfile: "default", + GCPProjectID: "", + Mode: modes.ModeAffordable, + }, + expectedMap: map[string]string{ + "DEFANG_PROVIDER": "aws", + "AWS_REGION": "us-west-2", + "AWS_PROFILE": "default", + "DEFANG_MODE": "affordable", + }, + }, + { + name: "GCP params", + params: StackParameters{ + Name: "gcpstack", + Provider: cliClient.ProviderGCP, + Region: "us-central1", + AWSProfile: "", + GCPProjectID: "gcp-project-123", + Mode: modes.ModeBalanced, + }, + expectedMap: map[string]string{ + "DEFANG_PROVIDER": "gcp", + "GCP_LOCATION": "us-central1", + "GCP_PROJECT_ID": "gcp-project-123", + "DEFANG_MODE": "balanced", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resultMap := tt.params.ToMap() + if len(resultMap) != len(tt.expectedMap) { + t.Errorf("Params.ToMap() = %v, want %v", resultMap, tt.expectedMap) + } + for key, expectedValue := range tt.expectedMap { + if resultMap[key] != expectedValue { + t.Errorf("Params.ToMap()[%q] = %q, want %q", key, resultMap[key], expectedValue) + } + } + }) + } +} + +func TestParamsFromMap(t *testing.T) { + tests := []struct { + name string + inputMap map[string]string + expectedParams StackParameters + }{ + { + name: "GCP params", + inputMap: map[string]string{ + "DEFANG_PROVIDER": "gcp", + "GCP_LOCATION": "us-central1", + "DEFANG_MODE": "balanced", + }, + expectedParams: StackParameters{ + Provider: cliClient.ProviderGCP, + Region: "us-central1", + Mode: modes.ModeBalanced, + }, + }, + { + name: "AWS params", + inputMap: map[string]string{ + "DEFANG_PROVIDER": "aws", + "AWS_REGION": "us-west-2", + "AWS_PROFILE": "default", + "DEFANG_MODE": "affordable", + }, + expectedParams: StackParameters{ + Provider: cliClient.ProviderAWS, + Region: "us-west-2", + AWSProfile: "default", + Mode: modes.ModeAffordable, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + resultParams, err := ParamsFromMap(tt.inputMap) + if err != nil { + t.Errorf("ParamsFromMap() error = %v", err) + } + + if resultParams.Provider != tt.expectedParams.Provider || + resultParams.Region != tt.expectedParams.Region || + resultParams.Mode != tt.expectedParams.Mode || + resultParams.AWSProfile != tt.expectedParams.AWSProfile { + t.Errorf("ParamsFromMap() = %+v, want %+v", resultParams, tt.expectedParams) + } + }) + } +} diff --git a/src/testdata/sanity/.defang/beta b/src/testdata/sanity/.defang/beta new file mode 100644 index 000000000..4fd4ce067 --- /dev/null +++ b/src/testdata/sanity/.defang/beta @@ -0,0 +1,2 @@ +DEFANG_PROVIDER=aws +AWS_REGION=us-west-2 \ No newline at end of file