From 9f6bb492bbcd6def92d0bc195faeb751e1591535 Mon Sep 17 00:00:00 2001
From: gary rong <garyrong0905@gmail.com>
Date: Thu, 10 Dec 2020 21:33:52 +0800
Subject: [PATCH] les, light: remove untrusted header retrieval in ODR (#21907)

* les, light: remove untrusted header retrieval in ODR

* les: polish

* light: check the hash equality in odr
---
 les/client_handler.go | 62 +++++++++++++++++++++++++++++++++++++------
 les/odr.go            | 22 ++++++++-------
 les/odr_requests.go   | 52 +++++++++++++++---------------------
 les/retrieve.go       |  9 +++++++
 les/sync.go           |  6 ++---
 les/sync_test.go      |  2 +-
 light/odr.go          | 11 +++-----
 light/odr_util.go     | 39 +++++++++++----------------
 light/trie.go         |  4 +++
 9 files changed, 125 insertions(+), 82 deletions(-)

diff --git a/les/client_handler.go b/les/client_handler.go
index 77a0ea5c6..d7ca1c54f 100644
--- a/les/client_handler.go
+++ b/les/client_handler.go
@@ -17,6 +17,7 @@
 package les
 
 import (
+	"context"
 	"math/big"
 	"sync"
 	"sync/atomic"
@@ -200,14 +201,23 @@ func (h *clientHandler) handleMsg(p *serverPeer) error {
 		p.fcServer.ReceivedReply(resp.ReqID, resp.BV)
 		p.answeredRequest(resp.ReqID)
 
-		// Filter out any explicitly requested headers, deliver the rest to the downloader
-		filter := len(headers) == 1
-		if filter {
-			headers = h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers)
-		}
-		if len(headers) != 0 || !filter {
-			if err := h.downloader.DeliverHeaders(p.id, headers); err != nil {
-				log.Debug("Failed to deliver headers", "err", err)
+		// Filter out the explicitly requested header by the retriever
+		if h.backend.retriever.requested(resp.ReqID) {
+			deliverMsg = &Msg{
+				MsgType: MsgBlockHeaders,
+				ReqID:   resp.ReqID,
+				Obj:     resp.Headers,
+			}
+		} else {
+			// Filter out any explicitly requested headers, deliver the rest to the downloader
+			filter := len(headers) == 1
+			if filter {
+				headers = h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers)
+			}
+			if len(headers) != 0 || !filter {
+				if err := h.downloader.DeliverHeaders(p.id, headers); err != nil {
+					log.Debug("Failed to deliver headers", "err", err)
+				}
 			}
 		}
 	case BlockBodiesMsg:
@@ -394,6 +404,42 @@ func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip
 	return nil
 }
 
+// RetrieveSingleHeaderByNumber requests a single header by the specified block
+// number. This function will wait the response until it's timeout or delivered.
+func (pc *peerConnection) RetrieveSingleHeaderByNumber(context context.Context, number uint64) (*types.Header, error) {
+	reqID := genReqID()
+	rq := &distReq{
+		getCost: func(dp distPeer) uint64 {
+			peer := dp.(*serverPeer)
+			return peer.getRequestCost(GetBlockHeadersMsg, 1)
+		},
+		canSend: func(dp distPeer) bool {
+			return dp.(*serverPeer) == pc.peer
+		},
+		request: func(dp distPeer) func() {
+			peer := dp.(*serverPeer)
+			cost := peer.getRequestCost(GetBlockHeadersMsg, 1)
+			peer.fcServer.QueuedRequest(reqID, cost)
+			return func() { peer.requestHeadersByNumber(reqID, number, 1, 0, false) }
+		},
+	}
+	var header *types.Header
+	if err := pc.handler.backend.retriever.retrieve(context, reqID, rq, func(peer distPeer, msg *Msg) error {
+		if msg.MsgType != MsgBlockHeaders {
+			return errInvalidMessageType
+		}
+		headers := msg.Obj.([]*types.Header)
+		if len(headers) != 1 {
+			return errInvalidEntryCount
+		}
+		header = headers[0]
+		return nil
+	}, nil); err != nil {
+		return nil, err
+	}
+	return header, nil
+}
+
 // downloaderPeerNotify implements peerSetNotify
 type downloaderPeerNotify clientHandler
 
diff --git a/les/odr.go b/les/odr.go
index f8469cc10..2c36f512d 100644
--- a/les/odr.go
+++ b/les/odr.go
@@ -24,7 +24,6 @@ import (
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/ethdb"
 	"github.com/ethereum/go-ethereum/light"
-	"github.com/ethereum/go-ethereum/log"
 )
 
 // LesOdr implements light.OdrBackend
@@ -83,7 +82,8 @@ func (odr *LesOdr) IndexerConfig() *light.IndexerConfig {
 }
 
 const (
-	MsgBlockBodies = iota
+	MsgBlockHeaders = iota
+	MsgBlockBodies
 	MsgCode
 	MsgReceipts
 	MsgProofsV2
@@ -122,13 +122,17 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro
 			return func() { lreq.Request(reqID, p) }
 		},
 	}
-	sent := mclock.Now()
-	if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil {
-		// retrieved from network, store in db
-		req.StoreResult(odr.db)
+
+	defer func(sent mclock.AbsTime) {
+		if err != nil {
+			return
+		}
 		requestRTT.Update(time.Duration(mclock.Now() - sent))
-	} else {
-		log.Debug("Failed to retrieve data from network", "err", err)
+	}(mclock.Now())
+
+	if err := odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err != nil {
+		return err
 	}
-	return
+	req.StoreResult(odr.db)
+	return nil
 }
diff --git a/les/odr_requests.go b/les/odr_requests.go
index 3704436a0..eb1d3602e 100644
--- a/les/odr_requests.go
+++ b/les/odr_requests.go
@@ -327,9 +327,6 @@ func (r *ChtRequest) CanSend(peer *serverPeer) bool {
 	peer.lock.RLock()
 	defer peer.lock.RUnlock()
 
-	if r.Untrusted {
-		return peer.headInfo.Number >= r.BlockNum && peer.id == r.PeerId
-	}
 	return peer.headInfo.Number >= r.Config.ChtConfirms && r.ChtNum <= (peer.headInfo.Number-r.Config.ChtConfirms)/r.Config.ChtSize
 }
 
@@ -369,39 +366,34 @@ func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error {
 	if err := rlp.DecodeBytes(headerEnc, header); err != nil {
 		return errHeaderUnavailable
 	}
-
 	// Verify the CHT
-	// Note: For untrusted CHT request, there is no proof response but
-	// header data.
-	var node light.ChtNode
-	if !r.Untrusted {
-		var encNumber [8]byte
-		binary.BigEndian.PutUint64(encNumber[:], r.BlockNum)
-
-		reads := &readTraceDB{db: nodeSet}
-		value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads)
-		if err != nil {
-			return fmt.Errorf("merkle proof verification failed: %v", err)
-		}
-		if len(reads.reads) != nodeSet.KeyCount() {
-			return errUselessNodes
-		}
+	var (
+		node      light.ChtNode
+		encNumber [8]byte
+	)
+	binary.BigEndian.PutUint64(encNumber[:], r.BlockNum)
 
-		if err := rlp.DecodeBytes(value, &node); err != nil {
-			return err
-		}
-		if node.Hash != header.Hash() {
-			return errCHTHashMismatch
-		}
-		if r.BlockNum != header.Number.Uint64() {
-			return errCHTNumberMismatch
-		}
+	reads := &readTraceDB{db: nodeSet}
+	value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads)
+	if err != nil {
+		return fmt.Errorf("merkle proof verification failed: %v", err)
+	}
+	if len(reads.reads) != nodeSet.KeyCount() {
+		return errUselessNodes
+	}
+	if err := rlp.DecodeBytes(value, &node); err != nil {
+		return err
+	}
+	if node.Hash != header.Hash() {
+		return errCHTHashMismatch
+	}
+	if r.BlockNum != header.Number.Uint64() {
+		return errCHTNumberMismatch
 	}
 	// Verifications passed, store and return
 	r.Header = header
 	r.Proof = nodeSet
-	r.Td = node.Td // For untrusted request, td here is nil, todo improve the les/2 protocol
-
+	r.Td = node.Td
 	return nil
 }
 
diff --git a/les/retrieve.go b/les/retrieve.go
index 4f77004f2..ca4f867ea 100644
--- a/les/retrieve.go
+++ b/les/retrieve.go
@@ -155,6 +155,15 @@ func (rm *retrieveManager) sendReq(reqID uint64, req *distReq, val validatorFunc
 	return r
 }
 
+// requested reports whether the request with given reqid is sent by the retriever.
+func (rm *retrieveManager) requested(reqId uint64) bool {
+	rm.lock.RLock()
+	defer rm.lock.RUnlock()
+
+	_, ok := rm.sentReqs[reqId]
+	return ok
+}
+
 // deliver is called by the LES protocol manager to deliver reply messages to waiting requests
 func (rm *retrieveManager) deliver(peer distPeer, msg *Msg) error {
 	rm.lock.RLock()
diff --git a/les/sync.go b/les/sync.go
index d2568d45b..ad3a0e0f3 100644
--- a/les/sync.go
+++ b/les/sync.go
@@ -56,8 +56,8 @@ func (h *clientHandler) validateCheckpoint(peer *serverPeer) error {
 	defer cancel()
 
 	// Fetch the block header corresponding to the checkpoint registration.
-	cp := peer.checkpoint
-	header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id)
+	wrapPeer := &peerConnection{handler: h, peer: peer}
+	header, err := wrapPeer.RetrieveSingleHeaderByNumber(ctx, peer.checkpointNumber)
 	if err != nil {
 		return err
 	}
@@ -66,7 +66,7 @@ func (h *clientHandler) validateCheckpoint(peer *serverPeer) error {
 	if err != nil {
 		return err
 	}
-	events := h.backend.oracle.Contract().LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash())
+	events := h.backend.oracle.Contract().LookupCheckpointEvents(logs, peer.checkpoint.SectionIndex, peer.checkpoint.Hash())
 	if len(events) == 0 {
 		return errInvalidCheckpoint
 	}
diff --git a/les/sync_test.go b/les/sync_test.go
index 6924e7b43..2eb0f88bf 100644
--- a/les/sync_test.go
+++ b/les/sync_test.go
@@ -53,7 +53,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) {
 			time.Sleep(10 * time.Millisecond)
 		}
 	}
-	// Generate 512+4 blocks (totally 1 CHT sections)
+	// Generate 128+1 blocks (totally 1 CHT sections)
 	server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false, true)
 	defer tearDown()
 
diff --git a/light/odr.go b/light/odr.go
index 7016ef8ef..bb243f915 100644
--- a/light/odr.go
+++ b/light/odr.go
@@ -135,8 +135,6 @@ func (req *ReceiptsRequest) StoreResult(db ethdb.Database) {
 
 // ChtRequest is the ODR request type for retrieving header by Canonical Hash Trie
 type ChtRequest struct {
-	Untrusted        bool   // Indicator whether the result retrieved is trusted or not
-	PeerId           string // The specified peer id from which to retrieve data.
 	Config           *IndexerConfig
 	ChtNum, BlockNum uint64
 	ChtRoot          common.Hash
@@ -148,12 +146,9 @@ type ChtRequest struct {
 // StoreResult stores the retrieved data in local database
 func (req *ChtRequest) StoreResult(db ethdb.Database) {
 	hash, num := req.Header.Hash(), req.Header.Number.Uint64()
-
-	if !req.Untrusted {
-		rawdb.WriteHeader(db, req.Header)
-		rawdb.WriteTd(db, hash, num, req.Td)
-		rawdb.WriteCanonicalHash(db, hash, num)
-	}
+	rawdb.WriteHeader(db, req.Header)
+	rawdb.WriteTd(db, hash, num, req.Td)
+	rawdb.WriteCanonicalHash(db, hash, num)
 }
 
 // BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure
diff --git a/light/odr_util.go b/light/odr_util.go
index aec0c7b69..ec2d1e649 100644
--- a/light/odr_util.go
+++ b/light/odr_util.go
@@ -19,20 +19,23 @@ package light
 import (
 	"bytes"
 	"context"
+	"errors"
 	"math/big"
 
 	"github.com/ethereum/go-ethereum/common"
 	"github.com/ethereum/go-ethereum/core"
 	"github.com/ethereum/go-ethereum/core/rawdb"
 	"github.com/ethereum/go-ethereum/core/types"
-	"github.com/ethereum/go-ethereum/crypto"
 	"github.com/ethereum/go-ethereum/rlp"
 )
 
-var sha3Nil = crypto.Keccak256Hash(nil)
+// errNonCanonicalHash is returned if the requested chain data doesn't belong
+// to the canonical chain. ODR can only retrieve the canonical chain data covered
+// by the CHT or Bloom trie for verification.
+var errNonCanonicalHash = errors.New("hash is not currently canonical")
 
 // GetHeaderByNumber retrieves the canonical block header corresponding to the
-// given number.
+// given number. The returned header is proven by local CHT.
 func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) {
 	// Try to find it in the local database first.
 	db := odr.Database()
@@ -63,25 +66,6 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ
 	return r.Header, nil
 }
 
-// GetUntrustedHeaderByNumber retrieves specified block header without
-// correctness checking. Note this function should only be used in light
-// client checkpoint syncing.
-func GetUntrustedHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64, peerId string) (*types.Header, error) {
-	// todo(rjl493456442) it's a hack to retrieve headers which is not covered
-	// by CHT. Fix it in LES4
-	r := &ChtRequest{
-		BlockNum:  number,
-		ChtNum:    number / odr.IndexerConfig().ChtSize,
-		Untrusted: true,
-		PeerId:    peerId,
-		Config:    odr.IndexerConfig(),
-	}
-	if err := odr.Retrieve(ctx, r); err != nil {
-		return nil, err
-	}
-	return r.Header, nil
-}
-
 // GetCanonicalHash retrieves the canonical block hash corresponding to the number.
 func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (common.Hash, error) {
 	hash := rawdb.ReadCanonicalHash(odr.Database(), number)
@@ -102,10 +86,13 @@ func GetTd(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64)
 	if td != nil {
 		return td, nil
 	}
-	_, err := GetHeaderByNumber(ctx, odr, number)
+	header, err := GetHeaderByNumber(ctx, odr, number)
 	if err != nil {
 		return nil, err
 	}
+	if header.Hash() != hash {
+		return nil, errNonCanonicalHash
+	}
 	// <hash, number> -> td mapping already be stored in db, get it.
 	return rawdb.ReadTd(odr.Database(), hash, number), nil
 }
@@ -120,6 +107,9 @@ func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number ui
 	if err != nil {
 		return nil, errNoHeader
 	}
+	if header.Hash() != hash {
+		return nil, errNonCanonicalHash
+	}
 	r := &BlockRequest{Hash: hash, Number: number, Header: header}
 	if err := odr.Retrieve(ctx, r); err != nil {
 		return nil, err
@@ -167,6 +157,9 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num
 		if err != nil {
 			return nil, errNoHeader
 		}
+		if header.Hash() != hash {
+			return nil, errNonCanonicalHash
+		}
 		r := &ReceiptsRequest{Hash: hash, Number: number, Header: header}
 		if err := odr.Retrieve(ctx, r); err != nil {
 			return nil, err
diff --git a/light/trie.go b/light/trie.go
index 3eb05f4a3..0516b9448 100644
--- a/light/trie.go
+++ b/light/trie.go
@@ -30,6 +30,10 @@ import (
 	"github.com/ethereum/go-ethereum/trie"
 )
 
+var (
+	sha3Nil = crypto.Keccak256Hash(nil)
+)
+
 func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB {
 	state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil)
 	return state
-- 
GitLab