From 8d56bf5ceb74a7ed45c986450848a89e2df61189 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Mon, 17 Oct 2016 23:01:29 +0200
Subject: [PATCH] trie: ensure dirty flag is unset for embedded child nodes

This was caught by the new invariant check.
---
 trie/hasher.go    | 31 ++++++++++++++-----------------
 trie/trie_test.go | 43 ++++++++++++++++++++++++++++---------------
 2 files changed, 42 insertions(+), 32 deletions(-)

diff --git a/trie/hasher.go b/trie/hasher.go
index 57e156ebf..b6223bf32 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -75,23 +75,20 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
 	if err != nil {
 		return hashNode{}, n, err
 	}
-	// Cache the hash of the ndoe for later reuse.
-	if hash, ok := hashed.(hashNode); ok && !force {
-		switch cached := cached.(type) {
-		case *shortNode:
-			cached = cached.copy()
-			cached.flags.hash = hash
-			if db != nil {
-				cached.flags.dirty = false
-			}
-			return hashed, cached, nil
-		case *fullNode:
-			cached = cached.copy()
-			cached.flags.hash = hash
-			if db != nil {
-				cached.flags.dirty = false
-			}
-			return hashed, cached, nil
+	// Cache the hash of the ndoe for later reuse and remove
+	// the dirty flag in commit mode. It's fine to assign these values directly
+	// without copying the node first because hashChildren copies it.
+	cachedHash, _ := hashed.(hashNode)
+	switch cn := cached.(type) {
+	case *shortNode:
+		cn.flags.hash = cachedHash
+		if db != nil {
+			cn.flags.dirty = false
+		}
+	case *fullNode:
+		cn.flags.hash = cachedHash
+		if db != nil {
+			cn.flags.dirty = false
 		}
 	}
 	return hashed, cached, nil
diff --git a/trie/trie_test.go b/trie/trie_test.go
index da0d2360b..14ac5a666 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -462,31 +462,44 @@ func runRandTest(rt randTest) bool {
 				return false
 			}
 		case opCheckCacheInvariant:
-			return checkCacheInvariant(tr.root, tr.cachegen, 0)
+			return checkCacheInvariant(tr.root, nil, tr.cachegen, false, 0)
 		}
 	}
 	return true
 }
 
-func checkCacheInvariant(n node, parentCachegen uint16, depth int) bool {
+func checkCacheInvariant(n, parent node, parentCachegen uint16, parentDirty bool, depth int) bool {
+	var children []node
+	var flag nodeFlag
 	switch n := n.(type) {
 	case *shortNode:
-		if n.flags.gen > parentCachegen {
-			fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
-			return false
-		}
-		return checkCacheInvariant(n.Val, n.flags.gen, depth+1)
+		flag = n.flags
+		children = []node{n.Val}
 	case *fullNode:
-		if n.flags.gen > parentCachegen {
-			fmt.Printf("cache invariant violation: %d > %d\nat depth %d node %s", n.flags.gen, parentCachegen, depth, spew.Sdump(n))
+		flag = n.flags
+		children = n.Children[:]
+	default:
+		return true
+	}
+
+	showerror := func() {
+		fmt.Printf("at depth %d node %s", depth, spew.Sdump(n))
+		fmt.Printf("parent: %s", spew.Sdump(parent))
+	}
+	if flag.gen > parentCachegen {
+		fmt.Printf("cache invariant violation: %d > %d\n", flag.gen, parentCachegen)
+		showerror()
+		return false
+	}
+	if depth > 0 && !parentDirty && flag.dirty {
+		fmt.Printf("cache invariant violation: child is dirty but parent isn't\n")
+		showerror()
+		return false
+	}
+	for _, child := range children {
+		if !checkCacheInvariant(child, n, flag.gen, flag.dirty, depth+1) {
 			return false
 		}
-		for _, child := range n.Children {
-			if !checkCacheInvariant(child, n.flags.gen, depth+1) {
-				return false
-			}
-		}
-		return true
 	}
 	return true
 }
-- 
GitLab