From 3c7181d28f1f24aaea2da5cce664ffac52f369df Mon Sep 17 00:00:00 2001
From: obscuren <geffobscura@gmail.com>
Date: Mon, 2 Feb 2015 19:58:34 -0800
Subject: [PATCH] Fixed a copy issue in the trie which could cause a consensus
 failure

---
 trie/cache.go     | 10 +++++++++-
 trie/fullnode.go  |  6 +++---
 trie/hashnode.go  | 15 +++++++++------
 trie/node.go      |  8 ++++++--
 trie/shortnode.go |  8 +++++++-
 trie/trie.go      | 21 ++++++++++++++++-----
 trie/valuenode.go |  4 +++-
 7 files changed, 53 insertions(+), 19 deletions(-)

diff --git a/trie/cache.go b/trie/cache.go
index e03702b25..2143785fa 100644
--- a/trie/cache.go
+++ b/trie/cache.go
@@ -37,6 +37,14 @@ func (self *Cache) Flush() {
 	//self.Reset()
 }
 
+func (self *Cache) Copy() *Cache {
+	cache := NewCache(self.backend)
+	for k, v := range self.store {
+		cache.store[k] = v
+	}
+	return cache
+}
+
 func (self *Cache) Reset() {
-	self.store = make(map[string][]byte)
+	//self.store = make(map[string][]byte)
 }
diff --git a/trie/fullnode.go b/trie/fullnode.go
index ebbe7f384..522fdb373 100644
--- a/trie/fullnode.go
+++ b/trie/fullnode.go
@@ -20,11 +20,11 @@ func (self *FullNode) Branches() []Node {
 	return self.nodes[:16]
 }
 
-func (self *FullNode) Copy() Node {
-	nnode := NewFullNode(self.trie)
+func (self *FullNode) Copy(t *Trie) Node {
+	nnode := NewFullNode(t)
 	for i, node := range self.nodes {
 		if node != nil {
-			nnode.nodes[i] = node
+			nnode.nodes[i] = node.Copy(t)
 		}
 	}
 
diff --git a/trie/hashnode.go b/trie/hashnode.go
index 40ccd54c3..e46628368 100644
--- a/trie/hashnode.go
+++ b/trie/hashnode.go
@@ -1,11 +1,14 @@
 package trie
 
+import "github.com/ethereum/go-ethereum/ethutil"
+
 type HashNode struct {
-	key []byte
+	key  []byte
+	trie *Trie
 }
 
-func NewHash(key []byte) *HashNode {
-	return &HashNode{key}
+func NewHash(key []byte, trie *Trie) *HashNode {
+	return &HashNode{key, trie}
 }
 
 func (self *HashNode) RlpData() interface{} {
@@ -17,6 +20,6 @@ func (self *HashNode) Hash() interface{} {
 }
 
 // These methods will never be called but we have to satisfy Node interface
-func (self *HashNode) Value() Node { return nil }
-func (self *HashNode) Dirty() bool { return true }
-func (self *HashNode) Copy() Node  { return self }
+func (self *HashNode) Value() Node       { return nil }
+func (self *HashNode) Dirty() bool       { return true }
+func (self *HashNode) Copy(t *Trie) Node { return NewHash(ethutil.CopyBytes(self.key), t) }
diff --git a/trie/node.go b/trie/node.go
index f28f24771..0d8a7cff9 100644
--- a/trie/node.go
+++ b/trie/node.go
@@ -6,7 +6,7 @@ var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b
 
 type Node interface {
 	Value() Node
-	Copy() Node // All nodes, for now, return them self
+	Copy(*Trie) Node // All nodes, for now, return them self
 	Dirty() bool
 	fstring(string) string
 	Hash() interface{}
@@ -18,7 +18,11 @@ func (self *ValueNode) String() string            { return self.fstring("") }
 func (self *FullNode) String() string             { return self.fstring("") }
 func (self *ShortNode) String() string            { return self.fstring("") }
 func (self *ValueNode) fstring(ind string) string { return fmt.Sprintf("%x ", self.data) }
-func (self *HashNode) fstring(ind string) string  { return fmt.Sprintf("< %x > ", self.key) }
+
+//func (self *HashNode) fstring(ind string) string  { return fmt.Sprintf("< %x > ", self.key) }
+func (self *HashNode) fstring(ind string) string {
+	return fmt.Sprintf("%v", self.trie.trans(self))
+}
 
 // Full node
 func (self *FullNode) fstring(ind string) string {
diff --git a/trie/shortnode.go b/trie/shortnode.go
index f132b56d9..d96492958 100644
--- a/trie/shortnode.go
+++ b/trie/shortnode.go
@@ -1,5 +1,7 @@
 package trie
 
+import "github.com/ethereum/go-ethereum/ethutil"
+
 type ShortNode struct {
 	trie  *Trie
 	key   []byte
@@ -15,7 +17,11 @@ func (self *ShortNode) Value() Node {
 	return self.value
 }
 func (self *ShortNode) Dirty() bool { return true }
-func (self *ShortNode) Copy() Node  { return NewShortNode(self.trie, self.key, self.value) }
+func (self *ShortNode) Copy(t *Trie) Node {
+	node := &ShortNode{t, nil, self.value.Copy(t)}
+	node.key = ethutil.CopyBytes(self.key)
+	return node
+}
 
 func (self *ShortNode) RlpData() interface{} {
 	return []interface{}{self.key, self.value.Hash()}
diff --git a/trie/trie.go b/trie/trie.go
index 36f2af5d2..9087f7bda 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -34,7 +34,9 @@ func New(root []byte, backend Backend) *Trie {
 	trie := &Trie{}
 	trie.revisions = list.New()
 	trie.roothash = root
-	trie.cache = NewCache(backend)
+	if backend != nil {
+		trie.cache = NewCache(backend)
+	}
 
 	if root != nil {
 		value := ethutil.NewValueFromBytes(trie.cache.Get(root))
@@ -49,7 +51,15 @@ func (self *Trie) Iterator() *Iterator {
 }
 
 func (self *Trie) Copy() *Trie {
-	return New(self.roothash, self.cache.backend)
+	cpy := make([]byte, 32)
+	copy(cpy, self.roothash)
+	trie := New(nil, nil)
+	trie.cache = self.cache.Copy()
+	if self.root != nil {
+		trie.root = self.root.Copy(trie)
+	}
+
+	return trie
 }
 
 // Legacy support
@@ -177,7 +187,7 @@ func (self *Trie) insert(node Node, key []byte, value Node) Node {
 		return NewShortNode(self, key[:matchlength], n)
 
 	case *FullNode:
-		cpy := node.Copy().(*FullNode)
+		cpy := node.Copy(self).(*FullNode)
 		cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
 
 		return cpy
@@ -244,7 +254,7 @@ func (self *Trie) delete(node Node, key []byte) Node {
 		}
 
 	case *FullNode:
-		n := node.Copy().(*FullNode)
+		n := node.Copy(self).(*FullNode)
 		n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
 
 		pos := -1
@@ -301,7 +311,7 @@ func (self *Trie) mknode(value *ethutil.Value) Node {
 		}
 		return fnode
 	case 32:
-		return &HashNode{value.Bytes()}
+		return &HashNode{value.Bytes(), self}
 	}
 
 	return &ValueNode{self, value.Bytes()}
@@ -331,4 +341,5 @@ func (self *Trie) store(node Node) interface{} {
 
 func (self *Trie) PrintRoot() {
 	fmt.Println(self.root)
+	fmt.Printf("root=%x\n", self.Root())
 }
diff --git a/trie/valuenode.go b/trie/valuenode.go
index 689befb2a..8912b1c82 100644
--- a/trie/valuenode.go
+++ b/trie/valuenode.go
@@ -1,5 +1,7 @@
 package trie
 
+import "github.com/ethereum/go-ethereum/ethutil"
+
 type ValueNode struct {
 	trie *Trie
 	data []byte
@@ -8,6 +10,6 @@ type ValueNode struct {
 func (self *ValueNode) Value() Node          { return self } // Best not to call :-)
 func (self *ValueNode) Val() []byte          { return self.data }
 func (self *ValueNode) Dirty() bool          { return true }
-func (self *ValueNode) Copy() Node           { return &ValueNode{self.trie, self.data} }
+func (self *ValueNode) Copy(t *Trie) Node    { return &ValueNode{t, ethutil.CopyBytes(self.data)} }
 func (self *ValueNode) RlpData() interface{} { return self.data }
 func (self *ValueNode) Hash() interface{}    { return self.data }
-- 
GitLab