check function argument when Sub

This commit is contained in:
pocke 2015-01-11 13:08:26 +09:00
parent 04c9d89576
commit 9a8ffa6529
2 changed files with 53 additions and 9 deletions

43
pfs.go
View File

@ -9,14 +9,14 @@ import (
type PFS struct {
// listeners are listener functions.
listeners []reflect.Value
lmu *sync.RWMutex
lmu sync.RWMutex
argTypes []reflect.Type
tmu sync.RWMutex
}
func New() *PFS {
return &PFS{
listeners: make([]reflect.Value, 0),
lmu: &sync.RWMutex{},
}
return &PFS{}
}
func (p *PFS) Pub(args ...interface{}) bool {
@ -64,11 +64,40 @@ func (p *PFS) checkFuncSignature(f interface{}) (*reflect.Value, error) {
return nil, fmt.Errorf("Argument should be a function")
}
types := fnArgTypes(fn)
p.lmu.RLock()
defer p.lmu.RUnlock()
if len(p.listeners) != 0 {
// TODO: check fn arguments
if len(p.listeners) == 0 {
p.tmu.Lock()
defer p.tmu.Unlock()
p.argTypes = types
return &fn, nil
}
p.tmu.RLock()
defer p.tmu.RUnlock()
if len(types) != len(p.argTypes) {
return nil, fmt.Errorf("Argument length expected %d, but got %d", len(p.argTypes), len(types))
}
for i, t := range types {
if t != p.argTypes[i] {
return nil, fmt.Errorf("Argument Error. Args[%d] expected %s, but got %s", i, p.argTypes[i], t)
}
}
return &fn, nil
}
func fnArgTypes(fn reflect.Value) []reflect.Type {
fnType := fn.Type()
fnNum := fnType.NumIn()
types := make([]reflect.Type, 0, fnNum)
for i := 0; i < fnNum; i++ {
types = append(types, fnType.In(i))
}
return types
}

View File

@ -66,9 +66,24 @@ func TestManySub(t *testing.T) {
}
func TestSubWhenNotFunction(t *testing.T) {
pfs := pfs.New()
err := pfs.Sub("foobar")
p := pfs.New()
err := p.Sub("foobar")
if err == nil {
t.Error("should return error When recieve not function. But got nil.")
}
}
func TestSubWhenInvalidArgs(t *testing.T) {
p := pfs.New()
p.Sub(func(i int) {})
err := p.Sub(func() {})
if err == nil {
t.Error("Should return error when different argument num. But got nil")
}
err = p.Sub(func(s string) {})
if err == nil {
t.Error("Should return error when different args type. But got nil")
}
}