From 9ec3329899a0ff62ed2f83c61b50140881a577a8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Tue, 16 Feb 2021 09:04:07 +0200
Subject: [PATCH] core/state/snapshot: ensure Cap retains a min number of
 layers

---
 core/state/snapshot/snapshot.go      |  47 ++++--------
 core/state/snapshot/snapshot_test.go | 111 ++++++++++++---------------
 2 files changed, 65 insertions(+), 93 deletions(-)

diff --git a/core/state/snapshot/snapshot.go b/core/state/snapshot/snapshot.go
index df2b1ed33..aa5f5900b 100644
--- a/core/state/snapshot/snapshot.go
+++ b/core/state/snapshot/snapshot.go
@@ -300,6 +300,12 @@ func (t *Tree) Update(blockRoot common.Hash, parentRoot common.Hash, destructs m
 // Cap traverses downwards the snapshot tree from a head block hash until the
 // number of allowed layers are crossed. All layers beyond the permitted number
 // are flattened downwards.
+//
+// Note, the final diff layer count in general will be one more than the amount
+// requested. This happens because the bottom-most diff layer is the accumulator
+// which may or may not overflow and cascade to disk. Since this last layer's
+// survival is only known *after* capping, we need to omit it from the count if
+// we want to ensure that *at least* the requested number of diff layers remain.
 func (t *Tree) Cap(root common.Hash, layers int) error {
 	// Retrieve the head snapshot to cap from
 	snap := t.Snapshot(root)
@@ -324,10 +330,7 @@ func (t *Tree) Cap(root common.Hash, layers int) error {
 	// Flattening the bottom-most diff layer requires special casing since there's
 	// no child to rewire to the grandparent. In that case we can fake a temporary
 	// child for the capping and then remove it.
-	var persisted *diskLayer
-
-	switch layers {
-	case 0:
+	if layers == 0 {
 		// If full commit was requested, flatten the diffs and merge onto disk
 		diff.lock.RLock()
 		base := diffToDisk(diff.flatten().(*diffLayer))
@@ -336,33 +339,9 @@ func (t *Tree) Cap(root common.Hash, layers int) error {
 		// Replace the entire snapshot tree with the flat base
 		t.layers = map[common.Hash]snapshot{base.root: base}
 		return nil
-
-	case 1:
-		// If full flattening was requested, flatten the diffs but only merge if the
-		// memory limit was reached
-		var (
-			bottom *diffLayer
-			base   *diskLayer
-		)
-		diff.lock.RLock()
-		bottom = diff.flatten().(*diffLayer)
-		if bottom.memory >= aggregatorMemoryLimit {
-			base = diffToDisk(bottom)
-		}
-		diff.lock.RUnlock()
-
-		// If all diff layers were removed, replace the entire snapshot tree
-		if base != nil {
-			t.layers = map[common.Hash]snapshot{base.root: base}
-			return nil
-		}
-		// Merge the new aggregated layer into the snapshot tree, clean stales below
-		t.layers[bottom.root] = bottom
-
-	default:
-		// Many layers requested to be retained, cap normally
-		persisted = t.cap(diff, layers)
 	}
+	persisted := t.cap(diff, layers)
+
 	// Remove any layer that is stale or links into a stale layer
 	children := make(map[common.Hash][]common.Hash)
 	for root, snap := range t.layers {
@@ -405,9 +384,15 @@ func (t *Tree) Cap(root common.Hash, layers int) error {
 // layer limit is reached, memory cap is also enforced (but not before).
 //
 // The method returns the new disk layer if diffs were persisted into it.
+//
+// Note, the final diff layer count in general will be one more than the amount
+// requested. This happens because the bottom-most diff layer is the accumulator
+// which may or may not overflow and cascade to disk. Since this last layer's
+// survival is only known *after* capping, we need to omit it from the count if
+// we want to ensure that *at least* the requested number of diff layers remain.
 func (t *Tree) cap(diff *diffLayer, layers int) *diskLayer {
 	// Dive until we run out of layers or reach the persistent database
-	for ; layers > 2; layers-- {
+	for i := 0; i < layers-1; i++ {
 		// If we still have diff layers below, continue down
 		if parent, ok := diff.parent.(*diffLayer); ok {
 			diff = parent
diff --git a/core/state/snapshot/snapshot_test.go b/core/state/snapshot/snapshot_test.go
index a315fd216..4b787cfe2 100644
--- a/core/state/snapshot/snapshot_test.go
+++ b/core/state/snapshot/snapshot_test.go
@@ -162,57 +162,10 @@ func TestDiskLayerExternalInvalidationPartialFlatten(t *testing.T) {
 	defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit)
 	aggregatorMemoryLimit = 0
 
-	if err := snaps.Cap(common.HexToHash("0x03"), 2); err != nil {
-		t.Fatalf("failed to merge diff layer onto disk: %v", err)
-	}
-	// Since the base layer was modified, ensure that data retrievald on the external reference fail
-	if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
-		t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
-	}
-	if slot, err := ref.Storage(common.HexToHash("0xa1"), common.HexToHash("0xb1")); err != ErrSnapshotStale {
-		t.Errorf("stale reference returned storage slot: %#x (err: %v)", slot, err)
-	}
-	if n := len(snaps.layers); n != 2 {
-		t.Errorf("post-cap layer count mismatch: have %d, want %d", n, 2)
-		fmt.Println(snaps.layers)
-	}
-}
-
-// Tests that if a diff layer becomes stale, no active external references will
-// be returned with junk data. This version of the test flattens every diff layer
-// to check internal corner case around the bottom-most memory accumulator.
-func TestDiffLayerExternalInvalidationFullFlatten(t *testing.T) {
-	// Create an empty base layer and a snapshot tree out of it
-	base := &diskLayer{
-		diskdb: rawdb.NewMemoryDatabase(),
-		root:   common.HexToHash("0x01"),
-		cache:  fastcache.New(1024 * 500),
-	}
-	snaps := &Tree{
-		layers: map[common.Hash]snapshot{
-			base.root: base,
-		},
-	}
-	// Commit two diffs on top and retrieve a reference to the bottommost
-	accounts := map[common.Hash][]byte{
-		common.HexToHash("0xa1"): randomAccount(),
-	}
-	if err := snaps.Update(common.HexToHash("0x02"), common.HexToHash("0x01"), nil, accounts, nil); err != nil {
-		t.Fatalf("failed to create a diff layer: %v", err)
-	}
-	if err := snaps.Update(common.HexToHash("0x03"), common.HexToHash("0x02"), nil, accounts, nil); err != nil {
-		t.Fatalf("failed to create a diff layer: %v", err)
-	}
-	if n := len(snaps.layers); n != 3 {
-		t.Errorf("pre-cap layer count mismatch: have %d, want %d", n, 3)
-	}
-	ref := snaps.Snapshot(common.HexToHash("0x02"))
-
-	// Flatten the diff layer into the bottom accumulator
 	if err := snaps.Cap(common.HexToHash("0x03"), 1); err != nil {
-		t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
+		t.Fatalf("failed to merge accumulator onto disk: %v", err)
 	}
-	// Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail
+	// Since the base layer was modified, ensure that data retrievald on the external reference fail
 	if acc, err := ref.Account(common.HexToHash("0x01")); err != ErrSnapshotStale {
 		t.Errorf("stale reference returned account: %#x (err: %v)", acc, err)
 	}
@@ -267,7 +220,7 @@ func TestDiffLayerExternalInvalidationPartialFlatten(t *testing.T) {
 		t.Errorf("layers modified, got %d exp %d", got, exp)
 	}
 	// Flatten the diff layer into the bottom accumulator
-	if err := snaps.Cap(common.HexToHash("0x04"), 2); err != nil {
+	if err := snaps.Cap(common.HexToHash("0x04"), 1); err != nil {
 		t.Fatalf("failed to flatten diff layer into accumulator: %v", err)
 	}
 	// Since the accumulator diff layer was modified, ensure that data retrievald on the external reference fail
@@ -389,7 +342,7 @@ func TestSnaphots(t *testing.T) {
 	// Create a starting base layer and a snapshot tree out of it
 	base := &diskLayer{
 		diskdb: rawdb.NewMemoryDatabase(),
-		root:   common.HexToHash("0x01"),
+		root:   makeRoot(1),
 		cache:  fastcache.New(1024 * 500),
 	}
 	snaps := &Tree{
@@ -397,17 +350,16 @@ func TestSnaphots(t *testing.T) {
 			base.root: base,
 		},
 	}
-	// Construct the snapshots with 128 layers
+	// Construct the snapshots with 129 layers, flattening whatever's above that
 	var (
 		last = common.HexToHash("0x01")
 		head common.Hash
 	)
-	// Flush another 128 layers, one diff will be flatten into the parent.
-	for i := 0; i < 128; i++ {
+	for i := 0; i < 129; i++ {
 		head = makeRoot(uint64(i + 2))
 		snaps.Update(head, last, nil, setAccount(fmt.Sprintf("%d", i+2)), nil)
 		last = head
-		snaps.Cap(head, 128) // 129 layers(128 diffs + 1 disk) are allowed, 129th is the persistent layer
+		snaps.Cap(head, 128) // 130 layers (128 diffs + 1 accumulator + 1 disk)
 	}
 	var cases = []struct {
 		headRoot     common.Hash
@@ -417,22 +369,57 @@ func TestSnaphots(t *testing.T) {
 		expectBottom common.Hash
 	}{
 		{head, 0, false, 0, common.Hash{}},
-		{head, 64, false, 64, makeRoot(127 + 2 - 63)},
-		{head, 128, false, 128, makeRoot(2)},              // All diff layers
-		{head, 129, true, 128, makeRoot(2)},               // All diff layers
-		{head, 129, false, 129, common.HexToHash("0x01")}, // All diff layers + disk layer
+		{head, 64, false, 64, makeRoot(129 + 2 - 64)},
+		{head, 128, false, 128, makeRoot(3)}, // Normal diff layers, no accumulator
+		{head, 129, true, 129, makeRoot(2)},  // All diff layers, including accumulator
+		{head, 130, false, 130, makeRoot(1)}, // All diff layers + disk layer
+	}
+	for i, c := range cases {
+		layers := snaps.Snapshots(c.headRoot, c.limit, c.nodisk)
+		if len(layers) != c.expected {
+			t.Errorf("non-overflow test %d: returned snapshot layers are mismatched, want %v, got %v", i, c.expected, len(layers))
+		}
+		if len(layers) == 0 {
+			continue
+		}
+		bottommost := layers[len(layers)-1]
+		if bottommost.Root() != c.expectBottom {
+			t.Errorf("non-overflow test %d: snapshot mismatch, want %v, get %v", i, c.expectBottom, bottommost.Root())
+		}
+	}
+	// Above we've tested the normal capping, which leaves the accumulator live.
+	// Test that if the bottommost accumulator diff layer overflows the allowed
+	// memory limit, the snapshot tree gets capped to one less layer.
+	// Commit the diff layer onto the disk and ensure it's persisted
+	defer func(memcap uint64) { aggregatorMemoryLimit = memcap }(aggregatorMemoryLimit)
+	aggregatorMemoryLimit = 0
+
+	snaps.Cap(head, 128) // 129 (128 diffs + 1 overflown accumulator + 1 disk)
+
+	cases = []struct {
+		headRoot     common.Hash
+		limit        int
+		nodisk       bool
+		expected     int
+		expectBottom common.Hash
+	}{
+		{head, 0, false, 0, common.Hash{}},
+		{head, 64, false, 64, makeRoot(129 + 2 - 64)},
+		{head, 128, false, 128, makeRoot(3)}, // All diff layers, accumulator was flattened
+		{head, 129, true, 128, makeRoot(3)},  // All diff layers, accumulator was flattened
+		{head, 130, false, 129, makeRoot(2)}, // All diff layers + disk layer
 	}
-	for _, c := range cases {
+	for i, c := range cases {
 		layers := snaps.Snapshots(c.headRoot, c.limit, c.nodisk)
 		if len(layers) != c.expected {
-			t.Fatalf("Returned snapshot layers are mismatched, want %v, got %v", c.expected, len(layers))
+			t.Errorf("overflow test %d: returned snapshot layers are mismatched, want %v, got %v", i, c.expected, len(layers))
 		}
 		if len(layers) == 0 {
 			continue
 		}
 		bottommost := layers[len(layers)-1]
 		if bottommost.Root() != c.expectBottom {
-			t.Fatalf("Snapshot mismatch, want %v, get %v", c.expectBottom, bottommost.Root())
+			t.Errorf("overflow test %d: snapshot mismatch, want %v, get %v", i, c.expectBottom, bottommost.Root())
 		}
 	}
 }
-- 
GitLab