Skip to content

Commit

Permalink
feat: support convert mrs format back to text format
Browse files Browse the repository at this point in the history
  • Loading branch information
wwqgtxx committed Jul 28, 2024
1 parent 1db3e45 commit c5a4a91
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 12 deletions.
10 changes: 10 additions & 0 deletions component/cidr/ipcidr_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,16 @@ func (set *IpCidrSet) Merge() error {
return nil
}

func (set *IpCidrSet) Foreach(f func(prefix netip.Prefix) bool) {
for _, r := range set.rr {
for _, prefix := range r.Prefixes() {
if !f(prefix) {
return
}
}
}
}

// ToIPSet not safe convert to *netipx.IPSet
// be careful, must be used after Merge
func (set *IpCidrSet) ToIPSet() *netipx.IPSet {
Expand Down
19 changes: 13 additions & 6 deletions component/trie/domain.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,27 +123,34 @@ func (t *DomainTrie[T]) Optimize() {
t.root.optimize()
}

func (t *DomainTrie[T]) Foreach(print func(domain string, data T)) {
func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) {
for key, data := range t.root.getChildren() {
recursion([]string{key}, data, print)
recursion([]string{key}, data, fn)
if data != nil && data.inited {
print(joinDomain([]string{key}), data.data)
if !fn(joinDomain([]string{key}), data.data) {
return
}
}
}
}

func recursion[T any](items []string, node *Node[T], fn func(domain string, data T)) {
func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool {
for key, data := range node.getChildren() {
newItems := append([]string{key}, items...)
if data != nil && data.inited {
domain := joinDomain(newItems)
if domain[0] == domainStepByte {
domain = complexWildcard + domain
}
fn(domain, data.Data())
if !fn(domain, data.Data()) {
return false
}
}
if !recursion(newItems, data, fn) {
return false
}
recursion(newItems, data, fn)
}
return true
}

func joinDomain(items []string) string {
Expand Down
38 changes: 37 additions & 1 deletion component/trie/domain_set.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ type qElt struct{ s, e, col int }
// NewDomainSet creates a new *DomainSet struct, from a DomainTrie.
func (t *DomainTrie[T]) NewDomainSet() *DomainSet {
reserveDomains := make([]string, 0)
t.Foreach(func(domain string, data T) {
t.Foreach(func(domain string, data T) bool {
reserveDomains = append(reserveDomains, utils.Reverse(domain))
return true
})
// ensure that the same prefix is continuous
// and according to the ascending sequence of length
Expand Down Expand Up @@ -136,6 +137,41 @@ func (ss *DomainSet) Has(key string) bool {

}

func (ss *DomainSet) keys(f func(key string) bool) {
var currentKey []byte
var traverse func(int, int) bool
traverse = func(nodeId, bmIdx int) bool {
if getBit(ss.leaves, nodeId) != 0 {
if !f(string(currentKey)) {
return false
}
}

for ; ; bmIdx++ {
if getBit(ss.labelBitmap, bmIdx) != 0 {
return true
}
nextLabel := ss.labels[bmIdx-nodeId]
currentKey = append(currentKey, nextLabel)
nextNodeId := countZeros(ss.labelBitmap, ss.ranks, bmIdx+1)
nextBmIdx := selectIthOne(ss.labelBitmap, ss.ranks, ss.selects, nextNodeId-1) + 1
if !traverse(nextNodeId, nextBmIdx) {
return false
}
currentKey = currentKey[:len(currentKey)-1]
}
}

traverse(0, 0)
return
}

func (ss *DomainSet) Foreach(f func(key string) bool) {
ss.keys(func(key string) bool {
return f(utils.Reverse(key))
})
}

func setBit(bm *[]uint64, i int, v int) {
for i>>6 >= len(*bm) {
*bm = append(*bm, 0)
Expand Down
20 changes: 20 additions & 0 deletions component/trie/domain_set_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
package trie_test

import (
"golang.org/x/exp/slices"
"testing"

"github.com/metacubex/mihomo/component/trie"
"github.com/stretchr/testify/assert"
)

func testDump(t *testing.T, tree *trie.DomainTrie[struct{}], set *trie.DomainSet) {
var dataSrc []string
tree.Foreach(func(domain string, data struct{}) bool {
dataSrc = append(dataSrc, domain)
return true
})
slices.Sort(dataSrc)
var dataSet []string
set.Foreach(func(key string) bool {
dataSet = append(dataSet, key)
return true
})
slices.Sort(dataSet)
assert.Equal(t, dataSrc, dataSet)
}

func TestDomainSet(t *testing.T) {
tree := trie.New[struct{}]()
domainSet := []string{
Expand All @@ -33,6 +50,7 @@ func TestDomainSet(t *testing.T) {
assert.True(t, set.Has("google.com"))
assert.False(t, set.Has("qq.com"))
assert.False(t, set.Has("www.baidu.com"))
testDump(t, tree, set)
}

func TestDomainSetComplexWildcard(t *testing.T) {
Expand All @@ -55,6 +73,7 @@ func TestDomainSetComplexWildcard(t *testing.T) {
assert.False(t, set.Has("google.com"))
assert.True(t, set.Has("www.baidu.com"))
assert.True(t, set.Has("test.test.baidu.com"))
testDump(t, tree, set)
}

func TestDomainSetWildcard(t *testing.T) {
Expand Down Expand Up @@ -82,4 +101,5 @@ func TestDomainSetWildcard(t *testing.T) {
assert.False(t, set.Has("a.www.google.com"))
assert.False(t, set.Has("test.qq.com"))
assert.False(t, set.Has("test.test.test.qq.com"))
testDump(t, tree, set)
}
3 changes: 2 additions & 1 deletion component/trie/domain_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ func TestTrie_Foreach(t *testing.T) {
assert.NoError(t, tree.Insert(domain, localIP))
}
count := 0
tree.Foreach(func(domain string, data netip.Addr) {
tree.Foreach(func(domain string, data netip.Addr) bool {
count++
return true
})
assert.Equal(t, 7, count)
}
15 changes: 11 additions & 4 deletions docs/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -944,10 +944,17 @@ rule-providers:
type: file
rule3:
# mrs类型ruleset,目前仅支持domain和ipcidr(即不支持classical),
# behavior=domain,format=yaml 可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换得到
# behavior=domain,format=text 可以通过“mihomo convert-ruleset domain text XXX.text XXX.mrs”转换得到
# behavior=ipcidr,format=yaml 可以通过“mihomo convert-ruleset ipcidr yaml XXX.yaml XXX.mrs”转换得到
# behavior=ipcidr,format=text 可以通过“mihomo convert-ruleset ipcidr text XXX.text XXX.mrs”转换得到
#
# 对于behavior=domain:
# - format=yaml 可以通过“mihomo convert-ruleset domain yaml XXX.yaml XXX.mrs”转换到mrs格式
# - format=text 可以通过“mihomo convert-ruleset domain text XXX.text XXX.mrs”转换到mrs格式
# - XXX.mrs 可以通过"mihomo convert-ruleset domain mrs XXX.mrs XXX.text"转换回text格式(暂不支持转换回ymal格式)
#
# 对于behavior=ipcidr:
# - format=yaml 可以通过“mihomo convert-ruleset ipcidr yaml XXX.yaml XXX.mrs”转换到mrs格式
# - format=text 可以通过“mihomo convert-ruleset ipcidr text XXX.text XXX.mrs”转换到mrs格式
# - XXX.mrs 可以通过"mihomo convert-ruleset ipcidr mrs XXX.mrs XXX.text"转换回text格式(暂不支持转换回ymal格式)
#
type: http
url: "url"
format: mrs
Expand Down
16 changes: 16 additions & 0 deletions rules/provider/domain_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,22 @@ func (d *domainStrategy) WriteMrs(w io.Writer) error {
return d.domainSet.WriteBin(w)
}

func (d *domainStrategy) DumpMrs(f func(key string) bool) {
if d.domainSet != nil {
var lastKey string
d.domainSet.Foreach(func(key string) bool {
defer func() { lastKey = key }()
if key != "+."+lastKey && lastKey != "" { // Clear the rules added by trie internal processing
return f(lastKey)
}
return true
})
if lastKey != "" {
f(lastKey)
}
}
}

var _ mrsRuleStrategy = (*domainStrategy)(nil)

func NewDomainStrategy() *domainStrategy {
Expand Down
9 changes: 9 additions & 0 deletions rules/provider/ipcidr_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package provider
import (
"errors"
"io"
"net/netip"

"github.com/metacubex/mihomo/component/cidr"
C "github.com/metacubex/mihomo/constant"
Expand Down Expand Up @@ -82,6 +83,14 @@ func (i *ipcidrStrategy) WriteMrs(w io.Writer) error {
return i.cidrSet.WriteBin(w)
}

func (i *ipcidrStrategy) DumpMrs(f func(key string) bool) {
if i.cidrSet != nil {
i.cidrSet.Foreach(func(prefix netip.Prefix) bool {
return f(prefix.String())
})
}
}

func (i *ipcidrStrategy) ToIpCidr() *netipx.IPSet {
return i.cidrSet.ToIPSet()
}
Expand Down
12 changes: 12 additions & 0 deletions rules/provider/mrs_converter.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package provider
import (
"encoding/binary"
"errors"
"fmt"
"io"
"os"

Expand All @@ -21,6 +22,17 @@ func ConvertToMrs(buf []byte, behavior P.RuleBehavior, format P.RuleFormat, w io
return errors.New("empty rule")
}
if _strategy, ok := strategy.(mrsRuleStrategy); ok {
if format == P.MrsRule { // export to TextRule
_strategy.DumpMrs(func(key string) bool {
_, err = fmt.Fprintln(w, key)
if err != nil {
return false
}
return true
})
return nil
}

var encoder *zstd.Encoder
encoder, err = zstd.NewWriter(w)
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions rules/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ type mrsRuleStrategy interface {
ruleStrategy
FromMrs(r io.Reader, count int) error
WriteMrs(w io.Writer) error
DumpMrs(f func(key string) bool)
}

func (rp *ruleSetProvider) Type() P.ProviderType {
Expand Down

0 comments on commit c5a4a91

Please sign in to comment.