Skip to content

Commit

Permalink
start factoring out the 'ssh reachable' code.
Browse files Browse the repository at this point in the history
We want to integrate with the golang crypto ssh library so that
we not only check that we can get a TCP connection to the port, but
also so that we can check that there is a valid SSH that is presenting
the right public key on the other side.
Also, our code was causing the goroutines to block indefinitely, as
they'd never be able to send on the channel once we find one that is
correct, so close a done channel to signal they have nothing to do.
  • Loading branch information
jameinel committed Jan 22, 2017
1 parent 4728ad2 commit 37034d8
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 83 deletions.
38 changes: 0 additions & 38 deletions network/hostport.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"net"
"sort"
"strconv"
"time"

"github.com/juju/errors"
"github.com/juju/utils/set"
Expand Down Expand Up @@ -241,40 +240,3 @@ func UniqueHostPorts(hostPorts []HostPort) []HostPort {

return results
}

// Dialer defines a Dial() method matching the signature of net.Dial().
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}

// ReachableHostPort dials the entries in the given hostPorts, in parallel,
// using the given dialer, closing successfully established connections
// immediately. Individual connection errors are discarded, and an error is
// returned only if none of the hostPorts can be reached when the given timeout
// expires.
//
// Usually, a net.Dialer initialized with a non-empty Timeout field is passed
// for dialer.
func ReachableHostPort(hostPorts []HostPort, dialer Dialer, timeout time.Duration) (HostPort, error) {
uniqueHPs := UniqueHostPorts(hostPorts)
successful := make(chan HostPort, 1)

for _, hostPort := range uniqueHPs {
go func(hp HostPort) {
conn, err := dialer.Dial("tcp", hp.NetAddr())
if err == nil {
conn.Close()
successful <- hp
}
}(hostPort)
}

select {
case result := <-successful:
logger.Infof("dialed %q successfully", result)
return result, nil

case <-time.After(timeout):
return HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts)
}
}
45 changes: 0 additions & 45 deletions network/hostport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,30 +501,6 @@ func (s *HostPortSuite) TestUniqueHostPortsHugeUniqueInput(c *gc.C) {
c.Assert(results, jc.DeepEquals, expected)
}

func (s *HostPortSuite) TestReachableHostPortAllUnreachable(c *gc.C) {
dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
unreachableHPs := s.manyHostPorts(c, maxTCPPort, nil) // use IANA reserved port
timeout := 300 * time.Millisecond

best, err := network.ReachableHostPort(unreachableHPs, dialer, timeout)
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*")
c.Check(best, gc.Equals, network.HostPort{})
}

func (s *HostPortSuite) TestReachableHostPortRealDial(c *gc.C) {
fakeHostPort := network.NewHostPorts(1234, "127.0.0.1")[0]
hostPorts := []network.HostPort{
fakeHostPort,
testTCPServer(c),
}
timeout := 300 * time.Millisecond

dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
best, err := network.ReachableHostPort(hostPorts, dialer, timeout)
c.Check(err, jc.ErrorIsNil)
c.Check(best, jc.DeepEquals, hostPorts[1]) // the only real listener
}

const maxTCPPort = 65535

func (s *HostPortSuite) manyHostPorts(c *gc.C, count int, addressFunc func(index int) string) []network.HostPort {
Expand All @@ -542,24 +518,3 @@ func (s *HostPortSuite) manyHostPorts(c *gc.C, count int, addressFunc func(index
}
return results
}

func testTCPServer(c *gc.C) network.HostPort {
listener, err := net.Listen("tcp", "127.0.0.1:0")
c.Assert(err, jc.ErrorIsNil)

listenAddress := listener.Addr().String()
hostPort, err := network.ParseHostPort(listenAddress)
c.Assert(err, jc.ErrorIsNil)
c.Logf("listening on %q", hostPort)

go func() {
conn, _ := listener.Accept()
if conn != nil {
c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr())
conn.Close()
}
listener.Close()
}()

return *hostPort
}
14 changes: 14 additions & 0 deletions network/ssh/package_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright 2016 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package ssh_test

import (
"testing"

gc "gopkg.in/check.v1"
)

func TestAll(t *testing.T) {
gc.TestingT(t)
}
65 changes: 65 additions & 0 deletions network/ssh/reachable.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package ssh

import (
"net"
"time"

"golang.org/x/crypto/ssh"
"github.com/juju/errors"
"github.com/juju/juju/network"

"github.com/juju/loggo"
)


var logger = loggo.GetLogger("juju.network.ssh")

// Dialer defines a Dial() method matching the signature of net.Dial().
type Dialer interface {
Dial(network, address string) (net.Conn, error)
}

// ReachableHostPort dials the entries in the given hostPorts, in parallel,
// using the given dialer, closing successfully established connections
// after checking the ssh key. Individual connection errors are discarded, and
// an error is returned only if none of the hostPorts can be reached when the
// given timeout expires.
// If publicKeys is a non empty list, then the SSH host public key will be
// checked. If it is not in the list, then that host is not considered valid.
//
// Usually, a net.Dialer initialized with a non-empty Timeout field is passed
// for dialer.
func ReachableHostPort(hostPorts []network.HostPort, publicKeys []string, dialer Dialer, timeout time.Duration) (network.HostPort, error) {
uniqueHPs := network.UniqueHostPorts(hostPorts)
successful := make(chan network.HostPort, 1)
done := make(chan struct{}, 0)

for _, hostPort := range uniqueHPs {
go func(hp network.HostPort) {
conn, err := dialer.Dial("tcp", hp.NetAddr())
if err == nil {
conn.Close()
select {
case successful <- hp:
return
case <-done:
return
}
}
}(hostPort)
}

select {
case result := <-successful:
logger.Infof("dialed %q successfully", result)
close(done)
return result, nil

case <-time.After(timeout):
close(done)
return network.HostPort{}, errors.Errorf("cannot connect to any address: %v", hostPorts)
}
}
168 changes: 168 additions & 0 deletions network/ssh/reachable_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
// Copyright 2014 Canonical Ltd.
// Licensed under the AGPLv3, see LICENCE file for details.

package ssh_test

import (
"net"
"time"

_ "github.com/juju/errors"
jc "github.com/juju/testing/checkers"
gc "gopkg.in/check.v1"

"github.com/juju/juju/network"
"github.com/juju/juju/network/ssh"
coretesting "github.com/juju/juju/testing"
)

type SSHReachableHostPortSuite struct {
coretesting.BaseSuite
}

var _ = gc.Suite(&SSHReachableHostPortSuite{})

func (s *SSHReachableHostPortSuite) TestAllUnreachable(c *gc.C) {
dialer := &net.Dialer{Timeout: 50 * time.Millisecond}
unreachableHPs := closedTCPHostPorts(c, 10)
timeout := 100 * time.Millisecond

best, err := ssh.ReachableHostPort(unreachableHPs, nil, dialer, timeout)
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*")
c.Check(best, gc.Equals, network.HostPort{})
}

func (s *SSHReachableHostPortSuite) TestReachableInvalidPublicKey(c *gc.C) {
hostPorts := []network.HostPort{
testSSHServer(c, "wrong public-key"),
}
timeout := 300 * time.Millisecond

dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout)
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*")
c.Check(best, gc.Equals, network.HostPort{})
}

func (s *SSHReachableHostPortSuite) TestReachableValidPublicKey(c *gc.C) {
hostPorts := []network.HostPort{
testSSHServer(c, "public-key"),
}
timeout := 300 * time.Millisecond

dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout)
c.Check(err, jc.ErrorIsNil)
c.Check(best, gc.Equals, hostPorts[0])
}

func (s *SSHReachableHostPortSuite) TestReachableMixedPublicKeys(c *gc.C) {
// One is just closed, one is TCP only, one is SSH but the wrong key, one
// is SSH with the right key
fakeHostPort := closedTCPHostPorts(c, 1)[0]
hostPorts := []network.HostPort{
fakeHostPort,
testTCPServer(c),
testSSHServer(c, "wrong public-key"),
testSSHServer(c, "public-key"),
}
timeout := 300 * time.Millisecond
dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout)
c.Check(best, gc.Equals, network.HostPort{})
c.Check(err, jc.ErrorIsNil)
c.Check(best, jc.DeepEquals, hostPorts[3])
}

func (s *SSHReachableHostPortSuite) TestReachableNoPublicKeysPassed(c *gc.C) {
fakeHostPort := closedTCPHostPorts(c, 1)[0]
hostPorts := []network.HostPort{
fakeHostPort,
testTCPServer(c),
}
timeout := 300 * time.Millisecond

dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
best, err := ssh.ReachableHostPort(hostPorts, nil, dialer, timeout)
c.Check(err, jc.ErrorIsNil)
c.Check(best, jc.DeepEquals, hostPorts[1]) // the only real listener
}

func (s *SSHReachableHostPortSuite) TestReachableNoPublicKeysAvailable(c *gc.C) {
fakeHostPort := closedTCPHostPorts(c, 1)[0]
hostPorts := []network.HostPort{
fakeHostPort,
testTCPServer(c),
}
timeout := 300 * time.Millisecond

dialer := &net.Dialer{Timeout: 100 * time.Millisecond}
best, err := ssh.ReachableHostPort(hostPorts, []string{"public-key"}, dialer, timeout)
c.Check(err, gc.ErrorMatches, "cannot connect to any address: .*")
c.Check(best, gc.Equals, network.HostPort{})
}

const maxTCPPort = 65535

// closedTCPHostPorts opens and then immediately closes a bunch of ports and
// saves their port numbers so we're unlikely to find a real listener at that
// address.
func closedTCPHostPorts(c *gc.C, count int) []network.HostPort {
ports := make([]network.HostPort, count)
for i := 0; i < count; i++ {
listener, err := net.Listen("tcp", "127.0.0.1:0")
c.Assert(err, jc.ErrorIsNil)
defer listener.Close()
listenAddress := listener.Addr().String()
port, err := network.ParseHostPort(listenAddress)
c.Assert(err, jc.ErrorIsNil)
ports[i] = *port
}
// By the time we return all the listeners are closed
return ports
}

// testTCPServer only listens on the socket, but doesn't speak SSH
func testTCPServer(c *gc.C) network.HostPort {
listener, err := net.Listen("tcp", "127.0.0.1:0")
c.Assert(err, jc.ErrorIsNil)

listenAddress := listener.Addr().String()
hostPort, err := network.ParseHostPort(listenAddress)
c.Assert(err, jc.ErrorIsNil)
c.Logf("listening on %q", hostPort)

go func() {
conn, _ := listener.Accept()
if conn != nil {
c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr())
conn.Close()
}
listener.Close()
}()

return *hostPort
}

// testSSHServer will listen on the socket and respond with the appropriate
// public key information and then die.
func testSSHServer(c *gc.C, publicKey string) network.HostPort {
listener, err := net.Listen("tcp", "127.0.0.1:0")
c.Assert(err, jc.ErrorIsNil)

listenAddress := listener.Addr().String()
hostPort, err := network.ParseHostPort(listenAddress)
c.Assert(err, jc.ErrorIsNil)
c.Logf("listening on %q", hostPort)

go func() {
conn, _ := listener.Accept()
if conn != nil {
c.Logf("accepted connection on %q from %s", hostPort, conn.RemoteAddr())
conn.Close()
}
listener.Close()
}()

return *hostPort
}

0 comments on commit 37034d8

Please sign in to comment.