From 37034d884bd7b7aca7f94bb5684ce813a51c3be1 Mon Sep 17 00:00:00 2001 From: John Arbash Meinel Date: Mon, 23 Jan 2017 00:02:31 +0400 Subject: [PATCH] start factoring out the 'ssh reachable' code. We want to integrate with the golang crypto ssh library so that we not only check that we can get a TCP connection to the port, but also so that we can check that there is a valid SSH that is presenting the right public key on the other side. Also, our code was causing the goroutines to block indefinitely, as they'd never be able to send on the channel once we find one that is correct, so close a done channel to signal they have nothing to do. --- network/hostport.go | 38 -------- network/hostport_test.go | 45 --------- network/ssh/package_test.go | 14 +++ network/ssh/reachable.go | 65 +++++++++++++ network/ssh/reachable_test.go | 168 ++++++++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+), 83 deletions(-) create mode 100644 network/ssh/package_test.go create mode 100644 network/ssh/reachable.go create mode 100644 network/ssh/reachable_test.go diff --git a/network/hostport.go b/network/hostport.go index 731467646a6..c45ba57593a 100644 --- a/network/hostport.go +++ b/network/hostport.go @@ -7,7 +7,6 @@ import ( "net" "sort" "strconv" - "time" "github.com/juju/errors" "github.com/juju/utils/set" @@ -241,40 +240,3 @@ func UniqueHostPorts(hostPorts []HostPort) []HostPort { return results } - -// Dialer defines a Dial() method matching the signature of net.Dial(). -type Dialer interface { - Dial(network, address string) (net.Conn, error) -} - -// ReachableHostPort dials the entries in the given hostPorts, in parallel, -// using the given dialer, closing successfully established connections -// immediately. Individual connection errors are discarded, and an error is -// returned only if none of the hostPorts can be reached when the given timeout -// expires. -// -// Usually, a net.Dialer initialized with a non-empty Timeout field is passed -// for dialer. -func ReachableHostPort(hostPorts []HostPort, dialer Dialer, timeout time.Duration) (HostPort, error) { - uniqueHPs := UniqueHostPorts(hostPorts) - successful := make(chan HostPort, 1) - - for _, hostPort := range uniqueHPs { - go func(hp HostPort) { - conn, err := dialer.Dial("tcp", hp.NetAddr()) - if err == nil { - conn.Close() - successful <- hp - } - }(hostPort) - } - - select { - case result := <-successful: - logger.Infof("dialed %q successfully", result) - return result, nil - - case <-time.After(timeout): - return HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts) - } -} diff --git a/network/hostport_test.go b/network/hostport_test.go index 9b6a4bccfd1..2b9d769bf11 100644 --- a/network/hostport_test.go +++ b/network/hostport_test.go @@ -501,30 +501,6 @@ func (s *HostPortSuite) TestUniqueHostPortsHugeUniqueInput(c *gc.C) { c.Assert(results, jc.DeepEquals, expected) } -func (s *HostPortSuite) TestReachableHostPortAllUnreachable(c *gc.C) { - dialer := &net.Dialer{Timeout: 100 * time.Millisecond} - unreachableHPs := s.manyHostPorts(c, maxTCPPort, nil) // use IANA reserved port - timeout := 300 * time.Millisecond - - best, err := network.ReachableHostPort(unreachableHPs, dialer, timeout) - c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") - c.Check(best, gc.Equals, network.HostPort{}) -} - -func (s *HostPortSuite) TestReachableHostPortRealDial(c *gc.C) { - fakeHostPort := network.NewHostPorts(1234, "127.0.0.1")[0] - hostPorts := []network.HostPort{ - fakeHostPort, - testTCPServer(c), - } - timeout := 300 * time.Millisecond - - dialer := &net.Dialer{Timeout: 100 * time.Millisecond} - best, err := network.ReachableHostPort(hostPorts, dialer, timeout) - c.Check(err, jc.ErrorIsNil) - c.Check(best, jc.DeepEquals, hostPorts[1]) // the only real listener -} - const maxTCPPort = 65535 func (s *HostPortSuite) manyHostPorts(c *gc.C, count int, addressFunc func(index int) string) []network.HostPort { @@ -542,24 +518,3 @@ func (s *HostPortSuite) manyHostPorts(c *gc.C, count int, addressFunc func(index } return results } - -func testTCPServer(c *gc.C) network.HostPort { - listener, err := net.Listen("tcp", "127.0.0.1:0") - c.Assert(err, jc.ErrorIsNil) - - listenAddress := listener.Addr().String() - hostPort, err := network.ParseHostPort(listenAddress) - c.Assert(err, jc.ErrorIsNil) - c.Logf("listening on %q", hostPort) - - go func() { - conn, _ := listener.Accept() - if conn != nil { - c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr()) - conn.Close() - } - listener.Close() - }() - - return *hostPort -} diff --git a/network/ssh/package_test.go b/network/ssh/package_test.go new file mode 100644 index 00000000000..036da5d8e8c --- /dev/null +++ b/network/ssh/package_test.go @@ -0,0 +1,14 @@ +// Copyright 2016 Canonical Ltd. +// Licensed under the AGPLv3, see LICENCE file for details. + +package ssh_test + +import ( + "testing" + + gc "gopkg.in/check.v1" +) + +func TestAll(t *testing.T) { + gc.TestingT(t) +} diff --git a/network/ssh/reachable.go b/network/ssh/reachable.go new file mode 100644 index 00000000000..aa93de7271a --- /dev/null +++ b/network/ssh/reachable.go @@ -0,0 +1,65 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the AGPLv3, see LICENCE file for details. + +package ssh + +import ( + "net" + "time" + + "golang.org/x/crypto/ssh" + "github.com/juju/errors" + "github.com/juju/juju/network" + + "github.com/juju/loggo" +) + + +var logger = loggo.GetLogger("juju.network.ssh") + +// Dialer defines a Dial() method matching the signature of net.Dial(). +type Dialer interface { + Dial(network, address string) (net.Conn, error) +} + +// ReachableHostPort dials the entries in the given hostPorts, in parallel, +// using the given dialer, closing successfully established connections +// after checking the ssh key. Individual connection errors are discarded, and +// an error is returned only if none of the hostPorts can be reached when the +// given timeout expires. +// If publicKeys is a non empty list, then the SSH host public key will be +// checked. If it is not in the list, then that host is not considered valid. +// +// Usually, a net.Dialer initialized with a non-empty Timeout field is passed +// for dialer. +func ReachableHostPort(hostPorts []network.HostPort, publicKeys []string, dialer Dialer, timeout time.Duration) (network.HostPort, error) { + uniqueHPs := network.UniqueHostPorts(hostPorts) + successful := make(chan network.HostPort, 1) + done := make(chan struct{}, 0) + + for _, hostPort := range uniqueHPs { + go func(hp network.HostPort) { + conn, err := dialer.Dial("tcp", hp.NetAddr()) + if err == nil { + conn.Close() + select { + case successful <- hp: + return + case <-done: + return + } + } + }(hostPort) + } + + select { + case result := <-successful: + logger.Infof("dialed %q successfully", result) + close(done) + return result, nil + + case <-time.After(timeout): + close(done) + return network.HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts) + } +} diff --git a/network/ssh/reachable_test.go b/network/ssh/reachable_test.go new file mode 100644 index 00000000000..1f2a3e283a5 --- /dev/null +++ b/network/ssh/reachable_test.go @@ -0,0 +1,168 @@ +// Copyright 2014 Canonical Ltd. +// Licensed under the AGPLv3, see LICENCE file for details. + +package ssh_test + +import ( + "net" + "time" + + _ "github.com/juju/errors" + jc "github.com/juju/testing/checkers" + gc "gopkg.in/check.v1" + + "github.com/juju/juju/network" + "github.com/juju/juju/network/ssh" + coretesting "github.com/juju/juju/testing" +) + +type SSHReachableHostPortSuite struct { + coretesting.BaseSuite +} + +var _ = gc.Suite(&SSHReachableHostPortSuite{}) + +func (s *SSHReachableHostPortSuite) TestAllUnreachable(c *gc.C) { + dialer := &net.Dialer{Timeout: 50 * time.Millisecond} + unreachableHPs := closedTCPHostPorts(c, 10) + timeout := 100 * time.Millisecond + + best, err := ssh.ReachableHostPort(unreachableHPs, nil, dialer, timeout) + c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") + c.Check(best, gc.Equals, network.HostPort{}) +} + +func (s *SSHReachableHostPortSuite) TestReachableInvalidPublicKey(c *gc.C) { + hostPorts := []network.HostPort{ + testSSHServer(c, "wrong public-key"), + } + timeout := 300 * time.Millisecond + + dialer := &net.Dialer{Timeout: 100 * time.Millisecond} + best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) + c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") + c.Check(best, gc.Equals, network.HostPort{}) +} + +func (s *SSHReachableHostPortSuite) TestReachableValidPublicKey(c *gc.C) { + hostPorts := []network.HostPort{ + testSSHServer(c, "public-key"), + } + timeout := 300 * time.Millisecond + + dialer := &net.Dialer{Timeout: 100 * time.Millisecond} + best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) + c.Check(err, jc.ErrorIsNil) + c.Check(best, gc.Equals, hostPorts[0]) +} + +func (s *SSHReachableHostPortSuite) TestReachableMixedPublicKeys(c *gc.C) { + // One is just closed, one is TCP only, one is SSH but the wrong key, one + // is SSH with the right key + fakeHostPort := closedTCPHostPorts(c, 1)[0] + hostPorts := []network.HostPort{ + fakeHostPort, + testTCPServer(c), + testSSHServer(c, "wrong public-key"), + testSSHServer(c, "public-key"), + } + timeout := 300 * time.Millisecond + dialer := &net.Dialer{Timeout: 100 * time.Millisecond} + best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) + c.Check(best, gc.Equals, network.HostPort{}) + c.Check(err, jc.ErrorIsNil) + c.Check(best, jc.DeepEquals, hostPorts[3]) +} + +func (s *SSHReachableHostPortSuite) TestReachableNoPublicKeysPassed(c *gc.C) { + fakeHostPort := closedTCPHostPorts(c, 1)[0] + hostPorts := []network.HostPort{ + fakeHostPort, + testTCPServer(c), + } + timeout := 300 * time.Millisecond + + dialer := &net.Dialer{Timeout: 100 * time.Millisecond} + best, err := ssh.ReachableHostPort(hostPorts, nil, dialer, timeout) + c.Check(err, jc.ErrorIsNil) + c.Check(best, jc.DeepEquals, hostPorts[1]) // the only real listener +} + +func (s *SSHReachableHostPortSuite) TestReachableNoPublicKeysAvailable(c *gc.C) { + fakeHostPort := closedTCPHostPorts(c, 1)[0] + hostPorts := []network.HostPort{ + fakeHostPort, + testTCPServer(c), + } + timeout := 300 * time.Millisecond + + dialer := &net.Dialer{Timeout: 100 * time.Millisecond} + best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout) + c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*") + c.Check(best, gc.Equals, network.HostPort{}) +} + +const maxTCPPort = 65535 + +// closedTCPHostPorts opens and then immediately closes a bunch of ports and +// saves their port numbers so we're unlikely to find a real listener at that +// address. +func closedTCPHostPorts(c *gc.C, count int) []network.HostPort { + ports := make([]network.HostPort, count) + for i := 0; i < count; i++ { + listener, err := net.Listen("tcp", "127.0.0.1:0") + c.Assert(err, jc.ErrorIsNil) + defer listener.Close() + listenAddress := listener.Addr().String() + port, err := network.ParseHostPort(listenAddress) + c.Assert(err, jc.ErrorIsNil) + ports[i] = *port + } + // By the time we return all the listeners are closed + return ports +} + +// testTCPServer only listens on the socket, but doesn't speak SSH +func testTCPServer(c *gc.C) network.HostPort { + listener, err := net.Listen("tcp", "127.0.0.1:0") + c.Assert(err, jc.ErrorIsNil) + + listenAddress := listener.Addr().String() + hostPort, err := network.ParseHostPort(listenAddress) + c.Assert(err, jc.ErrorIsNil) + c.Logf("listening on %q", hostPort) + + go func() { + conn, _ := listener.Accept() + if conn != nil { + c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr()) + conn.Close() + } + listener.Close() + }() + + return *hostPort +} + +// testSSHServer will listen on the socket and respond with the appropriate +// public key information and then die. +func testSSHServer(c *gc.C, publicKey string) network.HostPort { + listener, err := net.Listen("tcp", "127.0.0.1:0") + c.Assert(err, jc.ErrorIsNil) + + listenAddress := listener.Addr().String() + hostPort, err := network.ParseHostPort(listenAddress) + c.Assert(err, jc.ErrorIsNil) + c.Logf("listening on %q", hostPort) + + go func() { + conn, _ := listener.Accept() + if conn != nil { + c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr()) + conn.Close() + } + listener.Close() + }() + + return *hostPort +}