diff --git a/nats.go b/nats.go index fbb86c031..b3ba1f7b6 100644 --- a/nats.go +++ b/nats.go @@ -123,6 +123,12 @@ type asyncCB func() // Option is a function on the options for a connection. type Option func(*Options) error +// CustomDialer can be used to specify any dialer, not necessarily +// a *net.Dialer. +type CustomDialer interface { + Dial(network, address string) (net.Conn, error) +} + // Options can be used to create a customized connection. type Options struct { @@ -225,9 +231,14 @@ type Options struct { // Token sets the token to be used when connecting to a server. Token string - // Dialer allows a custom Dialer when forming connections. + // Dialer allows a custom net.Dialer when forming connections. + // DEPRECATED: should use CustomDialer instead. Dialer *net.Dialer + // CustomDialer allows to specify a custom dialer (not necessarily + // a *net.Dialer). + CustomDialer CustomDialer + // UseOldRequestStyle forces the old method of Requests that utilize // a new Inbox and a new Subscription for each request. UseOldRequestStyle bool @@ -586,6 +597,7 @@ func Token(token string) Option { // Dialer is an Option to set the dialer which will be used when // attempting to establish a connection. +// DEPRECATED: Should use CustomDialer instead. func Dialer(dialer *net.Dialer) Option { return func(o *Options) error { o.Dialer = dialer @@ -593,6 +605,16 @@ func Dialer(dialer *net.Dialer) Option { } } +// SetCustomDialer is an Option to set a custom dialer which will be +// used when attempting to establish a connection. If both Dialer +// and CustomDialer are specified, CustomDialer takes precedence. +func SetCustomDialer(dialer CustomDialer) Option { + return func(o *Options) error { + o.CustomDialer = dialer + return nil + } +} + // UseOldRequestyStyle is an Option to force usage of the old Request style. func UseOldRequestStyle() Option { return func(o *Options) error { @@ -877,7 +899,13 @@ func (nc *Conn) createConn() (err error) { cur.lastAttempt = time.Now() } - dialer := nc.Opts.Dialer + // CustomDialer takes precedence. If not set, use Opts.Dialer which + // is set to a default *net.Dialer (in Connect()) if not explicitly + // set by the user. + dialer := nc.Opts.CustomDialer + if dialer == nil { + dialer = nc.Opts.Dialer + } nc.conn, err = dialer.Dial("tcp", nc.url.Host) if err != nil { return err diff --git a/test/conn_test.go b/test/conn_test.go index fa338bb9a..6d2a77d47 100644 --- a/test/conn_test.go +++ b/test/conn_test.go @@ -1266,6 +1266,15 @@ func TestNoRaceOnLastError(t *testing.T) { } } +type customDialer struct { + ch chan bool +} + +func (cd *customDialer) Dial(network, address string) (net.Conn, error) { + cd.ch <- true + return nil, fmt.Errorf("on purpose") +} + func TestUseCustomDialer(t *testing.T) { s := RunDefaultServer() defer s.Shutdown() @@ -1310,6 +1319,51 @@ func TestUseCustomDialer(t *testing.T) { if nc3.Opts.Dialer.Timeout != nats.DefaultTimeout { t.Fatalf("Expected DialTimeout to be set to %v, got %v", nats.DefaultTimeout, nc.Opts.Dialer.Timeout) } + + // Create custom dialer that return error on Dial(). + cdialer := &customDialer{ch: make(chan bool, 1)} + + // When both Dialer and CustomDialer are set, CustomDialer + // should take precedence. That means that the connection + // should fail for these two set of options. + options := []*nats.Options{ + &nats.Options{Dialer: dialer, CustomDialer: cdialer}, + &nats.Options{CustomDialer: cdialer}, + } + for _, o := range options { + o.Servers = []string{nats.DefaultURL} + nc, err := o.Connect() + // As of now, Connect() would not return the actual dialer error, + // instead it returns "no server available for connections". + // So use go channel to ensure that custom dialer's Dial() method + // was invoked. + if err == nil { + if nc != nil { + nc.Close() + } + t.Fatal("Expected error, got none") + } + if err := Wait(cdialer.ch); err != nil { + t.Fatal("Did not get our notification") + } + } + // Same with variadic + foptions := [][]nats.Option{ + []nats.Option{nats.Dialer(dialer), nats.SetCustomDialer(cdialer)}, + []nats.Option{nats.SetCustomDialer(cdialer)}, + } + for _, fos := range foptions { + nc, err := nats.Connect(nats.DefaultURL, fos...) + if err == nil { + if nc != nil { + nc.Close() + } + t.Fatal("Expected error, got none") + } + if err := Wait(cdialer.ch); err != nil { + t.Fatal("Did not get our notification") + } + } } func TestDefaultOptionsDialer(t *testing.T) {