From e7c24976bf265f7fed72131150a61da53c0de599 Mon Sep 17 00:00:00 2001 From: mzz <2017@duck.com> Date: Wed, 19 Feb 2025 17:08:53 +0800 Subject: [PATCH] fix: panic due to goroutine setting returned error (#748) --- common/netutils/context_dialer.go | 38 ---------------- common/netutils/dns.go | 19 ++++++++ common/netutils/ip46.go | 45 +++++++------------ common/netutils/ip46_test.go | 7 +-- component/dns/upstream.go | 5 +-- .../outbound/dialer/connectivity_check.go | 12 ++--- component/sniffing/sniffer.go | 2 +- control/control_plane.go | 2 +- control/dns.go | 2 +- 9 files changed, 48 insertions(+), 84 deletions(-) delete mode 100644 common/netutils/context_dialer.go diff --git a/common/netutils/context_dialer.go b/common/netutils/context_dialer.go deleted file mode 100644 index 99af3e6e00..0000000000 --- a/common/netutils/context_dialer.go +++ /dev/null @@ -1,38 +0,0 @@ -/* - * SPDX-License-Identifier: AGPL-3.0-only - * Copyright (c) 2022-2024, daeuniverse Organization - */ - -package netutils - -import ( - "context" - "net" -) - -type ContextDialer struct { - Dialer net.Dialer -} - -func (d *ContextDialer) DialContext(ctx context.Context, network, addr string) (c net.Conn, err error) { - var done = make(chan struct{}) - go func() { - c, err = d.Dialer.Dial(network, addr) - if err != nil { - close(done) - return - } - select { - case <-ctx.Done(): - _ = c.Close() - default: - close(done) - } - }() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-done: - return c, err - } -} diff --git a/common/netutils/dns.go b/common/netutils/dns.go index fdb3eb6dff..e65022b9c0 100644 --- a/common/netutils/dns.go +++ b/common/netutils/dns.go @@ -140,6 +140,25 @@ func ResolveNS(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host return records, nil } +func ResolveSOA(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, network string) (records []string, err error) { + typ := dnsmessage.TypeSOA + resources, err := resolve(ctx, d, dns, host, typ, network) + if err != nil { + return nil, err + } + for _, ans := range resources { + if ans.Header().Rrtype != typ { + continue + } + ns, ok := ans.(*dnsmessage.SOA) + if !ok { + return nil, ErrBadDnsAns + } + records = append(records, ns.Ns) + } + return records, nil +} + func resolve(ctx context.Context, d netproxy.Dialer, dns netip.AddrPort, host string, typ uint16, network string) (ans []dnsmessage.RR, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() diff --git a/common/netutils/ip46.go b/common/netutils/ip46.go index ca9c28d79b..10171687a6 100644 --- a/common/netutils/ip46.go +++ b/common/netutils/ip46.go @@ -8,7 +8,6 @@ package netutils import ( "context" "errors" - "fmt" "net/netip" "sync" @@ -22,24 +21,22 @@ type Ip46 struct { Ip6 netip.Addr } -func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, network string, race bool) (ipv46 *Ip46, err error) { +func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort, host string, network string, race bool) (ipv46 *Ip46, err4, err6 error) { var log *logrus.Logger if _log := ctx.Value("logger"); _log != nil { log = _log.(*logrus.Logger) defer func() { - if err == nil { - log.Tracef("ResolveIp46 %v using %v: A(%v) AAAA(%v)", host, systemDns, ipv46.Ip4, ipv46.Ip6) - } else { - log.Tracef("ResolveIp46 %v using %v: %v", host, systemDns, err) - } + log.WithField("err4", err4). + WithField("err6", err6). + Tracef("ResolveIp46 %v using %v: A(%v) AAAA(%v)", host, systemDns, ipv46.Ip4, ipv46.Ip6) }() } var wg sync.WaitGroup wg.Add(2) - var err4, err6 error var addrs4, addrs6 []netip.Addr ctx4, cancel4 := context.WithCancel(ctx) ctx6, cancel6 := context.WithCancel(ctx) + var _err4, _err6 error go func() { defer func() { wg.Done() @@ -50,13 +47,10 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort }() var e error addrs4, e = ResolveNetip(ctx4, dialer, dns, host, dnsmessage.TypeA, network) - if err != nil && !errors.Is(e, context.Canceled) { - err4 = e + if e != nil && !errors.Is(e, context.Canceled) { + _err4 = e return } - if len(addrs4) == 0 { - addrs4 = []netip.Addr{{}} - } }() go func() { defer func() { @@ -68,27 +62,18 @@ func ResolveIp46(ctx context.Context, dialer netproxy.Dialer, dns netip.AddrPort }() var e error addrs6, e = ResolveNetip(ctx6, dialer, dns, host, dnsmessage.TypeAAAA, network) - if err != nil && !errors.Is(e, context.Canceled) { + if e != nil && !errors.Is(e, context.Canceled) { err6 = e return } - if len(addrs6) == 0 { - addrs6 = []netip.Addr{{}} - } }() wg.Wait() - if err4 != nil || err6 != nil { - if err4 != nil && err6 != nil { - return nil, fmt.Errorf("%w: %v", err4, err6) - } - if err4 != nil { - return nil, err4 - } else { - return nil, err6 - } + ipv46 = &Ip46{} + if len(addrs4) != 0 { + ipv46.Ip4 = addrs4[0] + } + if len(addrs6) != 0 { + ipv46.Ip6 = addrs6[0] } - return &Ip46{ - Ip4: addrs4[0], - Ip6: addrs6[0], - }, nil + return ipv46, _err4, _err6 } diff --git a/common/netutils/ip46_test.go b/common/netutils/ip46_test.go index 20724fb8e9..59943417fc 100644 --- a/common/netutils/ip46_test.go +++ b/common/netutils/ip46_test.go @@ -17,11 +17,12 @@ import ( func TestResolveIp46(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - ip46, err := ResolveIp46(ctx, direct.SymmetricDirect, netip.MustParseAddrPort("223.5.5.5:53"), "www.apple.com", "udp", false) - if err != nil { - t.Fatal(err) + ip46, err4, err6 := ResolveIp46(ctx, direct.SymmetricDirect, netip.MustParseAddrPort("223.5.5.5:53"), "ipv6.google.com", "udp", false) + if err4 != nil || err6 != nil { + t.Fatal(err4, err6) } if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { t.Fatal("No record") } + t.Log(ip46) } diff --git a/component/dns/upstream.go b/component/dns/upstream.go index 0adee2fe19..59ecfa797a 100644 --- a/component/dns/upstream.go +++ b/component/dns/upstream.go @@ -112,10 +112,7 @@ func NewUpstream(ctx context.Context, upstream *url.URL, resolverNetwork string) } }() - ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, resolverNetwork, false) - if err != nil { - return nil, fmt.Errorf("failed to resolve dns_upstream: %w", err) - } + ip46, _, _ := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, hostname, resolverNetwork, false) if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { return nil, fmt.Errorf("dns_upstream %v has no record", upstream.String()) } diff --git a/component/outbound/dialer/connectivity_check.go b/component/outbound/dialer/connectivity_check.go index a725a96273..d06987cd0c 100644 --- a/component/outbound/dialer/connectivity_check.go +++ b/component/outbound/dialer/connectivity_check.go @@ -160,9 +160,9 @@ func ParseTcpCheckOption(ctx context.Context, rawURL []string, method string, re if len(rawURL) > 1 { ip46 = parseIp46FromList(rawURL[1:]) } else { - ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), resolverNetwork, false) - if err != nil { - return nil, err + ip46, _, _ = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, u.Hostname(), resolverNetwork, false) + if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { + return nil, fmt.Errorf("ResolveIp46: no valid ip for %v", u.Hostname()) } } return &TcpCheckOption{ @@ -205,9 +205,9 @@ func ParseCheckDnsOption(ctx context.Context, dnsHostPort []string, resolverNetw if len(dnsHostPort) > 1 { ip46 = parseIp46FromList(dnsHostPort[1:]) } else { - ip46, err = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, resolverNetwork, false) - if err != nil { - return nil, err + ip46, _, _ = netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, host, resolverNetwork, false) + if !ip46.Ip4.IsValid() && !ip46.Ip6.IsValid() { + return nil, fmt.Errorf("ResolveIp46: no valid ip for %v", host) } } return &CheckDnsOption{ diff --git a/component/sniffing/sniffer.go b/component/sniffing/sniffer.go index 3bc94afe94..454daa682e 100644 --- a/component/sniffing/sniffer.go +++ b/component/sniffing/sniffer.go @@ -107,7 +107,7 @@ func (s *Sniffer) SniffTcp() (d string, err error) { if s.stream { go func() { // Read once. - _, err = s.buf.ReadFromOnce(s.r) + _, err := s.buf.ReadFromOnce(s.r) if err != nil { s.dataError = err } diff --git a/control/control_plane.go b/control/control_plane.go index d534db9504..885833d7b6 100644 --- a/control/control_plane.go +++ b/control/control_plane.go @@ -662,7 +662,7 @@ func (c *ControlPlane) ChooseDialTarget(outbound consts.OutboundIndex, dst netip // TODO: use DNS controller and re-route by control plane. systemDns, err := netutils.SystemDns() if err == nil { - if ip46, err := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, common.MagicNetwork("udp", c.soMarkFromDae, c.mptcp), true); err == nil && (ip46.Ip4.IsValid() || ip46.Ip6.IsValid()) { + if ip46, _, _ := netutils.ResolveIp46(ctx, direct.SymmetricDirect, systemDns, domain, common.MagicNetwork("udp", c.soMarkFromDae, c.mptcp), true); ip46.Ip4.IsValid() || ip46.Ip6.IsValid() { // Has A/AAAA records. It is a real domain. dialMode = consts.DialMode_Domain // Add it to real-domain set. diff --git a/control/dns.go b/control/dns.go index 1884b17734..c8f0e85da7 100644 --- a/control/dns.go +++ b/control/dns.go @@ -317,7 +317,7 @@ func (d *DoUDP) ForwardDNS(ctx context.Context, data []byte) (*dnsmessage.Msg, e go func() { // Send DNS request every seconds. for { - _, err = conn.Write(data) + _, _ = conn.Write(data) // if err != nil { // if c.log.IsLevelEnabled(logrus.DebugLevel) { // c.log.WithFields(logrus.Fields{