From c60aff8058085ba5597ada4b64f37dd220ec3c5f Mon Sep 17 00:00:00 2001
From: Igor Mandrigin <mandrigin@users.noreply.github.com>
Date: Tue, 26 May 2020 16:37:25 +0300
Subject: [PATCH] Promote hashed state initially (#577)

---
 eth/downloader/stagedsync_downloader.go       |  11 +-
 .../stagedsync_stage_execute_test.go          | 134 +-------
 eth/downloader/stagedsync_stage_hashcheck.go  | 324 +++++++++++++++++-
 .../stagedsync_stage_hashcheck_test.go        | 138 ++++++++
 eth/downloader/stagedsync_stage_indexes.go    |  49 +--
 eth/downloader/stagedsync_testutil.go         | 132 +++++++
 eth/downloader/stagedsync_utils.go            |  40 +++
 trie/flatdb_sub_trie_loader.go                |   2 +-
 trie/sub_trie_loader.go                       |   2 +-
 9 files changed, 642 insertions(+), 190 deletions(-)
 create mode 100644 eth/downloader/stagedsync_stage_hashcheck_test.go
 create mode 100644 eth/downloader/stagedsync_testutil.go
 create mode 100644 eth/downloader/stagedsync_utils.go

diff --git a/eth/downloader/stagedsync_downloader.go b/eth/downloader/stagedsync_downloader.go
index b9db114da4..851ab80bdc 100644
--- a/eth/downloader/stagedsync_downloader.go
+++ b/eth/downloader/stagedsync_downloader.go
@@ -8,7 +8,6 @@ import (
 )
 
 func (d *Downloader) doStagedSyncWithFetchers(p *peerConnection, headersFetchers []func() error) error {
-	fmt.Println("doStagedSyncWithFetchers")
 	log.Info("Sync stage 1/7. Downloading headers...")
 
 	var err error
@@ -39,9 +38,7 @@ func (d *Downloader) doStagedSyncWithFetchers(p *peerConnection, headersFetchers
 		case Execution:
 			err = unwindExecutionStage(unwindPoint, d.stateDB)
 		case HashCheck:
-			if !core.UsePlainStateExecution {
-				err = d.unwindHashCheckStage(unwindPoint)
-			}
+			err = unwindHashCheckStage(unwindPoint, d.stateDB)
 		case AccountHistoryIndex:
 			err = unwindAccountHistoryIndex(unwindPoint, d.stateDB, core.UsePlainStateExecution)
 		case StorageHistoryIndex:
@@ -96,10 +93,8 @@ func (d *Downloader) doStagedSyncWithFetchers(p *peerConnection, headersFetchers
 
 	// Further stages go there
 	log.Info("Sync stage 5/7. Validating final hash")
-	if !core.UsePlainStateExecution {
-		if err = d.spawnCheckFinalHashStage(syncHeadNumber); err != nil {
-			return err
-		}
+	if err = spawnCheckFinalHashStage(d.stateDB, syncHeadNumber, d.datadir); err != nil {
+		return err
 	}
 
 	log.Info("Sync stage 5/7. Validating final hash... Complete!")
diff --git a/eth/downloader/stagedsync_stage_execute_test.go b/eth/downloader/stagedsync_stage_execute_test.go
index e6970cd177..7d43c25b4f 100644
--- a/eth/downloader/stagedsync_stage_execute_test.go
+++ b/eth/downloader/stagedsync_stage_execute_test.go
@@ -1,34 +1,19 @@
 package downloader
 
 import (
-	"context"
-	"fmt"
-	"math/big"
 	"testing"
 
-	"github.com/holiman/uint256"
-	"github.com/stretchr/testify/assert"
-
-	"github.com/ledgerwatch/turbo-geth/common"
 	"github.com/ledgerwatch/turbo-geth/common/dbutils"
 	"github.com/ledgerwatch/turbo-geth/core"
-	"github.com/ledgerwatch/turbo-geth/core/state"
-	"github.com/ledgerwatch/turbo-geth/core/types/accounts"
 	"github.com/ledgerwatch/turbo-geth/ethdb"
 )
 
-const (
-	staticCodeStaticIncarnations         = iota // no incarnation changes, no code changes
-	changeCodeWithIncarnations                  // code changes with incarnation
-	changeCodeIndepenentlyOfIncarnations        // code changes with and without incarnation
-)
-
 func TestUnwindExecutionStageHashedStatic(t *testing.T) {
 	initialDb := ethdb.NewMemDatabase()
-	generateBlocks(t, 50, hashedWriterGen(initialDb), staticCodeStaticIncarnations)
+	generateBlocks(t, 1, 50, hashedWriterGen(initialDb), staticCodeStaticIncarnations)
 
 	mutation := ethdb.NewMemDatabase()
-	generateBlocks(t, 100, hashedWriterGen(mutation), staticCodeStaticIncarnations)
+	generateBlocks(t, 1, 100, hashedWriterGen(mutation), staticCodeStaticIncarnations)
 
 	err := SaveStageProgress(mutation, Execution, 100)
 	if err != nil {
@@ -45,10 +30,10 @@ func TestUnwindExecutionStageHashedStatic(t *testing.T) {
 
 func TestUnwindExecutionStageHashedWithIncarnationChanges(t *testing.T) {
 	initialDb := ethdb.NewMemDatabase()
-	generateBlocks(t, 50, hashedWriterGen(initialDb), changeCodeWithIncarnations)
+	generateBlocks(t, 1, 50, hashedWriterGen(initialDb), changeCodeWithIncarnations)
 
 	mutation := ethdb.NewMemDatabase()
-	generateBlocks(t, 100, hashedWriterGen(mutation), changeCodeWithIncarnations)
+	generateBlocks(t, 1, 100, hashedWriterGen(mutation), changeCodeWithIncarnations)
 
 	err := SaveStageProgress(mutation, Execution, 100)
 	if err != nil {
@@ -65,10 +50,10 @@ func TestUnwindExecutionStageHashedWithIncarnationChanges(t *testing.T) {
 func TestUnwindExecutionStageHashedWithCodeChanges(t *testing.T) {
 	t.Skip("not supported yet, to be restored")
 	initialDb := ethdb.NewMemDatabase()
-	generateBlocks(t, 50, hashedWriterGen(initialDb), changeCodeIndepenentlyOfIncarnations)
+	generateBlocks(t, 1, 50, hashedWriterGen(initialDb), changeCodeIndepenentlyOfIncarnations)
 
 	mutation := ethdb.NewMemDatabase()
-	generateBlocks(t, 100, hashedWriterGen(mutation), changeCodeIndepenentlyOfIncarnations)
+	generateBlocks(t, 1, 100, hashedWriterGen(mutation), changeCodeIndepenentlyOfIncarnations)
 
 	err := SaveStageProgress(mutation, Execution, 100)
 	if err != nil {
@@ -84,10 +69,10 @@ func TestUnwindExecutionStageHashedWithCodeChanges(t *testing.T) {
 
 func TestUnwindExecutionStagePlainStatic(t *testing.T) {
 	initialDb := ethdb.NewMemDatabase()
-	generateBlocks(t, 50, plainWriterGen(initialDb), staticCodeStaticIncarnations)
+	generateBlocks(t, 1, 50, plainWriterGen(initialDb), staticCodeStaticIncarnations)
 
 	mutation := ethdb.NewMemDatabase()
-	generateBlocks(t, 100, plainWriterGen(mutation), staticCodeStaticIncarnations)
+	generateBlocks(t, 1, 100, plainWriterGen(mutation), staticCodeStaticIncarnations)
 
 	err := SaveStageProgress(mutation, Execution, 100)
 	if err != nil {
@@ -104,10 +89,10 @@ func TestUnwindExecutionStagePlainStatic(t *testing.T) {
 
 func TestUnwindExecutionStagePlainWithIncarnationChanges(t *testing.T) {
 	initialDb := ethdb.NewMemDatabase()
-	generateBlocks(t, 50, plainWriterGen(initialDb), changeCodeWithIncarnations)
+	generateBlocks(t, 1, 50, plainWriterGen(initialDb), changeCodeWithIncarnations)
 
 	mutation := ethdb.NewMemDatabase()
-	generateBlocks(t, 100, plainWriterGen(mutation), changeCodeWithIncarnations)
+	generateBlocks(t, 1, 100, plainWriterGen(mutation), changeCodeWithIncarnations)
 
 	err := SaveStageProgress(mutation, Execution, 100)
 	if err != nil {
@@ -125,10 +110,10 @@ func TestUnwindExecutionStagePlainWithIncarnationChanges(t *testing.T) {
 func TestUnwindExecutionStagePlainWithCodeChanges(t *testing.T) {
 	t.Skip("not supported yet, to be restored")
 	initialDb := ethdb.NewMemDatabase()
-	generateBlocks(t, 50, plainWriterGen(initialDb), changeCodeIndepenentlyOfIncarnations)
+	generateBlocks(t, 1, 50, plainWriterGen(initialDb), changeCodeIndepenentlyOfIncarnations)
 
 	mutation := ethdb.NewMemDatabase()
-	generateBlocks(t, 100, plainWriterGen(mutation), changeCodeIndepenentlyOfIncarnations)
+	generateBlocks(t, 1, 100, plainWriterGen(mutation), changeCodeIndepenentlyOfIncarnations)
 
 	err := SaveStageProgress(mutation, Execution, 100)
 	if err != nil {
@@ -142,98 +127,3 @@ func TestUnwindExecutionStagePlainWithCodeChanges(t *testing.T) {
 
 	compareCurrentState(t, initialDb, mutation, dbutils.PlainStateBucket, dbutils.PlainContractCodeBucket)
 }
-
-func generateBlocks(t *testing.T, numberOfBlocks uint64, stateWriterGen stateWriterGen, difficulty int) {
-	from := uint64(1)
-	ctx := context.Background()
-	acc1 := accounts.NewAccount()
-	acc1.Incarnation = 1
-	acc := &acc1
-	acc.Initialised = true
-	var addr common.Address = common.HexToAddress("0x1234567890")
-	acc.Balance.SetUint64(0)
-	for blockNumber := from; blockNumber < from+numberOfBlocks; blockNumber++ {
-		updateIncarnation := difficulty != staticCodeStaticIncarnations && blockNumber%10 == 0
-		newAcc := acc.SelfCopy()
-		newAcc.Balance.SetUint64(blockNumber)
-		if updateIncarnation {
-			newAcc.Incarnation = acc.Incarnation + 1
-		}
-		blockWriter := stateWriterGen(blockNumber)
-
-		var oldValue, newValue uint256.Int
-		newValue.SetOne()
-		var location common.Hash
-		location.SetBytes(big.NewInt(int64(blockNumber)).Bytes())
-
-		if blockNumber == 1 {
-			err := blockWriter.CreateContract(addr)
-			if err != nil {
-				t.Fatal(err)
-			}
-		}
-		if blockNumber == 1 || updateIncarnation || difficulty == changeCodeIndepenentlyOfIncarnations {
-			code := []byte(fmt.Sprintf("acc-code-%v", blockNumber))
-			codeHash, _ := common.HashData(code)
-			if err := blockWriter.UpdateAccountCode(addr, newAcc.Incarnation, codeHash, code); err != nil {
-				t.Fatal(err)
-			}
-			newAcc.CodeHash = codeHash
-		}
-		if err := blockWriter.WriteAccountStorage(ctx, addr, newAcc.Incarnation, &location, &oldValue, &newValue); err != nil {
-			t.Fatal(err)
-		}
-		if err := blockWriter.UpdateAccountData(ctx, addr, acc /* original */, newAcc /* new account */); err != nil {
-			t.Fatal(err)
-		}
-		if err := blockWriter.WriteChangeSets(); err != nil {
-			t.Fatal(err)
-		}
-		acc = newAcc
-	}
-}
-
-func compareCurrentState(
-	t *testing.T,
-	db1 ethdb.Database,
-	db2 ethdb.Database,
-	buckets ...[]byte,
-) {
-	for _, bucket := range buckets {
-		compareBucket(t, db1, db2, bucket)
-	}
-}
-
-func compareBucket(t *testing.T, db1, db2 ethdb.Database, bucketName []byte) {
-	var err error
-
-	bucket1 := make(map[string][]byte)
-	err = db1.Walk(bucketName, nil, 0, func(k, v []byte) (bool, error) {
-		bucket1[string(k)] = v
-		return true, nil
-	})
-	assert.Nil(t, err)
-
-	bucket2 := make(map[string][]byte)
-	err = db2.Walk(bucketName, nil, 0, func(k, v []byte) (bool, error) {
-		bucket2[string(k)] = v
-		return true, nil
-	})
-	assert.Nil(t, err)
-
-	assert.Equal(t, bucket1 /*expected*/, bucket2 /*actual*/)
-}
-
-type stateWriterGen func(uint64) state.WriterWithChangeSets
-
-func hashedWriterGen(db ethdb.Database) stateWriterGen {
-	return func(blockNum uint64) state.WriterWithChangeSets {
-		return state.NewDbStateWriter(db, db, blockNum)
-	}
-}
-
-func plainWriterGen(db ethdb.Database) stateWriterGen {
-	return func(blockNum uint64) state.WriterWithChangeSets {
-		return state.NewPlainStateWriter(db, db, blockNum)
-	}
-}
diff --git a/eth/downloader/stagedsync_stage_hashcheck.go b/eth/downloader/stagedsync_stage_hashcheck.go
index 80e29096fd..d42a60b346 100644
--- a/eth/downloader/stagedsync_stage_hashcheck.go
+++ b/eth/downloader/stagedsync_stage_hashcheck.go
@@ -1,15 +1,32 @@
 package downloader
 
 import (
+	"bufio"
+	"bytes"
+	"container/heap"
 	"fmt"
+	"io"
+	"io/ioutil"
+	"os"
+	"runtime"
+	"sort"
 
+	"github.com/ledgerwatch/turbo-geth/common"
+	"github.com/ledgerwatch/turbo-geth/common/dbutils"
+	"github.com/ledgerwatch/turbo-geth/core"
+	"github.com/ledgerwatch/turbo-geth/core/rawdb"
+	"github.com/ledgerwatch/turbo-geth/ethdb"
 	"github.com/ledgerwatch/turbo-geth/log"
 	"github.com/ledgerwatch/turbo-geth/trie"
+
 	"github.com/pkg/errors"
+	"github.com/ugorji/go/codec"
 )
 
-func (d *Downloader) spawnCheckFinalHashStage(syncHeadNumber uint64) error {
-	hashProgress, err := GetStageProgress(d.stateDB, HashCheck)
+var cbor codec.CborHandle
+
+func spawnCheckFinalHashStage(stateDB ethdb.Database, syncHeadNumber uint64, datadir string) error {
+	hashProgress, err := GetStageProgress(stateDB, HashCheck)
 	if err != nil {
 		return err
 	}
@@ -18,24 +35,36 @@ func (d *Downloader) spawnCheckFinalHashStage(syncHeadNumber uint64) error {
 	if hashProgress == 0 {
 		return nil
 	}
+
 	if hashProgress == syncHeadNumber {
 		// we already did hash check for this block
 		// we don't do the obvious `if hashProgress > syncHeadNumber` to support reorgs more naturally
 		return nil
 	}
 
-	syncHeadBlock := d.blockchain.GetBlockByNumber(syncHeadNumber)
+	hashedStatePromotion := stateDB.NewBatch()
 
-	// make sure that we won't write the the real DB
-	// should never be commited
-	euphemeralMutation := d.stateDB.NewBatch()
+	if core.UsePlainStateExecution {
+		err = promoteHashedState(hashedStatePromotion, hashProgress, datadir)
+		if err != nil {
+			return err
+		}
+	}
+
+	_, err = hashedStatePromotion.Commit()
+	if err != nil {
+		return err
+	}
+
+	hash := rawdb.ReadCanonicalHash(stateDB, syncHeadNumber)
+	syncHeadBlock := rawdb.ReadBlock(stateDB, hash, syncHeadNumber)
 
 	blockNr := syncHeadBlock.Header().Number.Uint64()
 
 	log.Info("Validating root hash", "block", blockNr, "blockRoot", syncHeadBlock.Root().Hex())
 	loader := trie.NewSubTrieLoader(blockNr)
 	rl := trie.NewRetainList(0)
-	subTries, err1 := loader.LoadFromFlatDB(euphemeralMutation, rl, [][]byte{nil}, []int{0}, false)
+	subTries, err1 := loader.LoadFromFlatDB(stateDB, rl, [][]byte{nil}, []int{0}, false)
 	if err1 != nil {
 		return errors.Wrap(err1, "checking root hash failed")
 	}
@@ -46,28 +75,24 @@ func (d *Downloader) spawnCheckFinalHashStage(syncHeadNumber uint64) error {
 		return fmt.Errorf("wrong trie root: %x, expected (from header): %x", subTries.Hashes[0], syncHeadBlock.Root())
 	}
 
-	return SaveStageProgress(d.stateDB, HashCheck, blockNr)
+	return SaveStageProgress(stateDB, HashCheck, blockNr)
 }
 
-func (d *Downloader) unwindHashCheckStage(unwindPoint uint64) error {
+func unwindHashCheckStage(unwindPoint uint64, stateDB ethdb.Database) error {
 	// Currently it does not require unwinding because it does not create any Intemediate Hash records
 	// and recomputes the state root from scratch
-	lastProcessedBlockNumber, err := GetStageProgress(d.stateDB, HashCheck)
+	lastProcessedBlockNumber, err := GetStageProgress(stateDB, HashCheck)
 	if err != nil {
 		return fmt.Errorf("unwind HashCheck: get stage progress: %v", err)
 	}
-	unwindPoint, err1 := GetStageUnwind(d.stateDB, HashCheck)
-	if err1 != nil {
-		return err1
-	}
 	if unwindPoint >= lastProcessedBlockNumber {
-		err = SaveStageUnwind(d.stateDB, HashCheck, 0)
+		err = SaveStageUnwind(stateDB, HashCheck, 0)
 		if err != nil {
 			return fmt.Errorf("unwind HashCheck: reset: %v", err)
 		}
 		return nil
 	}
-	mutation := d.stateDB.NewBatch()
+	mutation := stateDB.NewBatch()
 	err = SaveStageUnwind(mutation, HashCheck, 0)
 	if err != nil {
 		return fmt.Errorf("unwind HashCheck: reset: %v", err)
@@ -78,3 +103,270 @@ func (d *Downloader) unwindHashCheckStage(unwindPoint uint64) error {
 	}
 	return nil
 }
+
+func promoteHashedState(db ethdb.Database, progress uint64, datadir string) error {
+	if progress == 0 {
+		return promoteHashedStateCleanly(db, datadir)
+	}
+	return errors.New("incremental state promotion not implemented")
+}
+
+func promoteHashedStateCleanly(db ethdb.Database, datadir string) error {
+	err := copyBucket(datadir, db, dbutils.PlainStateBucket, dbutils.CurrentStateBucket, transformPlainStateKey)
+	if err != nil {
+		return err
+	}
+	return copyBucket(datadir, db, dbutils.PlainContractCodeBucket, dbutils.ContractCodeBucket, transformContractCodeKey)
+}
+
+func copyBucket(
+	datadir string,
+	db ethdb.Database,
+	sourceBucket,
+	destBucket []byte,
+	transformKeyFunc func([]byte) ([]byte, error)) error {
+
+	var m runtime.MemStats
+
+	buffer := newSortableBuffer()
+	files := []string{}
+
+	defer func() {
+		deleteFiles(files)
+	}()
+
+	err := db.Walk(sourceBucket, nil, 0, func(k, v []byte) (bool, error) {
+		newK, err := transformKeyFunc(k)
+		if err != nil {
+			return false, err
+		}
+		buffer.Put(newK, v)
+
+		bufferSize := buffer.Size()
+		if bufferSize >= buffer.OptimalSize {
+			sort.Sort(buffer)
+			file, err := buffer.FlushToDisk(datadir)
+			if err != nil {
+				return false, err
+			}
+			if len(file) > 0 {
+				files = append(files, file)
+				log.Info("Plain -> Hashed / created a buffer file",
+					"bucket", string(sourceBucket),
+					"name", file,
+					"size", bufferSize,
+					"plainKey", fmt.Sprintf("%x...", k[:4]),
+					"alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC))
+			}
+
+			runtime.ReadMemStats(&m)
+
+		}
+		return true, nil
+	})
+	if err != nil {
+		return err
+	}
+
+	sort.Sort(buffer)
+	var file string
+	bufferSize := buffer.Size()
+	file, err = buffer.FlushToDisk(datadir)
+	if err != nil {
+		return err
+	}
+	if len(file) > 0 {
+		files = append(files, file)
+
+		log.Info("Plain -> Hashed / created a buffer file (final)",
+			"bucket", string(sourceBucket),
+			"name", file,
+			"size", bufferSize,
+			"alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC))
+	}
+	return mergeTempFilesIntoBucket(db, files, destBucket)
+}
+
+func transformPlainStateKey(key []byte) ([]byte, error) {
+	switch len(key) {
+	case common.AddressLength:
+		// account
+		hash, err := common.HashData(key)
+		return hash[:], err
+	case common.AddressLength + common.IncarnationLength + common.HashLength:
+		// storage
+		address, incarnation, key := dbutils.PlainParseCompositeStorageKey(key)
+		addrHash, err := common.HashData(address[:])
+		if err != nil {
+			return nil, err
+		}
+		secKey, err := common.HashData(key[:])
+		if err != nil {
+			return nil, err
+		}
+		compositeKey := dbutils.GenerateCompositeStorageKey(addrHash, incarnation, secKey)
+		return compositeKey, nil
+	default:
+		// no other keys are supported
+		return nil, fmt.Errorf("could not convert key from plain to hashed, unexpected len: %d", len(key))
+	}
+}
+
+func transformContractCodeKey(key []byte) ([]byte, error) {
+	if len(key) != common.AddressLength+common.IncarnationLength {
+		return nil, fmt.Errorf("could not convert code key from plain to hashed, unexpected len: %d", len(key))
+	}
+	address, incarnation := dbutils.PlainParseStoragePrefix(key)
+
+	addrHash, err := common.HashData(address[:])
+	if err != nil {
+		return nil, err
+	}
+
+	compositeKey := dbutils.GenerateStoragePrefix(addrHash[:], incarnation)
+	return compositeKey, nil
+}
+
+type sortableBufferEntry struct {
+	key   []byte
+	value []byte
+}
+
+type sortableBuffer struct {
+	entries     []sortableBufferEntry
+	size        int
+	OptimalSize int
+	encoder     *codec.Encoder
+}
+
+func (b *sortableBuffer) Put(k, v []byte) {
+	b.size += len(k)
+	b.size += len(v)
+	b.entries = append(b.entries, sortableBufferEntry{k, v})
+}
+
+func (b *sortableBuffer) Size() int {
+	return b.size
+}
+
+func (b *sortableBuffer) Len() int {
+	return len(b.entries)
+}
+
+func (b *sortableBuffer) Less(i, j int) bool {
+	return bytes.Compare(b.entries[i].key, b.entries[j].key) < 0
+}
+
+func (b *sortableBuffer) Swap(i, j int) {
+	b.entries[i], b.entries[j] = b.entries[j], b.entries[i]
+}
+
+func (b *sortableBuffer) FlushToDisk(datadir string) (string, error) {
+	if len(b.entries) == 0 {
+		return "", nil
+	}
+	bufferFile, err := ioutil.TempFile(datadir, "tg-sync-sortable-buf")
+	if err != nil {
+		return "", err
+	}
+	defer bufferFile.Close() //nolint:errcheck
+
+	filename := bufferFile.Name()
+	w := bufio.NewWriter(bufferFile)
+	defer w.Flush() //nolint:errcheck
+	b.encoder.Reset(w)
+
+	for i := range b.entries {
+		err = writeToDisk(b.encoder, b.entries[i].key, b.entries[i].value)
+		if err != nil {
+			return "", fmt.Errorf("error writing entries to disk: %v", err)
+		}
+	}
+
+	b.entries = b.entries[:0] // keep the capacity
+	return filename, nil
+}
+
+func newSortableBuffer() *sortableBuffer {
+	return &sortableBuffer{
+		entries:     make([]sortableBufferEntry, 0),
+		size:        0,
+		OptimalSize: 256 * 1024 * 1024, /* 256 mb */
+		encoder:     codec.NewEncoder(nil, &cbor),
+	}
+}
+
+func writeToDisk(encoder *codec.Encoder, key []byte, value []byte) error {
+	toWrite := [][]byte{key, value}
+	return encoder.Encode(toWrite)
+}
+
+func readElementFromDisk(decoder *codec.Decoder) ([]byte, []byte, error) {
+	result := make([][]byte, 2)
+	err := decoder.Decode(&result)
+	return result[0], result[1], err
+}
+
+func mergeTempFilesIntoBucket(db ethdb.Database, files []string, bucket []byte) error {
+	decoder := codec.NewDecoder(nil, &cbor)
+	var m runtime.MemStats
+	h := &Heap{}
+	heap.Init(h)
+	readers := make([]io.Reader, len(files))
+	for i, filename := range files {
+		if f, err := os.Open(filename); err == nil {
+			readers[i] = bufio.NewReader(f)
+			defer f.Close() //nolint:errcheck
+		} else {
+			return err
+		}
+		decoder.Reset(readers[i])
+		if key, value, err := readElementFromDisk(decoder); err == nil {
+			he := HeapElem{key, i, value}
+			heap.Push(h, he)
+		} else /* we must have at least one entry per file */ {
+			return fmt.Errorf("error reading first readers: n=%d current=%d filename=%s err=%v",
+				len(files), i, filename, err)
+		}
+	}
+	batch := db.NewBatch()
+	for h.Len() > 0 {
+		element := (heap.Pop(h)).(HeapElem)
+		reader := readers[element.timeIdx]
+		if err := batch.Put(bucket, element.key, element.value); err != nil {
+			return err
+		}
+		batchSize := batch.BatchSize()
+		if batchSize > batch.IdealBatchSize() {
+			if _, err := batch.Commit(); err != nil {
+				return err
+			}
+			log.Info(
+				"Commited index batch",
+				"bucket", string(bucket),
+				"size", common.StorageSize(batchSize),
+				"hashedKey", fmt.Sprintf("%x...", element.key[:4]),
+				"alloc", int(m.Alloc/1024), "sys", int(m.Sys/1024), "numGC", int(m.NumGC))
+		}
+		var err error
+		decoder.Reset(reader)
+		if element.key, element.value, err = readElementFromDisk(decoder); err == nil {
+			heap.Push(h, element)
+		} else if err != io.EOF {
+			return fmt.Errorf("error while reading next element from disk: %v", err)
+		}
+	}
+	_, err := batch.Commit()
+	return err
+}
+
+func deleteFiles(files []string) {
+	for _, filename := range files {
+		err := os.Remove(filename)
+		if err != nil {
+			log.Warn("promoting hashed state, error while removing temp file", "file", filename, "err", err)
+		} else {
+			log.Warn("promoting hashed state, removed temp", "file", filename)
+		}
+	}
+}
diff --git a/eth/downloader/stagedsync_stage_hashcheck_test.go b/eth/downloader/stagedsync_stage_hashcheck_test.go
new file mode 100644
index 0000000000..a8bfd8beaf
--- /dev/null
+++ b/eth/downloader/stagedsync_stage_hashcheck_test.go
@@ -0,0 +1,138 @@
+package downloader
+
+import (
+	"bytes"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"testing"
+
+	"github.com/ledgerwatch/turbo-geth/common/dbutils"
+	"github.com/ledgerwatch/turbo-geth/ethdb"
+	"github.com/stretchr/testify/assert"
+	"github.com/ugorji/go/codec"
+)
+
+func TestWriteAndReadBufferEntry(t *testing.T) {
+
+	buffer := bytes.NewBuffer(make([]byte, 0))
+	encoder := codec.NewEncoder(buffer, &cbor)
+
+	keys := make([]string, 100)
+	vals := make([]string, 100)
+
+	for i := range keys {
+		keys[i] = fmt.Sprintf("key-%d", i)
+		vals[i] = fmt.Sprintf("value-%d", i)
+	}
+
+	for i := range keys {
+		if err := writeToDisk(encoder, []byte(keys[i]), []byte(vals[i])); err != nil {
+			t.Error(err)
+		}
+	}
+
+	bb := buffer.Bytes()
+
+	readBuffer := bytes.NewReader(bb)
+
+	decoder := codec.NewDecoder(readBuffer, &cbor)
+
+	for i := range keys {
+		k, v, err := readElementFromDisk(decoder)
+		if err != nil {
+			t.Error(err)
+		}
+		assert.Equal(t, keys[i], string(k))
+		assert.Equal(t, vals[i], string(v))
+	}
+
+	_, _, err := readElementFromDisk(decoder)
+	assert.Equal(t, io.EOF, err)
+}
+
+func getDataDir() string {
+	name, err := ioutil.TempDir("", "geth-tests-staged-sync")
+	if err != nil {
+		panic(err)
+	}
+	return name
+}
+
+func TestPromoteHashedStateClearState(t *testing.T) {
+	db1 := ethdb.NewMemDatabase()
+	db2 := ethdb.NewMemDatabase()
+
+	generateBlocks(t, 1, 50, hashedWriterGen(db1), changeCodeWithIncarnations)
+
+	generateBlocks(t, 1, 50, plainWriterGen(db2), changeCodeWithIncarnations)
+
+	m2 := db2.NewBatch()
+	err := promoteHashedState(m2, 0, getDataDir())
+	if err != nil {
+		t.Errorf("error while promoting state: %v", err)
+	}
+	_, err = m2.Commit()
+	if err != nil {
+		t.Errorf("error while commiting state: %v", err)
+	}
+
+	compareCurrentState(t, db1, db2, dbutils.CurrentStateBucket, dbutils.ContractCodeBucket)
+}
+
+func TestPromoteHashedStateIncremental(t *testing.T) {
+	t.Skip("not implemented yet")
+	db1 := ethdb.NewMemDatabase()
+	db2 := ethdb.NewMemDatabase()
+
+	generateBlocks(t, 1, 50, hashedWriterGen(db1), changeCodeWithIncarnations)
+	generateBlocks(t, 1, 50, plainWriterGen(db2), changeCodeWithIncarnations)
+
+	m2 := db2.NewBatch()
+	err := promoteHashedState(m2, 0, getDataDir())
+	if err != nil {
+		t.Errorf("error while promoting state: %v", err)
+	}
+	_, err = m2.Commit()
+	if err != nil {
+		t.Errorf("error while commiting state: %v", err)
+	}
+
+	generateBlocks(t, 51, 50, hashedWriterGen(db1), changeCodeWithIncarnations)
+	generateBlocks(t, 51, 50, plainWriterGen(db2), changeCodeWithIncarnations)
+
+	m2 = db2.NewBatch()
+	err = promoteHashedState(m2, 50, getDataDir())
+	if err != nil {
+		t.Errorf("error while promoting state: %v", err)
+	}
+	_, err = m2.Commit()
+	if err != nil {
+		t.Errorf("error while commiting state: %v", err)
+	}
+
+	compareCurrentState(t, db1, db2, dbutils.CurrentStateBucket, dbutils.ContractCodeBucket)
+}
+
+func TestPromoteHashedStateIncrementalMixed(t *testing.T) {
+	t.Skip("not implemented yet")
+	db1 := ethdb.NewMemDatabase()
+	db2 := ethdb.NewMemDatabase()
+
+	generateBlocks(t, 1, 100, hashedWriterGen(db1), changeCodeWithIncarnations)
+	generateBlocks(t, 1, 50, hashedWriterGen(db1), changeCodeWithIncarnations)
+	generateBlocks(t, 51, 50, plainWriterGen(db2), changeCodeWithIncarnations)
+
+	m2 := db2.NewBatch()
+	err := promoteHashedState(m2, 50, getDataDir())
+	if err != nil {
+		t.Errorf("error while promoting state: %v", err)
+	}
+
+	_, err = m2.Commit()
+	if err != nil {
+		t.Errorf("error while commiting state: %v", err)
+	}
+
+	compareCurrentState(t, db1, db2, dbutils.CurrentStateBucket, dbutils.ContractCodeBucket)
+}
diff --git a/eth/downloader/stagedsync_stage_indexes.go b/eth/downloader/stagedsync_stage_indexes.go
index 7b0412c97a..3666f98c4c 100644
--- a/eth/downloader/stagedsync_stage_indexes.go
+++ b/eth/downloader/stagedsync_stage_indexes.go
@@ -6,17 +6,18 @@ import (
 	"container/heap"
 	"encoding/binary"
 	"fmt"
+	"io"
+	"io/ioutil"
+	"os"
+	"runtime"
+	"sort"
+
 	"github.com/ledgerwatch/turbo-geth/common"
 	"github.com/ledgerwatch/turbo-geth/common/changeset"
 	"github.com/ledgerwatch/turbo-geth/common/dbutils"
 	"github.com/ledgerwatch/turbo-geth/core"
 	"github.com/ledgerwatch/turbo-geth/ethdb"
 	"github.com/ledgerwatch/turbo-geth/log"
-	"io"
-	"io/ioutil"
-	"os"
-	"runtime"
-	"sort"
 )
 
 func fillChangeSetBuffer(db ethdb.Database, bucket []byte, blockNum uint64, changesets []byte, offsets []int, blockNums []uint64) (bool, uint64, []int, []uint64, error) {
@@ -87,42 +88,6 @@ func writeBufferMapToTempFile(datadir string, pattern string, bufferMap map[stri
 	return filename, nil
 }
 
-type HeapElem struct {
-	key     []byte
-	timeIdx int
-}
-
-type Heap []HeapElem
-
-func (h Heap) Len() int {
-	return len(h)
-}
-
-func (h Heap) Less(i, j int) bool {
-	if c := bytes.Compare(h[i].key, h[j].key); c != 0 {
-		return c < 0
-	}
-	return h[i].timeIdx < h[j].timeIdx
-}
-
-func (h Heap) Swap(i, j int) {
-	h[i], h[j] = h[j], h[i]
-}
-
-func (h *Heap) Push(x interface{}) {
-	// Push and Pop use pointer receivers because they modify the slice's length,
-	// not just its contents.
-	*h = append(*h, x.(HeapElem))
-}
-
-func (h *Heap) Pop() interface{} {
-	old := *h
-	n := len(old)
-	x := old[n-1]
-	*h = old[0 : n-1]
-	return x
-}
-
 func mergeFilesIntoBucket(bufferFileNames []string, db ethdb.Database, bucket []byte, keyLength int) error {
 	var m runtime.MemStats
 	h := &Heap{}
@@ -139,7 +104,7 @@ func mergeFilesIntoBucket(bufferFileNames []string, db ethdb.Database, bucket []
 		// Read first key
 		keyBuf := make([]byte, keyLength)
 		if n, err := io.ReadFull(readers[i], keyBuf); err == nil && n == keyLength {
-			heap.Push(h, HeapElem{keyBuf, i})
+			heap.Push(h, HeapElem{keyBuf, i, nil})
 		} else {
 			return fmt.Errorf("init reading from account buffer file: %d %x %v", n, keyBuf[:n], err)
 		}
diff --git a/eth/downloader/stagedsync_testutil.go b/eth/downloader/stagedsync_testutil.go
new file mode 100644
index 0000000000..4fbf45a304
--- /dev/null
+++ b/eth/downloader/stagedsync_testutil.go
@@ -0,0 +1,132 @@
+package downloader
+
+import (
+	"context"
+	"fmt"
+	"math/big"
+	"testing"
+
+	"github.com/holiman/uint256"
+	"github.com/ledgerwatch/turbo-geth/common"
+	"github.com/ledgerwatch/turbo-geth/core/state"
+	"github.com/ledgerwatch/turbo-geth/core/types/accounts"
+	"github.com/ledgerwatch/turbo-geth/ethdb"
+	"github.com/stretchr/testify/assert"
+)
+
+const (
+	staticCodeStaticIncarnations         = iota // no incarnation changes, no code changes
+	changeCodeWithIncarnations                  // code changes with incarnation
+	changeCodeIndepenentlyOfIncarnations        // code changes with and without incarnation
+)
+
+func compareCurrentState(
+	t *testing.T,
+	db1 ethdb.Database,
+	db2 ethdb.Database,
+	buckets ...[]byte,
+) {
+	for _, bucket := range buckets {
+		compareBucket(t, db1, db2, bucket)
+	}
+}
+
+func compareBucket(t *testing.T, db1, db2 ethdb.Database, bucketName []byte) {
+	var err error
+
+	bucket1 := make(map[string][]byte)
+	err = db1.Walk(bucketName, nil, 0, func(k, v []byte) (bool, error) {
+		bucket1[string(k)] = v
+		return true, nil
+	})
+	assert.Nil(t, err)
+
+	bucket2 := make(map[string][]byte)
+	err = db2.Walk(bucketName, nil, 0, func(k, v []byte) (bool, error) {
+		bucket2[string(k)] = v
+		return true, nil
+	})
+	assert.Nil(t, err)
+
+	assert.Equal(t, bucket1 /*expected*/, bucket2 /*actual*/)
+}
+
+type stateWriterGen func(uint64) state.WriterWithChangeSets
+
+func hashedWriterGen(db ethdb.Database) stateWriterGen {
+	return func(blockNum uint64) state.WriterWithChangeSets {
+		return state.NewDbStateWriter(db, db, blockNum)
+	}
+}
+
+func plainWriterGen(db ethdb.Database) stateWriterGen {
+	return func(blockNum uint64) state.WriterWithChangeSets {
+		return state.NewPlainStateWriter(db, db, blockNum)
+	}
+}
+func generateBlocks(t *testing.T, from uint64, numberOfBlocks uint64, stateWriterGen stateWriterGen, difficulty int) {
+	acc1 := accounts.NewAccount()
+	acc1.Incarnation = 1
+	acc1.Initialised = true
+	acc1.Balance.SetUint64(0)
+
+	acc2 := accounts.NewAccount()
+	acc2.Incarnation = 0
+	acc2.Initialised = true
+	acc2.Balance.SetUint64(0)
+
+	testAccounts := []*accounts.Account{
+		&acc1,
+		&acc2,
+	}
+	ctx := context.Background()
+
+	for blockNumber := from; blockNumber < from+numberOfBlocks; blockNumber++ {
+		updateIncarnation := difficulty != staticCodeStaticIncarnations && blockNumber%10 == 0
+		blockWriter := stateWriterGen(blockNumber)
+
+		for i, oldAcc := range testAccounts {
+			addr := common.HexToAddress(fmt.Sprintf("0x1234567890%d", i))
+
+			newAcc := oldAcc.SelfCopy()
+			newAcc.Balance.SetUint64(blockNumber)
+			if updateIncarnation && oldAcc.Incarnation > 0 /* only update for contracts */ {
+				newAcc.Incarnation = oldAcc.Incarnation + 1
+			}
+
+			if blockNumber == 1 && newAcc.Incarnation > 0 {
+				err := blockWriter.CreateContract(addr)
+				if err != nil {
+					t.Fatal(err)
+				}
+			}
+			if blockNumber == 1 || updateIncarnation || difficulty == changeCodeIndepenentlyOfIncarnations {
+				if newAcc.Incarnation > 0 {
+					code := []byte(fmt.Sprintf("acc-code-%v", blockNumber))
+					codeHash, _ := common.HashData(code)
+					if err := blockWriter.UpdateAccountCode(addr, newAcc.Incarnation, codeHash, code); err != nil {
+						t.Fatal(err)
+					}
+					newAcc.CodeHash = codeHash
+				}
+			}
+
+			if newAcc.Incarnation > 0 {
+				var oldValue, newValue uint256.Int
+				newValue.SetOne()
+				var location common.Hash
+				location.SetBytes(big.NewInt(int64(blockNumber)).Bytes())
+				if err := blockWriter.WriteAccountStorage(ctx, addr, newAcc.Incarnation, &location, &oldValue, &newValue); err != nil {
+					t.Fatal(err)
+				}
+			}
+			if err := blockWriter.UpdateAccountData(ctx, addr, oldAcc /* original */, newAcc /* new account */); err != nil {
+				t.Fatal(err)
+			}
+			if err := blockWriter.WriteChangeSets(); err != nil {
+				t.Fatal(err)
+			}
+			testAccounts[i] = newAcc
+		}
+	}
+}
diff --git a/eth/downloader/stagedsync_utils.go b/eth/downloader/stagedsync_utils.go
new file mode 100644
index 0000000000..26501a98a4
--- /dev/null
+++ b/eth/downloader/stagedsync_utils.go
@@ -0,0 +1,40 @@
+package downloader
+
+import "bytes"
+
+type HeapElem struct {
+	key     []byte
+	timeIdx int
+	value   []byte
+}
+
+type Heap []HeapElem
+
+func (h Heap) Len() int {
+	return len(h)
+}
+
+func (h Heap) Less(i, j int) bool {
+	if c := bytes.Compare(h[i].key, h[j].key); c != 0 {
+		return c < 0
+	}
+	return h[i].timeIdx < h[j].timeIdx
+}
+
+func (h Heap) Swap(i, j int) {
+	h[i], h[j] = h[j], h[i]
+}
+
+func (h *Heap) Push(x interface{}) {
+	// Push and Pop use pointer receivers because they modify the slice's length,
+	// not just its contents.
+	*h = append(*h, x.(HeapElem))
+}
+
+func (h *Heap) Pop() interface{} {
+	old := *h
+	n := len(old)
+	x := old[n-1]
+	*h = old[0 : n-1]
+	return x
+}
diff --git a/trie/flatdb_sub_trie_loader.go b/trie/flatdb_sub_trie_loader.go
index 1f62c87a28..dbda82d7e7 100644
--- a/trie/flatdb_sub_trie_loader.go
+++ b/trie/flatdb_sub_trie_loader.go
@@ -105,7 +105,7 @@ func NewFlatDbSubTrieLoader() *FlatDbSubTrieLoader {
 }
 
 // Reset prepares the loader for reuse
-func (fstl *FlatDbSubTrieLoader) Reset(db ethdb.Database, rl RetainDecider, dbPrefixes [][]byte, fixedbits []int, trace bool) error {
+func (fstl *FlatDbSubTrieLoader) Reset(db ethdb.Getter, rl RetainDecider, dbPrefixes [][]byte, fixedbits []int, trace bool) error {
 	fstl.defaultReceiver.Reset(rl, trace)
 	fstl.receiver = fstl.defaultReceiver
 	fstl.rangeIdx = 0
diff --git a/trie/sub_trie_loader.go b/trie/sub_trie_loader.go
index c66eceb4d4..5bc4d20b23 100644
--- a/trie/sub_trie_loader.go
+++ b/trie/sub_trie_loader.go
@@ -59,7 +59,7 @@ func (stl *SubTrieLoader) LoadSubTries(db ethdb.Database, blockNr uint64, rl Ret
 	return stl.LoadFromFlatDB(db, rl, dbPrefixes, fixedbits, trace)
 }
 
-func (stl *SubTrieLoader) LoadFromFlatDB(db ethdb.Database, rl RetainDecider, dbPrefixes [][]byte, fixedbits []int, trace bool) (SubTries, error) {
+func (stl *SubTrieLoader) LoadFromFlatDB(db ethdb.Getter, rl RetainDecider, dbPrefixes [][]byte, fixedbits []int, trace bool) (SubTries, error) {
 	loader := NewFlatDbSubTrieLoader()
 	if err1 := loader.Reset(db, rl, dbPrefixes, fixedbits, trace); err1 != nil {
 		return SubTries{}, err1
-- 
GitLab