From d9c5ef3eb9962b18f0c676b734560f936cf98131 Mon Sep 17 00:00:00 2001
From: Alex Sharov <AskAlexSharov@gmail.com>
Date: Sun, 18 Jul 2021 15:59:05 +0700
Subject: [PATCH] Pruning stages order support (#2393)

---
 cmd/integration/commands/stages.go |   7 +-
 eth/backend.go                     |   2 +-
 eth/stagedsync/default_stages.go   | 142 ++++++++++++-----------
 eth/stagedsync/stage.go            |   7 +-
 eth/stagedsync/stagebuilder.go     |  11 --
 eth/stagedsync/sync.go             | 176 ++++++++++++++++-------------
 eth/stagedsync/sync_test.go        |  52 ++++-----
 turbo/stages/mock_sentry.go        | 147 ++++++++++++------------
 turbo/stages/stageloop.go          | 160 +++++++++++---------------
 9 files changed, 351 insertions(+), 353 deletions(-)

diff --git a/cmd/integration/commands/stages.go b/cmd/integration/commands/stages.go
index 43f058a4b1..bd8f86b082 100644
--- a/cmd/integration/commands/stages.go
+++ b/cmd/integration/commands/stages.go
@@ -826,7 +826,8 @@ func newSync(ctx context.Context, db ethdb.RwKV, miningConfig *params.MiningConf
 			stagedsync.StageTrieCfg(db, false, true, tmpdir),
 			stagedsync.StageMiningFinishCfg(db, *chainConfig, engine, miner, ctx.Done()),
 		),
-		stagedsync.MiningUnwindOrder(),
+		stagedsync.MiningUnwindOrder,
+		stagedsync.MiningPruneOrder,
 	)
 
 	return sm, engine, chainConfig, vmConfig, txPool, sync, miningSync, miner
@@ -840,8 +841,8 @@ func progress(tx ethdb.KVGetter, stage stages.SyncStage) uint64 {
 	return res
 }
 
-func stage(st *stagedsync.Sync, db ethdb.KVGetter, stage stages.SyncStage) *stagedsync.StageState {
-	res, err := st.StageState(stage, db)
+func stage(st *stagedsync.Sync, tx ethdb.Tx, stage stages.SyncStage) *stagedsync.StageState {
+	res, err := st.StageState(stage, tx, nil)
 	if err != nil {
 		panic(err)
 	}
diff --git a/eth/backend.go b/eth/backend.go
index 46fe263a3e..6905c11668 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -276,7 +276,7 @@ func New(stack *node.Node, config *ethconfig.Config) (*Ethereum, error) {
 			stagedsync.StageHashStateCfg(backend.chainKV, tmpdir),
 			stagedsync.StageTrieCfg(backend.chainKV, false, true, tmpdir),
 			stagedsync.StageMiningFinishCfg(backend.chainKV, *backend.chainConfig, backend.engine, miner, backend.miningSealingQuit),
-		), stagedsync.MiningUnwindOrder())
+		), stagedsync.MiningUnwindOrder, stagedsync.MiningPruneOrder)
 
 	var ethashApi *ethash.API
 	if casted, ok := backend.engine.(*ethash.Ethash); ok {
diff --git a/eth/stagedsync/default_stages.go b/eth/stagedsync/default_stages.go
index 97dc3d25e0..7d8d5762fb 100644
--- a/eth/stagedsync/default_stages.go
+++ b/eth/stagedsync/default_stages.go
@@ -285,78 +285,88 @@ func DefaultStages(ctx context.Context,
 	}
 }
 
-func DefaultForwardOrder() UnwindOrder {
-	return []stages.SyncStage{
-		stages.Headers,
-		stages.BlockHashes,
-		stages.CreateHeadersSnapshot,
-		stages.Bodies,
-		stages.CreateBodiesSnapshot,
-		stages.Senders,
-		stages.Execution,
-		stages.Translation,
-		stages.CreateStateSnapshot,
-		stages.HashState,
-		stages.IntermediateHashes,
-		stages.CallTraces,
-		stages.AccountHistoryIndex,
-		stages.StorageHistoryIndex,
-		stages.LogIndex,
-		stages.TxLookup,
-		stages.TxPool,
-		stages.Finish,
-	}
+var DefaultForwardOrder = UnwindOrder{
+	stages.Headers,
+	stages.BlockHashes,
+	stages.CreateHeadersSnapshot,
+	stages.Bodies,
+	stages.CreateBodiesSnapshot,
+	stages.Senders,
+	stages.Execution,
+	stages.Translation,
+	stages.CreateStateSnapshot,
+	stages.HashState,
+	stages.IntermediateHashes,
+	stages.CallTraces,
+	stages.AccountHistoryIndex,
+	stages.StorageHistoryIndex,
+	stages.LogIndex,
+	stages.TxLookup,
+	stages.TxPool,
+	stages.Finish,
 }
 
-func DefaultPruningOrder() UnwindOrder {
-	return []stages.SyncStage{
-		stages.Headers,
-		stages.BlockHashes,
-		stages.CreateHeadersSnapshot,
-		stages.Bodies,
-		stages.CreateBodiesSnapshot,
-		stages.Senders,
-		stages.Execution,
-		stages.Translation,
-		stages.CreateStateSnapshot,
-		stages.HashState,
-		stages.IntermediateHashes,
-		stages.CallTraces,
-		stages.AccountHistoryIndex,
-		stages.StorageHistoryIndex,
-		stages.LogIndex,
-		stages.TxLookup,
-		stages.TxPool,
-		stages.Finish,
-	}
+var DefaultPruneOrder = PruneOrder{
+	stages.Headers,
+	stages.BlockHashes,
+	stages.CreateHeadersSnapshot,
+	stages.Bodies,
+	stages.CreateBodiesSnapshot,
+
+	// Unwinding of tx pool (reinjecting transactions into the pool needs to happen after unwinding execution)
+	// also tx pool is before senders because senders unwind is inside cycle transaction
+	stages.TxPool,
+
+	stages.Senders,
+	stages.Execution,
+	stages.Translation,
+	stages.CreateStateSnapshot,
+
+	// Unwinding of IHashes needs to happen after unwinding HashState
+	stages.IntermediateHashes,
+	stages.HashState,
+
+	stages.CallTraces,
+	stages.AccountHistoryIndex,
+	stages.StorageHistoryIndex,
+	stages.LogIndex,
+	stages.TxLookup,
+	stages.Finish,
 }
 
-func DefaultUnwindOrder() UnwindOrder {
-	return []stages.SyncStage{
-		stages.Headers,
-		stages.BlockHashes,
-		stages.CreateHeadersSnapshot,
-		stages.Bodies,
-		stages.CreateBodiesSnapshot,
+// UnwindOrder represents the order in which the stages needs to be unwound.
+// The unwind order is important and not always just stages going backwards.
+// Let's say, there is tx pool can be unwound only after execution.
+type UnwindOrder []stages.SyncStage
+type PruneOrder []stages.SyncStage
 
-		// Unwinding of tx pool (reinjecting transactions into the pool needs to happen after unwinding execution)
-		// also tx pool is before senders because senders unwind is inside cycle transaction
-		stages.TxPool,
+var DefaultUnwindOrder = UnwindOrder{
+	stages.Headers,
+	stages.BlockHashes,
+	stages.CreateHeadersSnapshot,
+	stages.Bodies,
+	stages.CreateBodiesSnapshot,
 
-		stages.Senders,
-		stages.Execution,
-		stages.Translation,
-		stages.CreateStateSnapshot,
+	// Unwinding of tx pool (reinjecting transactions into the pool needs to happen after unwinding execution)
+	// also tx pool is before senders because senders unwind is inside cycle transaction
+	stages.TxPool,
 
-		// Unwinding of IHashes needs to happen after unwinding HashState
-		stages.IntermediateHashes,
-		stages.HashState,
+	stages.Senders,
+	stages.Execution,
+	stages.Translation,
+	stages.CreateStateSnapshot,
 
-		stages.CallTraces,
-		stages.AccountHistoryIndex,
-		stages.StorageHistoryIndex,
-		stages.LogIndex,
-		stages.TxLookup,
-		stages.Finish,
-	}
+	// Unwinding of IHashes needs to happen after unwinding HashState
+	stages.IntermediateHashes,
+	stages.HashState,
+
+	stages.CallTraces,
+	stages.AccountHistoryIndex,
+	stages.StorageHistoryIndex,
+	stages.LogIndex,
+	stages.TxLookup,
+	stages.Finish,
 }
+
+var MiningUnwindOrder = UnwindOrder{} // nothing to unwind in mining - because mining does not commit db changes
+var MiningPruneOrder = PruneOrder{}   // nothing to unwind in mining - because mining does not commit db changes
diff --git a/eth/stagedsync/stage.go b/eth/stagedsync/stage.go
index f70830a417..dd430c4aad 100644
--- a/eth/stagedsync/stage.go
+++ b/eth/stagedsync/stage.go
@@ -82,9 +82,10 @@ func (u *UnwindState) Done(db ethdb.Putter) error {
 }
 
 type PruneState struct {
-	ID         stages.SyncStage
-	PrunePoint uint64 // PrunePoint is the block to prune to.
-	state      *Sync
+	ID                 stages.SyncStage
+	PrunePoint         uint64 // PrunePoint is the block to prune to.
+	CurrentBlockNumber uint64
+	state              *Sync
 }
 
 func (u *PruneState) LogPrefix() string { return u.state.LogPrefix() }
diff --git a/eth/stagedsync/stagebuilder.go b/eth/stagedsync/stagebuilder.go
index 79de385eef..25e57b4427 100644
--- a/eth/stagedsync/stagebuilder.go
+++ b/eth/stagedsync/stagebuilder.go
@@ -81,14 +81,3 @@ func MiningStages(
 		},
 	}
 }
-
-// UnwindOrder represents the order in which the stages needs to be unwound.
-// Currently it is using indexes of stages, 0-based.
-// The unwind order is important and not always just stages going backwards.
-// Let's say, there is tx pool (state 10) can be unwound only after execution
-// is fully unwound (stages 9...3).
-type UnwindOrder []stages.SyncStage
-
-func MiningUnwindOrder() UnwindOrder {
-	return []stages.SyncStage{}
-}
diff --git a/eth/stagedsync/sync.go b/eth/stagedsync/sync.go
index 643eb9e1dd..02203feff8 100644
--- a/eth/stagedsync/sync.go
+++ b/eth/stagedsync/sync.go
@@ -13,6 +13,7 @@ import (
 	"github.com/ledgerwatch/erigon/eth/stagedsync/stages"
 	"github.com/ledgerwatch/erigon/ethdb"
 	"github.com/ledgerwatch/erigon/log"
+	"github.com/ledgerwatch/erigon/params"
 )
 
 type Sync struct {
@@ -22,6 +23,7 @@ type Sync struct {
 
 	stages       []*Stage
 	unwindOrder  []*Stage
+	pruningOrder []*Stage
 	currentStage uint
 }
 
@@ -32,6 +34,10 @@ func (s *Sync) NewUnwindState(id stages.SyncStage, unwindPoint, currentProgress
 	return &UnwindState{id, unwindPoint, currentProgress, common.Hash{}, s}
 }
 
+func (s *Sync) NewPruneState(id stages.SyncStage, prunePoint, currentProgress uint64) *PruneState {
+	return &PruneState{id, prunePoint, currentProgress, s}
+}
+
 func (s *Sync) NextStage() {
 	if s == nil {
 		return
@@ -73,11 +79,6 @@ func (s *Sync) IsAfter(stage1, stage2 stages.SyncStage) bool {
 	return idx1 > idx2
 }
 
-func (s *Sync) GetLocalHeight(db ethdb.KVGetter) (uint64, error) {
-	state, err := s.StageState(stages.Headers, db)
-	return state.BlockNumber, err
-}
-
 func (s *Sync) UnwindTo(unwindPoint uint64, badBlock common.Hash) {
 	log.Info("UnwindTo", "block", unwindPoint, "bad_block_hash", badBlock.String())
 	s.unwindPoint = &unwindPoint
@@ -88,10 +89,6 @@ func (s *Sync) IsDone() bool {
 	return s.currentStage >= uint(len(s.stages)) && s.unwindPoint == nil
 }
 
-func (s *Sync) CurrentStage() (uint, *Stage) {
-	return s.currentStage, s.stages[s.currentStage]
-}
-
 func (s *Sync) LogPrefix() string {
 	if s == nil {
 		return ""
@@ -109,18 +106,8 @@ func (s *Sync) SetCurrentStage(id stages.SyncStage) error {
 	return fmt.Errorf("stage not found with id: %v", id)
 }
 
-func (s *Sync) StageByID(id stages.SyncStage) (*Stage, error) {
-	for _, stage := range s.stages {
-		if stage.ID == id {
-			return stage, nil
-		}
-	}
-	return nil, fmt.Errorf("stage not found with id: %v", id)
-}
-
-func New(stagesList []*Stage, unwindOrder []stages.SyncStage) *Sync {
+func New(stagesList []*Stage, unwindOrder UnwindOrder, pruneOrder PruneOrder) *Sync {
 	unwindStages := make([]*Stage, len(stagesList))
-
 	for i, stageIndex := range unwindOrder {
 		for _, s := range stagesList {
 			if s.ID == stageIndex {
@@ -129,21 +116,45 @@ func New(stagesList []*Stage, unwindOrder []stages.SyncStage) *Sync {
 			}
 		}
 	}
+	pruneStages := make([]*Stage, len(stagesList))
+	for i, stageIndex := range pruneOrder {
+		for _, s := range stagesList {
+			if s.ID == stageIndex {
+				pruneStages[i] = s
+				break
+			}
+		}
+	}
 
-	st := &Sync{
+	return &Sync{
 		stages:       stagesList,
 		currentStage: 0,
 		unwindOrder:  unwindStages,
+		//pruningOrder: pruneStages,
 	}
-
-	return st
 }
 
-func (s *Sync) StageState(stage stages.SyncStage, db ethdb.KVGetter) (*StageState, error) {
-	blockNum, err := stages.GetStageProgress(db, stage)
-	if err != nil {
-		return nil, err
+func (s *Sync) StageState(stage stages.SyncStage, tx ethdb.Tx, db ethdb.RoKV) (*StageState, error) {
+	var blockNum uint64
+	var err error
+	useExternalTx := tx != nil
+	if useExternalTx {
+		blockNum, err = stages.GetStageProgress(tx, stage)
+		if err != nil {
+			return nil, err
+		}
+	} else {
+		if err = db.View(context.Background(), func(tx ethdb.Tx) error {
+			blockNum, err = stages.GetStageProgress(tx, stage)
+			if err != nil {
+				return err
+			}
+			return nil
+		}); err != nil {
+			return nil, err
+		}
 	}
+
 	return &StageState{s, stage, blockNum}, nil
 }
 
@@ -152,14 +163,11 @@ func (s *Sync) Run(db ethdb.RwKV, tx ethdb.RwTx, firstCycle bool) error {
 	for !s.IsDone() {
 		if s.unwindPoint != nil {
 			for i := len(s.unwindOrder) - 1; i >= 0; i-- {
-				if err := s.SetCurrentStage(s.unwindOrder[i].ID); err != nil {
-					return err
-				}
-				if s.unwindOrder[i].Disabled {
+				if s.unwindOrder[i].Disabled || s.unwindOrder[i].Unwind == nil {
 					continue
 				}
 				t := time.Now()
-				if err := s.unwindStage(firstCycle, s.unwindOrder[i].ID, db, tx); err != nil {
+				if err := s.unwindStage(firstCycle, s.unwindOrder[i], db, tx); err != nil {
 					return err
 				}
 				timings = append(timings, "Unwind "+string(s.unwindOrder[i].ID), time.Since(t))
@@ -172,21 +180,16 @@ func (s *Sync) Run(db ethdb.RwKV, tx ethdb.RwTx, firstCycle bool) error {
 			}
 		}
 
-		_, stage := s.CurrentStage()
+		stage := s.stages[s.currentStage]
 
 		if string(stage.ID) == debug.StopBeforeStage() { // stop process for debugging reasons
 			log.Error("STOP_BEFORE_STAGE env flag forced to stop app")
 			os.Exit(1)
 		}
 
-		if stage.Disabled {
+		if stage.Disabled || stage.Forward == nil {
 			logPrefix := s.LogPrefix()
-			message := fmt.Sprintf(
-				"[%s] disabled. %s",
-				logPrefix, stage.DisabledDescription,
-			)
-
-			log.Debug(message)
+			log.Debug(fmt.Sprintf("[%s] disabled. %s", logPrefix, stage.DisabledDescription))
 
 			s.NextStage()
 			continue
@@ -201,6 +204,20 @@ func (s *Sync) Run(db ethdb.RwKV, tx ethdb.RwTx, firstCycle bool) error {
 		s.NextStage()
 	}
 
+	for i := len(s.pruningOrder) - 1; i >= 0; i-- {
+		if s.pruningOrder[i].Disabled || s.pruningOrder[i].Prune == nil {
+			continue
+		}
+		t := time.Now()
+		if err := s.pruneStage(firstCycle, s.pruningOrder[i], db, tx); err != nil {
+			return err
+		}
+		timings = append(timings, "Pruning "+string(s.pruningOrder[i].ID), time.Since(t))
+	}
+	if err := s.SetCurrentStage(s.stages[0].ID); err != nil {
+		return err
+	}
+
 	if err := printLogs(tx, timings); err != nil {
 		return err
 	}
@@ -242,27 +259,12 @@ func printLogs(tx ethdb.RwTx, timings []interface{}) error {
 	return nil
 }
 
-func (s *Sync) runStage(stage *Stage, db ethdb.RwKV, tx ethdb.RwTx, firstCycle bool) error {
-	useExternalTx := tx != nil
-	if !useExternalTx {
-		var err error
-		tx, err = db.BeginRw(context.Background())
-		if err != nil {
-			return err
-		}
-		defer tx.Rollback()
-	}
-
-	stageState, err := s.StageState(stage.ID, tx)
+func (s *Sync) runStage(stage *Stage, db ethdb.RwKV, tx ethdb.RwTx, firstCycle bool) (err error) {
+	stageState, err := s.StageState(stage.ID, tx, db)
 	if err != nil {
 		return err
 	}
 
-	if !useExternalTx {
-		tx.Rollback()
-		tx = nil
-	}
-
 	start := time.Now()
 	logPrefix := s.LogPrefix()
 	if err = stage.Forward(firstCycle, stageState, s, tx); err != nil {
@@ -275,49 +277,61 @@ func (s *Sync) runStage(stage *Stage, db ethdb.RwKV, tx ethdb.RwTx, firstCycle b
 	return nil
 }
 
-func (s *Sync) unwindStage(firstCycle bool, stageID stages.SyncStage, db ethdb.RwKV, tx ethdb.RwTx) error {
-	useExternalTx := tx != nil
-	if !useExternalTx {
-		var err error
-		tx, err = db.BeginRw(context.Background())
-		if err != nil {
-			return err
-		}
-		defer tx.Rollback()
-	}
-
+func (s *Sync) unwindStage(firstCycle bool, stage *Stage, db ethdb.RwKV, tx ethdb.RwTx) error {
 	start := time.Now()
-	log.Info("Unwinding...", "stage", stageID)
-	stage, err := s.StageByID(stageID)
+	log.Info("Unwind...", "stage", stage.ID)
+	stageState, err := s.StageState(stage.ID, tx, db)
 	if err != nil {
 		return err
 	}
-	if stage.Unwind == nil {
+
+	unwind := s.NewUnwindState(stage.ID, *s.unwindPoint, stageState.BlockNumber)
+	unwind.BadBlock = s.badBlock
+
+	if stageState.BlockNumber <= unwind.UnwindPoint {
 		return nil
 	}
-	var stageState *StageState
-	stageState, err = s.StageState(stageID, tx)
+
+	if err = s.SetCurrentStage(stage.ID); err != nil {
+		return err
+	}
+
+	err = stage.Unwind(firstCycle, unwind, stageState, tx)
 	if err != nil {
 		return err
 	}
-	unwind := s.NewUnwindState(stageID, *s.unwindPoint, stageState.BlockNumber)
-	unwind.BadBlock = s.badBlock
 
-	if stageState.BlockNumber <= unwind.UnwindPoint {
+	if time.Since(start) > 30*time.Second {
+		log.Info("Unwind... DONE!", "stage", string(unwind.ID))
+	}
+	return nil
+}
+
+func (s *Sync) pruneStage(firstCycle bool, stage *Stage, db ethdb.RwKV, tx ethdb.RwTx) error {
+	start := time.Now()
+	log.Info("Prune...", "stage", stage.ID)
+
+	stageState, err := s.StageState(stage.ID, tx, db)
+	if err != nil {
+		return err
+	}
+
+	prunePoint := stageState.BlockNumber - params.FullImmutabilityThreshold // TODO: cli-customizable
+	prune := s.NewPruneState(stage.ID, prunePoint, stageState.BlockNumber)
+	if stageState.BlockNumber <= prune.PrunePoint {
 		return nil
 	}
-	if !useExternalTx {
-		tx.Rollback()
-		tx = nil
+	if err = s.SetCurrentStage(stage.ID); err != nil {
+		return err
 	}
 
-	err = stage.Unwind(firstCycle, unwind, stageState, tx)
+	err = stage.Prune(firstCycle, prune, tx)
 	if err != nil {
 		return err
 	}
 
 	if time.Since(start) > 30*time.Second {
-		log.Info("Unwinding... DONE!", "stage", string(unwind.ID))
+		log.Info("Prune... DONE!", "stage", string(prune.ID))
 	}
 	return nil
 }
diff --git a/eth/stagedsync/sync_test.go b/eth/stagedsync/sync_test.go
index 56c79dcb29..7bfb669a30 100644
--- a/eth/stagedsync/sync_test.go
+++ b/eth/stagedsync/sync_test.go
@@ -39,7 +39,7 @@ func TestStagesSuccess(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, nil)
+	state := New(s, nil, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.NoError(t, err)
@@ -79,7 +79,7 @@ func TestDisabledStages(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, nil)
+	state := New(s, nil, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.NoError(t, err)
@@ -119,7 +119,7 @@ func TestErroredStage(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID})
+	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID}, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.Equal(t, expectedErr, err)
@@ -202,7 +202,7 @@ func TestUnwindSomeStagesBehindUnwindPoint(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID, s[3].ID})
+	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID, s[3].ID}, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.NoError(t, err)
@@ -215,15 +215,15 @@ func TestUnwindSomeStagesBehindUnwindPoint(t *testing.T) {
 	}
 	assert.Equal(t, expectedFlow, flow)
 
-	stageState, err := state.StageState(stages.Headers, tx)
+	stageState, err := state.StageState(stages.Headers, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 1500, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Bodies, tx)
+	stageState, err = state.StageState(stages.Bodies, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 1000, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Senders, tx)
+	stageState, err = state.StageState(stages.Senders, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 1500, int(stageState.BlockNumber))
 }
@@ -295,7 +295,7 @@ func TestUnwind(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID, s[3].ID})
+	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID, s[3].ID}, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.NoError(t, err)
@@ -308,15 +308,15 @@ func TestUnwind(t *testing.T) {
 
 	assert.Equal(t, expectedFlow, flow)
 
-	stageState, err := state.StageState(stages.Headers, tx)
+	stageState, err := state.StageState(stages.Headers, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Bodies, tx)
+	stageState, err = state.StageState(stages.Bodies, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Senders, tx)
+	stageState, err = state.StageState(stages.Senders, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 
@@ -384,7 +384,7 @@ func TestUnwindEmptyUnwinder(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID})
+	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID}, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.NoError(t, err)
@@ -397,15 +397,15 @@ func TestUnwindEmptyUnwinder(t *testing.T) {
 
 	assert.Equal(t, expectedFlow, flow)
 
-	stageState, err := state.StageState(stages.Headers, tx)
+	stageState, err := state.StageState(stages.Headers, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Bodies, tx)
+	stageState, err = state.StageState(stages.Bodies, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 2000, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Senders, tx)
+	stageState, err = state.StageState(stages.Senders, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 }
@@ -440,12 +440,12 @@ func TestSyncDoTwice(t *testing.T) {
 		},
 	}
 
-	state := New(s, nil)
+	state := New(s, nil, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.NoError(t, err)
 
-	state = New(s, nil)
+	state = New(s, nil, nil)
 	err = state.Run(db, tx, true)
 	assert.NoError(t, err)
 
@@ -455,15 +455,15 @@ func TestSyncDoTwice(t *testing.T) {
 	}
 	assert.Equal(t, expectedFlow, flow)
 
-	stageState, err := state.StageState(stages.Headers, tx)
+	stageState, err := state.StageState(stages.Headers, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 200, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Bodies, tx)
+	stageState, err = state.StageState(stages.Bodies, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 400, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Senders, tx)
+	stageState, err = state.StageState(stages.Senders, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 600, int(stageState.BlockNumber))
 }
@@ -498,14 +498,14 @@ func TestStateSyncInterruptRestart(t *testing.T) {
 		},
 	}
 
-	state := New(s, nil)
+	state := New(s, nil, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.Equal(t, expectedErr, err)
 
 	expectedErr = nil
 
-	state = New(s, nil)
+	state = New(s, nil, nil)
 	err = state.Run(db, tx, true)
 	assert.NoError(t, err)
 
@@ -577,7 +577,7 @@ func TestSyncInterruptLongUnwind(t *testing.T) {
 			},
 		},
 	}
-	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID})
+	state := New(s, []stages.SyncStage{s[0].ID, s[1].ID, s[2].ID}, nil)
 	db, tx := kv.NewTestTx(t)
 	err := state.Run(db, tx, true)
 	assert.Error(t, errInterrupted, err)
@@ -599,15 +599,15 @@ func TestSyncInterruptLongUnwind(t *testing.T) {
 
 	assert.Equal(t, expectedFlow, flow)
 
-	stageState, err := state.StageState(stages.Headers, tx)
+	stageState, err := state.StageState(stages.Headers, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Bodies, tx)
+	stageState, err = state.StageState(stages.Bodies, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 
-	stageState, err = state.StageState(stages.Senders, tx)
+	stageState, err = state.StageState(stages.Senders, tx, nil)
 	assert.NoError(t, err)
 	assert.Equal(t, 500, int(stageState.BlockNumber))
 }
diff --git a/turbo/stages/mock_sentry.go b/turbo/stages/mock_sentry.go
index c8585cc17d..b1bf864e07 100644
--- a/turbo/stages/mock_sentry.go
+++ b/turbo/stages/mock_sentry.go
@@ -227,77 +227,81 @@ func MockWithEverything(t *testing.T, gspec *core.Genesis, key *ecdsa.PrivateKey
 			panic(err)
 		}
 	}
-
-	mock.Sync = NewStagedSync(mock.Ctx, sm,
-		stagedsync.StageHeadersCfg(
-			mock.DB,
-			mock.downloader.Hd,
-			*mock.ChainConfig,
-			sendHeaderRequest,
-			propagateNewBlockHashes,
-			penalize,
-			cfg.BatchSize,
-		),
-		stagedsync.StageBlockHashesCfg(mock.DB, mock.tmpdir),
-		stagedsync.StageSnapshotHeadersCfg(mock.DB, ethconfig.Snapshot{Enabled: false}, nil, nil),
-		stagedsync.StageBodiesCfg(
-			mock.DB,
-			mock.downloader.Bd,
-			sendBodyRequest,
-			penalize,
-			blockPropagator,
-			cfg.BodyDownloadTimeoutSeconds,
-			*mock.ChainConfig,
-			cfg.BatchSize,
-		),
-		stagedsync.StageSnapshotBodiesCfg(
-			mock.DB,
-			ethconfig.Snapshot{Enabled: false},
-			nil, nil,
-			"",
-		),
-		stagedsync.StageSendersCfg(mock.DB, mock.ChainConfig, mock.tmpdir),
-		stagedsync.StageExecuteBlocksCfg(
-			mock.DB,
-			sm.Receipts,
-			sm.CallTraces,
-			sm.TEVM,
-			0,
-			cfg.BatchSize,
-			nil,
-			mock.ChainConfig,
-			mock.Engine,
-			&vm.Config{NoReceipts: !sm.Receipts},
-			nil,
-			cfg.StateStream,
-			mock.tmpdir,
-		),
-		stagedsync.StageTranspileCfg(
-			mock.DB,
-			cfg.BatchSize,
-			mock.ChainConfig,
-		),
-		stagedsync.StageSnapshotStateCfg(
-			mock.DB,
-			ethconfig.Snapshot{Enabled: false},
-			"",
-			nil, nil,
+	mock.Sync = stagedsync.New(
+		stagedsync.DefaultStages(
+			mock.Ctx, sm,
+			stagedsync.StageHeadersCfg(
+				mock.DB,
+				mock.downloader.Hd,
+				*mock.ChainConfig,
+				sendHeaderRequest,
+				propagateNewBlockHashes,
+				penalize,
+				cfg.BatchSize,
+			),
+			stagedsync.StageBlockHashesCfg(mock.DB, mock.tmpdir),
+			stagedsync.StageSnapshotHeadersCfg(mock.DB, ethconfig.Snapshot{Enabled: false}, nil, nil),
+			stagedsync.StageBodiesCfg(
+				mock.DB,
+				mock.downloader.Bd,
+				sendBodyRequest,
+				penalize,
+				blockPropagator,
+				cfg.BodyDownloadTimeoutSeconds,
+				*mock.ChainConfig,
+				cfg.BatchSize,
+			),
+			stagedsync.StageSnapshotBodiesCfg(
+				mock.DB,
+				ethconfig.Snapshot{Enabled: false},
+				nil, nil,
+				"",
+			),
+			stagedsync.StageSendersCfg(mock.DB, mock.ChainConfig, mock.tmpdir),
+			stagedsync.StageExecuteBlocksCfg(
+				mock.DB,
+				sm.Receipts,
+				sm.CallTraces,
+				sm.TEVM,
+				0,
+				cfg.BatchSize,
+				nil,
+				mock.ChainConfig,
+				mock.Engine,
+				&vm.Config{NoReceipts: !sm.Receipts},
+				nil,
+				cfg.StateStream,
+				mock.tmpdir,
+			),
+			stagedsync.StageTranspileCfg(
+				mock.DB,
+				cfg.BatchSize,
+				mock.ChainConfig,
+			),
+			stagedsync.StageSnapshotStateCfg(
+				mock.DB,
+				ethconfig.Snapshot{Enabled: false},
+				"",
+				nil, nil,
+			),
+			stagedsync.StageHashStateCfg(mock.DB, mock.tmpdir),
+			stagedsync.StageTrieCfg(mock.DB, true, true, mock.tmpdir),
+			stagedsync.StageHistoryCfg(mock.DB, mock.tmpdir),
+			stagedsync.StageLogIndexCfg(mock.DB, mock.tmpdir),
+			stagedsync.StageCallTracesCfg(mock.DB, 0, mock.tmpdir),
+			stagedsync.StageTxLookupCfg(mock.DB, mock.tmpdir),
+			stagedsync.StageTxPoolCfg(mock.DB, txPool, func() {
+				mock.StreamWg.Add(1)
+				go txpool.RecvTxMessageLoop(mock.Ctx, mock.SentryClient, mock.downloader, mock.TxPoolP2PServer.HandleInboundMessage, &mock.ReceiveWg)
+				go txpropagate.BroadcastPendingTxsToNetwork(mock.Ctx, txPool, mock.TxPoolP2PServer.RecentPeers, mock.downloader)
+				mock.StreamWg.Wait()
+				mock.TxPoolP2PServer.TxFetcher.Start()
+			}),
+			stagedsync.StageFinishCfg(mock.DB, mock.tmpdir, nil, nil),
+			true, /* test */
 		),
-		stagedsync.StageHashStateCfg(mock.DB, mock.tmpdir),
-		stagedsync.StageTrieCfg(mock.DB, true, true, mock.tmpdir),
-		stagedsync.StageHistoryCfg(mock.DB, mock.tmpdir),
-		stagedsync.StageLogIndexCfg(mock.DB, mock.tmpdir),
-		stagedsync.StageCallTracesCfg(mock.DB, 0, mock.tmpdir),
-		stagedsync.StageTxLookupCfg(mock.DB, mock.tmpdir),
-		stagedsync.StageTxPoolCfg(mock.DB, txPool, func() {
-			mock.StreamWg.Add(1)
-			go txpool.RecvTxMessageLoop(mock.Ctx, mock.SentryClient, mock.downloader, mock.TxPoolP2PServer.HandleInboundMessage, &mock.ReceiveWg)
-			go txpropagate.BroadcastPendingTxsToNetwork(mock.Ctx, txPool, mock.TxPoolP2PServer.RecentPeers, mock.downloader)
-			mock.StreamWg.Wait()
-			mock.TxPoolP2PServer.TxFetcher.Start()
-		}),
-		stagedsync.StageFinishCfg(mock.DB, mock.tmpdir, nil, nil),
-		true, /* test */
+		stagedsync.DefaultUnwindOrder,
+		stagedsync.DefaultPruneOrder,
 	)
 
 	miningConfig := cfg.Miner
@@ -318,7 +322,8 @@ func MockWithEverything(t *testing.T, gspec *core.Genesis, key *ecdsa.PrivateKey
 			stagedsync.StageTrieCfg(mock.DB, false, true, mock.tmpdir),
 			stagedsync.StageMiningFinishCfg(mock.DB, *mock.ChainConfig, mock.Engine, miner, mock.Ctx.Done()),
 		),
-		stagedsync.MiningUnwindOrder(),
+		stagedsync.MiningUnwindOrder,
+		stagedsync.MiningPruneOrder,
 	)
 
 	mock.StreamWg.Add(1)
diff --git a/turbo/stages/stageloop.go b/turbo/stages/stageloop.go
index 5e0ea5893e..d273d0b415 100644
--- a/turbo/stages/stageloop.go
+++ b/turbo/stages/stageloop.go
@@ -26,34 +26,6 @@ import (
 	"github.com/ledgerwatch/erigon/turbo/txpool"
 )
 
-func NewStagedSync(
-	ctx context.Context,
-	sm ethdb.StorageMode,
-	headers stagedsync.HeadersCfg,
-	blockHashes stagedsync.BlockHashesCfg,
-	snapshotHeader stagedsync.SnapshotHeadersCfg,
-	bodies stagedsync.BodiesCfg,
-	snapshotBodies stagedsync.SnapshotBodiesCfg,
-	senders stagedsync.SendersCfg,
-	exec stagedsync.ExecuteBlockCfg,
-	trans stagedsync.TranspileCfg,
-	snapshotState stagedsync.SnapshotStateCfg,
-	hashState stagedsync.HashStateCfg,
-	trieCfg stagedsync.TrieCfg,
-	history stagedsync.HistoryCfg,
-	logIndex stagedsync.LogIndexCfg,
-	callTraces stagedsync.CallTracesCfg,
-	txLookup stagedsync.TxLookupCfg,
-	txPool stagedsync.TxPoolCfg,
-	finish stagedsync.FinishCfg,
-	test bool,
-) *stagedsync.Sync {
-	return stagedsync.New(
-		stagedsync.DefaultStages(ctx, sm, headers, blockHashes, snapshotHeader, bodies, snapshotBodies, senders, exec, trans, snapshotState, hashState, trieCfg, history, logIndex, callTraces, txLookup, txPool, finish, test),
-		stagedsync.DefaultUnwindOrder(),
-	)
-}
-
 // StageLoop runs the continuous loop of staged sync
 func StageLoop(
 	ctx context.Context,
@@ -239,69 +211,75 @@ func NewStagedSync2(
 		pruningDistance = params.FullImmutabilityThreshold
 	}
 
-	return NewStagedSync(ctx, cfg.StorageMode,
-		stagedsync.StageHeadersCfg(
-			db,
-			controlServer.Hd,
-			*controlServer.ChainConfig,
-			controlServer.SendHeaderRequest,
-			controlServer.PropagateNewBlockHashes,
-			controlServer.Penalize,
-			cfg.BatchSize,
-		),
-		stagedsync.StageBlockHashesCfg(db, tmpdir),
-		stagedsync.StageSnapshotHeadersCfg(db, cfg.Snapshot, client, snapshotMigrator),
-		stagedsync.StageBodiesCfg(
-			db,
-			controlServer.Bd,
-			controlServer.SendBodyRequest,
-			controlServer.Penalize,
-			controlServer.BroadcastNewBlock,
-			cfg.BodyDownloadTimeoutSeconds,
-			*controlServer.ChainConfig,
-			cfg.BatchSize,
-		),
-		stagedsync.StageSnapshotBodiesCfg(db, cfg.Snapshot, client, snapshotMigrator, tmpdir),
-		stagedsync.StageSendersCfg(db, controlServer.ChainConfig, tmpdir),
-		stagedsync.StageExecuteBlocksCfg(
-			db,
-			cfg.StorageMode.Receipts,
-			cfg.StorageMode.CallTraces,
-			cfg.StorageMode.TEVM,
-			pruningDistance,
-			cfg.BatchSize,
-			nil,
-			controlServer.ChainConfig,
-			controlServer.Engine,
-			&vm.Config{NoReceipts: !cfg.StorageMode.Receipts, EnableTEMV: cfg.StorageMode.TEVM},
-			accumulator,
-			cfg.StateStream,
-			tmpdir,
-		),
-		stagedsync.StageTranspileCfg(
-			db,
-			cfg.BatchSize,
-			controlServer.ChainConfig,
+	return stagedsync.New(
+		stagedsync.DefaultStages(
+			ctx,
+			cfg.StorageMode,
+			stagedsync.StageHeadersCfg(
+				db,
+				controlServer.Hd,
+				*controlServer.ChainConfig,
+				controlServer.SendHeaderRequest,
+				controlServer.PropagateNewBlockHashes,
+				controlServer.Penalize,
+				cfg.BatchSize,
+			),
+			stagedsync.StageBlockHashesCfg(db, tmpdir),
+			stagedsync.StageSnapshotHeadersCfg(db, cfg.Snapshot, client, snapshotMigrator),
+			stagedsync.StageBodiesCfg(
+				db,
+				controlServer.Bd,
+				controlServer.SendBodyRequest,
+				controlServer.Penalize,
+				controlServer.BroadcastNewBlock,
+				cfg.BodyDownloadTimeoutSeconds,
+				*controlServer.ChainConfig,
+				cfg.BatchSize,
+			),
+			stagedsync.StageSnapshotBodiesCfg(db, cfg.Snapshot, client, snapshotMigrator, tmpdir),
+			stagedsync.StageSendersCfg(db, controlServer.ChainConfig, tmpdir),
+			stagedsync.StageExecuteBlocksCfg(
+				db,
+				cfg.StorageMode.Receipts,
+				cfg.StorageMode.CallTraces,
+				cfg.StorageMode.TEVM,
+				pruningDistance,
+				cfg.BatchSize,
+				nil,
+				controlServer.ChainConfig,
+				controlServer.Engine,
+				&vm.Config{NoReceipts: !cfg.StorageMode.Receipts, EnableTEMV: cfg.StorageMode.TEVM},
+				accumulator,
+				cfg.StateStream,
+				tmpdir,
+			),
+			stagedsync.StageTranspileCfg(
+				db,
+				cfg.BatchSize,
+				controlServer.ChainConfig,
+			),
+			stagedsync.StageSnapshotStateCfg(db, cfg.Snapshot, tmpdir, client, snapshotMigrator),
+			stagedsync.StageHashStateCfg(db, tmpdir),
+			stagedsync.StageTrieCfg(db, true, true, tmpdir),
+			stagedsync.StageHistoryCfg(db, tmpdir),
+			stagedsync.StageLogIndexCfg(db, tmpdir),
+			stagedsync.StageCallTracesCfg(db, 0, tmpdir),
+			stagedsync.StageTxLookupCfg(db, tmpdir),
+			stagedsync.StageTxPoolCfg(db, txPool, func() {
+				for i := range txPoolServer.Sentries {
+					go func(i int) {
+						txpool.RecvTxMessageLoop(ctx, txPoolServer.Sentries[i], controlServer, txPoolServer.HandleInboundMessage, nil)
+					}(i)
+					go func(i int) {
+						txpool.RecvPeersLoop(ctx, txPoolServer.Sentries[i], controlServer, txPoolServer.RecentPeers, nil)
+					}(i)
+				}
+				txPoolServer.TxFetcher.Start()
+			}),
+			stagedsync.StageFinishCfg(db, tmpdir, client, snapshotMigrator),
+			false, /* test */
 		),
-		stagedsync.StageSnapshotStateCfg(db, cfg.Snapshot, tmpdir, client, snapshotMigrator),
-		stagedsync.StageHashStateCfg(db, tmpdir),
-		stagedsync.StageTrieCfg(db, true, true, tmpdir),
-		stagedsync.StageHistoryCfg(db, tmpdir),
-		stagedsync.StageLogIndexCfg(db, tmpdir),
-		stagedsync.StageCallTracesCfg(db, 0, tmpdir),
-		stagedsync.StageTxLookupCfg(db, tmpdir),
-		stagedsync.StageTxPoolCfg(db, txPool, func() {
-			for i := range txPoolServer.Sentries {
-				go func(i int) {
-					txpool.RecvTxMessageLoop(ctx, txPoolServer.Sentries[i], controlServer, txPoolServer.HandleInboundMessage, nil)
-				}(i)
-				go func(i int) {
-					txpool.RecvPeersLoop(ctx, txPoolServer.Sentries[i], controlServer, txPoolServer.RecentPeers, nil)
-				}(i)
-			}
-			txPoolServer.TxFetcher.Start()
-		}),
-		stagedsync.StageFinishCfg(db, tmpdir, client, snapshotMigrator),
-		false, /* test */
+		stagedsync.DefaultUnwindOrder,
+		stagedsync.DefaultPruneOrder,
 	), nil
 }
-- 
GitLab