From cc313e78b7dfbf44dec64ac0d22e4134ee44cc74 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Tue, 5 Sep 2017 19:49:37 +0300
Subject: [PATCH] core: use blocks and avoid deep reorgs in txpool

---
 core/tx_pool.go      | 84 +++++++++++++++++++++++++-------------------
 core/tx_pool_test.go | 16 ++++-----
 2 files changed, 55 insertions(+), 45 deletions(-)

diff --git a/core/tx_pool.go b/core/tx_pool.go
index f41fbe069..0ad765179 100644
--- a/core/tx_pool.go
+++ b/core/tx_pool.go
@@ -19,6 +19,7 @@ package core
 import (
 	"errors"
 	"fmt"
+	"math"
 	"math/big"
 	"sort"
 	"sync"
@@ -105,11 +106,11 @@ var (
 // blockChain provides the state of blockchain and current gas limit to do
 // some pre checks in tx pool and event subscribers.
 type blockChain interface {
-	CurrentHeader() *types.Header
-	SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription
-
+	CurrentBlock() *types.Block
 	GetBlock(hash common.Hash, number uint64) *types.Block
 	StateAt(root common.Hash) (*state.StateDB, error)
+
+	SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription
 }
 
 // TxPoolConfig are the configuration parameters of the transaction pool.
@@ -223,7 +224,7 @@ func NewTxPool(config TxPoolConfig, chainconfig *params.ChainConfig, chain block
 	}
 	pool.locals = newAccountSet(pool.signer)
 	pool.priced = newTxPricedList(&pool.all)
-	pool.reset(nil, chain.CurrentHeader())
+	pool.reset(nil, chain.CurrentBlock().Header())
 
 	// If local transactions and journaling is enabled, load from disk
 	if !config.NoLocals && config.Journal != "" {
@@ -265,7 +266,7 @@ func (pool *TxPool) loop() {
 	defer journal.Stop()
 
 	// Track the previous head headers for transaction reorgs
-	head := pool.chain.CurrentHeader()
+	head := pool.chain.CurrentBlock()
 
 	// Keep waiting for and reacting to the various events
 	for {
@@ -277,8 +278,8 @@ func (pool *TxPool) loop() {
 				if pool.chainconfig.IsHomestead(ev.Block.Number()) {
 					pool.homestead = true
 				}
-				pool.reset(head, ev.Block.Header())
-				head = ev.Block.Header()
+				pool.reset(head.Header(), ev.Block.Header())
+				head = ev.Block
 
 				pool.mu.Unlock()
 			}
@@ -344,43 +345,52 @@ func (pool *TxPool) reset(oldHead, newHead *types.Header) {
 	var reinject types.Transactions
 
 	if oldHead != nil && oldHead.Hash() != newHead.ParentHash {
-		var discarded, included types.Transactions
-
-		var (
-			rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number.Uint64())
-			add = pool.chain.GetBlock(newHead.Hash(), newHead.Number.Uint64())
-		)
-		for rem.NumberU64() > add.NumberU64() {
-			discarded = append(discarded, rem.Transactions()...)
-			if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil {
-				log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash())
-				return
-			}
-		}
-		for add.NumberU64() > rem.NumberU64() {
-			included = append(included, add.Transactions()...)
-			if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil {
-				log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash())
-				return
+		// If the reorg is too deep, avoid doing it (will happen during fast sync)
+		oldNum := oldHead.Number.Uint64()
+		newNum := newHead.Number.Uint64()
+
+		if depth := uint64(math.Abs(float64(oldNum) - float64(newNum))); depth > 64 {
+			log.Warn("Skipping deep transaction reorg", "depth", depth)
+		} else {
+			// Reorg seems shallow enough to pull in all transactions into memory
+			var discarded, included types.Transactions
+
+			var (
+				rem = pool.chain.GetBlock(oldHead.Hash(), oldHead.Number.Uint64())
+				add = pool.chain.GetBlock(newHead.Hash(), newHead.Number.Uint64())
+			)
+			for rem.NumberU64() > add.NumberU64() {
+				discarded = append(discarded, rem.Transactions()...)
+				if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil {
+					log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash())
+					return
+				}
 			}
-		}
-		for rem.Hash() != add.Hash() {
-			discarded = append(discarded, rem.Transactions()...)
-			if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil {
-				log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash())
-				return
+			for add.NumberU64() > rem.NumberU64() {
+				included = append(included, add.Transactions()...)
+				if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil {
+					log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash())
+					return
+				}
 			}
-			included = append(included, add.Transactions()...)
-			if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil {
-				log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash())
-				return
+			for rem.Hash() != add.Hash() {
+				discarded = append(discarded, rem.Transactions()...)
+				if rem = pool.chain.GetBlock(rem.ParentHash(), rem.NumberU64()-1); rem == nil {
+					log.Error("Unrooted old chain seen by tx pool", "block", oldHead.Number, "hash", oldHead.Hash())
+					return
+				}
+				included = append(included, add.Transactions()...)
+				if add = pool.chain.GetBlock(add.ParentHash(), add.NumberU64()-1); add == nil {
+					log.Error("Unrooted new chain seen by tx pool", "block", newHead.Number, "hash", newHead.Hash())
+					return
+				}
 			}
+			reinject = types.TxDifference(discarded, included)
 		}
-		reinject = types.TxDifference(discarded, included)
 	}
 	// Initialize the internal state to the current head
 	if newHead == nil {
-		newHead = pool.chain.CurrentHeader() // Special case during testing
+		newHead = pool.chain.CurrentBlock().Header() // Special case during testing
 	}
 	statedb, err := pool.chain.StateAt(newHead.Root)
 	if err != nil {
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index cdd45b4b1..17d736877 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -50,24 +50,24 @@ type testBlockChain struct {
 	chainHeadFeed *event.Feed
 }
 
-func (bc *testBlockChain) CurrentHeader() *types.Header {
-	return &types.Header{
+func (bc *testBlockChain) CurrentBlock() *types.Block {
+	return types.NewBlock(&types.Header{
 		GasLimit: bc.gasLimit,
-	}
-}
-
-func (bc *testBlockChain) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription {
-	return bc.chainHeadFeed.Subscribe(ch)
+	}, nil, nil, nil)
 }
 
 func (bc *testBlockChain) GetBlock(hash common.Hash, number uint64) *types.Block {
-	return types.NewBlock(bc.CurrentHeader(), nil, nil, nil)
+	return bc.CurrentBlock()
 }
 
 func (bc *testBlockChain) StateAt(common.Hash) (*state.StateDB, error) {
 	return bc.statedb, nil
 }
 
+func (bc *testBlockChain) SubscribeChainHeadEvent(ch chan<- ChainHeadEvent) event.Subscription {
+	return bc.chainHeadFeed.Subscribe(ch)
+}
+
 func transaction(nonce uint64, gaslimit *big.Int, key *ecdsa.PrivateKey) *types.Transaction {
 	return pricedTransaction(nonce, gaslimit, big.NewInt(1), key)
 }
-- 
GitLab