From 052901509166acd546b0a086327d89d639dce328 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Viktor=20Tr=C3=B3n?= <viktor.tron@gmail.com>
Date: Tue, 2 Apr 2019 09:15:16 +0200
Subject: [PATCH] swarm/network:   hive bug: needed shallow peers are not sent
 to nodes beyond connection's proximity order (#19326)

* swarm/network: fix hive bug not sending shallow peers

-  hive bug: needed shallow peers were not sent to nodes beyond connection's proximity order
- add extensive protocol exchange tests for initial subPeersMsg-peersMsg exchange
- modify bzzProtocolTester to allow pregenerated overlay addresses

* swarm/network: attempt to fix hive persistance test

* swarm/network: fix TestHiveStatePersistance (#1320)

* swarm/network: remove trace lines from the hive persistance test

* address PR review comments

* swarm/network: address PR comments on TestInitialPeersMsg

 * eliminate *testing.T argument from bzz/hive protocoltesters
 * add sorting (only runs in test code) on peersMsg payload
 * add random (0 to MaxPeersPerPO) peers for each po
 * add extra peers closer to pivot than control
---
 swarm/network/discovery.go      |  27 ++++-
 swarm/network/discovery_test.go | 206 +++++++++++++++++++++++++++++++-
 swarm/network/hive_test.go      |  84 ++++++-------
 swarm/network/protocol.go       |   3 +
 swarm/network/protocol_test.go  |  58 ++++++---
 5 files changed, 314 insertions(+), 64 deletions(-)

diff --git a/swarm/network/discovery.go b/swarm/network/discovery.go
index 4c503047a..54ecf257c 100644
--- a/swarm/network/discovery.go
+++ b/swarm/network/discovery.go
@@ -26,6 +26,8 @@ import (
 
 // discovery bzz extension for requesting and relaying node address records
 
+var sortPeers = noSortPeers
+
 // Peer wraps BzzPeer and embeds Kademlia overlay connectivity driver
 type Peer struct {
 	*BzzPeer
@@ -156,28 +158,39 @@ func (msg subPeersMsg) String() string {
 	return fmt.Sprintf("%T: request peers > PO%02d. ", msg, msg.Depth)
 }
 
+// handleSubPeersMsg handles incoming subPeersMsg
+// this message represents the saturation depth of the remote peer
+// saturation depth is the radius within which the peer subscribes to peers
+// the first time this is received we send peer info on all
+// our connected peers that fall within peers saturation depth
+// otherwise this depth is just recorded on the peer, so that
+// subsequent new connections are sent iff they fall within the radius
 func (d *Peer) handleSubPeersMsg(msg *subPeersMsg) error {
+	d.setDepth(msg.Depth)
+	// only send peers after the initial subPeersMsg
 	if !d.sentPeers {
-		d.setDepth(msg.Depth)
 		var peers []*BzzAddr
+		// iterate connection in ascending order of disctance from the remote address
 		d.kad.EachConn(d.Over(), 255, func(p *Peer, po int) bool {
-			if pob, _ := Pof(d, d.kad.BaseAddr(), 0); pob > po {
+			// terminate if we are beyond the radius
+			if uint8(po) < msg.Depth {
 				return false
 			}
-			if !d.seen(p.BzzAddr) {
+			if !d.seen(p.BzzAddr) { // here just records the peer sent
 				peers = append(peers, p.BzzAddr)
 			}
 			return true
 		})
+		// if useful  peers are found, send them over
 		if len(peers) > 0 {
-			go d.Send(context.TODO(), &peersMsg{Peers: peers})
+			go d.Send(context.TODO(), &peersMsg{Peers: sortPeers(peers)})
 		}
 	}
 	d.sentPeers = true
 	return nil
 }
 
-// seen takes an peer address and checks if it was sent to a peer already
+// seen takes a peer address and checks if it was sent to a peer already
 // if not, marks the peer as sent
 func (d *Peer) seen(p *BzzAddr) bool {
 	d.mtx.Lock()
@@ -201,3 +214,7 @@ func (d *Peer) setDepth(depth uint8) {
 	defer d.mtx.Unlock()
 	d.depth = depth
 }
+
+func noSortPeers(peers []*BzzAddr) []*BzzAddr {
+	return peers
+}
diff --git a/swarm/network/discovery_test.go b/swarm/network/discovery_test.go
index ea0d776e6..04e1b36fe 100644
--- a/swarm/network/discovery_test.go
+++ b/swarm/network/discovery_test.go
@@ -17,9 +17,22 @@
 package network
 
 import (
+	"crypto/ecdsa"
+	crand "crypto/rand"
+	"encoding/binary"
+	"fmt"
+	"math/rand"
+	"net"
+	"sort"
 	"testing"
+	"time"
 
+	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/protocols"
 	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
+	"github.com/ethereum/go-ethereum/swarm/pot"
 )
 
 /***
@@ -27,9 +40,9 @@ import (
  * - after connect, that outgoing subpeersmsg is sent
  *
  */
-func TestDiscovery(t *testing.T) {
+func TestSubPeersMsg(t *testing.T) {
 	params := NewHiveParams()
-	s, pp, err := newHiveTester(t, params, 1, nil)
+	s, pp, err := newHiveTester(params, 1, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -58,3 +71,192 @@ func TestDiscovery(t *testing.T) {
 		t.Fatal(err)
 	}
 }
+
+const (
+	maxPO         = 8 // PO of pivot and control; chosen to test enough cases but not run too long
+	maxPeerPO     = 6 // pivot has no peers closer than this to the control peer
+	maxPeersPerPO = 3
+)
+
+// TestInitialPeersMsg tests if peersMsg response to incoming subPeersMsg is correct
+func TestInitialPeersMsg(t *testing.T) {
+	for po := 0; po < maxPO; po++ {
+		for depth := 0; depth < maxPO; depth++ {
+			t.Run(fmt.Sprintf("PO=%d,advertised depth=%d", po, depth), func(t *testing.T) {
+				testInitialPeersMsg(t, po, depth)
+			})
+		}
+	}
+}
+
+// testInitialPeersMsg tests that the correct set of peer info is sent
+// to another peer after receiving their subPeersMsg request
+func testInitialPeersMsg(t *testing.T, peerPO, peerDepth int) {
+	// generate random pivot address
+	prvkey, err := crypto.GenerateKey()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	defer func(orig func([]*BzzAddr) []*BzzAddr) {
+		sortPeers = orig
+	}(sortPeers)
+	sortPeers = testSortPeers
+	pivotAddr := pot.NewAddressFromBytes(PrivateKeyToBzzKey(prvkey))
+	// generate control peers address at peerPO wrt pivot
+	peerAddr := pot.RandomAddressAt(pivotAddr, peerPO)
+	// construct kademlia and hive
+	to := NewKademlia(pivotAddr[:], NewKadParams())
+	hive := NewHive(NewHiveParams(), to, nil)
+
+	// expected addrs in peersMsg response
+	var expBzzAddrs []*BzzAddr
+	connect := func(a pot.Address, po int) (addrs []*BzzAddr) {
+		n := rand.Intn(maxPeersPerPO)
+		for i := 0; i < n; i++ {
+			peer, err := newDiscPeer(pot.RandomAddressAt(a, po))
+			if err != nil {
+				t.Fatal(err)
+			}
+			hive.On(peer)
+			addrs = append(addrs, peer.BzzAddr)
+		}
+		return addrs
+	}
+	register := func(a pot.Address, po int) {
+		addr := pot.RandomAddressAt(a, po)
+		hive.Register(&BzzAddr{OAddr: addr[:]})
+	}
+
+	// generate connected and just registered peers
+	for po := maxPeerPO; po >= 0; po-- {
+		// create a fake connected peer at po from peerAddr
+		ons := connect(peerAddr, po)
+		// create a fake registered address at po from peerAddr
+		register(peerAddr, po)
+		// we collect expected peer addresses only up till peerPO
+		if po < peerDepth {
+			continue
+		}
+		expBzzAddrs = append(expBzzAddrs, ons...)
+	}
+
+	// add extra connections closer to pivot than control
+	for po := peerPO + 1; po < maxPO; po++ {
+		ons := connect(pivotAddr, po)
+		if peerDepth <= peerPO {
+			expBzzAddrs = append(expBzzAddrs, ons...)
+		}
+	}
+
+	// create a special bzzBaseTester in which we can associate `enode.ID` to the `bzzAddr` we created above
+	s, _, err := newBzzBaseTesterWithAddrs(prvkey, [][]byte{peerAddr[:]}, DiscoverySpec, hive.Run)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// peerID to use in the protocol tester testExchange expect/trigger
+	peerID := s.Nodes[0].ID()
+	// block until control peer is found among hive peers
+	found := false
+	for attempts := 0; attempts < 20; attempts++ {
+		if _, found = hive.peers[peerID]; found {
+			break
+		}
+		time.Sleep(1 * time.Millisecond)
+	}
+
+	if !found {
+		t.Fatal("timeout waiting for peer connection to start")
+	}
+
+	// pivotDepth is the advertised depth of the pivot node we expect in the outgoing subPeersMsg
+	pivotDepth := hive.saturation()
+	// the test exchange is as follows:
+	// 1. pivot sends to the control peer a `subPeersMsg` advertising its depth (ignored)
+	// 2. peer sends to pivot a `subPeersMsg` advertising its own depth (arbitrarily chosen)
+	// 3. pivot responds with `peersMsg` with the set of expected peers
+	err = s.TestExchanges(
+		p2ptest.Exchange{
+			Label: "outgoing subPeersMsg",
+			Expects: []p2ptest.Expect{
+				{
+					Code: 1,
+					Msg:  &subPeersMsg{Depth: uint8(pivotDepth)},
+					Peer: peerID,
+				},
+			},
+		},
+		p2ptest.Exchange{
+			Label: "trigger subPeersMsg and expect peersMsg",
+			Triggers: []p2ptest.Trigger{
+				{
+					Code: 1,
+					Msg:  &subPeersMsg{Depth: uint8(peerDepth)},
+					Peer: peerID,
+				},
+			},
+			Expects: []p2ptest.Expect{
+				{
+					Code:    0,
+					Msg:     &peersMsg{Peers: testSortPeers(expBzzAddrs)},
+					Peer:    peerID,
+					Timeout: 100 * time.Millisecond,
+				},
+			},
+		})
+
+	// for values MaxPeerPO < peerPO < MaxPO the pivot has no peers to offer to the control peer
+	// in this case, no peersMsg will be sent out, and we would run into a time out
+	if len(expBzzAddrs) == 0 {
+		if err != nil {
+			if err.Error() != "exchange #1 \"trigger subPeersMsg and expect peersMsg\": timed out" {
+				t.Fatalf("expected timeout, got %v", err)
+			}
+			return
+		}
+		t.Fatalf("expected timeout, got no error")
+	}
+
+	if err != nil {
+		t.Fatal(err)
+	}
+}
+
+func testSortPeers(peers []*BzzAddr) []*BzzAddr {
+	comp := func(i, j int) bool {
+		vi := binary.BigEndian.Uint64(peers[i].OAddr)
+		vj := binary.BigEndian.Uint64(peers[j].OAddr)
+		return vi < vj
+	}
+	sort.Slice(peers, comp)
+	return peers
+}
+
+// as we are not creating a real node via the protocol,
+// we need to create the discovery peer objects for the additional kademlia
+// nodes manually
+func newDiscPeer(addr pot.Address) (*Peer, error) {
+	pKey, err := ecdsa.GenerateKey(crypto.S256(), crand.Reader)
+	if err != nil {
+		return nil, err
+	}
+	pubKey := pKey.PublicKey
+	nod := enode.NewV4(&pubKey, net.IPv4(127, 0, 0, 1), 0, 0)
+	bzzAddr := &BzzAddr{OAddr: addr[:], UAddr: []byte(nod.String())}
+	id := nod.ID()
+	p2pPeer := p2p.NewPeer(id, id.String(), nil)
+	return NewPeer(&BzzPeer{
+		Peer:    protocols.NewPeer(p2pPeer, &dummyMsgRW{}, DiscoverySpec),
+		BzzAddr: bzzAddr,
+	}, nil), nil
+}
+
+type dummyMsgRW struct{}
+
+func (d *dummyMsgRW) ReadMsg() (p2p.Msg, error) {
+	return p2p.Msg{}, nil
+}
+func (d *dummyMsgRW) WriteMsg(msg p2p.Msg) error {
+	return nil
+}
diff --git a/swarm/network/hive_test.go b/swarm/network/hive_test.go
index ddae95a45..d03db42bc 100644
--- a/swarm/network/hive_test.go
+++ b/swarm/network/hive_test.go
@@ -23,11 +23,12 @@ import (
 	"time"
 
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/p2p"
 	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
 	"github.com/ethereum/go-ethereum/swarm/state"
 )
 
-func newHiveTester(t *testing.T, params *HiveParams, n int, store state.Store) (*bzzTester, *Hive, error) {
+func newHiveTester(params *HiveParams, n int, store state.Store) (*bzzTester, *Hive, error) {
 	// setup
 	prvkey, err := crypto.GenerateKey()
 	if err != nil {
@@ -37,7 +38,7 @@ func newHiveTester(t *testing.T, params *HiveParams, n int, store state.Store) (
 	to := NewKademlia(addr, NewKadParams())
 	pp := NewHive(params, to, store) // hive
 
-	bt, err := newBzzBaseTester(t, n, prvkey, DiscoverySpec, pp.Run)
+	bt, err := newBzzBaseTester(n, prvkey, DiscoverySpec, pp.Run)
 	if err != nil {
 		return nil, nil, err
 	}
@@ -48,7 +49,7 @@ func newHiveTester(t *testing.T, params *HiveParams, n int, store state.Store) (
 // and that the peer connection exists afterwards
 func TestRegisterAndConnect(t *testing.T) {
 	params := NewHiveParams()
-	s, pp, err := newHiveTester(t, params, 1, nil)
+	s, pp, err := newHiveTester(params, 1, nil)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -108,65 +109,66 @@ func TestRegisterAndConnect(t *testing.T) {
 // Actual connectivity is not in scope for this test, as the peers loaded from state are not known to
 // the simulation; the test only verifies that the peers are known to the node
 func TestHiveStatePersistance(t *testing.T) {
-
 	dir, err := ioutil.TempDir("", "hive_test_store")
 	if err != nil {
-		panic(err)
+		t.Fatal(err)
 	}
 	defer os.RemoveAll(dir)
 
-	store, err := state.NewDBStore(dir) //start the hive with an empty dbstore
-	if err != nil {
-		t.Fatal(err)
-	}
+	const peersCount = 5
 
-	params := NewHiveParams()
-	s, pp, err := newHiveTester(t, params, 5, store)
-	if err != nil {
-		t.Fatal(err)
-	}
-	peers := make(map[string]bool)
-	for _, node := range s.Nodes {
-		raddr := NewAddr(node)
-		pp.Register(raddr)
-		peers[raddr.String()] = true
-	}
+	startHive := func(t *testing.T, dir string) (h *Hive) {
+		store, err := state.NewDBStore(dir)
+		if err != nil {
+			t.Fatal(err)
+		}
 
-	// start and stop the hive
-	// the known peers should be saved upon stopping
-	err = pp.Start(s.Server)
-	if err != nil {
-		t.Fatal(err)
-	}
-	pp.Stop()
-	store.Close()
+		params := NewHiveParams()
+		params.Discovery = false
 
-	// start the hive with an empty dbstore
-	persistedStore, err := state.NewDBStore(dir)
-	if err != nil {
-		t.Fatal(err)
+		prvkey, err := crypto.GenerateKey()
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		h = NewHive(params, NewKademlia(PrivateKeyToBzzKey(prvkey), NewKadParams()), store)
+		s := p2ptest.NewProtocolTester(prvkey, 0, func(p *p2p.Peer, rw p2p.MsgReadWriter) error { return nil })
+
+		if err := h.Start(s.Server); err != nil {
+			t.Fatal(err)
+		}
+		return h
 	}
 
-	s1, pp, err := newHiveTester(t, params, 0, persistedStore)
-	if err != nil {
+	h1 := startHive(t, dir)
+	peers := make(map[string]bool)
+	for i := 0; i < peersCount; i++ {
+		raddr := RandomAddr()
+		h1.Register(raddr)
+		peers[raddr.String()] = true
+	}
+	if err = h1.Stop(); err != nil {
 		t.Fatal(err)
 	}
 
 	// start the hive and check that we know of all expected peers
-	pp.Start(s1.Server)
+	h2 := startHive(t, dir)
+	defer func() {
+		if err = h2.Stop(); err != nil {
+			t.Fatal(err)
+		}
+	}()
+
 	i := 0
-	pp.Kademlia.EachAddr(nil, 256, func(addr *BzzAddr, po int) bool {
+	h2.Kademlia.EachAddr(nil, 256, func(addr *BzzAddr, po int) bool {
 		delete(peers, addr.String())
 		i++
 		return true
 	})
-	// TODO remove this line when verified that test passes
-	time.Sleep(time.Second)
-	if i != 5 {
-		t.Fatalf("invalid number of entries: got %v, want %v", i, 5)
+	if i != peersCount {
+		t.Fatalf("invalid number of entries: got %v, want %v", i, peersCount)
 	}
 	if len(peers) != 0 {
 		t.Fatalf("%d peers left over: %v", len(peers), peers)
 	}
-
 }
diff --git a/swarm/network/protocol.go b/swarm/network/protocol.go
index fcceb5c31..ad3f8df8f 100644
--- a/swarm/network/protocol.go
+++ b/swarm/network/protocol.go
@@ -20,6 +20,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"math/rand"
 	"sync"
 	"time"
 
@@ -37,6 +38,8 @@ const (
 	bzzHandshakeTimeout = 3000 * time.Millisecond
 )
 
+var DefaultTestNetworkID = rand.Uint64()
+
 // BzzSpec is the spec of the generic swarm handshake
 var BzzSpec = &protocols.Spec{
 	Name:       "bzz",
diff --git a/swarm/network/protocol_test.go b/swarm/network/protocol_test.go
index 1e7bb04aa..b562a4253 100644
--- a/swarm/network/protocol_test.go
+++ b/swarm/network/protocol_test.go
@@ -21,6 +21,7 @@ import (
 	"flag"
 	"fmt"
 	"os"
+	"sync"
 	"testing"
 	"time"
 
@@ -31,13 +32,15 @@ import (
 	"github.com/ethereum/go-ethereum/p2p/enr"
 	"github.com/ethereum/go-ethereum/p2p/protocols"
 	p2ptest "github.com/ethereum/go-ethereum/p2p/testing"
+	"github.com/ethereum/go-ethereum/swarm/pot"
 )
 
 const (
-	TestProtocolVersion   = 8
-	TestProtocolNetworkID = 3
+	TestProtocolVersion = 8
 )
 
+var TestProtocolNetworkID = DefaultTestNetworkID
+
 var (
 	loglevel = flag.Int("loglevel", 2, "verbosity of logs")
 )
@@ -70,20 +73,37 @@ func HandshakeMsgExchange(lhs, rhs *HandshakeMsg, id enode.ID) []p2ptest.Exchang
 	}
 }
 
-func newBzzBaseTester(t *testing.T, n int, prvkey *ecdsa.PrivateKey, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, error) {
-	cs := make(map[string]chan bool)
+func newBzzBaseTester(n int, prvkey *ecdsa.PrivateKey, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, error) {
+	var addrs [][]byte
+	for i := 0; i < n; i++ {
+		addr := pot.RandomAddress()
+		addrs = append(addrs, addr[:])
+	}
+	pt, _, err := newBzzBaseTesterWithAddrs(prvkey, addrs, spec, run)
+	return pt, err
+}
+
+func newBzzBaseTesterWithAddrs(prvkey *ecdsa.PrivateKey, addrs [][]byte, spec *protocols.Spec, run func(*BzzPeer) error) (*bzzTester, [][]byte, error) {
+	n := len(addrs)
+	cs := make(map[enode.ID]chan bool)
 
 	srv := func(p *BzzPeer) error {
 		defer func() {
-			if cs[p.ID().String()] != nil {
-				close(cs[p.ID().String()])
+			if cs[p.ID()] != nil {
+				close(cs[p.ID()])
 			}
 		}()
 		return run(p)
 	}
-
+	mu := &sync.Mutex{}
+	nodeToAddr := make(map[enode.ID][]byte)
 	protocol := func(p *p2p.Peer, rw p2p.MsgReadWriter) error {
-		return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: NewAddr(p.Node())})
+		mu.Lock()
+		defer mu.Unlock()
+		nodeToAddr[p.ID()] = addrs[0]
+		bzzAddr := &BzzAddr{addrs[0], []byte(p.Node().String())}
+		addrs = addrs[1:]
+		return srv(&BzzPeer{Peer: protocols.NewPeer(p, rw, spec), BzzAddr: bzzAddr})
 	}
 
 	s := p2ptest.NewProtocolTester(prvkey, n, protocol)
@@ -92,30 +112,36 @@ func newBzzBaseTester(t *testing.T, n int, prvkey *ecdsa.PrivateKey, spec *proto
 	record.Set(NewENRAddrEntry(bzzKey))
 	err := enode.SignV4(&record, prvkey)
 	if err != nil {
-		return nil, fmt.Errorf("unable to generate ENR: %v", err)
+		return nil, nil, fmt.Errorf("unable to generate ENR: %v", err)
 	}
 	nod, err := enode.New(enode.V4ID{}, &record)
 	if err != nil {
-		return nil, fmt.Errorf("unable to create enode: %v", err)
+		return nil, nil, fmt.Errorf("unable to create enode: %v", err)
 	}
 	addr := getENRBzzAddr(nod)
 
 	for _, node := range s.Nodes {
 		log.Warn("node", "node", node)
-		cs[node.ID().String()] = make(chan bool)
+		cs[node.ID()] = make(chan bool)
 	}
 
-	return &bzzTester{
+	var nodeAddrs [][]byte
+	pt := &bzzTester{
 		addr:           addr,
 		ProtocolTester: s,
 		cs:             cs,
-	}, nil
+	}
+	for _, n := range pt.Nodes {
+		nodeAddrs = append(nodeAddrs, nodeToAddr[n.ID()])
+	}
+
+	return pt, nodeAddrs, nil
 }
 
 type bzzTester struct {
 	*p2ptest.ProtocolTester
 	addr *BzzAddr
-	cs   map[string]chan bool
+	cs   map[enode.ID]chan bool
 	bzz  *Bzz
 }
 
@@ -124,7 +150,7 @@ func newBzz(addr *BzzAddr, lightNode bool) *Bzz {
 		OverlayAddr:  addr.Over(),
 		UnderlayAddr: addr.Under(),
 		HiveParams:   NewHiveParams(),
-		NetworkID:    DefaultNetworkID,
+		NetworkID:    DefaultTestNetworkID,
 		LightNode:    lightNode,
 	}
 	kad := NewKademlia(addr.OAddr, NewKadParams())
@@ -207,7 +233,7 @@ func TestBzzHandshakeNetworkIDMismatch(t *testing.T) {
 	err = s.testHandshake(
 		correctBzzHandshake(s.addr, lightNode),
 		&HandshakeMsg{Version: TestProtocolVersion, NetworkID: 321, Addr: NewAddr(node)},
-		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= 3)")},
+		&p2ptest.Disconnect{Peer: node.ID(), Error: fmt.Errorf("Handshake error: Message handler error: (msg code 0): network id mismatch 321 (!= %v)", TestProtocolNetworkID)},
 	)
 
 	if err != nil {
-- 
GitLab