Skip to content

Commit

Permalink
Add multicast filter
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Nov 4, 2023
1 parent b93db96 commit 150b116
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 0 deletions.
1 change: 1 addition & 0 deletions stack_gvisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ func (t *GVisor) Start() error {
if err != nil {
return err
}
linkEndpoint = &LinkEndpointFilter{linkEndpoint, t.tun.CreateVectorisedWriter()}
ipStack, err := newGVisorStack(linkEndpoint)
if err != nil {
return err
Expand Down
52 changes: 52 additions & 0 deletions stack_gvisor_filter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build with_gvisor

package tun

import (
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
)

var _ stack.LinkEndpoint = (*LinkEndpointFilter)(nil)

type LinkEndpointFilter struct {
stack.LinkEndpoint
Writer N.VectorisedWriter
}

func (w *LinkEndpointFilter) Attach(dispatcher stack.NetworkDispatcher) {
w.LinkEndpoint.Attach(&networkDispatcherFilter{dispatcher, w.Writer})
}

var _ stack.NetworkDispatcher = (*networkDispatcherFilter)(nil)

type networkDispatcherFilter struct {
stack.NetworkDispatcher
writer N.VectorisedWriter
}

func (w *networkDispatcherFilter) DeliverNetworkPacket(protocol tcpip.NetworkProtocolNumber, pkt stack.PacketBufferPtr) {
var network header.Network
if protocol == header.IPv4ProtocolNumber {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv4MinimumSize); loaded {
network = header.IPv4(headerPackets)
}
} else {
if headerPackets, loaded := pkt.Data().PullUp(header.IPv6MinimumSize); loaded {
network = header.IPv6(headerPackets)
}
}
if network == nil {
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
return
}
destination := AddrFromAddress(network.DestinationAddress())
if destination.IsMulticast() || !destination.IsGlobalUnicast() {
_, _ = bufio.WriteVectorised(w.writer, pkt.AsSlices())
return
}
w.NetworkDispatcher.DeliverNetworkPacket(protocol, pkt)
}
7 changes: 7 additions & 0 deletions stack_system.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ func (s *System) acceptLoop(listener net.Listener) {
}

func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
destination := packet.DestinationIP()
if destination.IsMulticast() || !destination.IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv4TCP(packet, packet.Payload())
Expand All @@ -246,6 +250,9 @@ func (s *System) processIPv4(packet clashtcpip.IPv4Packet) error {
}

func (s *System) processIPv6(packet clashtcpip.IPv6Packet) error {
if !packet.DestinationIP().IsGlobalUnicast() {
return common.Error(s.tun.Write(packet))
}
switch packet.Protocol() {
case clashtcpip.TCP:
return s.processIPv6TCP(packet, packet.Payload())
Expand Down

0 comments on commit 150b116

Please sign in to comment.