Skip to content

Commit 6a8c071

Browse files
committed
Add Rule.Equal method.
This adds an implementation of Rule.Equal (and RulePortRange.Equal, RuleUIDRange.Equal) which allows us to compare two Rule objects.
1 parent 20a4b9a commit 6a8c071

File tree

3 files changed

+148
-30
lines changed

3 files changed

+148
-30
lines changed

rule.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,47 @@ type Rule struct {
3131
Type uint8
3232
}
3333

34+
func (r Rule) Equal(x Rule) bool {
35+
return r.Table == x.Table &&
36+
((r.Src == nil && x.Src == nil) ||
37+
(r.Src != nil && x.Src != nil && r.Src.String() == x.Src.String())) &&
38+
((r.Dst == nil && x.Dst == nil) ||
39+
(r.Dst != nil && x.Dst != nil && r.Dst.String() == x.Dst.String())) &&
40+
r.OifName == x.OifName &&
41+
r.Priority == x.Priority &&
42+
r.Family == x.Family &&
43+
r.IifName == x.IifName &&
44+
r.Invert == x.Invert &&
45+
r.Tos == x.Tos &&
46+
r.Type == x.Type &&
47+
r.IPProto == x.IPProto &&
48+
r.Protocol == x.Protocol &&
49+
r.Mark == x.Mark &&
50+
// For non-zero marks, mask defaults to 0xFFFFFFFF if not set. So if either mask is nil
51+
// while the other is 0xFFFFFFFF when mark is non-zero, treat the masks as identical.
52+
// See kernel source: https://github.com/torvalds/linux/blob/v6.15/net/core/fib_rules.c#L624
53+
(ptrEqual(r.Mask, x.Mask) || (r.Mark != 0 &&
54+
(r.Mask == nil && *x.Mask == 0xFFFFFFFF || x.Mask == nil && *r.Mask == 0xFFFFFFFF))) &&
55+
r.TunID == x.TunID &&
56+
r.Goto == x.Goto &&
57+
r.Flow == x.Flow &&
58+
r.SuppressIfgroup == x.SuppressIfgroup &&
59+
r.SuppressPrefixlen == x.SuppressPrefixlen &&
60+
(r.Dport == x.Dport || (r.Dport != nil && x.Dport != nil && r.Dport.Equal(*x.Dport))) &&
61+
(r.Sport == x.Sport || (r.Sport != nil && x.Sport != nil && r.Sport.Equal(*x.Sport))) &&
62+
(r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange)))
63+
}
64+
65+
func ptrEqual(a, b *uint32) bool {
66+
if a == b {
67+
return true
68+
}
69+
if (a == nil) || (b == nil) {
70+
return false
71+
}
72+
return *a == *b
73+
}
74+
3475
func (r Rule) String() string {
3576
from := "all"
3677
if r.Src != nil && r.Src.String() != "<nil>" {
@@ -70,6 +111,10 @@ type RulePortRange struct {
70111
End uint16
71112
}
72113

114+
func (r RulePortRange) Equal(x RulePortRange) bool {
115+
return r.Start == x.Start && r.End == x.End
116+
}
117+
73118
// NewRuleUIDRange creates rule uid range.
74119
func NewRuleUIDRange(start, end uint32) *RuleUIDRange {
75120
return &RuleUIDRange{Start: start, End: end}
@@ -80,3 +125,7 @@ type RuleUIDRange struct {
80125
Start uint32
81126
End uint32
82127
}
128+
129+
func (r RuleUIDRange) Equal(x RuleUIDRange) bool {
130+
return r.Start == x.Start && r.End == x.End
131+
}

rule_linux.go

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -336,16 +336,6 @@ func (pr *RuleUIDRange) toRtAttrData() []byte {
336336
return bytes.Join(b, []byte{})
337337
}
338338

339-
func ptrEqual(a, b *uint32) bool {
340-
if a == b {
341-
return true
342-
}
343-
if (a == nil) || (b == nil) {
344-
return false
345-
}
346-
return *a == *b
347-
}
348-
349339
func (r Rule) typeString() string {
350340
switch r.Type {
351341
case unix.RTN_UNSPEC: // zero

rule_test.go

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package netlink
55

66
import (
77
"net"
8+
"strconv"
89
"testing"
910
"time"
1011

@@ -600,7 +601,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) {
600601
t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules))
601602
} else {
602603
for i := range wantRules {
603-
if !ruleEquals(wantRules[i], rules[i]) {
604+
if !wantRules[i].Equal(rules[i]) {
604605
t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i])
605606
}
606607
}
@@ -666,30 +667,108 @@ func TestRuleString(t *testing.T) {
666667

667668
func ruleExists(rules []Rule, rule Rule) bool {
668669
for i := range rules {
669-
if ruleEquals(rules[i], rule) {
670+
if rules[i].Equal(rule) {
670671
return true
671672
}
672673
}
673674

674675
return false
675676
}
676677

677-
func ruleEquals(a, b Rule) bool {
678-
return a.Table == b.Table &&
679-
((a.Src == nil && b.Src == nil) ||
680-
(a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) &&
681-
((a.Dst == nil && b.Dst == nil) ||
682-
(a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) &&
683-
a.OifName == b.OifName &&
684-
a.Priority == b.Priority &&
685-
a.Family == b.Family &&
686-
a.IifName == b.IifName &&
687-
a.Invert == b.Invert &&
688-
a.Tos == b.Tos &&
689-
a.Type == b.Type &&
690-
a.IPProto == b.IPProto &&
691-
a.Protocol == b.Protocol &&
692-
a.Mark == b.Mark &&
693-
(ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 &&
694-
(a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF)))
678+
func TestRuleEqual(t *testing.T) {
679+
cases := []Rule{
680+
{Priority: 1000},
681+
{Family: FAMILY_V6},
682+
{Table: 10},
683+
{Mark: 1},
684+
{Mask: &[]uint32{0x1}[0]},
685+
{Tos: 1},
686+
{TunID: 3},
687+
{Goto: 10},
688+
{Src: &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}},
689+
{Dst: &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}},
690+
{Flow: 3},
691+
{IifName: "IifName"},
692+
{OifName: "OifName"},
693+
{SuppressIfgroup: 7},
694+
{SuppressPrefixlen: 16},
695+
{Invert: true},
696+
{Dport: &RulePortRange{Start: 10, End: 20}},
697+
{Sport: &RulePortRange{Start: 1, End: 2}},
698+
{IPProto: unix.IPPROTO_TCP},
699+
{UIDRange: &RuleUIDRange{Start: 3, End: 5}},
700+
{Protocol: FAMILY_V6},
701+
{Type: unix.RTN_UNREACHABLE},
702+
}
703+
for i1 := range cases {
704+
for i2 := range cases {
705+
got := cases[i1].Equal(cases[i2])
706+
expected := i1 == i2
707+
if got != expected {
708+
t.Errorf("Equal(%q,%q) == %s but expected %s",
709+
cases[i1], cases[i2],
710+
strconv.FormatBool(got),
711+
strconv.FormatBool(expected))
712+
}
713+
}
714+
}
715+
}
716+
717+
func TestRuleEqualMaskMark(t *testing.T) {
718+
a := Rule{Mark: 1, Mask: nil}
719+
b := Rule{Mark: 1, Mask: &[]uint32{0xFFFFFFFF}[0]}
720+
if !a.Equal(b) || !b.Equal(a) {
721+
t.Errorf("Rules are expected to be equal")
722+
}
723+
724+
b = Rule{Mark: 2, Mask: &[]uint32{0xFFFFFFFF}[0]}
725+
if a.Equal(b) || b.Equal(a) {
726+
t.Errorf("Rules are not expected to be equal")
727+
}
728+
729+
a = Rule{Mark: 0, Mask: nil}
730+
b = Rule{Mark: 0, Mask: &[]uint32{0xFFFFFFFF}[0]}
731+
if a.Equal(b) || b.Equal(a) {
732+
t.Errorf("Rules are not expected to be equal")
733+
}
734+
}
735+
736+
func TestRulePortRangeEqual(t *testing.T) {
737+
cases := []RulePortRange{
738+
{Start: 10, End: 10},
739+
{Start: 10, End: 22},
740+
{Start: 11, End: 22},
741+
}
742+
for i1 := range cases {
743+
for i2 := range cases {
744+
got := cases[i1].Equal(cases[i2])
745+
expected := i1 == i2
746+
if got != expected {
747+
t.Errorf("Equal(%q,%q) == %s but expected %s",
748+
cases[i1], cases[i2],
749+
strconv.FormatBool(got),
750+
strconv.FormatBool(expected))
751+
}
752+
}
753+
}
754+
}
755+
756+
func TestRuleUIDRangeEqual(t *testing.T) {
757+
cases := []RuleUIDRange{
758+
{Start: 10, End: 10},
759+
{Start: 10, End: 22},
760+
{Start: 11, End: 22},
761+
}
762+
for i1 := range cases {
763+
for i2 := range cases {
764+
got := cases[i1].Equal(cases[i2])
765+
expected := i1 == i2
766+
if got != expected {
767+
t.Errorf("Equal(%q,%q) == %s but expected %s",
768+
cases[i1], cases[i2],
769+
strconv.FormatBool(got),
770+
strconv.FormatBool(expected))
771+
}
772+
}
773+
}
695774
}

0 commit comments

Comments
 (0)