Skip to content

Commit

Permalink
Add a timeout around the upstream dialer. Default is 30 seconds.
Browse files Browse the repository at this point in the history
  • Loading branch information
film42 committed Jul 22, 2017
1 parent 4050af0 commit 61826d4
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 3 deletions.
1 change: 1 addition & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type Config struct {
StripProxyHeaders bool `json:"strip_proxy_headers"`
Port int
UseIncomingLocalAddr bool `json:"use_incoming_local_addr"`
DialTimeout int
}

func (config *Config) AuthenticationRequired() bool {
Expand Down
10 changes: 8 additions & 2 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"net"
"net/http"
"strings"
"time"
)

const ProxyAuthenticationRequired = "HTTP/1.0 407 Proxy authentication required\r\n\r\n"
Expand All @@ -23,6 +24,8 @@ type connection struct {
}

func (c *connection) Dial(network, address string) (net.Conn, error) {
timeout := time.Second * time.Duration(config.DialTimeout)

if config.UseIncomingLocalAddr {
if c.localAddr == nil {
logger.Warn.Println(c.id, "Missing local net.Addr: a default local net.Addr will be used")
Expand All @@ -39,7 +42,10 @@ func (c *connection) Dial(network, address string) (net.Conn, error) {
goto fallback
}

dialer := &net.Dialer{LocalAddr: c.localAddr}
dialer := &net.Dialer{
LocalAddr: c.localAddr,
Timeout: timeout,
}

// Try to dial with the incoming LocalAddr to keep the incoming and outgoing IPs the same.
conn, err := dialer.Dial(network, address)
Expand All @@ -54,7 +60,7 @@ func (c *connection) Dial(network, address string) (net.Conn, error) {
}

fallback:
return net.Dial(network, address)
return net.DialTimeout(network, address, timeout)
}

func (c *connection) Handle() {
Expand Down
4 changes: 3 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ func main() {
configPtr := flag.String("config", "", "config file")
portPtr := flag.Int("port", 25000, "listen port")
stripProxyHeadersPtr := flag.Bool("strip-proxy-headers", true, "strip proxy headers from http requests")
useIncomingLocalAddr := flag.Bool("use-incoming-local-addr", true, "Attempt to use the local address of the incoming connection when connecting upstream")
useIncomingLocalAddr := flag.Bool("use-incoming-local-addr", true, "attempt to use the local address of the incoming connection when connecting upstream")
dialTimeoutPtr := flag.Int("dial-timeout", 30, " timeout for connecting to an upstream server in seconds")
usernamePtr := flag.String("username", "", "username for proxy authentication")
passwordPtr := flag.String("password", "", "password for proxy authentication")
flag.Parse()
Expand All @@ -77,6 +78,7 @@ func main() {
Port: *portPtr,
StripProxyHeaders: *stripProxyHeadersPtr,
UseIncomingLocalAddr: *useIncomingLocalAddr,
DialTimeout: *dialTimeoutPtr,
}

if *usernamePtr != "" {
Expand Down

0 comments on commit 61826d4

Please sign in to comment.