From e146fbe4e739e5912aadcceb77f9aff803b4a052 Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Fri, 4 Oct 2019 15:24:01 +0200
Subject: [PATCH] core/state: lazy sorting, snapshot invalidation

---
 core/state/snapshot/difflayer.go      | 177 ++++++----
 core/state/snapshot/difflayer_test.go | 448 ++++++++++++++++++++++++++
 core/state/snapshot/disklayer.go      |  44 ++-
 core/state/snapshot/generate.go       |   5 +-
 core/state/snapshot/snapshot.go       |  17 +-
 core/state/snapshot/sort.go           |  30 ++
 core/state/state_object.go            |   8 +-
 core/state/statedb.go                 |  34 +-
 8 files changed, 671 insertions(+), 92 deletions(-)
 create mode 100644 core/state/snapshot/difflayer_test.go

diff --git a/core/state/snapshot/difflayer.go b/core/state/snapshot/difflayer.go
index f163feb56..c7a65e2a4 100644
--- a/core/state/snapshot/difflayer.go
+++ b/core/state/snapshot/difflayer.go
@@ -40,13 +40,12 @@ type diffLayer struct {
 
 	number uint64      // Block number to which this snapshot diff belongs to
 	root   common.Hash // Root hash to which this snapshot diff belongs to
+	stale  bool        // Signals that the layer became stale (state progressed)
 
-	accountList   []common.Hash                          // List of account for iteration, might not be sorted yet (lazy)
-	accountSorted bool                                   // Flag whether the account list has alreayd been sorted or not
-	accountData   map[common.Hash][]byte                 // Keyed accounts for direct retrival (nil means deleted)
-	storageList   map[common.Hash][]common.Hash          // List of storage slots for iterated retrievals, one per account
-	storageSorted map[common.Hash]bool                   // Flag whether the storage slot list has alreayd been sorted or not
-	storageData   map[common.Hash]map[common.Hash][]byte // Keyed storage slots for direct retrival. one per account (nil means deleted)
+	accountList []common.Hash                          // List of account for iteration. If it exists, it's sorted, otherwise it's nil
+	accountData map[common.Hash][]byte                 // Keyed accounts for direct retrival (nil means deleted)
+	storageList map[common.Hash][]common.Hash          // List of storage slots for iterated retrievals, one per account. Any existing lists are sorted if non-nil
+	storageData map[common.Hash]map[common.Hash][]byte // Keyed storage slots for direct retrival. one per account (nil means deleted)
 
 	lock sync.RWMutex
 }
@@ -62,21 +61,13 @@ func newDiffLayer(parent snapshot, number uint64, root common.Hash, accounts map
 		accountData: accounts,
 		storageData: storage,
 	}
-	// Fill the account hashes and sort them for the iterator
-	accountList := make([]common.Hash, 0, len(accounts))
-	for hash, data := range accounts {
-		accountList = append(accountList, hash)
+	// Determine mem size
+	for _, data := range accounts {
 		dl.memory += uint64(len(data))
 	}
-	sort.Sort(hashes(accountList))
-	dl.accountList = accountList
-	dl.accountSorted = true
-
-	dl.memory += uint64(len(dl.accountList) * common.HashLength)
 
 	// Fill the storage hashes and sort them for the iterator
-	dl.storageList = make(map[common.Hash][]common.Hash, len(storage))
-	dl.storageSorted = make(map[common.Hash]bool, len(storage))
+	dl.storageList = make(map[common.Hash][]common.Hash)
 
 	for accountHash, slots := range storage {
 		// If the slots are nil, sanity check that it's a deleted account
@@ -93,19 +84,11 @@ func newDiffLayer(parent snapshot, number uint64, root common.Hash, accounts map
 		// account was just updated.
 		if account, ok := accounts[accountHash]; account == nil || !ok {
 			log.Error(fmt.Sprintf("storage in %#x exists, but account nil (exists: %v)", accountHash, ok))
-			//panic(fmt.Sprintf("storage in %#x exists, but account nil (exists: %v)", accountHash, ok))
 		}
-		// Fill the storage hashes for this account and sort them for the iterator
-		storageList := make([]common.Hash, 0, len(slots))
-		for storageHash, data := range slots {
-			storageList = append(storageList, storageHash)
+		// Determine mem size
+		for _, data := range slots {
 			dl.memory += uint64(len(data))
 		}
-		sort.Sort(hashes(storageList))
-		dl.storageList[accountHash] = storageList
-		dl.storageSorted[accountHash] = true
-
-		dl.memory += uint64(len(storageList) * common.HashLength)
 	}
 	dl.memory += uint64(len(dl.storageList) * common.HashLength)
 
@@ -119,28 +102,36 @@ func (dl *diffLayer) Info() (uint64, common.Hash) {
 
 // Account directly retrieves the account associated with a particular hash in
 // the snapshot slim data format.
-func (dl *diffLayer) Account(hash common.Hash) *Account {
-	data := dl.AccountRLP(hash)
+func (dl *diffLayer) Account(hash common.Hash) (*Account, error) {
+	data, err := dl.AccountRLP(hash)
+	if err != nil {
+		return nil, err
+	}
 	if len(data) == 0 { // can be both nil and []byte{}
-		return nil
+		return nil, nil
 	}
 	account := new(Account)
 	if err := rlp.DecodeBytes(data, account); err != nil {
 		panic(err)
 	}
-	return account
+	return account, nil
 }
 
 // AccountRLP directly retrieves the account RLP associated with a particular
 // hash in the snapshot slim data format.
-func (dl *diffLayer) AccountRLP(hash common.Hash) []byte {
+func (dl *diffLayer) AccountRLP(hash common.Hash) ([]byte, error) {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
+	// If the layer was flattened into, consider it invalid (any live reference to
+	// the original should be marked as unusable).
+	if dl.stale {
+		return nil, ErrSnapshotStale
+	}
 	// If the account is known locally, return it. Note, a nil account means it was
 	// deleted, and is a different notion than an unknown account!
 	if data, ok := dl.accountData[hash]; ok {
-		return data
+		return data, nil
 	}
 	// Account unknown to this diff, resolve from parent
 	return dl.parent.AccountRLP(hash)
@@ -149,18 +140,23 @@ func (dl *diffLayer) AccountRLP(hash common.Hash) []byte {
 // Storage directly retrieves the storage data associated with a particular hash,
 // within a particular account. If the slot is unknown to this diff, it's parent
 // is consulted.
-func (dl *diffLayer) Storage(accountHash, storageHash common.Hash) []byte {
+func (dl *diffLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) {
 	dl.lock.RLock()
 	defer dl.lock.RUnlock()
 
+	// If the layer was flattened into, consider it invalid (any live reference to
+	// the original should be marked as unusable).
+	if dl.stale {
+		return nil, ErrSnapshotStale
+	}
 	// If the account is known locally, try to resolve the slot locally. Note, a nil
 	// account means it was deleted, and is a different notion than an unknown account!
 	if storage, ok := dl.storageData[accountHash]; ok {
 		if storage == nil {
-			return nil
+			return nil, nil
 		}
 		if data, ok := storage[storageHash]; ok {
-			return data
+			return data, nil
 		}
 	}
 	// Account - or slot within - unknown to this diff, resolve from parent
@@ -193,13 +189,17 @@ func (dl *diffLayer) Cap(layers int, memory uint64) (uint64, uint64) {
 	case *diskLayer:
 		return parent.number, dl.number
 	case *diffLayer:
+		// Flatten the parent into the grandparent. The flattening internally obtains a
+		// write lock on grandparent.
+		flattened := parent.flatten().(*diffLayer)
+
 		dl.lock.Lock()
 		defer dl.lock.Unlock()
 
-		dl.parent = parent.flatten()
-		if dl.parent.(*diffLayer).memory < memory {
-			diskNumber, _ := parent.parent.Info()
-			return diskNumber, parent.number
+		dl.parent = flattened
+		if flattened.memory < memory {
+			diskNumber, _ := flattened.parent.Info()
+			return diskNumber, flattened.number
 		}
 	default:
 		panic(fmt.Sprintf("unknown data layer: %T", parent))
@@ -213,10 +213,18 @@ func (dl *diffLayer) Cap(layers int, memory uint64) (uint64, uint64) {
 	parent.lock.RLock()
 	defer parent.lock.RUnlock()
 
-	// Start by temporarilly deleting the current snapshot block marker. This
+	// Start by temporarily deleting the current snapshot block marker. This
 	// ensures that in the case of a crash, the entire snapshot is invalidated.
 	rawdb.DeleteSnapshotBlock(batch)
 
+	// Mark the original base as stale as we're going to create a new wrapper
+	base.lock.Lock()
+	if base.stale {
+		panic("parent disk layer is stale") // we've committed into the same base from two children, boo
+	}
+	base.stale = true
+	base.lock.Unlock()
+
 	// Push all the accounts into the database
 	for hash, data := range parent.accountData {
 		if len(data) > 0 {
@@ -264,15 +272,20 @@ func (dl *diffLayer) Cap(layers int, memory uint64) (uint64, uint64) {
 		}
 	}
 	// Update the snapshot block marker and write any remainder data
-	base.number, base.root = parent.number, parent.root
-
-	rawdb.WriteSnapshotBlock(batch, base.number, base.root)
+	newBase := &diskLayer{
+		root:    parent.root,
+		number:  parent.number,
+		cache:   base.cache,
+		db:      base.db,
+		journal: base.journal,
+	}
+	rawdb.WriteSnapshotBlock(batch, newBase.number, newBase.root)
 	if err := batch.Write(); err != nil {
 		log.Crit("Failed to write leftover snapshot", "err", err)
 	}
-	dl.parent = base
+	dl.parent = newBase
 
-	return base.number, dl.number
+	return newBase.number, dl.number
 }
 
 // flatten pushes all data from this point downwards, flattening everything into
@@ -289,19 +302,25 @@ func (dl *diffLayer) flatten() snapshot {
 	// be smarter about grouping flattens together).
 	parent = parent.flatten().(*diffLayer)
 
+	parent.lock.Lock()
+	defer parent.lock.Unlock()
+
+	// Before actually writing all our data to the parent, first ensure that the
+	// parent hasn't been 'corrupted' by someone else already flattening into it
+	if parent.stale {
+		panic("parent diff layer is stale") // we've flattened into the same parent from two children, boo
+	}
+	parent.stale = true
+
 	// Overwrite all the updated accounts blindly, merge the sorted list
 	for hash, data := range dl.accountData {
 		parent.accountData[hash] = data
 	}
-	parent.accountList = append(parent.accountList, dl.accountList...) // TODO(karalabe): dedup!!
-	parent.accountSorted = false
-
 	// Overwrite all the updates storage slots (individually)
 	for accountHash, storage := range dl.storageData {
 		// If storage didn't exist (or was deleted) in the parent; or if the storage
 		// was freshly deleted in the child, overwrite blindly
 		if parent.storageData[accountHash] == nil || storage == nil {
-			parent.storageList[accountHash] = dl.storageList[accountHash]
 			parent.storageData[accountHash] = storage
 			continue
 		}
@@ -311,14 +330,18 @@ func (dl *diffLayer) flatten() snapshot {
 			comboData[storageHash] = data
 		}
 		parent.storageData[accountHash] = comboData
-		parent.storageList[accountHash] = append(parent.storageList[accountHash], dl.storageList[accountHash]...) // TODO(karalabe): dedup!!
-		parent.storageSorted[accountHash] = false
 	}
 	// Return the combo parent
-	parent.number = dl.number
-	parent.root = dl.root
-	parent.memory += dl.memory
-	return parent
+	return &diffLayer{
+		parent:      parent.parent,
+		number:      dl.number,
+		root:        dl.root,
+		storageList: parent.storageList,
+		storageData: parent.storageData,
+		accountList: parent.accountList,
+		accountData: parent.accountData,
+		memory:      parent.memory + dl.memory,
+	}
 }
 
 // Journal commits an entire diff hierarchy to disk into a single journal file.
@@ -335,3 +358,45 @@ func (dl *diffLayer) Journal() error {
 	writer.Close()
 	return nil
 }
+
+// AccountList returns a sorted list of all accounts in this difflayer.
+func (dl *diffLayer) AccountList() []common.Hash {
+	dl.lock.Lock()
+	defer dl.lock.Unlock()
+	if dl.accountList != nil {
+		return dl.accountList
+	}
+	accountList := make([]common.Hash, len(dl.accountData))
+	i := 0
+	for k, _ := range dl.accountData {
+		accountList[i] = k
+		i++
+		// This would be a pretty good opportunity to also
+		// calculate the size, if we want to
+	}
+	sort.Sort(hashes(accountList))
+	dl.accountList = accountList
+	return dl.accountList
+}
+
+// StorageList returns a sorted list of all storage slot hashes
+// in this difflayer for the given account.
+func (dl *diffLayer) StorageList(accountHash common.Hash) []common.Hash {
+	dl.lock.Lock()
+	defer dl.lock.Unlock()
+	if dl.storageList[accountHash] != nil {
+		return dl.storageList[accountHash]
+	}
+	accountStorageMap := dl.storageData[accountHash]
+	accountStorageList := make([]common.Hash, len(accountStorageMap))
+	i := 0
+	for k, _ := range accountStorageMap {
+		accountStorageList[i] = k
+		i++
+		// This would be a pretty good opportunity to also
+		// calculate the size, if we want to
+	}
+	sort.Sort(hashes(accountStorageList))
+	dl.storageList[accountHash] = accountStorageList
+	return accountStorageList
+}
diff --git a/core/state/snapshot/difflayer_test.go b/core/state/snapshot/difflayer_test.go
new file mode 100644
index 000000000..5a718c617
--- /dev/null
+++ b/core/state/snapshot/difflayer_test.go
@@ -0,0 +1,448 @@
+// Copyright 2019 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package snapshot
+
+import (
+	"bytes"
+	"fmt"
+	"math/big"
+	"math/rand"
+	"testing"
+	"time"
+
+	"github.com/allegro/bigcache"
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/core/rawdb"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+func randomAccount() []byte {
+	root := randomHash()
+	a := Account{
+		Balance:  big.NewInt(rand.Int63()),
+		Nonce:    rand.Uint64(),
+		Root:     root[:],
+		CodeHash: emptyCode[:],
+	}
+	data, _ := rlp.EncodeToBytes(a)
+	return data
+}
+
+// TestMergeBasics tests some simple merges
+func TestMergeBasics(t *testing.T) {
+	var (
+		accounts = make(map[common.Hash][]byte)
+		storage  = make(map[common.Hash]map[common.Hash][]byte)
+	)
+	// Fill up a parent
+	for i := 0; i < 100; i++ {
+		h := randomHash()
+		data := randomAccount()
+
+		accounts[h] = data
+		if rand.Intn(20) < 10 {
+			accStorage := make(map[common.Hash][]byte)
+			value := make([]byte, 32)
+			rand.Read(value)
+			accStorage[randomHash()] = value
+			storage[h] = accStorage
+		}
+	}
+	// Add some (identical) layers on top
+	parent := newDiffLayer(emptyLayer{}, 1, common.Hash{}, accounts, storage)
+	child := newDiffLayer(parent, 1, common.Hash{}, accounts, storage)
+	child = newDiffLayer(child, 1, common.Hash{}, accounts, storage)
+	child = newDiffLayer(child, 1, common.Hash{}, accounts, storage)
+	child = newDiffLayer(child, 1, common.Hash{}, accounts, storage)
+	// And flatten
+	merged := (child.flatten()).(*diffLayer)
+
+	{ // Check account lists
+		// Should be zero/nil first
+		if got, exp := len(merged.accountList), 0; got != exp {
+			t.Errorf("accountList wrong, got %v exp %v", got, exp)
+		}
+		// Then set when we call AccountList
+		if got, exp := len(merged.AccountList()), len(accounts); got != exp {
+			t.Errorf("AccountList() wrong, got %v exp %v", got, exp)
+		}
+		if got, exp := len(merged.accountList), len(accounts); got != exp {
+			t.Errorf("accountList [2] wrong, got %v exp %v", got, exp)
+		}
+	}
+	{ // Check storage lists
+		i := 0
+		for aHash, sMap := range storage {
+			if got, exp := len(merged.storageList), i; got != exp {
+				t.Errorf("[1] storageList wrong, got %v exp %v", got, exp)
+			}
+			if got, exp := len(merged.StorageList(aHash)), len(sMap); got != exp {
+				t.Errorf("[2] StorageList() wrong, got %v exp %v", got, exp)
+			}
+			if got, exp := len(merged.storageList[aHash]), len(sMap); got != exp {
+				t.Errorf("storageList wrong, got %v exp %v", got, exp)
+			}
+			i++
+		}
+	}
+}
+
+// TestMergeDelete tests some deletion
+func TestMergeDelete(t *testing.T) {
+	var (
+		storage = make(map[common.Hash]map[common.Hash][]byte)
+	)
+	// Fill up a parent
+	h1 := common.HexToHash("0x01")
+	h2 := common.HexToHash("0x02")
+
+	flip := func() map[common.Hash][]byte {
+		accs := make(map[common.Hash][]byte)
+		accs[h1] = randomAccount()
+		accs[h2] = nil
+		return accs
+	}
+	flop := func() map[common.Hash][]byte {
+		accs := make(map[common.Hash][]byte)
+		accs[h1] = nil
+		accs[h2] = randomAccount()
+		return accs
+	}
+
+	// Add some flip-flopping layers on top
+	parent := newDiffLayer(emptyLayer{}, 1, common.Hash{}, flip(), storage)
+	child := parent.Update(common.Hash{}, flop(), storage)
+	child = child.Update(common.Hash{}, flip(), storage)
+	child = child.Update(common.Hash{}, flop(), storage)
+	child = child.Update(common.Hash{}, flip(), storage)
+	child = child.Update(common.Hash{}, flop(), storage)
+	child = child.Update(common.Hash{}, flip(), storage)
+
+	if data, _ := child.Account(h1); data == nil {
+		t.Errorf("last diff layer: expected %x to be non-nil", h1)
+	}
+	if data, _ := child.Account(h2); data != nil {
+		t.Errorf("last diff layer: expected %x to be nil", h2)
+	}
+	// And flatten
+	merged := (child.flatten()).(*diffLayer)
+
+	// check number
+	if got, exp := merged.number, child.number; got != exp {
+		t.Errorf("merged layer: wrong number - exp %d got %d", exp, got)
+	}
+	if data, _ := merged.Account(h1); data == nil {
+		t.Errorf("merged layer: expected %x to be non-nil", h1)
+	}
+	if data, _ := merged.Account(h2); data != nil {
+		t.Errorf("merged layer: expected %x to be nil", h2)
+	}
+	// If we add more granular metering of memory, we can enable this again,
+	// but it's not implemented for now
+	//if got, exp := merged.memory, child.memory; got != exp {
+	//	t.Errorf("mem wrong, got %d, exp %d", got, exp)
+	//}
+}
+
+// This tests that if we create a new account, and set a slot, and then merge
+// it, the lists will be correct.
+func TestInsertAndMerge(t *testing.T) {
+	// Fill up a parent
+	var (
+		acc    = common.HexToHash("0x01")
+		slot   = common.HexToHash("0x02")
+		parent *diffLayer
+		child  *diffLayer
+	)
+	{
+		var accounts = make(map[common.Hash][]byte)
+		var storage = make(map[common.Hash]map[common.Hash][]byte)
+		parent = newDiffLayer(emptyLayer{}, 1, common.Hash{}, accounts, storage)
+	}
+	{
+		var accounts = make(map[common.Hash][]byte)
+		var storage = make(map[common.Hash]map[common.Hash][]byte)
+		accounts[acc] = randomAccount()
+		accstorage := make(map[common.Hash][]byte)
+		storage[acc] = accstorage
+		storage[acc][slot] = []byte{0x01}
+		child = newDiffLayer(parent, 2, common.Hash{}, accounts, storage)
+	}
+	// And flatten
+	merged := (child.flatten()).(*diffLayer)
+	{ // Check that slot value is present
+		got, _ := merged.Storage(acc, slot)
+		if exp := []byte{0x01}; bytes.Compare(got, exp) != 0 {
+			t.Errorf("merged slot value wrong, got %x, exp %x", got, exp)
+		}
+	}
+}
+
+// TestCapTree tests some functionality regarding capping/flattening
+func TestCapTree(t *testing.T) {
+
+	var (
+		storage = make(map[common.Hash]map[common.Hash][]byte)
+	)
+	setAccount := func(accKey string) map[common.Hash][]byte {
+		return map[common.Hash][]byte{
+			common.HexToHash(accKey): randomAccount(),
+		}
+	}
+	// the bottom-most layer, aside from the 'disk layer'
+	cache, _ := bigcache.NewBigCache(bigcache.Config{ // TODO(karalabe): dedup
+		Shards:             1,
+		LifeWindow:         time.Hour,
+		MaxEntriesInWindow: 1 * 1024,
+		MaxEntrySize:       1,
+		HardMaxCacheSize:   1,
+	})
+
+	base := &diskLayer{
+		journal: "",
+		db:      rawdb.NewMemoryDatabase(),
+		cache:   cache,
+		number:  0,
+		root:    common.HexToHash("0x01"),
+	}
+	// The lowest difflayer
+	a1 := base.Update(common.HexToHash("0xa1"), setAccount("0xa1"), storage)
+
+	a2 := a1.Update(common.HexToHash("0xa2"), setAccount("0xa2"), storage)
+	b2 := a1.Update(common.HexToHash("0xb2"), setAccount("0xb2"), storage)
+
+	a3 := a2.Update(common.HexToHash("0xa3"), setAccount("0xa3"), storage)
+	b3 := b2.Update(common.HexToHash("0xb3"), setAccount("0xb3"), storage)
+
+	checkExist := func(layer *diffLayer, key string) error {
+		accountKey := common.HexToHash(key)
+		data, _ := layer.Account(accountKey)
+		if data == nil {
+			return fmt.Errorf("expected %x to exist, got nil", accountKey)
+		}
+		return nil
+	}
+	shouldErr := func(layer *diffLayer, key string) error {
+		accountKey := common.HexToHash(key)
+		data, err := layer.Account(accountKey)
+		if err == nil {
+			return fmt.Errorf("expected error, got data %x", data)
+		}
+		return nil
+	}
+
+	// check basics
+	if err := checkExist(b3, "0xa1"); err != nil {
+		t.Error(err)
+	}
+	if err := checkExist(b3, "0xb2"); err != nil {
+		t.Error(err)
+	}
+	if err := checkExist(b3, "0xb3"); err != nil {
+		t.Error(err)
+	}
+	// Now, merge the a-chain
+	diskNum, diffNum := a3.Cap(0, 1024)
+	if diskNum != 0 {
+		t.Errorf("disk layer err, got %d exp %d", diskNum, 0)
+	}
+	if diffNum != 2 {
+		t.Errorf("diff layer err, got %d exp %d", diffNum, 2)
+	}
+	// At this point, a2 got merged into a1. Thus, a1 is now modified,
+	// and as a1 is the parent of b2, b2 should no longer be able to iterate into parent
+
+	// These should still be accessible
+	if err := checkExist(b3, "0xb2"); err != nil {
+		t.Error(err)
+	}
+	if err := checkExist(b3, "0xb3"); err != nil {
+		t.Error(err)
+	}
+	//b2ParentNum, _ := b2.parent.Info()
+	//if b2.parent.invalid == false
+	//	t.Errorf("err, exp parent to be invalid, got %v", b2.parent, b2ParentNum)
+	//}
+	// But these would need iteration into the modified parent:
+	if err := shouldErr(b3, "0xa1"); err != nil {
+		t.Error(err)
+	}
+	if err := shouldErr(b3, "0xa2"); err != nil {
+		t.Error(err)
+	}
+	if err := shouldErr(b3, "0xa3"); err != nil {
+		t.Error(err)
+	}
+}
+
+type emptyLayer struct{}
+
+func (emptyLayer) Update(blockRoot common.Hash, accounts map[common.Hash][]byte, storage map[common.Hash]map[common.Hash][]byte) *diffLayer {
+	panic("implement me")
+}
+
+func (emptyLayer) Cap(layers int, memory uint64) (uint64, uint64) {
+	panic("implement me")
+}
+
+func (emptyLayer) Journal() error {
+	panic("implement me")
+}
+
+func (emptyLayer) Info() (uint64, common.Hash) {
+	return 0, common.Hash{}
+}
+func (emptyLayer) Number() uint64 {
+	return 0
+}
+
+func (emptyLayer) Account(hash common.Hash) (*Account, error) {
+	return nil, nil
+}
+
+func (emptyLayer) AccountRLP(hash common.Hash) ([]byte, error) {
+	return nil, nil
+}
+
+func (emptyLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) {
+	return nil, nil
+}
+
+// BenchmarkSearch checks how long it takes to find a non-existing key
+// BenchmarkSearch-6   	  200000	     10481 ns/op (1K per layer)
+// BenchmarkSearch-6   	  200000	     10760 ns/op (10K per layer)
+// BenchmarkSearch-6   	  100000	     17866 ns/op
+//
+// BenchmarkSearch-6   	  500000	      3723 ns/op (10k per layer, only top-level RLock()
+func BenchmarkSearch(b *testing.B) {
+	// First, we set up 128 diff layers, with 1K items each
+
+	blocknum := uint64(0)
+	fill := func(parent snapshot) *diffLayer {
+		accounts := make(map[common.Hash][]byte)
+		storage := make(map[common.Hash]map[common.Hash][]byte)
+
+		for i := 0; i < 10000; i++ {
+			accounts[randomHash()] = randomAccount()
+		}
+		blocknum++
+		return newDiffLayer(parent, blocknum, common.Hash{}, accounts, storage)
+	}
+
+	var layer snapshot
+	layer = emptyLayer{}
+	for i := 0; i < 128; i++ {
+		layer = fill(layer)
+	}
+
+	key := common.Hash{}
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		layer.AccountRLP(key)
+	}
+}
+
+// BenchmarkSearchSlot checks how long it takes to find a non-existing key
+// - Number of layers: 128
+// - Each layers contains the account, with a couple of storage slots
+// BenchmarkSearchSlot-6   	  100000	     14554 ns/op
+// BenchmarkSearchSlot-6   	  100000	     22254 ns/op (when checking parent root using mutex)
+// BenchmarkSearchSlot-6   	  100000	     14551 ns/op (when checking parent number using atomic)
+func BenchmarkSearchSlot(b *testing.B) {
+	// First, we set up 128 diff layers, with 1K items each
+
+	blocknum := uint64(0)
+	accountKey := common.Hash{}
+	storageKey := common.HexToHash("0x1337")
+	accountRLP := randomAccount()
+	fill := func(parent snapshot) *diffLayer {
+		accounts := make(map[common.Hash][]byte)
+		accounts[accountKey] = accountRLP
+		storage := make(map[common.Hash]map[common.Hash][]byte)
+
+		accStorage := make(map[common.Hash][]byte)
+		for i := 0; i < 5; i++ {
+			value := make([]byte, 32)
+			rand.Read(value)
+			accStorage[randomHash()] = value
+			storage[accountKey] = accStorage
+		}
+		blocknum++
+		return newDiffLayer(parent, blocknum, common.Hash{}, accounts, storage)
+	}
+
+	var layer snapshot
+	layer = emptyLayer{}
+	for i := 0; i < 128; i++ {
+		layer = fill(layer)
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		layer.Storage(accountKey, storageKey)
+	}
+}
+
+// With accountList and sorting
+//BenchmarkFlatten-6   	      50	  29890856 ns/op
+//
+// Without sorting and tracking accountlist
+// BenchmarkFlatten-6   	     300	   5511511 ns/op
+func BenchmarkFlatten(b *testing.B) {
+
+	fill := func(parent snapshot, blocknum int) *diffLayer {
+		accounts := make(map[common.Hash][]byte)
+		storage := make(map[common.Hash]map[common.Hash][]byte)
+
+		for i := 0; i < 100; i++ {
+			accountKey := randomHash()
+			accounts[accountKey] = randomAccount()
+
+			accStorage := make(map[common.Hash][]byte)
+			for i := 0; i < 20; i++ {
+				value := make([]byte, 32)
+				rand.Read(value)
+				accStorage[randomHash()] = value
+
+			}
+			storage[accountKey] = accStorage
+		}
+		return newDiffLayer(parent, uint64(blocknum), common.Hash{}, accounts, storage)
+	}
+
+	b.ResetTimer()
+
+	for i := 0; i < b.N; i++ {
+		b.StopTimer()
+		var layer snapshot
+		layer = emptyLayer{}
+		for i := 1; i < 128; i++ {
+			layer = fill(layer, i)
+		}
+		b.StartTimer()
+
+		for i := 1; i < 128; i++ {
+			dl, ok := layer.(*diffLayer)
+			if !ok {
+				break
+			}
+
+			layer = dl.flatten()
+		}
+		b.StopTimer()
+	}
+}
diff --git a/core/state/snapshot/disklayer.go b/core/state/snapshot/disklayer.go
index 0406d298f..a9839f01a 100644
--- a/core/state/snapshot/disklayer.go
+++ b/core/state/snapshot/disklayer.go
@@ -17,6 +17,8 @@
 package snapshot
 
 import (
+	"sync"
+
 	"github.com/allegro/bigcache"
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/rawdb"
@@ -32,6 +34,9 @@ type diskLayer struct {
 
 	number uint64      // Block number of the base snapshot
 	root   common.Hash // Root hash of the base snapshot
+	stale  bool        // Signals that the layer became stale (state progressed)
+
+	lock sync.RWMutex
 }
 
 // Info returns the block number and root hash for which this snapshot was made.
@@ -41,28 +46,39 @@ func (dl *diskLayer) Info() (uint64, common.Hash) {
 
 // Account directly retrieves the account associated with a particular hash in
 // the snapshot slim data format.
-func (dl *diskLayer) Account(hash common.Hash) *Account {
-	data := dl.AccountRLP(hash)
+func (dl *diskLayer) Account(hash common.Hash) (*Account, error) {
+	data, err := dl.AccountRLP(hash)
+	if err != nil {
+		return nil, err
+	}
 	if len(data) == 0 { // can be both nil and []byte{}
-		return nil
+		return nil, nil
 	}
 	account := new(Account)
 	if err := rlp.DecodeBytes(data, account); err != nil {
 		panic(err)
 	}
-	return account
+	return account, nil
 }
 
 // AccountRLP directly retrieves the account RLP associated with a particular
 // hash in the snapshot slim data format.
-func (dl *diskLayer) AccountRLP(hash common.Hash) []byte {
+func (dl *diskLayer) AccountRLP(hash common.Hash) ([]byte, error) {
+	dl.lock.RLock()
+	defer dl.lock.RUnlock()
+
+	// If the layer was flattened into, consider it invalid (any live reference to
+	// the original should be marked as unusable).
+	if dl.stale {
+		return nil, ErrSnapshotStale
+	}
 	key := string(hash[:])
 
 	// Try to retrieve the account from the memory cache
 	if blob, err := dl.cache.Get(key); err == nil {
 		snapshotCleanHitMeter.Mark(1)
 		snapshotCleanReadMeter.Mark(int64(len(blob)))
-		return blob
+		return blob, nil
 	}
 	// Cache doesn't contain account, pull from disk and cache for later
 	blob := rawdb.ReadAccountSnapshot(dl.db, hash)
@@ -71,19 +87,27 @@ func (dl *diskLayer) AccountRLP(hash common.Hash) []byte {
 	snapshotCleanMissMeter.Mark(1)
 	snapshotCleanWriteMeter.Mark(int64(len(blob)))
 
-	return blob
+	return blob, nil
 }
 
 // Storage directly retrieves the storage data associated with a particular hash,
 // within a particular account.
-func (dl *diskLayer) Storage(accountHash, storageHash common.Hash) []byte {
+func (dl *diskLayer) Storage(accountHash, storageHash common.Hash) ([]byte, error) {
+	dl.lock.RLock()
+	defer dl.lock.RUnlock()
+
+	// If the layer was flattened into, consider it invalid (any live reference to
+	// the original should be marked as unusable).
+	if dl.stale {
+		return nil, ErrSnapshotStale
+	}
 	key := string(append(accountHash[:], storageHash[:]...))
 
 	// Try to retrieve the storage slot from the memory cache
 	if blob, err := dl.cache.Get(key); err == nil {
 		snapshotCleanHitMeter.Mark(1)
 		snapshotCleanReadMeter.Mark(int64(len(blob)))
-		return blob
+		return blob, nil
 	}
 	// Cache doesn't contain storage slot, pull from disk and cache for later
 	blob := rawdb.ReadStorageSnapshot(dl.db, accountHash, storageHash)
@@ -92,7 +116,7 @@ func (dl *diskLayer) Storage(accountHash, storageHash common.Hash) []byte {
 	snapshotCleanMissMeter.Mark(1)
 	snapshotCleanWriteMeter.Mark(int64(len(blob)))
 
-	return blob
+	return blob, nil
 }
 
 // Update creates a new layer on top of the existing snapshot diff tree with
diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
index 0d451fe50..4a66e0626 100644
--- a/core/state/snapshot/generate.go
+++ b/core/state/snapshot/generate.go
@@ -135,6 +135,7 @@ func generateSnapshot(db ethdb.KeyValueStore, journal string, headNumber uint64,
 			curStorageNodes int
 			curAccountSize  common.StorageSize
 			curStorageSize  common.StorageSize
+			accountHash     = common.BytesToHash(accIt.Key)
 		)
 		var acc struct {
 			Nonce    uint64
@@ -148,7 +149,7 @@ func generateSnapshot(db ethdb.KeyValueStore, journal string, headNumber uint64,
 		data := AccountRLP(acc.Nonce, acc.Balance, acc.Root, acc.CodeHash)
 		curAccountSize += common.StorageSize(1 + common.HashLength + len(data))
 
-		rawdb.WriteAccountSnapshot(batch, common.BytesToHash(accIt.Key), data)
+		rawdb.WriteAccountSnapshot(batch, accountHash, data)
 		if batch.ValueSize() > ethdb.IdealBatchSize {
 			batch.Write()
 			batch.Reset()
@@ -163,7 +164,7 @@ func generateSnapshot(db ethdb.KeyValueStore, journal string, headNumber uint64,
 				curStorageSize += common.StorageSize(1 + 2*common.HashLength + len(storeIt.Value))
 				curStorageCount++
 
-				rawdb.WriteStorageSnapshot(batch, common.BytesToHash(accIt.Key), common.BytesToHash(storeIt.Key), storeIt.Value)
+				rawdb.WriteStorageSnapshot(batch, accountHash, common.BytesToHash(storeIt.Key), storeIt.Value)
 				if batch.ValueSize() > ethdb.IdealBatchSize {
 					batch.Write()
 					batch.Reset()
diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go
index 6d4df96da..6a21d57dc 100644
--- a/core/state/snapshot/snapshot.go
+++ b/core/state/snapshot/snapshot.go
@@ -38,6 +38,11 @@ var (
 	snapshotCleanMissMeter  = metrics.NewRegisteredMeter("state/snapshot/clean/miss", nil)
 	snapshotCleanReadMeter  = metrics.NewRegisteredMeter("state/snapshot/clean/read", nil)
 	snapshotCleanWriteMeter = metrics.NewRegisteredMeter("state/snapshot/clean/write", nil)
+
+	// ErrSnapshotStale is returned from data accessors if the underlying snapshot
+	// layer had been invalidated due to the chain progressing forward far enough
+	// to not maintain the layer's original state.
+	ErrSnapshotStale = errors.New("snapshot stale")
 )
 
 // Snapshot represents the functionality supported by a snapshot storage layer.
@@ -47,15 +52,15 @@ type Snapshot interface {
 
 	// Account directly retrieves the account associated with a particular hash in
 	// the snapshot slim data format.
-	Account(hash common.Hash) *Account
+	Account(hash common.Hash) (*Account, error)
 
 	// AccountRLP directly retrieves the account RLP associated with a particular
 	// hash in the snapshot slim data format.
-	AccountRLP(hash common.Hash) []byte
+	AccountRLP(hash common.Hash) ([]byte, error)
 
 	// Storage directly retrieves the storage data associated with a particular hash,
 	// within a particular account.
-	Storage(accountHash, storageHash common.Hash) []byte
+	Storage(accountHash, storageHash common.Hash) ([]byte, error)
 }
 
 // snapshot is the internal version of the snapshot data layer that supports some
@@ -80,7 +85,7 @@ type snapshot interface {
 }
 
 // SnapshotTree is an Ethereum state snapshot tree. It consists of one persistent
-// base layer backed by a key-value store, on top of which arbitrarilly many in-
+// base layer backed by a key-value store, on top of which arbitrarily many in-
 // memory diff layers are topped. The memory diffs can form a tree with branching,
 // but the disk layer is singleton and common to all. If a reorg goes deeper than
 // the disk layer, everything needs to be deleted.
@@ -220,7 +225,7 @@ func loadSnapshot(db ethdb.KeyValueStore, journal string, headNumber uint64, hea
 	if _, err := os.Stat(journal); os.IsNotExist(err) {
 		// Journal doesn't exist, don't worry if it's not supposed to
 		if number != headNumber || root != headRoot {
-			return nil, fmt.Errorf("snapshot journal missing, head does't match snapshot: #%d [%#x] vs. #%d [%#x]",
+			return nil, fmt.Errorf("snapshot journal missing, head doesn't match snapshot: #%d [%#x] vs. #%d [%#x]",
 				headNumber, headRoot, number, root)
 		}
 		return base, nil
@@ -237,7 +242,7 @@ func loadSnapshot(db ethdb.KeyValueStore, journal string, headNumber uint64, hea
 	// Journal doesn't exist, don't worry if it's not supposed to
 	number, root = snapshot.Info()
 	if number != headNumber || root != headRoot {
-		return nil, fmt.Errorf("head does't match snapshot: #%d [%#x] vs. #%d [%#x]",
+		return nil, fmt.Errorf("head doesn't match snapshot: #%d [%#x] vs. #%d [%#x]",
 			headNumber, headRoot, number, root)
 	}
 	return snapshot, nil
diff --git a/core/state/snapshot/sort.go b/core/state/snapshot/sort.go
index 04729c60b..ee7cc4990 100644
--- a/core/state/snapshot/sort.go
+++ b/core/state/snapshot/sort.go
@@ -60,3 +60,33 @@ func merge(a, b []common.Hash) []common.Hash {
 	}
 	return result
 }
+
+// dedupMerge combines two sorted lists of hashes into a combo sorted one,
+// and removes duplicates in the process
+func dedupMerge(a, b []common.Hash) []common.Hash {
+	result := make([]common.Hash, len(a)+len(b))
+	i := 0
+	for len(a) > 0 && len(b) > 0 {
+		if diff := bytes.Compare(a[0][:], b[0][:]); diff < 0 {
+			result[i] = a[0]
+			a = a[1:]
+		} else {
+			result[i] = b[0]
+			b = b[1:]
+			// If they were equal, progress a too
+			if diff == 0 {
+				a = a[1:]
+			}
+		}
+		i++
+	}
+	for j := 0; j < len(a); j++ {
+		result[i] = a[j]
+		i++
+	}
+	for j := 0; j < len(b); j++ {
+		result[i] = b[j]
+		i++
+	}
+	return result[:i]
+}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index 98be56671..d10caa831 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -204,13 +204,13 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
 		if metrics.EnabledExpensive {
 			defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now())
 		}
-		enc = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:]))
-	} else {
-		// Track the amount of time wasted on reading the storage trie
+		enc, err = s.db.snap.Storage(s.addrHash, crypto.Keccak256Hash(key[:]))
+	}
+	// If snapshot unavailable or reading from it failed, load from the database
+	if s.db.snap == nil || err != nil {
 		if metrics.EnabledExpensive {
 			defer func(start time.Time) { s.db.StorageReads += time.Since(start) }(time.Now())
 		}
-		// Otherwise load the value from the database
 		if enc, err = s.getTrie(db).TryGet(key[:]); err != nil {
 			s.setError(err)
 			return common.Hash{}
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 0fb1095ce..7d7499892 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -511,25 +511,31 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
 		return obj
 	}
 	// If no live objects are available, attempt to use snapshots
-	var data Account
+	var (
+		data Account
+		err  error
+	)
 	if s.snap != nil {
 		if metrics.EnabledExpensive {
 			defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now())
 		}
-		acc := s.snap.Account(crypto.Keccak256Hash(addr[:]))
-		if acc == nil {
-			return nil
-		}
-		data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash
-		if len(data.CodeHash) == 0 {
-			data.CodeHash = emptyCodeHash
-		}
-		data.Root = common.BytesToHash(acc.Root)
-		if data.Root == (common.Hash{}) {
-			data.Root = emptyRoot
+		var acc *snapshot.Account
+		if acc, err = s.snap.Account(crypto.Keccak256Hash(addr[:])); err == nil {
+			if acc == nil {
+				return nil
+			}
+			data.Nonce, data.Balance, data.CodeHash = acc.Nonce, acc.Balance, acc.CodeHash
+			if len(data.CodeHash) == 0 {
+				data.CodeHash = emptyCodeHash
+			}
+			data.Root = common.BytesToHash(acc.Root)
+			if data.Root == (common.Hash{}) {
+				data.Root = emptyRoot
+			}
 		}
-	} else {
-		// Snapshot unavailable, fall back to the trie
+	}
+	// If snapshot unavailable or reading from it failed, load from the database
+	if s.snap == nil || err != nil {
 		if metrics.EnabledExpensive {
 			defer func(start time.Time) { s.AccountReads += time.Since(start) }(time.Now())
 		}
-- 
GitLab