From c6649f584b6bc357bb10c0df564c72e4f919e4a9 Mon Sep 17 00:00:00 2001
From: battlmonstr <battlmonstr@users.noreply.github.com>
Date: Thu, 28 Apr 2022 04:21:52 +0200
Subject: [PATCH] p2p: refactor MaxPendingPeers handling (#3981)

* use semaphore instead of a chan struct{}
* move MaxPendingPeers default value to DefaultConfig.P2P
* log Error if Accept fails
* replace quit channel with context
---
 cmd/utils/flags.go                 |  2 +-
 node/defaults.go                   |  9 +--
 p2p/dial.go                        |  3 -
 p2p/server.go                      | 82 +++++++++++++++------------
 p2p/server_test.go                 | 89 ++++++++++++++++--------------
 p2p/simulations/adapters/inproc.go |  1 +
 6 files changed, 101 insertions(+), 85 deletions(-)

diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index e58cccb147..bc089079b9 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -491,7 +491,7 @@ var (
 	}
 	MaxPendingPeersFlag = cli.IntFlag{
 		Name:  "maxpendpeers",
-		Usage: "Maximum number of pending connection attempts (defaults used if set to 0)",
+		Usage: "Maximum number of TCP connections pending to become connected peers",
 		Value: node.DefaultConfig.P2P.MaxPendingPeers,
 	}
 	ListenPortFlag = cli.IntFlag{
diff --git a/node/defaults.go b/node/defaults.go
index 951571f899..34b31f63ab 100644
--- a/node/defaults.go
+++ b/node/defaults.go
@@ -43,9 +43,10 @@ var DefaultConfig = Config{
 	WSPort:           DefaultWSPort,
 	WSModules:        []string{"net", "web3"},
 	P2P: p2p.Config{
-		ListenAddr:   ":30303",
-		ListenAddr65: ":30304",
-		MaxPeers:     100,
-		NAT:          nat.Any(),
+		ListenAddr:      ":30303",
+		ListenAddr65:    ":30304",
+		MaxPeers:        100,
+		MaxPendingPeers: 50,
+		NAT:             nat.Any(),
 	},
 }
diff --git a/p2p/dial.go b/p2p/dial.go
index d15458f4d3..d1ce4d57d6 100644
--- a/p2p/dial.go
+++ b/p2p/dial.go
@@ -146,9 +146,6 @@ type dialConfig struct {
 }
 
 func (cfg dialConfig) withDefaults() dialConfig {
-	if cfg.maxActiveDials == 0 {
-		cfg.maxActiveDials = defaultMaxPendingPeers
-	}
 	if cfg.log == nil {
 		cfg.log = log.Root()
 	}
diff --git a/p2p/server.go b/p2p/server.go
index b11e8a182f..595052fc06 100644
--- a/p2p/server.go
+++ b/p2p/server.go
@@ -24,6 +24,7 @@ import (
 	"encoding/hex"
 	"errors"
 	"fmt"
+	"golang.org/x/sync/semaphore"
 	"net"
 	"sort"
 	"sync"
@@ -52,8 +53,7 @@ const (
 	discmixTimeout = 5 * time.Second
 
 	// Connectivity defaults.
-	defaultMaxPendingPeers = 50
-	defaultDialRatio       = 3
+	defaultDialRatio = 3
 
 	// This time limits inbound connection attempts per source IP.
 	inboundThrottleTime = 30 * time.Second
@@ -79,7 +79,7 @@ type Config struct {
 
 	// MaxPendingPeers is the maximum number of peers that can be pending in the
 	// handshake phase, counted separately for inbound and outbound connections.
-	// Zero defaults to preset values.
+	// It must be greater than zero.
 	MaxPendingPeers int `toml:",omitempty"`
 
 	// DialRatio controls the ratio of inbound to dialed connections.
@@ -191,7 +191,9 @@ type Server struct {
 	dialsched *dialScheduler
 
 	// Channels into the run loop.
-	quit                    chan struct{}
+	quitCtx                 context.Context
+	quitFunc                context.CancelFunc
+	quit                    <-chan struct{}
 	addtrusted              chan *enode.Node
 	removetrusted           chan *enode.Node
 	peerOp                  chan peerOpFunc
@@ -409,10 +411,10 @@ func (srv *Server) Stop() {
 		return
 	}
 	srv.running = false
-	close(srv.quit)
+	srv.quitFunc()
 	if srv.listener != nil {
 		// this unblocks listener Accept
-		srv.listener.Close()
+		_ = srv.listener.Close()
 	}
 	if srv.nodedb != nil {
 		srv.nodedb.Close()
@@ -476,13 +478,17 @@ func (srv *Server) Start(ctx context.Context) error {
 	if srv.PrivateKey == nil {
 		return errors.New("Server.PrivateKey must be set to a non-nil key")
 	}
+	if srv.MaxPendingPeers <= 0 {
+		return errors.New("MaxPendingPeers must be greater than zero")
+	}
 	if srv.newTransport == nil {
 		srv.newTransport = newRLPX
 	}
 	if srv.listenFunc == nil {
 		srv.listenFunc = net.Listen
 	}
-	srv.quit = make(chan struct{})
+	srv.quitCtx, srv.quitFunc = context.WithCancel(ctx)
+	srv.quit = srv.quitCtx.Done()
 	srv.delpeer = make(chan peerDrop)
 	srv.checkpointPostHandshake = make(chan *conn)
 	srv.checkpointAddPeer = make(chan *conn)
@@ -495,11 +501,11 @@ func (srv *Server) Start(ctx context.Context) error {
 		return err
 	}
 	if srv.ListenAddr != "" {
-		if err := srv.setupListening(); err != nil {
+		if err := srv.setupListening(srv.quitCtx); err != nil {
 			return err
 		}
 	}
-	if err := srv.setupDiscovery(ctx); err != nil {
+	if err := srv.setupDiscovery(srv.quitCtx); err != nil {
 		return err
 	}
 	srv.setupDialScheduler()
@@ -586,8 +592,8 @@ func (srv *Server) setupDiscovery(ctx context.Context) error {
 			srv.loopWG.Add(1)
 			go func() {
 				defer debug.LogPanic()
+				defer srv.loopWG.Done()
 				nat.Map(srv.NAT, srv.quit, "udp", realaddr.Port, realaddr.Port, "ethereum discovery")
-				srv.loopWG.Done()
 			}()
 		}
 	}
@@ -682,7 +688,7 @@ func (srv *Server) maxDialedConns() (limit int) {
 	return limit
 }
 
-func (srv *Server) setupListening() error {
+func (srv *Server) setupListening(ctx context.Context) error {
 	// Launch the listener.
 	listener, err := srv.listenFunc("tcp", srv.ListenAddr)
 	if err != nil {
@@ -698,14 +704,18 @@ func (srv *Server) setupListening() error {
 			srv.loopWG.Add(1)
 			go func() {
 				defer debug.LogPanic()
+				defer srv.loopWG.Done()
 				nat.Map(srv.NAT, srv.quit, "tcp", tcp.Port, tcp.Port, "ethereum p2p")
-				srv.loopWG.Done()
 			}()
 		}
 	}
 
 	srv.loopWG.Add(1)
-	go srv.listenLoop()
+	go func() {
+		defer debug.LogPanic()
+		defer srv.loopWG.Done()
+		srv.listenLoop(ctx)
+	}()
 	return nil
 }
 
@@ -857,32 +867,26 @@ func (srv *Server) addPeerChecks(peers map[enode.ID]*Peer, inboundCount int, c *
 
 // listenLoop runs in its own goroutine and accepts
 // inbound connections.
-func (srv *Server) listenLoop() {
-	defer debug.LogPanic()
+func (srv *Server) listenLoop(ctx context.Context) {
 	srv.log.Trace("TCP listener up", "addr", srv.listener.Addr())
 
-	// The slots channel limits accepts of new connections.
-	tokens := defaultMaxPendingPeers
-	if srv.MaxPendingPeers > 0 {
-		tokens = srv.MaxPendingPeers
-	}
-	slots := make(chan struct{}, tokens)
-	for i := 0; i < tokens; i++ {
-		slots <- struct{}{}
-	}
+	// The slots limit accepts of new connections.
+	slots := semaphore.NewWeighted(int64(srv.MaxPendingPeers))
 
 	// Wait for slots to be returned on exit. This ensures all connection goroutines
 	// are down before listenLoop returns.
-	defer srv.loopWG.Done()
 	defer func() {
-		for i := 0; i < cap(slots); i++ {
-			<-slots
-		}
+		_ = slots.Acquire(ctx, int64(srv.MaxPendingPeers))
 	}()
 
 	for {
 		// Wait for a free slot before accepting.
-		<-slots
+		if slotErr := slots.Acquire(ctx, 1); slotErr != nil {
+			if !errors.Is(slotErr, context.Canceled) {
+				srv.log.Error("Failed to get a peer connection slot", "err", slotErr)
+			}
+			return
+		}
 
 		var (
 			fd      net.Conn
@@ -899,8 +903,13 @@ func (srv *Server) listenLoop() {
 				time.Sleep(time.Millisecond * 200)
 				continue
 			} else if err != nil {
-				srv.log.Trace("Read error", "err", err)
-				slots <- struct{}{}
+				// Log the error unless the server is shutting down.
+				select {
+				case <-srv.quit:
+				default:
+					srv.log.Error("Server listener failed to accept a connection", "err", err)
+				}
+				slots.Release(1)
 				return
 			}
 			break
@@ -908,9 +917,9 @@ func (srv *Server) listenLoop() {
 
 		remoteIP := netutil.AddrIP(fd.RemoteAddr())
 		if err := srv.checkInboundConn(fd, remoteIP); err != nil {
-			srv.log.Trace("Rejected inbound connnection", "addr", fd.RemoteAddr(), "err", err)
-			fd.Close()
-			slots <- struct{}{}
+			srv.log.Trace("Rejected inbound connection", "addr", fd.RemoteAddr(), "err", err)
+			_ = fd.Close()
+			slots.Release(1)
 			continue
 		}
 		if remoteIP != nil {
@@ -923,8 +932,9 @@ func (srv *Server) listenLoop() {
 		}
 		go func() {
 			defer debug.LogPanic()
-			srv.SetupConn(fd, inboundConn, nil)
-			slots <- struct{}{}
+			defer slots.Release(1)
+			// The error is logged in Server.setupConn().
+			_ = srv.SetupConn(fd, inboundConn, nil)
 		}()
 	}
 }
diff --git a/p2p/server_test.go b/p2p/server_test.go
index ac5dcb7fd7..ad4a92f5c2 100644
--- a/p2p/server_test.go
+++ b/p2p/server_test.go
@@ -69,12 +69,13 @@ func (c *testTransport) close(err error) {
 
 func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server {
 	config := Config{
-		Name:        "test",
-		MaxPeers:    10,
-		ListenAddr:  "127.0.0.1:0",
-		NoDiscovery: true,
-		PrivateKey:  newkey(),
-		Log:         testlog.Logger(t, log.LvlError),
+		Name:            "test",
+		MaxPeers:        10,
+		MaxPendingPeers: 10,
+		ListenAddr:      "127.0.0.1:0",
+		NoDiscovery:     true,
+		PrivateKey:      newkey(),
+		Log:             testlog.Logger(t, log.LvlError),
 	}
 	server := &Server{
 		Config:      config,
@@ -211,18 +212,20 @@ func TestServerDial(t *testing.T) {
 // This test checks that RemovePeer disconnects the peer if it is connected.
 func TestServerRemovePeerDisconnect(t *testing.T) {
 	srv1 := &Server{Config: Config{
-		PrivateKey:  newkey(),
-		MaxPeers:    1,
-		NoDiscovery: true,
-		Log:         testlog.Logger(t, log.LvlTrace).New("server", "1"),
+		PrivateKey:      newkey(),
+		MaxPeers:        1,
+		MaxPendingPeers: 1,
+		NoDiscovery:     true,
+		Log:             testlog.Logger(t, log.LvlTrace).New("server", "1"),
 	}}
 	srv2 := &Server{Config: Config{
-		PrivateKey:  newkey(),
-		MaxPeers:    1,
-		NoDiscovery: true,
-		NoDial:      true,
-		ListenAddr:  "127.0.0.1:0",
-		Log:         testlog.Logger(t, log.LvlTrace).New("server", "2"),
+		PrivateKey:      newkey(),
+		MaxPeers:        1,
+		MaxPendingPeers: 1,
+		NoDiscovery:     true,
+		NoDial:          true,
+		ListenAddr:      "127.0.0.1:0",
+		Log:             testlog.Logger(t, log.LvlTrace).New("server", "2"),
 	}}
 	if err := srv1.TestStart(); err != nil {
 		t.Fatal("cant start srv1")
@@ -249,12 +252,13 @@ func TestServerAtCap(t *testing.T) {
 	trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey)
 	srv := &Server{
 		Config: Config{
-			PrivateKey:   newkey(),
-			MaxPeers:     10,
-			NoDial:       true,
-			NoDiscovery:  true,
-			TrustedNodes: []*enode.Node{newNode(trustedID, "")},
-			Log:          testlog.Logger(t, log.LvlTrace),
+			PrivateKey:      newkey(),
+			MaxPeers:        10,
+			MaxPendingPeers: 10,
+			NoDial:          true,
+			NoDiscovery:     true,
+			TrustedNodes:    []*enode.Node{newNode(trustedID, "")},
+			Log:             testlog.Logger(t, log.LvlTrace),
 		},
 	}
 	if err := srv.TestStart(); err != nil {
@@ -325,12 +329,13 @@ func TestServerPeerLimits(t *testing.T) {
 
 	srv := &Server{
 		Config: Config{
-			PrivateKey:  srvkey,
-			MaxPeers:    0,
-			NoDial:      true,
-			NoDiscovery: true,
-			Protocols:   []Protocol{discard},
-			Log:         testlog.Logger(t, log.LvlTrace),
+			PrivateKey:      srvkey,
+			MaxPeers:        0,
+			MaxPendingPeers: 50,
+			NoDial:          true,
+			NoDiscovery:     true,
+			Protocols:       []Protocol{discard},
+			Log:             testlog.Logger(t, log.LvlTrace),
 		},
 		newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return tp },
 	}
@@ -432,12 +437,13 @@ func TestServerSetupConn(t *testing.T) {
 	for i, test := range tests {
 		t.Run(test.wantCalls, func(t *testing.T) {
 			cfg := Config{
-				PrivateKey:  srvkey,
-				MaxPeers:    10,
-				NoDial:      true,
-				NoDiscovery: true,
-				Protocols:   []Protocol{discard},
-				Log:         testlog.Logger(t, log.LvlTrace),
+				PrivateKey:      srvkey,
+				MaxPeers:        10,
+				MaxPendingPeers: 10,
+				NoDial:          true,
+				NoDiscovery:     true,
+				Protocols:       []Protocol{discard},
+				Log:             testlog.Logger(t, log.LvlTrace),
 			}
 			srv := &Server{
 				Config:       cfg,
@@ -518,13 +524,14 @@ func TestServerInboundThrottle(t *testing.T) {
 	newTransportCalled := make(chan struct{})
 	srv := &Server{
 		Config: Config{
-			PrivateKey:  newkey(),
-			ListenAddr:  "127.0.0.1:0",
-			MaxPeers:    10,
-			NoDial:      true,
-			NoDiscovery: true,
-			Protocols:   []Protocol{discard},
-			Log:         testlog.Logger(t, log.LvlTrace),
+			PrivateKey:      newkey(),
+			ListenAddr:      "127.0.0.1:0",
+			MaxPeers:        10,
+			MaxPendingPeers: 10,
+			NoDial:          true,
+			NoDiscovery:     true,
+			Protocols:       []Protocol{discard},
+			Log:             testlog.Logger(t, log.LvlTrace),
 		},
 		newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport {
 			newTransportCalled <- struct{}{}
diff --git a/p2p/simulations/adapters/inproc.go b/p2p/simulations/adapters/inproc.go
index 8f505fc760..68a22002cc 100644
--- a/p2p/simulations/adapters/inproc.go
+++ b/p2p/simulations/adapters/inproc.go
@@ -96,6 +96,7 @@ func (s *SimAdapter) NewNode(config *NodeConfig) (Node, error) {
 		P2P: p2p.Config{
 			PrivateKey:      config.PrivateKey,
 			MaxPeers:        math.MaxInt32,
+			MaxPendingPeers: 50,
 			NoDiscovery:     true,
 			Dialer:          s,
 			EnableMsgEvents: config.EnableMsgEvents,
-- 
GitLab