Skip to content

Commit

Permalink
Add broadcast filter
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 6, 2023
1 parent 1a00992 commit da350ec
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 6 deletions.
12 changes: 12 additions & 0 deletions stack.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package tun

import (
"context"
"encoding/binary"
"net"
"net/netip"

"github.com/sagernet/sing/common/control"
Expand Down Expand Up @@ -52,3 +54,13 @@ func NewStack(
return nil, E.New("unknown stack: ", stack)
}
}

func BroadcastAddr(inet4Address []netip.Prefix) netip.Addr {
if len(inet4Address) == 0 {
return netip.Addr{}
}
prefix := inet4Address[0]
var broadcastAddr [4]byte
binary.BigEndian.PutUint32(broadcastAddr[:], binary.BigEndian.Uint32(prefix.Masked().Addr().AsSlice())|^binary.BigEndian.Uint32(net.CIDRMask(prefix.Bits(), 32)))
return netip.AddrFrom4(broadcastAddr)
}
4 changes: 3 additions & 1 deletion stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ type GVisor struct {
tunMtu uint32
endpointIndependentNat bool
udpTimeout int64
broadcastAddr netip.Addr
handler Handler
logger logger.Logger
stack *stack.Stack
Expand All @@ -59,6 +60,7 @@ func NewGVisor(
tunMtu: options.MTU,
endpointIndependentNat: options.EndpointIndependentNat,
udpTimeout: options.UDPTimeout,
broadcastAddr: BroadcastAddr(options.Inet4Address),
handler: options.Handler,
logger: options.Logger,
}
Expand All @@ -70,7 +72,7 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.tun.CreateVectorisedWriter()}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.broadcastAddr, t.tun.CreateVectorisedWriter()}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err
Expand Down
12 changes: 8 additions & 4 deletions stack_gvisor_filter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
package tun

import (
"net/netip"

"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
Expand All @@ -14,18 +16,20 @@ var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)

type LinkEndpointFilter struct {
stack.LinkEndpoint
Writer N.VectorisedWriter
BroadcastAddress netip.Addr
Writer N.VectorisedWriter
}

func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.Writer})
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.BroadcastAddress, w.Writer})
}

var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)

type networkDispatcherFilter struct {
stack.NetworkDispatcher
writer N.VectorisedWriter
broadcastAddress netip.Addr
writer N.VectorisedWriter
}

func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
Expand All @@ -44,7 +48,7 @@ func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkPro
return
}
destination := AddrFromAddress(network.DestinationAddress())
if destination.IsMulticast() || !destination.IsGlobalUnicast() {
if destination == w.broadcastAddress || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
return
}
Expand Down
4 changes: 3 additions & 1 deletion stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ type System struct {
inet4Address netip.Addr
inet6ServerAddress netip.Addr
inet6Address netip.Addr
broadcastAddr netip.Addr
udpTimeout int64
tcpListener net.Listener
tcpListener6 net.Listener
Expand Down Expand Up @@ -60,6 +61,7 @@ func NewSystem(options StackOptions) (Stack, error) {
logger: options.Logger,
inet4Prefixes: options.Inet4Address,
inet6Prefixes: options.Inet6Address,
broadcastAddr: BroadcastAddr(options.Inet4Address),
bindInterface: options.ForwarderBindInterface,
interfaceFinder: options.InterfaceFinder,
}
Expand Down Expand Up @@ -234,7 +236,7 @@ func (s *System) acceptLoop(listener net.Listener) {

func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
destination := packet.DestinationIP()
if destination.IsMulticast() || !destination.IsGlobalUnicast() {
if destination == s.broadcastAddr || !destination.IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
Expand Down

0 comments on commit da350ec

Please sign in to comment.