From b9d4412715ccacd0a8adac13b0b587db193936c6 Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Wed, 7 Jul 2021 17:28:14 +0200
Subject: [PATCH] cmd/devp2p: fixes for eth and discv4 tests (#23155)

This PR fixes a false positive PONG 'to' endpoint mismatch seen in hive tests:

    got {IP:172.17.0.7 UDP:44025 TCP:44025}, want {IP:172.17.0.7 UDP:44025 TCP:0}

Co-authored-by: Felix Lange <fjl@twurst.com>
---
 cmd/devp2p/internal/ethtest/helpers.go    | 81 +++++++++++++----------
 cmd/devp2p/internal/v4test/discv4tests.go | 13 ++--
 2 files changed, 53 insertions(+), 41 deletions(-)

diff --git a/cmd/devp2p/internal/ethtest/helpers.go b/cmd/devp2p/internal/ethtest/helpers.go
index a9a213f33..6f7365483 100644
--- a/cmd/devp2p/internal/ethtest/helpers.go
+++ b/cmd/devp2p/internal/ethtest/helpers.go
@@ -24,6 +24,7 @@ import (
 	"time"
 
 	"github.com/davecgh/go-spew/spew"
+	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/eth/protocols/eth"
@@ -649,58 +650,68 @@ func (s *Suite) hashAnnounce(isEth66 bool) error {
 		return fmt.Errorf("peering failed: %v", err)
 	}
 	// create NewBlockHashes announcement
-	nextBlock := s.fullChain.blocks[s.chain.Len()]
-	newBlockHash := &NewBlockHashes{
-		{Hash: nextBlock.Hash(), Number: nextBlock.Number().Uint64()},
+	type anno struct {
+		Hash   common.Hash // Hash of one particular block being announced
+		Number uint64      // Number of one particular block being announced
 	}
-
+	nextBlock := s.fullChain.blocks[s.chain.Len()]
+	announcement := anno{Hash: nextBlock.Hash(), Number: nextBlock.Number().Uint64()}
+	newBlockHash := &NewBlockHashes{announcement}
 	if err := sendConn.Write(newBlockHash); err != nil {
 		return fmt.Errorf("failed to write to connection: %v", err)
 	}
+	// Announcement sent, now wait for a header request
+	var (
+		id             uint64
+		msg            Message
+		blockHeaderReq GetBlockHeaders
+	)
 	if isEth66 {
-		// expect GetBlockHeaders request, and respond
-		id, msg := sendConn.Read66()
+		id, msg = sendConn.Read66()
 		switch msg := msg.(type) {
 		case GetBlockHeaders:
-			blockHeaderReq := msg
-			if blockHeaderReq.Amount != 1 {
-				return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
-			}
-			if blockHeaderReq.Origin.Hash != nextBlock.Hash() {
-				return fmt.Errorf("unexpected block header requested: %v", pretty.Sdump(blockHeaderReq))
-			}
-			resp := &eth.BlockHeadersPacket66{
-				RequestId: id,
-				BlockHeadersPacket: eth.BlockHeadersPacket{
-					nextBlock.Header(),
-				},
-			}
-			if err := sendConn.Write66(resp, BlockHeaders{}.Code()); err != nil {
-				return fmt.Errorf("failed to write to connection: %v", err)
-			}
+			blockHeaderReq = msg
 		default:
 			return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
 		}
+		if blockHeaderReq.Amount != 1 {
+			return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
+		}
+		if blockHeaderReq.Origin.Hash != announcement.Hash {
+			return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
+				pretty.Sdump(announcement),
+				pretty.Sdump(blockHeaderReq))
+		}
+		if err := sendConn.Write66(&eth.BlockHeadersPacket66{
+			RequestId: id,
+			BlockHeadersPacket: eth.BlockHeadersPacket{
+				nextBlock.Header(),
+			},
+		}, BlockHeaders{}.Code()); err != nil {
+			return fmt.Errorf("failed to write to connection: %v", err)
+		}
 	} else {
-		// expect GetBlockHeaders request, and respond
-		switch msg := sendConn.Read().(type) {
+		msg = sendConn.Read()
+		switch msg := msg.(type) {
 		case *GetBlockHeaders:
-			blockHeaderReq := *msg
-			if blockHeaderReq.Amount != 1 {
-				return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
-			}
-			if blockHeaderReq.Origin.Hash != nextBlock.Hash() {
-				return fmt.Errorf("unexpected block header requested: %v", pretty.Sdump(blockHeaderReq))
-			}
-			if err := sendConn.Write(&BlockHeaders{nextBlock.Header()}); err != nil {
-				return fmt.Errorf("failed to write to connection: %v", err)
-			}
+			blockHeaderReq = *msg
 		default:
 			return fmt.Errorf("unexpected %s", pretty.Sdump(msg))
 		}
+		if blockHeaderReq.Amount != 1 {
+			return fmt.Errorf("unexpected number of block headers requested: %v", blockHeaderReq.Amount)
+		}
+		if blockHeaderReq.Origin.Hash != announcement.Hash {
+			return fmt.Errorf("unexpected block header requested. Announced:\n %v\n Remote request:\n%v",
+				pretty.Sdump(announcement),
+				pretty.Sdump(blockHeaderReq))
+		}
+		if err := sendConn.Write(&BlockHeaders{nextBlock.Header()}); err != nil {
+			return fmt.Errorf("failed to write to connection: %v", err)
+		}
 	}
 	// wait for block announcement
-	msg := recvConn.readAndServe(s.chain, timeout)
+	msg = recvConn.readAndServe(s.chain, timeout)
 	switch msg := msg.(type) {
 	case *NewBlockHashes:
 		hashes := *msg
diff --git a/cmd/devp2p/internal/v4test/discv4tests.go b/cmd/devp2p/internal/v4test/discv4tests.go
index 140b96bfa..1b5e5304e 100644
--- a/cmd/devp2p/internal/v4test/discv4tests.go
+++ b/cmd/devp2p/internal/v4test/discv4tests.go
@@ -21,7 +21,6 @@ import (
 	"crypto/rand"
 	"fmt"
 	"net"
-	"reflect"
 	"time"
 
 	"github.com/ethereum/go-ethereum/crypto"
@@ -89,16 +88,18 @@ func BasicPing(t *utesting.T) {
 
 // checkPong verifies that reply is a valid PONG matching the given ping hash.
 func (te *testenv) checkPong(reply v4wire.Packet, pingHash []byte) error {
-	if reply == nil || reply.Kind() != v4wire.PongPacket {
-		return fmt.Errorf("expected PONG reply, got %v", reply)
+	if reply == nil {
+		return fmt.Errorf("expected PONG reply, got nil")
+	}
+	if reply.Kind() != v4wire.PongPacket {
+		return fmt.Errorf("expected PONG reply, got %v %v", reply.Name(), reply)
 	}
 	pong := reply.(*v4wire.Pong)
 	if !bytes.Equal(pong.ReplyTok, pingHash) {
 		return fmt.Errorf("PONG reply token mismatch: got %x, want %x", pong.ReplyTok, pingHash)
 	}
-	wantEndpoint := te.localEndpoint(te.l1)
-	if !reflect.DeepEqual(pong.To, wantEndpoint) {
-		return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, wantEndpoint)
+	if want := te.localEndpoint(te.l1); !want.IP.Equal(pong.To.IP) || want.UDP != pong.To.UDP {
+		return fmt.Errorf("PONG 'to' endpoint mismatch: got %+v, want %+v", pong.To, want)
 	}
 	if v4wire.Expired(pong.Expiration) {
 		return fmt.Errorf("PONG is expired (%v)", pong.Expiration)
-- 
GitLab