diff --git a/README.md b/README.md index 79a98ae83..164f4b3b6 100644 --- a/README.md +++ b/README.md @@ -172,7 +172,7 @@ The Defang CLI recognizes the following environment variables: - `DEFANG_NO_CACHE` - If set to `true`, disables pull-through caching of container images; defaults to `false` - `DEFANG_ORG` - The name of the organization to use; defaults to the user's GitHub name - `DEFANG_PREFIX` - The prefix to use for all BYOC resources; defaults to `Defang` -- `DEFANG_PROVIDER` - The name of the cloud provider to use, `auto` (default), `aws`, `digitalocean`, `gcp`, or `defang` +- `DEFANG_PROVIDER` - The name of the cloud provider to use, `aws`, `digitalocean`, `gcp`, or `defang` - `DEFANG_PULUMI_BACKEND` - The Pulumi backend URL or `"pulumi-cloud"`; defaults to a self-hosted backend - `DEFANG_PULUMI_DEBUG` - If set to `true`, enables debug logging for Pulumi operations; defaults to `false` - `DEFANG_PULUMI_DIFF` - If set to `true`, shows the Pulumi diff during deployments; defaults to `false` diff --git a/pkgs/npm/README.md b/pkgs/npm/README.md index 2857ed89d..c76501efa 100644 --- a/pkgs/npm/README.md +++ b/pkgs/npm/README.md @@ -42,7 +42,7 @@ The Defang CLI recognizes the following environment variables: - `DEFANG_NO_CACHE` - If set to `true`, disables pull-through caching of container images; defaults to `false` - `DEFANG_ORG` - The name of the organization to use; defaults to the user's GitHub name - `DEFANG_PREFIX` - The prefix to use for all BYOC resources; defaults to `Defang` -- `DEFANG_PROVIDER` - The name of the cloud provider to use, `auto` (default), `aws`, `digitalocean`, `gcp`, or `defang` +- `DEFANG_PROVIDER` - The name of the cloud provider to use, `aws`, `digitalocean`, `gcp`, or `defang` - `DEFANG_PULUMI_BACKEND` - The Pulumi backend URL or `"pulumi-cloud"`; defaults to a self-hosted backend - `DEFANG_PULUMI_DEBUG` - If set to `true`, enables debug logging for Pulumi operations; defaults to `false` - `DEFANG_PULUMI_DIFF` - If set to `true`, shows the Pulumi diff during deployments; defaults to `false` diff --git a/src/.goreleaser.yml b/src/.goreleaser.yml index b715a3354..1f2138b1b 100644 --- a/src/.goreleaser.yml +++ b/src/.goreleaser.yml @@ -88,7 +88,7 @@ release: 2. Extract the archive. This should reveal the binary file for Defang. 3. Manually place the binary file in a directory that's included in your system's `PATH` environment variable. ### Additional Step for MacOS Users - If you're having trouble running the binary on MacOS, please check our [FAQs](https://docs.defang.io/docs/faq#im-having-trouble-running-the-binary-on-my-mac-what-should-i-do). + If you're having trouble running the binary on MacOS, please check our [FAQs](https://docs.defang.io/docs/intro/faq/questions#im-having-trouble-running-the-binary-on-my-mac-what-should-i-do). Please remember this software is in beta, so please report any issues or feedback through our GitHub page. Your help in improving Defang is greatly appreciated! # mode: keep-existing diff --git a/src/README.md b/src/README.md index 2857ed89d..c76501efa 100644 --- a/src/README.md +++ b/src/README.md @@ -42,7 +42,7 @@ The Defang CLI recognizes the following environment variables: - `DEFANG_NO_CACHE` - If set to `true`, disables pull-through caching of container images; defaults to `false` - `DEFANG_ORG` - The name of the organization to use; defaults to the user's GitHub name - `DEFANG_PREFIX` - The prefix to use for all BYOC resources; defaults to `Defang` -- `DEFANG_PROVIDER` - The name of the cloud provider to use, `auto` (default), `aws`, `digitalocean`, `gcp`, or `defang` +- `DEFANG_PROVIDER` - The name of the cloud provider to use, `aws`, `digitalocean`, `gcp`, or `defang` - `DEFANG_PULUMI_BACKEND` - The Pulumi backend URL or `"pulumi-cloud"`; defaults to a self-hosted backend - `DEFANG_PULUMI_DEBUG` - If set to `true`, enables debug logging for Pulumi operations; defaults to `false` - `DEFANG_PULUMI_DIFF` - If set to `true`, shows the Pulumi diff during deployments; defaults to `false` diff --git a/src/cmd/cli/command/cd.go b/src/cmd/cli/command/cd.go index be0ee40e3..ca77f9195 100644 --- a/src/cmd/cli/command/cd.go +++ b/src/cmd/cli/command/cd.go @@ -5,7 +5,6 @@ import ( "os" "github.com/DefangLabs/defang/src/pkg/cli" - cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/aws" "github.com/DefangLabs/defang/src/pkg/cli/compose" "github.com/DefangLabs/defang/src/pkg/term" @@ -36,19 +35,19 @@ var cdCmd = &cobra.Command{ func bootstrapCommand(cmd *cobra.Command, args []string, command string) error { ctx := cmd.Context() loader := configureLoader(cmd) - provider, err := newProviderChecked(ctx, loader) - if err != nil { - return err - } if len(args) == 0 { - projectName, err := cliClient.LoadProjectNameWithFallback(ctx, loader, provider) + projectName, err := loader.LoadProjectName(ctx) if err != nil { return err } args = []string{projectName} } + provider, err := newProviderChecked(ctx, args[0], false) + if err != nil { + return err + } var errs []error for _, projectName := range args { err := canIUseProvider(ctx, provider, projectName, 0) @@ -103,8 +102,7 @@ var cdTearDownCmd = &cobra.Command{ RunE: func(cmd *cobra.Command, args []string) error { force, _ := cmd.Flags().GetBool("force") - loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + provider, err := newProviderChecked(cmd.Context(), "", false) if err != nil { return err } @@ -122,7 +120,7 @@ var cdListCmd = &cobra.Command{ remote, _ := cmd.Flags().GetBool("remote") all, _ := cmd.Flags().GetBool("all") - provider, err := newProviderChecked(cmd.Context(), nil) + provider, err := newProviderChecked(cmd.Context(), "", false) if err != nil { return err } @@ -156,7 +154,9 @@ var cdPreviewCmd = &cobra.Command{ return err } - provider, err := newProviderChecked(cmd.Context(), loader) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), project.Name, saveStacksToWkDir) if err != nil { return err } @@ -183,8 +183,7 @@ var cdInstallCmd = &cobra.Command{ Short: "Install the CD resources into the cluster", Hidden: true, // users shouldn't have to run this manually, because it's done on deploy RunE: func(cmd *cobra.Command, args []string) error { - loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + provider, err := newProviderChecked(cmd.Context(), "", false) if err != nil { return err } diff --git a/src/cmd/cli/command/commands.go b/src/cmd/cli/command/commands.go index 23d58a1ad..6abb85c5f 100644 --- a/src/cmd/cli/command/commands.go +++ b/src/cmd/cli/command/commands.go @@ -16,6 +16,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/DefangLabs/defang/src/pkg" "github.com/DefangLabs/defang/src/pkg/agent" + agentTools "github.com/DefangLabs/defang/src/pkg/agent/tools" "github.com/DefangLabs/defang/src/pkg/cli" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc" @@ -24,6 +25,7 @@ import ( "github.com/DefangLabs/defang/src/pkg/clouds/aws" "github.com/DefangLabs/defang/src/pkg/debug" "github.com/DefangLabs/defang/src/pkg/dryrun" + "github.com/DefangLabs/defang/src/pkg/elicitations" "github.com/DefangLabs/defang/src/pkg/login" "github.com/DefangLabs/defang/src/pkg/logs" "github.com/DefangLabs/defang/src/pkg/mcp" @@ -86,7 +88,7 @@ func Execute(ctx context.Context) error { if strings.Contains(err.Error(), "maximum number of projects") { projectName := "" - provider, err := newProviderChecked(ctx, nil) + provider, err := newProviderChecked(ctx, projectName, false) if err != nil { return err } @@ -433,10 +435,17 @@ var RootCmd = &cobra.Command{ } } - // Read the global flags again from any .defang files in the cwd - err = global.loadDotDefang(global.getStackName(cmd.Flags())) - if err != nil { - return err + if !cmd.Flags().Changed("project-name") { + // TODO: consider connecting to fabric before loading the stack + // so we can support loading stacks from remote + // Read the global flags from the selected stack file + err = global.loadStackFile(global.getStackName(cmd.Flags())) + if err != nil { + // if the stack file does not exist, continue without error, we will load it from the remote later. + if !errors.Is(err, os.ErrNotExist) { + return err + } + } } err = global.syncFlagsWithEnv(cmd.Flags()) @@ -524,9 +533,8 @@ var whoamiCmd = &cobra.Command{ Args: cobra.NoArgs, Short: "Show the current user", RunE: func(cmd *cobra.Command, args []string) error { - loader := configureLoader(cmd) global.NonInteractive = true // don't show provider prompt - provider, err := newProvider(cmd.Context(), loader) + provider, err := newProviderChecked(cmd.Context(), "", false) if err != nil { term.Debug("unable to get provider:", err) } @@ -576,7 +584,9 @@ var certGenerateCmd = &cobra.Command{ return err } - provider, err := newProviderChecked(cmd.Context(), loader) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), project.Name, saveStacksToWkDir) if err != nil { return err } @@ -748,12 +758,13 @@ var configSetCmd = &cobra.Command{ // Make sure we have a project to set config for before asking for a value loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } @@ -879,12 +890,13 @@ var configDeleteCmd = &cobra.Command{ Short: "Removes one or more config values", RunE: func(cmd *cobra.Command, names []string) error { loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } @@ -912,12 +924,14 @@ var configListCmd = &cobra.Command{ Short: "List configs", RunE: func(cmd *cobra.Command, args []string) error { loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } @@ -943,12 +957,13 @@ var debugCmd = &cobra.Command{ } loader := configureLoader(cmd) - _, err := newProviderChecked(ctx, loader) + project, err := loader.LoadProject(ctx) if err != nil { return err } - - project, err := loader.LoadProject(ctx) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + _, err = newProviderChecked(ctx, project.Name, saveStacksToWkDir) if err != nil { return err } @@ -992,12 +1007,13 @@ var deleteCmd = &cobra.Command{ var tail, _ = cmd.Flags().GetBool("tail") loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } @@ -1238,7 +1254,7 @@ var providerDescription = map[cliClient.ProviderID]string{ cliClient.ProviderGCP: "Deploy to Google Cloud Platform using gcloud Application Default Credentials.", } -func updateProviderID(ctx context.Context, loader cliClient.Loader) error { +func updateProviderID(ctx context.Context) error { extraMsg := "" whence := "default project" @@ -1255,24 +1271,16 @@ func updateProviderID(ctx context.Context, loader cliClient.Loader) error { switch global.ProviderID { case cliClient.ProviderAuto: - if global.NonInteractive { - // Defaults to defang provider in non-interactive mode - if awsInEnv() { - term.Warn("Using Defang playground, but AWS environment variables were detected; did you forget --provider=aws or DEFANG_PROVIDER=aws?") - } - if doInEnv() { - term.Warn("Using Defang playground, but DIGITALOCEAN_TOKEN environment variable was detected; did you forget --provider=digitalocean or DEFANG_PROVIDER=digitalocean?") - } - if gcpInEnv() { - term.Warn("Using Defang playground, but GCP_PROJECT_ID/CLOUDSDK_CORE_PROJECT environment variable was detected; did you forget --provider=gcp or DEFANG_PROVIDER=gcp?") - } - global.ProviderID = cliClient.ProviderDefang - } else { - var err error - if whence, err = determineProviderID(ctx, loader); err != nil { - return err - } + if awsInEnv() { + term.Warn("Using Defang playground, but AWS environment variables were detected; did you forget --provider=aws or DEFANG_PROVIDER=aws?") + } + if doInEnv() { + term.Warn("Using Defang playground, but DIGITALOCEAN_TOKEN environment variable was detected; did you forget --provider=digitalocean or DEFANG_PROVIDER=digitalocean?") + } + if gcpInEnv() { + term.Warn("Using Defang playground, but GCP_PROJECT_ID/CLOUDSDK_CORE_PROJECT environment variable was detected; did you forget --provider=gcp or DEFANG_PROVIDER=gcp?") } + global.ProviderID = cliClient.ProviderDefang case cliClient.ProviderAWS: if !awsInConfig(ctx) { term.Warn("AWS provider was selected, but AWS environment is not set") @@ -1294,8 +1302,8 @@ func updateProviderID(ctx context.Context, loader cliClient.Loader) error { return nil } -func newProvider(ctx context.Context, loader cliClient.Loader) (cliClient.Provider, error) { - if err := updateProviderID(ctx, loader); err != nil { +func newProvider(ctx context.Context) (cliClient.Provider, error) { + if err := updateProviderID(ctx); err != nil { return nil, err } @@ -1303,10 +1311,23 @@ func newProvider(ctx context.Context, loader cliClient.Loader) (cliClient.Provid return provider, nil } -func newProviderChecked(ctx context.Context, loader cliClient.Loader) (cliClient.Provider, error) { - provider, err := newProvider(ctx, loader) +type providerCreator struct{} + +func (pc *providerCreator) NewProvider(ctx context.Context, providerId cliClient.ProviderID, client cliClient.FabricClient, stack string) cliClient.Provider { + return cli.NewProvider(ctx, providerId, client, stack) +} + +func newProviderChecked(ctx context.Context, projectName string, useWkDir bool) (cliClient.Provider, error) { + if global.NonInteractive { + return newProvider(ctx) + } + pc := &providerCreator{} + elicitationsClient := elicitations.NewSurveyClient(os.Stdin, os.Stdout, os.Stderr) + ec := elicitations.NewController(elicitationsClient) + pp := agentTools.NewProviderPreparer(pc, ec, global.Client) + _, provider, err := pp.SetupProvider(ctx, projectName, &global.Stack, useWkDir) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to setup provider: %w", err) } _, err = provider.AccountInfo(ctx) return provider, err @@ -1316,40 +1337,6 @@ func canIUseProvider(ctx context.Context, provider cliClient.Provider, projectNa return cliClient.CanIUseProvider(ctx, global.Client, provider, projectName, global.Stack, serviceCount) } -func determineProviderID(ctx context.Context, loader cliClient.Loader) (string, error) { - var projectName string - if loader != nil { - var err error - projectName, err = loader.LoadProjectName(ctx) - if err != nil { - term.Warnf("Unable to load project: %v", err) - } - - if projectName != "" && !RootCmd.PersistentFlags().Changed("provider") { // If user manually selected auto provider, do not load from remote - resp, err := global.Client.GetSelectedProvider(ctx, &defangv1.GetSelectedProviderRequest{Project: projectName}) - if err != nil { - term.Debugf("Unable to get selected provider: %v", err) - } else if resp.Provider != defangv1.Provider_PROVIDER_UNSPECIFIED { - global.ProviderID.SetValue(resp.Provider) - return "stored preference", nil - } - } - } - - whence, err := interactiveSelectProvider(cliClient.AllProviders()) - - // Save the selected provider to the fabric - if projectName != "" { - if err := global.Client.SetSelectedProvider(ctx, &defangv1.SetSelectedProviderRequest{Project: projectName, Provider: global.ProviderID.Value()}); err != nil { - term.Debugf("Unable to save selected provider to defang server: %v", err) - } else { - term.Printf("%v is now the default provider for project %v and will auto-select next time if no other provider is specified. Use --provider=auto to reselect.", global.ProviderID, projectName) - } - } - - return whence, err -} - func interactiveSelectProvider(providers []cliClient.ProviderID) (string, error) { if len(providers) < 2 { panic("interactiveSelectProvider called with less than 2 providers") diff --git a/src/cmd/cli/command/commands_test.go b/src/cmd/cli/command/commands_test.go index 704798e2c..34292c18c 100644 --- a/src/cmd/cli/command/commands_test.go +++ b/src/cmd/cli/command/commands_test.go @@ -12,10 +12,8 @@ import ( cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/aws" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc/gcp" - "github.com/DefangLabs/defang/src/pkg/cli/compose" pkg "github.com/DefangLabs/defang/src/pkg/clouds/aws" gcpdriver "github.com/DefangLabs/defang/src/pkg/clouds/gcp" - "github.com/DefangLabs/defang/src/pkg/term" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/DefangLabs/defang/src/protos/io/defang/v1/defangv1connect" "github.com/aws/aws-sdk-go-v2/service/ssm" @@ -81,6 +79,12 @@ func (m *mockFabricService) SetSelectedProvider(context.Context, *connect.Reques return connect.NewResponse(&emptypb.Empty{}), nil } +func (m *mockFabricService) ListDeployments(context.Context, *connect.Request[defangv1.ListDeploymentsRequest]) (*connect.Response[defangv1.ListDeploymentsResponse], error) { + return connect.NewResponse(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{}, + }), nil +} + func init() { SetupCommands(context.Background(), "0.0.0-test") } @@ -245,7 +249,6 @@ func TestGetProvider(t *testing.T) { } mockClient.SetClient(mockCtrl) global.Client = &mockClient - loader := cliClient.MockLoader{Project: compose.Project{Name: "empty"}} oldRootCmd := RootCmd t.Cleanup(func() { RootCmd = oldRootCmd @@ -261,154 +264,6 @@ func TestGetProvider(t *testing.T) { ctx := t.Context() - t.Run("Nil loader auto provider non-interactive should load playground provider", func(t *testing.T) { - global.ProviderID = "auto" - os.Unsetenv("DEFANG_PROVIDER") - RootCmd = FakeRootWithProviderParam("") - - p, err := newProvider(ctx, nil) - if err != nil { - t.Fatalf("getProvider() failed: %v", err) - } - if _, ok := p.(*cliClient.PlaygroundProvider); !ok { - t.Errorf("Expected provider to be of type *cliClient.PlaygroundProvider, got %T", p) - } - }) - - t.Run("Auto provider should get provider from client", func(t *testing.T) { - global.ProviderID = "auto" - os.Unsetenv("DEFANG_PROVIDER") - t.Setenv("AWS_REGION", "us-west-2") - RootCmd = FakeRootWithProviderParam("") - - mockCtrl.savedProvider = map[string]defangv1.Provider{"empty": defangv1.Provider_AWS} - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - }) - - p, err := newProvider(ctx, loader) - if err != nil { - t.Fatalf("getProvider() failed: %v", err) - } - if _, ok := p.(*aws.ByocAws); !ok { - t.Errorf("Expected provider to be of type *aws.ByocAws, got %T", p) - } - }) - - t.Run("Auto provider from param with saved provider should go interactive and save", func(t *testing.T) { - global.ProviderID = "auto" - os.Unsetenv("DEFANG_PROVIDER") - t.Setenv("AWS_REGION", "us-west-2") - mockCtrl.savedProvider = map[string]defangv1.Provider{"someotherproj": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("") - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - oldTerm := term.DefaultTerm - term.DefaultTerm = term.NewTerm( - &FakeStdin{bytes.NewReader([]byte("aws\n"))}, - &FakeStdout{new(bytes.Buffer)}, - new(bytes.Buffer), - ) - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - term.DefaultTerm = oldTerm - }) - - p, err := newProvider(ctx, loader) - if err != nil { - t.Fatalf("getProvider() failed: %v", err) - } - if _, ok := p.(*aws.ByocAws); !ok { - t.Errorf("Expected provider to be of type *aws.ByocAws, got %T", p) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_AWS { - t.Errorf("Expected provider to be saved as AWS, got %v", mockCtrl.savedProvider["empty"]) - } - }) - - t.Run("Interactive provider prompt infer default provider from environment variable", func(t *testing.T) { - if testing.Short() { - t.Skip("Skip digitalocean test") - } - global.ProviderID = "auto" - os.Unsetenv("DEFANG_PROVIDER") - os.Unsetenv("AWS_PROFILE") - t.Setenv("AWS_REGION", "us-west-2") - t.Setenv("DIGITALOCEAN_TOKEN", "test-token") - mockCtrl.savedProvider = map[string]defangv1.Provider{"someotherproj": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("") - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - oldTerm := term.DefaultTerm - term.DefaultTerm = term.NewTerm( - &FakeStdin{bytes.NewReader([]byte("\n"))}, // Use default option, which should be DO from env var - &FakeStdout{new(bytes.Buffer)}, - new(bytes.Buffer), - ) - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - term.DefaultTerm = oldTerm - }) - - _, err := newProvider(ctx, loader) - if err != nil && !strings.HasPrefix(err.Error(), "GET https://api.digitalocean.com/v2/account: 401") { - t.Fatalf("getProvider() failed: %v", err) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_DIGITALOCEAN { - t.Errorf("Expected provider to be saved as DIGITALOCEAN, got %v", mockCtrl.savedProvider["empty"]) - } - }) - - t.Run("Auto provider from param with saved provider should go interactive and save", func(t *testing.T) { - os.Unsetenv("GCP_PROJECT_ID") // To trigger error - os.Unsetenv("DEFANG_PROVIDER") - global.ProviderID = "auto" - mockCtrl.savedProvider = map[string]defangv1.Provider{"empty": defangv1.Provider_AWS} - RootCmd = FakeRootWithProviderParam("auto") - - ni := global.NonInteractive - sts := aws.StsClient - aws.StsClient = &mockStsProviderAPI{} - global.NonInteractive = false - oldTerm := term.DefaultTerm - term.DefaultTerm = term.NewTerm( - &FakeStdin{bytes.NewReader([]byte("gcp\n"))}, - &FakeStdout{new(bytes.Buffer)}, - new(bytes.Buffer), - ) - t.Cleanup(func() { - global.NonInteractive = ni - aws.StsClient = sts - mockCtrl.savedProvider = nil - term.DefaultTerm = oldTerm - }) - - _, err := newProvider(ctx, loader) - if err != nil && err.Error() != "GCP_PROJECT_ID or CLOUDSDK_CORE_PROJECT must be set for GCP projects" { - t.Fatalf("getProvider() failed: %v", err) - } - if mockCtrl.savedProvider["empty"] != defangv1.Provider_GCP { - t.Errorf("Expected provider to be saved as GCP, got %v", mockCtrl.savedProvider["empty"]) - } - }) - t.Run("Should take provider from param without updating saved provider", func(t *testing.T) { os.Unsetenv("DIGITALOCEAN_TOKEN") os.Unsetenv("DEFANG_PROVIDER") @@ -421,7 +276,7 @@ func TestGetProvider(t *testing.T) { mockCtrl.savedProvider = nil }) - _, err := newProvider(ctx, loader) + _, err := newProvider(ctx) if err != nil && !strings.HasPrefix(err.Error(), "DIGITALOCEAN_TOKEN must be set") { t.Fatalf("getProvider() failed: %v", err) } @@ -440,7 +295,7 @@ func TestGetProvider(t *testing.T) { aws.StsClient = sts }) - p, err := newProvider(ctx, loader) + p, err := newProvider(ctx) if err != nil { t.Errorf("getProvider() failed: %v", err) } @@ -459,7 +314,7 @@ func TestGetProvider(t *testing.T) { }, nil } - p, err := newProvider(ctx, loader) + p, err := newProvider(ctx) if err != nil { t.Errorf("getProvider() failed: %v", err) } @@ -480,7 +335,7 @@ func TestGetProvider(t *testing.T) { mockCtrl.canIUseResponse.CdImage = "" }) - p, err := newProvider(ctx, loader) + p, err := newProvider(ctx) if err != nil { t.Errorf("getProvider() failed: %v", err) } @@ -513,7 +368,7 @@ func TestGetProvider(t *testing.T) { mockCtrl.canIUseResponse.CdImage = "" }) - p, err := newProvider(ctx, loader) + p, err := newProvider(ctx) if err != nil { t.Errorf("getProvider() failed: %v", err) } diff --git a/src/cmd/cli/command/compose.go b/src/cmd/cli/command/compose.go index a98bc68f6..59045ba0b 100644 --- a/src/cmd/cli/command/compose.go +++ b/src/cmd/cli/command/compose.go @@ -5,11 +5,9 @@ import ( "errors" "fmt" "io" - "slices" "strings" "time" - "github.com/AlecAivazis/survey/v2" "github.com/DefangLabs/defang/src/pkg" "github.com/DefangLabs/defang/src/pkg/cli" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" @@ -19,7 +17,6 @@ import ( "github.com/DefangLabs/defang/src/pkg/dryrun" "github.com/DefangLabs/defang/src/pkg/logs" "github.com/DefangLabs/defang/src/pkg/modes" - "github.com/DefangLabs/defang/src/pkg/stacks" "github.com/DefangLabs/defang/src/pkg/term" "github.com/DefangLabs/defang/src/pkg/timeutils" "github.com/DefangLabs/defang/src/pkg/track" @@ -78,7 +75,7 @@ func makeComposeUpCmd() *cobra.Command { }, loadErr) } - provider, err := newProviderChecked(ctx, loader) + provider, err := newProviderChecked(ctx, project.Name, true) if err != nil { return err } @@ -89,25 +86,6 @@ func makeComposeUpCmd() *cobra.Command { return err } - // Check if the project is already deployed and warn the user if they're deploying it elsewhere - if resp, err := global.Client.ListDeployments(ctx, &defangv1.ListDeploymentsRequest{ - Project: project.Name, - Type: defangv1.DeploymentType_DEPLOYMENT_TYPE_ACTIVE, - }); err != nil { - term.Debugf("ListDeployments failed: %v", err) - } else if accountInfo, err := provider.AccountInfo(ctx); err != nil { - term.Debugf("AccountInfo failed: %v", err) - } else if len(resp.Deployments) > 0 { - handleExistingDeployments(resp.Deployments, accountInfo, project.Name) - } else if global.Stack == "" { - promptToCreateStack(stacks.StackParameters{ - Name: stacks.MakeDefaultName(accountInfo.Provider, accountInfo.Region), - Provider: accountInfo.Provider, - Region: accountInfo.Region, - Mode: global.Mode, - }) - } - // Show a warning for any (managed) services that we cannot monitor var managedServices []string for _, service := range project.Services { @@ -201,84 +179,6 @@ func makeComposeUpCmd() *cobra.Command { return composeUpCmd } -func handleExistingDeployments(existingDeployments []*defangv1.Deployment, accountInfo *cliClient.AccountInfo, projectName string) error { - samePlace := slices.ContainsFunc(existingDeployments, func(dep *defangv1.Deployment) bool { - // Old deployments may not have a region or account ID, so we check for empty values too - return dep.Provider == global.ProviderID.Value() && (dep.ProviderAccountId == accountInfo.AccountID || dep.ProviderAccountId == "") && (dep.Region == accountInfo.Region || dep.Region == "") - }) - if samePlace { - return nil - } - if err := confirmDeploymentToNewLocation(projectName, existingDeployments); err != nil { - return err - } - if global.Stack == "" { - stackName := "beta" - _, err := stacks.Create(stacks.StackParameters{ - Name: stackName, - Provider: accountInfo.Provider, - Region: accountInfo.Region, - Mode: global.Mode, - }) - if err != nil { - term.Debugf("Failed to create stack %v", err) - } else { - term.Info(stacks.PostCreateMessage(stackName)) - } - } - return nil -} - -func printExistingDeployments(existingDeployments []*defangv1.Deployment) { - term.Info("This project has already deployed to the following locations:") - deploymentStrings := make([]string, 0, len(existingDeployments)) - for _, dep := range existingDeployments { - var providerId cliClient.ProviderID - providerId.SetValue(dep.Provider) - deploymentStrings = append(deploymentStrings, fmt.Sprintf(" - %v", cliClient.AccountInfo{Provider: providerId, AccountID: dep.ProviderAccountId, Region: dep.Region})) - } - // sort and remove duplicates - slices.Sort(deploymentStrings) - deploymentStrings = slices.Compact(deploymentStrings) - term.Println(strings.Join(deploymentStrings, "\n")) -} - -func confirmDeploymentToNewLocation(projectName string, existingDeployments []*defangv1.Deployment) error { - printExistingDeployments(existingDeployments) - var confirm bool - if err := survey.AskOne(&survey.Confirm{ - Message: "Are you sure you want to continue?", - Default: false, - }, &confirm, survey.WithStdio(term.DefaultTerm.Stdio())); err != nil { - return err - } else if !confirm { - return fmt.Errorf("deployment of project %q was canceled", projectName) - } - return nil -} - -func promptToCreateStack(params stacks.StackParameters) error { - if global.NonInteractive { - term.Info("Consider creating a stack to manage your deployments.") - printDefangHint("To create a stack, do:", "stack new --name="+params.Name) - return nil - } - - err := PromptForStackParameters(¶ms) - if err != nil { - return err - } - - _, err = stacks.Create(params) - if err != nil { - return err - } - - term.Info(stacks.PostCreateMessage(params.Name)) - - return nil -} - func handleComposeUpErr(ctx context.Context, debugger *debug.Debugger, project *compose.Project, provider cliClient.Provider, err error) error { if errors.Is(err, types.ErrComposeFileNotFound) { // TODO: generate a compose file based on the current project @@ -418,12 +318,13 @@ func makeComposeDownCmd() *cobra.Command { } loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } @@ -541,7 +442,7 @@ func makeComposeConfigCmd() *cobra.Command { }, loadErr) } - provider, err := newProvider(ctx, loader) + provider, err := newProviderChecked(ctx, project.Name, true) if err != nil { return err } @@ -571,12 +472,13 @@ func makeComposePsCmd() *cobra.Command { long, _ := cmd.Flags().GetBool("long") loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } @@ -696,12 +598,13 @@ func handleLogsCmd(cmd *cobra.Command, args []string) error { } loader := configureLoader(cmd) - provider, err := newProviderChecked(cmd.Context(), loader) + projectName, err := loader.LoadProjectName(cmd.Context()) if err != nil { return err } - - projectName, err := cliClient.LoadProjectNameWithFallback(cmd.Context(), loader, provider) + projectNameFlag, _ := cmd.Flags().GetString("project-name") + saveStacksToWkDir := projectNameFlag == "" + provider, err := newProviderChecked(cmd.Context(), projectName, saveStacksToWkDir) if err != nil { return err } diff --git a/src/cmd/cli/command/globals.go b/src/cmd/cli/command/globals.go index d4b7eef4e..9eb868c0e 100644 --- a/src/cmd/cli/command/globals.go +++ b/src/cmd/cli/command/globals.go @@ -242,7 +242,7 @@ func (r *GlobalConfig) syncFlagsWithEnv(flags *pflag.FlagSet) error { } /* -loadDotDefang loads configuration values from .defang files into environment variables. +loadStackFile loads configuration values from .defang files into environment variables. Loading order: @@ -257,7 +257,7 @@ are considered required when specified, while the general RC file is optional. This function also checks for conflicts between environment variables in the stack file and existing shell environment variables, and warns the user if any are found. */ -func (r *GlobalConfig) loadDotDefang(stackName string) error { +func (r *GlobalConfig) loadStackFile(stackName string) error { if stackName != "" { // Check for conflicts before loading err := checkEnvConflicts(stackName) @@ -277,7 +277,6 @@ in the file conflict with existing shell environment variables. If conflicts are found, it warns the user that the shell environment variable will take precedence. */ func checkEnvConflicts(stackName string) error { - path, err := filepath.Abs(filepath.Join(stacks.Directory, stackName)) if err != nil { return err diff --git a/src/cmd/cli/command/globals_test.go b/src/cmd/cli/command/globals_test.go index b96dfab78..b502576af 100644 --- a/src/cmd/cli/command/globals_test.go +++ b/src/cmd/cli/command/globals_test.go @@ -21,7 +21,7 @@ func Test_readGlobals(t *testing.T) { t.Run("OS env beats any .defang file", func(t *testing.T) { t.Chdir("testdata/with-stack") t.Setenv("VALUE", "from OS env") - err := testConfig.loadDotDefang("test") + err := testConfig.loadStackFile("test") if err != nil { t.Fatalf("%v", err) } @@ -32,7 +32,7 @@ func Test_readGlobals(t *testing.T) { }) t.Run("incorrect stackname used if no stack", func(t *testing.T) { - err := testConfig.loadDotDefang("non-existent-stack") + err := testConfig.loadStackFile("non-existent-stack") if err == nil { t.Fatalf("this test should fail for non-existent stack: %v", err) } @@ -316,7 +316,7 @@ func Test_configurationPrecedence(t *testing.T) { t.Chdir(tempDir) // simulates the actual loading sequence - err := testConfig.loadDotDefang(tt.rcStack.stackname) + err := testConfig.loadStackFile(tt.rcStack.stackname) if err != nil { t.Fatalf("failed to load env file: %v", err) } @@ -424,7 +424,6 @@ AWS_REGION="us-east-1"`, for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - prevTerm := term.DefaultTerm var stdout, stderr bytes.Buffer term.DefaultTerm = term.NewTerm(os.Stdin, &stdout, &stderr) diff --git a/src/pkg/agent/tools/default_tool_cli.go b/src/pkg/agent/tools/default_tool_cli.go index f7802af27..f7b164cc0 100644 --- a/src/pkg/agent/tools/default_tool_cli.go +++ b/src/pkg/agent/tools/default_tool_cli.go @@ -29,7 +29,7 @@ type StackConfig struct { type DefaultToolCLI struct{} -func (DefaultToolCLI) CanIUseProvider(ctx context.Context, client *cliClient.GrpcClient, providerId cliClient.ProviderID, projectName string, provider cliClient.Provider, serviceCount int) error { +func (DefaultToolCLI) CanIUseProvider(ctx context.Context, client cliClient.FabricClient, providerId cliClient.ProviderID, projectName string, provider cliClient.Provider, serviceCount int) error { return cliClient.CanIUseProvider(ctx, client, provider, projectName, "", serviceCount) // TODO: add stack } @@ -37,7 +37,7 @@ func (DefaultToolCLI) ConfigSet(ctx context.Context, projectName string, provide return cli.ConfigSet(ctx, projectName, provider, name, value) } -func (DefaultToolCLI) RunEstimate(ctx context.Context, project *compose.Project, client *cliClient.GrpcClient, provider cliClient.Provider, providerId cliClient.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) { +func (DefaultToolCLI) RunEstimate(ctx context.Context, project *compose.Project, client cliClient.FabricClient, provider cliClient.Provider, providerId cliClient.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) { return cli.RunEstimate(ctx, project, client, provider, providerId, region, mode) } @@ -46,11 +46,11 @@ func (DefaultToolCLI) ListConfig(ctx context.Context, provider cliClient.Provide return provider.ListConfig(ctx, req) } -func (DefaultToolCLI) Connect(ctx context.Context, cluster string) (*cliClient.GrpcClient, error) { +func (DefaultToolCLI) Connect(ctx context.Context, cluster string) (cliClient.FabricClient, error) { return cli.Connect(ctx, cluster) } -func (DefaultToolCLI) ComposeUp(ctx context.Context, client *cliClient.GrpcClient, provider cliClient.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) { +func (DefaultToolCLI) ComposeUp(ctx context.Context, client cliClient.FabricClient, provider cliClient.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) { return cli.ComposeUp(ctx, client, provider, params) } @@ -58,14 +58,10 @@ func (DefaultToolCLI) Tail(ctx context.Context, provider cliClient.Provider, pro return cli.Tail(ctx, provider, projectName, options) } -func (DefaultToolCLI) ComposeDown(ctx context.Context, projectName string, client *cliClient.GrpcClient, provider cliClient.Provider) (string, error) { +func (DefaultToolCLI) ComposeDown(ctx context.Context, projectName string, client cliClient.FabricClient, provider cliClient.Provider) (string, error) { return cli.ComposeDown(ctx, projectName, client, provider) } -func (DefaultToolCLI) LoadProjectNameWithFallback(ctx context.Context, loader cliClient.Loader, provider cliClient.Provider) (string, error) { - return cliClient.LoadProjectNameWithFallback(ctx, loader, provider) -} - func (DefaultToolCLI) ConfigDelete(ctx context.Context, projectName string, provider cliClient.Provider, name string) error { return cli.ConfigDelete(ctx, projectName, provider, name) } @@ -91,7 +87,11 @@ func (DefaultToolCLI) LoadProject(ctx context.Context, loader cliClient.Loader) return loader.LoadProject(ctx) } -func (DefaultToolCLI) CreatePlaygroundProvider(client *cliClient.GrpcClient) cliClient.Provider { +func (DefaultToolCLI) LoadProjectName(ctx context.Context, loader cliClient.Loader) (string, error) { + return loader.LoadProjectName(ctx) +} + +func (DefaultToolCLI) CreatePlaygroundProvider(client cliClient.FabricClient) cliClient.Provider { return &cliClient.PlaygroundProvider{FabricClient: client} } @@ -104,7 +104,7 @@ func (DefaultToolCLI) GenerateAuthURL(authPort int) string { return "Please open this URL in your browser: http://127.0.0.1:" + strconv.Itoa(authPort) + " to login" } -func (DefaultToolCLI) InteractiveLoginMCP(ctx context.Context, client *cliClient.GrpcClient, cluster string, mcpClient string) error { +func (DefaultToolCLI) InteractiveLoginMCP(ctx context.Context, client cliClient.FabricClient, cluster string, mcpClient string) error { return login.InteractiveLoginMCP(ctx, client, cluster, mcpClient) } diff --git a/src/pkg/agent/tools/deploy.go b/src/pkg/agent/tools/deploy.go index c1c9a1b97..69d7df732 100644 --- a/src/pkg/agent/tools/deploy.go +++ b/src/pkg/agent/tools/deploy.go @@ -45,7 +45,7 @@ func HandleDeployTool(ctx context.Context, loader cliClient.ProjectLoader, cli C } pp := NewProviderPreparer(cli, ec, client) - providerID, provider, err := pp.SetupProvider(ctx, config.Stack) + providerID, provider, err := pp.SetupProvider(ctx, project.Name, config.Stack, true) if err != nil { return "", fmt.Errorf("failed to setup provider: %w", err) } diff --git a/src/pkg/agent/tools/deploy_test.go b/src/pkg/agent/tools/deploy_test.go index bb7fa68fb..3b949ff53 100644 --- a/src/pkg/agent/tools/deploy_test.go +++ b/src/pkg/agent/tools/deploy_test.go @@ -41,7 +41,7 @@ type MockDeployCLI struct { CallLog []string } -func (m *MockDeployCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +func (m *MockDeployCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { m.CallLog = append(m.CallLog, fmt.Sprintf("Connect(%s)", cluster)) if m.ConnectError != nil { return &client.GrpcClient{}, m.ConnectError @@ -55,12 +55,12 @@ func (m *MockDeployCLI) NewProvider(ctx context.Context, providerId client.Provi return nil } -func (m *MockDeployCLI) InteractiveLoginMCP(ctx context.Context, client *client.GrpcClient, cluster string, mcpClient string) error { +func (m *MockDeployCLI) InteractiveLoginMCP(ctx context.Context, client client.FabricClient, cluster string, mcpClient string) error { m.CallLog = append(m.CallLog, "InteractiveLoginMCP") return m.InteractiveLoginMCPError } -func (m *MockDeployCLI) ComposeUp(ctx context.Context, fabric *client.GrpcClient, provider client.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) { +func (m *MockDeployCLI) ComposeUp(ctx context.Context, fabric client.FabricClient, provider client.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) { m.CallLog = append(m.CallLog, "ComposeUp") if m.ComposeUpError != nil { return nil, nil, m.ComposeUpError @@ -81,7 +81,7 @@ func (m *MockDeployCLI) TailAndMonitor(ctx context.Context, project *compose.Pro return nil, nil } -func (m *MockDeployCLI) CanIUseProvider(ctx context.Context, client *client.GrpcClient, providerId client.ProviderID, projectName string, provider client.Provider, serviceCount int) error { +func (m *MockDeployCLI) CanIUseProvider(ctx context.Context, client client.FabricClient, providerId client.ProviderID, projectName string, provider client.Provider, serviceCount int) error { m.CallLog = append(m.CallLog, "CanIUseProvider") return nil } @@ -174,7 +174,6 @@ func TestHandleDeployTool(t *testing.T) { loader := &client.MockLoader{} ec := elicitations.NewController(&mockElicitationsClient{ responses: map[string]string{ - "strategy": "profile", "profile_name": "default", }, }) diff --git a/src/pkg/agent/tools/destroy.go b/src/pkg/agent/tools/destroy.go index 9adcd4509..942f73eec 100644 --- a/src/pkg/agent/tools/destroy.go +++ b/src/pkg/agent/tools/destroy.go @@ -23,15 +23,14 @@ func HandleDestroyTool(ctx context.Context, loader cliClient.ProjectLoader, cli return "", fmt.Errorf("could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client) - _, provider, err := pp.SetupProvider(ctx, config.Stack) + projectName, err := cli.LoadProjectName(ctx, loader) if err != nil { - return "", fmt.Errorf("failed to setup provider: %w", err) + return "", fmt.Errorf("failed to load project name: %w", err) } - term.Debug("Function invoked: cliClient.LoadProjectNameWithFallback") - projectName, err := cli.LoadProjectNameWithFallback(ctx, loader, provider) + pp := NewProviderPreparer(cli, ec, client) + _, provider, err := pp.SetupProvider(ctx, projectName, config.Stack, false) if err != nil { - return "", fmt.Errorf("failed to load project name: %w", err) + return "", fmt.Errorf("failed to setup provider: %w", err) } if config.ProviderID == nil { diff --git a/src/pkg/agent/tools/destroy_test.go b/src/pkg/agent/tools/destroy_test.go index cb89436d2..29cc39e89 100644 --- a/src/pkg/agent/tools/destroy_test.go +++ b/src/pkg/agent/tools/destroy_test.go @@ -9,29 +9,33 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/elicitations" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/bufbuild/connect-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) // MockDestroyCLI implements CLIInterface for testing type MockDestroyCLI struct { CLIInterface - ConnectError error - ComposeDownError error - LoadProjectNameWithFallbackError error - CanIUseProviderError error - ComposeDownResult string - ProjectName string - CallLog []string + ConnectError error + ComposeDownError error + LoadProjectNameError error + CanIUseProviderError error + ComposeDownResult string + ProjectName string + CallLog []string } -func (m *MockDestroyCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +var mockFC *mockFabricClient + +func (m *MockDestroyCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { m.CallLog = append(m.CallLog, fmt.Sprintf("Connect(%s)", cluster)) if m.ConnectError != nil { return nil, m.ConnectError } - return &client.GrpcClient{}, nil + return mockFC, nil } func (m *MockDestroyCLI) NewProvider(ctx context.Context, providerId client.ProviderID, grpcClient client.FabricClient, stack string) client.Provider { @@ -39,7 +43,7 @@ func (m *MockDestroyCLI) NewProvider(ctx context.Context, providerId client.Prov return nil } -func (m *MockDestroyCLI) ComposeDown(ctx context.Context, projectName string, grpcClient *client.GrpcClient, provider client.Provider) (string, error) { +func (m *MockDestroyCLI) ComposeDown(ctx context.Context, projectName string, grpcClient client.FabricClient, provider client.Provider) (string, error) { m.CallLog = append(m.CallLog, fmt.Sprintf("ComposeDown(%s)", projectName)) if m.ComposeDownError != nil { return "", m.ComposeDownError @@ -47,15 +51,15 @@ func (m *MockDestroyCLI) ComposeDown(ctx context.Context, projectName string, gr return m.ComposeDownResult, nil } -func (m *MockDestroyCLI) LoadProjectNameWithFallback(ctx context.Context, loader client.Loader, provider client.Provider) (string, error) { - m.CallLog = append(m.CallLog, "LoadProjectNameWithFallback") - if m.LoadProjectNameWithFallbackError != nil { - return "", m.LoadProjectNameWithFallbackError +func (m *MockDestroyCLI) LoadProjectName(ctx context.Context, loader client.Loader) (string, error) { + m.CallLog = append(m.CallLog, "LoadProjectName") + if m.LoadProjectNameError != nil { + return "", m.LoadProjectNameError } return m.ProjectName, nil } -func (m *MockDestroyCLI) CanIUseProvider(ctx context.Context, grpcClient *client.GrpcClient, providerId client.ProviderID, projectName string, provider client.Provider, serviceCount int) error { +func (m *MockDestroyCLI) CanIUseProvider(ctx context.Context, grpcClient client.FabricClient, providerId client.ProviderID, projectName string, provider client.Provider, serviceCount int) error { m.CallLog = append(m.CallLog, fmt.Sprintf("CanIUseProvider(%s, %s)", providerId, projectName)) if m.CanIUseProviderError != nil { return m.CanIUseProviderError @@ -64,6 +68,7 @@ func (m *MockDestroyCLI) CanIUseProvider(ctx context.Context, grpcClient *client } func TestHandleDestroyTool(t *testing.T) { + mockFC = &mockFabricClient{} tests := []struct { name string providerID client.ProviderID @@ -83,7 +88,7 @@ func TestHandleDestroyTool(t *testing.T) { name: "load_project_name_error", providerID: client.ProviderAWS, setupMock: func(m *MockDestroyCLI) { - m.LoadProjectNameWithFallbackError = errors.New("failed to load project name") + m.LoadProjectNameError = errors.New("failed to load project name") }, expectedError: "failed to load project name: failed to load project name", }, @@ -141,11 +146,23 @@ func TestHandleDestroyTool(t *testing.T) { loader := &client.MockLoader{} ec := elicitations.NewController(&mockElicitationsClient{ responses: map[string]string{ - "strategy": "profile", "profile_name": "default", }, }) stackName := "test-stack" + mockFC.On("ListDeployments", mock.Anything, mock.Anything).Return(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Id: "deployment-123", + Project: "test-project", + Stack: stackName, + Region: "us-test-2", + Provider: defangv1.Provider_AWS, + ProviderAccountId: "123456789012", + }, + }, + }, nil) + result, err := HandleDestroyTool(t.Context(), loader, mockCLI, ec, StackConfig{ Cluster: "test-cluster", ProviderID: &tt.providerID, @@ -166,8 +183,8 @@ func TestHandleDestroyTool(t *testing.T) { if tt.expectedError == "" && tt.name == "successful_destroy" { expectedCalls := []string{ "Connect(test-cluster)", + "LoadProjectName", "NewProvider(aws)", - "LoadProjectNameWithFallback", "CanIUseProvider(aws, test-project)", "ComposeDown(test-project)", } diff --git a/src/pkg/agent/tools/estimate_test.go b/src/pkg/agent/tools/estimate_test.go index eba3ac2ed..778315e78 100644 --- a/src/pkg/agent/tools/estimate_test.go +++ b/src/pkg/agent/tools/estimate_test.go @@ -28,7 +28,7 @@ type MockEstimateCLI struct { ProviderIDAfterSet client.ProviderID // Track the providerID that gets set } -func (m *MockEstimateCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +func (m *MockEstimateCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { m.CallLog = append(m.CallLog, fmt.Sprintf("Connect(%s)", cluster)) if m.ConnectError != nil { return nil, m.ConnectError @@ -44,7 +44,7 @@ func (m *MockEstimateCLI) LoadProject(ctx context.Context, loader client.Loader) return m.Project, nil } -func (m *MockEstimateCLI) RunEstimate(ctx context.Context, project *compose.Project, grpcClient *client.GrpcClient, provider client.Provider, providerId client.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) { +func (m *MockEstimateCLI) RunEstimate(ctx context.Context, project *compose.Project, FabricClient client.FabricClient, provider client.Provider, providerId client.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) { projectName := "" if project != nil { projectName = project.Name @@ -56,7 +56,7 @@ func (m *MockEstimateCLI) RunEstimate(ctx context.Context, project *compose.Proj return m.EstimateResponse, nil } -func (m *MockEstimateCLI) CreatePlaygroundProvider(grpcClient *client.GrpcClient) client.Provider { +func (m *MockEstimateCLI) CreatePlaygroundProvider(FabricClient client.FabricClient) client.Provider { m.CallLog = append(m.CallLog, "CreatePlaygroundProvider") return nil } @@ -121,7 +121,7 @@ func TestHandleEstimateTool(t *testing.T) { setupMock: func(m *MockEstimateCLI) { m.Project = &compose.Project{Name: "test-project"} }, - expectedError: "provider not one of [auto defang aws digitalocean gcp]", + expectedError: "provider not one of [defang aws digitalocean gcp]", }, { name: "run_estimate_error", diff --git a/src/pkg/agent/tools/interfaces.go b/src/pkg/agent/tools/interfaces.go index 23860fe34..4c876f738 100644 --- a/src/pkg/agent/tools/interfaces.go +++ b/src/pkg/agent/tools/interfaces.go @@ -15,22 +15,22 @@ import ( ) type CLIInterface interface { - CanIUseProvider(ctx context.Context, client *cliClient.GrpcClient, providerId cliClient.ProviderID, projectName string, provider cliClient.Provider, serviceCount int) error - ComposeDown(ctx context.Context, projectName string, client *cliClient.GrpcClient, provider cliClient.Provider) (string, error) - ComposeUp(ctx context.Context, client *cliClient.GrpcClient, provider cliClient.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) + CanIUseProvider(ctx context.Context, client cliClient.FabricClient, providerId cliClient.ProviderID, projectName string, provider cliClient.Provider, serviceCount int) error + ComposeDown(ctx context.Context, projectName string, client cliClient.FabricClient, provider cliClient.Provider) (string, error) + ComposeUp(ctx context.Context, client cliClient.FabricClient, provider cliClient.Provider, params cli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) ConfigDelete(ctx context.Context, projectName string, provider cliClient.Provider, name string) error ConfigSet(ctx context.Context, projectName string, provider cliClient.Provider, name, value string) error - Connect(ctx context.Context, cluster string) (*cliClient.GrpcClient, error) - CreatePlaygroundProvider(client *cliClient.GrpcClient) cliClient.Provider + Connect(ctx context.Context, cluster string) (cliClient.FabricClient, error) + CreatePlaygroundProvider(client cliClient.FabricClient) cliClient.Provider GenerateAuthURL(authPort int) string GetServices(ctx context.Context, projectName string, provider cliClient.Provider) ([]deployment_info.Service, error) - InteractiveLoginMCP(ctx context.Context, client *cliClient.GrpcClient, cluster string, mcpClient string) error + InteractiveLoginMCP(ctx context.Context, client cliClient.FabricClient, cluster string, mcpClient string) error ListConfig(ctx context.Context, provider cliClient.Provider, projectName string) (*defangv1.Secrets, error) LoadProject(ctx context.Context, loader cliClient.Loader) (*compose.Project, error) - LoadProjectNameWithFallback(ctx context.Context, loader cliClient.Loader, provider cliClient.Provider) (string, error) + LoadProjectName(ctx context.Context, loader cliClient.Loader) (string, error) NewProvider(ctx context.Context, providerId cliClient.ProviderID, client cliClient.FabricClient, stack string) cliClient.Provider PrintEstimate(mode modes.Mode, estimate *defangv1.EstimateResponse) string - RunEstimate(ctx context.Context, project *compose.Project, client *cliClient.GrpcClient, provider cliClient.Provider, providerId cliClient.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) + RunEstimate(ctx context.Context, project *compose.Project, client cliClient.FabricClient, provider cliClient.Provider, providerId cliClient.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) Tail(ctx context.Context, provider cliClient.Provider, projectName string, options cliTypes.TailOptions) error TailAndMonitor(ctx context.Context, project *compose.Project, provider cliClient.Provider, waitTimeout time.Duration, options cliTypes.TailOptions) (cli.ServiceStates, error) } diff --git a/src/pkg/agent/tools/listConfig.go b/src/pkg/agent/tools/listConfig.go index e62bcf1c8..9e723e148 100644 --- a/src/pkg/agent/tools/listConfig.go +++ b/src/pkg/agent/tools/listConfig.go @@ -20,21 +20,19 @@ func HandleListConfigTool(ctx context.Context, loader cliClient.ProjectLoader, c term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, sc.Cluster) if err != nil { - return "", fmt.Errorf("Could not connect: %w", err) + return "", fmt.Errorf("could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client) - _, provider, err := pp.SetupProvider(ctx, sc.Stack) - if err != nil { - return "", fmt.Errorf("failed to setup provider: %w", err) - } - - term.Debug("Function invoked: cliClient.LoadProjectNameWithFallback") - projectName, err := cli.LoadProjectNameWithFallback(ctx, loader, provider) + projectName, err := cli.LoadProjectName(ctx, loader) if err != nil { return "", fmt.Errorf("failed to load project name: %w", err) } term.Debug("Project name loaded:", projectName) + pp := NewProviderPreparer(cli, ec, client) + _, provider, err := pp.SetupProvider(ctx, projectName, sc.Stack, false) + if err != nil { + return "", fmt.Errorf("failed to setup provider: %w", err) + } term.Debug("Function invoked: cli.ConfigList") config, err := cli.ListConfig(ctx, provider, projectName) diff --git a/src/pkg/agent/tools/listConfig_test.go b/src/pkg/agent/tools/listConfig_test.go index cb77ba141..26182c668 100644 --- a/src/pkg/agent/tools/listConfig_test.go +++ b/src/pkg/agent/tools/listConfig_test.go @@ -11,6 +11,7 @@ import ( "github.com/DefangLabs/defang/src/pkg/elicitations" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -25,12 +26,12 @@ type MockListConfigCLI struct { CallLog []string } -func (m *MockListConfigCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +func (m *MockListConfigCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { m.CallLog = append(m.CallLog, fmt.Sprintf("Connect(%s)", cluster)) if m.ConnectError != nil { return nil, m.ConnectError } - return &client.GrpcClient{}, nil + return mockFC, nil } func (m *MockListConfigCLI) NewProvider(ctx context.Context, providerId client.ProviderID, client client.FabricClient, stack string) client.Provider { @@ -38,8 +39,8 @@ func (m *MockListConfigCLI) NewProvider(ctx context.Context, providerId client.P return nil // Mock provider } -func (m *MockListConfigCLI) LoadProjectNameWithFallback(ctx context.Context, loader client.Loader, provider client.Provider) (string, error) { - m.CallLog = append(m.CallLog, "LoadProjectNameWithFallback") +func (m *MockListConfigCLI) LoadProjectName(ctx context.Context, loader client.Loader) (string, error) { + m.CallLog = append(m.CallLog, "LoadProjectName") if m.LoadProjectNameError != nil { return "", m.LoadProjectNameError } @@ -55,6 +56,7 @@ func (m *MockListConfigCLI) ListConfig(ctx context.Context, provider client.Prov } func TestHandleListConfigTool(t *testing.T) { + mockFC = &mockFabricClient{} tests := []struct { name string providerID client.ProviderID @@ -68,7 +70,7 @@ func TestHandleListConfigTool(t *testing.T) { setupMock: func(m *MockListConfigCLI) { m.ConnectError = errors.New("connection failed") }, - expectedError: "Could not connect: connection failed", + expectedError: "could not connect: connection failed", }, { name: "load_project_name_error", @@ -128,12 +130,24 @@ func TestHandleListConfigTool(t *testing.T) { loader := &client.MockLoader{} ec := elicitations.NewController(&mockElicitationsClient{ responses: map[string]string{ - "strategy": "profile", "profile_name": "default", }, }) stackName := "test-stack" + mockFC.On("ListDeployments", mock.Anything, mock.Anything).Return(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Id: "deployment-123", + Project: "test-project", + Stack: stackName, + Region: "us-test-2", + Provider: defangv1.Provider_AWS, + ProviderAccountId: "123456789012", + }, + }, + }, nil) + result, err := HandleListConfigTool(t.Context(), loader, mockCLI, ec, StackConfig{ Cluster: "test-cluster", ProviderID: &tt.providerID, @@ -154,8 +168,8 @@ func TestHandleListConfigTool(t *testing.T) { if tt.expectedError == "" && tt.name == "successful_list_single_config" { expectedCalls := []string{ "Connect(test-cluster)", + "LoadProjectName", "NewProvider(aws)", - "LoadProjectNameWithFallback", "ListConfig(test-project)", } assert.Equal(t, expectedCalls, mockCLI.CallLog) diff --git a/src/pkg/agent/tools/logs.go b/src/pkg/agent/tools/logs.go index 8793bb0c6..9c53f0198 100644 --- a/src/pkg/agent/tools/logs.go +++ b/src/pkg/agent/tools/logs.go @@ -45,18 +45,17 @@ func HandleLogsTool(ctx context.Context, loader cliClient.ProjectLoader, params return "", fmt.Errorf("could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client) - _, provider, err := pp.SetupProvider(ctx, config.Stack) + projectName, err := cli.LoadProjectName(ctx, loader) if err != nil { - return "", fmt.Errorf("failed to setup provider: %w", err) + return "", fmt.Errorf("failed to load project name: %w", err) } + term.Debug("Project name loaded:", projectName) - term.Debug("Function invoked: cli.LoadProjectNameWithFallback") - projectName, err := cli.LoadProjectNameWithFallback(ctx, loader, provider) + pp := NewProviderPreparer(cli, ec, client) + _, provider, err := pp.SetupProvider(ctx, projectName, config.Stack, false) if err != nil { - return "", fmt.Errorf("failed to load project name: %w", err) + return "", fmt.Errorf("failed to setup provider: %w", err) } - term.Debug("Project name loaded:", projectName) if config.ProviderID == nil { return "", errors.New("provider ID is required to fetch logs") @@ -78,7 +77,6 @@ func HandleLogsTool(ctx context.Context, loader cliClient.ProjectLoader, params }) if err != nil { - err = fmt.Errorf("failed to fetch logs: %w", err) term.Error("Failed to fetch logs", "error", err) return "", fmt.Errorf("failed to fetch logs: %w", err) } diff --git a/src/pkg/agent/tools/provider.go b/src/pkg/agent/tools/provider.go index 18c4ff77e..b7427d68d 100644 --- a/src/pkg/agent/tools/provider.go +++ b/src/pkg/agent/tools/provider.go @@ -6,13 +6,16 @@ import ( "errors" "fmt" "os" + "path/filepath" "sort" "strings" + "time" cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/elicitations" "github.com/DefangLabs/defang/src/pkg/stacks" "github.com/DefangLabs/defang/src/pkg/term" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" ) const CreateNewStack = "Create new stack" @@ -21,10 +24,40 @@ type ProviderCreator interface { NewProvider(ctx context.Context, providerId cliClient.ProviderID, client cliClient.FabricClient, stack string) cliClient.Provider } +type StacksManager interface { + Create(params stacks.StackParameters) (string, error) + Read(stackName string) (*stacks.StackParameters, error) + LoadParameters(*stacks.StackParameters) + List() ([]stacks.StackListItem, error) +} + +type stacksManager struct{} + +func NewStacksManager() *stacksManager { + return &stacksManager{} +} + +func (sm *stacksManager) Create(params stacks.StackParameters) (string, error) { + return stacks.Create(params) +} + +func (sm *stacksManager) Read(stackName string) (*stacks.StackParameters, error) { + return stacks.Read(stackName) +} + +func (sm *stacksManager) LoadParameters(params *stacks.StackParameters) { + stacks.LoadParameters(params) +} + +func (sm *stacksManager) List() ([]stacks.StackListItem, error) { + return stacks.List() +} + type providerPreparer struct { pc ProviderCreator ec elicitations.Controller fc cliClient.FabricClient + sm StacksManager } func NewProviderPreparer(pc ProviderCreator, ec elicitations.Controller, fc cliClient.FabricClient) *providerPreparer { @@ -32,38 +65,44 @@ func NewProviderPreparer(pc ProviderCreator, ec elicitations.Controller, fc cliC pc: pc, ec: ec, fc: fc, + sm: NewStacksManager(), } } -func (pp *providerPreparer) SetupProvider(ctx context.Context, stackName *string) (*cliClient.ProviderID, cliClient.Provider, error) { +func (pp *providerPreparer) SetupProvider(ctx context.Context, projectName string, stackName *string, useWkDir bool) (*cliClient.ProviderID, cliClient.Provider, error) { var providerID cliClient.ProviderID var err error var stack *stacks.StackParameters if stackName == nil { return nil, nil, errors.New("stackName cannot be nil") } - if *stackName != "" { - stack, err = stacks.Read(*stackName) - if err != nil { - return nil, nil, fmt.Errorf("failed to read stack: %w", err) - } - err = stacks.Load(*stackName) + if *stackName == "" { + stack, err = pp.selectOrCreateStack(ctx, projectName, useWkDir) if err != nil { - return nil, nil, fmt.Errorf("failed to load stack: %w", err) + return nil, nil, fmt.Errorf("failed to setup stack: %w", err) } + *stackName = stack.Name } else { - stack, err = pp.setupStack(ctx) + stack, err = pp.getStackParameters(ctx, projectName, *stackName, useWkDir) if err != nil { - return nil, nil, fmt.Errorf("failed to setup stack: %w", err) + return nil, nil, fmt.Errorf("failed to load stack: %w", err) } - *stackName = stack.Name } + term.Debugf("Loading stack params %v", stack) + pp.sm.LoadParameters(stack) err = providerID.Set(stack.Provider.Name()) if err != nil { return nil, nil, fmt.Errorf("failed to set provider ID: %w", err) } + if useWkDir { + _, err = pp.sm.Create(*stack) + if err != nil { + term.Warnf("Failed to create stackfile: %v", err) + } + } + err = pp.setupProviderAuthentication(ctx, providerID) if err != nil { return nil, nil, fmt.Errorf("failed to setup provider authentication: %w", err) @@ -74,53 +113,239 @@ func (pp *providerPreparer) SetupProvider(ctx context.Context, stackName *string return &providerID, provider, nil } -func selectStack(ctx context.Context, ec elicitations.Controller) (string, error) { - stackList, err := stacks.List() +type StackOption struct { + Name string + Local bool + LastDeployedAt time.Time + Parameters *stacks.StackParameters +} + +func (pp *providerPreparer) collectStackOptions(ctx context.Context, projectName string, useWkDir bool) (map[string]StackOption, error) { + // Merge remote and local stacks into a single list of type StackOption, + // prefer local if both exist + stackMap := make(map[string]StackOption) + remoteStackList, err := pp.collectPreviouslyDeployedStacks(ctx, projectName) if err != nil { - return "", fmt.Errorf("failed to list stacks: %w", err) + return nil, fmt.Errorf("failed to collect existing stacks: %w", err) + } + for _, remoteStack := range remoteStackList { + stackMap[remoteStack.Name] = StackOption{ + Name: remoteStack.Name, + Local: false, + LastDeployedAt: remoteStack.DeployedAt, + Parameters: &remoteStack.StackParameters, + } + } + + if useWkDir { + localStackList, err := pp.sm.List() + if err != nil { + return nil, fmt.Errorf("failed to list stacks: %w", err) + } + + for _, localStack := range localStackList { + existing, exists := stackMap[localStack.Name] + lastDeployedAt := time.Time{} + if exists { + lastDeployedAt = existing.LastDeployedAt + } + stackMap[localStack.Name] = StackOption{ + Name: localStack.Name, + Local: true, + LastDeployedAt: lastDeployedAt, + Parameters: nil, + } + } + } + + stackLabelMap := make(map[string]StackOption) + for _, stackOption := range stackMap { + label := stackOption.Name + if !stackOption.LastDeployedAt.IsZero() { + label = fmt.Sprintf("%s (last deployed %s)", stackOption.Name, stackOption.LastDeployedAt.Local().Format(time.RFC822)) + } + stackLabelMap[label] = stackOption + } + + return stackLabelMap, nil +} + +func printStacksInfoMessage(stacks map[string]StackOption) { + _, betaExists := stacks["beta"] + if betaExists { + infoLine := "This project was deployed with an implicit Stack called 'beta' before Stacks were introduced." + if len(stacks) == 1 { + infoLine += "\n To update your existing deployment, select the 'beta' Stack.\n" + + "Creating a new Stack will result in a separate deployment instance." + } + infoLine += "\n To learn more about Stacks, visit: https://docs.defang.io/docs/concepts/stacks" + term.Info(infoLine + "\n") } + executable, _ := os.Executable() + term.Infof("To skip this prompt, run %s up --stack=%s", filepath.Base(executable), "") +} - if len(stackList) == 0 { - return CreateNewStack, nil +func (pp *providerPreparer) selectStack(ctx context.Context, projectName string, useWkDir bool) (*StackOption, error) { + stackOptions, err := pp.collectStackOptions(ctx, projectName, useWkDir) + if err != nil { + return nil, fmt.Errorf("failed to collect stack options: %w", err) } + if len(stackOptions) == 0 { + return &StackOption{Name: CreateNewStack}, nil + } + + printStacksInfoMessage(stackOptions) - stackNames := make([]string, 0, len(stackList)+1) - for _, s := range stackList { - stackNames = append(stackNames, s.Name) + // Convert map back to slice + stackLabels := make([]string, 0, len(stackOptions)+1) + for label := range stackOptions { + stackLabels = append(stackLabels, label) + } + if useWkDir { + stackLabels = append(stackLabels, CreateNewStack) } - stackNames = append(stackNames, CreateNewStack) - selectedStackName, err := ec.RequestEnum(ctx, "Select a stack", "stack", stackNames) + selectedStackLabel, err := pp.ec.RequestEnum(ctx, "Select a stack", "stack", stackLabels) if err != nil { - return "", fmt.Errorf("failed to elicit stack choice: %w", err) + return nil, fmt.Errorf("failed to elicit stack choice: %w", err) + } + + // Handle special case where user selects "Create new stack" + if selectedStackLabel == CreateNewStack { + return &StackOption{Name: CreateNewStack}, nil + } + + selectedStackOption, ok := stackOptions[selectedStackLabel] + if !ok { + return nil, fmt.Errorf("selected stack label %q not found in stack options map", selectedStackLabel) + } + if selectedStackOption.Local { + return &selectedStackOption, nil + } + + if selectedStackOption.Parameters == nil { + return nil, fmt.Errorf("stack parameters for remote stack %q are nil", selectedStackLabel) + } + + if useWkDir { + term.Debugf("Importing stack %s from remote", selectedStackLabel) + _, err = pp.sm.Create(*selectedStackOption.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to create local stack from remote: %w", err) + } } - return selectedStackName, nil + return &selectedStackOption, nil +} + +type ExistingStack struct { + stacks.StackParameters + DeployedAt time.Time } -func (pp *providerPreparer) setupStack(ctx context.Context) (*stacks.StackParameters, error) { - selectedStackName, err := selectStack(ctx, pp.ec) +func (pp *providerPreparer) collectPreviouslyDeployedStacks(ctx context.Context, projectName string) ([]*ExistingStack, error) { + resp, err := pp.fc.ListDeployments(ctx, &defangv1.ListDeploymentsRequest{ + Project: projectName, + }) + if err != nil { + return nil, fmt.Errorf("failed to list deployments: %w", err) + } + deployments := resp.GetDeployments() + stackMap := make(map[string]*ExistingStack) + for _, deployment := range deployments { + stackName := deployment.GetStack() + if stackName == "" { + stackName = "beta" + } + var providerID cliClient.ProviderID + providerID.SetValue(deployment.GetProvider()) + // avoid overwriting existing entries, deployments are already sorted by deployed_at desc + if _, exists := stackMap[stackName]; !exists { + var deployedAt time.Time + if ts := deployment.GetTimestamp(); ts != nil { + deployedAt = ts.AsTime() + } + stackMap[stackName] = &ExistingStack{ + StackParameters: stacks.StackParameters{ + Name: stackName, + Provider: providerID, + Region: deployment.GetRegion(), + }, + DeployedAt: deployedAt, + } + } + } + stackParams := make([]*ExistingStack, 0, len(stackMap)) + for _, params := range stackMap { + stackParams = append(stackParams, params) + } + return stackParams, nil +} + +func (pp *providerPreparer) selectOrCreateStack(ctx context.Context, projectName string, useWkDir bool) (*stacks.StackParameters, error) { + selectedStack, err := pp.selectStack(ctx, projectName, useWkDir) if err != nil { return nil, fmt.Errorf("failed to select stack: %w", err) } - if selectedStackName == CreateNewStack { - newStack, err := pp.createNewStack(ctx) + if selectedStack.Name == CreateNewStack { + newStack, err := pp.promptForStackParameters(ctx) if err != nil { return nil, fmt.Errorf("failed to create new stack: %w", err) } - selectedStackName = newStack.Name + return newStack, nil + } + + // For local stacks, parameters need to be loaded from the stack manager + if selectedStack.Local && selectedStack.Parameters == nil { + params, err := pp.sm.Read(selectedStack.Name) + if err != nil { + return nil, fmt.Errorf("failed to read local stack %q: %w", selectedStack.Name, err) + } + return params, nil + } + + return selectedStack.Parameters, nil +} + +func (pp *providerPreparer) getStackParameters(ctx context.Context, projectName, stackName string, useWkDir bool) (*stacks.StackParameters, error) { + if !useWkDir { + return pp.importRemoteStack(ctx, projectName, stackName) } - err = stacks.Load(selectedStackName) + stack, err := pp.sm.Read(stackName) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("failed to read stack: %w", err) + } + stack, err = pp.importRemoteStack(ctx, projectName, stackName) + if err != nil { + return nil, fmt.Errorf("failed to import remote stack: %w", err) + } + if stack == nil { + return nil, fmt.Errorf("stack %q does not exist locally or remotely", stackName) + } + return stack, nil + } + + return stack, nil +} + +func (pp *providerPreparer) importRemoteStack(ctx context.Context, projectName, stackName string) (*stacks.StackParameters, error) { + existingStacks, err := pp.collectPreviouslyDeployedStacks(ctx, projectName) if err != nil { - return nil, fmt.Errorf("failed to load stack: %w", err) + return nil, fmt.Errorf("failed to collect existing stacks: %w", err) + } + for _, existingStack := range existingStacks { + if existingStack.Name == stackName { + return &existingStack.StackParameters, nil + } } - return stacks.Read(selectedStackName) + return nil, fmt.Errorf("stack %q does not exist remotely", stackName) } -func (pp *providerPreparer) createNewStack(ctx context.Context) (*stacks.StackListItem, error) { +func (pp *providerPreparer) promptForStackParameters(ctx context.Context) (*stacks.StackParameters, error) { var providerNames []string for _, p := range cliClient.AllProviders() { providerNames = append(providerNames, p.Name()) @@ -160,16 +385,8 @@ func (pp *providerPreparer) createNewStack(ctx context.Context) (*stacks.StackLi Region: region, Name: name, } - _, err = stacks.Create(params) - if err != nil { - return nil, fmt.Errorf("failed to create stack: %w", err) - } - return &stacks.StackListItem{ - Name: name, - Provider: providerID.Name(), - Region: region, - }, nil + return ¶ms, nil } func (pp *providerPreparer) setupProviderAuthentication(ctx context.Context, providerId cliClient.ProviderID) error { @@ -185,51 +402,50 @@ func (pp *providerPreparer) setupProviderAuthentication(ctx context.Context, pro } func (pp *providerPreparer) SetupAWSAuthentication(ctx context.Context) error { - if os.Getenv("AWS_PROFILE") != "" || (os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "") { + if os.Getenv("AWS_PROFILE") != "" { + return nil + } + + if os.Getenv("AWS_ACCESS_KEY_ID") != "" && os.Getenv("AWS_SECRET_ACCESS_KEY") != "" { return nil } - // TODO: check the fs for AWS credentials file or config for profile names // TODO: add support for aws sso strategy - strategy, err := pp.ec.RequestEnum(ctx, "How do you authenticate to AWS?", "strategy", []string{ - "profile", - "access_key", - }) + knownProfiles, err := listAWSProfiles() if err != nil { - return fmt.Errorf("failed to elicit AWS Access Key ID: %w", err) + return fmt.Errorf("failed to list AWS profiles: %w", err) } - if strategy == "profile" { - if os.Getenv("AWS_PROFILE") == "" { - knownProfiles, err := listAWSProfiles() - if err != nil { - return fmt.Errorf("failed to list AWS profiles: %w", err) - } - profile, err := pp.ec.RequestEnum(ctx, "Select your profile", "profile_name", knownProfiles) + if len(knownProfiles) > 0 { + const useAccessKeysOption = "Use Access Key ID and Secret Access Key" + knownProfiles = append(knownProfiles, useAccessKeysOption) + profile, err := pp.ec.RequestEnum(ctx, "Select your profile", "profile_name", knownProfiles) + if err != nil { + return fmt.Errorf("failed to elicit AWS Profile Name: %w", err) + } + if profile != useAccessKeysOption { + err := os.Setenv("AWS_PROFILE", profile) if err != nil { - return fmt.Errorf("failed to elicit AWS Profile Name: %w", err) - } - if err := os.Setenv("AWS_PROFILE", profile); err != nil { return fmt.Errorf("failed to set AWS_PROFILE environment variable: %w", err) } + return nil } - } else { - if os.Getenv("AWS_ACCESS_KEY_ID") == "" { - accessKeyID, err := pp.ec.RequestString(ctx, "Enter your AWS Access Key ID:", "access_key_id") - if err != nil { - return fmt.Errorf("failed to elicit AWS Access Key ID: %w", err) - } - if err := os.Setenv("AWS_ACCESS_KEY_ID", accessKeyID); err != nil { - return fmt.Errorf("failed to set AWS_ACCESS_KEY_ID environment variable: %w", err) - } + } + if os.Getenv("AWS_ACCESS_KEY_ID") == "" { + accessKeyID, err := pp.ec.RequestString(ctx, "Enter your AWS Access Key ID:", "access_key_id") + if err != nil { + return fmt.Errorf("failed to elicit AWS Access Key ID: %w", err) } - if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" { - accessKeySecret, err := pp.ec.RequestString(ctx, "Enter your AWS Secret Access Key:", "access_key_secret") - if err != nil { - return fmt.Errorf("failed to elicit AWS Secret Access Key: %w", err) - } - if err := os.Setenv("AWS_SECRET_ACCESS_KEY", accessKeySecret); err != nil { - return fmt.Errorf("failed to set AWS_SECRET_ACCESS_KEY environment variable: %w", err) - } + if err := os.Setenv("AWS_ACCESS_KEY_ID", accessKeyID); err != nil { + return fmt.Errorf("failed to set AWS_ACCESS_KEY_ID environment variable: %w", err) + } + } + if os.Getenv("AWS_SECRET_ACCESS_KEY") == "" { + accessKeySecret, err := pp.ec.RequestString(ctx, "Enter your AWS Secret Access Key:", "access_key_secret") + if err != nil { + return fmt.Errorf("failed to elicit AWS Secret Access Key: %w", err) + } + if err := os.Setenv("AWS_SECRET_ACCESS_KEY", accessKeySecret); err != nil { + return fmt.Errorf("failed to set AWS_SECRET_ACCESS_KEY environment variable: %w", err) } } return nil @@ -312,6 +528,10 @@ func listAWSProfiles() ([]string, error) { profiles[section] = struct{}{} } } + if err := scanner.Err(); err != nil { + f.Close() + return nil, fmt.Errorf("error reading %s: %w", file, err) + } f.Close() } diff --git a/src/pkg/agent/tools/provider_test.go b/src/pkg/agent/tools/provider_test.go new file mode 100644 index 000000000..509978e31 --- /dev/null +++ b/src/pkg/agent/tools/provider_test.go @@ -0,0 +1,681 @@ +package tools + +import ( + "context" + "iter" + "os" + "strings" + "testing" + "time" + + cliClient "github.com/DefangLabs/defang/src/pkg/cli/client" + "github.com/DefangLabs/defang/src/pkg/stacks" + "github.com/DefangLabs/defang/src/pkg/types" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" + defangv1connect "github.com/DefangLabs/defang/src/protos/io/defang/v1/defangv1connect" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" +) + +// Mock implementations +type mockProviderCreator struct { + mock.Mock +} + +func (m *mockProviderCreator) NewProvider(ctx context.Context, providerId cliClient.ProviderID, client cliClient.FabricClient, stack string) cliClient.Provider { + args := m.Called(ctx, providerId, client, stack) + provider, ok := args.Get(0).(cliClient.Provider) + if !ok { + return nil + } + return provider +} + +type mockElicitationsController struct { + mock.Mock +} + +func (m *mockElicitationsController) RequestString(ctx context.Context, message, field string) (string, error) { + args := m.Called(ctx, message, field) + return args.String(0), args.Error(1) +} + +func (m *mockElicitationsController) RequestStringWithDefault(ctx context.Context, message, field, defaultValue string) (string, error) { + args := m.Called(ctx, message, field, defaultValue) + return args.String(0), args.Error(1) +} + +func (m *mockElicitationsController) RequestEnum(ctx context.Context, message, field string, options []string) (string, error) { + args := m.Called(ctx, message, field, options) + return args.String(0), args.Error(1) +} + +type mockFabricClient struct { + mock.Mock +} + +func (m *mockFabricClient) ListDeployments(ctx context.Context, req *defangv1.ListDeploymentsRequest) (*defangv1.ListDeploymentsResponse, error) { + args := m.Called(ctx, req) + resp, ok := args.Get(0).(*defangv1.ListDeploymentsResponse) + if !ok { + return nil, args.Error(1) + } + return resp, args.Error(1) +} + +// We only need to implement ListDeployments for our tests, so we'll embed the interface +// and only override the method we care about +func (m *mockFabricClient) AgreeToS(context.Context) error { return nil } +func (m *mockFabricClient) CanIUse(context.Context, *defangv1.CanIUseRequest) (*defangv1.CanIUseResponse, error) { + return nil, nil +} +func (m *mockFabricClient) CheckLoginAndToS(context.Context) error { return nil } +func (m *mockFabricClient) Debug(context.Context, *defangv1.DebugRequest) (*defangv1.DebugResponse, error) { + return nil, nil +} +func (m *mockFabricClient) DelegateSubdomainZone(context.Context, *defangv1.DelegateSubdomainZoneRequest) (*defangv1.DelegateSubdomainZoneResponse, error) { + return nil, nil +} +func (m *mockFabricClient) DeleteSubdomainZone(context.Context, *defangv1.DeleteSubdomainZoneRequest) error { + return nil +} +func (m *mockFabricClient) Estimate(context.Context, *defangv1.EstimateRequest) (*defangv1.EstimateResponse, error) { + return nil, nil +} +func (m *mockFabricClient) GenerateCompose(context.Context, *defangv1.GenerateComposeRequest) (*defangv1.GenerateComposeResponse, error) { + return nil, nil +} +func (m *mockFabricClient) GenerateFiles(context.Context, *defangv1.GenerateFilesRequest) (*defangv1.GenerateFilesResponse, error) { + return nil, nil +} +func (m *mockFabricClient) GetController() defangv1connect.FabricControllerClient { return nil } +func (m *mockFabricClient) GetDelegateSubdomainZone(context.Context, *defangv1.GetDelegateSubdomainZoneRequest) (*defangv1.DelegateSubdomainZoneResponse, error) { + return nil, nil +} +func (m *mockFabricClient) GetPlaygroundProjectDomain(context.Context) (*defangv1.GetPlaygroundProjectDomainResponse, error) { + return nil, nil +} +func (m *mockFabricClient) GetSelectedProvider(context.Context, *defangv1.GetSelectedProviderRequest) (*defangv1.GetSelectedProviderResponse, error) { + return nil, nil +} +func (m *mockFabricClient) GetTenantName() types.TenantName { return "" } +func (m *mockFabricClient) GetVersions(context.Context) (*defangv1.Version, error) { return nil, nil } +func (m *mockFabricClient) Preview(context.Context, *defangv1.PreviewRequest) (*defangv1.PreviewResponse, error) { + return nil, nil +} +func (m *mockFabricClient) Publish(context.Context, *defangv1.PublishRequest) error { return nil } +func (m *mockFabricClient) PutDeployment(context.Context, *defangv1.PutDeploymentRequest) error { + return nil +} +func (m *mockFabricClient) RevokeToken(context.Context) error { return nil } +func (m *mockFabricClient) SetSelectedProvider(context.Context, *defangv1.SetSelectedProviderRequest) error { + return nil +} +func (m *mockFabricClient) Token(context.Context, *defangv1.TokenRequest) (*defangv1.TokenResponse, error) { + return nil, nil +} +func (m *mockFabricClient) Track(string, ...cliClient.Property) error { return nil } +func (m *mockFabricClient) VerifyDNSSetup(context.Context, *defangv1.VerifyDNSSetupRequest) error { + return nil +} +func (m *mockFabricClient) WhoAmI(context.Context) (*defangv1.WhoAmIResponse, error) { return nil, nil } + +type mockStacksManager struct { + mock.Mock +} + +func (m *mockStacksManager) Create(params stacks.StackParameters) (string, error) { + args := m.Called(params) + return args.String(0), args.Error(1) +} + +func (m *mockStacksManager) Read(stackName string) (*stacks.StackParameters, error) { + args := m.Called(stackName) + param, ok := args.Get(0).(*stacks.StackParameters) + if !ok { + return nil, args.Error(1) + } + return param, args.Error(1) +} + +func (m *mockStacksManager) LoadParameters(params *stacks.StackParameters) { + m.Called(params) +} + +func (m *mockStacksManager) List() ([]stacks.StackListItem, error) { + args := m.Called() + list, ok := args.Get(0).([]stacks.StackListItem) + if !ok { + return nil, args.Error(1) + } + return list, args.Error(1) +} + +type mockProvider struct { + mock.Mock +} + +// Implement DNSResolver interface +func (m *mockProvider) ServicePrivateDNS(name string) string { return "" } +func (m *mockProvider) ServicePublicDNS(name string, projectName string) string { return "" } +func (m *mockProvider) UpdateShardDomain(ctx context.Context) error { return nil } + +// Implement Provider interface +func (m *mockProvider) AccountInfo(context.Context) (*cliClient.AccountInfo, error) { return nil, nil } +func (m *mockProvider) BootstrapCommand(context.Context, cliClient.BootstrapCommandRequest) (types.ETag, error) { + return "", nil +} +func (m *mockProvider) BootstrapList(context.Context, bool) (iter.Seq[string], error) { + return nil, nil +} +func (m *mockProvider) CreateUploadURL(context.Context, *defangv1.UploadURLRequest) (*defangv1.UploadURLResponse, error) { + return nil, nil +} +func (m *mockProvider) DelayBeforeRetry(context.Context) error { return nil } +func (m *mockProvider) Delete(context.Context, *defangv1.DeleteRequest) (*defangv1.DeleteResponse, error) { + return nil, nil +} +func (m *mockProvider) DeleteConfig(context.Context, *defangv1.Secrets) error { return nil } +func (m *mockProvider) Deploy(context.Context, *defangv1.DeployRequest) (*defangv1.DeployResponse, error) { + return nil, nil +} +func (m *mockProvider) Destroy(context.Context, *defangv1.DestroyRequest) (types.ETag, error) { + return "", nil +} +func (m *mockProvider) GetDeploymentStatus(context.Context) error { return nil } +func (m *mockProvider) GetProjectUpdate(context.Context, string) (*defangv1.ProjectUpdate, error) { + return nil, nil +} +func (m *mockProvider) GetService(context.Context, *defangv1.GetRequest) (*defangv1.ServiceInfo, error) { + return nil, nil +} +func (m *mockProvider) GetServices(context.Context, *defangv1.GetServicesRequest) (*defangv1.GetServicesResponse, error) { + return nil, nil +} +func (m *mockProvider) ListConfig(context.Context, *defangv1.ListConfigsRequest) (*defangv1.Secrets, error) { + return nil, nil +} +func (m *mockProvider) PrepareDomainDelegation(context.Context, cliClient.PrepareDomainDelegationRequest) (*cliClient.PrepareDomainDelegationResponse, error) { + return nil, nil +} +func (m *mockProvider) Preview(context.Context, *defangv1.DeployRequest) (*defangv1.DeployResponse, error) { + return nil, nil +} +func (m *mockProvider) PutConfig(context.Context, *defangv1.PutConfigRequest) error { return nil } +func (m *mockProvider) QueryForDebug(context.Context, *defangv1.DebugRequest) error { return nil } +func (m *mockProvider) QueryLogs(context.Context, *defangv1.TailRequest) (cliClient.ServerStream[defangv1.TailResponse], error) { + return nil, nil +} +func (m *mockProvider) Subscribe(context.Context, *defangv1.SubscribeRequest) (cliClient.ServerStream[defangv1.SubscribeResponse], error) { + return nil, nil +} +func (m *mockProvider) TearDown(context.Context) error { return nil } +func (m *mockProvider) RemoteProjectName(context.Context) (string, error) { return "", nil } +func (m *mockProvider) SetCanIUseConfig(*defangv1.CanIUseResponse) {} +func (m *mockProvider) SetUpCD(context.Context) error { return nil } +func (m *mockProvider) TearDownCD(context.Context) error { return nil } + +// Helper function to create a test stack parameters +func createTestStackParameters(name string, provider cliClient.ProviderID, region string) *stacks.StackParameters { + return &stacks.StackParameters{ + Name: name, + Provider: provider, + Region: region, + } +} + +// Helper function to create deployment response +func createDeploymentResponse(stackName, provider, region string, deployedAt time.Time) *defangv1.ListDeploymentsResponse { + var providerEnum defangv1.Provider + if provider == "defang" { + providerEnum = defangv1.Provider_DEFANG + } else { + providerEnum = defangv1.Provider(defangv1.Provider_value[strings.ToUpper(provider)]) + } + + return &defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Stack: stackName, + Provider: providerEnum, + Region: region, + Timestamp: timestamppb.New(deployedAt), + }, + }, + } +} + +type TestCase struct { + name string + inputStackName string + useWkDir bool + localStackExists bool + remoteStackExists bool + hasOtherLocalStacks bool + hasOtherRemoteStacks bool + userSelectsCreateNew bool + userProviderChoice string + userRegionChoice string + userStackNameChoice string + userStackSelection string + expectedStackName string + expectedProvider cliClient.ProviderID + expectedRegion string + expectStackFileWritten bool + expectNewStackCreated bool +} + +func TestSetupProvider(t *testing.T) { + testCases := []TestCase{ + { + name: "stackname provided, stackfile exists locally, no previous deployments, useWkDir true, expect stack to be loaded", + inputStackName: "teststack", + useWkDir: true, + localStackExists: true, + remoteStackExists: false, + expectedStackName: "teststack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: false, + }, + { + name: "stackname provided, stackfile exists locally, no previous deployments, useWkDir false, expect stack to be loaded", + inputStackName: "teststack", + useWkDir: false, + localStackExists: true, + remoteStackExists: true, // Need remote stack for useWkDir=false case + expectedStackName: "teststack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: false, + }, + { + name: "stackname provided, stackfile doesn't exist locally, exists in previous deployments, useWkDir true, expect stackfile written and stack loaded", + inputStackName: "remotestack", + useWkDir: true, + localStackExists: false, + remoteStackExists: true, + expectedStackName: "remotestack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: true, + expectNewStackCreated: false, + }, + { + name: "stackname provided, stackfile doesn't exist locally, exists in previous deployments, useWkDir false, expect remote stack loaded but stackfile not written", + inputStackName: "remotestack", + useWkDir: false, + localStackExists: false, + remoteStackExists: true, + expectedStackName: "remotestack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: false, + }, + { + name: "no stackname, no local stackfiles, no deployments, useWkDir true, create stack with stackfile", + inputStackName: "", + useWkDir: true, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: false, + hasOtherRemoteStacks: false, + userProviderChoice: "Defang Playground", + userStackNameChoice: "newstack", + expectedStackName: "newstack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, // createNewStack doesn't write the file during tests + expectNewStackCreated: true, + }, + { + name: "no stackname, no local stackfiles, no deployments, useWkDir false, create stack without stackfile", + inputStackName: "", + useWkDir: false, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: false, + hasOtherRemoteStacks: false, + userProviderChoice: "Defang Playground", + userStackNameChoice: "newstack", + expectedStackName: "newstack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: true, + }, + { + name: "no stackname, local stackfile exists, useWkDir true, user selects existing stackfile", + inputStackName: "", + useWkDir: true, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: true, + hasOtherRemoteStacks: false, + userStackSelection: "existingstack", + expectedStackName: "existingstack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: false, + }, + { + name: "no stackname, local stackfile exists, useWkDir true, user creates new stack", + inputStackName: "", + useWkDir: true, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: true, + hasOtherRemoteStacks: false, + userSelectsCreateNew: true, + userProviderChoice: "Defang Playground", + userStackNameChoice: "brandnew", + expectedStackName: "brandnew", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: true, + }, + { + name: "no stackname, local stackfile exists, useWkDir false, create new stack without stackfile", + inputStackName: "", + useWkDir: false, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: true, + hasOtherRemoteStacks: false, + userSelectsCreateNew: true, + userProviderChoice: "Defang Playground", + userStackNameChoice: "newdefang", + expectedStackName: "newdefang", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: true, + }, + { + name: "no stackname, previous deployment exists, useWkDir true, user selects existing deployment", + inputStackName: "", + useWkDir: true, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: false, + hasOtherRemoteStacks: true, + userStackSelection: "remotestack (last deployed TIME)", + expectedStackName: "remotestack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: true, + expectNewStackCreated: false, + }, + { + name: "no stackname, previous deployment exists, useWkDir false, user selects existing deployment", + inputStackName: "", + useWkDir: false, + localStackExists: false, + remoteStackExists: false, + hasOtherLocalStacks: false, + hasOtherRemoteStacks: true, + userStackSelection: "remotestack (last deployed TIME)", + expectedStackName: "remotestack", + expectedProvider: cliClient.ProviderDefang, + expectedRegion: "", + expectStackFileWritten: false, + expectNewStackCreated: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + projectName := "test-project" + stackName := tc.inputStackName + deployTime := time.Now().Add(-24 * time.Hour) + + // Create fresh mocks for each test + mockPC := &mockProviderCreator{} + mockEC := &mockElicitationsController{} + mockFC := &mockFabricClient{} + mockSM := &mockStacksManager{} + mockProv := &mockProvider{} + + pp := &providerPreparer{ + pc: mockPC, + ec: mockEC, + fc: mockFC, + sm: mockSM, + } + + // Set up mocks based on test case + setupMocks(t, tc, ctx, projectName, deployTime, mockPC, mockEC, mockFC, mockSM, mockProv) + + // Call the function under test - SetupProvider + providerID, provider, err := pp.SetupProvider(ctx, projectName, &stackName, tc.useWkDir) + + // Assertions + require.NoError(t, err) + require.NotNil(t, providerID) + require.Equal(t, tc.expectedProvider, *providerID) + require.Equal(t, mockProv, provider) + require.Equal(t, tc.expectedStackName, stackName) + + // Verify mocks were called as expected + mockSM.AssertExpectations(t) + mockFC.AssertExpectations(t) + mockPC.AssertExpectations(t) + if tc.userProviderChoice != "" || tc.userStackSelection != "" || tc.userStackNameChoice != "" { + mockEC.AssertExpectations(t) + } + }) + } +} + +func setupMocks(t *testing.T, tc TestCase, ctx context.Context, projectName string, deployTime time.Time, mockPC *mockProviderCreator, mockEC *mockElicitationsController, mockFC *mockFabricClient, mockSM *mockStacksManager, mockProv *mockProvider) { + t.Helper() + + // Handle stackname provided scenarios + if tc.inputStackName != "" { + expectedStack := createTestStackParameters(tc.inputStackName, tc.expectedProvider, tc.expectedRegion) + + if tc.localStackExists && tc.useWkDir { + // When useWkDir=true, it first tries to Read from local + mockSM.On("Read", tc.inputStackName).Return(expectedStack, nil) + } else if tc.remoteStackExists || !tc.useWkDir { + // When useWkDir=false, it goes directly to importRemoteStack + // When useWkDir=true but local doesn't exist, it tries to Read first then importRemoteStack + if tc.useWkDir && !tc.localStackExists { + mockSM.On("Read", tc.inputStackName).Return((*stacks.StackParameters)(nil), os.ErrNotExist) + } + + // Mock ListDeployments for importRemoteStack + if tc.remoteStackExists { + mockFC.On("ListDeployments", ctx, &defangv1.ListDeploymentsRequest{ + Project: projectName, + }).Return(createDeploymentResponse(tc.inputStackName, "defang", tc.expectedRegion, deployTime), nil) + } else { + mockFC.On("ListDeployments", ctx, &defangv1.ListDeploymentsRequest{ + Project: projectName, + }).Return(&defangv1.ListDeploymentsResponse{Deployments: []*defangv1.Deployment{}}, nil) + } + + if tc.expectStackFileWritten { + mockSM.On("Create", *expectedStack).Return(".defang/"+tc.inputStackName, nil) + } + } + + mockSM.On("LoadParameters", expectedStack).Return() + + // When useWkDir is true, SetupProvider always calls Create to ensure the stack file exists + if tc.useWkDir { + mockSM.On("Create", *expectedStack).Return(".defang/"+tc.inputStackName, nil) + } + + mockPC.On("NewProvider", ctx, tc.expectedProvider, mockFC, tc.inputStackName).Return(mockProv) + return + } + + // Handle no stackname scenarios + deployments := []*defangv1.Deployment{} + if tc.hasOtherRemoteStacks { + deployments = append(deployments, &defangv1.Deployment{ + Stack: "remotestack", + Provider: defangv1.Provider_DEFANG, + Region: tc.expectedRegion, + Timestamp: timestamppb.New(deployTime), + }) + } + mockFC.On("ListDeployments", ctx, &defangv1.ListDeploymentsRequest{ + Project: projectName, + }).Return(&defangv1.ListDeploymentsResponse{Deployments: deployments}, nil) + + // Mock List for local stacks (only called when useWkDir=true) + if tc.useWkDir { + localStacks := []stacks.StackListItem{} + if tc.hasOtherLocalStacks { + localStacks = append(localStacks, stacks.StackListItem{ + Name: "existingstack", + Provider: tc.expectedProvider.String(), + Region: tc.expectedRegion, + }) + } + mockSM.On("List").Return(localStacks, nil) + } + + // Mock stack selection or creation + // Local stacks only matter when useWkDir is true + if (tc.hasOtherLocalStacks && tc.useWkDir) || tc.hasOtherRemoteStacks { + stackOptions := []string{} + if tc.hasOtherLocalStacks && tc.useWkDir { + stackOptions = append(stackOptions, "existingstack") + } + if tc.hasOtherRemoteStacks { + label := "remotestack (last deployed " + deployTime.Local().Format(time.RFC822) + ")" + stackOptions = append(stackOptions, label) + // Update the expected selection to match the actual format + if strings.Contains(tc.userStackSelection, "remotestack") { + tc.userStackSelection = label + } + } + if tc.useWkDir { + stackOptions = append(stackOptions, CreateNewStack) + } + + selectedOption := tc.userStackSelection + if tc.userSelectsCreateNew { + selectedOption = CreateNewStack + } + mockEC.On("RequestEnum", ctx, "Select a stack", "stack", stackOptions).Return(selectedOption, nil) + + if tc.userSelectsCreateNew { + setupNewStackCreationMocks(tc, ctx, mockEC, mockSM) + } else if tc.userStackSelection == "existingstack" { + setupExistingLocalStackMocks("existingstack", tc.expectedProvider, tc.expectedRegion, mockSM) + // Add Create mock for useWkDir case + if tc.useWkDir { + expectedStack := createTestStackParameters("existingstack", tc.expectedProvider, tc.expectedRegion) + mockSM.On("Create", *expectedStack).Return(".defang/existingstack", nil) + } + } else if strings.Contains(tc.userStackSelection, "remotestack") { + setupExistingRemoteStackMocks("remotestack", tc.expectedProvider, tc.expectedRegion, tc.expectStackFileWritten, mockSM) + // Add Create mock for useWkDir case + if tc.useWkDir && !tc.expectStackFileWritten { + expectedStack := createTestStackParameters("remotestack", tc.expectedProvider, tc.expectedRegion) + mockSM.On("Create", *expectedStack).Return(".defang/remotestack", nil) + } + } + } else if shouldCreateNewStack(tc) { + // No stack selection needed, directly create new stack + setupNewStackCreationMocks(tc, ctx, mockEC, mockSM) + } + + // For new stack creation scenarios, need to mock LoadParameters + expectedStack := createTestStackParameters(tc.expectedStackName, tc.expectedProvider, tc.expectedRegion) + if tc.expectNewStackCreated { + mockSM.On("LoadParameters", expectedStack).Return() + + // When useWkDir=true, SetupProvider always calls Create to ensure the stack file exists + if tc.useWkDir { + // Note: Create call is already mocked in setupNewStackCreationMocks + // but we may need a second call from SetupProvider line 100 + mockSM.On("Create", *expectedStack).Return(".defang/"+tc.expectedStackName, nil) + } + } + + mockPC.On("NewProvider", ctx, tc.expectedProvider, mockFC, tc.expectedStackName).Return(mockProv) +} + +func setupExistingLocalStackMocks(stackName string, expectedProvider cliClient.ProviderID, expectedRegion string, mockSM *mockStacksManager) { + expectedStack := createTestStackParameters(stackName, expectedProvider, expectedRegion) + mockSM.On("Read", stackName).Return(expectedStack, nil) + mockSM.On("LoadParameters", expectedStack).Return() +} + +func setupExistingRemoteStackMocks(stackName string, expectedProvider cliClient.ProviderID, expectedRegion string, expectStackFileWritten bool, mockSM *mockStacksManager) { + expectedStack := createTestStackParameters(stackName, expectedProvider, expectedRegion) + if expectStackFileWritten { + mockSM.On("Create", *expectedStack).Return(".defang/"+stackName, nil) + } + mockSM.On("LoadParameters", expectedStack).Return() +} + +func shouldCreateNewStack(tc TestCase) bool { + // User explicitly wants to create new stack + if tc.userSelectsCreateNew { + return true + } + + // When useWkDir=false and there are existing local stacks, still create new + if !tc.useWkDir && tc.hasOtherLocalStacks { + return true + } + + // No existing stacks anywhere, so must create new + if !tc.hasOtherLocalStacks && !tc.hasOtherRemoteStacks { + return true + } + + return false +} + +func setupNewStackCreationMocks(tc TestCase, ctx context.Context, mockEC *mockElicitationsController, mockSM *mockStacksManager) { + providerNames := []string{"Defang Playground", "AWS", "DigitalOcean", "Google Cloud Platform"} + mockEC.On("RequestEnum", ctx, "Where do you want to deploy?", "provider", providerNames).Return(tc.userProviderChoice, nil) + + if tc.userProviderChoice != "Defang Playground" { + var defaultRegion string + switch tc.userProviderChoice { + case "AWS": + defaultRegion = "us-east-1" + case "Google Cloud Platform": + defaultRegion = "us-central1" + case "DigitalOcean": + defaultRegion = "nyc1" + } + mockEC.On("RequestStringWithDefault", ctx, "Which region do you want to deploy to?", "region", defaultRegion).Return(tc.userRegionChoice, nil) + } + + var defaultName string + switch tc.userProviderChoice { + case "Defang Playground": + defaultName = "defang" + default: + // For other providers, would need region processing + defaultName = "defang" + } + + mockEC.On("RequestStringWithDefault", ctx, "Enter a name for your stack:", "stack_name", defaultName).Return(tc.userStackNameChoice, nil) + + // Mock Create for when useWkDir is true (this will be called by createNewStack) + if tc.useWkDir { + expectedParams := stacks.StackParameters{ + Provider: tc.expectedProvider, + Region: tc.expectedRegion, + Name: tc.userStackNameChoice, + } + mockSM.On("Create", expectedParams).Return(".defang/"+tc.userStackNameChoice, nil) + } +} diff --git a/src/pkg/agent/tools/removeConfig.go b/src/pkg/agent/tools/removeConfig.go index 5a96cc0a0..4e0484910 100644 --- a/src/pkg/agent/tools/removeConfig.go +++ b/src/pkg/agent/tools/removeConfig.go @@ -21,18 +21,17 @@ func HandleRemoveConfigTool(ctx context.Context, loader cliClient.ProjectLoader, term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, sc.Cluster) if err != nil { - return "", fmt.Errorf("Could not connect: %w", err) + return "", fmt.Errorf("could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client) - _, provider, err := pp.SetupProvider(ctx, sc.Stack) + projectName, err := cli.LoadProjectName(ctx, loader) if err != nil { - return "", fmt.Errorf("failed to setup provider: %w", err) + return "", fmt.Errorf("failed to load project name: %w", err) } - term.Debug("Function invoked: cliClient.LoadProjectNameWithFallback") - projectName, err := cli.LoadProjectNameWithFallback(ctx, loader, provider) + pp := NewProviderPreparer(cli, ec, client) + _, provider, err := pp.SetupProvider(ctx, projectName, sc.Stack, false) if err != nil { - return "", fmt.Errorf("failed to load project name: %w", err) + return "", fmt.Errorf("failed to setup provider: %w", err) } if err := cli.ConfigDelete(ctx, projectName, provider, params.Name); err != nil { // Show a warning (not an error) if the config was not found diff --git a/src/pkg/agent/tools/removeConfig_test.go b/src/pkg/agent/tools/removeConfig_test.go index 73ad5e770..172cfe394 100644 --- a/src/pkg/agent/tools/removeConfig_test.go +++ b/src/pkg/agent/tools/removeConfig_test.go @@ -10,8 +10,10 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/compose" "github.com/DefangLabs/defang/src/pkg/elicitations" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/bufbuild/connect-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -26,12 +28,12 @@ type MockRemoveConfigCLI struct { CallLog []string } -func (m *MockRemoveConfigCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +func (m *MockRemoveConfigCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { m.CallLog = append(m.CallLog, fmt.Sprintf("Connect(%s)", cluster)) if m.ConnectError != nil { return nil, m.ConnectError } - return &client.GrpcClient{}, nil + return mockFC, nil } func (m *MockRemoveConfigCLI) NewProvider(ctx context.Context, providerId client.ProviderID, client client.FabricClient, stack string) client.Provider { @@ -39,8 +41,8 @@ func (m *MockRemoveConfigCLI) NewProvider(ctx context.Context, providerId client return nil // Mock provider } -func (m *MockRemoveConfigCLI) LoadProjectNameWithFallback(ctx context.Context, loader client.Loader, provider client.Provider) (string, error) { - m.CallLog = append(m.CallLog, "LoadProjectNameWithFallback") +func (m *MockRemoveConfigCLI) LoadProjectName(ctx context.Context, loader client.Loader) (string, error) { + m.CallLog = append(m.CallLog, "LoadProjectName") if m.LoadProjectNameError != nil { return "", m.LoadProjectNameError } @@ -56,6 +58,7 @@ func (m *MockRemoveConfigCLI) ConfigDelete(ctx context.Context, projectName stri } func TestHandleRemoveConfigTool(t *testing.T) { + mockFC = &mockFabricClient{} tests := []struct { name string configName string @@ -72,7 +75,7 @@ func TestHandleRemoveConfigTool(t *testing.T) { m.ConnectError = errors.New("connection failed") }, expectError: true, - expectedError: "Could not connect: connection failed", + expectedError: "could not connect: connection failed", }, { name: "load_project_name_error", @@ -145,12 +148,23 @@ func TestHandleRemoveConfigTool(t *testing.T) { } ec := elicitations.NewController(&mockElicitationsClient{ responses: map[string]string{ - "strategy": "profile", "profile_name": "default", }, }) provider := client.ProviderAWS stackName := "test-stack" + mockFC.On("ListDeployments", mock.Anything, mock.Anything).Return(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Id: "deployment-123", + Project: "test-project", + Stack: stackName, + Region: "us-test-2", + Provider: defangv1.Provider_AWS, + ProviderAccountId: "123456789012", + }, + }, + }, nil) result, err := HandleRemoveConfigTool(t.Context(), loader, params, mockCLI, ec, StackConfig{ Cluster: "test-cluster", ProviderID: &provider, @@ -174,8 +188,8 @@ func TestHandleRemoveConfigTool(t *testing.T) { if !tt.expectError && tt.name == "successful_config_removal" { expectedCalls := []string{ "Connect(test-cluster)", + "LoadProjectName", "NewProvider(aws)", - "LoadProjectNameWithFallback", "ConfigDelete(test-project, DATABASE_URL)", } assert.Equal(t, expectedCalls, mockCLI.CallLog) diff --git a/src/pkg/agent/tools/services.go b/src/pkg/agent/tools/services.go index d12222229..16539285c 100644 --- a/src/pkg/agent/tools/services.go +++ b/src/pkg/agent/tools/services.go @@ -26,20 +26,19 @@ func HandleServicesTool(ctx context.Context, loader cliClient.ProjectLoader, cli return "", fmt.Errorf("could not connect: %w", err) } - pp := NewProviderPreparer(cli, ec, client) - _, provider, err := pp.SetupProvider(ctx, config.Stack) - if err != nil { - return "", fmt.Errorf("failed to setup provider: %w", err) - } - term.Debug("Function invoked: cli.LoadProjectNameWithFallback") - projectName, err := cli.LoadProjectNameWithFallback(ctx, loader, provider) - term.Debugf("Project name loaded: %s", projectName) + projectName, err := cli.LoadProjectName(ctx, loader) if err != nil { if strings.Contains(err.Error(), "no projects found") { return "no projects found on Playground", nil } return "", fmt.Errorf("failed to load project name: %w", err) } + term.Debugf("Project name loaded: %s", projectName) + pp := NewProviderPreparer(cli, ec, client) + _, provider, err := pp.SetupProvider(ctx, projectName, config.Stack, false) + if err != nil { + return "", fmt.Errorf("failed to setup provider: %w", err) + } serviceResponse, err := cli.GetServices(ctx, projectName, provider) if err != nil { diff --git a/src/pkg/agent/tools/services_test.go b/src/pkg/agent/tools/services_test.go index ff9d3b050..cd648e350 100644 --- a/src/pkg/agent/tools/services_test.go +++ b/src/pkg/agent/tools/services_test.go @@ -16,17 +16,18 @@ import ( defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/bufbuild/connect-go" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) // MockCLI implements CLIInterface for testing type MockCLI struct { CLIInterface - ConnectError error - LoadProjectNameWithFallbackError error - MockClient *client.GrpcClient - MockProvider client.Provider - MockProjectName string + ConnectError error + LoadProjectNameError error + MockClient client.FabricClient + MockProvider client.Provider + MockProjectName string GetServicesError error MockServices []deployment_info.Service @@ -35,7 +36,7 @@ type MockCLI struct { GetServicesProvider client.Provider } -func (m *MockCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +func (m *MockCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { if m.ConnectError != nil { return nil, m.ConnectError } @@ -46,9 +47,9 @@ func (m *MockCLI) NewProvider(ctx context.Context, providerId client.ProviderID, return m.MockProvider } -func (m *MockCLI) LoadProjectNameWithFallback(ctx context.Context, loader client.Loader, provider client.Provider) (string, error) { - if m.LoadProjectNameWithFallbackError != nil { - return "", m.LoadProjectNameWithFallbackError +func (m *MockCLI) LoadProjectName(ctx context.Context, loader client.Loader) (string, error) { + if m.LoadProjectNameError != nil { + return "", m.LoadProjectNameError } if m.MockProjectName != "" { return m.MockProjectName, nil @@ -66,11 +67,11 @@ func (m *MockCLI) GetServices(ctx context.Context, projectName string, provider return m.MockServices, nil } -func (m *MockCLI) ComposeDown(ctx context.Context, projectName string, client *client.GrpcClient, provider client.Provider) (string, error) { +func (m *MockCLI) ComposeDown(ctx context.Context, projectName string, client client.FabricClient, provider client.Provider) (string, error) { return "", nil } -func (m *MockCLI) ComposeUp(ctx context.Context, client *client.GrpcClient, provider client.Provider, params defangcli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) { +func (m *MockCLI) ComposeUp(ctx context.Context, client client.FabricClient, provider client.Provider, params defangcli.ComposeUpParams) (*defangv1.DeployResponse, *compose.Project, error) { return nil, nil, nil } @@ -82,7 +83,7 @@ func (m *MockCLI) ConfigSet(ctx context.Context, projectName string, provider cl return nil } -func (m *MockCLI) CreatePlaygroundProvider(client *client.GrpcClient) client.Provider { +func (m *MockCLI) CreatePlaygroundProvider(client client.FabricClient) client.Provider { return m.MockProvider } @@ -90,7 +91,7 @@ func (m *MockCLI) GenerateAuthURL(authPort int) string { return "" } -func (m *MockCLI) InteractiveLoginMCP(ctx context.Context, client *client.GrpcClient, cluster string, mcpClient string) error { +func (m *MockCLI) InteractiveLoginMCP(ctx context.Context, client client.FabricClient, cluster string, mcpClient string) error { return nil } @@ -106,7 +107,7 @@ func (m *MockCLI) PrintEstimate(mode modes.Mode, estimate *defangv1.EstimateResp return "" } -func (m *MockCLI) RunEstimate(ctx context.Context, project *compose.Project, client *client.GrpcClient, provider client.Provider, providerId client.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) { +func (m *MockCLI) RunEstimate(ctx context.Context, project *compose.Project, client client.FabricClient, provider client.Provider, providerId client.ProviderID, region string, mode modes.Mode) (*defangv1.EstimateResponse, error) { return nil, nil } @@ -150,6 +151,7 @@ func (m *mockElicitationsClient) Request(ctx context.Context, req elicitations.R } func TestHandleServicesToolWithMockCLI(t *testing.T) { + mockFC = &mockFabricClient{} tests := []struct { name string providerId client.ProviderID @@ -177,9 +179,9 @@ func TestHandleServicesToolWithMockCLI(t *testing.T) { name: "load_project_name_error", providerId: client.ProviderDefang, mockCLI: &MockCLI{ - MockClient: &client.GrpcClient{}, - MockProvider: &client.PlaygroundProvider{}, - LoadProjectNameWithFallbackError: errors.New("failed to load project name"), + MockClient: mockFC, + MockProvider: &client.PlaygroundProvider{}, + LoadProjectNameError: errors.New("failed to load project name"), }, expectedError: true, @@ -192,7 +194,7 @@ func TestHandleServicesToolWithMockCLI(t *testing.T) { name: "get_services_no_services_error", providerId: client.ProviderDefang, mockCLI: &MockCLI{ - MockClient: &client.GrpcClient{}, + MockClient: mockFC, MockProvider: &client.PlaygroundProvider{}, MockProjectName: "test-project", GetServicesError: defangcli.ErrNoServices{ProjectName: "test-project"}, @@ -206,7 +208,7 @@ func TestHandleServicesToolWithMockCLI(t *testing.T) { name: "get_services_project_not_deployed", providerId: client.ProviderDefang, mockCLI: &MockCLI{ - MockClient: &client.GrpcClient{}, + MockClient: mockFC, MockProvider: &client.PlaygroundProvider{}, MockProjectName: "test-project", GetServicesError: createConnectError(connect.CodeNotFound, "project test-project is not deployed in Playground"), @@ -220,7 +222,7 @@ func TestHandleServicesToolWithMockCLI(t *testing.T) { name: "get_services_generic_error", providerId: client.ProviderDefang, mockCLI: &MockCLI{ - MockClient: &client.GrpcClient{}, + MockClient: mockFC, MockProvider: &client.PlaygroundProvider{}, MockProjectName: "test-project", GetServicesError: errors.New("generic GetServices failure"), @@ -235,7 +237,7 @@ func TestHandleServicesToolWithMockCLI(t *testing.T) { name: "successful_cli_operations_until_get_services", providerId: client.ProviderDefang, mockCLI: &MockCLI{ - MockClient: &client.GrpcClient{}, + MockClient: mockFC, MockProvider: &client.PlaygroundProvider{}, MockProjectName: "test-project", MockServices: []deployment_info.Service{ @@ -263,11 +265,22 @@ func TestHandleServicesToolWithMockCLI(t *testing.T) { loader := &client.MockLoader{} ec := elicitations.NewController(&mockElicitationsClient{ responses: map[string]string{ - "strategy": "profile", "profile_name": "default", }, }) stackName := "test-stack" + mockFC.On("ListDeployments", mock.Anything, mock.Anything).Return(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Id: "deployment-123", + Project: "test-project", + Stack: stackName, + Region: "us-test-2", + Provider: defangv1.Provider_AWS, + ProviderAccountId: "123456789012", + }, + }, + }, nil) result, err := HandleServicesTool(t.Context(), loader, tt.mockCLI, ec, StackConfig{ Cluster: "test-cluster", ProviderID: &tt.providerId, diff --git a/src/pkg/agent/tools/setConfig.go b/src/pkg/agent/tools/setConfig.go index 0cf54bd2f..049286f29 100644 --- a/src/pkg/agent/tools/setConfig.go +++ b/src/pkg/agent/tools/setConfig.go @@ -21,24 +21,23 @@ func HandleSetConfig(ctx context.Context, loader cliClient.ProjectLoader, params term.Debug("Function invoked: cli.Connect") client, err := cli.Connect(ctx, sc.Cluster) if err != nil { - return "", fmt.Errorf("Could not connect: %w", err) + return "", fmt.Errorf("could not connect: %w", err) } pp := NewProviderPreparer(cli, ec, client) - _, provider, err := pp.SetupProvider(ctx, sc.Stack) - if err != nil { - return "", fmt.Errorf("failed to setup provider: %w", err) - } - if params.ProjectName == "" { - term.Debug("Function invoked: cliClient.LoadProjectNameWithFallback") - projectName, err := cli.LoadProjectNameWithFallback(ctx, loader, provider) + projectName, err := cli.LoadProjectName(ctx, loader) if err != nil { return "", fmt.Errorf("failed to load project name: %w", err) } params.ProjectName = projectName } + _, provider, err := pp.SetupProvider(ctx, params.ProjectName, sc.Stack, false) + if err != nil { + return "", fmt.Errorf("failed to setup provider: %w", err) + } + if !pkg.IsValidSecretName(params.Name) { return "", fmt.Errorf("Invalid config name: secret name %q is not valid", params.Name) } diff --git a/src/pkg/agent/tools/setConfig_test.go b/src/pkg/agent/tools/setConfig_test.go index 04a15edb6..5a1637860 100644 --- a/src/pkg/agent/tools/setConfig_test.go +++ b/src/pkg/agent/tools/setConfig_test.go @@ -8,7 +8,9 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/elicitations" + defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" ) @@ -22,7 +24,6 @@ type MockSetConfigCLI struct { NewProviderCalled bool LoadProjectNameCalled bool ConfigSetCalled bool - ReturnedGrpcClient *client.GrpcClient ReturnedProvider client.Provider ReturnedProjectName string ConfigSetProjectName string @@ -31,16 +32,12 @@ type MockSetConfigCLI struct { ConfigSetValue string } -func (m *MockSetConfigCLI) Connect(ctx context.Context, cluster string) (*client.GrpcClient, error) { +func (m *MockSetConfigCLI) Connect(ctx context.Context, cluster string) (client.FabricClient, error) { m.ConnectCalled = true if m.ConnectError != nil { return nil, m.ConnectError } - if m.ReturnedGrpcClient == nil { - // Return a non-nil client to avoid nil pointer issues - m.ReturnedGrpcClient = &client.GrpcClient{} - } - return m.ReturnedGrpcClient, nil + return mockFC, nil } func (m *MockSetConfigCLI) NewProvider(ctx context.Context, providerId client.ProviderID, fabricClient client.FabricClient, stack string) client.Provider { @@ -62,7 +59,7 @@ func (p *MockProvider) AccountInfo(context.Context) (*client.AccountInfo, error) return &client.AccountInfo{}, nil } -func (m *MockSetConfigCLI) LoadProjectNameWithFallback(ctx context.Context, loader client.Loader, provider client.Provider) (string, error) { +func (m *MockSetConfigCLI) LoadProjectName(ctx context.Context, loader client.Loader) (string, error) { m.LoadProjectNameCalled = true if m.LoadProjectNameError != nil { return "", m.LoadProjectNameError @@ -83,6 +80,7 @@ func (m *MockSetConfigCLI) ConfigSet(ctx context.Context, projectName string, pr } func TestHandleSetConfig(t *testing.T) { + mockFC = &mockFabricClient{} // Common test data const ( testCluster = "test-cluster" @@ -161,7 +159,7 @@ func TestHandleSetConfig(t *testing.T) { requestArgs: map[string]interface{}{"name": testConfigName, "value": testValue}, mockCLI: &MockSetConfigCLI{ConnectError: errors.New("connection failed")}, expectedError: true, - errorMessage: "Could not connect: connection failed", + errorMessage: "could not connect: connection failed", expectedConnectCalls: true, }, { @@ -173,7 +171,7 @@ func TestHandleSetConfig(t *testing.T) { expectedError: true, errorMessage: "failed to load project name: project loading failed", expectedConnectCalls: true, - expectedProviderCalls: true, + expectedProviderCalls: false, expectedProjectNameCalls: true, }, { @@ -242,11 +240,22 @@ func TestHandleSetConfig(t *testing.T) { } ec := elicitations.NewController(&mockElicitationsClient{ responses: map[string]string{ - "strategy": "profile", "profile_name": "default", }, }) stackName := "test-stack" + mockFC.On("ListDeployments", mock.Anything, mock.Anything).Return(&defangv1.ListDeploymentsResponse{ + Deployments: []*defangv1.Deployment{ + { + Id: "deployment-123", + Project: "test-project", + Stack: stackName, + Region: "us-test-2", + Provider: defangv1.Provider_AWS, + ProviderAccountId: "123456789012", + }, + }, + }, nil) result, err := HandleSetConfig(t.Context(), loader, params, tt.mockCLI, ec, StackConfig{ Cluster: tt.cluster, ProviderID: &tt.providerId, diff --git a/src/pkg/cli/client/provider_id.go b/src/pkg/cli/client/provider_id.go index 46c4b1038..4de971cfe 100644 --- a/src/pkg/cli/client/provider_id.go +++ b/src/pkg/cli/client/provider_id.go @@ -75,7 +75,7 @@ func (p *ProviderID) Set(str string) error { } } - return fmt.Errorf("provider not one of %v", allProviders) + return fmt.Errorf("provider not one of %v", AllProviders()) } func (p *ProviderID) SetValue(val defangv1.Provider) { diff --git a/src/pkg/migrate/heroku.go b/src/pkg/migrate/heroku.go index beb393cc9..5b1f33d56 100644 --- a/src/pkg/migrate/heroku.go +++ b/src/pkg/migrate/heroku.go @@ -334,10 +334,10 @@ func authenticateHerokuCLI() error { } // getHerokuAuthTokenFromCLI obtains a short-lived Heroku API token by invoking the local Heroku CLI. -// +// // It checks that the `heroku` executable is available, ensures the CLI is authenticated, runs // `heroku authorizations:create --expires-in=300 --json`, and parses the resulting JSON for the token. -// +// // The returned string is the extracted access token. An error is returned if the CLI is not installed, // authentication fails, the command cannot be executed, or the command output cannot be parsed. func getHerokuAuthTokenFromCLI() (string, error) { diff --git a/src/pkg/stacks/stacks.go b/src/pkg/stacks/stacks.go index 9eb0b7038..d84134738 100644 --- a/src/pkg/stacks/stacks.go +++ b/src/pkg/stacks/stacks.go @@ -144,7 +144,7 @@ func Parse(content string) (StackParameters, error) { return params, nil } -func Marshal(params StackParameters) (string, error) { +func paramsToMap(params *StackParameters) map[string]string { var properties map[string]string = make(map[string]string) properties["DEFANG_PROVIDER"] = strings.ToLower(params.Provider.String()) if params.Region != "" { @@ -162,6 +162,11 @@ func Marshal(params StackParameters) (string, error) { if params.Mode != modes.ModeUnspecified { properties["DEFANG_MODE"] = strings.ToLower(params.Mode.String()) } + return properties +} + +func Marshal(params StackParameters) (string, error) { + properties := paramsToMap(¶ms) return godotenv.Marshal(properties) } @@ -207,6 +212,23 @@ func Load(name string) error { return nil } +func LoadParameters(params *StackParameters) { + // copied from godotenv Load function with slight modification to load from StackParameters + currentEnv := map[string]bool{} + rawEnv := os.Environ() + for _, rawEnvLine := range rawEnv { + key := strings.Split(rawEnvLine, "=")[0] + currentEnv[key] = true + } + + properties := paramsToMap(params) + for key, value := range properties { + if !currentEnv[key] { + _ = os.Setenv(key, value) + } + } +} + func PostCreateMessage(stackName string) string { return fmt.Sprintf( "A stackfile has been created at `.defang/%s`.\n"+