diff --git a/config.go b/config.go index cde7e17..c8b2bac 100644 --- a/config.go +++ b/config.go @@ -18,6 +18,7 @@ type Config struct { StripProxyHeaders bool `json:"strip_proxy_headers"` Port int UseIncomingLocalAddr bool `json:"use_incoming_local_addr"` + DialTimeout int } func (config *Config) AuthenticationRequired() bool { diff --git a/connection.go b/connection.go index 72431f3..4fe5576 100644 --- a/connection.go +++ b/connection.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "strings" + "time" ) const ProxyAuthenticationRequired = "HTTP/1.0 407 Proxy authentication required\r\n\r\n" @@ -23,6 +24,8 @@ type connection struct { } func (c *connection) Dial(network, address string) (net.Conn, error) { + timeout := time.Second * time.Duration(config.DialTimeout) + if config.UseIncomingLocalAddr { if c.localAddr == nil { logger.Warn.Println(c.id, "Missing local net.Addr: a default local net.Addr will be used") @@ -39,7 +42,10 @@ func (c *connection) Dial(network, address string) (net.Conn, error) { goto fallback } - dialer := &net.Dialer{LocalAddr: c.localAddr} + dialer := &net.Dialer{ + LocalAddr: c.localAddr, + Timeout: timeout, + } // Try to dial with the incoming LocalAddr to keep the incoming and outgoing IPs the same. conn, err := dialer.Dial(network, address) @@ -54,7 +60,7 @@ func (c *connection) Dial(network, address string) (net.Conn, error) { } fallback: - return net.Dial(network, address) + return net.DialTimeout(network, address, timeout) } func (c *connection) Handle() { diff --git a/main.go b/main.go index 89cbb81..35e9908 100644 --- a/main.go +++ b/main.go @@ -53,7 +53,8 @@ func main() { configPtr := flag.String("config", "", "config file") portPtr := flag.Int("port", 25000, "listen port") stripProxyHeadersPtr := flag.Bool("strip-proxy-headers", true, "strip proxy headers from http requests") - useIncomingLocalAddr := flag.Bool("use-incoming-local-addr", true, "Attempt to use the local address of the incoming connection when connecting upstream") + useIncomingLocalAddr := flag.Bool("use-incoming-local-addr", true, "attempt to use the local address of the incoming connection when connecting upstream") + dialTimeoutPtr := flag.Int("dial-timeout", 30, " timeout for connecting to an upstream server in seconds") usernamePtr := flag.String("username", "", "username for proxy authentication") passwordPtr := flag.String("password", "", "password for proxy authentication") flag.Parse() @@ -77,6 +78,7 @@ func main() { Port: *portPtr, StripProxyHeaders: *stripProxyHeadersPtr, UseIncomingLocalAddr: *useIncomingLocalAddr, + DialTimeout: *dialTimeoutPtr, } if *usernamePtr != "" {