From a81cf0d2b3497e5d78b2c06427953b90c1a0d70f Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Wed, 28 Apr 2021 21:47:48 +0200
Subject: [PATCH] trie: remove redundant returns + use stacktrie where
 applicable (#22760)

* trie: add benchmark for proofless range

* trie: remove unused returns + use stacktrie
---
 core/state/snapshot/generate.go               |  2 +-
 eth/protocols/snap/sync.go                    |  6 +-
 tests/fuzzers/rangeproof/rangeproof-fuzzer.go | 14 +--
 trie/notary.go                                | 14 +--
 trie/proof.go                                 | 69 ++++++---------
 trie/proof_test.go                            | 88 ++++++++++++-------
 6 files changed, 96 insertions(+), 97 deletions(-)

diff --git a/core/state/snapshot/generate.go b/core/state/snapshot/generate.go
index 13b34f4d6..78fca45e4 100644
--- a/core/state/snapshot/generate.go
+++ b/core/state/snapshot/generate.go
@@ -368,7 +368,7 @@ func (dl *diskLayer) proveRange(stats *generatorStats, root common.Hash, prefix
 	}
 	// Verify the snapshot segment with range prover, ensure that all flat states
 	// in this range correspond to merkle trie.
-	_, _, _, cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof)
+	_, cont, err := trie.VerifyRangeProof(root, origin, last, keys, vals, proof)
 	return &proofResult{
 			keys:     keys,
 			vals:     vals,
diff --git a/eth/protocols/snap/sync.go b/eth/protocols/snap/sync.go
index 2ad677f94..287ac8d72 100644
--- a/eth/protocols/snap/sync.go
+++ b/eth/protocols/snap/sync.go
@@ -2176,7 +2176,7 @@ func (s *Syncer) OnAccounts(peer SyncPeer, id uint64, hashes []common.Hash, acco
 	if len(keys) > 0 {
 		end = keys[len(keys)-1]
 	}
-	_, _, _, cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb)
+	_, cont, err := trie.VerifyRangeProof(root, req.origin[:], end, keys, accounts, proofdb)
 	if err != nil {
 		logger.Warn("Account range failed proof", "err", err)
 		// Signal this request as failed, and ready for rescheduling
@@ -2413,7 +2413,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
 		if len(nodes) == 0 {
 			// No proof has been attached, the response must cover the entire key
 			// space and hash to the origin root.
-			dbs[i], _, _, _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil)
+			dbs[i], _, err = trie.VerifyRangeProof(req.roots[i], nil, nil, keys, slots[i], nil)
 			if err != nil {
 				s.scheduleRevertStorageRequest(req) // reschedule request
 				logger.Warn("Storage slots failed proof", "err", err)
@@ -2428,7 +2428,7 @@ func (s *Syncer) OnStorage(peer SyncPeer, id uint64, hashes [][]common.Hash, slo
 			if len(keys) > 0 {
 				end = keys[len(keys)-1]
 			}
-			dbs[i], _, _, cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb)
+			dbs[i], cont, err = trie.VerifyRangeProof(req.roots[i], req.origin[:], end, keys, slots[i], proofdb)
 			if err != nil {
 				s.scheduleRevertStorageRequest(req) // reschedule request
 				logger.Warn("Storage range failed proof", "err", err)
diff --git a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
index b82a38072..984bb9d0a 100644
--- a/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
+++ b/tests/fuzzers/rangeproof/rangeproof-fuzzer.go
@@ -170,17 +170,11 @@ func (f *fuzzer) fuzz() int {
 		}
 		ok = 1
 		//nodes, subtrie
-		nodes, subtrie, notary, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof)
+		nodes, hasMore, err := trie.VerifyRangeProof(tr.Hash(), first, last, keys, vals, proof)
 		if err != nil {
 			if nodes != nil {
 				panic("err != nil && nodes != nil")
 			}
-			if subtrie != nil {
-				panic("err != nil && subtrie != nil")
-			}
-			if notary != nil {
-				panic("err != nil && notary != nil")
-			}
 			if hasMore {
 				panic("err != nil && hasMore == true")
 			}
@@ -188,12 +182,6 @@ func (f *fuzzer) fuzz() int {
 			if nodes == nil {
 				panic("err == nil && nodes == nil")
 			}
-			if subtrie == nil {
-				panic("err == nil && subtrie == nil")
-			}
-			if notary == nil {
-				panic("err == nil && subtrie == nil")
-			}
 		}
 	}
 	return ok
diff --git a/trie/notary.go b/trie/notary.go
index 5a64727aa..10c7628f5 100644
--- a/trie/notary.go
+++ b/trie/notary.go
@@ -21,17 +21,17 @@ import (
 	"github.com/ethereum/go-ethereum/ethdb/memorydb"
 )
 
-// KeyValueNotary tracks which keys have been accessed through a key-value reader
+// keyValueNotary tracks which keys have been accessed through a key-value reader
 // with te scope of verifying if certain proof datasets are maliciously bloated.
-type KeyValueNotary struct {
+type keyValueNotary struct {
 	ethdb.KeyValueReader
 	reads map[string]struct{}
 }
 
-// NewKeyValueNotary wraps a key-value database with an access notary to track
+// newKeyValueNotary wraps a key-value database with an access notary to track
 // which items have bene accessed.
-func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary {
-	return &KeyValueNotary{
+func newKeyValueNotary(db ethdb.KeyValueReader) *keyValueNotary {
+	return &keyValueNotary{
 		KeyValueReader: db,
 		reads:          make(map[string]struct{}),
 	}
@@ -39,14 +39,14 @@ func NewKeyValueNotary(db ethdb.KeyValueReader) *KeyValueNotary {
 
 // Get retrieves an item from the underlying database, but also tracks it as an
 // accessed slot for bloat checks.
-func (k *KeyValueNotary) Get(key []byte) ([]byte, error) {
+func (k *keyValueNotary) Get(key []byte) ([]byte, error) {
 	k.reads[string(key)] = struct{}{}
 	return k.KeyValueReader.Get(key)
 }
 
 // Accessed returns s snapshot of the original key-value store containing only the
 // data accessed through the notary.
-func (k *KeyValueNotary) Accessed() ethdb.KeyValueStore {
+func (k *keyValueNotary) Accessed() ethdb.KeyValueStore {
 	db := memorydb.New()
 	for keystr := range k.reads {
 		key := []byte(keystr)
diff --git a/trie/proof.go b/trie/proof.go
index 61c35a842..2feed24de 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -464,115 +464,100 @@ func hasRightElement(node node, key []byte) bool {
 //
 // Except returning the error to indicate the proof is valid or not, the function will
 // also return a flag to indicate whether there exists more accounts/slots in the trie.
-func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, *Trie, *KeyValueNotary, bool, error) {
+func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, keys [][]byte, values [][]byte, proof ethdb.KeyValueReader) (ethdb.KeyValueStore, bool, error) {
 	if len(keys) != len(values) {
-		return nil, nil, nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
+		return nil, false, fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
 	}
 	// Ensure the received batch is monotonic increasing.
 	for i := 0; i < len(keys)-1; i++ {
 		if bytes.Compare(keys[i], keys[i+1]) >= 0 {
-			return nil, nil, nil, false, errors.New("range is not monotonically increasing")
+			return nil, false, errors.New("range is not monotonically increasing")
 		}
 	}
 	// Create a key-value notary to track which items from the given proof the
 	// range prover actually needed to verify the data
-	notary := NewKeyValueNotary(proof)
+	notary := newKeyValueNotary(proof)
 
 	// Special case, there is no edge proof at all. The given range is expected
 	// to be the whole leaf-set in the trie.
 	if proof == nil {
 		var (
 			diskdb = memorydb.New()
-			triedb = NewDatabase(diskdb)
+			tr     = NewStackTrie(diskdb)
 		)
-		tr, err := New(common.Hash{}, triedb)
-		if err != nil {
-			return nil, nil, nil, false, err
-		}
 		for index, key := range keys {
 			tr.TryUpdate(key, values[index])
 		}
-		if tr.Hash() != rootHash {
-			return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
-		}
-		// Proof seems valid, serialize all the nodes into the database
-		if _, err := tr.Commit(nil); err != nil {
-			return nil, nil, nil, false, err
+		if have, want := tr.Hash(), rootHash; have != want {
+			return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", want, have)
 		}
-		if err := triedb.Commit(rootHash, false, nil); err != nil {
-			return nil, nil, nil, false, err
+		// Proof seems valid, serialize remaining nodes into the database
+		if _, err := tr.Commit(); err != nil {
+			return nil, false, err
 		}
-		return diskdb, tr, notary, false, nil // No more elements
+		return diskdb, false, nil // No more elements
 	}
 	// Special case, there is a provided edge proof but zero key/value
 	// pairs, ensure there are no more accounts / slots in the trie.
 	if len(keys) == 0 {
 		root, val, err := proofToPath(rootHash, nil, firstKey, notary, true)
 		if err != nil {
-			return nil, nil, nil, false, err
+			return nil, false, err
 		}
 		if val != nil || hasRightElement(root, firstKey) {
-			return nil, nil, nil, false, errors.New("more entries available")
+			return nil, false, errors.New("more entries available")
 		}
 		// Since the entire proof is a single path, we can construct a trie and a
 		// node database directly out of the inputs, no need to generate them
 		diskdb := notary.Accessed()
-		tr := &Trie{
-			db:   NewDatabase(diskdb),
-			root: root,
-		}
-		return diskdb, tr, notary, hasRightElement(root, firstKey), nil
+		return diskdb, hasRightElement(root, firstKey), nil
 	}
 	// Special case, there is only one element and two edge keys are same.
 	// In this case, we can't construct two edge paths. So handle it here.
 	if len(keys) == 1 && bytes.Equal(firstKey, lastKey) {
 		root, val, err := proofToPath(rootHash, nil, firstKey, notary, false)
 		if err != nil {
-			return nil, nil, nil, false, err
+			return nil, false, err
 		}
 		if !bytes.Equal(firstKey, keys[0]) {
-			return nil, nil, nil, false, errors.New("correct proof but invalid key")
+			return nil, false, errors.New("correct proof but invalid key")
 		}
 		if !bytes.Equal(val, values[0]) {
-			return nil, nil, nil, false, errors.New("correct proof but invalid data")
+			return nil, false, errors.New("correct proof but invalid data")
 		}
 		// Since the entire proof is a single path, we can construct a trie and a
 		// node database directly out of the inputs, no need to generate them
 		diskdb := notary.Accessed()
-		tr := &Trie{
-			db:   NewDatabase(diskdb),
-			root: root,
-		}
-		return diskdb, tr, notary, hasRightElement(root, firstKey), nil
+		return diskdb, hasRightElement(root, firstKey), nil
 	}
 	// Ok, in all other cases, we require two edge paths available.
 	// First check the validity of edge keys.
 	if bytes.Compare(firstKey, lastKey) >= 0 {
-		return nil, nil, nil, false, errors.New("invalid edge keys")
+		return nil, false, errors.New("invalid edge keys")
 	}
 	// todo(rjl493456442) different length edge keys should be supported
 	if len(firstKey) != len(lastKey) {
-		return nil, nil, nil, false, errors.New("inconsistent edge keys")
+		return nil, false, errors.New("inconsistent edge keys")
 	}
 	// Convert the edge proofs to edge trie paths. Then we can
 	// have the same tree architecture with the original one.
 	// For the first edge proof, non-existent proof is allowed.
 	root, _, err := proofToPath(rootHash, nil, firstKey, notary, true)
 	if err != nil {
-		return nil, nil, nil, false, err
+		return nil, false, err
 	}
 	// Pass the root node here, the second path will be merged
 	// with the first one. For the last edge proof, non-existent
 	// proof is also allowed.
 	root, _, err = proofToPath(rootHash, root, lastKey, notary, true)
 	if err != nil {
-		return nil, nil, nil, false, err
+		return nil, false, err
 	}
 	// Remove all internal references. All the removed parts should
 	// be re-filled(or re-constructed) by the given leaves range.
 	empty, err := unsetInternal(root, firstKey, lastKey)
 	if err != nil {
-		return nil, nil, nil, false, err
+		return nil, false, err
 	}
 	// Rebuild the trie with the leaf stream, the shape of trie
 	// should be same with the original one.
@@ -588,16 +573,16 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
 		tr.TryUpdate(key, values[index])
 	}
 	if tr.Hash() != rootHash {
-		return nil, nil, nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
+		return nil, false, fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, tr.Hash())
 	}
 	// Proof seems valid, serialize all the nodes into the database
 	if _, err := tr.Commit(nil); err != nil {
-		return nil, nil, nil, false, err
+		return nil, false, err
 	}
 	if err := triedb.Commit(rootHash, false, nil); err != nil {
-		return nil, nil, nil, false, err
+		return nil, false, err
 	}
-	return diskdb, tr, notary, hasRightElement(root, keys[len(keys)-1]), nil
+	return diskdb, hasRightElement(root, keys[len(keys)-1]), nil
 }
 
 // get returns the child of the given node. Return nil if the
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 304affa9f..7a906e254 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -182,7 +182,7 @@ func TestRangeProof(t *testing.T) {
 			keys = append(keys, entries[i].k)
 			vals = append(vals, entries[i].v)
 		}
-		_, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+		_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
 		if err != nil {
 			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
@@ -233,7 +233,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
 			keys = append(keys, entries[i].k)
 			vals = append(vals, entries[i].v)
 		}
-		_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+		_, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
 		if err != nil {
 			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
@@ -254,7 +254,7 @@ func TestRangeProofWithNonExistentProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+	_, _, err := VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
 	if err != nil {
 		t.Fatal("Failed to verify whole rang with non-existent edges")
 	}
@@ -289,7 +289,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
+	_, _, err := VerifyRangeProof(trie.Hash(), first, k[len(k)-1], k, v, proof)
 	if err == nil {
 		t.Fatalf("Expected to detect the error, got nil")
 	}
@@ -311,7 +311,7 @@ func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), k[0], last, k, v, proof)
 	if err == nil {
 		t.Fatalf("Expected to detect the error, got nil")
 	}
@@ -335,7 +335,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(entries[start].k, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the first node %v", err)
 	}
-	_, _, _, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, _, err := VerifyRangeProof(trie.Hash(), entries[start].k, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -350,7 +350,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(entries[start].k, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), first, entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -365,7 +365,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), entries[start].k, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -380,7 +380,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[start].k}, [][]byte{entries[start].v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -399,7 +399,7 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err := tinyTrie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
+	_, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -421,7 +421,7 @@ func TestAllElementsProof(t *testing.T) {
 		k = append(k, entries[i].k)
 		v = append(v, entries[i].v)
 	}
-	_, _, _, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
+	_, _, err := VerifyRangeProof(trie.Hash(), nil, nil, k, v, nil)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -434,7 +434,7 @@ func TestAllElementsProof(t *testing.T) {
 	if err := trie.Prove(entries[len(entries)-1].k, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), k[0], k[len(k)-1], k, v, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -449,7 +449,7 @@ func TestAllElementsProof(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), first, last, k, v, proof)
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
@@ -482,7 +482,7 @@ func TestSingleSideRangeProof(t *testing.T) {
 				k = append(k, entries[i].k)
 				v = append(v, entries[i].v)
 			}
-			_, _, _, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
+			_, _, err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k[len(k)-1], k, v, proof)
 			if err != nil {
 				t.Fatalf("Expected no error, got %v", err)
 			}
@@ -518,7 +518,7 @@ func TestReverseSingleSideRangeProof(t *testing.T) {
 				k = append(k, entries[i].k)
 				v = append(v, entries[i].v)
 			}
-			_, _, _, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
+			_, _, err := VerifyRangeProof(trie.Hash(), k[0], last.Bytes(), k, v, proof)
 			if err != nil {
 				t.Fatalf("Expected no error, got %v", err)
 			}
@@ -590,7 +590,7 @@ func TestBadRangeProof(t *testing.T) {
 			index = mrand.Intn(end - start)
 			vals[index] = nil
 		}
-		_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
+		_, _, err := VerifyRangeProof(trie.Hash(), first, last, keys, vals, proof)
 		if err == nil {
 			t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
 		}
@@ -624,7 +624,7 @@ func TestGappedRangeProof(t *testing.T) {
 		keys = append(keys, entries[i].k)
 		vals = append(vals, entries[i].v)
 	}
-	_, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+	_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
 	if err == nil {
 		t.Fatal("expect error, got nil")
 	}
@@ -651,7 +651,7 @@ func TestSameSideProofs(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+	_, _, err := VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
 	if err == nil {
 		t.Fatalf("Expected error, got nil")
 	}
@@ -667,7 +667,7 @@ func TestSameSideProofs(t *testing.T) {
 	if err := trie.Prove(last, 0, proof); err != nil {
 		t.Fatalf("Failed to prove the last node %v", err)
 	}
-	_, _, _, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
+	_, _, err = VerifyRangeProof(trie.Hash(), first, last, [][]byte{entries[pos].k}, [][]byte{entries[pos].v}, proof)
 	if err == nil {
 		t.Fatalf("Expected error, got nil")
 	}
@@ -735,7 +735,7 @@ func TestHasRightElement(t *testing.T) {
 			k = append(k, entries[i].k)
 			v = append(v, entries[i].v)
 		}
-		_, _, _, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
+		_, hasMore, err := VerifyRangeProof(trie.Hash(), firstKey, lastKey, k, v, proof)
 		if err != nil {
 			t.Fatalf("Expected no error, got %v", err)
 		}
@@ -768,25 +768,19 @@ func TestEmptyRangeProof(t *testing.T) {
 		if err := trie.Prove(first, 0, proof); err != nil {
 			t.Fatalf("Failed to prove the first node %v", err)
 		}
-		db, tr, not, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
+		db, _, err := VerifyRangeProof(trie.Hash(), first, nil, nil, nil, proof)
 		if c.err && err == nil {
 			t.Fatalf("Expected error, got nil")
 		}
 		if !c.err && err != nil {
 			t.Fatalf("Expected no error, got %v", err)
 		}
-		// If no error was returned, ensure the returned trie and database contains
+		// If no error was returned, ensure the returned database contains
 		// the entire proof, since there's no value
 		if !c.err {
-			if err := tr.Prove(first, 0, memorydb.New()); err != nil {
-				t.Errorf("returned trie doesn't contain original proof: %v", err)
-			}
 			if memdb := db.(*memorydb.Database); memdb.Len() != proof.Len() {
 				t.Errorf("database entry count mismatch: have %d, want %d", memdb.Len(), proof.Len())
 			}
-			if not == nil {
-				t.Errorf("missing notary")
-			}
 		}
 	}
 }
@@ -805,6 +799,8 @@ func TestBloatedProof(t *testing.T) {
 	var vals [][]byte
 
 	proof := memorydb.New()
+	// In the 'malicious' case, we add proofs for every single item
+	// (but only one key/value pair used as leaf)
 	for i, entry := range entries {
 		trie.Prove(entry.k, 0, proof)
 		if i == 50 {
@@ -812,12 +808,15 @@ func TestBloatedProof(t *testing.T) {
 			vals = append(vals, entry.v)
 		}
 	}
+	// For reference, we use the same function, but _only_ prove the first
+	// and last element
 	want := memorydb.New()
 	trie.Prove(keys[0], 0, want)
 	trie.Prove(keys[len(keys)-1], 0, want)
 
-	_, _, notary, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
-	if used := notary.Accessed().(*memorydb.Database); used.Len() != want.Len() {
+	db, _, _ := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, vals, proof)
+	// The db should not contain anything of the bloated data
+	if used := db.(*memorydb.Database); used.Len() != want.Len() {
 		t.Fatalf("notary proof size mismatch: have %d, want %d", used.Len(), want.Len())
 	}
 }
@@ -922,13 +921,40 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
 
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
-		_, _, _, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
+		_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, proof)
 		if err != nil {
 			b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
 	}
 }
 
+func BenchmarkVerifyRangeNoProof10(b *testing.B)   { benchmarkVerifyRangeNoProof(b, 100) }
+func BenchmarkVerifyRangeNoProof500(b *testing.B)  { benchmarkVerifyRangeNoProof(b, 500) }
+func BenchmarkVerifyRangeNoProof1000(b *testing.B) { benchmarkVerifyRangeNoProof(b, 1000) }
+
+func benchmarkVerifyRangeNoProof(b *testing.B, size int) {
+	trie, vals := randomTrie(size)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	var keys [][]byte
+	var values [][]byte
+	for _, entry := range entries {
+		keys = append(keys, entry.k)
+		values = append(values, entry.v)
+	}
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_, _, err := VerifyRangeProof(trie.Hash(), keys[0], keys[len(keys)-1], keys, values, nil)
+		if err != nil {
+			b.Fatalf("Expected no error, got %v", err)
+		}
+	}
+}
+
 func randomTrie(n int) (*Trie, map[string]*kv) {
 	trie := new(Trie)
 	vals := make(map[string]*kv)
-- 
GitLab