Skip to content

Commit 677a969

Browse files
committed
Add Rule.Equal method.
1 parent 9d88d83 commit 677a969

File tree

2 files changed

+135
-20
lines changed

2 files changed

+135
-20
lines changed

rule.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,34 @@ type Rule struct {
3131
Type uint8
3232
}
3333

34+
func (r Rule) Equal(x Rule) bool {
35+
return r.Table == x.Table &&
36+
((r.Src == nil && x.Src == nil) ||
37+
(r.Src != nil && x.Src != nil && r.Src.String() == x.Src.String())) &&
38+
((r.Dst == nil && x.Dst == nil) ||
39+
(r.Dst != nil && x.Dst != nil && r.Dst.String() == x.Dst.String())) &&
40+
r.OifName == x.OifName &&
41+
r.Priority == x.Priority &&
42+
r.Family == x.Family &&
43+
r.IifName == x.IifName &&
44+
r.Invert == x.Invert &&
45+
r.Tos == x.Tos &&
46+
r.Type == x.Type &&
47+
r.IPProto == x.IPProto &&
48+
r.Protocol == x.Protocol &&
49+
r.Mark == x.Mark &&
50+
(ptrEqual(r.Mask, x.Mask) || (r.Mark != 0 &&
51+
(r.Mask == nil && *x.Mask == 0xFFFFFFFF || x.Mask == nil && *r.Mask == 0xFFFFFFFF))) &&
52+
r.TunID == x.TunID &&
53+
r.Goto == x.Goto &&
54+
r.Flow == x.Flow &&
55+
r.SuppressIfgroup == x.SuppressIfgroup &&
56+
r.SuppressPrefixlen == x.SuppressPrefixlen &&
57+
(r.Dport == x.Dport || (r.Dport != nil && x.Dport != nil && r.Dport.Equal(*x.Dport))) &&
58+
(r.Sport == x.Sport || (r.Sport != nil && x.Sport != nil && r.Sport.Equal(*x.Sport))) &&
59+
(r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange)))
60+
}
61+
3462
func (r Rule) String() string {
3563
from := "all"
3664
if r.Src != nil && r.Src.String() != "<nil>" {
@@ -70,6 +98,10 @@ type RulePortRange struct {
7098
End uint16
7199
}
72100

101+
func (r RulePortRange) Equal(x RulePortRange) bool {
102+
return r.Start == x.Start && r.End == x.End
103+
}
104+
73105
// NewRuleUIDRange creates rule uid range.
74106
func NewRuleUIDRange(start, end uint32) *RuleUIDRange {
75107
return &RuleUIDRange{Start: start, End: end}
@@ -80,3 +112,7 @@ type RuleUIDRange struct {
80112
Start uint32
81113
End uint32
82114
}
115+
116+
func (r RuleUIDRange) Equal(x RuleUIDRange) bool {
117+
return r.Start == x.Start && r.End == x.End
118+
}

rule_test.go

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package netlink
55

66
import (
77
"net"
8+
"strconv"
89
"testing"
910
"time"
1011

@@ -600,7 +601,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) {
600601
t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules))
601602
} else {
602603
for i := range wantRules {
603-
if !ruleEquals(wantRules[i], rules[i]) {
604+
if !wantRules[i].Equal(rules[i]) {
604605
t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i])
605606
}
606607
}
@@ -666,30 +667,108 @@ func TestRuleString(t *testing.T) {
666667

667668
func ruleExists(rules []Rule, rule Rule) bool {
668669
for i := range rules {
669-
if ruleEquals(rules[i], rule) {
670+
if rules[i].Equal(rule) {
670671
return true
671672
}
672673
}
673674

674675
return false
675676
}
676677

677-
func ruleEquals(a, b Rule) bool {
678-
return a.Table == b.Table &&
679-
((a.Src == nil && b.Src == nil) ||
680-
(a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) &&
681-
((a.Dst == nil && b.Dst == nil) ||
682-
(a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) &&
683-
a.OifName == b.OifName &&
684-
a.Priority == b.Priority &&
685-
a.Family == b.Family &&
686-
a.IifName == b.IifName &&
687-
a.Invert == b.Invert &&
688-
a.Tos == b.Tos &&
689-
a.Type == b.Type &&
690-
a.IPProto == b.IPProto &&
691-
a.Protocol == b.Protocol &&
692-
a.Mark == b.Mark &&
693-
(ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 &&
694-
(a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF)))
678+
func TestRuleEqual(t *testing.T) {
679+
cases := []Rule{
680+
{Priority: 1000},
681+
{Family: FAMILY_V6},
682+
{Table: 10},
683+
{Mark: 1},
684+
{Mask: &[]uint32{0x1}[0]},
685+
{Tos: 1},
686+
{TunID: 3},
687+
{Goto: 10},
688+
{Src: &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}},
689+
{Dst: &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}},
690+
{Flow: 3},
691+
{IifName: "IifName"},
692+
{OifName: "OifName"},
693+
{SuppressIfgroup: 7},
694+
{SuppressPrefixlen: 16},
695+
{Invert: true},
696+
{Dport: &RulePortRange{Start: 10, End: 20}},
697+
{Sport: &RulePortRange{Start: 1, End: 2}},
698+
{IPProto: unix.IPPROTO_TCP},
699+
{UIDRange: &RuleUIDRange{Start: 3, End: 5}},
700+
{Protocol: FAMILY_V6},
701+
{Type: unix.RTN_UNREACHABLE},
702+
}
703+
for i1 := range cases {
704+
for i2 := range cases {
705+
got := cases[i1].Equal(cases[i2])
706+
expected := i1 == i2
707+
if got != expected {
708+
t.Errorf("Equal(%q,%q) == %s but expected %s",
709+
cases[i1], cases[i2],
710+
strconv.FormatBool(got),
711+
strconv.FormatBool(expected))
712+
}
713+
}
714+
}
715+
}
716+
717+
func TestRuleEqualmaskMark(t *testing.T) {
718+
expected := true // These are equal
719+
m := &[]uint32{0xFFFFFFFF}[0]
720+
cases := []Rule{
721+
{Mark: 1, Mask: nil},
722+
{Mark: 1, Mask: m},
723+
{Mark: 1, Mask: &[]uint32{0xFFFFFFFF}[0]},
724+
}
725+
for i1 := range cases {
726+
for i2 := range cases {
727+
got := cases[i1].Equal(cases[i2])
728+
if got != expected {
729+
t.Errorf("Equal(%q,%q) == %s but expected %s",
730+
cases[i1], cases[i2],
731+
strconv.FormatBool(got),
732+
strconv.FormatBool(expected))
733+
}
734+
}
735+
}
736+
}
737+
738+
func TestRulePortRangeEqual(t *testing.T) {
739+
cases := []RulePortRange{
740+
{Start: 10, End: 10},
741+
{Start: 11, End: 22},
742+
}
743+
for i1 := range cases {
744+
for i2 := range cases {
745+
got := cases[i1].Equal(cases[i2])
746+
expected := i1 == i2
747+
if got != expected {
748+
t.Errorf("Equal(%q,%q) == %s but expected %s",
749+
cases[i1], cases[i2],
750+
strconv.FormatBool(got),
751+
strconv.FormatBool(expected))
752+
}
753+
}
754+
}
755+
}
756+
757+
func TestRuleUIDRangeEqual(t *testing.T) {
758+
cases := []RuleUIDRange{
759+
{Start: 10, End: 10},
760+
{Start: 11, End: 22},
761+
}
762+
for i1 := range cases {
763+
for i2 := range cases {
764+
got := cases[i1].Equal(cases[i2])
765+
expected := i1 == i2
766+
if got != expected {
767+
t.Errorf("Equal(%q,%q) == %s but expected %s",
768+
cases[i1], cases[i2],
769+
strconv.FormatBool(got),
770+
strconv.FormatBool(expected))
771+
}
772+
}
773+
}
695774
}

0 commit comments

Comments
 (0)