Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 149 additions & 2 deletions internal/cmd/auth_credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@ import (
)

type AuthCredentialsCmd struct {
Set AuthCredentialsSetCmd `cmd:"" default:"withargs" help:"Store OAuth client credentials"`
List AuthCredentialsListCmd `cmd:"" name:"list" help:"List stored OAuth client credentials"`
Set AuthCredentialsSetCmd `cmd:"" default:"withargs" help:"Store OAuth client credentials"`
List AuthCredentialsListCmd `cmd:"" name:"list" help:"List stored OAuth client credentials"`
Remove AuthCredentialsRemoveCmd `cmd:"" name:"remove" help:"Remove stored OAuth client credentials"`
}

type AuthCredentialsSetCmd struct {
Expand Down Expand Up @@ -160,3 +161,149 @@ func (c *AuthCredentialsListCmd) Run(ctx context.Context, _ *RootFlags) error {
}
return nil
}

type AuthCredentialsRemoveCmd struct {
Client string `arg:"" optional:"" name:"client" help:"Client name to remove (omit for default, or 'all' to remove every client)"`
}

func (c *AuthCredentialsRemoveCmd) Run(ctx context.Context, flags *RootFlags) error {
u := ui.FromContext(ctx)

// Determine target client(s): explicit arg > --client flag > default.
target := strings.TrimSpace(c.Client)
if target == "" {
t, err := normalizeClientForFlag(authclient.ClientOverrideFromContext(ctx))
if err != nil {
return err
}
target = t
}

if strings.EqualFold(target, "all") {
return c.removeAll(ctx, flags, u)
}

client, err := config.NormalizeClientNameOrDefault(target)
if err != nil {
return err
}

accounts := findAccountsForClient(client)

action := fmt.Sprintf("remove OAuth credentials for client %q", client)
if len(accounts) > 0 {
action += fmt.Sprintf(" and %d associated token(s) (%s)", len(accounts), strings.Join(accounts, ", "))
}
if err := confirmDestructive(ctx, flags, action); err != nil {
return err
}

if err := config.DeleteClientCredentialsFor(client); err != nil {
return err
}

tokensRemoved := removeTokensForClient(client, accounts)
domainsRemoved := removeDomainMappings(client)

return writeResult(ctx, u,
kv("removed", true),
kv("client", client),
kv("tokens_removed", tokensRemoved),
kv("domains_removed", domainsRemoved),
)
}

func (c *AuthCredentialsRemoveCmd) removeAll(ctx context.Context, flags *RootFlags, u *ui.UI) error {
creds, err := config.ListClientCredentials()
if err != nil {
return err
}
if len(creds) == 0 {
return writeResult(ctx, u, kv("removed", 0))
}

names := make([]string, 0, len(creds))
for _, info := range creds {
names = append(names, info.Client)
}
if err := confirmDestructive(ctx, flags, fmt.Sprintf("remove all OAuth credentials (%s)", strings.Join(names, ", "))); err != nil {
return err
}

var allTokens []string
for _, info := range creds {
accounts := findAccountsForClient(info.Client)
if err := config.DeleteClientCredentialsFor(info.Client); err != nil {
return err
}
allTokens = append(allTokens, removeTokensForClient(info.Client, accounts)...)
removeDomainMappings(info.Client)
}

return writeResult(ctx, u,
kv("removed", len(creds)),
kv("clients", names),
kv("tokens_removed", allTokens),
)
}

// findAccountsForClient returns emails that have tokens stored under the given client.
func findAccountsForClient(client string) []string {
store, err := openSecretsStore()
if err != nil {
return nil
}
tokens, err := store.ListTokens()
if err != nil {
return nil
}
var emails []string
for _, tok := range tokens {
tokClient, _ := config.NormalizeClientNameOrDefault(tok.Client)
if tokClient == client {
emails = append(emails, tok.Email)
}
}
return emails
}

// removeTokensForClient deletes tokens for the given accounts under the specified client.
func removeTokensForClient(client string, emails []string) []string {
if len(emails) == 0 {
return nil
}
store, err := openSecretsStore()
if err != nil {
return nil
}
var removed []string
for _, email := range emails {
if err := store.DeleteToken(client, email); err == nil {
removed = append(removed, email)
}
}
return removed
}

// removeDomainMappings deletes config domain entries that point to the given client.
func removeDomainMappings(client string) []string {
cfg, err := config.ReadConfig()
if err != nil {
return nil
}
var removed []string
for domain, mapped := range cfg.ClientDomains {
normalized, nerr := config.NormalizeClientNameOrDefault(mapped)
if nerr != nil {
continue
}
if normalized == client {
removed = append(removed, domain)
delete(cfg.ClientDomains, domain)
}
}
if len(removed) > 0 {
_ = config.WriteConfig(cfg)
}
return removed
}
17 changes: 17 additions & 0 deletions internal/config/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,23 @@ func ReadClientCredentialsFor(client string) (ClientCredentials, error) {
return c, nil
}

func DeleteClientCredentialsFor(client string) error {
path, err := ClientCredentialsPathFor(client)
if err != nil {
return fmt.Errorf("resolve credentials path: %w", err)
}

if err := os.Remove(path); err != nil {
if os.IsNotExist(err) {
return &CredentialsMissingError{Path: path, Cause: err}
}

return fmt.Errorf("delete credentials: %w", err)
}

return nil
}

func ClientCredentialsExists(client string) (bool, error) {
path, err := ClientCredentialsPathFor(client)
if err != nil {
Expand Down