mirror of https://github.com/pocke/goevent.git
check function argument when Sub
This commit is contained in:
parent
04c9d89576
commit
9a8ffa6529
43
pfs.go
43
pfs.go
|
@ -9,14 +9,14 @@ import (
|
||||||
type PFS struct {
|
type PFS struct {
|
||||||
// listeners are listener functions.
|
// listeners are listener functions.
|
||||||
listeners []reflect.Value
|
listeners []reflect.Value
|
||||||
lmu *sync.RWMutex
|
lmu sync.RWMutex
|
||||||
|
|
||||||
|
argTypes []reflect.Type
|
||||||
|
tmu sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New() *PFS {
|
func New() *PFS {
|
||||||
return &PFS{
|
return &PFS{}
|
||||||
listeners: make([]reflect.Value, 0),
|
|
||||||
lmu: &sync.RWMutex{},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *PFS) Pub(args ...interface{}) bool {
|
func (p *PFS) Pub(args ...interface{}) bool {
|
||||||
|
@ -64,11 +64,40 @@ func (p *PFS) checkFuncSignature(f interface{}) (*reflect.Value, error) {
|
||||||
return nil, fmt.Errorf("Argument should be a function")
|
return nil, fmt.Errorf("Argument should be a function")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
types := fnArgTypes(fn)
|
||||||
|
|
||||||
p.lmu.RLock()
|
p.lmu.RLock()
|
||||||
defer p.lmu.RUnlock()
|
defer p.lmu.RUnlock()
|
||||||
if len(p.listeners) != 0 {
|
if len(p.listeners) == 0 {
|
||||||
// TODO: check fn arguments
|
p.tmu.Lock()
|
||||||
|
defer p.tmu.Unlock()
|
||||||
|
p.argTypes = types
|
||||||
|
return &fn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
p.tmu.RLock()
|
||||||
|
defer p.tmu.RUnlock()
|
||||||
|
if len(types) != len(p.argTypes) {
|
||||||
|
return nil, fmt.Errorf("Argument length expected %d, but got %d", len(p.argTypes), len(types))
|
||||||
|
}
|
||||||
|
for i, t := range types {
|
||||||
|
if t != p.argTypes[i] {
|
||||||
|
return nil, fmt.Errorf("Argument Error. Args[%d] expected %s, but got %s", i, p.argTypes[i], t)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return &fn, nil
|
return &fn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func fnArgTypes(fn reflect.Value) []reflect.Type {
|
||||||
|
fnType := fn.Type()
|
||||||
|
fnNum := fnType.NumIn()
|
||||||
|
|
||||||
|
types := make([]reflect.Type, 0, fnNum)
|
||||||
|
|
||||||
|
for i := 0; i < fnNum; i++ {
|
||||||
|
types = append(types, fnType.In(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
return types
|
||||||
|
}
|
||||||
|
|
19
pfs_test.go
19
pfs_test.go
|
@ -66,9 +66,24 @@ func TestManySub(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSubWhenNotFunction(t *testing.T) {
|
func TestSubWhenNotFunction(t *testing.T) {
|
||||||
pfs := pfs.New()
|
p := pfs.New()
|
||||||
err := pfs.Sub("foobar")
|
err := p.Sub("foobar")
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Error("should return error When recieve not function. But got nil.")
|
t.Error("should return error When recieve not function. But got nil.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSubWhenInvalidArgs(t *testing.T) {
|
||||||
|
p := pfs.New()
|
||||||
|
p.Sub(func(i int) {})
|
||||||
|
|
||||||
|
err := p.Sub(func() {})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Should return error when different argument num. But got nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = p.Sub(func(s string) {})
|
||||||
|
if err == nil {
|
||||||
|
t.Error("Should return error when different args type. But got nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue