good morning!!!!

Skip to content
Snippets Groups Projects
Commit e07603bb authored by Janos Guljas's avatar Janos Guljas
Browse files

p2p/testing: check for all expectations in TestExchanges

Handle all expectations in ProtocolSession.TestExchanges in any
order that are received.
parent 40733908
No related branches found
No related tags found
No related merge requests found
......@@ -19,13 +19,17 @@ package testing
import (
"errors"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
)
var errTimedOut = errors.New("timed out")
// ProtocolSession is a quasi simulation of a pivot node running
// a service and a number of dummy peers that can send (trigger) or
// receive (expect) messages
......@@ -46,6 +50,7 @@ type Exchange struct {
Label string
Triggers []Trigger
Expects []Expect
Timeout time.Duration
}
// Trigger is part of the exchange, incoming message for the pivot node
......@@ -102,76 +107,145 @@ func (self *ProtocolSession) trigger(trig Trigger) error {
}
// expect checks an expectation of a message sent out by the pivot node
func (self *ProtocolSession) expect(exp Expect) error {
if exp.Msg == nil {
return errors.New("no message to expect")
}
simNode, ok := self.adapter.GetNode(exp.Peer)
if !ok {
return fmt.Errorf("trigger: peer %v does not exist (1- %v)", exp.Peer, len(self.IDs))
func (self *ProtocolSession) expect(exps []Expect) error {
// construct a map of expectations for each node
peerExpects := make(map[discover.NodeID][]Expect)
for _, exp := range exps {
if exp.Msg == nil {
return errors.New("no message to expect")
}
peerExpects[exp.Peer] = append(peerExpects[exp.Peer], exp)
}
mockNode, ok := simNode.Services()[0].(*mockNode)
if !ok {
return fmt.Errorf("trigger: peer %v is not a mock", exp.Peer)
// construct a map of mockNodes for each node
mockNodes := make(map[discover.NodeID]*mockNode)
for nodeID := range peerExpects {
simNode, ok := self.adapter.GetNode(nodeID)
if !ok {
return fmt.Errorf("trigger: peer %v does not exist (1- %v)", nodeID, len(self.IDs))
}
mockNode, ok := simNode.Services()[0].(*mockNode)
if !ok {
return fmt.Errorf("trigger: peer %v is not a mock", nodeID)
}
mockNodes[nodeID] = mockNode
}
// done chanell cancels all created goroutines when function returns
done := make(chan struct{})
defer close(done)
// errc catches the first error from
errc := make(chan error)
wg := &sync.WaitGroup{}
wg.Add(len(mockNodes))
for nodeID, mockNode := range mockNodes {
nodeID := nodeID
mockNode := mockNode
go func() {
defer wg.Done()
// Sum all Expect timeouts to give the maximum
// time for all expectations to finish.
// mockNode.Expect checks all received messages against
// a list of expected messages and timeout for each
// of them can not be checked separately.
var t time.Duration
for _, exp := range peerExpects[nodeID] {
if exp.Timeout == time.Duration(0) {
t += 2000 * time.Millisecond
} else {
t += exp.Timeout
}
}
alarm := time.NewTimer(t)
defer alarm.Stop()
// expectErrc is used to check if error returned
// from mockNode.Expect is not nil and to send it to
// errc only in that case.
// done channel will be closed when function
expectErrc := make(chan error)
go func() {
select {
case expectErrc <- mockNode.Expect(peerExpects[nodeID]...):
case <-done:
case <-alarm.C:
}
}()
select {
case err := <-expectErrc:
if err != nil {
select {
case errc <- err:
case <-done:
case <-alarm.C:
errc <- errTimedOut
}
}
case <-done:
case <-alarm.C:
errc <- errTimedOut
}
}()
}
go func() {
errc <- mockNode.Expect(&exp)
wg.Wait()
// close errc when all goroutines finish to return nill err from errc
close(errc)
}()
t := exp.Timeout
if t == time.Duration(0) {
t = 2000 * time.Millisecond
}
select {
case err := <-errc:
return err
case <-time.After(t):
return fmt.Errorf("timout expecting %v sent to peer %v", exp.Msg, exp.Peer)
}
return <-errc
}
// TestExchanges tests a series of exchanges against the session
func (self *ProtocolSession) TestExchanges(exchanges ...Exchange) error {
// launch all triggers of this exchanges
for i, e := range exchanges {
if err := self.testExchange(e); err != nil {
return fmt.Errorf("exchange #%d %q: %v", i, e.Label, err)
}
log.Trace(fmt.Sprintf("exchange #%d %q: run successfully", i, e.Label))
}
return nil
}
// testExchange tests a single Exchange.
// Default timeout value is 2 seconds.
func (self *ProtocolSession) testExchange(e Exchange) error {
errc := make(chan error)
done := make(chan struct{})
defer close(done)
for _, e := range exchanges {
errc := make(chan error, len(e.Triggers)+len(e.Expects))
go func() {
for _, trig := range e.Triggers {
errc <- self.trigger(trig)
err := self.trigger(trig)
if err != nil {
errc <- err
return
}
}
// each expectation is spawned in separate go-routine
// expectations of an exchange are conjunctive but unordered, i.e.,
// only all of them arriving constitutes a pass
// each expectation is meant to be for a different peer, otherwise they are expected to panic
// testing of an exchange blocks until all expectations are decided
// an expectation is decided if
// expected message arrives OR
// an unexpected message arrives (panic)
// times out on their individual timeout
for _, ex := range e.Expects {
// expect msg spawned to separate go routine
go func(exp Expect) {
errc <- self.expect(exp)
}(ex)
select {
case errc <- self.expect(e.Expects):
case <-done:
}
}()
// time out globally or finish when all expectations satisfied
timeout := time.After(5 * time.Second)
for i := 0; i < len(e.Triggers)+len(e.Expects); i++ {
select {
case err := <-errc:
if err != nil {
return fmt.Errorf("exchange failed with: %v", err)
}
case <-timeout:
return fmt.Errorf("exchange %v: '%v' timed out", i, e.Label)
}
}
// time out globally or finish when all expectations satisfied
t := e.Timeout
if t == 0 {
t = 2000 * time.Millisecond
}
alarm := time.NewTimer(t)
select {
case err := <-errc:
return err
case <-alarm.C:
return errTimedOut
}
return nil
}
// TestDisconnected tests the disconnections given as arguments
......
......@@ -24,7 +24,11 @@ that can be used to send and receive messages
package testing
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"strings"
"sync"
"testing"
......@@ -34,6 +38,7 @@ import (
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/ethereum/go-ethereum/p2p/simulations"
"github.com/ethereum/go-ethereum/p2p/simulations/adapters"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
)
......@@ -152,7 +157,7 @@ type mockNode struct {
testNode
trigger chan *Trigger
expect chan *Expect
expect chan []Expect
err chan error
stop chan struct{}
stopOnce sync.Once
......@@ -161,7 +166,7 @@ type mockNode struct {
func newMockNode() *mockNode {
mock := &mockNode{
trigger: make(chan *Trigger),
expect: make(chan *Expect),
expect: make(chan []Expect),
err: make(chan error),
stop: make(chan struct{}),
}
......@@ -176,8 +181,8 @@ func (m *mockNode) Run(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
select {
case trig := <-m.trigger:
m.err <- p2p.Send(rw, trig.Code, trig.Msg)
case exp := <-m.expect:
m.err <- p2p.ExpectMsg(rw, exp.Code, exp.Msg)
case exps := <-m.expect:
m.err <- expectMsgs(rw, exps)
case <-m.stop:
return nil
}
......@@ -189,7 +194,7 @@ func (m *mockNode) Trigger(trig *Trigger) error {
return <-m.err
}
func (m *mockNode) Expect(exp *Expect) error {
func (m *mockNode) Expect(exp ...Expect) error {
m.expect <- exp
return <-m.err
}
......@@ -198,3 +203,67 @@ func (m *mockNode) Stop() error {
m.stopOnce.Do(func() { close(m.stop) })
return nil
}
func expectMsgs(rw p2p.MsgReadWriter, exps []Expect) error {
matched := make([]bool, len(exps))
for {
msg, err := rw.ReadMsg()
if err != nil {
if err == io.EOF {
break
}
return err
}
actualContent, err := ioutil.ReadAll(msg.Payload)
if err != nil {
return err
}
var found bool
for i, exp := range exps {
if exp.Code == msg.Code && bytes.Equal(actualContent, mustEncodeMsg(exp.Msg)) {
if matched[i] {
return fmt.Errorf("message #%d received two times", i)
}
matched[i] = true
found = true
break
}
}
if !found {
expected := make([]string, 0)
for i, exp := range exps {
if matched[i] {
continue
}
expected = append(expected, fmt.Sprintf("code %d payload %x", exp.Code, mustEncodeMsg(exp.Msg)))
}
return fmt.Errorf("unexpected message code %d payload %x, expected %s", msg.Code, actualContent, strings.Join(expected, " or "))
}
done := true
for _, m := range matched {
if !m {
done = false
break
}
}
if done {
return nil
}
}
for i, m := range matched {
if !m {
return fmt.Errorf("expected message #%d not received", i)
}
}
return nil
}
// mustEncodeMsg uses rlp to encode a message.
// In case of error it panics.
func mustEncodeMsg(msg interface{}) []byte {
contentEnc, err := rlp.EncodeToBytes(msg)
if err != nil {
panic("content encode error: " + err.Error())
}
return contentEnc
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment