From f7112cc182ec9ec43ff56d4ff3c84d2518aa30ff Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Mon, 14 Sep 2020 19:23:01 +0200
Subject: [PATCH] rlp: add SplitUint64 (#21563)

This can be useful when working with raw RLP data.
---
 rlp/raw.go      | 26 ++++++++++++++++++++++++++
 rlp/raw_test.go | 45 +++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 71 insertions(+)

diff --git a/rlp/raw.go b/rlp/raw.go
index 2b3f328f6..c2a8517f6 100644
--- a/rlp/raw.go
+++ b/rlp/raw.go
@@ -57,6 +57,32 @@ func SplitString(b []byte) (content, rest []byte, err error) {
 	return content, rest, nil
 }
 
+// SplitUint64 decodes an integer at the beginning of b.
+// It also returns the remaining data after the integer in 'rest'.
+func SplitUint64(b []byte) (x uint64, rest []byte, err error) {
+	content, rest, err := SplitString(b)
+	if err != nil {
+		return 0, b, err
+	}
+	switch {
+	case len(content) == 0:
+		return 0, rest, nil
+	case len(content) == 1:
+		if content[0] == 0 {
+			return 0, b, ErrCanonInt
+		}
+		return uint64(content[0]), rest, nil
+	case len(content) > 8:
+		return 0, b, errUintOverflow
+	default:
+		x, err = readSize(content, byte(len(content)))
+		if err != nil {
+			return 0, b, ErrCanonInt
+		}
+		return x, rest, nil
+	}
+}
+
 // SplitList splits b into the content of a list and any remaining
 // bytes after the list.
 func SplitList(b []byte) (content, rest []byte, err error) {
diff --git a/rlp/raw_test.go b/rlp/raw_test.go
index 2aad04210..cdae4ff08 100644
--- a/rlp/raw_test.go
+++ b/rlp/raw_test.go
@@ -71,6 +71,49 @@ func TestSplitTypes(t *testing.T) {
 	}
 }
 
+func TestSplitUint64(t *testing.T) {
+	tests := []struct {
+		input string
+		val   uint64
+		rest  string
+		err   error
+	}{
+		{"01", 1, "", nil},
+		{"7FFF", 0x7F, "FF", nil},
+		{"80FF", 0, "FF", nil},
+		{"81FAFF", 0xFA, "FF", nil},
+		{"82FAFAFF", 0xFAFA, "FF", nil},
+		{"83FAFAFAFF", 0xFAFAFA, "FF", nil},
+		{"84FAFAFAFAFF", 0xFAFAFAFA, "FF", nil},
+		{"85FAFAFAFAFAFF", 0xFAFAFAFAFA, "FF", nil},
+		{"86FAFAFAFAFAFAFF", 0xFAFAFAFAFAFA, "FF", nil},
+		{"87FAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFA, "FF", nil},
+		{"88FAFAFAFAFAFAFAFAFF", 0xFAFAFAFAFAFAFAFA, "FF", nil},
+
+		// errors
+		{"", 0, "", io.ErrUnexpectedEOF},
+		{"00", 0, "00", ErrCanonInt},
+		{"81", 0, "81", ErrValueTooLarge},
+		{"8100", 0, "8100", ErrCanonSize},
+		{"8200FF", 0, "8200FF", ErrCanonInt},
+		{"8103FF", 0, "8103FF", ErrCanonSize},
+		{"89FAFAFAFAFAFAFAFAFAFF", 0, "89FAFAFAFAFAFAFAFAFAFF", errUintOverflow},
+	}
+
+	for i, test := range tests {
+		val, rest, err := SplitUint64(unhex(test.input))
+		if val != test.val {
+			t.Errorf("test %d: val mismatch: got %x, want %x (input %q)", i, val, test.val, test.input)
+		}
+		if !bytes.Equal(rest, unhex(test.rest)) {
+			t.Errorf("test %d: rest mismatch: got %x, want %s (input %q)", i, rest, test.rest, test.input)
+		}
+		if err != test.err {
+			t.Errorf("test %d: error mismatch: got %q, want %q", i, err, test.err)
+		}
+	}
+}
+
 func TestSplit(t *testing.T) {
 	tests := []struct {
 		input     string
@@ -78,7 +121,9 @@ func TestSplit(t *testing.T) {
 		val, rest string
 		err       error
 	}{
+		{input: "00FFFF", kind: Byte, val: "00", rest: "FFFF"},
 		{input: "01FFFF", kind: Byte, val: "01", rest: "FFFF"},
+		{input: "7FFFFF", kind: Byte, val: "7F", rest: "FFFF"},
 		{input: "80FFFF", kind: String, val: "", rest: "FFFF"},
 		{input: "C3010203", kind: List, val: "010203"},
 
-- 
GitLab