diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go
index 05d55a6fa38e0646d84387fb22af2d0a7e647eaa..855d862de2c879530e06173b3221da3f14d0fe97 100644
--- a/core/state/snapshot/difflayer.go
+++ b/core/state/snapshot/difflayer.go
@@ -229,6 +229,11 @@ func (dl *diffLayer) Root() common.Hash {
 	return dl.root
 }
 
+// Parent returns the subsequent layer of a diff layer.
+func (dl *diffLayer) Parent() snapshot {
+	return dl.parent
+}
+
 // Stale return whether this layer has become stale (was flattened across) or if
 // it's still live.
 func (dl *diffLayer) Stale() bool {
@@ -405,7 +410,7 @@ func (dl *diffLayer) flatten() snapshot {
 	for hash, data := range dl.accountData {
 		parent.accountData[hash] = data
 	}
-	// Overwrite all the updates storage slots (individually)
+	// Overwrite all the updated storage slots (individually)
 	for accountHash, storage := range dl.storageData {
 		// If storage didn't exist (or was deleted) in the parent; or if the storage
 		// was freshly deleted in the child, overwrite blindly
@@ -425,53 +430,62 @@ func (dl *diffLayer) flatten() snapshot {
 		parent:      parent.parent,
 		origin:      parent.origin,
 		root:        dl.root,
-		storageList: parent.storageList,
-		storageData: parent.storageData,
-		accountList: parent.accountList,
 		accountData: parent.accountData,
+		storageData: parent.storageData,
+		storageList: make(map[common.Hash][]common.Hash),
 		diffed:      dl.diffed,
 		memory:      parent.memory + dl.memory,
 	}
 }
 
-// AccountList returns a sorted list of all accounts in this difflayer.
+// AccountList returns a sorted list of all accounts in this difflayer, including
+// the deleted ones.
+//
+// Note, the returned slice is not a copy, so do not modify it.
 func (dl *diffLayer) AccountList() []common.Hash {
+	// If an old list already exists, return it
+	dl.lock.RLock()
+	list := dl.accountList
+	dl.lock.RUnlock()
+
+	if list != nil {
+		return list
+	}
+	// No old sorted account list exists, generate a new one
 	dl.lock.Lock()
 	defer dl.lock.Unlock()
-	if dl.accountList != nil {
-		return dl.accountList
-	}
-	accountList := make([]common.Hash, len(dl.accountData))
-	i := 0
-	for k, _ := range dl.accountData {
-		accountList[i] = k
-		i++
-		// This would be a pretty good opportunity to also
-		// calculate the size, if we want to
+
+	dl.accountList = make([]common.Hash, 0, len(dl.accountData))
+	for hash := range dl.accountData {
+		dl.accountList = append(dl.accountList, hash)
 	}
-	sort.Sort(hashes(accountList))
-	dl.accountList = accountList
+	sort.Sort(hashes(dl.accountList))
 	return dl.accountList
 }
 
-// StorageList returns a sorted list of all storage slot hashes
-// in this difflayer for the given account.
+// StorageList returns a sorted list of all storage slot hashes in this difflayer
+// for the given account.
+//
+// Note, the returned slice is not a copy, so do not modify it.
 func (dl *diffLayer) StorageList(accountHash common.Hash) []common.Hash {
+	// If an old list already exists, return it
+	dl.lock.RLock()
+	list := dl.storageList[accountHash]
+	dl.lock.RUnlock()
+
+	if list != nil {
+		return list
+	}
+	// No old sorted account list exists, generate a new one
 	dl.lock.Lock()
 	defer dl.lock.Unlock()
-	if dl.storageList[accountHash] != nil {
-		return dl.storageList[accountHash]
-	}
-	accountStorageMap := dl.storageData[accountHash]
-	accountStorageList := make([]common.Hash, len(accountStorageMap))
-	i := 0
-	for k, _ := range accountStorageMap {
-		accountStorageList[i] = k
-		i++
-		// This would be a pretty good opportunity to also
-		// calculate the size, if we want to
+
+	storageMap := dl.storageData[accountHash]
+	storageList := make([]common.Hash, 0, len(storageMap))
+	for k, _ := range storageMap {
+		storageList = append(storageList, k)
 	}
-	sort.Sort(hashes(accountStorageList))
-	dl.storageList[accountHash] = accountStorageList
-	return accountStorageList
+	sort.Sort(hashes(storageList))
+	dl.storageList[accountHash] = storageList
+	return storageList
 }
diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go
index 7d7b21eb05818c634f6593217b5d9ee007ec0794..80a9b4093f26ab8f9b968c3f4773e792f56c8dee 100644
--- a/core/state/snapshot/difflayer_test.go
+++ b/core/state/snapshot/difflayer_test.go
@@ -18,7 +18,6 @@ package snapshot
 
 import (
 	"bytes"
-	"math/big"
 	"math/rand"
 	"testing"
 
@@ -26,21 +25,8 @@ import (
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/ethdb/memorydb"
-	"github.com/ethereum/go-ethereum/rlp"
 )
 
-func randomAccount() []byte {
-	root := randomHash()
-	a := Account{
-		Balance:  big.NewInt(rand.Int63()),
-		Nonce:    rand.Uint64(),
-		Root:     root[:],
-		CodeHash: emptyCode[:],
-	}
-	data, _ := rlp.EncodeToBytes(a)
-	return data
-}
-
 // TestMergeBasics tests some simple merges
 func TestMergeBasics(t *testing.T) {
 	var (
diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go
index 7c5b3e3e91a2fc80050d670cbd3d36bbdb7f9e48..0c4c3deb1b03a6312b3db8cb14e47bd707eb59bf 100644
--- a/core/state/snapshot/disklayer.go
+++ b/core/state/snapshot/disklayer.go
@@ -48,6 +48,11 @@ func (dl *diskLayer) Root() common.Hash {
 	return dl.root
 }
 
+// Parent always returns nil as there's no layer below the disk.
+func (dl *diskLayer) Parent() snapshot {
+	return nil
+}
+
 // Stale return whether this layer has become stale (was flattened across) or if
 // it's still live.
 func (dl *diskLayer) Stale() bool {
diff --git a/core/state/snapshot/iterator.go b/core/state/snapshot/iterator.go
index 6df7b3147e027f4559fa66c231de6276f68a0a10..4005cb3ca196415ff43c4dc8dd709b1eef95708b 100644
--- a/core/state/snapshot/iterator.go
+++ b/core/state/snapshot/iterator.go
@@ -18,18 +18,17 @@ package snapshot
 
 import (
 	"bytes"
+	"fmt"
 	"sort"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/ethdb"
 )
 
 // AccountIterator is an iterator to step over all the accounts in a snapshot,
 // which may or may npt be composed of multiple layers.
 type AccountIterator interface {
-	// Seek steps the iterator forward as many elements as needed, so that after
-	// calling Next(), the iterator will be at a key higher than the given hash.
-	Seek(hash common.Hash)
-
 	// Next steps the iterator forward one element, returning false if exhausted,
 	// or an error if iteration failed for some reason (e.g. root being iterated
 	// becomes stale and garbage collected).
@@ -39,78 +38,159 @@ type AccountIterator interface {
 	// caused a premature iteration exit (e.g. snapshot stack becoming stale).
 	Error() error
 
-	// Key returns the hash of the account the iterator is currently at.
-	Key() common.Hash
+	// Hash returns the hash of the account the iterator is currently at.
+	Hash() common.Hash
 
-	// Value returns the RLP encoded slim account the iterator is currently at.
+	// Account returns the RLP encoded slim account the iterator is currently at.
 	// An error will be returned if the iterator becomes invalid (e.g. snaph
-	Value() []byte
+	Account() []byte
+
+	// Release releases associated resources. Release should always succeed and
+	// can be called multiple times without causing error.
+	Release()
 }
 
 // diffAccountIterator is an account iterator that steps over the accounts (both
-// live and deleted) contained within a single
+// live and deleted) contained within a single diff layer. Higher order iterators
+// will use the deleted accounts to skip deeper iterators.
 type diffAccountIterator struct {
-	layer *diffLayer
-	index int
+	// curHash is the current hash the iterator is positioned on. The field is
+	// explicitly tracked since the referenced diff layer might go stale after
+	// the iterator was positioned and we don't want to fail accessing the old
+	// hash as long as the iterator is not touched any more.
+	curHash common.Hash
+
+	// curAccount is the current value the iterator is positioned on. The field
+	// is explicitly tracked since the referenced diff layer might go stale after
+	// the iterator was positioned and we don't want to fail accessing the old
+	// value as long as the iterator is not touched any more.
+	curAccount []byte
+
+	layer *diffLayer    // Live layer to retrieve values from
+	keys  []common.Hash // Keys left in the layer to iterate
+	fail  error         // Any failures encountered (stale)
 }
 
-func (dl *diffLayer) newAccountIterator() *diffAccountIterator {
-	dl.AccountList()
-	return &diffAccountIterator{layer: dl, index: -1}
-}
-
-// Seek steps the iterator forward as many elements as needed, so that after
-// calling Next(), the iterator will be at a key higher than the given hash.
-func (it *diffAccountIterator) Seek(key common.Hash) {
-	// Search uses binary search to find and return the smallest index i
-	// in [0, n) at which f(i) is true
-	index := sort.Search(len(it.layer.accountList), func(i int) bool {
-		return bytes.Compare(key[:], it.layer.accountList[i][:]) < 0
+// AccountIterator creates an account iterator over a single diff layer.
+func (dl *diffLayer) AccountIterator(seek common.Hash) AccountIterator {
+	// Seek out the requested starting account
+	hashes := dl.AccountList()
+	index := sort.Search(len(hashes), func(i int) bool {
+		return bytes.Compare(seek[:], hashes[i][:]) < 0
 	})
-	it.index = index - 1
+	// Assemble and returned the already seeked iterator
+	return &diffAccountIterator{
+		layer: dl,
+		keys:  hashes[index:],
+	}
 }
 
 // Next steps the iterator forward one element, returning false if exhausted.
 func (it *diffAccountIterator) Next() bool {
-	if it.index < len(it.layer.accountList) {
-		it.index++
+	// If the iterator was already stale, consider it a programmer error. Although
+	// we could just return false here, triggering this path would probably mean
+	// somebody forgot to check for Error, so lets blow up instead of undefined
+	// behavior that's hard to debug.
+	if it.fail != nil {
+		panic(fmt.Sprintf("called Next of failed iterator: %v", it.fail))
+	}
+	// Stop iterating if all keys were exhausted
+	if len(it.keys) == 0 {
+		return false
+	}
+	// Iterator seems to be still alive, retrieve and cache the live hash and
+	// account value, or fail now if layer became stale
+	it.layer.lock.RLock()
+	defer it.layer.lock.RUnlock()
+
+	if it.layer.stale {
+		it.fail, it.keys = ErrSnapshotStale, nil
+		return false
 	}
-	return it.index < len(it.layer.accountList)
+	it.curHash = it.keys[0]
+	if blob, ok := it.layer.accountData[it.curHash]; !ok {
+		panic(fmt.Sprintf("iterator referenced non-existent account: %x", it.curHash))
+	} else {
+		it.curAccount = blob
+	}
+	// Values cached, shift the iterator and notify the user of success
+	it.keys = it.keys[1:]
+	return true
 }
 
 // Error returns any failure that occurred during iteration, which might have
 // caused a premature iteration exit (e.g. snapshot stack becoming stale).
-//
-// A diff layer is immutable after creation content wise and can always be fully
-// iterated without error, so this method always returns nil.
 func (it *diffAccountIterator) Error() error {
-	return nil
+	return it.fail
 }
 
-// Key returns the hash of the account the iterator is currently at.
-func (it *diffAccountIterator) Key() common.Hash {
-	if it.index < len(it.layer.accountList) {
-		return it.layer.accountList[it.index]
-	}
-	return common.Hash{}
+// Hash returns the hash of the account the iterator is currently at.
+func (it *diffAccountIterator) Hash() common.Hash {
+	return it.curHash
 }
 
-// Value returns the RLP encoded slim account the iterator is currently at.
-func (it *diffAccountIterator) Value() []byte {
-	it.layer.lock.RLock()
-	defer it.layer.lock.RUnlock()
+// Account returns the RLP encoded slim account the iterator is currently at.
+func (it *diffAccountIterator) Account() []byte {
+	return it.curAccount
+}
+
+// Release is a noop for diff account iterators as there are no held resources.
+func (it *diffAccountIterator) Release() {}
 
-	hash := it.layer.accountList[it.index]
-	if data, ok := it.layer.accountData[hash]; ok {
-		return data
+// diskAccountIterator is an account iterator that steps over the live accounts
+// contained within a disk layer.
+type diskAccountIterator struct {
+	layer *diskLayer
+	it    ethdb.Iterator
+}
+
+// AccountIterator creates an account iterator over a disk layer.
+func (dl *diskLayer) AccountIterator(seek common.Hash) AccountIterator {
+	return &diskAccountIterator{
+		layer: dl,
+		it:    dl.diskdb.NewIteratorWithPrefix(append(rawdb.SnapshotAccountPrefix, seek[:]...)),
 	}
-	panic("iterator references non-existent layer account")
 }
 
-func (dl *diffLayer) iterators() []AccountIterator {
-	if parent, ok := dl.parent.(*diffLayer); ok {
-		iterators := parent.iterators()
-		return append(iterators, dl.newAccountIterator())
+// Next steps the iterator forward one element, returning false if exhausted.
+func (it *diskAccountIterator) Next() bool {
+	// If the iterator was already exhausted, don't bother
+	if it.it == nil {
+		return false
+	}
+	// Try to advance the iterator and release it if we reahed the end
+	if !it.it.Next() || !bytes.HasPrefix(it.it.Key(), rawdb.SnapshotAccountPrefix) {
+		it.it.Release()
+		it.it = nil
+		return false
+	}
+	return true
+}
+
+// Error returns any failure that occurred during iteration, which might have
+// caused a premature iteration exit (e.g. snapshot stack becoming stale).
+//
+// A diff layer is immutable after creation content wise and can always be fully
+// iterated without error, so this method always returns nil.
+func (it *diskAccountIterator) Error() error {
+	return it.it.Error()
+}
+
+// Hash returns the hash of the account the iterator is currently at.
+func (it *diskAccountIterator) Hash() common.Hash {
+	return common.BytesToHash(it.it.Key())
+}
+
+// Account returns the RLP encoded slim account the iterator is currently at.
+func (it *diskAccountIterator) Account() []byte {
+	return it.it.Value()
+}
+
+// Release releases the database snapshot held during iteration.
+func (it *diskAccountIterator) Release() {
+	// The iterator is auto-released on exhaustion, so make sure it's still alive
+	if it.it != nil {
+		it.it.Release()
+		it.it = nil
 	}
-	return []AccountIterator{dl.newAccountIterator()}
 }
diff --git a/core/state/snapshot/iterator_binary.go b/core/state/snapshot/iterator_binary.go
index 7ff6e3337dd8e80fa0db02d3bbd98511db7cf6cc..39288e6fb9dd6d0151f2bf71123dee9fc7ab16cd 100644
--- a/core/state/snapshot/iterator_binary.go
+++ b/core/state/snapshot/iterator_binary.go
@@ -40,10 +40,10 @@ func (dl *diffLayer) newBinaryAccountIterator() AccountIterator {
 	parent, ok := dl.parent.(*diffLayer)
 	if !ok {
 		// parent is the disk layer
-		return dl.newAccountIterator()
+		return dl.AccountIterator(common.Hash{})
 	}
 	l := &binaryAccountIterator{
-		a: dl.newAccountIterator(),
+		a: dl.AccountIterator(common.Hash{}).(*diffAccountIterator),
 		b: parent.newBinaryAccountIterator(),
 	}
 	l.aDone = !l.a.Next()
@@ -51,12 +51,6 @@ func (dl *diffLayer) newBinaryAccountIterator() AccountIterator {
 	return l
 }
 
-// Seek steps the iterator forward as many elements as needed, so that after
-// calling Next(), the iterator will be at a key higher than the given hash.
-func (it *binaryAccountIterator) Seek(key common.Hash) {
-	panic("todo: implement")
-}
-
 // Next steps the iterator forward one element, returning false if exhausted,
 // or an error if iteration failed for some reason (e.g. root being iterated
 // becomes stale and garbage collected).
@@ -64,9 +58,9 @@ func (it *binaryAccountIterator) Next() bool {
 	if it.aDone && it.bDone {
 		return false
 	}
-	nextB := it.b.Key()
+	nextB := it.b.Hash()
 first:
-	nextA := it.a.Key()
+	nextA := it.a.Hash()
 	if it.aDone {
 		it.bDone = !it.b.Next()
 		it.k = nextB
@@ -97,15 +91,15 @@ func (it *binaryAccountIterator) Error() error {
 	return it.fail
 }
 
-// Key returns the hash of the account the iterator is currently at.
-func (it *binaryAccountIterator) Key() common.Hash {
+// Hash returns the hash of the account the iterator is currently at.
+func (it *binaryAccountIterator) Hash() common.Hash {
 	return it.k
 }
 
-// Value returns the RLP encoded slim account the iterator is currently at, or
+// Account returns the RLP encoded slim account the iterator is currently at, or
 // nil if the iterated snapshot stack became stale (you can check Error after
 // to see if it failed or not).
-func (it *binaryAccountIterator) Value() []byte {
+func (it *binaryAccountIterator) Account() []byte {
 	blob, err := it.a.layer.AccountRLP(it.k)
 	if err != nil {
 		it.fail = err
@@ -113,3 +107,9 @@ func (it *binaryAccountIterator) Value() []byte {
 	}
 	return blob
 }
+
+// Release recursively releases all the iterators in the stack.
+func (it *binaryAccountIterator) Release() {
+	it.a.Release()
+	it.b.Release()
+}
diff --git a/core/state/snapshot/iterator_fast.go b/core/state/snapshot/iterator_fast.go
index 8df037e9f4abca96056e5e482a6f9a80be37f94e..676a3af175b04fd40e3340ba19cdcfb322e9144e 100644
--- a/core/state/snapshot/iterator_fast.go
+++ b/core/state/snapshot/iterator_fast.go
@@ -24,90 +24,121 @@ import (
 	"github.com/ethereum/go-ethereum/common"
 )
 
-type weightedIterator struct {
+// weightedAccountIterator is an account iterator with an assigned weight. It is
+// used to prioritise which account is the correct one if multiple iterators find
+// the same one (modified in multiple consecutive blocks).
+type weightedAccountIterator struct {
 	it       AccountIterator
 	priority int
 }
 
+// weightedAccountIterators is a set of iterators implementing the sort.Interface.
+type weightedAccountIterators []*weightedAccountIterator
+
+// Len implements sort.Interface, returning the number of active iterators.
+func (its weightedAccountIterators) Len() int { return len(its) }
+
+// Less implements sort.Interface, returning which of two iterators in the stack
+// is before the other.
+func (its weightedAccountIterators) Less(i, j int) bool {
+	// Order the iterators primarilly by the account hashes
+	hashI := its[i].it.Hash()
+	hashJ := its[j].it.Hash()
+
+	switch bytes.Compare(hashI[:], hashJ[:]) {
+	case -1:
+		return true
+	case 1:
+		return false
+	}
+	// Same account in multiple layers, split by priority
+	return its[i].priority < its[j].priority
+}
+
+// Swap implements sort.Interface, swapping two entries in the iterator stack.
+func (its weightedAccountIterators) Swap(i, j int) {
+	its[i], its[j] = its[j], its[i]
+}
+
 // fastAccountIterator is a more optimized multi-layer iterator which maintains a
-// direct mapping of all iterators leading down to the bottom layer
+// direct mapping of all iterators leading down to the bottom layer.
 type fastAccountIterator struct {
-	iterators []*weightedIterator
+	tree *Tree       // Snapshot tree to reinitialize stale sub-iterators with
+	root common.Hash // Root hash to reinitialize stale sub-iterators through
+
+	iterators weightedAccountIterators
 	initiated bool
 	fail      error
 }
 
-// newFastAccountIterator creates a new fastAccountIterator
-func (dl *diffLayer) newFastAccountIterator() AccountIterator {
-	f := &fastAccountIterator{
-		initiated: false,
+// newFastAccountIterator creates a new hierarhical account iterator with one
+// element per diff layer. The returned combo iterator can be used to walk over
+// the entire snapshot diff stack simultaneously.
+func newFastAccountIterator(tree *Tree, root common.Hash, seek common.Hash) (AccountIterator, error) {
+	snap := tree.Snapshot(root)
+	if snap == nil {
+		return nil, fmt.Errorf("unknown snapshot: %x", root)
 	}
-	for i, it := range dl.iterators() {
-		f.iterators = append(f.iterators, &weightedIterator{it, -i})
+	fi := &fastAccountIterator{
+		tree: tree,
+		root: root,
 	}
-	f.Seek(common.Hash{})
-	return f
-}
-
-// Len returns the number of active iterators
-func (fi *fastAccountIterator) Len() int {
-	return len(fi.iterators)
-}
-
-// Less implements sort.Interface
-func (fi *fastAccountIterator) Less(i, j int) bool {
-	a := fi.iterators[i].it.Key()
-	b := fi.iterators[j].it.Key()
-	bDiff := bytes.Compare(a[:], b[:])
-	if bDiff < 0 {
-		return true
-	}
-	if bDiff > 0 {
-		return false
+	current := snap.(snapshot)
+	for depth := 0; current != nil; depth++ {
+		fi.iterators = append(fi.iterators, &weightedAccountIterator{
+			it:       current.AccountIterator(seek),
+			priority: depth,
+		})
+		current = current.Parent()
 	}
-	// keys are equal, sort by iterator priority
-	return fi.iterators[i].priority < fi.iterators[j].priority
-}
-
-// Swap implements sort.Interface
-func (fi *fastAccountIterator) Swap(i, j int) {
-	fi.iterators[i], fi.iterators[j] = fi.iterators[j], fi.iterators[i]
+	fi.init()
+	return fi, nil
 }
 
-func (fi *fastAccountIterator) Seek(key common.Hash) {
-	// We need to apply this across all iterators
-	var seen = make(map[common.Hash]int)
+// init walks over all the iterators and resolves any clashes between them, after
+// which it prepares the stack for step-by-step iteration.
+func (fi *fastAccountIterator) init() {
+	// Track which account hashes are iterators positioned on
+	var positioned = make(map[common.Hash]int)
 
-	length := len(fi.iterators)
+	// Position all iterators and track how many remain live
 	for i := 0; i < len(fi.iterators); i++ {
-		//for i, it := range fi.iterators {
+		// Retrieve the first element and if it clashes with a previous iterator,
+		// advance either the current one or the old one. Repeat until nothing is
+		// clashing any more.
 		it := fi.iterators[i]
-		it.it.Seek(key)
 		for {
+			// If the iterator is exhausted, drop it off the end
 			if !it.it.Next() {
-				// To be removed
-				// swap it to the last position for now
-				fi.iterators[i], fi.iterators[length-1] = fi.iterators[length-1], fi.iterators[i]
-				length--
+				it.it.Release()
+				last := len(fi.iterators) - 1
+
+				fi.iterators[i] = fi.iterators[last]
+				fi.iterators[last] = nil
+				fi.iterators = fi.iterators[:last]
+
+				i--
 				break
 			}
-			v := it.it.Key()
-			if other, exist := seen[v]; !exist {
-				seen[v] = i
+			// The iterator is still alive, check for collisions with previous ones
+			hash := it.it.Hash()
+			if other, exist := positioned[hash]; !exist {
+				positioned[hash] = i
 				break
 			} else {
+				// Iterators collide, one needs to be progressed, use priority to
+				// determine which.
+				//
 				// This whole else-block can be avoided, if we instead
 				// do an inital priority-sort of the iterators. If we do that,
 				// then we'll only wind up here if a lower-priority (preferred) iterator
 				// has the same value, and then we will always just continue.
 				// However, it costs an extra sort, so it's probably not better
-
-				// One needs to be progressed, use priority to determine which
 				if fi.iterators[other].priority < it.priority {
-					// the 'it' should be progressed
+					// The 'it' should be progressed
 					continue
 				} else {
-					// the 'other' should be progressed - swap them
+					// The 'other' should be progressed, swap them
 					it = fi.iterators[other]
 					fi.iterators[other], fi.iterators[i] = fi.iterators[i], fi.iterators[other]
 					continue
@@ -115,15 +146,12 @@ func (fi *fastAccountIterator) Seek(key common.Hash) {
 			}
 		}
 	}
-	// Now remove those that were placed in the end
-	fi.iterators = fi.iterators[:length]
-	// The list is now totally unsorted, need to re-sort the entire list
-	sort.Sort(fi)
+	// Re-sort the entire list
+	sort.Sort(fi.iterators)
 	fi.initiated = false
 }
 
-// Next implements the Iterator interface. It returns false if no more elemnts
-// can be retrieved (false == exhausted)
+// Next steps the iterator forward one element, returning false if exhausted.
 func (fi *fastAccountIterator) Next() bool {
 	if len(fi.iterators) == 0 {
 		return false
@@ -134,101 +162,88 @@ func (fi *fastAccountIterator) Next() bool {
 		fi.initiated = true
 		return true
 	}
-	return fi.innerNext(0)
+	return fi.next(0)
 }
 
-// innerNext handles the next operation internally,
-// and should be invoked when we know that two elements in the list may have
-// the same value.
-// For example, if the list becomes [2,3,5,5,8,9,10], then we should invoke
-// innerNext(3), which will call Next on elem 3 (the second '5'). It will continue
-// along the list and apply the same operation if needed
-func (fi *fastAccountIterator) innerNext(pos int) bool {
-	if !fi.iterators[pos].it.Next() {
-		//Exhausted, remove this iterator
-		fi.remove(pos)
-		if len(fi.iterators) == 0 {
-			return false
-		}
-		return true
+// next handles the next operation internally and should be invoked when we know
+// that two elements in the list may have the same value.
+//
+// For example, if the iterated hashes become [2,3,5,5,8,9,10], then we should
+// invoke next(3), which will call Next on elem 3 (the second '5') and will
+// cascade along the list, applying the same operation if needed.
+func (fi *fastAccountIterator) next(idx int) bool {
+	// If this particular iterator got exhausted, remove it and return true (the
+	// next one is surely not exhausted yet, otherwise it would have been removed
+	// already).
+	if it := fi.iterators[idx].it; !it.Next() {
+		it.Release()
+
+		fi.iterators = append(fi.iterators[:idx], fi.iterators[idx+1:]...)
+		return len(fi.iterators) > 0
 	}
-	if pos == len(fi.iterators)-1 {
-		// Only one iterator left
+	// If there's noone left to cascade into, return
+	if idx == len(fi.iterators)-1 {
 		return true
 	}
-	// We next:ed the elem at 'pos'. Now we may have to re-sort that elem
+	// We next-ed the iterator at 'idx', now we may have to re-sort that element
 	var (
-		current, neighbour = fi.iterators[pos], fi.iterators[pos+1]
-		val, neighbourVal  = current.it.Key(), neighbour.it.Key()
+		cur, next         = fi.iterators[idx], fi.iterators[idx+1]
+		curHash, nextHash = cur.it.Hash(), next.it.Hash()
 	)
-	if diff := bytes.Compare(val[:], neighbourVal[:]); diff < 0 {
+	if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 {
 		// It is still in correct place
 		return true
-	} else if diff == 0 && current.priority < neighbour.priority {
-		// So still in correct place, but we need to iterate on the neighbour
-		fi.innerNext(pos + 1)
+	} else if diff == 0 && cur.priority < next.priority {
+		// So still in correct place, but we need to iterate on the next
+		fi.next(idx + 1)
 		return true
 	}
-	// At this point, the elem is in the wrong location, but the
-	// remaining list is sorted. Find out where to move the elem
-	iteratee := -1
+	// At this point, the iterator is in the wrong location, but the remaining
+	// list is sorted. Find out where to move the item.
+	clash := -1
 	index := sort.Search(len(fi.iterators), func(n int) bool {
-		if n < pos {
-			// No need to search 'behind' us
+		// The iterator always advances forward, so anything before the old slot
+		// is known to be behind us, so just skip them altogether. This actually
+		// is an important clause since the sort order got invalidated.
+		if n < idx {
 			return false
 		}
 		if n == len(fi.iterators)-1 {
 			// Can always place an elem last
 			return true
 		}
-		neighbour := fi.iterators[n+1].it.Key()
-		if diff := bytes.Compare(val[:], neighbour[:]); diff < 0 {
+		nextHash := fi.iterators[n+1].it.Hash()
+		if diff := bytes.Compare(curHash[:], nextHash[:]); diff < 0 {
 			return true
 		} else if diff > 0 {
 			return false
 		}
 		// The elem we're placing it next to has the same value,
 		// so whichever winds up on n+1 will need further iteraton
-		iteratee = n + 1
-		if current.priority < fi.iterators[n+1].priority {
+		clash = n + 1
+		if cur.priority < fi.iterators[n+1].priority {
 			// We can drop the iterator here
 			return true
 		}
 		// We need to move it one step further
 		return false
 		// TODO benchmark which is best, this works too:
-		//iteratee = n
+		//clash = n
 		//return true
 		// Doing so should finish the current search earlier
 	})
-	fi.move(pos, index)
-	if iteratee != -1 {
-		fi.innerNext(iteratee)
+	fi.move(idx, index)
+	if clash != -1 {
+		fi.next(clash)
 	}
 	return true
 }
 
-// move moves an iterator to another position in the list
+// move advances an iterator to another position in the list.
 func (fi *fastAccountIterator) move(index, newpos int) {
-	if newpos > len(fi.iterators)-1 {
-		newpos = len(fi.iterators) - 1
-	}
-	var (
-		elem   = fi.iterators[index]
-		middle = fi.iterators[index+1 : newpos+1]
-		suffix []*weightedIterator
-	)
-	if newpos < len(fi.iterators)-1 {
-		suffix = fi.iterators[newpos+1:]
-	}
-	fi.iterators = append(fi.iterators[:index], middle...)
-	fi.iterators = append(fi.iterators, elem)
-	fi.iterators = append(fi.iterators, suffix...)
-}
-
-// remove drops an iterator from the list
-func (fi *fastAccountIterator) remove(index int) {
-	fi.iterators = append(fi.iterators[:index], fi.iterators[index+1:]...)
+	elem := fi.iterators[index]
+	copy(fi.iterators[index:], fi.iterators[index+1:newpos+1])
+	fi.iterators[newpos] = elem
 }
 
 // Error returns any failure that occurred during iteration, which might have
@@ -237,20 +252,29 @@ func (fi *fastAccountIterator) Error() error {
 	return fi.fail
 }
 
-// Key returns the current key
-func (fi *fastAccountIterator) Key() common.Hash {
-	return fi.iterators[0].it.Key()
+// Hash returns the current key
+func (fi *fastAccountIterator) Hash() common.Hash {
+	return fi.iterators[0].it.Hash()
 }
 
-// Value returns the current key
-func (fi *fastAccountIterator) Value() []byte {
-	return fi.iterators[0].it.Value()
+// Account returns the current key
+func (fi *fastAccountIterator) Account() []byte {
+	return fi.iterators[0].it.Account()
+}
+
+// Release iterates over all the remaining live layer iterators and releases each
+// of thme individually.
+func (fi *fastAccountIterator) Release() {
+	for _, it := range fi.iterators {
+		it.it.Release()
+	}
+	fi.iterators = nil
 }
 
 // Debug is a convencience helper during testing
 func (fi *fastAccountIterator) Debug() {
 	for _, it := range fi.iterators {
-		fmt.Printf("[p=%v v=%v] ", it.priority, it.it.Key()[0])
+		fmt.Printf("[p=%v v=%v] ", it.priority, it.it.Hash()[0])
 	}
 	fmt.Println()
 }
diff --git a/core/state/snapshot/iterator_test.go b/core/state/snapshot/iterator_test.go
index 01e525653dd3cfe8e13b4e05ff9481cc168c37d3..902985cf6cf930d0e844e814aeba6b996a0825ba 100644
--- a/core/state/snapshot/iterator_test.go
+++ b/core/state/snapshot/iterator_test.go
@@ -23,7 +23,9 @@ import (
 	"math/rand"
 	"testing"
 
+	"github.com/VictoriaMetrics/fastcache"
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core/rawdb"
 )
 
 // TestIteratorBasics tests some simple single-layer iteration
@@ -47,7 +49,7 @@ func TestIteratorBasics(t *testing.T) {
 	}
 	// Add some (identical) layers on top
 	parent := newDiffLayer(emptyLayer(), common.Hash{}, accounts, storage)
-	it := parent.newAccountIterator()
+	it := parent.AccountIterator(common.Hash{})
 	verifyIterator(t, 100, it)
 }
 
@@ -75,14 +77,16 @@ func (ti *testIterator) Error() error {
 	panic("implement me")
 }
 
-func (ti *testIterator) Key() common.Hash {
+func (ti *testIterator) Hash() common.Hash {
 	return common.BytesToHash([]byte{ti.values[0]})
 }
 
-func (ti *testIterator) Value() []byte {
+func (ti *testIterator) Account() []byte {
 	panic("implement me")
 }
 
+func (ti *testIterator) Release() {}
+
 func TestFastIteratorBasics(t *testing.T) {
 	type testCase struct {
 		lists   [][]byte
@@ -96,10 +100,10 @@ func TestFastIteratorBasics(t *testing.T) {
 			{9, 10}, {10, 13, 15, 16}},
 			expKeys: []byte{0, 1, 2, 7, 8, 9, 10, 13, 14, 15, 16}},
 	} {
-		var iterators []*weightedIterator
+		var iterators []*weightedAccountIterator
 		for i, data := range tc.lists {
 			it := newTestIterator(data...)
-			iterators = append(iterators, &weightedIterator{it, i})
+			iterators = append(iterators, &weightedAccountIterator{it, i})
 
 		}
 		fi := &fastAccountIterator{
@@ -108,7 +112,7 @@ func TestFastIteratorBasics(t *testing.T) {
 		}
 		count := 0
 		for fi.Next() {
-			if got, exp := fi.Key()[31], tc.expKeys[count]; exp != got {
+			if got, exp := fi.Hash()[31], tc.expKeys[count]; exp != got {
 				t.Errorf("tc %d, [%d]: got %d exp %d", i, count, got, exp)
 			}
 			count++
@@ -117,68 +121,86 @@ func TestFastIteratorBasics(t *testing.T) {
 }
 
 func verifyIterator(t *testing.T, expCount int, it AccountIterator) {
+	t.Helper()
+
 	var (
-		i    = 0
-		last = common.Hash{}
+		count = 0
+		last  = common.Hash{}
 	)
 	for it.Next() {
-		v := it.Key()
-		if bytes.Compare(last[:], v[:]) >= 0 {
-			t.Errorf("Wrong order:\n%x \n>=\n%x", last, v)
+		if hash := it.Hash(); bytes.Compare(last[:], hash[:]) >= 0 {
+			t.Errorf("wrong order: %x >= %x", last, hash)
 		}
-		i++
+		count++
 	}
-	if i != expCount {
-		t.Errorf("iterator len wrong, expected %d, got %d", expCount, i)
+	if count != expCount {
+		t.Errorf("iterator count mismatch: have %d, want %d", count, expCount)
+	}
+	if err := it.Error(); err != nil {
+		t.Errorf("iterator failed: %v", err)
 	}
 }
 
-// TestIteratorTraversal tests some simple multi-layer iteration
+// TestIteratorTraversal tests some simple multi-layer iteration.
 func TestIteratorTraversal(t *testing.T) {
-	var (
-		storage = make(map[common.Hash]map[common.Hash][]byte)
-	)
-
-	mkAccounts := func(args ...string) map[common.Hash][]byte {
-		accounts := make(map[common.Hash][]byte)
-		for _, h := range args {
-			accounts[common.HexToHash(h)] = randomAccount()
-		}
-		return accounts
+	// Create an empty base layer and a snapshot tree out of it
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
 	}
-	// entries in multiple layers should only become output once
-	parent := newDiffLayer(emptyLayer(), common.Hash{},
-		mkAccounts("0xaa", "0xee", "0xff", "0xf0"), storage)
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	// Stack three diff layers on top with various overlaps
+	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"),
+		randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
+
+	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"),
+		randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
 
-	child := parent.Update(common.Hash{},
-		mkAccounts("0xbb", "0xdd", "0xf0"), storage)
+	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"),
+		randomAccountSet("0xcc", "0xf0", "0xff"), nil)
 
-	child = child.Update(common.Hash{},
-		mkAccounts("0xcc", "0xf0", "0xff"), storage)
+	// Verify the single and multi-layer iterators
+	head := snaps.Snapshot(common.HexToHash("0x04"))
 
-	// single layer iterator
-	verifyIterator(t, 3, child.newAccountIterator())
-	// multi-layered binary iterator
-	verifyIterator(t, 7, child.newBinaryAccountIterator())
-	// multi-layered fast iterator
-	verifyIterator(t, 7, child.newFastAccountIterator())
+	verifyIterator(t, 3, head.(snapshot).AccountIterator(common.Hash{}))
+	verifyIterator(t, 7, head.(*diffLayer).newBinaryAccountIterator())
+
+	it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
+	defer it.Release()
+
+	verifyIterator(t, 7, it)
 }
 
 // TestIteratorTraversalValues tests some multi-layer iteration, where we
-// also expect the correct values to show up
+// also expect the correct values to show up.
 func TestIteratorTraversalValues(t *testing.T) {
+	// Create an empty base layer and a snapshot tree out of it
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
+	}
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	// Create a batch of account sets to seed subsequent layers with
 	var (
-		storage = make(map[common.Hash]map[common.Hash][]byte)
-		a       = make(map[common.Hash][]byte)
-		b       = make(map[common.Hash][]byte)
-		c       = make(map[common.Hash][]byte)
-		d       = make(map[common.Hash][]byte)
-		e       = make(map[common.Hash][]byte)
-		f       = make(map[common.Hash][]byte)
-		g       = make(map[common.Hash][]byte)
-		h       = make(map[common.Hash][]byte)
+		a = make(map[common.Hash][]byte)
+		b = make(map[common.Hash][]byte)
+		c = make(map[common.Hash][]byte)
+		d = make(map[common.Hash][]byte)
+		e = make(map[common.Hash][]byte)
+		f = make(map[common.Hash][]byte)
+		g = make(map[common.Hash][]byte)
+		h = make(map[common.Hash][]byte)
 	)
-	// entries in multiple layers should only become output once
 	for i := byte(2); i < 0xff; i++ {
 		a[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 0, i))
 		if i > 20 && i%2 == 0 {
@@ -203,35 +225,36 @@ func TestIteratorTraversalValues(t *testing.T) {
 			h[common.Hash{i}] = []byte(fmt.Sprintf("layer-%d, key %d", 7, i))
 		}
 	}
-	child := newDiffLayer(emptyLayer(), common.Hash{}, a, storage).
-		Update(common.Hash{}, b, storage).
-		Update(common.Hash{}, c, storage).
-		Update(common.Hash{}, d, storage).
-		Update(common.Hash{}, e, storage).
-		Update(common.Hash{}, f, storage).
-		Update(common.Hash{}, g, storage).
-		Update(common.Hash{}, h, storage)
-
-	it := child.newFastAccountIterator()
+	// Assemble a stack of snapshots from the account layers
+	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), a, nil)
+	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), b, nil)
+	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), c, nil)
+	snaps.Update(common.HexToHash("0x05"), common.HexToHash("0x04"), d, nil)
+	snaps.Update(common.HexToHash("0x06"), common.HexToHash("0x05"), e, nil)
+	snaps.Update(common.HexToHash("0x07"), common.HexToHash("0x06"), f, nil)
+	snaps.Update(common.HexToHash("0x08"), common.HexToHash("0x07"), g, nil)
+	snaps.Update(common.HexToHash("0x09"), common.HexToHash("0x08"), h, nil)
+
+	it, _ := snaps.AccountIterator(common.HexToHash("0x09"), common.Hash{})
+	defer it.Release()
+
+	head := snaps.Snapshot(common.HexToHash("0x09"))
 	for it.Next() {
-		key := it.Key()
-		exp, err := child.accountRLP(key, 0)
+		hash := it.Hash()
+		want, err := head.AccountRLP(hash)
 		if err != nil {
-			t.Fatal(err)
+			t.Fatalf("failed to retrieve expected account: %v", err)
 		}
-		got := it.Value()
-		if !bytes.Equal(exp, got) {
-			t.Fatalf("Error on key %x, got %v exp %v", key, string(got), string(exp))
+		if have := it.Account(); !bytes.Equal(want, have) {
+			t.Fatalf("hash %x: account mismatch: have %x, want %x", hash, have, want)
 		}
-		//fmt.Printf("val: %v\n", string(it.Value()))
 	}
 }
 
+// This testcase is notorious, all layers contain the exact same 200 accounts.
 func TestIteratorLargeTraversal(t *testing.T) {
-	// This testcase is a bit notorious -- all layers contain the exact
-	// same 200 accounts.
-	var storage = make(map[common.Hash]map[common.Hash][]byte)
-	mkAccounts := func(num int) map[common.Hash][]byte {
+	// Create a custom account factory to recreate the same addresses
+	makeAccounts := func(num int) map[common.Hash][]byte {
 		accounts := make(map[common.Hash][]byte)
 		for i := 0; i < num; i++ {
 			h := common.Hash{}
@@ -240,25 +263,121 @@ func TestIteratorLargeTraversal(t *testing.T) {
 		}
 		return accounts
 	}
-	parent := newDiffLayer(emptyLayer(), common.Hash{},
-		mkAccounts(200), storage)
-	child := parent.Update(common.Hash{},
-		mkAccounts(200), storage)
-	for i := 2; i < 100; i++ {
-		child = child.Update(common.Hash{},
-			mkAccounts(200), storage)
-	}
-	// single layer iterator
-	verifyIterator(t, 200, child.newAccountIterator())
-	// multi-layered binary iterator
-	verifyIterator(t, 200, child.newBinaryAccountIterator())
-	// multi-layered fast iterator
-	verifyIterator(t, 200, child.newFastAccountIterator())
+	// Build up a large stack of snapshots
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
+	}
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	for i := 1; i < 128; i++ {
+		snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), makeAccounts(200), nil)
+	}
+	// Iterate the entire stack and ensure everything is hit only once
+	head := snaps.Snapshot(common.HexToHash("0x80"))
+	verifyIterator(t, 200, head.(snapshot).AccountIterator(common.Hash{}))
+	verifyIterator(t, 200, head.(*diffLayer).newBinaryAccountIterator())
+
+	it, _ := snaps.AccountIterator(common.HexToHash("0x80"), common.Hash{})
+	defer it.Release()
+
+	verifyIterator(t, 200, it)
 }
 
-// BenchmarkIteratorTraversal is a bit a bit notorious -- all layers contain the exact
-// same 200 accounts. That means that we need to process 2000 items, but only
-// spit out 200 values eventually.
+// TestIteratorFlattening tests what happens when we
+// - have a live iterator on child C (parent C1 -> C2 .. CN)
+// - flattens C2 all the way into CN
+// - continues iterating
+func TestIteratorFlattening(t *testing.T) {
+	// Create an empty base layer and a snapshot tree out of it
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
+	}
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	// Create a stack of diffs on top
+	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"),
+		randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
+
+	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"),
+		randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
+
+	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"),
+		randomAccountSet("0xcc", "0xf0", "0xff"), nil)
+
+	// Create an iterator and flatten the data from underneath it
+	it, _ := snaps.AccountIterator(common.HexToHash("0x04"), common.Hash{})
+	defer it.Release()
+
+	if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil {
+		t.Fatalf("failed to flatten snapshot stack: %v", err)
+	}
+	//verifyIterator(t, 7, it)
+}
+
+func TestIteratorSeek(t *testing.T) {
+	// Create a snapshot stack with some initial data
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
+	}
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"),
+		randomAccountSet("0xaa", "0xee", "0xff", "0xf0"), nil)
+
+	snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"),
+		randomAccountSet("0xbb", "0xdd", "0xf0"), nil)
+
+	snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"),
+		randomAccountSet("0xcc", "0xf0", "0xff"), nil)
+
+	// Construct various iterators and ensure their tranversal is correct
+	it, _ := snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xdd"))
+	defer it.Release()
+	verifyIterator(t, 3, it) // expected: ee, f0, ff
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xaa"))
+	defer it.Release()
+	verifyIterator(t, 3, it) // expected: ee, f0, ff
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x02"), common.HexToHash("0xff"))
+	defer it.Release()
+	verifyIterator(t, 0, it) // expected: nothing
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xbb"))
+	defer it.Release()
+	verifyIterator(t, 5, it) // expected: cc, dd, ee, f0, ff
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xef"))
+	defer it.Release()
+	verifyIterator(t, 2, it) // expected: f0, ff
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xf0"))
+	defer it.Release()
+	verifyIterator(t, 1, it) // expected: ff
+
+	it, _ = snaps.AccountIterator(common.HexToHash("0x04"), common.HexToHash("0xff"))
+	defer it.Release()
+	verifyIterator(t, 0, it) // expected: nothing
+}
+
+// BenchmarkIteratorTraversal is a bit a bit notorious -- all layers contain the
+// exact same 200 accounts. That means that we need to process 2000 items, but
+// only spit out 200 values eventually.
 //
 // The value-fetching benchmark is easy on the binary iterator, since it never has to reach
 // down at any depth for retrieving the values -- all are on the toppmost layer
@@ -267,12 +386,9 @@ func TestIteratorLargeTraversal(t *testing.T) {
 // BenchmarkIteratorTraversal/binary_iterator_values-6       	    2403	    501810 ns/op
 // BenchmarkIteratorTraversal/fast_iterator_keys-6           	    1923	    677966 ns/op
 // BenchmarkIteratorTraversal/fast_iterator_values-6         	    1741	    649967 ns/op
-//
 func BenchmarkIteratorTraversal(b *testing.B) {
-
-	var storage = make(map[common.Hash]map[common.Hash][]byte)
-
-	mkAccounts := func(num int) map[common.Hash][]byte {
+	// Create a custom account factory to recreate the same addresses
+	makeAccounts := func(num int) map[common.Hash][]byte {
 		accounts := make(map[common.Hash][]byte)
 		for i := 0; i < num; i++ {
 			h := common.Hash{}
@@ -281,24 +397,29 @@ func BenchmarkIteratorTraversal(b *testing.B) {
 		}
 		return accounts
 	}
-	parent := newDiffLayer(emptyLayer(), common.Hash{},
-		mkAccounts(200), storage)
-
-	child := parent.Update(common.Hash{},
-		mkAccounts(200), storage)
-
-	for i := 2; i < 100; i++ {
-		child = child.Update(common.Hash{},
-			mkAccounts(200), storage)
-
+	// Build up a large stack of snapshots
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
+	}
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	for i := 1; i <= 100; i++ {
+		snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), makeAccounts(200), nil)
 	}
 	// We call this once before the benchmark, so the creation of
 	// sorted accountlists are not included in the results.
-	child.newBinaryAccountIterator()
+	head := snaps.Snapshot(common.HexToHash("0x65"))
+	head.(*diffLayer).newBinaryAccountIterator()
+
 	b.Run("binary iterator keys", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			got := 0
-			it := child.newBinaryAccountIterator()
+			it := head.(*diffLayer).newBinaryAccountIterator()
 			for it.Next() {
 				got++
 			}
@@ -310,10 +431,10 @@ func BenchmarkIteratorTraversal(b *testing.B) {
 	b.Run("binary iterator values", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			got := 0
-			it := child.newBinaryAccountIterator()
+			it := head.(*diffLayer).newBinaryAccountIterator()
 			for it.Next() {
 				got++
-				child.accountRLP(it.Key(), 0)
+				head.(*diffLayer).accountRLP(it.Hash(), 0)
 			}
 			if exp := 200; got != exp {
 				b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
@@ -322,8 +443,10 @@ func BenchmarkIteratorTraversal(b *testing.B) {
 	})
 	b.Run("fast iterator keys", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
+			it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+			defer it.Release()
+
 			got := 0
-			it := child.newFastAccountIterator()
 			for it.Next() {
 				got++
 			}
@@ -334,11 +457,13 @@ func BenchmarkIteratorTraversal(b *testing.B) {
 	})
 	b.Run("fast iterator values", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
+			it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+			defer it.Release()
+
 			got := 0
-			it := child.newFastAccountIterator()
 			for it.Next() {
 				got++
-				it.Value()
+				it.Account()
 			}
 			if exp := 200; got != exp {
 				b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
@@ -354,13 +479,12 @@ func BenchmarkIteratorTraversal(b *testing.B) {
 // call recursively 100 times for the majority of the values
 //
 // BenchmarkIteratorLargeBaselayer/binary_iterator_(keys)-6         	     514	   1971999 ns/op
-// BenchmarkIteratorLargeBaselayer/fast_iterator_(keys)-6           	   10000	    114385 ns/op
 // BenchmarkIteratorLargeBaselayer/binary_iterator_(values)-6       	      61	  18997492 ns/op
+// BenchmarkIteratorLargeBaselayer/fast_iterator_(keys)-6           	   10000	    114385 ns/op
 // BenchmarkIteratorLargeBaselayer/fast_iterator_(values)-6         	    4047	    296823 ns/op
 func BenchmarkIteratorLargeBaselayer(b *testing.B) {
-	var storage = make(map[common.Hash]map[common.Hash][]byte)
-
-	mkAccounts := func(num int) map[common.Hash][]byte {
+	// Create a custom account factory to recreate the same addresses
+	makeAccounts := func(num int) map[common.Hash][]byte {
 		accounts := make(map[common.Hash][]byte)
 		for i := 0; i < num; i++ {
 			h := common.Hash{}
@@ -369,25 +493,30 @@ func BenchmarkIteratorLargeBaselayer(b *testing.B) {
 		}
 		return accounts
 	}
-
-	parent := newDiffLayer(emptyLayer(), common.Hash{},
-		mkAccounts(2000), storage)
-
-	child := parent.Update(common.Hash{},
-		mkAccounts(20), storage)
-
-	for i := 2; i < 100; i++ {
-		child = child.Update(common.Hash{},
-			mkAccounts(20), storage)
-
+	// Build up a large stack of snapshots
+	base := &diskLayer{
+		diskdb: rawdb.NewMemoryDatabase(),
+		root:   common.HexToHash("0x01"),
+		cache:  fastcache.New(1024 * 500),
+	}
+	snaps := &Tree{
+		layers: map[common.Hash]snapshot{
+			base.root: base,
+		},
+	}
+	snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), makeAccounts(2000), nil)
+	for i := 2; i <= 100; i++ {
+		snaps.Update(common.HexToHash(fmt.Sprintf("0x%02x", i+1)), common.HexToHash(fmt.Sprintf("0x%02x", i)), makeAccounts(20), nil)
 	}
 	// We call this once before the benchmark, so the creation of
 	// sorted accountlists are not included in the results.
-	child.newBinaryAccountIterator()
+	head := snaps.Snapshot(common.HexToHash("0x65"))
+	head.(*diffLayer).newBinaryAccountIterator()
+
 	b.Run("binary iterator (keys)", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			got := 0
-			it := child.newBinaryAccountIterator()
+			it := head.(*diffLayer).newBinaryAccountIterator()
 			for it.Next() {
 				got++
 			}
@@ -396,39 +525,42 @@ func BenchmarkIteratorLargeBaselayer(b *testing.B) {
 			}
 		}
 	})
-	b.Run("fast iterator (keys)", func(b *testing.B) {
+	b.Run("binary iterator (values)", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
 			got := 0
-			it := child.newFastAccountIterator()
+			it := head.(*diffLayer).newBinaryAccountIterator()
 			for it.Next() {
 				got++
+				v := it.Hash()
+				head.(*diffLayer).accountRLP(v, 0)
 			}
 			if exp := 2000; got != exp {
 				b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
 			}
 		}
 	})
-	b.Run("binary iterator (values)", func(b *testing.B) {
+	b.Run("fast iterator (keys)", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
+			it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+			defer it.Release()
+
 			got := 0
-			it := child.newBinaryAccountIterator()
 			for it.Next() {
 				got++
-				v := it.Key()
-				child.accountRLP(v, -0)
 			}
 			if exp := 2000; got != exp {
 				b.Errorf("iterator len wrong, expected %d, got %d", exp, got)
 			}
 		}
 	})
-
 	b.Run("fast iterator (values)", func(b *testing.B) {
 		for i := 0; i < b.N; i++ {
+			it, _ := snaps.AccountIterator(common.HexToHash("0x65"), common.Hash{})
+			defer it.Release()
+
 			got := 0
-			it := child.newFastAccountIterator()
 			for it.Next() {
-				it.Value()
+				it.Account()
 				got++
 			}
 			if exp := 2000; got != exp {
@@ -438,117 +570,38 @@ func BenchmarkIteratorLargeBaselayer(b *testing.B) {
 	})
 }
 
-// TestIteratorFlatting tests what happens when we
-// - have a live iterator on child C (parent C1 -> C2 .. CN)
-// - flattens C2 all the way into CN
-// - continues iterating
-// Right now, this "works" simply because the keys do not change -- the
-// iterator is not aware that a layer has become stale. This naive
-// solution probably won't work in the long run, however
-func TestIteratorFlattning(t *testing.T) {
-	var (
-		storage = make(map[common.Hash]map[common.Hash][]byte)
-	)
-	mkAccounts := func(args ...string) map[common.Hash][]byte {
-		accounts := make(map[common.Hash][]byte)
-		for _, h := range args {
-			accounts[common.HexToHash(h)] = randomAccount()
-		}
-		return accounts
-	}
-	// entries in multiple layers should only become output once
-	parent := newDiffLayer(emptyLayer(), common.Hash{},
-		mkAccounts("0xaa", "0xee", "0xff", "0xf0"), storage)
-
-	child := parent.Update(common.Hash{},
-		mkAccounts("0xbb", "0xdd", "0xf0"), storage)
-
-	child = child.Update(common.Hash{},
-		mkAccounts("0xcc", "0xf0", "0xff"), storage)
+/*
+func BenchmarkBinaryAccountIteration(b *testing.B) {
+	benchmarkAccountIteration(b, func(snap snapshot) AccountIterator {
+		return snap.(*diffLayer).newBinaryAccountIterator()
+	})
+}
 
-	it := child.newFastAccountIterator()
-	child.parent.(*diffLayer).flatten()
-	// The parent should now be stale
-	verifyIterator(t, 7, it)
+func BenchmarkFastAccountIteration(b *testing.B) {
+	benchmarkAccountIteration(b, newFastAccountIterator)
 }
 
-func TestIteratorSeek(t *testing.T) {
-	storage := make(map[common.Hash]map[common.Hash][]byte)
-	mkAccounts := func(args ...string) map[common.Hash][]byte {
-		accounts := make(map[common.Hash][]byte)
-		for _, h := range args {
-			accounts[common.HexToHash(h)] = randomAccount()
-		}
-		return accounts
+func benchmarkAccountIteration(b *testing.B, iterator func(snap snapshot) AccountIterator) {
+	// Create a diff stack and randomize the accounts across them
+	layers := make([]map[common.Hash][]byte, 128)
+	for i := 0; i < len(layers); i++ {
+		layers[i] = make(map[common.Hash][]byte)
 	}
-	parent := newDiffLayer(emptyLayer(), common.Hash{},
-		mkAccounts("0xaa", "0xee", "0xff", "0xf0"), storage)
-	it := AccountIterator(parent.newAccountIterator())
-	// expected: ee, f0, ff
-	it.Seek(common.HexToHash("0xdd"))
-	verifyIterator(t, 3, it)
-
-	it = parent.newAccountIterator()
-	// expected: ee, f0, ff
-	it.Seek(common.HexToHash("0xaa"))
-	verifyIterator(t, 3, it)
-
-	it = parent.newAccountIterator()
-	// expected: nothing
-	it.Seek(common.HexToHash("0xff"))
-	verifyIterator(t, 0, it)
-
-	child := parent.Update(common.Hash{},
-		mkAccounts("0xbb", "0xdd", "0xf0"), storage)
-
-	child = child.Update(common.Hash{},
-		mkAccounts("0xcc", "0xf0", "0xff"), storage)
-
-	it = child.newFastAccountIterator()
-	// expected: cc, dd, ee, f0, ff
-	it.Seek(common.HexToHash("0xbb"))
-	verifyIterator(t, 5, it)
-
-	it = child.newFastAccountIterator()
-	it.Seek(common.HexToHash("0xef"))
-	// exp: f0, ff
-	verifyIterator(t, 2, it)
-
-	it = child.newFastAccountIterator()
-	it.Seek(common.HexToHash("0xf0"))
-	verifyIterator(t, 1, it)
-
-	it.Seek(common.HexToHash("0xff"))
-	verifyIterator(t, 0, it)
-}
+	for i := 0; i < b.N; i++ {
+		depth := rand.Intn(len(layers))
+		layers[depth][randomHash()] = randomAccount()
+	}
+	stack := snapshot(emptyLayer())
+	for _, layer := range layers {
+		stack = stack.Update(common.Hash{}, layer, nil)
+	}
+	// Reset the timers and report all the stats
+	it := iterator(stack)
 
-//BenchmarkIteratorSeek/init+seek-6         	    4328	    245477 ns/op
-func BenchmarkIteratorSeek(b *testing.B) {
+	b.ResetTimer()
+	b.ReportAllocs()
 
-	var storage = make(map[common.Hash]map[common.Hash][]byte)
-	mkAccounts := func(num int) map[common.Hash][]byte {
-		accounts := make(map[common.Hash][]byte)
-		for i := 0; i < num; i++ {
-			h := common.Hash{}
-			binary.BigEndian.PutUint64(h[:], uint64(i+1))
-			accounts[h] = randomAccount()
-		}
-		return accounts
-	}
-	layer := newDiffLayer(emptyLayer(), common.Hash{}, mkAccounts(200), storage)
-	for i := 1; i < 100; i++ {
-		layer = layer.Update(common.Hash{},
-			mkAccounts(200), storage)
+	for it.Next() {
 	}
-	b.Run("init+seek", func(b *testing.B) {
-		b.ResetTimer()
-		seekpos := make([]byte, 20)
-		for i := 0; i < b.N; i++ {
-			b.StopTimer()
-			rand.Read(seekpos)
-			it := layer.newFastAccountIterator()
-			b.StartTimer()
-			it.Seek(common.BytesToHash(seekpos))
-		}
-	})
 }
+*/
diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go
index 7650cf2c13a60989f8eb38d90d8745539c6d7635..5f9a8be637e0899acddefe945251fcddc5a592cb 100644
--- a/core/state/snapshot/snapshot.go
+++ b/core/state/snapshot/snapshot.go
@@ -113,9 +113,17 @@ type Snapshot interface {
 type snapshot interface {
 	Snapshot
 
+	// Parent returns the subsequent layer of a snapshot, or nil if the base was
+	// reached.
+	//
+	// Note, the method is an internal helper to avoid type switching between the
+	// disk and diff layers. There is no locking involved.
+	Parent() snapshot
+
 	// Update creates a new layer on top of the existing snapshot diff tree with
-	// the specified data items. Note, the maps are retained by the method to avoid
-	// copying everything.
+	// the specified data items.
+	//
+	// Note, the maps are retained by the method to avoid copying everything.
 	Update(blockRoot common.Hash, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer
 
 	// Journal commits an entire diff hierarchy to disk into a single journal entry.
@@ -126,6 +134,9 @@ type snapshot interface {
 	// Stale return whether this layer has become stale (was flattened across) or
 	// if it's still live.
 	Stale() bool
+
+	// AccountIterator creates an account iterator over an arbitrary layer.
+	AccountIterator(seek common.Hash) AccountIterator
 }
 
 // SnapshotTree is an Ethereum state snapshot tree. It consists of one persistent
@@ -170,15 +181,7 @@ func New(diskdb ethdb.KeyValueStore, triedb *trie.Database, cache int, root comm
 	// Existing snapshot loaded, seed all the layers
 	for head != nil {
 		snap.layers[head.Root()] = head
-
-		switch self := head.(type) {
-		case *diffLayer:
-			head = self.parent
-		case *diskLayer:
-			head = nil
-		default:
-			panic(fmt.Sprintf("unknown data layer: %T", self))
-		}
+		head = head.Parent()
 	}
 	return snap
 }
@@ -563,3 +566,9 @@ func (t *Tree) Rebuild(root common.Hash) {
 		root: generateSnapshot(t.diskdb, t.triedb, t.cache, root, wiper),
 	}
 }
+
+// AccountIterator creates a new account iterator for the specified root hash and
+// seeks to a starting account hash.
+func (t *Tree) AccountIterator(root common.Hash, seek common.Hash) (AccountIterator, error) {
+	return newFastAccountIterator(t, root, seek)
+}
diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go
index 44b8f3cefda4449db5dac970ed3f0c471c97eff3..2b14828178af91e8cd8214c2b554d9ac282b6916 100644
--- a/core/state/snapshot/snapshot_test.go
+++ b/core/state/snapshot/snapshot_test.go
@@ -18,13 +18,48 @@ package snapshot
 
 import (
 	"fmt"
+	"math/big"
+	"math/rand"
 	"testing"
 
 	"github.com/VictoriaMetrics/fastcache"
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/rlp"
 )
 
+// randomHash generates a random blob of data and returns it as a hash.
+func randomHash() common.Hash {
+	var hash common.Hash
+	if n, err := rand.Read(hash[:]); n != common.HashLength || err != nil {
+		panic(err)
+	}
+	return hash
+}
+
+// randomAccount generates a random account and returns it RLP encoded.
+func randomAccount() []byte {
+	root := randomHash()
+	a := Account{
+		Balance:  big.NewInt(rand.Int63()),
+		Nonce:    rand.Uint64(),
+		Root:     root[:],
+		CodeHash: emptyCode[:],
+	}
+	data, _ := rlp.EncodeToBytes(a)
+	return data
+}
+
+// randomAccountSet generates a set of random accounts with the given strings as
+// the account address hashes.
+func randomAccountSet(hashes ...string) map[common.Hash][]byte {
+	accounts := make(map[common.Hash][]byte)
+	for _, hash := range hashes {
+		accounts[common.HexToHash(hash)] = randomAccount()
+	}
+	return accounts
+}
+
 // Tests that if a disk layer becomes stale, no active external references will
 // be returned with junk data. This version of the test flattens every diff layer
 // to check internal corner case around the bottom-most memory accumulator.
@@ -46,8 +81,7 @@ func TestDiskLayerExternalInvalidationFullFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	storage := make(map[common.Hash]map[common.Hash][]byte)
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 2 {
@@ -91,11 +125,10 @@ func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	storage := make(map[common.Hash]map[common.Hash][]byte)
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 3 {
@@ -140,11 +173,10 @@ func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	storage := make(map[common.Hash]map[common.Hash][]byte)
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 3 {
@@ -188,14 +220,13 @@ func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) {
 	accounts := map[common.Hash][]byte{
 		common.HexToHash("0xa1"): randomAccount(),
 	}
-	storage := make(map[common.Hash]map[common.Hash][]byte)
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
-	if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), accounts, storage); err != nil {
+	if err := snaps.Update(common.HexToHash("0x04"), common.HexToHash("0x03"), accounts, nil); err != nil {
 		t.Fatalf("failed to create a diff layer: %v", err)
 	}
 	if n := len(snaps.layers); n != 4 {
diff --git a/core/state/snapshot/wipe_test.go b/core/state/snapshot/wipe_test.go
index f12769a950f73d11daf081e4efc04223334b83f6..cb6e174b3184317e4fa0ecbba4c191af3bb4cc0e 100644
--- a/core/state/snapshot/wipe_test.go
+++ b/core/state/snapshot/wipe_test.go
@@ -25,15 +25,6 @@ import (
 	"github.com/ethereum/go-ethereum/ethdb/memorydb"
 )
 
-// randomHash generates a random blob of data and returns it as a hash.
-func randomHash() common.Hash {
-	var hash common.Hash
-	if n, err := rand.Read(hash[:]); n != common.HashLength || err != nil {
-		panic(err)
-	}
-	return hash
-}
-
 // Tests that given a database with random data content, all parts of a snapshot
 // can be crrectly wiped without touching anything else.
 func TestWipe(t *testing.T) {