diff --git a/component/domain-trie/tire.go b/component/domain-trie/tire.go index f4e87ff..b4f2bce 100644 --- a/component/domain-trie/tire.go +++ b/component/domain-trie/tire.go @@ -6,8 +6,9 @@ import ( ) const ( - wildcard = "*" - domainStep = "." + wildcard = "*" + dotWildcard = "" + domainStep = "." ) var ( @@ -21,8 +22,23 @@ type Trie struct { root *Node } -func isValidDomain(domain string) bool { - return domain != "" && domain[0] != '.' && domain[len(domain)-1] != '.' +func validAndSplitDomain(domain string) ([]string, bool) { + if domain != "" && domain[len(domain)-1] == '.' { + return nil, false + } + + parts := strings.Split(domain, domainStep) + if len(parts) == 1 { + return nil, false + } + + for _, part := range parts[1:] { + if part == "" { + return nil, false + } + } + + return parts, true } // Insert adds a node to the trie. @@ -30,12 +46,13 @@ func isValidDomain(domain string) bool { // 1. www.example.com // 2. *.example.com // 3. subdomain.*.example.com +// 4. .example.com func (t *Trie) Insert(domain string, data interface{}) error { - if !isValidDomain(domain) { + parts, valid := validAndSplitDomain(domain) + if !valid { return ErrInvalidDomain } - parts := strings.Split(domain, domainStep) node := t.root // reverse storage domain part to save space for i := len(parts) - 1; i >= 0; i-- { @@ -55,28 +72,38 @@ func (t *Trie) Insert(domain string, data interface{}) error { // Priority as: // 1. static part // 2. wildcard domain +// 2. dot wildcard domain func (t *Trie) Search(domain string) *Node { - if !isValidDomain(domain) { + parts, valid := validAndSplitDomain(domain) + if !valid || parts[0] == "" { return nil } - parts := strings.Split(domain, domainStep) n := t.root + var dotWildcardNode *Node for i := len(parts) - 1; i >= 0; i-- { part := parts[i] - var child *Node - if !n.hasChild(part) { - if !n.hasChild(wildcard) { - return nil - } - - child = n.getChild(wildcard) - } else { - child = n.getChild(part) + if node := n.getChild(dotWildcard); node != nil { + dotWildcardNode = node } - n = child + if n.hasChild(part) { + n = n.getChild(part) + } else { + n = n.getChild(wildcard) + } + + if n == nil { + break + } + } + + if n == nil { + if dotWildcardNode != nil { + return dotWildcardNode + } + return nil } if n.Data == nil { diff --git a/component/domain-trie/trie_test.go b/component/domain-trie/trie_test.go index 228106c..927e434 100644 --- a/component/domain-trie/trie_test.go +++ b/component/domain-trie/trie_test.go @@ -3,6 +3,8 @@ package trie import ( "net" "testing" + + "github.com/stretchr/testify/assert" ) var localIP = net.IP{127, 0, 0, 1} @@ -19,17 +21,9 @@ func TestTrie_Basic(t *testing.T) { } node := tree.Search("example.com") - if node == nil { - t.Error("should not recv nil") - } - - if !node.Data.(net.IP).Equal(localIP) { - t.Error("should equal 127.0.0.1") - } - - if tree.Insert("", localIP) == nil { - t.Error("should return error") - } + assert.NotNil(t, node) + assert.True(t, node.Data.(net.IP).Equal(localIP)) + assert.NotNil(t, tree.Insert("", localIP)) } func TestTrie_Wildcard(t *testing.T) { @@ -38,50 +32,54 @@ func TestTrie_Wildcard(t *testing.T) { "*.example.com", "sub.*.example.com", "*.dev", + ".org", + ".example.net", } for _, domain := range domains { tree.Insert(domain, localIP) } - if tree.Search("sub.example.com") == nil { - t.Error("should not recv nil") + assert.NotNil(t, tree.Search("sub.example.com")) + assert.NotNil(t, tree.Search("sub.foo.example.com")) + assert.NotNil(t, tree.Search("test.org")) + assert.NotNil(t, tree.Search("test.example.net")) + assert.Nil(t, tree.Search("foo.sub.example.com")) + assert.Nil(t, tree.Search("foo.example.dev")) + assert.Nil(t, tree.Search("example.com")) +} + +func TestTrie_Priority(t *testing.T) { + tree := New() + domains := []string{ + ".dev", + "example.dev", + "*.example.dev", + "test.example.dev", } - if tree.Search("sub.foo.example.com") == nil { - t.Error("should not recv nil") + assertFn := func(domain string, data int) { + node := tree.Search(domain) + assert.NotNil(t, node) + assert.Equal(t, data, node.Data) } - if tree.Search("foo.sub.example.com") != nil { - t.Error("should recv nil") + for idx, domain := range domains { + tree.Insert(domain, idx) } - if tree.Search("foo.example.dev") != nil { - t.Error("should recv nil") - } - - if tree.Search("example.com") != nil { - t.Error("should recv nil") - } + assertFn("test.dev", 0) + assertFn("foo.bar.dev", 0) + assertFn("example.dev", 1) + assertFn("foo.example.dev", 2) + assertFn("test.example.dev", 3) } func TestTrie_Boundary(t *testing.T) { tree := New() tree.Insert("*.dev", localIP) - if err := tree.Insert(".", localIP); err == nil { - t.Error("should recv err") - } - - if err := tree.Insert(".com", localIP); err == nil { - t.Error("should recv err") - } - - if tree.Search("dev") != nil { - t.Error("should recv nil") - } - - if tree.Search(".dev") != nil { - t.Error("should recv nil") - } + assert.NotNil(t, tree.Insert(".", localIP)) + assert.NotNil(t, tree.Insert("..dev", localIP)) + assert.Nil(t, tree.Search("dev")) }