From a829a5658723ff3681f14650818ef050cb0a7fa8 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Sat, 21 Mar 2015 00:49:31 +0100
Subject: [PATCH] rlp: add Stream.Raw

---
 rlp/decode.go      | 25 +++++++++++++++++++++++++
 rlp/decode_test.go | 16 +++++++++++++++-
 rlp/encode.go      | 30 +++++++++++++++++++-----------
 3 files changed, 59 insertions(+), 12 deletions(-)

diff --git a/rlp/decode.go b/rlp/decode.go
index 0e99d9caa..0fde0a947 100644
--- a/rlp/decode.go
+++ b/rlp/decode.go
@@ -540,6 +540,31 @@ func (s *Stream) Bytes() ([]byte, error) {
 	}
 }
 
+// Raw reads a raw encoded value including RLP type information.
+func (s *Stream) Raw() ([]byte, error) {
+	kind, size, err := s.Kind()
+	if err != nil {
+		return nil, err
+	}
+	if kind == Byte {
+		s.kind = -1 // rearm Kind
+		return []byte{s.byteval}, nil
+	}
+	// the original header has already been read and is no longer
+	// available. read content and put a new header in front of it.
+	start := headsize(size)
+	buf := make([]byte, uint64(start)+size)
+	if err := s.readFull(buf[start:]); err != nil {
+		return nil, err
+	}
+	if kind == String {
+		puthead(buf, 0x80, 0xB8, size)
+	} else {
+		puthead(buf, 0xC0, 0xF7, size)
+	}
+	return buf, nil
+}
+
 var errUintOverflow = errors.New("rlp: uint overflow")
 
 // Uint reads an RLP string of up to 8 bytes and returns its contents
diff --git a/rlp/decode_test.go b/rlp/decode_test.go
index 0f034d5d8..a18ff1d08 100644
--- a/rlp/decode_test.go
+++ b/rlp/decode_test.go
@@ -165,6 +165,20 @@ func TestStreamList(t *testing.T) {
 	}
 }
 
+func TestStreamRaw(t *testing.T) {
+	s := NewStream(bytes.NewReader(unhex("C58401010101")))
+	s.List()
+
+	want := unhex("8401010101")
+	raw, err := s.Raw()
+	if err != nil {
+		t.Fatal(err)
+	}
+	if !bytes.Equal(want, raw) {
+		t.Errorf("raw mismatch: got %x, want %x", raw, want)
+	}
+}
+
 func TestDecodeErrors(t *testing.T) {
 	r := bytes.NewReader(nil)
 
@@ -331,7 +345,7 @@ var decodeTests = []decodeTest{
 	{input: "C109", ptr: new(*[]uint), value: &[]uint{9}},
 	{input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}},
 
-	// check that input position is advanced also empty values.
+	// check that input position is advanced also for empty values.
 	{input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}},
 
 	// pointer should be reset to nil
diff --git a/rlp/encode.go b/rlp/encode.go
index 7ac74d8fb..289bc4eaa 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -70,7 +70,7 @@ func (e flatenc) EncodeRLP(out io.Writer) error {
 	newhead := eb.lheads[prevnheads]
 	copy(eb.lheads[prevnheads:], eb.lheads[prevnheads+1:])
 	eb.lheads = eb.lheads[:len(eb.lheads)-1]
-	eb.lhsize -= newhead.tagsize()
+	eb.lhsize -= headsize(uint64(newhead.size))
 	return nil
 }
 
@@ -155,21 +155,29 @@ type listhead struct {
 // encode writes head to the given buffer, which must be at least
 // 9 bytes long. It returns the encoded bytes.
 func (head *listhead) encode(buf []byte) []byte {
-	if head.size < 56 {
-		buf[0] = 0xC0 + byte(head.size)
-		return buf[:1]
-	} else {
-		sizesize := putint(buf[1:], uint64(head.size))
-		buf[0] = 0xF7 + byte(sizesize)
-		return buf[:sizesize+1]
+	return buf[:puthead(buf, 0xC0, 0xF7, uint64(head.size))]
+}
+
+// headsize returns the size of a list or string header
+// for a value of the given size.
+func headsize(size uint64) int {
+	if size < 56 {
+		return 1
 	}
+	return 1 + intsize(size)
 }
 
-func (head *listhead) tagsize() int {
-	if head.size < 56 {
+// puthead writes a list or string header to buf.
+// buf must be at least 9 bytes long.
+func puthead(buf []byte, smalltag, largetag byte, size uint64) int {
+	if size < 56 {
+		buf[0] = smalltag + byte(size)
 		return 1
+	} else {
+		sizesize := putint(buf[1:], size)
+		buf[0] = largetag + byte(sizesize)
+		return sizesize + 1
 	}
-	return 1 + intsize(uint64(head.size))
 }
 
 func newencbuf() *encbuf {
-- 
GitLab