From 489566567073f697c619a61217ab5bd810bc833b Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Mon, 25 Jun 2018 16:52:25 +0800
Subject: [PATCH] les: handle conn/disc/reg logic in the eventloop (#16981)

* les: handle conn/disc/reg logic in the eventloop

* les: try to dial before start eventloop

* les: handle disconnect logic more safely

* les: grammar fix
---
 les/serverpool.go | 209 ++++++++++++++++++++++++++++++----------------
 1 file changed, 135 insertions(+), 74 deletions(-)

diff --git a/les/serverpool.go b/les/serverpool.go
index a39f88355..1a4c75229 100644
--- a/les/serverpool.go
+++ b/les/serverpool.go
@@ -87,6 +87,27 @@ const (
 	initStatsWeight = 1
 )
 
+// connReq represents a request for peer connection.
+type connReq struct {
+	p      *peer
+	ip     net.IP
+	port   uint16
+	result chan *poolEntry
+}
+
+// disconnReq represents a request for peer disconnection.
+type disconnReq struct {
+	entry   *poolEntry
+	stopped bool
+	done    chan struct{}
+}
+
+// registerReq represents a request for peer registration.
+type registerReq struct {
+	entry *poolEntry
+	done  chan struct{}
+}
+
 // serverPool implements a pool for storing and selecting newly discovered and already
 // known light server nodes. It received discovered nodes, stores statistics about
 // known nodes and takes care of always having enough good quality servers connected.
@@ -105,10 +126,13 @@ type serverPool struct {
 	discLookups   chan bool
 
 	entries              map[discover.NodeID]*poolEntry
-	lock                 sync.Mutex
 	timeout, enableRetry chan *poolEntry
 	adjustStats          chan poolStatAdjust
 
+	connCh     chan *connReq
+	disconnCh  chan *disconnReq
+	registerCh chan *registerReq
+
 	knownQueue, newQueue       poolEntryQueue
 	knownSelect, newSelect     *weightedRandomSelect
 	knownSelected, newSelected int
@@ -125,6 +149,9 @@ func newServerPool(db ethdb.Database, quit chan struct{}, wg *sync.WaitGroup) *s
 		timeout:      make(chan *poolEntry, 1),
 		adjustStats:  make(chan poolStatAdjust, 100),
 		enableRetry:  make(chan *poolEntry, 1),
+		connCh:       make(chan *connReq),
+		disconnCh:    make(chan *disconnReq),
+		registerCh:   make(chan *registerReq),
 		knownSelect:  newWeightedRandomSelect(),
 		newSelect:    newWeightedRandomSelect(),
 		fastDiscover: true,
@@ -147,9 +174,8 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
 		pool.discLookups = make(chan bool, 100)
 		go pool.server.DiscV5.SearchTopic(pool.topic, pool.discSetPeriod, pool.discNodes, pool.discLookups)
 	}
-
-	go pool.eventLoop()
 	pool.checkDial()
+	go pool.eventLoop()
 }
 
 // connect should be called upon any incoming connection. If the connection has been
@@ -158,83 +184,44 @@ func (pool *serverPool) start(server *p2p.Server, topic discv5.Topic) {
 // Note that whenever a connection has been accepted and a pool entry has been returned,
 // disconnect should also always be called.
 func (pool *serverPool) connect(p *peer, ip net.IP, port uint16) *poolEntry {
-	pool.lock.Lock()
-	defer pool.lock.Unlock()
-	entry := pool.entries[p.ID()]
-	if entry == nil {
-		entry = pool.findOrNewNode(p.ID(), ip, port)
-	}
-	p.Log().Debug("Connecting to new peer", "state", entry.state)
-	if entry.state == psConnected || entry.state == psRegistered {
+	log.Debug("Connect new entry", "enode", p.id)
+	req := &connReq{p: p, ip: ip, port: port, result: make(chan *poolEntry, 1)}
+	select {
+	case pool.connCh <- req:
+	case <-pool.quit:
 		return nil
 	}
-	pool.connWg.Add(1)
-	entry.peer = p
-	entry.state = psConnected
-	addr := &poolEntryAddress{
-		ip:       ip,
-		port:     port,
-		lastSeen: mclock.Now(),
-	}
-	entry.lastConnected = addr
-	entry.addr = make(map[string]*poolEntryAddress)
-	entry.addr[addr.strKey()] = addr
-	entry.addrSelect = *newWeightedRandomSelect()
-	entry.addrSelect.update(addr)
-	return entry
+	return <-req.result
 }
 
 // registered should be called after a successful handshake
 func (pool *serverPool) registered(entry *poolEntry) {
 	log.Debug("Registered new entry", "enode", entry.id)
-	pool.lock.Lock()
-	defer pool.lock.Unlock()
-
-	entry.state = psRegistered
-	entry.regTime = mclock.Now()
-	if !entry.known {
-		pool.newQueue.remove(entry)
-		entry.known = true
+	req := &registerReq{entry: entry, done: make(chan struct{})}
+	select {
+	case pool.registerCh <- req:
+	case <-pool.quit:
+		return
 	}
-	pool.knownQueue.setLatest(entry)
-	entry.shortRetry = shortRetryCnt
+	<-req.done
 }
 
 // disconnect should be called when ending a connection. Service quality statistics
 // can be updated optionally (not updated if no registration happened, in this case
 // only connection statistics are updated, just like in case of timeout)
 func (pool *serverPool) disconnect(entry *poolEntry) {
-	log.Debug("Disconnected old entry", "enode", entry.id)
-	pool.lock.Lock()
-	defer pool.lock.Unlock()
-
-	if entry.state == psRegistered {
-		connTime := mclock.Now() - entry.regTime
-		connAdjust := float64(connTime) / float64(targetConnTime)
-		if connAdjust > 1 {
-			connAdjust = 1
-		}
-		stopped := false
-		select {
-		case <-pool.quit:
-			stopped = true
-		default:
-		}
-		if stopped {
-			entry.connectStats.add(1, connAdjust)
-		} else {
-			entry.connectStats.add(connAdjust, 1)
-		}
+	stopped := false
+	select {
+	case <-pool.quit:
+		stopped = true
+	default:
 	}
+	log.Debug("Disconnected old entry", "enode", entry.id)
+	req := &disconnReq{entry: entry, stopped: stopped, done: make(chan struct{})}
 
-	entry.state = psNotConnected
-	if entry.knownSelected {
-		pool.knownSelected--
-	} else {
-		pool.newSelected--
-	}
-	pool.setRetryDial(entry)
-	pool.connWg.Done()
+	// Block until disconnection request is served.
+	pool.disconnCh <- req
+	<-req.done
 }
 
 const (
@@ -277,25 +264,51 @@ func (pool *serverPool) eventLoop() {
 	if pool.discSetPeriod != nil {
 		pool.discSetPeriod <- time.Millisecond * 100
 	}
+
+	// disconnect updates service quality statistics depending on the connection time
+	// and disconnection initiator.
+	disconnect := func(req *disconnReq, stopped bool) {
+		// Handle peer disconnection requests.
+		entry := req.entry
+		if entry.state == psRegistered {
+			connAdjust := float64(mclock.Now()-entry.regTime) / float64(targetConnTime)
+			if connAdjust > 1 {
+				connAdjust = 1
+			}
+			if stopped {
+				// disconnect requested by ourselves.
+				entry.connectStats.add(1, connAdjust)
+			} else {
+				// disconnect requested by server side.
+				entry.connectStats.add(connAdjust, 1)
+			}
+		}
+		entry.state = psNotConnected
+
+		if entry.knownSelected {
+			pool.knownSelected--
+		} else {
+			pool.newSelected--
+		}
+		pool.setRetryDial(entry)
+		pool.connWg.Done()
+		close(req.done)
+	}
+
 	for {
 		select {
 		case entry := <-pool.timeout:
-			pool.lock.Lock()
 			if !entry.removed {
 				pool.checkDialTimeout(entry)
 			}
-			pool.lock.Unlock()
 
 		case entry := <-pool.enableRetry:
-			pool.lock.Lock()
 			if !entry.removed {
 				entry.delayedRetry = false
 				pool.updateCheckDial(entry)
 			}
-			pool.lock.Unlock()
 
 		case adj := <-pool.adjustStats:
-			pool.lock.Lock()
 			switch adj.adjustType {
 			case pseBlockDelay:
 				adj.entry.delayStats.add(float64(adj.time), 1)
@@ -305,13 +318,10 @@ func (pool *serverPool) eventLoop() {
 			case pseResponseTimeout:
 				adj.entry.timeoutStats.add(1, 1)
 			}
-			pool.lock.Unlock()
 
 		case node := <-pool.discNodes:
-			pool.lock.Lock()
 			entry := pool.findOrNewNode(discover.NodeID(node.ID), node.IP, node.TCP)
 			pool.updateCheckDial(entry)
-			pool.lock.Unlock()
 
 		case conv := <-pool.discLookups:
 			if conv {
@@ -327,15 +337,66 @@ func (pool *serverPool) eventLoop() {
 				}
 			}
 
+		case req := <-pool.connCh:
+			// Handle peer connection requests.
+			entry := pool.entries[req.p.ID()]
+			if entry == nil {
+				entry = pool.findOrNewNode(req.p.ID(), req.ip, req.port)
+			}
+			if entry.state == psConnected || entry.state == psRegistered {
+				req.result <- nil
+				continue
+			}
+			pool.connWg.Add(1)
+			entry.peer = req.p
+			entry.state = psConnected
+			addr := &poolEntryAddress{
+				ip:       req.ip,
+				port:     req.port,
+				lastSeen: mclock.Now(),
+			}
+			entry.lastConnected = addr
+			entry.addr = make(map[string]*poolEntryAddress)
+			entry.addr[addr.strKey()] = addr
+			entry.addrSelect = *newWeightedRandomSelect()
+			entry.addrSelect.update(addr)
+			req.result <- entry
+
+		case req := <-pool.registerCh:
+			// Handle peer registration requests.
+			entry := req.entry
+			entry.state = psRegistered
+			entry.regTime = mclock.Now()
+			if !entry.known {
+				pool.newQueue.remove(entry)
+				entry.known = true
+			}
+			pool.knownQueue.setLatest(entry)
+			entry.shortRetry = shortRetryCnt
+			close(req.done)
+
+		case req := <-pool.disconnCh:
+			// Handle peer disconnection requests.
+			disconnect(req, req.stopped)
+
 		case <-pool.quit:
 			if pool.discSetPeriod != nil {
 				close(pool.discSetPeriod)
 			}
-			pool.connWg.Wait()
+
+			// Spawn a goroutine to close the disconnCh after all connections are disconnected.
+			go func() {
+				pool.connWg.Wait()
+				close(pool.disconnCh)
+			}()
+
+			// Handle all remaining disconnection requests before exit.
+			for req := range pool.disconnCh {
+				disconnect(req, true)
+			}
 			pool.saveNodes()
 			pool.wg.Done()
 			return
-
 		}
 	}
 }
-- 
GitLab