From 44ff3f3dc98a5bd72e06ff6b05739c2dce8c9b62 Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Fri, 24 Apr 2020 19:37:56 +0800
Subject: [PATCH] trie: initial implementation for range proof (#20908)

* trie: initial implementation for range proof

* trie: add benchmark

* trie: fix lint

* trie: fix minor issue

* trie: unset the edge valuenode as well

* trie: unset the edge valuenode as nilValuenode
---
 les/odr_requests.go |   6 +-
 trie/proof.go       | 221 ++++++++++++++++++++++++++++++++++++++++++--
 trie/proof_test.go  | 188 ++++++++++++++++++++++++++++++++++++-
 3 files changed, 400 insertions(+), 15 deletions(-)

diff --git a/les/odr_requests.go b/les/odr_requests.go
index 146da2213..c4b38060c 100644
--- a/les/odr_requests.go
+++ b/les/odr_requests.go
@@ -224,7 +224,7 @@ func (r *TrieRequest) Validate(db ethdb.Database, msg *Msg) error {
 	// Verify the proof and store if checks out
 	nodeSet := proofs.NodeSet()
 	reads := &readTraceDB{db: nodeSet}
-	if _, _, err := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil {
+	if _, err := trie.VerifyProof(r.Id.Root, r.Key, reads); err != nil {
 		return fmt.Errorf("merkle proof verification failed: %v", err)
 	}
 	// check if all nodes have been read by VerifyProof
@@ -378,7 +378,7 @@ func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error {
 		binary.BigEndian.PutUint64(encNumber[:], r.BlockNum)
 
 		reads := &readTraceDB{db: nodeSet}
-		value, _, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads)
+		value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads)
 		if err != nil {
 			return fmt.Errorf("merkle proof verification failed: %v", err)
 		}
@@ -470,7 +470,7 @@ func (r *BloomRequest) Validate(db ethdb.Database, msg *Msg) error {
 
 	for i, idx := range r.SectionIndexList {
 		binary.BigEndian.PutUint64(encNumber[2:], idx)
-		value, _, err := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads)
+		value, err := trie.VerifyProof(r.BloomTrieRoot, encNumber[:], reads)
 		if err != nil {
 			return err
 		}
diff --git a/trie/proof.go b/trie/proof.go
index 58ca69c68..07ce8e6d8 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -18,10 +18,12 @@ package trie
 
 import (
 	"bytes"
+	"errors"
 	"fmt"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/ethdb/memorydb"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/rlp"
 )
@@ -101,33 +103,232 @@ func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWri
 // VerifyProof checks merkle proofs. The given proof must contain the value for
 // key in a trie with the given root hash. VerifyProof returns an error if the
 // proof contains invalid trie nodes or the wrong value.
-func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, nodes int, err error) {
+func VerifyProof(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueReader) (value []byte, err error) {
 	key = keybytesToHex(key)
 	wantHash := rootHash
 	for i := 0; ; i++ {
 		buf, _ := proofDb.Get(wantHash[:])
 		if buf == nil {
-			return nil, i, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash)
+			return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash)
 		}
 		n, err := decodeNode(wantHash[:], buf)
 		if err != nil {
-			return nil, i, fmt.Errorf("bad proof node %d: %v", i, err)
+			return nil, fmt.Errorf("bad proof node %d: %v", i, err)
 		}
-		keyrest, cld := get(n, key)
+		keyrest, cld := get(n, key, true)
 		switch cld := cld.(type) {
 		case nil:
 			// The trie doesn't contain the key.
-			return nil, i, nil
+			return nil, nil
 		case hashNode:
 			key = keyrest
 			copy(wantHash[:], cld)
 		case valueNode:
-			return cld, i + 1, nil
+			return cld, nil
 		}
 	}
 }
 
-func get(tn node, key []byte) ([]byte, node) {
+// proofToPath converts a merkle proof to trie node path.
+// The main purpose of this function is recovering a node
+// path from the merkle proof stream. All necessary nodes
+// will be resolved and leave the remaining as hashnode.
+func proofToPath(rootHash common.Hash, root node, key []byte, proofDb ethdb.KeyValueReader) (node, error) {
+	// resolveNode retrieves and resolves trie node from merkle proof stream
+	resolveNode := func(hash common.Hash) (node, error) {
+		buf, _ := proofDb.Get(hash[:])
+		if buf == nil {
+			return nil, fmt.Errorf("proof node (hash %064x) missing", hash)
+		}
+		n, err := decodeNode(hash[:], buf)
+		if err != nil {
+			return nil, fmt.Errorf("bad proof node %v", err)
+		}
+		return n, err
+	}
+	// If the root node is empty, resolve it first
+	if root == nil {
+		n, err := resolveNode(rootHash)
+		if err != nil {
+			return nil, err
+		}
+		root = n
+	}
+	var (
+		err           error
+		child, parent node
+		keyrest       []byte
+		terminate     bool
+	)
+	key, parent = keybytesToHex(key), root
+	for {
+		keyrest, child = get(parent, key, false)
+		switch cld := child.(type) {
+		case nil:
+			// The trie doesn't contain the key.
+			return nil, errors.New("the node is not contained in trie")
+		case *shortNode:
+			key, parent = keyrest, child // Already resolved
+			continue
+		case *fullNode:
+			key, parent = keyrest, child // Already resolved
+			continue
+		case hashNode:
+			child, err = resolveNode(common.BytesToHash(cld))
+			if err != nil {
+				return nil, err
+			}
+		case valueNode:
+			terminate = true
+		}
+		// Link the parent and child.
+		switch pnode := parent.(type) {
+		case *shortNode:
+			pnode.Val = child
+		case *fullNode:
+			pnode.Children[key[0]] = child
+		default:
+			panic(fmt.Sprintf("%T: invalid node: %v", pnode, pnode))
+		}
+		if terminate {
+			return root, nil // The whole path is resolved
+		}
+		key, parent = keyrest, child
+	}
+}
+
+// unsetInternal removes all internal node references(hashnode, embedded node).
+// It should be called after a trie is constructed with two edge proofs. Also
+// the given boundary keys must be the one used to construct the edge proofs.
+//
+// It's the key step for range proof. All visited nodes should be marked dirty
+// since the node content might be modified. Besides it can happen that some
+// fullnodes only have one child which is disallowed. But if the proof is valid,
+// the missing children will be filled, otherwise it will be thrown anyway.
+func unsetInternal(node node, left []byte, right []byte) error {
+	left, right = keybytesToHex(left), keybytesToHex(right)
+
+	// todo(rjl493456442) different length edge keys should be supported
+	if len(left) != len(right) {
+		return errors.New("inconsistent edge path")
+	}
+	// Step down to the fork point
+	prefix, pos := prefixLen(left, right), 0
+	for {
+		if pos >= prefix {
+			break
+		}
+		switch n := (node).(type) {
+		case *shortNode:
+			if len(left)-pos < len(n.Key) || !bytes.Equal(n.Key, left[pos:pos+len(n.Key)]) {
+				return errors.New("invalid edge path")
+			}
+			n.flags = nodeFlag{dirty: true}
+			node, pos = n.Val, pos+len(n.Key)
+		case *fullNode:
+			n.flags = nodeFlag{dirty: true}
+			node, pos = n.Children[left[pos]], pos+1
+		default:
+			panic(fmt.Sprintf("%T: invalid node: %v", node, node))
+		}
+	}
+	fn, ok := node.(*fullNode)
+	if !ok {
+		return errors.New("the fork point must be a fullnode")
+	}
+	// Find the fork point! Unset all intermediate references
+	for i := left[prefix] + 1; i < right[prefix]; i++ {
+		fn.Children[i] = nil
+	}
+	fn.flags = nodeFlag{dirty: true}
+	unset(fn.Children[left[prefix]], left[prefix+1:], false)
+	unset(fn.Children[right[prefix]], right[prefix+1:], true)
+	return nil
+}
+
+// unset removes all internal node references either the left most or right most.
+func unset(root node, rest []byte, removeLeft bool) {
+	switch rn := root.(type) {
+	case *fullNode:
+		if removeLeft {
+			for i := 0; i < int(rest[0]); i++ {
+				rn.Children[i] = nil
+			}
+			rn.flags = nodeFlag{dirty: true}
+		} else {
+			for i := rest[0] + 1; i < 16; i++ {
+				rn.Children[i] = nil
+			}
+			rn.flags = nodeFlag{dirty: true}
+		}
+		unset(rn.Children[rest[0]], rest[1:], removeLeft)
+	case *shortNode:
+		rn.flags = nodeFlag{dirty: true}
+		if _, ok := rn.Val.(valueNode); ok {
+			rn.Val = nilValueNode
+			return
+		}
+		unset(rn.Val, rest[len(rn.Key):], removeLeft)
+	case hashNode, nil, valueNode:
+		panic("it shouldn't happen")
+	}
+}
+
+// VerifyRangeProof checks whether the given leave nodes and edge proofs
+// can prove the given trie leaves range is matched with given root hash
+// and the range is consecutive(no gap inside).
+func VerifyRangeProof(rootHash common.Hash, keys [][]byte, values [][]byte, firstProof ethdb.KeyValueReader, lastProof ethdb.KeyValueReader) error {
+	if len(keys) != len(values) {
+		return fmt.Errorf("inconsistent proof data, keys: %d, values: %d", len(keys), len(values))
+	}
+	if len(keys) == 0 {
+		return fmt.Errorf("nothing to verify")
+	}
+	if len(keys) == 1 {
+		value, err := VerifyProof(rootHash, keys[0], firstProof)
+		if err != nil {
+			return err
+		}
+		if !bytes.Equal(value, values[0]) {
+			return fmt.Errorf("correct proof but invalid data")
+		}
+		return nil
+	}
+	// Convert the edge proofs to edge trie paths. Then we can
+	// have the same tree architecture with the original one.
+	root, err := proofToPath(rootHash, nil, keys[0], firstProof)
+	if err != nil {
+		return err
+	}
+	// Pass the root node here, the second path will be merged
+	// with the first one.
+	root, err = proofToPath(rootHash, root, keys[len(keys)-1], lastProof)
+	if err != nil {
+		return err
+	}
+	// Remove all internal references. All the removed parts should
+	// be re-filled(or re-constructed) by the given leaves range.
+	if err := unsetInternal(root, keys[0], keys[len(keys)-1]); err != nil {
+		return err
+	}
+	// Rebuild the trie with the leave stream, the shape of trie
+	// should be same with the original one.
+	newtrie := &Trie{root: root, db: NewDatabase(memorydb.New())}
+	for index, key := range keys {
+		newtrie.TryUpdate(key, values[index])
+	}
+	if newtrie.Hash() != rootHash {
+		return fmt.Errorf("invalid proof, wanthash %x, got %x", rootHash, newtrie.Hash())
+	}
+	return nil
+}
+
+// get returns the child of the given node. Return nil if the
+// node with specified key doesn't exist at all.
+//
+// There is an additional flag `skipResolved`. If it's set then
+// all resolved nodes won't be returned.
+func get(tn node, key []byte, skipResolved bool) ([]byte, node) {
 	for {
 		switch n := tn.(type) {
 		case *shortNode:
@@ -136,9 +337,15 @@ func get(tn node, key []byte) ([]byte, node) {
 			}
 			tn = n.Val
 			key = key[len(n.Key):]
+			if !skipResolved {
+				return key, tn
+			}
 		case *fullNode:
 			tn = n.Children[key[0]]
 			key = key[1:]
+			if !skipResolved {
+				return key, tn
+			}
 		case hashNode:
 			return key, n
 		case nil:
diff --git a/trie/proof_test.go b/trie/proof_test.go
index 4caae7338..781702b87 100644
--- a/trie/proof_test.go
+++ b/trie/proof_test.go
@@ -20,6 +20,7 @@ import (
 	"bytes"
 	crand "crypto/rand"
 	mrand "math/rand"
+	"sort"
 	"testing"
 	"time"
 
@@ -65,7 +66,7 @@ func TestProof(t *testing.T) {
 			if proof == nil {
 				t.Fatalf("prover %d: missing key %x while constructing proof", i, kv.k)
 			}
-			val, _, err := VerifyProof(root, kv.k, proof)
+			val, err := VerifyProof(root, kv.k, proof)
 			if err != nil {
 				t.Fatalf("prover %d: failed to verify proof for key %x: %v\nraw proof: %x", i, kv.k, err, proof)
 			}
@@ -87,7 +88,7 @@ func TestOneElementProof(t *testing.T) {
 		if proof.Len() != 1 {
 			t.Errorf("prover %d: proof should have one element", i)
 		}
-		val, _, err := VerifyProof(trie.Hash(), []byte("k"), proof)
+		val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
 		if err != nil {
 			t.Fatalf("prover %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
 		}
@@ -97,6 +98,145 @@ func TestOneElementProof(t *testing.T) {
 	}
 }
 
+type entrySlice []*kv
+
+func (p entrySlice) Len() int           { return len(p) }
+func (p entrySlice) Less(i, j int) bool { return bytes.Compare(p[i].k, p[j].k) < 0 }
+func (p entrySlice) Swap(i, j int)      { p[i], p[j] = p[j], p[i] }
+
+func TestRangeProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+	for i := 0; i < 500; i++ {
+		start := mrand.Intn(len(entries))
+		end := mrand.Intn(len(entries)-start) + start
+		if start == end {
+			continue
+		}
+		firstProof, lastProof := memorydb.New(), memorydb.New()
+		if err := trie.Prove(entries[start].k, 0, firstProof); err != nil {
+			t.Fatalf("Failed to prove the first node %v", err)
+		}
+		if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil {
+			t.Fatalf("Failed to prove the last node %v", err)
+		}
+		var keys [][]byte
+		var vals [][]byte
+		for i := start; i < end; i++ {
+			keys = append(keys, entries[i].k)
+			vals = append(vals, entries[i].v)
+		}
+		err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof)
+		if err != nil {
+			t.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
+		}
+	}
+}
+
+func TestBadRangeProof(t *testing.T) {
+	trie, vals := randomTrie(4096)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	for i := 0; i < 500; i++ {
+		start := mrand.Intn(len(entries))
+		end := mrand.Intn(len(entries)-start) + start
+		if start == end {
+			continue
+		}
+		firstProof, lastProof := memorydb.New(), memorydb.New()
+		if err := trie.Prove(entries[start].k, 0, firstProof); err != nil {
+			t.Fatalf("Failed to prove the first node %v", err)
+		}
+		if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil {
+			t.Fatalf("Failed to prove the last node %v", err)
+		}
+		var keys [][]byte
+		var vals [][]byte
+		for i := start; i < end; i++ {
+			keys = append(keys, entries[i].k)
+			vals = append(vals, entries[i].v)
+		}
+		testcase := mrand.Intn(6)
+		var index int
+		switch testcase {
+		case 0:
+			// Modified key
+			index = mrand.Intn(end - start)
+			keys[index] = randBytes(32) // In theory it can't be same
+		case 1:
+			// Modified val
+			index = mrand.Intn(end - start)
+			vals[index] = randBytes(20) // In theory it can't be same
+		case 2:
+			// Gapped entry slice
+			index = mrand.Intn(end - start)
+			keys = append(keys[:index], keys[index+1:]...)
+			vals = append(vals[:index], vals[index+1:]...)
+			if len(keys) <= 1 {
+				continue
+			}
+		case 3:
+			// Switched entry slice, same effect with gapped
+			index = mrand.Intn(end - start)
+			keys[index] = entries[len(entries)-1].k
+			vals[index] = entries[len(entries)-1].v
+		case 4:
+			// Set random key to nil
+			index = mrand.Intn(end - start)
+			keys[index] = nil
+		case 5:
+			// Set random value to nil
+			index = mrand.Intn(end - start)
+			vals[index] = nil
+		}
+		err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof)
+		if err == nil {
+			t.Fatalf("%d Case %d index %d range: (%d->%d) expect error, got nil", i, testcase, index, start, end-1)
+		}
+	}
+}
+
+// TestGappedRangeProof focuses on the small trie with embedded nodes.
+// If the gapped node is embedded in the trie, it should be detected too.
+func TestGappedRangeProof(t *testing.T) {
+	trie := new(Trie)
+	var entries []*kv // Sorted entries
+	for i := byte(0); i < 10; i++ {
+		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+		trie.Update(value.k, value.v)
+		entries = append(entries, value)
+	}
+	first, last := 2, 8
+	firstProof, lastProof := memorydb.New(), memorydb.New()
+	if err := trie.Prove(entries[first].k, 0, firstProof); err != nil {
+		t.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(entries[last-1].k, 0, lastProof); err != nil {
+		t.Fatalf("Failed to prove the last node %v", err)
+	}
+	var keys [][]byte
+	var vals [][]byte
+	for i := first; i < last; i++ {
+		if i == (first+last)/2 {
+			continue
+		}
+		keys = append(keys, entries[i].k)
+		vals = append(vals, entries[i].v)
+	}
+	err := VerifyRangeProof(trie.Hash(), keys, vals, firstProof, lastProof)
+	if err == nil {
+		t.Fatal("expect error, got nil")
+	}
+}
+
 func TestBadProof(t *testing.T) {
 	trie, vals := randomTrie(800)
 	root := trie.Hash()
@@ -118,7 +258,7 @@ func TestBadProof(t *testing.T) {
 			mutateByte(val)
 			proof.Put(crypto.Keccak256(val), val)
 
-			if _, _, err := VerifyProof(root, kv.k, proof); err == nil {
+			if _, err := VerifyProof(root, kv.k, proof); err == nil {
 				t.Fatalf("prover %d: expected proof to fail for key %x", i, kv.k)
 			}
 		}
@@ -138,7 +278,7 @@ func TestMissingKeyProof(t *testing.T) {
 		if proof.Len() != 1 {
 			t.Errorf("test %d: proof should have one element", i)
 		}
-		val, _, err := VerifyProof(trie.Hash(), []byte(key), proof)
+		val, err := VerifyProof(trie.Hash(), []byte(key), proof)
 		if err != nil {
 			t.Fatalf("test %d: failed to verify proof: %v\nraw proof: %x", i, err, proof)
 		}
@@ -191,12 +331,50 @@ func BenchmarkVerifyProof(b *testing.B) {
 	b.ResetTimer()
 	for i := 0; i < b.N; i++ {
 		im := i % len(keys)
-		if _, _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
+		if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
 			b.Fatalf("key %x: %v", keys[im], err)
 		}
 	}
 }
 
+func BenchmarkVerifyRangeProof10(b *testing.B)   { benchmarkVerifyRangeProof(b, 10) }
+func BenchmarkVerifyRangeProof100(b *testing.B)  { benchmarkVerifyRangeProof(b, 100) }
+func BenchmarkVerifyRangeProof1000(b *testing.B) { benchmarkVerifyRangeProof(b, 1000) }
+func BenchmarkVerifyRangeProof5000(b *testing.B) { benchmarkVerifyRangeProof(b, 5000) }
+
+func benchmarkVerifyRangeProof(b *testing.B, size int) {
+	trie, vals := randomTrie(8192)
+	var entries entrySlice
+	for _, kv := range vals {
+		entries = append(entries, kv)
+	}
+	sort.Sort(entries)
+
+	start := 2
+	end := start + size
+	firstProof, lastProof := memorydb.New(), memorydb.New()
+	if err := trie.Prove(entries[start].k, 0, firstProof); err != nil {
+		b.Fatalf("Failed to prove the first node %v", err)
+	}
+	if err := trie.Prove(entries[end-1].k, 0, lastProof); err != nil {
+		b.Fatalf("Failed to prove the last node %v", err)
+	}
+	var keys [][]byte
+	var values [][]byte
+	for i := start; i < end; i++ {
+		keys = append(keys, entries[i].k)
+		values = append(values, entries[i].v)
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		err := VerifyRangeProof(trie.Hash(), keys, values, firstProof, lastProof)
+		if err != nil {
+			b.Fatalf("Case %d(%d->%d) expect no error, got %v", i, start, end-1, err)
+		}
+	}
+}
+
 func randomTrie(n int) (*Trie, map[string]*kv) {
 	trie := new(Trie)
 	vals := make(map[string]*kv)
-- 
GitLab