From 1e1865b73f6b0d2fef656d2f37e20e41d13a5ef0 Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Wed, 5 Feb 2020 13:12:09 +0100
Subject: [PATCH] core: implement background trie prefetcher

Squashed from the following commits:

core/state: lazily init snapshot storage map
core/state: fix flawed meter on storage reads
core/state: make statedb/stateobjects reuse a hasher
core/blockchain, core/state: implement new trie prefetcher
core: make trie prefetcher deliver tries to statedb
core/state: refactor trie_prefetcher, export storage tries
blockchain: re-enable the next-block-prefetcher
state: remove panics in trie prefetcher
core/state/trie_prefetcher: address some review concerns

sq
---
 core/blockchain.go            |  27 +++-
 core/state/database.go        |  12 +-
 core/state/state_object.go    |  72 +++++++---
 core/state/statedb.go         |  45 +++++-
 core/state/trie_prefetcher.go | 249 ++++++++++++++++++++++++++++++++++
 crypto/crypto.go              |  17 ++-
 crypto/crypto_test.go         |   7 +
 7 files changed, 395 insertions(+), 34 deletions(-)
 create mode 100644 core/state/trie_prefetcher.go

diff --git a/core/blockchain.go b/core/blockchain.go
index b8f483b85..ccb99bded 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -201,11 +201,12 @@ type BlockChain struct {
 	running       int32          // 0 if chain is running, 1 when stopped
 	procInterrupt int32          // interrupt signaler for block processing
 
-	engine     consensus.Engine
-	validator  Validator  // Block and state validator interface
-	prefetcher Prefetcher // Block state prefetcher interface
-	processor  Processor  // Block transaction processor interface
-	vmConfig   vm.Config
+	engine         consensus.Engine
+	validator      Validator             // Block and state validator interface
+	triePrefetcher *state.TriePrefetcher // Trie prefetcher interface
+	prefetcher     Prefetcher
+	processor      Processor // Block transaction processor interface
+	vmConfig       vm.Config
 
 	shouldPreserve     func(*types.Block) bool        // Function used to determine whether should preserve the given block.
 	terminateInsert    func(common.Hash, uint64) bool // Testing hook used to terminate ancient receipt chain insertion.
@@ -249,6 +250,15 @@ func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *par
 	}
 	bc.validator = NewBlockValidator(chainConfig, bc, engine)
 	bc.prefetcher = newStatePrefetcher(chainConfig, bc, engine)
+	tp := state.NewTriePrefetcher(bc.stateCache)
+
+	bc.wg.Add(1)
+	go func() {
+		tp.Loop()
+		bc.wg.Done()
+	}()
+	bc.triePrefetcher = tp
+
 	bc.processor = NewStateProcessor(chainConfig, bc, engine)
 
 	var err error
@@ -991,6 +1001,9 @@ func (bc *BlockChain) Stop() {
 	bc.scope.Close()
 	close(bc.quit)
 	bc.StopInsert()
+	if bc.triePrefetcher != nil {
+		bc.triePrefetcher.Close()
+	}
 	bc.wg.Wait()
 
 	// Ensure that the entirety of the state snapshot is journalled to disk.
@@ -1857,6 +1870,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
 			parent = bc.GetHeader(block.ParentHash(), block.NumberU64()-1)
 		}
 		statedb, err := state.New(parent.Root, bc.stateCache, bc.snaps)
+		statedb.UsePrefetcher(bc.triePrefetcher)
 		if err != nil {
 			return it.index, err
 		}
@@ -1891,8 +1905,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks, verifySeals bool) (int, er
 		storageUpdateTimer.Update(statedb.StorageUpdates)             // Storage updates are complete, we can mark them
 		snapshotAccountReadTimer.Update(statedb.SnapshotAccountReads) // Account reads are complete, we can mark them
 		snapshotStorageReadTimer.Update(statedb.SnapshotStorageReads) // Storage reads are complete, we can mark them
-
-		triehash := statedb.AccountHashes + statedb.StorageHashes // Save to not double count in validation
+		triehash := statedb.AccountHashes + statedb.StorageHashes     // Save to not double count in validation
 		trieproc := statedb.SnapshotAccountReads + statedb.AccountReads + statedb.AccountUpdates
 		trieproc += statedb.SnapshotStorageReads + statedb.StorageReads + statedb.StorageUpdates
 
diff --git a/core/state/database.go b/core/state/database.go
index 83f7b2a83..1a06e3340 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -129,12 +129,20 @@ type cachingDB struct {
 
 // OpenTrie opens the main account trie at a specific root hash.
 func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
-	return trie.NewSecure(root, db.db)
+	tr, err := trie.NewSecure(root, db.db)
+	if err != nil {
+		return nil, err
+	}
+	return tr, nil
 }
 
 // OpenStorageTrie opens the storage trie of an account.
 func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
-	return trie.NewSecure(root, db.db)
+	tr, err := trie.NewSecure(root, db.db)
+	if err != nil {
+		return nil, err
+	}
+	return tr, nil
 }
 
 // CopyTrie returns an independent copy of the given trie.
diff --git a/core/state/state_object.go b/core/state/state_object.go
index d0d3b4513..43c5074d9 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -157,11 +157,20 @@ func (s *stateObject) touch() {
 
 func (s *stateObject) getTrie(db Database) Trie {
 	if s.trie == nil {
-		var err error
-		s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root)
-		if err != nil {
-			s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{})
-			s.setError(fmt.Errorf("can't create storage trie: %v", err))
+		// Try fetching from prefetcher first
+		// We don't prefetch empty tries
+		if s.data.Root != emptyRoot && s.db.prefetcher != nil {
+			// When the miner is creating the pending state, there is no
+			// prefetcher
+			s.trie = s.db.prefetcher.GetTrie(s.data.Root)
+		}
+		if s.trie == nil {
+			var err error
+			s.trie, err = db.OpenStorageTrie(s.addrHash, s.data.Root)
+			if err != nil {
+				s.trie, _ = db.OpenStorageTrie(s.addrHash, common.Hash{})
+				s.setError(fmt.Errorf("can't create storage trie: %v", err))
+			}
 		}
 	}
 	return s.trie
@@ -197,12 +206,24 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
 	}
 	// If no live objects are available, attempt to use snapshots
 	var (
-		enc []byte
-		err error
+		enc   []byte
+		err   error
+		meter *time.Duration
 	)
+	readStart := time.Now()
+	if metrics.EnabledExpensive {
+		// If the snap is 'under construction', the first lookup may fail. If that
+		// happens, we don't want to double-count the time elapsed. Thus this
+		// dance with the metering.
+		defer func() {
+			if meter != nil {
+				*meter += time.Since(readStart)
+			}
+		}()
+	}
 	if s.db.snap != nil {
 		if metrics.EnabledExpensive {
-			defer func(start time.Time) { s.db.SnapshotStorageReads += time.Since(start) }(time.Now())
+			meter = &s.db.SnapshotStorageReads
 		}
 		// If the object was destructed in *this* block (and potentially resurrected),
 		// the storage has been cleared out, and we should *not* consult the previous
@@ -217,8 +238,14 @@ func (s *stateObject) GetCommittedState(db Database, key common.Hash) common.Has
 	}
 	// If snapshot unavailable or reading from it failed, load from the database
 	if s.db.snap == nil || err != nil {
+		if meter != nil {
+			// If we already spent time checking the snapshot, account for it
+			// and reset the readStart
+			*meter += time.Since(readStart)
+			readStart = time.Now()
+		}
 		if metrics.EnabledExpensive {
-			defer func(start time.Time) { s.db.StorageReads += time.Since(start) }(time.Now())
+			meter = &s.db.StorageReads
 		}
 		if enc, err = s.getTrie(db).TryGet(key.Bytes()); err != nil {
 			s.setError(err)
@@ -283,8 +310,13 @@ func (s *stateObject) setState(key, value common.Hash) {
 // finalise moves all dirty storage slots into the pending area to be hashed or
 // committed later. It is invoked at the end of every transaction.
 func (s *stateObject) finalise() {
+	trieChanges := make([]common.Hash, 0, len(s.dirtyStorage))
 	for key, value := range s.dirtyStorage {
 		s.pendingStorage[key] = value
+		trieChanges = append(trieChanges, key)
+	}
+	if len(trieChanges) > 0 && s.db.prefetcher != nil && s.data.Root != emptyRoot {
+		s.db.prefetcher.PrefetchStorage(s.data.Root, trieChanges)
 	}
 	if len(s.dirtyStorage) > 0 {
 		s.dirtyStorage = make(Storage)
@@ -303,18 +335,11 @@ func (s *stateObject) updateTrie(db Database) Trie {
 	if metrics.EnabledExpensive {
 		defer func(start time.Time) { s.db.StorageUpdates += time.Since(start) }(time.Now())
 	}
-	// Retrieve the snapshot storage map for the object
+	// The snapshot storage map for the object
 	var storage map[common.Hash][]byte
-	if s.db.snap != nil {
-		// Retrieve the old storage map, if available, create a new one otherwise
-		storage = s.db.snapStorage[s.addrHash]
-		if storage == nil {
-			storage = make(map[common.Hash][]byte)
-			s.db.snapStorage[s.addrHash] = storage
-		}
-	}
 	// Insert all the pending updates into the trie
 	tr := s.getTrie(db)
+	hasher := s.db.hasher
 	for key, value := range s.pendingStorage {
 		// Skip noop changes, persist actual changes
 		if value == s.originStorage[key] {
@@ -331,8 +356,15 @@ func (s *stateObject) updateTrie(db Database) Trie {
 			s.setError(tr.TryUpdate(key[:], v))
 		}
 		// If state snapshotting is active, cache the data til commit
-		if storage != nil {
-			storage[crypto.Keccak256Hash(key[:])] = v // v will be nil if value is 0x00
+		if s.db.snap != nil {
+			if storage == nil {
+				// Retrieve the old storage map, if available, create a new one otherwise
+				if storage = s.db.snapStorage[s.addrHash]; storage == nil {
+					storage = make(map[common.Hash][]byte)
+					s.db.snapStorage[s.addrHash] = storage
+				}
+			}
+			storage[crypto.HashData(hasher, key[:])] = v // v will be nil if value is 0x00
 		}
 	}
 	if len(s.pendingStorage) > 0 {
diff --git a/core/state/statedb.go b/core/state/statedb.go
index a9d1de2e0..ce50962e8 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -62,8 +62,11 @@ func (n *proofList) Delete(key []byte) error {
 // * Contracts
 // * Accounts
 type StateDB struct {
-	db   Database
-	trie Trie
+	db           Database
+	prefetcher   *TriePrefetcher
+	originalRoot common.Hash // The pre-state root, before any changes were made
+	trie         Trie
+	hasher       crypto.KeccakState
 
 	snaps         *snapshot.Tree
 	snap          snapshot.Snapshot
@@ -125,6 +128,7 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
 	sdb := &StateDB{
 		db:                  db,
 		trie:                tr,
+		originalRoot:        root,
 		snaps:               snaps,
 		stateObjects:        make(map[common.Address]*stateObject),
 		stateObjectsPending: make(map[common.Address]struct{}),
@@ -133,6 +137,7 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
 		preimages:           make(map[common.Hash][]byte),
 		journal:             newJournal(),
 		accessList:          newAccessList(),
+		hasher:              crypto.NewKeccakState(),
 	}
 	if sdb.snaps != nil {
 		if sdb.snap = sdb.snaps.Snapshot(root); sdb.snap != nil {
@@ -144,6 +149,13 @@ func New(root common.Hash, db Database, snaps *snapshot.Tree) (*StateDB, error)
 	return sdb, nil
 }
 
+func (s *StateDB) UsePrefetcher(prefetcher *TriePrefetcher) {
+	if prefetcher != nil {
+		s.prefetcher = prefetcher
+		s.prefetcher.Resume(s.originalRoot)
+	}
+}
+
 // setError remembers the first non-nil error it is called with.
 func (s *StateDB) setError(err error) {
 	if s.dbErr == nil {
@@ -532,7 +544,7 @@ func (s *StateDB) getDeletedStateObject(addr common.Address) *stateObject {
 			defer func(start time.Time) { s.SnapshotAccountReads += time.Since(start) }(time.Now())
 		}
 		var acc *snapshot.Account
-		if acc, err = s.snap.Account(crypto.Keccak256Hash(addr.Bytes())); err == nil {
+		if acc, err = s.snap.Account(crypto.HashData(s.hasher, addr.Bytes())); err == nil {
 			if acc == nil {
 				return nil
 			}
@@ -675,6 +687,7 @@ func (s *StateDB) Copy() *StateDB {
 		logSize:             s.logSize,
 		preimages:           make(map[common.Hash][]byte, len(s.preimages)),
 		journal:             newJournal(),
+		hasher:              crypto.NewKeccakState(),
 	}
 	// Copy the dirty states, logs, and preimages
 	for addr := range s.journal.dirties {
@@ -760,6 +773,7 @@ func (s *StateDB) GetRefund() uint64 {
 // the journal as well as the refunds. Finalise, however, will not push any updates
 // into the tries just yet. Only IntermediateRoot or Commit will do that.
 func (s *StateDB) Finalise(deleteEmptyObjects bool) {
+	var addressesToPrefetch []common.Address
 	for addr := range s.journal.dirties {
 		obj, exist := s.stateObjects[addr]
 		if !exist {
@@ -788,7 +802,17 @@ func (s *StateDB) Finalise(deleteEmptyObjects bool) {
 		}
 		s.stateObjectsPending[addr] = struct{}{}
 		s.stateObjectsDirty[addr] = struct{}{}
+		// At this point, also ship the address off to the precacher. The precacher
+		// will start loading tries, and when the change is eventually committed,
+		// the commit-phase will be a lot faster
+		if s.prefetcher != nil {
+			addressesToPrefetch = append(addressesToPrefetch, addr)
+		}
+	}
+	if s.prefetcher != nil {
+		s.prefetcher.PrefetchAddresses(addressesToPrefetch)
 	}
+
 	// Invalidate journal because reverting across transactions is not allowed.
 	s.clearJournalAndRefund()
 }
@@ -800,6 +824,21 @@ func (s *StateDB) IntermediateRoot(deleteEmptyObjects bool) common.Hash {
 	// Finalise all the dirty storage states and write them into the tries
 	s.Finalise(deleteEmptyObjects)
 
+	// Now we're about to start to write changes to the trie. The trie is so
+	// far _untouched_. We can check with the prefetcher, if it can give us
+	// a trie which has the same root, but also has some content loaded into it.
+	// If so, use that one instead.
+	if s.prefetcher != nil {
+		s.prefetcher.Pause()
+		// We only want to do this _once_, if someone calls IntermediateRoot again,
+		// we shouldn't fetch the trie again
+		if s.originalRoot != (common.Hash{}) {
+			if trie := s.prefetcher.GetTrie(s.originalRoot); trie != nil {
+				s.trie = trie
+			}
+			s.originalRoot = common.Hash{}
+		}
+	}
 	for addr := range s.stateObjectsPending {
 		obj := s.stateObjects[addr]
 		if obj.deleted {
diff --git a/core/state/trie_prefetcher.go b/core/state/trie_prefetcher.go
new file mode 100644
index 000000000..8a1aab325
--- /dev/null
+++ b/core/state/trie_prefetcher.go
@@ -0,0 +1,249 @@
+// Copyright 2020 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package state
+
+import (
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/log"
+	"github.com/ethereum/go-ethereum/metrics"
+)
+
+var (
+	// trieDeliveryMeter counts how many times the prefetcher was unable to supply
+	// the statedb with a prefilled trie. This meter should be zero -- if it's not, that
+	// needs to be investigated
+	trieDeliveryMissMeter = metrics.NewRegisteredMeter("trie/prefetch/deliverymiss", nil)
+
+	triePrefetchFetchMeter = metrics.NewRegisteredMeter("trie/prefetch/fetch", nil)
+	triePrefetchSkipMeter  = metrics.NewRegisteredMeter("trie/prefetch/skip", nil)
+	triePrefetchDropMeter  = metrics.NewRegisteredMeter("trie/prefetch/drop", nil)
+)
+
+// TriePrefetcher is an active prefetcher, which receives accounts or storage
+// items on two channels, and does trie-loading of the items.
+// The goal is to get as much useful content into the caches as possible
+type TriePrefetcher struct {
+	requestCh  chan (fetchRequest) // Chan to receive requests for data to fetch
+	cmdCh      chan (*cmd)         // Chan to control activity, pause/new root
+	quitCh     chan (struct{})
+	deliveryCh chan (struct{})
+	db         Database
+
+	paused bool
+
+	storageTries    map[common.Hash]Trie
+	accountTrie     Trie
+	accountTrieRoot common.Hash
+}
+
+func NewTriePrefetcher(db Database) *TriePrefetcher {
+	return &TriePrefetcher{
+		requestCh:  make(chan fetchRequest, 200),
+		cmdCh:      make(chan *cmd),
+		quitCh:     make(chan struct{}),
+		deliveryCh: make(chan struct{}),
+		db:         db,
+	}
+}
+
+type cmd struct {
+	root common.Hash
+}
+
+type fetchRequest struct {
+	slots       []common.Hash
+	storageRoot *common.Hash
+	addresses   []common.Address
+}
+
+func (p *TriePrefetcher) Loop() {
+	var (
+		accountTrieRoot common.Hash
+		accountTrie     Trie
+		storageTries    map[common.Hash]Trie
+
+		err error
+		// Some tracking of performance
+		skipped int64
+		fetched int64
+
+		paused = true
+	)
+	// The prefetcher loop has two distinct phases:
+	// 1: Paused: when in this state, the accumulated tries are accessible to outside
+	// callers.
+	// 2: Active prefetching, awaiting slots and accounts to prefetch
+	for {
+		select {
+		case <-p.quitCh:
+			return
+		case cmd := <-p.cmdCh:
+			// Clear out any old requests
+		drain:
+			for {
+				select {
+				case req := <-p.requestCh:
+					if req.slots != nil {
+						skipped += int64(len(req.slots))
+					} else {
+						skipped += int64(len(req.addresses))
+					}
+				default:
+					break drain
+				}
+			}
+			if paused {
+				// Clear old data
+				p.storageTries = nil
+				p.accountTrie = nil
+				p.accountTrieRoot = common.Hash{}
+				// Resume again
+				storageTries = make(map[common.Hash]Trie)
+				accountTrieRoot = cmd.root
+				accountTrie, err = p.db.OpenTrie(accountTrieRoot)
+				if err != nil {
+					log.Error("Trie prefetcher failed opening trie", "root", accountTrieRoot, "err", err)
+				}
+				if accountTrieRoot == (common.Hash{}) {
+					log.Error("Trie prefetcher unpaused with bad root")
+				}
+				paused = false
+			} else {
+				// Update metrics at new block events
+				triePrefetchFetchMeter.Mark(fetched)
+				triePrefetchSkipMeter.Mark(skipped)
+				fetched, skipped = 0, 0
+				// Make the tries accessible
+				p.accountTrie = accountTrie
+				p.storageTries = storageTries
+				p.accountTrieRoot = accountTrieRoot
+				if cmd.root != (common.Hash{}) {
+					log.Error("Trie prefetcher paused with non-empty root")
+				}
+				paused = true
+			}
+			p.deliveryCh <- struct{}{}
+		case req := <-p.requestCh:
+			if paused {
+				continue
+			}
+			if sRoot := req.storageRoot; sRoot != nil {
+				// Storage slots to fetch
+				var (
+					storageTrie Trie
+					err         error
+				)
+				if storageTrie = storageTries[*sRoot]; storageTrie == nil {
+					if storageTrie, err = p.db.OpenTrie(*sRoot); err != nil {
+						log.Warn("trie prefetcher failed opening storage trie", "root", *sRoot, "err", err)
+						skipped += int64(len(req.slots))
+						continue
+					}
+					storageTries[*sRoot] = storageTrie
+				}
+				for _, key := range req.slots {
+					storageTrie.TryGet(key[:])
+				}
+				fetched += int64(len(req.slots))
+			} else { // an account
+				for _, addr := range req.addresses {
+					accountTrie.TryGet(addr[:])
+				}
+				fetched += int64(len(req.addresses))
+			}
+		}
+	}
+}
+
+// Close stops the prefetcher
+func (p *TriePrefetcher) Close() {
+	if p.quitCh != nil {
+		close(p.quitCh)
+		p.quitCh = nil
+	}
+}
+
+// Resume causes the prefetcher to clear out old data, and get ready to
+// fetch data concerning the new root
+func (p *TriePrefetcher) Resume(root common.Hash) {
+	p.paused = false
+	p.cmdCh <- &cmd{
+		root: root,
+	}
+	// Wait for it
+	<-p.deliveryCh
+}
+
+// Pause causes the prefetcher to pause prefetching, and make tries
+// accessible to callers via GetTrie
+func (p *TriePrefetcher) Pause() {
+	if p.paused {
+		return
+	}
+	p.paused = true
+	p.cmdCh <- &cmd{
+		root: common.Hash{},
+	}
+	// Wait for it
+	<-p.deliveryCh
+}
+
+// PrefetchAddresses adds an address for prefetching
+func (p *TriePrefetcher) PrefetchAddresses(addresses []common.Address) {
+	cmd := fetchRequest{
+		addresses: addresses,
+	}
+	// We do an async send here, to not cause the caller to block
+	//p.requestCh <- cmd
+	select {
+	case p.requestCh <- cmd:
+	default:
+		triePrefetchDropMeter.Mark(int64(len(addresses)))
+	}
+}
+
+// PrefetchStorage adds a storage root and a set of keys for prefetching
+func (p *TriePrefetcher) PrefetchStorage(root common.Hash, slots []common.Hash) {
+	cmd := fetchRequest{
+		storageRoot: &root,
+		slots:       slots,
+	}
+	// We do an async send here, to not cause the caller to block
+	//p.requestCh <- cmd
+	select {
+	case p.requestCh <- cmd:
+	default:
+		triePrefetchDropMeter.Mark(int64(len(slots)))
+	}
+}
+
+// GetTrie returns the trie matching the root hash, or nil if the prefetcher
+// doesn't have it.
+func (p *TriePrefetcher) GetTrie(root common.Hash) Trie {
+	if root == p.accountTrieRoot {
+		return p.accountTrie
+	}
+	if storageTrie, ok := p.storageTries[root]; ok {
+		// Two accounts may well have the same storage root, but we cannot allow
+		// them both to make updates to the same trie instance. Therefore,
+		// we need to either delete the trie now, or deliver a copy of the trie.
+		delete(p.storageTries, root)
+		return storageTrie
+	}
+	trieDeliveryMissMeter.Mark(1)
+	return nil
+}
diff --git a/crypto/crypto.go b/crypto/crypto.go
index a4a49136a..40969a289 100644
--- a/crypto/crypto.go
+++ b/crypto/crypto.go
@@ -60,10 +60,23 @@ type KeccakState interface {
 	Read([]byte) (int, error)
 }
 
+// NewKeccakState creates a new KeccakState
+func NewKeccakState() KeccakState {
+	return sha3.NewLegacyKeccak256().(KeccakState)
+}
+
+// HashData hashes the provided data using the KeccakState and returns a 32 byte hash
+func HashData(kh KeccakState, data []byte) (h common.Hash) {
+	kh.Reset()
+	kh.Write(data)
+	kh.Read(h[:])
+	return h
+}
+
 // Keccak256 calculates and returns the Keccak256 hash of the input data.
 func Keccak256(data ...[]byte) []byte {
 	b := make([]byte, 32)
-	d := sha3.NewLegacyKeccak256().(KeccakState)
+	d := NewKeccakState()
 	for _, b := range data {
 		d.Write(b)
 	}
@@ -74,7 +87,7 @@ func Keccak256(data ...[]byte) []byte {
 // Keccak256Hash calculates and returns the Keccak256 hash of the input data,
 // converting it to an internal Hash data structure.
 func Keccak256Hash(data ...[]byte) (h common.Hash) {
-	d := sha3.NewLegacyKeccak256().(KeccakState)
+	d := NewKeccakState()
 	for _, b := range data {
 		d.Write(b)
 	}
diff --git a/crypto/crypto_test.go b/crypto/crypto_test.go
index f71ae8232..f9b0d3e83 100644
--- a/crypto/crypto_test.go
+++ b/crypto/crypto_test.go
@@ -42,6 +42,13 @@ func TestKeccak256Hash(t *testing.T) {
 	checkhash(t, "Sha3-256-array", func(in []byte) []byte { h := Keccak256Hash(in); return h[:] }, msg, exp)
 }
 
+func TestKeccak256Hasher(t *testing.T) {
+	msg := []byte("abc")
+	exp, _ := hex.DecodeString("4e03657aea45a94fc7d47ba826c8d667c0d1e6e33a64a036ec44f58fa12d6c45")
+	hasher := NewKeccakState()
+	checkhash(t, "Sha3-256-array", func(in []byte) []byte { h := HashData(hasher, in); return h[:] }, msg, exp)
+}
+
 func TestToECDSAErrors(t *testing.T) {
 	if _, err := HexToECDSA("0000000000000000000000000000000000000000000000000000000000000000"); err == nil {
 		t.Fatal("HexToECDSA should've returned error")
-- 
GitLab