From 55599ee95d4151a2502465e0afc7c47bd1acba77 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Szil=C3=A1gyi?= <peterke@gmail.com>
Date: Mon, 5 Feb 2018 18:40:32 +0200
Subject: [PATCH] core, trie: intermediate mempool between trie and database
 (#15857)

This commit reduces database I/O by not writing every state trie to disk.
---
 accounts/abi/bind/backends/simulated.go |  14 +-
 cmd/evm/runner.go                       |   4 +-
 cmd/geth/chaincmd.go                    |   4 +-
 cmd/geth/main.go                        |   3 +
 cmd/geth/usage.go                       |   6 +-
 cmd/utils/cmd.go                        |  26 +-
 cmd/utils/flags.go                      |  47 +++-
 common/size.go                          |  27 +-
 consensus/errors.go                     |   4 +
 core/bench_test.go                      |   4 +-
 core/block_validator.go                 |   9 +-
 core/block_validator_test.go            |   8 +-
 core/blockchain.go                      | 334 +++++++++++++++++-----
 core/blockchain_test.go                 | 145 ++++++++--
 core/chain_indexer.go                   |   3 +
 core/chain_makers.go                    |   9 +-
 core/chain_makers_test.go               |   2 +-
 core/dao_test.go                        |  24 +-
 core/genesis.go                         |  24 +-
 core/genesis_test.go                    |   6 +-
 core/state/database.go                  |  51 +++-
 core/state/iterator_test.go             |  17 +-
 core/state/state_object.go              |   5 +-
 core/state/state_test.go                |   6 +-
 core/state/statedb.go                   |  38 ++-
 core/state/statedb_test.go              |  14 +-
 core/state/sync_test.go                 |  44 +--
 core/tx_pool_test.go                    |   4 +-
 core/types/block.go                     |   9 +
 core/types/receipt.go                   |  13 +
 core/types/transaction.go               |   2 +
 eth/api.go                              |   4 +-
 eth/api_tracer.go                       | 135 ++-------
 eth/backend.go                          |   8 +-
 eth/config.go                           |   8 +-
 eth/downloader/downloader.go            | 317 +++++++++++----------
 eth/downloader/downloader_test.go       | 192 +++----------
 eth/downloader/queue.go                 | 169 +++++------
 eth/downloader/statesync.go             |  31 +--
 eth/handler.go                          |   6 +-
 eth/handler_test.go                     |  16 +-
 eth/helper_test.go                      |  14 +-
 eth/protocol_test.go                    |   6 +-
 eth/sync_test.go                        |   4 +-
 internal/ethapi/api.go                  |   2 +-
 les/handler.go                          | 193 +++++++------
 les/handler_test.go                     |   2 +-
 les/helper_test.go                      |   2 +-
 les/odr_test.go                         |   1 -
 light/lightchain.go                     |   7 +
 light/nodeset.go                        |   8 +-
 light/odr_test.go                       |   4 +-
 light/postprocess.go                    |  64 +++--
 light/trie.go                           |  18 +-
 light/trie_test.go                      |   2 +-
 light/txpool_test.go                    |   2 +-
 miner/worker.go                         |   2 +-
 tests/block_test_util.go                |   2 +-
 tests/state_test_util.go                |   6 +-
 trie/database.go                        | 355 ++++++++++++++++++++++++
 trie/hasher.go                          |  61 +++-
 trie/iterator_test.go                   | 125 ++++++---
 trie/proof.go                           |  47 ++--
 trie/secure_trie.go                     |  62 ++---
 trie/secure_trie_test.go                |  20 +-
 trie/sync.go                            |  14 +-
 trie/sync_test.go                       | 103 ++++---
 trie/trie.go                            |  90 +++---
 trie/trie_test.go                       | 104 +++----
 69 files changed, 1953 insertions(+), 1159 deletions(-)
 create mode 100644 trie/database.go

diff --git a/accounts/abi/bind/backends/simulated.go b/accounts/abi/bind/backends/simulated.go
index 1803d3f23..bd342a8cb 100644
--- a/accounts/abi/bind/backends/simulated.go
+++ b/accounts/abi/bind/backends/simulated.go
@@ -68,7 +68,7 @@ func NewSimulatedBackend(alloc core.GenesisAlloc) *SimulatedBackend {
 	database, _ := ethdb.NewMemDatabase()
 	genesis := core.Genesis{Config: params.AllEthashProtocolChanges, Alloc: alloc}
 	genesis.MustCommit(database)
-	blockchain, _ := core.NewBlockChain(database, genesis.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := core.NewBlockChain(database, nil, genesis.Config, ethash.NewFaker(), vm.Config{})
 
 	backend := &SimulatedBackend{
 		database:   database,
@@ -102,8 +102,10 @@ func (b *SimulatedBackend) Rollback() {
 
 func (b *SimulatedBackend) rollback() {
 	blocks, _ := core.GenerateChain(b.config, b.blockchain.CurrentBlock(), ethash.NewFaker(), b.database, 1, func(int, *core.BlockGen) {})
+	statedb, _ := b.blockchain.State()
+
 	b.pendingBlock = blocks[0]
-	b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
+	b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
 }
 
 // CodeAt returns the code associated with a certain account in the blockchain.
@@ -309,8 +311,10 @@ func (b *SimulatedBackend) SendTransaction(ctx context.Context, tx *types.Transa
 		}
 		block.AddTx(tx)
 	})
+	statedb, _ := b.blockchain.State()
+
 	b.pendingBlock = blocks[0]
-	b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
+	b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
 	return nil
 }
 
@@ -386,8 +390,10 @@ func (b *SimulatedBackend) AdjustTime(adjustment time.Duration) error {
 		}
 		block.OffsetTime(int64(adjustment.Seconds()))
 	})
+	statedb, _ := b.blockchain.State()
+
 	b.pendingBlock = blocks[0]
-	b.pendingState, _ = state.New(b.pendingBlock.Root(), state.NewDatabase(b.database))
+	b.pendingState, _ = state.New(b.pendingBlock.Root(), statedb.Database())
 
 	return nil
 }
diff --git a/cmd/evm/runner.go b/cmd/evm/runner.go
index 96de0c76a..a9a2e5420 100644
--- a/cmd/evm/runner.go
+++ b/cmd/evm/runner.go
@@ -96,7 +96,9 @@ func runCmd(ctx *cli.Context) error {
 	}
 	if ctx.GlobalString(GenesisFlag.Name) != "" {
 		gen := readGenesis(ctx.GlobalString(GenesisFlag.Name))
-		_, statedb = gen.ToBlock()
+		db, _ := ethdb.NewMemDatabase()
+		genesis := gen.ToBlock(db)
+		statedb, _ = state.New(genesis.Root(), state.NewDatabase(db))
 		chainConfig = gen.Config
 	} else {
 		db, _ := ethdb.NewMemDatabase()
diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go
index 4a9a7b11b..35bf576e1 100644
--- a/cmd/geth/chaincmd.go
+++ b/cmd/geth/chaincmd.go
@@ -202,7 +202,7 @@ func importChain(ctx *cli.Context) error {
 
 	if len(ctx.Args()) == 1 {
 		if err := utils.ImportChain(chain, ctx.Args().First()); err != nil {
-			utils.Fatalf("Import error: %v", err)
+			log.Error("Import error", "err", err)
 		}
 	} else {
 		for _, arg := range ctx.Args() {
@@ -211,7 +211,7 @@ func importChain(ctx *cli.Context) error {
 			}
 		}
 	}
-
+	chain.Stop()
 	fmt.Printf("Import done in %v.\n\n", time.Since(start))
 
 	// Output pre-compaction stats mostly to see the import trashing
diff --git a/cmd/geth/main.go b/cmd/geth/main.go
index b955bd243..cb8d63bf7 100644
--- a/cmd/geth/main.go
+++ b/cmd/geth/main.go
@@ -85,10 +85,13 @@ var (
 		utils.FastSyncFlag,
 		utils.LightModeFlag,
 		utils.SyncModeFlag,
+		utils.GCModeFlag,
 		utils.LightServFlag,
 		utils.LightPeersFlag,
 		utils.LightKDFFlag,
 		utils.CacheFlag,
+		utils.CacheDatabaseFlag,
+		utils.CacheGCFlag,
 		utils.TrieCacheGenFlag,
 		utils.ListenPortFlag,
 		utils.MaxPeersFlag,
diff --git a/cmd/geth/usage.go b/cmd/geth/usage.go
index a834d5b7a..a2bcaff02 100644
--- a/cmd/geth/usage.go
+++ b/cmd/geth/usage.go
@@ -22,10 +22,11 @@ import (
 	"io"
 	"sort"
 
+	"strings"
+
 	"github.com/ethereum/go-ethereum/cmd/utils"
 	"github.com/ethereum/go-ethereum/internal/debug"
 	"gopkg.in/urfave/cli.v1"
-	"strings"
 )
 
 // AppHelpTemplate is the test template for the default, global app help topic.
@@ -74,6 +75,7 @@ var AppHelpFlagGroups = []flagGroup{
 			utils.TestnetFlag,
 			utils.RinkebyFlag,
 			utils.SyncModeFlag,
+			utils.GCModeFlag,
 			utils.EthStatsURLFlag,
 			utils.IdentityFlag,
 			utils.LightServFlag,
@@ -127,6 +129,8 @@ var AppHelpFlagGroups = []flagGroup{
 		Name: "PERFORMANCE TUNING",
 		Flags: []cli.Flag{
 			utils.CacheFlag,
+			utils.CacheDatabaseFlag,
+			utils.CacheGCFlag,
 			utils.TrieCacheGenFlag,
 		},
 	},
diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go
index 23b10c2d7..53cdf7861 100644
--- a/cmd/utils/cmd.go
+++ b/cmd/utils/cmd.go
@@ -116,7 +116,6 @@ func ImportChain(chain *core.BlockChain, fn string) error {
 			return err
 		}
 	}
-
 	stream := rlp.NewStream(reader, 0)
 
 	// Run actual the import.
@@ -150,25 +149,34 @@ func ImportChain(chain *core.BlockChain, fn string) error {
 		if checkInterrupt() {
 			return fmt.Errorf("interrupted")
 		}
-		if hasAllBlocks(chain, blocks[:i]) {
+		missing := missingBlocks(chain, blocks[:i])
+		if len(missing) == 0 {
 			log.Info("Skipping batch as all blocks present", "batch", batch, "first", blocks[0].Hash(), "last", blocks[i-1].Hash())
 			continue
 		}
-
-		if _, err := chain.InsertChain(blocks[:i]); err != nil {
+		if _, err := chain.InsertChain(missing); err != nil {
 			return fmt.Errorf("invalid block %d: %v", n, err)
 		}
 	}
 	return nil
 }
 
-func hasAllBlocks(chain *core.BlockChain, bs []*types.Block) bool {
-	for _, b := range bs {
-		if !chain.HasBlock(b.Hash(), b.NumberU64()) {
-			return false
+func missingBlocks(chain *core.BlockChain, blocks []*types.Block) []*types.Block {
+	head := chain.CurrentBlock()
+	for i, block := range blocks {
+		// If we're behind the chain head, only check block, state is available at head
+		if head.NumberU64() > block.NumberU64() {
+			if !chain.HasBlock(block.Hash(), block.NumberU64()) {
+				return blocks[i:]
+			}
+			continue
+		}
+		// If we're above the chain head, state availability is a must
+		if !chain.HasBlockAndState(block.Hash(), block.NumberU64()) {
+			return blocks[i:]
 		}
 	}
-	return true
+	return nil
 }
 
 func ExportChain(blockchain *core.BlockChain, fn string) error {
diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go
index 833cd95de..2a2909ff2 100644
--- a/cmd/utils/flags.go
+++ b/cmd/utils/flags.go
@@ -170,7 +170,11 @@ var (
 		Usage: `Blockchain sync mode ("fast", "full", or "light")`,
 		Value: &defaultSyncMode,
 	}
-
+	GCModeFlag = cli.StringFlag{
+		Name:  "gcmode",
+		Usage: `Blockchain garbage collection mode ("full", "archive")`,
+		Value: "full",
+	}
 	LightServFlag = cli.IntFlag{
 		Name:  "lightserv",
 		Usage: "Maximum percentage of time allowed for serving LES requests (0-90)",
@@ -293,8 +297,18 @@ var (
 	// Performance tuning settings
 	CacheFlag = cli.IntFlag{
 		Name:  "cache",
-		Usage: "Megabytes of memory allocated to internal caching (min 16MB / database forced)",
-		Value: 128,
+		Usage: "Megabytes of memory allocated to internal caching",
+		Value: 1024,
+	}
+	CacheDatabaseFlag = cli.IntFlag{
+		Name:  "cache.database",
+		Usage: "Percentage of cache memory allowance to use for database io",
+		Value: 75,
+	}
+	CacheGCFlag = cli.IntFlag{
+		Name:  "cache.gc",
+		Usage: "Percentage of cache memory allowance to use for trie pruning",
+		Value: 25,
 	}
 	TrieCacheGenFlag = cli.IntFlag{
 		Name:  "trie-cache-gens",
@@ -1021,11 +1035,19 @@ func SetEthConfig(ctx *cli.Context, stack *node.Node, cfg *eth.Config) {
 		cfg.NetworkId = ctx.GlobalUint64(NetworkIdFlag.Name)
 	}
 
-	if ctx.GlobalIsSet(CacheFlag.Name) {
-		cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name)
+	if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheDatabaseFlag.Name) {
+		cfg.DatabaseCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100
 	}
 	cfg.DatabaseHandles = makeDatabaseHandles()
 
+	if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
+		Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
+	}
+	cfg.NoPruning = ctx.GlobalString(GCModeFlag.Name) == "archive"
+
+	if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
+		cfg.TrieCache = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
+	}
 	if ctx.GlobalIsSet(MinerThreadsFlag.Name) {
 		cfg.MinerThreads = ctx.GlobalInt(MinerThreadsFlag.Name)
 	}
@@ -1157,7 +1179,7 @@ func SetupNetwork(ctx *cli.Context) {
 // MakeChainDatabase open an LevelDB using the flags passed to the client and will hard crash if it fails.
 func MakeChainDatabase(ctx *cli.Context, stack *node.Node) ethdb.Database {
 	var (
-		cache   = ctx.GlobalInt(CacheFlag.Name)
+		cache   = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheDatabaseFlag.Name) / 100
 		handles = makeDatabaseHandles()
 	)
 	name := "chaindata"
@@ -1209,8 +1231,19 @@ func MakeChain(ctx *cli.Context, stack *node.Node) (chain *core.BlockChain, chai
 			})
 		}
 	}
+	if gcmode := ctx.GlobalString(GCModeFlag.Name); gcmode != "full" && gcmode != "archive" {
+		Fatalf("--%s must be either 'full' or 'archive'", GCModeFlag.Name)
+	}
+	cache := &core.CacheConfig{
+		Disabled:      ctx.GlobalString(GCModeFlag.Name) == "archive",
+		TrieNodeLimit: eth.DefaultConfig.TrieCache,
+		TrieTimeLimit: eth.DefaultConfig.TrieTimeout,
+	}
+	if ctx.GlobalIsSet(CacheFlag.Name) || ctx.GlobalIsSet(CacheGCFlag.Name) {
+		cache.TrieNodeLimit = ctx.GlobalInt(CacheFlag.Name) * ctx.GlobalInt(CacheGCFlag.Name) / 100
+	}
 	vmcfg := vm.Config{EnablePreimageRecording: ctx.GlobalBool(VMEnableDebugFlag.Name)}
-	chain, err = core.NewBlockChain(chainDb, config, engine, vmcfg)
+	chain, err = core.NewBlockChain(chainDb, cache, config, engine, vmcfg)
 	if err != nil {
 		Fatalf("Can't create BlockChain: %v", err)
 	}
diff --git a/common/size.go b/common/size.go
index c5a0cb0f2..bd0fc85c7 100644
--- a/common/size.go
+++ b/common/size.go
@@ -20,18 +20,29 @@ import (
 	"fmt"
 )
 
+// StorageSize is a wrapper around a float value that supports user friendly
+// formatting.
 type StorageSize float64
 
-func (self StorageSize) String() string {
-	if self > 1000000 {
-		return fmt.Sprintf("%.2f mB", self/1000000)
-	} else if self > 1000 {
-		return fmt.Sprintf("%.2f kB", self/1000)
+// String implements the stringer interface.
+func (s StorageSize) String() string {
+	if s > 1000000 {
+		return fmt.Sprintf("%.2f mB", s/1000000)
+	} else if s > 1000 {
+		return fmt.Sprintf("%.2f kB", s/1000)
 	} else {
-		return fmt.Sprintf("%.2f B", self)
+		return fmt.Sprintf("%.2f B", s)
 	}
 }
 
-func (self StorageSize) Int64() int64 {
-	return int64(self)
+// TerminalString implements log.TerminalStringer, formatting a string for console
+// output during logging.
+func (s StorageSize) TerminalString() string {
+	if s > 1000000 {
+		return fmt.Sprintf("%.2fmB", s/1000000)
+	} else if s > 1000 {
+		return fmt.Sprintf("%.2fkB", s/1000)
+	} else {
+		return fmt.Sprintf("%.2fB", s)
+	}
 }
diff --git a/consensus/errors.go b/consensus/errors.go
index 3b136dbdd..a005c5f63 100644
--- a/consensus/errors.go
+++ b/consensus/errors.go
@@ -23,6 +23,10 @@ var (
 	// that is unknown.
 	ErrUnknownAncestor = errors.New("unknown ancestor")
 
+	// ErrPrunedAncestor is returned when validating a block requires an ancestor
+	// that is known, but the state of which is not available.
+	ErrPrunedAncestor = errors.New("pruned ancestor")
+
 	// ErrFutureBlock is returned when a block's timestamp is in the future according
 	// to the current node.
 	ErrFutureBlock = errors.New("block in the future")
diff --git a/core/bench_test.go b/core/bench_test.go
index f976331d1..e23f0d19d 100644
--- a/core/bench_test.go
+++ b/core/bench_test.go
@@ -173,7 +173,7 @@ func benchInsertChain(b *testing.B, disk bool, gen func(int, *BlockGen)) {
 
 	// Time the insertion of the new chain.
 	// State and blocks are stored in the same DB.
-	chainman, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	chainman, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer chainman.Stop()
 	b.ReportAllocs()
 	b.ResetTimer()
@@ -283,7 +283,7 @@ func benchReadChain(b *testing.B, full bool, count uint64) {
 		if err != nil {
 			b.Fatalf("error opening database at %v: %v", dir, err)
 		}
-		chain, err := NewBlockChain(db, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+		chain, err := NewBlockChain(db, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
 		if err != nil {
 			b.Fatalf("error creating chain: %v", err)
 		}
diff --git a/core/block_validator.go b/core/block_validator.go
index 143728bb8..98958809b 100644
--- a/core/block_validator.go
+++ b/core/block_validator.go
@@ -50,11 +50,14 @@ func NewBlockValidator(config *params.ChainConfig, blockchain *BlockChain, engin
 // validated at this point.
 func (v *BlockValidator) ValidateBody(block *types.Block) error {
 	// Check whether the block's known, and if not, that it's linkable
-	if v.bc.HasBlockAndState(block.Hash()) {
+	if v.bc.HasBlockAndState(block.Hash(), block.NumberU64()) {
 		return ErrKnownBlock
 	}
-	if !v.bc.HasBlockAndState(block.ParentHash()) {
-		return consensus.ErrUnknownAncestor
+	if !v.bc.HasBlockAndState(block.ParentHash(), block.NumberU64()-1) {
+		if !v.bc.HasBlock(block.ParentHash(), block.NumberU64()-1) {
+			return consensus.ErrUnknownAncestor
+		}
+		return consensus.ErrPrunedAncestor
 	}
 	// Header validity is known at this point, check the uncles and transactions
 	header := block.Header()
diff --git a/core/block_validator_test.go b/core/block_validator_test.go
index e668601f3..e334b3c3c 100644
--- a/core/block_validator_test.go
+++ b/core/block_validator_test.go
@@ -42,7 +42,7 @@ func TestHeaderVerification(t *testing.T) {
 		headers[i] = block.Header()
 	}
 	// Run the header checker for blocks one-by-one, checking for both valid and invalid nonces
-	chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+	chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
 	defer chain.Stop()
 
 	for i := 0; i < len(blocks); i++ {
@@ -106,11 +106,11 @@ func testHeaderConcurrentVerification(t *testing.T, threads int) {
 		var results <-chan error
 
 		if valid {
-			chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
+			chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFaker(), vm.Config{})
 			_, results = chain.engine.VerifyHeaders(chain, headers, seals)
 			chain.Stop()
 		} else {
-			chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{})
+			chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeFailer(uint64(len(headers)-1)), vm.Config{})
 			_, results = chain.engine.VerifyHeaders(chain, headers, seals)
 			chain.Stop()
 		}
@@ -173,7 +173,7 @@ func testHeaderConcurrentAbortion(t *testing.T, threads int) {
 	defer runtime.GOMAXPROCS(old)
 
 	// Start the verifications and immediately abort
-	chain, _ := NewBlockChain(testdb, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{})
+	chain, _ := NewBlockChain(testdb, nil, params.TestChainConfig, ethash.NewFakeDelayer(time.Millisecond), vm.Config{})
 	defer chain.Stop()
 
 	abort, results := chain.engine.VerifyHeaders(chain, headers, seals)
diff --git a/core/blockchain.go b/core/blockchain.go
index d5e139e31..8d141fddb 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -42,6 +42,7 @@ import (
 	"github.com/ethereum/go-ethereum/rlp"
 	"github.com/ethereum/go-ethereum/trie"
 	"github.com/hashicorp/golang-lru"
+	"gopkg.in/karalabe/cookiejar.v2/collections/prque"
 )
 
 var (
@@ -56,11 +57,20 @@ const (
 	maxFutureBlocks     = 256
 	maxTimeFutureBlocks = 30
 	badBlockLimit       = 10
+	triesInMemory       = 128
 
 	// BlockChainVersion ensures that an incompatible database forces a resync from scratch.
 	BlockChainVersion = 3
 )
 
+// CacheConfig contains the configuration values for the trie caching/pruning
+// that's resident in a blockchain.
+type CacheConfig struct {
+	Disabled      bool          // Whether to disable trie write caching (archive node)
+	TrieNodeLimit int           // Memory limit (MB) at which to flush the current in-memory trie to disk
+	TrieTimeLimit time.Duration // Time limit after which to flush the current in-memory trie to disk
+}
+
 // BlockChain represents the canonical chain given a database with a genesis
 // block. The Blockchain manages chain imports, reverts, chain reorganisations.
 //
@@ -76,10 +86,14 @@ const (
 // included in the canonical one where as GetBlockByNumber always represents the
 // canonical chain.
 type BlockChain struct {
-	config *params.ChainConfig // chain & network configuration
+	chainConfig *params.ChainConfig // Chain & network configuration
+	cacheConfig *CacheConfig        // Cache configuration for pruning
+
+	db     ethdb.Database // Low level persistent database to store final content in
+	triegc *prque.Prque   // Priority queue mapping block numbers to tries to gc
+	gcproc time.Duration  // Accumulates canonical block processing for trie dumping
 
 	hc            *HeaderChain
-	chainDb       ethdb.Database
 	rmLogsFeed    event.Feed
 	chainFeed     event.Feed
 	chainSideFeed event.Feed
@@ -119,7 +133,13 @@ type BlockChain struct {
 // NewBlockChain returns a fully initialised block chain using information
 // available in the database. It initialises the default Ethereum Validator and
 // Processor.
-func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) {
+func NewBlockChain(db ethdb.Database, cacheConfig *CacheConfig, chainConfig *params.ChainConfig, engine consensus.Engine, vmConfig vm.Config) (*BlockChain, error) {
+	if cacheConfig == nil {
+		cacheConfig = &CacheConfig{
+			TrieNodeLimit: 256 * 1024 * 1024,
+			TrieTimeLimit: 5 * time.Minute,
+		}
+	}
 	bodyCache, _ := lru.New(bodyCacheLimit)
 	bodyRLPCache, _ := lru.New(bodyCacheLimit)
 	blockCache, _ := lru.New(blockCacheLimit)
@@ -127,9 +147,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
 	badBlocks, _ := lru.New(badBlockLimit)
 
 	bc := &BlockChain{
-		config:       config,
-		chainDb:      chainDb,
-		stateCache:   state.NewDatabase(chainDb),
+		chainConfig:  chainConfig,
+		cacheConfig:  cacheConfig,
+		db:           db,
+		triegc:       prque.New(),
+		stateCache:   state.NewDatabase(db),
 		quit:         make(chan struct{}),
 		bodyCache:    bodyCache,
 		bodyRLPCache: bodyRLPCache,
@@ -139,11 +161,11 @@ func NewBlockChain(chainDb ethdb.Database, config *params.ChainConfig, engine co
 		vmConfig:     vmConfig,
 		badBlocks:    badBlocks,
 	}
-	bc.SetValidator(NewBlockValidator(config, bc, engine))
-	bc.SetProcessor(NewStateProcessor(config, bc, engine))
+	bc.SetValidator(NewBlockValidator(chainConfig, bc, engine))
+	bc.SetProcessor(NewStateProcessor(chainConfig, bc, engine))
 
 	var err error
-	bc.hc, err = NewHeaderChain(chainDb, config, engine, bc.getProcInterrupt)
+	bc.hc, err = NewHeaderChain(db, chainConfig, engine, bc.getProcInterrupt)
 	if err != nil {
 		return nil, err
 	}
@@ -180,7 +202,7 @@ func (bc *BlockChain) getProcInterrupt() bool {
 // assumes that the chain manager mutex is held.
 func (bc *BlockChain) loadLastState() error {
 	// Restore the last known head block
-	head := GetHeadBlockHash(bc.chainDb)
+	head := GetHeadBlockHash(bc.db)
 	if head == (common.Hash{}) {
 		// Corrupt or empty database, init from scratch
 		log.Warn("Empty database, resetting chain")
@@ -196,15 +218,17 @@ func (bc *BlockChain) loadLastState() error {
 	// Make sure the state associated with the block is available
 	if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
 		// Dangling block without a state associated, init from scratch
-		log.Warn("Head state missing, resetting chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
-		return bc.Reset()
+		log.Warn("Head state missing, repairing chain", "number", currentBlock.Number(), "hash", currentBlock.Hash())
+		if err := bc.repair(&currentBlock); err != nil {
+			return err
+		}
 	}
 	// Everything seems to be fine, set as the head block
 	bc.currentBlock = currentBlock
 
 	// Restore the last known head header
 	currentHeader := bc.currentBlock.Header()
-	if head := GetHeadHeaderHash(bc.chainDb); head != (common.Hash{}) {
+	if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) {
 		if header := bc.GetHeaderByHash(head); header != nil {
 			currentHeader = header
 		}
@@ -213,7 +237,7 @@ func (bc *BlockChain) loadLastState() error {
 
 	// Restore the last known head fast block
 	bc.currentFastBlock = bc.currentBlock
-	if head := GetHeadFastBlockHash(bc.chainDb); head != (common.Hash{}) {
+	if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) {
 		if block := bc.GetBlockByHash(head); block != nil {
 			bc.currentFastBlock = block
 		}
@@ -243,7 +267,7 @@ func (bc *BlockChain) SetHead(head uint64) error {
 
 	// Rewind the header chain, deleting all block bodies until then
 	delFn := func(hash common.Hash, num uint64) {
-		DeleteBody(bc.chainDb, hash, num)
+		DeleteBody(bc.db, hash, num)
 	}
 	bc.hc.SetHead(head, delFn)
 	currentHeader := bc.hc.CurrentHeader()
@@ -275,10 +299,10 @@ func (bc *BlockChain) SetHead(head uint64) error {
 	if bc.currentFastBlock == nil {
 		bc.currentFastBlock = bc.genesisBlock
 	}
-	if err := WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash()); err != nil {
+	if err := WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()); err != nil {
 		log.Crit("Failed to reset head full block", "err", err)
 	}
-	if err := WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash()); err != nil {
+	if err := WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()); err != nil {
 		log.Crit("Failed to reset head fast block", "err", err)
 	}
 	return bc.loadLastState()
@@ -292,7 +316,7 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
 	if block == nil {
 		return fmt.Errorf("non existent block [%x…]", hash[:4])
 	}
-	if _, err := trie.NewSecure(block.Root(), bc.chainDb, 0); err != nil {
+	if _, err := trie.NewSecure(block.Root(), bc.stateCache.TrieDB(), 0); err != nil {
 		return err
 	}
 	// If all checks out, manually set the head block
@@ -387,7 +411,7 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
 	if err := bc.hc.WriteTd(genesis.Hash(), genesis.NumberU64(), genesis.Difficulty()); err != nil {
 		log.Crit("Failed to write genesis block TD", "err", err)
 	}
-	if err := WriteBlock(bc.chainDb, genesis); err != nil {
+	if err := WriteBlock(bc.db, genesis); err != nil {
 		log.Crit("Failed to write genesis block", "err", err)
 	}
 	bc.genesisBlock = genesis
@@ -400,6 +424,24 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
 	return nil
 }
 
+// repair tries to repair the current blockchain by rolling back the current block
+// until one with associated state is found. This is needed to fix incomplete db
+// writes caused either by crashes/power outages, or simply non-committed tries.
+//
+// This method only rolls back the current block. The current header and current
+// fast block are left intact.
+func (bc *BlockChain) repair(head **types.Block) error {
+	for {
+		// Abort if we've rewound to a head block that does have associated state
+		if _, err := state.New((*head).Root(), bc.stateCache); err == nil {
+			log.Info("Rewound blockchain to past state", "number", (*head).Number(), "hash", (*head).Hash())
+			return nil
+		}
+		// Otherwise rewind one block and recheck state availability there
+		(*head) = bc.GetBlock((*head).ParentHash(), (*head).NumberU64()-1)
+	}
+}
+
 // Export writes the active chain to the given writer.
 func (bc *BlockChain) Export(w io.Writer) error {
 	return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64())
@@ -437,13 +479,13 @@ func (bc *BlockChain) ExportN(w io.Writer, first uint64, last uint64) error {
 // Note, this function assumes that the `mu` mutex is held!
 func (bc *BlockChain) insert(block *types.Block) {
 	// If the block is on a side chain or an unknown one, force other heads onto it too
-	updateHeads := GetCanonicalHash(bc.chainDb, block.NumberU64()) != block.Hash()
+	updateHeads := GetCanonicalHash(bc.db, block.NumberU64()) != block.Hash()
 
 	// Add the block to the canonical chain number scheme and mark as the head
-	if err := WriteCanonicalHash(bc.chainDb, block.Hash(), block.NumberU64()); err != nil {
+	if err := WriteCanonicalHash(bc.db, block.Hash(), block.NumberU64()); err != nil {
 		log.Crit("Failed to insert block number", "err", err)
 	}
-	if err := WriteHeadBlockHash(bc.chainDb, block.Hash()); err != nil {
+	if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil {
 		log.Crit("Failed to insert head block hash", "err", err)
 	}
 	bc.currentBlock = block
@@ -452,7 +494,7 @@ func (bc *BlockChain) insert(block *types.Block) {
 	if updateHeads {
 		bc.hc.SetCurrentHeader(block.Header())
 
-		if err := WriteHeadFastBlockHash(bc.chainDb, block.Hash()); err != nil {
+		if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil {
 			log.Crit("Failed to insert head fast block hash", "err", err)
 		}
 		bc.currentFastBlock = block
@@ -472,7 +514,7 @@ func (bc *BlockChain) GetBody(hash common.Hash) *types.Body {
 		body := cached.(*types.Body)
 		return body
 	}
-	body := GetBody(bc.chainDb, hash, bc.hc.GetBlockNumber(hash))
+	body := GetBody(bc.db, hash, bc.hc.GetBlockNumber(hash))
 	if body == nil {
 		return nil
 	}
@@ -488,7 +530,7 @@ func (bc *BlockChain) GetBodyRLP(hash common.Hash) rlp.RawValue {
 	if cached, ok := bc.bodyRLPCache.Get(hash); ok {
 		return cached.(rlp.RawValue)
 	}
-	body := GetBodyRLP(bc.chainDb, hash, bc.hc.GetBlockNumber(hash))
+	body := GetBodyRLP(bc.db, hash, bc.hc.GetBlockNumber(hash))
 	if len(body) == 0 {
 		return nil
 	}
@@ -502,21 +544,25 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool {
 	if bc.blockCache.Contains(hash) {
 		return true
 	}
-	ok, _ := bc.chainDb.Has(blockBodyKey(hash, number))
+	ok, _ := bc.db.Has(blockBodyKey(hash, number))
 	return ok
 }
 
+// HasState checks if state trie is fully present in the database or not.
+func (bc *BlockChain) HasState(hash common.Hash) bool {
+	_, err := bc.stateCache.OpenTrie(hash)
+	return err == nil
+}
+
 // HasBlockAndState checks if a block and associated state trie is fully present
 // in the database or not, caching it if present.
-func (bc *BlockChain) HasBlockAndState(hash common.Hash) bool {
+func (bc *BlockChain) HasBlockAndState(hash common.Hash, number uint64) bool {
 	// Check first that the block itself is known
-	block := bc.GetBlockByHash(hash)
+	block := bc.GetBlock(hash, number)
 	if block == nil {
 		return false
 	}
-	// Ensure the associated state is also present
-	_, err := bc.stateCache.OpenTrie(block.Root())
-	return err == nil
+	return bc.HasState(block.Root())
 }
 
 // GetBlock retrieves a block from the database by hash and number,
@@ -526,7 +572,7 @@ func (bc *BlockChain) GetBlock(hash common.Hash, number uint64) *types.Block {
 	if block, ok := bc.blockCache.Get(hash); ok {
 		return block.(*types.Block)
 	}
-	block := GetBlock(bc.chainDb, hash, number)
+	block := GetBlock(bc.db, hash, number)
 	if block == nil {
 		return nil
 	}
@@ -543,13 +589,18 @@ func (bc *BlockChain) GetBlockByHash(hash common.Hash) *types.Block {
 // GetBlockByNumber retrieves a block from the database by number, caching it
 // (associated with its hash) if found.
 func (bc *BlockChain) GetBlockByNumber(number uint64) *types.Block {
-	hash := GetCanonicalHash(bc.chainDb, number)
+	hash := GetCanonicalHash(bc.db, number)
 	if hash == (common.Hash{}) {
 		return nil
 	}
 	return bc.GetBlock(hash, number)
 }
 
+// GetReceiptsByHash retrieves the receipts for all transactions in a given block.
+func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts {
+	return GetBlockReceipts(bc.db, hash, GetBlockNumber(bc.db, hash))
+}
+
 // GetBlocksFromHash returns the block corresponding to hash and up to n-1 ancestors.
 // [deprecated by eth/62]
 func (bc *BlockChain) GetBlocksFromHash(hash common.Hash, n int) (blocks []*types.Block) {
@@ -577,6 +628,12 @@ func (bc *BlockChain) GetUnclesInChain(block *types.Block, length int) []*types.
 	return uncles
 }
 
+// TrieNode retrieves a blob of data associated with a trie node (or code hash)
+// either from ephemeral in-memory cache, or from persistent storage.
+func (bc *BlockChain) TrieNode(hash common.Hash) ([]byte, error) {
+	return bc.stateCache.TrieDB().Node(hash)
+}
+
 // Stop stops the blockchain service. If any imports are currently in progress
 // it will abort them using the procInterrupt.
 func (bc *BlockChain) Stop() {
@@ -589,6 +646,33 @@ func (bc *BlockChain) Stop() {
 	atomic.StoreInt32(&bc.procInterrupt, 1)
 
 	bc.wg.Wait()
+
+	// Ensure the state of a recent block is also stored to disk before exiting.
+	// It is fine if this state does not exist (fast start/stop cycle), but it is
+	// advisable to leave an N block gap from the head so 1) a restart loads up
+	// the last N blocks as sync assistance to remote nodes; 2) a restart during
+	// a (small) reorg doesn't require deep reprocesses; 3) chain "repair" from
+	// missing states are constantly tested.
+	//
+	// This may be tuned a bit on mainnet if its too annoying to reprocess the last
+	// N blocks.
+	if !bc.cacheConfig.Disabled {
+		triedb := bc.stateCache.TrieDB()
+		if number := bc.CurrentBlock().NumberU64(); number >= triesInMemory {
+			recent := bc.GetBlockByNumber(bc.CurrentBlock().NumberU64() - triesInMemory + 1)
+
+			log.Info("Writing cached state to disk", "block", recent.Number(), "hash", recent.Hash(), "root", recent.Root())
+			if err := triedb.Commit(recent.Root(), true); err != nil {
+				log.Error("Failed to commit recent state trie", "err", err)
+			}
+		}
+		for !bc.triegc.Empty() {
+			triedb.Dereference(bc.triegc.PopItem().(common.Hash), common.Hash{})
+		}
+		if size := triedb.Size(); size != 0 {
+			log.Error("Dangling trie nodes after full cleanup")
+		}
+	}
 	log.Info("Blockchain manager stopped")
 }
 
@@ -633,11 +717,11 @@ func (bc *BlockChain) Rollback(chain []common.Hash) {
 		}
 		if bc.currentFastBlock.Hash() == hash {
 			bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1)
-			WriteHeadFastBlockHash(bc.chainDb, bc.currentFastBlock.Hash())
+			WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash())
 		}
 		if bc.currentBlock.Hash() == hash {
 			bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1)
-			WriteHeadBlockHash(bc.chainDb, bc.currentBlock.Hash())
+			WriteHeadBlockHash(bc.db, bc.currentBlock.Hash())
 		}
 	}
 }
@@ -696,7 +780,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 		stats = struct{ processed, ignored int32 }{}
 		start = time.Now()
 		bytes = 0
-		batch = bc.chainDb.NewBatch()
+		batch = bc.db.NewBatch()
 	)
 	for i, block := range blockChain {
 		receipts := receiptChain[i]
@@ -714,7 +798,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 			continue
 		}
 		// Compute all the non-consensus fields of the receipts
-		SetReceiptsData(bc.config, block, receipts)
+		SetReceiptsData(bc.chainConfig, block, receipts)
 		// Write all the data out into the database
 		if err := WriteBody(batch, block.Hash(), block.NumberU64(), block.Body()); err != nil {
 			return i, fmt.Errorf("failed to write block body: %v", err)
@@ -747,7 +831,7 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 	head := blockChain[len(blockChain)-1]
 	if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
 		if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 {
-			if err := WriteHeadFastBlockHash(bc.chainDb, head.Hash()); err != nil {
+			if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil {
 				log.Crit("Failed to update head fast block hash", "err", err)
 			}
 			bc.currentFastBlock = head
@@ -758,15 +842,33 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 	log.Info("Imported new block receipts",
 		"count", stats.processed,
 		"elapsed", common.PrettyDuration(time.Since(start)),
-		"bytes", bytes,
 		"number", head.Number(),
 		"hash", head.Hash(),
+		"size", common.StorageSize(bytes),
 		"ignored", stats.ignored)
 	return 0, nil
 }
 
-// WriteBlock writes the block to the chain.
-func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
+var lastWrite uint64
+
+// WriteBlockWithoutState writes only the block and its metadata to the database,
+// but does not write any state. This is used to construct competing side forks
+// up to the point where they exceed the canonical total difficulty.
+func (bc *BlockChain) WriteBlockWithoutState(block *types.Block, td *big.Int) (err error) {
+	bc.wg.Add(1)
+	defer bc.wg.Done()
+
+	if err := bc.hc.WriteTd(block.Hash(), block.NumberU64(), td); err != nil {
+		return err
+	}
+	if err := WriteBlock(bc.db, block); err != nil {
+		return err
+	}
+	return nil
+}
+
+// WriteBlockWithState writes the block and all associated state to the database.
+func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.Receipt, state *state.StateDB) (status WriteStatus, err error) {
 	bc.wg.Add(1)
 	defer bc.wg.Done()
 
@@ -787,17 +889,73 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R
 		return NonStatTy, err
 	}
 	// Write other block data using a batch.
-	batch := bc.chainDb.NewBatch()
+	batch := bc.db.NewBatch()
 	if err := WriteBlock(batch, block); err != nil {
 		return NonStatTy, err
 	}
-	if _, err := state.CommitTo(batch, bc.config.IsEIP158(block.Number())); err != nil {
+	root, err := state.Commit(bc.chainConfig.IsEIP158(block.Number()))
+	if err != nil {
 		return NonStatTy, err
 	}
+	triedb := bc.stateCache.TrieDB()
+
+	// If we're running an archive node, always flush
+	if bc.cacheConfig.Disabled {
+		if err := triedb.Commit(root, false); err != nil {
+			return NonStatTy, err
+		}
+	} else {
+		// Full but not archive node, do proper garbage collection
+		triedb.Reference(root, common.Hash{}) // metadata reference to keep trie alive
+		bc.triegc.Push(root, -float32(block.NumberU64()))
+
+		if current := block.NumberU64(); current > triesInMemory {
+			// Find the next state trie we need to commit
+			header := bc.GetHeaderByNumber(current - triesInMemory)
+			chosen := header.Number.Uint64()
+
+			// Only write to disk if we exceeded our memory allowance *and* also have at
+			// least a given number of tries gapped.
+			var (
+				size  = triedb.Size()
+				limit = common.StorageSize(bc.cacheConfig.TrieNodeLimit) * 1024 * 1024
+			)
+			if size > limit || bc.gcproc > bc.cacheConfig.TrieTimeLimit {
+				// If we're exceeding limits but haven't reached a large enough memory gap,
+				// warn the user that the system is becoming unstable.
+				if chosen < lastWrite+triesInMemory {
+					switch {
+					case size >= 2*limit:
+						log.Error("Trie memory critical, forcing to disk", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+					case bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit:
+						log.Error("Trie timing critical, forcing to disk", "time", bc.gcproc, "allowance", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+					case size > limit:
+						log.Warn("Trie memory at dangerous levels", "size", size, "limit", limit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+					case bc.gcproc > bc.cacheConfig.TrieTimeLimit:
+						log.Warn("Trie timing at dangerous levels", "time", bc.gcproc, "limit", bc.cacheConfig.TrieTimeLimit, "optimum", float64(chosen-lastWrite)/triesInMemory)
+					}
+				}
+				// If optimum or critical limits reached, write to disk
+				if chosen >= lastWrite+triesInMemory || size >= 2*limit || bc.gcproc >= 2*bc.cacheConfig.TrieTimeLimit {
+					triedb.Commit(header.Root, true)
+					lastWrite = chosen
+					bc.gcproc = 0
+				}
+			}
+			// Garbage collect anything below our required write retention
+			for !bc.triegc.Empty() {
+				root, number := bc.triegc.Pop()
+				if uint64(-number) > chosen {
+					bc.triegc.Push(root, number)
+					break
+				}
+				triedb.Dereference(root.(common.Hash), common.Hash{})
+			}
+		}
+	}
 	if err := WriteBlockReceipts(batch, block.Hash(), block.NumberU64(), receipts); err != nil {
 		return NonStatTy, err
 	}
-
 	// If the total difficulty is higher than our known, add it to the canonical chain
 	// Second clause in the if statement reduces the vulnerability to selfish mining.
 	// Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf
@@ -818,7 +976,7 @@ func (bc *BlockChain) WriteBlockAndState(block *types.Block, receipts []*types.R
 			return NonStatTy, err
 		}
 		// Write hash preimages
-		if err := WritePreimages(bc.chainDb, block.NumberU64(), state.Preimages()); err != nil {
+		if err := WritePreimages(bc.db, block.NumberU64(), state.Preimages()); err != nil {
 			return NonStatTy, err
 		}
 		status = CanonStatTy
@@ -910,31 +1068,60 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
 		if err == nil {
 			err = bc.Validator().ValidateBody(block)
 		}
-		if err != nil {
-			if err == ErrKnownBlock {
-				stats.ignored++
-				continue
+		switch {
+		case err == ErrKnownBlock:
+			stats.ignored++
+			continue
+
+		case err == consensus.ErrFutureBlock:
+			// Allow up to MaxFuture second in the future blocks. If this limit is exceeded
+			// the chain is discarded and processed at a later time if given.
+			max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks)
+			if block.Time().Cmp(max) > 0 {
+				return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max)
 			}
+			bc.futureBlocks.Add(block.Hash(), block)
+			stats.queued++
+			continue
+
+		case err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()):
+			bc.futureBlocks.Add(block.Hash(), block)
+			stats.queued++
+			continue
 
-			if err == consensus.ErrFutureBlock {
-				// Allow up to MaxFuture second in the future blocks. If this limit
-				// is exceeded the chain is discarded and processed at a later time
-				// if given.
-				max := big.NewInt(time.Now().Unix() + maxTimeFutureBlocks)
-				if block.Time().Cmp(max) > 0 {
-					return i, events, coalescedLogs, fmt.Errorf("future block: %v > %v", block.Time(), max)
+		case err == consensus.ErrPrunedAncestor:
+			// Block competing with the canonical chain, store in the db, but don't process
+			// until the competitor TD goes above the canonical TD
+			localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64())
+			externTd := new(big.Int).Add(bc.GetTd(block.ParentHash(), block.NumberU64()-1), block.Difficulty())
+			if localTd.Cmp(externTd) > 0 {
+				if err = bc.WriteBlockWithoutState(block, externTd); err != nil {
+					return i, events, coalescedLogs, err
 				}
-				bc.futureBlocks.Add(block.Hash(), block)
-				stats.queued++
 				continue
 			}
+			// Competitor chain beat canonical, gather all blocks from the common ancestor
+			var winner []*types.Block
 
-			if err == consensus.ErrUnknownAncestor && bc.futureBlocks.Contains(block.ParentHash()) {
-				bc.futureBlocks.Add(block.Hash(), block)
-				stats.queued++
-				continue
+			parent := bc.GetBlock(block.ParentHash(), block.NumberU64()-1)
+			for !bc.HasState(parent.Root()) {
+				winner = append(winner, parent)
+				parent = bc.GetBlock(parent.ParentHash(), parent.NumberU64()-1)
+			}
+			for j := 0; j < len(winner)/2; j++ {
+				winner[j], winner[len(winner)-1-j] = winner[len(winner)-1-j], winner[j]
+			}
+			// Import all the pruned blocks to make the state available
+			bc.chainmu.Unlock()
+			_, evs, logs, err := bc.insertChain(winner)
+			bc.chainmu.Lock()
+			events, coalescedLogs = evs, logs
+
+			if err != nil {
+				return i, events, coalescedLogs, err
 			}
 
+		case err != nil:
 			bc.reportBlock(block, nil, err)
 			return i, events, coalescedLogs, err
 		}
@@ -962,8 +1149,10 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
 			bc.reportBlock(block, receipts, err)
 			return i, events, coalescedLogs, err
 		}
+		proctime := time.Since(bstart)
+
 		// Write the block to the chain and get the status.
-		status, err := bc.WriteBlockAndState(block, receipts, state)
+		status, err := bc.WriteBlockWithState(block, receipts, state)
 		if err != nil {
 			return i, events, coalescedLogs, err
 		}
@@ -977,6 +1166,9 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
 			events = append(events, ChainEvent{block, block.Hash(), logs})
 			lastCanon = block
 
+			// Only count canonical blocks for GC processing time
+			bc.gcproc += proctime
+
 		case SideStatTy:
 			log.Debug("Inserted forked block", "number", block.Number(), "hash", block.Hash(), "diff", block.Difficulty(), "elapsed",
 				common.PrettyDuration(time.Since(bstart)), "txs", len(block.Transactions()), "gas", block.GasUsed(), "uncles", len(block.Uncles()))
@@ -986,7 +1178,7 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
 		}
 		stats.processed++
 		stats.usedGas += usedGas
-		stats.report(chain, i)
+		stats.report(chain, i, bc.stateCache.TrieDB().Size())
 	}
 	// Append a single chain head event if we've progressed the chain
 	if lastCanon != nil && bc.CurrentBlock().Hash() == lastCanon.Hash() {
@@ -1009,7 +1201,7 @@ const statsReportLimit = 8 * time.Second
 
 // report prints statistics if some number of blocks have been processed
 // or more than a few seconds have passed since the last message.
-func (st *insertStats) report(chain []*types.Block, index int) {
+func (st *insertStats) report(chain []*types.Block, index int, cache common.StorageSize) {
 	// Fetch the timings for the batch
 	var (
 		now     = mclock.Now()
@@ -1024,7 +1216,7 @@ func (st *insertStats) report(chain []*types.Block, index int) {
 		context := []interface{}{
 			"blocks", st.processed, "txs", txs, "mgas", float64(st.usedGas) / 1000000,
 			"elapsed", common.PrettyDuration(elapsed), "mgasps", float64(st.usedGas) * 1000 / float64(elapsed),
-			"number", end.Number(), "hash", end.Hash(),
+			"number", end.Number(), "hash", end.Hash(), "cache", cache,
 		}
 		if st.queued > 0 {
 			context = append(context, []interface{}{"queued", st.queued}...)
@@ -1060,7 +1252,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
 		// These logs are later announced as deleted.
 		collectLogs = func(h common.Hash) {
 			// Coalesce logs and set 'Removed'.
-			receipts := GetBlockReceipts(bc.chainDb, h, bc.hc.GetBlockNumber(h))
+			receipts := GetBlockReceipts(bc.db, h, bc.hc.GetBlockNumber(h))
 			for _, receipt := range receipts {
 				for _, log := range receipt.Logs {
 					del := *log
@@ -1129,7 +1321,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
 		// insert the block in the canonical way, re-writing history
 		bc.insert(newChain[i])
 		// write lookup entries for hash based transaction/receipt searches
-		if err := WriteTxLookupEntries(bc.chainDb, newChain[i]); err != nil {
+		if err := WriteTxLookupEntries(bc.db, newChain[i]); err != nil {
 			return err
 		}
 		addedTxs = append(addedTxs, newChain[i].Transactions()...)
@@ -1139,7 +1331,7 @@ func (bc *BlockChain) reorg(oldBlock, newBlock *types.Block) error {
 	// When transactions get deleted from the database that means the
 	// receipts that were created in the fork must also be deleted
 	for _, tx := range diff {
-		DeleteTxLookupEntry(bc.chainDb, tx.Hash())
+		DeleteTxLookupEntry(bc.db, tx.Hash())
 	}
 	if len(deletedLogs) > 0 {
 		go bc.rmLogsFeed.Send(RemovedLogsEvent{deletedLogs})
@@ -1231,7 +1423,7 @@ Hash: 0x%x
 
 Error: %v
 ##############################
-`, bc.config, block.Number(), block.Hash(), receiptString, err))
+`, bc.chainConfig, block.Number(), block.Hash(), receiptString, err))
 }
 
 // InsertHeaderChain attempts to insert the given header chain in to the local
@@ -1338,7 +1530,7 @@ func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
 }
 
 // Config retrieves the blockchain's chain configuration.
-func (bc *BlockChain) Config() *params.ChainConfig { return bc.config }
+func (bc *BlockChain) Config() *params.ChainConfig { return bc.chainConfig }
 
 // Engine retrieves the blockchain's consensus engine.
 func (bc *BlockChain) Engine() consensus.Engine { return bc.engine }
diff --git a/core/blockchain_test.go b/core/blockchain_test.go
index cbde3bcd2..635379161 100644
--- a/core/blockchain_test.go
+++ b/core/blockchain_test.go
@@ -46,7 +46,7 @@ func newTestBlockChain(fake bool) *BlockChain {
 	if !fake {
 		engine = ethash.NewTester()
 	}
-	blockchain, err := NewBlockChain(db, gspec.Config, engine, vm.Config{})
+	blockchain, err := NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
 	if err != nil {
 		panic(err)
 	}
@@ -148,9 +148,9 @@ func testBlockChainImport(chain types.Blocks, blockchain *BlockChain) error {
 			return err
 		}
 		blockchain.mu.Lock()
-		WriteTd(blockchain.chainDb, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
-		WriteBlock(blockchain.chainDb, block)
-		statedb.CommitTo(blockchain.chainDb, false)
+		WriteTd(blockchain.db, block.Hash(), block.NumberU64(), new(big.Int).Add(block.Difficulty(), blockchain.GetTdByHash(block.ParentHash())))
+		WriteBlock(blockchain.db, block)
+		statedb.Commit(false)
 		blockchain.mu.Unlock()
 	}
 	return nil
@@ -166,8 +166,8 @@ func testHeaderChainImport(chain []*types.Header, blockchain *BlockChain) error
 		}
 		// Manually insert the header into the database, but don't reorganise (allows subsequent testing)
 		blockchain.mu.Lock()
-		WriteTd(blockchain.chainDb, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
-		WriteHeader(blockchain.chainDb, header)
+		WriteTd(blockchain.db, header.Hash(), header.Number.Uint64(), new(big.Int).Add(header.Difficulty, blockchain.GetTdByHash(header.ParentHash)))
+		WriteHeader(blockchain.db, header)
 		blockchain.mu.Unlock()
 	}
 	return nil
@@ -186,9 +186,9 @@ func TestLastBlock(t *testing.T) {
 	bchain := newTestBlockChain(false)
 	defer bchain.Stop()
 
-	block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.chainDb, 0)[0]
+	block := makeBlockChain(bchain.CurrentBlock(), 1, ethash.NewFaker(), bchain.db, 0)[0]
 	bchain.insert(block)
-	if block.Hash() != GetHeadBlockHash(bchain.chainDb) {
+	if block.Hash() != GetHeadBlockHash(bchain.db) {
 		t.Errorf("Write/Get HeadBlockHash failed")
 	}
 }
@@ -496,7 +496,7 @@ func testReorgBadHashes(t *testing.T, full bool) {
 	}
 
 	// Create a new BlockChain and check that it rolled back the state.
-	ncm, err := NewBlockChain(bc.chainDb, bc.config, ethash.NewFaker(), vm.Config{})
+	ncm, err := NewBlockChain(bc.db, nil, bc.chainConfig, ethash.NewFaker(), vm.Config{})
 	if err != nil {
 		t.Fatalf("failed to create new chain manager: %v", err)
 	}
@@ -609,7 +609,7 @@ func TestFastVsFullChains(t *testing.T) {
 	// Import the chain as an archive node for the comparison baseline
 	archiveDb, _ := ethdb.NewMemDatabase()
 	gspec.MustCommit(archiveDb)
-	archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+	archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer archive.Stop()
 
 	if n, err := archive.InsertChain(blocks); err != nil {
@@ -618,7 +618,7 @@ func TestFastVsFullChains(t *testing.T) {
 	// Fast import the chain as a non-archive node to test
 	fastDb, _ := ethdb.NewMemDatabase()
 	gspec.MustCommit(fastDb)
-	fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+	fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer fast.Stop()
 
 	headers := make([]*types.Header, len(blocks))
@@ -696,7 +696,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 	archiveDb, _ := ethdb.NewMemDatabase()
 	gspec.MustCommit(archiveDb)
 
-	archive, _ := NewBlockChain(archiveDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+	archive, _ := NewBlockChain(archiveDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	if n, err := archive.InsertChain(blocks); err != nil {
 		t.Fatalf("failed to process block %d: %v", n, err)
 	}
@@ -709,7 +709,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 	// Import the chain as a non-archive node and ensure all pointers are updated
 	fastDb, _ := ethdb.NewMemDatabase()
 	gspec.MustCommit(fastDb)
-	fast, _ := NewBlockChain(fastDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+	fast, _ := NewBlockChain(fastDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer fast.Stop()
 
 	headers := make([]*types.Header, len(blocks))
@@ -730,7 +730,7 @@ func TestLightVsFastVsFullChainHeads(t *testing.T) {
 	lightDb, _ := ethdb.NewMemDatabase()
 	gspec.MustCommit(lightDb)
 
-	light, _ := NewBlockChain(lightDb, gspec.Config, ethash.NewFaker(), vm.Config{})
+	light, _ := NewBlockChain(lightDb, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	if n, err := light.InsertHeaderChain(headers, 1); err != nil {
 		t.Fatalf("failed to insert header %d: %v", n, err)
 	}
@@ -799,7 +799,7 @@ func TestChainTxReorgs(t *testing.T) {
 		}
 	})
 	// Import the chain. This runs all block validation rules.
-	blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	if i, err := blockchain.InsertChain(chain); err != nil {
 		t.Fatalf("failed to insert original chain[%d]: %v", i, err)
 	}
@@ -870,7 +870,7 @@ func TestLogReorgs(t *testing.T) {
 		signer  = types.NewEIP155Signer(gspec.Config.ChainId)
 	)
 
-	blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer blockchain.Stop()
 
 	rmLogsCh := make(chan RemovedLogsEvent)
@@ -917,7 +917,7 @@ func TestReorgSideEvent(t *testing.T) {
 		signer  = types.NewEIP155Signer(gspec.Config.ChainId)
 	)
 
-	blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer blockchain.Stop()
 
 	chain, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, gen *BlockGen) {})
@@ -992,7 +992,7 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
 	bc := newTestBlockChain(true)
 	defer bc.Stop()
 
-	chain, _ := GenerateChain(bc.config, bc.genesisBlock, ethash.NewFaker(), bc.chainDb, 10, func(i int, gen *BlockGen) {})
+	chain, _ := GenerateChain(bc.chainConfig, bc.genesisBlock, ethash.NewFaker(), bc.db, 10, func(i int, gen *BlockGen) {})
 
 	var pend sync.WaitGroup
 	pend.Add(len(chain))
@@ -1003,14 +1003,14 @@ func TestCanonicalBlockRetrieval(t *testing.T) {
 
 			// try to retrieve a block by its canonical hash and see if the block data can be retrieved.
 			for {
-				ch := GetCanonicalHash(bc.chainDb, block.NumberU64())
+				ch := GetCanonicalHash(bc.db, block.NumberU64())
 				if ch == (common.Hash{}) {
 					continue // busy wait for canonical hash to be written
 				}
 				if ch != block.Hash() {
 					t.Fatalf("unknown canonical hash, want %s, got %s", block.Hash().Hex(), ch.Hex())
 				}
-				fb := GetBlock(bc.chainDb, ch, block.NumberU64())
+				fb := GetBlock(bc.db, ch, block.NumberU64())
 				if fb == nil {
 					t.Fatalf("unable to retrieve block %d for canonical hash: %s", block.NumberU64(), ch.Hex())
 				}
@@ -1043,7 +1043,7 @@ func TestEIP155Transition(t *testing.T) {
 		genesis = gspec.MustCommit(db)
 	)
 
-	blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer blockchain.Stop()
 
 	blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 4, func(i int, block *BlockGen) {
@@ -1151,7 +1151,7 @@ func TestEIP161AccountRemoval(t *testing.T) {
 		}
 		genesis = gspec.MustCommit(db)
 	)
-	blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer blockchain.Stop()
 
 	blocks, _ := GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, 3, func(i int, block *BlockGen) {
@@ -1226,7 +1226,7 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) {
 	diskdb, _ := ethdb.NewMemDatabase()
 	new(Genesis).MustCommit(diskdb)
 
-	chain, err := NewBlockChain(diskdb, params.TestChainConfig, engine, vm.Config{})
+	chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
 	if err != nil {
 		t.Fatalf("failed to create tester chain: %v", err)
 	}
@@ -1245,3 +1245,102 @@ func TestBlockchainHeaderchainReorgConsistency(t *testing.T) {
 		}
 	}
 }
+
+// Tests that importing small side forks doesn't leave junk in the trie database
+// cache (which would eventually cause memory issues).
+func TestTrieForkGC(t *testing.T) {
+	// Generate a canonical chain to act as the main dataset
+	engine := ethash.NewFaker()
+
+	db, _ := ethdb.NewMemDatabase()
+	genesis := new(Genesis).MustCommit(db)
+	blocks, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
+
+	// Generate a bunch of fork blocks, each side forking from the canonical chain
+	forks := make([]*types.Block, len(blocks))
+	for i := 0; i < len(forks); i++ {
+		parent := genesis
+		if i > 0 {
+			parent = blocks[i-1]
+		}
+		fork, _ := GenerateChain(params.TestChainConfig, parent, engine, db, 1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
+		forks[i] = fork[0]
+	}
+	// Import the canonical and fork chain side by side, forcing the trie cache to cache both
+	diskdb, _ := ethdb.NewMemDatabase()
+	new(Genesis).MustCommit(diskdb)
+
+	chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
+	if err != nil {
+		t.Fatalf("failed to create tester chain: %v", err)
+	}
+	for i := 0; i < len(blocks); i++ {
+		if _, err := chain.InsertChain(blocks[i : i+1]); err != nil {
+			t.Fatalf("block %d: failed to insert into chain: %v", i, err)
+		}
+		if _, err := chain.InsertChain(forks[i : i+1]); err != nil {
+			t.Fatalf("fork %d: failed to insert into chain: %v", i, err)
+		}
+	}
+	// Dereference all the recent tries and ensure no past trie is left in
+	for i := 0; i < triesInMemory; i++ {
+		chain.stateCache.TrieDB().Dereference(blocks[len(blocks)-1-i].Root(), common.Hash{})
+		chain.stateCache.TrieDB().Dereference(forks[len(blocks)-1-i].Root(), common.Hash{})
+	}
+	if len(chain.stateCache.TrieDB().Nodes()) > 0 {
+		t.Fatalf("stale tries still alive after garbase collection")
+	}
+}
+
+// Tests that doing large reorgs works even if the state associated with the
+// forking point is not available any more.
+func TestLargeReorgTrieGC(t *testing.T) {
+	// Generate the original common chain segment and the two competing forks
+	engine := ethash.NewFaker()
+
+	db, _ := ethdb.NewMemDatabase()
+	genesis := new(Genesis).MustCommit(db)
+
+	shared, _ := GenerateChain(params.TestChainConfig, genesis, engine, db, 64, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{1}) })
+	original, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{2}) })
+	competitor, _ := GenerateChain(params.TestChainConfig, shared[len(shared)-1], engine, db, 2*triesInMemory+1, func(i int, b *BlockGen) { b.SetCoinbase(common.Address{3}) })
+
+	// Import the shared chain and the original canonical one
+	diskdb, _ := ethdb.NewMemDatabase()
+	new(Genesis).MustCommit(diskdb)
+
+	chain, err := NewBlockChain(diskdb, nil, params.TestChainConfig, engine, vm.Config{})
+	if err != nil {
+		t.Fatalf("failed to create tester chain: %v", err)
+	}
+	if _, err := chain.InsertChain(shared); err != nil {
+		t.Fatalf("failed to insert shared chain: %v", err)
+	}
+	if _, err := chain.InsertChain(original); err != nil {
+		t.Fatalf("failed to insert shared chain: %v", err)
+	}
+	// Ensure that the state associated with the forking point is pruned away
+	if node, _ := chain.stateCache.TrieDB().Node(shared[len(shared)-1].Root()); node != nil {
+		t.Fatalf("common-but-old ancestor still cache")
+	}
+	// Import the competitor chain without exceeding the canonical's TD and ensure
+	// we have not processed any of the blocks (protection against malicious blocks)
+	if _, err := chain.InsertChain(competitor[:len(competitor)-2]); err != nil {
+		t.Fatalf("failed to insert competitor chain: %v", err)
+	}
+	for i, block := range competitor[:len(competitor)-2] {
+		if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
+			t.Fatalf("competitor %d: low TD chain became processed", i)
+		}
+	}
+	// Import the head of the competitor chain, triggering the reorg and ensure we
+	// successfully reprocess all the stashed away blocks.
+	if _, err := chain.InsertChain(competitor[len(competitor)-2:]); err != nil {
+		t.Fatalf("failed to finalize competitor chain: %v", err)
+	}
+	for i, block := range competitor[:len(competitor)-triesInMemory] {
+		if node, _ := chain.stateCache.TrieDB().Node(block.Root()); node != nil {
+			t.Fatalf("competitor %d: competing chain state missing", i)
+		}
+	}
+}
diff --git a/core/chain_indexer.go b/core/chain_indexer.go
index 7fb184aaa..158ed8324 100644
--- a/core/chain_indexer.go
+++ b/core/chain_indexer.go
@@ -203,6 +203,9 @@ func (c *ChainIndexer) eventLoop(currentHeader *types.Header, events chan ChainE
 			if header.ParentHash != prevHash {
 				// Reorg to the common ancestor (might not exist in light sync mode, skip reorg then)
 				// TODO(karalabe, zsfelfoldi): This seems a bit brittle, can we detect this case explicitly?
+
+				// TODO(karalabe): This operation is expensive and might block, causing the event system to
+				// potentially also lock up. We need to do with on a different thread somehow.
 				if h := FindCommonAncestor(c.chainDb, prevHeader, header); h != nil {
 					c.newHead(h.Number.Uint64(), true)
 				}
diff --git a/core/chain_makers.go b/core/chain_makers.go
index 5e264a994..6744428ff 100644
--- a/core/chain_makers.go
+++ b/core/chain_makers.go
@@ -166,7 +166,7 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
 	genblock := func(i int, parent *types.Block, statedb *state.StateDB) (*types.Block, types.Receipts) {
 		// TODO(karalabe): This is needed for clique, which depends on multiple blocks.
 		// It's nonetheless ugly to spin up a blockchain here. Get rid of this somehow.
-		blockchain, _ := NewBlockChain(db, config, engine, vm.Config{})
+		blockchain, _ := NewBlockChain(db, nil, config, engine, vm.Config{})
 		defer blockchain.Stop()
 
 		b := &BlockGen{i: i, parent: parent, chain: blocks, chainReader: blockchain, statedb: statedb, config: config, engine: engine}
@@ -192,10 +192,13 @@ func GenerateChain(config *params.ChainConfig, parent *types.Block, engine conse
 		if b.engine != nil {
 			block, _ := b.engine.Finalize(b.chainReader, b.header, statedb, b.txs, b.uncles, b.receipts)
 			// Write state changes to db
-			_, err := statedb.CommitTo(db, config.IsEIP158(b.header.Number))
+			root, err := statedb.Commit(config.IsEIP158(b.header.Number))
 			if err != nil {
 				panic(fmt.Sprintf("state write error: %v", err))
 			}
+			if err := statedb.Database().TrieDB().Commit(root, false); err != nil {
+				panic(fmt.Sprintf("trie write error: %v", err))
+			}
 			return block, b.receipts
 		}
 		return nil, nil
@@ -246,7 +249,7 @@ func newCanonical(engine consensus.Engine, n int, full bool) (ethdb.Database, *B
 	db, _ := ethdb.NewMemDatabase()
 	genesis := gspec.MustCommit(db)
 
-	blockchain, _ := NewBlockChain(db, params.AllEthashProtocolChanges, engine, vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, params.AllEthashProtocolChanges, engine, vm.Config{})
 	// Create and inject the requested chain
 	if n == 0 {
 		return db, blockchain, nil
diff --git a/core/chain_makers_test.go b/core/chain_makers_test.go
index a3b80da29..93be43ddc 100644
--- a/core/chain_makers_test.go
+++ b/core/chain_makers_test.go
@@ -79,7 +79,7 @@ func ExampleGenerateChain() {
 	})
 
 	// Import the chain. This runs all block validation rules.
-	blockchain, _ := NewBlockChain(db, gspec.Config, ethash.NewFaker(), vm.Config{})
+	blockchain, _ := NewBlockChain(db, nil, gspec.Config, ethash.NewFaker(), vm.Config{})
 	defer blockchain.Stop()
 
 	if i, err := blockchain.InsertChain(chain); err != nil {
diff --git a/core/dao_test.go b/core/dao_test.go
index 43e2982a5..e0a3e3ff3 100644
--- a/core/dao_test.go
+++ b/core/dao_test.go
@@ -45,7 +45,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 	proConf.DAOForkBlock = forkBlock
 	proConf.DAOForkSupport = true
 
-	proBc, _ := NewBlockChain(proDb, &proConf, ethash.NewFaker(), vm.Config{})
+	proBc, _ := NewBlockChain(proDb, nil, &proConf, ethash.NewFaker(), vm.Config{})
 	defer proBc.Stop()
 
 	conDb, _ := ethdb.NewMemDatabase()
@@ -55,7 +55,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 	conConf.DAOForkBlock = forkBlock
 	conConf.DAOForkSupport = false
 
-	conBc, _ := NewBlockChain(conDb, &conConf, ethash.NewFaker(), vm.Config{})
+	conBc, _ := NewBlockChain(conDb, nil, &conConf, ethash.NewFaker(), vm.Config{})
 	defer conBc.Stop()
 
 	if _, err := proBc.InsertChain(prefix); err != nil {
@@ -69,7 +69,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 		// Create a pro-fork block, and try to feed into the no-fork chain
 		db, _ = ethdb.NewMemDatabase()
 		gspec.MustCommit(db)
-		bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{})
+		bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{})
 		defer bc.Stop()
 
 		blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64()))
@@ -79,6 +79,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 		if _, err := bc.InsertChain(blocks); err != nil {
 			t.Fatalf("failed to import contra-fork chain for expansion: %v", err)
 		}
+		if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+			t.Fatalf("failed to commit contra-fork head for expansion: %v", err)
+		}
 		blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
 		if _, err := conBc.InsertChain(blocks); err == nil {
 			t.Fatalf("contra-fork chain accepted pro-fork block: %v", blocks[0])
@@ -91,7 +94,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 		// Create a no-fork block, and try to feed into the pro-fork chain
 		db, _ = ethdb.NewMemDatabase()
 		gspec.MustCommit(db)
-		bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{})
+		bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{})
 		defer bc.Stop()
 
 		blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64()))
@@ -101,6 +104,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 		if _, err := bc.InsertChain(blocks); err != nil {
 			t.Fatalf("failed to import pro-fork chain for expansion: %v", err)
 		}
+		if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+			t.Fatalf("failed to commit pro-fork head for expansion: %v", err)
+		}
 		blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
 		if _, err := proBc.InsertChain(blocks); err == nil {
 			t.Fatalf("pro-fork chain accepted contra-fork block: %v", blocks[0])
@@ -114,7 +120,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 	// Verify that contra-forkers accept pro-fork extra-datas after forking finishes
 	db, _ = ethdb.NewMemDatabase()
 	gspec.MustCommit(db)
-	bc, _ := NewBlockChain(db, &conConf, ethash.NewFaker(), vm.Config{})
+	bc, _ := NewBlockChain(db, nil, &conConf, ethash.NewFaker(), vm.Config{})
 	defer bc.Stop()
 
 	blocks := conBc.GetBlocksFromHash(conBc.CurrentBlock().Hash(), int(conBc.CurrentBlock().NumberU64()))
@@ -124,6 +130,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 	if _, err := bc.InsertChain(blocks); err != nil {
 		t.Fatalf("failed to import contra-fork chain for expansion: %v", err)
 	}
+	if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+		t.Fatalf("failed to commit contra-fork head for expansion: %v", err)
+	}
 	blocks, _ = GenerateChain(&proConf, conBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
 	if _, err := conBc.InsertChain(blocks); err != nil {
 		t.Fatalf("contra-fork chain didn't accept pro-fork block post-fork: %v", err)
@@ -131,7 +140,7 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 	// Verify that pro-forkers accept contra-fork extra-datas after forking finishes
 	db, _ = ethdb.NewMemDatabase()
 	gspec.MustCommit(db)
-	bc, _ = NewBlockChain(db, &proConf, ethash.NewFaker(), vm.Config{})
+	bc, _ = NewBlockChain(db, nil, &proConf, ethash.NewFaker(), vm.Config{})
 	defer bc.Stop()
 
 	blocks = proBc.GetBlocksFromHash(proBc.CurrentBlock().Hash(), int(proBc.CurrentBlock().NumberU64()))
@@ -141,6 +150,9 @@ func TestDAOForkRangeExtradata(t *testing.T) {
 	if _, err := bc.InsertChain(blocks); err != nil {
 		t.Fatalf("failed to import pro-fork chain for expansion: %v", err)
 	}
+	if err := bc.stateCache.TrieDB().Commit(bc.CurrentHeader().Root, true); err != nil {
+		t.Fatalf("failed to commit pro-fork head for expansion: %v", err)
+	}
 	blocks, _ = GenerateChain(&conConf, proBc.CurrentBlock(), ethash.NewFaker(), db, 1, func(i int, gen *BlockGen) {})
 	if _, err := proBc.InsertChain(blocks); err != nil {
 		t.Fatalf("pro-fork chain didn't accept contra-fork block post-fork: %v", err)
diff --git a/core/genesis.go b/core/genesis.go
index e22985b80..b6ead2250 100644
--- a/core/genesis.go
+++ b/core/genesis.go
@@ -169,10 +169,9 @@ func SetupGenesisBlock(db ethdb.Database, genesis *Genesis) (*params.ChainConfig
 
 	// Check whether the genesis block is already written.
 	if genesis != nil {
-		block, _ := genesis.ToBlock()
-		hash := block.Hash()
+		hash := genesis.ToBlock(nil).Hash()
 		if hash != stored {
-			return genesis.Config, block.Hash(), &GenesisMismatchError{stored, hash}
+			return genesis.Config, hash, &GenesisMismatchError{stored, hash}
 		}
 	}
 
@@ -220,9 +219,12 @@ func (g *Genesis) configOrDefault(ghash common.Hash) *params.ChainConfig {
 	}
 }
 
-// ToBlock creates the block and state of a genesis specification.
-func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
-	db, _ := ethdb.NewMemDatabase()
+// ToBlock creates the genesis block and writes state of a genesis specification
+// to the given database (or discards it if nil).
+func (g *Genesis) ToBlock(db ethdb.Database) *types.Block {
+	if db == nil {
+		db, _ = ethdb.NewMemDatabase()
+	}
 	statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
 	for addr, account := range g.Alloc {
 		statedb.AddBalance(addr, account.Balance)
@@ -252,19 +254,19 @@ func (g *Genesis) ToBlock() (*types.Block, *state.StateDB) {
 	if g.Difficulty == nil {
 		head.Difficulty = params.GenesisDifficulty
 	}
-	return types.NewBlock(head, nil, nil, nil), statedb
+	statedb.Commit(false)
+	statedb.Database().TrieDB().Commit(root, true)
+
+	return types.NewBlock(head, nil, nil, nil)
 }
 
 // Commit writes the block and state of a genesis specification to the database.
 // The block is committed as the canonical head block.
 func (g *Genesis) Commit(db ethdb.Database) (*types.Block, error) {
-	block, statedb := g.ToBlock()
+	block := g.ToBlock(db)
 	if block.Number().Sign() != 0 {
 		return nil, fmt.Errorf("can't commit genesis block with number > 0")
 	}
-	if _, err := statedb.CommitTo(db, false); err != nil {
-		return nil, fmt.Errorf("cannot write state: %v", err)
-	}
 	if err := WriteTd(db, block.Hash(), block.NumberU64(), g.Difficulty); err != nil {
 		return nil, err
 	}
diff --git a/core/genesis_test.go b/core/genesis_test.go
index 2fe931b24..cd548d4b1 100644
--- a/core/genesis_test.go
+++ b/core/genesis_test.go
@@ -30,11 +30,11 @@ import (
 )
 
 func TestDefaultGenesisBlock(t *testing.T) {
-	block, _ := DefaultGenesisBlock().ToBlock()
+	block := DefaultGenesisBlock().ToBlock(nil)
 	if block.Hash() != params.MainnetGenesisHash {
 		t.Errorf("wrong mainnet genesis hash, got %v, want %v", block.Hash(), params.MainnetGenesisHash)
 	}
-	block, _ = DefaultTestnetGenesisBlock().ToBlock()
+	block = DefaultTestnetGenesisBlock().ToBlock(nil)
 	if block.Hash() != params.TestnetGenesisHash {
 		t.Errorf("wrong testnet genesis hash, got %v, want %v", block.Hash(), params.TestnetGenesisHash)
 	}
@@ -118,7 +118,7 @@ func TestSetupGenesis(t *testing.T) {
 				// Commit the 'old' genesis block with Homestead transition at #2.
 				// Advance to block #4, past the homestead transition block of customg.
 				genesis := oldcustomg.MustCommit(db)
-				bc, _ := NewBlockChain(db, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{})
+				bc, _ := NewBlockChain(db, nil, oldcustomg.Config, ethash.NewFullFaker(), vm.Config{})
 				defer bc.Stop()
 				bc.SetValidator(bproc{})
 				bc.InsertChain(makeBlockChainWithDiff(genesis, []int{2, 3, 4, 5}, 0))
diff --git a/core/state/database.go b/core/state/database.go
index 946625e76..36926ec69 100644
--- a/core/state/database.go
+++ b/core/state/database.go
@@ -40,16 +40,23 @@ const (
 
 // Database wraps access to tries and contract code.
 type Database interface {
-	// Accessing tries:
 	// OpenTrie opens the main account trie.
-	// OpenStorageTrie opens the storage trie of an account.
 	OpenTrie(root common.Hash) (Trie, error)
+
+	// OpenStorageTrie opens the storage trie of an account.
 	OpenStorageTrie(addrHash, root common.Hash) (Trie, error)
-	// Accessing contract code:
-	ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
-	ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
+
 	// CopyTrie returns an independent copy of the given trie.
 	CopyTrie(Trie) Trie
+
+	// ContractCode retrieves a particular contract's code.
+	ContractCode(addrHash, codeHash common.Hash) ([]byte, error)
+
+	// ContractCodeSize retrieves a particular contracts code's size.
+	ContractCodeSize(addrHash, codeHash common.Hash) (int, error)
+
+	// TrieDB retrieves the low level trie database used for data storage.
+	TrieDB() *trie.Database
 }
 
 // Trie is a Ethereum Merkle Trie.
@@ -57,26 +64,33 @@ type Trie interface {
 	TryGet(key []byte) ([]byte, error)
 	TryUpdate(key, value []byte) error
 	TryDelete(key []byte) error
-	CommitTo(trie.DatabaseWriter) (common.Hash, error)
+	Commit(onleaf trie.LeafCallback) (common.Hash, error)
 	Hash() common.Hash
 	NodeIterator(startKey []byte) trie.NodeIterator
 	GetKey([]byte) []byte // TODO(fjl): remove this when SecureTrie is removed
+	Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error
 }
 
 // NewDatabase creates a backing store for state. The returned database is safe for
-// concurrent use and retains cached trie nodes in memory.
+// concurrent use and retains cached trie nodes in memory. The pool is an optional
+// intermediate trie-node memory pool between the low level storage layer and the
+// high level trie abstraction.
 func NewDatabase(db ethdb.Database) Database {
 	csc, _ := lru.New(codeSizeCacheSize)
-	return &cachingDB{db: db, codeSizeCache: csc}
+	return &cachingDB{
+		db:            trie.NewDatabase(db),
+		codeSizeCache: csc,
+	}
 }
 
 type cachingDB struct {
-	db            ethdb.Database
+	db            *trie.Database
 	mu            sync.Mutex
 	pastTries     []*trie.SecureTrie
 	codeSizeCache *lru.Cache
 }
 
+// OpenTrie opens the main account trie.
 func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) {
 	db.mu.Lock()
 	defer db.mu.Unlock()
@@ -105,10 +119,12 @@ func (db *cachingDB) pushTrie(t *trie.SecureTrie) {
 	}
 }
 
+// OpenStorageTrie opens the storage trie of an account.
 func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) {
 	return trie.NewSecure(root, db.db, 0)
 }
 
+// CopyTrie returns an independent copy of the given trie.
 func (db *cachingDB) CopyTrie(t Trie) Trie {
 	switch t := t.(type) {
 	case cachedTrie:
@@ -120,14 +136,16 @@ func (db *cachingDB) CopyTrie(t Trie) Trie {
 	}
 }
 
+// ContractCode retrieves a particular contract's code.
 func (db *cachingDB) ContractCode(addrHash, codeHash common.Hash) ([]byte, error) {
-	code, err := db.db.Get(codeHash[:])
+	code, err := db.db.Node(codeHash)
 	if err == nil {
 		db.codeSizeCache.Add(codeHash, len(code))
 	}
 	return code, err
 }
 
+// ContractCodeSize retrieves a particular contracts code's size.
 func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, error) {
 	if cached, ok := db.codeSizeCache.Get(codeHash); ok {
 		return cached.(int), nil
@@ -139,16 +157,25 @@ func (db *cachingDB) ContractCodeSize(addrHash, codeHash common.Hash) (int, erro
 	return len(code), err
 }
 
+// TrieDB retrieves any intermediate trie-node caching layer.
+func (db *cachingDB) TrieDB() *trie.Database {
+	return db.db
+}
+
 // cachedTrie inserts its trie into a cachingDB on commit.
 type cachedTrie struct {
 	*trie.SecureTrie
 	db *cachingDB
 }
 
-func (m cachedTrie) CommitTo(dbw trie.DatabaseWriter) (common.Hash, error) {
-	root, err := m.SecureTrie.CommitTo(dbw)
+func (m cachedTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) {
+	root, err := m.SecureTrie.Commit(onleaf)
 	if err == nil {
 		m.db.pushTrie(m.SecureTrie)
 	}
 	return root, err
 }
+
+func (m cachedTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
+	return m.SecureTrie.Prove(key, fromLevel, proofDb)
+}
diff --git a/core/state/iterator_test.go b/core/state/iterator_test.go
index ff66ba7a9..9e46c851c 100644
--- a/core/state/iterator_test.go
+++ b/core/state/iterator_test.go
@@ -21,12 +21,13 @@ import (
 	"testing"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/ethdb"
 )
 
 // Tests that the node iterator indeed walks over the entire database contents.
 func TestNodeIteratorCoverage(t *testing.T) {
 	// Create some arbitrary test state to iterate
-	db, mem, root, _ := makeTestState()
+	db, root, _ := makeTestState()
 
 	state, err := New(root, db)
 	if err != nil {
@@ -39,14 +40,18 @@ func TestNodeIteratorCoverage(t *testing.T) {
 			hashes[it.Hash] = struct{}{}
 		}
 	}
-
-	// Cross check the hashes and the database itself
+	// Cross check the iterated hashes and the database/nodepool content
 	for hash := range hashes {
-		if _, err := mem.Get(hash.Bytes()); err != nil {
-			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
+		if _, err := db.TrieDB().Node(hash); err != nil {
+			t.Errorf("failed to retrieve reported node %x", hash)
+		}
+	}
+	for _, hash := range db.TrieDB().Nodes() {
+		if _, ok := hashes[hash]; !ok {
+			t.Errorf("state entry not reported %x", hash)
 		}
 	}
-	for _, key := range mem.Keys() {
+	for _, key := range db.TrieDB().DiskDB().(*ethdb.MemDatabase).Keys() {
 		if bytes.HasPrefix(key, []byte("secure-key-")) {
 			continue
 		}
diff --git a/core/state/state_object.go b/core/state/state_object.go
index b2378c69c..b2112bfae 100644
--- a/core/state/state_object.go
+++ b/core/state/state_object.go
@@ -25,7 +25,6 @@ import (
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/rlp"
-	"github.com/ethereum/go-ethereum/trie"
 )
 
 var emptyCodeHash = crypto.Keccak256(nil)
@@ -238,12 +237,12 @@ func (self *stateObject) updateRoot(db Database) {
 
 // CommitTrie the storage trie of the object to dwb.
 // This updates the trie root.
-func (self *stateObject) CommitTrie(db Database, dbw trie.DatabaseWriter) error {
+func (self *stateObject) CommitTrie(db Database) error {
 	self.updateTrie(db)
 	if self.dbErr != nil {
 		return self.dbErr
 	}
-	root, err := self.trie.CommitTo(dbw)
+	root, err := self.trie.Commit(nil)
 	if err == nil {
 		self.data.Root = root
 	}
diff --git a/core/state/state_test.go b/core/state/state_test.go
index bbae3685b..6d42d63d8 100644
--- a/core/state/state_test.go
+++ b/core/state/state_test.go
@@ -48,7 +48,7 @@ func (s *StateSuite) TestDump(c *checker.C) {
 	// write some of them to the trie
 	s.state.updateStateObject(obj1)
 	s.state.updateStateObject(obj2)
-	s.state.CommitTo(s.db, false)
+	s.state.Commit(false)
 
 	// check that dump contains the state objects that are in trie
 	got := string(s.state.Dump())
@@ -97,7 +97,7 @@ func (s *StateSuite) TestNull(c *checker.C) {
 	//value := common.FromHex("0x823140710bf13990e4500136726d8b55")
 	var value common.Hash
 	s.state.SetState(address, common.Hash{}, value)
-	s.state.CommitTo(s.db, false)
+	s.state.Commit(false)
 	value = s.state.GetState(address, common.Hash{})
 	if !common.EmptyHash(value) {
 		c.Errorf("expected empty hash. got %x", value)
@@ -155,7 +155,7 @@ func TestSnapshot2(t *testing.T) {
 	so0.deleted = false
 	state.setStateObject(so0)
 
-	root, _ := state.CommitTo(db, false)
+	root, _ := state.Commit(false)
 	state.Reset(root)
 
 	// and one with deleted == true
diff --git a/core/state/statedb.go b/core/state/statedb.go
index 8e29104d5..776693e24 100644
--- a/core/state/statedb.go
+++ b/core/state/statedb.go
@@ -36,6 +36,14 @@ type revision struct {
 	journalIndex int
 }
 
+var (
+	// emptyState is the known hash of an empty state trie entry.
+	emptyState = crypto.Keccak256Hash(nil)
+
+	// emptyCode is the known hash of the empty EVM bytecode.
+	emptyCode = crypto.Keccak256Hash(nil)
+)
+
 // StateDBs within the ethereum protocol are used to store anything
 // within the merkle trie. StateDBs take care of caching and storing
 // nested states. It's the general query interface to retrieve:
@@ -235,6 +243,11 @@ func (self *StateDB) GetState(a common.Address, b common.Hash) common.Hash {
 	return common.Hash{}
 }
 
+// Database retrieves the low level database supporting the lower level trie ops.
+func (self *StateDB) Database() Database {
+	return self.db
+}
+
 // StorageTrie returns the storage trie of an account.
 // The return value is a copy and is nil for non-existent accounts.
 func (self *StateDB) StorageTrie(a common.Address) Trie {
@@ -568,8 +581,8 @@ func (s *StateDB) clearJournalAndRefund() {
 	s.refund = 0
 }
 
-// CommitTo writes the state to the given database.
-func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (root common.Hash, err error) {
+// Commit writes the state to the underlying in-memory trie database.
+func (s *StateDB) Commit(deleteEmptyObjects bool) (root common.Hash, err error) {
 	defer s.clearJournalAndRefund()
 
 	// Commit objects to the trie.
@@ -583,13 +596,11 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
 		case isDirty:
 			// Write any contract code associated with the state object
 			if stateObject.code != nil && stateObject.dirtyCode {
-				if err := dbw.Put(stateObject.CodeHash(), stateObject.code); err != nil {
-					return common.Hash{}, err
-				}
+				s.db.TrieDB().Insert(common.BytesToHash(stateObject.CodeHash()), stateObject.code)
 				stateObject.dirtyCode = false
 			}
 			// Write any storage changes in the state object to its storage trie.
-			if err := stateObject.CommitTrie(s.db, dbw); err != nil {
+			if err := stateObject.CommitTrie(s.db); err != nil {
 				return common.Hash{}, err
 			}
 			// Update the object in the main account trie.
@@ -598,7 +609,20 @@ func (s *StateDB) CommitTo(dbw trie.DatabaseWriter, deleteEmptyObjects bool) (ro
 		delete(s.stateObjectsDirty, addr)
 	}
 	// Write trie changes.
-	root, err = s.trie.CommitTo(dbw)
+	root, err = s.trie.Commit(func(leaf []byte, parent common.Hash) error {
+		var account Account
+		if err := rlp.DecodeBytes(leaf, &account); err != nil {
+			return nil
+		}
+		if account.Root != emptyState {
+			s.db.TrieDB().Reference(account.Root, parent)
+		}
+		code := common.BytesToHash(account.CodeHash)
+		if code != emptyCode {
+			s.db.TrieDB().Reference(code, parent)
+		}
+		return nil
+	})
 	log.Debug("Trie cache stats after commit", "misses", trie.CacheMisses(), "unloads", trie.CacheUnloads())
 	return root, err
 }
diff --git a/core/state/statedb_test.go b/core/state/statedb_test.go
index 5c80e3aa5..d9e3d9b79 100644
--- a/core/state/statedb_test.go
+++ b/core/state/statedb_test.go
@@ -97,10 +97,10 @@ func TestIntermediateLeaks(t *testing.T) {
 	}
 
 	// Commit and cross check the databases.
-	if _, err := transState.CommitTo(transDb, false); err != nil {
+	if _, err := transState.Commit(false); err != nil {
 		t.Fatalf("failed to commit transition state: %v", err)
 	}
-	if _, err := finalState.CommitTo(finalDb, false); err != nil {
+	if _, err := finalState.Commit(false); err != nil {
 		t.Fatalf("failed to commit final state: %v", err)
 	}
 	for _, key := range finalDb.Keys() {
@@ -122,8 +122,8 @@ func TestIntermediateLeaks(t *testing.T) {
 // https://github.com/ethereum/go-ethereum/pull/15549.
 func TestCopy(t *testing.T) {
 	// Create a random state test to copy and modify "independently"
-	mem, _ := ethdb.NewMemDatabase()
-	orig, _ := New(common.Hash{}, NewDatabase(mem))
+	db, _ := ethdb.NewMemDatabase()
+	orig, _ := New(common.Hash{}, NewDatabase(db))
 
 	for i := byte(0); i < 255; i++ {
 		obj := orig.GetOrNewStateObject(common.BytesToAddress([]byte{i}))
@@ -346,11 +346,10 @@ func (test *snapshotTest) run() bool {
 		}
 		action.fn(action, state)
 	}
-
 	// Revert all snapshots in reverse order. Each revert must yield a state
 	// that is equivalent to fresh state with all actions up the snapshot applied.
 	for sindex--; sindex >= 0; sindex-- {
-		checkstate, _ := New(common.Hash{}, NewDatabase(db))
+		checkstate, _ := New(common.Hash{}, state.Database())
 		for _, action := range test.actions[:test.snapshots[sindex]] {
 			action.fn(action, checkstate)
 		}
@@ -409,7 +408,7 @@ func (test *snapshotTest) checkEqual(state, checkstate *StateDB) error {
 
 func (s *StateSuite) TestTouchDelete(c *check.C) {
 	s.state.GetOrNewStateObject(common.Address{})
-	root, _ := s.state.CommitTo(s.db, false)
+	root, _ := s.state.Commit(false)
 	s.state.Reset(root)
 
 	snapshot := s.state.Snapshot()
@@ -417,7 +416,6 @@ func (s *StateSuite) TestTouchDelete(c *check.C) {
 	if len(s.state.stateObjectsDirty) != 1 {
 		c.Fatal("expected one dirty state object")
 	}
-
 	s.state.RevertToSnapshot(snapshot)
 	if len(s.state.stateObjectsDirty) != 0 {
 		c.Fatal("expected no dirty state object")
diff --git a/core/state/sync_test.go b/core/state/sync_test.go
index 06c572ea6..8f14a44e7 100644
--- a/core/state/sync_test.go
+++ b/core/state/sync_test.go
@@ -36,10 +36,10 @@ type testAccount struct {
 }
 
 // makeTestState create a sample test state to test node-wise reconstruction.
-func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount) {
+func makeTestState() (Database, common.Hash, []*testAccount) {
 	// Create an empty state
-	mem, _ := ethdb.NewMemDatabase()
-	db := NewDatabase(mem)
+	diskdb, _ := ethdb.NewMemDatabase()
+	db := NewDatabase(diskdb)
 	state, _ := New(common.Hash{}, db)
 
 	// Fill it with some arbitrary data
@@ -61,10 +61,10 @@ func makeTestState() (Database, *ethdb.MemDatabase, common.Hash, []*testAccount)
 		state.updateStateObject(obj)
 		accounts = append(accounts, acc)
 	}
-	root, _ := state.CommitTo(mem, false)
+	root, _ := state.Commit(false)
 
 	// Return the generated state
-	return db, mem, root, accounts
+	return db, root, accounts
 }
 
 // checkStateAccounts cross references a reconstructed state with an expected
@@ -96,7 +96,7 @@ func checkTrieConsistency(db ethdb.Database, root common.Hash) error {
 	if v, _ := db.Get(root[:]); v == nil {
 		return nil // Consider a non existent state consistent.
 	}
-	trie, err := trie.New(root, db)
+	trie, err := trie.New(root, trie.NewDatabase(db))
 	if err != nil {
 		return err
 	}
@@ -138,7 +138,7 @@ func TestIterativeStateSyncBatched(t *testing.T)    { testIterativeStateSync(t,
 
 func testIterativeStateSync(t *testing.T, batch int) {
 	// Create a random state to copy
-	_, srcMem, srcRoot, srcAccounts := makeTestState()
+	srcDb, srcRoot, srcAccounts := makeTestState()
 
 	// Create a destination state and sync with the scheduler
 	dstDb, _ := ethdb.NewMemDatabase()
@@ -148,9 +148,9 @@ func testIterativeStateSync(t *testing.T, batch int) {
 	for len(queue) > 0 {
 		results := make([]trie.SyncResult, len(queue))
 		for i, hash := range queue {
-			data, err := srcMem.Get(hash.Bytes())
+			data, err := srcDb.TrieDB().Node(hash)
 			if err != nil {
-				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+				t.Fatalf("failed to retrieve node data for %x", hash)
 			}
 			results[i] = trie.SyncResult{Hash: hash, Data: data}
 		}
@@ -170,7 +170,7 @@ func testIterativeStateSync(t *testing.T, batch int) {
 // partial results are returned, and the others sent only later.
 func TestIterativeDelayedStateSync(t *testing.T) {
 	// Create a random state to copy
-	_, srcMem, srcRoot, srcAccounts := makeTestState()
+	srcDb, srcRoot, srcAccounts := makeTestState()
 
 	// Create a destination state and sync with the scheduler
 	dstDb, _ := ethdb.NewMemDatabase()
@@ -181,9 +181,9 @@ func TestIterativeDelayedStateSync(t *testing.T) {
 		// Sync only half of the scheduled nodes
 		results := make([]trie.SyncResult, len(queue)/2+1)
 		for i, hash := range queue[:len(results)] {
-			data, err := srcMem.Get(hash.Bytes())
+			data, err := srcDb.TrieDB().Node(hash)
 			if err != nil {
-				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+				t.Fatalf("failed to retrieve node data for %x", hash)
 			}
 			results[i] = trie.SyncResult{Hash: hash, Data: data}
 		}
@@ -207,7 +207,7 @@ func TestIterativeRandomStateSyncBatched(t *testing.T)    { testIterativeRandomS
 
 func testIterativeRandomStateSync(t *testing.T, batch int) {
 	// Create a random state to copy
-	_, srcMem, srcRoot, srcAccounts := makeTestState()
+	srcDb, srcRoot, srcAccounts := makeTestState()
 
 	// Create a destination state and sync with the scheduler
 	dstDb, _ := ethdb.NewMemDatabase()
@@ -221,9 +221,9 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
 		// Fetch all the queued nodes in a random order
 		results := make([]trie.SyncResult, 0, len(queue))
 		for hash := range queue {
-			data, err := srcMem.Get(hash.Bytes())
+			data, err := srcDb.TrieDB().Node(hash)
 			if err != nil {
-				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+				t.Fatalf("failed to retrieve node data for %x", hash)
 			}
 			results = append(results, trie.SyncResult{Hash: hash, Data: data})
 		}
@@ -247,7 +247,7 @@ func testIterativeRandomStateSync(t *testing.T, batch int) {
 // partial results are returned (Even those randomly), others sent only later.
 func TestIterativeRandomDelayedStateSync(t *testing.T) {
 	// Create a random state to copy
-	_, srcMem, srcRoot, srcAccounts := makeTestState()
+	srcDb, srcRoot, srcAccounts := makeTestState()
 
 	// Create a destination state and sync with the scheduler
 	dstDb, _ := ethdb.NewMemDatabase()
@@ -263,9 +263,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
 		for hash := range queue {
 			delete(queue, hash)
 
-			data, err := srcMem.Get(hash.Bytes())
+			data, err := srcDb.TrieDB().Node(hash)
 			if err != nil {
-				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+				t.Fatalf("failed to retrieve node data for %x", hash)
 			}
 			results = append(results, trie.SyncResult{Hash: hash, Data: data})
 
@@ -292,9 +292,9 @@ func TestIterativeRandomDelayedStateSync(t *testing.T) {
 // the database.
 func TestIncompleteStateSync(t *testing.T) {
 	// Create a random state to copy
-	_, srcMem, srcRoot, srcAccounts := makeTestState()
+	srcDb, srcRoot, srcAccounts := makeTestState()
 
-	checkTrieConsistency(srcMem, srcRoot)
+	checkTrieConsistency(srcDb.TrieDB().DiskDB().(ethdb.Database), srcRoot)
 
 	// Create a destination state and sync with the scheduler
 	dstDb, _ := ethdb.NewMemDatabase()
@@ -306,9 +306,9 @@ func TestIncompleteStateSync(t *testing.T) {
 		// Fetch a batch of state nodes
 		results := make([]trie.SyncResult, len(queue))
 		for i, hash := range queue {
-			data, err := srcMem.Get(hash.Bytes())
+			data, err := srcDb.TrieDB().Node(hash)
 			if err != nil {
-				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
+				t.Fatalf("failed to retrieve node data for %x", hash)
 			}
 			results[i] = trie.SyncResult{Hash: hash, Data: data}
 		}
diff --git a/core/tx_pool_test.go b/core/tx_pool_test.go
index cd11f2ba2..158b9776b 100644
--- a/core/tx_pool_test.go
+++ b/core/tx_pool_test.go
@@ -78,8 +78,8 @@ func pricedTransaction(nonce uint64, gaslimit uint64, gasprice *big.Int, key *ec
 }
 
 func setupTxPool() (*TxPool, *ecdsa.PrivateKey) {
-	db, _ := ethdb.NewMemDatabase()
-	statedb, _ := state.New(common.Hash{}, state.NewDatabase(db))
+	diskdb, _ := ethdb.NewMemDatabase()
+	statedb, _ := state.New(common.Hash{}, state.NewDatabase(diskdb))
 	blockchain := &testBlockChain{statedb, 1000000, new(event.Feed)}
 
 	key, _ := crypto.GenerateKey()
diff --git a/core/types/block.go b/core/types/block.go
index ffe317342..92b868d9d 100644
--- a/core/types/block.go
+++ b/core/types/block.go
@@ -25,6 +25,7 @@ import (
 	"sort"
 	"sync/atomic"
 	"time"
+	"unsafe"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common/hexutil"
@@ -121,6 +122,12 @@ func (h *Header) HashNoNonce() common.Hash {
 	})
 }
 
+// Size returns the approximate memory used by all internal contents. It is used
+// to approximate and limit the memory consumption of various caches.
+func (h *Header) Size() common.StorageSize {
+	return common.StorageSize(unsafe.Sizeof(*h)) + common.StorageSize(len(h.Extra)+(h.Difficulty.BitLen()+h.Number.BitLen()+h.Time.BitLen())/8)
+}
+
 func rlpHash(x interface{}) (h common.Hash) {
 	hw := sha3.NewKeccak256()
 	rlp.Encode(hw, x)
@@ -322,6 +329,8 @@ func (b *Block) HashNoNonce() common.Hash {
 	return b.header.HashNoNonce()
 }
 
+// Size returns the true RLP encoded storage size of the block, either by encoding
+// and returning it, or returning a previsouly cached value.
 func (b *Block) Size() common.StorageSize {
 	if size := b.size.Load(); size != nil {
 		return size.(common.StorageSize)
diff --git a/core/types/receipt.go b/core/types/receipt.go
index 208d54aaa..f945f6f6a 100644
--- a/core/types/receipt.go
+++ b/core/types/receipt.go
@@ -20,6 +20,7 @@ import (
 	"bytes"
 	"fmt"
 	"io"
+	"unsafe"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common/hexutil"
@@ -136,6 +137,18 @@ func (r *Receipt) statusEncoding() []byte {
 	return r.PostState
 }
 
+// Size returns the approximate memory used by all internal contents. It is used
+// to approximate and limit the memory consumption of various caches.
+func (r *Receipt) Size() common.StorageSize {
+	size := common.StorageSize(unsafe.Sizeof(*r)) + common.StorageSize(len(r.PostState))
+
+	size += common.StorageSize(len(r.Logs)) * common.StorageSize(unsafe.Sizeof(Log{}))
+	for _, log := range r.Logs {
+		size += common.StorageSize(len(log.Topics)*common.HashLength + len(log.Data))
+	}
+	return size
+}
+
 // String implements the Stringer interface.
 func (r *Receipt) String() string {
 	if len(r.PostState) == 0 {
diff --git a/core/types/transaction.go b/core/types/transaction.go
index a7ed211e4..5660582ba 100644
--- a/core/types/transaction.go
+++ b/core/types/transaction.go
@@ -206,6 +206,8 @@ func (tx *Transaction) Hash() common.Hash {
 	return v
 }
 
+// Size returns the true RLP encoded storage size of the transaction, either by
+// encoding and returning it, or returning a previsouly cached value.
 func (tx *Transaction) Size() common.StorageSize {
 	if size := tx.size.Load(); size != nil {
 		return size.(common.StorageSize)
diff --git a/eth/api.go b/eth/api.go
index 0db3eb554..a345b57e4 100644
--- a/eth/api.go
+++ b/eth/api.go
@@ -462,11 +462,11 @@ func (api *PrivateDebugAPI) getModifiedAccounts(startBlock, endBlock *types.Bloc
 		return nil, fmt.Errorf("start block height (%d) must be less than end block height (%d)", startBlock.Number().Uint64(), endBlock.Number().Uint64())
 	}
 
-	oldTrie, err := trie.NewSecure(startBlock.Root(), api.eth.chainDb, 0)
+	oldTrie, err := trie.NewSecure(startBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0)
 	if err != nil {
 		return nil, err
 	}
-	newTrie, err := trie.NewSecure(endBlock.Root(), api.eth.chainDb, 0)
+	newTrie, err := trie.NewSecure(endBlock.Root(), trie.NewDatabase(api.eth.chainDb), 0)
 	if err != nil {
 		return nil, err
 	}
diff --git a/eth/api_tracer.go b/eth/api_tracer.go
index d49f077ae..07c4457bc 100644
--- a/eth/api_tracer.go
+++ b/eth/api_tracer.go
@@ -24,7 +24,6 @@ import (
 	"io/ioutil"
 	"runtime"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -34,7 +33,6 @@ import (
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/core/vm"
 	"github.com/ethereum/go-ethereum/eth/tracers"
-	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/internal/ethapi"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/rlp"
@@ -72,6 +70,7 @@ type txTraceResult struct {
 type blockTraceTask struct {
 	statedb *state.StateDB   // Intermediate state prepped for tracing
 	block   *types.Block     // Block to trace the transactions from
+	rootref common.Hash      // Trie root reference held for this task
 	results []*txTraceResult // Trace results procudes by the task
 }
 
@@ -90,59 +89,6 @@ type txTraceTask struct {
 	index   int            // Transaction offset in the block
 }
 
-// ephemeralDatabase is a memory wrapper around a proper database, which acts as
-// an ephemeral write layer. This construct is used by the chain tracer to write
-// state tries for intermediate blocks without serializing to disk, but at the
-// same time to allow disk fallback for reads that do no hit the memory layer.
-type ephemeralDatabase struct {
-	diskdb ethdb.Database     // Persistent disk database to fall back to with reads
-	memdb  *ethdb.MemDatabase // Ephemeral memory database for primary reads and writes
-}
-
-func (db *ephemeralDatabase) Put(key []byte, value []byte) error { return db.memdb.Put(key, value) }
-func (db *ephemeralDatabase) Delete(key []byte) error            { return errors.New("delete not supported") }
-func (db *ephemeralDatabase) Close()                             { db.memdb.Close() }
-func (db *ephemeralDatabase) NewBatch() ethdb.Batch {
-	return db.memdb.NewBatch()
-}
-func (db *ephemeralDatabase) Has(key []byte) (bool, error) {
-	if has, _ := db.memdb.Has(key); has {
-		return has, nil
-	}
-	return db.diskdb.Has(key)
-}
-func (db *ephemeralDatabase) Get(key []byte) ([]byte, error) {
-	if blob, _ := db.memdb.Get(key); blob != nil {
-		return blob, nil
-	}
-	return db.diskdb.Get(key)
-}
-
-// Prune does a state sync into a new memory write layer and replaces the old one.
-// This allows us to discard entries that are no longer referenced from the current
-// state.
-func (db *ephemeralDatabase) Prune(root common.Hash) {
-	// Pull the still relevant state data into memory
-	sync := state.NewStateSync(root, db.diskdb)
-	for sync.Pending() > 0 {
-		hash := sync.Missing(1)[0]
-
-		// Move the next trie node from the memory layer into a sync struct
-		node, err := db.memdb.Get(hash[:])
-		if err != nil {
-			panic(err) // memdb must have the data
-		}
-		if _, _, err := sync.Process([]trie.SyncResult{{Hash: hash, Data: node}}); err != nil {
-			panic(err) // it's not possible to fail processing a node
-		}
-	}
-	// Discard the old memory layer and write a new one
-	db.memdb, _ = ethdb.NewMemDatabaseWithCap(db.memdb.Len())
-	if _, err := sync.Commit(db); err != nil {
-		panic(err) // writing into a memdb cannot fail
-	}
-}
-
 // TraceChain returns the structured logs created during the execution of EVM
 // between two blocks (excluding start) and returns them as a JSON object.
 func (api *PrivateDebugAPI) TraceChain(ctx context.Context, start, end rpc.BlockNumber, config *TraceConfig) (*rpc.Subscription, error) {
@@ -188,19 +134,15 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 
 	// Ensure we have a valid starting state before doing any work
 	origin := start.NumberU64()
+	database := state.NewDatabase(api.eth.ChainDb())
 
-	memdb, _ := ethdb.NewMemDatabase()
-	db := &ephemeralDatabase{
-		diskdb: api.eth.ChainDb(),
-		memdb:  memdb,
-	}
 	if number := start.NumberU64(); number > 0 {
 		start = api.eth.blockchain.GetBlock(start.ParentHash(), start.NumberU64()-1)
 		if start == nil {
 			return nil, fmt.Errorf("parent block #%d not found", number-1)
 		}
 	}
-	statedb, err := state.New(start.Root(), state.NewDatabase(db))
+	statedb, err := state.New(start.Root(), database)
 	if err != nil {
 		// If the starting state is missing, allow some number of blocks to be reexecuted
 		reexec := defaultTraceReexec
@@ -213,7 +155,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 			if start == nil {
 				break
 			}
-			if statedb, err = state.New(start.Root(), state.NewDatabase(db)); err == nil {
+			if statedb, err = state.New(start.Root(), database); err == nil {
 				break
 			}
 		}
@@ -256,7 +198,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 					res, err := api.traceTx(ctx, msg, vmctx, task.statedb, config)
 					if err != nil {
 						task.results[i] = &txTraceResult{Error: err.Error()}
-						log.Warn("Tracing failed", "err", err)
+						log.Warn("Tracing failed", "hash", tx.Hash(), "block", task.block.NumberU64(), "err", err)
 						break
 					}
 					task.statedb.DeleteSuicides()
@@ -273,7 +215,6 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 	}
 	// Start a goroutine to feed all the blocks into the tracers
 	begin := time.Now()
-	complete := start.NumberU64()
 
 	go func() {
 		var (
@@ -281,6 +222,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 			number uint64
 			traced uint64
 			failed error
+			proot  common.Hash
 		)
 		// Ensure everything is properly cleaned up on any exit path
 		defer func() {
@@ -308,7 +250,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 			// Print progress logs if long enough time elapsed
 			if time.Since(logged) > 8*time.Second {
 				if number > origin {
-					log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin))
+					log.Info("Tracing chain segment", "start", origin, "end", end.NumberU64(), "current", number, "transactions", traced, "elapsed", time.Since(begin), "memory", database.TrieDB().Size())
 				} else {
 					log.Info("Preparing state for chain trace", "block", number, "start", origin, "elapsed", time.Since(begin))
 				}
@@ -325,13 +267,11 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 				txs := block.Transactions()
 
 				select {
-				case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, results: make([]*txTraceResult, len(txs))}:
+				case tasks <- &blockTraceTask{statedb: statedb.Copy(), block: block, rootref: proot, results: make([]*txTraceResult, len(txs))}:
 				case <-notifier.Closed():
 					return
 				}
 				traced += uint64(len(txs))
-			} else {
-				atomic.StoreUint64(&complete, number)
 			}
 			// Generate the next state snapshot fast without tracing
 			_, _, _, err := api.eth.blockchain.Processor().Process(block, statedb, vm.Config{})
@@ -340,7 +280,7 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 				break
 			}
 			// Finalize the state so any modifications are written to the trie
-			root, err := statedb.CommitTo(db, true)
+			root, err := statedb.Commit(true)
 			if err != nil {
 				failed = err
 				break
@@ -349,26 +289,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 				failed = err
 				break
 			}
-			// After every N blocks, prune the database to only retain relevant data
-			if (number-start.NumberU64())%4096 == 0 {
-				// Wait until currently pending trace jobs finish
-				for atomic.LoadUint64(&complete) != number {
-					select {
-					case <-time.After(100 * time.Millisecond):
-					case <-notifier.Closed():
-						return
-					}
-				}
-				// No more concurrent access at this point, prune the database
-				var (
-					nodes = db.memdb.Len()
-					start = time.Now()
-				)
-				db.Prune(root)
-				log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(start))
-
-				statedb, _ = state.New(root, state.NewDatabase(db))
+			// Reference the trie twice, once for us, once for the trancer
+			database.TrieDB().Reference(root, common.Hash{})
+			if number >= origin {
+				database.TrieDB().Reference(root, common.Hash{})
 			}
+			// Dereference all past tries we ourselves are done working with
+			database.TrieDB().Dereference(proot, common.Hash{})
+			proot = root
 		}
 	}()
 
@@ -387,12 +315,14 @@ func (api *PrivateDebugAPI) traceChain(ctx context.Context, start, end *types.Bl
 			}
 			done[uint64(result.Block)] = result
 
+			// Dereference any paret tries held in memory by this task
+			database.TrieDB().Dereference(res.rootref, common.Hash{})
+
 			// Stream completed traces to the user, aborting on the first error
 			for result, ok := done[next]; ok; result, ok = done[next] {
 				if len(result.Traces) > 0 || next == end.NumberU64() {
 					notifier.Notify(sub.ID, result)
 				}
-				atomic.StoreUint64(&complete, next)
 				delete(done, next)
 				next++
 			}
@@ -544,18 +474,14 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
 	}
 	// Otherwise try to reexec blocks until we find a state or reach our limit
 	origin := block.NumberU64()
+	database := state.NewDatabase(api.eth.ChainDb())
 
-	memdb, _ := ethdb.NewMemDatabase()
-	db := &ephemeralDatabase{
-		diskdb: api.eth.ChainDb(),
-		memdb:  memdb,
-	}
 	for i := uint64(0); i < reexec; i++ {
 		block = api.eth.blockchain.GetBlock(block.ParentHash(), block.NumberU64()-1)
 		if block == nil {
 			break
 		}
-		if statedb, err = state.New(block.Root(), state.NewDatabase(db)); err == nil {
+		if statedb, err = state.New(block.Root(), database); err == nil {
 			break
 		}
 	}
@@ -571,6 +497,7 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
 	var (
 		start  = time.Now()
 		logged time.Time
+		proot  common.Hash
 	)
 	for block.NumberU64() < origin {
 		// Print progress logs if long enough time elapsed
@@ -587,26 +514,18 @@ func (api *PrivateDebugAPI) computeStateDB(block *types.Block, reexec uint64) (*
 			return nil, err
 		}
 		// Finalize the state so any modifications are written to the trie
-		root, err := statedb.CommitTo(db, true)
+		root, err := statedb.Commit(true)
 		if err != nil {
 			return nil, err
 		}
 		if err := statedb.Reset(root); err != nil {
 			return nil, err
 		}
-		// After every N blocks, prune the database to only retain relevant data
-		if block.NumberU64()%4096 == 0 || block.NumberU64() == origin {
-			var (
-				nodes = db.memdb.Len()
-				begin = time.Now()
-			)
-			db.Prune(root)
-			log.Info("Pruned tracer state entries", "deleted", nodes-db.memdb.Len(), "left", db.memdb.Len(), "elapsed", time.Since(begin))
-
-			statedb, _ = state.New(root, state.NewDatabase(db))
-		}
+		database.TrieDB().Reference(root, common.Hash{})
+		database.TrieDB().Dereference(proot, common.Hash{})
+		proot = root
 	}
-	log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start))
+	log.Info("Historical state regenerated", "block", block.NumberU64(), "elapsed", time.Since(start), "size", database.TrieDB().Size())
 	return statedb, nil
 }
 
diff --git a/eth/backend.go b/eth/backend.go
index bcd724c0c..94aad2310 100644
--- a/eth/backend.go
+++ b/eth/backend.go
@@ -144,9 +144,11 @@ func New(ctx *node.ServiceContext, config *Config) (*Ethereum, error) {
 		}
 		core.WriteBlockChainVersion(chainDb, core.BlockChainVersion)
 	}
-
-	vmConfig := vm.Config{EnablePreimageRecording: config.EnablePreimageRecording}
-	eth.blockchain, err = core.NewBlockChain(chainDb, eth.chainConfig, eth.engine, vmConfig)
+	var (
+		vmConfig    = vm.Config{EnablePreimageRecording: config.EnablePreimageRecording}
+		cacheConfig = &core.CacheConfig{Disabled: config.NoPruning, TrieNodeLimit: config.TrieCache, TrieTimeLimit: config.TrieTimeout}
+	)
+	eth.blockchain, err = core.NewBlockChain(chainDb, cacheConfig, eth.chainConfig, eth.engine, vmConfig)
 	if err != nil {
 		return nil, err
 	}
diff --git a/eth/config.go b/eth/config.go
index 2158c71ba..dd7f42c7d 100644
--- a/eth/config.go
+++ b/eth/config.go
@@ -22,6 +22,7 @@ import (
 	"os/user"
 	"path/filepath"
 	"runtime"
+	"time"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/common/hexutil"
@@ -44,7 +45,9 @@ var DefaultConfig = Config{
 	},
 	NetworkId:     1,
 	LightPeers:    100,
-	DatabaseCache: 128,
+	DatabaseCache: 768,
+	TrieCache:     256,
+	TrieTimeout:   5 * time.Minute,
 	GasPrice:      big.NewInt(18 * params.Shannon),
 
 	TxPool: core.DefaultTxPoolConfig,
@@ -78,6 +81,7 @@ type Config struct {
 	// Protocol options
 	NetworkId uint64 // Network ID to use for selecting peers to connect to
 	SyncMode  downloader.SyncMode
+	NoPruning bool
 
 	// Light client options
 	LightServ  int `toml:",omitempty"` // Maximum percentage of time allowed for serving LES requests
@@ -87,6 +91,8 @@ type Config struct {
 	SkipBcVersionCheck bool `toml:"-"`
 	DatabaseHandles    int  `toml:"-"`
 	DatabaseCache      int
+	TrieCache          int
+	TrieTimeout        time.Duration
 
 	// Mining-related options
 	Etherbase    common.Address `toml:",omitempty"`
diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go
index 746c6a402..7f490d9e9 100644
--- a/eth/downloader/downloader.go
+++ b/eth/downloader/downloader.go
@@ -18,10 +18,8 @@
 package downloader
 
 import (
-	"crypto/rand"
 	"errors"
 	"fmt"
-	"math"
 	"math/big"
 	"sync"
 	"sync/atomic"
@@ -61,12 +59,11 @@ var (
 	maxHeadersProcess = 2048      // Number of header download results to import at once into the chain
 	maxResultsProcess = 2048      // Number of content download results to import at once into the chain
 
-	fsHeaderCheckFrequency = 100        // Verification frequency of the downloaded headers during fast sync
-	fsHeaderSafetyNet      = 2048       // Number of headers to discard in case a chain violation is detected
-	fsHeaderForceVerify    = 24         // Number of headers to verify before and after the pivot to accept it
-	fsPivotInterval        = 256        // Number of headers out of which to randomize the pivot point
-	fsMinFullBlocks        = 64         // Number of blocks to retrieve fully even in fast sync
-	fsCriticalTrials       = uint32(32) // Number of times to retry in the cricical section before bailing
+	fsHeaderCheckFrequency = 100             // Verification frequency of the downloaded headers during fast sync
+	fsHeaderSafetyNet      = 2048            // Number of headers to discard in case a chain violation is detected
+	fsHeaderForceVerify    = 24              // Number of headers to verify before and after the pivot to accept it
+	fsHeaderContCheck      = 3 * time.Second // Time interval to check for header continuations during state download
+	fsMinFullBlocks        = 64              // Number of blocks to retrieve fully even in fast sync
 )
 
 var (
@@ -102,9 +99,6 @@ type Downloader struct {
 	peers   *peerSet // Set of active peers from which download can proceed
 	stateDB ethdb.Database
 
-	fsPivotLock  *types.Header // Pivot header on critical section entry (cannot change between retries)
-	fsPivotFails uint32        // Number of subsequent fast sync failures in the critical section
-
 	rttEstimate   uint64 // Round trip time to target for download requests
 	rttConfidence uint64 // Confidence in the estimated RTT (unit: millionths to allow atomic ops)
 
@@ -124,6 +118,7 @@ type Downloader struct {
 	synchroniseMock func(id string, hash common.Hash) error // Replacement for synchronise during testing
 	synchronising   int32
 	notified        int32
+	committed       int32
 
 	// Channels
 	headerCh      chan dataPack        // [eth/62] Channel receiving inbound block headers
@@ -156,7 +151,7 @@ type Downloader struct {
 // LightChain encapsulates functions required to synchronise a light chain.
 type LightChain interface {
 	// HasHeader verifies a header's presence in the local chain.
-	HasHeader(h common.Hash, number uint64) bool
+	HasHeader(common.Hash, uint64) bool
 
 	// GetHeaderByHash retrieves a header from the local chain.
 	GetHeaderByHash(common.Hash) *types.Header
@@ -179,7 +174,7 @@ type BlockChain interface {
 	LightChain
 
 	// HasBlockAndState verifies block and associated states' presence in the local chain.
-	HasBlockAndState(common.Hash) bool
+	HasBlockAndState(common.Hash, uint64) bool
 
 	// GetBlockByHash retrieves a block from the local chain.
 	GetBlockByHash(common.Hash) *types.Block
@@ -391,9 +386,7 @@ func (d *Downloader) synchronise(id string, hash common.Hash, td *big.Int, mode
 
 	// Set the requested sync mode, unless it's forbidden
 	d.mode = mode
-	if d.mode == FastSync && atomic.LoadUint32(&d.fsPivotFails) >= fsCriticalTrials {
-		d.mode = FullSync
-	}
+
 	// Retrieve the origin peer and initiate the downloading process
 	p := d.peers.Peer(id)
 	if p == nil {
@@ -441,57 +434,40 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I
 	d.syncStatsChainHeight = height
 	d.syncStatsLock.Unlock()
 
-	// Initiate the sync using a concurrent header and content retrieval algorithm
+	// Ensure our origin point is below any fast sync pivot point
 	pivot := uint64(0)
-	switch d.mode {
-	case LightSync:
-		pivot = height
-	case FastSync:
-		// Calculate the new fast/slow sync pivot point
-		if d.fsPivotLock == nil {
-			pivotOffset, err := rand.Int(rand.Reader, big.NewInt(int64(fsPivotInterval)))
-			if err != nil {
-				panic(fmt.Sprintf("Failed to access crypto random source: %v", err))
-			}
-			if height > uint64(fsMinFullBlocks)+pivotOffset.Uint64() {
-				pivot = height - uint64(fsMinFullBlocks) - pivotOffset.Uint64()
-			}
+	if d.mode == FastSync {
+		if height <= uint64(fsMinFullBlocks) {
+			origin = 0
 		} else {
-			// Pivot point locked in, use this and do not pick a new one!
-			pivot = d.fsPivotLock.Number.Uint64()
-		}
-		// If the point is below the origin, move origin back to ensure state download
-		if pivot < origin {
-			if pivot > 0 {
+			pivot = height - uint64(fsMinFullBlocks)
+			if pivot <= origin {
 				origin = pivot - 1
-			} else {
-				origin = 0
 			}
 		}
-		log.Debug("Fast syncing until pivot block", "pivot", pivot)
 	}
-	d.queue.Prepare(origin+1, d.mode, pivot, latest)
+	d.committed = 1
+	if d.mode == FastSync && pivot != 0 {
+		d.committed = 0
+	}
+	// Initiate the sync using a concurrent header and content retrieval algorithm
+	d.queue.Prepare(origin+1, d.mode)
 	if d.syncInitHook != nil {
 		d.syncInitHook(origin, height)
 	}
 
 	fetchers := []func() error{
-		func() error { return d.fetchHeaders(p, origin+1) }, // Headers are always retrieved
-		func() error { return d.fetchBodies(origin + 1) },   // Bodies are retrieved during normal and fast sync
-		func() error { return d.fetchReceipts(origin + 1) }, // Receipts are retrieved during fast sync
-		func() error { return d.processHeaders(origin+1, td) },
+		func() error { return d.fetchHeaders(p, origin+1, pivot) }, // Headers are always retrieved
+		func() error { return d.fetchBodies(origin + 1) },          // Bodies are retrieved during normal and fast sync
+		func() error { return d.fetchReceipts(origin + 1) },        // Receipts are retrieved during fast sync
+		func() error { return d.processHeaders(origin+1, pivot, td) },
 	}
 	if d.mode == FastSync {
 		fetchers = append(fetchers, func() error { return d.processFastSyncContent(latest) })
 	} else if d.mode == FullSync {
 		fetchers = append(fetchers, d.processFullSyncContent)
 	}
-	err = d.spawnSync(fetchers)
-	if err != nil && d.mode == FastSync && d.fsPivotLock != nil {
-		// If sync failed in the critical section, bump the fail counter.
-		atomic.AddUint32(&d.fsPivotFails, 1)
-	}
-	return err
+	return d.spawnSync(fetchers)
 }
 
 // spawnSync runs d.process and all given fetcher functions to completion in
@@ -671,7 +647,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
 					continue
 				}
 				// Otherwise check if we already know the header or not
-				if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) {
+				if (d.mode == FullSync && d.blockchain.HasBlockAndState(headers[i].Hash(), headers[i].Number.Uint64())) || (d.mode != FullSync && d.lightchain.HasHeader(headers[i].Hash(), headers[i].Number.Uint64())) {
 					number, hash = headers[i].Number.Uint64(), headers[i].Hash()
 
 					// If every header is known, even future ones, the peer straight out lied about its head
@@ -736,7 +712,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
 				arrived = true
 
 				// Modify the search interval based on the response
-				if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) {
+				if (d.mode == FullSync && !d.blockchain.HasBlockAndState(headers[0].Hash(), headers[0].Number.Uint64())) || (d.mode != FullSync && !d.lightchain.HasHeader(headers[0].Hash(), headers[0].Number.Uint64())) {
 					end = check
 					break
 				}
@@ -774,7 +750,7 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err
 // other peers are only accepted if they map cleanly to the skeleton. If no one
 // can fill in the skeleton - not even the origin peer - it's assumed invalid and
 // the origin is dropped.
-func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
+func (d *Downloader) fetchHeaders(p *peerConnection, from uint64, pivot uint64) error {
 	p.log.Debug("Directing header downloads", "origin", from)
 	defer p.log.Debug("Header download terminated")
 
@@ -825,6 +801,18 @@ func (d *Downloader) fetchHeaders(p *peerConnection, from uint64) error {
 			}
 			// If no more headers are inbound, notify the content fetchers and return
 			if packet.Items() == 0 {
+				// Don't abort header fetches while the pivot is downloading
+				if atomic.LoadInt32(&d.committed) == 0 && pivot <= from {
+					p.log.Debug("No headers, waiting for pivot commit")
+					select {
+					case <-time.After(fsHeaderContCheck):
+						getHeaders(from)
+						continue
+					case <-d.cancelCh:
+						return errCancelHeaderFetch
+					}
+				}
+				// Pivot done (or not in fast sync) and no more headers, terminate the process
 				p.log.Debug("No more headers available")
 				select {
 				case d.headerProcCh <- nil:
@@ -1129,10 +1117,8 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
 				}
 				if request.From > 0 {
 					peer.log.Trace("Requesting new batch of data", "type", kind, "from", request.From)
-				} else if len(request.Headers) > 0 {
-					peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
 				} else {
-					peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Hashes))
+					peer.log.Trace("Requesting new batch of data", "type", kind, "count", len(request.Headers), "from", request.Headers[0].Number)
 				}
 				// Fetch the chunk and make sure any errors return the hashes to the queue
 				if fetchHook != nil {
@@ -1160,10 +1146,7 @@ func (d *Downloader) fetchParts(errCancel error, deliveryCh chan dataPack, deliv
 // processHeaders takes batches of retrieved headers from an input channel and
 // keeps processing and scheduling them into the header chain and downloader's
 // queue until the stream ends or a failure occurs.
-func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
-	// Calculate the pivoting point for switching from fast to slow sync
-	pivot := d.queue.FastSyncPivot()
-
+func (d *Downloader) processHeaders(origin uint64, pivot uint64, td *big.Int) error {
 	// Keep a count of uncertain headers to roll back
 	rollback := []*types.Header{}
 	defer func() {
@@ -1188,19 +1171,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 				"header", fmt.Sprintf("%d->%d", lastHeader, d.lightchain.CurrentHeader().Number),
 				"fast", fmt.Sprintf("%d->%d", lastFastBlock, curFastBlock),
 				"block", fmt.Sprintf("%d->%d", lastBlock, curBlock))
-
-			// If we're already past the pivot point, this could be an attack, thread carefully
-			if rollback[len(rollback)-1].Number.Uint64() > pivot {
-				// If we didn't ever fail, lock in the pivot header (must! not! change!)
-				if atomic.LoadUint32(&d.fsPivotFails) == 0 {
-					for _, header := range rollback {
-						if header.Number.Uint64() == pivot {
-							log.Warn("Fast-sync pivot locked in", "number", pivot, "hash", header.Hash())
-							d.fsPivotLock = header
-						}
-					}
-				}
-			}
 		}
 	}()
 
@@ -1302,13 +1272,6 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 						rollback = append(rollback[:0], rollback[len(rollback)-fsHeaderSafetyNet:]...)
 					}
 				}
-				// If we're fast syncing and just pulled in the pivot, make sure it's the one locked in
-				if d.mode == FastSync && d.fsPivotLock != nil && chunk[0].Number.Uint64() <= pivot && chunk[len(chunk)-1].Number.Uint64() >= pivot {
-					if pivot := chunk[int(pivot-chunk[0].Number.Uint64())]; pivot.Hash() != d.fsPivotLock.Hash() {
-						log.Warn("Pivot doesn't match locked in one", "remoteNumber", pivot.Number, "remoteHash", pivot.Hash(), "localNumber", d.fsPivotLock.Number, "localHash", d.fsPivotLock.Hash())
-						return errInvalidChain
-					}
-				}
 				// Unless we're doing light chains, schedule the headers for associated content retrieval
 				if d.mode == FullSync || d.mode == FastSync {
 					// If we've reached the allowed number of pending headers, stall a bit
@@ -1343,7 +1306,7 @@ func (d *Downloader) processHeaders(origin uint64, td *big.Int) error {
 // processFullSyncContent takes fetch results from the queue and imports them into the chain.
 func (d *Downloader) processFullSyncContent() error {
 	for {
-		results := d.queue.WaitResults()
+		results := d.queue.Results(true)
 		if len(results) == 0 {
 			return nil
 		}
@@ -1357,30 +1320,28 @@ func (d *Downloader) processFullSyncContent() error {
 }
 
 func (d *Downloader) importBlockResults(results []*fetchResult) error {
-	for len(results) != 0 {
-		// Check for any termination requests. This makes clean shutdown faster.
-		select {
-		case <-d.quitCh:
-			return errCancelContentProcessing
-		default:
-		}
-		// Retrieve the a batch of results to import
-		items := int(math.Min(float64(len(results)), float64(maxResultsProcess)))
-		first, last := results[0].Header, results[items-1].Header
-		log.Debug("Inserting downloaded chain", "items", len(results),
-			"firstnum", first.Number, "firsthash", first.Hash(),
-			"lastnum", last.Number, "lasthash", last.Hash(),
-		)
-		blocks := make([]*types.Block, items)
-		for i, result := range results[:items] {
-			blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
-		}
-		if index, err := d.blockchain.InsertChain(blocks); err != nil {
-			log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
-			return errInvalidChain
-		}
-		// Shift the results to the next batch
-		results = results[items:]
+	// Check for any early termination requests
+	if len(results) == 0 {
+		return nil
+	}
+	select {
+	case <-d.quitCh:
+		return errCancelContentProcessing
+	default:
+	}
+	// Retrieve the a batch of results to import
+	first, last := results[0].Header, results[len(results)-1].Header
+	log.Debug("Inserting downloaded chain", "items", len(results),
+		"firstnum", first.Number, "firsthash", first.Hash(),
+		"lastnum", last.Number, "lasthash", last.Hash(),
+	)
+	blocks := make([]*types.Block, len(results))
+	for i, result := range results {
+		blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+	}
+	if index, err := d.blockchain.InsertChain(blocks); err != nil {
+		log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
+		return errInvalidChain
 	}
 	return nil
 }
@@ -1388,35 +1349,92 @@ func (d *Downloader) importBlockResults(results []*fetchResult) error {
 // processFastSyncContent takes fetch results from the queue and writes them to the
 // database. It also controls the synchronisation of state nodes of the pivot block.
 func (d *Downloader) processFastSyncContent(latest *types.Header) error {
-	// Start syncing state of the reported head block.
-	// This should get us most of the state of the pivot block.
+	// Start syncing state of the reported head block. This should get us most of
+	// the state of the pivot block.
 	stateSync := d.syncState(latest.Root)
 	defer stateSync.Cancel()
 	go func() {
-		if err := stateSync.Wait(); err != nil {
+		if err := stateSync.Wait(); err != nil && err != errCancelStateFetch {
 			d.queue.Close() // wake up WaitResults
 		}
 	}()
-
-	pivot := d.queue.FastSyncPivot()
+	// Figure out the ideal pivot block. Note, that this goalpost may move if the
+	// sync takes long enough for the chain head to move significantly.
+	pivot := uint64(0)
+	if height := latest.Number.Uint64(); height > uint64(fsMinFullBlocks) {
+		pivot = height - uint64(fsMinFullBlocks)
+	}
+	// To cater for moving pivot points, track the pivot block and subsequently
+	// accumulated download results separatey.
+	var (
+		oldPivot *fetchResult   // Locked in pivot block, might change eventually
+		oldTail  []*fetchResult // Downloaded content after the pivot
+	)
 	for {
-		results := d.queue.WaitResults()
+		// Wait for the next batch of downloaded data to be available, and if the pivot
+		// block became stale, move the goalpost
+		results := d.queue.Results(oldPivot == nil) // Block if we're not monitoring pivot staleness
 		if len(results) == 0 {
-			return stateSync.Cancel()
+			// If pivot sync is done, stop
+			if oldPivot == nil {
+				return stateSync.Cancel()
+			}
+			// If sync failed, stop
+			select {
+			case <-d.cancelCh:
+				return stateSync.Cancel()
+			default:
+			}
 		}
 		if d.chainInsertHook != nil {
 			d.chainInsertHook(results)
 		}
+		if oldPivot != nil {
+			results = append(append([]*fetchResult{oldPivot}, oldTail...), results...)
+		}
+		// Split around the pivot block and process the two sides via fast/full sync
+		if atomic.LoadInt32(&d.committed) == 0 {
+			latest = results[len(results)-1].Header
+			if height := latest.Number.Uint64(); height > pivot+2*uint64(fsMinFullBlocks) {
+				log.Warn("Pivot became stale, moving", "old", pivot, "new", height-uint64(fsMinFullBlocks))
+				pivot = height - uint64(fsMinFullBlocks)
+			}
+		}
 		P, beforeP, afterP := splitAroundPivot(pivot, results)
 		if err := d.commitFastSyncData(beforeP, stateSync); err != nil {
 			return err
 		}
 		if P != nil {
-			stateSync.Cancel()
-			if err := d.commitPivotBlock(P); err != nil {
-				return err
+			// If new pivot block found, cancel old state retrieval and restart
+			if oldPivot != P {
+				stateSync.Cancel()
+
+				stateSync = d.syncState(P.Header.Root)
+				defer stateSync.Cancel()
+				go func() {
+					if err := stateSync.Wait(); err != nil && err != errCancelStateFetch {
+						d.queue.Close() // wake up WaitResults
+					}
+				}()
+				oldPivot = P
+			}
+			// Wait for completion, occasionally checking for pivot staleness
+			select {
+			case <-stateSync.done:
+				if stateSync.err != nil {
+					return stateSync.err
+				}
+				if err := d.commitPivotBlock(P); err != nil {
+					return err
+				}
+				oldPivot = nil
+
+			case <-time.After(time.Second):
+				oldTail = afterP
+				continue
 			}
 		}
+		// Fast sync done, pivot commit done, full import
 		if err := d.importBlockResults(afterP); err != nil {
 			return err
 		}
@@ -1439,52 +1457,49 @@ func splitAroundPivot(pivot uint64, results []*fetchResult) (p *fetchResult, bef
 }
 
 func (d *Downloader) commitFastSyncData(results []*fetchResult, stateSync *stateSync) error {
-	for len(results) != 0 {
-		// Check for any termination requests.
-		select {
-		case <-d.quitCh:
-			return errCancelContentProcessing
-		case <-stateSync.done:
-			if err := stateSync.Wait(); err != nil {
-				return err
-			}
-		default:
-		}
-		// Retrieve the a batch of results to import
-		items := int(math.Min(float64(len(results)), float64(maxResultsProcess)))
-		first, last := results[0].Header, results[items-1].Header
-		log.Debug("Inserting fast-sync blocks", "items", len(results),
-			"firstnum", first.Number, "firsthash", first.Hash(),
-			"lastnumn", last.Number, "lasthash", last.Hash(),
-		)
-		blocks := make([]*types.Block, items)
-		receipts := make([]types.Receipts, items)
-		for i, result := range results[:items] {
-			blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
-			receipts[i] = result.Receipts
-		}
-		if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil {
-			log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
-			return errInvalidChain
+	// Check for any early termination requests
+	if len(results) == 0 {
+		return nil
+	}
+	select {
+	case <-d.quitCh:
+		return errCancelContentProcessing
+	case <-stateSync.done:
+		if err := stateSync.Wait(); err != nil {
+			return err
 		}
-		// Shift the results to the next batch
-		results = results[items:]
+	default:
+	}
+	// Retrieve the a batch of results to import
+	first, last := results[0].Header, results[len(results)-1].Header
+	log.Debug("Inserting fast-sync blocks", "items", len(results),
+		"firstnum", first.Number, "firsthash", first.Hash(),
+		"lastnumn", last.Number, "lasthash", last.Hash(),
+	)
+	blocks := make([]*types.Block, len(results))
+	receipts := make([]types.Receipts, len(results))
+	for i, result := range results {
+		blocks[i] = types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+		receipts[i] = result.Receipts
+	}
+	if index, err := d.blockchain.InsertReceiptChain(blocks, receipts); err != nil {
+		log.Debug("Downloaded item processing failed", "number", results[index].Header.Number, "hash", results[index].Header.Hash(), "err", err)
+		return errInvalidChain
 	}
 	return nil
 }
 
 func (d *Downloader) commitPivotBlock(result *fetchResult) error {
-	b := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
-	// Sync the pivot block state. This should complete reasonably quickly because
-	// we've already synced up to the reported head block state earlier.
-	if err := d.syncState(b.Root()).Wait(); err != nil {
+	block := types.NewBlockWithHeader(result.Header).WithBody(result.Transactions, result.Uncles)
+	log.Debug("Committing fast sync pivot as new head", "number", block.Number(), "hash", block.Hash())
+	if _, err := d.blockchain.InsertReceiptChain([]*types.Block{block}, []types.Receipts{result.Receipts}); err != nil {
 		return err
 	}
-	log.Debug("Committing fast sync pivot as new head", "number", b.Number(), "hash", b.Hash())
-	if _, err := d.blockchain.InsertReceiptChain([]*types.Block{b}, []types.Receipts{result.Receipts}); err != nil {
+	if err := d.blockchain.FastSyncCommitHead(block.Hash()); err != nil {
 		return err
 	}
-	return d.blockchain.FastSyncCommitHead(b.Hash())
+	atomic.StoreInt32(&d.committed, 1)
+	return nil
 }
 
 // DeliverHeaders injects a new batch of block headers received from a remote
diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go
index e9c7b6170..d94d55f11 100644
--- a/eth/downloader/downloader_test.go
+++ b/eth/downloader/downloader_test.go
@@ -28,7 +28,6 @@ import (
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/consensus/ethash"
 	"github.com/ethereum/go-ethereum/core"
-	"github.com/ethereum/go-ethereum/core/state"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/ethdb"
@@ -45,8 +44,8 @@ var (
 // Reduce some of the parameters to make the tester faster.
 func init() {
 	MaxForkAncestry = uint64(10000)
-	blockCacheLimit = 1024
-	fsCriticalTrials = 10
+	blockCacheItems = 1024
+	fsHeaderContCheck = 500 * time.Millisecond
 }
 
 // downloadTester is a test simulator for mocking out local block chain.
@@ -223,7 +222,7 @@ func (dl *downloadTester) HasHeader(hash common.Hash, number uint64) bool {
 }
 
 // HasBlockAndState checks if a block and associated state is present in the testers canonical chain.
-func (dl *downloadTester) HasBlockAndState(hash common.Hash) bool {
+func (dl *downloadTester) HasBlockAndState(hash common.Hash, number uint64) bool {
 	block := dl.GetBlockByHash(hash)
 	if block == nil {
 		return false
@@ -293,7 +292,7 @@ func (dl *downloadTester) CurrentFastBlock() *types.Block {
 func (dl *downloadTester) FastSyncCommitHead(hash common.Hash) error {
 	// For now only check that the state trie is correct
 	if block := dl.GetBlockByHash(hash); block != nil {
-		_, err := trie.NewSecure(block.Root(), dl.stateDb, 0)
+		_, err := trie.NewSecure(block.Root(), trie.NewDatabase(dl.stateDb), 0)
 		return err
 	}
 	return fmt.Errorf("non existent block: %x", hash[:4])
@@ -619,28 +618,22 @@ func assertOwnChain(t *testing.T, tester *downloadTester, length int) {
 // number of items of the various chain components.
 func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) {
 	// Initialize the counters for the first fork
-	headers, blocks := lengths[0], lengths[0]
+	headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-fsMinFullBlocks
 
-	minReceipts, maxReceipts := lengths[0]-fsMinFullBlocks-fsPivotInterval, lengths[0]-fsMinFullBlocks
-	if minReceipts < 0 {
-		minReceipts = 1
-	}
-	if maxReceipts < 0 {
-		maxReceipts = 1
+	if receipts < 0 {
+		receipts = 1
 	}
 	// Update the counters for each subsequent fork
 	for _, length := range lengths[1:] {
 		headers += length - common
 		blocks += length - common
-
-		minReceipts += length - common - fsMinFullBlocks - fsPivotInterval
-		maxReceipts += length - common - fsMinFullBlocks
+		receipts += length - common - fsMinFullBlocks
 	}
 	switch tester.downloader.mode {
 	case FullSync:
-		minReceipts, maxReceipts = 1, 1
+		receipts = 1
 	case LightSync:
-		blocks, minReceipts, maxReceipts = 1, 1, 1
+		blocks, receipts = 1, 1
 	}
 	if hs := len(tester.ownHeaders); hs != headers {
 		t.Fatalf("synchronised headers mismatch: have %v, want %v", hs, headers)
@@ -648,11 +641,12 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
 	if bs := len(tester.ownBlocks); bs != blocks {
 		t.Fatalf("synchronised blocks mismatch: have %v, want %v", bs, blocks)
 	}
-	if rs := len(tester.ownReceipts); rs < minReceipts || rs > maxReceipts {
-		t.Fatalf("synchronised receipts mismatch: have %v, want between [%v, %v]", rs, minReceipts, maxReceipts)
+	if rs := len(tester.ownReceipts); rs != receipts {
+		t.Fatalf("synchronised receipts mismatch: have %v, want %v", rs, receipts)
 	}
 	// Verify the state trie too for fast syncs
-	if tester.downloader.mode == FastSync {
+	/*if tester.downloader.mode == FastSync {
+		pivot := uint64(0)
 		var index int
 		if pivot := int(tester.downloader.queue.fastSyncPivot); pivot < common {
 			index = pivot
@@ -660,11 +654,11 @@ func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, leng
 			index = len(tester.ownHashes) - lengths[len(lengths)-1] + int(tester.downloader.queue.fastSyncPivot)
 		}
 		if index > 0 {
-			if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(tester.stateDb)); statedb == nil || err != nil {
+			if statedb, err := state.New(tester.ownHeaders[tester.ownHashes[index]].Root, state.NewDatabase(trie.NewDatabase(tester.stateDb))); statedb == nil || err != nil {
 				t.Fatalf("state reconstruction failed: %v", err)
 			}
 		}
-	}
+	}*/
 }
 
 // Tests that simple synchronization against a canonical chain works correctly.
@@ -684,7 +678,7 @@ func testCanonicalSynchronisation(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@@ -710,7 +704,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a long block chain to download and the tester
-	targetBlocks := 8 * blockCacheLimit
+	targetBlocks := 8 * blockCacheItems
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@@ -745,9 +739,9 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
 			cached = len(tester.downloader.queue.blockDonePool)
 			if mode == FastSync {
 				if receipts := len(tester.downloader.queue.receiptDonePool); receipts < cached {
-					if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
-						cached = receipts
-					}
+					//if tester.downloader.queue.resultCache[receipts].Header.Number.Uint64() < tester.downloader.queue.fastSyncPivot {
+					cached = receipts
+					//}
 				}
 			}
 			frozen = int(atomic.LoadUint32(&blocked))
@@ -755,7 +749,7 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
 			tester.downloader.queue.lock.Unlock()
 			tester.lock.Unlock()
 
-			if cached == blockCacheLimit || retrieved+cached+frozen == targetBlocks+1 {
+			if cached == blockCacheItems || retrieved+cached+frozen == targetBlocks+1 {
 				break
 			}
 		}
@@ -765,8 +759,8 @@ func testThrottling(t *testing.T, protocol int, mode SyncMode) {
 		tester.lock.RLock()
 		retrieved = len(tester.ownBlocks)
 		tester.lock.RUnlock()
-		if cached != blockCacheLimit && retrieved+cached+frozen != targetBlocks+1 {
-			t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheLimit, retrieved, frozen, targetBlocks+1)
+		if cached != blockCacheItems && retrieved+cached+frozen != targetBlocks+1 {
+			t.Fatalf("block count mismatch: have %v, want %v (owned %v, blocked %v, target %v)", cached, blockCacheItems, retrieved, frozen, targetBlocks+1)
 		}
 		// Permit the blocked blocks to import
 		if atomic.LoadUint32(&blocked) > 0 {
@@ -974,7 +968,7 @@ func testCancel(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download and the tester
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	if targetBlocks >= MaxHashFetch {
 		targetBlocks = MaxHashFetch - 15
 	}
@@ -1016,12 +1010,12 @@ func testMultiSynchronisation(t *testing.T, protocol int, mode SyncMode) {
 
 	// Create various peers with various parts of the chain
 	targetPeers := 8
-	targetBlocks := targetPeers*blockCacheLimit - 15
+	targetBlocks := targetPeers*blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	for i := 0; i < targetPeers; i++ {
 		id := fmt.Sprintf("peer #%d", i)
-		tester.newPeer(id, protocol, hashes[i*blockCacheLimit:], headers, blocks, receipts)
+		tester.newPeer(id, protocol, hashes[i*blockCacheItems:], headers, blocks, receipts)
 	}
 	if err := tester.sync("peer #0", nil, mode); err != nil {
 		t.Fatalf("failed to synchronise blocks: %v", err)
@@ -1045,7 +1039,7 @@ func testMultiProtoSync(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	// Create peers of every type
@@ -1084,7 +1078,7 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a block chain to download
-	targetBlocks := 2*blockCacheLimit - 15
+	targetBlocks := 2*blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
@@ -1110,8 +1104,8 @@ func testEmptyShortCircuit(t *testing.T, protocol int, mode SyncMode) {
 			bodiesNeeded++
 		}
 	}
-	for hash, receipt := range receipts {
-		if mode == FastSync && len(receipt) > 0 && headers[hash].Number.Uint64() <= tester.downloader.queue.fastSyncPivot {
+	for _, receipt := range receipts {
+		if mode == FastSync && len(receipt) > 0 {
 			receiptsNeeded++
 		}
 	}
@@ -1139,7 +1133,7 @@ func testMissingHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	// Attempt a full sync with an attacker feeding gapped headers
@@ -1174,7 +1168,7 @@ func testShiftedHeaderAttack(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	// Attempt a full sync with an attacker feeding shifted headers
@@ -1208,7 +1202,7 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := 3*fsHeaderSafetyNet + fsPivotInterval + fsMinFullBlocks
+	targetBlocks := 3*fsHeaderSafetyNet + 256 + fsMinFullBlocks
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	// Attempt to sync with an attacker that feeds junk during the fast sync phase.
@@ -1248,7 +1242,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 	tester.newPeer("withhold-attack", protocol, hashes, headers, blocks, receipts)
 	missing = 3*fsHeaderSafetyNet + MaxHeaderFetch + 1
 
-	tester.downloader.fsPivotFails = 0
 	tester.downloader.syncInitHook = func(uint64, uint64) {
 		for i := missing; i <= len(hashes); i++ {
 			delete(tester.peerHeaders["withhold-attack"], hashes[len(hashes)-i])
@@ -1267,8 +1260,6 @@ func testInvalidHeaderRollback(t *testing.T, protocol int, mode SyncMode) {
 			t.Errorf("fast sync pivot block #%d not rolled back", head)
 		}
 	}
-	tester.downloader.fsPivotFails = fsCriticalTrials
-
 	// Synchronise with the valid peer and make sure sync succeeds. Since the last
 	// rollback should also disable fast syncing for this process, verify that we
 	// did a fresh full sync. Note, we can't assert anything about the receipts
@@ -1383,7 +1374,7 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	// Set a sync init hook to catch progress changes
@@ -1532,7 +1523,7 @@ func testFailedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small enough block chain to download
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
 
 	// Set a sync init hook to catch progress changes
@@ -1609,7 +1600,7 @@ func testFakedSyncProgress(t *testing.T, protocol int, mode SyncMode) {
 	defer tester.terminate()
 
 	// Create a small block chain
-	targetBlocks := blockCacheLimit - 15
+	targetBlocks := blockCacheItems - 15
 	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks+3, 0, tester.genesis, nil, false)
 
 	// Set a sync init hook to catch progress changes
@@ -1697,6 +1688,7 @@ func TestDeliverHeadersHang(t *testing.T) {
 type floodingTestPeer struct {
 	peer   Peer
 	tester *downloadTester
+	pend   sync.WaitGroup
 }
 
 func (ftp *floodingTestPeer) Head() (common.Hash, *big.Int) { return ftp.peer.Head() }
@@ -1717,9 +1709,12 @@ func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int
 	deliveriesDone := make(chan struct{}, 500)
 	for i := 0; i < cap(deliveriesDone); i++ {
 		peer := fmt.Sprintf("fake-peer%d", i)
+		ftp.pend.Add(1)
+
 		go func() {
 			ftp.tester.downloader.DeliverHeaders(peer, []*types.Header{{}, {}, {}, {}})
 			deliveriesDone <- struct{}{}
+			ftp.pend.Done()
 		}()
 	}
 	// Deliver the actual requested headers.
@@ -1751,110 +1746,15 @@ func testDeliverHeadersHang(t *testing.T, protocol int, mode SyncMode) {
 		// Whenever the downloader requests headers, flood it with
 		// a lot of unrequested header deliveries.
 		tester.downloader.peers.peers["peer"].peer = &floodingTestPeer{
-			tester.downloader.peers.peers["peer"].peer,
-			tester,
+			peer:   tester.downloader.peers.peers["peer"].peer,
+			tester: tester,
 		}
 		if err := tester.sync("peer", nil, mode); err != nil {
-			t.Errorf("sync failed: %v", err)
+			t.Errorf("test %d: sync failed: %v", i, err)
 		}
 		tester.terminate()
-	}
-}
-
-// Tests that if fast sync aborts in the critical section, it can restart a few
-// times before giving up.
-// We use data driven subtests to manage this so that it will be parallel on its own
-// and not with the other tests, avoiding intermittent failures.
-func TestFastCriticalRestarts(t *testing.T) {
-	testCases := []struct {
-		protocol int
-		progress bool
-	}{
-		{63, false},
-		{64, false},
-		{63, true},
-		{64, true},
-	}
-	for _, tc := range testCases {
-		t.Run(fmt.Sprintf("protocol %d progress %v", tc.protocol, tc.progress), func(t *testing.T) {
-			testFastCriticalRestarts(t, tc.protocol, tc.progress)
-		})
-	}
-}
-
-func testFastCriticalRestarts(t *testing.T, protocol int, progress bool) {
-	t.Parallel()
-
-	tester := newTester()
-	defer tester.terminate()
-
-	// Create a large enough blockchin to actually fast sync on
-	targetBlocks := fsMinFullBlocks + 2*fsPivotInterval - 15
-	hashes, headers, blocks, receipts := tester.makeChain(targetBlocks, 0, tester.genesis, nil, false)
-
-	// Create a tester peer with a critical section header missing (force failures)
-	tester.newPeer("peer", protocol, hashes, headers, blocks, receipts)
-	delete(tester.peerHeaders["peer"], hashes[fsMinFullBlocks-1])
-	tester.downloader.dropPeer = func(id string) {} // We reuse the same "faulty" peer throughout the test
-
-	// Remove all possible pivot state roots and slow down replies (test failure resets later)
-	for i := 0; i < fsPivotInterval; i++ {
-		tester.peerMissingStates["peer"][headers[hashes[fsMinFullBlocks+i]].Root] = true
-	}
-	(tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(500 * time.Millisecond) // Enough to reach the critical section
-
-	// Synchronise with the peer a few times and make sure they fail until the retry limit
-	for i := 0; i < int(fsCriticalTrials)-1; i++ {
-		// Attempt a sync and ensure it fails properly
-		if err := tester.sync("peer", nil, FastSync); err == nil {
-			t.Fatalf("failing fast sync succeeded: %v", err)
-		}
-		time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
-
-		// If it's the first failure, pivot should be locked => reenable all others to detect pivot changes
-		if i == 0 {
-			time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
-			if tester.downloader.fsPivotLock == nil {
-				time.Sleep(400 * time.Millisecond) // Make sure the first huge timeout expires too
-				t.Fatalf("pivot block not locked in after critical section failure")
-			}
-			tester.lock.Lock()
-			tester.peerHeaders["peer"][hashes[fsMinFullBlocks-1]] = headers[hashes[fsMinFullBlocks-1]]
-			tester.peerMissingStates["peer"] = map[common.Hash]bool{tester.downloader.fsPivotLock.Root: true}
-			(tester.downloader.peers.peers["peer"].peer).(*downloadTesterPeer).setDelay(0)
-			tester.lock.Unlock()
-		}
-	}
-	// Return all nodes if we're testing fast sync progression
-	if progress {
-		tester.lock.Lock()
-		tester.peerMissingStates["peer"] = map[common.Hash]bool{}
-		tester.lock.Unlock()
-
-		if err := tester.sync("peer", nil, FastSync); err != nil {
-			t.Fatalf("failed to synchronise blocks in progressed fast sync: %v", err)
-		}
-		time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
 
-		if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != 1 {
-			t.Fatalf("progressed pivot trial count mismatch: have %v, want %v", fails, 1)
-		}
-		assertOwnChain(t, tester, targetBlocks+1)
-	} else {
-		if err := tester.sync("peer", nil, FastSync); err == nil {
-			t.Fatalf("succeeded to synchronise blocks in failed fast sync")
-		}
-		time.Sleep(150 * time.Millisecond) // Make sure no in-flight requests remain
-
-		if fails := atomic.LoadUint32(&tester.downloader.fsPivotFails); fails != fsCriticalTrials {
-			t.Fatalf("failed pivot trial count mismatch: have %v, want %v", fails, fsCriticalTrials)
-		}
-	}
-	// Retry limit exhausted, downloader will switch to full sync, should succeed
-	if err := tester.sync("peer", nil, FastSync); err != nil {
-		t.Fatalf("failed to synchronise blocks in slow sync: %v", err)
+		// Flush all goroutines to prevent messing with subsequent tests
+		tester.downloader.peers.peers["peer"].peer.(*floodingTestPeer).pend.Wait()
 	}
-	// Note, we can't assert the chain here because the test asserter assumes sync
-	// completed using a single mode of operation, whereas fast-then-slow can result
-	// in arbitrary intermediate state that's not cleanly verifiable.
 }
diff --git a/eth/downloader/queue.go b/eth/downloader/queue.go
index 6926f1d8c..a1a70e46e 100644
--- a/eth/downloader/queue.go
+++ b/eth/downloader/queue.go
@@ -32,7 +32,11 @@ import (
 	"gopkg.in/karalabe/cookiejar.v2/collections/prque"
 )
 
-var blockCacheLimit = 8192 // Maximum number of blocks to cache before throttling the download
+var (
+	blockCacheItems      = 8192             // Maximum number of blocks to cache before throttling the download
+	blockCacheMemory     = 64 * 1024 * 1024 // Maximum amount of memory to use for block caching
+	blockCacheSizeWeight = 0.1              // Multiplier to approximate the average block size based on past ones
+)
 
 var (
 	errNoFetchesPending = errors.New("no fetches pending")
@@ -41,17 +45,17 @@ var (
 
 // fetchRequest is a currently running data retrieval operation.
 type fetchRequest struct {
-	Peer    *peerConnection     // Peer to which the request was sent
-	From    uint64              // [eth/62] Requested chain element index (used for skeleton fills only)
-	Hashes  map[common.Hash]int // [eth/61] Requested hashes with their insertion index (priority)
-	Headers []*types.Header     // [eth/62] Requested headers, sorted by request order
-	Time    time.Time           // Time when the request was made
+	Peer    *peerConnection // Peer to which the request was sent
+	From    uint64          // [eth/62] Requested chain element index (used for skeleton fills only)
+	Headers []*types.Header // [eth/62] Requested headers, sorted by request order
+	Time    time.Time       // Time when the request was made
 }
 
 // fetchResult is a struct collecting partial results from data fetchers until
 // all outstanding pieces complete and the result as a whole can be processed.
 type fetchResult struct {
-	Pending int // Number of data fetches still pending
+	Pending int         // Number of data fetches still pending
+	Hash    common.Hash // Hash of the header to prevent recalculating
 
 	Header       *types.Header
 	Uncles       []*types.Header
@@ -61,12 +65,10 @@ type fetchResult struct {
 
 // queue represents hashes that are either need fetching or are being fetched
 type queue struct {
-	mode          SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
-	fastSyncPivot uint64   // Block number where the fast sync pivots into archive synchronisation mode
-
-	headerHead common.Hash // [eth/62] Hash of the last queued header to verify order
+	mode SyncMode // Synchronisation mode to decide on the block parts to schedule for fetching
 
 	// Headers are "special", they download in batches, supported by a skeleton chain
+	headerHead      common.Hash                    // [eth/62] Hash of the last queued header to verify order
 	headerTaskPool  map[uint64]*types.Header       // [eth/62] Pending header retrieval tasks, mapping starting indexes to skeleton headers
 	headerTaskQueue *prque.Prque                   // [eth/62] Priority queue of the skeleton indexes to fetch the filling headers for
 	headerPeerMiss  map[string]map[uint64]struct{} // [eth/62] Set of per-peer header batches known to be unavailable
@@ -87,8 +89,9 @@ type queue struct {
 	receiptPendPool  map[string]*fetchRequest      // [eth/63] Currently pending receipt retrieval operations
 	receiptDonePool  map[common.Hash]struct{}      // [eth/63] Set of the completed receipt fetches
 
-	resultCache  []*fetchResult // Downloaded but not yet delivered fetch results
-	resultOffset uint64         // Offset of the first cached fetch result in the block chain
+	resultCache  []*fetchResult     // Downloaded but not yet delivered fetch results
+	resultOffset uint64             // Offset of the first cached fetch result in the block chain
+	resultSize   common.StorageSize // Approximate size of a block (exponential moving average)
 
 	lock   *sync.Mutex
 	active *sync.Cond
@@ -109,7 +112,7 @@ func newQueue() *queue {
 		receiptTaskQueue: prque.New(),
 		receiptPendPool:  make(map[string]*fetchRequest),
 		receiptDonePool:  make(map[common.Hash]struct{}),
-		resultCache:      make([]*fetchResult, blockCacheLimit),
+		resultCache:      make([]*fetchResult, blockCacheItems),
 		active:           sync.NewCond(lock),
 		lock:             lock,
 	}
@@ -122,10 +125,8 @@ func (q *queue) Reset() {
 
 	q.closed = false
 	q.mode = FullSync
-	q.fastSyncPivot = 0
 
 	q.headerHead = common.Hash{}
-
 	q.headerPendPool = make(map[string]*fetchRequest)
 
 	q.blockTaskPool = make(map[common.Hash]*types.Header)
@@ -138,7 +139,7 @@ func (q *queue) Reset() {
 	q.receiptPendPool = make(map[string]*fetchRequest)
 	q.receiptDonePool = make(map[common.Hash]struct{})
 
-	q.resultCache = make([]*fetchResult, blockCacheLimit)
+	q.resultCache = make([]*fetchResult, blockCacheItems)
 	q.resultOffset = 0
 }
 
@@ -214,27 +215,13 @@ func (q *queue) Idle() bool {
 	return (queued + pending + cached) == 0
 }
 
-// FastSyncPivot retrieves the currently used fast sync pivot point.
-func (q *queue) FastSyncPivot() uint64 {
-	q.lock.Lock()
-	defer q.lock.Unlock()
-
-	return q.fastSyncPivot
-}
-
 // ShouldThrottleBlocks checks if the download should be throttled (active block (body)
 // fetches exceed block cache).
 func (q *queue) ShouldThrottleBlocks() bool {
 	q.lock.Lock()
 	defer q.lock.Unlock()
 
-	// Calculate the currently in-flight block (body) requests
-	pending := 0
-	for _, request := range q.blockPendPool {
-		pending += len(request.Hashes) + len(request.Headers)
-	}
-	// Throttle if more blocks (bodies) are in-flight than free space in the cache
-	return pending >= len(q.resultCache)-len(q.blockDonePool)
+	return q.resultSlots(q.blockPendPool, q.blockDonePool) <= 0
 }
 
 // ShouldThrottleReceipts checks if the download should be throttled (active receipt
@@ -243,13 +230,39 @@ func (q *queue) ShouldThrottleReceipts() bool {
 	q.lock.Lock()
 	defer q.lock.Unlock()
 
-	// Calculate the currently in-flight receipt requests
+	return q.resultSlots(q.receiptPendPool, q.receiptDonePool) <= 0
+}
+
+// resultSlots calculates the number of results slots available for requests
+// whilst adhering to both the item and the memory limit too of the results
+// cache.
+func (q *queue) resultSlots(pendPool map[string]*fetchRequest, donePool map[common.Hash]struct{}) int {
+	// Calculate the maximum length capped by the memory limit
+	limit := len(q.resultCache)
+	if common.StorageSize(len(q.resultCache))*q.resultSize > common.StorageSize(blockCacheMemory) {
+		limit = int((common.StorageSize(blockCacheMemory) + q.resultSize - 1) / q.resultSize)
+	}
+	// Calculate the number of slots already finished
+	finished := 0
+	for _, result := range q.resultCache[:limit] {
+		if result == nil {
+			break
+		}
+		if _, ok := donePool[result.Hash]; ok {
+			finished++
+		}
+	}
+	// Calculate the number of slots currently downloading
 	pending := 0
-	for _, request := range q.receiptPendPool {
-		pending += len(request.Headers)
+	for _, request := range pendPool {
+		for _, header := range request.Headers {
+			if header.Number.Uint64() < q.resultOffset+uint64(limit) {
+				pending++
+			}
+		}
 	}
-	// Throttle if more receipts are in-flight than free space in the cache
-	return pending >= len(q.resultCache)-len(q.receiptDonePool)
+	// Return the free slots to distribute
+	return limit - finished - pending
 }
 
 // ScheduleSkeleton adds a batch of header retrieval tasks to the queue to fill
@@ -323,8 +336,7 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
 		q.blockTaskPool[hash] = header
 		q.blockTaskQueue.Push(header, -float32(header.Number.Uint64()))
 
-		if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot {
-			// Fast phase of the fast sync, retrieve receipts too
+		if q.mode == FastSync {
 			q.receiptTaskPool[hash] = header
 			q.receiptTaskQueue.Push(header, -float32(header.Number.Uint64()))
 		}
@@ -335,18 +347,25 @@ func (q *queue) Schedule(headers []*types.Header, from uint64) []*types.Header {
 	return inserts
 }
 
-// WaitResults retrieves and permanently removes a batch of fetch
-// results from the cache. the result slice will be empty if the queue
-// has been closed.
-func (q *queue) WaitResults() []*fetchResult {
+// Results retrieves and permanently removes a batch of fetch results from
+// the cache. the result slice will be empty if the queue has been closed.
+func (q *queue) Results(block bool) []*fetchResult {
 	q.lock.Lock()
 	defer q.lock.Unlock()
 
+	// Count the number of items available for processing
 	nproc := q.countProcessableItems()
 	for nproc == 0 && !q.closed {
+		if !block {
+			return nil
+		}
 		q.active.Wait()
 		nproc = q.countProcessableItems()
 	}
+	// Since we have a batch limit, don't pull more into "dangling" memory
+	if nproc > maxResultsProcess {
+		nproc = maxResultsProcess
+	}
 	results := make([]*fetchResult, nproc)
 	copy(results, q.resultCache[:nproc])
 	if len(results) > 0 {
@@ -363,6 +382,21 @@ func (q *queue) WaitResults() []*fetchResult {
 		}
 		// Advance the expected block number of the first cache entry.
 		q.resultOffset += uint64(nproc)
+
+		// Recalculate the result item weights to prevent memory exhaustion
+		for _, result := range results {
+			size := result.Header.Size()
+			for _, uncle := range result.Uncles {
+				size += uncle.Size()
+			}
+			for _, receipt := range result.Receipts {
+				size += receipt.Size()
+			}
+			for _, tx := range result.Transactions {
+				size += tx.Size()
+			}
+			q.resultSize = common.StorageSize(blockCacheSizeWeight)*size + (1-common.StorageSize(blockCacheSizeWeight))*q.resultSize
+		}
 	}
 	return results
 }
@@ -370,21 +404,9 @@ func (q *queue) WaitResults() []*fetchResult {
 // countProcessableItems counts the processable items.
 func (q *queue) countProcessableItems() int {
 	for i, result := range q.resultCache {
-		// Don't process incomplete or unavailable items.
 		if result == nil || result.Pending > 0 {
 			return i
 		}
-		// Stop before processing the pivot block to ensure that
-		// resultCache has space for fsHeaderForceVerify items. Not
-		// doing this could leave us unable to download the required
-		// amount of headers.
-		if q.mode == FastSync && result.Header.Number.Uint64() == q.fastSyncPivot {
-			for j := 0; j < fsHeaderForceVerify; j++ {
-				if i+j+1 >= len(q.resultCache) || q.resultCache[i+j+1] == nil {
-					return i
-				}
-			}
-		}
 	}
 	return len(q.resultCache)
 }
@@ -473,10 +495,8 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
 		return nil, false, nil
 	}
 	// Calculate an upper limit on the items we might fetch (i.e. throttling)
-	space := len(q.resultCache) - len(donePool)
-	for _, request := range pendPool {
-		space -= len(request.Headers)
-	}
+	space := q.resultSlots(pendPool, donePool)
+
 	// Retrieve a batch of tasks, skipping previously failed ones
 	send := make([]*types.Header, 0, count)
 	skip := make([]*types.Header, 0)
@@ -484,6 +504,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
 	progress := false
 	for proc := 0; proc < space && len(send) < count && !taskQueue.Empty(); proc++ {
 		header := taskQueue.PopItem().(*types.Header)
+		hash := header.Hash()
 
 		// If we're the first to request this task, initialise the result container
 		index := int(header.Number.Int64() - int64(q.resultOffset))
@@ -493,18 +514,19 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
 		}
 		if q.resultCache[index] == nil {
 			components := 1
-			if q.mode == FastSync && header.Number.Uint64() <= q.fastSyncPivot {
+			if q.mode == FastSync {
 				components = 2
 			}
 			q.resultCache[index] = &fetchResult{
 				Pending: components,
+				Hash:    hash,
 				Header:  header,
 			}
 		}
 		// If this fetch task is a noop, skip this fetch operation
 		if isNoop(header) {
-			donePool[header.Hash()] = struct{}{}
-			delete(taskPool, header.Hash())
+			donePool[hash] = struct{}{}
+			delete(taskPool, hash)
 
 			space, proc = space-1, proc-1
 			q.resultCache[index].Pending--
@@ -512,7 +534,7 @@ func (q *queue) reserveHeaders(p *peerConnection, count int, taskPool map[common
 			continue
 		}
 		// Otherwise unless the peer is known not to have the data, add to the retrieve list
-		if p.Lacks(header.Hash()) {
+		if p.Lacks(hash) {
 			skip = append(skip, header)
 		} else {
 			send = append(send, header)
@@ -565,9 +587,6 @@ func (q *queue) cancel(request *fetchRequest, taskQueue *prque.Prque, pendPool m
 	if request.From > 0 {
 		taskQueue.Push(request.From, -float32(request.From))
 	}
-	for hash, index := range request.Hashes {
-		taskQueue.Push(hash, float32(index))
-	}
 	for _, header := range request.Headers {
 		taskQueue.Push(header, -float32(header.Number.Uint64()))
 	}
@@ -640,18 +659,11 @@ func (q *queue) expire(timeout time.Duration, pendPool map[string]*fetchRequest,
 			if request.From > 0 {
 				taskQueue.Push(request.From, -float32(request.From))
 			}
-			for hash, index := range request.Hashes {
-				taskQueue.Push(hash, float32(index))
-			}
 			for _, header := range request.Headers {
 				taskQueue.Push(header, -float32(header.Number.Uint64()))
 			}
 			// Add the peer to the expiry report along the the number of failed requests
-			expirations := len(request.Hashes)
-			if expirations < len(request.Headers) {
-				expirations = len(request.Headers)
-			}
-			expiries[id] = expirations
+			expiries[id] = len(request.Headers)
 		}
 	}
 	// Remove the expired requests from the pending pool
@@ -828,14 +840,16 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
 			failure = err
 			break
 		}
-		donePool[header.Hash()] = struct{}{}
+		hash := header.Hash()
+
+		donePool[hash] = struct{}{}
 		q.resultCache[index].Pending--
 		useful = true
 		accepted++
 
 		// Clean up a successful fetch
 		request.Headers[i] = nil
-		delete(taskPool, header.Hash())
+		delete(taskPool, hash)
 	}
 	// Return all failed or missing fetches to the queue
 	for _, header := range request.Headers {
@@ -860,7 +874,7 @@ func (q *queue) deliver(id string, taskPool map[common.Hash]*types.Header, taskQ
 
 // Prepare configures the result cache to allow accepting and caching inbound
 // fetch results.
-func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.Header) {
+func (q *queue) Prepare(offset uint64, mode SyncMode) {
 	q.lock.Lock()
 	defer q.lock.Unlock()
 
@@ -868,6 +882,5 @@ func (q *queue) Prepare(offset uint64, mode SyncMode, pivot uint64, head *types.
 	if q.resultOffset < offset {
 		q.resultOffset = offset
 	}
-	q.fastSyncPivot = pivot
 	q.mode = mode
 }
diff --git a/eth/downloader/statesync.go b/eth/downloader/statesync.go
index 937828b94..9cc65a208 100644
--- a/eth/downloader/statesync.go
+++ b/eth/downloader/statesync.go
@@ -20,7 +20,6 @@ import (
 	"fmt"
 	"hash"
 	"sync"
-	"sync/atomic"
 	"time"
 
 	"github.com/ethereum/go-ethereum/common"
@@ -294,6 +293,9 @@ func (s *stateSync) loop() error {
 		case <-s.cancel:
 			return errCancelStateFetch
 
+		case <-s.d.cancelCh:
+			return errCancelStateFetch
+
 		case req := <-s.deliver:
 			// Response, disconnect or timeout triggered, drop the peer if stalling
 			log.Trace("Received node data response", "peer", req.peer.id, "count", len(req.response), "dropped", req.dropped, "timeout", !req.dropped && req.timedOut())
@@ -304,15 +306,11 @@ func (s *stateSync) loop() error {
 				s.d.dropPeer(req.peer.id)
 			}
 			// Process all the received blobs and check for stale delivery
-			stale, err := s.process(req)
-			if err != nil {
+			if err := s.process(req); err != nil {
 				log.Warn("Node data write error", "err", err)
 				return err
 			}
-			// The the delivery contains requested data, mark the node idle (otherwise it's a timed out delivery)
-			if !stale {
-				req.peer.SetNodeDataIdle(len(req.response))
-			}
+			req.peer.SetNodeDataIdle(len(req.response))
 		}
 	}
 	return s.commit(true)
@@ -352,6 +350,7 @@ func (s *stateSync) assignTasks() {
 			case s.d.trackStateReq <- req:
 				req.peer.FetchNodeData(req.items)
 			case <-s.cancel:
+			case <-s.d.cancelCh:
 			}
 		}
 	}
@@ -390,7 +389,7 @@ func (s *stateSync) fillTasks(n int, req *stateReq) {
 // process iterates over a batch of delivered state data, injecting each item
 // into a running state sync, re-queuing any items that were requested but not
 // delivered.
-func (s *stateSync) process(req *stateReq) (bool, error) {
+func (s *stateSync) process(req *stateReq) error {
 	// Collect processing stats and update progress if valid data was received
 	duplicate, unexpected := 0, 0
 
@@ -401,7 +400,7 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
 	}(time.Now())
 
 	// Iterate over all the delivered data and inject one-by-one into the trie
-	progress, stale := false, len(req.response) > 0
+	progress := false
 
 	for _, blob := range req.response {
 		prog, hash, err := s.processNodeData(blob)
@@ -415,20 +414,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
 		case trie.ErrAlreadyProcessed:
 			duplicate++
 		default:
-			return stale, fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err)
+			return fmt.Errorf("invalid state node %s: %v", hash.TerminalString(), err)
 		}
-		// If the node delivered a requested item, mark the delivery non-stale
 		if _, ok := req.tasks[hash]; ok {
 			delete(req.tasks, hash)
-			stale = false
 		}
 	}
-	// If we're inside the critical section, reset fail counter since we progressed.
-	if progress && atomic.LoadUint32(&s.d.fsPivotFails) > 1 {
-		log.Trace("Fast-sync progressed, resetting fail counter", "previous", atomic.LoadUint32(&s.d.fsPivotFails))
-		atomic.StoreUint32(&s.d.fsPivotFails, 1) // Don't ever reset to 0, as that will unlock the pivot block
-	}
-
 	// Put unfulfilled tasks back into the retry queue
 	npeers := s.d.peers.Len()
 	for hash, task := range req.tasks {
@@ -441,12 +432,12 @@ func (s *stateSync) process(req *stateReq) (bool, error) {
 		// If we've requested the node too many times already, it may be a malicious
 		// sync where nobody has the right data. Abort.
 		if len(task.attempts) >= npeers {
-			return stale, fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
+			return fmt.Errorf("state node %s failed with all peers (%d tries, %d peers)", hash.TerminalString(), len(task.attempts), npeers)
 		}
 		// Missing item, place into the retry queue.
 		s.tasks[hash] = task
 	}
-	return stale, nil
+	return nil
 }
 
 // processNodeData tries to inject a trie node data blob delivered from a remote
diff --git a/eth/handler.go b/eth/handler.go
index fcd53c5a6..c2426544f 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -71,7 +71,6 @@ type ProtocolManager struct {
 
 	txpool      txPool
 	blockchain  *core.BlockChain
-	chaindb     ethdb.Database
 	chainconfig *params.ChainConfig
 	maxPeers    int
 
@@ -106,7 +105,6 @@ func NewProtocolManager(config *params.ChainConfig, mode downloader.SyncMode, ne
 		eventMux:    mux,
 		txpool:      txpool,
 		blockchain:  blockchain,
-		chaindb:     chaindb,
 		chainconfig: config,
 		peers:       newPeerSet(),
 		newPeerCh:   make(chan *peer),
@@ -538,7 +536,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 				return errResp(ErrDecode, "msg %v: %v", msg, err)
 			}
 			// Retrieve the requested state entry, stopping if enough was found
-			if entry, err := pm.chaindb.Get(hash.Bytes()); err == nil {
+			if entry, err := pm.blockchain.TrieNode(hash); err == nil {
 				data = append(data, entry)
 				bytes += len(entry)
 			}
@@ -576,7 +574,7 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 				return errResp(ErrDecode, "msg %v: %v", msg, err)
 			}
 			// Retrieve the requested block's receipts, skipping if unknown to us
-			results := core.GetBlockReceipts(pm.chaindb, hash, core.GetBlockNumber(pm.chaindb, hash))
+			results := pm.blockchain.GetReceiptsByHash(hash)
 			if results == nil {
 				if header := pm.blockchain.GetHeaderByHash(hash); header == nil || header.ReceiptHash != types.EmptyRootHash {
 					continue
diff --git a/eth/handler_test.go b/eth/handler_test.go
index 9a02eddfb..e336dfa28 100644
--- a/eth/handler_test.go
+++ b/eth/handler_test.go
@@ -56,7 +56,7 @@ func TestProtocolCompatibility(t *testing.T) {
 	for i, tt := range tests {
 		ProtocolVersions = []uint{tt.version}
 
-		pm, err := newTestProtocolManager(tt.mode, 0, nil, nil)
+		pm, _, err := newTestProtocolManager(tt.mode, 0, nil, nil)
 		if pm != nil {
 			defer pm.Stop()
 		}
@@ -71,7 +71,7 @@ func TestGetBlockHeaders62(t *testing.T) { testGetBlockHeaders(t, 62) }
 func TestGetBlockHeaders63(t *testing.T) { testGetBlockHeaders(t, 63) }
 
 func testGetBlockHeaders(t *testing.T, protocol int) {
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil)
+	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxHashFetch+15, nil, nil)
 	peer, _ := newTestPeer("peer", protocol, pm, true)
 	defer peer.close()
 
@@ -230,7 +230,7 @@ func TestGetBlockBodies62(t *testing.T) { testGetBlockBodies(t, 62) }
 func TestGetBlockBodies63(t *testing.T) { testGetBlockBodies(t, 63) }
 
 func testGetBlockBodies(t *testing.T, protocol int) {
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
+	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, downloader.MaxBlockFetch+15, nil, nil)
 	peer, _ := newTestPeer("peer", protocol, pm, true)
 	defer peer.close()
 
@@ -337,13 +337,13 @@ func testGetNodeData(t *testing.T, protocol int) {
 		}
 	}
 	// Assemble the test environment
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
+	pm, db := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
 	peer, _ := newTestPeer("peer", protocol, pm, true)
 	defer peer.close()
 
 	// Fetch for now the entire chain db
 	hashes := []common.Hash{}
-	for _, key := range pm.chaindb.(*ethdb.MemDatabase).Keys() {
+	for _, key := range db.Keys() {
 		if len(key) == len(common.Hash{}) {
 			hashes = append(hashes, common.BytesToHash(key))
 		}
@@ -429,7 +429,7 @@ func testGetReceipt(t *testing.T, protocol int) {
 		}
 	}
 	// Assemble the test environment
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
+	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 4, generator, nil)
 	peer, _ := newTestPeer("peer", protocol, pm, true)
 	defer peer.close()
 
@@ -439,7 +439,7 @@ func testGetReceipt(t *testing.T, protocol int) {
 		block := pm.blockchain.GetBlockByNumber(i)
 
 		hashes = append(hashes, block.Hash())
-		receipts = append(receipts, core.GetBlockReceipts(pm.chaindb, block.Hash(), block.NumberU64()))
+		receipts = append(receipts, pm.blockchain.GetReceiptsByHash(block.Hash()))
 	}
 	// Send the hash request and verify the response
 	p2p.Send(peer.app, 0x0f, hashes)
@@ -472,7 +472,7 @@ func testDAOChallenge(t *testing.T, localForked, remoteForked bool, timeout bool
 		config        = &params.ChainConfig{DAOForkBlock: big.NewInt(1), DAOForkSupport: localForked}
 		gspec         = &core.Genesis{Config: config}
 		genesis       = gspec.MustCommit(db)
-		blockchain, _ = core.NewBlockChain(db, config, pow, vm.Config{})
+		blockchain, _ = core.NewBlockChain(db, nil, config, pow, vm.Config{})
 	)
 	pm, err := NewProtocolManager(config, downloader.FullSync, DefaultConfig.NetworkId, evmux, new(testTxPool), pow, blockchain, db)
 	if err != nil {
diff --git a/eth/helper_test.go b/eth/helper_test.go
index 9a4dc9010..2b05cea80 100644
--- a/eth/helper_test.go
+++ b/eth/helper_test.go
@@ -49,7 +49,7 @@ var (
 // newTestProtocolManager creates a new protocol manager for testing purposes,
 // with the given number of blocks already known, and potential notification
 // channels for different events.
-func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, error) {
+func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase, error) {
 	var (
 		evmux  = new(event.TypeMux)
 		engine = ethash.NewFaker()
@@ -59,7 +59,7 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func
 			Alloc:  core.GenesisAlloc{testBank: {Balance: big.NewInt(1000000)}},
 		}
 		genesis       = gspec.MustCommit(db)
-		blockchain, _ = core.NewBlockChain(db, gspec.Config, engine, vm.Config{})
+		blockchain, _ = core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
 	)
 	chain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
 	if _, err := blockchain.InsertChain(chain); err != nil {
@@ -68,22 +68,22 @@ func newTestProtocolManager(mode downloader.SyncMode, blocks int, generator func
 
 	pm, err := NewProtocolManager(gspec.Config, mode, DefaultConfig.NetworkId, evmux, &testTxPool{added: newtx}, engine, blockchain, db)
 	if err != nil {
-		return nil, err
+		return nil, nil, err
 	}
 	pm.Start(1000)
-	return pm, nil
+	return pm, db, nil
 }
 
 // newTestProtocolManagerMust creates a new protocol manager for testing purposes,
 // with the given number of blocks already known, and potential notification
 // channels for different events. In case of an error, the constructor force-
 // fails the test.
-func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) *ProtocolManager {
-	pm, err := newTestProtocolManager(mode, blocks, generator, newtx)
+func newTestProtocolManagerMust(t *testing.T, mode downloader.SyncMode, blocks int, generator func(int, *core.BlockGen), newtx chan<- []*types.Transaction) (*ProtocolManager, *ethdb.MemDatabase) {
+	pm, db, err := newTestProtocolManager(mode, blocks, generator, newtx)
 	if err != nil {
 		t.Fatalf("Failed to create protocol manager: %v", err)
 	}
-	return pm
+	return pm, db
 }
 
 // testTxPool is a fake, helper transaction pool for testing purposes
diff --git a/eth/protocol_test.go b/eth/protocol_test.go
index 7cbcba571..b2f93d8dd 100644
--- a/eth/protocol_test.go
+++ b/eth/protocol_test.go
@@ -41,7 +41,7 @@ func TestStatusMsgErrors62(t *testing.T) { testStatusMsgErrors(t, 62) }
 func TestStatusMsgErrors63(t *testing.T) { testStatusMsgErrors(t, 63) }
 
 func testStatusMsgErrors(t *testing.T, protocol int) {
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
+	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
 	var (
 		genesis = pm.blockchain.Genesis()
 		head    = pm.blockchain.CurrentHeader()
@@ -98,7 +98,7 @@ func TestRecvTransactions63(t *testing.T) { testRecvTransactions(t, 63) }
 
 func testRecvTransactions(t *testing.T, protocol int) {
 	txAdded := make(chan []*types.Transaction)
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
+	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, txAdded)
 	pm.acceptTxs = 1 // mark synced to accept transactions
 	p, _ := newTestPeer("peer", protocol, pm, true)
 	defer pm.Stop()
@@ -125,7 +125,7 @@ func TestSendTransactions62(t *testing.T) { testSendTransactions(t, 62) }
 func TestSendTransactions63(t *testing.T) { testSendTransactions(t, 63) }
 
 func testSendTransactions(t *testing.T, protocol int) {
-	pm := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
+	pm, _ := newTestProtocolManagerMust(t, downloader.FullSync, 0, nil, nil)
 	defer pm.Stop()
 
 	// Fill the pool with big transactions.
diff --git a/eth/sync_test.go b/eth/sync_test.go
index 9eaa1156f..88c10c7f7 100644
--- a/eth/sync_test.go
+++ b/eth/sync_test.go
@@ -30,12 +30,12 @@ import (
 // imported into the blockchain.
 func TestFastSyncDisabling(t *testing.T) {
 	// Create a pristine protocol manager, check that fast sync is left enabled
-	pmEmpty := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
+	pmEmpty, _ := newTestProtocolManagerMust(t, downloader.FastSync, 0, nil, nil)
 	if atomic.LoadUint32(&pmEmpty.fastSync) == 0 {
 		t.Fatalf("fast sync disabled on pristine blockchain")
 	}
 	// Create a full protocol manager, check that fast sync gets disabled
-	pmFull := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
+	pmFull, _ := newTestProtocolManagerMust(t, downloader.FastSync, 1024, nil, nil)
 	if atomic.LoadUint32(&pmFull.fastSync) == 1 {
 		t.Fatalf("fast sync not disabled on non-empty blockchain")
 	}
diff --git a/internal/ethapi/api.go b/internal/ethapi/api.go
index a4cba7a4d..314086335 100644
--- a/internal/ethapi/api.go
+++ b/internal/ethapi/api.go
@@ -808,7 +808,7 @@ func (s *PublicBlockChainAPI) rpcOutputBlock(b *types.Block, inclTx bool, fullTx
 		"difficulty":       (*hexutil.Big)(head.Difficulty),
 		"totalDifficulty":  (*hexutil.Big)(s.b.GetTd(b.Hash())),
 		"extraData":        hexutil.Bytes(head.Extra),
-		"size":             hexutil.Uint64(uint64(b.Size().Int64())),
+		"size":             hexutil.Uint64(b.Size()),
 		"gasLimit":         hexutil.Uint64(head.GasLimit),
 		"gasUsed":          hexutil.Uint64(head.GasUsed),
 		"timestamp":        (*hexutil.Big)(head.Time),
diff --git a/les/handler.go b/les/handler.go
index 8cd37c7ab..5c93133fb 100644
--- a/les/handler.go
+++ b/les/handler.go
@@ -18,7 +18,6 @@
 package les
 
 import (
-	"bytes"
 	"encoding/binary"
 	"errors"
 	"fmt"
@@ -78,6 +77,7 @@ type BlockChain interface {
 	GetHeaderByHash(hash common.Hash) *types.Header
 	CurrentHeader() *types.Header
 	GetTd(hash common.Hash, number uint64) *big.Int
+	State() (*state.StateDB, error)
 	InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
 	Rollback(chain []common.Hash)
 	GetHeaderByNumber(number uint64) *types.Header
@@ -579,17 +579,19 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		for _, req := range req.Reqs {
 			// Retrieve the requested state entry, stopping if enough was found
 			if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
-				if trie, _ := trie.New(header.Root, pm.chainDb); trie != nil {
-					sdata := trie.Get(req.AccKey)
-					var acc state.Account
-					if err := rlp.DecodeBytes(sdata, &acc); err == nil {
-						entry, _ := pm.chainDb.Get(acc.CodeHash)
-						if bytes+len(entry) >= softResponseLimit {
-							break
-						}
-						data = append(data, entry)
-						bytes += len(entry)
-					}
+				statedb, err := pm.blockchain.State()
+				if err != nil {
+					continue
+				}
+				account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey))
+				if err != nil {
+					continue
+				}
+				code, _ := statedb.Database().TrieDB().Node(common.BytesToHash(account.CodeHash))
+
+				data = append(data, code)
+				if bytes += len(code); bytes >= softResponseLimit {
+					break
 				}
 			}
 		}
@@ -701,25 +703,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			return errResp(ErrRequestRejected, "")
 		}
 		for _, req := range req.Reqs {
-			if bytes >= softResponseLimit {
-				break
-			}
 			// Retrieve the requested state entry, stopping if enough was found
 			if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
-				if tr, _ := trie.New(header.Root, pm.chainDb); tr != nil {
-					if len(req.AccKey) > 0 {
-						sdata := tr.Get(req.AccKey)
-						tr = nil
-						var acc state.Account
-						if err := rlp.DecodeBytes(sdata, &acc); err == nil {
-							tr, _ = trie.New(acc.Root, pm.chainDb)
-						}
+				statedb, err := pm.blockchain.State()
+				if err != nil {
+					continue
+				}
+				var trie state.Trie
+				if len(req.AccKey) > 0 {
+					account, err := pm.getAccount(statedb, header.Root, common.BytesToHash(req.AccKey))
+					if err != nil {
+						continue
 					}
-					if tr != nil {
-						var proof light.NodeList
-						tr.Prove(req.Key, 0, &proof)
-						proofs = append(proofs, proof)
-						bytes += proof.DataSize()
+					trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root)
+				} else {
+					trie, _ = statedb.Database().OpenTrie(header.Root)
+				}
+				if trie != nil {
+					var proof light.NodeList
+					trie.Prove(req.Key, 0, &proof)
+
+					proofs = append(proofs, proof)
+					if bytes += proof.DataSize(); bytes >= softResponseLimit {
+						break
 					}
 				}
 			}
@@ -740,9 +746,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		}
 		// Gather state data until the fetch or network limits is reached
 		var (
-			lastBHash  common.Hash
-			lastAccKey []byte
-			tr, str    *trie.Trie
+			lastBHash common.Hash
+			statedb   *state.StateDB
+			root      common.Hash
 		)
 		reqCnt := len(req.Reqs)
 		if reject(uint64(reqCnt), MaxProofsFetch) {
@@ -752,35 +758,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		nodes := light.NewNodeSet()
 
 		for _, req := range req.Reqs {
-			if nodes.DataSize() >= softResponseLimit {
-				break
-			}
-			if tr == nil || req.BHash != lastBHash {
+			// Look up the state belonging to the request
+			if statedb == nil || req.BHash != lastBHash {
+				statedb, root, lastBHash = nil, common.Hash{}, req.BHash
+
 				if header := core.GetHeader(pm.chainDb, req.BHash, core.GetBlockNumber(pm.chainDb, req.BHash)); header != nil {
-					tr, _ = trie.New(header.Root, pm.chainDb)
-				} else {
-					tr = nil
+					statedb, _ = pm.blockchain.State()
+					root = header.Root
 				}
-				lastBHash = req.BHash
-				str = nil
 			}
-			if tr != nil {
-				if len(req.AccKey) > 0 {
-					if str == nil || !bytes.Equal(req.AccKey, lastAccKey) {
-						sdata := tr.Get(req.AccKey)
-						str = nil
-						var acc state.Account
-						if err := rlp.DecodeBytes(sdata, &acc); err == nil {
-							str, _ = trie.New(acc.Root, pm.chainDb)
-						}
-						lastAccKey = common.CopyBytes(req.AccKey)
-					}
-					if str != nil {
-						str.Prove(req.Key, req.FromLevel, nodes)
-					}
-				} else {
-					tr.Prove(req.Key, req.FromLevel, nodes)
+			if statedb == nil {
+				continue
+			}
+			// Pull the account or storage trie of the request
+			var trie state.Trie
+			if len(req.AccKey) > 0 {
+				account, err := pm.getAccount(statedb, root, common.BytesToHash(req.AccKey))
+				if err != nil {
+					continue
 				}
+				trie, _ = statedb.Database().OpenStorageTrie(common.BytesToHash(req.AccKey), account.Root)
+			} else {
+				trie, _ = statedb.Database().OpenTrie(root)
+			}
+			if trie == nil {
+				continue
+			}
+			// Prove the user's request from the account or stroage trie
+			trie.Prove(req.Key, req.FromLevel, nodes)
+			if nodes.DataSize() >= softResponseLimit {
+				break
 			}
 		}
 		proofs := nodes.NodeList()
@@ -849,23 +856,29 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		if reject(uint64(reqCnt), MaxHelperTrieProofsFetch) {
 			return errResp(ErrRequestRejected, "")
 		}
-		trieDb := ethdb.NewTable(pm.chainDb, light.ChtTablePrefix)
 		for _, req := range req.Reqs {
-			if bytes >= softResponseLimit {
-				break
-			}
-
 			if header := pm.blockchain.GetHeaderByNumber(req.BlockNum); header != nil {
 				sectionHead := core.GetCanonicalHash(pm.chainDb, req.ChtNum*light.ChtV1Frequency-1)
 				if root := light.GetChtRoot(pm.chainDb, req.ChtNum-1, sectionHead); root != (common.Hash{}) {
-					if tr, _ := trie.New(root, trieDb); tr != nil {
-						var encNumber [8]byte
-						binary.BigEndian.PutUint64(encNumber[:], req.BlockNum)
-						var proof light.NodeList
-						tr.Prove(encNumber[:], 0, &proof)
-						proofs = append(proofs, ChtResp{Header: header, Proof: proof})
-						bytes += proof.DataSize() + estHeaderRlpSize
+					statedb, err := pm.blockchain.State()
+					if err != nil {
+						continue
 					}
+					trie, err := statedb.Database().OpenTrie(root)
+					if err != nil {
+						continue
+					}
+					var encNumber [8]byte
+					binary.BigEndian.PutUint64(encNumber[:], req.BlockNum)
+
+					var proof light.NodeList
+					trie.Prove(encNumber[:], 0, &proof)
+
+					proofs = append(proofs, ChtResp{Header: header, Proof: proof})
+					if bytes += proof.DataSize() + estHeaderRlpSize; bytes >= softResponseLimit {
+						break
+					}
+
 				}
 			}
 		}
@@ -897,25 +910,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			lastIdx  uint64
 			lastType uint
 			root     common.Hash
-			tr       *trie.Trie
+			statedb  *state.StateDB
+			trie     state.Trie
 		)
 
 		nodes := light.NewNodeSet()
 
 		for _, req := range req.Reqs {
-			if nodes.DataSize()+auxBytes >= softResponseLimit {
-				break
-			}
-			if tr == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx {
-				var prefix string
-				root, prefix = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx)
-				if root != (common.Hash{}) {
-					if t, err := trie.New(root, ethdb.NewTable(pm.chainDb, prefix)); err == nil {
-						tr = t
+			if trie == nil || req.HelperTrieType != lastType || req.TrieIdx != lastIdx {
+				statedb, trie, lastType, lastIdx = nil, nil, req.HelperTrieType, req.TrieIdx
+
+				if root, _ = pm.getHelperTrie(req.HelperTrieType, req.TrieIdx); root != (common.Hash{}) {
+					if statedb, _ = pm.blockchain.State(); statedb != nil {
+						trie, _ = statedb.Database().OpenTrie(root)
 					}
 				}
-				lastType = req.HelperTrieType
-				lastIdx = req.TrieIdx
 			}
 			if req.AuxReq == auxRoot {
 				var data []byte
@@ -925,8 +934,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 				auxData = append(auxData, data)
 				auxBytes += len(data)
 			} else {
-				if tr != nil {
-					tr.Prove(req.Key, req.FromLevel, nodes)
+				if trie != nil {
+					trie.Prove(req.Key, req.FromLevel, nodes)
 				}
 				if req.AuxReq != 0 {
 					data := pm.getHelperTrieAuxData(req)
@@ -934,6 +943,9 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 					auxBytes += len(data)
 				}
 			}
+			if nodes.DataSize()+auxBytes >= softResponseLimit {
+				break
+			}
 		}
 		proofs := nodes.NodeList()
 		bv, rcost := p.fcClient.RequestProcessed(costs.baseCost + uint64(reqCnt)*costs.reqCost)
@@ -1090,6 +1102,23 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 	return nil
 }
 
+// getAccount retrieves an account from the state based at root.
+func (pm *ProtocolManager) getAccount(statedb *state.StateDB, root, hash common.Hash) (state.Account, error) {
+	trie, err := trie.New(root, statedb.Database().TrieDB())
+	if err != nil {
+		return state.Account{}, err
+	}
+	blob, err := trie.TryGet(hash[:])
+	if err != nil {
+		return state.Account{}, err
+	}
+	var account state.Account
+	if err = rlp.DecodeBytes(blob, &account); err != nil {
+		return state.Account{}, err
+	}
+	return account, nil
+}
+
 // getHelperTrie returns the post-processed trie root for the given trie ID and section index
 func (pm *ProtocolManager) getHelperTrie(id uint, idx uint64) (common.Hash, string) {
 	switch id {
diff --git a/les/handler_test.go b/les/handler_test.go
index 10e5499a3..e5446c031 100644
--- a/les/handler_test.go
+++ b/les/handler_test.go
@@ -359,7 +359,7 @@ func testGetProofs(t *testing.T, protocol int) {
 	for i := uint64(0); i <= bc.CurrentBlock().NumberU64(); i++ {
 		header := bc.GetHeaderByNumber(i)
 		root := header.Root
-		trie, _ := trie.New(root, db)
+		trie, _ := trie.New(root, trie.NewDatabase(db))
 
 		for _, acc := range accounts {
 			req := ProofReq{
diff --git a/les/helper_test.go b/les/helper_test.go
index 1c1de64ad..bf08e1e2f 100644
--- a/les/helper_test.go
+++ b/les/helper_test.go
@@ -146,7 +146,7 @@ func newTestProtocolManager(lightSync bool, blocks int, generator func(int, *cor
 	if lightSync {
 		chain, _ = light.NewLightChain(odr, gspec.Config, engine)
 	} else {
-		blockchain, _ := core.NewBlockChain(db, gspec.Config, engine, vm.Config{})
+		blockchain, _ := core.NewBlockChain(db, nil, gspec.Config, engine, vm.Config{})
 		gchain, _ := core.GenerateChain(gspec.Config, genesis, ethash.NewFaker(), db, blocks, generator)
 		if _, err := blockchain.InsertChain(gchain); err != nil {
 			panic(err)
diff --git a/les/odr_test.go b/les/odr_test.go
index cf609be88..88e121cda 100644
--- a/les/odr_test.go
+++ b/les/odr_test.go
@@ -101,7 +101,6 @@ func odrAccounts(ctx context.Context, db ethdb.Database, config *params.ChainCon
 			res = append(res, rlp...)
 		}
 	}
-
 	return res
 }
 
diff --git a/light/lightchain.go b/light/lightchain.go
index f47957512..24529ef82 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -18,6 +18,7 @@ package light
 
 import (
 	"context"
+	"errors"
 	"math/big"
 	"sync"
 	"sync/atomic"
@@ -26,6 +27,7 @@ import (
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/consensus"
 	"github.com/ethereum/go-ethereum/core"
+	"github.com/ethereum/go-ethereum/core/state"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/event"
@@ -212,6 +214,11 @@ func (bc *LightChain) Genesis() *types.Block {
 	return bc.genesisBlock
 }
 
+// State returns a new mutable state based on the current HEAD block.
+func (bc *LightChain) State() (*state.StateDB, error) {
+	return nil, errors.New("not implemented, needs client/server interface split")
+}
+
 // GetBody retrieves a block body (transactions and uncles) from the database
 // or ODR service by hash, caching it if found.
 func (self *LightChain) GetBody(ctx context.Context, hash common.Hash) (*types.Body, error) {
diff --git a/light/nodeset.go b/light/nodeset.go
index c530a4fbe..ffdb71bb7 100644
--- a/light/nodeset.go
+++ b/light/nodeset.go
@@ -22,8 +22,8 @@ import (
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/rlp"
-	"github.com/ethereum/go-ethereum/trie"
 )
 
 // NodeSet stores a set of trie nodes. It implements trie.Database and can also
@@ -99,7 +99,7 @@ func (db *NodeSet) NodeList() NodeList {
 }
 
 // Store writes the contents of the set to the given database
-func (db *NodeSet) Store(target trie.Database) {
+func (db *NodeSet) Store(target ethdb.Putter) {
 	db.lock.RLock()
 	defer db.lock.RUnlock()
 
@@ -108,11 +108,11 @@ func (db *NodeSet) Store(target trie.Database) {
 	}
 }
 
-// NodeList stores an ordered list of trie nodes. It implements trie.DatabaseWriter.
+// NodeList stores an ordered list of trie nodes. It implements ethdb.Putter.
 type NodeList []rlp.RawValue
 
 // Store writes the contents of the list to the given database
-func (n NodeList) Store(db trie.Database) {
+func (n NodeList) Store(db ethdb.Putter) {
 	for _, node := range n {
 		db.Put(crypto.Keccak256(node), node)
 	}
diff --git a/light/odr_test.go b/light/odr_test.go
index e3d07518a..d3f9374fd 100644
--- a/light/odr_test.go
+++ b/light/odr_test.go
@@ -74,7 +74,7 @@ func (odr *testOdr) Retrieve(ctx context.Context, req OdrRequest) error {
 	case *ReceiptsRequest:
 		req.Receipts = core.GetBlockReceipts(odr.sdb, req.Hash, core.GetBlockNumber(odr.sdb, req.Hash))
 	case *TrieRequest:
-		t, _ := trie.New(req.Id.Root, odr.sdb)
+		t, _ := trie.New(req.Id.Root, trie.NewDatabase(odr.sdb))
 		nodes := NewNodeSet()
 		t.Prove(req.Key, 0, nodes)
 		req.Proof = nodes
@@ -239,7 +239,7 @@ func testChainOdr(t *testing.T, protocol int, fn odrTestFn) {
 	)
 	gspec.MustCommit(ldb)
 	// Assemble the test environment
-	blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
+	blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
 	gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, 4, testChainGen)
 	if _, err := blockchain.InsertChain(gchain); err != nil {
 		t.Fatal(err)
diff --git a/light/postprocess.go b/light/postprocess.go
index 32dbc102b..bbac58d12 100644
--- a/light/postprocess.go
+++ b/light/postprocess.go
@@ -113,7 +113,8 @@ func StoreChtRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root common
 
 // ChtIndexerBackend implements core.ChainIndexerBackend
 type ChtIndexerBackend struct {
-	db, cdb              ethdb.Database
+	diskdb               ethdb.Database
+	triedb               *trie.Database
 	section, sectionSize uint64
 	lastHash             common.Hash
 	trie                 *trie.Trie
@@ -121,8 +122,6 @@ type ChtIndexerBackend struct {
 
 // NewBloomTrieIndexer creates a BloomTrie chain indexer
 func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
-	cdb := ethdb.NewTable(db, ChtTablePrefix)
-	idb := ethdb.NewTable(db, "chtIndex-")
 	var sectionSize, confirmReq uint64
 	if clientMode {
 		sectionSize = ChtFrequency
@@ -131,17 +130,23 @@ func NewChtIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
 		sectionSize = ChtV1Frequency
 		confirmReq = HelperTrieProcessConfirmations
 	}
-	return core.NewChainIndexer(db, idb, &ChtIndexerBackend{db: db, cdb: cdb, sectionSize: sectionSize}, sectionSize, confirmReq, time.Millisecond*100, "cht")
+	idb := ethdb.NewTable(db, "chtIndex-")
+	backend := &ChtIndexerBackend{
+		diskdb:      db,
+		triedb:      trie.NewDatabase(ethdb.NewTable(db, ChtTablePrefix)),
+		sectionSize: sectionSize,
+	}
+	return core.NewChainIndexer(db, idb, backend, sectionSize, confirmReq, time.Millisecond*100, "cht")
 }
 
 // Reset implements core.ChainIndexerBackend
 func (c *ChtIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error {
 	var root common.Hash
 	if section > 0 {
-		root = GetChtRoot(c.db, section-1, lastSectionHead)
+		root = GetChtRoot(c.diskdb, section-1, lastSectionHead)
 	}
 	var err error
-	c.trie, err = trie.New(root, c.cdb)
+	c.trie, err = trie.New(root, c.triedb)
 	c.section = section
 	return err
 }
@@ -151,7 +156,7 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
 	hash, num := header.Hash(), header.Number.Uint64()
 	c.lastHash = hash
 
-	td := core.GetTd(c.db, hash, num)
+	td := core.GetTd(c.diskdb, hash, num)
 	if td == nil {
 		panic(nil)
 	}
@@ -163,17 +168,16 @@ func (c *ChtIndexerBackend) Process(header *types.Header) {
 
 // Commit implements core.ChainIndexerBackend
 func (c *ChtIndexerBackend) Commit() error {
-	batch := c.cdb.NewBatch()
-	root, err := c.trie.CommitTo(batch)
+	root, err := c.trie.Commit(nil)
 	if err != nil {
 		return err
-	} else {
-		batch.Write()
-		if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 {
-			log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root))
-		}
-		StoreChtRoot(c.db, c.section, c.lastHash, root)
 	}
+	c.triedb.Commit(root, false)
+
+	if ((c.section+1)*c.sectionSize)%ChtFrequency == 0 {
+		log.Info("Storing CHT", "idx", c.section*c.sectionSize/ChtFrequency, "sectionHead", fmt.Sprintf("%064x", c.lastHash), "root", fmt.Sprintf("%064x", root))
+	}
+	StoreChtRoot(c.diskdb, c.section, c.lastHash, root)
 	return nil
 }
 
@@ -205,7 +209,8 @@ func StoreBloomTrieRoot(db ethdb.Database, sectionIdx uint64, sectionHead, root
 
 // BloomTrieIndexerBackend implements core.ChainIndexerBackend
 type BloomTrieIndexerBackend struct {
-	db, cdb                                    ethdb.Database
+	diskdb                                     ethdb.Database
+	triedb                                     *trie.Database
 	section, parentSectionSize, bloomTrieRatio uint64
 	trie                                       *trie.Trie
 	sectionHeads                               []common.Hash
@@ -213,9 +218,12 @@ type BloomTrieIndexerBackend struct {
 
 // NewBloomTrieIndexer creates a BloomTrie chain indexer
 func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer {
-	cdb := ethdb.NewTable(db, BloomTrieTablePrefix)
+	backend := &BloomTrieIndexerBackend{
+		diskdb: db,
+		triedb: trie.NewDatabase(ethdb.NewTable(db, BloomTrieTablePrefix)),
+	}
 	idb := ethdb.NewTable(db, "bltIndex-")
-	backend := &BloomTrieIndexerBackend{db: db, cdb: cdb}
+
 	var confirmReq uint64
 	if clientMode {
 		backend.parentSectionSize = BloomTrieFrequency
@@ -233,10 +241,10 @@ func NewBloomTrieIndexer(db ethdb.Database, clientMode bool) *core.ChainIndexer
 func (b *BloomTrieIndexerBackend) Reset(section uint64, lastSectionHead common.Hash) error {
 	var root common.Hash
 	if section > 0 {
-		root = GetBloomTrieRoot(b.db, section-1, lastSectionHead)
+		root = GetBloomTrieRoot(b.diskdb, section-1, lastSectionHead)
 	}
 	var err error
-	b.trie, err = trie.New(root, b.cdb)
+	b.trie, err = trie.New(root, b.triedb)
 	b.section = section
 	return err
 }
@@ -259,7 +267,7 @@ func (b *BloomTrieIndexerBackend) Commit() error {
 		binary.BigEndian.PutUint64(encKey[2:10], b.section)
 		var decomp []byte
 		for j := uint64(0); j < b.bloomTrieRatio; j++ {
-			data, err := core.GetBloomBits(b.db, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
+			data, err := core.GetBloomBits(b.diskdb, i, b.section*b.bloomTrieRatio+j, b.sectionHeads[j])
 			if err != nil {
 				return err
 			}
@@ -279,17 +287,15 @@ func (b *BloomTrieIndexerBackend) Commit() error {
 			b.trie.Delete(encKey[:])
 		}
 	}
-
-	batch := b.cdb.NewBatch()
-	root, err := b.trie.CommitTo(batch)
+	root, err := b.trie.Commit(nil)
 	if err != nil {
 		return err
-	} else {
-		batch.Write()
-		sectionHead := b.sectionHeads[b.bloomTrieRatio-1]
-		log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize))
-		StoreBloomTrieRoot(b.db, b.section, sectionHead, root)
 	}
+	b.triedb.Commit(root, false)
+
+	sectionHead := b.sectionHeads[b.bloomTrieRatio-1]
+	log.Info("Storing BloomTrie", "section", b.section, "sectionHead", fmt.Sprintf("%064x", sectionHead), "root", fmt.Sprintf("%064x", root), "compression ratio", float64(compSize)/float64(decompSize))
+	StoreBloomTrieRoot(b.diskdb, b.section, sectionHead, root)
 
 	return nil
 }
diff --git a/light/trie.go b/light/trie.go
index 7a9c86b98..c07e99461 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -18,12 +18,14 @@ package light
 
 import (
 	"context"
+	"errors"
 	"fmt"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core/state"
 	"github.com/ethereum/go-ethereum/core/types"
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/trie"
 )
 
@@ -83,6 +85,10 @@ func (db *odrDatabase) ContractCodeSize(addrHash, codeHash common.Hash) (int, er
 	return len(code), err
 }
 
+func (db *odrDatabase) TrieDB() *trie.Database {
+	return nil
+}
+
 type odrTrie struct {
 	db   *odrDatabase
 	id   *TrieID
@@ -113,11 +119,11 @@ func (t *odrTrie) TryDelete(key []byte) error {
 	})
 }
 
-func (t *odrTrie) CommitTo(db trie.DatabaseWriter) (common.Hash, error) {
+func (t *odrTrie) Commit(onleaf trie.LeafCallback) (common.Hash, error) {
 	if t.trie == nil {
 		return t.id.Root, nil
 	}
-	return t.trie.CommitTo(db)
+	return t.trie.Commit(onleaf)
 }
 
 func (t *odrTrie) Hash() common.Hash {
@@ -135,13 +141,17 @@ func (t *odrTrie) GetKey(sha []byte) []byte {
 	return nil
 }
 
+func (t *odrTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
+	return errors.New("not implemented, needs client/server interface split")
+}
+
 // do tries and retries to execute a function until it returns with no error or
 // an error type other than MissingNodeError
 func (t *odrTrie) do(key []byte, fn func() error) error {
 	for {
 		var err error
 		if t.trie == nil {
-			t.trie, err = trie.New(t.id.Root, t.db.backend.Database())
+			t.trie, err = trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database()))
 		}
 		if err == nil {
 			err = fn()
@@ -167,7 +177,7 @@ func newNodeIterator(t *odrTrie, startkey []byte) trie.NodeIterator {
 	// Open the actual non-ODR trie if that hasn't happened yet.
 	if t.trie == nil {
 		it.do(func() error {
-			t, err := trie.New(t.id.Root, t.db.backend.Database())
+			t, err := trie.New(t.id.Root, trie.NewDatabase(t.db.backend.Database()))
 			if err == nil {
 				it.t.trie = t
 			}
diff --git a/light/trie_test.go b/light/trie_test.go
index d99664718..0d6b2cc1d 100644
--- a/light/trie_test.go
+++ b/light/trie_test.go
@@ -40,7 +40,7 @@ func TestNodeIterator(t *testing.T) {
 		genesis    = gspec.MustCommit(fulldb)
 	)
 	gspec.MustCommit(lightdb)
-	blockchain, _ := core.NewBlockChain(fulldb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
+	blockchain, _ := core.NewBlockChain(fulldb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
 	gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), fulldb, 4, testChainGen)
 	if _, err := blockchain.InsertChain(gchain); err != nil {
 		panic(err)
diff --git a/light/txpool_test.go b/light/txpool_test.go
index b343f79b0..13d7d3ceb 100644
--- a/light/txpool_test.go
+++ b/light/txpool_test.go
@@ -88,7 +88,7 @@ func TestTxPool(t *testing.T) {
 	)
 	gspec.MustCommit(ldb)
 	// Assemble the test environment
-	blockchain, _ := core.NewBlockChain(sdb, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
+	blockchain, _ := core.NewBlockChain(sdb, nil, params.TestChainConfig, ethash.NewFullFaker(), vm.Config{})
 	gchain, _ := core.GenerateChain(params.TestChainConfig, genesis, ethash.NewFaker(), sdb, poolTestBlocks, txPoolTestChainGen)
 	if _, err := blockchain.InsertChain(gchain); err != nil {
 		panic(err)
diff --git a/miner/worker.go b/miner/worker.go
index 1520277e1..15395ae0b 100644
--- a/miner/worker.go
+++ b/miner/worker.go
@@ -309,7 +309,7 @@ func (self *worker) wait() {
 			for _, log := range work.state.Logs() {
 				log.BlockHash = block.Hash()
 			}
-			stat, err := self.chain.WriteBlockAndState(block, work.receipts, work.state)
+			stat, err := self.chain.WriteBlockWithState(block, work.receipts, work.state)
 			if err != nil {
 				log.Error("Failed writing block to chain", "err", err)
 				continue
diff --git a/tests/block_test_util.go b/tests/block_test_util.go
index 4bfd6433f..beba48483 100644
--- a/tests/block_test_util.go
+++ b/tests/block_test_util.go
@@ -110,7 +110,7 @@ func (t *BlockTest) Run() error {
 		return fmt.Errorf("genesis block state root does not match test: computed=%x, test=%x", gblock.Root().Bytes()[:6], t.json.Genesis.StateRoot[:6])
 	}
 
-	chain, err := core.NewBlockChain(db, config, ethash.NewShared(), vm.Config{})
+	chain, err := core.NewBlockChain(db, nil, config, ethash.NewShared(), vm.Config{})
 	if err != nil {
 		return err
 	}
diff --git a/tests/state_test_util.go b/tests/state_test_util.go
index 78c05b024..18280d2a4 100644
--- a/tests/state_test_util.go
+++ b/tests/state_test_util.go
@@ -125,7 +125,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
 	if !ok {
 		return nil, UnsupportedForkError{subtest.Fork}
 	}
-	block, _ := t.genesis(config).ToBlock()
+	block := t.genesis(config).ToBlock(nil)
 	db, _ := ethdb.NewMemDatabase()
 	statedb := MakePreState(db, t.json.Pre)
 
@@ -147,7 +147,7 @@ func (t *StateTest) Run(subtest StateSubtest, vmconfig vm.Config) (*state.StateD
 	if logs := rlpHash(statedb.Logs()); logs != common.Hash(post.Logs) {
 		return statedb, fmt.Errorf("post state logs hash mismatch: got %x, want %x", logs, post.Logs)
 	}
-	root, _ := statedb.CommitTo(db, config.IsEIP158(block.Number()))
+	root, _ := statedb.Commit(config.IsEIP158(block.Number()))
 	if root != common.Hash(post.Root) {
 		return statedb, fmt.Errorf("post state root mismatch: got %x, want %x", root, post.Root)
 	}
@@ -170,7 +170,7 @@ func MakePreState(db ethdb.Database, accounts core.GenesisAlloc) *state.StateDB
 		}
 	}
 	// Commit and re-open to start with a clean state.
-	root, _ := statedb.CommitTo(db, false)
+	root, _ := statedb.Commit(false)
 	statedb, _ = state.New(root, sdb)
 	return statedb
 }
diff --git a/trie/database.go b/trie/database.go
new file mode 100644
index 000000000..d79120813
--- /dev/null
+++ b/trie/database.go
@@ -0,0 +1,355 @@
+// Copyright 2017 The go-ethereum Authors
+// This file is part of the go-ethereum library.
+//
+// The go-ethereum library is free software: you can redistribute it and/or modify
+// it under the terms of the GNU Lesser General Public License as published by
+// the Free Software Foundation, either version 3 of the License, or
+// (at your option) any later version.
+//
+// The go-ethereum library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+// GNU Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public License
+// along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
+
+package trie
+
+import (
+	"sync"
+	"time"
+
+	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/ethdb"
+	"github.com/ethereum/go-ethereum/log"
+)
+
+// secureKeyPrefix is the database key prefix used to store trie node preimages.
+var secureKeyPrefix = []byte("secure-key-")
+
+// secureKeyLength is the length of the above prefix + 32byte hash.
+const secureKeyLength = 11 + 32
+
+// DatabaseReader wraps the Get and Has method of a backing store for the trie.
+type DatabaseReader interface {
+	// Get retrieves the value associated with key form the database.
+	Get(key []byte) (value []byte, err error)
+
+	// Has retrieves whether a key is present in the database.
+	Has(key []byte) (bool, error)
+}
+
+// Database is an intermediate write layer between the trie data structures and
+// the disk database. The aim is to accumulate trie writes in-memory and only
+// periodically flush a couple tries to disk, garbage collecting the remainder.
+type Database struct {
+	diskdb ethdb.Database // Persistent storage for matured trie nodes
+
+	nodes     map[common.Hash]*cachedNode // Data and references relationships of a node
+	preimages map[common.Hash][]byte      // Preimages of nodes from the secure trie
+	seckeybuf [secureKeyLength]byte       // Ephemeral buffer for calculating preimage keys
+
+	gctime  time.Duration      // Time spent on garbage collection since last commit
+	gcnodes uint64             // Nodes garbage collected since last commit
+	gcsize  common.StorageSize // Data storage garbage collected since last commit
+
+	nodesSize     common.StorageSize // Storage size of the nodes cache
+	preimagesSize common.StorageSize // Storage size of the preimages cache
+
+	lock sync.RWMutex
+}
+
+// cachedNode is all the information we know about a single cached node in the
+// memory database write layer.
+type cachedNode struct {
+	blob     []byte              // Cached data block of the trie node
+	parents  int                 // Number of live nodes referencing this one
+	children map[common.Hash]int // Children referenced by this nodes
+}
+
+// NewDatabase creates a new trie database to store ephemeral trie content before
+// its written out to disk or garbage collected.
+func NewDatabase(diskdb ethdb.Database) *Database {
+	return &Database{
+		diskdb: diskdb,
+		nodes: map[common.Hash]*cachedNode{
+			{}: {children: make(map[common.Hash]int)},
+		},
+		preimages: make(map[common.Hash][]byte),
+	}
+}
+
+// DiskDB retrieves the persistent storage backing the trie database.
+func (db *Database) DiskDB() DatabaseReader {
+	return db.diskdb
+}
+
+// Insert writes a new trie node to the memory database if it's yet unknown. The
+// method will make a copy of the slice.
+func (db *Database) Insert(hash common.Hash, blob []byte) {
+	db.lock.Lock()
+	defer db.lock.Unlock()
+
+	db.insert(hash, blob)
+}
+
+// insert is the private locked version of Insert.
+func (db *Database) insert(hash common.Hash, blob []byte) {
+	if _, ok := db.nodes[hash]; ok {
+		return
+	}
+	db.nodes[hash] = &cachedNode{
+		blob:     common.CopyBytes(blob),
+		children: make(map[common.Hash]int),
+	}
+	db.nodesSize += common.StorageSize(common.HashLength + len(blob))
+}
+
+// insertPreimage writes a new trie node pre-image to the memory database if it's
+// yet unknown. The method will make a copy of the slice.
+//
+// Note, this method assumes that the database's lock is held!
+func (db *Database) insertPreimage(hash common.Hash, preimage []byte) {
+	if _, ok := db.preimages[hash]; ok {
+		return
+	}
+	db.preimages[hash] = common.CopyBytes(preimage)
+	db.preimagesSize += common.StorageSize(common.HashLength + len(preimage))
+}
+
+// Node retrieves a cached trie node from memory. If it cannot be found cached,
+// the method queries the persistent database for the content.
+func (db *Database) Node(hash common.Hash) ([]byte, error) {
+	// Retrieve the node from cache if available
+	db.lock.RLock()
+	node := db.nodes[hash]
+	db.lock.RUnlock()
+
+	if node != nil {
+		return node.blob, nil
+	}
+	// Content unavailable in memory, attempt to retrieve from disk
+	return db.diskdb.Get(hash[:])
+}
+
+// preimage retrieves a cached trie node pre-image from memory. If it cannot be
+// found cached, the method queries the persistent database for the content.
+func (db *Database) preimage(hash common.Hash) ([]byte, error) {
+	// Retrieve the node from cache if available
+	db.lock.RLock()
+	preimage := db.preimages[hash]
+	db.lock.RUnlock()
+
+	if preimage != nil {
+		return preimage, nil
+	}
+	// Content unavailable in memory, attempt to retrieve from disk
+	return db.diskdb.Get(db.secureKey(hash[:]))
+}
+
+// secureKey returns the database key for the preimage of key, as an ephemeral
+// buffer. The caller must not hold onto the return value because it will become
+// invalid on the next call.
+func (db *Database) secureKey(key []byte) []byte {
+	buf := append(db.seckeybuf[:0], secureKeyPrefix...)
+	buf = append(buf, key...)
+	return buf
+}
+
+// Nodes retrieves the hashes of all the nodes cached within the memory database.
+// This method is extremely expensive and should only be used to validate internal
+// states in test code.
+func (db *Database) Nodes() []common.Hash {
+	db.lock.RLock()
+	defer db.lock.RUnlock()
+
+	var hashes = make([]common.Hash, 0, len(db.nodes))
+	for hash := range db.nodes {
+		if hash != (common.Hash{}) { // Special case for "root" references/nodes
+			hashes = append(hashes, hash)
+		}
+	}
+	return hashes
+}
+
+// Reference adds a new reference from a parent node to a child node.
+func (db *Database) Reference(child common.Hash, parent common.Hash) {
+	db.lock.RLock()
+	defer db.lock.RUnlock()
+
+	db.reference(child, parent)
+}
+
+// reference is the private locked version of Reference.
+func (db *Database) reference(child common.Hash, parent common.Hash) {
+	// If the node does not exist, it's a node pulled from disk, skip
+	node, ok := db.nodes[child]
+	if !ok {
+		return
+	}
+	// If the reference already exists, only duplicate for roots
+	if _, ok = db.nodes[parent].children[child]; ok && parent != (common.Hash{}) {
+		return
+	}
+	node.parents++
+	db.nodes[parent].children[child]++
+}
+
+// Dereference removes an existing reference from a parent node to a child node.
+func (db *Database) Dereference(child common.Hash, parent common.Hash) {
+	db.lock.Lock()
+	defer db.lock.Unlock()
+
+	nodes, storage, start := len(db.nodes), db.nodesSize, time.Now()
+	db.dereference(child, parent)
+
+	db.gcnodes += uint64(nodes - len(db.nodes))
+	db.gcsize += storage - db.nodesSize
+	db.gctime += time.Since(start)
+
+	log.Debug("Dereferenced trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start),
+		"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize)
+}
+
+// dereference is the private locked version of Dereference.
+func (db *Database) dereference(child common.Hash, parent common.Hash) {
+	// Dereference the parent-child
+	node := db.nodes[parent]
+
+	node.children[child]--
+	if node.children[child] == 0 {
+		delete(node.children, child)
+	}
+	// If the node does not exist, it's a previously committed node.
+	node, ok := db.nodes[child]
+	if !ok {
+		return
+	}
+	// If there are no more references to the child, delete it and cascade
+	node.parents--
+	if node.parents == 0 {
+		for hash := range node.children {
+			db.dereference(hash, child)
+		}
+		delete(db.nodes, child)
+		db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob))
+	}
+}
+
+// Commit iterates over all the children of a particular node, writes them out
+// to disk, forcefully tearing down all references in both directions.
+//
+// As a side effect, all pre-images accumulated up to this point are also written.
+func (db *Database) Commit(node common.Hash, report bool) error {
+	// Create a database batch to flush persistent data out. It is important that
+	// outside code doesn't see an inconsistent state (referenced data removed from
+	// memory cache during commit but not yet in persistent storage). This is ensured
+	// by only uncaching existing data when the database write finalizes.
+	db.lock.RLock()
+
+	start := time.Now()
+	batch := db.diskdb.NewBatch()
+
+	// Move all of the accumulated preimages into a write batch
+	for hash, preimage := range db.preimages {
+		if err := batch.Put(db.secureKey(hash[:]), preimage); err != nil {
+			log.Error("Failed to commit preimage from trie database", "err", err)
+			db.lock.RUnlock()
+			return err
+		}
+		if batch.ValueSize() > ethdb.IdealBatchSize {
+			if err := batch.Write(); err != nil {
+				return err
+			}
+			batch.Reset()
+		}
+	}
+	// Move the trie itself into the batch, flushing if enough data is accumulated
+	nodes, storage := len(db.nodes), db.nodesSize+db.preimagesSize
+	if err := db.commit(node, batch); err != nil {
+		log.Error("Failed to commit trie from trie database", "err", err)
+		db.lock.RUnlock()
+		return err
+	}
+	// Write batch ready, unlock for readers during persistence
+	if err := batch.Write(); err != nil {
+		log.Error("Failed to write trie to disk", "err", err)
+		db.lock.RUnlock()
+		return err
+	}
+	db.lock.RUnlock()
+
+	// Write successful, clear out the flushed data
+	db.lock.Lock()
+	defer db.lock.Unlock()
+
+	db.preimages = make(map[common.Hash][]byte)
+	db.preimagesSize = 0
+
+	db.uncache(node)
+
+	logger := log.Info
+	if !report {
+		logger = log.Debug
+	}
+	logger("Persisted trie from memory database", "nodes", nodes-len(db.nodes), "size", storage-db.nodesSize, "time", time.Since(start),
+		"gcnodes", db.gcnodes, "gcsize", db.gcsize, "gctime", db.gctime, "livenodes", len(db.nodes), "livesize", db.nodesSize)
+
+	// Reset the garbage collection statistics
+	db.gcnodes, db.gcsize, db.gctime = 0, 0, 0
+
+	return nil
+}
+
+// commit is the private locked version of Commit.
+func (db *Database) commit(hash common.Hash, batch ethdb.Batch) error {
+	// If the node does not exist, it's a previously committed node
+	node, ok := db.nodes[hash]
+	if !ok {
+		return nil
+	}
+	for child := range node.children {
+		if err := db.commit(child, batch); err != nil {
+			return err
+		}
+	}
+	if err := batch.Put(hash[:], node.blob); err != nil {
+		return err
+	}
+	// If we've reached an optimal match size, commit and start over
+	if batch.ValueSize() >= ethdb.IdealBatchSize {
+		if err := batch.Write(); err != nil {
+			return err
+		}
+		batch.Reset()
+	}
+	return nil
+}
+
+// uncache is the post-processing step of a commit operation where the already
+// persisted trie is removed from the cache. The reason behind the two-phase
+// commit is to ensure consistent data availability while moving from memory
+// to disk.
+func (db *Database) uncache(hash common.Hash) {
+	// If the node does not exist, we're done on this path
+	node, ok := db.nodes[hash]
+	if !ok {
+		return
+	}
+	// Otherwise uncache the node's subtries and remove the node itself too
+	for child := range node.children {
+		db.uncache(child)
+	}
+	delete(db.nodes, hash)
+	db.nodesSize -= common.StorageSize(common.HashLength + len(node.blob))
+}
+
+// Size returns the current storage size of the memory cache in front of the
+// persistent database layer.
+func (db *Database) Size() common.StorageSize {
+	db.lock.RLock()
+	defer db.lock.RUnlock()
+
+	return db.nodesSize + db.preimagesSize
+}
diff --git a/trie/hasher.go b/trie/hasher.go
index 4719aabf6..2fc44787a 100644
--- a/trie/hasher.go
+++ b/trie/hasher.go
@@ -27,21 +27,23 @@ import (
 )
 
 type hasher struct {
-	tmp                  *bytes.Buffer
-	sha                  hash.Hash
-	cachegen, cachelimit uint16
+	tmp        *bytes.Buffer
+	sha        hash.Hash
+	cachegen   uint16
+	cachelimit uint16
+	onleaf     LeafCallback
 }
 
-// hashers live in a global pool.
+// hashers live in a global db.
 var hasherPool = sync.Pool{
 	New: func() interface{} {
 		return &hasher{tmp: new(bytes.Buffer), sha: sha3.NewKeccak256()}
 	},
 }
 
-func newHasher(cachegen, cachelimit uint16) *hasher {
+func newHasher(cachegen, cachelimit uint16, onleaf LeafCallback) *hasher {
 	h := hasherPool.Get().(*hasher)
-	h.cachegen, h.cachelimit = cachegen, cachelimit
+	h.cachegen, h.cachelimit, h.onleaf = cachegen, cachelimit, onleaf
 	return h
 }
 
@@ -51,7 +53,7 @@ func returnHasherToPool(h *hasher) {
 
 // hash collapses a node down into a hash node, also returning a copy of the
 // original node initialized with the computed hash to replace the original one.
-func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error) {
+func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) {
 	// If we're not storing the node, just hashing, use available cached data
 	if hash, dirty := n.cache(); hash != nil {
 		if db == nil {
@@ -98,7 +100,7 @@ func (h *hasher) hash(n node, db DatabaseWriter, force bool) (node, node, error)
 // hashChildren replaces the children of a node with their hashes if the encoded
 // size of the child is larger than a hash, returning the collapsed node as well
 // as a replacement for the original node with the child hashes cached in.
-func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, error) {
+func (h *hasher) hashChildren(original node, db *Database) (node, node, error) {
 	var err error
 
 	switch n := original.(type) {
@@ -145,7 +147,10 @@ func (h *hasher) hashChildren(original node, db DatabaseWriter) (node, node, err
 	}
 }
 
-func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
+// store hashes the node n and if we have a storage layer specified, it writes
+// the key/value pair to it and tracks any node->child references as well as any
+// node->external trie references.
+func (h *hasher) store(n node, db *Database, force bool) (node, error) {
 	// Don't store hashes or empty nodes.
 	if _, isHash := n.(hashNode); n == nil || isHash {
 		return n, nil
@@ -155,7 +160,6 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
 	if err := rlp.Encode(h.tmp, n); err != nil {
 		panic("encode error: " + err.Error())
 	}
-
 	if h.tmp.Len() < 32 && !force {
 		return n, nil // Nodes smaller than 32 bytes are stored inside their parent
 	}
@@ -167,7 +171,42 @@ func (h *hasher) store(n node, db DatabaseWriter, force bool) (node, error) {
 		hash = hashNode(h.sha.Sum(nil))
 	}
 	if db != nil {
-		return hash, db.Put(hash, h.tmp.Bytes())
+		// We are pooling the trie nodes into an intermediate memory cache
+		db.lock.Lock()
+
+		hash := common.BytesToHash(hash)
+		db.insert(hash, h.tmp.Bytes())
+
+		// Track all direct parent->child node references
+		switch n := n.(type) {
+		case *shortNode:
+			if child, ok := n.Val.(hashNode); ok {
+				db.reference(common.BytesToHash(child), hash)
+			}
+		case *fullNode:
+			for i := 0; i < 16; i++ {
+				if child, ok := n.Children[i].(hashNode); ok {
+					db.reference(common.BytesToHash(child), hash)
+				}
+			}
+		}
+		db.lock.Unlock()
+
+		// Track external references from account->storage trie
+		if h.onleaf != nil {
+			switch n := n.(type) {
+			case *shortNode:
+				if child, ok := n.Val.(valueNode); ok {
+					h.onleaf(child, hash)
+				}
+			case *fullNode:
+				for i := 0; i < 16; i++ {
+					if child, ok := n.Children[i].(valueNode); ok {
+						h.onleaf(child, hash)
+					}
+				}
+			}
+		}
 	}
 	return hash, nil
 }
diff --git a/trie/iterator_test.go b/trie/iterator_test.go
index 4808d8b0c..dce1c78b5 100644
--- a/trie/iterator_test.go
+++ b/trie/iterator_test.go
@@ -42,7 +42,7 @@ func TestIterator(t *testing.T) {
 		all[val.k] = val.v
 		trie.Update([]byte(val.k), []byte(val.v))
 	}
-	trie.Commit()
+	trie.Commit(nil)
 
 	found := make(map[string]string)
 	it := NewIterator(trie.NodeIterator(nil))
@@ -109,11 +109,18 @@ func TestNodeIteratorCoverage(t *testing.T) {
 	}
 	// Cross check the hashes and the database itself
 	for hash := range hashes {
-		if _, err := db.Get(hash.Bytes()); err != nil {
+		if _, err := db.Node(hash); err != nil {
 			t.Errorf("failed to retrieve reported node %x: %v", hash, err)
 		}
 	}
-	for _, key := range db.(*ethdb.MemDatabase).Keys() {
+	for hash, obj := range db.nodes {
+		if obj != nil && hash != (common.Hash{}) {
+			if _, ok := hashes[hash]; !ok {
+				t.Errorf("state entry not reported %x", hash)
+			}
+		}
+	}
+	for _, key := range db.diskdb.(*ethdb.MemDatabase).Keys() {
 		if _, ok := hashes[common.BytesToHash(key)]; !ok {
 			t.Errorf("state entry not reported %x", key)
 		}
@@ -191,13 +198,13 @@ func TestDifferenceIterator(t *testing.T) {
 	for _, val := range testdata1 {
 		triea.Update([]byte(val.k), []byte(val.v))
 	}
-	triea.Commit()
+	triea.Commit(nil)
 
 	trieb := newEmpty()
 	for _, val := range testdata2 {
 		trieb.Update([]byte(val.k), []byte(val.v))
 	}
-	trieb.Commit()
+	trieb.Commit(nil)
 
 	found := make(map[string]string)
 	di, _ := NewDifferenceIterator(triea.NodeIterator(nil), trieb.NodeIterator(nil))
@@ -227,13 +234,13 @@ func TestUnionIterator(t *testing.T) {
 	for _, val := range testdata1 {
 		triea.Update([]byte(val.k), []byte(val.v))
 	}
-	triea.Commit()
+	triea.Commit(nil)
 
 	trieb := newEmpty()
 	for _, val := range testdata2 {
 		trieb.Update([]byte(val.k), []byte(val.v))
 	}
-	trieb.Commit()
+	trieb.Commit(nil)
 
 	di, _ := NewUnionIterator([]NodeIterator{triea.NodeIterator(nil), trieb.NodeIterator(nil)})
 	it := NewIterator(di)
@@ -278,43 +285,75 @@ func TestIteratorNoDups(t *testing.T) {
 }
 
 // This test checks that nodeIterator.Next can be retried after inserting missing trie nodes.
-func TestIteratorContinueAfterError(t *testing.T) {
-	db, _ := ethdb.NewMemDatabase()
-	tr, _ := New(common.Hash{}, db)
+func TestIteratorContinueAfterErrorDisk(t *testing.T)    { testIteratorContinueAfterError(t, false) }
+func TestIteratorContinueAfterErrorMemonly(t *testing.T) { testIteratorContinueAfterError(t, true) }
+
+func testIteratorContinueAfterError(t *testing.T, memonly bool) {
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+
+	tr, _ := New(common.Hash{}, triedb)
 	for _, val := range testdata1 {
 		tr.Update([]byte(val.k), []byte(val.v))
 	}
-	tr.Commit()
+	tr.Commit(nil)
+	if !memonly {
+		triedb.Commit(tr.Hash(), true)
+	}
 	wantNodeCount := checkIteratorNoDups(t, tr.NodeIterator(nil), nil)
-	keys := db.Keys()
-	t.Log("node count", wantNodeCount)
 
+	var (
+		diskKeys [][]byte
+		memKeys  []common.Hash
+	)
+	if memonly {
+		memKeys = triedb.Nodes()
+	} else {
+		diskKeys = diskdb.Keys()
+	}
 	for i := 0; i < 20; i++ {
 		// Create trie that will load all nodes from DB.
-		tr, _ := New(tr.Hash(), db)
+		tr, _ := New(tr.Hash(), triedb)
 
 		// Remove a random node from the database. It can't be the root node
 		// because that one is already loaded.
-		var rkey []byte
+		var (
+			rkey common.Hash
+			rval []byte
+			robj *cachedNode
+		)
 		for {
-			if rkey = keys[rand.Intn(len(keys))]; !bytes.Equal(rkey, tr.Hash().Bytes()) {
+			if memonly {
+				rkey = memKeys[rand.Intn(len(memKeys))]
+			} else {
+				copy(rkey[:], diskKeys[rand.Intn(len(diskKeys))])
+			}
+			if rkey != tr.Hash() {
 				break
 			}
 		}
-		rval, _ := db.Get(rkey)
-		db.Delete(rkey)
-
+		if memonly {
+			robj = triedb.nodes[rkey]
+			delete(triedb.nodes, rkey)
+		} else {
+			rval, _ = diskdb.Get(rkey[:])
+			diskdb.Delete(rkey[:])
+		}
 		// Iterate until the error is hit.
 		seen := make(map[string]bool)
 		it := tr.NodeIterator(nil)
 		checkIteratorNoDups(t, it, seen)
 		missing, ok := it.Error().(*MissingNodeError)
-		if !ok || !bytes.Equal(missing.NodeHash[:], rkey) {
+		if !ok || missing.NodeHash != rkey {
 			t.Fatal("didn't hit missing node, got", it.Error())
 		}
 
 		// Add the node back and continue iteration.
-		db.Put(rkey, rval)
+		if memonly {
+			triedb.nodes[rkey] = robj
+		} else {
+			diskdb.Put(rkey[:], rval)
+		}
 		checkIteratorNoDups(t, it, seen)
 		if it.Error() != nil {
 			t.Fatal("unexpected error", it.Error())
@@ -328,21 +367,41 @@ func TestIteratorContinueAfterError(t *testing.T) {
 // Similar to the test above, this one checks that failure to create nodeIterator at a
 // certain key prefix behaves correctly when Next is called. The expectation is that Next
 // should retry seeking before returning true for the first time.
-func TestIteratorContinueAfterSeekError(t *testing.T) {
+func TestIteratorContinueAfterSeekErrorDisk(t *testing.T) {
+	testIteratorContinueAfterSeekError(t, false)
+}
+func TestIteratorContinueAfterSeekErrorMemonly(t *testing.T) {
+	testIteratorContinueAfterSeekError(t, true)
+}
+
+func testIteratorContinueAfterSeekError(t *testing.T, memonly bool) {
 	// Commit test trie to db, then remove the node containing "bars".
-	db, _ := ethdb.NewMemDatabase()
-	ctr, _ := New(common.Hash{}, db)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+
+	ctr, _ := New(common.Hash{}, triedb)
 	for _, val := range testdata1 {
 		ctr.Update([]byte(val.k), []byte(val.v))
 	}
-	root, _ := ctr.Commit()
+	root, _ := ctr.Commit(nil)
+	if !memonly {
+		triedb.Commit(root, true)
+	}
 	barNodeHash := common.HexToHash("05041990364eb72fcb1127652ce40d8bab765f2bfe53225b1170d276cc101c2e")
-	barNode, _ := db.Get(barNodeHash[:])
-	db.Delete(barNodeHash[:])
-
+	var (
+		barNodeBlob []byte
+		barNodeObj  *cachedNode
+	)
+	if memonly {
+		barNodeObj = triedb.nodes[barNodeHash]
+		delete(triedb.nodes, barNodeHash)
+	} else {
+		barNodeBlob, _ = diskdb.Get(barNodeHash[:])
+		diskdb.Delete(barNodeHash[:])
+	}
 	// Create a new iterator that seeks to "bars". Seeking can't proceed because
 	// the node is missing.
-	tr, _ := New(root, db)
+	tr, _ := New(root, triedb)
 	it := tr.NodeIterator([]byte("bars"))
 	missing, ok := it.Error().(*MissingNodeError)
 	if !ok {
@@ -350,10 +409,12 @@ func TestIteratorContinueAfterSeekError(t *testing.T) {
 	} else if missing.NodeHash != barNodeHash {
 		t.Fatal("wrong node missing")
 	}
-
 	// Reinsert the missing node.
-	db.Put(barNodeHash[:], barNode[:])
-
+	if memonly {
+		triedb.nodes[barNodeHash] = barNodeObj
+	} else {
+		diskdb.Put(barNodeHash[:], barNodeBlob)
+	}
 	// Check that iteration produces the right set of values.
 	if err := checkIteratorOrder(testdata1[2:], NewIterator(it)); err != nil {
 		t.Fatal(err)
diff --git a/trie/proof.go b/trie/proof.go
index 5e886a259..508e4a6cf 100644
--- a/trie/proof.go
+++ b/trie/proof.go
@@ -22,20 +22,19 @@ import (
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/crypto"
+	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 
-// Prove constructs a merkle proof for key. The result contains all
-// encoded nodes on the path to the value at key. The value itself is
-// also included in the last node and can be retrieved by verifying
-// the proof.
+// Prove constructs a merkle proof for key. The result contains all encoded nodes
+// on the path to the value at key. The value itself is also included in the last
+// node and can be retrieved by verifying the proof.
 //
-// If the trie does not contain a value for key, the returned proof
-// contains all nodes of the longest existing prefix of the key
-// (at least the root node), ending with the node that proves the
-// absence of the key.
-func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
+// If the trie does not contain a value for key, the returned proof contains all
+// nodes of the longest existing prefix of the key (at least the root node), ending
+// with the node that proves the absence of the key.
+func (t *Trie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
 	// Collect all nodes on the path to key.
 	key = keybytesToHex(key)
 	nodes := []node{}
@@ -66,7 +65,7 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
 			panic(fmt.Sprintf("%T: invalid node: %v", tn, tn))
 		}
 	}
-	hasher := newHasher(0, 0)
+	hasher := newHasher(0, 0, nil)
 	for i, n := range nodes {
 		// Don't bother checking for errors here since hasher panics
 		// if encoding doesn't work and we're not writing to any database.
@@ -89,19 +88,29 @@ func (t *Trie) Prove(key []byte, fromLevel uint, proofDb DatabaseWriter) error {
 	return nil
 }
 
-// VerifyProof checks merkle proofs. The given proof must contain the
-// value for key in a trie with the given root hash. VerifyProof
-// returns an error if the proof contains invalid trie nodes or the
-// wrong value.
+// Prove constructs a merkle proof for key. The result contains all encoded nodes
+// on the path to the value at key. The value itself is also included in the last
+// node and can be retrieved by verifying the proof.
+//
+// If the trie does not contain a value for key, the returned proof contains all
+// nodes of the longest existing prefix of the key (at least the root node), ending
+// with the node that proves the absence of the key.
+func (t *SecureTrie) Prove(key []byte, fromLevel uint, proofDb ethdb.Putter) error {
+	return t.trie.Prove(key, fromLevel, proofDb)
+}
+
+// VerifyProof checks merkle proofs. The given proof must contain the value for
+// key in a trie with the given root hash. VerifyProof returns an error if the
+// proof contains invalid trie nodes or the wrong value.
 func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (value []byte, err error, nodes int) {
 	key = keybytesToHex(key)
-	wantHash := rootHash[:]
+	wantHash := rootHash
 	for i := 0; ; i++ {
-		buf, _ := proofDb.Get(wantHash)
+		buf, _ := proofDb.Get(wantHash[:])
 		if buf == nil {
-			return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash[:]), i
+			return nil, fmt.Errorf("proof node %d (hash %064x) missing", i, wantHash), i
 		}
-		n, err := decodeNode(wantHash, buf, 0)
+		n, err := decodeNode(wantHash[:], buf, 0)
 		if err != nil {
 			return nil, fmt.Errorf("bad proof node %d: %v", i, err), i
 		}
@@ -112,7 +121,7 @@ func VerifyProof(rootHash common.Hash, key []byte, proofDb DatabaseReader) (valu
 			return nil, nil, i
 		case hashNode:
 			key = keyrest
-			wantHash = cld
+			copy(wantHash[:], cld)
 		case valueNode:
 			return cld, nil, i + 1
 		}
diff --git a/trie/secure_trie.go b/trie/secure_trie.go
index 20c303f31..3881ee18a 100644
--- a/trie/secure_trie.go
+++ b/trie/secure_trie.go
@@ -23,10 +23,6 @@ import (
 	"github.com/ethereum/go-ethereum/log"
 )
 
-var secureKeyPrefix = []byte("secure-key-")
-
-const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash
-
 // SecureTrie wraps a trie with key hashing. In a secure trie, all
 // access operations hash the key using keccak256. This prevents
 // calling code from creating long chains of nodes that
@@ -39,25 +35,25 @@ const secureKeyLength = 11 + 32 // Length of the above prefix + 32byte hash
 // SecureTrie is not safe for concurrent use.
 type SecureTrie struct {
 	trie             Trie
-	hashKeyBuf       [secureKeyLength]byte
-	secKeyBuf        [200]byte
+	hashKeyBuf       [common.HashLength]byte
 	secKeyCache      map[string][]byte
 	secKeyCacheOwner *SecureTrie // Pointer to self, replace the key cache on mismatch
 }
 
-// NewSecure creates a trie with an existing root node from db.
+// NewSecure creates a trie with an existing root node from a backing database
+// and optional intermediate in-memory node pool.
 //
 // If root is the zero hash or the sha3 hash of an empty string, the
 // trie is initially empty. Otherwise, New will panic if db is nil
 // and returns MissingNodeError if the root node cannot be found.
 //
-// Accessing the trie loads nodes from db on demand.
+// Accessing the trie loads nodes from the database or node pool on demand.
 // Loaded nodes are kept around until their 'cache generation' expires.
 // A new cache generation is created by each call to Commit.
 // cachelimit sets the number of past cache generations to keep.
-func NewSecure(root common.Hash, db Database, cachelimit uint16) (*SecureTrie, error) {
+func NewSecure(root common.Hash, db *Database, cachelimit uint16) (*SecureTrie, error) {
 	if db == nil {
-		panic("NewSecure called with nil database")
+		panic("trie.NewSecure called without a database")
 	}
 	trie, err := New(root, db)
 	if err != nil {
@@ -135,7 +131,7 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
 	if key, ok := t.getSecKeyCache()[string(shaKey)]; ok {
 		return key
 	}
-	key, _ := t.trie.db.Get(t.secKey(shaKey))
+	key, _ := t.trie.db.preimage(common.BytesToHash(shaKey))
 	return key
 }
 
@@ -144,8 +140,19 @@ func (t *SecureTrie) GetKey(shaKey []byte) []byte {
 //
 // Committing flushes nodes from memory. Subsequent Get calls will load nodes
 // from the database.
-func (t *SecureTrie) Commit() (root common.Hash, err error) {
-	return t.CommitTo(t.trie.db)
+func (t *SecureTrie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
+	// Write all the pre-images to the actual disk database
+	if len(t.getSecKeyCache()) > 0 {
+		t.trie.db.lock.Lock()
+		for hk, key := range t.secKeyCache {
+			t.trie.db.insertPreimage(common.BytesToHash([]byte(hk)), key)
+		}
+		t.trie.db.lock.Unlock()
+
+		t.secKeyCache = make(map[string][]byte)
+	}
+	// Commit the trie to its intermediate node database
+	return t.trie.Commit(onleaf)
 }
 
 func (t *SecureTrie) Hash() common.Hash {
@@ -167,38 +174,11 @@ func (t *SecureTrie) NodeIterator(start []byte) NodeIterator {
 	return t.trie.NodeIterator(start)
 }
 
-// CommitTo writes all nodes and the secure hash pre-images to the given database.
-// Nodes are stored with their sha3 hash as the key.
-//
-// Committing flushes nodes from memory. Subsequent Get calls will load nodes from
-// the trie's database. Calling code must ensure that the changes made to db are
-// written back to the trie's attached database before using the trie.
-func (t *SecureTrie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
-	if len(t.getSecKeyCache()) > 0 {
-		for hk, key := range t.secKeyCache {
-			if err := db.Put(t.secKey([]byte(hk)), key); err != nil {
-				return common.Hash{}, err
-			}
-		}
-		t.secKeyCache = make(map[string][]byte)
-	}
-	return t.trie.CommitTo(db)
-}
-
-// secKey returns the database key for the preimage of key, as an ephemeral buffer.
-// The caller must not hold onto the return value because it will become
-// invalid on the next call to hashKey or secKey.
-func (t *SecureTrie) secKey(key []byte) []byte {
-	buf := append(t.secKeyBuf[:0], secureKeyPrefix...)
-	buf = append(buf, key...)
-	return buf
-}
-
 // hashKey returns the hash of key as an ephemeral buffer.
 // The caller must not hold onto the return value because it will become
 // invalid on the next call to hashKey or secKey.
 func (t *SecureTrie) hashKey(key []byte) []byte {
-	h := newHasher(0, 0)
+	h := newHasher(0, 0, nil)
 	h.sha.Reset()
 	h.sha.Write(key)
 	buf := h.sha.Sum(t.hashKeyBuf[:0])
diff --git a/trie/secure_trie_test.go b/trie/secure_trie_test.go
index d74102e2a..aedf5a1cd 100644
--- a/trie/secure_trie_test.go
+++ b/trie/secure_trie_test.go
@@ -28,16 +28,20 @@ import (
 )
 
 func newEmptySecure() *SecureTrie {
-	db, _ := ethdb.NewMemDatabase()
-	trie, _ := NewSecure(common.Hash{}, db, 0)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+
+	trie, _ := NewSecure(common.Hash{}, triedb, 0)
 	return trie
 }
 
 // makeTestSecureTrie creates a large enough secure trie for testing.
-func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
+func makeTestSecureTrie() (*Database, *SecureTrie, map[string][]byte) {
 	// Create an empty trie
-	db, _ := ethdb.NewMemDatabase()
-	trie, _ := NewSecure(common.Hash{}, db, 0)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+
+	trie, _ := NewSecure(common.Hash{}, triedb, 0)
 
 	// Fill it with some arbitrary data
 	content := make(map[string][]byte)
@@ -58,10 +62,10 @@ func makeTestSecureTrie() (ethdb.Database, *SecureTrie, map[string][]byte) {
 			trie.Update(key, val)
 		}
 	}
-	trie.Commit()
+	trie.Commit(nil)
 
 	// Return the generated trie
-	return db, trie, content
+	return triedb, trie, content
 }
 
 func TestSecureDelete(t *testing.T) {
@@ -137,7 +141,7 @@ func TestSecureTrieConcurrency(t *testing.T) {
 					tries[index].Update(key, val)
 				}
 			}
-			tries[index].Commit()
+			tries[index].Commit(nil)
 		}(i)
 	}
 	// Wait for all threads to finish
diff --git a/trie/sync.go b/trie/sync.go
index fea10051f..b573a9f73 100644
--- a/trie/sync.go
+++ b/trie/sync.go
@@ -21,6 +21,7 @@ import (
 	"fmt"
 
 	"github.com/ethereum/go-ethereum/common"
+	"github.com/ethereum/go-ethereum/ethdb"
 	"gopkg.in/karalabe/cookiejar.v2/collections/prque"
 )
 
@@ -42,7 +43,7 @@ type request struct {
 	depth   int        // Depth level within the trie the node is located to prioritise DFS
 	deps    int        // Number of dependencies before allowed to commit this node
 
-	callback TrieSyncLeafCallback // Callback to invoke if a leaf node it reached on this branch
+	callback LeafCallback // Callback to invoke if a leaf node it reached on this branch
 }
 
 // SyncResult is a simple list to return missing nodes along with their request
@@ -67,11 +68,6 @@ func newSyncMemBatch() *syncMemBatch {
 	}
 }
 
-// TrieSyncLeafCallback is a callback type invoked when a trie sync reaches a
-// leaf node. It's used by state syncing to check if the leaf node requires some
-// further data syncing.
-type TrieSyncLeafCallback func(leaf []byte, parent common.Hash) error
-
 // TrieSync is the main state trie synchronisation scheduler, which provides yet
 // unknown trie hashes to retrieve, accepts node data associated with said hashes
 // and reconstructs the trie step by step until all is done.
@@ -83,7 +79,7 @@ type TrieSync struct {
 }
 
 // NewTrieSync creates a new trie data download scheduler.
-func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLeafCallback) *TrieSync {
+func NewTrieSync(root common.Hash, database DatabaseReader, callback LeafCallback) *TrieSync {
 	ts := &TrieSync{
 		database: database,
 		membatch: newSyncMemBatch(),
@@ -95,7 +91,7 @@ func NewTrieSync(root common.Hash, database DatabaseReader, callback TrieSyncLea
 }
 
 // AddSubTrie registers a new trie to the sync code, rooted at the designated parent.
-func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback TrieSyncLeafCallback) {
+func (s *TrieSync) AddSubTrie(root common.Hash, depth int, parent common.Hash, callback LeafCallback) {
 	// Short circuit if the trie is empty or already known
 	if root == emptyRoot {
 		return
@@ -217,7 +213,7 @@ func (s *TrieSync) Process(results []SyncResult) (bool, int, error) {
 
 // Commit flushes the data stored in the internal membatch out to persistent
 // storage, returning th enumber of items written and any occurred error.
-func (s *TrieSync) Commit(dbw DatabaseWriter) (int, error) {
+func (s *TrieSync) Commit(dbw ethdb.Putter) (int, error) {
 	// Dump the membatch into a database dbw
 	for i, key := range s.membatch.order {
 		if err := dbw.Put(key[:], s.membatch.batch[key]); err != nil {
diff --git a/trie/sync_test.go b/trie/sync_test.go
index ec16a25bd..4a720612b 100644
--- a/trie/sync_test.go
+++ b/trie/sync_test.go
@@ -25,10 +25,11 @@ import (
 )
 
 // makeTestTrie create a sample test trie to test node-wise reconstruction.
-func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) {
+func makeTestTrie() (*Database, *Trie, map[string][]byte) {
 	// Create an empty trie
-	db, _ := ethdb.NewMemDatabase()
-	trie, _ := New(common.Hash{}, db)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	trie, _ := New(common.Hash{}, triedb)
 
 	// Fill it with some arbitrary data
 	content := make(map[string][]byte)
@@ -49,15 +50,15 @@ func makeTestTrie() (ethdb.Database, *Trie, map[string][]byte) {
 			trie.Update(key, val)
 		}
 	}
-	trie.Commit()
+	trie.Commit(nil)
 
 	// Return the generated trie
-	return db, trie, content
+	return triedb, trie, content
 }
 
 // checkTrieContents cross references a reconstructed trie with an expected data
 // content map.
-func checkTrieContents(t *testing.T, db Database, root []byte, content map[string][]byte) {
+func checkTrieContents(t *testing.T, db *Database, root []byte, content map[string][]byte) {
 	// Check root availability and trie contents
 	trie, err := New(common.BytesToHash(root), db)
 	if err != nil {
@@ -74,7 +75,7 @@ func checkTrieContents(t *testing.T, db Database, root []byte, content map[strin
 }
 
 // checkTrieConsistency checks that all nodes in a trie are indeed present.
-func checkTrieConsistency(db Database, root common.Hash) error {
+func checkTrieConsistency(db *Database, root common.Hash) error {
 	// Create and iterate a trie rooted in a subnode
 	trie, err := New(root, db)
 	if err != nil {
@@ -88,12 +89,18 @@ func checkTrieConsistency(db Database, root common.Hash) error {
 
 // Tests that an empty trie is not scheduled for syncing.
 func TestEmptyTrieSync(t *testing.T) {
-	emptyA, _ := New(common.Hash{}, nil)
-	emptyB, _ := New(emptyRoot, nil)
+	diskdbA, _ := ethdb.NewMemDatabase()
+	triedbA := NewDatabase(diskdbA)
+
+	diskdbB, _ := ethdb.NewMemDatabase()
+	triedbB := NewDatabase(diskdbB)
+
+	emptyA, _ := New(common.Hash{}, triedbA)
+	emptyB, _ := New(emptyRoot, triedbB)
 
 	for i, trie := range []*Trie{emptyA, emptyB} {
-		db, _ := ethdb.NewMemDatabase()
-		if req := NewTrieSync(common.BytesToHash(trie.Root()), db, nil).Missing(1); len(req) != 0 {
+		diskdb, _ := ethdb.NewMemDatabase()
+		if req := NewTrieSync(trie.Hash(), diskdb, nil).Missing(1); len(req) != 0 {
 			t.Errorf("test %d: content requested for empty trie: %v", i, req)
 		}
 	}
@@ -109,14 +116,15 @@ func testIterativeTrieSync(t *testing.T, batch int) {
 	srcDb, srcTrie, srcData := makeTestTrie()
 
 	// Create a destination trie and sync with the scheduler
-	dstDb, _ := ethdb.NewMemDatabase()
-	sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
 
 	queue := append([]common.Hash{}, sched.Missing(batch)...)
 	for len(queue) > 0 {
 		results := make([]SyncResult, len(queue))
 		for i, hash := range queue {
-			data, err := srcDb.Get(hash.Bytes())
+			data, err := srcDb.Node(hash)
 			if err != nil {
 				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
 			}
@@ -125,13 +133,13 @@ func testIterativeTrieSync(t *testing.T, batch int) {
 		if _, index, err := sched.Process(results); err != nil {
 			t.Fatalf("failed to process result #%d: %v", index, err)
 		}
-		if index, err := sched.Commit(dstDb); err != nil {
+		if index, err := sched.Commit(diskdb); err != nil {
 			t.Fatalf("failed to commit data #%d: %v", index, err)
 		}
 		queue = append(queue[:0], sched.Missing(batch)...)
 	}
 	// Cross check that the two tries are in sync
-	checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+	checkTrieContents(t, triedb, srcTrie.Root(), srcData)
 }
 
 // Tests that the trie scheduler can correctly reconstruct the state even if only
@@ -141,15 +149,16 @@ func TestIterativeDelayedTrieSync(t *testing.T) {
 	srcDb, srcTrie, srcData := makeTestTrie()
 
 	// Create a destination trie and sync with the scheduler
-	dstDb, _ := ethdb.NewMemDatabase()
-	sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
 
 	queue := append([]common.Hash{}, sched.Missing(10000)...)
 	for len(queue) > 0 {
 		// Sync only half of the scheduled nodes
 		results := make([]SyncResult, len(queue)/2+1)
 		for i, hash := range queue[:len(results)] {
-			data, err := srcDb.Get(hash.Bytes())
+			data, err := srcDb.Node(hash)
 			if err != nil {
 				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
 			}
@@ -158,13 +167,13 @@ func TestIterativeDelayedTrieSync(t *testing.T) {
 		if _, index, err := sched.Process(results); err != nil {
 			t.Fatalf("failed to process result #%d: %v", index, err)
 		}
-		if index, err := sched.Commit(dstDb); err != nil {
+		if index, err := sched.Commit(diskdb); err != nil {
 			t.Fatalf("failed to commit data #%d: %v", index, err)
 		}
 		queue = append(queue[len(results):], sched.Missing(10000)...)
 	}
 	// Cross check that the two tries are in sync
-	checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+	checkTrieContents(t, triedb, srcTrie.Root(), srcData)
 }
 
 // Tests that given a root hash, a trie can sync iteratively on a single thread,
@@ -178,8 +187,9 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
 	srcDb, srcTrie, srcData := makeTestTrie()
 
 	// Create a destination trie and sync with the scheduler
-	dstDb, _ := ethdb.NewMemDatabase()
-	sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
 
 	queue := make(map[common.Hash]struct{})
 	for _, hash := range sched.Missing(batch) {
@@ -189,7 +199,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
 		// Fetch all the queued nodes in a random order
 		results := make([]SyncResult, 0, len(queue))
 		for hash := range queue {
-			data, err := srcDb.Get(hash.Bytes())
+			data, err := srcDb.Node(hash)
 			if err != nil {
 				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
 			}
@@ -199,7 +209,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
 		if _, index, err := sched.Process(results); err != nil {
 			t.Fatalf("failed to process result #%d: %v", index, err)
 		}
-		if index, err := sched.Commit(dstDb); err != nil {
+		if index, err := sched.Commit(diskdb); err != nil {
 			t.Fatalf("failed to commit data #%d: %v", index, err)
 		}
 		queue = make(map[common.Hash]struct{})
@@ -208,7 +218,7 @@ func testIterativeRandomTrieSync(t *testing.T, batch int) {
 		}
 	}
 	// Cross check that the two tries are in sync
-	checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+	checkTrieContents(t, triedb, srcTrie.Root(), srcData)
 }
 
 // Tests that the trie scheduler can correctly reconstruct the state even if only
@@ -218,8 +228,9 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
 	srcDb, srcTrie, srcData := makeTestTrie()
 
 	// Create a destination trie and sync with the scheduler
-	dstDb, _ := ethdb.NewMemDatabase()
-	sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
 
 	queue := make(map[common.Hash]struct{})
 	for _, hash := range sched.Missing(10000) {
@@ -229,7 +240,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
 		// Sync only half of the scheduled nodes, even those in random order
 		results := make([]SyncResult, 0, len(queue)/2+1)
 		for hash := range queue {
-			data, err := srcDb.Get(hash.Bytes())
+			data, err := srcDb.Node(hash)
 			if err != nil {
 				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
 			}
@@ -243,7 +254,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
 		if _, index, err := sched.Process(results); err != nil {
 			t.Fatalf("failed to process result #%d: %v", index, err)
 		}
-		if index, err := sched.Commit(dstDb); err != nil {
+		if index, err := sched.Commit(diskdb); err != nil {
 			t.Fatalf("failed to commit data #%d: %v", index, err)
 		}
 		for _, result := range results {
@@ -254,7 +265,7 @@ func TestIterativeRandomDelayedTrieSync(t *testing.T) {
 		}
 	}
 	// Cross check that the two tries are in sync
-	checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+	checkTrieContents(t, triedb, srcTrie.Root(), srcData)
 }
 
 // Tests that a trie sync will not request nodes multiple times, even if they
@@ -264,8 +275,9 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
 	srcDb, srcTrie, srcData := makeTestTrie()
 
 	// Create a destination trie and sync with the scheduler
-	dstDb, _ := ethdb.NewMemDatabase()
-	sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
 
 	queue := append([]common.Hash{}, sched.Missing(0)...)
 	requested := make(map[common.Hash]struct{})
@@ -273,7 +285,7 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
 	for len(queue) > 0 {
 		results := make([]SyncResult, len(queue))
 		for i, hash := range queue {
-			data, err := srcDb.Get(hash.Bytes())
+			data, err := srcDb.Node(hash)
 			if err != nil {
 				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
 			}
@@ -287,13 +299,13 @@ func TestDuplicateAvoidanceTrieSync(t *testing.T) {
 		if _, index, err := sched.Process(results); err != nil {
 			t.Fatalf("failed to process result #%d: %v", index, err)
 		}
-		if index, err := sched.Commit(dstDb); err != nil {
+		if index, err := sched.Commit(diskdb); err != nil {
 			t.Fatalf("failed to commit data #%d: %v", index, err)
 		}
 		queue = append(queue[:0], sched.Missing(0)...)
 	}
 	// Cross check that the two tries are in sync
-	checkTrieContents(t, dstDb, srcTrie.Root(), srcData)
+	checkTrieContents(t, triedb, srcTrie.Root(), srcData)
 }
 
 // Tests that at any point in time during a sync, only complete sub-tries are in
@@ -303,8 +315,9 @@ func TestIncompleteTrieSync(t *testing.T) {
 	srcDb, srcTrie, _ := makeTestTrie()
 
 	// Create a destination trie and sync with the scheduler
-	dstDb, _ := ethdb.NewMemDatabase()
-	sched := NewTrieSync(common.BytesToHash(srcTrie.Root()), dstDb, nil)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+	sched := NewTrieSync(srcTrie.Hash(), diskdb, nil)
 
 	added := []common.Hash{}
 	queue := append([]common.Hash{}, sched.Missing(1)...)
@@ -312,7 +325,7 @@ func TestIncompleteTrieSync(t *testing.T) {
 		// Fetch a batch of trie nodes
 		results := make([]SyncResult, len(queue))
 		for i, hash := range queue {
-			data, err := srcDb.Get(hash.Bytes())
+			data, err := srcDb.Node(hash)
 			if err != nil {
 				t.Fatalf("failed to retrieve node data for %x: %v", hash, err)
 			}
@@ -322,7 +335,7 @@ func TestIncompleteTrieSync(t *testing.T) {
 		if _, index, err := sched.Process(results); err != nil {
 			t.Fatalf("failed to process result #%d: %v", index, err)
 		}
-		if index, err := sched.Commit(dstDb); err != nil {
+		if index, err := sched.Commit(diskdb); err != nil {
 			t.Fatalf("failed to commit data #%d: %v", index, err)
 		}
 		for _, result := range results {
@@ -330,7 +343,7 @@ func TestIncompleteTrieSync(t *testing.T) {
 		}
 		// Check that all known sub-tries in the synced trie are complete
 		for _, root := range added {
-			if err := checkTrieConsistency(dstDb, root); err != nil {
+			if err := checkTrieConsistency(triedb, root); err != nil {
 				t.Fatalf("trie inconsistent: %v", err)
 			}
 		}
@@ -340,12 +353,12 @@ func TestIncompleteTrieSync(t *testing.T) {
 	// Sanity check that removing any node from the database is detected
 	for _, node := range added[1:] {
 		key := node.Bytes()
-		value, _ := dstDb.Get(key)
+		value, _ := diskdb.Get(key)
 
-		dstDb.Delete(key)
-		if err := checkTrieConsistency(dstDb, added[0]); err == nil {
+		diskdb.Delete(key)
+		if err := checkTrieConsistency(triedb, added[0]); err == nil {
 			t.Fatalf("trie inconsistency not caught, missing: %x", key)
 		}
-		dstDb.Put(key, value)
+		diskdb.Put(key, value)
 	}
 }
diff --git a/trie/trie.go b/trie/trie.go
index 8fe98d835..e37a1ae10 100644
--- a/trie/trie.go
+++ b/trie/trie.go
@@ -22,16 +22,17 @@ import (
 	"fmt"
 
 	"github.com/ethereum/go-ethereum/common"
-	"github.com/ethereum/go-ethereum/crypto/sha3"
+	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/rcrowley/go-metrics"
 )
 
 var (
-	// This is the known root hash of an empty trie.
+	// emptyRoot is the known root hash of an empty trie.
 	emptyRoot = common.HexToHash("56e81f171bcc55a6ff8345e692c0f86e5b48e01b996cadc001622fb5e363b421")
-	// This is the known hash of an empty state trie entry.
-	emptyState common.Hash
+
+	// emptyState is the known hash of an empty state trie entry.
+	emptyState = crypto.Keccak256Hash(nil)
 )
 
 var (
@@ -53,29 +54,10 @@ func CacheUnloads() int64 {
 	return cacheUnloadCounter.Count()
 }
 
-func init() {
-	sha3.NewKeccak256().Sum(emptyState[:0])
-}
-
-// Database must be implemented by backing stores for the trie.
-type Database interface {
-	DatabaseReader
-	DatabaseWriter
-}
-
-// DatabaseReader wraps the Get method of a backing store for the trie.
-type DatabaseReader interface {
-	Get(key []byte) (value []byte, err error)
-	Has(key []byte) (bool, error)
-}
-
-// DatabaseWriter wraps the Put method of a backing store for the trie.
-type DatabaseWriter interface {
-	// Put stores the mapping key->value in the database.
-	// Implementations must not hold onto the value bytes, the trie
-	// will reuse the slice across calls to Put.
-	Put(key, value []byte) error
-}
+// LeafCallback is a callback type invoked when a trie operation reaches a leaf
+// node. It's used by state sync and commit to allow handling external references
+// between account and storage tries.
+type LeafCallback func(leaf []byte, parent common.Hash) error
 
 // Trie is a Merkle Patricia Trie.
 // The zero value is an empty trie with no database.
@@ -83,8 +65,8 @@ type DatabaseWriter interface {
 //
 // Trie is not safe for concurrent use.
 type Trie struct {
+	db           *Database
 	root         node
-	db           Database
 	originalRoot common.Hash
 
 	// Cache generation values.
@@ -111,12 +93,15 @@ func (t *Trie) newFlag() nodeFlag {
 // trie is initially empty and does not require a database. Otherwise,
 // New will panic if db is nil and returns a MissingNodeError if root does
 // not exist in the database. Accessing the trie loads nodes from db on demand.
-func New(root common.Hash, db Database) (*Trie, error) {
-	trie := &Trie{db: db, originalRoot: root}
+func New(root common.Hash, db *Database) (*Trie, error) {
+	if db == nil {
+		panic("trie.New called without a database")
+	}
+	trie := &Trie{
+		db:           db,
+		originalRoot: root,
+	}
 	if (root != common.Hash{}) && root != emptyRoot {
-		if db == nil {
-			panic("trie.New: cannot use existing root without a database")
-		}
 		rootnode, err := trie.resolveHash(root[:], nil)
 		if err != nil {
 			return nil, err
@@ -447,12 +432,13 @@ func (t *Trie) resolve(n node, prefix []byte) (node, error) {
 func (t *Trie) resolveHash(n hashNode, prefix []byte) (node, error) {
 	cacheMissCounter.Inc(1)
 
-	enc, err := t.db.Get(n)
+	hash := common.BytesToHash(n)
+
+	enc, err := t.db.Node(hash)
 	if err != nil || enc == nil {
-		return nil, &MissingNodeError{NodeHash: common.BytesToHash(n), Path: prefix}
+		return nil, &MissingNodeError{NodeHash: hash, Path: prefix}
 	}
-	dec := mustDecodeNode(n, enc, t.cachegen)
-	return dec, nil
+	return mustDecodeNode(n, enc, t.cachegen), nil
 }
 
 // Root returns the root hash of the trie.
@@ -462,32 +448,18 @@ func (t *Trie) Root() []byte { return t.Hash().Bytes() }
 // Hash returns the root hash of the trie. It does not write to the
 // database and can be used even if the trie doesn't have one.
 func (t *Trie) Hash() common.Hash {
-	hash, cached, _ := t.hashRoot(nil)
+	hash, cached, _ := t.hashRoot(nil, nil)
 	t.root = cached
 	return common.BytesToHash(hash.(hashNode))
 }
 
-// Commit writes all nodes to the trie's database.
-// Nodes are stored with their sha3 hash as the key.
-//
-// Committing flushes nodes from memory.
-// Subsequent Get calls will load nodes from the database.
-func (t *Trie) Commit() (root common.Hash, err error) {
+// Commit writes all nodes to the trie's memory database, tracking the internal
+// and external (for account tries) references.
+func (t *Trie) Commit(onleaf LeafCallback) (root common.Hash, err error) {
 	if t.db == nil {
-		panic("Commit called on trie with nil database")
+		panic("commit called on trie with nil database")
 	}
-	return t.CommitTo(t.db)
-}
-
-// CommitTo writes all nodes to the given database.
-// Nodes are stored with their sha3 hash as the key.
-//
-// Committing flushes nodes from memory. Subsequent Get calls will
-// load nodes from the trie's database. Calling code must ensure that
-// the changes made to db are written back to the trie's attached
-// database before using the trie.
-func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
-	hash, cached, err := t.hashRoot(db)
+	hash, cached, err := t.hashRoot(t.db, onleaf)
 	if err != nil {
 		return common.Hash{}, err
 	}
@@ -496,11 +468,11 @@ func (t *Trie) CommitTo(db DatabaseWriter) (root common.Hash, err error) {
 	return common.BytesToHash(hash.(hashNode)), nil
 }
 
-func (t *Trie) hashRoot(db DatabaseWriter) (node, node, error) {
+func (t *Trie) hashRoot(db *Database, onleaf LeafCallback) (node, node, error) {
 	if t.root == nil {
 		return hashNode(emptyRoot.Bytes()), nil, nil
 	}
-	h := newHasher(t.cachegen, t.cachelimit)
+	h := newHasher(t.cachegen, t.cachelimit, onleaf)
 	defer returnHasherToPool(h)
 	return h.hash(t.root, db, true)
 }
diff --git a/trie/trie_test.go b/trie/trie_test.go
index 1e28c3bc4..997222628 100644
--- a/trie/trie_test.go
+++ b/trie/trie_test.go
@@ -43,8 +43,8 @@ func init() {
 
 // Used for testing
 func newEmpty() *Trie {
-	db, _ := ethdb.NewMemDatabase()
-	trie, _ := New(common.Hash{}, db)
+	diskdb, _ := ethdb.NewMemDatabase()
+	trie, _ := New(common.Hash{}, NewDatabase(diskdb))
 	return trie
 }
 
@@ -68,8 +68,8 @@ func TestNull(t *testing.T) {
 }
 
 func TestMissingRoot(t *testing.T) {
-	db, _ := ethdb.NewMemDatabase()
-	trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), db)
+	diskdb, _ := ethdb.NewMemDatabase()
+	trie, err := New(common.HexToHash("0beec7b5ea3f0fdbc95d0dd47f3c5bc275da8a33"), NewDatabase(diskdb))
 	if trie != nil {
 		t.Error("New returned non-nil trie for invalid root")
 	}
@@ -78,70 +78,75 @@ func TestMissingRoot(t *testing.T) {
 	}
 }
 
-func TestMissingNode(t *testing.T) {
-	db, _ := ethdb.NewMemDatabase()
-	trie, _ := New(common.Hash{}, db)
+func TestMissingNodeDisk(t *testing.T)    { testMissingNode(t, false) }
+func TestMissingNodeMemonly(t *testing.T) { testMissingNode(t, true) }
+
+func testMissingNode(t *testing.T, memonly bool) {
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+
+	trie, _ := New(common.Hash{}, triedb)
 	updateString(trie, "120000", "qwerqwerqwerqwerqwerqwerqwerqwer")
 	updateString(trie, "123456", "asdfasdfasdfasdfasdfasdfasdfasdf")
-	root, _ := trie.Commit()
+	root, _ := trie.Commit(nil)
+	if !memonly {
+		triedb.Commit(root, true)
+	}
 
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	_, err := trie.TryGet([]byte("120000"))
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	_, err = trie.TryGet([]byte("120099"))
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	_, err = trie.TryGet([]byte("123456"))
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	err = trie.TryUpdate([]byte("120099"), []byte("zxcvzxcvzxcvzxcvzxcvzxcvzxcvzxcv"))
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	err = trie.TryDelete([]byte("123456"))
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
 
-	db.Delete(common.FromHex("e1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9"))
+	hash := common.HexToHash("0xe1d943cc8f061a0c0b98162830b970395ac9315654824bf21b73b891365262f9")
+	if memonly {
+		delete(triedb.nodes, hash)
+	} else {
+		diskdb.Delete(hash[:])
+	}
 
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	_, err = trie.TryGet([]byte("120000"))
 	if _, ok := err.(*MissingNodeError); !ok {
 		t.Errorf("Wrong error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	_, err = trie.TryGet([]byte("120099"))
 	if _, ok := err.(*MissingNodeError); !ok {
 		t.Errorf("Wrong error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	_, err = trie.TryGet([]byte("123456"))
 	if err != nil {
 		t.Errorf("Unexpected error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	err = trie.TryUpdate([]byte("120099"), []byte("zxcv"))
 	if _, ok := err.(*MissingNodeError); !ok {
 		t.Errorf("Wrong error: %v", err)
 	}
-
-	trie, _ = New(root, db)
+	trie, _ = New(root, triedb)
 	err = trie.TryDelete([]byte("123456"))
 	if _, ok := err.(*MissingNodeError); !ok {
 		t.Errorf("Wrong error: %v", err)
@@ -165,7 +170,7 @@ func TestInsert(t *testing.T) {
 	updateString(trie, "A", "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa")
 
 	exp = common.HexToHash("d23786fb4a010da3ce639d66d5e904a11dbc02746d1ce25029e53290cabf28ab")
-	root, err := trie.Commit()
+	root, err := trie.Commit(nil)
 	if err != nil {
 		t.Fatalf("commit error: %v", err)
 	}
@@ -194,7 +199,7 @@ func TestGet(t *testing.T) {
 		if i == 1 {
 			return
 		}
-		trie.Commit()
+		trie.Commit(nil)
 	}
 }
 
@@ -263,7 +268,7 @@ func TestReplication(t *testing.T) {
 	for _, val := range vals {
 		updateString(trie, val.k, val.v)
 	}
-	exp, err := trie.Commit()
+	exp, err := trie.Commit(nil)
 	if err != nil {
 		t.Fatalf("commit error: %v", err)
 	}
@@ -278,7 +283,7 @@ func TestReplication(t *testing.T) {
 			t.Errorf("trie2 doesn't have %q => %q", kv.k, kv.v)
 		}
 	}
-	hash, err := trie2.Commit()
+	hash, err := trie2.Commit(nil)
 	if err != nil {
 		t.Fatalf("commit error: %v", err)
 	}
@@ -314,7 +319,7 @@ func TestLargeValue(t *testing.T) {
 }
 
 type countingDB struct {
-	Database
+	ethdb.Database
 	gets map[string]int
 }
 
@@ -332,19 +337,20 @@ func TestCacheUnload(t *testing.T) {
 	key2 := "---some other branch"
 	updateString(trie, key1, "this is the branch of key1.")
 	updateString(trie, key2, "this is the branch of key2.")
-	root, _ := trie.Commit()
+
+	root, _ := trie.Commit(nil)
+	trie.db.Commit(root, true)
 
 	// Commit the trie repeatedly and access key1.
 	// The branch containing it is loaded from DB exactly two times:
 	// in the 0th and 6th iteration.
-	db := &countingDB{Database: trie.db, gets: make(map[string]int)}
-	trie, _ = New(root, db)
+	db := &countingDB{Database: trie.db.diskdb, gets: make(map[string]int)}
+	trie, _ = New(root, NewDatabase(db))
 	trie.SetCacheLimit(5)
 	for i := 0; i < 12; i++ {
 		getString(trie, key1)
-		trie.Commit()
+		trie.Commit(nil)
 	}
-
 	// Check that it got loaded two times.
 	for dbkey, count := range db.gets {
 		if count != 2 {
@@ -407,8 +413,10 @@ func (randTest) Generate(r *rand.Rand, size int) reflect.Value {
 }
 
 func runRandTest(rt randTest) bool {
-	db, _ := ethdb.NewMemDatabase()
-	tr, _ := New(common.Hash{}, db)
+	diskdb, _ := ethdb.NewMemDatabase()
+	triedb := NewDatabase(diskdb)
+
+	tr, _ := New(common.Hash{}, triedb)
 	values := make(map[string]string) // tracks content of the trie
 
 	for i, step := range rt {
@@ -426,23 +434,23 @@ func runRandTest(rt randTest) bool {
 				rt[i].err = fmt.Errorf("mismatch for key 0x%x, got 0x%x want 0x%x", step.key, v, want)
 			}
 		case opCommit:
-			_, rt[i].err = tr.Commit()
+			_, rt[i].err = tr.Commit(nil)
 		case opHash:
 			tr.Hash()
 		case opReset:
-			hash, err := tr.Commit()
+			hash, err := tr.Commit(nil)
 			if err != nil {
 				rt[i].err = err
 				return false
 			}
-			newtr, err := New(hash, db)
+			newtr, err := New(hash, triedb)
 			if err != nil {
 				rt[i].err = err
 				return false
 			}
 			tr = newtr
 		case opItercheckhash:
-			checktr, _ := New(common.Hash{}, nil)
+			checktr, _ := New(common.Hash{}, triedb)
 			it := NewIterator(tr.NodeIterator(nil))
 			for it.Next() {
 				checktr.Update(it.Key, it.Value)
@@ -524,7 +532,7 @@ func benchGet(b *testing.B, commit bool) {
 	}
 	binary.LittleEndian.PutUint64(k, benchElemCount/2)
 	if commit {
-		trie.Commit()
+		trie.Commit(nil)
 	}
 
 	b.ResetTimer()
@@ -534,7 +542,7 @@ func benchGet(b *testing.B, commit bool) {
 	b.StopTimer()
 
 	if commit {
-		ldb := trie.db.(*ethdb.LDBDatabase)
+		ldb := trie.db.diskdb.(*ethdb.LDBDatabase)
 		ldb.Close()
 		os.RemoveAll(ldb.Path())
 	}
@@ -585,16 +593,16 @@ func BenchmarkHash(b *testing.B) {
 	trie.Hash()
 }
 
-func tempDB() (string, Database) {
+func tempDB() (string, *Database) {
 	dir, err := ioutil.TempDir("", "trie-bench")
 	if err != nil {
 		panic(fmt.Sprintf("can't create temporary directory: %v", err))
 	}
-	db, err := ethdb.NewLDBDatabase(dir, 256, 0)
+	diskdb, err := ethdb.NewLDBDatabase(dir, 256, 0)
 	if err != nil {
 		panic(fmt.Sprintf("can't create temporary database: %v", err))
 	}
-	return dir, db
+	return dir, NewDatabase(diskdb)
 }
 
 func getString(trie *Trie, k string) []byte {
-- 
GitLab