From dc109cce26da8a93f74a998f9dd7fc2ac0ab98d2 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= <zsfelfoldi@gmail.com>
Date: Thu, 25 Feb 2021 21:08:34 +0100
Subject: [PATCH] les: move server pool to les/vflux/client (#22377)

* les: move serverPool to les/vflux/client

* les: add metrics

* les: moved ValueTracker inside ServerPool

* les: protect against node registration before server pool is started

* les/vflux/client: fixed tests

* les: make peer registration safe
---
 internal/web3ext/web3ext.go               |  14 +-
 les/client.go                             |  38 +----
 les/client_handler.go                     |  14 ++
 les/peer.go                               |   9 +-
 les/vflux/client/queueiterator_test.go    |  11 --
 les/{ => vflux/client}/serverpool.go      | 173 +++++++++++++---------
 les/{ => vflux/client}/serverpool_test.go |  64 ++++----
 les/vflux/client/valuetracker.go          |  77 +++++-----
 les/vflux/client/valuetracker_test.go     |   4 +-
 9 files changed, 206 insertions(+), 198 deletions(-)
 rename les/{ => vflux/client}/serverpool.go (76%)
 rename les/{ => vflux/client}/serverpool_test.go (86%)

diff --git a/internal/web3ext/web3ext.go b/internal/web3ext/web3ext.go
index 77954bbbf..6fcf4b838 100644
--- a/internal/web3ext/web3ext.go
+++ b/internal/web3ext/web3ext.go
@@ -33,7 +33,7 @@ var Modules = map[string]string{
 	"swarmfs":    SwarmfsJs,
 	"txpool":     TxpoolJs,
 	"les":        LESJs,
-	"lespay":     LESPayJs,
+	"vflux":      VfluxJs,
 }
 
 const ChequebookJs = `
@@ -877,24 +877,24 @@ web3._extend({
 });
 `
 
-const LESPayJs = `
+const VfluxJs = `
 web3._extend({
-	property: 'lespay',
+	property: 'vflux',
 	methods:
 	[
 		new web3._extend.Method({
 			name: 'distribution',
-			call: 'lespay_distribution',
+			call: 'vflux_distribution',
 			params: 2
 		}),
 		new web3._extend.Method({
 			name: 'timeout',
-			call: 'lespay_timeout',
+			call: 'vflux_timeout',
 			params: 2
 		}),
 		new web3._extend.Method({
 			name: 'value',
-			call: 'lespay_value',
+			call: 'vflux_value',
 			params: 2
 		}),
 	],
@@ -902,7 +902,7 @@ web3._extend({
 	[
 		new web3._extend.Property({
 			name: 'requestStats',
-			getter: 'lespay_requestStats'
+			getter: 'vflux_requestStats'
 		}),
 	]
 });
diff --git a/les/client.go b/les/client.go
index 053118df5..e20519fd9 100644
--- a/les/client.go
+++ b/les/client.go
@@ -57,8 +57,7 @@ type LightEthereum struct {
 	handler        *clientHandler
 	txPool         *light.TxPool
 	blockchain     *light.LightChain
-	serverPool     *serverPool
-	valueTracker   *vfc.ValueTracker
+	serverPool     *vfc.ServerPool
 	dialCandidates enode.Iterator
 	pruner         *pruner
 
@@ -109,17 +108,14 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) {
 		engine:         ethconfig.CreateConsensusEngine(stack, chainConfig, &config.Ethash, nil, false, chainDb),
 		bloomRequests:  make(chan chan *bloombits.Retrieval),
 		bloomIndexer:   core.NewBloomIndexer(chainDb, params.BloomBitsBlocksClient, params.HelperTrieConfirmations),
-		valueTracker:   vfc.NewValueTracker(lesDb, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)),
 		p2pServer:      stack.Server(),
 		p2pConfig:      &stack.Config().P2P,
 	}
-	peers.subscribe((*vtSubscription)(leth.valueTracker))
 
-	leth.serverPool = newServerPool(lesDb, []byte("serverpool:"), leth.valueTracker, time.Second, nil, &mclock.System{}, config.UltraLightServers)
-	peers.subscribe(leth.serverPool)
-	leth.dialCandidates = leth.serverPool.dialIterator
+	leth.serverPool, leth.dialCandidates = vfc.NewServerPool(lesDb, []byte("serverpool:"), time.Second, nil, &mclock.System{}, config.UltraLightServers, requestList)
+	leth.serverPool.AddMetrics(suggestedTimeoutGauge, totalValueGauge, serverSelectableGauge, serverConnectedGauge, sessionValueMeter, serverDialedMeter)
 
-	leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.getTimeout)
+	leth.retriever = newRetrieveManager(peers, leth.reqDist, leth.serverPool.GetTimeout)
 	leth.relay = newLesTxRelay(peers, leth.retriever)
 
 	leth.odr = NewLesOdr(chainDb, light.DefaultClientIndexerConfig, leth.peers, leth.retriever)
@@ -193,23 +189,6 @@ func New(stack *node.Node, config *ethconfig.Config) (*LightEthereum, error) {
 	return leth, nil
 }
 
-// vtSubscription implements serverPeerSubscriber
-type vtSubscription vfc.ValueTracker
-
-// registerPeer implements serverPeerSubscriber
-func (v *vtSubscription) registerPeer(p *serverPeer) {
-	vt := (*vfc.ValueTracker)(v)
-	p.setValueTracker(vt, vt.Register(p.ID()))
-	p.updateVtParams()
-}
-
-// unregisterPeer implements serverPeerSubscriber
-func (v *vtSubscription) unregisterPeer(p *serverPeer) {
-	vt := (*vfc.ValueTracker)(v)
-	vt.Unregister(p.ID())
-	p.setValueTracker(nil, nil)
-}
-
 type LightDummyAPI struct{}
 
 // Etherbase is the address that mining rewards will be send to
@@ -266,7 +245,7 @@ func (s *LightEthereum) APIs() []rpc.API {
 		}, {
 			Namespace: "vflux",
 			Version:   "1.0",
-			Service:   vfc.NewPrivateClientAPI(s.valueTracker),
+			Service:   s.serverPool.API(),
 			Public:    false,
 		},
 	}...)
@@ -302,8 +281,8 @@ func (s *LightEthereum) Start() error {
 	if err != nil {
 		return err
 	}
-	s.serverPool.addSource(discovery)
-	s.serverPool.start()
+	s.serverPool.AddSource(discovery)
+	s.serverPool.Start()
 	// Start bloom request workers.
 	s.wg.Add(bloomServiceThreads)
 	s.startBloomHandlers(params.BloomBitsBlocksClient)
@@ -316,8 +295,7 @@ func (s *LightEthereum) Start() error {
 // Ethereum protocol.
 func (s *LightEthereum) Stop() error {
 	close(s.closeCh)
-	s.serverPool.stop()
-	s.valueTracker.Stop()
+	s.serverPool.Stop()
 	s.peers.close()
 	s.reqDist.close()
 	s.odr.Stop()
diff --git a/les/client_handler.go b/les/client_handler.go
index 6cd786cda..f8e9edc9f 100644
--- a/les/client_handler.go
+++ b/les/client_handler.go
@@ -114,11 +114,25 @@ func (h *clientHandler) handle(p *serverPeer) error {
 		p.Log().Debug("Light Ethereum handshake failed", "err", err)
 		return err
 	}
+	// Register peer with the server pool
+	if h.backend.serverPool != nil {
+		if nvt, err := h.backend.serverPool.RegisterNode(p.Node()); err == nil {
+			p.setValueTracker(nvt)
+			p.updateVtParams()
+			defer func() {
+				p.setValueTracker(nil)
+				h.backend.serverPool.UnregisterNode(p.Node())
+			}()
+		} else {
+			return err
+		}
+	}
 	// Register the peer locally
 	if err := h.backend.peers.register(p); err != nil {
 		p.Log().Error("Light Ethereum peer registration failed", "err", err)
 		return err
 	}
+
 	serverConnectionGauge.Update(int64(h.backend.peers.len()))
 
 	connectedAt := mclock.Now()
diff --git a/les/peer.go b/les/peer.go
index 0361167ee..78019b1d8 100644
--- a/les/peer.go
+++ b/les/peer.go
@@ -349,7 +349,6 @@ type serverPeer struct {
 
 	fcServer         *flowcontrol.ServerNode // Client side mirror token bucket.
 	vtLock           sync.Mutex
-	valueTracker     *vfc.ValueTracker
 	nodeValueTracker *vfc.NodeValueTracker
 	sentReqs         map[uint64]sentReqEntry
 
@@ -676,9 +675,8 @@ func (p *serverPeer) Handshake(genesis common.Hash, forkid forkid.ID, forkFilter
 
 // setValueTracker sets the value tracker references for connected servers. Note that the
 // references should be removed upon disconnection by setValueTracker(nil, nil).
-func (p *serverPeer) setValueTracker(vt *vfc.ValueTracker, nvt *vfc.NodeValueTracker) {
+func (p *serverPeer) setValueTracker(nvt *vfc.NodeValueTracker) {
 	p.vtLock.Lock()
-	p.valueTracker = vt
 	p.nodeValueTracker = nvt
 	if nvt != nil {
 		p.sentReqs = make(map[uint64]sentReqEntry)
@@ -705,7 +703,7 @@ func (p *serverPeer) updateVtParams() {
 			}
 		}
 	}
-	p.valueTracker.UpdateCosts(p.nodeValueTracker, reqCosts)
+	p.nodeValueTracker.UpdateCosts(reqCosts)
 }
 
 // sentReqEntry remembers sent requests and their sending times
@@ -732,7 +730,6 @@ func (p *serverPeer) answeredRequest(id uint64) {
 	}
 	e, ok := p.sentReqs[id]
 	delete(p.sentReqs, id)
-	vt := p.valueTracker
 	nvt := p.nodeValueTracker
 	p.vtLock.Unlock()
 	if !ok {
@@ -752,7 +749,7 @@ func (p *serverPeer) answeredRequest(id uint64) {
 		vtReqs[1] = vfc.ServedRequest{ReqType: uint32(m.rest), Amount: e.amount - 1}
 	}
 	dt := time.Duration(mclock.Now() - e.at)
-	vt.Served(nvt, vtReqs[:reqCount], dt)
+	nvt.Served(vtReqs[:reqCount], dt)
 }
 
 // clientPeer represents each node to which the les server is connected.
diff --git a/les/vflux/client/queueiterator_test.go b/les/vflux/client/queueiterator_test.go
index a74301c7d..400d978e1 100644
--- a/les/vflux/client/queueiterator_test.go
+++ b/les/vflux/client/queueiterator_test.go
@@ -26,17 +26,6 @@ import (
 	"github.com/ethereum/go-ethereum/p2p/nodestate"
 )
 
-func testNodeID(i int) enode.ID {
-	return enode.ID{42, byte(i % 256), byte(i / 256)}
-}
-
-func testNodeIndex(id enode.ID) int {
-	if id[0] != 42 {
-		return -1
-	}
-	return int(id[1]) + int(id[2])*256
-}
-
 func testNode(i int) *enode.Node {
 	return enode.SignNull(new(enr.Record), testNodeID(i))
 }
diff --git a/les/serverpool.go b/les/vflux/client/serverpool.go
similarity index 76%
rename from les/serverpool.go
rename to les/vflux/client/serverpool.go
index 977579988..95f724609 100644
--- a/les/serverpool.go
+++ b/les/vflux/client/serverpool.go
@@ -14,7 +14,7 @@
 // You should have received a copy of the GNU Lesser General Public License
 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 
-package les
+package client
 
 import (
 	"errors"
@@ -27,8 +27,8 @@ import (
 	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/les/utils"
-	vfc "github.com/ethereum/go-ethereum/les/vflux/client"
 	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/metrics"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/enr"
 	"github.com/ethereum/go-ethereum/p2p/nodestate"
@@ -50,31 +50,34 @@ const (
 	maxQueryFails       = 100                    // number of consecutive UDP query failures before we print a warning
 )
 
-// serverPool provides a node iterator for dial candidates. The output is a mix of newly discovered
+// ServerPool provides a node iterator for dial candidates. The output is a mix of newly discovered
 // nodes, a weighted random selection of known (previously valuable) nodes and trusted/paid nodes.
-type serverPool struct {
+type ServerPool struct {
 	clock    mclock.Clock
 	unixTime func() int64
 	db       ethdb.KeyValueStore
 
-	ns           *nodestate.NodeStateMachine
-	vt           *vfc.ValueTracker
-	mixer        *enode.FairMix
-	mixSources   []enode.Iterator
-	dialIterator enode.Iterator
-	validSchemes enr.IdentityScheme
-	trustedURLs  []string
-	fillSet      *vfc.FillSet
-	queryFails   uint32
+	ns                  *nodestate.NodeStateMachine
+	vt                  *ValueTracker
+	mixer               *enode.FairMix
+	mixSources          []enode.Iterator
+	dialIterator        enode.Iterator
+	validSchemes        enr.IdentityScheme
+	trustedURLs         []string
+	fillSet             *FillSet
+	started, queryFails uint32
 
 	timeoutLock      sync.RWMutex
 	timeout          time.Duration
-	timeWeights      vfc.ResponseTimeWeights
+	timeWeights      ResponseTimeWeights
 	timeoutRefreshed mclock.AbsTime
+
+	suggestedTimeoutGauge, totalValueGauge metrics.Gauge
+	sessionValueMeter                      metrics.Meter
 }
 
 // nodeHistory keeps track of dial costs which determine node weight together with the
-// service value calculated by vfc.ValueTracker.
+// service value calculated by ValueTracker.
 type nodeHistory struct {
 	dialCost                       utils.ExpiredValue
 	redialWaitStart, redialWaitEnd int64 // unix time (seconds)
@@ -91,18 +94,18 @@ type nodeHistoryEnc struct {
 type queryFunc func(*enode.Node) int
 
 var (
-	serverPoolSetup    = &nodestate.Setup{Version: 1}
-	sfHasValue         = serverPoolSetup.NewPersistentFlag("hasValue")
-	sfQueried          = serverPoolSetup.NewFlag("queried")
-	sfCanDial          = serverPoolSetup.NewFlag("canDial")
-	sfDialing          = serverPoolSetup.NewFlag("dialed")
-	sfWaitDialTimeout  = serverPoolSetup.NewFlag("dialTimeout")
-	sfConnected        = serverPoolSetup.NewFlag("connected")
-	sfRedialWait       = serverPoolSetup.NewFlag("redialWait")
-	sfAlwaysConnect    = serverPoolSetup.NewFlag("alwaysConnect")
+	clientSetup        = &nodestate.Setup{Version: 1}
+	sfHasValue         = clientSetup.NewPersistentFlag("hasValue")
+	sfQueried          = clientSetup.NewFlag("queried")
+	sfCanDial          = clientSetup.NewFlag("canDial")
+	sfDialing          = clientSetup.NewFlag("dialed")
+	sfWaitDialTimeout  = clientSetup.NewFlag("dialTimeout")
+	sfConnected        = clientSetup.NewFlag("connected")
+	sfRedialWait       = clientSetup.NewFlag("redialWait")
+	sfAlwaysConnect    = clientSetup.NewFlag("alwaysConnect")
 	sfDisableSelection = nodestate.MergeFlags(sfQueried, sfCanDial, sfDialing, sfConnected, sfRedialWait)
 
-	sfiNodeHistory = serverPoolSetup.NewPersistentField("nodeHistory", reflect.TypeOf(nodeHistory{}),
+	sfiNodeHistory = clientSetup.NewPersistentField("nodeHistory", reflect.TypeOf(nodeHistory{}),
 		func(field interface{}) ([]byte, error) {
 			if n, ok := field.(nodeHistory); ok {
 				ne := nodeHistoryEnc{
@@ -126,25 +129,25 @@ var (
 			return n, err
 		},
 	)
-	sfiNodeWeight     = serverPoolSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0)))
-	sfiConnectedStats = serverPoolSetup.NewField("connectedStats", reflect.TypeOf(vfc.ResponseTimeStats{}))
+	sfiNodeWeight     = clientSetup.NewField("nodeWeight", reflect.TypeOf(uint64(0)))
+	sfiConnectedStats = clientSetup.NewField("connectedStats", reflect.TypeOf(ResponseTimeStats{}))
 )
 
 // newServerPool creates a new server pool
-func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *vfc.ValueTracker, mixTimeout time.Duration, query queryFunc, clock mclock.Clock, trustedURLs []string) *serverPool {
-	s := &serverPool{
+func NewServerPool(db ethdb.KeyValueStore, dbKey []byte, mixTimeout time.Duration, query queryFunc, clock mclock.Clock, trustedURLs []string, requestList []RequestInfo) (*ServerPool, enode.Iterator) {
+	s := &ServerPool{
 		db:           db,
 		clock:        clock,
 		unixTime:     func() int64 { return time.Now().Unix() },
 		validSchemes: enode.ValidSchemes,
 		trustedURLs:  trustedURLs,
-		vt:           vt,
-		ns:           nodestate.NewNodeStateMachine(db, []byte(string(dbKey)+"ns:"), clock, serverPoolSetup),
+		vt:           NewValueTracker(db, &mclock.System{}, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000)),
+		ns:           nodestate.NewNodeStateMachine(db, []byte(string(dbKey)+"ns:"), clock, clientSetup),
 	}
 	s.recalTimeout()
 	s.mixer = enode.NewFairMix(mixTimeout)
-	knownSelector := vfc.NewWrsIterator(s.ns, sfHasValue, sfDisableSelection, sfiNodeWeight)
-	alwaysConnect := vfc.NewQueueIterator(s.ns, sfAlwaysConnect, sfDisableSelection, true, nil)
+	knownSelector := NewWrsIterator(s.ns, sfHasValue, sfDisableSelection, sfiNodeWeight)
+	alwaysConnect := NewQueueIterator(s.ns, sfAlwaysConnect, sfDisableSelection, true, nil)
 	s.mixSources = append(s.mixSources, knownSelector)
 	s.mixSources = append(s.mixSources, alwaysConnect)
 
@@ -166,14 +169,30 @@ func newServerPool(db ethdb.KeyValueStore, dbKey []byte, vt *vfc.ValueTracker, m
 		}
 	})
 
-	s.ns.AddLogMetrics(sfHasValue, sfDisableSelection, "selectable", nil, nil, serverSelectableGauge)
-	s.ns.AddLogMetrics(sfDialing, nodestate.Flags{}, "dialed", serverDialedMeter, nil, nil)
-	s.ns.AddLogMetrics(sfConnected, nodestate.Flags{}, "connected", nil, nil, serverConnectedGauge)
-	return s
+	return s, s.dialIterator
+}
+
+// AddMetrics adds metrics to the server pool. Should be called before Start().
+func (s *ServerPool) AddMetrics(
+	suggestedTimeoutGauge, totalValueGauge, serverSelectableGauge, serverConnectedGauge metrics.Gauge,
+	sessionValueMeter, serverDialedMeter metrics.Meter) {
+
+	s.suggestedTimeoutGauge = suggestedTimeoutGauge
+	s.totalValueGauge = totalValueGauge
+	s.sessionValueMeter = sessionValueMeter
+	if serverSelectableGauge != nil {
+		s.ns.AddLogMetrics(sfHasValue, sfDisableSelection, "selectable", nil, nil, serverSelectableGauge)
+	}
+	if serverDialedMeter != nil {
+		s.ns.AddLogMetrics(sfDialing, nodestate.Flags{}, "dialed", serverDialedMeter, nil, nil)
+	}
+	if serverConnectedGauge != nil {
+		s.ns.AddLogMetrics(sfConnected, nodestate.Flags{}, "connected", nil, nil, serverConnectedGauge)
+	}
 }
 
-// addSource adds a node discovery source to the server pool (should be called before start)
-func (s *serverPool) addSource(source enode.Iterator) {
+// AddSource adds a node discovery source to the server pool (should be called before start)
+func (s *ServerPool) AddSource(source enode.Iterator) {
 	if source != nil {
 		s.mixSources = append(s.mixSources, source)
 	}
@@ -182,8 +201,8 @@ func (s *serverPool) addSource(source enode.Iterator) {
 // addPreNegFilter installs a node filter mechanism that performs a pre-negotiation query.
 // Nodes that are filtered out and does not appear on the output iterator are put back
 // into redialWait state.
-func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enode.Iterator {
-	s.fillSet = vfc.NewFillSet(s.ns, input, sfQueried)
+func (s *ServerPool) addPreNegFilter(input enode.Iterator, query queryFunc) enode.Iterator {
+	s.fillSet = NewFillSet(s.ns, input, sfQueried)
 	s.ns.SubscribeState(sfQueried, func(n *enode.Node, oldState, newState nodestate.Flags) {
 		if newState.Equals(sfQueried) {
 			fails := atomic.LoadUint32(&s.queryFails)
@@ -221,7 +240,7 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod
 			}()
 		}
 	})
-	return vfc.NewQueueIterator(s.ns, sfCanDial, nodestate.Flags{}, false, func(waiting bool) {
+	return NewQueueIterator(s.ns, sfCanDial, nodestate.Flags{}, false, func(waiting bool) {
 		if waiting {
 			s.fillSet.SetTarget(preNegLimit)
 		} else {
@@ -231,7 +250,7 @@ func (s *serverPool) addPreNegFilter(input enode.Iterator, query queryFunc) enod
 }
 
 // start starts the server pool. Note that NodeStateMachine should be started first.
-func (s *serverPool) start() {
+func (s *ServerPool) Start() {
 	s.ns.Start()
 	for _, iter := range s.mixSources {
 		// add sources to mixer at startup because the mixer instantly tries to read them
@@ -261,10 +280,11 @@ func (s *serverPool) start() {
 			}
 		})
 	})
+	atomic.StoreUint32(&s.started, 1)
 }
 
 // stop stops the server pool
-func (s *serverPool) stop() {
+func (s *ServerPool) Stop() {
 	s.dialIterator.Close()
 	if s.fillSet != nil {
 		s.fillSet.Close()
@@ -276,32 +296,34 @@ func (s *serverPool) stop() {
 		})
 	})
 	s.ns.Stop()
+	s.vt.Stop()
 }
 
 // registerPeer implements serverPeerSubscriber
-func (s *serverPool) registerPeer(p *serverPeer) {
-	s.ns.SetState(p.Node(), sfConnected, sfDialing.Or(sfWaitDialTimeout), 0)
-	nvt := s.vt.Register(p.ID())
-	s.ns.SetField(p.Node(), sfiConnectedStats, nvt.RtStats())
-	p.setValueTracker(s.vt, nvt)
-	p.updateVtParams()
+func (s *ServerPool) RegisterNode(node *enode.Node) (*NodeValueTracker, error) {
+	if atomic.LoadUint32(&s.started) == 0 {
+		return nil, errors.New("server pool not started yet")
+	}
+	s.ns.SetState(node, sfConnected, sfDialing.Or(sfWaitDialTimeout), 0)
+	nvt := s.vt.Register(node.ID())
+	s.ns.SetField(node, sfiConnectedStats, nvt.RtStats())
+	return nvt, nil
 }
 
 // unregisterPeer implements serverPeerSubscriber
-func (s *serverPool) unregisterPeer(p *serverPeer) {
+func (s *ServerPool) UnregisterNode(node *enode.Node) {
 	s.ns.Operation(func() {
-		s.setRedialWait(p.Node(), dialCost, dialWaitStep)
-		s.ns.SetStateSub(p.Node(), nodestate.Flags{}, sfConnected, 0)
-		s.ns.SetFieldSub(p.Node(), sfiConnectedStats, nil)
+		s.setRedialWait(node, dialCost, dialWaitStep)
+		s.ns.SetStateSub(node, nodestate.Flags{}, sfConnected, 0)
+		s.ns.SetFieldSub(node, sfiConnectedStats, nil)
 	})
-	s.vt.Unregister(p.ID())
-	p.setValueTracker(nil, nil)
+	s.vt.Unregister(node.ID())
 }
 
 // recalTimeout calculates the current recommended timeout. This value is used by
 // the client as a "soft timeout" value. It also affects the service value calculation
 // of individual nodes.
-func (s *serverPool) recalTimeout() {
+func (s *ServerPool) recalTimeout() {
 	// Use cached result if possible, avoid recalculating too frequently.
 	s.timeoutLock.RLock()
 	refreshed := s.timeoutRefreshed
@@ -330,17 +352,21 @@ func (s *serverPool) recalTimeout() {
 	s.timeoutLock.Lock()
 	if s.timeout != timeout {
 		s.timeout = timeout
-		s.timeWeights = vfc.TimeoutWeights(s.timeout)
+		s.timeWeights = TimeoutWeights(s.timeout)
 
-		suggestedTimeoutGauge.Update(int64(s.timeout / time.Millisecond))
-		totalValueGauge.Update(int64(rts.Value(s.timeWeights, s.vt.StatsExpFactor())))
+		if s.suggestedTimeoutGauge != nil {
+			s.suggestedTimeoutGauge.Update(int64(s.timeout / time.Millisecond))
+		}
+		if s.totalValueGauge != nil {
+			s.totalValueGauge.Update(int64(rts.Value(s.timeWeights, s.vt.StatsExpFactor())))
+		}
 	}
 	s.timeoutRefreshed = now
 	s.timeoutLock.Unlock()
 }
 
-// getTimeout returns the recommended request timeout.
-func (s *serverPool) getTimeout() time.Duration {
+// GetTimeout returns the recommended request timeout.
+func (s *ServerPool) GetTimeout() time.Duration {
 	s.recalTimeout()
 	s.timeoutLock.RLock()
 	defer s.timeoutLock.RUnlock()
@@ -349,7 +375,7 @@ func (s *serverPool) getTimeout() time.Duration {
 
 // getTimeoutAndWeight returns the recommended request timeout as well as the
 // response time weight which is necessary to calculate service value.
-func (s *serverPool) getTimeoutAndWeight() (time.Duration, vfc.ResponseTimeWeights) {
+func (s *ServerPool) getTimeoutAndWeight() (time.Duration, ResponseTimeWeights) {
 	s.recalTimeout()
 	s.timeoutLock.RLock()
 	defer s.timeoutLock.RUnlock()
@@ -358,7 +384,7 @@ func (s *serverPool) getTimeoutAndWeight() (time.Duration, vfc.ResponseTimeWeigh
 
 // addDialCost adds the given amount of dial cost to the node history and returns the current
 // amount of total dial cost
-func (s *serverPool) addDialCost(n *nodeHistory, amount int64) uint64 {
+func (s *ServerPool) addDialCost(n *nodeHistory, amount int64) uint64 {
 	logOffset := s.vt.StatsExpirer().LogOffset(s.clock.Now())
 	if amount > 0 {
 		n.dialCost.Add(amount, logOffset)
@@ -371,7 +397,7 @@ func (s *serverPool) addDialCost(n *nodeHistory, amount int64) uint64 {
 }
 
 // serviceValue returns the service value accumulated in this session and in total
-func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue float64) {
+func (s *ServerPool) serviceValue(node *enode.Node) (sessionValue, totalValue float64) {
 	nvt := s.vt.GetNode(node.ID())
 	if nvt == nil {
 		return 0, 0
@@ -381,11 +407,13 @@ func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue fl
 	expFactor := s.vt.StatsExpFactor()
 
 	totalValue = currentStats.Value(timeWeights, expFactor)
-	if connStats, ok := s.ns.GetField(node, sfiConnectedStats).(vfc.ResponseTimeStats); ok {
+	if connStats, ok := s.ns.GetField(node, sfiConnectedStats).(ResponseTimeStats); ok {
 		diff := currentStats
 		diff.SubStats(&connStats)
 		sessionValue = diff.Value(timeWeights, expFactor)
-		sessionValueMeter.Mark(int64(sessionValue))
+		if s.sessionValueMeter != nil {
+			s.sessionValueMeter.Mark(int64(sessionValue))
+		}
 	}
 	return
 }
@@ -393,7 +421,7 @@ func (s *serverPool) serviceValue(node *enode.Node) (sessionValue, totalValue fl
 // updateWeight calculates the node weight and updates the nodeWeight field and the
 // hasValue flag. It also saves the node state if necessary.
 // Note: this function should run inside a NodeStateMachine operation
-func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) {
+func (s *ServerPool) updateWeight(node *enode.Node, totalValue float64, totalDialCost uint64) {
 	weight := uint64(totalValue * nodeWeightMul / float64(totalDialCost))
 	if weight >= nodeWeightThreshold {
 		s.ns.SetStateSub(node, sfHasValue, nodestate.Flags{}, 0)
@@ -415,7 +443,7 @@ func (s *serverPool) updateWeight(node *enode.Node, totalValue float64, totalDia
 // to the minimum.
 // Note: node weight is also recalculated and updated by this function.
 // Note 2: this function should run inside a NodeStateMachine operation
-func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) {
+func (s *ServerPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep float64) {
 	n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory)
 	sessionValue, totalValue := s.serviceValue(node)
 	totalDialCost := s.addDialCost(&n, addDialCost)
@@ -481,9 +509,14 @@ func (s *serverPool) setRedialWait(node *enode.Node, addDialCost int64, waitStep
 // This function should be called during startup and shutdown only, otherwise setRedialWait
 // will keep the weights updated as the underlying statistics are adjusted.
 // Note: this function should run inside a NodeStateMachine operation
-func (s *serverPool) calculateWeight(node *enode.Node) {
+func (s *ServerPool) calculateWeight(node *enode.Node) {
 	n, _ := s.ns.GetField(node, sfiNodeHistory).(nodeHistory)
 	_, totalValue := s.serviceValue(node)
 	totalDialCost := s.addDialCost(&n, 0)
 	s.updateWeight(node, totalValue, totalDialCost)
 }
+
+// API returns the vflux client API
+func (s *ServerPool) API() *PrivateClientAPI {
+	return NewPrivateClientAPI(s.vt)
+}
diff --git a/les/serverpool_test.go b/les/vflux/client/serverpool_test.go
similarity index 86%
rename from les/serverpool_test.go
rename to les/vflux/client/serverpool_test.go
index 5c8ae56f6..3af3db95b 100644
--- a/les/serverpool_test.go
+++ b/les/vflux/client/serverpool_test.go
@@ -14,10 +14,11 @@
 // You should have received a copy of the GNU Lesser General Public License
 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
 
-package les
+package client
 
 import (
 	"math/rand"
+	"strconv"
 	"sync/atomic"
 	"testing"
 	"time"
@@ -25,8 +26,6 @@ import (
 	"github.com/ethereum/go-ethereum/common/mclock"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/ethdb/memorydb"
-	vfc "github.com/ethereum/go-ethereum/les/vflux/client"
-	"github.com/ethereum/go-ethereum/p2p"
 	"github.com/ethereum/go-ethereum/p2p/enode"
 	"github.com/ethereum/go-ethereum/p2p/enr"
 )
@@ -50,13 +49,13 @@ func testNodeIndex(id enode.ID) int {
 	return int(id[1]) + int(id[2])*256
 }
 
-type serverPoolTest struct {
+type ServerPoolTest struct {
 	db                   ethdb.KeyValueStore
 	clock                *mclock.Simulated
 	quit                 chan struct{}
 	preNeg, preNegFail   bool
-	vt                   *vfc.ValueTracker
-	sp                   *serverPool
+	vt                   *ValueTracker
+	sp                   *ServerPool
 	input                enode.Iterator
 	testNodes            []spTestNode
 	trusted              []string
@@ -71,15 +70,15 @@ type spTestNode struct {
 	connectCycles, waitCycles int
 	nextConnCycle, totalConn  int
 	connected, service        bool
-	peer                      *serverPeer
+	node                      *enode.Node
 }
 
-func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest {
+func newServerPoolTest(preNeg, preNegFail bool) *ServerPoolTest {
 	nodes := make([]*enode.Node, spTestNodes)
 	for i := range nodes {
 		nodes[i] = enode.SignNull(&enr.Record{}, testNodeID(i))
 	}
-	return &serverPoolTest{
+	return &ServerPoolTest{
 		clock:      &mclock.Simulated{},
 		db:         memorydb.New(),
 		input:      enode.CycleNodes(nodes),
@@ -89,7 +88,7 @@ func newServerPoolTest(preNeg, preNegFail bool) *serverPoolTest {
 	}
 }
 
-func (s *serverPoolTest) beginWait() {
+func (s *ServerPoolTest) beginWait() {
 	// ensure that dialIterator and the maximal number of pre-neg queries are not all stuck in a waiting state
 	for atomic.AddInt32(&s.waitCount, 1) > preNegLimit {
 		atomic.AddInt32(&s.waitCount, -1)
@@ -97,16 +96,16 @@ func (s *serverPoolTest) beginWait() {
 	}
 }
 
-func (s *serverPoolTest) endWait() {
+func (s *ServerPoolTest) endWait() {
 	atomic.AddInt32(&s.waitCount, -1)
 	atomic.AddInt32(&s.waitEnded, 1)
 }
 
-func (s *serverPoolTest) addTrusted(i int) {
+func (s *ServerPoolTest) addTrusted(i int) {
 	s.trusted = append(s.trusted, enode.SignNull(&enr.Record{}, testNodeID(i)).String())
 }
 
-func (s *serverPoolTest) start() {
+func (s *ServerPoolTest) start() {
 	var testQuery queryFunc
 	if s.preNeg {
 		testQuery = func(node *enode.Node) int {
@@ -144,13 +143,17 @@ func (s *serverPoolTest) start() {
 		}
 	}
 
-	s.vt = vfc.NewValueTracker(s.db, s.clock, requestList, time.Minute, 1/float64(time.Hour), 1/float64(time.Hour*100), 1/float64(time.Hour*1000))
-	s.sp = newServerPool(s.db, []byte("serverpool:"), s.vt, 0, testQuery, s.clock, s.trusted)
-	s.sp.addSource(s.input)
+	requestList := make([]RequestInfo, testReqTypes)
+	for i := range requestList {
+		requestList[i] = RequestInfo{Name: "testreq" + strconv.Itoa(i), InitAmount: 1, InitValue: 1}
+	}
+
+	s.sp, _ = NewServerPool(s.db, []byte("sp:"), 0, testQuery, s.clock, s.trusted, requestList)
+	s.sp.AddSource(s.input)
 	s.sp.validSchemes = enode.ValidSchemesForTesting
 	s.sp.unixTime = func() int64 { return int64(s.clock.Now()) / int64(time.Second) }
 	s.disconnect = make(map[int][]int)
-	s.sp.start()
+	s.sp.Start()
 	s.quit = make(chan struct{})
 	go func() {
 		last := int32(-1)
@@ -170,31 +173,30 @@ func (s *serverPoolTest) start() {
 	}()
 }
 
-func (s *serverPoolTest) stop() {
+func (s *ServerPoolTest) stop() {
 	close(s.quit)
-	s.sp.stop()
-	s.vt.Stop()
+	s.sp.Stop()
 	for i := range s.testNodes {
 		n := &s.testNodes[i]
 		if n.connected {
 			n.totalConn += s.cycle
 		}
 		n.connected = false
-		n.peer = nil
+		n.node = nil
 		n.nextConnCycle = 0
 	}
 	s.conn, s.servedConn = 0, 0
 }
 
-func (s *serverPoolTest) run() {
+func (s *ServerPoolTest) run() {
 	for count := spTestLength; count > 0; count-- {
 		if dcList := s.disconnect[s.cycle]; dcList != nil {
 			for _, idx := range dcList {
 				n := &s.testNodes[idx]
-				s.sp.unregisterPeer(n.peer)
+				s.sp.UnregisterNode(n.node)
 				n.totalConn += s.cycle
 				n.connected = false
-				n.peer = nil
+				n.node = nil
 				s.conn--
 				if n.service {
 					s.servedConn--
@@ -221,10 +223,10 @@ func (s *serverPoolTest) run() {
 				n.connected = true
 				dc := s.cycle + n.connectCycles
 				s.disconnect[dc] = append(s.disconnect[dc], idx)
-				n.peer = &serverPeer{peerCommons: peerCommons{Peer: p2p.NewPeer(id, "", nil)}}
-				s.sp.registerPeer(n.peer)
+				n.node = dial
+				nv, _ := s.sp.RegisterNode(n.node)
 				if n.service {
-					s.vt.Served(s.vt.GetNode(id), []vfc.ServedRequest{{ReqType: 0, Amount: 100}}, 0)
+					nv.Served([]ServedRequest{{ReqType: 0, Amount: 100}}, 0)
 				}
 			}
 		}
@@ -234,7 +236,7 @@ func (s *serverPoolTest) run() {
 	}
 }
 
-func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) {
+func (s *ServerPoolTest) setNodes(count, conn, wait int, service, trusted bool) (res []int) {
 	for ; count > 0; count-- {
 		idx := rand.Intn(spTestNodes)
 		for s.testNodes[idx].connectCycles != 0 || s.testNodes[idx].connected {
@@ -253,11 +255,11 @@ func (s *serverPoolTest) setNodes(count, conn, wait int, service, trusted bool)
 	return
 }
 
-func (s *serverPoolTest) resetNodes() {
+func (s *ServerPoolTest) resetNodes() {
 	for i, n := range s.testNodes {
 		if n.connected {
 			n.totalConn += s.cycle
-			s.sp.unregisterPeer(n.peer)
+			s.sp.UnregisterNode(n.node)
 		}
 		s.testNodes[i] = spTestNode{totalConn: n.totalConn}
 	}
@@ -266,7 +268,7 @@ func (s *serverPoolTest) resetNodes() {
 	s.trusted = nil
 }
 
-func (s *serverPoolTest) checkNodes(t *testing.T, nodes []int) {
+func (s *ServerPoolTest) checkNodes(t *testing.T, nodes []int) {
 	var sum int
 	for _, idx := range nodes {
 		n := &s.testNodes[idx]
diff --git a/les/vflux/client/valuetracker.go b/les/vflux/client/valuetracker.go
index 4e67b31d9..f5390d092 100644
--- a/les/vflux/client/valuetracker.go
+++ b/les/vflux/client/valuetracker.go
@@ -45,6 +45,7 @@ var (
 type NodeValueTracker struct {
 	lock sync.Mutex
 
+	vt                   *ValueTracker
 	rtStats, lastRtStats ResponseTimeStats
 	lastTransfer         mclock.AbsTime
 	basket               serverBasket
@@ -52,15 +53,12 @@ type NodeValueTracker struct {
 	reqValues            *[]float64
 }
 
-// init initializes a NodeValueTracker.
-// Note that the contents of the referenced reqValues slice will not change; a new
-// reference is passed if the values are updated by ValueTracker.
-func (nv *NodeValueTracker) init(now mclock.AbsTime, reqValues *[]float64) {
-	reqTypeCount := len(*reqValues)
-	nv.reqCosts = make([]uint64, reqTypeCount)
-	nv.lastTransfer = now
-	nv.reqValues = reqValues
-	nv.basket.init(reqTypeCount)
+// UpdateCosts updates the node value tracker's request cost table
+func (nv *NodeValueTracker) UpdateCosts(reqCosts []uint64) {
+	nv.vt.lock.Lock()
+	defer nv.vt.lock.Unlock()
+
+	nv.updateCosts(reqCosts, &nv.vt.refBasket.reqValues, nv.vt.refBasket.reqValueFactor(reqCosts))
 }
 
 // updateCosts updates the request cost table of the server. The request value factor
@@ -97,6 +95,28 @@ func (nv *NodeValueTracker) transferStats(now mclock.AbsTime, transferRate float
 	return nv.basket.transfer(-math.Expm1(-transferRate * float64(dt))), recentRtStats
 }
 
+type ServedRequest struct {
+	ReqType, Amount uint32
+}
+
+// Served adds a served request to the node's statistics. An actual request may be composed
+// of one or more request types (service vector indices).
+func (nv *NodeValueTracker) Served(reqs []ServedRequest, respTime time.Duration) {
+	nv.vt.statsExpLock.RLock()
+	expFactor := nv.vt.statsExpFactor
+	nv.vt.statsExpLock.RUnlock()
+
+	nv.lock.Lock()
+	defer nv.lock.Unlock()
+
+	var value float64
+	for _, r := range reqs {
+		nv.basket.add(r.ReqType, r.Amount, nv.reqCosts[r.ReqType]*uint64(r.Amount), expFactor)
+		value += (*nv.reqValues)[r.ReqType] * float64(r.Amount)
+	}
+	nv.rtStats.Add(respTime, value, expFactor)
+}
+
 // RtStats returns the node's own response time distribution statistics
 func (nv *NodeValueTracker) RtStats() ResponseTimeStats {
 	nv.lock.Lock()
@@ -333,7 +353,12 @@ func (vt *ValueTracker) Register(id enode.ID) *NodeValueTracker {
 		return nil
 	}
 	nv := vt.loadOrNewNode(id)
-	nv.init(vt.clock.Now(), &vt.refBasket.reqValues)
+	reqTypeCount := len(vt.refBasket.reqValues)
+	nv.reqCosts = make([]uint64, reqTypeCount)
+	nv.lastTransfer = vt.clock.Now()
+	nv.reqValues = &vt.refBasket.reqValues
+	nv.basket.init(reqTypeCount)
+
 	vt.connected[id] = nv
 	return nv
 }
@@ -364,7 +389,7 @@ func (vt *ValueTracker) loadOrNewNode(id enode.ID) *NodeValueTracker {
 	if nv, ok := vt.connected[id]; ok {
 		return nv
 	}
-	nv := &NodeValueTracker{lastTransfer: vt.clock.Now()}
+	nv := &NodeValueTracker{vt: vt, lastTransfer: vt.clock.Now()}
 	enc, err := vt.db.Get(append(vtNodeKey, id[:]...))
 	if err != nil {
 		return nv
@@ -425,14 +450,6 @@ func (vt *ValueTracker) saveNode(id enode.ID, nv *NodeValueTracker) {
 	}
 }
 
-// UpdateCosts updates the node value tracker's request cost table
-func (vt *ValueTracker) UpdateCosts(nv *NodeValueTracker, reqCosts []uint64) {
-	vt.lock.Lock()
-	defer vt.lock.Unlock()
-
-	nv.updateCosts(reqCosts, &vt.refBasket.reqValues, vt.refBasket.reqValueFactor(reqCosts))
-}
-
 // RtStats returns the global response time distribution statistics
 func (vt *ValueTracker) RtStats() ResponseTimeStats {
 	vt.lock.Lock()
@@ -464,28 +481,6 @@ func (vt *ValueTracker) periodicUpdate() {
 	vt.saveToDb()
 }
 
-type ServedRequest struct {
-	ReqType, Amount uint32
-}
-
-// Served adds a served request to the node's statistics. An actual request may be composed
-// of one or more request types (service vector indices).
-func (vt *ValueTracker) Served(nv *NodeValueTracker, reqs []ServedRequest, respTime time.Duration) {
-	vt.statsExpLock.RLock()
-	expFactor := vt.statsExpFactor
-	vt.statsExpLock.RUnlock()
-
-	nv.lock.Lock()
-	defer nv.lock.Unlock()
-
-	var value float64
-	for _, r := range reqs {
-		nv.basket.add(r.ReqType, r.Amount, nv.reqCosts[r.ReqType]*uint64(r.Amount), expFactor)
-		value += (*nv.reqValues)[r.ReqType] * float64(r.Amount)
-	}
-	nv.rtStats.Add(respTime, value, vt.statsExpFactor)
-}
-
 type RequestStatsItem struct {
 	Name                string
 	ReqAmount, ReqValue float64
diff --git a/les/vflux/client/valuetracker_test.go b/les/vflux/client/valuetracker_test.go
index ad398749e..87a337be8 100644
--- a/les/vflux/client/valuetracker_test.go
+++ b/les/vflux/client/valuetracker_test.go
@@ -64,7 +64,7 @@ func TestValueTracker(t *testing.T) {
 			for j := range costList {
 				costList[j] = uint64(baseCost * relPrices[j])
 			}
-			vt.UpdateCosts(nodes[i], costList)
+			nodes[i].UpdateCosts(costList)
 		}
 		for i := range nodes {
 			nodes[i] = vt.Register(enode.ID{byte(i)})
@@ -77,7 +77,7 @@ func TestValueTracker(t *testing.T) {
 				node := rand.Intn(testNodeCount)
 				respTime := time.Duration((rand.Float64() + 1) * float64(time.Second) * float64(node+1) / testNodeCount)
 				totalAmount[reqType] += uint64(reqAmount)
-				vt.Served(nodes[node], []ServedRequest{{uint32(reqType), uint32(reqAmount)}}, respTime)
+				nodes[node].Served([]ServedRequest{{uint32(reqType), uint32(reqAmount)}}, respTime)
 				clock.Run(time.Second)
 			}
 		} else {
-- 
GitLab