From 63f27e246b5bad3c2d0cc674be91fa03bae40d36 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Tue, 20 Aug 2019 17:46:00 -0400
Subject: [PATCH] Reduce Reader/Writer allocations

Closes #116
---
 websocket.go | 43 +++++++++++++++++++++++++++----------------
 1 file changed, 27 insertions(+), 16 deletions(-)

diff --git a/websocket.go b/websocket.go
index ee61f54..5942b68 100644
--- a/websocket.go
+++ b/websocket.go
@@ -56,6 +56,8 @@ type Conn struct {
 	// read limit for a message in bytes.
 	msgReadLimit int64
 
+	// Used to ensure a previous writer is not used after being closed.
+	activeWriter *messageWriter
 	// messageWriter state.
 	writeMsgOpcode opcode
 	writeMsgCtx    context.Context
@@ -63,7 +65,7 @@ type Conn struct {
 
 	// Used to ensure the previous reader is read till EOF before allowing
 	// a new one.
-	previousReader *messageReader
+	activeReader *messageReader
 	// readFrameLock is acquired to read from bw.
 	readFrameLock     chan struct{}
 	readClosed        int64
@@ -358,7 +360,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 }
 
 func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
-	if c.previousReader != nil && !c.readFrameEOF {
+	if c.activeReader != nil && !c.readFrameEOF {
 		// The only way we know for sure the previous reader is not yet complete is
 		// if there is an active frame not yet fully read.
 		// Otherwise, a user may have read the last byte but not the EOF if the EOF
@@ -371,7 +373,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 		return 0, nil, err
 	}
 
-	if c.previousReader != nil && !c.previousReader.eof {
+	if c.activeReader != nil && !c.activeReader.eof() {
 		if h.opcode != opContinuation {
 			err := xerrors.Errorf("received new data message without finishing the previous message")
 			c.Close(StatusProtocolError, err.Error())
@@ -382,7 +384,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 			return 0, nil, xerrors.Errorf("previous message not read to completion")
 		}
 
-		c.previousReader.eof = true
+		c.activeReader = nil
 
 		h, err = c.readTillMsg(ctx)
 		if err != nil {
@@ -403,7 +405,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 	r := &messageReader{
 		c: c,
 	}
-	c.previousReader = r
+	c.activeReader = r
 	return MessageType(h.opcode), r, nil
 }
 
@@ -430,8 +432,11 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
 
 // messageReader enables reading a data frame from the WebSocket connection.
 type messageReader struct {
-	c   *Conn
-	eof bool
+	c *Conn
+}
+
+func (r *messageReader) eof() bool {
+	return r.c.activeReader != r
 }
 
 // Read reads as many bytes as possible into p.
@@ -449,7 +454,7 @@ func (r *messageReader) Read(p []byte) (int, error) {
 }
 
 func (r *messageReader) read(p []byte) (int, error) {
-	if r.eof {
+	if r.eof() {
 		return 0, xerrors.Errorf("cannot use EOFed reader")
 	}
 
@@ -502,7 +507,7 @@ func (r *messageReader) read(p []byte) (int, error) {
 		r.c.readFrameEOF = true
 
 		if h.fin {
-			r.eof = true
+			r.c.activeReader = nil
 			return n, io.EOF
 		}
 	}
@@ -593,9 +598,11 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
 	}
 	c.writeMsgCtx = ctx
 	c.writeMsgOpcode = opcode(typ)
-	return &messageWriter{
+	w := &messageWriter{
 		c: c,
-	}, nil
+	}
+	c.activeWriter = w
+	return w, nil
 }
 
 // Write is a convenience method to write a message to the connection.
@@ -622,8 +629,11 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
 
 // messageWriter enables writing to a WebSocket connection.
 type messageWriter struct {
-	c      *Conn
-	closed bool
+	c *Conn
+}
+
+func (w *messageWriter) closed() bool {
+	return w != w.c.activeWriter
 }
 
 // Write writes the given bytes to the WebSocket connection.
@@ -636,7 +646,7 @@ func (w *messageWriter) Write(p []byte) (int, error) {
 }
 
 func (w *messageWriter) write(p []byte) (int, error) {
-	if w.closed {
+	if w.closed() {
 		return 0, xerrors.Errorf("cannot use closed writer")
 	}
 	n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
@@ -658,16 +668,17 @@ func (w *messageWriter) Close() error {
 }
 
 func (w *messageWriter) close() error {
-	if w.closed {
+	if w.closed() {
 		return xerrors.Errorf("cannot use closed writer")
 	}
-	w.closed = true
+	w.closed()
 
 	_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
 	if err != nil {
 		return xerrors.Errorf("failed to write fin frame: %w", err)
 	}
 
+	w.c.activeWriter = nil
 	w.c.releaseLock(w.c.writeMsgLock)
 	return nil
 }
-- 
GitLab