From b39ca873380498fd7ac2bb1d9fa221404cf90da8 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Wed, 29 May 2019 23:21:55 -0400
Subject: [PATCH] Fix bugs and improve docs

---
 README.md            |   6 +--
 accept.go            |   9 +++-
 dial.go              |  43 ++++++++++++++++--
 example_echo_test.go |   5 ++-
 go.mod               |   2 +-
 go.sum               |   2 +
 statuscode.go        |   2 +-
 websocket.go         | 101 +++++++++++++++++++++++--------------------
 websocket_test.go    |  14 +++---
 wsjson/wsjson.go     |  14 +++++-
 10 files changed, 128 insertions(+), 70 deletions(-)

diff --git a/README.md b/README.md
index 3f42742..1b9af61 100644
--- a/README.md
+++ b/README.md
@@ -22,7 +22,7 @@ go get nhooyr.io/websocket@v0.2.0
 - Zero dependencies outside of the stdlib for the core library
 - JSON and ProtoBuf helpers in the wsjson and wspb subpackages
 - High performance
-- Concurrent writes
+- Concurrent reads and writes out of the box
 
 ## Roadmap
 
@@ -122,8 +122,8 @@ also uses net/http's Client and ResponseWriter directly for WebSocket handshakes
 gorilla/websocket writes its handshakes to the underlying net.Conn which means
 it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2.
 
-Some more advantages of nhooyr/websocket are that it supports concurrent writes and makes it
-very easy to close the connection with a status code and reason.
+Some more advantages of nhooyr/websocket are that it supports concurrent reads,
+writes and makes it very easy to close the connection with a status code and reason.
 
 In terms of performance, there is no significant difference between the two. Will update 
 with benchmarks soon ([#75](https://github.com/nhooyr/websocket/issues/75)).
diff --git a/accept.go b/accept.go
index 207ecc7..17016d2 100644
--- a/accept.go
+++ b/accept.go
@@ -1,8 +1,10 @@
 package websocket
 
 import (
+	"bytes"
 	"crypto/sha1"
 	"encoding/base64"
+	"io"
 	"net/http"
 	"net/textproto"
 	"net/url"
@@ -78,6 +80,9 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
 //
 // Accept will reject the handshake if the Origin domain is not the same as the Host unless
 // the InsecureSkipVerify option is set.
+//
+// The returned connection will be bound by r.Context(). Use c.Context() to change
+// the bounding context.
 func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) {
 	c, err := accept(w, r, opts)
 	if err != nil {
@@ -126,6 +131,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
 		return nil, err
 	}
 
+	b, _ := brw.Reader.Peek(brw.Reader.Buffered())
+	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
+
 	c := &Conn{
 		subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
 		br:          brw.Reader,
@@ -133,7 +141,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
 		closer:      netConn,
 	}
 	c.init()
-	// TODO document.
 	c.Context(r.Context())
 
 	return c, nil
diff --git a/dial.go b/dial.go
index 3c7e71d..f1ad725 100644
--- a/dial.go
+++ b/dial.go
@@ -5,13 +5,13 @@ import (
 	"bytes"
 	"context"
 	"encoding/base64"
+	"golang.org/x/xerrors"
 	"io"
 	"io/ioutil"
 	"net/http"
 	"net/url"
 	"strings"
-
-	"golang.org/x/xerrors"
+	"sync"
 )
 
 // DialOptions represents the options available to pass to Dial.
@@ -112,8 +112,8 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res
 
 	c := &Conn{
 		subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
-		br:          bufio.NewReader(rwc),
-		bw:          bufio.NewWriter(rwc),
+		br:          getBufioReader(rwc),
+		bw:          getBufioWriter(rwc),
 		closer:      rwc,
 		client:      true,
 	}
@@ -140,3 +140,38 @@ func verifyServerResponse(resp *http.Response) error {
 
 	return nil
 }
+
+// The below pools can only be used by the client because http.Hijacker will always
+// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top.
+
+var bufioReaderPool = sync.Pool{
+	New: func() interface{} {
+		return bufio.NewReader(nil)
+	},
+}
+
+func getBufioReader(r io.Reader) *bufio.Reader {
+	br := bufioReaderPool.Get().(*bufio.Reader)
+	br.Reset(r)
+	return br
+}
+
+func returnBufioReader(br *bufio.Reader) {
+	bufioReaderPool.Put(br)
+}
+
+var bufioWriterPool = sync.Pool{
+	New: func() interface{} {
+		return bufio.NewWriter(nil)
+	},
+}
+
+func getBufioWriter(w io.Writer) *bufio.Writer {
+	bw := bufioWriterPool.Get().(*bufio.Writer)
+	bw.Reset(w)
+	return bw
+}
+
+func returnBufioWriter(bw *bufio.Writer) {
+	bufioWriterPool.Put(bw)
+}
diff --git a/example_echo_test.go b/example_echo_test.go
index ab0e8e7..a86d5b8 100644
--- a/example_echo_test.go
+++ b/example_echo_test.go
@@ -51,6 +51,7 @@ func Example_echo() {
 
 	// Now we dial the server, send the messages and echo the responses.
 	err = client("ws://" + l.Addr().String())
+	time.Sleep(time.Second)
 	if err != nil {
 		log.Fatalf("client failed: %v", err)
 	}
@@ -66,6 +67,8 @@ func Example_echo() {
 // It ensures the client speaks the echo subprotocol and
 // only allows one message every 100ms with a 10 message burst.
 func echoServer(w http.ResponseWriter, r *http.Request) error {
+	log.Printf("serving %v", r.RemoteAddr)
+
 	c, err := websocket.Accept(w, r, websocket.AcceptOptions{
 		Subprotocols: []string{"echo"},
 	})
@@ -83,7 +86,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error {
 	for {
 		err = echo(r.Context(), c, l)
 		if err != nil {
-			return xerrors.Errorf("failed to echo: %w", err)
+			return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err)
 		}
 	}
 }
diff --git a/go.mod b/go.mod
index f747eec..cc9a865 100644
--- a/go.mod
+++ b/go.mod
@@ -12,6 +12,6 @@ require (
 	golang.org/x/text v0.3.2 // indirect
 	golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
 	golang.org/x/tools v0.0.0-20190429184909-35c670923e21
-	golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18
+	golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522
 	mvdan.cc/sh v2.6.4+incompatible
 )
diff --git a/go.sum b/go.sum
index 63aaa2a..90c9346 100644
--- a/go.sum
+++ b/go.sum
@@ -30,5 +30,7 @@ golang.org/x/tools v0.0.0-20190429184909-35c670923e21 h1:Kjcw+D2LTzLmxOHrMK9uvYP
 golang.org/x/tools v0.0.0-20190429184909-35c670923e21/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
 golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 h1:1AGvnywFL1aB5KLRxyLseWJI6aSYPo3oF7HSpXdWQdU=
 golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
+golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4=
+golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
 mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM=
 mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8=
diff --git a/statuscode.go b/statuscode.go
index d422374..c7b2036 100644
--- a/statuscode.go
+++ b/statuscode.go
@@ -49,7 +49,7 @@ type CloseError struct {
 }
 
 func (ce CloseError) Error() string {
-	return fmt.Sprintf("websocket closed with status = %v and reason = %q", ce.Code, ce.Reason)
+	return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
 }
 
 func parseClosePayload(p []byte) (CloseError, error) {
diff --git a/websocket.go b/websocket.go
index 6e35281..0b77966 100644
--- a/websocket.go
+++ b/websocket.go
@@ -14,7 +14,8 @@ import (
 )
 
 // Conn represents a WebSocket connection.
-// All methods except Reader can be used concurrently.
+// All methods may be called concurrently.
+//
 // Please be sure to call Close on the connection when you
 // are finished with it to release resources.
 type Conn struct {
@@ -31,8 +32,10 @@ type Conn struct {
 	writeDataLock  chan struct{}
 	writeFrameLock chan struct{}
 
-	readData chan header
-	readDone chan struct{}
+	readDataLock chan struct{}
+	readData     chan header
+	readDone     chan struct{}
+	readLoopDone chan struct{}
 
 	setReadTimeout  chan context.Context
 	setWriteTimeout chan context.Context
@@ -44,7 +47,7 @@ type Conn struct {
 // when the connection is closed.
 // If the parent context is cancelled, the connection will be closed.
 //
-// This is an experimental API meaning it may be remove in the future.
+// This is an experimental API that may be remove in the future.
 // Please let me know how you feel about it.
 func (c *Conn) Context(parent context.Context) context.Context {
 	select {
@@ -77,6 +80,18 @@ func (c *Conn) close(err error) {
 		c.closeErr = xerrors.Errorf("websocket closed: %w", cerr)
 
 		close(c.closed)
+
+		// See comment in dial.go
+		if c.client {
+			go func() {
+				<-c.readLoopDone
+				c.readDataLock <- struct{}{}
+				c.writeFrameLock <- struct{}{}
+
+				returnBufioReader(c.br)
+				returnBufioWriter(c.bw)
+			}()
+		}
 	})
 }
 
@@ -94,6 +109,8 @@ func (c *Conn) init() {
 
 	c.readData = make(chan header)
 	c.readDone = make(chan struct{})
+	c.readDataLock = make(chan struct{}, 1)
+	c.readLoopDone = make(chan struct{})
 
 	c.setReadTimeout = make(chan context.Context)
 	c.setWriteTimeout = make(chan context.Context)
@@ -174,8 +191,8 @@ func (c *Conn) timeoutLoop() {
 		select {
 		case <-c.closed:
 			return
-		case readCtx = <-c.setWriteTimeout:
-		case writeCtx = <-c.setReadTimeout:
+		case writeCtx = <-c.setWriteTimeout:
+		case readCtx = <-c.setReadTimeout:
 		case <-readCtx.Done():
 			c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err()))
 		case <-writeCtx.Done():
@@ -276,6 +293,8 @@ func (c *Conn) readTillData() (header, error) {
 }
 
 func (c *Conn) readLoop() {
+	defer close(c.readLoopDone)
+
 	for {
 		h, err := c.readTillData()
 		if err != nil {
@@ -487,8 +506,7 @@ func (w *messageWriter) close() error {
 //
 // Your application must keep reading messages for the Conn to automatically respond to ping
 // and close frames and not become stuck waiting for a data message to be read.
-// Please ensure to read the full message from io.Reader. If you do not read till
-// io.EOF, the connection will break unless the next read would have yielded io.EOF.
+// Please ensure to read the full message from io.Reader.
 //
 // You can only read a single message at a time so do not call this method
 // concurrently.
@@ -500,30 +518,10 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 	return typ, r, nil
 }
 
-func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
-	// if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
-	// 	// If the next read yields io.EOF we are good to go.
-	// 	r := messageReader{
-	// 		ctx: ctx,
-	// 		c:   c,
-	// 	}
-	// 	_, err := r.Read(nil)
-	// 	if err == nil {
-	// 		return 0, nil, xerrors.New("previous message not fully read")
-	// 	}
-	// 	if !xerrors.Is(err, io.EOF) {
-	// 		return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err)
-	// 	}
-	//
-	// 	atomic.StoreInt64(&c.activeReader, 1)
-	// }
-
-	select {
-	case <-c.closed:
-		return 0, nil, c.closeErr
-	case <-ctx.Done():
-		return 0, nil, ctx.Err()
-	case c.setReadTimeout <- ctx:
+func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
+	err = c.acquireLock(ctx, c.readDataLock)
+	if err != nil {
+		return 0, nil, err
 	}
 
 	select {
@@ -533,25 +531,24 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 		return 0, nil, ctx.Err()
 	case h := <-c.readData:
 		if h.opcode == opContinuation {
-			if h.fin && h.payloadLength == 0 {
-				select {
-				case <-c.closed:
-					return 0, nil, c.closeErr
-				case c.readDone <- struct{}{}:
-					return c.reader(ctx)
-				}
+			ce := CloseError{
+				Code:   StatusProtocolError,
+				Reason: "continuation frame not after data or text frame",
 			}
-			return 0, nil, xerrors.Errorf("previous reader was not read to EOF")
+			c.Close(ce.Code, ce.Reason)
+			return 0, nil, ce
 		}
 		return MessageType(h.opcode), &messageReader{
-			h: &h,
-			c: c,
+			ctx: ctx,
+			h:   &h,
+			c:   c,
 		}, nil
 	}
 }
 
 // messageReader enables reading a data frame from the WebSocket connection.
 type messageReader struct {
+	ctx     context.Context
 	maskPos int
 	h       *header
 	c       *Conn
@@ -598,8 +595,20 @@ func (r *messageReader) read(p []byte) (int, error) {
 		p = p[:r.h.payloadLength]
 	}
 
+	select {
+	case <-r.c.closed:
+		return 0, r.c.closeErr
+	case r.c.setReadTimeout <- r.ctx:
+	}
+
 	n, err := io.ReadFull(r.c.br, p)
 
+	select {
+	case <-r.c.closed:
+		return 0, r.c.closeErr
+	case r.c.setReadTimeout <- context.Background():
+	}
+
 	r.h.payloadLength -= int64(n)
 	if r.h.masked {
 		r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p)
@@ -618,12 +627,8 @@ func (r *messageReader) read(p []byte) (int, error) {
 		}
 		if r.h.fin {
 			r.eofed = true
-			select {
-			case <-r.c.closed:
-				return n, r.c.closeErr
-			case r.c.setReadTimeout <- context.Background():
-				return n, io.EOF
-			}
+			r.c.releaseLock(r.c.readDataLock)
+			return n, io.EOF
 		}
 		r.maskPos = 0
 		r.h = nil
diff --git a/websocket_test.go b/websocket_test.go
index 8d18c73..0ac0557 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -293,10 +293,6 @@ func TestHandshake(t *testing.T) {
 				if err != nil {
 					return err
 				}
-				err = write()
-				if err != nil {
-					return err
-				}
 
 				c.Close(websocket.StatusNormalClosure, "")
 				return nil
@@ -329,11 +325,6 @@ func TestHandshake(t *testing.T) {
 				if err != nil {
 					return err
 				}
-				// Read twice to ensure the un EOFed previous reader works correctly.
-				err = read()
-				if err != nil {
-					return err
-				}
 
 				c.Close(websocket.StatusNormalClosure, "")
 				return nil
@@ -766,6 +757,11 @@ func benchConn(b *testing.B, echo, stream bool, size int) {
 			if err != nil {
 				b.Fatal(err)
 			}
+
+			_, err = r.Read(nil)
+			if !xerrors.Is(err, io.EOF) {
+				b.Fatalf("more data in reader than needed")
+			}
 		}
 	}
 	b.StopTimer()
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
index 9dd61bd..853369e 100644
--- a/wsjson/wsjson.go
+++ b/wsjson/wsjson.go
@@ -4,9 +4,8 @@ package wsjson
 import (
 	"context"
 	"encoding/json"
-	"io"
-
 	"golang.org/x/xerrors"
+	"io"
 
 	"nhooyr.io/websocket"
 )
@@ -41,6 +40,17 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
 		return xerrors.Errorf("failed to decode json: %w", err)
 	}
 
+	// Have to ensure we read till EOF.
+	// Unfortunate but necessary evil for now. Can improve later.
+	// The code to do this automatically gets complicated fast because
+	// we support concurrent reading.
+	// So the Reader has to synchronize with Read somehow.
+	// Maybe its best to bring back the old readLoop?
+	_, err = r.Read(nil)
+	if !xerrors.Is(err, io.EOF) {
+		return xerrors.Errorf("more data than needed in reader")
+	}
+
 	return nil
 }
 
-- 
GitLab