Skip to content

Commit 4065af3

Browse files
authored
[#674] Support env var in ssh config (#683)
1 parent 87a33e6 commit 4065af3

File tree

2 files changed

+127
-8
lines changed

2 files changed

+127
-8
lines changed

internal/dag/executor/ssh.go

+40-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ import (
2020
"errors"
2121
"fmt"
2222
"io"
23+
"net"
2324
"os"
25+
"reflect"
2426
"strings"
2527

2628
"github.com/mitchellh/mapstructure"
@@ -37,15 +39,23 @@ type sshExec struct {
3739
session *ssh.Session
3840
}
3941

40-
type sshExecConfig struct {
42+
type sshExecConfigDefinition struct {
4143
User string
4244
IP string
43-
Port int
45+
Port any
4446
Key string
4547
Password string
4648
StrictHostKeyChecking bool
4749
}
4850

51+
type sshExecConfig struct {
52+
User string
53+
IP string
54+
Port string
55+
Key string
56+
Password string
57+
}
58+
4959
// selectSSHAuthMethod selects the authentication method based on the configuration.
5060
// If the key is provided, it will use the public key authentication method.
5161
// Otherwise, it will use the password authentication method.
@@ -67,10 +77,21 @@ func selectSSHAuthMethod(cfg *sshExecConfig) (ssh.AuthMethod, error) {
6777
return ssh.Password(cfg.Password), nil
6878
}
6979

80+
// expandEnvHook is a mapstructure decode hook that expands environment variables in string fields
81+
func expandEnvHook(f reflect.Type, t reflect.Type, data any) (any, error) {
82+
if f.Kind() != reflect.String || t.Kind() != reflect.String {
83+
return data, nil
84+
}
85+
return os.ExpandEnv(data.(string)), nil
86+
}
87+
7088
func newSSHExec(_ context.Context, step dag.Step) (Executor, error) {
71-
cfg := new(sshExecConfig)
89+
def := new(sshExecConfigDefinition)
7290
md, err := mapstructure.NewDecoder(
73-
&mapstructure.DecoderConfig{Result: cfg},
91+
&mapstructure.DecoderConfig{
92+
Result: def,
93+
DecodeHook: expandEnvHook,
94+
},
7495
)
7596

7697
if err != nil {
@@ -81,11 +102,22 @@ func newSSHExec(_ context.Context, step dag.Step) (Executor, error) {
81102
return nil, err
82103
}
83104

84-
if cfg.Port == 0 {
85-
cfg.Port = 22
105+
cfg := &sshExecConfig{
106+
User: def.User,
107+
IP: def.IP,
108+
Key: def.Key,
109+
Password: def.Password,
110+
}
111+
112+
// Handle Port as either string or int
113+
port := os.ExpandEnv(fmt.Sprintf("%v", def.Port))
114+
if port == "" {
115+
port = "22"
86116
}
117+
cfg.Port = port
87118

88-
if cfg.StrictHostKeyChecking {
119+
// StrictHostKeyChecking is not supported yet.
120+
if def.StrictHostKeyChecking {
89121
return nil, errStrictHostKey
90122
}
91123

@@ -130,7 +162,7 @@ func (e *sshExec) Kill(_ os.Signal) error {
130162
}
131163

132164
func (e *sshExec) Run() error {
133-
addr := fmt.Sprintf("%s:%d", e.config.IP, e.config.Port)
165+
addr := net.JoinHostPort(e.config.IP, e.config.Port)
134166
conn, err := ssh.Dial("tcp", addr, e.sshConfig)
135167
if err != nil {
136168
return err

internal/dag/executor/ssh_test.go

+87
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (C) 2024 The Dagu Authors
2+
//
3+
// This program is free software: you can redistribute it and/or modify
4+
// it under the terms of the GNU General Public License as published by
5+
// the Free Software Foundation, either version 3 of the License, or
6+
// (at your option) any later version.
7+
//
8+
// This program is distributed in the hope that it will be useful,
9+
// but WITHOUT ANY WARRANTY; without even the implied warranty of
10+
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11+
// GNU General Public License for more details.
12+
//
13+
// You should have received a copy of the GNU General Public License
14+
// along with this program. If not, see <https://www.gnu.org/licenses/>.
15+
16+
package executor
17+
18+
import (
19+
"context"
20+
"os"
21+
"testing"
22+
23+
"github.com/dagu-org/dagu/internal/dag"
24+
"github.com/stretchr/testify/assert"
25+
"github.com/stretchr/testify/require"
26+
)
27+
28+
func TestSSHExecutor(t *testing.T) {
29+
t.Parallel()
30+
31+
t.Run("Basic", func(t *testing.T) {
32+
step := dag.Step{
33+
Name: "ssh-exec",
34+
ExecutorConfig: dag.ExecutorConfig{
35+
Type: "ssh",
36+
Config: map[string]any{
37+
"User": "testuser",
38+
"IP": "testip",
39+
"Port": 25,
40+
"Password": "testpassword",
41+
},
42+
},
43+
}
44+
ctx := context.Background()
45+
exec, err := newSSHExec(ctx, step)
46+
require.NoError(t, err)
47+
48+
sshExec, ok := exec.(*sshExec)
49+
require.True(t, ok)
50+
51+
assert.Equal(t, "testuser", sshExec.config.User)
52+
assert.Equal(t, "testip", sshExec.config.IP)
53+
assert.Equal(t, "25", sshExec.config.Port)
54+
assert.Equal(t, "testpassword", sshExec.config.Password)
55+
})
56+
57+
t.Run("ExpandEnv", func(t *testing.T) {
58+
os.Setenv("TEST_SSH_EXEC_USER", "testuser")
59+
os.Setenv("TEST_SSH_EXEC_IP", "testip")
60+
os.Setenv("TEST_SSH_EXEC_PORT", "23")
61+
os.Setenv("TEST_SSH_EXEC_PASSWORD", "testpassword")
62+
63+
step := dag.Step{
64+
Name: "ssh-exec",
65+
ExecutorConfig: dag.ExecutorConfig{
66+
Type: "ssh",
67+
Config: map[string]any{
68+
"User": "${TEST_SSH_EXEC_USER}",
69+
"IP": "${TEST_SSH_EXEC_IP}",
70+
"Port": "${TEST_SSH_EXEC_PORT}",
71+
"Password": "${TEST_SSH_EXEC_PASSWORD}",
72+
},
73+
},
74+
}
75+
ctx := context.Background()
76+
exec, err := newSSHExec(ctx, step)
77+
require.NoError(t, err)
78+
79+
sshExec, ok := exec.(*sshExec)
80+
require.True(t, ok)
81+
82+
assert.Equal(t, "testuser", sshExec.config.User)
83+
assert.Equal(t, "testip", sshExec.config.IP)
84+
assert.Equal(t, "23", sshExec.config.Port)
85+
assert.Equal(t, "testpassword", sshExec.config.Password)
86+
})
87+
}

0 commit comments

Comments
 (0)