From c1a352c1085baa5c5f7650d331603bbb5532dea4 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Wed, 9 Sep 2015 03:35:41 +0200
Subject: [PATCH] trie: add merkle proof functions

---
 trie/proof.go      | 122 +++++++++++++++++++++++++++++++++++++++
 trie/proof_test.go | 139 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 261 insertions(+)
 create mode 100644 trie/proof.go
 create mode 100644 trie/proof_test.go

diff --git a/trie/proof.go b/trie/proof.go
new file mode 100644
index 000000000..a705c49db
--- /dev/null
+++ b/trie/proof.go
@@ -0,0 +1,122 @@
+package trie
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/crypto/sha3"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+// Prove constructs a merkle proof for key. The result contains all
+// encoded nodes on the path to the value at key. The value itself is
+// also included in the last node and can be retrieved by verifying
+// the proof.
+//
+// The returned proof is nil if the trie does not contain a value for key.
+// For existing keys, the proof will have at least one element.
+func (t *Trie) Prove(key []byte) []rlp.RawValue {
+	// Collect all nodes on the path to key.
+	key = compactHexDecode(key)
+	nodes := []node{}
+	tn := t.root
+	for len(key) > 0 {
+		switch n := tn.(type) {
+		case shortNode:
+			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+				// The trie doesn't contain the key.
+				return nil
+			}
+			tn = n.Val
+			key = key[len(n.Key):]
+			nodes = append(nodes, n)
+		case fullNode:
+			tn = n[key[0]]
+			key = key[1:]
+			nodes = append(nodes, n)
+		case nil:
+			return nil
+		case hashNode:
+			tn = t.resolveHash(n)
+		default:
+			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+		}
+	}
+	if t.hasher == nil {
+		t.hasher = newHasher()
+	}
+	proof := make([]rlp.RawValue, 0, len(nodes))
+	for i, n := range nodes {
+		// Don't bother checking for errors here since hasher panics
+		// if encoding doesn't work and we're not writing to any database.
+		n, _ = t.hasher.replaceChildren(n, nil)
+		hn, _ := t.hasher.store(n, nil, false)
+		if _, ok := hn.(hashNode); ok || i == 0 {
+			// If the node's database encoding is a hash (or is the
+			// root node), it becomes a proof element.
+			enc, _ := rlp.EncodeToBytes(n)
+			proof = append(proof, enc)
+		}
+	}
+	return proof
+}
+
+// VerifyProof checks merkle proofs. The given proof must contain the
+// value for key in a trie with the given root hash. VerifyProof
+// returns an error if the proof contains invalid trie nodes or the
+// wrong value.
+func VerifyProof(rootHash common.Hash, key []byte, proof []rlp.RawValue) (value []byte, err error) {
+	key = compactHexDecode(key)
+	sha := sha3.NewKeccak256()
+	wantHash := rootHash.Bytes()
+	for i, buf := range proof {
+		sha.Reset()
+		sha.Write(buf)
+		if !bytes.Equal(sha.Sum(nil), wantHash) {
+			return nil, fmt.Errorf("bad proof node %d: hash mismatch", i)
+		}
+		n, err := decodeNode(buf)
+		if err != nil {
+			return nil, fmt.Errorf("bad proof node %d: %v", i, err)
+		}
+		keyrest, cld := get(n, key)
+		switch cld := cld.(type) {
+		case nil:
+			return nil, fmt.Errorf("key mismatch at proof node %d", i)
+		case hashNode:
+			key = keyrest
+			wantHash = cld
+		case valueNode:
+			if i != len(proof)-1 {
+				return nil, errors.New("additional nodes at end of proof")
+			}
+			return cld, nil
+		}
+	}
+	return nil, errors.New("unexpected end of proof")
+}
+
+func get(tn node, key []byte) ([]byte, node) {
+	for len(key) > 0 {
+		switch n := tn.(type) {
+		case shortNode:
+			if len(key) < len(n.Key) || !bytes.Equal(n.Key, key[:len(n.Key)]) {
+				return nil, nil
+			}
+			tn = n.Val
+			key = key[len(n.Key):]
+		case fullNode:
+			tn = n[key[0]]
+			key = key[1:]
+		case hashNode:
+			return key, n
+		case nil:
+			return key, nil
+		default:
+			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
+		}
+	}
+	return nil, tn.(valueNode)
+}
diff --git a/trie/proof_test.go b/trie/proof_test.go
new file mode 100644
index 000000000..6b5bef05c
--- /dev/null
+++ b/trie/proof_test.go
@@ -0,0 +1,139 @@
+package trie
+
+import (
+	"bytes"
+	crand "crypto/rand"
+	mrand "math/rand"
+	"testing"
+	"time"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/rlp"
+)
+
+func init() {
+	mrand.Seed(time.Now().Unix())
+}
+
+func TestProof(t *testing.T) {
+	trie, vals := randomTrie(500)
+	root := trie.Hash()
+	for _, kv := range vals {
+		proof := trie.Prove(kv.k)
+		if proof == nil {
+			t.Fatalf("missing key %x while constructing proof", kv.k)
+		}
+		val, err := VerifyProof(root, kv.k, proof)
+		if err != nil {
+			t.Fatalf("VerifyProof error for key %x: %v\nraw proof: %x", kv.k, err, proof)
+		}
+		if !bytes.Equal(val, kv.v) {
+			t.Fatalf("VerifyProof returned wrong value for key %x: got %x, want %x", kv.k, val, kv.v)
+		}
+	}
+}
+
+func TestOneElementProof(t *testing.T) {
+	trie := new(Trie)
+	updateString(trie, "k", "v")
+	proof := trie.Prove([]byte("k"))
+	if proof == nil {
+		t.Fatal("nil proof")
+	}
+	if len(proof) != 1 {
+		t.Error("proof should have one element")
+	}
+	val, err := VerifyProof(trie.Hash(), []byte("k"), proof)
+	if err != nil {
+		t.Fatalf("VerifyProof error: %v\nraw proof: %x", err, proof)
+	}
+	if !bytes.Equal(val, []byte("v")) {
+		t.Fatalf("VerifyProof returned wrong value: got %x, want 'k'", val)
+	}
+}
+
+func TestVerifyBadProof(t *testing.T) {
+	trie, vals := randomTrie(800)
+	root := trie.Hash()
+	for _, kv := range vals {
+		proof := trie.Prove(kv.k)
+		if proof == nil {
+			t.Fatal("nil proof")
+		}
+		mutateByte(proof[mrand.Intn(len(proof))])
+		if _, err := VerifyProof(root, kv.k, proof); err == nil {
+			t.Fatalf("expected proof to fail for key %x", kv.k)
+		}
+	}
+}
+
+// mutateByte changes one byte in b.
+func mutateByte(b []byte) {
+	for r := mrand.Intn(len(b)); ; {
+		new := byte(mrand.Intn(255))
+		if new != b[r] {
+			b[r] = new
+			break
+		}
+	}
+}
+
+func BenchmarkProve(b *testing.B) {
+	trie, vals := randomTrie(100)
+	var keys []string
+	for k := range vals {
+		keys = append(keys, k)
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		kv := vals[keys[i%len(keys)]]
+		if trie.Prove(kv.k) == nil {
+			b.Fatalf("nil proof for %x", kv.k)
+		}
+	}
+}
+
+func BenchmarkVerifyProof(b *testing.B) {
+	trie, vals := randomTrie(100)
+	root := trie.Hash()
+	var keys []string
+	var proofs [][]rlp.RawValue
+	for k := range vals {
+		keys = append(keys, k)
+		proofs = append(proofs, trie.Prove([]byte(k)))
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		im := i % len(keys)
+		if _, err := VerifyProof(root, []byte(keys[im]), proofs[im]); err != nil {
+			b.Fatalf("key %x: error", keys[im], err)
+		}
+	}
+}
+
+func randomTrie(n int) (*Trie, map[string]*kv) {
+	trie := new(Trie)
+	vals := make(map[string]*kv)
+	for i := byte(0); i < 100; i++ {
+		value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false}
+		value2 := &kv{common.LeftPadBytes([]byte{i + 10}, 32), []byte{i}, false}
+		trie.Update(value.k, value.v)
+		trie.Update(value2.k, value2.v)
+		vals[string(value.k)] = value
+		vals[string(value2.k)] = value2
+	}
+	for i := 0; i < n; i++ {
+		value := &kv{randBytes(32), randBytes(20), false}
+		trie.Update(value.k, value.v)
+		vals[string(value.k)] = value
+	}
+	return trie, vals
+}
+
+func randBytes(n int) []byte {
+	r := make([]byte, n)
+	crand.Read(r)
+	return r
+}
-- 
GitLab