From e4270cacf4aa26875affc619dbf82ad18d06226e Mon Sep 17 00:00:00 2001
From: rene <41963722+renaynay@users.noreply.github.com>
Date: Wed, 28 Apr 2021 21:38:38 +0200
Subject: [PATCH] cmd/devp2p: fix flaky SameRequestID test (#22754)

---
 cmd/devp2p/internal/ethtest/eth66_suite.go    | 71 ++++++++++++++-----
 .../internal/ethtest/eth66_suiteHelpers.go    | 38 ++++++----
 2 files changed, 78 insertions(+), 31 deletions(-)

diff --git a/cmd/devp2p/internal/ethtest/eth66_suite.go b/cmd/devp2p/internal/ethtest/eth66_suite.go
index 4265b25f6..41177189d 100644
--- a/cmd/devp2p/internal/ethtest/eth66_suite.go
+++ b/cmd/devp2p/internal/ethtest/eth66_suite.go
@@ -76,9 +76,14 @@ func (s *Suite) TestGetBlockHeaders_66(t *utesting.T) {
 		},
 	}
 	// write message
-	headers := s.getBlockHeaders66(t, conn, req, req.RequestId)
+	headers, err := s.getBlockHeaders66(conn, req, req.RequestId)
+	if err != nil {
+		t.Fatalf("could not get block headers: %v", err)
+	}
 	// check for correct headers
-	headersMatch(t, s.chain, headers)
+	if !headersMatch(t, s.chain, headers) {
+		t.Fatal("received wrong header(s)")
+	}
 }
 
 // TestSimultaneousRequests_66 sends two simultaneous `GetBlockHeader` requests
@@ -115,12 +120,25 @@ func (s *Suite) TestSimultaneousRequests_66(t *utesting.T) {
 	// wait for headers for first request
 	headerChan := make(chan BlockHeaders, 1)
 	go func(headers chan BlockHeaders) {
-		headers <- s.getBlockHeaders66(t, conn1, req1, req1.RequestId)
+		recvHeaders, err := s.getBlockHeaders66(conn1, req1, req1.RequestId)
+		if err != nil {
+			t.Fatalf("could not get block headers: %v", err)
+			return
+		}
+		headers <- recvHeaders
 	}(headerChan)
 	// check headers of second request
-	headersMatch(t, s.chain, s.getBlockHeaders66(t, conn2, req2, req2.RequestId))
+	headers1, err := s.getBlockHeaders66(conn2, req2, req2.RequestId)
+	if err != nil {
+		t.Fatalf("could not get block headers: %v", err)
+	}
+	if !headersMatch(t, s.chain, headers1) {
+		t.Fatal("wrong header(s) in response to req2")
+	}
 	// check headers of first request
-	headersMatch(t, s.chain, <-headerChan)
+	if !headersMatch(t, s.chain, <-headerChan) {
+		t.Fatal("wrong header(s) in response to req1")
+	}
 }
 
 // TestBroadcast_66 tests whether a block announcement is correctly
@@ -377,26 +395,31 @@ func (s *Suite) TestZeroRequestID_66(t *utesting.T) {
 			Amount: 2,
 		},
 	}
-	headersMatch(t, s.chain, s.getBlockHeaders66(t, conn, req, req.RequestId))
+	headers, err := s.getBlockHeaders66(conn, req, req.RequestId)
+	if err != nil {
+		t.Fatalf("could not get block headers: %v", err)
+	}
+	if !headersMatch(t, s.chain, headers) {
+		t.Fatal("received wrong header(s)")
+	}
 }
 
 // TestSameRequestID_66 sends two requests with the same request ID
 // concurrently to a single node.
 func (s *Suite) TestSameRequestID_66(t *utesting.T) {
 	conn := s.setupConnection66(t)
-	defer conn.Close()
-	// create two separate requests with same ID
+	// create two requests with the same request ID
 	reqID := uint64(1234)
-	req1 := &eth.GetBlockHeadersPacket66{
+	request1 := &eth.GetBlockHeadersPacket66{
 		RequestId: reqID,
 		GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
 			Origin: eth.HashOrNumber{
-				Number: 0,
+				Number: 1,
 			},
 			Amount: 2,
 		},
 	}
-	req2 := &eth.GetBlockHeadersPacket66{
+	request2 := &eth.GetBlockHeadersPacket66{
 		RequestId: reqID,
 		GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
 			Origin: eth.HashOrNumber{
@@ -405,12 +428,26 @@ func (s *Suite) TestSameRequestID_66(t *utesting.T) {
 			Amount: 2,
 		},
 	}
-	// send requests concurrently
-	go func() {
-		headersMatch(t, s.chain, s.getBlockHeaders66(t, conn, req2, reqID))
-	}()
-	// check response from first request
-	headersMatch(t, s.chain, s.getBlockHeaders66(t, conn, req1, reqID))
+	// write the first request
+	err := conn.write66(request1, GetBlockHeaders{}.Code())
+	if err != nil {
+		t.Fatalf("could not write to connection: %v", err)
+	}
+	// perform second request
+	headers2, err := s.getBlockHeaders66(conn, request2, reqID)
+	if err != nil {
+		t.Fatalf("could not get block headers: %v", err)
+		return
+	}
+	// wait for response to first request
+	headers1, err := s.waitForBlockHeadersResponse66(conn, reqID)
+	if err != nil {
+		t.Fatalf("could not get BlockHeaders response: %v", err)
+	}
+	// check if headers match
+	if !headersMatch(t, s.chain, headers1) || !headersMatch(t, s.chain, headers2) {
+		t.Fatal("received wrong header(s)")
+	}
 }
 
 // TestLargeTxRequest_66 tests whether a node can fulfill a large GetPooledTransactions
diff --git a/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go b/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go
index 3af8295c6..fec02b524 100644
--- a/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go
+++ b/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go
@@ -18,6 +18,7 @@ package ethtest
 
 import (
 	"fmt"
+	"reflect"
 	"time"
 
 	"github.com/ethereum/go-ethereum/core/types"
@@ -150,8 +151,7 @@ func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID ui
 func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
 	start := time.Now()
 	for time.Since(start) < timeout {
-		timeout := time.Now().Add(10 * time.Second)
-		c.SetReadDeadline(timeout)
+		c.SetReadDeadline(time.Now().Add(10 * time.Second))
 
 		reqID, msg := c.read66()
 
@@ -257,6 +257,9 @@ func (c *Conn) waitForBlock66(block *types.Block) error {
 				return nil
 			}
 			time.Sleep(100 * time.Millisecond)
+		case *NewPooledTransactionHashes:
+			// ignore old announcements
+			continue
 		default:
 			return fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
 		}
@@ -269,31 +272,38 @@ func sendSuccessfulTx66(t *utesting.T, s *Suite, tx *types.Transaction) {
 	sendSuccessfulTxWithConn(t, s, tx, sendConn)
 }
 
-func (s *Suite) getBlockHeaders66(t *utesting.T, conn *Conn, req eth.Packet, expectedID uint64) BlockHeaders {
-	if err := conn.write66(req, GetBlockHeaders{}.Code()); err != nil {
-		t.Fatalf("could not write to connection: %v", err)
-	}
-	// check block headers response
+// waitForBlockHeadersResponse66 waits for a BlockHeaders message with the given expected request ID
+func (s *Suite) waitForBlockHeadersResponse66(conn *Conn, expectedID uint64) (BlockHeaders, error) {
 	reqID, msg := conn.readAndServe66(s.chain, timeout)
-
 	switch msg := msg.(type) {
 	case BlockHeaders:
 		if reqID != expectedID {
-			t.Fatalf("request ID mismatch: wanted %d, got %d", expectedID, reqID)
+			return nil, fmt.Errorf("request ID mismatch: wanted %d, got %d", expectedID, reqID)
 		}
-		return msg
+		return msg, nil
 	default:
-		t.Fatalf("unexpected: %s", pretty.Sdump(msg))
-		return nil
+		return nil, fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
 	}
 }
 
-func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) {
+func (s *Suite) getBlockHeaders66(conn *Conn, req eth.Packet, expectedID uint64) (BlockHeaders, error) {
+	if err := conn.write66(req, GetBlockHeaders{}.Code()); err != nil {
+		return nil, fmt.Errorf("could not write to connection: %v", err)
+	}
+	return s.waitForBlockHeadersResponse66(conn, expectedID)
+}
+
+func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) bool {
+	mismatched := 0
 	for _, header := range headers {
 		num := header.Number.Uint64()
 		t.Logf("received header (%d): %s", num, pretty.Sdump(header.Hash()))
-		assert.Equal(t, chain.blocks[int(num)].Header(), header)
+		if !reflect.DeepEqual(chain.blocks[int(num)].Header(), header) {
+			mismatched += 1
+			t.Logf("received wrong header: %v", pretty.Sdump(header))
+		}
 	}
+	return mismatched == 0
 }
 
 func (s *Suite) sendNextBlock66(t *utesting.T) {
-- 
GitLab