From bf06095811795d8929e7f42da57811c9bcef7916 Mon Sep 17 00:00:00 2001
From: TBC Dev <48684072+tbcd@users.noreply.github.com>
Date: Tue, 5 Apr 2022 09:07:07 +0800
Subject: [PATCH] hack: Allow blockTotal <= 0 as offset from head (#3818)

---
 cmd/hack/hack.go | 43 +++++++++++++++++++++++++++++++++++--------
 1 file changed, 35 insertions(+), 8 deletions(-)

diff --git a/cmd/hack/hack.go b/cmd/hack/hack.go
index 992ee5b9cf..b6b089d8e9 100644
--- a/cmd/hack/hack.go
+++ b/cmd/hack/hack.go
@@ -66,7 +66,7 @@ var (
 	cpuprofile = flag.String("cpuprofile", "", "write cpu profile `file`")
 	rewind     = flag.Int("rewind", 1, "rewind to given number of blocks")
 	block      = flag.Int("block", 1, "specifies a block number for operation")
-	blockTotal = flag.Int("blocktotal", 1, "specifies a total amount of blocks to process")
+	blockTotal = flag.Int("blocktotal", 1, "specifies a total amount of blocks to process (will offset from head block if <= 0)")
 	account    = flag.String("account", "0x", "specifies account to investigate")
 	name       = flag.String("name", "", "name to add to the file names")
 	chaindata  = flag.String("chaindata", "chaindata", "path to the chaindata database file")
@@ -647,13 +647,23 @@ func testBlockHashes(chaindata string, block int, stateRoot common.Hash) {
 	}))
 }
 
+func getCurrentBlockNumber(tx kv.Tx) *uint64 {
+	hash := rawdb.ReadHeadBlockHash(tx)
+	if hash == (common.Hash{}) {
+		return nil
+	}
+	return rawdb.ReadHeaderNumber(tx, hash)
+}
+
 func printCurrentBlockNumber(chaindata string) {
 	ethDb := mdbx.MustOpen(chaindata)
 	defer ethDb.Close()
 	ethDb.View(context.Background(), func(tx kv.Tx) error {
-		hash := rawdb.ReadHeadBlockHash(tx)
-		number := rawdb.ReadHeaderNumber(tx, hash)
-		fmt.Printf("Block number: %d\n", *number)
+		if number := getCurrentBlockNumber(tx); number != nil {
+			fmt.Printf("Block number: %d\n", *number)
+		} else {
+			fmt.Println("Block number: <nil>")
+		}
 		return nil
 	})
 }
@@ -1724,7 +1734,21 @@ func mint(chaindata string, block uint64) error {
 	return tx.Commit()
 }
 
-func extractHashes(chaindata string, blockStep uint64, blockTotal uint64, name string) error {
+func getBlockTotal(tx kv.Tx, blockFrom uint64, blockTotalOrOffset int64) uint64 {
+	if blockTotalOrOffset > 0 {
+		return uint64(blockTotalOrOffset)
+	}
+	if head := getCurrentBlockNumber(tx); head != nil {
+		if blockSub := uint64(-blockTotalOrOffset); blockSub <= *head {
+			if blockEnd := *head - blockSub; blockEnd > blockFrom {
+				return blockEnd - blockFrom + 1
+			}
+		}
+	}
+	return 1
+}
+
+func extractHashes(chaindata string, blockStep uint64, blockTotalOrOffset int64, name string) error {
 	db := mdbx.MustOpen(chaindata)
 	defer db.Close()
 
@@ -1742,6 +1766,8 @@ func extractHashes(chaindata string, blockStep uint64, blockTotal uint64, name s
 
 	b := uint64(0)
 	tool.Check(db.View(context.Background(), func(tx kv.Tx) error {
+		blockTotal := getBlockTotal(tx, b, blockTotalOrOffset)
+		// Note: blockTotal used here as block number rather than block count
 		for b <= blockTotal {
 			hash, err := rawdb.ReadCanonicalHash(tx, b)
 			if err != nil {
@@ -1765,7 +1791,7 @@ func extractHashes(chaindata string, blockStep uint64, blockTotal uint64, name s
 	return nil
 }
 
-func extractHeaders(chaindata string, block uint64, blockTotal uint64) error {
+func extractHeaders(chaindata string, block uint64, blockTotalOrOffset int64) error {
 	db := mdbx.MustOpen(chaindata)
 	defer db.Close()
 	tx, err := db.BeginRo(context.Background())
@@ -1779,6 +1805,7 @@ func extractHeaders(chaindata string, block uint64, blockTotal uint64) error {
 	}
 	defer c.Close()
 	blockEncoded := dbutils.EncodeBlockNumber(block)
+	blockTotal := getBlockTotal(tx, block, blockTotalOrOffset)
 	for k, v, err := c.Seek(blockEncoded); k != nil && blockTotal > 0; k, v, err = c.Next() {
 		if err != nil {
 			return err
@@ -2604,10 +2631,10 @@ func main() {
 		err = mint(*chaindata, uint64(*block))
 
 	case "extractHeaders":
-		err = extractHeaders(*chaindata, uint64(*block), uint64(*blockTotal))
+		err = extractHeaders(*chaindata, uint64(*block), int64(*blockTotal))
 
 	case "extractHashes":
-		err = extractHashes(*chaindata, uint64(*block), uint64(*blockTotal), *name)
+		err = extractHashes(*chaindata, uint64(*block), int64(*blockTotal), *name)
 
 	case "defrag":
 		err = hackdb.Defrag()
-- 
GitLab