From 40a3856af9d28fce6550509a01cf926525da5d22 Mon Sep 17 00:00:00 2001
From: Miya Chen <miyatlchen@gmail.com>
Date: Tue, 10 Oct 2017 16:53:05 +0800
Subject: [PATCH] eth/fetcher: check the origin of filter tasks (#14975)

* eth/fetcher: check the origin of filter task

* eth/fetcher: add some details to fetcher logs
---
 eth/fetcher/fetcher.go      | 18 +++++----
 eth/fetcher/fetcher_test.go | 79 +++++++++++++++++++++----------------
 eth/handler.go              |  4 +-
 3 files changed, 56 insertions(+), 45 deletions(-)

diff --git a/eth/fetcher/fetcher.go b/eth/fetcher/fetcher.go
index 98cc1a76b..50966f5ee 100644
--- a/eth/fetcher/fetcher.go
+++ b/eth/fetcher/fetcher.go
@@ -83,6 +83,7 @@ type announce struct {
 
 // headerFilterTask represents a batch of headers needing fetcher filtering.
 type headerFilterTask struct {
+	peer    string          // The source peer of block headers
 	headers []*types.Header // Collection of headers to filter
 	time    time.Time       // Arrival time of the headers
 }
@@ -90,6 +91,7 @@ type headerFilterTask struct {
 // headerFilterTask represents a batch of block bodies (transactions and uncles)
 // needing fetcher filtering.
 type bodyFilterTask struct {
+	peer         string                 // The source peer of block bodies
 	transactions [][]*types.Transaction // Collection of transactions per block bodies
 	uncles       [][]*types.Header      // Collection of uncles per block bodies
 	time         time.Time              // Arrival time of the blocks' contents
@@ -218,8 +220,8 @@ func (f *Fetcher) Enqueue(peer string, block *types.Block) error {
 
 // FilterHeaders extracts all the headers that were explicitly requested by the fetcher,
 // returning those that should be handled differently.
-func (f *Fetcher) FilterHeaders(headers []*types.Header, time time.Time) []*types.Header {
-	log.Trace("Filtering headers", "headers", len(headers))
+func (f *Fetcher) FilterHeaders(peer string, headers []*types.Header, time time.Time) []*types.Header {
+	log.Trace("Filtering headers", "peer", peer, "headers", len(headers))
 
 	// Send the filter channel to the fetcher
 	filter := make(chan *headerFilterTask)
@@ -231,7 +233,7 @@ func (f *Fetcher) FilterHeaders(headers []*types.Header, time time.Time) []*type
 	}
 	// Request the filtering of the header list
 	select {
-	case filter <- &headerFilterTask{headers: headers, time: time}:
+	case filter <- &headerFilterTask{peer: peer, headers: headers, time: time}:
 	case <-f.quit:
 		return nil
 	}
@@ -246,8 +248,8 @@ func (f *Fetcher) FilterHeaders(headers []*types.Header, time time.Time) []*type
 
 // FilterBodies extracts all the block bodies that were explicitly requested by
 // the fetcher, returning those that should be handled differently.
-func (f *Fetcher) FilterBodies(transactions [][]*types.Transaction, uncles [][]*types.Header, time time.Time) ([][]*types.Transaction, [][]*types.Header) {
-	log.Trace("Filtering bodies", "txs", len(transactions), "uncles", len(uncles))
+func (f *Fetcher) FilterBodies(peer string, transactions [][]*types.Transaction, uncles [][]*types.Header, time time.Time) ([][]*types.Transaction, [][]*types.Header) {
+	log.Trace("Filtering bodies", "peer", peer, "txs", len(transactions), "uncles", len(uncles))
 
 	// Send the filter channel to the fetcher
 	filter := make(chan *bodyFilterTask)
@@ -259,7 +261,7 @@ func (f *Fetcher) FilterBodies(transactions [][]*types.Transaction, uncles [][]*
 	}
 	// Request the filtering of the body list
 	select {
-	case filter <- &bodyFilterTask{transactions: transactions, uncles: uncles, time: time}:
+	case filter <- &bodyFilterTask{peer: peer, transactions: transactions, uncles: uncles, time: time}:
 	case <-f.quit:
 		return nil, nil
 	}
@@ -444,7 +446,7 @@ func (f *Fetcher) loop() {
 				hash := header.Hash()
 
 				// Filter fetcher-requested headers from other synchronisation algorithms
-				if announce := f.fetching[hash]; announce != nil && f.fetched[hash] == nil && f.completing[hash] == nil && f.queued[hash] == nil {
+				if announce := f.fetching[hash]; announce != nil && announce.origin == task.peer && f.fetched[hash] == nil && f.completing[hash] == nil && f.queued[hash] == nil {
 					// If the delivered header does not match the promised number, drop the announcer
 					if header.Number.Uint64() != announce.number {
 						log.Trace("Invalid block number fetched", "peer", announce.origin, "hash", header.Hash(), "announced", announce.number, "provided", header.Number)
@@ -523,7 +525,7 @@ func (f *Fetcher) loop() {
 						txnHash := types.DeriveSha(types.Transactions(task.transactions[i]))
 						uncleHash := types.CalcUncleHash(task.uncles[i])
 
-						if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash {
+						if txnHash == announce.header.TxHash && uncleHash == announce.header.UncleHash && announce.origin == task.peer {
 							// Mark the body matched, reassemble if still unknown
 							matched = true
 
diff --git a/eth/fetcher/fetcher_test.go b/eth/fetcher/fetcher_test.go
index 85d2f8645..9889e6cc5 100644
--- a/eth/fetcher/fetcher_test.go
+++ b/eth/fetcher/fetcher_test.go
@@ -153,7 +153,7 @@ func (f *fetcherTester) dropPeer(peer string) {
 }
 
 // makeHeaderFetcher retrieves a block header fetcher associated with a simulated peer.
-func (f *fetcherTester) makeHeaderFetcher(blocks map[common.Hash]*types.Block, drift time.Duration) headerRequesterFn {
+func (f *fetcherTester) makeHeaderFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) headerRequesterFn {
 	closure := make(map[common.Hash]*types.Block)
 	for hash, block := range blocks {
 		closure[hash] = block
@@ -166,14 +166,14 @@ func (f *fetcherTester) makeHeaderFetcher(blocks map[common.Hash]*types.Block, d
 			headers = append(headers, block.Header())
 		}
 		// Return on a new thread
-		go f.fetcher.FilterHeaders(headers, time.Now().Add(drift))
+		go f.fetcher.FilterHeaders(peer, headers, time.Now().Add(drift))
 
 		return nil
 	}
 }
 
 // makeBodyFetcher retrieves a block body fetcher associated with a simulated peer.
-func (f *fetcherTester) makeBodyFetcher(blocks map[common.Hash]*types.Block, drift time.Duration) bodyRequesterFn {
+func (f *fetcherTester) makeBodyFetcher(peer string, blocks map[common.Hash]*types.Block, drift time.Duration) bodyRequesterFn {
 	closure := make(map[common.Hash]*types.Block)
 	for hash, block := range blocks {
 		closure[hash] = block
@@ -191,7 +191,7 @@ func (f *fetcherTester) makeBodyFetcher(blocks map[common.Hash]*types.Block, dri
 			}
 		}
 		// Return on a new thread
-		go f.fetcher.FilterBodies(transactions, uncles, time.Now().Add(drift))
+		go f.fetcher.FilterBodies(peer, transactions, uncles, time.Now().Add(drift))
 
 		return nil
 	}
@@ -282,8 +282,8 @@ func testSequentialAnnouncements(t *testing.T, protocol int) {
 	hashes, blocks := makeChain(targetBlocks, 0, genesis)
 
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	// Iteratively announce blocks until all are imported
 	imported := make(chan *types.Block)
@@ -309,22 +309,28 @@ func testConcurrentAnnouncements(t *testing.T, protocol int) {
 
 	// Assemble a tester with a built in counter for the requests
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	firstHeaderFetcher := tester.makeHeaderFetcher("first", blocks, -gatherSlack)
+	firstBodyFetcher := tester.makeBodyFetcher("first", blocks, 0)
+	secondHeaderFetcher := tester.makeHeaderFetcher("second", blocks, -gatherSlack)
+	secondBodyFetcher := tester.makeBodyFetcher("second", blocks, 0)
 
 	counter := uint32(0)
-	headerWrapper := func(hash common.Hash) error {
+	firstHeaderWrapper := func(hash common.Hash) error {
+		atomic.AddUint32(&counter, 1)
+		return firstHeaderFetcher(hash)
+	}
+	secondHeaderWrapper := func(hash common.Hash) error {
 		atomic.AddUint32(&counter, 1)
-		return headerFetcher(hash)
+		return secondHeaderFetcher(hash)
 	}
 	// Iteratively announce blocks until all are imported
 	imported := make(chan *types.Block)
 	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
 
 	for i := len(hashes) - 2; i >= 0; i-- {
-		tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), headerWrapper, bodyFetcher)
-		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), headerWrapper, bodyFetcher)
-		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), headerWrapper, bodyFetcher)
+		tester.fetcher.Notify("first", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout), firstHeaderWrapper, firstBodyFetcher)
+		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout+time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
+		tester.fetcher.Notify("second", hashes[i], uint64(len(hashes)-i-1), time.Now().Add(-arriveTimeout-time.Millisecond), secondHeaderWrapper, secondBodyFetcher)
 		verifyImportEvent(t, imported, true)
 	}
 	verifyImportDone(t, imported)
@@ -347,8 +353,8 @@ func testOverlappingAnnouncements(t *testing.T, protocol int) {
 	hashes, blocks := makeChain(targetBlocks, 0, genesis)
 
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	// Iteratively announce blocks, but overlap them continuously
 	overlap := 16
@@ -381,8 +387,8 @@ func testPendingDeduplication(t *testing.T, protocol int) {
 
 	// Assemble a tester with a built in counter and delayed fetcher
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("repeater", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("repeater", blocks, 0)
 
 	delay := 50 * time.Millisecond
 	counter := uint32(0)
@@ -425,8 +431,8 @@ func testRandomArrivalImport(t *testing.T, protocol int) {
 	skip := targetBlocks / 2
 
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	// Iteratively announce blocks, skipping one entry
 	imported := make(chan *types.Block, len(hashes)-1)
@@ -456,8 +462,8 @@ func testQueueGapFill(t *testing.T, protocol int) {
 	skip := targetBlocks / 2
 
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	// Iteratively announce blocks, skipping one entry
 	imported := make(chan *types.Block, len(hashes)-1)
@@ -486,8 +492,8 @@ func testImportDeduplication(t *testing.T, protocol int) {
 
 	// Create the tester and wrap the importer with a counter
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	counter := uint32(0)
 	tester.fetcher.insertChain = func(blocks types.Blocks) (int, error) {
@@ -570,8 +576,8 @@ func testDistantAnnouncementDiscarding(t *testing.T, protocol int) {
 	tester.blocks = map[common.Hash]*types.Block{head: blocks[head]}
 	tester.lock.Unlock()
 
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("lower", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("lower", blocks, 0)
 
 	fetching := make(chan struct{}, 2)
 	tester.fetcher.fetchingHook = func(hashes []common.Hash) { fetching <- struct{}{} }
@@ -603,14 +609,14 @@ func testInvalidNumberAnnouncement(t *testing.T, protocol int) {
 	hashes, blocks := makeChain(1, 0, genesis)
 
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	badHeaderFetcher := tester.makeHeaderFetcher("bad", blocks, -gatherSlack)
+	badBodyFetcher := tester.makeBodyFetcher("bad", blocks, 0)
 
 	imported := make(chan *types.Block)
 	tester.fetcher.importedHook = func(block *types.Block) { imported <- block }
 
 	// Announce a block with a bad number, check for immediate drop
-	tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+	tester.fetcher.Notify("bad", hashes[0], 2, time.Now().Add(-arriveTimeout), badHeaderFetcher, badBodyFetcher)
 	verifyImportEvent(t, imported, false)
 
 	tester.lock.RLock()
@@ -620,8 +626,11 @@ func testInvalidNumberAnnouncement(t *testing.T, protocol int) {
 	if !dropped {
 		t.Fatalf("peer with invalid numbered announcement not dropped")
 	}
+
+	goodHeaderFetcher := tester.makeHeaderFetcher("good", blocks, -gatherSlack)
+	goodBodyFetcher := tester.makeBodyFetcher("good", blocks, 0)
 	// Make sure a good announcement passes without a drop
-	tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), headerFetcher, bodyFetcher)
+	tester.fetcher.Notify("good", hashes[0], 1, time.Now().Add(-arriveTimeout), goodHeaderFetcher, goodBodyFetcher)
 	verifyImportEvent(t, imported, true)
 
 	tester.lock.RLock()
@@ -645,8 +654,8 @@ func testEmptyBlockShortCircuit(t *testing.T, protocol int) {
 	hashes, blocks := makeChain(32, 0, genesis)
 
 	tester := newTester()
-	headerFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	bodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	headerFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	bodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	// Add a monitoring hook for all internal events
 	fetching := make(chan []common.Hash)
@@ -697,12 +706,12 @@ func testHashMemoryExhaustionAttack(t *testing.T, protocol int) {
 	// Create a valid chain and an infinite junk chain
 	targetBlocks := hashLimit + 2*maxQueueDist
 	hashes, blocks := makeChain(targetBlocks, 0, genesis)
-	validHeaderFetcher := tester.makeHeaderFetcher(blocks, -gatherSlack)
-	validBodyFetcher := tester.makeBodyFetcher(blocks, 0)
+	validHeaderFetcher := tester.makeHeaderFetcher("valid", blocks, -gatherSlack)
+	validBodyFetcher := tester.makeBodyFetcher("valid", blocks, 0)
 
 	attack, _ := makeChain(targetBlocks, 0, unknownBlock)
-	attackerHeaderFetcher := tester.makeHeaderFetcher(nil, -gatherSlack)
-	attackerBodyFetcher := tester.makeBodyFetcher(nil, 0)
+	attackerHeaderFetcher := tester.makeHeaderFetcher("attacker", nil, -gatherSlack)
+	attackerBodyFetcher := tester.makeBodyFetcher("attacker", nil, 0)
 
 	// Feed the tester a huge hashset from the attacker, and a limited from the valid peer
 	for i := 0; i < len(attack); i++ {
diff --git a/eth/handler.go b/eth/handler.go
index cee719ddb..bec5126dc 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -450,7 +450,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 				return nil
 			}
 			// Irrelevant of the fork checks, send the header to the fetcher just in case
-			headers = pm.fetcher.FilterHeaders(headers, time.Now())
+			headers = pm.fetcher.FilterHeaders(p.id, headers, time.Now())
 		}
 		if len(headers) > 0 || !filter {
 			err := pm.downloader.DeliverHeaders(p.id, headers)
@@ -503,7 +503,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		// Filter out any explicitly requested bodies, deliver the rest to the downloader
 		filter := len(trasactions) > 0 || len(uncles) > 0
 		if filter {
-			trasactions, uncles = pm.fetcher.FilterBodies(trasactions, uncles, time.Now())
+			trasactions, uncles = pm.fetcher.FilterBodies(p.id, trasactions, uncles, time.Now())
 		}
 		if len(trasactions) > 0 || len(uncles) > 0 || !filter {
 			err := pm.downloader.DeliverBodies(p.id, trasactions, uncles)
-- 
GitLab