From 6a8c0713bd25be8a8abbfe39aa74fc02708f9d1f Mon Sep 17 00:00:00 2001 From: Stephan Ferlin-Reiter Date: Wed, 25 Jun 2025 21:59:56 +0200 Subject: [PATCH 1/2] Add Rule.Equal method. This adds an implementation of Rule.Equal (and RulePortRange.Equal, RuleUIDRange.Equal) which allows us to compare two Rule objects. --- rule.go | 49 +++++++++++++++++++++ rule_linux.go | 10 ----- rule_test.go | 119 +++++++++++++++++++++++++++++++++++++++++--------- 3 files changed, 148 insertions(+), 30 deletions(-) diff --git a/rule.go b/rule.go index 9d74c7cd..759d078d 100644 --- a/rule.go +++ b/rule.go @@ -31,6 +31,47 @@ type Rule struct { Type uint8 } +func (r Rule) Equal(x Rule) bool { + return r.Table == x.Table && + ((r.Src == nil && x.Src == nil) || + (r.Src != nil && x.Src != nil && r.Src.String() == x.Src.String())) && + ((r.Dst == nil && x.Dst == nil) || + (r.Dst != nil && x.Dst != nil && r.Dst.String() == x.Dst.String())) && + r.OifName == x.OifName && + r.Priority == x.Priority && + r.Family == x.Family && + r.IifName == x.IifName && + r.Invert == x.Invert && + r.Tos == x.Tos && + r.Type == x.Type && + r.IPProto == x.IPProto && + r.Protocol == x.Protocol && + r.Mark == x.Mark && + // For non-zero marks, mask defaults to 0xFFFFFFFF if not set. So if either mask is nil + // while the other is 0xFFFFFFFF when mark is non-zero, treat the masks as identical. + // See kernel source: https://github.com/torvalds/linux/blob/v6.15/net/core/fib_rules.c#L624 + (ptrEqual(r.Mask, x.Mask) || (r.Mark != 0 && + (r.Mask == nil && *x.Mask == 0xFFFFFFFF || x.Mask == nil && *r.Mask == 0xFFFFFFFF))) && + r.TunID == x.TunID && + r.Goto == x.Goto && + r.Flow == x.Flow && + r.SuppressIfgroup == x.SuppressIfgroup && + r.SuppressPrefixlen == x.SuppressPrefixlen && + (r.Dport == x.Dport || (r.Dport != nil && x.Dport != nil && r.Dport.Equal(*x.Dport))) && + (r.Sport == x.Sport || (r.Sport != nil && x.Sport != nil && r.Sport.Equal(*x.Sport))) && + (r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange))) +} + +func ptrEqual(a, b *uint32) bool { + if a == b { + return true + } + if (a == nil) || (b == nil) { + return false + } + return *a == *b +} + func (r Rule) String() string { from := "all" if r.Src != nil && r.Src.String() != "" { @@ -70,6 +111,10 @@ type RulePortRange struct { End uint16 } +func (r RulePortRange) Equal(x RulePortRange) bool { + return r.Start == x.Start && r.End == x.End +} + // NewRuleUIDRange creates rule uid range. func NewRuleUIDRange(start, end uint32) *RuleUIDRange { return &RuleUIDRange{Start: start, End: end} @@ -80,3 +125,7 @@ type RuleUIDRange struct { Start uint32 End uint32 } + +func (r RuleUIDRange) Equal(x RuleUIDRange) bool { + return r.Start == x.Start && r.End == x.End +} diff --git a/rule_linux.go b/rule_linux.go index dba99147..eae5eabe 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -336,16 +336,6 @@ func (pr *RuleUIDRange) toRtAttrData() []byte { return bytes.Join(b, []byte{}) } -func ptrEqual(a, b *uint32) bool { - if a == b { - return true - } - if (a == nil) || (b == nil) { - return false - } - return *a == *b -} - func (r Rule) typeString() string { switch r.Type { case unix.RTN_UNSPEC: // zero diff --git a/rule_test.go b/rule_test.go index 4420e5b5..a558088f 100644 --- a/rule_test.go +++ b/rule_test.go @@ -5,6 +5,7 @@ package netlink import ( "net" + "strconv" "testing" "time" @@ -600,7 +601,7 @@ func runRuleListFiltered(t *testing.T, family int, srcNet, dstNet *net.IPNet) { t.Errorf("Expected len: %d, got: %d", len(wantRules), len(rules)) } else { for i := range wantRules { - if !ruleEquals(wantRules[i], rules[i]) { + if !wantRules[i].Equal(rules[i]) { t.Errorf("Rules mismatch, want %v, got %v", wantRules[i], rules[i]) } } @@ -666,7 +667,7 @@ func TestRuleString(t *testing.T) { func ruleExists(rules []Rule, rule Rule) bool { for i := range rules { - if ruleEquals(rules[i], rule) { + if rules[i].Equal(rule) { return true } } @@ -674,22 +675,100 @@ func ruleExists(rules []Rule, rule Rule) bool { return false } -func ruleEquals(a, b Rule) bool { - return a.Table == b.Table && - ((a.Src == nil && b.Src == nil) || - (a.Src != nil && b.Src != nil && a.Src.String() == b.Src.String())) && - ((a.Dst == nil && b.Dst == nil) || - (a.Dst != nil && b.Dst != nil && a.Dst.String() == b.Dst.String())) && - a.OifName == b.OifName && - a.Priority == b.Priority && - a.Family == b.Family && - a.IifName == b.IifName && - a.Invert == b.Invert && - a.Tos == b.Tos && - a.Type == b.Type && - a.IPProto == b.IPProto && - a.Protocol == b.Protocol && - a.Mark == b.Mark && - (ptrEqual(a.Mask, b.Mask) || (a.Mark != 0 && - (a.Mask == nil && *b.Mask == 0xFFFFFFFF || b.Mask == nil && *a.Mask == 0xFFFFFFFF))) +func TestRuleEqual(t *testing.T) { + cases := []Rule{ + {Priority: 1000}, + {Family: FAMILY_V6}, + {Table: 10}, + {Mark: 1}, + {Mask: &[]uint32{0x1}[0]}, + {Tos: 1}, + {TunID: 3}, + {Goto: 10}, + {Src: &net.IPNet{IP: net.IPv4(172, 16, 0, 1), Mask: net.CIDRMask(16, 32)}}, + {Dst: &net.IPNet{IP: net.IPv4(172, 16, 1, 1), Mask: net.CIDRMask(24, 32)}}, + {Flow: 3}, + {IifName: "IifName"}, + {OifName: "OifName"}, + {SuppressIfgroup: 7}, + {SuppressPrefixlen: 16}, + {Invert: true}, + {Dport: &RulePortRange{Start: 10, End: 20}}, + {Sport: &RulePortRange{Start: 1, End: 2}}, + {IPProto: unix.IPPROTO_TCP}, + {UIDRange: &RuleUIDRange{Start: 3, End: 5}}, + {Protocol: FAMILY_V6}, + {Type: unix.RTN_UNREACHABLE}, + } + for i1 := range cases { + for i2 := range cases { + got := cases[i1].Equal(cases[i2]) + expected := i1 == i2 + if got != expected { + t.Errorf("Equal(%q,%q) == %s but expected %s", + cases[i1], cases[i2], + strconv.FormatBool(got), + strconv.FormatBool(expected)) + } + } + } +} + +func TestRuleEqualMaskMark(t *testing.T) { + a := Rule{Mark: 1, Mask: nil} + b := Rule{Mark: 1, Mask: &[]uint32{0xFFFFFFFF}[0]} + if !a.Equal(b) || !b.Equal(a) { + t.Errorf("Rules are expected to be equal") + } + + b = Rule{Mark: 2, Mask: &[]uint32{0xFFFFFFFF}[0]} + if a.Equal(b) || b.Equal(a) { + t.Errorf("Rules are not expected to be equal") + } + + a = Rule{Mark: 0, Mask: nil} + b = Rule{Mark: 0, Mask: &[]uint32{0xFFFFFFFF}[0]} + if a.Equal(b) || b.Equal(a) { + t.Errorf("Rules are not expected to be equal") + } +} + +func TestRulePortRangeEqual(t *testing.T) { + cases := []RulePortRange{ + {Start: 10, End: 10}, + {Start: 10, End: 22}, + {Start: 11, End: 22}, + } + for i1 := range cases { + for i2 := range cases { + got := cases[i1].Equal(cases[i2]) + expected := i1 == i2 + if got != expected { + t.Errorf("Equal(%q,%q) == %s but expected %s", + cases[i1], cases[i2], + strconv.FormatBool(got), + strconv.FormatBool(expected)) + } + } + } +} + +func TestRuleUIDRangeEqual(t *testing.T) { + cases := []RuleUIDRange{ + {Start: 10, End: 10}, + {Start: 10, End: 22}, + {Start: 11, End: 22}, + } + for i1 := range cases { + for i2 := range cases { + got := cases[i1].Equal(cases[i2]) + expected := i1 == i2 + if got != expected { + t.Errorf("Equal(%q,%q) == %s but expected %s", + cases[i1], cases[i2], + strconv.FormatBool(got), + strconv.FormatBool(expected)) + } + } + } } From 393298de9e9c7a61954593feec0b0eb4d28ec014 Mon Sep 17 00:00:00 2001 From: Stephan Ferlin-Reiter Date: Wed, 25 Jun 2025 22:02:16 +0200 Subject: [PATCH 2/2] Fix reading of rule type and add l3mdev flag. When reading rules, makes sure to populate the type field of Rule. Note that an unspecified type is interpreted as UNICAST, for example, when adding a rule. To make sure that a rule with unspecified type, which was just added, compares equal to that rule being read back, we we update Rule.Equal to treat unspecified and UNICAST types as equal. Also adds a l3mdev flag to Rule, reads it when listing rules, sets it when updating rules, and makes sure it is taken into account when comparing rules using Equal as introduced in #1095. --- rule.go | 7 +++++-- rule_linux.go | 7 +++++++ rule_test.go | 11 ++++++++++- 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/rule.go b/rule.go index 759d078d..536f74a8 100644 --- a/rule.go +++ b/rule.go @@ -29,6 +29,7 @@ type Rule struct { UIDRange *RuleUIDRange Protocol uint8 Type uint8 + L3mdev uint8 } func (r Rule) Equal(x Rule) bool { @@ -43,7 +44,8 @@ func (r Rule) Equal(x Rule) bool { r.IifName == x.IifName && r.Invert == x.Invert && r.Tos == x.Tos && - r.Type == x.Type && + (r.Type == x.Type || + (r.Type == 0 && x.Type == 1 || r.Type == 1 && x.Type == 0)) && // 1 is unix.RTN_UNICAST r.IPProto == x.IPProto && r.Protocol == x.Protocol && r.Mark == x.Mark && @@ -59,7 +61,8 @@ func (r Rule) Equal(x Rule) bool { r.SuppressPrefixlen == x.SuppressPrefixlen && (r.Dport == x.Dport || (r.Dport != nil && x.Dport != nil && r.Dport.Equal(*x.Dport))) && (r.Sport == x.Sport || (r.Sport != nil && x.Sport != nil && r.Sport.Equal(*x.Sport))) && - (r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange))) + (r.UIDRange == x.UIDRange || (r.UIDRange != nil && x.UIDRange != nil && r.UIDRange.Equal(*x.UIDRange))) && + r.L3mdev == x.L3mdev } func ptrEqual(a, b *uint32) bool { diff --git a/rule_linux.go b/rule_linux.go index eae5eabe..1fcfc3b7 100644 --- a/rule_linux.go +++ b/rule_linux.go @@ -178,6 +178,10 @@ func ruleHandle(rule *Rule, req *nl.NetlinkRequest) error { req.AddData(nl.NewRtAttr(nl.FRA_PROTOCOL, nl.Uint8Attr(rule.Protocol))) } + if rule.L3mdev > 0 { + req.AddData(nl.NewRtAttr(nl.FRA_L3MDEV, nl.Uint8Attr(rule.L3mdev))) + } + _, err := req.Execute(unix.NETLINK_ROUTE, 0) return err } @@ -239,6 +243,7 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( rule.Invert = msg.Flags&FibRuleInvert > 0 rule.Family = int(msg.Family) rule.Tos = uint(msg.Tos) + rule.Type = msg.Type for j := range attrs { switch attrs[j].Attr.Type { @@ -291,6 +296,8 @@ func (h *Handle) RuleListFiltered(family int, filter *Rule, filterMask uint64) ( rule.UIDRange = NewRuleUIDRange(native.Uint32(attrs[j].Value[0:4]), native.Uint32(attrs[j].Value[4:8])) case nl.FRA_PROTOCOL: rule.Protocol = uint8(attrs[j].Value[0]) + case nl.FRA_L3MDEV: + rule.L3mdev = uint8(attrs[j].Value[0]) } } diff --git a/rule_test.go b/rule_test.go index a558088f..29490cb8 100644 --- a/rule_test.go +++ b/rule_test.go @@ -55,7 +55,7 @@ func TestRuleAddDel(t *testing.T) { // find this rule found := ruleExists(rules, *rule) if !found { - t.Fatal("Rule has diffrent options than one added") + t.Fatal("Rule has different options than one added") } if err := RuleDel(rule); err != nil { @@ -699,6 +699,7 @@ func TestRuleEqual(t *testing.T) { {UIDRange: &RuleUIDRange{Start: 3, End: 5}}, {Protocol: FAMILY_V6}, {Type: unix.RTN_UNREACHABLE}, + {L3mdev: 1}, } for i1 := range cases { for i2 := range cases { @@ -714,6 +715,14 @@ func TestRuleEqual(t *testing.T) { } } +func TestRuleEqualTypeUnspecifiedEqualsUnicast(t *testing.T) { + a := Rule{Type: unix.RTN_UNSPEC} + b := Rule{Type: unix.RTN_UNICAST} + if !a.Equal(b) || !b.Equal(a) { + t.Errorf("Rules are expected to be equal") + } +} + func TestRuleEqualMaskMark(t *testing.T) { a := Rule{Mark: 1, Mask: nil} b := Rule{Mark: 1, Mask: &[]uint32{0xFFFFFFFF}[0]}