From 2e4b1105932814e737c4fa3b5048bc9d72d7dea3 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 1 Jul 2019 10:36:09 -0400
Subject: [PATCH] Protect against Reader after CloseRead

Closes #101
---
 netconn.go   |  3 ++-
 websocket.go | 11 ++++++++++-
 2 files changed, 12 insertions(+), 2 deletions(-)

diff --git a/netconn.go b/netconn.go
index 0de2f1c..184d5d6 100644
--- a/netconn.go
+++ b/netconn.go
@@ -2,11 +2,12 @@ package websocket
 
 import (
 	"context"
-	"golang.org/x/xerrors"
 	"io"
 	"math"
 	"net"
 	"time"
+
+	"golang.org/x/xerrors"
 )
 
 // NetConn converts a *websocket.Conn into a net.Conn.
diff --git a/websocket.go b/websocket.go
index e7fb0df..f875a14 100644
--- a/websocket.go
+++ b/websocket.go
@@ -12,6 +12,7 @@ import (
 	"runtime"
 	"strconv"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"golang.org/x/xerrors"
@@ -64,6 +65,7 @@ type Conn struct {
 	previousReader *messageReader
 	// readFrameLock is acquired to read from bw.
 	readFrameLock     chan struct{}
+	readClosed        int64
 	readHeaderBuf     []byte
 	controlPayloadBuf []byte
 
@@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 // See https://github.com/nhooyr/websocket/issues/87#issue-451703332
 // Most users should not need this.
 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
+	if atomic.LoadInt64(&c.readClosed) == 1 {
+		return 0, nil, xerrors.Errorf("websocket connection read closed")
+	}
+
 	typ, r, err := c.reader(ctx)
 	if err != nil {
 		return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
@@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 // Use this when you do not want to read data messages from the connection anymore but will
 // want to write messages to it.
 func (c *Conn) CloseRead(ctx context.Context) context.Context {
+	atomic.StoreInt64(&c.readClosed, 1)
+
 	ctx, cancel := context.WithCancel(ctx)
 	go func() {
 		defer cancel()
-		c.Reader(ctx)
+		// We use the unexported reader so that we don't get the read closed error.
+		c.reader(ctx)
 		c.Close(StatusPolicyViolation, "unexpected data message")
 	}()
 	return ctx
-- 
GitLab