From 4f4b395aa46b486332ce942a28345f93bb88cb13 Mon Sep 17 00:00:00 2001 From: Igor Mandrigin <mandrigin@users.noreply.github.com> Date: Tue, 24 Mar 2020 01:10:36 +0300 Subject: [PATCH] Introduce code node & get rid of code map (#398) * introduce code node * replace codeMap with code touches * fix a comment * fixups to tests * fix compile error * fix getnodedata tests * add tests and test stubs * add more test stubs * add test method bodies * add and fix more tests on trie for new codenode * add test change code between blocks * fix crash in stateless * remove unneded files * remove comment * fix deleted account code * fix resolve set builder for code nodes --- cmd/pics/state.go | 70 ++----- cmd/state/stateless/stateless.go | 5 +- core/state/database.go | 143 ++++++++++---- core/state/database_test.go | 48 +++++ core/state/stateless.go | 45 +++-- eth/handler.go | 3 +- trie/debug.go | 7 + trie/hashbuilder.go | 25 ++- trie/node.go | 5 + trie/resolve_set.go | 27 ++- trie/resolve_set_builder.go | 42 ++-- trie/resolver.go | 31 ++- trie/resolver_stateful.go | 14 ++ trie/resolver_stateless.go | 2 +- trie/resolver_stateless_test.go | 6 +- trie/trie.go | 96 ++++++++- trie/trie_from_witness.go | 28 ++- trie/trie_test.go | 324 +++++++++++++++++++++++++++++++ trie/trie_transform.go | 7 +- trie/trie_witness.go | 8 +- trie/visual.go | 21 +- trie/witness_builder.go | 16 +- trie/witness_builder_test.go | 8 +- 23 files changed, 774 insertions(+), 207 deletions(-) diff --git a/cmd/pics/state.go b/cmd/pics/state.go index e3ddaab0d3..799e7dc868 100644 --- a/cmd/pics/state.go +++ b/cmd/pics/state.go @@ -39,7 +39,7 @@ func constructCodeMap(tds *state.TrieDbState) (map[common.Hash][]byte, error) { return codeMap, nil } -func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number int, keyCompression int, codeCompressed bool, valCompressed bool, +func statePicture(t *trie.Trie, number int, keyCompression int, codeCompressed bool, valCompressed bool, quadTrie bool, quadColors bool, highlights [][]byte) (*trie.Trie, error) { filename := fmt.Sprintf("state_%d.dot", number) f, err := os.Create(filename) @@ -62,7 +62,6 @@ func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number int, keyC FontColors: fontColors, Values: true, CutTerminals: keyCompression, - CodeMap: codeMap, CodeCompressed: codeCompressed, ValCompressed: valCompressed, ValHex: true, @@ -332,14 +331,10 @@ func initialState1() error { return err } t := tds.Trie() - var codeMap map[common.Hash][]byte - if codeMap, err = constructCodeMap(tds); err != nil { + if _, err = statePicture(t, 0, 0, false, false, false, false, nil); err != nil { return err } - if _, err = statePicture(t, codeMap, 0, 0, false, false, false, false, nil); err != nil { - return err - } - if _, err = statePicture(t, codeMap, 1, 48, false, false, false, false, nil); err != nil { + if _, err = statePicture(t, 1, 48, false, false, false, false, nil); err != nil { return err } if err = stateDatabaseComparison(snapshotDb, dbBolt, 0); err != nil { @@ -460,13 +455,10 @@ func initialState1() error { if _, err = blockchain.InsertChain(context.Background(), types.Blocks{blocks[0]}); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } if err = stateDatabaseComparison(snapshotDb, dbBolt, 1); err != nil { return err } - if _, err = statePicture(t, codeMap, 2, 48, false, false, false, false, nil); err != nil { + if _, err = statePicture(t, 2, 48, false, false, false, false, nil); err != nil { return err } @@ -476,13 +468,10 @@ func initialState1() error { if _, err = blockchain.InsertChain(context.Background(), types.Blocks{blocks[1]}); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } if err = stateDatabaseComparison(snapshotDb, dbBolt, 2); err != nil { return err } - if _, err = statePicture(t, codeMap, 3, 48, false, false, false, false, nil); err != nil { + if _, err = statePicture(t, 3, 48, false, false, false, false, nil); err != nil { return err } @@ -492,19 +481,16 @@ func initialState1() error { if _, err = blockchain.InsertChain(context.Background(), types.Blocks{blocks[2]}); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } if err = stateDatabaseComparison(snapshotDb, dbBolt, 3); err != nil { return err } - if _, err = statePicture(t, codeMap, 4, 48, false, false, false, false, nil); err != nil { + if _, err = statePicture(t, 4, 48, false, false, false, false, nil); err != nil { return err } - if _, err = statePicture(t, codeMap, 5, 48, true, false, false, false, nil); err != nil { + if _, err = statePicture(t, 5, 48, true, false, false, false, nil); err != nil { return err } - if _, err = statePicture(t, codeMap, 6, 48, true, true, false, false, nil); err != nil { + if _, err = statePicture(t, 6, 48, true, true, false, false, nil); err != nil { return err } @@ -517,10 +503,7 @@ func initialState1() error { if err = stateDatabaseComparison(snapshotDb, dbBolt, 4); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } - if _, err = statePicture(t, codeMap, 7, 48, true, true, false, false, nil); err != nil { + if _, err = statePicture(t, 7, 48, true, true, false, false, nil); err != nil { return err } @@ -533,10 +516,7 @@ func initialState1() error { if err = stateDatabaseComparison(snapshotDb, dbBolt, 5); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } - if _, err = statePicture(t, codeMap, 8, 54, true, true, false, false, nil); err != nil { + if _, err = statePicture(t, 8, 54, true, true, false, false, nil); err != nil { return err } @@ -549,13 +529,10 @@ func initialState1() error { if err = stateDatabaseComparison(snapshotDb, dbBolt, 5); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } - if _, err = statePicture(t, codeMap, 9, 54, true, true, false, false, nil); err != nil { + if _, err = statePicture(t, 9, 54, true, true, false, false, nil); err != nil { return err } - if _, err = statePicture(t, codeMap, 10, 110, true, true, true, true, nil); err != nil { + if _, err = statePicture(t, 10, 110, true, true, true, true, nil); err != nil { return err } @@ -568,10 +545,7 @@ func initialState1() error { if err = stateDatabaseComparison(snapshotDb, dbBolt, 7); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } - quadTrie, err := statePicture(t, codeMap, 11, 110, true, true, true, true, nil) + quadTrie, err := statePicture(t, 11, 110, true, true, true, true, nil) if err != nil { return err } @@ -586,10 +560,7 @@ func initialState1() error { if err = stateDatabaseComparison(snapshotDb, dbBolt, 8); err != nil { return err } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } - if _, err = statePicture(t, codeMap, 12, 110, true, true, true, true, nil); err != nil { + if _, err = statePicture(t, 12, 110, true, true, true, true, nil); err != nil { return err } @@ -607,23 +578,18 @@ func initialState1() error { touchQuads = append(touchQuads, touchQuad) } - if codeMap, err = constructCodeMap(tds); err != nil { - return err - } - var witness *trie.Witness - if witness, err = quadTrie.ExtractWitness(0, false, rs, codeMap); err != nil { + if witness, err = quadTrie.ExtractWitness(0, false, rs); err != nil { return err } var witnessTrie *trie.Trie - var witnessCodeMap map[common.Hash][]byte - if witnessTrie, witnessCodeMap, err = trie.BuildTrieFromWitness(witness, false, false); err != nil { + if witnessTrie, err = trie.BuildTrieFromWitness(witness, false, false); err != nil { return err } - if _, err = statePicture(witnessTrie, witnessCodeMap, 13, 110, true, true, false /*already quad*/, true, touchQuads); err != nil { + if _, err = statePicture(witnessTrie, 13, 110, true, true, false /*already quad*/, true, touchQuads); err != nil { return err } @@ -632,7 +598,7 @@ func initialState1() error { } // Repeat the block witness illustration, but without any highlighted keys - if _, err = statePicture(witnessTrie, witnessCodeMap, 15, 110, true, true, false /*already quad*/, true, nil); err != nil { + if _, err = statePicture(witnessTrie, 15, 110, true, true, false /*already quad*/, true, nil); err != nil { return err } diff --git a/cmd/state/stateless/stateless.go b/cmd/state/stateless/stateless.go index 1bdaeda326..6e3c0de7f2 100644 --- a/cmd/state/stateless/stateless.go +++ b/cmd/state/stateless/stateless.go @@ -70,7 +70,7 @@ func runBlock(ibs *state.IntraBlockState, txnWriter state.StateWriter, blockWrit return nil } -func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number uint64) error { +func statePicture(t *trie.Trie, number uint64) error { filename := fmt.Sprintf("state_%d.dot", number) f, err := os.Create(filename) if err != nil { @@ -85,7 +85,6 @@ func statePicture(t *trie.Trie, codeMap map[common.Hash][]byte, number uint64) e FontColors: fontColors, Values: true, CutTerminals: 0, - CodeMap: codeMap, CodeCompressed: false, ValCompressed: false, ValHex: true, @@ -375,7 +374,7 @@ func Stateless( return } if _, ok := starkBlocks[blockNum-1]; ok { - err = statePicture(s.GetTrie(), s.GetCodeMap(), blockNum-1) + err = statePicture(s.GetTrie(), blockNum-1) check(err) } ibs := state.New(s) diff --git a/core/state/database.go b/core/state/database.go index 60f55deb76..824c0592ed 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -28,7 +28,6 @@ import ( "sync" "sync/atomic" - lru "github.com/hashicorp/golang-lru" "github.com/ledgerwatch/bolt" "github.com/ledgerwatch/turbo-geth/common" "github.com/ledgerwatch/turbo-geth/common/dbutils" @@ -95,6 +94,8 @@ func (nw *NoopWriter) CreateContract(address common.Address) error { // Structure holding updates, deletes, and reads registered within one change period // A change period can be transaction within a block, or a block within group of blocks type Buffer struct { + codeReads map[common.Hash]common.Hash + codeUpdates map[common.Hash][]byte storageUpdates map[common.Hash]map[common.Hash][]byte storageReads map[common.Hash]map[common.Hash]struct{} accountUpdates map[common.Hash]*accounts.Account @@ -105,6 +106,8 @@ type Buffer struct { // Prepares buffer for work or clears previous data func (b *Buffer) initialise() { + b.codeReads = make(map[common.Hash]common.Hash) + b.codeUpdates = make(map[common.Hash][]byte) b.storageUpdates = make(map[common.Hash]map[common.Hash][]byte) b.storageReads = make(map[common.Hash]map[common.Hash]struct{}) b.accountUpdates = make(map[common.Hash]*accounts.Account) @@ -124,6 +127,14 @@ func (b *Buffer) detachAccounts() { // Merges the content of another buffer into this one func (b *Buffer) merge(other *Buffer) { + for addrHash, codeHash := range other.codeReads { + b.codeReads[addrHash] = codeHash + } + + for addrHash, code := range other.codeUpdates { + b.codeUpdates[addrHash] = code + } + for addrHash, om := range other.storageUpdates { m, ok := b.storageUpdates[addrHash] if !ok { @@ -167,8 +178,6 @@ type TrieDbState struct { buffers []*Buffer aggregateBuffer *Buffer // Merge of all buffers currentBuffer *Buffer - codeCache *lru.Cache - codeSizeCache *lru.Cache historical bool noHistory bool resolveReads bool @@ -183,14 +192,6 @@ type TrieDbState struct { } func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieDbState, error) { - csc, err := lru.New(100000) - if err != nil { - return nil, err - } - cc, err := lru.New(10000) - if err != nil { - return nil, err - } t := trie.New(root) tp := trie.NewTriePruning(blockNr) @@ -199,8 +200,6 @@ func NewTrieDbState(root common.Hash, db ethdb.Database, blockNr uint64) (*TrieD tMu: new(sync.Mutex), db: db, blockNr: blockNr, - codeCache: cc, - codeSizeCache: csc, resolveSetBuilder: trie.NewResolveSetBuilder(), tp: tp, savePreimages: true, @@ -573,8 +572,6 @@ func (tds *TrieDbState) WithNewBuffer() *TrieDbState { buffers: buffers, aggregateBuffer: aggregateBuffer, currentBuffer: currentBuffer, - codeCache: tds.codeCache, - codeSizeCache: tds.codeSizeCache, historical: tds.historical, noHistory: tds.noHistory, resolveReads: tds.resolveReads, @@ -704,6 +701,10 @@ func (tds *TrieDbState) populateStorageBlockProof(storageTouches common.StorageK return nil } +func (tds *TrieDbState) buildCodeTouches(withReads bool) map[common.Hash]common.Hash { + return tds.aggregateBuffer.codeReads +} + // Builds a sorted list of all address hashes that were touched within the // period for which we are aggregating updates func (tds *TrieDbState) buildAccountTouches(withReads bool, withValues bool) (common.Hashes, []*accounts.Account) { @@ -748,9 +749,30 @@ func (tds *TrieDbState) buildAccountTouches(withReads bool, withValues bool) (co return accountTouches, aValues } +func (tds *TrieDbState) resolveCodeTouches(codeTouches map[common.Hash]common.Hash, resolveFunc trie.ResolveFunc) error { + firstRequest := true + for address, codeHash := range codeTouches { + if need, req := tds.t.NeedResolutonForCode(address, codeHash); need { + if tds.resolver == nil { + tds.resolver = trie.NewResolver(0, true, tds.blockNr) + tds.resolver.SetHistorical(tds.historical) + } else if firstRequest { + tds.resolver.Reset(0, true, tds.blockNr) + } + firstRequest = false + tds.resolver.AddCodeRequest(req) + } + } + + if !firstRequest { + return resolveFunc(tds.resolver) + } + return nil +} + // Expands the accounts trie (by loading data from the database) if it is required // for accessing accounts whose addresses are contained in the accountTouches -func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes, resolveFunc func(*trie.Resolver) error) error { +func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes, resolveFunc trie.ResolveFunc) error { var firstRequest = true for _, addrHash := range accountTouches { if need, req := tds.t.NeedResolution(nil, addrHash[:]); need { @@ -784,7 +806,7 @@ func (tds *TrieDbState) ExtractTouches() (accountTouches [][]byte, storageTouche return tds.resolveSetBuilder.ExtractTouches() } -func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc func(*trie.Resolver) error) error { +func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc trie.ResolveFunc) error { // Aggregating the current buffer, if any if tds.currentBuffer != nil { if tds.aggregateBuffer == nil { @@ -805,12 +827,20 @@ func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc func(*trie.Resolver // Prepare (resolve) accounts trie so that actual modifications can proceed without database access accountTouches, _ := tds.buildAccountTouches(tds.resolveReads, false) + + // Prepare (resolve) contract code reads so that actual modifications can proceed without database access + codeTouches := tds.buildCodeTouches(tds.resolveReads) + var err error if err = tds.resolveAccountTouches(accountTouches, resolveFunc); err != nil { return err } + if err = tds.resolveCodeTouches(codeTouches, resolveFunc); err != nil { + return err + } + if tds.resolveReads { tds.populateAccountBlockProof(accountTouches) } @@ -943,6 +973,13 @@ func (tds *TrieDbState) updateTrieRoots(forward bool) ([]common.Hash, error) { tds.t.UpdateAccount(addrHash[:], account) } else { tds.t.Delete(addrHash[:]) + delete(b.codeUpdates, addrHash) + } + } + + for addrHash, newCode := range b.codeUpdates { + if err := tds.t.UpdateAccountCode(addrHash[:], newCode); err != nil { + return nil, err } } for addrHash, m := range b.storageUpdates { @@ -1331,18 +1368,35 @@ func (tds *TrieDbState) ReadAccountStorage(address common.Address, incarnation u return enc, nil } +func (tds *TrieDbState) ReadCodeByHash(codeHash common.Hash) (code []byte, err error) { + if bytes.Equal(codeHash[:], emptyCodeHash) { + return nil, nil + } + + code, err = tds.db.Get(dbutils.CodeBucket, codeHash[:]) + if tds.resolveReads { + // we have to be careful, because the code might change + // during the block executuion, so we are always + // storing the latest code hash + tds.resolveSetBuilder.ReadCode(codeHash) + } + return code, err +} + func (tds *TrieDbState) ReadAccountCode(address common.Address, codeHash common.Hash) (code []byte, err error) { if bytes.Equal(codeHash[:], emptyCodeHash) { return nil, nil } - if cached, ok := tds.codeCache.Get(codeHash); ok { - code, err = cached.([]byte), nil + + addrHash, err := tds.HashAddress(address, false /*save*/) + if err != nil { + return nil, err + } + + if cached, ok := tds.t.GetAccountCode(addrHash[:]); ok { + code, err = cached, nil } else { code, err = tds.db.Get(dbutils.CodeBucket, codeHash[:]) - if err == nil { - tds.codeSizeCache.Add(codeHash, len(code)) - tds.codeCache.Add(codeHash, code) - } } if tds.resolveReads { addrHash, err1 := common.HashData(address[:]) @@ -1352,25 +1406,23 @@ func (tds *TrieDbState) ReadAccountCode(address common.Address, codeHash common. if _, ok := tds.currentBuffer.accountUpdates[addrHash]; !ok { tds.currentBuffer.accountReads[addrHash] = struct{}{} } - tds.resolveSetBuilder.ReadCode(codeHash, code) + // we have to be careful, because the code might change + // during the block executuion, so we are always + // storing the latest code hash + tds.currentBuffer.codeReads[addrHash] = codeHash + tds.resolveSetBuilder.ReadCode(codeHash) } return code, err } func (tds *TrieDbState) ReadAccountCodeSize(address common.Address, codeHash common.Hash) (codeSize int, err error) { - var code []byte - if cached, ok := tds.codeSizeCache.Get(codeHash); ok { - codeSize, err = cached.(int), nil - if tds.resolveReads { - if cachedCode, ok := tds.codeCache.Get(codeHash); ok { - code, err = cachedCode.([]byte), nil - } else { - code, err = tds.ReadAccountCode(address, codeHash) - if err != nil { - return 0, err - } - } - } + addrHash, err := tds.HashAddress(address, false /*save*/) + if err != nil { + return 0, err + } + + if code, ok := tds.t.GetAccountCode(addrHash[:]); ok { + codeSize, err = len(code), nil } else { code, err = tds.ReadAccountCode(address, codeHash) if err != nil { @@ -1386,7 +1438,11 @@ func (tds *TrieDbState) ReadAccountCodeSize(address common.Address, codeHash com if _, ok := tds.currentBuffer.accountUpdates[addrHash]; !ok { tds.currentBuffer.accountReads[addrHash] = struct{}{} } - tds.resolveSetBuilder.ReadCode(codeHash, code) + // we have to be careful, because the code might change + // during the block executuion, so we are always + // storing the latest code hash + tds.currentBuffer.codeReads[addrHash] = codeHash + tds.resolveSetBuilder.ReadCode(codeHash) } return codeSize, nil } @@ -1519,8 +1575,9 @@ func (tsw *TrieStateWriter) DeleteAccount(_ context.Context, address common.Addr func (tsw *TrieStateWriter) UpdateAccountCode(addrHash common.Hash, incarnation uint64, codeHash common.Hash, code []byte) error { if tsw.tds.resolveReads { - tsw.tds.resolveSetBuilder.CreateCode(codeHash, code) + tsw.tds.resolveSetBuilder.CreateCode(codeHash) } + tsw.tds.currentBuffer.codeUpdates[addrHash] = code return nil } @@ -1551,12 +1608,12 @@ func (tsw *TrieStateWriter) WriteAccountStorage(_ context.Context, address commo // ExtractWitness produces block witness for the block just been processed, in a serialised form func (tds *TrieDbState) ExtractWitness(trace bool, isBinary bool) (*trie.Witness, error) { - rs, codeMap := tds.resolveSetBuilder.Build(isBinary) + rs := tds.resolveSetBuilder.Build(isBinary) - return tds.makeBlockWitness(trace, rs, codeMap, isBinary) + return tds.makeBlockWitness(trace, rs, isBinary) } -func (tds *TrieDbState) makeBlockWitness(trace bool, rs *trie.ResolveSet, codeMap map[common.Hash][]byte, isBinary bool) (*trie.Witness, error) { +func (tds *TrieDbState) makeBlockWitness(trace bool, rs *trie.ResolveSet, isBinary bool) (*trie.Witness, error) { tds.tMu.Lock() defer tds.tMu.Unlock() @@ -1565,7 +1622,7 @@ func (tds *TrieDbState) makeBlockWitness(trace bool, rs *trie.ResolveSet, codeMa t = trie.HexToBin(tds.t).Trie() } - return t.ExtractWitness(tds.blockNr, trace, rs, codeMap) + return t.ExtractWitness(tds.blockNr, trace, rs) } func (tsw *TrieStateWriter) CreateContract(address common.Address) error { diff --git a/core/state/database_test.go b/core/state/database_test.go index 3c3d960e35..d23189d370 100644 --- a/core/state/database_test.go +++ b/core/state/database_test.go @@ -1370,3 +1370,51 @@ func TestClearTombstonesForReCreatedAccount(t *testing.T) { assert.Equal(expect, ok, k) } } + +func TestChangeAccountCodeBetweenBlocks(t *testing.T) { + contract := common.HexToAddress("0x71dd1027069078091B3ca48093B00E4735B20624") + + db := ethdb.NewMemDatabase() + tds, err := state.NewTrieDbState(common.Hash{}, db, 0) + if err != nil { + t.Errorf("could not create TrieDbState: %v", err) + } + tsw := tds.TrieStateWriter() + intraBlockState := state.New(tds) + ctx := context.Background() + // Start the 1st transaction + tds.StartNewBuffer() + intraBlockState.CreateAccount(contract, true) + + oldCode := []byte{0x01, 0x02, 0x03, 0x04} + + intraBlockState.SetCode(contract, oldCode) + intraBlockState.AddBalance(contract, big.NewInt(1000000000)) + if err = intraBlockState.FinalizeTx(ctx, tsw); err != nil { + t.Errorf("error finalising 1st tx: %v", err) + } + + tds.ComputeTrieRoots() + + oldCodeHash := common.BytesToHash(crypto.Keccak256(oldCode)) + + trieCode, err := tds.ReadAccountCode(contract, oldCodeHash) + assert.NoError(t, err, "you can receive the new code") + assert.Equal(t, oldCode, trieCode, "new code should be received") + + tds.StartNewBuffer() + + newCode := []byte{0x04, 0x04, 0x04, 0x04} + intraBlockState.SetCode(contract, newCode) + + if err = intraBlockState.FinalizeTx(ctx, tsw); err != nil { + t.Errorf("error finalising 1st tx: %v", err) + } + + tds.ComputeTrieRoots() + + newCodeHash := common.BytesToHash(crypto.Keccak256(newCode)) + trieCode, err = tds.ReadAccountCode(contract, newCodeHash) + assert.NoError(t, err, "you can receive the new code") + assert.Equal(t, newCode, trieCode, "new code should be received") +} diff --git a/core/state/stateless.go b/core/state/stateless.go index 6b5a8cfd21..720df5785e 100644 --- a/core/state/stateless.go +++ b/core/state/stateless.go @@ -34,7 +34,7 @@ import ( // during the execution of block(s) type Stateless struct { t *trie.Trie // State trie - codeMap map[common.Hash][]byte // Lookup index from code hashes to corresponding bytecode + codeUpdates map[common.Hash][]byte // Lookup index from code hashes to corresponding bytecode blockNr uint64 // Current block number storageUpdates map[common.Hash]map[common.Hash][]byte accountUpdates map[common.Hash]*accounts.Account @@ -47,10 +47,11 @@ type Stateless struct { // It deserialises the block witness and creates the state trie out of it, checking that the root of the constructed // state trie matches the value of `stateRoot` parameter func NewStateless(stateRoot common.Hash, blockWitness *trie.Witness, blockNr uint64, trace bool, isBinary bool) (*Stateless, error) { - t, codeMap, err := trie.BuildTrieFromWitness(blockWitness, isBinary, trace) + t, err := trie.BuildTrieFromWitness(blockWitness, isBinary, trace) if err != nil { return nil, err } + if !isBinary { if t.Hash() != stateRoot { filename := fmt.Sprintf("root_%d.txt", blockNr) @@ -64,7 +65,7 @@ func NewStateless(stateRoot common.Hash, blockWitness *trie.Witness, blockNr uin } return &Stateless{ t: t, - codeMap: codeMap, + codeUpdates: make(map[common.Hash][]byte), storageUpdates: make(map[common.Hash]map[common.Hash][]byte), accountUpdates: make(map[common.Hash]*accounts.Account), deleted: make(map[common.Hash]struct{}), @@ -113,7 +114,6 @@ func (s *Stateless) ReadAccountStorage(address common.Address, incarnation uint6 } // ReadAccountCode is a part of the StateReader interface -// This implementation looks the code up in the codeMap, failing if the code is not found. func (s *Stateless) ReadAccountCode(address common.Address, codeHash common.Hash) (code []byte, err error) { if bytes.Equal(codeHash[:], emptyCodeHash) { return nil, nil @@ -121,10 +121,20 @@ func (s *Stateless) ReadAccountCode(address common.Address, codeHash common.Hash if s.trace { fmt.Printf("Getting code for %x\n", codeHash) } - if code, ok := s.codeMap[codeHash]; ok { + + addrHash, err := common.HashData(address[:]) + if err != nil { + return nil, err + } + + if code, ok := s.codeUpdates[addrHash]; ok { + return code, nil + } + + if code, ok := s.t.GetAccountCode(addrHash[:]); ok { return code, nil } - return nil, fmt.Errorf("could not find bytecode for hash %x", codeHash) + return nil, fmt.Errorf("could not find bytecode for acc: %x hash %x", address, codeHash) } // ReadAccountCodeSize is a part of the StateReader interface @@ -134,9 +144,20 @@ func (s *Stateless) ReadAccountCodeSize(address common.Address, codeHash common. if bytes.Equal(codeHash[:], emptyCodeHash) { return 0, nil } - if code, ok := s.codeMap[codeHash]; ok { + + addrHash, err := common.HashData(address[:]) + if err != nil { + return 0, err + } + + if code, ok := s.codeUpdates[addrHash]; ok { + return len(code), nil + } + + if code, ok := s.t.GetAccountCode(addrHash[:]); ok { return len(code), nil } + return 0, fmt.Errorf("could not find bytecode for hash %x", codeHash) } @@ -172,9 +193,9 @@ func (s *Stateless) DeleteAccount(_ context.Context, address common.Address, ori // UpdateAccountCode is a part of the StateWriter interface // This implementation adds the code to the codeMap to make it available for further accesses func (s *Stateless) UpdateAccountCode(addrHash common.Hash, incarnation uint64, codeHash common.Hash, code []byte) error { - if _, ok := s.codeMap[codeHash]; !ok { - s.codeMap[codeHash] = code - } + + s.codeUpdates[codeHash] = code + if s.trace { fmt.Printf("Stateless: UpdateAccountCode %x codeHash %x\n", addrHash, codeHash) } @@ -311,7 +332,3 @@ func (s *Stateless) CheckRoot(expected common.Hash) error { func (s *Stateless) GetTrie() *trie.Trie { return s.t } - -func (s *Stateless) GetCodeMap() map[common.Hash][]byte { - return s.codeMap -} diff --git a/eth/handler.go b/eth/handler.go index b7c4e7038a..0b2c038fea 100644 --- a/eth/handler.go +++ b/eth/handler.go @@ -712,8 +712,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error { } // Now attempt to get the byte code - var zeroAddress common.Address - code, err := tds.ReadAccountCode(zeroAddress, hash) + code, err := tds.ReadCodeByHash(hash) if err == nil { data = append(data, code) bytes += len(code) diff --git a/trie/debug.go b/trie/debug.go index b9fb404cf7..d3ae43e2d4 100644 --- a/trie/debug.go +++ b/trie/debug.go @@ -252,6 +252,13 @@ func (n valueNode) print(w io.Writer) { fmt.Fprintf(w, "v(%x)", []byte(n)) } +func (n codeNode) fstring(ind string) string { + return fmt.Sprintf("code: %x ", []byte(n)) +} +func (n codeNode) print(w io.Writer) { + fmt.Fprintf(w, "code(%x)", []byte(n)) +} + func (an accountNode) fstring(ind string) string { encodedAccount := pool.GetBuffer(an.EncodingLengthForHashing()) an.EncodeForHashing(encodedAccount.B) diff --git a/trie/hashbuilder.go b/trie/hashbuilder.go index 55311b49e6..4929e4396a 100644 --- a/trie/hashbuilder.go +++ b/trie/hashbuilder.go @@ -1,6 +1,7 @@ package trie import ( + "bytes" "fmt" "io" "math/big" @@ -210,14 +211,26 @@ func (hb *HashBuilder) accountLeaf(length int, keyHex []byte, storageSize uint64 } popped++ } + var accountCode codeNode if fieldSet&uint32(8) != 0 { copy(hb.acc.CodeHash[:], hb.hashStack[len(hb.hashStack)-popped*hashStackStride-common.HashLength:len(hb.hashStack)-popped*hashStackStride]) + ok := false + if !bytes.Equal(hb.acc.CodeHash[:], EmptyCodeHash[:]) { + stackTop := hb.nodeStack[len(hb.nodeStack)-popped-1] + if stackTop != nil { // if we don't have any stack top it might be okay because we didn't resolve the code yet (stateful resolver) + // but if we have something on top of the stack that isn't `nil`, it has to be a codeNode + accountCode, ok = stackTop.(codeNode) + if !ok { + return fmt.Errorf("unexpected node type on the node stack, wanted codeNode, got %t:%s", stackTop, stackTop) + } + } + } popped++ } var accCopy accounts.Account accCopy.Copy(&hb.acc) - s := &shortNode{Key: common.CopyBytes(key), Val: &accountNode{accCopy, root, true}} + s := &shortNode{Key: common.CopyBytes(key), Val: &accountNode{accCopy, root, true, accountCode}} // this invocation will take care of the popping given number of items from both hash stack and node stack, // pushing resulting hash to the hash stack, and nil to the node stack if err = hb.accountLeafHashWithKey(key, popped); err != nil { @@ -520,23 +533,23 @@ func (hb *HashBuilder) hash(hash []byte) error { return nil } -func (hb *HashBuilder) code(code []byte) (common.Hash, error) { +func (hb *HashBuilder) code(code []byte) error { if hb.trace { fmt.Printf("CODE\n") } codeCopy := common.CopyBytes(code) - hb.nodeStack = append(hb.nodeStack, nil) + hb.nodeStack = append(hb.nodeStack, codeNode(codeCopy)) hb.sha.Reset() if _, err := hb.sha.Write(codeCopy); err != nil { - return common.Hash{}, err + return err } var hash [hashStackStride]byte // RLP representation of hash (or un-hashes value) hash[0] = 0x80 + common.HashLength if _, err := hb.sha.Read(hash[1:]); err != nil { - return common.Hash{}, err + return err } hb.hashStack = append(hb.hashStack, hash[:]...) - return common.BytesToHash(hash[1:]), nil + return nil } func (hb *HashBuilder) emptyRoot() { diff --git a/trie/node.go b/trie/node.go index 660863d24e..314d264411 100644 --- a/trie/node.go +++ b/trie/node.go @@ -62,7 +62,10 @@ type ( accounts.Account storage node rootCorrect bool + code codeNode } + + codeNode []byte ) // nilValueNode is used when collapsing internal trie nodes for hashing, since @@ -211,6 +214,7 @@ type nodeRef struct { func (n hashNode) reference() []byte { return n } func (n valueNode) reference() []byte { return nil } +func (n codeNode) reference() []byte { return nil } func (n *fullNode) reference() []byte { return n.ref.data[0:n.ref.len] } func (n *duoNode) reference() []byte { return n.ref.data[0:n.ref.len] } func (n *shortNode) reference() []byte { return n.ref.data[0:n.ref.len] } @@ -222,4 +226,5 @@ func (n duoNode) String() string { return n.fstring("") } func (n shortNode) String() string { return n.fstring("") } func (n hashNode) String() string { return n.fstring("") } func (n valueNode) String() string { return n.fstring("") } +func (n codeNode) String() string { return n.fstring("") } func (an accountNode) String() string { return an.fstring("") } diff --git a/trie/resolve_set.go b/trie/resolve_set.go index 655bf2d6db..2da87c149c 100644 --- a/trie/resolve_set.go +++ b/trie/resolve_set.go @@ -19,6 +19,8 @@ package trie import ( "bytes" "sort" + + "github.com/ledgerwatch/turbo-geth/common" ) // ResolveSet encapsulates the set of keys that are required to be fully available, or resolved @@ -26,20 +28,21 @@ import ( // pairs // DESCRIBED: docs/programmers_guide/guide.md#converting-sequence-of-keys-and-value-into-a-multiproof type ResolveSet struct { - inited bool // Whether keys are sorted and "LTE" and "GT" indices set - binary bool // if true, use binary encoding instead of Hex - minLength int // Mininum length of prefixes for which `HashOnly` function can return `true` - lteIndex int // Index of the "LTE" key in the keys slice. Next one is "GT" - hexes sortable + inited bool // Whether keys are sorted and "LTE" and "GT" indices set + binary bool // if true, use binary encoding instead of Hex + minLength int // Mininum length of prefixes for which `HashOnly` function can return `true` + lteIndex int // Index of the "LTE" key in the keys slice. Next one is "GT" + hexes sortable + codeTouches map[common.Hash]struct{} } // NewResolveSet creates new ResolveSet func NewResolveSet(minLength int) *ResolveSet { - return &ResolveSet{minLength: minLength} + return &ResolveSet{minLength: minLength, codeTouches: make(map[common.Hash]struct{})} } func NewBinaryResolveSet(minLength int) *ResolveSet { - return &ResolveSet{minLength: minLength, binary: true} + return &ResolveSet{minLength: minLength, codeTouches: make(map[common.Hash]struct{}), binary: true} } // AddKey adds a new key (in KEY encoding) to the set @@ -57,6 +60,16 @@ func (rs *ResolveSet) AddHex(hex []byte) { } } +// AddCodeTouch adds a new code touch into the resolve set +func (rs *ResolveSet) AddCodeTouch(codeHash common.Hash) { + rs.codeTouches[codeHash] = struct{}{} +} + +func (rs *ResolveSet) IsCodeTouched(codeHash common.Hash) bool { + _, ok := rs.codeTouches[codeHash] + return ok +} + func (rs *ResolveSet) ensureInited() { if rs.inited { return diff --git a/trie/resolve_set_builder.go b/trie/resolve_set_builder.go index 21f5e8a838..6e30530e89 100644 --- a/trie/resolve_set_builder.go +++ b/trie/resolve_set_builder.go @@ -6,17 +6,17 @@ import "github.com/ledgerwatch/turbo-geth/common" // the execution of a block. It also tracks the contract codes that were created and used during the execution // of a block type ResolveSetBuilder struct { - touches [][]byte // Read/change set of account keys (account hashes) - storageTouches [][]byte // Read/change set of storage keys (account hashes concatenated with storage key hashes) - proofCodes map[common.Hash][]byte // Contract codes that have been accessed - createdCodes map[common.Hash][]byte // Contract codes that were created (deployed) + touches [][]byte // Read/change set of account keys (account hashes) + storageTouches [][]byte // Read/change set of storage keys (account hashes concatenated with storage key hashes) + proofCodes map[common.Hash]struct{} // Contract codes that have been accessed (codeHash) + createdCodes map[common.Hash]struct{} // Contract codes that were created (deployed) (codeHash) } // NewResolveSetBuilder creates new ProofGenerator and initialised its maps func NewResolveSetBuilder() *ResolveSetBuilder { return &ResolveSetBuilder{ - proofCodes: make(map[common.Hash][]byte), - createdCodes: make(map[common.Hash][]byte), + proofCodes: make(map[common.Hash]struct{}), + createdCodes: make(map[common.Hash]struct{}), } } @@ -39,30 +39,30 @@ func (pg *ResolveSetBuilder) ExtractTouches() ([][]byte, [][]byte) { return touches, storageTouches } -// extractCodeMap returns the map of all contract codes that were required during the block's execution -// but were not created during that same block. It also clears the maps for the next block's execution -func (pg *ResolveSetBuilder) extractCodeMap() map[common.Hash][]byte { +// extractCodeTouches returns the set of all contract codes that were required during the block's execution +// but were not created during that same block. It also clears the set for the next block's execution +func (pg *ResolveSetBuilder) extractCodeTouches() map[common.Hash]struct{} { proofCodes := pg.proofCodes - pg.proofCodes = make(map[common.Hash][]byte) - pg.createdCodes = make(map[common.Hash][]byte) + pg.proofCodes = make(map[common.Hash]struct{}) + pg.createdCodes = make(map[common.Hash]struct{}) return proofCodes } // ReadCode registers that given contract code has been accessed during current block's execution -func (pg *ResolveSetBuilder) ReadCode(codeHash common.Hash, code []byte) { - if _, ok := pg.createdCodes[codeHash]; !ok { - pg.proofCodes[codeHash] = code +func (pg *ResolveSetBuilder) ReadCode(codeHash common.Hash) { + if _, ok := pg.proofCodes[codeHash]; !ok { + pg.proofCodes[codeHash] = struct{}{} } } // CreateCode registers that given contract code has been created (deployed) during current block's execution -func (pg *ResolveSetBuilder) CreateCode(codeHash common.Hash, code []byte) { +func (pg *ResolveSetBuilder) CreateCode(codeHash common.Hash) { if _, ok := pg.proofCodes[codeHash]; !ok { - pg.createdCodes[codeHash] = code + pg.createdCodes[codeHash] = struct{}{} } } -func (pg *ResolveSetBuilder) Build(isBinary bool) (*ResolveSet, CodeMap) { +func (pg *ResolveSetBuilder) Build(isBinary bool) *ResolveSet { var rs *ResolveSet if isBinary { rs = NewBinaryResolveSet(0) @@ -71,6 +71,7 @@ func (pg *ResolveSetBuilder) Build(isBinary bool) (*ResolveSet, CodeMap) { } touches, storageTouches := pg.ExtractTouches() + codeTouches := pg.extractCodeTouches() for _, touch := range touches { rs.AddKey(touch) @@ -78,6 +79,9 @@ func (pg *ResolveSetBuilder) Build(isBinary bool) (*ResolveSet, CodeMap) { for _, touch := range storageTouches { rs.AddKey(touch) } - codeMap := pg.extractCodeMap() - return rs, codeMap + for codeHash, _ := range codeTouches { + rs.AddCodeTouch(codeHash) + } + + return rs } diff --git a/trie/resolver.go b/trie/resolver.go index 1ca198f571..92715e45ef 100644 --- a/trie/resolver.go +++ b/trie/resolver.go @@ -13,6 +13,8 @@ import ( var emptyHash [32]byte +type ResolveFunc func(*Resolver) error + func (t *Trie) Rebuild(db ethdb.Database, blockNr uint64) error { if t.root == nil { return nil @@ -38,15 +40,17 @@ type Resolver struct { blockNr uint64 topLevels int // How many top levels of the trie to keep (not roll into hashes) requests []*ResolveRequest + codeRequests []*ResolveRequestForCode witnesses []*Witness // list of witnesses for resolved subtries, nil if `collectWitnesses` is false } func NewResolver(topLevels int, forAccounts bool, blockNr uint64) *Resolver { tr := Resolver{ - accounts: forAccounts, - requests: []*ResolveRequest{}, - blockNr: blockNr, - topLevels: topLevels, + accounts: forAccounts, + requests: []*ResolveRequest{}, + codeRequests: []*ResolveRequestForCode{}, + blockNr: blockNr, + topLevels: topLevels, } return &tr } @@ -56,6 +60,7 @@ func (tr *Resolver) Reset(topLevels int, forAccounts bool, blockNr uint64) { tr.accounts = forAccounts tr.blockNr = blockNr tr.requests = tr.requests[:0] + tr.codeRequests = tr.codeRequests[:0] tr.witnesses = nil tr.collectWitnesses = false tr.historical = false @@ -76,6 +81,11 @@ func (tr *Resolver) SetHistorical(h bool) { tr.historical = h } +// AddCodeRequest add a request for code resolution +func (tr *Resolver) AddCodeRequest(req *ResolveRequestForCode) { + tr.codeRequests = append(tr.codeRequests, req) +} + // Resolver implements sort.Interface // and sorts by resolve requests // (more general requests come first) @@ -152,7 +162,10 @@ func (tr *Resolver) ResolveStateful(db ethdb.Database, blockNr uint64) error { sort.Stable(tr) resolver := NewResolverStateful(tr.topLevels, tr.requests, hf) - return resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical) + if err := resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical); err != nil { + return err + } + return resolver.AttachRequestedCode(db, tr.codeRequests) } func (tr *Resolver) ResolveStatefulCached(db ethdb.Database, blockNr uint64) error { @@ -166,7 +179,10 @@ func (tr *Resolver) ResolveStatefulCached(db ethdb.Database, blockNr uint64) err sort.Stable(tr) resolver := NewResolverStatefulCached(tr.topLevels, tr.requests, hf) - return resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical) + if err := resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical); err != nil { + return err + } + return resolver.AttachRequestedCode(db, tr.codeRequests) } // ResolveStateless resolves and hooks subtries using a witnesses database instead of @@ -174,6 +190,7 @@ func (tr *Resolver) ResolveStatefulCached(db ethdb.Database, blockNr uint64) err func (tr *Resolver) ResolveStateless(db WitnessStorage, blockNr uint64, trieLimit uint32, startPos int64) (int64, error) { sort.Stable(tr) resolver := NewResolverStateless(tr.requests, hookSubtrie) + // we expect CodeNodes to be already attached to the trie in stateless resolution return resolver.RebuildTrie(db, blockNr, trieLimit, startPos) } @@ -212,7 +229,7 @@ func (tr *Resolver) extractWitnessAndHookSubtrie(currentReq *ResolveRequest, hbR tr.witnesses = make([]*Witness, 0) } - witness, err := extractWitnessFromRootNode(hbRoot, tr.blockNr, false /*tr.hb.trace*/, nil, nil) + witness, err := extractWitnessFromRootNode(hbRoot, tr.blockNr, false /*tr.hb.trace*/, nil) if err != nil { return fmt.Errorf("error while extracting witness for resolver: %w", err) } diff --git a/trie/resolver_stateful.go b/trie/resolver_stateful.go index d282c88d56..f5bd455f18 100644 --- a/trie/resolver_stateful.go +++ b/trie/resolver_stateful.go @@ -186,6 +186,20 @@ func (tr *ResolverStateful) RebuildTrie( return nil } +func (tr *ResolverStateful) AttachRequestedCode(db ethdb.Getter, requests []*ResolveRequestForCode) error { + for _, req := range requests { + codeHash := req.codeHash + code, err := db.Get(dbutils.CodeBucket, codeHash[:]) + if err != nil { + return err + } + if err := req.t.UpdateAccountCode(req.addrHash[:], codeNode(code)); err != nil { + return err + } + } + return nil +} + func (tr *ResolverStateful) WalkerAccounts(keyIdx int, k []byte, v []byte) error { return tr.Walker(true, keyIdx, k, v) } diff --git a/trie/resolver_stateless.go b/trie/resolver_stateless.go index 6f43bc8659..5fa5c18031 100644 --- a/trie/resolver_stateless.go +++ b/trie/resolver_stateless.go @@ -42,7 +42,7 @@ func (r *ResolverStateless) RebuildTrie(db WitnessStorage, blockNr uint64, trieL return 0, err } - trie, _, err := BuildTrieFromWitness(witness, false /*is-binary*/, false /*trace*/) + trie, err := BuildTrieFromWitness(witness, false /*is-binary*/, false /*trace*/) if err != nil { return 0, err } diff --git a/trie/resolver_stateless_test.go b/trie/resolver_stateless_test.go index 0e5e81392c..c04882f43e 100644 --- a/trie/resolver_stateless_test.go +++ b/trie/resolver_stateless_test.go @@ -35,17 +35,17 @@ func TestRebuildTrie(t *testing.T) { trie2 := buildTestTrie(10) trie3 := buildTestTrie(100) - w1, err := extractWitnessFromRootNode(trie1.root, 1, false, nil, nil) + w1, err := extractWitnessFromRootNode(trie1.root, 1, false, nil) if err != nil { t.Error(err) } - w2, err := extractWitnessFromRootNode(trie2.root, 1, false, nil, nil) + w2, err := extractWitnessFromRootNode(trie2.root, 1, false, nil) if err != nil { t.Error(err) } - w3, err := extractWitnessFromRootNode(trie3.root, 1, false, nil, nil) + w3, err := extractWitnessFromRootNode(trie3.root, 1, false, nil) if err != nil { t.Error(err) } diff --git a/trie/trie.go b/trie/trie.go index 7a1f2c87af..95275e9b44 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -26,7 +26,9 @@ import ( "github.com/ledgerwatch/turbo-geth/common/debug" "github.com/ledgerwatch/turbo-geth/core/types/accounts" "github.com/ledgerwatch/turbo-geth/crypto" + "github.com/ledgerwatch/turbo-geth/ethdb" "github.com/ledgerwatch/turbo-geth/log" + "github.com/pkg/errors" ) var ( @@ -131,6 +133,31 @@ func (t *Trie) GetAccount(key []byte) (value *accounts.Account, gotValue bool) { return nil, gotValue } +func (t *Trie) GetAccountCode(key []byte) (value []byte, gotValue bool) { + if t.root == nil { + return nil, true + } + + hex := keybytesToHex(key) + if t.binary { + hex = keyHexToBin(hex) + } + + accNode, gotValue := t.getAccount(t.root, hex, 0) + if accNode != nil { + if bytes.Equal(accNode.Account.CodeHash[:], EmptyCodeHash[:]) { + return nil, gotValue + } + + if accNode.code == nil { + return nil, false + } + + return accNode.code, gotValue + } + return nil, gotValue +} + func (t *Trie) getAccount(origNode node, key []byte, pos int) (value *accountNode, gotValue bool) { switch n := (origNode).(type) { case nil: @@ -247,20 +274,59 @@ func (t *Trie) UpdateAccount(key []byte, acc *accounts.Account) { if t.root == nil { var newnode node if value.Root == EmptyRoot || value.Root == (common.Hash{}) { - newnode = &shortNode{Key: hex, Val: &accountNode{*value, nil, true}} + newnode = &shortNode{Key: hex, Val: &accountNode{*value, nil, true, nil}} } else { - newnode = &shortNode{Key: hex, Val: &accountNode{*value, hashNode(value.Root[:]), true}} + newnode = &shortNode{Key: hex, Val: &accountNode{*value, hashNode(value.Root[:]), true, nil}} } t.root = newnode } else { if value.Root == EmptyRoot || value.Root == (common.Hash{}) { - _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, nil, true}) + _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, nil, true, nil}) } else { - _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, hashNode(value.Root[:]), true}) + _, t.root = t.insert(t.root, hex, 0, &accountNode{*value, hashNode(value.Root[:]), true, nil}) } } } +// UpdateAccountCode attaches the code node to an account at specified key +func (t *Trie) UpdateAccountCode(key []byte, code codeNode) error { + if t.root == nil { + return nil + } + + hex := keybytesToHex(key) + if t.binary { + hex = keyHexToBin(hex) + } + + accNode, gotValue := t.getAccount(t.root, hex, 0) + if accNode == nil || !gotValue { + return errors.Wrapf(ethdb.ErrKeyNotFound, "account not found with key: %x", key) + } + + actualCodeHash := crypto.Keccak256(code) + if !bytes.Equal(accNode.CodeHash[:], actualCodeHash) { + return fmt.Errorf("inserted code mismatch account hash (acc.CodeHash=%x codeHash=%x)", accNode.CodeHash[:], actualCodeHash) + } + + accNode.code = code + + _, t.root = t.insert(t.root, hex, 0, accNode) + return nil +} + +// ResolveRequestFor Code expresses the need to fetch code from the DB (by its hash) and attach +// to a specific account leaf in the trie. +type ResolveRequestForCode struct { + t *Trie + addrHash common.Hash // contract address hash + codeHash common.Hash +} + +func (rr *ResolveRequestForCode) String() string { + return fmt.Sprintf("rr_code{addrHash:%x,codeHash:%x}", rr.addrHash, rr.codeHash) +} + // ResolveRequest expresses the need to fetch a subtrie from the database. The location of this // subtrie is specified by the resolveHex[:resolvePos]. The remaining part of resolveHex (if present) // is useful to ensure that specific leaves of the trie are fully expanded (and not rolled into @@ -283,6 +349,8 @@ type ResolveRequest struct { NodeRLP []byte // [OUT] RLP of the resolved node } +/* add code hash there, if nil -- don't do anything with the code */ + // NewResolveRequest creates a new ResolveRequest. // contract must be either address hash + incarnation (32+8 bytes) or nil func (t *Trie) NewResolveRequest(contract []byte, hex []byte, pos int, resolveHash []byte) *ResolveRequest { @@ -293,6 +361,23 @@ func (rr *ResolveRequest) String() string { return fmt.Sprintf("rr{t:%x,resolveHex:%x,resolvePos:%d,resolveHash:%s}", rr.contract, rr.resolveHex, rr.resolvePos, rr.resolveHash) } +func (t *Trie) NewResolveRequestForCode(addrHash common.Hash, codeHash common.Hash) *ResolveRequestForCode { + return &ResolveRequestForCode{t, addrHash, codeHash} +} + +func (t *Trie) NeedResolutonForCode(addrHash common.Hash, codeHash common.Hash) (bool, *ResolveRequestForCode) { + if bytes.Equal(codeHash[:], EmptyCodeHash[:]) { + return false, nil + } + + _, ok := t.GetAccountCode(addrHash[:]) + if !ok { + return true, t.NewResolveRequestForCode(addrHash, codeHash) + } + + return false, nil +} + // NeedResolution determines whether the trie needs to be extended (resolved) by fetching data // from the database, if one were to access the key specified // In the case of "Yes", also returns a corresponding ResolveRequest @@ -380,6 +465,9 @@ func (t *Trie) insert(origNode node, key []byte, pos int, value node) (updated b if origNok && vnok { updated = !origAccN.Equals(&vAccN.Account) if updated { + if !bytes.Equal(origAccN.CodeHash[:], vAccN.CodeHash[:]) { + origAccN.code = nil + } origAccN.Account.Copy(&vAccN.Account) origAccN.rootCorrect = false } diff --git a/trie/trie_from_witness.go b/trie/trie_from_witness.go index 184ceec652..47d1f2093c 100644 --- a/trie/trie_from_witness.go +++ b/trie/trie_from_witness.go @@ -4,12 +4,10 @@ import ( "fmt" "math/big" - "github.com/ledgerwatch/turbo-geth/common" "github.com/ledgerwatch/turbo-geth/trie/rlphacks" ) -func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, CodeMap, error) { - codeMap := make(map[common.Hash][]byte) +func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, error) { hb := NewHashBuilder(false) for _, operator := range witness.Operators { switch op := operator.(type) { @@ -20,38 +18,36 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C keyHex := op.Key val := op.Value if err := hb.leaf(len(op.Key), keyHex, rlphacks.RlpSerializableBytes(val)); err != nil { - return nil, nil, err + return nil, err } case *OperatorExtension: if trace { fmt.Printf("EXTENSION ") } if err := hb.extension(op.Key); err != nil { - return nil, nil, err + return nil, err } case *OperatorBranch: if trace { fmt.Printf("BRANCH ") } if err := hb.branch(uint16(op.Mask)); err != nil { - return nil, nil, err + return nil, err } case *OperatorHash: if trace { fmt.Printf("HASH ") } if err := hb.hash(op.Hash[:]); err != nil { - return nil, nil, err + return nil, err } case *OperatorCode: if trace { fmt.Printf("CODE ") } - if codeHash, err := hb.code(op.Code); err == nil { - codeMap[codeHash] = op.Code - } else { - return nil, nil, err + if err := hb.code(op.Code); err != nil { + return nil, err } case *OperatorLeafAccount: @@ -74,7 +70,7 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C incarnaton := uint64(0) if err := hb.accountLeaf(len(op.Key), op.Key, 0, balance, nonce, incarnaton, fieldSet); err != nil { - return nil, nil, err + return nil, err } case *OperatorEmptyRoot: if trace { @@ -82,7 +78,7 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C } hb.emptyRoot() default: - return nil, nil, fmt.Errorf("unknown operand type: %T", operator) + return nil, fmt.Errorf("unknown operand type: %T", operator) } } if trace { @@ -90,9 +86,9 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C } if !hb.hasRoot() { if isBinary { - return NewBinary(EmptyRoot), nil, nil + return NewBinary(EmptyRoot), nil } - return New(EmptyRoot), nil, nil + return New(EmptyRoot), nil } r := hb.root() var tr *Trie @@ -102,5 +98,5 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C tr = New(hb.rootHash()) } tr.root = r - return tr, codeMap, nil + return tr, nil } diff --git a/trie/trie_test.go b/trie/trie_test.go index c7f44f01b7..5066b925a0 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -566,3 +566,327 @@ func TestHashMapLeak(t *testing.T) { assert.GreaterOrEqual(t, nHashes, nExpected*7/8) assert.LessOrEqual(t, nHashes, nExpected*9/8) } + +func genRandomByteArrayOfLen(length uint) []byte { + array := make([]byte, length) + for i := uint(0); i < length; i++ { + array[i] = byte(rand.Intn(256)) + } + return array +} + +func getAddressForIndex(index int) [20]byte { + var address [20]byte + binary.BigEndian.PutUint32(address[:], uint32(index)) + return address +} + +func TestCodeNodeValid(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + numberOfAccounts := 20 + + addresses := make([][20]byte, numberOfAccounts) + for i := 0; i < len(addresses); i++ { + addresses[i] = getAddressForIndex(i) + } + codeValues := make([][]byte, len(addresses)) + for i := 0; i < len(addresses); i++ { + codeValues[i] = genRandomByteArrayOfLen(128) + codeHash := common.BytesToHash(crypto.Keccak256(codeValues[i])) + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash + + trie.UpdateAccount(crypto.Keccak256(addresses[i][:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(addresses[i][:]), codeValues[i]) + assert.Nil(t, err, "should successfully insert code") + } + + for i := 0; i < len(addresses); i++ { + value, gotValue := trie.GetAccountCode(crypto.Keccak256(addresses[i][:])) + assert.True(t, gotValue, "should receive code value") + assert.True(t, bytes.Equal(value, codeValues[i]), "should receive the right code") + } +} + +func TestCodeNodeUpdateNotExisting(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + codeValue := genRandomByteArrayOfLen(128) + + codeHash := common.BytesToHash(crypto.Keccak256(codeValue)) + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue) + assert.Nil(t, err, "should successfully insert code") + + nonExistingAddress := getAddressForIndex(9999) + codeValue2 := genRandomByteArrayOfLen(128) + + err = trie.UpdateAccountCode(crypto.Keccak256(nonExistingAddress[:]), codeValue2) + assert.Error(t, err, "should return an error for non existing acc") +} + +func TestCodeNodeGetNotExistingAccount(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + codeValue := genRandomByteArrayOfLen(128) + + codeHash := common.BytesToHash(crypto.Keccak256(codeValue)) + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue) + assert.Nil(t, err, "should successfully insert code") + + nonExistingAddress := getAddressForIndex(9999) + + value, gotValue := trie.GetAccountCode(crypto.Keccak256(nonExistingAddress[:])) + assert.True(t, gotValue, "should indicate that account doesn't exist at all (not just hashed)") + assert.Nil(t, value, "the value should be nil") +} + +func TestCodeNodeGetHashedAccount(t *testing.T) { + trie := newEmpty() + + address := getAddressForIndex(0) + + fakeAccount := genRandomByteArrayOfLen(50) + fakeAccountHash := common.BytesToHash(crypto.Keccak256(fakeAccount)) + + hex := keybytesToHex(crypto.Keccak256(address[:])) + + _, trie.root = trie.insert(trie.root, hex, 0, hashNode(fakeAccountHash[:])) + + value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:])) + assert.False(t, gotValue, "should indicate that account exists but hashed") + assert.Nil(t, value, "the value should be nil") +} + +func TestCodeNodeGetExistingAccountNoCodeNotEmpty(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + codeValue := genRandomByteArrayOfLen(128) + + codeHash := common.BytesToHash(crypto.Keccak256(codeValue)) + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + + value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:])) + assert.False(t, gotValue, "should indicate that account exists with code but the code isn't in cache") + assert.Nil(t, value, "the value should be nil") +} + +func TestCodeNodeGetExistingAccountEmptyCode(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + + codeHash := EmptyCodeHash + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + + value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:])) + assert.True(t, gotValue, "should indicate that account exists with empty code") + assert.Nil(t, value, "the value should be nil") +} + +func TestCodeNodeWrongHash(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + + codeValue1 := genRandomByteArrayOfLen(128) + codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1)) + + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash1 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + + codeValue2 := genRandomByteArrayOfLen(128) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue2) + assert.Error(t, err, "should NOT be able to insert code with wrong hash") +} + +func TestCodeNodeUpdateAccountAndCodeValidHash(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + + codeValue1 := genRandomByteArrayOfLen(128) + codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1)) + + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash1 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1) + assert.Nil(t, err, "should successfully insert code") + + codeValue2 := genRandomByteArrayOfLen(128) + codeHash2 := common.BytesToHash(crypto.Keccak256(codeValue2)) + + acc.CodeHash = codeHash2 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err = trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue2) + assert.Nil(t, err, "should successfully insert code") +} + +func TestCodeNodeUpdateAccountAndCodeInvalidHash(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + + codeValue1 := genRandomByteArrayOfLen(128) + codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1)) + + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash1 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1) + assert.Nil(t, err, "should successfully insert code") + + codeValue2 := genRandomByteArrayOfLen(128) + codeHash2 := common.BytesToHash(crypto.Keccak256(codeValue2)) + + codeValue3 := genRandomByteArrayOfLen(128) + + acc.CodeHash = codeHash2 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err = trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue3) + assert.Error(t, err, "should NOT be able to insert code with wrong hash") +} + +func TestCodeNodeUpdateAccountChangeCodeHash(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + + codeValue1 := genRandomByteArrayOfLen(128) + codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1)) + + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash1 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1) + assert.Nil(t, err, "should successfully insert code") + + codeValue2 := genRandomByteArrayOfLen(128) + codeHash2 := common.BytesToHash(crypto.Keccak256(codeValue2)) + + acc.CodeHash = codeHash2 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:])) + assert.Nil(t, value, "the value should reset after the code change happen") + assert.False(t, gotValue, "should indicate that the code isn't in the cache") +} + +func TestCodeNodeUpdateAccountNoChangeCodeHash(t *testing.T) { + trie := newEmpty() + + random := rand.New(rand.NewSource(0)) + + address := getAddressForIndex(0) + + codeValue1 := genRandomByteArrayOfLen(128) + codeHash1 := common.BytesToHash(crypto.Keccak256(codeValue1)) + + balance := new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + + acc := accounts.NewAccount() + acc.Nonce = uint64(random.Int63()) + acc.Balance = *balance + acc.Root = EmptyRoot + acc.CodeHash = codeHash1 + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + err := trie.UpdateAccountCode(crypto.Keccak256(address[:]), codeValue1) + assert.Nil(t, err, "should successfully insert code") + + acc.Nonce = uint64(random.Int63()) + balance = new(big.Int).Rand(random, new(big.Int).Exp(common.Big2, common.Big256, nil)) + acc.Balance = *balance + + trie.UpdateAccount(crypto.Keccak256(address[:]), &acc) + value, gotValue := trie.GetAccountCode(crypto.Keccak256(address[:])) + assert.Equal(t, value, codeValue1, "the value should NOT reset after account's non codehash had changed") + assert.True(t, gotValue, "should indicate that the code is still in the cache") +} diff --git a/trie/trie_transform.go b/trie/trie_transform.go index e8d71be077..d89929ea9f 100644 --- a/trie/trie_transform.go +++ b/trie/trie_transform.go @@ -18,7 +18,12 @@ func transformSubTrie(nd node, hex []byte, newTrie *Trie, transformFunc keyTrans case *accountNode: accountCopy := accounts.NewAccount() accountCopy.Copy(&n.Account) - _, newTrie.root = newTrie.insert(newTrie.root, transformFunc(hex), 0, &accountNode{accountCopy, nil, true}) + var code []byte = nil + if n.code != nil { + code = make([]byte, len(n.code)) + copy(code, n.code) + } + _, newTrie.root = newTrie.insert(newTrie.root, transformFunc(hex), 0, &accountNode{accountCopy, nil, true, codeNode(code)}) aHex := hex if aHex[len(aHex)-1] == 16 { aHex = aHex[:len(aHex)-1] diff --git a/trie/trie_witness.go b/trie/trie_witness.go index 43eb2d6fea..520b107f30 100644 --- a/trie/trie_witness.go +++ b/trie/trie_witness.go @@ -1,14 +1,14 @@ package trie -func (t *Trie) ExtractWitness(blockNr uint64, trace bool, rs *ResolveSet, codeMap CodeMap) (*Witness, error) { - return extractWitnessFromRootNode(t.root, blockNr, trace, rs, codeMap) +func (t *Trie) ExtractWitness(blockNr uint64, trace bool, rs *ResolveSet) (*Witness, error) { + return extractWitnessFromRootNode(t.root, blockNr, trace, rs) } // extractWitnessFromRootNode extracts a witness for a subtrie starting from the specified root // if hashOnly param is nil it will make a witness for the full subtrie, // if hashOnly param is set to a ResolveSet instance, it will make a witness for only the accounts/storages that were actually touched; other paths will be hashed. -func extractWitnessFromRootNode(root node, blockNr uint64, trace bool, hashOnly HashOnly, codeMap CodeMap) (*Witness, error) { - builder := NewWitnessBuilder(root, blockNr, trace, codeMap) +func extractWitnessFromRootNode(root node, blockNr uint64, trace bool, hashOnly HashOnly) (*Witness, error) { + builder := NewWitnessBuilder(root, blockNr, trace) var limiter *MerklePathLimiter if hashOnly != nil { hr := newHasher(false) diff --git a/trie/visual.go b/trie/visual.go index e08a30a350..862ab2212b 100644 --- a/trie/visual.go +++ b/trie/visual.go @@ -30,16 +30,15 @@ import ( // VisualOpts contains various configuration options fo the Visual function // It has been introduced as a replacement for too many arguments with options type VisualOpts struct { - Highlights [][]byte // Collection of keys, in the HEX encoding, that need to be highlighted with digits - IndexColors []string // Array of colors for representing digits as colored boxes - FontColors []string // Array of colors, the same length as indexColors, for the textual digits inside the coloured boxes - CutTerminals int // Specifies how many digits to cut from the terminal short node keys for a more convinient display - Values bool // Whether to display value nodes (as box with rounded corners) - CodeCompressed bool // Whether to turn the code from a large rectange to a small square for a more convinient display - ValCompressed bool // Whether long values (over 10 characters) are shortened using ... in the middle - ValHex bool // Whether values should be displayed as hex numbers (otherwise they are displayed as just strings) - SameLevel bool // Whether the leaves (and hashes) need to be on the same horizontal level - CodeMap map[common.Hash][]byte // Map that allows looking up bytecode of contracts by the bytecode's hash + Highlights [][]byte // Collection of keys, in the HEX encoding, that need to be highlighted with digits + IndexColors []string // Array of colors for representing digits as colored boxes + FontColors []string // Array of colors, the same length as indexColors, for the textual digits inside the coloured boxes + CutTerminals int // Specifies how many digits to cut from the terminal short node keys for a more convinient display + Values bool // Whether to display value nodes (as box with rounded corners) + CodeCompressed bool // Whether to turn the code from a large rectange to a small square for a more convinient display + ValCompressed bool // Whether long values (over 10 characters) are shortened using ... in the middle + ValHex bool // Whether values should be displayed as hex numbers (otherwise they are displayed as just strings) + SameLevel bool // Whether the leaves (and hashes) need to be on the same horizontal level } // Visual creates visualisation of trie with highlighting @@ -108,7 +107,7 @@ func visualNode(nd node, hex []byte, w io.Writer, highlights [][]byte, opts *Vis `n_%x -> e_%x; `, hex, accountHex) if !a.IsEmptyCodeHash() { - if code, ok := opts.CodeMap[a.CodeHash]; ok { + if code := a.code; code != nil { codeHex := keybytesToHex(code) codeHex = codeHex[:len(codeHex)-1] visual.HexBox(w, fmt.Sprintf("c_%x", accountHex), codeHex, 32, opts.CodeCompressed, false) diff --git a/trie/witness_builder.go b/trie/witness_builder.go index af7251b691..5706076733 100644 --- a/trie/witness_builder.go +++ b/trie/witness_builder.go @@ -10,6 +10,7 @@ import ( type HashNodeFunc func(node, bool, []byte) (int, error) type HashOnly interface { HashOnly([]byte) bool + IsCodeTouched(common.Hash) bool Current() []byte } @@ -18,22 +19,18 @@ type MerklePathLimiter struct { HashFunc HashNodeFunc } -type CodeMap map[common.Hash][]byte - type WitnessBuilder struct { root node blockNr uint64 trace bool - codeMap CodeMap operands []WitnessOperator } -func NewWitnessBuilder(root node, blockNr uint64, trace bool, codeMap CodeMap) *WitnessBuilder { +func NewWitnessBuilder(root node, blockNr uint64, trace bool) *WitnessBuilder { return &WitnessBuilder{ root: root, blockNr: blockNr, trace: trace, - codeMap: codeMap, operands: make([]WitnessOperator, 0), } } @@ -158,17 +155,16 @@ func (b *WitnessBuilder) addEmptyRoot() error { return nil } -func (b *WitnessBuilder) processAccountCode(n *accountNode) error { +func (b *WitnessBuilder) processAccountCode(n *accountNode, hashOnly HashOnly) error { if n.IsEmptyRoot() && n.IsEmptyCodeHash() { return nil } - code, ok := b.codeMap[n.CodeHash] - if !ok { + if n.code == nil || !hashOnly.IsCodeTouched(n.CodeHash) { return b.addHashOp(hashNode(n.CodeHash[:])) } - return b.addCodeOp(code) + return b.addCodeOp(n.code) } func (b *WitnessBuilder) processAccountStorage(n *accountNode, hex []byte, limiter *MerklePathLimiter) error { @@ -188,7 +184,7 @@ func (b *WitnessBuilder) makeBlockWitness( nd node, hex []byte, limiter *MerklePathLimiter, force bool) error { processAccountNode := func(key []byte, storageKey []byte, n *accountNode) error { - if err := b.processAccountCode(n); err != nil { + if err := b.processAccountCode(n, limiter.HashOnly); err != nil { return err } if err := b.processAccountStorage(n, storageKey, limiter); err != nil { diff --git a/trie/witness_builder_test.go b/trie/witness_builder_test.go index 9aa492cc86..d319b49b70 100644 --- a/trie/witness_builder_test.go +++ b/trie/witness_builder_test.go @@ -18,7 +18,7 @@ func TestBlockWitnessBinary(t *testing.T) { rs := NewBinaryResolveSet(2) rs.AddKey([]byte("ABCD0001")) - bwb := NewWitnessBuilder(trBin.Trie().root, 1, false, nil) + bwb := NewWitnessBuilder(trBin.Trie().root, 1, false) hr := newHasher(false) defer returnHasherToPool(hr) @@ -29,7 +29,7 @@ func TestBlockWitnessBinary(t *testing.T) { t.Errorf("Could not make block witness: %v", err) } - trBin1, _, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/) + trBin1, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/) if err != nil { t.Errorf("Could not restore trie from the block witness: %v", err) } @@ -57,7 +57,7 @@ func TestBlockWitnessBinaryAccount(t *testing.T) { rs := NewBinaryResolveSet(2) rs.AddKey([]byte("ABCD0001")) - bwb := NewWitnessBuilder(trBin.Trie().root, 1, false, nil) + bwb := NewWitnessBuilder(trBin.Trie().root, 1, false) hr := newHasher(false) defer returnHasherToPool(hr) @@ -68,7 +68,7 @@ func TestBlockWitnessBinaryAccount(t *testing.T) { t.Errorf("Could not make block witness: %v", err) } - trBin1, _, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/) + trBin1, err := BuildTrieFromWitness(w, true /*is-binary*/, false /*trace*/) if err != nil { t.Errorf("Could not restore trie from the block witness: %v", err) } -- GitLab