From fae165a5def1a335594cf6761164e31fa4e8d27d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Wed, 28 Apr 2021 23:09:15 +0300
Subject: [PATCH] core, eth, ethdb, trie: simplify range proofs

---
 core/rawdb/table.go                           |   5 -
 core/state/snapshot/generate.go               |   2 +-
 eth/protocols/snap/sync.go                    | 121 +++++++++---------
 ethdb/batch.go                                |  28 +++-
 ethdb/leveldb/leveldb.go                      |   9 +-
 ethdb/memorydb/memorydb.go                    |   9 +-
 tests/fuzzers/rangeproof/rangeproof-fuzzer.go |   9 +-
 tests/fuzzers/stacktrie/trie_fuzzer.go        |   1 -
 trie/notary.go                                |  57 ---------
 trie/proof.go                                 |  82 +++++-------
 trie/proof_test.go                            |  62 ++++-----
 trie/trie_test.go                             |   1 -
 12 files changed, 149 insertions(+), 237 deletions(-)
 delete mode 100644 trie/notary.go

diff --git a/core/rawdb/table.go b/core/rawdb/table.go
index 4daa6b534..323ef6293 100644
--- a/core/rawdb/table.go
+++ b/core/rawdb/table.go
@@ -176,11 +176,6 @@ func (b *tableBatch) Delete(key []byte) error {
 	return b.batch.Delete(append([]byte(b.prefix), key...))
 }
 
-// KeyCount retrieves the number of keys queued up for writing.
-func (b *tableBatch) KeyCount() int {
-	return b.batch.KeyCount()
-}
-
 // ValueSize retrieves the amount of data queued up for writing.
 func (b *tableBatch) ValueSize() int {
 	return b.batch.ValueSize()
diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
index 78fca45e4..8992d3f91 100644
--- a/core/state/snapshot/generate.go
+++ b/core/state/snapshot/generate.go
@@ -368,7 +368,7 @@ func (dl *diskLayer) proveRange(stats *generatorStats, root common.Hash, prefix
 	}
 	// Verify the snapshot segment with range prover, ensure that all flat states
 	// in this range correspond to merkle trie.
-	_, cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof)
+	cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof)
 	return &proofResult{
 			keys:     keys,
 			vals:     vals,
diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go
index 287ac8d72..d9c0cb9b1 100644
--- a/eth/protocols/snap/sync.go
+++ b/eth/protocols/snap/sync.go
@@ -202,9 +202,8 @@ type storageResponse struct {
 	accounts []common.Hash // Account hashes requested, may be only partially filled
 	roots    []common.Hash // Storage roots requested, may be only partially filled
 
-	hashes [][]common.Hash       // Storage slot hashes in the returned range
-	slots  [][][]byte            // Storage slot values in the returned range
-	nodes  []ethdb.KeyValueStore // Database containing the reconstructed trie nodes
+	hashes [][]common.Hash // Storage slot hashes in the returned range
+	slots  [][][]byte      // Storage slot values in the returned range
 
 	cont bool // Whether the last storage range has a continuation
 }
@@ -680,12 +679,22 @@ func (s *Syncer) loadSyncStatus() {
 			}
 			s.tasks = progress.Tasks
 			for _, task := range s.tasks {
-				task.genBatch = s.db.NewBatch()
+				task.genBatch = ethdb.HookedBatch{
+					Batch: s.db.NewBatch(),
+					OnPut: func(key []byte, value []byte) {
+						s.accountBytes += common.StorageSize(len(key) + len(value))
+					},
+				}
 				task.genTrie = trie.NewStackTrie(task.genBatch)
 
 				for _, subtasks := range task.SubTasks {
 					for _, subtask := range subtasks {
-						subtask.genBatch = s.db.NewBatch()
+						subtask.genBatch = ethdb.HookedBatch{
+							Batch: s.db.NewBatch(),
+							OnPut: func(key []byte, value []byte) {
+								s.storageBytes += common.StorageSize(len(key) + len(value))
+							},
+						}
 						subtask.genTrie = trie.NewStackTrie(task.genBatch)
 					}
 				}
@@ -729,7 +738,12 @@ func (s *Syncer) loadSyncStatus() {
 			// Make sure we don't overflow if the step is not a proper divisor
 			last = common.HexToHash("0xffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff")
 		}
-		batch := s.db.NewBatch()
+		batch := ethdb.HookedBatch{
+			Batch: s.db.NewBatch(),
+			OnPut: func(key []byte, value []byte) {
+				s.accountBytes += common.StorageSize(len(key) + len(value))
+			},
+		}
 		s.tasks = append(s.tasks, &accountTask{
 			Next:     next,
 			Last:     last,
@@ -746,19 +760,14 @@ func (s *Syncer) loadSyncStatus() {
 func (s *Syncer) saveSyncStatus() {
 	// Serialize any partial progress to disk before spinning down
 	for _, task := range s.tasks {
-		keys, bytes := task.genBatch.KeyCount(), task.genBatch.ValueSize()
 		if err := task.genBatch.Write(); err != nil {
 			log.Error("Failed to persist account slots", "err", err)
 		}
-		s.accountBytes += common.StorageSize(keys*common.HashLength + bytes)
-
 		for _, subtasks := range task.SubTasks {
 			for _, subtask := range subtasks {
-				keys, bytes := subtask.genBatch.KeyCount(), subtask.genBatch.ValueSize()
 				if err := subtask.genBatch.Write(); err != nil {
 					log.Error("Failed to persist storage slots", "err", err)
 				}
-				s.accountBytes += common.StorageSize(keys*common.HashLength + bytes)
 			}
 		}
 	}
@@ -1763,12 +1772,15 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
 	if res.subTask != nil {
 		res.subTask.req = nil
 	}
-	batch := s.db.NewBatch()
-
+	batch := ethdb.HookedBatch{
+		Batch: s.db.NewBatch(),
+		OnPut: func(key []byte, value []byte) {
+			s.storageBytes += common.StorageSize(len(key) + len(value))
+		},
+	}
 	var (
-		slots int
-		nodes int
-		bytes common.StorageSize
+		slots           int
+		oldStorageBytes = s.storageBytes
 	)
 	// Iterate over all the accounts and reconstruct their storage tries from the
 	// delivered slots
@@ -1829,7 +1841,12 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
 					r := newHashRange(lastKey, chunks)
 
 					// Our first task is the one that was just filled by this response.
-					batch := s.db.NewBatch()
+					batch := ethdb.HookedBatch{
+						Batch: s.db.NewBatch(),
+						OnPut: func(key []byte, value []byte) {
+							s.storageBytes += common.StorageSize(len(key) + len(value))
+						},
+					}
 					tasks = append(tasks, &storageTask{
 						Next:     common.Hash{},
 						Last:     r.End(),
@@ -1838,7 +1855,12 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
 						genTrie:  trie.NewStackTrie(batch),
 					})
 					for r.Next() {
-						batch := s.db.NewBatch()
+						batch := ethdb.HookedBatch{
+							Batch: s.db.NewBatch(),
+							OnPut: func(key []byte, value []byte) {
+								s.storageBytes += common.StorageSize(len(key) + len(value))
+							},
+						}
 						tasks = append(tasks, &storageTask{
 							Next:     r.Start(),
 							Last:     r.End(),
@@ -1883,27 +1905,23 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
 				}
 			}
 		}
-		// Iterate over all the reconstructed trie nodes and push them to disk
-		// if the contract is fully delivered. If it's chunked, the trie nodes
-		// will be reconstructed later.
+		// Iterate over all the complete contracts, reconstruct the trie nodes and
+		// push them to disk. If the contract is chunked, the trie nodes will be
+		// reconstructed later.
 		slots += len(res.hashes[i])
 
 		if i < len(res.hashes)-1 || res.subTask == nil {
-			it := res.nodes[i].NewIterator(nil, nil)
-			for it.Next() {
-				batch.Put(it.Key(), it.Value())
-
-				bytes += common.StorageSize(common.HashLength + len(it.Value()))
-				nodes++
+			tr := trie.NewStackTrie(batch)
+			for j := 0; j < len(res.hashes[i]); j++ {
+				tr.Update(res.hashes[i][j][:], res.slots[i][j])
 			}
-			it.Release()
+			tr.Commit()
 		}
 		// Persist the received storage segements. These flat state maybe
 		// outdated during the sync, but it can be fixed later during the
 		// snapshot generation.
 		for j := 0; j < len(res.hashes[i]); j++ {
 			rawdb.WriteStorageSnapshot(batch, account, res.hashes[i][j], res.slots[i][j])
-			bytes += common.StorageSize(1 + 2*common.HashLength + len(res.slots[i][j]))
 
 			// If we're storing large contracts, generate the trie nodes
 			// on the fly to not trash the gluing points
@@ -1926,15 +1944,11 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
 				}
 			}
 		}
-		if data := res.subTask.genBatch.ValueSize(); data > ethdb.IdealBatchSize || res.subTask.done {
-			keys := res.subTask.genBatch.KeyCount()
+		if res.subTask.genBatch.ValueSize() > ethdb.IdealBatchSize || res.subTask.done {
 			if err := res.subTask.genBatch.Write(); err != nil {
 				log.Error("Failed to persist stack slots", "err", err)
 			}
 			res.subTask.genBatch.Reset()
-
-			bytes += common.StorageSize(keys*common.HashLength + data)
-			nodes += keys
 		}
 	}
 	// Flush anything written just now and update the stats
@@ -1942,9 +1956,8 @@ func (s *Syncer) processStorageResponse(res *storageResponse) {
 		log.Crit("Failed to persist storage slots", "err", err)
 	}
 	s.storageSynced += uint64(slots)
-	s.storageBytes += bytes
 
-	log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "nodes", nodes, "bytes", bytes)
+	log.Debug("Persisted set of storage slots", "accounts", len(res.hashes), "slots", slots, "bytes", s.storageBytes-oldStorageBytes)
 
 	// If this delivery completed the last pending task, forward the account task
 	// to the next chunk
@@ -2042,18 +2055,20 @@ func (s *Syncer) forwardAccountTask(task *accountTask) {
 	// Persist the received account segements. These flat state maybe
 	// outdated during the sync, but it can be fixed later during the
 	// snapshot generation.
-	var (
-		nodes int
-		bytes common.StorageSize
-	)
-	batch := s.db.NewBatch()
+	oldAccountBytes := s.accountBytes
+
+	batch := ethdb.HookedBatch{
+		Batch: s.db.NewBatch(),
+		OnPut: func(key []byte, value []byte) {
+			s.accountBytes += common.StorageSize(len(key) + len(value))
+		},
+	}
 	for i, hash := range res.hashes {
 		if task.needCode[i] || task.needState[i] {
 			break
 		}
 		slim := snapshot.SlimAccountRLP(res.accounts[i].Nonce, res.accounts[i].Balance, res.accounts[i].Root, res.accounts[i].CodeHash)
 		rawdb.WriteAccountSnapshot(batch, hash, slim)
-		bytes += common.StorageSize(1 + common.HashLength + len(slim))
 
 		// If the task is complete, drop it into the stack trie to generate
 		// account trie nodes for it
@@ -2069,7 +2084,6 @@ func (s *Syncer) forwardAccountTask(task *accountTask) {
 	if err := batch.Write(); err != nil {
 		log.Crit("Failed to persist accounts", "err", err)
 	}
-	s.accountBytes += bytes
 	s.accountSynced += uint64(len(res.accounts))
 
 	// Task filling persisted, push it the chunk marker forward to the first
@@ -2091,17 +2105,13 @@ func (s *Syncer) forwardAccountTask(task *accountTask) {
 			log.Error("Failed to commit stack account", "err", err)
 		}
 	}
-	if data := task.genBatch.ValueSize(); data > ethdb.IdealBatchSize || task.done {
-		keys := task.genBatch.KeyCount()
+	if task.genBatch.ValueSize() > ethdb.IdealBatchSize || task.done {
 		if err := task.genBatch.Write(); err != nil {
 			log.Error("Failed to persist stack account", "err", err)
 		}
 		task.genBatch.Reset()
-
-		nodes += keys
-		bytes += common.StorageSize(keys*common.HashLength + data)
 	}
-	log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "nodes", nodes, "bytes", bytes)
+	log.Debug("Persisted range of accounts", "accounts", len(res.accounts), "bytes", s.accountBytes-oldAccountBytes)
 }
 
 // OnAccounts is a callback method to invoke when a range of accounts are
@@ -2176,7 +2186,7 @@ func (s *Syncer) OnAccounts(peer SyncPeer, id uint64, hashes []common.Hash, acco
 	if len(keys) > 0 {
 		end = keys[len(keys)-1]
 	}
-	_, cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb)
+	cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb)
 	if err != nil {
 		logger.Warn("Account range failed proof", "err", err)
 		// Signal this request as failed, and ready for rescheduling
@@ -2393,10 +2403,8 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
 	s.lock.Unlock()
 
 	// Reconstruct the partial tries from the response and verify them
-	var (
-		dbs  = make([]ethdb.KeyValueStore, len(hashes))
-		cont bool
-	)
+	var cont bool
+
 	for i := 0; i < len(hashes); i++ {
 		// Convert the keys and proofs into an internal format
 		keys := make([][]byte, len(hashes[i]))
@@ -2413,7 +2421,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
 		if len(nodes) == 0 {
 			// No proof has been attached, the response must cover the entire key
 			// space and hash to the origin root.
-			dbs[i], _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil)
+			_, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil)
 			if err != nil {
 				s.scheduleRevertStorageRequest(req) // reschedule request
 				logger.Warn("Storage slots failed proof", "err", err)
@@ -2428,7 +2436,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
 			if len(keys) > 0 {
 				end = keys[len(keys)-1]
 			}
-			dbs[i], cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb)
+			cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb)
 			if err != nil {
 				s.scheduleRevertStorageRequest(req) // reschedule request
 				logger.Warn("Storage range failed proof", "err", err)
@@ -2444,7 +2452,6 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
 		roots:    req.roots,
 		hashes:   hashes,
 		slots:    slots,
-		nodes:    dbs,
 		cont:     cont,
 	}
 	select {
diff --git a/ethdb/batch.go b/ethdb/batch.go
index 5f8207fc4..135369331 100644
--- a/ethdb/batch.go
+++ b/ethdb/batch.go
@@ -25,9 +25,6 @@ const IdealBatchSize = 100 * 1024
 type Batch interface {
 	KeyValueWriter
 
-	// KeyCount retrieves the number of keys queued up for writing.
-	KeyCount() int
-
 	// ValueSize retrieves the amount of data queued up for writing.
 	ValueSize() int
 
@@ -47,3 +44,28 @@ type Batcher interface {
 	// until a final write is called.
 	NewBatch() Batch
 }
+
+// HookedBatch wraps an arbitrary batch where each operation may be hooked into
+// to monitor from black box code.
+type HookedBatch struct {
+	Batch
+
+	OnPut    func(key []byte, value []byte) // Callback if a key is inserted
+	OnDelete func(key []byte)               // Callback if a key is deleted
+}
+
+// Put inserts the given value into the key-value data store.
+func (b HookedBatch) Put(key []byte, value []byte) error {
+	if b.OnPut != nil {
+		b.OnPut(key, value)
+	}
+	return b.Batch.Put(key, value)
+}
+
+// Delete removes the key from the key-value data store.
+func (b HookedBatch) Delete(key []byte) error {
+	if b.OnDelete != nil {
+		b.OnDelete(key)
+	}
+	return b.Batch.Delete(key)
+}
diff --git a/ethdb/leveldb/leveldb.go b/ethdb/leveldb/leveldb.go
index da00226e9..5d19cc357 100644
--- a/ethdb/leveldb/leveldb.go
+++ b/ethdb/leveldb/leveldb.go
@@ -448,7 +448,6 @@ func (db *Database) meter(refresh time.Duration) {
 type batch struct {
 	db   *leveldb.DB
 	b    *leveldb.Batch
-	keys int
 	size int
 }
 
@@ -462,16 +461,10 @@ func (b *batch) Put(key, value []byte) error {
 // Delete inserts the a key removal into the batch for later committing.
 func (b *batch) Delete(key []byte) error {
 	b.b.Delete(key)
-	b.keys++
 	b.size += len(key)
 	return nil
 }
 
-// KeyCount retrieves the number of keys queued up for writing.
-func (b *batch) KeyCount() int {
-	return b.keys
-}
-
 // ValueSize retrieves the amount of data queued up for writing.
 func (b *batch) ValueSize() int {
 	return b.size
@@ -485,7 +478,7 @@ func (b *batch) Write() error {
 // Reset resets the batch for reuse.
 func (b *batch) Reset() {
 	b.b.Reset()
-	b.keys, b.size = 0, 0
+	b.size = 0
 }
 
 // Replay replays the batch contents.
diff --git a/ethdb/memorydb/memorydb.go b/ethdb/memorydb/memorydb.go
index ded9f5e66..fedc9e326 100644
--- a/ethdb/memorydb/memorydb.go
+++ b/ethdb/memorydb/memorydb.go
@@ -198,7 +198,6 @@ type keyvalue struct {
 type batch struct {
 	db     *Database
 	writes []keyvalue
-	keys   int
 	size   int
 }
 
@@ -212,16 +211,10 @@ func (b *batch) Put(key, value []byte) error {
 // Delete inserts the a key removal into the batch for later committing.
 func (b *batch) Delete(key []byte) error {
 	b.writes = append(b.writes, keyvalue{common.CopyBytes(key), nil, true})
-	b.keys++
 	b.size += len(key)
 	return nil
 }
 
-// KeyCount retrieves the number of keys queued up for writing.
-func (b *batch) KeyCount() int {
-	return b.keys
-}
-
 // ValueSize retrieves the amount of data queued up for writing.
 func (b *batch) ValueSize() int {
 	return b.size
@@ -245,7 +238,7 @@ func (b *batch) Write() error {
 // Reset resets the batch for reuse.
 func (b *batch) Reset() {
 	b.writes = b.writes[:0]
-	b.keys, b.size = 0, 0
+	b.size = 0
 }
 
 // Replay replays the batch contents.
diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
index 984bb9d0a..09ee6bb9c 100644
--- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
+++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
@@ -170,18 +170,11 @@ func (f *fuzzer) fuzz() int {
 		}
 		ok = 1
 		//nodes, subtrie
-		nodes, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof)
+		hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof)
 		if err != nil {
-			if nodes != nil {
-				panic("err != nil && nodes != nil")
-			}
 			if hasMore {
 				panic("err != nil && hasMore == true")
 			}
-		} else {
-			if nodes == nil {
-				panic("err == nil && nodes == nil")
-			}
 		}
 	}
 	return ok
diff --git a/tests/fuzzers/stacktrie/trie_fuzzer.go b/tests/fuzzers/stacktrie/trie_fuzzer.go
index 0013c919c..5cea7769c 100644
--- a/tests/fuzzers/stacktrie/trie_fuzzer.go
+++ b/tests/fuzzers/stacktrie/trie_fuzzer.go
@@ -90,7 +90,6 @@ func (b *spongeBatch) Put(key, value []byte) error {
 	return nil
 }
 func (b *spongeBatch) Delete(key []byte) error             { panic("implement me") }
-func (b *spongeBatch) KeyCount() int                       { panic("not implemented") }
 func (b *spongeBatch) ValueSize() int                      { return 100 }
 func (b *spongeBatch) Write() error                        { return nil }
 func (b *spongeBatch) Reset()                              {}
diff --git a/trie/notary.go b/trie/notary.go
deleted file mode 100644
index 10c7628f5..000000000
--- a/trie/notary.go
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2020 The go-ethereum Authors
-// This file is part of the go-ethereum library.
-//
-// The go-ethereum library is free software: you can redistribute it and/or modify
-// it under the terms of the GNU Lesser General Public License as published by
-// the Free Software Foundation, either version 3 of the License, or
-// (at your option) any later version.
-//
-// The go-ethereum library is distributed in the hope that it will be useful,
-// but WITHOUT ANY WARRANTY; without even the implied warranty of
-// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
-// GNU Lesser General Public License for more details.
-//
-// You should have received a copy of the GNU Lesser General Public License
-// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
-
-package trie
-
-import (
-	"github.com/ethereum/go-ethereum/ethdb"
-	"github.com/ethereum/go-ethereum/ethdb/memorydb"
-)
-
-// keyValueNotary tracks which keys have been accessed through a key-value reader
-// with te scope of verifying if certain proof datasets are maliciously bloated.
-type keyValueNotary struct {
-	ethdb.KeyValueReader
-	reads map[string]struct{}
-}
-
-// newKeyValueNotary wraps a key-value database with an access notary to track
-// which items have bene accessed.
-func newKeyValueNotary(db ethdb.KeyValueReader) *keyValueNotary {
-	return &keyValueNotary{
-		KeyValueReader: db,
-		reads:          make(map[string]struct{}),
-	}
-}
-
-// Get retrieves an item from the underlying database, but also tracks it as an
-// accessed slot for bloat checks.
-func (k *keyValueNotary) Get(key []byte) ([]byte, error) {
-	k.reads[string(key)] = struct{}{}
-	return k.KeyValueReader.Get(key)
-}
-
-// Accessed returns s snapshot of the original key-value store containing only the
-// data accessed through the notary.
-func (k *keyValueNotary) Accessed() ethdb.KeyValueStore {
-	db := memorydb.New()
-	for keystr := range k.reads {
-		key := []byte(keystr)
-		val, _ := k.KeyValueReader.Get(key)
-		db.Put(key, val)
-	}
-	return db
-}
diff --git a/trie/proof.go b/trie/proof.go
index 2feed24de..08a9e4042 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -464,108 +464,91 @@ func hasRightElement(node node, key []byte) bool {
 //
 // Except returning the error to indicate the proof is valid or not, the function will
 // also return a flag to indicate whether there exists more accounts/slots in the trie.
-func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, bool, error) {
+//
+// Note: This method does not verify that the proof is of minimal form. If the input
+// proofs are 'bloated' with neighbour leaves or random data, aside from the 'useful'
+// data, then the proof will still be accepted.
+func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (bool, error) {
 	if len(keys) != len(values) {
-		return nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
+		return false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
 	}
 	// Ensure the received batch is monotonic increasing.
 	for i := 0; i < len(keys)-1; i++ {
 		if bytes.Compare(keys[i], keys[i+1]) >= 0 {
-			return nil, false, errors.New("range is not monotonically increasing")
+			return false, errors.New("range is not monotonically increasing")
 		}
 	}
-	// Create a key-value notary to track which items from the given proof the
-	// range prover actually needed to verify the data
-	notary := newKeyValueNotary(proof)
-
 	// Special case, there is no edge proof at all. The given range is expected
 	// to be the whole leaf-set in the trie.
 	if proof == nil {
-		var (
-			diskdb = memorydb.New()
-			tr     = NewStackTrie(diskdb)
-		)
+		tr := NewStackTrie(nil)
 		for index, key := range keys {
 			tr.TryUpdate(key, values[index])
 		}
 		if have, want := tr.Hash(), rootHash; have != want {
-			return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
-		}
-		// Proof seems valid, serialize remaining nodes into the database
-		if _, err := tr.Commit(); err != nil {
-			return nil, false, err
+			return false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
 		}
-		return diskdb, false, nil // No more elements
+		return false, nil // No more elements
 	}
 	// Special case, there is a provided edge proof but zero key/value
 	// pairs, ensure there are no more accounts / slots in the trie.
 	if len(keys) == 0 {
-		root, val, err := proofToPath(rootHash, nil, firstKey, notary, true)
+		root, val, err := proofToPath(rootHash, nil, firstKey, proof, true)
 		if err != nil {
-			return nil, false, err
+			return false, err
 		}
 		if val != nil || hasRightElement(root, firstKey) {
-			return nil, false, errors.New("more entries available")
+			return false, errors.New("more entries available")
 		}
-		// Since the entire proof is a single path, we can construct a trie and a
-		// node database directly out of the inputs, no need to generate them
-		diskdb := notary.Accessed()
-		return diskdb, hasRightElement(root, firstKey), nil
+		return hasRightElement(root, firstKey), nil
 	}
 	// Special case, there is only one element and two edge keys are same.
 	// In this case, we can't construct two edge paths. So handle it here.
 	if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
-		root, val, err := proofToPath(rootHash, nil, firstKey, notary, false)
+		root, val, err := proofToPath(rootHash, nil, firstKey, proof, false)
 		if err != nil {
-			return nil, false, err
+			return false, err
 		}
 		if !bytes.Equal(firstKey, keys[0]) {
-			return nil, false, errors.New("correct proof but invalid key")
+			return false, errors.New("correct proof but invalid key")
 		}
 		if !bytes.Equal(val, values[0]) {
-			return nil, false, errors.New("correct proof but invalid data")
+			return false, errors.New("correct proof but invalid data")
 		}
-		// Since the entire proof is a single path, we can construct a trie and a
-		// node database directly out of the inputs, no need to generate them
-		diskdb := notary.Accessed()
-		return diskdb, hasRightElement(root, firstKey), nil
+		return hasRightElement(root, firstKey), nil
 	}
 	// Ok, in all other cases, we require two edge paths available.
 	// First check the validity of edge keys.
 	if bytes.Compare(firstKey, lastKey) >= 0 {
-		return nil, false, errors.New("invalid edge keys")
+		return false, errors.New("invalid edge keys")
 	}
 	// todo(rjl493456442) different length edge keys should be supported
 	if len(firstKey) != len(lastKey) {
-		return nil, false, errors.New("inconsistent edge keys")
+		return false, errors.New("inconsistent edge keys")
 	}
 	// Convert the edge proofs to edge trie paths. Then we can
 	// have the same tree architecture with the original one.
 	// For the first edge proof, non-existent proof is allowed.
-	root, _, err := proofToPath(rootHash, nil, firstKey, notary, true)
+	root, _, err := proofToPath(rootHash, nil, firstKey, proof, true)
 	if err != nil {
-		return nil, false, err
+		return false, err
 	}
 	// Pass the root node here, the second path will be merged
 	// with the first one. For the last edge proof, non-existent
 	// proof is also allowed.
-	root, _, err = proofToPath(rootHash, root, lastKey, notary, true)
+	root, _, err = proofToPath(rootHash, root, lastKey, proof, true)
 	if err != nil {
-		return nil, false, err
+		return false, err
 	}
 	// Remove all internal references. All the removed parts should
 	// be re-filled(or re-constructed) by the given leaves range.
 	empty, err := unsetInternal(root, firstKey, lastKey)
 	if err != nil {
-		return nil, false, err
+		return false, err
 	}
 	// Rebuild the trie with the leaf stream, the shape of trie
 	// should be same with the original one.
-	var (
-		diskdb = memorydb.New()
-		triedb = NewDatabase(diskdb)
-	)
-	tr := &Trie{root: root, db: triedb}
+	tr := &Trie{root: root, db: NewDatabase(memorydb.New())}
 	if empty {
 		tr.root = nil
 	}
@@ -573,16 +556,9 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
 		tr.TryUpdate(key, values[index])
 	}
 	if tr.Hash() != rootHash {
-		return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
-	}
-	// Proof seems valid, serialize all the nodes into the database
-	if _, err := tr.Commit(nil); err != nil {
-		return nil, false, err
-	}
-	if err := triedb.Commit(rootHash, false, nil); err != nil {
-		return nil, false, err
+		return false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
 	}
-	return diskdb, hasRightElement(root, keys[len(keys)-1]), nil
+	return hasRightElement(root, keys[len(keys)-1]), nil
 }
 
 // get returns the child of the given node. Return nil if the
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 7a906e254..a35b7144c 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -182,7 +182,7 @@ func TestRangeProof(t *testing.T) {
 			keys = append(keys, entries[i].k)
 			vals = append(vals, entries[i].v)
 		}
-		_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+		_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
 		if err != nil {
 			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
@@ -233,7 +233,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
 			keys = append(keys, entries[i].k)
 			vals = append(vals, entries[i].v)
 		}
-		_, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+		_, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
 		if err != nil {
 			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
@@ -254,7 +254,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+	_, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
 	if err != nil {
 		t.Fatal("Failed to verify whole rang with non-existent edges")
 	}
@@ -289,7 +289,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
+	_, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
 	if err == nil {
 		t.Fatalf("Expected to detect the error, got nil")
 	}
@@ -311,7 +311,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
+	_, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
 	if err == nil {
 		t.Fatalf("Expected to detect the error, got nil")
 	}
@@ -335,7 +335,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(entries[start].k, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the first node %v", err)
 	}
-	_, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -350,7 +350,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(entries[start].k, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -365,7 +365,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -380,7 +380,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -399,7 +399,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := tinyTrie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
+	_, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -421,7 +421,7 @@ func TestAllElementsProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
+	_, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -434,7 +434,7 @@ func TestAllElementsProof(t *testing.T) {
 	if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
+	_, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -449,7 +449,7 @@ func TestAllElementsProof(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+	_, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -482,7 +482,7 @@ func TestSingleSideRangeProof(t *testing.T) {
 				k = append(k, entries[i].k)
 				v = append(v, entries[i].v)
 			}
-			_, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
+			_, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
 			if err != nil {
 				t.Fatalf("Expected no error, got %v", err)
 			}
@@ -518,7 +518,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
 				k = append(k, entries[i].k)
 				v = append(v, entries[i].v)
 			}
-			_, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
+			_, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
 			if err != nil {
 				t.Fatalf("Expected no error, got %v", err)
 			}
@@ -590,7 +590,7 @@ func TestBadRangeProof(t *testing.T) {
 			index = mrand.Intn(end - start)
 			vals[index] = nil
 		}
-		_, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+		_, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
 		if err == nil {
 			t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
 		}
@@ -624,7 +624,7 @@ func TestGappedRangeProof(t *testing.T) {
 		keys = append(keys, entries[i].k)
 		vals = append(vals, entries[i].v)
 	}
-	_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+	_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
 	if err == nil {
 		t.Fatal("expect error, got nil")
 	}
@@ -651,7 +651,7 @@ func TestSameSideProofs(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+	_, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
 	if err == nil {
 		t.Fatalf("Expected error, got nil")
 	}
@@ -667,7 +667,7 @@ func TestSameSideProofs(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+	_, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
 	if err == nil {
 		t.Fatalf("Expected error, got nil")
 	}
@@ -735,7 +735,7 @@ func TestHasRightElement(t *testing.T) {
 			k = append(k, entries[i].k)
 			v = append(v, entries[i].v)
 		}
-		_, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
+		hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
 		if err != nil {
 			t.Fatalf("Expected no error, got %v", err)
 		}
@@ -768,25 +768,19 @@ func TestEmptyRangeProof(t *testing.T) {
 		if err := trie.Prove(first, 0, proof); err != nil {
 			t.Fatalf("Failed to prove the first node %v", err)
 		}
-		db, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
+		_, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
 		if c.err && err == nil {
 			t.Fatalf("Expected error, got nil")
 		}
 		if !c.err && err != nil {
 			t.Fatalf("Expected no error, got %v", err)
 		}
-		// If no error was returned, ensure the returned database contains
-		// the entire proof, since there's no value
-		if !c.err {
-			if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() {
-				t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len())
-			}
-		}
 	}
 }
 
 // TestBloatedProof tests a malicious proof, where the proof is more or less the
-// whole trie.
+// whole trie. Previously we didn't accept such packets, but the new APIs do, so
+// lets leave this test as a bit weird, but present.
 func TestBloatedProof(t *testing.T) {
 	// Use a small trie
 	trie, kvs := nonRandomTrie(100)
@@ -814,10 +808,8 @@ func TestBloatedProof(t *testing.T) {
 	trie.Prove(keys[0], 0, want)
 	trie.Prove(keys[len(keys)-1], 0, want)
 
-	db, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
-	// The db should not contain anything of the bloated data
-	if used := db.(*memorydb.Database); used.Len() != want.Len() {
-		t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len())
+	if _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof); err != nil {
+		t.Fatalf("expected bloated proof to succeed, got %v", err)
 	}
 }
 
@@ -921,7 +913,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
 
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
-		_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
+		_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
 		if err != nil {
 			b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
@@ -948,7 +940,7 @@ func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
 	}
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
-		_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil)
+		_, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil)
 		if err != nil {
 			b.Fatalf("Expected no error, got %v", err)
 		}
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 44fddf87e..492b423c2 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -706,7 +706,6 @@ func (b *spongeBatch) Put(key, value []byte) error {
 	return nil
 }
 func (b *spongeBatch) Delete(key []byte) error             { panic("implement me") }
-func (b *spongeBatch) KeyCount() int                       { return 100 }
 func (b *spongeBatch) ValueSize() int                      { return 100 }
 func (b *spongeBatch) Write() error                        { return nil }
 func (b *spongeBatch) Reset()                              {}
-- 
GitLab