From 049f5b357200406c01f7bf2d2e2936d1d28a319b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Felf=C3=B6ldi=20Zsolt?= <zsfelfoldi@gmail.com>
Date: Tue, 12 Jun 2018 15:52:54 +0200
Subject: [PATCH] core, eth, les: more efficient hash-based header chain
 retrieval (#16946)

---
 core/blockchain.go  | 12 ++++++++++++
 core/headerchain.go | 37 +++++++++++++++++++++++++++++++++++++
 eth/handler.go      | 37 +++++++++++++++++++++++--------------
 les/handler.go      | 35 ++++++++++++++++++++++-------------
 light/lightchain.go | 12 ++++++++++++
 5 files changed, 106 insertions(+), 27 deletions(-)

diff --git a/core/blockchain.go b/core/blockchain.go
index ea26fa034..34832252a 100644
--- a/core/blockchain.go
+++ b/core/blockchain.go
@@ -1524,6 +1524,18 @@ func (bc *BlockChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []com
 	return bc.hc.GetBlockHashesFromHash(hash, max)
 }
 
+// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or
+// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the
+// number of blocks to be individually checked before we reach the canonical chain.
+//
+// Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
+func (bc *BlockChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
+
+	return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
+}
+
 // GetHeaderByNumber retrieves a block header from the database by number,
 // caching it (associated with its hash) if found.
 func (bc *BlockChain) GetHeaderByNumber(number uint64) *types.Header {
diff --git a/core/headerchain.go b/core/headerchain.go
index 2ac0cccc7..6e759ed1c 100644
--- a/core/headerchain.go
+++ b/core/headerchain.go
@@ -307,6 +307,43 @@ func (hc *HeaderChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []co
 	return chain
 }
 
+// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or
+// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the
+// number of blocks to be individually checked before we reach the canonical chain.
+//
+// Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
+func (hc *HeaderChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
+	if ancestor > number {
+		return common.Hash{}, 0
+	}
+	if ancestor == 1 {
+		// in this case it is cheaper to just read the header
+		if header := hc.GetHeader(hash, number); header != nil {
+			return header.ParentHash, number - 1
+		} else {
+			return common.Hash{}, 0
+		}
+	}
+	for ancestor != 0 {
+		if rawdb.ReadCanonicalHash(hc.chainDb, number) == hash {
+			number -= ancestor
+			return rawdb.ReadCanonicalHash(hc.chainDb, number), number
+		}
+		if *maxNonCanonical == 0 {
+			return common.Hash{}, 0
+		}
+		*maxNonCanonical--
+		ancestor--
+		header := hc.GetHeader(hash, number)
+		if header == nil {
+			return common.Hash{}, 0
+		}
+		hash = header.ParentHash
+		number--
+	}
+	return hash, number
+}
+
 // GetTd retrieves a block's total difficulty in the canonical chain from the
 // database by hash and number, caching it if found.
 func (hc *HeaderChain) GetTd(hash common.Hash, number uint64) *big.Int {
diff --git a/eth/handler.go b/eth/handler.go
index 918d71088..a46e7f13c 100644
--- a/eth/handler.go
+++ b/eth/handler.go
@@ -340,6 +340,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			return errResp(ErrDecode, "%v: %v", msg, err)
 		}
 		hashMode := query.Origin.Hash != (common.Hash{})
+		first := true
+		maxNonCanonical := uint64(100)
 
 		// Gather headers until the fetch or network limits is reached
 		var (
@@ -351,31 +353,36 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			// Retrieve the next header satisfying the query
 			var origin *types.Header
 			if hashMode {
-				origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+				if first {
+					first = false
+					origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+					if origin != nil {
+						query.Origin.Number = origin.Number.Uint64()
+					}
+				} else {
+					origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
+				}
 			} else {
 				origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
 			}
 			if origin == nil {
 				break
 			}
-			number := origin.Number.Uint64()
 			headers = append(headers, origin)
 			bytes += estHeaderRlpSize
 
 			// Advance to the next header of the query
 			switch {
-			case query.Origin.Hash != (common.Hash{}) && query.Reverse:
+			case hashMode && query.Reverse:
 				// Hash based traversal towards the genesis block
-				for i := 0; i < int(query.Skip)+1; i++ {
-					if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil {
-						query.Origin.Hash = header.ParentHash
-						number--
-					} else {
-						unknown = true
-						break
-					}
+				ancestor := query.Skip + 1
+				if ancestor == 0 {
+					unknown = true
+				} else {
+					query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
+					unknown = (query.Origin.Hash == common.Hash{})
 				}
-			case query.Origin.Hash != (common.Hash{}) && !query.Reverse:
+			case hashMode && !query.Reverse:
 				// Hash based traversal towards the leaf block
 				var (
 					current = origin.Number.Uint64()
@@ -387,8 +394,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 					unknown = true
 				} else {
 					if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
-						if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash {
-							query.Origin.Hash = header.Hash()
+						nextHash := header.Hash()
+						expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
+						if expOldHash == query.Origin.Hash {
+							query.Origin.Hash, query.Origin.Number = nextHash, next
 						} else {
 							unknown = true
 						}
diff --git a/les/handler.go b/les/handler.go
index 38f810d72..a1c16cb87 100644
--- a/les/handler.go
+++ b/les/handler.go
@@ -83,7 +83,7 @@ type BlockChain interface {
 	InsertHeaderChain(chain []*types.Header, checkFreq int) (int, error)
 	Rollback(chain []common.Hash)
 	GetHeaderByNumber(number uint64) *types.Header
-	GetBlockHashesFromHash(hash common.Hash, max uint64) []common.Hash
+	GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64)
 	Genesis() *types.Block
 	SubscribeChainHeadEvent(ch chan<- core.ChainHeadEvent) event.Subscription
 }
@@ -419,6 +419,8 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 		}
 
 		hashMode := query.Origin.Hash != (common.Hash{})
+		first := true
+		maxNonCanonical := uint64(100)
 
 		// Gather headers until the fetch or network limits is reached
 		var (
@@ -430,14 +432,21 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			// Retrieve the next header satisfying the query
 			var origin *types.Header
 			if hashMode {
-				origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+				if first {
+					first = false
+					origin = pm.blockchain.GetHeaderByHash(query.Origin.Hash)
+					if origin != nil {
+						query.Origin.Number = origin.Number.Uint64()
+					}
+				} else {
+					origin = pm.blockchain.GetHeader(query.Origin.Hash, query.Origin.Number)
+				}
 			} else {
 				origin = pm.blockchain.GetHeaderByNumber(query.Origin.Number)
 			}
 			if origin == nil {
 				break
 			}
-			number := origin.Number.Uint64()
 			headers = append(headers, origin)
 			bytes += estHeaderRlpSize
 
@@ -445,14 +454,12 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 			switch {
 			case hashMode && query.Reverse:
 				// Hash based traversal towards the genesis block
-				for i := 0; i < int(query.Skip)+1; i++ {
-					if header := pm.blockchain.GetHeader(query.Origin.Hash, number); header != nil {
-						query.Origin.Hash = header.ParentHash
-						number--
-					} else {
-						unknown = true
-						break
-					}
+				ancestor := query.Skip + 1
+				if ancestor == 0 {
+					unknown = true
+				} else {
+					query.Origin.Hash, query.Origin.Number = pm.blockchain.GetAncestor(query.Origin.Hash, query.Origin.Number, ancestor, &maxNonCanonical)
+					unknown = (query.Origin.Hash == common.Hash{})
 				}
 			case hashMode && !query.Reverse:
 				// Hash based traversal towards the leaf block
@@ -466,8 +473,10 @@ func (pm *ProtocolManager) handleMsg(p *peer) error {
 					unknown = true
 				} else {
 					if header := pm.blockchain.GetHeaderByNumber(next); header != nil {
-						if pm.blockchain.GetBlockHashesFromHash(header.Hash(), query.Skip+1)[query.Skip] == query.Origin.Hash {
-							query.Origin.Hash = header.Hash()
+						nextHash := header.Hash()
+						expOldHash, _ := pm.blockchain.GetAncestor(nextHash, next, query.Skip+1, &maxNonCanonical)
+						if expOldHash == query.Origin.Hash {
+							query.Origin.Hash, query.Origin.Number = nextHash, next
 						} else {
 							unknown = true
 						}
diff --git a/light/lightchain.go b/light/lightchain.go
index 9d0a4e4f7..30b9bd89a 100644
--- a/light/lightchain.go
+++ b/light/lightchain.go
@@ -433,6 +433,18 @@ func (self *LightChain) GetBlockHashesFromHash(hash common.Hash, max uint64) []c
 	return self.hc.GetBlockHashesFromHash(hash, max)
 }
 
+// GetAncestor retrieves the Nth ancestor of a given block. It assumes that either the given block or
+// a close ancestor of it is canonical. maxNonCanonical points to a downwards counter limiting the
+// number of blocks to be individually checked before we reach the canonical chain.
+//
+// Note: ancestor == 0 returns the same block, 1 returns its parent and so on.
+func (bc *LightChain) GetAncestor(hash common.Hash, number, ancestor uint64, maxNonCanonical *uint64) (common.Hash, uint64) {
+	bc.chainmu.Lock()
+	defer bc.chainmu.Unlock()
+
+	return bc.hc.GetAncestor(hash, number, ancestor, maxNonCanonical)
+}
+
 // GetHeaderByNumber retrieves a block header from the database by number,
 // caching it (associated with its hash) if found.
 func (self *LightChain) GetHeaderByNumber(number uint64) *types.Header {
-- 
GitLab