From 2f0758e81719a6ac59d0e6e12556c1fdb5aa40f9 Mon Sep 17 00:00:00 2001 From: TBC Dev <48684072+tbcd@users.noreply.github.com> Date: Sun, 28 Nov 2021 01:28:17 +0800 Subject: [PATCH] ChainSegment efficiency (#3042) * De-dup blockHeaders66() and blockHeaders65() * Simplify loops and EOL detection * Add ChainSegmentHeader struct and refactor * Add RawRlpHash() to avoid re-encode for header hash * Avoid multiple redundant rlpHash() * Sort headers by height,hash to make dups consecutive * Flip condition to reduce map lookups * Remove redundant check * Use rlp.RawValue rather than []byte to help self-document --- cmd/sentry/download/downloader.go | 140 ++++------- core/types/hashing.go | 9 + turbo/stages/headerdownload/header_algos.go | 217 ++++++++++-------- .../headerdownload/header_data_struct.go | 13 +- turbo/stages/headerdownload/header_test.go | 55 +++-- 5 files changed, 215 insertions(+), 219 deletions(-) diff --git a/cmd/sentry/download/downloader.go b/cmd/sentry/download/downloader.go index 10fc8a1677..18a15e6b90 100644 --- a/cmd/sentry/download/downloader.go +++ b/cmd/sentry/download/downloader.go @@ -16,6 +16,7 @@ import ( "github.com/ledgerwatch/erigon-lib/direct" "github.com/ledgerwatch/erigon-lib/gointerfaces" proto_sentry "github.com/ledgerwatch/erigon-lib/gointerfaces/sentry" + proto_types "github.com/ledgerwatch/erigon-lib/gointerfaces/types" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon/common" "github.com/ledgerwatch/erigon/consensus" @@ -468,7 +469,13 @@ func (cs *ControlServerImpl) newBlockHashes65(ctx context.Context, req *proto_se } func (cs *ControlServerImpl) blockHeaders66(ctx context.Context, in *proto_sentry.InboundMessage, sentry direct.SentryClient) error { - // Extract header from the block + // Parse the entire packet from scratch + var pkt eth.BlockHeadersPacket66 + if err := rlp.DecodeBytes(in.Data, &pkt); err != nil { + return fmt.Errorf("decode 1 BlockHeadersPacket66: %w", err) + } + + // Prepare to extract raw headers from the block rlpStream := rlp.NewStream(bytes.NewReader(in.Data), uint64(len(in.Data))) if _, err := rlpStream.List(); err != nil { // Now stream is at the beginning of 66 object return fmt.Errorf("decode 1 BlockHeadersPacket66: %w", err) @@ -476,114 +483,55 @@ func (cs *ControlServerImpl) blockHeaders66(ctx context.Context, in *proto_sentr if _, err := rlpStream.Uint(); err != nil { // Now stream is at the requestID field return fmt.Errorf("decode 2 BlockHeadersPacket66: %w", err) } - if _, err := rlpStream.List(); err != nil { // Now stream is at the BlockHeadersPacket, which is list of headers - return fmt.Errorf("decode 3 BlockHeadersPacket66: %w", err) - } - var headersRaw [][]byte - for headerRaw, err := rlpStream.Raw(); ; headerRaw, err = rlpStream.Raw() { - if err != nil { - if !errors.Is(err, rlp.EOL) { - return fmt.Errorf("decode 4 BlockHeadersPacket66: %w", err) - } - break - } + // Now stream is at the BlockHeadersPacket, which is list of headers - headersRaw = append(headersRaw, headerRaw) - } + return cs.blockHeaders(ctx, pkt.BlockHeadersPacket, rlpStream, in.PeerId, sentry) +} - // Parse the entire request from scratch - var request eth.BlockHeadersPacket66 - if err := rlp.DecodeBytes(in.Data, &request); err != nil { - return fmt.Errorf("decode 5 BlockHeadersPacket66: %w", err) - } - headers := request.BlockHeadersPacket - var heighestBlock uint64 - for _, h := range headers { - if h.Number.Uint64() > heighestBlock { - heighestBlock = h.Number.Uint64() - } +func (cs *ControlServerImpl) blockHeaders65(ctx context.Context, in *proto_sentry.InboundMessage, sentry direct.SentryClient) error { + // Parse the entire packet from scratch + var pkt eth.BlockHeadersPacket + if err := rlp.DecodeBytes(in.Data, &pkt); err != nil { + return fmt.Errorf("decode 1 BlockHeadersPacket65: %w", err) } - if segments, penalty, err := cs.Hd.SplitIntoSegments(headersRaw, headers); err == nil { - if penalty == headerdownload.NoPenalty { - var canRequestMore bool - for _, segment := range segments { - requestMore, penalties := cs.Hd.ProcessSegment(segment, false /* newBlock */, ConvertH256ToPeerID(in.PeerId)) - canRequestMore = canRequestMore || requestMore - if len(penalties) > 0 { - cs.Penalize(ctx, penalties) - } - } + // Prepare to extract raw headers from the block + rlpStream := rlp.NewStream(bytes.NewReader(in.Data), uint64(len(in.Data))) + // Now stream is at the BlockHeadersPacket, which is list of headers - if canRequestMore { - currentTime := uint64(time.Now().Unix()) - req, penalties := cs.Hd.RequestMoreHeaders(currentTime) - if req != nil { - if _, ok := cs.SendHeaderRequest(ctx, req); ok { - cs.Hd.SentRequest(req, currentTime, 5 /* timeout */) - log.Trace("Sent request", "height", req.Number) - } - } - cs.Penalize(ctx, penalties) - } - } else { - outreq := proto_sentry.PenalizePeerRequest{ - PeerId: in.PeerId, - Penalty: proto_sentry.PenaltyKind_Kick, // TODO: Extend penalty kinds - } - if _, err1 := sentry.PenalizePeer(ctx, &outreq, &grpc.EmptyCallOption{}); err1 != nil { - log.Error("Could not send penalty", "err", err1) - } - } - } else { - return fmt.Errorf("singleHeaderAsSegment failed: %w", err) - } - outreq := proto_sentry.PeerMinBlockRequest{ - PeerId: in.PeerId, - MinBlock: heighestBlock, - } - if _, err1 := sentry.PeerMinBlock(ctx, &outreq, &grpc.EmptyCallOption{}); err1 != nil { - log.Error("Could not send min block for peer", "err", err1) - } - return nil + return cs.blockHeaders(ctx, pkt, rlpStream, in.PeerId, sentry) } -func (cs *ControlServerImpl) blockHeaders65(ctx context.Context, in *proto_sentry.InboundMessage, sentry direct.SentryClient) error { - // Extract header from the block - rlpStream := rlp.NewStream(bytes.NewReader(in.Data), uint64(len(in.Data))) - if _, err := rlpStream.List(); err != nil { // Now stream is at the BlockHeadersPacket, which is list of headers - return fmt.Errorf("decode 3 BlockHeadersPacket66: %w", err) +func (cs *ControlServerImpl) blockHeaders(ctx context.Context, pkt eth.BlockHeadersPacket, rlpStream *rlp.Stream, peerID *proto_types.H256, sentry direct.SentryClient) error { + // Stream is at the BlockHeadersPacket, which is list of headers + if _, err := rlpStream.List(); err != nil { + return fmt.Errorf("decode 2 BlockHeadersPacket65: %w", err) } - var headersRaw [][]byte - for headerRaw, err := rlpStream.Raw(); ; headerRaw, err = rlpStream.Raw() { + // Extract headers from the block + var highestBlock uint64 + csHeaders := make([]headerdownload.ChainSegmentHeader, 0, len(pkt)) + for _, header := range pkt { + headerRaw, err := rlpStream.Raw() if err != nil { - if !errors.Is(err, rlp.EOL) { - return fmt.Errorf("decode 4 BlockHeadersPacket66: %w", err) - } - break + return fmt.Errorf("decode 3 BlockHeadersPacket65: %w", err) } - - headersRaw = append(headersRaw, headerRaw) - } - - // Parse the entire request from scratch - var request eth.BlockHeadersPacket - if err := rlp.DecodeBytes(in.Data, &request); err != nil { - return fmt.Errorf("decode 5 BlockHeadersPacket66: %w", err) - } - headers := request - var heighestBlock uint64 - for _, h := range headers { - if h.Number.Uint64() > heighestBlock { - heighestBlock = h.Number.Uint64() + number := header.Number.Uint64() + if number > highestBlock { + highestBlock = number } + csHeaders = append(csHeaders, headerdownload.ChainSegmentHeader{ + Header: header, + HeaderRaw: headerRaw, + Hash: types.RawRlpHash(headerRaw), + Number: number, + }) } - if segments, penalty, err := cs.Hd.SplitIntoSegments(headersRaw, headers); err == nil { + if segments, penalty, err := cs.Hd.SplitIntoSegments(csHeaders); err == nil { if penalty == headerdownload.NoPenalty { var canRequestMore bool for _, segment := range segments { - requestMore, penalties := cs.Hd.ProcessSegment(segment, false /* newBlock */, ConvertH256ToPeerID(in.PeerId)) + requestMore, penalties := cs.Hd.ProcessSegment(segment, false /* newBlock */, ConvertH256ToPeerID(peerID)) canRequestMore = canRequestMore || requestMore if len(penalties) > 0 { cs.Penalize(ctx, penalties) @@ -603,7 +551,7 @@ func (cs *ControlServerImpl) blockHeaders65(ctx context.Context, in *proto_sentr } } else { outreq := proto_sentry.PenalizePeerRequest{ - PeerId: in.PeerId, + PeerId: peerID, Penalty: proto_sentry.PenaltyKind_Kick, // TODO: Extend penalty kinds } if _, err1 := sentry.PenalizePeer(ctx, &outreq, &grpc.EmptyCallOption{}); err1 != nil { @@ -614,8 +562,8 @@ func (cs *ControlServerImpl) blockHeaders65(ctx context.Context, in *proto_sentr return fmt.Errorf("singleHeaderAsSegment failed: %w", err) } outreq := proto_sentry.PeerMinBlockRequest{ - PeerId: in.PeerId, - MinBlock: heighestBlock, + PeerId: peerID, + MinBlock: highestBlock, } if _, err1 := sentry.PeerMinBlock(ctx, &outreq, &grpc.EmptyCallOption{}); err1 != nil { log.Error("Could not send min block for peer", "err", err1) diff --git a/core/types/hashing.go b/core/types/hashing.go index d1173e3c50..8b08629ed7 100644 --- a/core/types/hashing.go +++ b/core/types/hashing.go @@ -167,6 +167,15 @@ var hasherPool = sync.Pool{ }, } +func RawRlpHash(rawRlpData rlp.RawValue) (h common.Hash) { + sha := hasherPool.Get().(crypto.KeccakState) + defer hasherPool.Put(sha) + sha.Reset() + sha.Write(rawRlpData) //nolint:errcheck + sha.Read(h[:]) //nolint:errcheck + return h +} + func rlpHash(x interface{}) (h common.Hash) { sha := hasherPool.Get().(crypto.KeccakState) defer hasherPool.Put(sha) diff --git a/turbo/stages/headerdownload/header_algos.go b/turbo/stages/headerdownload/header_algos.go index 0b5ce43ba3..f0f49808e7 100644 --- a/turbo/stages/headerdownload/header_algos.go +++ b/turbo/stages/headerdownload/header_algos.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "math" "math/big" "sort" "strings" @@ -28,61 +29,67 @@ import ( ) // Implements sort.Interface so we can sort the incoming header in the message by block height -type HeadersByBlockHeight []*types.Header +type HeadersByHeightAndHash []ChainSegmentHeader -func (h HeadersByBlockHeight) Len() int { +func (h HeadersByHeightAndHash) Len() int { return len(h) } -func (h HeadersByBlockHeight) Less(i, j int) bool { +func (h HeadersByHeightAndHash) Less(i, j int) bool { // Note - the ordering is the inverse ordering of the block heights - return h[i].Number.Cmp(h[j].Number) > 0 + if h[i].Number != h[j].Number { + return h[i].Number > h[j].Number + } + return bytes.Compare(h[i].Hash[:], h[j].Hash[:]) > 0 } -func (h HeadersByBlockHeight) Swap(i, j int) { +func (h HeadersByHeightAndHash) Swap(i, j int) { h[i], h[j] = h[j], h[i] } // SplitIntoSegments converts message containing headers into a collection of chain segments -func (hd *HeaderDownload) SplitIntoSegments(headersRaw [][]byte, msg []*types.Header) ([]*ChainSegment, Penalty, error) { +func (hd *HeaderDownload) SplitIntoSegments(csHeaders []ChainSegmentHeader) ([]ChainSegment, Penalty, error) { hd.lock.RLock() defer hd.lock.RUnlock() - sort.Sort(HeadersByBlockHeight(msg)) + sort.Sort(HeadersByHeightAndHash(csHeaders)) // Now all headers are order from the highest block height to the lowest - var segments []*ChainSegment // Segments being built + var segments []ChainSegment // Segments being built segmentMap := make(map[common.Hash]int) // Mapping of the header hash to the index of the chain segment it belongs childrenMap := make(map[common.Hash][]*types.Header) // Mapping parent hash to the children - dedupMap := make(map[common.Hash]struct{}) // Map used for detecting duplicate headers - for i, header := range msg { - headerHash := header.Hash() - if _, bad := hd.badHeaders[headerHash]; bad { - return nil, BadBlockPenalty, nil - } - if _, duplicate := dedupMap[headerHash]; duplicate { + + number := uint64(math.MaxUint64) + var hash common.Hash + for _, h := range csHeaders { + // Headers are sorted by number, then by hash, so any dups will be consecutive + if h.Number == number && h.Hash == hash { return nil, DuplicateHeaderPenalty, nil } - dedupMap[headerHash] = struct{}{} + number = h.Number + hash = h.Hash + + if _, bad := hd.badHeaders[hash]; bad { + return nil, BadBlockPenalty, nil + } var segmentIdx int - children := childrenMap[headerHash] + children := childrenMap[hash] for _, child := range children { - if valid, penalty := hd.childParentValid(child, header); !valid { + if valid, penalty := hd.childParentValid(child, h.Header); !valid { return nil, penalty, nil } } if len(children) == 1 { // Single child, extract segmentIdx - segmentIdx = segmentMap[headerHash] + segmentIdx = segmentMap[hash] } else { // No children, or more than one child, create new segment segmentIdx = len(segments) - segments = append(segments, &ChainSegment{}) + segments = append(segments, ChainSegment{}) } - segments[segmentIdx].Headers = append(segments[segmentIdx].Headers, header) - segments[segmentIdx].HeadersRaw = append(segments[segmentIdx].HeadersRaw, headersRaw[i]) - segmentMap[header.ParentHash] = segmentIdx - siblings := childrenMap[header.ParentHash] - siblings = append(siblings, header) - childrenMap[header.ParentHash] = siblings + segments[segmentIdx] = append(segments[segmentIdx], h) + segmentMap[h.Header.ParentHash] = segmentIdx + siblings := childrenMap[h.Header.ParentHash] + siblings = append(siblings, h.Header) + childrenMap[h.Header.ParentHash] = siblings } return segments, NoPenalty, nil } @@ -97,14 +104,21 @@ func (hd *HeaderDownload) childParentValid(child, parent *types.Header) (bool, P } // SingleHeaderAsSegment converts message containing 1 header into one singleton chain segment -func (hd *HeaderDownload) SingleHeaderAsSegment(headerRaw []byte, header *types.Header) ([]*ChainSegment, Penalty, error) { +func (hd *HeaderDownload) SingleHeaderAsSegment(headerRaw []byte, header *types.Header) ([]ChainSegment, Penalty, error) { hd.lock.RLock() defer hd.lock.RUnlock() - headerHash := header.Hash() + + headerHash := types.RawRlpHash(headerRaw) if _, bad := hd.badHeaders[headerHash]; bad { return nil, BadBlockPenalty, nil } - return []*ChainSegment{{HeadersRaw: [][]byte{headerRaw}, Headers: []*types.Header{header}}}, NoPenalty, nil + h := ChainSegmentHeader{ + Header: header, + HeaderRaw: headerRaw, + Hash: headerHash, + Number: header.Number.Uint64(), + } + return []ChainSegment{{h}}, NoPenalty, nil } // ReportBadHeader - @@ -122,11 +136,11 @@ func (hd *HeaderDownload) IsBadHeader(headerHash common.Hash) bool { } // FindAnchors attempts to find anchors to which given chain segment can be attached to -func (hd *HeaderDownload) findAnchors(segment *ChainSegment) (found bool, start int) { +func (hd *HeaderDownload) findAnchors(segment ChainSegment) (found bool, start int) { // Walk the segment from children towards parents - for i, header := range segment.Headers { + for i, h := range segment { // Check if the header can be attached to an anchor of a working tree - if _, attaching := hd.anchors[header.Hash()]; attaching { + if _, attaching := hd.anchors[h.Hash]; attaching { return true, i } } @@ -134,25 +148,25 @@ func (hd *HeaderDownload) findAnchors(segment *ChainSegment) (found bool, start } // FindLink attempts to find a non-persisted link that given chain segment can be attached to. -func (hd *HeaderDownload) findLink(segment *ChainSegment, start int) (found bool, end int) { - if _, duplicate := hd.getLink(segment.Headers[start].Hash()); duplicate { +func (hd *HeaderDownload) findLink(segment ChainSegment, start int) (found bool, end int) { + if _, duplicate := hd.getLink(segment[start].Hash); duplicate { return false, 0 } // Walk the segment from children towards parents - for i, header := range segment.Headers[start:] { + for i, h := range segment[start:] { // Check if the header can be attached to any links - if _, attaching := hd.getLink(header.ParentHash); attaching { + if _, attaching := hd.getLink(h.Header.ParentHash); attaching { return true, start + i + 1 } } - return false, len(segment.Headers) + return false, len(segment) } func (hd *HeaderDownload) removeUpwards(toRemove []*Link) { for len(toRemove) > 0 { removal := toRemove[len(toRemove)-1] toRemove = toRemove[:len(toRemove)-1] - delete(hd.links, removal.header.Hash()) + delete(hd.links, removal.hash) heap.Remove(hd.linkQueue, removal.idx) toRemove = append(toRemove, removal.next...) } @@ -167,12 +181,12 @@ func (hd *HeaderDownload) markPreverified(link *Link) { } // ExtendUp extends a working tree up from the link, using given chain segment -func (hd *HeaderDownload) extendUp(segment *ChainSegment, start, end int) error { +func (hd *HeaderDownload) extendUp(segment ChainSegment, start, end int) error { // Find attachment link again - linkHeader := segment.Headers[end-1] - attachmentLink, attaching := hd.getLink(linkHeader.ParentHash) + linkH := segment[end-1] + attachmentLink, attaching := hd.getLink(linkH.Header.ParentHash) if !attaching { - return fmt.Errorf("extendUp attachment link not found for %x", linkHeader.ParentHash) + return fmt.Errorf("extendUp attachment link not found for %x", linkH.Header.ParentHash) } if attachmentLink.preverified && len(attachmentLink.next) > 0 { return fmt.Errorf("cannot extendUp from preverified link %d with children", attachmentLink.blockHeight) @@ -180,7 +194,7 @@ func (hd *HeaderDownload) extendUp(segment *ChainSegment, start, end int) error // Iterate over headers backwards (from parents towards children) prevLink := attachmentLink for i := end - 1; i >= start; i-- { - link := hd.addHeaderAsLink(segment.Headers[i], false /* persisted */) + link := hd.addHeaderAsLink(segment[i], false /* persisted */) prevLink.next = append(prevLink.next, link) prevLink = link if _, ok := hd.preverifiedHashes[link.hash]; ok { @@ -189,7 +203,7 @@ func (hd *HeaderDownload) extendUp(segment *ChainSegment, start, end int) error } if _, bad := hd.badHeaders[attachmentLink.hash]; !bad && attachmentLink.persisted { - link := hd.links[linkHeader.Hash()] + link := hd.links[linkH.Hash] hd.insertList = append(hd.insertList, link) } return nil @@ -197,10 +211,10 @@ func (hd *HeaderDownload) extendUp(segment *ChainSegment, start, end int) error // ExtendDown extends some working trees down from the anchor, using given chain segment // it creates a new anchor and collects all the links from the attached anchors to it -func (hd *HeaderDownload) extendDown(segment *ChainSegment, start, end int) (bool, error) { +func (hd *HeaderDownload) extendDown(segment ChainSegment, start, end int) (bool, error) { // Find attachment anchor again - anchorHeader := segment.Headers[start] - if anchor, attaching := hd.anchors[anchorHeader.Hash()]; attaching { + anchorHash := segment[start].Hash + if anchor, attaching := hd.anchors[anchorHash]; attaching { anchorPreverified := false for _, link := range anchor.links { if link.preverified { @@ -208,7 +222,8 @@ func (hd *HeaderDownload) extendDown(segment *ChainSegment, start, end int) (boo break } } - newAnchorHeader := segment.Headers[end-1] + newAnchorH := segment[end-1] + newAnchorHeader := newAnchorH.Header var newAnchor *Anchor newAnchor, preExisting := hd.anchors[newAnchorHeader.ParentHash] if !preExisting { @@ -216,7 +231,7 @@ func (hd *HeaderDownload) extendDown(segment *ChainSegment, start, end int) (boo parentHash: newAnchorHeader.ParentHash, timestamp: 0, peerID: anchor.peerID, - blockHeight: newAnchorHeader.Number.Uint64(), + blockHeight: newAnchorH.Number, } if newAnchor.blockHeight > 0 { hd.anchors[newAnchorHeader.ParentHash] = newAnchor @@ -231,7 +246,7 @@ func (hd *HeaderDownload) extendDown(segment *ChainSegment, start, end int) (boo // Add all headers in the segments as links to this anchor var prevLink *Link for i := end - 1; i >= start; i-- { - link := hd.addHeaderAsLink(segment.Headers[i], false /* pesisted */) + link := hd.addHeaderAsLink(segment[i], false /* pesisted */) if prevLink == nil { newAnchor.links = append(newAnchor.links, link) } else { @@ -250,25 +265,25 @@ func (hd *HeaderDownload) extendDown(segment *ChainSegment, start, end int) (boo } return !preExisting, nil } - return false, fmt.Errorf("extendDown attachment anchors not found for %x", anchorHeader.Hash()) + return false, fmt.Errorf("extendDown attachment anchors not found for %x", anchorHash) } // Connect connects some working trees using anchors of some, and a link of another -func (hd *HeaderDownload) connect(segment *ChainSegment, start, end int) ([]PenaltyItem, error) { +func (hd *HeaderDownload) connect(segment ChainSegment, start, end int) ([]PenaltyItem, error) { // Find attachment link again - linkHeader := segment.Headers[end-1] + linkH := segment[end-1] // Find attachement anchors again - anchorHeader := segment.Headers[start] - attachmentLink, ok1 := hd.getLink(linkHeader.ParentHash) + anchorHash := segment[start].Hash + attachmentLink, ok1 := hd.getLink(linkH.Header.ParentHash) if !ok1 { - return nil, fmt.Errorf("connect attachment link not found for %x", linkHeader.ParentHash) + return nil, fmt.Errorf("connect attachment link not found for %x", linkH.Header.ParentHash) } if attachmentLink.preverified && len(attachmentLink.next) > 0 { return nil, fmt.Errorf("cannot connect to preverified link %d with children", attachmentLink.blockHeight) } - anchor, ok2 := hd.anchors[anchorHeader.Hash()] + anchor, ok2 := hd.anchors[anchorHash] if !ok2 { - return nil, fmt.Errorf("connect attachment anchors not found for %x", anchorHeader.Hash()) + return nil, fmt.Errorf("connect attachment anchors not found for %x", anchorHash) } anchorPreverified := false for _, link := range anchor.links { @@ -284,7 +299,7 @@ func (hd *HeaderDownload) connect(segment *ChainSegment, start, end int) ([]Pena // Iterate over headers backwards (from parents towards children) prevLink := attachmentLink for i := end - 1; i >= start; i-- { - link := hd.addHeaderAsLink(segment.Headers[i], false /* persisted */) + link := hd.addHeaderAsLink(segment[i], false /* persisted */) prevLink.next = append(prevLink.next, link) prevLink = link if _, ok := hd.preverifiedHashes[link.hash]; ok { @@ -302,18 +317,18 @@ func (hd *HeaderDownload) connect(segment *ChainSegment, start, end int) ([]Pena hd.invalidateAnchor(anchor) penalties = append(penalties, PenaltyItem{Penalty: AbandonedAnchorPenalty, PeerID: anchor.peerID}) } else if attachmentLink.persisted { - link := hd.links[linkHeader.Hash()] + link := hd.links[linkH.Hash] hd.insertList = append(hd.insertList, link) } return penalties, nil } -func (hd *HeaderDownload) removeAnchor(segment *ChainSegment, start int) error { +func (hd *HeaderDownload) removeAnchor(segment ChainSegment, start int) error { // Find attachement anchors again - anchorHeader := segment.Headers[start] - anchor, ok := hd.anchors[anchorHeader.Hash()] + anchorHash := segment[start].Hash + anchor, ok := hd.anchors[anchorHash] if !ok { - return fmt.Errorf("connect attachment anchors not found for %x", anchorHeader.Hash()) + return fmt.Errorf("connect attachment anchors not found for %x", anchorHash) } // Anchor is removed from the map, but not from the anchorQueue // This is because it is hard to find the index under which the anchor is stored in the anchorQueue @@ -323,14 +338,15 @@ func (hd *HeaderDownload) removeAnchor(segment *ChainSegment, start int) error { } // if anchor will be abandoned - given peerID will get Penalty -func (hd *HeaderDownload) newAnchor(segment *ChainSegment, start, end int, peerID enode.ID) (bool, error) { - anchorHeader := segment.Headers[end-1] +func (hd *HeaderDownload) newAnchor(segment ChainSegment, start, end int, peerID enode.ID) (bool, error) { + anchorH := segment[end-1] + anchorHeader := anchorH.Header var anchor *Anchor anchor, preExisting := hd.anchors[anchorHeader.ParentHash] if !preExisting { - if anchorHeader.Number.Uint64() < hd.highestInDb { - return false, fmt.Errorf("new anchor too far in the past: %d, latest header in db: %d", anchorHeader.Number.Uint64(), hd.highestInDb) + if anchorH.Number < hd.highestInDb { + return false, fmt.Errorf("new anchor too far in the past: %d, latest header in db: %d", anchorH.Number, hd.highestInDb) } if len(hd.anchors) >= hd.anchorLimit { return false, fmt.Errorf("too many anchors: %d, limit %d", len(hd.anchors), hd.anchorLimit) @@ -339,7 +355,7 @@ func (hd *HeaderDownload) newAnchor(segment *ChainSegment, start, end int, peerI parentHash: anchorHeader.ParentHash, peerID: peerID, timestamp: 0, - blockHeight: anchorHeader.Number.Uint64(), + blockHeight: anchorH.Number, } hd.anchors[anchorHeader.ParentHash] = anchor heap.Push(hd.anchorQueue, anchor) @@ -347,8 +363,7 @@ func (hd *HeaderDownload) newAnchor(segment *ChainSegment, start, end int, peerI // Iterate over headers backwards (from parents towards children) var prevLink *Link for i := end - 1; i >= start; i-- { - header := segment.Headers[i] - link := hd.addHeaderAsLink(header, false /* persisted */) + link := hd.addHeaderAsLink(segment[i], false /* persisted */) if prevLink == nil { anchor.links = append(anchor.links, link) } else { @@ -483,11 +498,17 @@ func (hd *HeaderDownload) RecoverFromDb(db kv.RoDB) error { if err != nil { return err } - var h types.Header - if err = rlp.DecodeBytes(v, &h); err != nil { + var header types.Header + if err = rlp.DecodeBytes(v, &header); err != nil { return err } - hd.addHeaderAsLink(&h, true /* persisted */) + h := ChainSegmentHeader{ + HeaderRaw: v, + Header: &header, + Hash: types.RawRlpHash(v), + Number: header.Number.Uint64(), + } + hd.addHeaderAsLink(h, true /* persisted */) select { case <-logEvery.C: @@ -614,11 +635,11 @@ func (hd *HeaderDownload) InsertHeaders(hf func(header *types.Header, hash commo if _, bad := hd.badHeaders[link.hash]; bad { skip = true } else if err := hd.engine.VerifyHeader(hd.headerReader, link.header, true /* seal */); err != nil { - log.Warn("Verification failed for header", "hash", link.header.Hash(), "height", link.blockHeight, "error", err) + log.Warn("Verification failed for header", "hash", link.hash, "height", link.blockHeight, "error", err) if errors.Is(err, consensus.ErrFutureBlock) { // This may become valid later linksInFuture = append(linksInFuture, link) - log.Warn("Added future link", "hash", link.header.Hash(), "height", link.blockHeight, "timestamp", link.header.Time) + log.Warn("Added future link", "hash", link.hash, "height", link.blockHeight, "timestamp", link.header.Time) continue // prevent removal of the link from the hd.linkQueue } else { skip = true @@ -702,16 +723,14 @@ func (hd *HeaderDownload) getLink(linkHash common.Hash) (*Link, bool) { } // addHeaderAsLink wraps header into a link and adds it to either queue of persisted links or queue of non-persisted links -func (hd *HeaderDownload) addHeaderAsLink(header *types.Header, persisted bool) *Link { - height := header.Number.Uint64() - linkHash := header.Hash() +func (hd *HeaderDownload) addHeaderAsLink(h ChainSegmentHeader, persisted bool) *Link { link := &Link{ - blockHeight: height, - hash: linkHash, - header: header, + blockHeight: h.Number, + hash: h.Hash, + header: h.Header, persisted: persisted, } - hd.links[linkHash] = link + hd.links[h.Hash] = link if persisted { heap.Push(hd.persistedLinkQueue, link) } else { @@ -860,8 +879,8 @@ func (hi *HeaderInserter) BestHeaderChanged() bool { // it allows higher-level algo immediately request more headers without waiting all stages precessing, // speeds up visibility of new blocks // It remember peerID - then later - if anchors created from segments will abandoned - this peerID gonna get Penalty -func (hd *HeaderDownload) ProcessSegment(segment *ChainSegment, newBlock bool, peerID enode.ID) (requestMore bool, penalties []PenaltyItem) { - log.Trace("processSegment", "from", segment.Headers[0].Number.Uint64(), "to", segment.Headers[len(segment.Headers)-1].Number.Uint64()) +func (hd *HeaderDownload) ProcessSegment(segment ChainSegment, newBlock bool, peerID enode.ID) (requestMore bool, penalties []PenaltyItem) { + log.Trace("processSegment", "from", segment[0].Number, "to", segment[len(segment)-1].Number) hd.lock.Lock() defer hd.lock.Unlock() foundAnchor, start := hd.findAnchors(segment) @@ -877,15 +896,14 @@ func (hd *HeaderDownload) ProcessSegment(segment *ChainSegment, newBlock bool, p } return } - height := segment.Headers[len(segment.Headers)-1].Number.Uint64() - hash := segment.Headers[len(segment.Headers)-1].Hash() - if newBlock || hd.seenAnnounces.Seen(hash) { - if height > hd.topSeenHeight { - hd.topSeenHeight = height + highest := segment[len(segment)-1] + if highest.Number > hd.topSeenHeight { + if newBlock || hd.seenAnnounces.Seen(highest.Hash) { + hd.topSeenHeight = highest.Number } } - startNum := segment.Headers[start].Number.Uint64() - endNum := segment.Headers[end-1].Number.Uint64() + startNum := segment[start].Number + endNum := segment[end-1].Number // There are 4 cases if foundAnchor { if foundTip { @@ -906,14 +924,12 @@ func (hd *HeaderDownload) ProcessSegment(segment *ChainSegment, newBlock bool, p log.Trace("Extended Down", "start", startNum, "end", endNum) } } else if foundTip { - if end > 0 { - // ExtendUp - if err := hd.extendUp(segment, start, end); err != nil { - log.Debug("ExtendUp failed", "error", err) - return - } - log.Trace("Extended Up", "start", startNum, "end", endNum) + // ExtendUp + if err := hd.extendUp(segment, start, end); err != nil { + log.Debug("ExtendUp failed", "error", err) + return } + log.Trace("Extended Up", "start", startNum, "end", endNum) } else { // NewAnchor var err error @@ -1052,7 +1068,8 @@ func DecodeTips(encodings []string) (map[common.Hash]HeaderRecord, error) { return nil, fmt.Errorf("parsing hard coded header on %d: %w", i, err) } - hardTips[h.Hash()] = HeaderRecord{Raw: b, Header: &h} + headerHash := types.RawRlpHash(res) + hardTips[headerHash] = HeaderRecord{Raw: b, Header: &h} buf.Reset() } diff --git a/turbo/stages/headerdownload/header_data_struct.go b/turbo/stages/headerdownload/header_data_struct.go index 149d9c7635..b6bc67e2b7 100644 --- a/turbo/stages/headerdownload/header_data_struct.go +++ b/turbo/stages/headerdownload/header_data_struct.go @@ -11,6 +11,7 @@ import ( "github.com/ledgerwatch/erigon/consensus" "github.com/ledgerwatch/erigon/core/types" "github.com/ledgerwatch/erigon/p2p/enode" + "github.com/ledgerwatch/erigon/rlp" ) // Link is a chain link that can be connect to other chain links @@ -118,12 +119,16 @@ func (aq *AnchorQueue) Pop() interface{} { return x } +type ChainSegmentHeader struct { + HeaderRaw rlp.RawValue + Header *types.Header + Hash common.Hash + Number uint64 +} + // First item in ChainSegment is the anchor // ChainSegment must be contigous and must not include bad headers -type ChainSegment struct { - HeadersRaw [][]byte - Headers []*types.Header -} +type ChainSegment []ChainSegmentHeader type PeerHandle int // This is int just for the PoC phase - will be replaced by more appropriate type to find a peer diff --git a/turbo/stages/headerdownload/header_test.go b/turbo/stages/headerdownload/header_test.go index f78070c6f2..266f322cc8 100644 --- a/turbo/stages/headerdownload/header_test.go +++ b/turbo/stages/headerdownload/header_test.go @@ -6,14 +6,30 @@ import ( "github.com/ledgerwatch/erigon/consensus/ethash" "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/rlp" ) +func newCSHeaders(headers ...*types.Header) []ChainSegmentHeader { + csHeaders := make([]ChainSegmentHeader, 0, len(headers)) + for _, header := range headers { + headerRaw, _ := rlp.EncodeToBytes(header) + h := ChainSegmentHeader{ + HeaderRaw: headerRaw, + Header: header, + Hash: header.Hash(), + Number: header.Number.Uint64(), + } + csHeaders = append(csHeaders, h) + } + return csHeaders +} + func TestSplitIntoSegments(t *testing.T) { engine := ethash.NewFaker() hd := NewHeaderDownload(100, 100, engine) // Empty message - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{}, []*types.Header{}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments([]ChainSegmentHeader{}); err == nil { if penalty != NoPenalty { t.Errorf("unexpected penalty: %s", penalty) } @@ -27,7 +43,7 @@ func TestSplitIntoSegments(t *testing.T) { // Single header var h types.Header h.Number = big.NewInt(5) - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}}, []*types.Header{&h}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h)); err == nil { if penalty != NoPenalty { t.Errorf("unexpected penalty: %s", penalty) } @@ -39,7 +55,7 @@ func TestSplitIntoSegments(t *testing.T) { } // Same header repeated twice - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}, {}}, []*types.Header{&h, &h}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h, &h)); err == nil { if penalty != DuplicateHeaderPenalty { t.Errorf("expected DuplicateHeader penalty, got %s", penalty) } @@ -52,7 +68,7 @@ func TestSplitIntoSegments(t *testing.T) { // Single header with a bad hash hd.ReportBadHeader(h.Hash()) - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}}, []*types.Header{&h}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h)); err == nil { if penalty != BadBlockPenalty { t.Errorf("expected BadBlock penalty, got %s", penalty) } @@ -70,17 +86,17 @@ func TestSplitIntoSegments(t *testing.T) { h2.Number = big.NewInt(2) h2.Difficulty = big.NewInt(1010) h2.ParentHash = h1.Hash() - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}, {}}, []*types.Header{&h1, &h2}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h1, &h2)); err == nil { if penalty != NoPenalty { t.Errorf("unexpected penalty: %s", penalty) } if len(chainSegments) != 1 { t.Errorf("expected 1 chainSegments, got %d", len(chainSegments)) } - if len(chainSegments[0].Headers) != 2 { - t.Errorf("expected chainSegment of the length 2, got %d", len(chainSegments[0].Headers)) + if len(chainSegments[0]) != 2 { + t.Errorf("expected chainSegment of the length 2, got %d", len(chainSegments[0])) } - if chainSegments[0].Headers[0] != &h2 { + if chainSegments[0][0].Header != &h2 { t.Errorf("expected h2 to be the root") } } else { @@ -89,7 +105,7 @@ func TestSplitIntoSegments(t *testing.T) { // Two connected headers with wrong numbers h2.Number = big.NewInt(3) // Child number 3, parent number 1 - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}, {}}, []*types.Header{&h1, &h2}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h1, &h2)); err == nil { if penalty != WrongChildBlockHeightPenalty { t.Errorf("expected WrongChildBlockHeight penalty, got %s", penalty) } @@ -113,17 +129,17 @@ func TestSplitIntoSegments(t *testing.T) { h3.Extra = []byte("I'm different") // To make sure the hash of h3 is different from the hash of h2 // Same three headers, but in a reverse order - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}, {}, {}}, []*types.Header{&h3, &h2, &h1}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h3, &h2, &h1)); err == nil { if penalty != NoPenalty { t.Errorf("unexpected penalty: %s", penalty) } if len(chainSegments) != 3 { t.Errorf("expected 3 chainSegments, got %d", len(chainSegments)) } - if len(chainSegments[0].Headers) != 1 { - t.Errorf("expected chainSegment of the length 1, got %d", len(chainSegments[0].Headers)) + if len(chainSegments[0]) != 1 { + t.Errorf("expected chainSegment of the length 1, got %d", len(chainSegments[0])) } - if chainSegments[2].Headers[0] != &h1 { + if chainSegments[2][0].Header != &h1 { t.Errorf("expected h1 to be the root") } } else { @@ -131,7 +147,7 @@ func TestSplitIntoSegments(t *testing.T) { } // Two headers not connected to each other - if chainSegments, penalty, err := hd.SplitIntoSegments([][]byte{{}, {}}, []*types.Header{&h3, &h2}); err == nil { + if chainSegments, penalty, err := hd.SplitIntoSegments(newCSHeaders(&h3, &h2)); err == nil { if penalty != NoPenalty { t.Errorf("unexpected penalty: %s", penalty) } @@ -148,17 +164,18 @@ func TestSingleHeaderAsSegment(t *testing.T) { hd := NewHeaderDownload(100, 100, engine) var h types.Header h.Number = big.NewInt(5) - if chainSegments, penalty, err := hd.SingleHeaderAsSegment([]byte{}, &h); err == nil { + headerRaw, _ := rlp.EncodeToBytes(h) + if chainSegments, penalty, err := hd.SingleHeaderAsSegment(headerRaw, &h); err == nil { if penalty != NoPenalty { t.Errorf("unexpected penalty: %s", penalty) } if len(chainSegments) != 1 { t.Errorf("expected 1 chainSegments, got %d", len(chainSegments)) } - if len(chainSegments[0].Headers) != 1 { - t.Errorf("expected chainSegment of the length 1, got %d", len(chainSegments[0].Headers)) + if len(chainSegments[0]) != 1 { + t.Errorf("expected chainSegment of the length 1, got %d", len(chainSegments[0])) } - if chainSegments[0].Headers[0] != &h { + if chainSegments[0][0].Header != &h { t.Errorf("expected h to be the root") } } else { @@ -167,7 +184,7 @@ func TestSingleHeaderAsSegment(t *testing.T) { // Same header with a bad hash hd.ReportBadHeader(h.Hash()) - if chainSegments, penalty, err := hd.SingleHeaderAsSegment([]byte{}, &h); err == nil { + if chainSegments, penalty, err := hd.SingleHeaderAsSegment(headerRaw, &h); err == nil { if penalty != BadBlockPenalty { t.Errorf("expected BadBlock penalty, got %s", penalty) } -- GitLab