diff --git a/cmd/state/commands/stateless.go b/cmd/state/commands/stateless.go index 4566b477583a067daad5695a676dcfe3f4615a79..46896a9b15092a09f03b244027d05f179d3c21a6 100644 --- a/cmd/state/commands/stateless.go +++ b/cmd/state/commands/stateless.go @@ -7,16 +7,18 @@ import ( ) var ( - statefile string - triesize uint32 - preroot bool - snapshotInterval uint64 - snapshotFrom uint64 - witnessInterval uint64 - noverify bool - bintries bool - starkBlocksFile string - starkStatsBase string + statefile string + triesize uint32 + preroot bool + snapshotInterval uint64 + snapshotFrom uint64 + witnessInterval uint64 + noverify bool + bintries bool + starkBlocksFile string + starkStatsBase string + statelessResolver bool + witnessDatabase string ) func init() { @@ -34,6 +36,11 @@ func init() { statelessCmd.Flags().BoolVar(&bintries, "bintries", false, "use binary tries instead of hexary to generate/load block witnesses") statelessCmd.Flags().StringVar(&starkBlocksFile, "starkBlocksFile", "", "file with the list of blocks for which to produce stark data") statelessCmd.Flags().StringVar(&starkStatsBase, "starkStatsBase", "stark_stats", "template for names of the files to write stark stats in") + statelessCmd.Flags().BoolVar(&statelessResolver, "statelessResolver", false, "use a witness DB instead of the state when resolving tries") + statelessCmd.Flags().StringVar(&witnessDatabase, "witnessDbFile", "", "optional path to a database where to store witnesses (empty string -- do not store witnesses") + if err := statelessCmd.MarkFlagFilename("witnessDbFile", ""); err != nil { + panic(err) + } rootCmd.AddCommand(statelessCmd) @@ -63,6 +70,8 @@ var statelessCmd = &cobra.Command{ createDb, starkBlocksFile, starkStatsBase, + statelessResolver, + witnessDatabase, ) return nil diff --git a/cmd/state/stateless/stateless.go b/cmd/state/stateless/stateless.go index c14b5288e795eac9177ed17c0ef97b14c2ff385a..fa813fdaee3d728a5c9a88cf19604dc4e58950b3 100644 --- a/cmd/state/stateless/stateless.go +++ b/cmd/state/stateless/stateless.go @@ -3,6 +3,7 @@ package stateless import ( "bytes" "context" + "encoding/csv" "fmt" "io/ioutil" "os" @@ -152,7 +153,9 @@ func Stateless( binary bool, createDb CreateDbFunc, starkBlocksFile string, - starkStatsBase string) { + starkStatsBase string, + useStatelessResolver bool, + witnessDatabasePath string) { state.MaxTrieCacheGen = triesize startTime := time.Now() @@ -229,6 +232,40 @@ func Stateless( defer func() { fmt.Printf("stoppped at block number: %d\n", blockNum) }() + var witnessDBWriter *WitnessDBWriter + var witnessDBReader *WitnessDBReader + + if useStatelessResolver && witnessDatabasePath == "" { + panic("to use stateless resolver, set the witness DB path") + } + + if witnessDatabasePath != "" { + var db ethdb.Database + db, err = createDb(witnessDatabasePath) + check(err) + defer db.Close() + + if useStatelessResolver { + witnessDBReader = NewWitnessDBReader(db) + fmt.Printf("Will use the stateless resolver with DB: %s\n", witnessDatabasePath) + } else { + statsFilePath := fmt.Sprintf("%v.stats.csv", witnessDatabasePath) + + var file *os.File + file, err = os.OpenFile(statsFilePath, os.O_RDWR|os.O_CREATE, os.ModePerm) + check(err) + defer file.Close() + + statsFileCsv := csv.NewWriter(file) + defer statsFileCsv.Flush() + + witnessDBWriter, err = NewWitnessDBWriter(db, statsFileCsv) + check(err) + fmt.Printf("witnesses will be stored to a db at path: %s\n\tstats: %s\n", witnessDatabasePath, statsFilePath) + } + + } + for !interrupt { trace := blockNum == 50111 // false // blockNum == 545080 tds.SetResolveReads(blockNum >= witnessThreshold) @@ -269,10 +306,25 @@ func Stateless( return } - if err = tds.ResolveStateTrie(); err != nil { - fmt.Printf("Failed to resolve state trie: %v\n", err) - return + if witnessDBReader != nil { + tds.SetBlockNr(blockNum) + err = tds.ResolveStateTrieStateless(witnessDBReader) + if err != nil { + fmt.Printf("Failed to statelessly resolve state trie: %v\n", err) + return + } + } else { + var resolveWitnesses []*trie.Witness + if resolveWitnesses, err = tds.ResolveStateTrie(witnessDBWriter != nil); err != nil { + fmt.Printf("Failed to resolve state trie: %v\n", err) + return + } + + if len(resolveWitnesses) > 0 { + witnessDBWriter.MustUpsert(blockNum, state.MaxTrieCacheGen, resolveWitnesses) + } } + blockWitness = nil if blockNum >= witnessThreshold { // Witness has to be extracted before the state trie is modified diff --git a/cmd/state/stateless/witness_db.go b/cmd/state/stateless/witness_db.go new file mode 100644 index 0000000000000000000000000000000000000000..c643902ce9b2f38c08c95e54f8ac787bec5189cd --- /dev/null +++ b/cmd/state/stateless/witness_db.go @@ -0,0 +1,84 @@ +package stateless + +import ( + "bytes" + "encoding/binary" + "encoding/csv" + "fmt" + + "github.com/ledgerwatch/turbo-geth/ethdb" + "github.com/ledgerwatch/turbo-geth/trie" +) + +var ( + witnessesBucket = []byte("witnesses") +) + +type WitnessDBWriter struct { + putter ethdb.Putter + statsWriter *csv.Writer +} + +func NewWitnessDBWriter(putter ethdb.Putter, statsWriter *csv.Writer) (*WitnessDBWriter, error) { + err := statsWriter.Write([]string{ + "blockNum", "maxTrieSize", "witnessesSize", + }) + if err != nil { + return nil, err + } + return &WitnessDBWriter{putter, statsWriter}, nil +} + +func (db *WitnessDBWriter) MustUpsert(blockNumber uint64, maxTrieSize uint32, resolveWitnesses []*trie.Witness) { + key := deriveDbKey(blockNumber, maxTrieSize) + + var buf bytes.Buffer + + for i, witness := range resolveWitnesses { + if _, err := witness.WriteTo(&buf); err != nil { + panic(fmt.Errorf("error while writing witness to a buffer: %w", err)) + } + if i < len(resolveWitnesses)-1 { + buf.WriteByte(byte(trie.OpNewTrie)) + } + } + + bytes := buf.Bytes() + err := db.putter.Put(witnessesBucket, key, bytes) + + if err != nil { + panic(fmt.Errorf("error while upserting witness: %w", err)) + } + + err = db.statsWriter.Write([]string{ + fmt.Sprintf("%v", blockNumber), + fmt.Sprintf("%v", maxTrieSize), + fmt.Sprintf("%v", len(bytes)), + }) + + if err != nil { + panic(fmt.Errorf("error while writing stats: %w", err)) + } +} + +type WitnessDBReader struct { + getter ethdb.Getter +} + +func NewWitnessDBReader(getter ethdb.Getter) *WitnessDBReader { + return &WitnessDBReader{getter} +} + +func (db *WitnessDBReader) GetWitnessesForBlock(blockNumber uint64, maxTrieSize uint32) ([]byte, error) { + key := deriveDbKey(blockNumber, maxTrieSize) + return db.getter.Get(witnessesBucket, key) +} + +func deriveDbKey(blockNumber uint64, maxTrieSize uint32) []byte { + buffer := make([]byte, 8+4) + + binary.LittleEndian.PutUint64(buffer[:], blockNumber) + binary.LittleEndian.PutUint32(buffer[8:], maxTrieSize) + + return buffer +} diff --git a/cmd/state/stateless/witness_db_test.go b/cmd/state/stateless/witness_db_test.go new file mode 100644 index 0000000000000000000000000000000000000000..37002c8392ea1c58b6bda1f8d86f7f3c9c55f044 --- /dev/null +++ b/cmd/state/stateless/witness_db_test.go @@ -0,0 +1,24 @@ +package stateless + +import ( + "encoding/binary" + "testing" +) + +func TestDeriveDbKey(t *testing.T) { + for i := 0; i < 1000000; i += 1003 { + for j := 0; j < 10000; j += 111 { + key := deriveDbKey(uint64(i), uint32(j)) + + block := binary.LittleEndian.Uint64(key[:8]) + if uint64(i) != block { + t.Errorf("cant unmarshall a block number from key; expected: %v got: %v", i, block) + } + + limit := binary.LittleEndian.Uint32(key[8:]) + if uint32(j) != limit { + t.Errorf("cant unmarshall a limit from key; expected: %v got: %v", j, limit) + } + } + } +} diff --git a/core/state/database.go b/core/state/database.go index 5bc08a57a31e1b08186091aa53dd1ff66ff21833..a84ab5e673fead64b173047a2aa345efe14e841f 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -356,7 +356,7 @@ func (tds *TrieDbState) LastRoot() common.Hash { // ComputeTrieRoots is a combination of `ResolveStateTrie` and `UpdateStateTrie` // DESCRIBED: docs/programmers_guide/guide.md#organising-ethereum-state-into-a-merkle-tree func (tds *TrieDbState) ComputeTrieRoots() ([]common.Hash, error) { - if err := tds.ResolveStateTrie(); err != nil { + if _, err := tds.ResolveStateTrie(false); err != nil { return nil, err } return tds.UpdateStateTrie() @@ -432,7 +432,7 @@ func (tds *TrieDbState) buildStorageTouches(withReads bool, withValues bool) (co // Expands the storage tries (by loading data from the database) if it is required // for accessing storage slots containing in the storageTouches map -func (tds *TrieDbState) resolveStorageTouches(storageTouches common.StorageKeys) error { +func (tds *TrieDbState) resolveStorageTouches(storageTouches common.StorageKeys, resolveFunc func(*trie.Resolver) error) error { var resolver *trie.Resolver for _, storageKey := range storageTouches { if need, req := tds.t.NeedResolution(storageKey[:common.HashLength], storageKey[:]); need { @@ -443,12 +443,7 @@ func (tds *TrieDbState) resolveStorageTouches(storageTouches common.StorageKeys) resolver.AddRequest(req) } } - if resolver != nil { - if err := resolver.ResolveWithDb(tds.db, tds.blockNr); err != nil { - return err - } - } - return nil + return resolveFunc(resolver) } // Populate pending block proof so that it will be sufficient for accessing all storage slots in storageTouches @@ -505,7 +500,7 @@ func (tds *TrieDbState) buildAccountTouches(withReads bool, withValues bool) (co // Expands the accounts trie (by loading data from the database) if it is required // for accessing accounts whose addresses are contained in the accountTouches -func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes) error { +func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes, resolveFunc func(*trie.Resolver) error) error { var resolver *trie.Resolver for _, addrHash := range accountTouches { if need, req := tds.t.NeedResolution(nil, addrHash[:]); need { @@ -516,13 +511,7 @@ func (tds *TrieDbState) resolveAccountTouches(accountTouches common.Hashes) erro resolver.AddRequest(req) } } - if resolver != nil { - if err := resolver.ResolveWithDb(tds.db, tds.blockNr); err != nil { - return err - } - resolver = nil - } - return nil + return resolveFunc(resolver) } func (tds *TrieDbState) populateAccountBlockProof(accountTouches common.Hashes) { @@ -539,9 +528,7 @@ func (tds *TrieDbState) ExtractTouches() (accountTouches [][]byte, storageTouche return tds.resolveSetBuilder.ExtractTouches() } -// ResolveStateTrie resolves parts of the state trie that would be necessary for any updates -// (and reads, if `resolveReads` is set). -func (tds *TrieDbState) ResolveStateTrie() error { +func (tds *TrieDbState) resolveStateTrieWithFunc(resolveFunc func(*trie.Resolver) error) error { // Aggregating the current buffer, if any if tds.currentBuffer != nil { if tds.aggregateBuffer == nil { @@ -562,7 +549,9 @@ func (tds *TrieDbState) ResolveStateTrie() error { // Prepare (resolve) accounts trie so that actual modifications can proceed without database access accountTouches, _ := tds.buildAccountTouches(tds.resolveReads, false) - if err := tds.resolveAccountTouches(accountTouches); err != nil { + var err error + + if err = tds.resolveAccountTouches(accountTouches, resolveFunc); err != nil { return err } @@ -570,9 +559,10 @@ func (tds *TrieDbState) ResolveStateTrie() error { tds.populateAccountBlockProof(accountTouches) } - if err := tds.resolveStorageTouches(storageTouches); err != nil { + if err = tds.resolveStorageTouches(storageTouches, resolveFunc); err != nil { return err } + if tds.resolveReads { if err := tds.populateStorageBlockProof(storageTouches); err != nil { return err @@ -581,6 +571,64 @@ func (tds *TrieDbState) ResolveStateTrie() error { return nil } +// ResolveStateTrie resolves parts of the state trie that would be necessary for any updates +// (and reads, if `resolveReads` is set). +func (tds *TrieDbState) ResolveStateTrie(extractWitnesses bool) ([]*trie.Witness, error) { + var witnesses []*trie.Witness + + resolveFunc := func(resolver *trie.Resolver) error { + if resolver == nil { + return nil + } + resolver.CollectWitnesses(extractWitnesses) + if err := resolver.ResolveWithDb(tds.db, tds.blockNr); err != nil { + return err + } + + if !extractWitnesses { + return nil + } + + resolverWitnesses := resolver.PopCollectedWitnesses() + if len(resolverWitnesses) == 0 { + return nil + } + + if witnesses == nil { + witnesses = resolverWitnesses + } else { + witnesses = append(witnesses, resolverWitnesses...) + } + + return nil + } + if err := tds.resolveStateTrieWithFunc(resolveFunc); err != nil { + return nil, err + } + + return witnesses, nil +} + +// ResolveStateTrieStateless uses a witness DB to resolve subtries +func (tds *TrieDbState) ResolveStateTrieStateless(database trie.WitnessStorage) error { + var startPos int64 + resolveFunc := func(resolver *trie.Resolver) error { + if resolver == nil { + return nil + } + + pos, err := resolver.ResolveStateless(database, tds.blockNr, MaxTrieCacheGen, startPos) + if err != nil { + return err + } + + startPos = pos + return nil + } + + return tds.resolveStateTrieWithFunc(resolveFunc) +} + // CalcTrieRoots calculates trie roots without modifying the state trie func (tds *TrieDbState) CalcTrieRoots(trace bool) (common.Hash, error) { tds.tMu.Lock() @@ -831,7 +879,7 @@ func (tds *TrieDbState) UnwindTo(blockNr uint64) error { }); err != nil { return err } - if err := tds.ResolveStateTrie(); err != nil { + if _, err := tds.ResolveStateTrie(false); err != nil { return err } diff --git a/miner/worker.go b/miner/worker.go index c40ae9878f2fdb66ca7c6d509f9955711f4dab83..e634ff8de853dfe8651ed85d68ee6ad67e5bb3ac 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -1031,7 +1031,7 @@ func NewBlock(engine consensus.Engine, s *state.IntraBlockState, tds *state.Trie return nil, err } - if err = tds.ResolveStateTrie(); err != nil { + if _, err = tds.ResolveStateTrie(false); err != nil { return nil, err } diff --git a/trie/resolver.go b/trie/resolver.go index 681ec277489a3a449bf30197e037b211af8ad17d..4fb3b5de418a8cbb188e43bc8add926c14867c21 100644 --- a/trie/resolver.go +++ b/trie/resolver.go @@ -3,15 +3,11 @@ package trie import ( "bytes" "fmt" - "runtime/debug" "sort" - "strings" - "github.com/ledgerwatch/turbo-geth/common/dbutils" - "github.com/ledgerwatch/turbo-geth/core/types/accounts" + "github.com/ledgerwatch/turbo-geth/common" "github.com/ledgerwatch/turbo-geth/ethdb" "github.com/ledgerwatch/turbo-geth/log" - "github.com/ledgerwatch/turbo-geth/trie/rlphacks" ) var emptyHash [32]byte @@ -35,37 +31,36 @@ func (t *Trie) Rebuild(db ethdb.Database, blockNr uint64) error { // One resolver per trie (prefix). // See also ResolveRequest in trie.go type Resolver struct { - accounts bool // Is this a resolver for accounts or for storage - topLevels int // How many top levels of the trie to keep (not roll into hashes) - requests []*ResolveRequest - reqIndices []int // Indices pointing back to request slice from slices returned by PrepareResolveParams - keyIdx int - currentReq *ResolveRequest // Request currently being handled - currentRs *ResolveSet // ResolveSet currently being used - historical bool - blockNr uint64 - hb *HashBuilder - fieldSet uint32 // fieldSet for the next invocation of genStructStep - rss []*ResolveSet - curr bytes.Buffer // Current key for the structure generation algorithm, as well as the input tape for the hash builder - succ bytes.Buffer - value bytes.Buffer // Current value to be used as the value tape for the hash builder - groups []uint16 - a accounts.Account + accounts bool // Is this a resolver for accounts or for storage + requests []*ResolveRequest + historical bool + blockNr uint64 + collectWitnesses bool // if true, stores witnesses for all the subtries that are being resolved + witnesses []*Witness // list of witnesses for resolved subtries, nil if `collectWitnesses` is false + topLevels int // How many top levels of the trie to keep (not roll into hashes) } func NewResolver(topLevels int, forAccounts bool, blockNr uint64) *Resolver { tr := Resolver{ - accounts: forAccounts, - topLevels: topLevels, - requests: []*ResolveRequest{}, - reqIndices: []int{}, - blockNr: blockNr, - hb: NewHashBuilder(false), + accounts: forAccounts, + requests: []*ResolveRequest{}, + blockNr: blockNr, + topLevels: topLevels, } return &tr } +func (tr *Resolver) CollectWitnesses(c bool) { + tr.collectWitnesses = c +} + +// PopCollectedWitnesses returns all the collected witnesses and clears the storage in this resolver +func (tr *Resolver) PopCollectedWitnesses() []*Witness { + result := tr.witnesses + tr.witnesses = nil + return result +} + func (tr *Resolver) SetHistorical(h bool) { tr.historical = h } @@ -113,104 +108,6 @@ func (tr *Resolver) Print() { } } -// PrepareResolveParams prepares information for the MultiWalk -func (tr *Resolver) PrepareResolveParams() ([][]byte, []uint) { - // Remove requests strictly contained in the preceding ones - startkeys := [][]byte{} - fixedbits := []uint{} - tr.rss = nil - if len(tr.requests) == 0 { - return startkeys, fixedbits - } - sort.Stable(tr) - var prevReq *ResolveRequest - for i, req := range tr.requests { - if prevReq == nil || - !bytes.Equal(req.contract, prevReq.contract) || - !bytes.Equal(req.resolveHex[:req.resolvePos], prevReq.resolveHex[:prevReq.resolvePos]) { - - tr.reqIndices = append(tr.reqIndices, i) - pLen := len(req.contract) - key := make([]byte, pLen+32) - copy(key[:], req.contract) - decodeNibbles(req.resolveHex[:req.resolvePos], key[pLen:]) - startkeys = append(startkeys, key) - req.extResolvePos = req.resolvePos + 2*pLen - fixedbits = append(fixedbits, uint(4*req.extResolvePos)) - prevReq = req - var minLength int - if req.resolvePos >= tr.topLevels { - minLength = 0 - } else { - minLength = tr.topLevels - req.resolvePos - } - rs := NewResolveSet(minLength) - tr.rss = append(tr.rss, rs) - rs.AddHex(req.resolveHex[req.resolvePos:]) - } else { - rs := tr.rss[len(tr.rss)-1] - rs.AddHex(req.resolveHex[req.resolvePos:]) - } - } - tr.currentReq = tr.requests[tr.reqIndices[0]] - tr.currentRs = tr.rss[0] - return startkeys, fixedbits -} - -func (tr *Resolver) finaliseRoot() error { - tr.curr.Reset() - tr.curr.Write(tr.succ.Bytes()) - tr.succ.Reset() - if tr.curr.Len() > 0 { - var err error - var data GenStructStepData - if tr.fieldSet == 0 { - data = GenStructStepLeafData{Value: rlphacks.RlpSerializableBytes(tr.value.Bytes())} - } else { - data = GenStructStepAccountData{ - FieldSet: tr.fieldSet, - StorageSize: tr.a.StorageSize, - Balance: &tr.a.Balance, - Nonce: tr.a.Nonce, - Incarnation: tr.a.Incarnation, - } - } - tr.groups, err = GenStructStep(tr.currentRs.HashOnly, tr.curr.Bytes(), tr.succ.Bytes(), tr.hb, data, tr.groups) - if err != nil { - return err - } - } - if tr.hb.hasRoot() { - hbRoot := tr.hb.root() - hbHash := tr.hb.rootHash() - - if tr.currentReq.RequiresRLP { - hasher := newHasher(false) - defer returnHasherToPool(hasher) - h, err := hasher.hashChildren(hbRoot, 0) - if err != nil { - return err - } - tr.currentReq.NodeRLP = h - } - var hookKey []byte - if tr.currentReq.contract == nil { - hookKey = tr.currentReq.resolveHex[:tr.currentReq.resolvePos] - } else { - contractHex := keybytesToHex(tr.currentReq.contract) - contractHex = contractHex[:len(contractHex)-1-16] // Remove terminal nibble and incarnation bytes - hookKey = append(contractHex, tr.currentReq.resolveHex[:tr.currentReq.resolvePos]...) - } - //fmt.Printf("hookKey: %x, %s\n", hookKey, hbRoot.fstring("")) - tr.currentReq.t.hook(hookKey, hbRoot) - if len(tr.currentReq.resolveHash) > 0 && !bytes.Equal(tr.currentReq.resolveHash, hbHash[:]) { - return fmt.Errorf("mismatching hash: %s %x for prefix %x, resolveHex %x, resolvePos %d", - tr.currentReq.resolveHash, hbHash, tr.currentReq.contract, tr.currentReq.resolveHex, tr.currentReq.resolvePos) - } - } - return nil -} - // Various values of the account field set const ( AccountFieldNonceOnly uint32 = 0x01 @@ -224,114 +121,71 @@ const ( AccountFieldSetContractWithSize uint32 = 0x1f // Bits 0-4 are set for nonce, balance, storageRoot, codeHash and storageSize ) -// Walker - k, v - shouldn't be reused in the caller's code -func (tr *Resolver) Walker(keyIdx int, k []byte, v []byte) error { - //fmt.Printf("keyIdx: %d key:%x value:%x, accounts: %t\n", keyIdx, k, v, tr.accounts) - if keyIdx != tr.keyIdx { - if err := tr.finaliseRoot(); err != nil { +// ResolveWithDb resolves and hooks subtries using a state database. +func (tr *Resolver) ResolveWithDb(db ethdb.Database, blockNr uint64) error { + var hf hookFunction + if tr.collectWitnesses { + hf = tr.extractWitnessAndHookSubtrie + } else { + hf = hookSubtrie + } + + sort.Stable(tr) + resolver := NewResolverStateful(tr.topLevels, tr.requests, hf) + return resolver.RebuildTrie(db, blockNr, tr.accounts, tr.historical) +} + +// ResolveStateless resolves and hooks subtries using a witnesses database instead of +// the state DB. +func (tr *Resolver) ResolveStateless(db WitnessStorage, blockNr uint64, trieLimit uint32, startPos int64) (int64, error) { + sort.Stable(tr) + resolver := NewResolverStateless(tr.requests, hookSubtrie) + return resolver.RebuildTrie(db, blockNr, trieLimit, startPos) +} + +func hookSubtrie(currentReq *ResolveRequest, hbRoot node, hbHash common.Hash) error { + if currentReq.RequiresRLP { + hasher := newHasher(false) + defer returnHasherToPool(hasher) + h, err := hasher.hashChildren(hbRoot, 0) + if err != nil { return err } - tr.hb.Reset() - tr.groups = nil - tr.keyIdx = keyIdx - tr.currentReq = tr.requests[tr.reqIndices[keyIdx]] - tr.currentRs = tr.rss[keyIdx] - tr.curr.Reset() + currentReq.NodeRLP = h } - if len(v) > 0 { - tr.curr.Reset() - tr.curr.Write(tr.succ.Bytes()) - tr.succ.Reset() - skip := tr.currentReq.extResolvePos // how many first nibbles to skip - i := 0 - for _, b := range k { - if i >= skip { - tr.succ.WriteByte(b / 16) - } - i++ - if i >= skip { - tr.succ.WriteByte(b % 16) - } - i++ - } - tr.succ.WriteByte(16) - if tr.curr.Len() > 0 { - var err error - var data GenStructStepData - if tr.fieldSet == 0 { - data = GenStructStepLeafData{Value: rlphacks.RlpSerializableBytes(tr.value.Bytes())} - } else { - data = GenStructStepAccountData{ - FieldSet: tr.fieldSet, - StorageSize: tr.a.StorageSize, - Balance: &tr.a.Balance, - Nonce: tr.a.Nonce, - Incarnation: tr.a.Incarnation, - } - } - tr.groups, err = GenStructStep(tr.currentRs.HashOnly, tr.curr.Bytes(), tr.succ.Bytes(), tr.hb, data, tr.groups) - if err != nil { - return err - } - } - // Remember the current key and value - if tr.accounts { - if err := tr.a.DecodeForStorage(v); err != nil { - return err - } - if tr.a.IsEmptyCodeHash() && tr.a.IsEmptyRoot() { - tr.fieldSet = AccountFieldSetNotContract - } else { - if tr.a.HasStorageSize { - tr.fieldSet = AccountFieldSetContractWithSize - } else { - tr.fieldSet = AccountFieldSetContract - } - // the first item ends up deepest on the stack, the seccond item - on the top - if err := tr.hb.hash(tr.a.CodeHash); err != nil { - return err - } - if err := tr.hb.hash(tr.a.Root); err != nil { - return err - } - } - } else { - tr.value.Reset() - tr.value.Write(v) - tr.fieldSet = AccountFieldSetNotAccount - } + + var hookKey []byte + if currentReq.contract == nil { + hookKey = currentReq.resolveHex[:currentReq.resolvePos] + } else { + contractHex := keybytesToHex(currentReq.contract) + contractHex = contractHex[:len(contractHex)-1-16] // Remove terminal nibble and incarnation bytes + hookKey = append(contractHex, currentReq.resolveHex[:currentReq.resolvePos]...) } + + //fmt.Printf("hookKey: %x, %s\n", hookKey, hbRoot.fstring("")) + currentReq.t.hook(hookKey, hbRoot) + if len(currentReq.resolveHash) > 0 && !bytes.Equal(currentReq.resolveHash, hbHash[:]) { + return fmt.Errorf("mismatching hash: %s %x for prefix %x, resolveHex %x, resolvePos %d", + currentReq.resolveHash, hbHash, currentReq.contract, currentReq.resolveHex, currentReq.resolvePos) + } + return nil } -func (tr *Resolver) ResolveWithDb(db ethdb.Database, blockNr uint64) error { - startkeys, fixedbits := tr.PrepareResolveParams() - var err error - if db == nil { - var b strings.Builder - fmt.Fprintf(&b, "ResolveWithDb(db=nil), tr.accounts: %t\n", tr.accounts) - for i, sk := range startkeys { - fmt.Fprintf(&b, "sk %x, bits: %d\n", sk, fixedbits[i]) - } - return fmt.Errorf("Unexpected resolution: %s at %s", b.String(), debug.Stack()) - } - if tr.accounts { - if tr.historical { - err = db.MultiWalkAsOf(dbutils.AccountsBucket, dbutils.AccountsHistoryBucket, startkeys, fixedbits, blockNr+1, tr.Walker) - } else { - err = db.MultiWalk(dbutils.AccountsBucket, startkeys, fixedbits, tr.Walker) - } - } else { - if tr.historical { - err = db.MultiWalkAsOf(dbutils.StorageBucket, dbutils.StorageHistoryBucket, startkeys, fixedbits, blockNr+1, tr.Walker) - } else { - err = db.MultiWalk(dbutils.StorageBucket, startkeys, fixedbits, tr.Walker) - } +func (tr *Resolver) extractWitnessAndHookSubtrie(currentReq *ResolveRequest, hbRoot node, hbHash common.Hash) error { + if tr.witnesses == nil { + tr.witnesses = make([]*Witness, 0) } + + witness, err := extractWitnessFromRootNode(hbRoot, tr.blockNr, false /*tr.hb.trace*/, nil, nil) if err != nil { - return err + return fmt.Errorf("error while extracting witness for resolver: %w", err) } - return tr.finaliseRoot() + + tr.witnesses = append(tr.witnesses, witness) + + return hookSubtrie(currentReq, hbRoot, hbHash) } func (t *Trie) rebuildHashes(db ethdb.Database, key []byte, pos int, blockNr uint64, accounts bool, expected hashNode) error { diff --git a/trie/resolver_stateful.go b/trie/resolver_stateful.go new file mode 100644 index 0000000000000000000000000000000000000000..a562ffaccefa114ec9aaed35e93b2ae144235abc --- /dev/null +++ b/trie/resolver_stateful.go @@ -0,0 +1,250 @@ +package trie + +import ( + "bytes" + "fmt" + "runtime/debug" + "strings" + + "github.com/ledgerwatch/turbo-geth/common" + "github.com/ledgerwatch/turbo-geth/common/dbutils" + "github.com/ledgerwatch/turbo-geth/core/types/accounts" + "github.com/ledgerwatch/turbo-geth/ethdb" + "github.com/ledgerwatch/turbo-geth/trie/rlphacks" +) + +type hookFunction func(*ResolveRequest, node, common.Hash) error + +type ResolverStateful struct { + rss []*ResolveSet + curr bytes.Buffer // Current key for the structure generation algorithm, as well as the input tape for the hash builder + succ bytes.Buffer + value bytes.Buffer // Current value to be used as the value tape for the hash builder + groups []uint16 + reqIndices []int // Indices pointing back to request slice from slices returned by PrepareResolveParams + hb *HashBuilder + topLevels int // How many top levels of the trie to keep (not roll into hashes) + currentReq *ResolveRequest // Request currently being handled + currentRs *ResolveSet // ResolveSet currently being used + keyIdx int + fieldSet uint32 // fieldSet for the next invocation of genStructStep + a accounts.Account + + requests []*ResolveRequest + + roots []node // roots of the tries that are being built + hookFunction hookFunction +} + +func NewResolverStateful(topLevels int, requests []*ResolveRequest, hookFunction hookFunction) *ResolverStateful { + return &ResolverStateful{ + topLevels: topLevels, + hb: NewHashBuilder(false), + reqIndices: []int{}, + requests: requests, + hookFunction: hookFunction, + } +} + +func (tr *ResolverStateful) PopRoots() []node { + roots := tr.roots + tr.roots = nil + return roots +} + +// PrepareResolveParams prepares information for the MultiWalk +func (tr *ResolverStateful) PrepareResolveParams() ([][]byte, []uint) { + // Remove requests strictly contained in the preceding ones + startkeys := [][]byte{} + fixedbits := []uint{} + tr.rss = nil + if len(tr.requests) == 0 { + return startkeys, fixedbits + } + var prevReq *ResolveRequest + for i, req := range tr.requests { + if prevReq == nil || + !bytes.Equal(req.contract, prevReq.contract) || + !bytes.Equal(req.resolveHex[:req.resolvePos], prevReq.resolveHex[:prevReq.resolvePos]) { + + tr.reqIndices = append(tr.reqIndices, i) + pLen := len(req.contract) + key := make([]byte, pLen+32) + copy(key[:], req.contract) + decodeNibbles(req.resolveHex[:req.resolvePos], key[pLen:]) + startkeys = append(startkeys, key) + req.extResolvePos = req.resolvePos + 2*pLen + fixedbits = append(fixedbits, uint(4*req.extResolvePos)) + prevReq = req + var minLength int + if req.resolvePos >= tr.topLevels { + minLength = 0 + } else { + minLength = tr.topLevels - req.resolvePos + } + rs := NewResolveSet(minLength) + tr.rss = append(tr.rss, rs) + rs.AddHex(req.resolveHex[req.resolvePos:]) + } else { + rs := tr.rss[len(tr.rss)-1] + rs.AddHex(req.resolveHex[req.resolvePos:]) + } + } + tr.currentReq = tr.requests[tr.reqIndices[0]] + tr.currentRs = tr.rss[0] + return startkeys, fixedbits +} + +func (tr *ResolverStateful) finaliseRoot() error { + tr.curr.Reset() + tr.curr.Write(tr.succ.Bytes()) + tr.succ.Reset() + if tr.curr.Len() > 0 { + var err error + var data GenStructStepData + if tr.fieldSet == 0 { + data = GenStructStepLeafData{Value: rlphacks.RlpSerializableBytes(tr.value.Bytes())} + } else { + data = GenStructStepAccountData{ + FieldSet: tr.fieldSet, + StorageSize: tr.a.StorageSize, + Balance: &tr.a.Balance, + Nonce: tr.a.Nonce, + Incarnation: tr.a.Incarnation, + } + } + tr.groups, err = GenStructStep(tr.currentRs.HashOnly, tr.curr.Bytes(), tr.succ.Bytes(), tr.hb, data, tr.groups) + if err != nil { + return err + } + } + if tr.hb.hasRoot() { + hbRoot := tr.hb.root() + hbHash := tr.hb.rootHash() + return tr.hookFunction(tr.currentReq, hbRoot, hbHash) + } + return nil +} + +func (tr *ResolverStateful) RebuildTrie( + db ethdb.Database, + blockNr uint64, + accounts bool, + historical bool) error { + startkeys, fixedbits := tr.PrepareResolveParams() + if db == nil { + var b strings.Builder + fmt.Fprintf(&b, "ResolveWithDb(db=nil), accounts: %t\n", accounts) + for i, sk := range startkeys { + fmt.Fprintf(&b, "sk %x, bits: %d\n", sk, fixedbits[i]) + } + return fmt.Errorf("unexpected resolution: %s at %s", b.String(), debug.Stack()) + } + + var err error + if accounts { + if historical { + err = db.MultiWalkAsOf(dbutils.AccountsBucket, dbutils.AccountsHistoryBucket, startkeys, fixedbits, blockNr+1, tr.WalkerAccounts) + } else { + err = db.MultiWalk(dbutils.AccountsBucket, startkeys, fixedbits, tr.WalkerAccounts) + } + } else { + if historical { + err = db.MultiWalkAsOf(dbutils.StorageBucket, dbutils.StorageHistoryBucket, startkeys, fixedbits, blockNr+1, tr.WalkerStorage) + } else { + err = db.MultiWalk(dbutils.StorageBucket, startkeys, fixedbits, tr.WalkerStorage) + } + } + if err != nil { + return err + } + return tr.finaliseRoot() +} + +func (tr *ResolverStateful) WalkerAccounts(keyIdx int, k []byte, v []byte) error { + return tr.Walker(true, keyIdx, k, v) +} + +func (tr *ResolverStateful) WalkerStorage(keyIdx int, k []byte, v []byte) error { + return tr.Walker(false, keyIdx, k, v) +} + +// Walker - k, v - shouldn't be reused in the caller's code +func (tr *ResolverStateful) Walker(isAccount bool, keyIdx int, k []byte, v []byte) error { + //fmt.Printf("keyIdx: %d key:%x value:%x, accounts: %t\n", keyIdx, k, v, tr.accounts) + if keyIdx != tr.keyIdx { + if err := tr.finaliseRoot(); err != nil { + return err + } + tr.hb.Reset() + tr.groups = nil + tr.keyIdx = keyIdx + tr.currentReq = tr.requests[tr.reqIndices[keyIdx]] + tr.currentRs = tr.rss[keyIdx] + tr.curr.Reset() + } + if len(v) > 0 { + tr.curr.Reset() + tr.curr.Write(tr.succ.Bytes()) + tr.succ.Reset() + skip := tr.currentReq.extResolvePos // how many first nibbles to skip + i := 0 + for _, b := range k { + if i >= skip { + tr.succ.WriteByte(b / 16) + } + i++ + if i >= skip { + tr.succ.WriteByte(b % 16) + } + i++ + } + tr.succ.WriteByte(16) + if tr.curr.Len() > 0 { + var err error + var data GenStructStepData + if tr.fieldSet == 0 { + data = GenStructStepLeafData{Value: rlphacks.RlpSerializableBytes(tr.value.Bytes())} + } else { + data = GenStructStepAccountData{ + FieldSet: tr.fieldSet, + StorageSize: tr.a.StorageSize, + Balance: &tr.a.Balance, + Nonce: tr.a.Nonce, + Incarnation: tr.a.Incarnation, + } + } + tr.groups, err = GenStructStep(tr.currentRs.HashOnly, tr.curr.Bytes(), tr.succ.Bytes(), tr.hb, data, tr.groups) + if err != nil { + return err + } + } + // Remember the current key and value + if isAccount { + if err := tr.a.DecodeForStorage(v); err != nil { + return err + } + if tr.a.IsEmptyCodeHash() && tr.a.IsEmptyRoot() { + tr.fieldSet = AccountFieldSetNotContract + } else { + if tr.a.HasStorageSize { + tr.fieldSet = AccountFieldSetContractWithSize + } else { + tr.fieldSet = AccountFieldSetContract + } + // the first item ends up deepest on the stack, the seccond item - on the top + if err := tr.hb.hash(tr.a.CodeHash); err != nil { + return err + } + if err := tr.hb.hash(tr.a.Root); err != nil { + return err + } + } + } else { + tr.value.Reset() + tr.value.Write(v) + tr.fieldSet = AccountFieldSetNotAccount + } + } + return nil +} diff --git a/trie/resolver_stateless.go b/trie/resolver_stateless.go new file mode 100644 index 0000000000000000000000000000000000000000..6f43bc865930bac2b32c00cb2a485413734e2e36 --- /dev/null +++ b/trie/resolver_stateless.go @@ -0,0 +1,63 @@ +package trie + +import ( + "bytes" + "io" +) + +type ResolverStateless struct { + requests []*ResolveRequest + hookFunction hookFunction +} + +func NewResolverStateless(requests []*ResolveRequest, hookFunction hookFunction) *ResolverStateless { + return &ResolverStateless{ + requests: requests, + hookFunction: hookFunction, + } +} + +func (r *ResolverStateless) RebuildTrie(db WitnessStorage, blockNr uint64, trieLimit uint32, startPos int64) (int64, error) { + serializedWitness, err := db.GetWitnessesForBlock(blockNr, trieLimit) + if err != nil { + return 0, err + } + + witnessReader := bytes.NewReader(serializedWitness) + if _, err := witnessReader.Seek(startPos, io.SeekStart); err != nil { + return 0, err + } + + var prevReq *ResolveRequest + requestIndex := 0 + + for witnessReader.Len() > 0 && requestIndex < len(r.requests) { + req := r.requests[requestIndex] + if prevReq == nil || + !bytes.Equal(req.contract, prevReq.contract) || + !bytes.Equal(req.resolveHex[:req.resolvePos], prevReq.resolveHex[:prevReq.resolvePos]) { + witness, err := NewWitnessFromReader(witnessReader, false /*trace*/) + + if err != nil { + return 0, err + } + + trie, _, err := BuildTrieFromWitness(witness, false /*is-binary*/, false /*trace*/) + if err != nil { + return 0, err + } + rootNode := trie.root + rootHash := trie.Hash() + + err = r.hookFunction(req, rootNode, rootHash) + if err != nil { + return 0, err + } + prevReq = req + } + requestIndex++ + } + + bytesRead := int64(len(serializedWitness) - witnessReader.Len()) + return bytesRead, nil +} diff --git a/trie/resolver_stateless_test.go b/trie/resolver_stateless_test.go new file mode 100644 index 0000000000000000000000000000000000000000..8d28d51ade447318315dea23afd0631e8a175194 --- /dev/null +++ b/trie/resolver_stateless_test.go @@ -0,0 +1,132 @@ +package trie + +import ( + "bytes" + "fmt" + "testing" + + "github.com/ledgerwatch/turbo-geth/common" +) + +type testWitnessStorage []byte + +func (s *testWitnessStorage) GetWitnessesForBlock(_ uint64, _ uint32) ([]byte, error) { + return []byte(*s), nil +} + +func generateKey(i int) []byte { + return []byte(fmt.Sprintf("key-number-%05d", i)) +} + +func generateValue(i int) []byte { + return []byte(fmt.Sprintf("value-number-%05d", i)) +} + +func buildTestTrie(numberOfNodes int) *Trie { + trie := New(EmptyRoot) + for i := 0; i < numberOfNodes; i++ { + trie.Update(generateKey(i), generateValue(i), 1) + } + return trie +} + +func TestRebuildTrie(t *testing.T) { + trie1 := buildTestTrie(0) + trie2 := buildTestTrie(10) + trie3 := buildTestTrie(100) + + w1, err := extractWitnessFromRootNode(trie1.root, 1, false, nil, nil) + if err != nil { + t.Error(err) + } + + w2, err := extractWitnessFromRootNode(trie2.root, 1, false, nil, nil) + if err != nil { + t.Error(err) + } + + w3, err := extractWitnessFromRootNode(trie3.root, 1, false, nil, nil) + if err != nil { + t.Error(err) + } + + var buff bytes.Buffer + _, err = w1.WriteTo(&buff) + if err != nil { + t.Error(err) + } + + err = buff.WriteByte(byte(OpNewTrie)) + if err != nil { + t.Error(err) + } + + _, err = w2.WriteTo(&buff) + if err != nil { + t.Error(err) + } + + err = buff.WriteByte(byte(OpNewTrie)) + if err != nil { + t.Error(err) + } + + _, err = w3.WriteTo(&buff) + if err != nil { + t.Error(err) + } + + storage := testWitnessStorage(buff.Bytes()) + + resolvedTries := make([]*Trie, 3) + + currentTrie := 0 + + req1 := trie1.NewResolveRequest(nil, []byte{0x01}, 1, trie1.Hash().Bytes()) + req2 := trie2.NewResolveRequest(nil, []byte{0x02}, 1, trie2.Hash().Bytes()) + req21 := trie2.NewResolveRequest(nil, []byte{0x02}, 1, trie2.Hash().Bytes()) + req3 := trie3.NewResolveRequest(nil, []byte{0x03}, 1, trie3.Hash().Bytes()) + req31 := trie3.NewResolveRequest(nil, []byte{0x03}, 1, trie3.Hash().Bytes()) + + hookFunction := func(req *ResolveRequest, root node, rootHash common.Hash) error { + trie := New(rootHash) + trie.root = root + resolvedTries[currentTrie] = trie + currentTrie++ + if !bytes.Equal(req.resolveHash.hash(), rootHash.Bytes()) { + return fmt.Errorf("root hash mismatch: expected %x got %x", + req.resolveHash.hash(), rootHash.Bytes()) + } + return nil + } + + // it should ignore duplicate resolve requests + resolver := NewResolverStateless([]*ResolveRequest{req1, req2, req21}, hookFunction) + + pos, err := resolver.RebuildTrie(&storage, 1, 1, 0) + if err != nil { + t.Error(err) + } + + // we also support partial resolution with continuation (for storage tries) + // so basically we first resolve accounts, then storages separately + // but we still want to keep one entry in a DB per block, so we store the last read position + // and then use it as a start + resolver = NewResolverStateless([]*ResolveRequest{req3, req31}, hookFunction) + _, err = resolver.RebuildTrie(&storage, 1, 1, pos) + if err != nil { + t.Error(err) + } + + if !bytes.Equal(resolvedTries[0].Hash().Bytes(), trie1.Hash().Bytes()) { + t.Errorf("tries are different") + } + + if !bytes.Equal(resolvedTries[1].Hash().Bytes(), trie2.Hash().Bytes()) { + t.Errorf("tries are different") + } + + if !bytes.Equal(resolvedTries[2].Hash().Bytes(), trie3.Hash().Bytes()) { + t.Errorf("tries are different") + } +} diff --git a/trie/trie_from_witness.go b/trie/trie_from_witness.go index 37dcfd0043d4a3217638fa84e5ab6f2c938c0f31..532c11ac31c92e6d809972dcfdb55918b70ed53a 100644 --- a/trie/trie_from_witness.go +++ b/trie/trie_from_witness.go @@ -88,6 +88,12 @@ func BuildTrieFromWitness(witness *Witness, isBinary bool, trace bool) (*Trie, C if trace { fmt.Printf("\n") } + if !hb.hasRoot() { + if isBinary { + return NewBinary(EmptyRoot), nil, nil + } + return New(EmptyRoot), nil, nil + } r := hb.root() var tr *Trie if isBinary { diff --git a/trie/trie_witness.go b/trie/trie_witness.go index b4e54d9a7567913da7558f0a9c691313b56e6008..43eb2d6feafe1fdc215ecdd0d2a8f5e06381319e 100644 --- a/trie/trie_witness.go +++ b/trie/trie_witness.go @@ -1,8 +1,20 @@ package trie func (t *Trie) ExtractWitness(blockNr uint64, trace bool, rs *ResolveSet, codeMap CodeMap) (*Witness, error) { - builder := NewWitnessBuilder(t, blockNr, trace, rs, codeMap) - hr := newHasher(false) - defer returnHasherToPool(hr) - return builder.Build(hr.hash) + return extractWitnessFromRootNode(t.root, blockNr, trace, rs, codeMap) +} + +// extractWitnessFromRootNode extracts a witness for a subtrie starting from the specified root +// if hashOnly param is nil it will make a witness for the full subtrie, +// if hashOnly param is set to a ResolveSet instance, it will make a witness for only the accounts/storages that were actually touched; other paths will be hashed. +func extractWitnessFromRootNode(root node, blockNr uint64, trace bool, hashOnly HashOnly, codeMap CodeMap) (*Witness, error) { + builder := NewWitnessBuilder(root, blockNr, trace, codeMap) + var limiter *MerklePathLimiter + if hashOnly != nil { + hr := newHasher(false) + defer returnHasherToPool(hr) + limiter = &MerklePathLimiter{hashOnly, hr.hash} + } + + return builder.Build(limiter) } diff --git a/trie/witness.go b/trie/witness.go index 691fb87cd7ffdae510992bac575e7cd34467d772..e2f1fb2fca8eacafa6f4e400b1b120fe4ae31d99 100644 --- a/trie/witness.go +++ b/trie/witness.go @@ -6,6 +6,11 @@ import ( "io" ) +// WitnessStorage is an interface representing a single +type WitnessStorage interface { + GetWitnessesForBlock(uint64, uint32) ([]byte, error) +} + // WitnessVersion represents the current version of the block witness // in case of incompatible changes it should be updated and the code to migrate the // old witness format should be present @@ -101,10 +106,17 @@ func NewWitnessFromReader(input io.Reader, trace bool) (*Witness, error) { op = &OperatorEmptyRoot{} case OpExtension: op = &OperatorExtension{} + case OpNewTrie: + /* end of the current trie, end the function */ + break default: return nil, fmt.Errorf("unexpected opcode while reading witness: %x", opcode[0]) } + if op == nil { + break + } + err = op.LoadFrom(operatorLoader) if err != nil { return nil, err diff --git a/trie/witness_builder.go b/trie/witness_builder.go index 517a9f3b74994ca0e26a0aff63f921363ada5cd7..af7251b6910ef3554a2f1fb51b5c8b2c79c40eae 100644 --- a/trie/witness_builder.go +++ b/trie/witness_builder.go @@ -8,31 +8,38 @@ import ( ) type HashNodeFunc func(node, bool, []byte) (int, error) +type HashOnly interface { + HashOnly([]byte) bool + Current() []byte +} + +type MerklePathLimiter struct { + HashOnly HashOnly + HashFunc HashNodeFunc +} type CodeMap map[common.Hash][]byte type WitnessBuilder struct { - t *Trie + root node blockNr uint64 trace bool - rs *ResolveSet codeMap CodeMap operands []WitnessOperator } -func NewWitnessBuilder(t *Trie, blockNr uint64, trace bool, rs *ResolveSet, codeMap CodeMap) *WitnessBuilder { +func NewWitnessBuilder(root node, blockNr uint64, trace bool, codeMap CodeMap) *WitnessBuilder { return &WitnessBuilder{ - t: t, + root: root, blockNr: blockNr, trace: trace, - rs: rs, codeMap: codeMap, operands: make([]WitnessOperator, 0), } } -func (b *WitnessBuilder) Build(hashNodeFunc HashNodeFunc) (*Witness, error) { - err := b.makeBlockWitness(b.t.root, []byte{}, hashNodeFunc, true) +func (b *WitnessBuilder) Build(limiter *MerklePathLimiter) (*Witness, error) { + err := b.makeBlockWitness(b.root, []byte{}, limiter, true) witness := NewWitness(b.operands) b.operands = nil return witness, err @@ -92,23 +99,26 @@ func (b *WitnessBuilder) addExtensionOp(key []byte) error { return nil } -func (b *WitnessBuilder) addHashOp(n node, force bool, hashNodeFunc HashNodeFunc) error { - var hash common.Hash +func (b *WitnessBuilder) makeHashNode(n node, force bool, hashNodeFunc HashNodeFunc) (hashNode, error) { switch n := n.(type) { case hashNode: - copy(hash[:], n[:]) + return n, nil default: + var hash common.Hash if _, err := hashNodeFunc(n, force, hash[:]); err != nil { - return err + return nil, err } + return hashNode(hash[:]), nil } +} +func (b *WitnessBuilder) addHashOp(n hashNode) error { if b.trace { - fmt.Printf("HASH: type: %T v %x\n", n, hash) + fmt.Printf("HASH: type: %T v %x\n", n, n) } var op OperatorHash - op.Hash = hash + op.Hash = common.BytesToHash(n) b.operands = append(b.operands, &op) return nil @@ -155,14 +165,13 @@ func (b *WitnessBuilder) processAccountCode(n *accountNode) error { code, ok := b.codeMap[n.CodeHash] if !ok { - // FIXME: these parameters aren't harmful, but probably addHashOp does too much - return b.addHashOp(hashNode(n.CodeHash[:]), false, nil) + return b.addHashOp(hashNode(n.CodeHash[:])) } return b.addCodeOp(code) } -func (b *WitnessBuilder) processAccountStorage(n *accountNode, hex []byte, hashNodeFunc HashNodeFunc) error { +func (b *WitnessBuilder) processAccountStorage(n *accountNode, hex []byte, limiter *MerklePathLimiter) error { if n.IsEmptyRoot() && n.IsEmptyCodeHash() { return nil } @@ -172,17 +181,17 @@ func (b *WitnessBuilder) processAccountStorage(n *accountNode, hex []byte, hashN } // Here we substitute rs parameter for storageRs, because it needs to become the default - return b.makeBlockWitness(n.storage, hex, hashNodeFunc, true) + return b.makeBlockWitness(n.storage, hex, limiter, true) } func (b *WitnessBuilder) makeBlockWitness( - nd node, hex []byte, hashNodeFunc HashNodeFunc, force bool) error { + nd node, hex []byte, limiter *MerklePathLimiter, force bool) error { processAccountNode := func(key []byte, storageKey []byte, n *accountNode) error { if err := b.processAccountCode(n); err != nil { return err } - if err := b.processAccountStorage(n, storageKey, hashNodeFunc); err != nil { + if err := b.processAccountStorage(n, storageKey, limiter); err != nil { return err } return b.addAccountLeafOp(key, n) @@ -208,41 +217,49 @@ func (b *WitnessBuilder) makeBlockWitness( case *accountNode: return processAccountNode(n.Key, hexVal, v) default: - if err := b.makeBlockWitness(n.Val, hexVal, hashNodeFunc, false); err != nil { + if err := b.makeBlockWitness(n.Val, hexVal, limiter, false); err != nil { return err } return b.addExtensionOp(n.Key) } case *duoNode: - hashOnly := b.rs.HashOnly(hex) // Save this because rs can move on to other keys during the recursive invocation + hashOnly := limiter != nil && limiter.HashOnly.HashOnly(hex) // Save this because rs can move on to other keys during the recursive invocation if b.trace { - fmt.Printf("b.rs.HashOnly(%x) -> %v\n", hex, hashOnly) + fmt.Printf("b.hashOnly.HashOnly(%x) -> %v\n", hex, hashOnly) } if hashOnly { - return b.addHashOp(n, force, hashNodeFunc) + hn, err := b.makeHashNode(n, force, limiter.HashFunc) + if err != nil { + return err + } + return b.addHashOp(hn) } i1, i2 := n.childrenIdx() - if err := b.makeBlockWitness(n.child1, expandKeyHex(hex, i1), hashNodeFunc, false); err != nil { + if err := b.makeBlockWitness(n.child1, expandKeyHex(hex, i1), limiter, false); err != nil { return err } - if err := b.makeBlockWitness(n.child2, expandKeyHex(hex, i2), hashNodeFunc, false); err != nil { + if err := b.makeBlockWitness(n.child2, expandKeyHex(hex, i2), limiter, false); err != nil { return err } return b.addBranchOp(n.mask) case *fullNode: - hashOnly := b.rs.HashOnly(hex) // Save this because rs can move on to other keys during the recursive invocation + hashOnly := limiter != nil && limiter.HashOnly.HashOnly(hex) // Save this because rs can move on to other keys during the recursive invocation if hashOnly { - return b.addHashOp(n, force, hashNodeFunc) + hn, err := b.makeHashNode(n, force, limiter.HashFunc) + if err != nil { + return err + } + return b.addHashOp(hn) } var mask uint32 for i, child := range n.Children { if child != nil { - if err := b.makeBlockWitness(child, expandKeyHex(hex, byte(i)), hashNodeFunc, false); err != nil { + if err := b.makeBlockWitness(child, expandKeyHex(hex, byte(i)), limiter, false); err != nil { return err } mask |= (uint32(1) << uint(i)) @@ -251,16 +268,16 @@ func (b *WitnessBuilder) makeBlockWitness( return b.addBranchOp(mask) case hashNode: - hashOnly := b.rs.HashOnly(hex) + hashOnly := limiter == nil || limiter.HashOnly.HashOnly(hex) if !hashOnly { - if c := b.rs.Current(); len(c) == len(hex)+1 && c[len(c)-1] == 16 { + if c := limiter.HashOnly.Current(); len(c) == len(hex)+1 && c[len(c)-1] == 16 { hashOnly = true } } if hashOnly { - return b.addHashOp(n, force, hashNodeFunc) + return b.addHashOp(n) } - return fmt.Errorf("unexpected hashNode: %s, at hex: %x, rs.Current: %x (%d)", n, hex, b.rs.Current(), len(hex)) + return fmt.Errorf("unexpected hashNode: %s, at hex: %x, rs.Current: %x (%d)", n, hex, limiter.HashOnly.Current(), len(hex)) default: return fmt.Errorf("unexpected node type: %T", nd) } diff --git a/trie/witness_builder_test.go b/trie/witness_builder_test.go index 699bbb3e5768cba7cee7a50ae3a9513609c166f1..9db850b3e912e10e1fc9ad6c59023706c4458c22 100644 --- a/trie/witness_builder_test.go +++ b/trie/witness_builder_test.go @@ -18,14 +18,14 @@ func TestBlockWitnessBinary(t *testing.T) { rs := NewBinaryResolveSet(2) rs.AddKey([]byte("ABCD0001")) - bwb := NewWitnessBuilder(trBin.Trie(), 1, false, rs, nil) + bwb := NewWitnessBuilder(trBin.Trie().root, 1, false, nil) hr := newHasher(false) defer returnHasherToPool(hr) var w *Witness var err error - if w, err = bwb.Build(hr.hash); err != nil { + if w, err = bwb.Build(&MerklePathLimiter{rs, hr.hash}); err != nil { t.Errorf("Could not make block witness: %v", err) } @@ -57,14 +57,14 @@ func TestBlockWitnessBinaryAccount(t *testing.T) { rs := NewBinaryResolveSet(2) rs.AddKey([]byte("ABCD0001")) - bwb := NewWitnessBuilder(trBin.Trie(), 1, false, rs, nil) + bwb := NewWitnessBuilder(trBin.Trie().root, 1, false, nil) hr := newHasher(false) defer returnHasherToPool(hr) var w *Witness var err error - if w, err = bwb.Build(hr.hash); err != nil { + if w, err = bwb.Build(&MerklePathLimiter{rs, hr.hash}); err != nil { t.Errorf("Could not make block witness: %v", err) } diff --git a/trie/witness_operators.go b/trie/witness_operators.go index b150974af0cd76d8d2aa0ccf32a676f3b4268cdf..9deab6d006052d4f6a12786e9171a70811fdbf75 100644 --- a/trie/witness_operators.go +++ b/trie/witness_operators.go @@ -34,6 +34,9 @@ const ( OpAccountLeaf // OpEmptyRoot places nil onto the node stack, and empty root hash onto the hash stack. OpEmptyRoot + + // OpNewTrie stops the processing, because another trie is encoded into the witness. + OpNewTrie = OperatorKindCode(0xBB) ) // WitnessOperator is a single operand in the block witness. It knows how to serialize/deserialize itself.