From a902880e99c5e46d23124ef702298b151d13ce7e Mon Sep 17 00:00:00 2001
From: TBC Dev <48684072+tbcd@users.noreply.github.com>
Date: Mon, 15 Nov 2021 11:07:57 +0800
Subject: [PATCH] Refactor sentry peers (#2961)

---
 cmd/sentry/download/sentry.go | 135 ++++++++++++++++------------------
 1 file changed, 62 insertions(+), 73 deletions(-)

diff --git a/cmd/sentry/download/sentry.go b/cmd/sentry/download/sentry.go
index 4f014d7d8d..017431205d 100644
--- a/cmd/sentry/download/sentry.go
+++ b/cmd/sentry/download/sentry.go
@@ -452,7 +452,7 @@ func NewSentryServer(ctx context.Context, dialCandidates enode.Iterator, readNod
 		DialCandidates: dialCandidates,
 		Run: func(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
 			peerID := peer.ID().String()
-			if _, ok := ss.GoodPeers.Load(peerID); ok {
+			if ss.getPeer(peerID) != nil {
 				log.Trace(fmt.Sprintf("[%s] Peer already has connection", peerID))
 				return nil
 			}
@@ -491,11 +491,11 @@ func NewSentryServer(ctx context.Context, dialCandidates enode.Iterator, readNod
 			return readNodeInfo()
 		},
 		PeerInfo: func(id enode.ID) interface{} {
-			p, ok := ss.GoodPeers.Load(id.String())
-			if !ok {
-				return nil
+			peerID := id.String()
+			if peerInfo := ss.getPeer(peerID); peerInfo != nil {
+				return peerInfo.peer.Info()
 			}
-			return p.(*PeerInfo).peer.Info()
+			return nil
 		},
 		//Attributes: []enr.Entry{eth.CurrentENREntry(chainConfig, genesisHash, headHeight)},
 	}
@@ -540,6 +540,46 @@ type SentryServerImpl struct {
 	p2p                  *p2p.Config
 }
 
+func (ss *SentryServerImpl) rangePeers(f func(peerID string, peerInfo *PeerInfo) bool) {
+	ss.GoodPeers.Range(func(key, value interface{}) bool {
+		peerInfo, _ := value.(*PeerInfo)
+		if peerInfo == nil {
+			return true
+		}
+		peerID := key.(string)
+		return f(peerID, peerInfo)
+	})
+}
+
+func (ss *SentryServerImpl) getPeer(peerID string) (peerInfo *PeerInfo) {
+	if value, ok := ss.GoodPeers.Load(peerID); ok {
+		peerInfo := value.(*PeerInfo)
+		if peerInfo != nil {
+			return peerInfo
+		}
+		ss.GoodPeers.Delete(peerID)
+	}
+	return nil
+}
+
+func (ss *SentryServerImpl) removePeer(peerID string) {
+	if value, ok := ss.GoodPeers.LoadAndDelete(peerID); ok {
+		peerInfo := value.(*PeerInfo)
+		if peerInfo != nil {
+			peerInfo.Remove()
+		}
+	}
+}
+
+func (ss *SentryServerImpl) writePeer(peerID string, peerInfo *PeerInfo, msgcode uint64, data []byte) error {
+	err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(data)), Payload: bytes.NewReader(data)})
+	if err != nil {
+		peerInfo.Remove()
+		ss.GoodPeers.Delete(peerID)
+	}
+	return err
+}
+
 func (ss *SentryServerImpl) startSync(ctx context.Context, bestHash common.Hash, peerID string) error {
 	switch ss.Protocol.Version {
 	case eth.ETH66:
@@ -570,26 +610,17 @@ func (ss *SentryServerImpl) startSync(ctx context.Context, bestHash common.Hash,
 
 func (ss *SentryServerImpl) PenalizePeer(_ context.Context, req *proto_sentry.PenalizePeerRequest) (*emptypb.Empty, error) {
 	//log.Warn("Received penalty", "kind", req.GetPenalty().Descriptor().FullName, "from", fmt.Sprintf("%s", req.GetPeerId()))
-	strId := string(gointerfaces.ConvertH512ToBytes(req.PeerId))
-	if x, ok := ss.GoodPeers.Load(strId); ok {
-		peerInfo := x.(*PeerInfo)
-		if peerInfo != nil {
-			peerInfo.Remove()
-		}
-	}
-	ss.GoodPeers.Delete(strId)
+	peerID := string(gointerfaces.ConvertH512ToBytes(req.PeerId))
+	ss.removePeer(peerID)
 	return &emptypb.Empty{}, nil
 }
 
 func (ss *SentryServerImpl) PeerMinBlock(_ context.Context, req *proto_sentry.PeerMinBlockRequest) (*emptypb.Empty, error) {
 	peerID := string(gointerfaces.ConvertH512ToBytes(req.PeerId))
-	x, _ := ss.GoodPeers.Load(peerID)
-	peerInfo, _ := x.(*PeerInfo)
-	if peerInfo == nil {
-		return &emptypb.Empty{}, nil
-	}
-	if req.MinBlock > peerInfo.Height() {
-		peerInfo.SetHeight(req.MinBlock)
+	if peerInfo := ss.getPeer(peerID); peerInfo != nil {
+		if req.MinBlock > peerInfo.Height() {
+			peerInfo.SetHeight(req.MinBlock)
+		}
 	}
 	return &emptypb.Empty{}, nil
 }
@@ -600,13 +631,7 @@ func (ss *SentryServerImpl) findPeer(minBlock uint64) (string, *PeerInfo, bool)
 	var foundPeerInfo *PeerInfo
 	var maxPermits int
 	now := time.Now()
-	ss.GoodPeers.Range(func(key, value interface{}) bool {
-		peerID := key.(string)
-		x, _ := ss.GoodPeers.Load(peerID)
-		peerInfo, _ := x.(*PeerInfo)
-		if peerInfo == nil {
-			return true
-		}
+	ss.rangePeers(func(peerID string, peerInfo *PeerInfo) bool {
 		if peerInfo.Height() >= minBlock {
 			deadlines := peerInfo.ClearDeadlines(now, false /* givePermit */)
 			//fmt.Printf("%d deadlines for peer %s\n", deadlines, peerID)
@@ -635,14 +660,7 @@ func (ss *SentryServerImpl) SendMessageByMinBlock(_ context.Context, inreq *prot
 		msgcode != eth.GetPooledTransactionsMsg {
 		return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageByMinBlock not implemented for message Id: %s", inreq.Data.Id)
 	}
-	if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(inreq.Data.Data)), Payload: bytes.NewReader(inreq.Data.Data)}); err != nil {
-		if x, ok := ss.GoodPeers.Load(peerID); ok {
-			peerInfo := x.(*PeerInfo)
-			if peerInfo != nil {
-				peerInfo.Remove()
-			}
-		}
-		ss.GoodPeers.Delete(peerID)
+	if err := ss.writePeer(peerID, peerInfo, msgcode, inreq.Data.Data); err != nil {
 		return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageByMinBlock to peer %s: %w", peerID, err)
 	}
 	peerInfo.AddDeadline(time.Now().Add(30 * time.Second))
@@ -651,13 +669,12 @@ func (ss *SentryServerImpl) SendMessageByMinBlock(_ context.Context, inreq *prot
 
 func (ss *SentryServerImpl) SendMessageById(_ context.Context, inreq *proto_sentry.SendMessageByIdRequest) (*proto_sentry.SentPeers, error) {
 	peerID := string(gointerfaces.ConvertH512ToBytes(inreq.PeerId))
-	x, ok := ss.GoodPeers.Load(peerID)
-	if !ok {
+	peerInfo := ss.getPeer(peerID)
+	if peerInfo == nil {
 		//TODO: enable after support peer to sentry mapping
 		//return &proto_sentry.SentPeers{}, fmt.Errorf("peer not found: %s", peerID)
 		return &proto_sentry.SentPeers{}, nil
 	}
-	peerInfo := x.(*PeerInfo)
 	msgcode := eth.FromProto[ss.Protocol.Version][inreq.Data.Id]
 	if msgcode != eth.GetBlockHeadersMsg &&
 		msgcode != eth.BlockHeadersMsg &&
@@ -670,14 +687,7 @@ func (ss *SentryServerImpl) SendMessageById(_ context.Context, inreq *proto_sent
 		return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageById not implemented for message Id: %s", inreq.Data.Id)
 	}
 
-	if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(inreq.Data.Data)), Payload: bytes.NewReader(inreq.Data.Data)}); err != nil {
-		if x, ok := ss.GoodPeers.Load(peerID); ok {
-			peerInfo := x.(*PeerInfo)
-			if peerInfo != nil {
-				peerInfo.Remove()
-			}
-		}
-		ss.GoodPeers.Delete(peerID)
+	if err := ss.writePeer(peerID, peerInfo, msgcode, inreq.Data.Data); err != nil {
 		return &proto_sentry.SentPeers{}, fmt.Errorf("sendMessageById to peer %s: %w", peerID, err)
 	}
 	return &proto_sentry.SentPeers{Peers: []*proto_types.H512{inreq.PeerId}}, nil
@@ -692,7 +702,7 @@ func (ss *SentryServerImpl) SendMessageToRandomPeers(ctx context.Context, req *p
 	}
 
 	amount := uint64(0)
-	ss.GoodPeers.Range(func(key, value interface{}) bool {
+	ss.rangePeers(func(peerID string, peerInfo *PeerInfo) bool {
 		amount++
 		return true
 	})
@@ -705,15 +715,8 @@ func (ss *SentryServerImpl) SendMessageToRandomPeers(ctx context.Context, req *p
 	i := 0
 	var innerErr error
 	reply := &proto_sentry.SentPeers{Peers: []*proto_types.H512{}}
-	ss.GoodPeers.Range(func(key, value interface{}) bool {
-		peerID := key.(string)
-		peerInfo, _ := value.(*PeerInfo)
-		if peerInfo == nil {
-			return true
-		}
-		if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(req.Data.Data)), Payload: bytes.NewReader(req.Data.Data)}); err != nil {
-			peerInfo.Remove()
-			ss.GoodPeers.Delete(peerID)
+	ss.rangePeers(func(peerID string, peerInfo *PeerInfo) bool {
+		if err := ss.writePeer(peerID, peerInfo, msgcode, req.Data.Data); err != nil {
 			innerErr = err
 			return true
 		}
@@ -737,16 +740,8 @@ func (ss *SentryServerImpl) SendMessageToAll(ctx context.Context, req *proto_sen
 
 	var innerErr error
 	reply := &proto_sentry.SentPeers{Peers: []*proto_types.H512{}}
-	ss.GoodPeers.Range(func(key, value interface{}) bool {
-		peerID := key.(string)
-		peerInfo, _ := value.(*PeerInfo)
-		if peerInfo == nil {
-			return true
-		}
-
-		if err := peerInfo.rw.WriteMsg(p2p.Msg{Code: msgcode, Size: uint32(len(req.Data)), Payload: bytes.NewReader(req.Data)}); err != nil {
-			peerInfo.Remove()
-			ss.GoodPeers.Delete(peerID)
+	ss.rangePeers(func(peerID string, peerInfo *PeerInfo) bool {
+		if err := ss.writePeer(peerID, peerInfo, msgcode, req.Data); err != nil {
 			innerErr = err
 			return true
 		}
@@ -806,13 +801,7 @@ func (ss *SentryServerImpl) SetStatus(_ context.Context, statusData *proto_sentr
 }
 
 func (ss *SentryServerImpl) SimplePeerCount() (pc int) {
-	ss.GoodPeers.Range(func(key, value interface{}) bool {
-		peerID := key.(string)
-		x, _ := ss.GoodPeers.Load(peerID)
-		peerInfo, _ := x.(*PeerInfo)
-		if peerInfo == nil {
-			return true
-		}
+	ss.rangePeers(func(peerID string, peerInfo *PeerInfo) bool {
 		pc++
 		return true
 	})
-- 
GitLab