From f79544d6499dbcc9a83966372502b328a49aedc6 Mon Sep 17 00:00:00 2001 From: tanghang Date: Fri, 13 Dec 2024 11:33:27 +0800 Subject: [PATCH] feat: support timeout when Dial --- client_unix.go | 15 +++++++++++++++ client_windows.go | 26 ++++++++++++++++++++++++++ 2 files changed, 41 insertions(+) diff --git a/client_unix.go b/client_unix.go index 728b3b615..4533e14a4 100644 --- a/client_unix.go +++ b/client_unix.go @@ -23,6 +23,7 @@ import ( "net" "strconv" "syscall" + "time" "golang.org/x/sync/errgroup" "golang.org/x/sys/unix" @@ -148,6 +149,11 @@ func (cli *Client) Dial(network, address string) (Conn, error) { return cli.DialContext(network, address, nil) } +// DialTimeout is like net.DialTimeout(). +func (cli *Client) DialTimeout(network, address string, timeout time.Duration) (Conn, error) { + return cli.DialContextTimeout(network, address, nil, timeout) +} + // DialContext is like Dial but also accepts an empty interface ctx that can be obtained later via Conn.Context. func (cli *Client) DialContext(network, address string, ctx any) (Conn, error) { c, err := net.Dial(network, address) @@ -157,6 +163,15 @@ func (cli *Client) DialContext(network, address string, ctx any) (Conn, error) { return cli.EnrollContext(c, ctx) } +// DialContextTimeout is like DialContext but also accepts a timeout. +func (cli *Client) DialContextTimeout(network, address string, ctx any, timeout time.Duration) (Conn, error) { + c, err := net.DialTimeout(network, address, timeout) + if err != nil { + return nil, err + } + return cli.EnrollContext(c, ctx) +} + // Enroll converts a net.Conn to gnet.Conn and then adds it into Client. func (cli *Client) Enroll(c net.Conn) (Conn, error) { return cli.EnrollContext(c, nil) diff --git a/client_windows.go b/client_windows.go index 96806414d..46ac7f484 100644 --- a/client_windows.go +++ b/client_windows.go @@ -20,6 +20,7 @@ import ( "os" "path/filepath" "sync" + "time" "golang.org/x/sync/errgroup" @@ -118,6 +119,10 @@ func (cli *Client) Dial(network, addr string) (Conn, error) { return cli.DialContext(network, addr, nil) } +func (cli *Client) DialTimeout(network, addr string, timeout time.Duration) (Conn, error) { + return cli.DialContextTimeout(network, addr, nil, timeout) +} + func (cli *Client) DialContext(network, addr string, ctx any) (Conn, error) { var ( c net.Conn @@ -139,6 +144,27 @@ func (cli *Client) DialContext(network, addr string, ctx any) (Conn, error) { return cli.EnrollContext(c, ctx) } +func (cli *Client) DialContextTimeout(network, addr string, ctx any, timeout time.Duration) (Conn, error) { + var ( + c net.Conn + err error + ) + if network == "unix" { + laddr, _ := net.ResolveUnixAddr(network, unixAddr(addr)) + raddr, _ := net.ResolveUnixAddr(network, addr) + c, err = net.DialUnix(network, laddr, raddr) + if err != nil { + return nil, err + } + } else { + c, err = net.DialTimeout(network, addr, timeout) + if err != nil { + return nil, err + } + } + return cli.EnrollContext(c, ctx) +} + func (cli *Client) Enroll(nc net.Conn) (gc Conn, err error) { return cli.EnrollContext(nc, nil) }