diff --git a/core/state/database.go b/core/state/database.go index 8665894540ea709d5f154b48cb2233146fcac562..78146d2e9505425661581f0019eae89863611802 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -719,7 +719,7 @@ func (tds *TrieDbState) UnwindTo(blockNr uint64) error { b.storageUpdates[addrHashWithVersion] = m } if len(value) > 0 { - m[keyHash] = AddExtraRLPLevel(value) + m[keyHash] = value if err := tds.db.Put(dbutils.StorageBucket, key[:common.HashLength+IncarnationLength+common.HashLength], value); err != nil { return err } @@ -867,12 +867,7 @@ func (tds *TrieDbState) ReadAccountStorage(address common.Address, incarnation u } enc, ok := tds.t.Get(dbutils.GenerateCompositeTrieKey(addrHash, seckey)) - if ok { - // Unwrap one RLP level - if len(enc) > 1 { - enc = enc[1:] - } - } else { + if !ok { // Not present in the trie, try database if tds.historical { enc, err = tds.db.GetAsOf(dbutils.StorageBucket, dbutils.StorageHistoryBucket, dbutils.GenerateCompositeStorageKey(addrHash, incarnation, seckey), tds.blockNr) @@ -1158,8 +1153,7 @@ func (tsw *TrieStateWriter) WriteAccountStorage(_ context.Context, address commo return err } if len(v) > 0 { - // Write into 1 extra RLP level - m[seckey] = AddExtraRLPLevel(v) + m[seckey] = v } else { m[seckey] = nil } diff --git a/core/state/helper.go b/core/state/helper.go index dc4e4e1d1c8f09436b2b40545ccfa85140aec0bf..7bf2df5b48611948695d1a1fceca0602fb1860a2 100644 --- a/core/state/helper.go +++ b/core/state/helper.go @@ -1,16 +1 @@ package state - -//Write into 1 extra RLP level -func AddExtraRLPLevel(v []byte) []byte { - var vv []byte - - if len(v) > 1 || v[0] >= 128 { - vv = make([]byte, len(v)+1) - vv[0] = byte(128 + len(v)) - copy(vv[1:], v) - } else { - vv = make([]byte, 1) - vv[0] = v[0] - } - return vv -} diff --git a/core/state/repair.go b/core/state/repair.go index f013d34c0e6b9f4fc19f4e3eea9edf9a53e4d971..7c1f9665805f948257bb58af19ce63efe416c1dd 100644 --- a/core/state/repair.go +++ b/core/state/repair.go @@ -370,7 +370,7 @@ func (rds *RepairDbState) WriteAccountStorage(address common.Address, incarnatio rds.storageUpdates[address] = m } if len(v) > 0 { - m[seckey] = AddExtraRLPLevel(v) + m[seckey] = v } else { m[seckey] = nil } diff --git a/core/state/stateless.go b/core/state/stateless.go index 2f12352d321ae0184eb0b095d6ff8273d2bf2da8..5283850c86225d1a499647b27af3c0a0e27fe033 100644 --- a/core/state/stateless.go +++ b/core/state/stateless.go @@ -197,8 +197,7 @@ func (s *Stateless) WriteAccountStorage(_ context.Context, address common.Addres return err } if len(v) > 0 { - // Write into 1 extra RLP level - m[seckey] = AddExtraRLPLevel(v) + m[seckey] = v } else { m[seckey] = nil } diff --git a/core/types/derive_sha.go b/core/types/derive_sha.go index 74c8d8d78c06b40dae3b9440c147beebc73d09b1..762d18e21b436c66bc9b1ecec7b3812a050f2761 100644 --- a/core/types/derive_sha.go +++ b/core/types/derive_sha.go @@ -40,7 +40,7 @@ func DeriveSha(list DerivableList) common.Hash { hb := trie.NewHashBuilder() hb.SetKeyTape(curr) - hb.SetValueTape(value) + hb.SetValueTape(trie.NewRlpEncodedBytesTape(value)) hb.Reset() prev.Reset() diff --git a/core/types/derive_sha_test.go b/core/types/derive_sha_test.go index 47e1122a8b17a4570da34113a83f6e864d1db8b7..f4474edfcff0f1559a75e1a4d241e7641fa2c796 100644 --- a/core/types/derive_sha_test.go +++ b/core/types/derive_sha_test.go @@ -75,7 +75,7 @@ func hashesEqual(h1, h2 common.Hash) bool { func legacyDeriveSha(list DerivableList) common.Hash { keybuf := new(bytes.Buffer) - trie := trie.New(common.Hash{}) + trie := trie.NewTestRLPTrie(common.Hash{}) for i := 0; i < list.Len(); i++ { keybuf.Reset() _ = rlp.Encode(keybuf, uint(i)) diff --git a/trie/account_node_test.go b/trie/account_node_test.go index 7e6c36b4d926007e0ec4efdc0db57cd9e68dc844..4fd84bea72eb550c9b78094523b3b52f30aeb5ec 100644 --- a/trie/account_node_test.go +++ b/trie/account_node_test.go @@ -2,14 +2,15 @@ package trie import ( "crypto/ecdsa" + "math/big" + "reflect" + "testing" + "github.com/ledgerwatch/turbo-geth/common" "github.com/ledgerwatch/turbo-geth/common/dbutils" "github.com/ledgerwatch/turbo-geth/core/types/accounts" "github.com/ledgerwatch/turbo-geth/crypto" "golang.org/x/crypto/sha3" - "math/big" - "reflect" - "testing" ) func TestGetAccount(t *testing.T) { @@ -119,7 +120,7 @@ func TestHash(t *testing.T) { } trie := New(common.Hash{}) - trie2 := New(common.Hash{}) + trie2 := NewTestRLPTrie(common.Hash{}) trie.UpdateAccount(addr1.Bytes(), acc1) trie.UpdateAccount(addr2.Bytes(), acc2) diff --git a/trie/bytes_rlp.go b/trie/bytes_rlp.go new file mode 100644 index 0000000000000000000000000000000000000000..7cf8b1458ee566226e684bd10897bff6b5644014 --- /dev/null +++ b/trie/bytes_rlp.go @@ -0,0 +1,67 @@ +package trie + +import ( + "io" +) + +type RlpSerializableBytes []byte + +func (b RlpSerializableBytes) ToDoubleRLP(w io.Writer) error { + return encodeBytesAsRlpToWriter(b, w, generateByteArrayLenDouble, 8) +} + +func (b RlpSerializableBytes) RawBytes() []byte { + return b +} + +func (b RlpSerializableBytes) DoubleRLPLen() int { + if len(b) < 1 { + return 0 + } + return generateRlpPrefixLenDouble(len(b), b[0]) + len(b) +} + +type RlpEncodedBytes []byte + +func (b RlpEncodedBytes) ToDoubleRLP(w io.Writer) error { + return encodeBytesAsRlpToWriter(b, w, generateByteArrayLen, 4) +} + +func (b RlpEncodedBytes) RawBytes() []byte { + return b +} + +func (b RlpEncodedBytes) DoubleRLPLen() int { + return generateRlpPrefixLen(len(b)) + len(b) +} + +func encodeBytesAsRlpToWriter(source []byte, w io.Writer, prefixGenFunc func([]byte, int, int) int, prefixBufferSize uint) error { + // > 1 byte, write a prefix or prefixes first + if len(source) > 1 || (len(source) == 1 && source[0] >= 0x80) { + prefix := make([]byte, prefixBufferSize) + prefixLen := prefixGenFunc(prefix, 0, len(source)) + + if _, err := w.Write(prefix[:prefixLen]); err != nil { + return err + } + } + + _, err := w.Write(source) + return err +} + +type ByteArrayWriter struct { + dest []byte + pos int +} + +func (w *ByteArrayWriter) Setup(dest []byte, pos int) { + w.dest = dest + w.pos = pos +} + +func (w *ByteArrayWriter) Write(data []byte) (int, error) { + copy(w.dest[w.pos:], data) + w.pos += len(data) + return len(data), nil +} diff --git a/trie/bytes_rlp_test.go b/trie/bytes_rlp_test.go new file mode 100644 index 0000000000000000000000000000000000000000..3954634ee888576dd7eb0fcc03b311b484a3a07f --- /dev/null +++ b/trie/bytes_rlp_test.go @@ -0,0 +1,64 @@ +package trie + +import ( + "bytes" + "testing" + + "github.com/ledgerwatch/turbo-geth/rlp" +) + +func TestFastDoubleRlpForByteArrays(t *testing.T) { + for i := 0; i < 256; i++ { + doTestWithByte(t, byte(i), 1) + } + doTestWithByte(t, 0x0, 100000) + doTestWithByte(t, 0xC, 100000) + doTestWithByte(t, 0xAB, 100000) +} + +func doTestWithByte(t *testing.T, b byte, iterations int) { + buffer := new(bytes.Buffer) + + for i := 0; i < iterations; i++ { + buffer.WriteByte(b) + source := buffer.Bytes() + + encSingle, _ := rlp.EncodeToBytes(source) + + encDouble, _ := rlp.EncodeToBytes(encSingle) + + if RlpSerializableBytes(source).DoubleRLPLen() != len(encDouble) { + t.Errorf("source [%2x * %d] wrong RlpSerializableBytes#DoubleRLPLen prediction: %d (expected %d)", source[0], len(source), RlpSerializableBytes(source).DoubleRLPLen(), len(encDouble)) + } + + if RlpEncodedBytes(encSingle).DoubleRLPLen() != len(encDouble) { + t.Errorf("source [%2x * %d] wrong RlpEncodedBytes#DoubleRLPLen prediction: %d (expected %d)", source[0], len(source), RlpEncodedBytes(encSingle).DoubleRLPLen(), len(encDouble)) + } + + buffDouble := new(bytes.Buffer) + if err := RlpSerializableBytes(source).ToDoubleRLP(buffDouble); err != nil { + t.Errorf("test failed, err = %v", err) + } + + buffSingle := new(bytes.Buffer) + if err := RlpEncodedBytes(encSingle).ToDoubleRLP(buffSingle); err != nil { + t.Errorf("test failed, err = %v", err) + } + + if !bytes.Equal(buffDouble.Bytes(), encDouble) { + t.Errorf("source [%2x * %d] wrong RlpSerializableBytes#ToDoubleRLP prediction: %x (expected %x)", source[0], len(source), displayOf(buffDouble.Bytes()), displayOf(encDouble)) + } + + if !bytes.Equal(buffSingle.Bytes(), encDouble) { + t.Errorf("source [%2x * %d] wrong RlpEncodedBytes#ToDoubleRLP prediction: %x (expected %x)", source[0], len(source), displayOf(buffSingle.Bytes()), displayOf(encDouble)) + } + } +} + +func displayOf(bytes []byte) []byte { + if len(bytes) < 20 { + return bytes + } + + return bytes[:20] +} diff --git a/trie/hasher.go b/trie/hasher.go index 453cd390342b3f80fc2d65fad6a44f5270903b6b..8ba5fd6094c50800220bdfa28c3e4100fc42059b 100644 --- a/trie/hasher.go +++ b/trie/hasher.go @@ -27,9 +27,9 @@ import ( ) type hasher struct { - sha keccakState - encodeToBytes bool - buffers [1024 * 1024]byte + sha keccakState + valueNodesRlpEncoded bool + buffers [1024 * 1024]byte } // keccakState wraps sha3.state. In addition to the usual hash methods, it also supports @@ -43,14 +43,14 @@ type keccakState interface { // hashers live in a global db. var hasherPool = make(chan *hasher, 128) -func newHasher(encodeToBytes bool) *hasher { +func newHasher(valueNodesRlpEncoded bool) *hasher { var h *hasher select { case h = <-hasherPool: default: h = &hasher{sha: sha3.NewLegacyKeccak256().(keccakState)} } - h.encodeToBytes = encodeToBytes + h.valueNodesRlpEncoded = valueNodesRlpEncoded return h } @@ -155,6 +155,50 @@ func generateByteArrayLen(buffer []byte, pos int, l int) int { return pos } +func generateRlpPrefixLen(l int) int { + if l < 2 { + return 0 + } + if l < 56 { + return 1 + } + if l < 256 { + return 2 + } + if l < 65536 { + return 3 + } + return 4 +} + +func generateRlpPrefixLenDouble(l int, firstByte byte) int { + if l < 2 { + if firstByte >= 0x80 { + return 2 + } + return 0 + } + if l < 55 { + return 2 + } + if l < 56 { // 2 + 1 + return 3 + } + if l < 254 { + return 4 + } + if l < 256 { + return 5 + } + if l < 65533 { + return 6 + } + if l < 65536 { + return 7 + } + return 8 +} + func generateByteArrayLenDouble(buffer []byte, pos int, l int) int { if l < 55 { // After first wrapping, the length will be l + 1 < 56 @@ -162,6 +206,13 @@ func generateByteArrayLenDouble(buffer []byte, pos int, l int) int { pos++ buffer[pos] = byte(128 + l) pos++ + } else if l < 56 { + buffer[pos] = byte(183 + 1) + pos++ + buffer[pos] = byte(l + 1) + pos++ + buffer[pos] = byte(128 + l) + pos++ } else if l < 254 { // After first wrapping, the length will be l + 2 < 256 buffer[pos] = byte(183 + 1) @@ -184,7 +235,7 @@ func generateByteArrayLenDouble(buffer []byte, pos int, l int) int { pos++ buffer[pos] = byte(l) pos++ - } else if l < 65534 { + } else if l < 65533 { // Both wrappings are 3 bytes buffer[pos] = byte(183 + 2) pos++ @@ -243,6 +294,7 @@ func generateByteArrayLenDouble(buffer []byte, pos int, l int) int { func (h *hasher) hashChildren(original node, bufOffset int) []byte { buffer := h.buffers[bufOffset:] pos := 4 + switch n := original.(type) { case *shortNode: // Starting at position 3, to leave space for len prefix @@ -263,11 +315,11 @@ func (h *hasher) hashChildren(original node, bufOffset int) []byte { buffer[pos] = vn[0] pos++ } else { - if h.encodeToBytes { - // Wrapping into another byte array - pos = generateByteArrayLenDouble(buffer, pos, len(vn)) - } else { + if h.valueNodesRlpEncoded { pos = generateByteArrayLen(buffer, pos, len(vn)) + } else { + // value node contains raw values + pos = generateByteArrayLenDouble(buffer, pos, len(vn)) } copy(buffer[pos:], vn) pos += len(vn) @@ -370,6 +422,8 @@ func (h *hasher) hashChildren(original node, bufOffset int) []byte { } } var enc []byte + needsDoubleRlpEncoding := false + switch n := n.Children[16].(type) { case *accountNode: encodedAccount := pool.GetBuffer(n.EncodingLengthForHashing()) @@ -377,9 +431,10 @@ func (h *hasher) hashChildren(original node, bufOffset int) []byte { enc = encodedAccount.Bytes() pool.PutBuffer(encodedAccount) case valueNode: + needsDoubleRlpEncoding = !h.valueNodesRlpEncoded enc = n case nil: - // skip + // skip default: // skip } @@ -391,7 +446,11 @@ func (h *hasher) hashChildren(original node, bufOffset int) []byte { buffer[pos] = enc[0] pos++ } else { - pos = generateByteArrayLen(buffer, pos, len(enc)) + if needsDoubleRlpEncoding { + pos = generateByteArrayLenDouble(buffer, pos, len(enc)) + } else { + pos = generateByteArrayLen(buffer, pos, len(enc)) + } copy(buffer[pos:], enc) pos += len(enc) } @@ -402,9 +461,10 @@ func (h *hasher) hashChildren(original node, bufOffset int) []byte { buffer[pos] = n[0] pos++ } else { - if h.encodeToBytes { - // Wrapping into another byte array + if h.valueNodesRlpEncoded { pos = generateByteArrayLen(buffer, pos, len(n)) + } else { + pos = generateByteArrayLenDouble(buffer, pos, len(n)) } copy(buffer[pos:], n) pos += len(n) @@ -420,10 +480,6 @@ func (h *hasher) hashChildren(original node, bufOffset int) []byte { buffer[pos] = enc[0] pos++ } else { - if h.encodeToBytes { - // Wrapping into another byte array - pos = generateByteArrayLen(buffer, pos, len(enc)) - } copy(buffer[pos:], enc) pos += len(enc) } diff --git a/trie/proof_generator.go b/trie/proof_generator.go index 5fe4673ecf0e1e3aba70adbe01230930767c9fa5..479384607f236341638b67986e817b14a8a1ed63 100644 --- a/trie/proof_generator.go +++ b/trie/proof_generator.go @@ -589,7 +589,10 @@ func BlockWitnessToTrie(bw []byte, trace bool) (*Trie, map[common.Hash][]byte, e hb.SetKeyTape(NewCborBytesTape(bw[startOffset:endOffset])) startOffset = endOffset endOffset = startOffset + lens[ValueTape] - hb.SetValueTape(NewCborBytesTape(bw[startOffset:endOffset])) + hb.SetValueTape( + NewRlpSerializableBytesTape( + NewCborBytesTape(bw[startOffset:endOffset]))) + startOffset = endOffset endOffset = startOffset + lens[NonceTape] hb.SetNonceTape(NewCborUint64Tape(bw[startOffset:endOffset])) diff --git a/trie/proof_generator_test.go b/trie/proof_generator_test.go index d8c826ede174be8087c12d8396a9d24a821a3712..652f51a8808a4f1814e0f3cd4097b8d4f5fcf03a 100644 --- a/trie/proof_generator_test.go +++ b/trie/proof_generator_test.go @@ -182,7 +182,8 @@ func TestSerialiseBlockWitness(t *testing.T) { if err := bwb.WriteTo(&b); err != nil { t.Errorf("Could not make block witness: %v", err) } - expected := common.FromHex("0xa76862616c616e6365730065636f64657300666861736865731822646b65797300666e6f6e63657300697374727563747572650b6676616c75657300582023181a62d35fe01562158be610f84e047f99f5e74d896da21682d925964ece3a0601024704010402040304") + + expected := common.FromHex("0xa76862616c616e6365730065636f64657300666861736865731822646b65797300666e6f6e63657300697374727563747572650b6676616c756573005820858f70a4b1e6aa71a7edc574d2ca946495a038aa37ce13dc7b7ed15661a6ff2f0601024704010402040304") if !bytes.Equal(expected, b.Bytes()) { t.Errorf("Expected %x, got: %x", expected, b.Bytes()) } diff --git a/trie/resolver.go b/trie/resolver.go index 884493d720b7e8ec0369f0a7af8b50e4345a86e2..5cfcd26064bb4192290990522ea4d972e95a48f0 100644 --- a/trie/resolver.go +++ b/trie/resolver.go @@ -113,7 +113,7 @@ func NewResolver(topLevels int, forAccounts bool, blockNr uint64) *Resolver { hb: NewHashBuilder(), } tr.hb.SetKeyTape(&tr.curr) - tr.hb.SetValueTape(&tr.value) + tr.hb.SetValueTape(NewRlpSerializableBytesTape(&tr.value)) tr.hb.SetNonceTape((*OneUint64Tape)(&tr.a.Nonce)) tr.hb.SetBalanceTape((*OneBalanceTape)(&tr.a.Balance)) tr.hb.SetHashTape(&tr.hashes) @@ -325,9 +325,6 @@ func (tr *Resolver) Walker(keyIdx int, k []byte, v []byte) (bool, error) { } } else { tr.value.Buffer.Reset() - if len(v) > 1 || v[0] >= 128 { - tr.value.Buffer.WriteByte(byte(128 + len(v))) - } tr.value.Buffer.Write(v) tr.fieldSet = AccountFieldSetNotAccount } diff --git a/trie/resolver_test.go b/trie/resolver_test.go index d8f1d838bbc6113b2345ba37747281bd5e01e72a..3a4a22e9859c46203fc44ad7849a9fb07ef4adfe 100644 --- a/trie/resolver_test.go +++ b/trie/resolver_test.go @@ -68,7 +68,7 @@ func TestResolve1(t *testing.T) { t: tr, resolveHex: keybytesToHex([]byte("aaaaabbbbbaaaaabbbbbaaaaabbbbbaa")), resolvePos: 10, // 5 bytes is 10 nibbles - resolveHash: hashNode(common.HexToHash("6556dfaac213851c044228962a8dc179125d81e496805ef0f4b891e9109135e2").Bytes()), + resolveHash: hashNode(common.HexToHash("bfb355c9a7c26a9c173a9c30e1fb2895fd9908726a8d3dd097203b207d852cf5").Bytes()), } r := NewResolver(0, false, 0) r.AddRequest(req) @@ -91,7 +91,7 @@ func TestResolve2(t *testing.T) { t: tr, resolveHex: keybytesToHex([]byte("aaaaabbbbbaaaaabbbbbaaaaabbbbbaa")), resolvePos: 10, // 5 bytes is 10 nibbles - resolveHash: hashNode(common.HexToHash("ca8155b4955b3723207ba30103f1759effbf87e5d8193fa215e5fe9818a00e2a").Bytes()), + resolveHash: hashNode(common.HexToHash("38eb1d28b717978c8cb21b6939dc69ba445d5dea67ca0e948bbf0aef9f1bc2fb").Bytes()), } r := NewResolver(0, false, 0) r.AddRequest(req) @@ -114,7 +114,7 @@ func TestResolve2Keep(t *testing.T) { t: tr, resolveHex: keybytesToHex([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 10, // 5 bytes is 10 nibbles - resolveHash: hashNode(common.HexToHash("ca8155b4955b3723207ba30103f1759effbf87e5d8193fa215e5fe9818a00e2a").Bytes()), + resolveHash: hashNode(common.HexToHash("38eb1d28b717978c8cb21b6939dc69ba445d5dea67ca0e948bbf0aef9f1bc2fb").Bytes()), } r := NewResolver(0, false, 0) r.AddRequest(req) @@ -140,7 +140,7 @@ func TestResolve3Keep(t *testing.T) { t: tr, resolveHex: keybytesToHex([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 10, // 5 bytes is 10 nibbles - resolveHash: hashNode(common.HexToHash("037d4f8cdf09ad062c866adefc24115c9e84e07399bd6ea058ed386b76dafde2").Bytes()), + resolveHash: hashNode(common.HexToHash("b780e7d2bc3b7ab7f85084edb2fff42facefa0df9dd1e8190470f277d8183e7c").Bytes()), } r := NewResolver(0, false, 0) r.AddRequest(req) @@ -179,7 +179,7 @@ func TestTrieResolver(t *testing.T) { t: tr, resolveHex: keybytesToHex([]byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 10, // 5 bytes is 10 nibbles - resolveHash: hashNode(common.HexToHash("ca8155b4955b3723207ba30103f1759effbf87e5d8193fa215e5fe9818a00e2a").Bytes()), + resolveHash: hashNode(common.HexToHash("38eb1d28b717978c8cb21b6939dc69ba445d5dea67ca0e948bbf0aef9f1bc2fb").Bytes()), } req2 := &ResolveRequest{ t: tr, @@ -191,7 +191,7 @@ func TestTrieResolver(t *testing.T) { t: tr, resolveHex: keybytesToHex([]byte("bbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")), resolvePos: 2, // 1 bytes is 2 nibbles - resolveHash: hashNode(common.HexToHash("79d4d20420e467bc56adad82c454d68bc72ffbe7a26ad33028002bcbd1d41a05").Bytes()), + resolveHash: hashNode(common.HexToHash("df6fd126d62ec79182d9ab6f879b63dfacb9ce2e1cb765b17b9752de9c2cbaa7").Bytes()), } resolver := NewResolver(0, false, 0) resolver.AddRequest(req3) diff --git a/trie/structural_2.go b/trie/structural_2.go index 51e6204ad77b0f1efd803d105d73c333144d7566..c130611082231162e293aa0263241c767a3a19b1 100644 --- a/trie/structural_2.go +++ b/trie/structural_2.go @@ -18,6 +18,7 @@ package trie import ( "fmt" + "io" "math/big" "math/bits" @@ -155,47 +156,21 @@ func GenStructStep( return GenStructStep(fieldSet, hashOnly, true, newPrec, newCurr, succ, e, groups) } -// BytesTape is an abstraction for an input tape that allows reading binary strings ([]byte) sequentially -// To be used for keys and binary string values -type BytesTape interface { - // Returned slice is only valid until the next invocation of NextBytes() - // i.e. the underlying array/slice may be shared between invocations - Next() ([]byte, error) -} - -// Uint64Tape is an abstraction for an input tape that allows reading unsigned 64-bit integers sequentially -// To be used for nonces of the accounts -type Uint64Tape interface { - Next() (uint64, error) -} - -// BigIntTape is an abstraction for an input tape that allows reading *big.Int values sequentially -// To be used for balances of the accounts -type BigIntTape interface { - // Returned pointer is only valid until the next invocation of NextBitInt() - // i.e. the underlying big.Int object may be shared between invocation - Next() (*big.Int, error) -} - -// HashTape is an abstraction for an input table that allows reading 32-byte hashes (common.Hash) sequentially -// To be used for intermediate hashes in the Patricia Merkle tree -type HashTape interface { - Next() (common.Hash, error) -} - const hashStackStride = common.HashLength + 1 // + 1 byte for RLP encoding // HashBuilder implements the interface `structInfoReceiver` and opcodes that the structural information of the trie // is comprised of // DESCRIBED: docs/programmers_guide/guide.md#separation-of-keys-and-the-structure type HashBuilder struct { - keyTape BytesTape // the source of key sequence - valueTape BytesTape // the source of values (for values that are not accounts or contracts) - nonceTape Uint64Tape // the source of nonces for accounts and contracts (field 0) - balanceTape BigIntTape // the source of balances for accounts and contracts (field 1) - sSizeTape Uint64Tape // the source of storage sizes for contracts (field 4) - hashTape HashTape // the source of hashes - codeTape BytesTape // the source of bytecodes + keyTape BytesTape // the source of key sequence + valueTape RlpSerializableTape // the source of values (for values that are not accounts or contracts) + nonceTape Uint64Tape // the source of nonces for accounts and contracts (field 0) + balanceTape BigIntTape // the source of balances for accounts and contracts (field 1) + sSizeTape Uint64Tape // the source of storage sizes for contracts (field 4) + hashTape HashTape // the source of hashes + codeTape BytesTape // the source of bytecodes + + byteArrayWriter *ByteArrayWriter hashStack []byte // Stack of sub-slices, each 33 bytes each, containing RLP encodings of node hashes (or of nodes themselves, if shorter than 32 bytes) nodeStack []node // Stack of nodes @@ -206,7 +181,8 @@ type HashBuilder struct { // NewHashBuilder creates a new HashBuilder func NewHashBuilder() *HashBuilder { return &HashBuilder{ - sha: sha3.NewLegacyKeccak256().(keccakState), + sha: sha3.NewLegacyKeccak256().(keccakState), + byteArrayWriter: &ByteArrayWriter{}, } } @@ -216,7 +192,7 @@ func (hb *HashBuilder) SetKeyTape(keyTape BytesTape) { } // SetValueTape sets the value tape to be used by this builder (opcodes leaf and leafHash) -func (hb *HashBuilder) SetValueTape(valueTape BytesTape) { +func (hb *HashBuilder) SetValueTape(valueTape RlpSerializableTape) { hb.valueTape = valueTape } @@ -262,19 +238,18 @@ func (hb *HashBuilder) leaf(length int) error { if err != nil { return err } - s := &shortNode{Key: common.CopyBytes(key), Val: valueNode(common.CopyBytes(val))} + s := &shortNode{Key: common.CopyBytes(key), Val: valueNode(common.CopyBytes(val.RawBytes()))} hb.nodeStack = append(hb.nodeStack, s) return hb.leafHashWithKeyVal(key, val) } // To be called internally -func (hb *HashBuilder) leafHashWithKeyVal(key, val []byte) error { +func (hb *HashBuilder) leafHashWithKeyVal(key []byte, val RlpSerializable) error { var hash [hashStackStride]byte // RLP representation of hash (or of un-hashed value if short) // Compute the total length of binary representation var keyPrefix [1]byte - var valPrefix [4]byte var lenPrefix [4]byte - var kp, vp, kl, vl int + var kp, kl int // Write key var compactLen int var ni int @@ -301,62 +276,53 @@ func (hb *HashBuilder) leafHashWithKeyVal(key, val []byte) error { } else { kl = 1 } - if len(val) > 1 || val[0] >= rlp.EmptyStringCode { - vp = generateByteArrayLen(valPrefix[:], 0, len(val)) - vl = len(val) - } else { - vl = 1 - } - totalLen := kp + kl + vp + vl + + totalLen := kp + kl + val.DoubleRLPLen() pt := generateStructLen(lenPrefix[:], totalLen) - if pt+totalLen < common.HashLength { + + var writer io.Writer + var reader io.Reader + + if totalLen+pt < common.HashLength { // Embedded node - pos := 0 - copy(hash[pos:], lenPrefix[:pt]) - pos += pt - copy(hash[pos:], keyPrefix[:kp]) - pos += kp - hash[pos] = compact0 - pos++ - for i := 1; i < compactLen; i++ { - hash[pos] = key[ni]*16 + key[ni+1] - pos++ - ni += 2 - } - copy(hash[pos:], valPrefix[:vp]) - pos += vp - copy(hash[pos:], val) + hb.byteArrayWriter.Setup(hash[:], 0) + writer = hb.byteArrayWriter } else { hb.sha.Reset() - if _, err := hb.sha.Write(lenPrefix[:pt]); err != nil { - return err - } - if _, err := hb.sha.Write(keyPrefix[:kp]); err != nil { - return err - } - var b [1]byte - b[0] = compact0 - if _, err := hb.sha.Write(b[:]); err != nil { - return err - } - for i := 1; i < compactLen; i++ { - b[0] = key[ni]*16 + key[ni+1] - if _, err := hb.sha.Write(b[:]); err != nil { - return err - } - ni += 2 - } - if _, err := hb.sha.Write(valPrefix[:vp]); err != nil { - return err - } - if _, err := hb.sha.Write(val); err != nil { + writer = hb.sha + reader = hb.sha + } + + if _, err := writer.Write(lenPrefix[:pt]); err != nil { + return err + } + if _, err := writer.Write(keyPrefix[:kp]); err != nil { + return err + } + var b [1]byte + b[0] = compact0 + if _, err := writer.Write(b[:]); err != nil { + return err + } + for i := 1; i < compactLen; i++ { + b[0] = key[ni]*16 + key[ni+1] + if _, err := writer.Write(b[:]); err != nil { return err } + ni += 2 + } + + if err := val.ToDoubleRLP(writer); err != nil { + return err + } + + if reader != nil { hash[0] = rlp.EmptyStringCode + common.HashLength - if _, err := hb.sha.Read(hash[1:]); err != nil { + if _, err := reader.Read(hash[1:]); err != nil { return err } } + hb.hashStack = append(hb.hashStack, hash[:]...) if len(hb.hashStack) > hashStackStride*len(hb.nodeStack) { hb.nodeStack = append(hb.nodeStack, nil) diff --git a/trie/structural_test.go b/trie/structural_test.go index 05b9c959f49795896690eade663d3cd77c113687..fef2060750706486864c9327ba792d7d45878745 100644 --- a/trie/structural_test.go +++ b/trie/structural_test.go @@ -60,7 +60,7 @@ func TestV2HashBuilding(t *testing.T) { var curr OneBytesTape var valueTape OneBytesTape hb.SetKeyTape(&curr) - hb.SetValueTape(&valueTape) + hb.SetValueTape(NewRlpSerializableBytesTape(&valueTape)) var groups []uint16 for i, key := range keys { prec.Reset() @@ -134,7 +134,7 @@ func TestV2Resolution(t *testing.T) { var curr OneBytesTape var valueTape OneBytesTape hb.SetKeyTape(&curr) - hb.SetValueTape(&valueTape) + hb.SetValueTape(NewRlpSerializableBytesTape(&valueTape)) var groups []uint16 for _, key := range keys { prec.Reset() diff --git a/trie/tapes.go b/trie/tapes.go new file mode 100644 index 0000000000000000000000000000000000000000..d1a7dd36821ebb8ee7dc160ff158a69c394a8c70 --- /dev/null +++ b/trie/tapes.go @@ -0,0 +1,81 @@ +package trie + +import ( + "io" + "math/big" + + "github.com/ledgerwatch/turbo-geth/common" +) + +// BytesTape is an abstraction for an input tape that allows reading binary strings ([]byte) sequentially +// To be used for keys and binary string values +type BytesTape interface { + // Returned slice is only valid until the next invocation of NextBytes() + // i.e. the underlying array/slice may be shared between invocations + Next() ([]byte, error) +} + +// Uint64Tape is an abstraction for an input tape that allows reading unsigned 64-bit integers sequentially +// To be used for nonces of the accounts +type Uint64Tape interface { + Next() (uint64, error) +} + +// BigIntTape is an abstraction for an input tape that allows reading *big.Int values sequentially +// To be used for balances of the accounts +type BigIntTape interface { + // Returned pointer is only valid until the next invocation of NextBitInt() + // i.e. the underlying big.Int object may be shared between invocation + Next() (*big.Int, error) +} + +// HashTape is an abstraction for an input table that allows reading 32-byte hashes (common.Hash) sequentially +// To be used for intermediate hashes in the Patricia Merkle tree +type HashTape interface { + Next() (common.Hash, error) +} + +// FIXME: comments here +type RlpSerializableTape interface { + Next() (RlpSerializable, error) +} + +type RlpSerializable interface { + ToDoubleRLP(io.Writer) error + DoubleRLPLen() int + RawBytes() []byte +} + +func NewRlpSerializableBytesTape(inner BytesTape) RlpSerializableTape { + return &RlpSerializableBytesTape{inner} +} + +type RlpSerializableBytesTape struct { + inner BytesTape +} + +func (t *RlpSerializableBytesTape) Next() (RlpSerializable, error) { + value, err := t.inner.Next() + if err != nil { + return nil, err + } + + return RlpSerializableBytes(value), nil +} + +func NewRlpEncodedBytesTape(inner BytesTape) RlpSerializableTape { + return &RlpBytesTape{inner} +} + +type RlpBytesTape struct { + inner BytesTape +} + +func (t *RlpBytesTape) Next() (RlpSerializable, error) { + value, err := t.inner.Next() + if err != nil { + return nil, err + } + + return RlpEncodedBytes(value), nil +} diff --git a/trie/trie.go b/trie/trie.go index 09e25ca9415955e74bdf482ed418085fd183530a..46beecf5cfc7d34af29edeecd50039f20532e74c 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -46,6 +46,8 @@ type Trie struct { touchFunc func(hex []byte, del bool) + newHasherFunc func() *hasher + Version uint8 } @@ -57,7 +59,21 @@ type Trie struct { // not exist in the database. Accessing the trie loads nodes from db on demand. func New(root common.Hash) *Trie { trie := &Trie{ - touchFunc: func([]byte, bool) {}, + touchFunc: func([]byte, bool) {}, + newHasherFunc: func() *hasher { return newHasher( /*valueNodesRlpEncoded = */ false) }, + } + if (root != common.Hash{}) && root != EmptyRoot { + trie.root = hashNode(root[:]) + } + return trie +} + +// NewTestRLPTrie treats all the data provided to `Update` function as rlp-encoded. +// it is usually used for testing purposes. +func NewTestRLPTrie(root common.Hash) *Trie { + trie := &Trie{ + touchFunc: func([]byte, bool) {}, + newHasherFunc: func() *hasher { return newHasher( /*valueNodesRlpEncoded = */ true) }, } if (root != common.Hash{}) && root != EmptyRoot { trie.root = hashNode(root[:]) @@ -1080,7 +1096,7 @@ func (t *Trie) DeepHash(keyPrefix []byte) (bool, common.Hash) { accNode.Root = EmptyRoot accNode.hashCorrect = true } else { - h := newHasher(false) + h := t.newHasherFunc() defer returnHasherToPool(h) h.hash(accNode.storage, true, accNode.Root[:]) } @@ -1229,7 +1245,7 @@ func (t *Trie) hashRoot() (node, error) { if t.root == nil { return hashNode(EmptyRoot.Bytes()), nil } - h := newHasher(false) + h := t.newHasherFunc() defer returnHasherToPool(h) var hn common.Hash h.hash(t.root, true, hn[:])