From 9e1bd0f3671d19d4964ed8c8a95edfd12413d8c3 Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Fri, 22 Jan 2021 17:11:24 +0800
Subject: [PATCH] trie: fix range prover (#22210)

Fixes a special case when the trie only has a single trie node and the range proof only contains a single element.
---
 trie/proof.go      | 44 +++++++++++++++++++++++++++++---------------
 trie/proof_test.go | 19 +++++++++++++++++++
 2 files changed, 48 insertions(+), 15 deletions(-)

diff --git a/trie/proof.go b/trie/proof.go
index e7102f12b..61c35a842 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -216,7 +216,7 @@ func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyV
 //
 // Note we have the assumption here the given boundary keys are different
 // and right is larger than left.
-func unsetInternal(n node, left []byte, right []byte) error {
+func unsetInternal(n node, left []byte, right []byte) (bool, error) {
 	left, right = keybytesToHex(left), keybytesToHex(right)
 
 	// Step down to the fork point. There are two scenarios can happen:
@@ -278,45 +278,55 @@ findFork:
 		// - left proof points to the shortnode, but right proof is greater
 		// - right proof points to the shortnode, but left proof is less
 		if shortForkLeft == -1 && shortForkRight == -1 {
-			return errors.New("empty range")
+			return false, errors.New("empty range")
 		}
 		if shortForkLeft == 1 && shortForkRight == 1 {
-			return errors.New("empty range")
+			return false, errors.New("empty range")
 		}
 		if shortForkLeft != 0 && shortForkRight != 0 {
+			// The fork point is root node, unset the entire trie
+			if parent == nil {
+				return true, nil
+			}
 			parent.(*fullNode).Children[left[pos-1]] = nil
-			return nil
+			return false, nil
 		}
 		// Only one proof points to non-existent key.
 		if shortForkRight != 0 {
-			// Unset left proof's path
 			if _, ok := rn.Val.(valueNode); ok {
+				// The fork point is root node, unset the entire trie
+				if parent == nil {
+					return true, nil
+				}
 				parent.(*fullNode).Children[left[pos-1]] = nil
-				return nil
+				return false, nil
 			}
-			return unset(rn, rn.Val, left[pos:], len(rn.Key), false)
+			return false, unset(rn, rn.Val, left[pos:], len(rn.Key), false)
 		}
 		if shortForkLeft != 0 {
-			// Unset right proof's path.
 			if _, ok := rn.Val.(valueNode); ok {
+				// The fork point is root node, unset the entire trie
+				if parent == nil {
+					return true, nil
+				}
 				parent.(*fullNode).Children[right[pos-1]] = nil
-				return nil
+				return false, nil
 			}
-			return unset(rn, rn.Val, right[pos:], len(rn.Key), true)
+			return false, unset(rn, rn.Val, right[pos:], len(rn.Key), true)
 		}
-		return nil
+		return false, nil
 	case *fullNode:
 		// unset all internal nodes in the forkpoint
 		for i := left[pos] + 1; i < right[pos]; i++ {
 			rn.Children[i] = nil
 		}
 		if err := unset(rn, rn.Children[left[pos]], left[pos:], 1, false); err != nil {
-			return err
+			return false, err
 		}
 		if err := unset(rn, rn.Children[right[pos]], right[pos:], 1, true); err != nil {
-			return err
+			return false, err
 		}
-		return nil
+		return false, nil
 	default:
 		panic(fmt.Sprintf("%T: invalid node: %v", n, n))
 	}
@@ -560,7 +570,8 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
 	}
 	// Remove all internal references. All the removed parts should
 	// be re-filled(or re-constructed) by the given leaves range.
-	if err := unsetInternal(root, firstKey, lastKey); err != nil {
+	empty, err := unsetInternal(root, firstKey, lastKey)
+	if err != nil {
 		return nil, nil, nil, false, err
 	}
 	// Rebuild the trie with the leaf stream, the shape of trie
@@ -570,6 +581,9 @@ func VerifyRangeProof(rootHash common.Hash, firstKey []byte, lastKey []byte, key
 		triedb = NewDatabase(diskdb)
 	)
 	tr := &Trie{root: root, db: triedb}
+	if empty {
+		tr.root = nil
+	}
 	for index, key := range keys {
 		tr.TryUpdate(key, values[index])
 	}
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 3ecd31888..304affa9f 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -384,6 +384,25 @@ func TestOneElementRangeProof(t *testing.T) {
 	if err != nil {
 		t.Fatalf("Expected no error, got %v", err)
 	}
+
+	// Test the mini trie with only a single element.
+	tinyTrie := new(Trie)
+	entry := &kv{randBytes(32), randBytes(20), false}
+	tinyTrie.Update(entry.k, entry.v)
+
+	first = common.HexToHash("0x0000000000000000000000000000000000000000000000000000000000000000").Bytes()
+	last = entry.k
+	proof = memorydb.New()
+	if err := tinyTrie.Prove(first, 0, proof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := tinyTrie.Prove(last, 0, proof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	_, _, _, _, err = VerifyRangeProof(tinyTrie.Hash(), first, last, [][]byte{entry.k}, [][]byte{entry.v}, proof)
+	if err != nil {
+		t.Fatalf("Expected no error, got %v", err)
+	}
 }
 
 // TestAllElementsProof tests the range proof with all elements.
-- 
GitLab