diff --git a/internal/cli/inference.go b/internal/cli/inference.go new file mode 100644 index 000000000..6aa62b653 --- /dev/null +++ b/internal/cli/inference.go @@ -0,0 +1,406 @@ +package cli + +import ( + "encoding/json" + "fmt" + "regexp" + "strings" + + "github.com/spf13/cobra" + + "github.com/fullsend-ai/fullsend/internal/dispatch/gcf" + "github.com/fullsend-ai/fullsend/internal/ui" +) + +var gcpProjectPattern = regexp.MustCompile(`^[a-z][a-z0-9-]{4,28}[a-z0-9]$`) + +func newInferenceCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "inference", + Short: "Manage inference credentials (requires GCP access)", + Long: `Commands for provisioning and inspecting inference WIF infrastructure. + +These commands only require GCP project access — no GitHub token or +mint project is needed. Use them to set up Workload Identity Federation +for Vertex AI inference, then hand off the WIF provider resource name +to the GitHub admin who runs 'fullsend admin install'.`, + } + cmd.AddCommand(newInferenceProvisionCmd()) + cmd.AddCommand(newInferenceStatusCmd()) + return cmd +} + +// parseOrgOrRepo determines whether the argument is an org name or owner/repo. +// Returns (org, "", nil) for org-scoped or (owner, "owner/repo", nil) for repo-scoped. +func parseOrgOrRepo(arg string) (org string, repo string, err error) { + if strings.Contains(arg, "/") { + parts := strings.SplitN(arg, "/", 2) + owner, repoName := parts[0], parts[1] + if owner == "" || repoName == "" { + return "", "", fmt.Errorf("invalid repo format: expected owner/repo, got %q", arg) + } + if !githubOwnerPattern.MatchString(owner) { + return "", "", fmt.Errorf("invalid owner name %q: must contain only alphanumeric characters and hyphens", owner) + } + if !githubRepoPattern.MatchString(repoName) { + return "", "", fmt.Errorf("invalid repo name %q: must contain only alphanumeric characters, hyphens, dots, or underscores", repoName) + } + return owner, arg, nil + } + + if err := validateOrgName(arg); err != nil { + return "", "", err + } + return arg, "", nil +} + +func newInferenceProvisionCmd() *cobra.Command { + var project string + var pool string + var provider string + var dryRun bool + + cmd := &cobra.Command{ + Use: "provision ", + Short: "Create WIF infrastructure for inference", + Long: `Provisions Workload Identity Federation infrastructure in a GCP project +for GitHub Actions to authenticate and access Vertex AI. + +Org-scoped mode (e.g. 'fullsend inference provision acme'): + Creates a WIF pool and provider scoped to all repos in the org. + +Repo-scoped mode (e.g. 'fullsend inference provision acme/widget'): + Creates a WIF pool and a dedicated provider scoped to a single repo. + +After provisioning, prints the WIF provider resource name for handoff +to the GitHub admin who runs 'fullsend admin install'. + +WIF pools are always created at locations/global.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if project == "" { + return fmt.Errorf("--project is required") + } + if !gcpProjectPattern.MatchString(project) { + return fmt.Errorf("invalid GCP project ID %q: must be 6-30 lowercase letters, digits, and hyphens", project) + } + + org, repo, err := parseOrgOrRepo(args[0]) + if err != nil { + return err + } + + if repo != "" && cmd.Flags().Changed("provider") { + return fmt.Errorf("--provider is not supported in repo-scoped mode (provider ID is auto-generated from owner/repo)") + } + + printer := ui.New(cmd.OutOrStdout()) + + if dryRun { + return runInferenceProvisionDryRun(cmd, printer, org, repo, project, pool, provider) + } + + return runInferenceProvision(cmd, printer, org, repo, project, pool, provider) + }, + } + + cmd.Flags().StringVar(&project, "project", "", "GCP project ID for Vertex AI (required)") + cmd.Flags().StringVar(&pool, "pool", "fullsend-pool", "WIF pool name") + cmd.Flags().StringVar(&provider, "provider", "github-oidc", "WIF provider name (org-scoped only)") + cmd.Flags().BoolVar(&dryRun, "dry-run", false, "preview changes without making them") + + return cmd +} + +func runInferenceProvisionDryRun(cmd *cobra.Command, printer *ui.Printer, org, repo, project, pool, provider string) error { + printer.Banner() + printer.Blank() + + if repo != "" { + printer.Header("Dry run: provision WIF for repo-scoped inference") + printer.Blank() + printer.StepInfo(fmt.Sprintf("Repository: %s", repo)) + parts := strings.SplitN(repo, "/", 2) + providerID := gcf.BuildRepoProviderID(parts[0], parts[1]) + printer.StepInfo(fmt.Sprintf("WIF provider: %s (repo-scoped)", providerID)) + printer.StepInfo(fmt.Sprintf("Condition: assertion.repository == '%s'", strings.ToLower(repo))) + } else { + printer.Header("Dry run: provision WIF for org-scoped inference") + printer.Blank() + printer.StepInfo(fmt.Sprintf("Organization: %s", org)) + printer.StepInfo(fmt.Sprintf("WIF provider: %s (org-scoped)", provider)) + printer.StepInfo(fmt.Sprintf("Condition: assertion.repository_owner == '%s'", strings.ToLower(org))) + } + + printer.Blank() + printer.StepInfo(fmt.Sprintf("GCP project: %s", project)) + printer.StepInfo(fmt.Sprintf("WIF pool: %s", pool)) + printer.Blank() + printer.StepInfo("Would create/update:") + printer.StepInfo(fmt.Sprintf(" - WIF pool: %s", pool)) + printer.StepInfo(" - WIF OIDC provider") + printer.StepInfo(" - IAM binding: roles/aiplatform.user") + printer.Blank() + + return nil +} + +func runInferenceProvision(cmd *cobra.Command, printer *ui.Printer, org, repo, project, pool, provider string) error { + printer.Banner() + printer.Blank() + + if repo != "" { + printer.Header("Provisioning WIF for repo-scoped inference: " + repo) + } else { + printer.Header("Provisioning WIF for org-scoped inference: " + org) + } + printer.Blank() + + ctx := cmd.Context() + + gcpClient := gcf.NewLiveGCFClient() + provisioner := gcf.NewProvisioner(gcf.Config{ + ProjectID: project, + GitHubOrgs: []string{org}, + Repo: repo, + WIFPoolName: pool, + WIFProvider: provider, + }, gcpClient) + + printer.StepStart("Provisioning WIF infrastructure") + wifProvider, err := provisioner.ProvisionWIF(ctx) + if err != nil { + printer.StepFail("WIF provisioning failed") + return fmt.Errorf("provisioning WIF for inference: %w", err) + } + printer.StepDone("WIF infrastructure ready") + printer.Blank() + + printer.KeyValue("WIF Provider", wifProvider) + printer.Blank() + + targetArg := org + if repo != "" { + targetArg = repo + } + printer.StepInfo("Pass this value to the GitHub setup command:") + printer.StepInfo(fmt.Sprintf(" fullsend admin install %s \\", targetArg)) + printer.StepInfo(fmt.Sprintf(" --inference-project=%s \\", project)) + printer.StepInfo(fmt.Sprintf(" --inference-wif-provider=%s", wifProvider)) + printer.Blank() + printer.StepWarn("IAM policy changes may take up to 7 minutes to propagate") + printer.Blank() + + return nil +} + +// inferenceStatusResult holds the data returned by the status command. +type inferenceStatusResult struct { + Status string + ProjectID string + WIFProvider string + Details []string // human-readable status lines +} + +func newInferenceStatusCmd() *cobra.Command { + var project string + var pool string + var provider string + var format string + + cmd := &cobra.Command{ + Use: "status ", + Short: "Check inference WIF health and print config", + Long: `Checks the health of inference WIF infrastructure and displays +configuration values for handoff to the GitHub admin. + +Use --format=env to print KEY=value pairs suitable for copying. +Use --format=json to get a machine-readable status + config output.`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + if project == "" { + return fmt.Errorf("--project is required") + } + if !gcpProjectPattern.MatchString(project) { + return fmt.Errorf("invalid GCP project ID %q: must be 6-30 lowercase letters, digits, and hyphens", project) + } + + switch format { + case "text", "json", "env": + // valid + default: + return fmt.Errorf("--format must be one of: text, json, env (got %q)", format) + } + + org, repo, err := parseOrgOrRepo(args[0]) + if err != nil { + return err + } + + if repo != "" && cmd.Flags().Changed("provider") { + return fmt.Errorf("--provider is not supported in repo-scoped mode (provider ID is auto-generated from owner/repo)") + } + + return runInferenceStatus(cmd, org, repo, project, pool, provider, format) + }, + } + + cmd.Flags().StringVar(&project, "project", "", "GCP project ID for Vertex AI (required)") + cmd.Flags().StringVar(&pool, "pool", "fullsend-pool", "WIF pool name") + cmd.Flags().StringVar(&provider, "provider", "github-oidc", "WIF provider name") + cmd.Flags().StringVar(&format, "format", "text", "output format: text, json, env") + + return cmd +} + +func runInferenceStatus(cmd *cobra.Command, org, repo, project, pool, provider, format string) error { + ctx := cmd.Context() + gcpClient := gcf.NewLiveGCFClient() + + poolName := pool + providerName := provider + if repo != "" { + parts := strings.SplitN(repo, "/", 2) + providerName = gcf.BuildRepoProviderID(parts[0], parts[1]) + } + + result := &inferenceStatusResult{ + ProjectID: project, + } + + // Step 1: Look up project number. + projectNumber, err := gcpClient.GetProjectNumber(ctx, project) + if err != nil { + result.Status = "error" + result.Details = append(result.Details, fmt.Sprintf("Failed to get project number: %v", err)) + return outputStatus(cmd, result, format) + } + result.Details = append(result.Details, "Project number: "+projectNumber) + + // Step 2: Check WIF provider exists. + providerInfo, err := gcpClient.GetWIFProvider(ctx, projectNumber, poolName, providerName) + if err != nil { + result.Status = "error" + result.Details = append(result.Details, fmt.Sprintf("Failed to check WIF provider: %v", err)) + return outputStatus(cmd, result, format) + } + + if providerInfo == nil { + result.Status = "not_provisioned" + result.Details = append(result.Details, fmt.Sprintf("WIF pool %q or provider %q not found", poolName, providerName)) + result.Details = append(result.Details, "Run 'fullsend inference provision' to create the infrastructure") + return outputStatus(cmd, result, format) + } + + // Step 3: Build WIF provider resource name. + wifProvider := fmt.Sprintf("projects/%s/locations/global/workloadIdentityPools/%s/providers/%s", + projectNumber, poolName, providerName) + result.WIFProvider = wifProvider + + // Step 4: Parse attribute condition for validation. + condition := providerInfo.AttributeCondition + result.Details = append(result.Details, "WIF provider: "+wifProvider) + result.Details = append(result.Details, "Attribute condition: "+condition) + + conditionOK := true + if repo != "" { + expected := fmt.Sprintf("assertion.repository == '%s'", strings.ToLower(repo)) + if condition == expected { + result.Details = append(result.Details, "Condition matches repo: OK") + } else { + result.Details = append(result.Details, fmt.Sprintf("Condition mismatch: expected %q", expected)) + conditionOK = false + } + } else { + expected := fmt.Sprintf("assertion.repository_owner == '%s'", strings.ToLower(org)) + if condition == expected { + result.Details = append(result.Details, "Condition matches org: OK") + } else if strings.Contains(condition, "repository_owner") && strings.Contains(condition, fmt.Sprintf("'%s'", strings.ToLower(org))) { + result.Details = append(result.Details, "Condition includes org (multi-org pool): OK") + } else { + result.Details = append(result.Details, fmt.Sprintf("Condition does not include org %q", org)) + conditionOK = false + } + } + + if conditionOK { + result.Status = "healthy" + } else { + result.Status = "unhealthy" + } + return outputStatus(cmd, result, format) +} + +func outputStatus(cmd *cobra.Command, result *inferenceStatusResult, format string) error { + switch format { + case "json": + output, err := formatStatusJSON(result) + if err != nil { + return err + } + fmt.Fprintln(cmd.OutOrStdout(), output) + case "env": + fmt.Fprint(cmd.OutOrStdout(), formatStatusEnv(result)) + default: + printer := ui.New(cmd.OutOrStdout()) + printer.Banner() + printer.Blank() + printer.Header("Inference Status") + printer.Blank() + + switch result.Status { + case "healthy": + printer.StepDone("Status: healthy") + case "unhealthy": + printer.StepWarn("Status: unhealthy (condition mismatch)") + case "not_provisioned": + printer.StepFail("Status: not provisioned") + default: + printer.StepFail("Status: " + result.Status) + } + + for _, detail := range result.Details { + printer.StepInfo(detail) + } + + printer.Blank() + if result.WIFProvider != "" { + printer.Header("Config values for handoff") + printer.Blank() + printer.KeyValue("FULLSEND_GCP_PROJECT_ID", result.ProjectID) + printer.KeyValue("FULLSEND_GCP_WIF_PROVIDER", result.WIFProvider) + printer.Blank() + } + } + + if result.Status != "healthy" { + return fmt.Errorf("inference status: %s", result.Status) + } + return nil +} + +func formatStatusJSON(result *inferenceStatusResult) (string, error) { + data := map[string]interface{}{ + "status": result.Status, + "details": result.Details, + } + if result.WIFProvider != "" { + data["FULLSEND_GCP_PROJECT_ID"] = result.ProjectID + data["FULLSEND_GCP_WIF_PROVIDER"] = result.WIFProvider + } + b, err := json.MarshalIndent(data, "", " ") + if err != nil { + return "", fmt.Errorf("marshaling status JSON: %w", err) + } + return string(b), nil +} + +func formatStatusEnv(result *inferenceStatusResult) string { + var sb strings.Builder + sb.WriteString(fmt.Sprintf("FULLSEND_INFERENCE_STATUS=%s\n", result.Status)) + if result.WIFProvider != "" { + sb.WriteString(fmt.Sprintf("FULLSEND_GCP_PROJECT_ID=%s\n", result.ProjectID)) + sb.WriteString(fmt.Sprintf("FULLSEND_GCP_WIF_PROVIDER=%s\n", result.WIFProvider)) + } + return sb.String() +} diff --git a/internal/cli/inference_test.go b/internal/cli/inference_test.go new file mode 100644 index 000000000..c3b48e4ea --- /dev/null +++ b/internal/cli/inference_test.go @@ -0,0 +1,379 @@ +package cli + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInferenceCommand_HasSubcommands(t *testing.T) { + cmd := newInferenceCmd() + names := make(map[string]bool) + for _, sub := range cmd.Commands() { + names[sub.Name()] = true + } + assert.True(t, names["provision"], "expected provision subcommand") + assert.True(t, names["status"], "expected status subcommand") +} + +func TestInferenceCommand_RegisteredInRoot(t *testing.T) { + cmd := newRootCmd() + found := false + for _, sub := range cmd.Commands() { + if sub.Use == "inference" { + found = true + break + } + } + assert.True(t, found, "expected inference subcommand registered in root") +} + +// --- provision tests --- + +func TestInferenceProvisionCmd_RequiresArg(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "accepts 1 arg(s)") +} + +func TestInferenceProvisionCmd_RequiresProject(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "--project is required") +} + +func TestInferenceProvisionCmd_RejectsInvalidProjectID(t *testing.T) { + tests := []struct { + name string + project string + }{ + {"uppercase", "MY-PROJECT"}, + {"too short", "ab"}, + {"starts with digit", "1project"}, + {"starts with hyphen", "-project"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme", + "--project", tc.project, "--dry-run"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid GCP project ID") + }) + } +} + +func TestInferenceProvisionCmd_Flags(t *testing.T) { + cmd := newInferenceProvisionCmd() + + projectFlag := cmd.Flags().Lookup("project") + require.NotNil(t, projectFlag, "expected --project flag") + + poolFlag := cmd.Flags().Lookup("pool") + require.NotNil(t, poolFlag, "expected --pool flag") + assert.Equal(t, "fullsend-pool", poolFlag.DefValue) + + providerFlag := cmd.Flags().Lookup("provider") + require.NotNil(t, providerFlag, "expected --provider flag") + assert.Equal(t, "github-oidc", providerFlag.DefValue) + + dryRunFlag := cmd.Flags().Lookup("dry-run") + require.NotNil(t, dryRunFlag, "expected --dry-run flag") + + assert.Nil(t, cmd.Flags().Lookup("region"), "should not have --region flag") +} + +func TestInferenceProvisionCmd_DetectsOrgMode(t *testing.T) { + // Org-scoped: arg without "/" + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme", + "--project", "my-project", + "--dry-run"}) + err := cmd.Execute() + // Should succeed (dry-run prints what would happen) + require.NoError(t, err) +} + +func TestInferenceProvisionCmd_DetectsRepoMode(t *testing.T) { + // Repo-scoped: arg with "/" + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme/widget", + "--project", "my-project", + "--dry-run"}) + err := cmd.Execute() + // Should succeed (dry-run prints what would happen) + require.NoError(t, err) +} + +func TestInferenceProvisionCmd_DryRunOrgSucceeds(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme", + "--project", "my-project", + "--dry-run"}) + err := cmd.Execute() + require.NoError(t, err) +} + +func TestInferenceProvisionCmd_DryRunRepoSucceeds(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme/widget", + "--project", "my-project", + "--dry-run"}) + err := cmd.Execute() + require.NoError(t, err) +} + +func TestInferenceProvisionCmd_DryRunCustomPool(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme", + "--project", "my-project", + "--pool", "custom-pool", + "--provider", "custom-provider", + "--dry-run"}) + err := cmd.Execute() + require.NoError(t, err) +} + +func TestInferenceProvisionCmd_RejectsInvalidOrgName(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "-invalid", + "--project", "my-project"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid") +} + +func TestInferenceProvisionCmd_RejectsInvalidRepoFormat(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme/", + "--project", "my-project"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid") +} + +func TestInferenceProvisionCmd_DoesNotRequireGitHubToken(t *testing.T) { + // Unset all GitHub tokens to prove they're not needed. + t.Setenv("GH_TOKEN", "") + t.Setenv("GITHUB_TOKEN", "") + + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme", + "--project", "my-project", + "--dry-run"}) + err := cmd.Execute() + // Should not fail with "no GitHub token found" + require.NoError(t, err) +} + +// --- status tests --- + +func TestInferenceStatusCmd_RequiresArg(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "status"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "accepts 1 arg(s)") +} + +func TestInferenceStatusCmd_RequiresProject(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "status", "acme"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "--project is required") +} + +func TestInferenceStatusCmd_RejectsInvalidProjectID(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "status", "acme", + "--project", "UPPER-CASE"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid GCP project ID") +} + +func TestInferenceStatusCmd_Flags(t *testing.T) { + cmd := newInferenceStatusCmd() + + projectFlag := cmd.Flags().Lookup("project") + require.NotNil(t, projectFlag, "expected --project flag") + + poolFlag := cmd.Flags().Lookup("pool") + require.NotNil(t, poolFlag, "expected --pool flag") + assert.Equal(t, "fullsend-pool", poolFlag.DefValue) + + providerFlag := cmd.Flags().Lookup("provider") + require.NotNil(t, providerFlag, "expected --provider flag") + assert.Equal(t, "github-oidc", providerFlag.DefValue) + + formatFlag := cmd.Flags().Lookup("format") + require.NotNil(t, formatFlag, "expected --format flag") + assert.Equal(t, "text", formatFlag.DefValue) + + assert.Nil(t, cmd.Flags().Lookup("region"), "should not have --region flag") +} + +func TestInferenceStatusCmd_RejectsInvalidFormat(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "status", "acme", + "--project", "my-project", + "--format", "yaml"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "--format must be one of: text, json, env") +} + +func TestInferenceStatusCmd_DoesNotRequireGitHubToken(t *testing.T) { + // Unset all GitHub tokens to prove they're not needed. + t.Setenv("GH_TOKEN", "") + t.Setenv("GITHUB_TOKEN", "") + + // Status without dry-run will try to reach GCP, which will fail, + // but it should NOT fail with "no GitHub token found". + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "status", "acme", + "--project", "my-project"}) + err := cmd.Execute() + if err != nil { + assert.NotContains(t, err.Error(), "no GitHub token found") + } +} + +// --- parseOrgOrRepo tests --- + +func TestParseOrgOrRepo_OrgMode(t *testing.T) { + org, repo, err := parseOrgOrRepo("acme") + require.NoError(t, err) + assert.Equal(t, "acme", org) + assert.Equal(t, "", repo) +} + +func TestParseOrgOrRepo_RepoMode(t *testing.T) { + org, repo, err := parseOrgOrRepo("acme/widget") + require.NoError(t, err) + assert.Equal(t, "acme", org) + assert.Equal(t, "acme/widget", repo) +} + +func TestParseOrgOrRepo_Invalid(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"empty owner in repo", "/widget", "invalid"}, + {"empty repo in repo", "acme/", "invalid"}, + {"leading hyphen", "-acme", "hyphen"}, + {"trailing hyphen", "acme-", "hyphen"}, + {"invalid chars", "ac me", "invalid"}, + {"dots in owner", "ac.me/widget", "invalid"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, _, err := parseOrgOrRepo(tc.input) + require.Error(t, err) + assert.Contains(t, err.Error(), tc.want) + }) + } +} + +// --- formatStatusJSON tests --- + +func TestFormatStatusJSON(t *testing.T) { + result := &inferenceStatusResult{ + Status: "healthy", + ProjectID: "my-project", + WIFProvider: "projects/123/locations/global/workloadIdentityPools/fullsend-pool/providers/github-oidc", + Details: []string{"Project number: 123", "WIF provider: found"}, + } + + output, err := formatStatusJSON(result) + require.NoError(t, err) + + var parsed map[string]interface{} + err = json.Unmarshal([]byte(output), &parsed) + require.NoError(t, err) + + assert.Equal(t, "healthy", parsed["status"]) + assert.Equal(t, "my-project", parsed["FULLSEND_GCP_PROJECT_ID"]) + assert.Equal(t, "projects/123/locations/global/workloadIdentityPools/fullsend-pool/providers/github-oidc", parsed["FULLSEND_GCP_WIF_PROVIDER"]) + details, ok := parsed["details"].([]interface{}) + require.True(t, ok, "expected details to be an array") + assert.Len(t, details, 2) +} + +func TestFormatStatusJSON_Unhealthy(t *testing.T) { + result := &inferenceStatusResult{ + Status: "error", + ProjectID: "my-project", + Details: []string{"Failed to get project number"}, + } + + output, err := formatStatusJSON(result) + require.NoError(t, err) + + var parsed map[string]interface{} + err = json.Unmarshal([]byte(output), &parsed) + require.NoError(t, err) + + assert.Equal(t, "error", parsed["status"]) + assert.Nil(t, parsed["FULLSEND_GCP_PROJECT_ID"], "should not include config keys when unhealthy") + assert.Nil(t, parsed["FULLSEND_GCP_WIF_PROVIDER"], "should not include config keys when unhealthy") +} + +// --- formatStatusEnv tests --- + +func TestFormatStatusEnv(t *testing.T) { + result := &inferenceStatusResult{ + Status: "healthy", + ProjectID: "my-project", + WIFProvider: "projects/123/locations/global/workloadIdentityPools/fullsend-pool/providers/github-oidc", + } + + output := formatStatusEnv(result) + assert.Contains(t, output, "FULLSEND_INFERENCE_STATUS=healthy") + assert.Contains(t, output, "FULLSEND_GCP_PROJECT_ID=my-project") + assert.Contains(t, output, "FULLSEND_GCP_WIF_PROVIDER=projects/123/locations/global/workloadIdentityPools/fullsend-pool/providers/github-oidc") + assert.NotContains(t, output, "FULLSEND_GCP_REGION") + assert.NotContains(t, output, "Status:") +} + +func TestFormatStatusEnv_Unhealthy(t *testing.T) { + result := &inferenceStatusResult{ + Status: "unhealthy", + ProjectID: "my-project", + } + + output := formatStatusEnv(result) + assert.Contains(t, output, "FULLSEND_INFERENCE_STATUS=unhealthy") + assert.NotContains(t, output, "FULLSEND_GCP_PROJECT_ID") + assert.NotContains(t, output, "FULLSEND_GCP_WIF_PROVIDER") +} + +func TestInferenceStatusCmd_RejectsProviderInRepoMode(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "status", "acme/widget", + "--project", "my-project", + "--provider", "custom-provider"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "--provider is not supported in repo-scoped mode") +} + +func TestInferenceProvisionCmd_RejectsProviderInRepoMode(t *testing.T) { + cmd := newRootCmd() + cmd.SetArgs([]string{"inference", "provision", "acme/widget", + "--project", "my-project", + "--provider", "custom-provider"}) + err := cmd.Execute() + require.Error(t, err) + assert.Contains(t, err.Error(), "--provider is not supported in repo-scoped mode") +} diff --git a/internal/cli/root.go b/internal/cli/root.go index d77ea02d5..1bce14026 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -21,6 +21,7 @@ func newRootCmd() *cobra.Command { Version: version, } cmd.AddCommand(newAdminCmd()) + cmd.AddCommand(newInferenceCmd()) cmd.AddCommand(newRunCmd()) cmd.AddCommand(newScanCmd()) cmd.AddCommand(newPostReviewCmd())