diff --git a/cmd/topo/vscode.go b/cmd/topo/vscode.go index 5c6566e..fe04b30 100644 --- a/cmd/topo/vscode.go +++ b/cmd/topo/vscode.go @@ -1,8 +1,11 @@ package main import ( + "encoding/json" + "fmt" "os" + "github.com/arm/topo/internal/ssh" "github.com/arm/topo/internal/vscode" "github.com/spf13/cobra" ) @@ -19,6 +22,26 @@ var getProjectCmd = &cobra.Command{ }, } +var listCandidateTargets = &cobra.Command{ + Use: "list-candidate-targets ", + Short: "Prints a list of candidate ssh targets defined in the given config file as JSON", + Hidden: true, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + cmd.SilenceUsage = true + configPath := args[0] + + hosts := ssh.ListHosts(configPath) + data, err := json.Marshal(hosts) + if err != nil { + return fmt.Errorf("failed to marshal ssh hosts: %w", err) + } + _, err = fmt.Fprintf(os.Stdout, "%s\n", data) + return err + }, +} + func init() { rootCmd.AddCommand(getProjectCmd) + rootCmd.AddCommand(listCandidateTargets) } diff --git a/internal/collections/set.go b/internal/collections/set.go new file mode 100644 index 0000000..d42664f --- /dev/null +++ b/internal/collections/set.go @@ -0,0 +1,31 @@ +package collections + +import ( + "maps" + "slices" +) + +type Set[T comparable] struct { + elements map[T]struct{} +} + +func NewSet[T comparable](items ...T) Set[T] { + s := Set[T]{elements: make(map[T]struct{})} + for _, item := range items { + s.Add(item) + } + return s +} + +func (s *Set[T]) Add(item T) { + s.elements[item] = struct{}{} +} + +func (s *Set[T]) ToSlice() []T { + return slices.Collect(maps.Keys(s.elements)) +} + +func (s *Set[T]) Contains(item T) bool { + _, exists := s.elements[item] + return exists +} diff --git a/internal/collections/set_test.go b/internal/collections/set_test.go new file mode 100644 index 0000000..5b9ea15 --- /dev/null +++ b/internal/collections/set_test.go @@ -0,0 +1,75 @@ +package collections_test + +import ( + "testing" + + "github.com/arm/topo/internal/collections" + "github.com/stretchr/testify/assert" +) + +func TestSet(t *testing.T) { + t.Run("NewSet", func(t *testing.T) { + t.Run("creates a set containing the provided items", func(t *testing.T) { + set := collections.NewSet("a", "b", "c") + + assert.True(t, set.Contains("a")) + assert.True(t, set.Contains("b")) + assert.True(t, set.Contains("c")) + assert.Len(t, set.ToSlice(), 3) + }) + + t.Run("deduplicates items", func(t *testing.T) { + set := collections.NewSet("a", "a", "b") + + got := set.ToSlice() + + assert.Len(t, got, 2) + }) + }) + + t.Run("Add", func(t *testing.T) { + t.Run("adds an item to the set", func(t *testing.T) { + set := collections.NewSet[string]() + + set.Add("x") + + assert.True(t, set.Contains("x")) + }) + }) + + t.Run("Contains", func(t *testing.T) { + t.Run("returns true for an item in the set", func(t *testing.T) { + set := collections.NewSet("present") + + got := set.Contains("present") + + assert.True(t, got) + }) + + t.Run("returns false for an item not in the set", func(t *testing.T) { + set := collections.NewSet("present") + + got := set.Contains("absent") + + assert.False(t, got) + }) + }) + + t.Run("ToSlice", func(t *testing.T) { + t.Run("returns all elements as a slice", func(t *testing.T) { + set := collections.NewSet(1, 2, 3) + + got := set.ToSlice() + + assert.ElementsMatch(t, []int{1, 2, 3}, got) + }) + + t.Run("returns an empty slice for an empty set", func(t *testing.T) { + set := collections.NewSet[int]() + + got := set.ToSlice() + + assert.Empty(t, got) + }) + }) +} diff --git a/internal/ssh/list.go b/internal/ssh/list.go new file mode 100644 index 0000000..cb2b491 --- /dev/null +++ b/internal/ssh/list.go @@ -0,0 +1,61 @@ +package ssh + +import ( + "strings" + + "github.com/arm/topo/internal/collections" + "github.com/arm/topo/internal/output/logger" + sshconfig "github.com/kevinburke/ssh_config" +) + +func gatherIncludedConfigPaths(cfg *sshconfig.Config) []string { + includedPaths := []string{} + + for _, host := range cfg.Hosts { + for _, node := range host.Nodes { + if include, ok := node.(*sshconfig.Include); ok { + includePath := strings.TrimSpace(strings.TrimPrefix(include.String(), "Include")) + if includePath != "" { + includedPaths = append(includedPaths, includePath) + } + } + } + } + + return includedPaths +} + +func ListHosts(configPath string) []string { + queue := []string{configPath} + seen := collections.NewSet[string]() + hosts := collections.NewSet[string]() + + for len(queue) > 0 { + currentPath := queue[0] + queue = queue[1:] + if seen.Contains(currentPath) { + continue + } + seen.Add(currentPath) + + cfg, err := readConfigFile(currentPath) + if err != nil { + logger.Error("failed to read ssh config file while listing hosts", "path", currentPath, "error", err) + continue + } + + for _, host := range cfg.Hosts { + queue = append(queue, gatherIncludedConfigPaths(cfg)...) + + for _, pattern := range host.Patterns { + patternStr := pattern.String() + if patternStr == "*" { + continue + } + hosts.Add(patternStr) + } + } + } + + return hosts.ToSlice() +} diff --git a/internal/ssh/list_test.go b/internal/ssh/list_test.go new file mode 100644 index 0000000..72fd548 --- /dev/null +++ b/internal/ssh/list_test.go @@ -0,0 +1,128 @@ +package ssh_test + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/arm/topo/internal/ssh" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func writeSSHConfig(t *testing.T, dir, name, content string) string { + t.Helper() + path := filepath.Join(dir, name) + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return filepath.ToSlash(path) +} + +func TestListHosts(t *testing.T) { + t.Run("returns hosts from a single config file", func(t *testing.T) { + tmp := t.TempDir() + configPath := writeSSHConfig(t, tmp, "config", ` +Host board1 + HostName 192.168.0.1 + +Host board2 + HostName 192.168.0.2 +`) + + got := ssh.ListHosts(configPath) + + assert.ElementsMatch(t, []string{"board1", "board2"}, got) + }) + + t.Run("excludes wildcard host pattern", func(t *testing.T) { + tmp := t.TempDir() + configPath := writeSSHConfig(t, tmp, "config", ` +Host * + ServerAliveInterval 60 + +Host myhost + HostName 10.0.0.1 +`) + + got := ssh.ListHosts(configPath) + + assert.ElementsMatch(t, []string{"myhost"}, got) + }) + + t.Run("follows include directives", func(t *testing.T) { + tmp := t.TempDir() + includedPath := writeSSHConfig(t, tmp, "extra_config", ` +Host included-host + HostName 10.0.0.2 +`) + configPath := writeSSHConfig(t, tmp, "config", fmt.Sprintf(` +Include %s + +Host main-host + HostName 10.0.0.1 +`, includedPath)) + + got := ssh.ListHosts(configPath) + + assert.ElementsMatch(t, []string{"main-host", "included-host"}, got) + }) + + t.Run("deduplicates hosts across files", func(t *testing.T) { + tmp := t.TempDir() + includedPath := writeSSHConfig(t, tmp, "extra_config", ` +Host shared-host + HostName 10.0.0.1 +`) + configPath := writeSSHConfig(t, tmp, "config", ` +Include `+includedPath+` + +Host shared-host + HostName 10.0.0.1 +`) + + got := ssh.ListHosts(configPath) + + assert.ElementsMatch(t, []string{"shared-host"}, got) + }) + + t.Run("handles cyclic includes without infinite loop", func(t *testing.T) { + tmp := t.TempDir() + configAPath := filepath.Join(tmp, "config_a") + configBPath := filepath.Join(tmp, "config_b") + + writeSSHConfig(t, tmp, "config_a", ` +Include `+configBPath+` + +Host host-a + HostName 10.0.0.1 +`) + writeSSHConfig(t, tmp, "config_b", ` +Include `+configAPath+` + +Host host-b + HostName 10.0.0.2 +`) + + got := ssh.ListHosts(configAPath) + + assert.Nil(t, got, "should return without hanging on cyclic includes") + }) + + t.Run("returns empty slice for nonexistent config file", func(t *testing.T) { + got := ssh.ListHosts("/nonexistent/path/config") + + assert.Empty(t, got) + }) + + t.Run("returns empty slice for config with only wildcard host", func(t *testing.T) { + tmp := t.TempDir() + configPath := writeSSHConfig(t, tmp, "config", ` +Host * + ServerAliveInterval 60 +`) + + got := ssh.ListHosts(configPath) + + assert.Empty(t, got) + }) +}