From 85d81b2cdde6f5377fa3af6e108ca0b84a6266bd Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= <zsfelfoldi@gmail.com>
Date: Wed, 21 Oct 2020 10:56:33 +0200
Subject: [PATCH] les: remove clientPeerSet and serverSet (#21566)

* les: move NodeStateMachine from clientPool to LesServer

* les: new header broadcaster

* les: peerCommons.headInfo always contains last announced head

* les: remove clientPeerSet and serverSet

* les: fixed panic

* les: fixed --nodiscover option

* les: disconnect all peers at ns.Stop()

* les: added comments and fixed signed broadcasts

* les: removed unused parameter, fixed tests
---
 les/client_handler.go             |   8 +-
 les/clientpool.go                 |  49 +++----
 les/clientpool_test.go            |  43 ++++--
 les/enr_entry.go                  |   2 +-
 les/lespay/server/prioritypool.go |   6 +-
 les/peer.go                       | 223 ++++--------------------------
 les/protocol.go                   |  10 +-
 les/server.go                     |  64 ++++++---
 les/server_handler.go             | 150 +++++++++++++-------
 les/test_helper.go                |  16 ++-
 10 files changed, 239 insertions(+), 332 deletions(-)

diff --git a/les/client_handler.go b/les/client_handler.go
index cfeec7a03..77a0ea5c6 100644
--- a/les/client_handler.go
+++ b/les/client_handler.go
@@ -102,13 +102,7 @@ func (h *clientHandler) handle(p *serverPeer) error {
 	p.Log().Debug("Light Ethereum peer connected", "name", p.Name())
 
 	// Execute the LES handshake
-	var (
-		head   = h.backend.blockchain.CurrentHeader()
-		hash   = head.Hash()
-		number = head.Number.Uint64()
-		td     = h.backend.blockchain.GetTd(hash, number)
-	)
-	if err := p.Handshake(td, hash, number, h.backend.blockchain.Genesis().Hash(), nil); err != nil {
+	if err := p.Handshake(h.backend.blockchain.Genesis().Hash()); err != nil {
 		p.Log().Debug("Light Ethereum handshake failed", "err", err)
 		return err
 	}
diff --git a/les/clientpool.go b/les/clientpool.go
index 4f6e3fafe..da0db6e62 100644
--- a/les/clientpool.go
+++ b/les/clientpool.go
@@ -18,7 +18,6 @@ package les
 
 import (
 	"fmt"
-	"reflect"
 	"sync"
 	"time"
 
@@ -46,19 +45,6 @@ const (
 	inactiveTimeout      = time.Second * 10
 )
 
-var (
-	clientPoolSetup     = &nodestate.Setup{}
-	clientField         = clientPoolSetup.NewField("clientInfo", reflect.TypeOf(&clientInfo{}))
-	connAddressField    = clientPoolSetup.NewField("connAddr", reflect.TypeOf(""))
-	balanceTrackerSetup = lps.NewBalanceTrackerSetup(clientPoolSetup)
-	priorityPoolSetup   = lps.NewPriorityPoolSetup(clientPoolSetup)
-)
-
-func init() {
-	balanceTrackerSetup.Connect(connAddressField, priorityPoolSetup.CapacityField)
-	priorityPoolSetup.Connect(balanceTrackerSetup.BalanceField, balanceTrackerSetup.UpdateFlag) // NodeBalance implements nodePriority
-}
-
 // clientPool implements a client database that assigns a priority to each client
 // based on a positive and negative balance. Positive balance is externally assigned
 // to prioritized clients and is decreased with connection time and processed
@@ -119,8 +105,7 @@ type clientInfo struct {
 }
 
 // newClientPool creates a new client pool
-func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool {
-	ns := nodestate.NewNodeStateMachine(nil, nil, clock, clientPoolSetup)
+func newClientPool(ns *nodestate.NodeStateMachine, lespayDb ethdb.Database, minCap uint64, connectedBias time.Duration, clock mclock.Clock, removePeer func(enode.ID)) *clientPool {
 	pool := &clientPool{
 		ns:                  ns,
 		BalanceTrackerSetup: balanceTrackerSetup,
@@ -147,7 +132,7 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du
 	})
 
 	ns.SubscribeState(pool.ActiveFlag.Or(pool.PriorityFlag), func(node *enode.Node, oldState, newState nodestate.Flags) {
-		c, _ := ns.GetField(node, clientField).(*clientInfo)
+		c, _ := ns.GetField(node, clientInfoField).(*clientInfo)
 		if c == nil {
 			return
 		}
@@ -172,7 +157,7 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du
 		if oldState.Equals(pool.ActiveFlag) && newState.Equals(pool.InactiveFlag) {
 			clientDeactivatedMeter.Mark(1)
 			log.Debug("Client deactivated", "id", node.ID())
-			c, _ := ns.GetField(node, clientField).(*clientInfo)
+			c, _ := ns.GetField(node, clientInfoField).(*clientInfo)
 			if c == nil || !c.peer.allowInactive() {
 				pool.removePeer(node.ID())
 			}
@@ -190,13 +175,11 @@ func newClientPool(lespayDb ethdb.Database, minCap uint64, connectedBias time.Du
 		newCap, _ := newValue.(uint64)
 		totalConnected += newCap - oldCap
 		totalConnectedGauge.Update(int64(totalConnected))
-		c, _ := ns.GetField(node, clientField).(*clientInfo)
+		c, _ := ns.GetField(node, clientInfoField).(*clientInfo)
 		if c != nil {
 			c.peer.updateCapacity(newCap)
 		}
 	})
-
-	ns.Start()
 	return pool
 }
 
@@ -210,7 +193,6 @@ func (f *clientPool) stop() {
 		f.disconnectNode(node)
 	})
 	f.bt.Stop()
-	f.ns.Stop()
 }
 
 // connect should be called after a successful handshake. If the connection was
@@ -225,7 +207,7 @@ func (f *clientPool) connect(peer clientPoolPeer) (uint64, error) {
 	}
 	// Dedup connected peers.
 	node, freeID := peer.Node(), peer.freeClientId()
-	if f.ns.GetField(node, clientField) != nil {
+	if f.ns.GetField(node, clientInfoField) != nil {
 		log.Debug("Client already connected", "address", freeID, "id", node.ID().String())
 		return 0, fmt.Errorf("Client already connected address=%s id=%s", freeID, node.ID().String())
 	}
@@ -237,7 +219,7 @@ func (f *clientPool) connect(peer clientPoolPeer) (uint64, error) {
 		connected:   true,
 		connectedAt: now,
 	}
-	f.ns.SetField(node, clientField, c)
+	f.ns.SetField(node, clientInfoField, c)
 	f.ns.SetField(node, connAddressField, freeID)
 	if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance == nil {
 		f.disconnect(peer)
@@ -280,7 +262,7 @@ func (f *clientPool) disconnect(p clientPoolPeer) {
 // disconnectNode removes node fields and flags related to connected status
 func (f *clientPool) disconnectNode(node *enode.Node) {
 	f.ns.SetField(node, connAddressField, nil)
-	f.ns.SetField(node, clientField, nil)
+	f.ns.SetField(node, clientInfoField, nil)
 }
 
 // setDefaultFactors sets the default price factors applied to subsequently connected clients
@@ -299,7 +281,8 @@ func (f *clientPool) capacityInfo() (uint64, uint64, uint64) {
 	defer f.lock.Unlock()
 
 	// total priority active cap will be supported when the token issuer module is added
-	return f.capLimit, f.pp.ActiveCapacity(), 0
+	_, activeCap := f.pp.Active()
+	return f.capLimit, activeCap, 0
 }
 
 // setLimits sets the maximum number and total capacity of connected clients,
@@ -314,13 +297,13 @@ func (f *clientPool) setLimits(totalConn int, totalCap uint64) {
 
 // setCapacity sets the assigned capacity of a connected client
 func (f *clientPool) setCapacity(node *enode.Node, freeID string, capacity uint64, bias time.Duration, setCap bool) (uint64, error) {
-	c, _ := f.ns.GetField(node, clientField).(*clientInfo)
+	c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo)
 	if c == nil {
 		if setCap {
 			return 0, fmt.Errorf("client %064x is not connected", node.ID())
 		}
 		c = &clientInfo{node: node}
-		f.ns.SetField(node, clientField, c)
+		f.ns.SetField(node, clientInfoField, c)
 		f.ns.SetField(node, connAddressField, freeID)
 		if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance == nil {
 			log.Error("BalanceField is missing", "node", node.ID())
@@ -328,7 +311,7 @@ func (f *clientPool) setCapacity(node *enode.Node, freeID string, capacity uint6
 		}
 		defer func() {
 			f.ns.SetField(node, connAddressField, nil)
-			f.ns.SetField(node, clientField, nil)
+			f.ns.SetField(node, clientInfoField, nil)
 		}()
 	}
 	var (
@@ -370,7 +353,7 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) {
 
 	if len(ids) == 0 {
 		f.ns.ForEach(nodestate.Flags{}, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
-			c, _ := f.ns.GetField(node, clientField).(*clientInfo)
+			c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo)
 			if c != nil {
 				cb(c)
 			}
@@ -381,12 +364,12 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) {
 			if node == nil {
 				node = enode.SignNull(&enr.Record{}, id)
 			}
-			c, _ := f.ns.GetField(node, clientField).(*clientInfo)
+			c, _ := f.ns.GetField(node, clientInfoField).(*clientInfo)
 			if c != nil {
 				cb(c)
 			} else {
 				c = &clientInfo{node: node}
-				f.ns.SetField(node, clientField, c)
+				f.ns.SetField(node, clientInfoField, c)
 				f.ns.SetField(node, connAddressField, "")
 				if c.balance, _ = f.ns.GetField(node, f.BalanceField).(*lps.NodeBalance); c.balance != nil {
 					cb(c)
@@ -394,7 +377,7 @@ func (f *clientPool) forClients(ids []enode.ID, cb func(client *clientInfo)) {
 					log.Error("BalanceField is missing")
 				}
 				f.ns.SetField(node, connAddressField, nil)
-				f.ns.SetField(node, clientField, nil)
+				f.ns.SetField(node, clientInfoField, nil)
 			}
 		}
 	}
diff --git a/les/clientpool_test.go b/les/clientpool_test.go
index cfd1486b4..b1c38d374 100644
--- a/les/clientpool_test.go
+++ b/les/clientpool_test.go
@@ -64,6 +64,11 @@ type poolTestPeer struct {
 	inactiveAllowed bool
 }
 
+func testStateMachine() *nodestate.NodeStateMachine {
+	return nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup)
+
+}
+
 func newPoolTestPeer(i int, disconnCh chan int) *poolTestPeer {
 	return &poolTestPeer{
 		index:     i,
@@ -91,7 +96,7 @@ func (i *poolTestPeer) allowInactive() bool {
 }
 
 func getBalance(pool *clientPool, p *poolTestPeer) (pos, neg uint64) {
-	temp := pool.ns.GetField(p.node, clientField) == nil
+	temp := pool.ns.GetField(p.node, clientInfoField) == nil
 	if temp {
 		pool.ns.SetField(p.node, connAddressField, p.freeClientId())
 	}
@@ -128,8 +133,9 @@ func testClientPool(t *testing.T, activeLimit, clientCount, paidCount int, rando
 		disconnFn = func(id enode.ID) {
 			disconnCh <- int(id[0]) + int(id[1])<<8
 		}
-		pool = newClientPool(db, 1, 0, &clock, disconnFn)
+		pool = newClientPool(testStateMachine(), db, 1, 0, &clock, disconnFn)
 	)
+	pool.ns.Start()
 
 	pool.setLimits(activeLimit, uint64(activeLimit))
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -233,7 +239,8 @@ func TestConnectPaidClient(t *testing.T) {
 		clock mclock.Simulated
 		db    = rawdb.NewMemoryDatabase()
 	)
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10))
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -248,7 +255,8 @@ func TestConnectPaidClientToSmallPool(t *testing.T) {
 		clock mclock.Simulated
 		db    = rawdb.NewMemoryDatabase()
 	)
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -266,7 +274,8 @@ func TestConnectPaidClientToFullPool(t *testing.T) {
 		db    = rawdb.NewMemoryDatabase()
 	)
 	removeFn := func(enode.ID) {} // Noop
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn)
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn)
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -295,7 +304,8 @@ func TestPaidClientKickedOut(t *testing.T) {
 	removeFn := func(id enode.ID) {
 		kickedCh <- int(id[0])
 	}
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn)
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn)
+	pool.ns.Start()
 	pool.bt.SetExpirationTCs(0, 0)
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
@@ -325,7 +335,8 @@ func TestConnectFreeClient(t *testing.T) {
 		clock mclock.Simulated
 		db    = rawdb.NewMemoryDatabase()
 	)
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10))
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -341,7 +352,8 @@ func TestConnectFreeClientToFullPool(t *testing.T) {
 		db    = rawdb.NewMemoryDatabase()
 	)
 	removeFn := func(enode.ID) {} // Noop
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn)
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn)
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -370,7 +382,8 @@ func TestFreeClientKickedOut(t *testing.T) {
 		kicked = make(chan int, 100)
 	)
 	removeFn := func(id enode.ID) { kicked <- int(id[0]) }
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn)
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn)
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -411,7 +424,8 @@ func TestPositiveBalanceCalculation(t *testing.T) {
 		kicked = make(chan int, 10)
 	)
 	removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn)
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn)
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -434,7 +448,8 @@ func TestDowngradePriorityClient(t *testing.T) {
 		kicked = make(chan int, 10)
 	)
 	removeFn := func(id enode.ID) { kicked <- int(id[0]) } // Noop
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, removeFn)
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, removeFn)
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1, CapacityFactor: 0, RequestFactor: 1})
@@ -468,7 +483,8 @@ func TestNegativeBalanceCalculation(t *testing.T) {
 		clock mclock.Simulated
 		db    = rawdb.NewMemoryDatabase()
 	)
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(10, uint64(10)) // Total capacity limit is 10
 	pool.setDefaultFactors(lps.PriceFactors{TimeFactor: 1e-3, CapacityFactor: 0, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 1e-3, CapacityFactor: 0, RequestFactor: 1})
@@ -503,7 +519,8 @@ func TestInactiveClient(t *testing.T) {
 		clock mclock.Simulated
 		db    = rawdb.NewMemoryDatabase()
 	)
-	pool := newClientPool(db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool := newClientPool(testStateMachine(), db, 1, defaultConnectedBias, &clock, func(id enode.ID) {})
+	pool.ns.Start()
 	defer pool.stop()
 	pool.setLimits(2, uint64(2))
 
diff --git a/les/enr_entry.go b/les/enr_entry.go
index 65d0d1fdb..11e6273be 100644
--- a/les/enr_entry.go
+++ b/les/enr_entry.go
@@ -36,7 +36,7 @@ func (e lesEntry) ENRKey() string {
 
 // setupDiscovery creates the node discovery source for the eth protocol.
 func (eth *LightEthereum) setupDiscovery(cfg *p2p.Config) (enode.Iterator, error) {
-	if /*cfg.NoDiscovery || */ len(eth.config.DiscoveryURLs) == 0 {
+	if cfg.NoDiscovery || len(eth.config.DiscoveryURLs) == 0 {
 		return nil, nil
 	}
 	client := dnsdisc.NewClient(dnsdisc.Config{})
diff --git a/les/lespay/server/prioritypool.go b/les/lespay/server/prioritypool.go
index 52224e093..c0c33840c 100644
--- a/les/lespay/server/prioritypool.go
+++ b/les/lespay/server/prioritypool.go
@@ -253,12 +253,12 @@ func (pp *PriorityPool) SetActiveBias(bias time.Duration) {
 	pp.tryActivate()
 }
 
-// ActiveCapacity returns the total capacity of currently active nodes
-func (pp *PriorityPool) ActiveCapacity() uint64 {
+// Active returns the number and total capacity of currently active nodes
+func (pp *PriorityPool) Active() (uint64, uint64) {
 	pp.lock.Lock()
 	defer pp.lock.Unlock()
 
-	return pp.activeCap
+	return pp.activeCount, pp.activeCap
 }
 
 // inactiveSetIndex callback updates ppNodeInfo item index in inactiveQueue
diff --git a/les/peer.go b/les/peer.go
index 0549daf9a..2b0117bed 100644
--- a/les/peer.go
+++ b/les/peer.go
@@ -126,7 +126,7 @@ type peerCommons struct {
 	frozen       uint32    // Flag whether the peer is frozen.
 	announceType uint64    // New block announcement type.
 	serving      uint32    // The status indicates the peer is served.
-	headInfo     blockInfo // Latest block information.
+	headInfo     blockInfo // Last announced block information.
 
 	// Background task queue for caching peer tasks and executing in order.
 	sendQueue *utils.ExecQueue
@@ -255,6 +255,8 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g
 	// Add some basic handshake fields
 	send = send.add("protocolVersion", uint64(p.version))
 	send = send.add("networkId", p.network)
+	// Note: the head info announced at handshake is only used in case of server peers
+	// but dummy values are still announced by clients for compatibility with older servers
 	send = send.add("headTd", td)
 	send = send.add("headHash", head)
 	send = send.add("headNum", headNum)
@@ -273,24 +275,14 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g
 	if size > allowedUpdateBytes {
 		return errResp(ErrRequestRejected, "")
 	}
-	var rGenesis, rHash common.Hash
-	var rVersion, rNetwork, rNum uint64
-	var rTd *big.Int
+	var rGenesis common.Hash
+	var rVersion, rNetwork uint64
 	if err := recv.get("protocolVersion", &rVersion); err != nil {
 		return err
 	}
 	if err := recv.get("networkId", &rNetwork); err != nil {
 		return err
 	}
-	if err := recv.get("headTd", &rTd); err != nil {
-		return err
-	}
-	if err := recv.get("headHash", &rHash); err != nil {
-		return err
-	}
-	if err := recv.get("headNum", &rNum); err != nil {
-		return err
-	}
 	if err := recv.get("genesisHash", &rGenesis); err != nil {
 		return err
 	}
@@ -303,7 +295,6 @@ func (p *peerCommons) handshake(td *big.Int, head common.Hash, headNum uint64, g
 	if int(rVersion) != p.version {
 		return errResp(ErrProtocolVersionMismatch, "%d (!= %d)", rVersion, p.version)
 	}
-	p.headInfo = blockInfo{Hash: rHash, Number: rNum, Td: rTd}
 	if recvCallback != nil {
 		return recvCallback(recv)
 	}
@@ -569,9 +560,11 @@ func (p *serverPeer) updateHead(hash common.Hash, number uint64, td *big.Int) {
 }
 
 // Handshake executes the les protocol handshake, negotiating version number,
-// network IDs, difficulties, head and genesis blocks.
-func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error {
-	return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) {
+// network IDs and genesis blocks.
+func (p *serverPeer) Handshake(genesis common.Hash) error {
+	// Note: there is no need to share local head with a server but older servers still
+	// require these fields so we announce zero values.
+	return p.handshake(common.Big0, common.Hash{}, 0, genesis, func(lists *keyValueList) {
 		// Add some client-specific handshake fields
 		//
 		// Enable signed announcement randomly even the server is not trusted.
@@ -581,6 +574,21 @@ func (p *serverPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, ge
 		}
 		*lists = (*lists).add("announceType", p.announceType)
 	}, func(recv keyValueMap) error {
+		var (
+			rHash common.Hash
+			rNum  uint64
+			rTd   *big.Int
+		)
+		if err := recv.get("headTd", &rTd); err != nil {
+			return err
+		}
+		if err := recv.get("headHash", &rHash); err != nil {
+			return err
+		}
+		if err := recv.get("headNum", &rNum); err != nil {
+			return err
+		}
+		p.headInfo = blockInfo{Hash: rHash, Number: rNum, Td: rTd}
 		if recv.get("serveChainSince", &p.chainSince) != nil {
 			p.onlyAnnounce = true
 		}
@@ -937,6 +945,9 @@ func (p *clientPeer) freezeClient() {
 // Handshake executes the les protocol handshake, negotiating version number,
 // network IDs, difficulties, head and genesis blocks.
 func (p *clientPeer) Handshake(td *big.Int, head common.Hash, headNum uint64, genesis common.Hash, server *LesServer) error {
+	// Note: clientPeer.headInfo should contain the last head announced to the client by us.
+	// The values announced in the handshake are dummy values for compatibility reasons and should be ignored.
+	p.headInfo = blockInfo{Hash: head, Number: headNum, Td: td}
 	return p.handshake(td, head, headNum, genesis, func(lists *keyValueList) {
 		// Add some information which services server can offer.
 		if !server.config.UltraLightOnlyAnnounce {
@@ -1009,145 +1020,6 @@ type serverPeerSubscriber interface {
 	unregisterPeer(*serverPeer)
 }
 
-// clientPeerSubscriber is an interface to notify services about added or
-// removed client peers
-type clientPeerSubscriber interface {
-	registerPeer(*clientPeer)
-	unregisterPeer(*clientPeer)
-}
-
-// clientPeerSet represents the set of active client peers currently
-// participating in the Light Ethereum sub-protocol.
-type clientPeerSet struct {
-	peers map[string]*clientPeer
-	// subscribers is a batch of subscribers and peerset will notify
-	// these subscribers when the peerset changes(new client peer is
-	// added or removed)
-	subscribers []clientPeerSubscriber
-	closed      bool
-	lock        sync.RWMutex
-}
-
-// newClientPeerSet creates a new peer set to track the client peers.
-func newClientPeerSet() *clientPeerSet {
-	return &clientPeerSet{peers: make(map[string]*clientPeer)}
-}
-
-// subscribe adds a service to be notified about added or removed
-// peers and also register all active peers into the given service.
-func (ps *clientPeerSet) subscribe(sub clientPeerSubscriber) {
-	ps.lock.Lock()
-	defer ps.lock.Unlock()
-
-	ps.subscribers = append(ps.subscribers, sub)
-	for _, p := range ps.peers {
-		sub.registerPeer(p)
-	}
-}
-
-// unSubscribe removes the specified service from the subscriber pool.
-func (ps *clientPeerSet) unSubscribe(sub clientPeerSubscriber) {
-	ps.lock.Lock()
-	defer ps.lock.Unlock()
-
-	for i, s := range ps.subscribers {
-		if s == sub {
-			ps.subscribers = append(ps.subscribers[:i], ps.subscribers[i+1:]...)
-			return
-		}
-	}
-}
-
-// register adds a new peer into the peer set, or returns an error if the
-// peer is already known.
-func (ps *clientPeerSet) register(peer *clientPeer) error {
-	ps.lock.Lock()
-	defer ps.lock.Unlock()
-
-	if ps.closed {
-		return errClosed
-	}
-	if _, exist := ps.peers[peer.id]; exist {
-		return errAlreadyRegistered
-	}
-	ps.peers[peer.id] = peer
-	for _, sub := range ps.subscribers {
-		sub.registerPeer(peer)
-	}
-	return nil
-}
-
-// unregister removes a remote peer from the peer set, disabling any further
-// actions to/from that particular entity. It also initiates disconnection
-// at the networking layer.
-func (ps *clientPeerSet) unregister(id string) error {
-	ps.lock.Lock()
-	defer ps.lock.Unlock()
-
-	p, ok := ps.peers[id]
-	if !ok {
-		return errNotRegistered
-	}
-	delete(ps.peers, id)
-	for _, sub := range ps.subscribers {
-		sub.unregisterPeer(p)
-	}
-	p.Peer.Disconnect(p2p.DiscRequested)
-	return nil
-}
-
-// ids returns a list of all registered peer IDs
-func (ps *clientPeerSet) ids() []string {
-	ps.lock.RLock()
-	defer ps.lock.RUnlock()
-
-	var ids []string
-	for id := range ps.peers {
-		ids = append(ids, id)
-	}
-	return ids
-}
-
-// peer retrieves the registered peer with the given id.
-func (ps *clientPeerSet) peer(id string) *clientPeer {
-	ps.lock.RLock()
-	defer ps.lock.RUnlock()
-
-	return ps.peers[id]
-}
-
-// len returns if the current number of peers in the set.
-func (ps *clientPeerSet) len() int {
-	ps.lock.RLock()
-	defer ps.lock.RUnlock()
-
-	return len(ps.peers)
-}
-
-// allClientPeers returns all client peers in a list.
-func (ps *clientPeerSet) allPeers() []*clientPeer {
-	ps.lock.RLock()
-	defer ps.lock.RUnlock()
-
-	list := make([]*clientPeer, 0, len(ps.peers))
-	for _, p := range ps.peers {
-		list = append(list, p)
-	}
-	return list
-}
-
-// close disconnects all peers. No new peers can be registered
-// after close has returned.
-func (ps *clientPeerSet) close() {
-	ps.lock.Lock()
-	defer ps.lock.Unlock()
-
-	for _, p := range ps.peers {
-		p.Disconnect(p2p.DiscQuitting)
-	}
-	ps.closed = true
-}
-
 // serverPeerSet represents the set of active server peers currently
 // participating in the Light Ethereum sub-protocol.
 type serverPeerSet struct {
@@ -1298,42 +1170,3 @@ func (ps *serverPeerSet) close() {
 	}
 	ps.closed = true
 }
-
-// serverSet is a special set which contains all connected les servers.
-// Les servers will also be discovered by discovery protocol because they
-// also run the LES protocol. We can't drop them although they are useless
-// for us(server) but for other protocols(e.g. ETH) upon the devp2p they
-// may be useful.
-type serverSet struct {
-	lock   sync.Mutex
-	set    map[string]*clientPeer
-	closed bool
-}
-
-func newServerSet() *serverSet {
-	return &serverSet{set: make(map[string]*clientPeer)}
-}
-
-func (s *serverSet) register(peer *clientPeer) error {
-	s.lock.Lock()
-	defer s.lock.Unlock()
-
-	if s.closed {
-		return errClosed
-	}
-	if _, exist := s.set[peer.id]; exist {
-		return errAlreadyRegistered
-	}
-	s.set[peer.id] = peer
-	return nil
-}
-
-func (s *serverSet) close() {
-	s.lock.Lock()
-	defer s.lock.Unlock()
-
-	for _, p := range s.set {
-		p.Disconnect(p2p.DiscQuitting)
-	}
-	s.closed = true
-}
diff --git a/les/protocol.go b/les/protocol.go
index 4fd19f9be..19a9561ce 100644
--- a/les/protocol.go
+++ b/les/protocol.go
@@ -174,12 +174,6 @@ var errorToString = map[int]string{
 	ErrMissingKey:              "Key missing from list",
 }
 
-type announceBlock struct {
-	Hash   common.Hash // Hash of one particular block being announced
-	Number uint64      // Number of one particular block being announced
-	Td     *big.Int    // Total difficulty of one particular block being announced
-}
-
 // announceData is the network packet for the block announcements.
 type announceData struct {
 	Hash       common.Hash // Hash of one particular block being announced
@@ -199,7 +193,7 @@ func (a *announceData) sanityCheck() error {
 
 // sign adds a signature to the block announcement by the given privKey
 func (a *announceData) sign(privKey *ecdsa.PrivateKey) {
-	rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td})
+	rlp, _ := rlp.EncodeToBytes(blockInfo{a.Hash, a.Number, a.Td})
 	sig, _ := crypto.Sign(crypto.Keccak256(rlp), privKey)
 	a.Update = a.Update.add("sign", sig)
 }
@@ -210,7 +204,7 @@ func (a *announceData) checkSignature(id enode.ID, update keyValueMap) error {
 	if err := update.get("sign", &sig); err != nil {
 		return err
 	}
-	rlp, _ := rlp.EncodeToBytes(announceBlock{a.Hash, a.Number, a.Td})
+	rlp, _ := rlp.EncodeToBytes(blockInfo{a.Hash, a.Number, a.Td})
 	recPubkey, err := crypto.SigToPub(crypto.Keccak256(rlp), sig)
 	if err != nil {
 		return err
diff --git a/les/server.go b/les/server.go
index 225a7ad1f..cbedce136 100644
--- a/les/server.go
+++ b/les/server.go
@@ -18,6 +18,7 @@ package les
 
 import (
 	"crypto/ecdsa"
+	"reflect"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common/mclock"
@@ -31,17 +32,32 @@ import (
 	"github.com/ethereum/go-ethereum/p2p/discv5"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/enr"
+	"github.com/ethereum/go-ethereum/p2p/nodestate"
 	"github.com/ethereum/go-ethereum/params"
 	"github.com/ethereum/go-ethereum/rpc"
 )
 
+var (
+	serverSetup         = &nodestate.Setup{}
+	clientPeerField     = serverSetup.NewField("clientPeer", reflect.TypeOf(&clientPeer{}))
+	clientInfoField     = serverSetup.NewField("clientInfo", reflect.TypeOf(&clientInfo{}))
+	connAddressField    = serverSetup.NewField("connAddr", reflect.TypeOf(""))
+	balanceTrackerSetup = lps.NewBalanceTrackerSetup(serverSetup)
+	priorityPoolSetup   = lps.NewPriorityPoolSetup(serverSetup)
+)
+
+func init() {
+	balanceTrackerSetup.Connect(connAddressField, priorityPoolSetup.CapacityField)
+	priorityPoolSetup.Connect(balanceTrackerSetup.BalanceField, balanceTrackerSetup.UpdateFlag) // NodeBalance implements nodePriority
+}
+
 type LesServer struct {
 	lesCommons
 
+	ns          *nodestate.NodeStateMachine
 	archiveMode bool // Flag whether the ethereum node runs in archive mode.
-	peers       *clientPeerSet
-	serverset   *serverSet
 	handler     *serverHandler
+	broadcaster *broadcaster
 	lesTopics   []discv5.Topic
 	privateKey  *ecdsa.PrivateKey
 
@@ -60,6 +76,7 @@ type LesServer struct {
 }
 
 func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesServer, error) {
+	ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup)
 	// Collect les protocol version information supported by local node.
 	lesTopics := make([]discv5.Topic, len(AdvertiseProtocolVersions))
 	for i, pv := range AdvertiseProtocolVersions {
@@ -83,9 +100,9 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer
 			bloomTrieIndexer: light.NewBloomTrieIndexer(e.ChainDb(), nil, params.BloomBitsBlocks, params.BloomTrieFrequency, true),
 			closeCh:          make(chan struct{}),
 		},
+		ns:           ns,
 		archiveMode:  e.ArchiveMode(),
-		peers:        newClientPeerSet(),
-		serverset:    newServerSet(),
+		broadcaster:  newBroadcaster(ns),
 		lesTopics:    lesTopics,
 		fcManager:    flowcontrol.NewClientManager(nil, &mclock.System{}),
 		servingQueue: newServingQueue(int64(time.Millisecond*10), float64(config.LightServ)/100),
@@ -116,7 +133,7 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer
 		srv.maxCapacity = totalRecharge
 	}
 	srv.fcManager.SetCapacityLimits(srv.minCapacity, srv.maxCapacity, srv.minCapacity*2)
-	srv.clientPool = newClientPool(srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, func(id enode.ID) { go srv.peers.unregister(id.String()) })
+	srv.clientPool = newClientPool(ns, srv.chainDb, srv.minCapacity, defaultConnectedBias, mclock.System{}, srv.dropClient)
 	srv.clientPool.setDefaultFactors(lps.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1}, lps.PriceFactors{TimeFactor: 0, CapacityFactor: 1, RequestFactor: 1})
 
 	checkpoint := srv.latestLocalCheckpoint()
@@ -130,6 +147,13 @@ func NewLesServer(node *node.Node, e *eth.Ethereum, config *eth.Config) (*LesSer
 	node.RegisterAPIs(srv.APIs())
 	node.RegisterLifecycle(srv)
 
+	// disconnect all peers at nsm shutdown
+	ns.SubscribeField(clientPeerField, func(node *enode.Node, state nodestate.Flags, oldValue, newValue interface{}) {
+		if state.Equals(serverSetup.OfflineFlag()) && oldValue != nil {
+			oldValue.(*clientPeer).Peer.Disconnect(p2p.DiscRequested)
+		}
+	})
+	ns.Start()
 	return srv, nil
 }
 
@@ -158,7 +182,7 @@ func (s *LesServer) APIs() []rpc.API {
 
 func (s *LesServer) Protocols() []p2p.Protocol {
 	ps := s.makeProtocols(ServerProtocolVersions, s.handler.runPeer, func(id enode.ID) interface{} {
-		if p := s.peers.peer(id.String()); p != nil {
+		if p := s.getClient(id); p != nil {
 			return p.Info()
 		}
 		return nil
@@ -173,6 +197,7 @@ func (s *LesServer) Protocols() []p2p.Protocol {
 // Start starts the LES server
 func (s *LesServer) Start() error {
 	s.privateKey = s.p2pSrv.PrivateKey
+	s.broadcaster.setSignerKey(s.privateKey)
 	s.handler.start()
 
 	s.wg.Add(1)
@@ -198,19 +223,11 @@ func (s *LesServer) Start() error {
 func (s *LesServer) Stop() error {
 	close(s.closeCh)
 
-	// Disconnect existing connections with other LES servers.
-	s.serverset.close()
-
-	// Disconnect existing sessions.
-	// This also closes the gate for any new registrations on the peer set.
-	// sessions which are already established but not added to pm.peers yet
-	// will exit when they try to register.
-	s.peers.close()
-
+	s.clientPool.stop()
+	s.ns.Stop()
 	s.fcManager.Stop()
 	s.costTracker.stop()
 	s.handler.stop()
-	s.clientPool.stop() // client pool should be closed after handler.
 	s.servingQueue.stop()
 
 	// Note, bloom trie indexer is closed by parent bloombits indexer.
@@ -279,3 +296,18 @@ func (s *LesServer) capacityManagement() {
 		}
 	}
 }
+
+func (s *LesServer) getClient(id enode.ID) *clientPeer {
+	if node := s.ns.GetNode(id); node != nil {
+		if p, ok := s.ns.GetField(node, clientPeerField).(*clientPeer); ok {
+			return p
+		}
+	}
+	return nil
+}
+
+func (s *LesServer) dropClient(id enode.ID) {
+	if p := s.getClient(id); p != nil {
+		p.Peer.Disconnect(p2p.DiscRequested)
+	}
+}
diff --git a/les/server_handler.go b/les/server_handler.go
index 583df9600..c657d37f1 100644
--- a/les/server_handler.go
+++ b/les/server_handler.go
@@ -17,6 +17,7 @@
 package les
 
 import (
+	"crypto/ecdsa"
 	"encoding/binary"
 	"encoding/json"
 	"errors"
@@ -36,6 +37,8 @@ import (
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/metrics"
 	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/nodestate"
 	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/trie"
 )
@@ -91,7 +94,7 @@ func newServerHandler(server *LesServer, blockchain *core.BlockChain, chainDb et
 // start starts the server handler.
 func (h *serverHandler) start() {
 	h.wg.Add(1)
-	go h.broadcastHeaders()
+	go h.broadcastLoop()
 }
 
 // stop stops the server handler.
@@ -123,47 +126,58 @@ func (h *serverHandler) handle(p *clientPeer) error {
 		p.Log().Debug("Light Ethereum handshake failed", "err", err)
 		return err
 	}
-	if p.server {
-		if err := h.server.serverset.register(p); err != nil {
-			return err
-		}
-		// connected to another server, no messages expected, just wait for disconnection
-		_, err := p.rw.ReadMsg()
-		return err
-	}
 	// Reject light clients if server is not synced.
 	if !h.synced() {
 		p.Log().Debug("Light server not synced, rejecting peer")
 		return p2p.DiscRequested
 	}
-	defer p.fcClient.Disconnect()
+	var registered bool
+	if err := h.server.ns.Operation(func() {
+		if h.server.ns.GetField(p.Node(), clientPeerField) != nil {
+			registered = true
+		} else {
+			h.server.ns.SetFieldSub(p.Node(), clientPeerField, p)
+		}
+	}); err != nil {
+		return err
+	}
+	if registered {
+		return errAlreadyRegistered
+	}
+
+	defer func() {
+		h.server.ns.SetField(p.Node(), clientPeerField, nil)
+		if p.fcClient != nil { // is nil when connecting another server
+			p.fcClient.Disconnect()
+		}
+	}()
+	if p.server {
+		// connected to another server, no messages expected, just wait for disconnection
+		_, err := p.rw.ReadMsg()
+		return err
+	}
 
 	// Disconnect the inbound peer if it's rejected by clientPool
 	if cap, err := h.server.clientPool.connect(p); cap != p.fcParams.MinRecharge || err != nil {
 		p.Log().Debug("Light Ethereum peer rejected", "err", errFullClientPool)
 		return errFullClientPool
 	}
-	p.balance, _ = h.server.clientPool.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance)
+	p.balance, _ = h.server.ns.GetField(p.Node(), h.server.clientPool.BalanceField).(*lps.NodeBalance)
 	if p.balance == nil {
 		return p2p.DiscRequested
 	}
-	// Register the peer locally
-	if err := h.server.peers.register(p); err != nil {
-		h.server.clientPool.disconnect(p)
-		p.Log().Error("Light Ethereum peer registration failed", "err", err)
-		return err
-	}
-	clientConnectionGauge.Update(int64(h.server.peers.len()))
+	activeCount, _ := h.server.clientPool.pp.Active()
+	clientConnectionGauge.Update(int64(activeCount))
 
 	var wg sync.WaitGroup // Wait group used to track all in-flight task routines.
 
 	connectedAt := mclock.Now()
 	defer func() {
 		wg.Wait() // Ensure all background task routines have exited.
-		h.server.peers.unregister(p.id)
 		h.server.clientPool.disconnect(p)
 		p.balance = nil
-		clientConnectionGauge.Update(int64(h.server.peers.len()))
+		activeCount, _ := h.server.clientPool.pp.Active()
+		clientConnectionGauge.Update(int64(activeCount))
 		connectionTimer.Update(time.Duration(mclock.Now() - connectedAt))
 	}()
 	// Mark the peer starts to be served.
@@ -911,11 +925,11 @@ func (h *serverHandler) txStatus(hash common.Hash) light.TxStatus {
 	return stat
 }
 
-// broadcastHeaders broadcasts new block information to all connected light
+// broadcastLoop broadcasts new block information to all connected light
 // clients. According to the agreement between client and server, server should
 // only broadcast new announcement if the total difficulty is higher than the
 // last one. Besides server will add the signature if client requires.
-func (h *serverHandler) broadcastHeaders() {
+func (h *serverHandler) broadcastLoop() {
 	defer h.wg.Done()
 
 	headCh := make(chan core.ChainHeadEvent, 10)
@@ -929,10 +943,6 @@ func (h *serverHandler) broadcastHeaders() {
 	for {
 		select {
 		case ev := <-headCh:
-			peers := h.server.peers.allPeers()
-			if len(peers) == 0 {
-				continue
-			}
 			header := ev.Block.Header()
 			hash, number := header.Hash(), header.Number.Uint64()
 			td := h.blockchain.GetTd(hash, number)
@@ -944,33 +954,73 @@ func (h *serverHandler) broadcastHeaders() {
 				reorg = lastHead.Number.Uint64() - rawdb.FindCommonAncestor(h.chainDb, header, lastHead).Number.Uint64()
 			}
 			lastHead, lastTd = header, td
-
 			log.Debug("Announcing block to peers", "number", number, "hash", hash, "td", td, "reorg", reorg)
-			var (
-				signed         bool
-				signedAnnounce announceData
-			)
-			announce := announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg}
-			for _, p := range peers {
-				p := p
-				switch p.announceType {
-				case announceTypeSimple:
-					if !p.queueSend(func() { p.sendAnnounce(announce) }) {
-						log.Debug("Drop announcement because queue is full", "number", number, "hash", hash)
-					}
-				case announceTypeSigned:
-					if !signed {
-						signedAnnounce = announce
-						signedAnnounce.sign(h.server.privateKey)
-						signed = true
-					}
-					if !p.queueSend(func() { p.sendAnnounce(signedAnnounce) }) {
-						log.Debug("Drop announcement because queue is full", "number", number, "hash", hash)
-					}
-				}
-			}
+			h.server.broadcaster.broadcast(announceData{Hash: hash, Number: number, Td: td, ReorgDepth: reorg})
 		case <-h.closeCh:
 			return
 		}
 	}
 }
+
+// broadcaster sends new header announcements to active client peers
+type broadcaster struct {
+	ns                           *nodestate.NodeStateMachine
+	privateKey                   *ecdsa.PrivateKey
+	lastAnnounce, signedAnnounce announceData
+}
+
+// newBroadcaster creates a new broadcaster
+func newBroadcaster(ns *nodestate.NodeStateMachine) *broadcaster {
+	b := &broadcaster{ns: ns}
+	ns.SubscribeState(priorityPoolSetup.ActiveFlag, func(node *enode.Node, oldState, newState nodestate.Flags) {
+		if newState.Equals(priorityPoolSetup.ActiveFlag) {
+			// send last announcement to activated peers
+			b.sendTo(node)
+		}
+	})
+	return b
+}
+
+// setSignerKey sets the signer key for signed announcements. Should be called before
+// starting the protocol handler.
+func (b *broadcaster) setSignerKey(privateKey *ecdsa.PrivateKey) {
+	b.privateKey = privateKey
+}
+
+// broadcast sends the given announcements to all active peers
+func (b *broadcaster) broadcast(announce announceData) {
+	b.ns.Operation(func() {
+		// iterate in an Operation to ensure that the active set does not change while iterating
+		b.lastAnnounce = announce
+		b.ns.ForEach(priorityPoolSetup.ActiveFlag, nodestate.Flags{}, func(node *enode.Node, state nodestate.Flags) {
+			b.sendTo(node)
+		})
+	})
+}
+
+// sendTo sends the most recent announcement to the given node unless the same or higher Td
+// announcement has already been sent.
+func (b *broadcaster) sendTo(node *enode.Node) {
+	if b.lastAnnounce.Td == nil {
+		return
+	}
+	if p, _ := b.ns.GetField(node, clientPeerField).(*clientPeer); p != nil {
+		if p.headInfo.Td == nil || b.lastAnnounce.Td.Cmp(p.headInfo.Td) > 0 {
+			switch p.announceType {
+			case announceTypeSimple:
+				if !p.queueSend(func() { p.sendAnnounce(b.lastAnnounce) }) {
+					log.Debug("Drop announcement because queue is full", "number", b.lastAnnounce.Number, "hash", b.lastAnnounce.Hash)
+				}
+			case announceTypeSigned:
+				if b.signedAnnounce.Hash != b.lastAnnounce.Hash {
+					b.signedAnnounce = b.lastAnnounce
+					b.signedAnnounce.sign(b.privateKey)
+				}
+				if !p.queueSend(func() { p.sendAnnounce(b.signedAnnounce) }) {
+					log.Debug("Drop announcement because queue is full", "number", b.lastAnnounce.Number, "hash", b.lastAnnounce.Hash)
+				}
+			}
+			p.headInfo = blockInfo{b.lastAnnounce.Hash, b.lastAnnounce.Number, b.lastAnnounce.Td}
+		}
+	}
+}
diff --git a/les/test_helper.go b/les/test_helper.go
index 9f9b28721..5a8d64f76 100644
--- a/les/test_helper.go
+++ b/les/test_helper.go
@@ -46,6 +46,7 @@ import (
 	"github.com/ethereum/go-ethereum/light"
 	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p/enode"
+	"github.com/ethereum/go-ethereum/p2p/nodestate"
 	"github.com/ethereum/go-ethereum/params"
 )
 
@@ -227,7 +228,7 @@ func newTestClientHandler(backend *backends.SimulatedBackend, odr *LesOdr, index
 	return client.handler
 }
 
-func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, peers *clientPeerSet, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) {
+func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Database, clock mclock.Clock) (*serverHandler, *backends.SimulatedBackend) {
 	var (
 		gspec = core.Genesis{
 			Config:   params.AllEthashProtocolChanges,
@@ -263,6 +264,7 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da
 		}
 		oracle = checkpointoracle.New(checkpointConfig, getLocal)
 	}
+	ns := nodestate.NewNodeStateMachine(nil, nil, mclock.System{}, serverSetup)
 	server := &LesServer{
 		lesCommons: lesCommons{
 			genesis:     genesis.Hash(),
@@ -274,7 +276,8 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da
 			oracle:      oracle,
 			closeCh:     make(chan struct{}),
 		},
-		peers:        peers,
+		ns:           ns,
+		broadcaster:  newBroadcaster(ns),
 		servingQueue: newServingQueue(int64(time.Millisecond*10), 1),
 		defParams: flowcontrol.ServerParams{
 			BufLimit:    testBufLimit,
@@ -284,13 +287,14 @@ func newTestServerHandler(blocks int, indexers []*core.ChainIndexer, db ethdb.Da
 	}
 	server.costTracker, server.minCapacity = newCostTracker(db, server.config)
 	server.costTracker.testCostList = testCostList(0) // Disable flow control mechanism.
-	server.clientPool = newClientPool(db, testBufRecharge, defaultConnectedBias, clock, func(id enode.ID) {})
+	server.clientPool = newClientPool(ns, db, testBufRecharge, defaultConnectedBias, clock, func(id enode.ID) {})
 	server.clientPool.setLimits(10000, 10000) // Assign enough capacity for clientpool
 	server.handler = newServerHandler(server, simulation.Blockchain(), db, txpool, func() bool { return true })
 	if server.oracle != nil {
 		server.oracle.Start(simulation)
 	}
 	server.servingQueue.setThreads(4)
+	ns.Start()
 	server.handler.start()
 	return server.handler, simulation
 }
@@ -463,7 +467,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba
 	if simClock {
 		clock = &mclock.Simulated{}
 	}
-	handler, b := newTestServerHandler(blocks, indexers, db, newClientPeerSet(), clock)
+	handler, b := newTestServerHandler(blocks, indexers, db, clock)
 
 	var peer *testPeer
 	if newPeer {
@@ -502,7 +506,7 @@ func newServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallba
 
 func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexerCallback, ulcServers []string, ulcFraction int, simClock bool, connect bool, disablePruning bool) (*testServer, *testClient, func()) {
 	sdb, cdb := rawdb.NewMemoryDatabase(), rawdb.NewMemoryDatabase()
-	speers, cpeers := newServerPeerSet(), newClientPeerSet()
+	speers := newServerPeerSet()
 
 	var clock mclock.Clock = &mclock.System{}
 	if simClock {
@@ -519,7 +523,7 @@ func newClientServerEnv(t *testing.T, blocks int, protocol int, callback indexer
 	ccIndexer, cbIndexer, cbtIndexer := cIndexers[0], cIndexers[1], cIndexers[2]
 	odr.SetIndexers(ccIndexer, cbIndexer, cbtIndexer)
 
-	server, b := newTestServerHandler(blocks, sindexers, sdb, cpeers, clock)
+	server, b := newTestServerHandler(blocks, sindexers, sdb, clock)
 	client := newTestClientHandler(b, odr, cIndexers, cdb, speers, ulcServers, ulcFraction)
 
 	scIndexer.Start(server.blockchain)
-- 
GitLab