From f40ff23b7bf5396936159ae500ed5b18fdaec27d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Wed, 18 Sep 2019 10:53:01 +0300
Subject: [PATCH] core: fix tx dedup return error count

---
 core/tx_pool.go      | 31 +++++++++++++++++----
 core/tx_pool_test.go | 65 ++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/core/tx_pool.go b/core/tx_pool.go
index d67be96e5..f7032dbd1 100644
--- a/core/tx_pool.go
+++ b/core/tx_pool.go
@@ -766,21 +766,40 @@ func (pool *TxPool) AddRemote(tx *types.Transaction) error {
 // addTxs attempts to queue a batch of transactions if they are valid.
 func (pool *TxPool) addTxs(txs []*types.Transaction, local, sync bool) []error {
 	// Filter out known ones without obtaining the pool lock or recovering signatures
-	for i := 0; i < len(txs); i++ {
-		if pool.all.Get(txs[i].Hash()) != nil {
+	var (
+		errs = make([]error, len(txs))
+		news = make([]*types.Transaction, 0, len(txs))
+	)
+	for i, tx := range txs {
+		// If the transaction is known, pre-set the error slot
+		if pool.all.Get(tx.Hash()) != nil {
+			errs[i] = fmt.Errorf("known transaction: %x", tx.Hash())
 			knownTxMeter.Mark(1)
-			txs = append(txs[:i], txs[i+1:]...)
-			i--
+			continue
 		}
+		// Accumulate all unknown transactions for deeper processing
+		news = append(news, tx)
+	}
+	if len(news) == 0 {
+		return errs
 	}
 	// Cache senders in transactions before obtaining lock (pool.signer is immutable)
-	for _, tx := range txs {
+	for _, tx := range news {
 		types.Sender(pool.signer, tx)
 	}
+	// Process all the new transaction and merge any errors into the original slice
 	pool.mu.Lock()
-	errs, dirtyAddrs := pool.addTxsLocked(txs, local)
+	newErrs, dirtyAddrs := pool.addTxsLocked(news, local)
 	pool.mu.Unlock()
 
+	var nilSlot = 0
+	for _, err := range newErrs {
+		for errs[nilSlot] != nil {
+			nilSlot++
+		}
+		errs[nilSlot] = err
+	}
+	// Reorg the pool internals if needed and return
 	done := pool.requestPromoteExecutables(dirtyAddrs)
 	if sync {
 		<-done
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index 388668ed8..0f1e7ac8f 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -1438,6 +1438,71 @@ func TestTransactionPoolStableUnderpricing(t *testing.T) {
 	}
 }
 
+// Tests that the pool rejects duplicate transactions.
+func TestTransactionDeduplication(t *testing.T) {
+	t.Parallel()
+
+	// Create the pool to test the pricing enforcement with
+	statedb, _ := state.New(common.Hash{}, state.NewDatabase(rawdb.NewMemoryDatabase()))
+	blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
+
+	pool := NewTxPool(testTxPoolConfig, params.TestChainConfig, blockchain)
+	defer pool.Stop()
+
+	// Create a test account to add transactions with
+	key, _ := crypto.GenerateKey()
+	pool.currentState.AddBalance(crypto.PubkeyToAddress(key.PublicKey), big.NewInt(1000000000))
+
+	// Create a batch of transactions and add a few of them
+	txs := make([]*types.Transaction, 16)
+	for i := 0; i < len(txs); i++ {
+		txs[i] = pricedTransaction(uint64(i), 100000, big.NewInt(1), key)
+	}
+	var firsts []*types.Transaction
+	for i := 0; i < len(txs); i += 2 {
+		firsts = append(firsts, txs[i])
+	}
+	errs := pool.AddRemotesSync(firsts)
+	if len(errs) != len(firsts) {
+		t.Fatalf("first add mismatching result count: have %d, want %d", len(errs), len(firsts))
+	}
+	for i, err := range errs {
+		if err != nil {
+			t.Errorf("add %d failed: %v", i, err)
+		}
+	}
+	pending, queued := pool.Stats()
+	if pending != 1 {
+		t.Fatalf("pending transactions mismatched: have %d, want %d", pending, 1)
+	}
+	if queued != len(txs)/2-1 {
+		t.Fatalf("queued transactions mismatched: have %d, want %d", queued, len(txs)/2-1)
+	}
+	// Try to add all of them now and ensure previous ones error out as knowns
+	errs = pool.AddRemotesSync(txs)
+	if len(errs) != len(txs) {
+		t.Fatalf("all add mismatching result count: have %d, want %d", len(errs), len(txs))
+	}
+	for i, err := range errs {
+		if i%2 == 0 && err == nil {
+			t.Errorf("add %d succeeded, should have failed as known", i)
+		}
+		if i%2 == 1 && err != nil {
+			t.Errorf("add %d failed: %v", i, err)
+		}
+	}
+	pending, queued = pool.Stats()
+	if pending != len(txs) {
+		t.Fatalf("pending transactions mismatched: have %d, want %d", pending, len(txs))
+	}
+	if queued != 0 {
+		t.Fatalf("queued transactions mismatched: have %d, want %d", queued, 0)
+	}
+	if err := validateTxPoolInternals(pool); err != nil {
+		t.Fatalf("pool internal state corrupted: %v", err)
+	}
+}
+
 // Tests that the pool rejects replacement transactions that don't meet the minimum
 // price bump required.
 func TestTransactionReplacement(t *testing.T) {
-- 
GitLab