Feature: domain trie support dot dot wildcard

This commit is contained in:
Dreamacro 2020-04-08 15:45:59 +08:00
parent 5591e15452
commit 65dab4e34f
2 changed files with 82 additions and 57 deletions

View File

@ -6,8 +6,9 @@ import (
)
const (
wildcard = "*"
domainStep = "."
wildcard = "*"
dotWildcard = ""
domainStep = "."
)
var (
@ -21,8 +22,23 @@ type Trie struct {
root *Node
}
func isValidDomain(domain string) bool {
return domain != "" && domain[0] != '.' && domain[len(domain)-1] != '.'
func validAndSplitDomain(domain string) ([]string, bool) {
if domain != "" && domain[len(domain)-1] == '.' {
return nil, false
}
parts := strings.Split(domain, domainStep)
if len(parts) == 1 {
return nil, false
}
for _, part := range parts[1:] {
if part == "" {
return nil, false
}
}
return parts, true
}
// Insert adds a node to the trie.
@ -30,12 +46,13 @@ func isValidDomain(domain string) bool {
// 1. www.example.com
// 2. *.example.com
// 3. subdomain.*.example.com
// 4. .example.com
func (t *Trie) Insert(domain string, data interface{}) error {
if !isValidDomain(domain) {
parts, valid := validAndSplitDomain(domain)
if !valid {
return ErrInvalidDomain
}
parts := strings.Split(domain, domainStep)
node := t.root
// reverse storage domain part to save space
for i := len(parts) - 1; i >= 0; i-- {
@ -55,28 +72,38 @@ func (t *Trie) Insert(domain string, data interface{}) error {
// Priority as:
// 1. static part
// 2. wildcard domain
// 2. dot wildcard domain
func (t *Trie) Search(domain string) *Node {
if !isValidDomain(domain) {
parts, valid := validAndSplitDomain(domain)
if !valid || parts[0] == "" {
return nil
}
parts := strings.Split(domain, domainStep)
n := t.root
var dotWildcardNode *Node
for i := len(parts) - 1; i >= 0; i-- {
part := parts[i]
var child *Node
if !n.hasChild(part) {
if !n.hasChild(wildcard) {
return nil
}
child = n.getChild(wildcard)
} else {
child = n.getChild(part)
if node := n.getChild(dotWildcard); node != nil {
dotWildcardNode = node
}
n = child
if n.hasChild(part) {
n = n.getChild(part)
} else {
n = n.getChild(wildcard)
}
if n == nil {
break
}
}
if n == nil {
if dotWildcardNode != nil {
return dotWildcardNode
}
return nil
}
if n.Data == nil {

View File

@ -3,6 +3,8 @@ package trie
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
)
var localIP = net.IP{127, 0, 0, 1}
@ -19,17 +21,9 @@ func TestTrie_Basic(t *testing.T) {
}
node := tree.Search("example.com")
if node == nil {
t.Error("should not recv nil")
}
if !node.Data.(net.IP).Equal(localIP) {
t.Error("should equal 127.0.0.1")
}
if tree.Insert("", localIP) == nil {
t.Error("should return error")
}
assert.NotNil(t, node)
assert.True(t, node.Data.(net.IP).Equal(localIP))
assert.NotNil(t, tree.Insert("", localIP))
}
func TestTrie_Wildcard(t *testing.T) {
@ -38,50 +32,54 @@ func TestTrie_Wildcard(t *testing.T) {
"*.example.com",
"sub.*.example.com",
"*.dev",
".org",
".example.net",
}
for _, domain := range domains {
tree.Insert(domain, localIP)
}
if tree.Search("sub.example.com") == nil {
t.Error("should not recv nil")
assert.NotNil(t, tree.Search("sub.example.com"))
assert.NotNil(t, tree.Search("sub.foo.example.com"))
assert.NotNil(t, tree.Search("test.org"))
assert.NotNil(t, tree.Search("test.example.net"))
assert.Nil(t, tree.Search("foo.sub.example.com"))
assert.Nil(t, tree.Search("foo.example.dev"))
assert.Nil(t, tree.Search("example.com"))
}
func TestTrie_Priority(t *testing.T) {
tree := New()
domains := []string{
".dev",
"example.dev",
"*.example.dev",
"test.example.dev",
}
if tree.Search("sub.foo.example.com") == nil {
t.Error("should not recv nil")
assertFn := func(domain string, data int) {
node := tree.Search(domain)
assert.NotNil(t, node)
assert.Equal(t, data, node.Data)
}
if tree.Search("foo.sub.example.com") != nil {
t.Error("should recv nil")
for idx, domain := range domains {
tree.Insert(domain, idx)
}
if tree.Search("foo.example.dev") != nil {
t.Error("should recv nil")
}
if tree.Search("example.com") != nil {
t.Error("should recv nil")
}
assertFn("test.dev", 0)
assertFn("foo.bar.dev", 0)
assertFn("example.dev", 1)
assertFn("foo.example.dev", 2)
assertFn("test.example.dev", 3)
}
func TestTrie_Boundary(t *testing.T) {
tree := New()
tree.Insert("*.dev", localIP)
if err := tree.Insert(".", localIP); err == nil {
t.Error("should recv err")
}
if err := tree.Insert(".com", localIP); err == nil {
t.Error("should recv err")
}
if tree.Search("dev") != nil {
t.Error("should recv nil")
}
if tree.Search(".dev") != nil {
t.Error("should recv nil")
}
assert.NotNil(t, tree.Insert(".", localIP))
assert.NotNil(t, tree.Insert("..dev", localIP))
assert.Nil(t, tree.Search("dev"))
}