From 65ce550b37670ce34aeaeaa6e66510028d2f7603 Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Wed, 20 May 2020 20:45:38 +0800
Subject: [PATCH] trie: extend range proofs with non-existence (#21000)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

* trie: implement range proof with non-existent edge proof

* trie: fix cornercase

* trie: consider empty range

* trie: add singleSide test

* trie: support all-elements range proof

* trie: fix typo

* trie: tiny typos and formulations

Co-authored-by: Péter Szilágyi <peterke@gmail.com>
---
 go.sum             |   2 +
 trie/proof.go      | 213 +++++++++++++++++++++-----
 trie/proof_test.go | 368 ++++++++++++++++++++++++++++++++++++++-------
 3 files changed, 491 insertions(+), 92 deletions(-)

diff --git a/go.sum b/go.sum
index 23b89d464..2ab3873f4 100644
--- a/go.sum
+++ b/go.sum
@@ -202,6 +202,8 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h
 golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
+golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884=
 golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
diff --git a/trie/proof.go b/trie/proof.go
index 07ce8e6d8..0f7d56a64 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -133,7 +133,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader)
 // The main purpose of this function is recovering a node
 // path from the merkle proof stream. All necessary nodes
 // will be resolved and leave the remaining as hashnode.
-func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader) (node, error) {
+func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader, allowNonExistent bool) (node, error) {
 	// resolveNode retrieves and resolves trie node from merkle proof stream
 	resolveNode := func(hash common.Hash) (node, error) {
 		buf, _ := proofDb.Get(hash[:])
@@ -146,7 +146,8 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV
 		}
 		return n, err
 	}
-	// If the root node is empty, resolve it first
+	// If the root node is empty, resolve it first.
+	// Root node must be included in the proof.
 	if root == nil {
 		n, err := resolveNode(rootHash)
 		if err != nil {
@@ -165,7 +166,13 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV
 		keyrest, child = get(parent, key, false)
 		switch cld := child.(type) {
 		case nil:
-			// The trie doesn't contain the key.
+			// The trie doesn't contain the key. It's possible
+			// the proof is a non-existing proof, but at least
+			// we can prove all resolved nodes are correct, it's
+			// enough for us to prove range.
+			if allowNonExistent {
+				return root, nil
+			}
 			return nil, errors.New("the node is not contained in trie")
 		case *shortNode:
 			key, parent = keyrest, child // Already resolved
@@ -205,7 +212,7 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV
 // since the node content might be modified. Besides it can happen that some
 // fullnodes only have one child which is disallowed. But if the proof is valid,
 // the missing children will be filled, otherwise it will be thrown anyway.
-func unsetInternal(node node, left []byte, right []byte) error {
+func unsetInternal(n node, left []byte, right []byte) error {
 	left, right = keybytesToHex(left), keybytesToHex(right)
 
 	// todo(rjl493456442) different length edge keys should be supported
@@ -214,25 +221,37 @@ func unsetInternal(node node, left []byte, right []byte) error {
 	}
 	// Step down to the fork point
 	prefix, pos := prefixLen(left, right), 0
+	var parent node
 	for {
 		if pos >= prefix {
 			break
 		}
-		switch n := (node).(type) {
+		switch rn := (n).(type) {
 		case *shortNode:
-			if len(left)-pos < len(n.Key) || !bytes.Equal(n.Key, left[pos:pos+len(n.Key)]) {
+			if len(right)-pos < len(rn.Key) || !bytes.Equal(rn.Key, right[pos:pos+len(rn.Key)]) {
 				return errors.New("invalid edge path")
 			}
-			n.flags = nodeFlag{dirty: true}
-			node, pos = n.Val, pos+len(n.Key)
+			// Special case, the non-existent proof points to the same path
+			// as the existent proof, but the path of existent proof is longer.
+			// In this case, truncate the extra path(it should be recovered
+			// by node insertion).
+			if len(left)-pos < len(rn.Key) || !bytes.Equal(rn.Key, left[pos:pos+len(rn.Key)]) {
+				fn := parent.(*fullNode)
+				fn.Children[left[pos-1]] = nil
+				return nil
+			}
+			rn.flags = nodeFlag{dirty: true}
+			parent = n
+			n, pos = rn.Val, pos+len(rn.Key)
 		case *fullNode:
-			n.flags = nodeFlag{dirty: true}
-			node, pos = n.Children[left[pos]], pos+1
+			rn.flags = nodeFlag{dirty: true}
+			parent = n
+			n, pos = rn.Children[right[pos]], pos+1
 		default:
-			panic(fmt.Sprintf("%T: invalid node: %v", node, node))
+			panic(fmt.Sprintf("%T: invalid node: %v", n, n))
 		}
 	}
-	fn, ok := node.(*fullNode)
+	fn, ok := n.(*fullNode)
 	if !ok {
 		return errors.New("the fork point must be a fullnode")
 	}
@@ -241,50 +260,164 @@ func unsetInternal(node node, left []byte, right []byte) error {
 		fn.Children[i] = nil
 	}
 	fn.flags = nodeFlag{dirty: true}
-	unset(fn.Children[left[prefix]], left[prefix+1:], false)
-	unset(fn.Children[right[prefix]], right[prefix+1:], true)
+	if err := unset(fn, fn.Children[left[prefix]], left[prefix:], 1, false); err != nil {
+		return err
+	}
+	if err := unset(fn, fn.Children[right[prefix]], right[prefix:], 1, true); err != nil {
+		return err
+	}
 	return nil
 }
 
 // unset removes all internal node references either the left most or right most.
-func unset(root node, rest []byte, removeLeft bool) {
-	switch rn := root.(type) {
+// If we try to unset all right most references, it can meet these scenarios:
+//
+// - The given path is existent in the trie, unset the associated shortnode
+// - The given path is non-existent in the trie
+//   - the fork point is a fullnode, the corresponding child pointed by path
+//     is nil, return
+//   - the fork point is a shortnode, the key of shortnode is less than path,
+//     keep the entire branch and return.
+//   - the fork point is a shortnode, the key of shortnode is greater than path,
+//     unset the entire branch.
+//
+// If we try to unset all left most references, then the given path should
+// be existent.
+func unset(parent node, child node, key []byte, pos int, removeLeft bool) error {
+	switch cld := child.(type) {
 	case *fullNode:
 		if removeLeft {
-			for i := 0; i < int(rest[0]); i++ {
-				rn.Children[i] = nil
+			for i := 0; i < int(key[pos]); i++ {
+				cld.Children[i] = nil
 			}
-			rn.flags = nodeFlag{dirty: true}
+			cld.flags = nodeFlag{dirty: true}
 		} else {
-			for i := rest[0] + 1; i < 16; i++ {
-				rn.Children[i] = nil
+			for i := key[pos] + 1; i < 16; i++ {
+				cld.Children[i] = nil
 			}
-			rn.flags = nodeFlag{dirty: true}
+			cld.flags = nodeFlag{dirty: true}
 		}
-		unset(rn.Children[rest[0]], rest[1:], removeLeft)
+		return unset(cld, cld.Children[key[pos]], key, pos+1, removeLeft)
 	case *shortNode:
-		rn.flags = nodeFlag{dirty: true}
-		if _, ok := rn.Val.(valueNode); ok {
-			rn.Val = nilValueNode
-			return
+		if len(key[pos:]) < len(cld.Key) || !bytes.Equal(cld.Key, key[pos:pos+len(cld.Key)]) {
+			// Find the fork point, it's an non-existent branch.
+			if removeLeft {
+				return errors.New("invalid right edge proof")
+			}
+			if bytes.Compare(cld.Key, key[pos:]) > 0 {
+				// The key of fork shortnode is greater than the
+				// path(it belongs to the range), unset the entrie
+				// branch. The parent must be a fullnode.
+				fn := parent.(*fullNode)
+				fn.Children[key[pos-1]] = nil
+			} else {
+				// The key of fork shortnode is less than the
+				// path(it doesn't belong to the range), keep
+				// it with the cached hash available.
+				return nil
+			}
+		}
+		if _, ok := cld.Val.(valueNode); ok {
+			fn := parent.(*fullNode)
+			fn.Children[key[pos-1]] = nil
+			return nil
 		}
-		unset(rn.Val, rest[len(rn.Key):], removeLeft)
-	case hashNode, nil, valueNode:
-		panic("it shouldn't happen")
+		cld.flags = nodeFlag{dirty: true}
+		return unset(cld, cld.Val, key, pos+len(cld.Key), removeLeft)
+	case nil:
+		// If the node is nil, it's a child of the fork point
+		// fullnode(it's an non-existent branch).
+		if removeLeft {
+			return errors.New("invalid right edge proof")
+		}
+		return nil
+	default:
+		panic("it shouldn't happen") // hashNode, valueNode
 	}
 }
 
-// VerifyRangeProof checks whether the given leave nodes and edge proofs
+// VerifyRangeProof checks whether the given leaf nodes and edge proofs
 // can prove the given trie leaves range is matched with given root hash
 // and the range is consecutive(no gap inside).
-func VerifyRangeProof(rootHash common.Hash, keys [][]byte, values [][]byte, firstProof ethdb.KeyValueReader, lastProof ethdb.KeyValueReader) error {
+//
+// Note the given first edge proof can be non-existing proof. For example
+// the first proof is for an non-existent values 0x03. The given batch
+// leaves are [0x04, 0x05, .. 0x09]. It's still feasible to prove. But the
+// last edge proof should always be an existent proof.
+//
+// The firstKey is paired with firstProof, not necessarily the same as keys[0]
+// (unless firstProof is an existent proof).
+//
+// Expect the normal case, this function can also be used to verify the following
+// range proofs:
+//
+// - All elements proof. In this case the left and right proof can be nil, but the
+//   range should be all the leaves in the trie.
+//
+// - Zero element proof(left edge proof should be a non-existent proof). In this
+//   case if there are still some other leaves available on the right side, then
+//   an error will be returned.
+//
+// - One element proof. In this case no matter the left edge proof is a non-existent
+//   proof or not, we can always verify the correctness of the proof.
+func VerifyRangeProof(rootHash common.Hash, firstKey []byte, keys [][]byte, values [][]byte, firstProof ethdb.KeyValueReader, lastProof ethdb.KeyValueReader) error {
 	if len(keys) != len(values) {
 		return fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
 	}
+	// Special case, there is no edge proof at all. The given range is expected
+	// to be the whole leaf-set in the trie.
+	if firstProof == nil && lastProof == nil {
+		emptytrie, err := New(common.Hash{}, NewDatabase(memorydb.New()))
+		if err != nil {
+			return err
+		}
+		for index, key := range keys {
+			emptytrie.TryUpdate(key, values[index])
+		}
+		if emptytrie.Hash() != rootHash {
+			return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, emptytrie.Hash())
+		}
+		return nil
+	}
+	// Special case, there is a provided non-existence proof and zero key/value
+	// pairs, meaning there are no more accounts / slots in the trie.
 	if len(keys) == 0 {
-		return fmt.Errorf("nothing to verify")
+		// Recover the non-existent proof to a path, ensure there is nothing left
+		root, err := proofToPath(rootHash, nil, firstKey, firstProof, true)
+		if err != nil {
+			return err
+		}
+		node, pos, firstKey := root, 0, keybytesToHex(firstKey)
+		for node != nil {
+			switch rn := node.(type) {
+			case *fullNode:
+				for i := firstKey[pos] + 1; i < 16; i++ {
+					if rn.Children[i] != nil {
+						return errors.New("more leaves available")
+					}
+				}
+				node, pos = rn.Children[firstKey[pos]], pos+1
+			case *shortNode:
+				if len(firstKey)-pos < len(rn.Key) || !bytes.Equal(rn.Key, firstKey[pos:pos+len(rn.Key)]) {
+					if bytes.Compare(rn.Key, firstKey[pos:]) < 0 {
+						node = nil
+						continue
+					} else {
+						return errors.New("more leaves available")
+					}
+				}
+				node, pos = rn.Val, pos+len(rn.Key)
+			case valueNode, hashNode:
+				return errors.New("more leaves available")
+			}
+		}
+		// Yeah, although we receive nothing, but we can prove
+		// there is no more leaf in the trie, return nil.
+		return nil
 	}
-	if len(keys) == 1 {
+	// Special case, there is only one element and left edge
+	// proof is an existent one.
+	if len(keys) == 1 && bytes.Equal(keys[0], firstKey) {
 		value, err := VerifyProof(rootHash, keys[0], firstProof)
 		if err != nil {
 			return err
@@ -296,19 +429,21 @@ func VerifyRangeProof(rootHash common.Hash, keys [][]byte, values [][]byte, firs
 	}
 	// Convert the edge proofs to edge trie paths. Then we can
 	// have the same tree architecture with the original one.
-	root, err := proofToPath(rootHash, nil, keys[0], firstProof)
+	// For the first edge proof, non-existent proof is allowed.
+	root, err := proofToPath(rootHash, nil, firstKey, firstProof, true)
 	if err != nil {
 		return err
 	}
 	// Pass the root node here, the second path will be merged
-	// with the first one.
-	root, err = proofToPath(rootHash, root, keys[len(keys)-1], lastProof)
+	// with the first one. For the last edge proof, non-existent
+	// proof is not allowed.
+	root, err = proofToPath(rootHash, root, keys[len(keys)-1], lastProof, false)
 	if err != nil {
 		return err
 	}
 	// Remove all internal references. All the removed parts should
 	// be re-filled(or re-constructed) by the given leaves range.
-	if err := unsetInternal(root, keys[0], keys[len(keys)-1]); err != nil {
+	if err := unsetInternal(root, firstKey, keys[len(keys)-1]); err != nil {
 		return err
 	}
 	// Rebuild the trie with the leave stream, the shape of trie
@@ -318,7 +453,7 @@ func VerifyRangeProof(rootHash common.Hash, keys [][]byte, values [][]byte, firs
 		newtrie.TryUpdate(key, values[index])
 	}
 	if newtrie.Hash() != rootHash {
-		return fmt.Errorf("invalid proof, wanthash %x, got %x", rootHash, newtrie.Hash())
+		return fmt.Errorf("invalid proof, want hash %x, got %x", rootHash, newtrie.Hash())
 	}
 	return nil
 }
diff --git a/trie/proof_test.go b/trie/proof_test.go
index ea02c289e..a68503f7d 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -98,12 +98,65 @@ func TestOneElementProof(t *testing.T) {
 	}
 }
 
+func TestBadProof(t *testing.T) {
+	trie, vals := randomTrie(800)
+	root := trie.Hash()
+	for i, prover := range makeProvers(trie) {
+		for _, kv := range vals {
+			proof := prover(kv.k)
+			if proof == nil {
+				t.Fatalf("prover %d: nil proof", i)
+			}
+			it := proof.NewIterator(nil, nil)
+			for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ {
+				it.Next()
+			}
+			key := it.Key()
+			val, _ := proof.Get(key)
+			proof.Delete(key)
+			it.Release()
+
+			mutateByte(val)
+			proof.Put(crypto.Keccak256(val), val)
+
+			if _, err := VerifyProof(root, kv.k, proof); err == nil {
+				t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
+			}
+		}
+	}
+}
+
+// Tests that missing keys can also be proven. The test explicitly uses a single
+// entry trie and checks for missing keys both before and after the single entry.
+func TestMissingKeyProof(t *testing.T) {
+	trie := new(Trie)
+	updateString(trie, "k", "v")
+
+	for i, key := range []string{"a", "j", "l", "z"} {
+		proof := memorydb.New()
+		trie.Prove([]byte(key), 0, proof)
+
+		if proof.Len() != 1 {
+			t.Errorf("test %d: proof should have one element", i)
+		}
+		val, err := VerifyProof(trie.Hash(), []byte(key), proof)
+		if err != nil {
+			t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
+		}
+		if val != nil {
+			t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val)
+		}
+	}
+}
+
 type entrySlice []*kv
 
 func (p entrySlice) Len() int           { return len(p) }
 func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
 func (p entrySlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
 
+// TestRangeProof tests normal range proof with both edge proofs
+// as the existent proof. The test cases are generated randomly.
 func TestRangeProof(t *testing.T) {
 	trie, vals := randomTrie(4096)
 	var entries entrySlice
@@ -130,13 +183,253 @@ func TestRangeProof(t *testing.T) {
 			keys = append(keys, entries[i].k)
 			vals = append(vals, entries[i].v)
 		}
-		err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof)
+		err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, firstProof, lastProof)
+		if err != nil {
+			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
+		}
+	}
+}
+
+// TestRangeProof tests normal range proof with the first edge proof
+// as the non-existent proof. The test cases are generated randomly.
+func TestRangeProofWithNonExistentProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+	for i := 0; i < 500; i++ {
+		start := mrand.Intn(len(entries))
+		end := mrand.Intn(len(entries)-start) + start
+		if start == end {
+			continue
+		}
+		firstProof, lastProof := memorydb.New(), memorydb.New()
+
+		first := decreseKey(common.CopyBytes(entries[start].k))
+		if start != 0 && bytes.Equal(first, entries[start-1].k) {
+			continue
+		}
+		if err := trie.Prove(first, 0, firstProof); err != nil {
+			t.Fatalf("Failed to prove the first node %v", err)
+		}
+		if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil {
+			t.Fatalf("Failed to prove the last node %v", err)
+		}
+		var keys [][]byte
+		var vals [][]byte
+		for i := start; i < end; i++ {
+			keys = append(keys, entries[i].k)
+			vals = append(vals, entries[i].v)
+		}
+		err := VerifyRangeProof(trie.Hash(), first, keys, vals, firstProof, lastProof)
 		if err != nil {
 			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
 	}
 }
 
+// TestRangeProofWithInvalidNonExistentProof tests such scenarios:
+// - The last edge proof is an non-existent proof
+// - There exists a gap between the first element and the left edge proof
+func TestRangeProofWithInvalidNonExistentProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	// Case 1
+	start, end := 100, 200
+	first, last := decreseKey(common.CopyBytes(entries[start].k)), increseKey(common.CopyBytes(entries[end].k))
+	firstProof, lastProof := memorydb.New(), memorydb.New()
+	if err := trie.Prove(first, 0, firstProof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(last, 0, lastProof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	var k [][]byte
+	var v [][]byte
+	for i := start; i < end; i++ {
+		k = append(k, entries[i].k)
+		v = append(v, entries[i].v)
+	}
+	err := VerifyRangeProof(trie.Hash(), first, k, v, firstProof, lastProof)
+	if err == nil {
+		t.Fatalf("Expected to detect the error, got nil")
+	}
+
+	// Case 2
+	start, end = 100, 200
+	first = decreseKey(common.CopyBytes(entries[start].k))
+
+	firstProof, lastProof = memorydb.New(), memorydb.New()
+	if err := trie.Prove(first, 0, firstProof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	start = 105 // Gap created
+	k = make([][]byte, 0)
+	v = make([][]byte, 0)
+	for i := start; i < end; i++ {
+		k = append(k, entries[i].k)
+		v = append(v, entries[i].v)
+	}
+	err = VerifyRangeProof(trie.Hash(), first, k, v, firstProof, lastProof)
+	if err == nil {
+		t.Fatalf("Expected to detect the error, got nil")
+	}
+}
+
+// TestOneElementRangeProof tests the proof with only one
+// element. The first edge proof can be existent one or
+// non-existent one.
+func TestOneElementRangeProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	// One element with existent edge proof
+	start := 1000
+	firstProof, lastProof := memorydb.New(), memorydb.New()
+	if err := trie.Prove(entries[start].k, 0, firstProof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(entries[start].k, 0, lastProof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	err := VerifyRangeProof(trie.Hash(), entries[start].k, [][]byte{entries[start].k}, [][]byte{entries[start].v}, firstProof, lastProof)
+	if err != nil {
+		t.Fatalf("Expected no error, got %v", err)
+	}
+
+	// One element with non-existent edge proof
+	start = 1000
+	first := decreseKey(common.CopyBytes(entries[start].k))
+	firstProof, lastProof = memorydb.New(), memorydb.New()
+	if err := trie.Prove(first, 0, firstProof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(entries[start].k, 0, lastProof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	err = VerifyRangeProof(trie.Hash(), first, [][]byte{entries[start].k}, [][]byte{entries[start].v}, firstProof, lastProof)
+	if err != nil {
+		t.Fatalf("Expected no error, got %v", err)
+	}
+}
+
+// TestEmptyRangeProof tests the range proof with "no" element.
+// The first edge proof must be a non-existent proof.
+func TestEmptyRangeProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	var cases = []struct {
+		pos int
+		err bool
+	}{
+		{len(entries) - 1, false},
+		{500, true},
+	}
+	for _, c := range cases {
+		firstProof := memorydb.New()
+		first := increseKey(common.CopyBytes(entries[c.pos].k))
+		if err := trie.Prove(first, 0, firstProof); err != nil {
+			t.Fatalf("Failed to prove the first node %v", err)
+		}
+		err := VerifyRangeProof(trie.Hash(), first, nil, nil, firstProof, nil)
+		if c.err && err == nil {
+			t.Fatalf("Expected error, got nil")
+		}
+		if !c.err && err != nil {
+			t.Fatalf("Expected no error, got %v", err)
+		}
+	}
+}
+
+// TestAllElementsProof tests the range proof with all elements.
+// The edge proofs can be nil.
+func TestAllElementsProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	var k [][]byte
+	var v [][]byte
+	for i := 0; i < len(entries); i++ {
+		k = append(k, entries[i].k)
+		v = append(v, entries[i].v)
+	}
+	err := VerifyRangeProof(trie.Hash(), k[0], k, v, nil, nil)
+	if err != nil {
+		t.Fatalf("Expected no error, got %v", err)
+	}
+
+	// Even with edge proofs, it should still work.
+	firstProof, lastProof := memorydb.New(), memorydb.New()
+	if err := trie.Prove(entries[0].k, 0, firstProof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(entries[len(entries)-1].k, 0, lastProof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	err = VerifyRangeProof(trie.Hash(), k[0], k, v, firstProof, lastProof)
+	if err != nil {
+		t.Fatalf("Expected no error, got %v", err)
+	}
+}
+
+// TestSingleSideRangeProof tests the range starts from zero.
+func TestSingleSideRangeProof(t *testing.T) {
+	trie := new(Trie)
+	var entries entrySlice
+	for i := 0; i < 4096; i++ {
+		value := &kv{randBytes(32), randBytes(20), false}
+		trie.Update(value.k, value.v)
+		entries = append(entries, value)
+	}
+	sort.Sort(entries)
+
+	var cases = []int{0, 1, 50, 100, 1000, 2000, len(entries) - 1}
+	for _, pos := range cases {
+		firstProof, lastProof := memorydb.New(), memorydb.New()
+		if err := trie.Prove(common.Hash{}.Bytes(), 0, firstProof); err != nil {
+			t.Fatalf("Failed to prove the first node %v", err)
+		}
+		if err := trie.Prove(entries[pos].k, 0, lastProof); err != nil {
+			t.Fatalf("Failed to prove the first node %v", err)
+		}
+		k := make([][]byte, 0)
+		v := make([][]byte, 0)
+		for i := 0; i <= pos; i++ {
+			k = append(k, entries[i].k)
+			v = append(v, entries[i].v)
+		}
+		err := VerifyRangeProof(trie.Hash(), common.Hash{}.Bytes(), k, v, firstProof, lastProof)
+		if err != nil {
+			t.Fatalf("Expected no error, got %v", err)
+		}
+	}
+}
+
+// TestBadRangeProof tests a few cases which the proof is wrong.
+// The prover is expected to detect the error.
 func TestBadRangeProof(t *testing.T) {
 	trie, vals := randomTrie(4096)
 	var entries entrySlice
@@ -208,7 +501,7 @@ func TestBadRangeProof(t *testing.T) {
 			index = mrand.Intn(end - start)
 			vals[index] = nil
 		}
-		err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof)
+		err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, firstProof, lastProof)
 		if err == nil {
 			t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
 		}
@@ -242,72 +535,41 @@ func TestGappedRangeProof(t *testing.T) {
 		keys = append(keys, entries[i].k)
 		vals = append(vals, entries[i].v)
 	}
-	err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof)
+	err := VerifyRangeProof(trie.Hash(), keys[0], keys, vals, firstProof, lastProof)
 	if err == nil {
 		t.Fatal("expect error, got nil")
 	}
 }
 
-func TestBadProof(t *testing.T) {
-	trie, vals := randomTrie(800)
-	root := trie.Hash()
-	for i, prover := range makeProvers(trie) {
-		for _, kv := range vals {
-			proof := prover(kv.k)
-			if proof == nil {
-				t.Fatalf("prover %d: nil proof", i)
-			}
-			it := proof.NewIterator(nil, nil)
-			for i, d := 0, mrand.Intn(proof.Len()); i <= d; i++ {
-				it.Next()
-			}
-			key := it.Key()
-			val, _ := proof.Get(key)
-			proof.Delete(key)
-			it.Release()
-
-			mutateByte(val)
-			proof.Put(crypto.Keccak256(val), val)
-
-			if _, err := VerifyProof(root, kv.k, proof); err == nil {
-				t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
-			}
+// mutateByte changes one byte in b.
+func mutateByte(b []byte) {
+	for r := mrand.Intn(len(b)); ; {
+		new := byte(mrand.Intn(255))
+		if new != b[r] {
+			b[r] = new
+			break
 		}
 	}
 }
 
-// Tests that missing keys can also be proven. The test explicitly uses a single
-// entry trie and checks for missing keys both before and after the single entry.
-func TestMissingKeyProof(t *testing.T) {
-	trie := new(Trie)
-	updateString(trie, "k", "v")
-
-	for i, key := range []string{"a", "j", "l", "z"} {
-		proof := memorydb.New()
-		trie.Prove([]byte(key), 0, proof)
-
-		if proof.Len() != 1 {
-			t.Errorf("test %d: proof should have one element", i)
-		}
-		val, err := VerifyProof(trie.Hash(), []byte(key), proof)
-		if err != nil {
-			t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
-		}
-		if val != nil {
-			t.Fatalf("test %d: verified value mismatch: have %x, want nil", i, val)
+func increseKey(key []byte) []byte {
+	for i := len(key) - 1; i >= 0; i-- {
+		key[i]++
+		if key[i] != 0x0 {
+			break
 		}
 	}
+	return key
 }
 
-// mutateByte changes one byte in b.
-func mutateByte(b []byte) {
-	for r := mrand.Intn(len(b)); ; {
-		new := byte(mrand.Intn(255))
-		if new != b[r] {
-			b[r] = new
+func decreseKey(key []byte) []byte {
+	for i := len(key) - 1; i >= 0; i-- {
+		key[i]--
+		if key[i] != 0xff {
 			break
 		}
 	}
+	return key
 }
 
 func BenchmarkProve(b *testing.B) {
@@ -379,7 +641,7 @@ func benchmarkVerifyRangeProof(b *testing.B, size int) {
 
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
-		err := VerifyRangeProof(trie.Hash(), keys, values, firstProof, lastProof)
+		err := VerifyRangeProof(trie.Hash(), keys[0], keys, values, firstProof, lastProof)
 		if err != nil {
 			b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
 		}
-- 
GitLab