From f83237573f0922dfc9fef17f79ccd06305ab6d16 Mon Sep 17 00:00:00 2001
From: Martin Holst Swende <martin@swende.se>
Date: Mon, 26 Feb 2018 10:53:10 +0100
Subject: [PATCH] core: make current*Block atomic, and accessor functions
 mutex-free (#16171)

* core: make current*Block atomic, and accessor functions mutex-free

* core: fix review concerns

* core: fix error in atomic assignment

* core/light: implement atomic getter/setter for headerchain
---
 core/blockchain.go  | 112 ++++++++++++++++++++++----------------------
 core/headerchain.go |  39 ++++++++-------
 light/lightchain.go |   6 ---
 3 files changed, 76 insertions(+), 81 deletions(-)

diff --git a/core/blockchain.go b/core/blockchain.go
index 53fe7ee2e..6006e6674 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -107,8 +107,8 @@ type BlockChain struct {
 	procmu  sync.RWMutex // block processor lock
 
 	checkpoint       int          // checkpoint counts towards the new checkpoint
-	currentBlock     *types.Block // Current head of the block chain
-	currentFastBlock *types.Block // Current head of the fast-sync chain (may be above the block chain!)
+	currentBlock     atomic.Value // Current head of the block chain
+	currentFastBlock atomic.Value // Current head of the fast-sync chain (may be above the block chain!)
 
 	stateCache   state.Database // State database to reuse between imports (contains state cache)
 	bodyCache    *lru.Cache     // Cache for the most recent block bodies
@@ -224,10 +224,10 @@ func (bc *BlockChain) loadLastState() error {
 		}
 	}
 	// Everything seems to be fine, set as the head block
-	bc.currentBlock = currentBlock
+	bc.currentBlock.Store(currentBlock)
 
 	// Restore the last known head header
-	currentHeader := bc.currentBlock.Header()
+	currentHeader := currentBlock.Header()
 	if head := GetHeadHeaderHash(bc.db); head != (common.Hash{}) {
 		if header := bc.GetHeaderByHash(head); header != nil {
 			currentHeader = header
@@ -236,21 +236,23 @@ func (bc *BlockChain) loadLastState() error {
 	bc.hc.SetCurrentHeader(currentHeader)
 
 	// Restore the last known head fast block
-	bc.currentFastBlock = bc.currentBlock
+	bc.currentFastBlock.Store(currentBlock)
 	if head := GetHeadFastBlockHash(bc.db); head != (common.Hash{}) {
 		if block := bc.GetBlockByHash(head); block != nil {
-			bc.currentFastBlock = block
+			bc.currentFastBlock.Store(block)
 		}
 	}
 
 	// Issue a status log for the user
+	currentFastBlock := bc.CurrentFastBlock()
+
 	headerTd := bc.GetTd(currentHeader.Hash(), currentHeader.Number.Uint64())
-	blockTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64())
-	fastTd := bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64())
+	blockTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
+	fastTd := bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64())
 
 	log.Info("Loaded most recent local header", "number", currentHeader.Number, "hash", currentHeader.Hash(), "td", headerTd)
-	log.Info("Loaded most recent local full block", "number", bc.currentBlock.Number(), "hash", bc.currentBlock.Hash(), "td", blockTd)
-	log.Info("Loaded most recent local fast block", "number", bc.currentFastBlock.Number(), "hash", bc.currentFastBlock.Hash(), "td", fastTd)
+	log.Info("Loaded most recent local full block", "number", currentBlock.Number(), "hash", currentBlock.Hash(), "td", blockTd)
+	log.Info("Loaded most recent local fast block", "number", currentFastBlock.Number(), "hash", currentFastBlock.Hash(), "td", fastTd)
 
 	return nil
 }
@@ -279,30 +281,32 @@ func (bc *BlockChain) SetHead(head uint64) error {
 	bc.futureBlocks.Purge()
 
 	// Rewind the block chain, ensuring we don't end up with a stateless head block
-	if bc.currentBlock != nil && currentHeader.Number.Uint64() < bc.currentBlock.NumberU64() {
-		bc.currentBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())
+	if currentBlock := bc.CurrentBlock(); currentBlock != nil && currentHeader.Number.Uint64() < currentBlock.NumberU64() {
+		bc.currentBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()))
 	}
-	if bc.currentBlock != nil {
-		if _, err := state.New(bc.currentBlock.Root(), bc.stateCache); err != nil {
+	if currentBlock := bc.CurrentBlock(); currentBlock != nil {
+		if _, err := state.New(currentBlock.Root(), bc.stateCache); err != nil {
 			// Rewound state missing, rolled back to before pivot, reset to genesis
-			bc.currentBlock = nil
+			bc.currentBlock.Store(bc.genesisBlock)
 		}
 	}
 	// Rewind the fast block in a simpleton way to the target head
-	if bc.currentFastBlock != nil && currentHeader.Number.Uint64() < bc.currentFastBlock.NumberU64() {
-		bc.currentFastBlock = bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64())
+	if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock != nil && currentHeader.Number.Uint64() < currentFastBlock.NumberU64() {
+		bc.currentFastBlock.Store(bc.GetBlock(currentHeader.Hash(), currentHeader.Number.Uint64()))
 	}
 	// If either blocks reached nil, reset to the genesis state
-	if bc.currentBlock == nil {
-		bc.currentBlock = bc.genesisBlock
+	if currentBlock := bc.CurrentBlock(); currentBlock == nil {
+		bc.currentBlock.Store(bc.genesisBlock)
 	}
-	if bc.currentFastBlock == nil {
-		bc.currentFastBlock = bc.genesisBlock
+	if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock == nil {
+		bc.currentFastBlock.Store(bc.genesisBlock)
 	}
-	if err := WriteHeadBlockHash(bc.db, bc.currentBlock.Hash()); err != nil {
+	currentBlock := bc.CurrentBlock()
+	currentFastBlock := bc.CurrentFastBlock()
+	if err := WriteHeadBlockHash(bc.db, currentBlock.Hash()); err != nil {
 		log.Crit("Failed to reset head full block", "err", err)
 	}
-	if err := WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash()); err != nil {
+	if err := WriteHeadFastBlockHash(bc.db, currentFastBlock.Hash()); err != nil {
 		log.Crit("Failed to reset head fast block", "err", err)
 	}
 	return bc.loadLastState()
@@ -321,7 +325,7 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
 	}
 	// If all checks out, manually set the head block
 	bc.mu.Lock()
-	bc.currentBlock = block
+	bc.currentBlock.Store(block)
 	bc.mu.Unlock()
 
 	log.Info("Committed new head block", "number", block.Number(), "hash", hash)
@@ -330,28 +334,19 @@ func (bc *BlockChain) FastSyncCommitHead(hash common.Hash) error {
 
 // GasLimit returns the gas limit of the current HEAD block.
 func (bc *BlockChain) GasLimit() uint64 {
-	bc.mu.RLock()
-	defer bc.mu.RUnlock()
-
-	return bc.currentBlock.GasLimit()
+	return bc.CurrentBlock().GasLimit()
 }
 
 // CurrentBlock retrieves the current head block of the canonical chain. The
 // block is retrieved from the blockchain's internal cache.
 func (bc *BlockChain) CurrentBlock() *types.Block {
-	bc.mu.RLock()
-	defer bc.mu.RUnlock()
-
-	return bc.currentBlock
+	return bc.currentBlock.Load().(*types.Block)
 }
 
 // CurrentFastBlock retrieves the current fast-sync head block of the canonical
 // chain. The block is retrieved from the blockchain's internal cache.
 func (bc *BlockChain) CurrentFastBlock() *types.Block {
-	bc.mu.RLock()
-	defer bc.mu.RUnlock()
-
-	return bc.currentFastBlock
+	return bc.currentFastBlock.Load().(*types.Block)
 }
 
 // SetProcessor sets the processor required for making state modifications.
@@ -416,10 +411,10 @@ func (bc *BlockChain) ResetWithGenesisBlock(genesis *types.Block) error {
 	}
 	bc.genesisBlock = genesis
 	bc.insert(bc.genesisBlock)
-	bc.currentBlock = bc.genesisBlock
+	bc.currentBlock.Store(bc.genesisBlock)
 	bc.hc.SetGenesis(bc.genesisBlock.Header())
 	bc.hc.SetCurrentHeader(bc.genesisBlock.Header())
-	bc.currentFastBlock = bc.genesisBlock
+	bc.currentFastBlock.Store(bc.genesisBlock)
 
 	return nil
 }
@@ -444,7 +439,7 @@ func (bc *BlockChain) repair(head **types.Block) error {
 
 // Export writes the active chain to the given writer.
 func (bc *BlockChain) Export(w io.Writer) error {
-	return bc.ExportN(w, uint64(0), bc.currentBlock.NumberU64())
+	return bc.ExportN(w, uint64(0), bc.CurrentBlock().NumberU64())
 }
 
 // ExportN writes a subset of the active chain to the given writer.
@@ -488,7 +483,7 @@ func (bc *BlockChain) insert(block *types.Block) {
 	if err := WriteHeadBlockHash(bc.db, block.Hash()); err != nil {
 		log.Crit("Failed to insert head block hash", "err", err)
 	}
-	bc.currentBlock = block
+	bc.currentBlock.Store(block)
 
 	// If the block is better than our head or is on a different chain, force update heads
 	if updateHeads {
@@ -497,7 +492,7 @@ func (bc *BlockChain) insert(block *types.Block) {
 		if err := WriteHeadFastBlockHash(bc.db, block.Hash()); err != nil {
 			log.Crit("Failed to insert head fast block hash", "err", err)
 		}
-		bc.currentFastBlock = block
+		bc.currentFastBlock.Store(block)
 	}
 }
 
@@ -714,13 +709,15 @@ func (bc *BlockChain) Rollback(chain []common.Hash) {
 		if currentHeader.Hash() == hash {
 			bc.hc.SetCurrentHeader(bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1))
 		}
-		if bc.currentFastBlock.Hash() == hash {
-			bc.currentFastBlock = bc.GetBlock(bc.currentFastBlock.ParentHash(), bc.currentFastBlock.NumberU64()-1)
-			WriteHeadFastBlockHash(bc.db, bc.currentFastBlock.Hash())
+		if currentFastBlock := bc.CurrentFastBlock(); currentFastBlock.Hash() == hash {
+			newFastBlock := bc.GetBlock(currentFastBlock.ParentHash(), currentFastBlock.NumberU64()-1)
+			bc.currentFastBlock.Store(newFastBlock)
+			WriteHeadFastBlockHash(bc.db, newFastBlock.Hash())
 		}
-		if bc.currentBlock.Hash() == hash {
-			bc.currentBlock = bc.GetBlock(bc.currentBlock.ParentHash(), bc.currentBlock.NumberU64()-1)
-			WriteHeadBlockHash(bc.db, bc.currentBlock.Hash())
+		if currentBlock := bc.CurrentBlock(); currentBlock.Hash() == hash {
+			newBlock := bc.GetBlock(currentBlock.ParentHash(), currentBlock.NumberU64()-1)
+			bc.currentBlock.Store(newBlock)
+			WriteHeadBlockHash(bc.db, newBlock.Hash())
 		}
 	}
 }
@@ -829,11 +826,12 @@ func (bc *BlockChain) InsertReceiptChain(blockChain types.Blocks, receiptChain [
 	bc.mu.Lock()
 	head := blockChain[len(blockChain)-1]
 	if td := bc.GetTd(head.Hash(), head.NumberU64()); td != nil { // Rewind may have occurred, skip in that case
-		if bc.GetTd(bc.currentFastBlock.Hash(), bc.currentFastBlock.NumberU64()).Cmp(td) < 0 {
+		currentFastBlock := bc.CurrentFastBlock()
+		if bc.GetTd(currentFastBlock.Hash(), currentFastBlock.NumberU64()).Cmp(td) < 0 {
 			if err := WriteHeadFastBlockHash(bc.db, head.Hash()); err != nil {
 				log.Crit("Failed to update head fast block hash", "err", err)
 			}
-			bc.currentFastBlock = head
+			bc.currentFastBlock.Store(head)
 		}
 	}
 	bc.mu.Unlock()
@@ -880,7 +878,8 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
 	bc.mu.Lock()
 	defer bc.mu.Unlock()
 
-	localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64())
+	currentBlock := bc.CurrentBlock()
+	localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
 	externTd := new(big.Int).Add(block.Difficulty(), ptd)
 
 	// Irrelevant of the canonical status, write the block itself to the database
@@ -955,14 +954,15 @@ func (bc *BlockChain) WriteBlockWithState(block *types.Block, receipts []*types.
 	// Second clause in the if statement reduces the vulnerability to selfish mining.
 	// Please refer to http://www.cs.cornell.edu/~ie53/publications/btcProcFC.pdf
 	reorg := externTd.Cmp(localTd) > 0
+	currentBlock = bc.CurrentBlock()
 	if !reorg && externTd.Cmp(localTd) == 0 {
 		// Split same-difficulty blocks by number, then at random
-		reorg = block.NumberU64() < bc.currentBlock.NumberU64() || (block.NumberU64() == bc.currentBlock.NumberU64() && mrand.Float64() < 0.5)
+		reorg = block.NumberU64() < currentBlock.NumberU64() || (block.NumberU64() == currentBlock.NumberU64() && mrand.Float64() < 0.5)
 	}
 	if reorg {
 		// Reorganise the chain if the parent is not the head block
-		if block.ParentHash() != bc.currentBlock.Hash() {
-			if err := bc.reorg(bc.currentBlock, block); err != nil {
+		if block.ParentHash() != currentBlock.Hash() {
+			if err := bc.reorg(currentBlock, block); err != nil {
 				return NonStatTy, err
 			}
 		}
@@ -1091,7 +1091,8 @@ func (bc *BlockChain) insertChain(chain types.Blocks) (int, []interface{}, []*ty
 		case err == consensus.ErrPrunedAncestor:
 			// Block competing with the canonical chain, store in the db, but don't process
 			// until the competitor TD goes above the canonical TD
-			localTd := bc.GetTd(bc.currentBlock.Hash(), bc.currentBlock.NumberU64())
+			currentBlock := bc.CurrentBlock()
+			localTd := bc.GetTd(currentBlock.Hash(), currentBlock.NumberU64())
 			externTd := new(big.Int).Add(bc.GetTd(block.ParentHash(), block.NumberU64()-1), block.Difficulty())
 			if localTd.Cmp(externTd) > 0 {
 				if err = bc.WriteBlockWithoutState(block, externTd); err != nil {
@@ -1480,9 +1481,6 @@ func (bc *BlockChain) writeHeader(header *types.Header) error {
 // CurrentHeader retrieves the current head header of the canonical chain. The
 // header is retrieved from the HeaderChain's internal cache.
 func (bc *BlockChain) CurrentHeader() *types.Header {
-	bc.mu.RLock()
-	defer bc.mu.RUnlock()
-
 	return bc.hc.CurrentHeader()
 }
 
diff --git a/core/headerchain.go b/core/headerchain.go
index 0e5215293..73cd5d2c4 100644
--- a/core/headerchain.go
+++ b/core/headerchain.go
@@ -32,6 +32,7 @@ import (
 	"github.com/ethereum/go-ethereum/log"
 	"github.com/ethereum/go-ethereum/params"
 	"github.com/hashicorp/golang-lru"
+	"sync/atomic"
 )
 
 const (
@@ -51,8 +52,8 @@ type HeaderChain struct {
 	chainDb       ethdb.Database
 	genesisHeader *types.Header
 
-	currentHeader     *types.Header // Current head of the header chain (may be above the block chain!)
-	currentHeaderHash common.Hash   // Hash of the current head of the header chain (prevent recomputing all the time)
+	currentHeader     atomic.Value // Current head of the header chain (may be above the block chain!)
+	currentHeaderHash common.Hash  // Hash of the current head of the header chain (prevent recomputing all the time)
 
 	headerCache *lru.Cache // Cache for the most recent block headers
 	tdCache     *lru.Cache // Cache for the most recent block total difficulties
@@ -95,13 +96,13 @@ func NewHeaderChain(chainDb ethdb.Database, config *params.ChainConfig, engine c
 		return nil, ErrNoGenesis
 	}
 
-	hc.currentHeader = hc.genesisHeader
+	hc.currentHeader.Store(hc.genesisHeader)
 	if head := GetHeadBlockHash(chainDb); head != (common.Hash{}) {
 		if chead := hc.GetHeaderByHash(head); chead != nil {
-			hc.currentHeader = chead
+			hc.currentHeader.Store(chead)
 		}
 	}
-	hc.currentHeaderHash = hc.currentHeader.Hash()
+	hc.currentHeaderHash = hc.CurrentHeader().Hash()
 
 	return hc, nil
 }
@@ -139,7 +140,7 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er
 	if ptd == nil {
 		return NonStatTy, consensus.ErrUnknownAncestor
 	}
-	localTd := hc.GetTd(hc.currentHeaderHash, hc.currentHeader.Number.Uint64())
+	localTd := hc.GetTd(hc.currentHeaderHash, hc.CurrentHeader().Number.Uint64())
 	externTd := new(big.Int).Add(header.Difficulty, ptd)
 
 	// Irrelevant of the canonical status, write the td and header to the database
@@ -181,7 +182,8 @@ func (hc *HeaderChain) WriteHeader(header *types.Header) (status WriteStatus, er
 		if err := WriteHeadHeaderHash(hc.chainDb, hash); err != nil {
 			log.Crit("Failed to insert head header hash", "err", err)
 		}
-		hc.currentHeaderHash, hc.currentHeader = hash, types.CopyHeader(header)
+		hc.currentHeaderHash = hash
+		hc.currentHeader.Store(types.CopyHeader(header))
 
 		status = CanonStatTy
 	} else {
@@ -383,7 +385,7 @@ func (hc *HeaderChain) GetHeaderByNumber(number uint64) *types.Header {
 // CurrentHeader retrieves the current head header of the canonical chain. The
 // header is retrieved from the HeaderChain's internal cache.
 func (hc *HeaderChain) CurrentHeader() *types.Header {
-	return hc.currentHeader
+	return hc.currentHeader.Load().(*types.Header)
 }
 
 // SetCurrentHeader sets the current head header of the canonical chain.
@@ -391,7 +393,7 @@ func (hc *HeaderChain) SetCurrentHeader(head *types.Header) {
 	if err := WriteHeadHeaderHash(hc.chainDb, head.Hash()); err != nil {
 		log.Crit("Failed to insert head header hash", "err", err)
 	}
-	hc.currentHeader = head
+	hc.currentHeader.Store(head)
 	hc.currentHeaderHash = head.Hash()
 }
 
@@ -403,19 +405,20 @@ type DeleteCallback func(common.Hash, uint64)
 // will be deleted and the new one set.
 func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) {
 	height := uint64(0)
-	if hc.currentHeader != nil {
-		height = hc.currentHeader.Number.Uint64()
+
+	if hdr := hc.CurrentHeader(); hdr != nil {
+		height = hdr.Number.Uint64()
 	}
 
-	for hc.currentHeader != nil && hc.currentHeader.Number.Uint64() > head {
-		hash := hc.currentHeader.Hash()
-		num := hc.currentHeader.Number.Uint64()
+	for hdr := hc.CurrentHeader(); hdr != nil && hdr.Number.Uint64() > head; hdr = hc.CurrentHeader() {
+		hash := hdr.Hash()
+		num := hdr.Number.Uint64()
 		if delFn != nil {
 			delFn(hash, num)
 		}
 		DeleteHeader(hc.chainDb, hash, num)
 		DeleteTd(hc.chainDb, hash, num)
-		hc.currentHeader = hc.GetHeader(hc.currentHeader.ParentHash, hc.currentHeader.Number.Uint64()-1)
+		hc.currentHeader.Store(hc.GetHeader(hdr.ParentHash, hdr.Number.Uint64()-1))
 	}
 	// Roll back the canonical chain numbering
 	for i := height; i > head; i-- {
@@ -426,10 +429,10 @@ func (hc *HeaderChain) SetHead(head uint64, delFn DeleteCallback) {
 	hc.tdCache.Purge()
 	hc.numberCache.Purge()
 
-	if hc.currentHeader == nil {
-		hc.currentHeader = hc.genesisHeader
+	if hc.CurrentHeader() == nil {
+		hc.currentHeader.Store(hc.genesisHeader)
 	}
-	hc.currentHeaderHash = hc.currentHeader.Hash()
+	hc.currentHeaderHash = hc.CurrentHeader().Hash()
 
 	if err := WriteHeadHeaderHash(hc.chainDb, hc.currentHeaderHash); err != nil {
 		log.Crit("Failed to reset head header hash", "err", err)
diff --git a/light/lightchain.go b/light/lightchain.go
index 181a1c2a6..2784615d3 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -171,9 +171,6 @@ func (bc *LightChain) SetHead(head uint64) {
 
 // GasLimit returns the gas limit of the current HEAD block.
 func (self *LightChain) GasLimit() uint64 {
-	self.mu.RLock()
-	defer self.mu.RUnlock()
-
 	return self.hc.CurrentHeader().GasLimit
 }
 
@@ -387,9 +384,6 @@ func (self *LightChain) InsertHeaderChain(chain []*types.Header, checkFreq int)
 // CurrentHeader retrieves the current head header of the canonical chain. The
 // header is retrieved from the HeaderChain's internal cache.
 func (self *LightChain) CurrentHeader() *types.Header {
-	self.mu.RLock()
-	defer self.mu.RUnlock()
-
 	return self.hc.CurrentHeader()
 }
 
-- 
GitLab