From a1313b5b1e9713e7e10f5ef29b7981c9c49d688b Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Tue, 4 Feb 2020 13:02:38 +0100
Subject: [PATCH] trie: make hasher parallel when number of changes are large
 (#20488)

* trie: make hasher parallel when number of changes are large

* trie: remove unused field dirtyCount

* trie: rename unhashedCount/unhashed
---
 trie/hasher.go      | 37 ++++++++++++++++++++++++++++---------
 trie/iterator.go    |  2 +-
 trie/proof.go       |  2 +-
 trie/secure_trie.go |  2 +-
 trie/trie.go        | 14 +++++++++++---
 5 files changed, 42 insertions(+), 15 deletions(-)

diff --git a/trie/hasher.go b/trie/hasher.go
index 71a3aec3b..8e8eec9f6 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -46,8 +46,9 @@ func (b *sliceBuffer) Reset() {
 // hasher is a type used for the trie Hash operation. A hasher has some
 // internal preallocated temp space
 type hasher struct {
-	sha keccakState
-	tmp sliceBuffer
+	sha      keccakState
+	tmp      sliceBuffer
+	parallel bool // Whether to use paralallel threads when hashing
 }
 
 // hasherPool holds pureHashers
@@ -60,8 +61,9 @@ var hasherPool = sync.Pool{
 	},
 }
 
-func newHasher() *hasher {
+func newHasher(parallel bool) *hasher {
 	h := hasherPool.Get().(*hasher)
+	h.parallel = parallel
 	return h
 }
 
@@ -126,14 +128,31 @@ func (h *hasher) hashFullNodeChildren(n *fullNode) (collapsed *fullNode, cached
 	// Hash the full node's children, caching the newly hashed subtrees
 	cached = n.copy()
 	collapsed = n.copy()
-	for i := 0; i < 16; i++ {
-		if child := n.Children[i]; child != nil {
-			collapsed.Children[i], cached.Children[i] = h.hash(child, false)
-		} else {
-			collapsed.Children[i] = nilValueNode
+	if h.parallel {
+		var wg sync.WaitGroup
+		wg.Add(16)
+		for i := 0; i < 16; i++ {
+			go func(i int) {
+				hasher := newHasher(false)
+				if child := n.Children[i]; child != nil {
+					collapsed.Children[i], cached.Children[i] = hasher.hash(child, false)
+				} else {
+					collapsed.Children[i] = nilValueNode
+				}
+				returnHasherToPool(hasher)
+				wg.Done()
+			}(i)
+		}
+		wg.Wait()
+	} else {
+		for i := 0; i < 16; i++ {
+			if child := n.Children[i]; child != nil {
+				collapsed.Children[i], cached.Children[i] = h.hash(child, false)
+			} else {
+				collapsed.Children[i] = nilValueNode
+			}
 		}
 	}
-	cached.Children[16] = n.Children[16]
 	return collapsed, cached
 }
 
diff --git a/trie/iterator.go b/trie/iterator.go
index 94b36a018..bb4025d8f 100644
--- a/trie/iterator.go
+++ b/trie/iterator.go
@@ -182,7 +182,7 @@ func (it *nodeIterator) LeafBlob() []byte {
 func (it *nodeIterator) LeafProof() [][]byte {
 	if len(it.stack) > 0 {
 		if _, ok := it.stack[len(it.stack)-1].node.(valueNode); ok {
-			hasher := newHasher()
+			hasher := newHasher(false)
 			defer returnHasherToPool(hasher)
 			proofs := make([][]byte, 0, len(it.stack))
 
diff --git a/trie/proof.go b/trie/proof.go
index f2c4658c4..58ca69c68 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -64,7 +64,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.KeyValueWriter) e
 			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
 		}
 	}
-	hasher := newHasher()
+	hasher := newHasher(false)
 	defer returnHasherToPool(hasher)
 
 	for i, n := range nodes {
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index b76a1dc8a..955771495 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -176,7 +176,7 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {
 // The caller must not hold onto the return value because it will become
 // invalid on the next call to hashKey or secKey.
 func (t *SecureTrie) hashKey(key []byte) []byte {
-	h := newHasher()
+	h := newHasher(false)
 	h.sha.Reset()
 	h.sha.Write(key)
 	buf := h.sha.Sum(t.hashKeyBuf[:0])
diff --git a/trie/trie.go b/trie/trie.go
index dd26f9b34..78e2eff53 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -48,6 +48,10 @@ type LeafCallback func(leaf []byte, parent common.Hash) error
 type Trie struct {
 	db   *Database
 	root node
+	// Keep track of the number leafs which have been inserted since the last
+	// hashing operation. This number will not directly map to the number of
+	// actually unhashed nodes
+	unhashed int
 }
 
 // newFlag returns the cache flag value for a newly created node.
@@ -163,6 +167,7 @@ func (t *Trie) Update(key, value []byte) {
 //
 // If a node was not found in the database, a MissingNodeError is returned.
 func (t *Trie) TryUpdate(key, value []byte) error {
+	t.unhashed++
 	k := keybytesToHex(key)
 	if len(value) != 0 {
 		_, n, err := t.insert(t.root, nil, k, valueNode(value))
@@ -259,6 +264,7 @@ func (t *Trie) Delete(key []byte) {
 // TryDelete removes any existing value for key from the trie.
 // If a node was not found in the database, a MissingNodeError is returned.
 func (t *Trie) TryDelete(key []byte) error {
+	t.unhashed++
 	k := keybytesToHex(key)
 	_, n, err := t.delete(t.root, nil, k)
 	if err != nil {
@@ -405,7 +411,7 @@ func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
 // Hash returns the root hash of the trie. It does not write to the
 // database and can be used even if the trie doesn't have one.
 func (t *Trie) Hash() common.Hash {
-	hash, cached, _ := t.hashRoot(nil, nil)
+	hash, cached, _ := t.hashRoot(nil)
 	t.root = cached
 	return common.BytesToHash(hash.(hashNode))
 }
@@ -456,12 +462,14 @@ func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
 }
 
 // hashRoot calculates the root hash of the given trie
-func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) {
+func (t *Trie) hashRoot(db *Database) (node, node, error) {
 	if t.root == nil {
 		return hashNode(emptyRoot.Bytes()), nil, nil
 	}
-	h := newHasher()
+	// If the number of changes is below 100, we let one thread handle it
+	h := newHasher(t.unhashed >= 100)
 	defer returnHasherToPool(h)
 	hashed, cached := h.hash(t.root, true)
+	t.unhashed = 0
 	return hashed, cached, nil
 }
-- 
GitLab