From e04f7fc9f8a3da6784118d0038b22a7290e0a0ef Mon Sep 17 00:00:00 2001
From: Alex Sharov <AskAlexSharov@gmail.com>
Date: Tue, 26 Apr 2022 12:54:05 +0700
Subject: [PATCH] Integration: allow headers --reset (#3972)

---
 cmd/integration/commands/stages.go | 91 ++++++++++++++++++++----------
 eth/stagedsync/stage_headers.go    |  2 +-
 2 files changed, 63 insertions(+), 30 deletions(-)

diff --git a/cmd/integration/commands/stages.go b/cmd/integration/commands/stages.go
index 9155d5e889..8318260036 100644
--- a/cmd/integration/commands/stages.go
+++ b/cmd/integration/commands/stages.go
@@ -322,6 +322,7 @@ func init() {
 
 	withDataDir(cmdStageHeaders)
 	withUnwind(cmdStageHeaders)
+	withReset(cmdStageHeaders)
 	withChain(cmdStageHeaders)
 	withHeimdall(cmdStageHeaders)
 
@@ -441,48 +442,80 @@ func init() {
 	rootCmd.AddCommand(cmdSetPrune)
 }
 
+// max is a helper function which returns the larger of the two given integers.
+func max(a, b uint64) uint64 { //nolint:unparam
+	if a > b {
+		return a
+	}
+	return b
+}
+
 func stageHeaders(db kv.RwDB, ctx context.Context) error {
 	return db.Update(ctx, func(tx kv.RwTx) error {
-		if unwind > 0 {
+		if !(unwind > 0 || reset) {
+			log.Info("This command only works with --unwind or --reset options")
+		}
+
+		if reset {
 			progress, err := stages.GetStageProgress(tx, stages.Headers)
 			if err != nil {
 				return fmt.Errorf("read Bodies progress: %w", err)
 			}
-			if unwind > progress {
-				return fmt.Errorf("cannot unwind past 0")
-			}
-			if err = stages.SaveStageProgress(tx, stages.Headers, progress-unwind); err != nil {
-				return fmt.Errorf("saving Bodies progress failed: %w", err)
-			}
-			progress, err = stages.GetStageProgress(tx, stages.Headers)
-			if err != nil {
-				return fmt.Errorf("re-read Bodies progress: %w", err)
-			}
+			unwind = progress
+		}
+
+		progress, err := stages.GetStageProgress(tx, stages.Headers)
+		if err != nil {
+			return fmt.Errorf("read Bodies progress: %w", err)
+		}
+		var unwindTo uint64
+		if unwind > progress {
+			unwindTo = 1 // keep genesis
+		} else {
+			unwindTo = max(1, progress-unwind)
+		}
+
+		if err = stages.SaveStageProgress(tx, stages.Headers, unwindTo); err != nil {
+			return fmt.Errorf("saving Bodies progress failed: %w", err)
+		}
+		progress, err = stages.GetStageProgress(tx, stages.Headers)
+		if err != nil {
+			return fmt.Errorf("re-read Bodies progress: %w", err)
+		}
+		{ // hard-unwind stage_body also
 			if err := rawdb.DeleteNewBlocks(tx, progress+1); err != nil {
 				return err
 			}
-			// remove all canonical markers from this point
-			if err := tx.ForEach(kv.HeaderCanonical, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
-				return tx.Delete(kv.HeaderCanonical, k, nil)
-			}); err != nil {
-				return err
-			}
-			if err := tx.ForEach(kv.HeaderTD, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
-				return tx.Delete(kv.HeaderTD, k, nil)
-			}); err != nil {
-				return err
-			}
-			hash, err := rawdb.ReadCanonicalHash(tx, progress-1)
+			progressBodies, err := stages.GetStageProgress(tx, stages.Bodies)
 			if err != nil {
-				return err
+				return fmt.Errorf("read Bodies progress: %w", err)
 			}
-			if err = tx.Put(kv.HeadHeaderKey, []byte(kv.HeadHeaderKey), hash[:]); err != nil {
-				return err
+			if progress < progressBodies {
+				if err = stages.SaveStageProgress(tx, stages.Bodies, progress); err != nil {
+					return fmt.Errorf("saving Bodies progress failed: %w", err)
+				}
 			}
-			log.Info("Progress", "headers", progress)
-			return nil
 		}
-		log.Info("This command only works with --unwind option")
+		// remove all canonical markers from this point
+		if err := tx.ForEach(kv.HeaderCanonical, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
+			return tx.Delete(kv.HeaderCanonical, k, nil)
+		}); err != nil {
+			return err
+		}
+		if err := tx.ForEach(kv.HeaderTD, dbutils.EncodeBlockNumber(progress+1), func(k, v []byte) error {
+			return tx.Delete(kv.HeaderTD, k, nil)
+		}); err != nil {
+			return err
+		}
+		hash, err := rawdb.ReadCanonicalHash(tx, progress-1)
+		if err != nil {
+			return err
+		}
+		if err = tx.Put(kv.HeadHeaderKey, []byte(kv.HeadHeaderKey), hash[:]); err != nil {
+			return err
+		}
+
+		log.Info("Progress", "headers", progress)
 		return nil
 	})
 }
diff --git a/eth/stagedsync/stage_headers.go b/eth/stagedsync/stage_headers.go
index 2c13774062..c23f84ed11 100644
--- a/eth/stagedsync/stage_headers.go
+++ b/eth/stagedsync/stage_headers.go
@@ -1100,7 +1100,7 @@ func DownloadAndIndexSnapshotsIfNeed(s *StageState, ctx context.Context, tx kv.R
 		}
 	}
 
-	if s.BlockNumber == 0 {
+	if s.BlockNumber < 2 { // allow genesis
 		logEvery := time.NewTicker(logInterval)
 		defer logEvery.Stop()
 
-- 
GitLab