Newer
Older
// This file is part of the go-ethereum library.
// The go-ethereum library is free software: you can redistribute it and/or modify
// it under the terms of the GNU Lesser General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// The go-ethereum library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Lesser General Public License for more details.
//
// You should have received a copy of the GNU Lesser General Public License
// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
"github.com/ethereum/go-ethereum/crypto/sha3"
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
func init() {
// glog.SetV(6)
// glog.SetToStderr(true)
}
type testTransport struct {
id discover.NodeID
*rlpx
closeErr error
}
func newTestTransport(id discover.NodeID, fd net.Conn) transport {
wrapped := newRLPX(fd).(*rlpx)
wrapped.rw = newRLPXFrameRW(fd, secrets{
MAC: zero16,
AES: zero16,
IngressMAC: sha3.NewKeccak256(),
EgressMAC: sha3.NewKeccak256(),
})
return &testTransport{id: id, rlpx: wrapped}
}
func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
return c.id, nil
}
func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
return &protoHandshake{ID: c.id, Name: "test"}, nil
}
func (c *testTransport) close(err error) {
c.rlpx.fd.Close()
c.closeErr = err
}
func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
Name: "test",
MaxPeers: 10,
ListenAddr: "127.0.0.1:0",
PrivateKey: newkey(),
newPeerHook: pf,
newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
if err := server.Start(); err != nil {
t.Fatalf("Could not start server: %v", err)
func TestServerListen(t *testing.T) {
// start the test server
connected := make(chan *Peer)
remid := randomID()
srv := startTestServer(t, remid, func(p *Peer) {
if p.ID() != remid {
t.Error("peer func called with wrong node id")
}
t.Error("peer func called with nil conn")
}
})
defer close(connected)
defer srv.Stop()
// dial the test server
conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
if err != nil {
t.Fatalf("could not dial: %v", err)
select {
case peer := <-connected:
if peer.LocalAddr().String() != conn.RemoteAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peers := srv.Peers()
if !reflect.DeepEqual(peers, []*Peer{peer}) {
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
}
case <-time.After(1 * time.Second):
t.Error("server did not accept within one second")
func TestServerDial(t *testing.T) {
// run a one-shot TCP server to handle the connection.
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("could not setup listener: %v")
}
defer listener.Close()
accepted := make(chan net.Conn)
go func() {
conn, err := listener.Accept()
if err != nil {
}
accepted <- conn
}()
remid := randomID()
srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
defer close(connected)
defer srv.Stop()
// tell the server to connect
tcpAddr := listener.Addr().(*net.TCPAddr)
srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
select {
case conn := <-accepted:
if peer.ID() != remid {
t.Errorf("peer has wrong id")
}
if peer.Name() != "test" {
t.Errorf("peer has wrong name")
}
if peer.RemoteAddr().String() != conn.LocalAddr().String() {
t.Errorf("peer started with wrong conn: got %v, want %v",
peers := srv.Peers()
if !reflect.DeepEqual(peers, []*Peer{peer}) {
t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
}
case <-time.After(1 * time.Second):
t.Error("server did not launch peer within one second")
case <-time.After(1 * time.Second):
t.Error("server did not connect within one second")
// This test checks that tasks generated by dialstate are
// actually executed and taskdone is called for them.
func TestServerTaskScheduling(t *testing.T) {
var (
done = make(chan *testTask)
quit, returned = make(chan struct{}), make(chan struct{})
tc = 0
tg = taskgen{
newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
tc++
return []task{&testTask{index: tc - 1}}
},
doneFunc: func(t task) {
select {
case done <- t.(*testTask):
case <-quit:
}
},
// The Server in this test isn't actually running
// because we're only interested in what run does.
srv := &Server{
MaxPeers: 10,
quit: make(chan struct{}),
ntab: fakeTable{},
running: true,
srv.loopWG.Add(1)
go func() {
srv.run(tg)
close(returned)
}()
var gotdone []*testTask
for i := 0; i < 100; i++ {
gotdone = append(gotdone, <-done)
for i, task := range gotdone {
if task.index != i {
t.Errorf("task %d has wrong index, got %d", i, task.index)
break
}
if !task.called {
t.Errorf("task %d was not called", i)
break
case <-returned:
case <-time.After(500 * time.Millisecond):
t.Error("Server.run did not return within 500ms")
type taskgen struct {
newFunc func(running int, peers map[discover.NodeID]*Peer) []task
doneFunc func(task)
}
func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
return tg.newFunc(running, peers)
}
func (tg taskgen) taskDone(t task, now time.Time) {
tg.doneFunc(t)
}
func (tg taskgen) addStatic(*discover.Node) {
}
type testTask struct {
index int
called bool
func (t *testTask) Do(srv *Server) {
t.called = true
}
// This test checks that connections are disconnected
// just after the encryption handshake when the server is
// at capacity. Trusted connections should still be accepted.
func TestServerAtCap(t *testing.T) {
trustedID := randomID()
srv := &Server{
PrivateKey: newkey(),
TrustedNodes: []*discover.Node{{ID: trustedID}},
if err := srv.Start(); err != nil {
t.Fatalf("could not start: %v", err)
newconn := func(id discover.NodeID) *conn {
fd, _ := net.Pipe()
tx := newTestTransport(id, fd)
return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
}
// Inject a few connections to fill up the peer set.
for i := 0; i < 10; i++ {
c := newconn(randomID())
if err := srv.checkpoint(c, srv.addpeer); err != nil {
t.Fatalf("could not add conn %d: %v", i, err)
}
}
// Try inserting a non-trusted connection.
c := newconn(randomID())
if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
t.Error("wrong error for insert:", err)
// Try inserting a trusted connection.
c = newconn(trustedID)
if err := srv.checkpoint(c, srv.posthandshake); err != nil {
t.Error("unexpected error for trusted conn @posthandshake:", err)
if !c.is(trustedConn) {
t.Error("Server did not set trusted flag")
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
func TestServerSetupConn(t *testing.T) {
id := randomID()
srvkey := newkey()
srvid := discover.PubkeyID(&srvkey.PublicKey)
tests := []struct {
dontstart bool
tt *setupTransport
flags connFlag
dialDest *discover.Node
wantCloseErr error
wantCalls string
}{
{
dontstart: true,
tt: &setupTransport{id: id},
wantCalls: "close,",
wantCloseErr: errServerStopped,
},
{
tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
flags: inboundConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: errors.New("read error"),
},
{
tt: &setupTransport{id: id},
dialDest: &discover.Node{ID: randomID()},
flags: dynDialedConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: DiscUnexpectedIdentity,
},
{
tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
dialDest: &discover.Node{ID: id},
flags: dynDialedConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: DiscUnexpectedIdentity,
},
{
tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
dialDest: &discover.Node{ID: id},
flags: dynDialedConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: errors.New("foo"),
},
{
tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
flags: inboundConn,
wantCalls: "doEncHandshake,close,",
wantCloseErr: DiscSelf,
},
{
tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}},
flags: inboundConn,
wantCalls: "doEncHandshake,doProtoHandshake,close,",
wantCloseErr: DiscUselessPeer,
},
for i, test := range tests {
srv := &Server{
PrivateKey: srvkey,
MaxPeers: 10,
NoDial: true,
Protocols: []Protocol{discard},
newTransport: func(fd net.Conn) transport { return test.tt },
if !test.dontstart {
if err := srv.Start(); err != nil {
t.Fatalf("couldn't start server: %v", err)
p1, _ := net.Pipe()
srv.setupConn(p1, test.flags, test.dialDest)
if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
}
if test.tt.calls != test.wantCalls {
t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
}
type setupTransport struct {
id discover.NodeID
encHandshakeErr error
phs *protoHandshake
protoHandshakeErr error
calls string
closeErr error
}
func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
c.calls += "doEncHandshake,"
return c.id, c.encHandshakeErr
}
func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
c.calls += "doProtoHandshake,"
if c.protoHandshakeErr != nil {
return nil, c.protoHandshakeErr
return c.phs, nil
}
func (c *setupTransport) close(err error) {
c.calls += "close,"
c.closeErr = err
}
// setupConn shouldn't write to/read from the connection.
func (c *setupTransport) WriteMsg(Msg) error {
panic("WriteMsg called on setupTransport")
}
func (c *setupTransport) ReadMsg() (Msg, error) {
panic("ReadMsg called on setupTransport")
func newkey() *ecdsa.PrivateKey {
key, err := crypto.GenerateKey()
if err != nil {
panic("couldn't generate key: " + err.Error())
}
return key
}
func randomID() (id discover.NodeID) {
for i := range id {
id[i] = byte(rand.Intn(255))
}
return id
}