-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
start factoring out the 'ssh reachable' code.
We want to integrate with the golang crypto ssh library so that we not only check that we can get a TCP connection to the port, but also so that we can check that there is a valid SSH that is presenting the right public key on the other side. Also, our code was causing the goroutines to block indefinitely, as they'd never be able to send on the channel once we find one that is correct, so close a done channel to signal they have nothing to do.
- Loading branch information
Showing
5 changed files
with
247 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
// Copyright 2016 Canonical Ltd. | ||
// Licensed under the AGPLv3, see LICENCE file for details. | ||
|
||
package ssh_test | ||
|
||
import ( | ||
"testing" | ||
|
||
gc "gopkg.in/check.v1" | ||
) | ||
|
||
func TestAll(t *testing.T) { | ||
gc.TestingT(t) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
// Copyright 2014 Canonical Ltd. | ||
// Licensed under the AGPLv3, see LICENCE file for details. | ||
|
||
package ssh | ||
|
||
import ( | ||
"net" | ||
"time" | ||
|
||
"golang.org/x/crypto/ssh" | ||
"github.com/juju/errors" | ||
"github.com/juju/juju/network" | ||
|
||
"github.com/juju/loggo" | ||
) | ||
|
||
|
||
var logger = loggo.GetLogger("juju.network.ssh") | ||
|
||
// Dialer defines a Dial() method matching the signature of net.Dial(). | ||
type Dialer interface { | ||
Dial(network, address string) (net.Conn, error) | ||
} | ||
|
||
// ReachableHostPort dials the entries in the given hostPorts, in parallel, | ||
// using the given dialer, closing successfully established connections | ||
// after checking the ssh key. Individual connection errors are discarded, and | ||
// an error is returned only if none of the hostPorts can be reached when the | ||
// given timeout expires. | ||
// If publicKeys is a non empty list, then the SSH host public key will be | ||
// checked. If it is not in the list, then that host is not considered valid. | ||
// | ||
// Usually, a net.Dialer initialized with a non-empty Timeout field is passed | ||
// for dialer. | ||
func ReachableHostPort(hostPorts []network.HostPort, publicKeys []string, dialer Dialer, timeout time.Duration) (network.HostPort, error) { | ||
uniqueHPs := network.UniqueHostPorts(hostPorts) | ||
successful := make(chan network.HostPort, 1) | ||
done := make(chan struct{}, 0) | ||
|
||
for _, hostPort := range uniqueHPs { | ||
go func(hp network.HostPort) { | ||
conn, err := dialer.Dial("tcp", hp.NetAddr()) | ||
if err == nil { | ||
conn.Close() | ||
select { | ||
case successful <- hp: | ||
return | ||
case <-done: | ||
return | ||
} | ||
} | ||
}(hostPort) | ||
} | ||
|
||
select { | ||
case result := <-successful: | ||
logger.Infof("dialed %q successfully", result) | ||
close(done) | ||
return result, nil | ||
|
||
case <-time.After(timeout): | ||
close(done) | ||
return network.HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,168 @@ | ||
// Copyright 2014 Canonical Ltd. | ||
// Licensed under the AGPLv3, see LICENCE file for details. | ||
|
||
package ssh_test | ||
|
||
import ( | ||
"net" | ||
"time" | ||
|
||
_ "github.com/juju/errors" | ||
jc "github.com/juju/testing/checkers" | ||
gc "gopkg.in/check.v1" | ||
|
||
"github.com/juju/juju/network" | ||
"github.com/juju/juju/network/ssh" | ||
coretesting "github.com/juju/juju/testing" | ||
) | ||
|
||
type SSHReachableHostPortSuite struct { | ||
coretesting.BaseSuite | ||
} | ||
|
||
var _ = gc.Suite(&SSHReachableHostPortSuite{}) | ||
|
||
func (s *SSHReachableHostPortSuite) TestAllUnreachable(c *gc.C) { | ||
dialer := &net.Dialer{Timeout: 50 * time.Millisecond} | ||
unreachableHPs := closedTCPHostPorts(c, 10) | ||
timeout := 100 * time.Millisecond | ||
|
||
best, err := ssh.ReachableHostPort(unreachableHPs, nil, dialer, timeout) | ||
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") | ||
c.Check(best, gc.Equals, network.HostPort{}) | ||
} | ||
|
||
func (s *SSHReachableHostPortSuite) TestReachableInvalidPublicKey(c *gc.C) { | ||
hostPorts := []network.HostPort{ | ||
testSSHServer(c, "wrong public-key"), | ||
} | ||
timeout := 300 * time.Millisecond | ||
|
||
dialer := &net.Dialer{Timeout: 100 * time.Millisecond} | ||
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) | ||
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") | ||
c.Check(best, gc.Equals, network.HostPort{}) | ||
} | ||
|
||
func (s *SSHReachableHostPortSuite) TestReachableValidPublicKey(c *gc.C) { | ||
hostPorts := []network.HostPort{ | ||
testSSHServer(c, "public-key"), | ||
} | ||
timeout := 300 * time.Millisecond | ||
|
||
dialer := &net.Dialer{Timeout: 100 * time.Millisecond} | ||
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) | ||
c.Check(err, jc.ErrorIsNil) | ||
c.Check(best, gc.Equals, hostPorts[0]) | ||
} | ||
|
||
func (s *SSHReachableHostPortSuite) TestReachableMixedPublicKeys(c *gc.C) { | ||
// One is just closed, one is TCP only, one is SSH but the wrong key, one | ||
// is SSH with the right key | ||
fakeHostPort := closedTCPHostPorts(c, 1)[0] | ||
hostPorts := []network.HostPort{ | ||
fakeHostPort, | ||
testTCPServer(c), | ||
testSSHServer(c, "wrong public-key"), | ||
testSSHServer(c, "public-key"), | ||
} | ||
timeout := 300 * time.Millisecond | ||
dialer := &net.Dialer{Timeout: 100 * time.Millisecond} | ||
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) | ||
c.Check(best, gc.Equals, network.HostPort{}) | ||
c.Check(err, jc.ErrorIsNil) | ||
c.Check(best, jc.DeepEquals, hostPorts[3]) | ||
} | ||
|
||
func (s *SSHReachableHostPortSuite) TestReachableNoPublicKeysPassed(c *gc.C) { | ||
fakeHostPort := closedTCPHostPorts(c, 1)[0] | ||
hostPorts := []network.HostPort{ | ||
fakeHostPort, | ||
testTCPServer(c), | ||
} | ||
timeout := 300 * time.Millisecond | ||
|
||
dialer := &net.Dialer{Timeout: 100 * time.Millisecond} | ||
best, err := ssh.ReachableHostPort(hostPorts, nil, dialer, timeout) | ||
c.Check(err, jc.ErrorIsNil) | ||
c.Check(best, jc.DeepEquals, hostPorts[1]) // the only real listener | ||
} | ||
|
||
func (s *SSHReachableHostPortSuite) TestReachableNoPublicKeysAvailable(c *gc.C) { | ||
fakeHostPort := closedTCPHostPorts(c, 1)[0] | ||
hostPorts := []network.HostPort{ | ||
fakeHostPort, | ||
testTCPServer(c), | ||
} | ||
timeout := 300 * time.Millisecond | ||
|
||
dialer := &net.Dialer{Timeout: 100 * time.Millisecond} | ||
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) | ||
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") | ||
c.Check(best, gc.Equals, network.HostPort{}) | ||
} | ||
|
||
const maxTCPPort = 65535 | ||
|
||
// closedTCPHostPorts opens and then immediately closes a bunch of ports and | ||
// saves their port numbers so we're unlikely to find a real listener at that | ||
// address. | ||
func closedTCPHostPorts(c *gc.C, count int) []network.HostPort { | ||
ports := make([]network.HostPort, count) | ||
for i := 0; i < count; i++ { | ||
listener, err := net.Listen("tcp", "127.0.0.1:0") | ||
c.Assert(err, jc.ErrorIsNil) | ||
defer listener.Close() | ||
listenAddress := listener.Addr().String() | ||
port, err := network.ParseHostPort(listenAddress) | ||
c.Assert(err, jc.ErrorIsNil) | ||
ports[i] = *port | ||
} | ||
// By the time we return all the listeners are closed | ||
return ports | ||
} | ||
|
||
// testTCPServer only listens on the socket, but doesn't speak SSH | ||
func testTCPServer(c *gc.C) network.HostPort { | ||
listener, err := net.Listen("tcp", "127.0.0.1:0") | ||
c.Assert(err, jc.ErrorIsNil) | ||
|
||
listenAddress := listener.Addr().String() | ||
hostPort, err := network.ParseHostPort(listenAddress) | ||
c.Assert(err, jc.ErrorIsNil) | ||
c.Logf("listening on %q", hostPort) | ||
|
||
go func() { | ||
conn, _ := listener.Accept() | ||
if conn != nil { | ||
c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr()) | ||
conn.Close() | ||
} | ||
listener.Close() | ||
}() | ||
|
||
return *hostPort | ||
} | ||
|
||
// testSSHServer will listen on the socket and respond with the appropriate | ||
// public key information and then die. | ||
func testSSHServer(c *gc.C, publicKey string) network.HostPort { | ||
listener, err := net.Listen("tcp", "127.0.0.1:0") | ||
c.Assert(err, jc.ErrorIsNil) | ||
|
||
listenAddress := listener.Addr().String() | ||
hostPort, err := network.ParseHostPort(listenAddress) | ||
c.Assert(err, jc.ErrorIsNil) | ||
c.Logf("listening on %q", hostPort) | ||
|
||
go func() { | ||
conn, _ := listener.Accept() | ||
if conn != nil { | ||
c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr()) | ||
conn.Close() | ||
} | ||
listener.Close() | ||
}() | ||
|
||
return *hostPort | ||
} |