From a13e920af01692cb07a520cda688f1cc5b5469dd Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Tue, 18 Apr 2017 13:37:10 +0200
Subject: [PATCH] trie: clean up iterator constructors

Make it so each iterator has exactly one public constructor:

- NodeIterators can be created through a method.
- Iterators can be created through NewIterator on any NodeIterator.
---
 core/state/dump.go     |  5 +++--
 core/state/iterator.go |  2 +-
 core/state/statedb.go  |  2 +-
 trie/iterator.go       | 15 ++++-----------
 trie/iterator_test.go  | 14 +++++++-------
 trie/secure_trie.go    |  6 +-----
 trie/sync_test.go      |  2 +-
 trie/trie.go           |  4 ++--
 trie/trie_test.go      |  2 +-
 9 files changed, 21 insertions(+), 31 deletions(-)

diff --git a/core/state/dump.go b/core/state/dump.go
index 8294d61b9..6338ddf88 100644
--- a/core/state/dump.go
+++ b/core/state/dump.go
@@ -22,6 +22,7 @@ import (
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/rlp"
+	"github.com/ethereum/go-ethereum/trie"
 )
 
 type DumpAccount struct {
@@ -44,7 +45,7 @@ func (self *StateDB) RawDump() Dump {
 		Accounts: make(map[string]DumpAccount),
 	}
 
-	it := self.trie.Iterator()
+	it := trie.NewIterator(self.trie.NodeIterator())
 	for it.Next() {
 		addr := self.trie.GetKey(it.Key)
 		var data Account
@@ -61,7 +62,7 @@ func (self *StateDB) RawDump() Dump {
 			Code:     common.Bytes2Hex(obj.Code(self.db)),
 			Storage:  make(map[string]string),
 		}
-		storageIt := obj.getTrie(self.db).Iterator()
+		storageIt := trie.NewIterator(obj.getTrie(self.db).NodeIterator())
 		for storageIt.Next() {
 			account.Storage[common.Bytes2Hex(self.trie.GetKey(storageIt.Key))] = common.Bytes2Hex(storageIt.Value)
 		}
diff --git a/core/state/iterator.go b/core/state/iterator.go
index 170aec983..d2dd5a74e 100644
--- a/core/state/iterator.go
+++ b/core/state/iterator.go
@@ -118,7 +118,7 @@ func (it *NodeIterator) step() error {
 	if err != nil {
 		return err
 	}
-	it.dataIt = trie.NewNodeIterator(dataTrie)
+	it.dataIt = dataTrie.NodeIterator()
 	if !it.dataIt.Next(true) {
 		it.dataIt = nil
 	}
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 0c72fc6b0..24381ced5 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -481,7 +481,7 @@ func (db *StateDB) ForEachStorage(addr common.Address, cb func(key, value common
 		cb(h, value)
 	}
 
-	it := so.getTrie(db.db).Iterator()
+	it := trie.NewIterator(so.getTrie(db.db).NodeIterator())
 	for it.Next() {
 		// ignore cached values
 		key := common.BytesToHash(db.trie.GetKey(it.Key))
diff --git a/trie/iterator.go b/trie/iterator.go
index dd63a0c5a..fef5b2593 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -31,15 +31,8 @@ type Iterator struct {
 	Value []byte // Current data value on which the iterator is positioned on
 }
 
-// NewIterator creates a new key-value iterator.
-func NewIterator(trie *Trie) *Iterator {
-	return &Iterator{
-		nodeIt: NewNodeIterator(trie),
-	}
-}
-
-// FromNodeIterator creates a new key-value iterator from a node iterator
-func NewIteratorFromNodeIterator(it NodeIterator) *Iterator {
+// NewIterator creates a new key-value iterator from a node iterator
+func NewIterator(it NodeIterator) *Iterator {
 	return &Iterator{
 		nodeIt: it,
 	}
@@ -99,8 +92,8 @@ type nodeIterator struct {
 	path []byte // Path to the current node
 }
 
-// NewNodeIterator creates an post-order trie iterator.
-func NewNodeIterator(trie *Trie) NodeIterator {
+// newNodeIterator creates an post-order trie iterator.
+func newNodeIterator(trie *Trie) NodeIterator {
 	if trie.Hash() == emptyState {
 		return new(nodeIterator)
 	}
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index c101bb7b0..04d51aaf5 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -42,7 +42,7 @@ func TestIterator(t *testing.T) {
 	trie.Commit()
 
 	found := make(map[string]string)
-	it := NewIterator(trie)
+	it := NewIterator(trie.NodeIterator())
 	for it.Next() {
 		found[string(it.Key)] = string(it.Value)
 	}
@@ -72,7 +72,7 @@ func TestIteratorLargeData(t *testing.T) {
 		vals[string(value2.k)] = value2
 	}
 
-	it := NewIterator(trie)
+	it := NewIterator(trie.NodeIterator())
 	for it.Next() {
 		vals[string(it.Key)].t = true
 	}
@@ -99,7 +99,7 @@ func TestNodeIteratorCoverage(t *testing.T) {
 
 	// Gather all the node hashes found by the iterator
 	hashes := make(map[common.Hash]struct{})
-	for it := NewNodeIterator(trie); it.Next(true); {
+	for it := trie.NodeIterator(); it.Next(true); {
 		if it.Hash() != (common.Hash{}) {
 			hashes[it.Hash()] = struct{}{}
 		}
@@ -154,8 +154,8 @@ func TestDifferenceIterator(t *testing.T) {
 	trieb.Commit()
 
 	found := make(map[string]string)
-	di, _ := NewDifferenceIterator(NewNodeIterator(triea), NewNodeIterator(trieb))
-	it := NewIteratorFromNodeIterator(di)
+	di, _ := NewDifferenceIterator(triea.NodeIterator(), trieb.NodeIterator())
+	it := NewIterator(di)
 	for it.Next() {
 		found[string(it.Key)] = string(it.Value)
 	}
@@ -189,8 +189,8 @@ func TestUnionIterator(t *testing.T) {
 	}
 	trieb.Commit()
 
-	di, _ := NewUnionIterator([]NodeIterator{NewNodeIterator(triea), NewNodeIterator(trieb)})
-	it := NewIteratorFromNodeIterator(di)
+	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(), trieb.NodeIterator()})
+	it := NewIterator(di)
 
 	all := []struct{ k, v string }{
 		{"aardvark", "c"},
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 113fb6a1a..201716d18 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -156,12 +156,8 @@ func (t *SecureTrie) Root() []byte {
 	return t.trie.Root()
 }
 
-func (t *SecureTrie) Iterator() *Iterator {
-	return t.trie.Iterator()
-}
-
 func (t *SecureTrie) NodeIterator() NodeIterator {
-	return NewNodeIterator(&t.trie)
+	return t.trie.NodeIterator()
 }
 
 // CommitTo writes all nodes and the secure hash pre-images to the given database.
diff --git a/trie/sync_test.go b/trie/sync_test.go
index acae039cd..6d345ad3f 100644
--- a/trie/sync_test.go
+++ b/trie/sync_test.go
@@ -80,7 +80,7 @@ func checkTrieConsistency(db Database, root common.Hash) error {
 	if err != nil {
 		return nil // // Consider a non existent state consistent
 	}
-	it := NewNodeIterator(trie)
+	it := trie.NodeIterator()
 	for it.Next(true) {
 	}
 	return it.Error()
diff --git a/trie/trie.go b/trie/trie.go
index e61bd0383..dbffc0ac3 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -126,8 +126,8 @@ func New(root common.Hash, db Database) (*Trie, error) {
 }
 
 // Iterator returns an iterator over all mappings in the trie.
-func (t *Trie) Iterator() *Iterator {
-	return NewIterator(t)
+func (t *Trie) NodeIterator() NodeIterator {
+	return newNodeIterator(t)
 }
 
 // Get returns the value for key stored in the trie.
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 01ae3a4e7..cacb08824 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -439,7 +439,7 @@ func runRandTest(rt randTest) bool {
 			tr = newtr
 		case opItercheckhash:
 			checktr, _ := New(common.Hash{}, nil)
-			it := tr.Iterator()
+			it := NewIterator(tr.NodeIterator())
 			for it.Next() {
 				checktr.Update(it.Key, it.Value)
 			}
-- 
GitLab