From 6e7e5d5fd56a9a6f73e51239ed6648d76db9650d Mon Sep 17 00:00:00 2001
From: Felix Lange <fjl@twurst.com>
Date: Wed, 4 Mar 2015 13:12:50 +0100
Subject: [PATCH] eth, whisper: fix msg.Payload reads

---
 eth/protocol.go | 36 +++++++++++++++++++++---------------
 whisper/peer.go | 27 ++++++++++-----------------
 2 files changed, 31 insertions(+), 32 deletions(-)

diff --git a/eth/protocol.go b/eth/protocol.go
index 663af43fe..b86f33614 100644
--- a/eth/protocol.go
+++ b/eth/protocol.go
@@ -3,7 +3,6 @@ package eth
 import (
 	"bytes"
 	"fmt"
-	"io"
 	"math/big"
 
 	"github.com/ethereum/go-ethereum/core/types"
@@ -188,33 +187,37 @@ func (self *ethProtocol) handle() error {
 
 	case BlockHashesMsg:
 		msgStream := rlp.NewStream(msg.Payload)
-		var err error
-		var i int
+		if _, err := msgStream.List(); err != nil {
+			return err
+		}
 
+		var i int
 		iter := func() (hash []byte, ok bool) {
-			hash, err = msgStream.Bytes()
-			if err == nil {
-				i++
-				ok = true
-			} else {
-				if err != io.EOF {
-					self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err)
-				}
+			hash, err := msgStream.Bytes()
+			if err == rlp.EOL {
+				return nil, false
+			} else if err != nil {
+				self.protoError(ErrDecode, "msg %v: after %v hashes : %v", msg, i, err)
+				return nil, false
 			}
-			return
+			i++
+			return hash, true
 		}
-
 		self.blockPool.AddBlockHashes(iter, self.id)
 
 	case GetBlocksMsg:
 		msgStream := rlp.NewStream(msg.Payload)
+		if _, err := msgStream.List(); err != nil {
+			return err
+		}
+
 		var blocks []interface{}
 		var i int
 		for {
 			i++
 			var hash []byte
 			if err := msgStream.Decode(&hash); err != nil {
-				if err == io.EOF {
+				if err == rlp.EOL {
 					break
 				} else {
 					return self.protoError(ErrDecode, "msg %v: %v", msg, err)
@@ -232,10 +235,13 @@ func (self *ethProtocol) handle() error {
 
 	case BlocksMsg:
 		msgStream := rlp.NewStream(msg.Payload)
+		if _, err := msgStream.List(); err != nil {
+			return err
+		}
 		for {
 			var block types.Block
 			if err := msgStream.Decode(&block); err != nil {
-				if err == io.EOF {
+				if err == rlp.EOL {
 					break
 				} else {
 					return self.protoError(ErrDecode, "msg %v: %v", msg, err)
diff --git a/whisper/peer.go b/whisper/peer.go
index 332ddd22a..66cfec88c 100644
--- a/whisper/peer.go
+++ b/whisper/peer.go
@@ -2,10 +2,10 @@ package whisper
 
 import (
 	"fmt"
-	"io/ioutil"
 	"time"
 
 	"github.com/ethereum/go-ethereum/p2p"
+	"github.com/ethereum/go-ethereum/rlp"
 	"gopkg.in/fatih/set.v0"
 )
 
@@ -77,8 +77,7 @@ func (self *peer) broadcast(envelopes []*Envelope) error {
 	}
 
 	if i > 0 {
-		msg := p2p.NewMsg(envelopesMsg, envs[:i]...)
-		if err := self.ws.WriteMsg(msg); err != nil {
+		if err := p2p.EncodeMsg(self.ws, envelopesMsg, envs[:i]...); err != nil {
 			return err
 		}
 		self.peer.DebugDetailln("broadcasted", i, "message(s)")
@@ -93,34 +92,28 @@ func (self *peer) addKnown(envelope *Envelope) {
 
 func (self *peer) handleStatus() error {
 	ws := self.ws
-
 	if err := ws.WriteMsg(self.statusMsg()); err != nil {
 		return err
 	}
-
 	msg, err := ws.ReadMsg()
 	if err != nil {
 		return err
 	}
-
 	if msg.Code != statusMsg {
 		return fmt.Errorf("peer send %x before status msg", msg.Code)
 	}
-
-	data, err := ioutil.ReadAll(msg.Payload)
-	if err != nil {
-		return err
+	s := rlp.NewStream(msg.Payload)
+	if _, err := s.List(); err != nil {
+		return fmt.Errorf("bad status message: %v", err)
 	}
-
-	if len(data) == 0 {
-		return fmt.Errorf("malformed status. data len = 0")
+	pv, err := s.Uint()
+	if err != nil {
+		return fmt.Errorf("bad status message: %v", err)
 	}
-
-	if pv := data[0]; pv != protocolVersion {
+	if pv != protocolVersion {
 		return fmt.Errorf("protocol version mismatch %d != %d", pv, protocolVersion)
 	}
-
-	return nil
+	return msg.Discard() // ignore anything after protocol version
 }
 
 func (self *peer) statusMsg() p2p.Msg {
-- 
GitLab