From cad64fb911e7029bef876f16e0956b3b0b4bb4d0 Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Fri, 17 Apr 2015 01:16:46 +0200
Subject: [PATCH] rlp: stricter rules for structs and pointers

The rules have changed as follows:

* When decoding into pointers, empty values no longer produce
  a nil pointer. This can be overriden for struct fields using the
  struct tag "nil".
* When decoding into structs, the input list must contain an element
  for each field.
---
 rlp/decode.go      | 76 ++++++++++++++++++++++++++++++++--------------
 rlp/decode_test.go | 65 +++++++++++++++++++++++++++++++--------
 rlp/encode.go      |  8 ++---
 rlp/typecache.go   | 51 ++++++++++++++++++++++---------
 4 files changed, 148 insertions(+), 52 deletions(-)

diff --git a/rlp/decode.go b/rlp/decode.go
index 43dd716b5..394f83fb2 100644
--- a/rlp/decode.go
+++ b/rlp/decode.go
@@ -36,17 +36,26 @@ type Decoder interface {
 // If the type implements the Decoder interface, decode calls
 // DecodeRLP.
 //
-// To decode into a pointer, Decode will set the pointer to nil if the
-// input has size zero. If the input has nonzero size, Decode will
-// parse the input data into a value of the type being pointed to.
-// If the pointer is non-nil, the existing value will reused.
+// To decode into a pointer, Decode will decode into the value pointed
+// to. If the pointer is nil, a new value of the pointer's element
+// type is allocated. If the pointer is non-nil, the existing value
+// will reused.
 //
 // To decode into a struct, Decode expects the input to be an RLP
 // list. The decoded elements of the list are assigned to each public
-// field in the order given by the struct's definition. If the input
-// list has too few elements, no error is returned and the remaining
-// fields will have the zero value.
-// Recursive struct types are supported.
+// field in the order given by the struct's definition. The input list
+// must contain an element for each decoded field. Decode returns an
+// error if there are too few or too many elements.
+//
+// The decoding of struct fields honours one particular struct tag,
+// "nil". This tag applies to pointer-typed fields and changes the
+// decoding rules for the field such that input values of size zero
+// decode as a nil pointer. This tag can be useful when decoding recursive
+// types.
+//
+//     type StructWithEmptyOK struct {
+//         Foo *[20]byte `rlp:"nil"`
+//     }
 //
 // To decode into a slice, the input must be a list and the resulting
 // slice will contain the input elements in order.
@@ -54,7 +63,7 @@ type Decoder interface {
 // can also be an RLP string.
 //
 // To decode into a Go string, the input must be an RLP string. The
-// bytes are taken as-is and will not necessarily be valid UTF-8.
+// input bytes are taken as-is and will not necessarily be valid UTF-8.
 //
 // To decode into an unsigned integer type, the input must also be an RLP
 // string. The bytes are interpreted as a big endian representation of
@@ -65,8 +74,8 @@ type Decoder interface {
 // To decode into an interface value, Decode stores one of these
 // in the value:
 //
-//	[]interface{}, for RLP lists
-//	[]byte, for RLP strings
+//	  []interface{}, for RLP lists
+//	  []byte, for RLP strings
 //
 // Non-empty interface types are not supported, nor are booleans,
 // signed integers, floating point numbers, maps, channels and
@@ -136,7 +145,7 @@ var (
 	bigInt           = reflect.TypeOf(big.Int{})
 )
 
-func makeDecoder(typ reflect.Type) (dec decoder, err error) {
+func makeDecoder(typ reflect.Type, tags tags) (dec decoder, err error) {
 	kind := typ.Kind()
 	switch {
 	case typ.Implements(decoderInterface):
@@ -156,6 +165,9 @@ func makeDecoder(typ reflect.Type) (dec decoder, err error) {
 	case kind == reflect.Struct:
 		return makeStructDecoder(typ)
 	case kind == reflect.Ptr:
+		if tags.nilOK {
+			return makeOptionalPtrDecoder(typ)
+		}
 		return makePtrDecoder(typ)
 	case kind == reflect.Interface:
 		return decodeInterface, nil
@@ -214,7 +226,7 @@ func makeListDecoder(typ reflect.Type) (decoder, error) {
 			return decodeByteSlice, nil
 		}
 	}
-	etypeinfo, err := cachedTypeInfo1(etype)
+	etypeinfo, err := cachedTypeInfo1(etype, tags{})
 	if err != nil {
 		return nil, err
 	}
@@ -352,11 +364,6 @@ func zero(val reflect.Value, start int) {
 	}
 }
 
-type field struct {
-	index int
-	info  *typeinfo
-}
-
 func makeStructDecoder(typ reflect.Type) (decoder, error) {
 	fields, err := structFields(typ)
 	if err != nil {
@@ -369,8 +376,7 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
 		for _, f := range fields {
 			err = f.info.decoder(s, val.Field(f.index))
 			if err == EOL {
-				// too few elements. leave the rest at their zero value.
-				break
+				return &decodeError{msg: "too few elements", typ: typ}
 			} else if err != nil {
 				return addErrorContext(err, "."+typ.Field(f.index).Name)
 			}
@@ -380,9 +386,35 @@ func makeStructDecoder(typ reflect.Type) (decoder, error) {
 	return dec, nil
 }
 
+// makePtrDecoder creates a decoder that decodes into
+// the pointer's element type.
 func makePtrDecoder(typ reflect.Type) (decoder, error) {
 	etype := typ.Elem()
-	etypeinfo, err := cachedTypeInfo1(etype)
+	etypeinfo, err := cachedTypeInfo1(etype, tags{})
+	if err != nil {
+		return nil, err
+	}
+	dec := func(s *Stream, val reflect.Value) (err error) {
+		newval := val
+		if val.IsNil() {
+			newval = reflect.New(etype)
+		}
+		if err = etypeinfo.decoder(s, newval.Elem()); err == nil {
+			val.Set(newval)
+		}
+		return err
+	}
+	return dec, nil
+}
+
+// makeOptionalPtrDecoder creates a decoder that decodes empty values
+// as nil. Non-empty values are decoded into a value of the element type,
+// just like makePtrDecoder does.
+//
+// This decoder is used for pointer-typed struct fields with struct tag "nil".
+func makeOptionalPtrDecoder(typ reflect.Type) (decoder, error) {
+	etype := typ.Elem()
+	etypeinfo, err := cachedTypeInfo1(etype, tags{})
 	if err != nil {
 		return nil, err
 	}
@@ -706,7 +738,7 @@ func (s *Stream) Decode(val interface{}) error {
 	if rval.IsNil() {
 		return errDecodeIntoNil
 	}
-	info, err := cachedTypeInfo(rtyp.Elem())
+	info, err := cachedTypeInfo(rtyp.Elem(), tags{})
 	if err != nil {
 		return err
 	}
diff --git a/rlp/decode_test.go b/rlp/decode_test.go
index 7e2ea2041..fd52bd1be 100644
--- a/rlp/decode_test.go
+++ b/rlp/decode_test.go
@@ -280,7 +280,7 @@ type simplestruct struct {
 
 type recstruct struct {
 	I     uint
-	Child *recstruct
+	Child *recstruct `rlp:"nil"`
 }
 
 var (
@@ -390,15 +390,33 @@ var decodeTests = []decodeTest{
 	{input: "8105", ptr: new(big.Int), error: "rlp: non-canonical size information for *big.Int"},
 
 	// structs
-	{input: "C0", ptr: new(simplestruct), value: simplestruct{0, ""}},
-	{input: "C105", ptr: new(simplestruct), value: simplestruct{5, ""}},
-	{input: "C50583343434", ptr: new(simplestruct), value: simplestruct{5, "444"}},
 	{
-		input: "C501C302C103",
+		input: "C50583343434",
+		ptr:   new(simplestruct),
+		value: simplestruct{5, "444"},
+	},
+	{
+		input: "C601C402C203C0",
 		ptr:   new(recstruct),
 		value: recstruct{1, &recstruct{2, &recstruct{3, nil}}},
 	},
 
+	// struct errors
+	{
+		input: "C0",
+		ptr:   new(simplestruct),
+		error: "rlp: too few elements for rlp.simplestruct",
+	},
+	{
+		input: "C105",
+		ptr:   new(simplestruct),
+		error: "rlp: too few elements for rlp.simplestruct",
+	},
+	{
+		input: "C7C50583343434C0",
+		ptr:   new([]*simplestruct),
+		error: "rlp: too few elements for rlp.simplestruct, decoding into ([]*rlp.simplestruct)[1]",
+	},
 	{
 		input: "83222222",
 		ptr:   new(simplestruct),
@@ -417,19 +435,15 @@ var decodeTests = []decodeTest{
 
 	// pointers
 	{input: "00", ptr: new(*[]byte), value: &[]byte{0}},
-	{input: "80", ptr: new(*uint), value: (*uint)(nil)},
-	{input: "C0", ptr: new(*uint), value: (*uint)(nil)},
+	{input: "80", ptr: new(*uint), value: uintp(0)},
+	{input: "C0", ptr: new(*uint), error: "rlp: expected input string or byte for uint"},
 	{input: "07", ptr: new(*uint), value: uintp(7)},
 	{input: "8158", ptr: new(*uint), value: uintp(0x58)},
 	{input: "C109", ptr: new(*[]uint), value: &[]uint{9}},
 	{input: "C58403030303", ptr: new(*[][]byte), value: &[][]byte{{3, 3, 3, 3}}},
 
 	// check that input position is advanced also for empty values.
-	{input: "C3808005", ptr: new([]*uint), value: []*uint{nil, nil, uintp(5)}},
-
-	// pointer should be reset to nil
-	{input: "05", ptr: sharedPtr, value: uintp(5)},
-	{input: "80", ptr: sharedPtr, value: (*uint)(nil)},
+	{input: "C3808005", ptr: new([]*uint), value: []*uint{uintp(0), uintp(0), uintp(5)}},
 
 	// interface{}
 	{input: "00", ptr: new(interface{}), value: []byte{0}},
@@ -599,6 +613,33 @@ func ExampleDecode() {
 	// Decoded value: rlp.example{A:0xa, B:0x14, private:0x0, String:"foobar"}
 }
 
+func ExampleDecode_structTagNil() {
+	// In this example, we'll use the "nil" struct tag to change
+	// how a pointer-typed field is decoded. The input contains an RLP
+	// list of one element, an empty string.
+	input := []byte{0xC1, 0x80}
+
+	// This type uses the normal rules.
+	// The empty input string is decoded as a pointer to an empty Go string.
+	var normalRules struct {
+		String *string
+	}
+	Decode(bytes.NewReader(input), &normalRules)
+	fmt.Printf("normal: String = %q\n", *normalRules.String)
+
+	// This type uses the struct tag.
+	// The empty input string is decoded as a nil pointer.
+	var withEmptyOK struct {
+		String *string `rlp:"nil"`
+	}
+	Decode(bytes.NewReader(input), &withEmptyOK)
+	fmt.Printf("with nil tag: String = %v\n", withEmptyOK.String)
+
+	// Output:
+	// normal: String = ""
+	// with nil tag: String = <nil>
+}
+
 func ExampleStream() {
 	input, _ := hex.DecodeString("C90A1486666F6F626172")
 	s := NewStream(bytes.NewReader(input), 0)
diff --git a/rlp/encode.go b/rlp/encode.go
index 6cf6776d6..10ff0ae79 100644
--- a/rlp/encode.go
+++ b/rlp/encode.go
@@ -194,7 +194,7 @@ func (w *encbuf) Write(b []byte) (int, error) {
 
 func (w *encbuf) encode(val interface{}) error {
 	rval := reflect.ValueOf(val)
-	ti, err := cachedTypeInfo(rval.Type())
+	ti, err := cachedTypeInfo(rval.Type(), tags{})
 	if err != nil {
 		return err
 	}
@@ -485,7 +485,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
 		return nil
 	}
 	eval := val.Elem()
-	ti, err := cachedTypeInfo(eval.Type())
+	ti, err := cachedTypeInfo(eval.Type(), tags{})
 	if err != nil {
 		return err
 	}
@@ -493,7 +493,7 @@ func writeInterface(val reflect.Value, w *encbuf) error {
 }
 
 func makeSliceWriter(typ reflect.Type) (writer, error) {
-	etypeinfo, err := cachedTypeInfo1(typ.Elem())
+	etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
 	if err != nil {
 		return nil, err
 	}
@@ -530,7 +530,7 @@ func makeStructWriter(typ reflect.Type) (writer, error) {
 }
 
 func makePtrWriter(typ reflect.Type) (writer, error) {
-	etypeinfo, err := cachedTypeInfo1(typ.Elem())
+	etypeinfo, err := cachedTypeInfo1(typ.Elem(), tags{})
 	if err != nil {
 		return nil, err
 	}
diff --git a/rlp/typecache.go b/rlp/typecache.go
index 398f25d90..d512012e9 100644
--- a/rlp/typecache.go
+++ b/rlp/typecache.go
@@ -7,7 +7,7 @@ import (
 
 var (
 	typeCacheMutex sync.RWMutex
-	typeCache      = make(map[reflect.Type]*typeinfo)
+	typeCache      = make(map[typekey]*typeinfo)
 )
 
 type typeinfo struct {
@@ -15,13 +15,25 @@ type typeinfo struct {
 	writer
 }
 
+// represents struct tags
+type tags struct {
+	nilOK bool
+}
+
+type typekey struct {
+	reflect.Type
+	// the key must include the struct tags because they
+	// might generate a different decoder.
+	tags
+}
+
 type decoder func(*Stream, reflect.Value) error
 
 type writer func(reflect.Value, *encbuf) error
 
-func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) {
+func cachedTypeInfo(typ reflect.Type, tags tags) (*typeinfo, error) {
 	typeCacheMutex.RLock()
-	info := typeCache[typ]
+	info := typeCache[typekey{typ, tags}]
 	typeCacheMutex.RUnlock()
 	if info != nil {
 		return info, nil
@@ -29,11 +41,12 @@ func cachedTypeInfo(typ reflect.Type) (*typeinfo, error) {
 	// not in the cache, need to generate info for this type.
 	typeCacheMutex.Lock()
 	defer typeCacheMutex.Unlock()
-	return cachedTypeInfo1(typ)
+	return cachedTypeInfo1(typ, tags)
 }
 
-func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) {
-	info := typeCache[typ]
+func cachedTypeInfo1(typ reflect.Type, tags tags) (*typeinfo, error) {
+	key := typekey{typ, tags}
+	info := typeCache[key]
 	if info != nil {
 		// another goroutine got the write lock first
 		return info, nil
@@ -41,21 +54,27 @@ func cachedTypeInfo1(typ reflect.Type) (*typeinfo, error) {
 	// put a dummmy value into the cache before generating.
 	// if the generator tries to lookup itself, it will get
 	// the dummy value and won't call itself recursively.
-	typeCache[typ] = new(typeinfo)
-	info, err := genTypeInfo(typ)
+	typeCache[key] = new(typeinfo)
+	info, err := genTypeInfo(typ, tags)
 	if err != nil {
 		// remove the dummy value if the generator fails
-		delete(typeCache, typ)
+		delete(typeCache, key)
 		return nil, err
 	}
-	*typeCache[typ] = *info
-	return typeCache[typ], err
+	*typeCache[key] = *info
+	return typeCache[key], err
+}
+
+type field struct {
+	index int
+	info  *typeinfo
 }
 
 func structFields(typ reflect.Type) (fields []field, err error) {
 	for i := 0; i < typ.NumField(); i++ {
 		if f := typ.Field(i); f.PkgPath == "" { // exported
-			info, err := cachedTypeInfo1(f.Type)
+			tags := parseStructTag(f.Tag.Get("rlp"))
+			info, err := cachedTypeInfo1(f.Type, tags)
 			if err != nil {
 				return nil, err
 			}
@@ -65,9 +84,13 @@ func structFields(typ reflect.Type) (fields []field, err error) {
 	return fields, nil
 }
 
-func genTypeInfo(typ reflect.Type) (info *typeinfo, err error) {
+func parseStructTag(tag string) tags {
+	return tags{nilOK: tag == "nil"}
+}
+
+func genTypeInfo(typ reflect.Type, tags tags) (info *typeinfo, err error) {
 	info = new(typeinfo)
-	if info.decoder, err = makeDecoder(typ); err != nil {
+	if info.decoder, err = makeDecoder(typ, tags); err != nil {
 		return nil, err
 	}
 	if info.writer, err = makeWriter(typ); err != nil {
-- 
GitLab