good morning!!!!

Skip to content
Snippets Groups Projects
Commit d3386500 authored by Jeffrey Wilcke's avatar Jeffrey Wilcke
Browse files

compilable trie (tests fail)

parent 20b7162a
No related branches found
No related tags found
No related merge requests found
...@@ -27,27 +27,27 @@ func StringToAddress(s string) Address { return BytesToAddress([]byte(s)) } ...@@ -27,27 +27,27 @@ func StringToAddress(s string) Address { return BytesToAddress([]byte(s)) }
// Don't use the default 'String' method in case we want to overwrite // Don't use the default 'String' method in case we want to overwrite
// Get the string representation of the underlying hash // Get the string representation of the underlying hash
func (h Hash) Str() string { func (h *Hash) Str() string {
return string(h[:]) return string(h[:])
} }
// Sets the hash to the value of b. If b is larger than len(h) it will panic // Sets the hash to the value of b. If b is larger than len(h) it will panic
func (h Hash) SetBytes(b []byte) { func (h *Hash) SetBytes(b []byte) {
if len(b) > len(h) { if len(b) > len(h) {
panic("unable to set bytes. too big") panic("unable to set bytes. too big")
} }
// reverse loop // reverse loop
for i := len(b); i >= 0; i-- { for i := len(b) - 1; i >= 0; i-- {
h[i] = b[i] h[i] = b[i]
} }
} }
// Set string `s` to h. If s is larger than len(h) it will panic // Set string `s` to h. If s is larger than len(h) it will panic
func (h Hash) SetString(s string) { h.SetBytes([]byte(s)) } func (h *Hash) SetString(s string) { h.SetBytes([]byte(s)) }
// Sets h to other // Sets h to other
func (h Hash) Set(other Hash) { func (h *Hash) Set(other Hash) {
for i, v := range other { for i, v := range other {
h[i] = v h[i] = v
} }
......
...@@ -2,17 +2,19 @@ package trie ...@@ -2,17 +2,19 @@ package trie
import ( import (
"bytes" "bytes"
"github.com/ethereum/go-ethereum/common"
) )
type Iterator struct { type Iterator struct {
trie *Trie trie *Trie
Key []byte Key common.Hash
Value []byte Value []byte
} }
func NewIterator(trie *Trie) *Iterator { func NewIterator(trie *Trie) *Iterator {
return &Iterator{trie: trie, Key: nil} return &Iterator{trie: trie}
} }
func (self *Iterator) Next() bool { func (self *Iterator) Next() bool {
...@@ -20,15 +22,15 @@ func (self *Iterator) Next() bool { ...@@ -20,15 +22,15 @@ func (self *Iterator) Next() bool {
defer self.trie.mu.Unlock() defer self.trie.mu.Unlock()
isIterStart := false isIterStart := false
if self.Key == nil { if (self.Key == common.Hash{}) {
isIterStart = true isIterStart = true
self.Key = make([]byte, 32) //self.Key = make([]byte, 32)
} }
key := RemTerm(CompactHexDecode(string(self.Key))) key := RemTerm(CompactHexDecode(self.Key.Str()))
k := self.next(self.trie.root, key, isIterStart) k := self.next(self.trie.root, key, isIterStart)
self.Key = []byte(DecodeCompact(k)) self.Key = common.StringToHash(DecodeCompact(k))
return len(k) > 0 return len(k) > 0
} }
......
...@@ -22,7 +22,7 @@ func TestIterator(t *testing.T) { ...@@ -22,7 +22,7 @@ func TestIterator(t *testing.T) {
it := trie.Iterator() it := trie.Iterator()
for it.Next() { for it.Next() {
v[string(it.Key)] = true v[it.Key.Str()] = true
} }
for k, found := range v { for k, found := range v {
......
package trie package trie
import "github.com/ethereum/go-ethereum/crypto" import (
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
)
type SecureTrie struct { type SecureTrie struct {
*Trie *Trie
} }
func NewSecure(root []byte, backend Backend) *SecureTrie { func NewSecure(root common.Hash, backend Backend) *SecureTrie {
return &SecureTrie{New(root, backend)} return &SecureTrie{New(root, backend)}
} }
func (self *SecureTrie) Update(key, value []byte) Node { func (self *SecureTrie) Update(key common.Hash, value []byte) Node {
return self.Trie.Update(crypto.Sha3(key), value) return self.Trie.Update(common.BytesToHash(crypto.Sha3(key[:])), value)
} }
func (self *SecureTrie) UpdateString(key, value string) Node { func (self *SecureTrie) UpdateString(key, value string) Node {
return self.Update([]byte(key), []byte(value)) return self.Update(common.StringToHash(key), []byte(value))
} }
func (self *SecureTrie) Get(key []byte) []byte { func (self *SecureTrie) Get(key common.Hash) []byte {
return self.Trie.Get(crypto.Sha3(key)) return self.Trie.Get(common.BytesToHash(crypto.Sha3(key[:])))
} }
func (self *SecureTrie) GetString(key string) []byte { func (self *SecureTrie) GetString(key string) []byte {
return self.Get([]byte(key)) return self.Get(common.StringToHash(key))
} }
func (self *SecureTrie) Delete(key []byte) Node { func (self *SecureTrie) Delete(key common.Hash) Node {
return self.Trie.Delete(crypto.Sha3(key)) return self.Trie.Delete(common.BytesToHash(crypto.Sha3(key[:])))
} }
func (self *SecureTrie) DeleteString(key string) Node { func (self *SecureTrie) DeleteString(key string) Node {
return self.Delete([]byte(key)) return self.Delete(common.StringToHash(key))
} }
func (self *SecureTrie) Copy() *SecureTrie { func (self *SecureTrie) Copy() *SecureTrie {
......
...@@ -11,14 +11,15 @@ import ( ...@@ -11,14 +11,15 @@ import (
) )
func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
t2 := New(nil, backend) t2 := New(common.Hash{}, backend)
it := t1.Iterator() it := t1.Iterator()
for it.Next() { for it.Next() {
t2.Update(it.Key, it.Value) t2.Update(it.Key, it.Value)
} }
return bytes.Equal(t2.Hash(), t1.Hash()), t2 a, b := t2.Hash(), t1.Hash()
return bytes.Equal(a[:], b[:]), t2
} }
type Trie struct { type Trie struct {
...@@ -38,8 +39,8 @@ func New(root common.Hash, backend Backend) *Trie { ...@@ -38,8 +39,8 @@ func New(root common.Hash, backend Backend) *Trie {
trie.cache = NewCache(backend) trie.cache = NewCache(backend)
} }
if root != nil { if (root != common.Hash{}) {
value := common.NewValueFromBytes(trie.cache.Get(root)) value := common.NewValueFromBytes(trie.cache.Get(root[:]))
trie.root = trie.mknode(value) trie.root = trie.mknode(value)
} }
...@@ -51,12 +52,13 @@ func (self *Trie) Iterator() *Iterator { ...@@ -51,12 +52,13 @@ func (self *Trie) Iterator() *Iterator {
} }
func (self *Trie) Copy() *Trie { func (self *Trie) Copy() *Trie {
//cpy := make([]byte, 32)
//copy(cpy, self.roothash)
// cheap copying method // cheap copying method
var cpy common.Hash var cpy common.Hash
cpy.Set(self.roothash[:]) cpy.Set(self.roothash)
cpy := make([]byte, 32) trie := New(common.Hash{}, nil)
copy(cpy, self.roothash)
trie := New(nil, nil)
trie.cache = self.cache.Copy() trie.cache = self.cache.Copy()
if self.root != nil { if self.root != nil {
trie.root = self.root.Copy(trie) trie.root = self.root.Copy(trie)
...@@ -66,21 +68,21 @@ func (self *Trie) Copy() *Trie { ...@@ -66,21 +68,21 @@ func (self *Trie) Copy() *Trie {
} }
// Legacy support // Legacy support
func (self *Trie) Root() []byte { return self.Hash() } func (self *Trie) Root() common.Hash { return self.Hash() }
func (self *Trie) Hash() []byte { func (self *Trie) Hash() common.Hash {
var hash []byte var hash common.Hash
if self.root != nil { if self.root != nil {
t := self.root.Hash() t := self.root.Hash()
if byts, ok := t.([]byte); ok && len(byts) > 0 { if h, ok := t.(common.Hash); ok && (h != common.Hash{}) {
hash = byts hash = h
} else { } else {
hash = crypto.Sha3(common.Encode(self.root.RlpData())) hash = common.BytesToHash(crypto.Sha3(common.Encode(self.root.RlpData())))
} }
} else { } else {
hash = crypto.Sha3(common.Encode("")) hash = common.BytesToHash(crypto.Sha3(common.Encode("")))
} }
if !bytes.Equal(hash, self.roothash) { if hash != self.roothash {
self.revisions.PushBack(self.roothash) self.revisions.PushBack(self.roothash)
self.roothash = hash self.roothash = hash
} }
...@@ -105,19 +107,21 @@ func (self *Trie) Reset() { ...@@ -105,19 +107,21 @@ func (self *Trie) Reset() {
self.cache.Reset() self.cache.Reset()
if self.revisions.Len() > 0 { if self.revisions.Len() > 0 {
revision := self.revisions.Remove(self.revisions.Back()).([]byte) revision := self.revisions.Remove(self.revisions.Back()).(common.Hash)
self.roothash = revision self.roothash = revision
} }
value := common.NewValueFromBytes(self.cache.Get(self.roothash)) value := common.NewValueFromBytes(self.cache.Get(self.roothash[:]))
self.root = self.mknode(value) self.root = self.mknode(value)
} }
func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } func (self *Trie) UpdateString(key, value string) Node {
func (self *Trie) Update(key, value []byte) Node { return self.Update(common.StringToHash(key), []byte(value))
}
func (self *Trie) Update(key common.Hash, value []byte) Node {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
k := CompactHexDecode(string(key)) k := CompactHexDecode(key.Str())
if len(value) != 0 { if len(value) != 0 {
self.root = self.insert(self.root, k, &ValueNode{self, value}) self.root = self.insert(self.root, k, &ValueNode{self, value})
...@@ -128,12 +132,12 @@ func (self *Trie) Update(key, value []byte) Node { ...@@ -128,12 +132,12 @@ func (self *Trie) Update(key, value []byte) Node {
return self.root return self.root
} }
func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } func (self *Trie) GetString(key string) []byte { return self.Get(common.StringToHash(key)) }
func (self *Trie) Get(key []byte) []byte { func (self *Trie) Get(key common.Hash) []byte {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
k := CompactHexDecode(string(key)) k := CompactHexDecode(key.Str())
n := self.get(self.root, k) n := self.get(self.root, k)
if n != nil { if n != nil {
...@@ -143,12 +147,12 @@ func (self *Trie) Get(key []byte) []byte { ...@@ -143,12 +147,12 @@ func (self *Trie) Get(key []byte) []byte {
return nil return nil
} }
func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } func (self *Trie) DeleteString(key string) Node { return self.Delete(common.StringToHash(key)) }
func (self *Trie) Delete(key []byte) Node { func (self *Trie) Delete(key common.Hash) Node {
self.mu.Lock() self.mu.Lock()
defer self.mu.Unlock() defer self.mu.Unlock()
k := CompactHexDecode(string(key)) k := CompactHexDecode(key.Str())
self.root = self.delete(self.root, k) self.root = self.delete(self.root, k)
return self.root return self.root
......
...@@ -5,8 +5,8 @@ import ( ...@@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/ethereum/go-ethereum/crypto"
"github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/crypto"
) )
type Db map[string][]byte type Db map[string][]byte
...@@ -16,18 +16,18 @@ func (self Db) Put(k, v []byte) { self[string(k)] = v } ...@@ -16,18 +16,18 @@ func (self Db) Put(k, v []byte) { self[string(k)] = v }
// Used for testing // Used for testing
func NewEmpty() *Trie { func NewEmpty() *Trie {
return New(nil, make(Db)) return New(common.Hash{}, make(Db))
} }
func NewEmptySecure() *SecureTrie { func NewEmptySecure() *SecureTrie {
return NewSecure(nil, make(Db)) return NewSecure(common.Hash{}, make(Db))
} }
func TestEmptyTrie(t *testing.T) { func TestEmptyTrie(t *testing.T) {
trie := NewEmpty() trie := NewEmpty()
res := trie.Hash() res := trie.Hash()
exp := crypto.Sha3(common.Encode("")) exp := crypto.Sha3(common.Encode(""))
if !bytes.Equal(res, exp) { if !bytes.Equal(res[:], exp[:]) {
t.Errorf("expected %x got %x", exp, res) t.Errorf("expected %x got %x", exp, res)
} }
} }
...@@ -41,7 +41,7 @@ func TestInsert(t *testing.T) { ...@@ -41,7 +41,7 @@ func TestInsert(t *testing.T) {
exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3") exp := common.Hex2Bytes("8aad789dff2f538bca5d8ea56e8abe10f4c7ba3a5dea95fea4cd6e7c3a1168d3")
root := trie.Hash() root := trie.Hash()
if !bytes.Equal(root, exp) { if !bytes.Equal(root[:], exp[:]) {
t.Errorf("exp %x got %x", exp, root) t.Errorf("exp %x got %x", exp, root)
} }
...@@ -50,7 +50,7 @@ func TestInsert(t *testing.T) { ...@@ -50,7 +50,7 @@ func TestInsert(t *testing.T) {
exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab") exp = common.Hex2Bytes("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
root = trie.Hash() root = trie.Hash()
if !bytes.Equal(root, exp) { if !bytes.Equal(root[:], exp) {
t.Errorf("exp %x got %x", exp, root) t.Errorf("exp %x got %x", exp, root)
} }
} }
...@@ -96,7 +96,7 @@ func TestDelete(t *testing.T) { ...@@ -96,7 +96,7 @@ func TestDelete(t *testing.T) {
hash := trie.Hash() hash := trie.Hash()
exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }
...@@ -120,7 +120,7 @@ func TestEmptyValues(t *testing.T) { ...@@ -120,7 +120,7 @@ func TestEmptyValues(t *testing.T) {
hash := trie.Hash() hash := trie.Hash()
exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84") exp := common.Hex2Bytes("5991bb8c6514148a29db676a14ac506cd2cd5775ace63c30a4fe457715e9ac84")
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }
...@@ -150,7 +150,7 @@ func TestReplication(t *testing.T) { ...@@ -150,7 +150,7 @@ func TestReplication(t *testing.T) {
hash := trie2.Hash() hash := trie2.Hash()
exp := trie.Hash() exp := trie.Hash()
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp[:]) {
t.Errorf("root failure. expected %x got %x", exp, hash) t.Errorf("root failure. expected %x got %x", exp, hash)
} }
...@@ -168,7 +168,9 @@ func TestReset(t *testing.T) { ...@@ -168,7 +168,9 @@ func TestReset(t *testing.T) {
} }
trie.Commit() trie.Commit()
before := common.CopyBytes(trie.roothash) var before common.Hash
before.Set(trie.roothash)
trie.UpdateString("should", "revert") trie.UpdateString("should", "revert")
trie.Hash() trie.Hash()
// Should have no effect // Should have no effect
...@@ -177,9 +179,11 @@ func TestReset(t *testing.T) { ...@@ -177,9 +179,11 @@ func TestReset(t *testing.T) {
// ### // ###
trie.Reset() trie.Reset()
after := common.CopyBytes(trie.roothash)
if !bytes.Equal(before, after) { var after common.Hash
after.Set(trie.roothash)
if before != after {
t.Errorf("expected roots to be equal. %x - %x", before, after) t.Errorf("expected roots to be equal. %x - %x", before, after)
} }
} }
...@@ -248,7 +252,7 @@ func BenchmarkGets(b *testing.B) { ...@@ -248,7 +252,7 @@ func BenchmarkGets(b *testing.B) {
b.ResetTimer() b.ResetTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
trie.Get([]byte("horse")) trie.GetString("horse")
} }
} }
...@@ -263,7 +267,8 @@ func BenchmarkUpdate(b *testing.B) { ...@@ -263,7 +267,8 @@ func BenchmarkUpdate(b *testing.B) {
} }
type kv struct { type kv struct {
k, v []byte k common.Hash
v []byte
t bool t bool
} }
...@@ -272,17 +277,21 @@ func TestLargeData(t *testing.T) { ...@@ -272,17 +277,21 @@ func TestLargeData(t *testing.T) {
vals := make(map[string]*kv) vals := make(map[string]*kv)
for i := byte(0); i < 255; i++ { for i := byte(0); i < 255; i++ {
value := &kv{common.LeftPadBytes([]byte{i}, 32), []byte{i}, false} var k1 common.Hash
value2 := &kv{common.LeftPadBytes([]byte{10, i}, 32), []byte{i}, false} k1.SetBytes([]byte{i})
var k2 common.Hash
k2.SetBytes([]byte{10, i})
value := &kv{k1, []byte{i}, false}
value2 := &kv{k2, []byte{i}, false}
trie.Update(value.k, value.v) trie.Update(value.k, value.v)
trie.Update(value2.k, value2.v) trie.Update(value2.k, value2.v)
vals[string(value.k)] = value vals[value.k.Str()] = value
vals[string(value2.k)] = value2 vals[value2.k.Str()] = value2
} }
it := trie.Iterator() it := trie.Iterator()
for it.Next() { for it.Next() {
vals[string(it.Key)].t = true vals[it.Key.Str()].t = true
} }
var untouched []*kv var untouched []*kv
...@@ -323,7 +332,7 @@ func TestSecureDelete(t *testing.T) { ...@@ -323,7 +332,7 @@ func TestSecureDelete(t *testing.T) {
hash := trie.Hash() hash := trie.Hash()
exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d") exp := common.Hex2Bytes("29b235a58c3c25ab83010c327d5932bcf05324b7d6b1185e650798034783ca9d")
if !bytes.Equal(hash, exp) { if !bytes.Equal(hash[:], exp) {
t.Errorf("expected %x got %x", exp, hash) t.Errorf("expected %x got %x", exp, hash)
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment