From 780bda4159cd001ed4e1704327c1292a1d21336d Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 4 Nov 2019 18:50:29 -0500
Subject: [PATCH] Fix race with c.readerShouldLock

Closes #168
---
 conn.go | 53 ++++++++++++++++++++++++++++++++++++-----------------
 1 file changed, 36 insertions(+), 17 deletions(-)

diff --git a/conn.go b/conn.go
index cbb7fa5..7d48b8a 100644
--- a/conn.go
+++ b/conn.go
@@ -78,11 +78,10 @@ type Conn struct {
 	readLock          chan struct{}
 
 	// messageReader state.
-	readerMsgCtx     context.Context
-	readerMsgHeader  header
-	readerFrameEOF   bool
-	readerMaskPos    int
-	readerShouldLock bool
+	readerMsgCtx    context.Context
+	readerMsgHeader header
+	readerFrameEOF  bool
+	readerMaskPos   int
 
 	setReadTimeout  chan context.Context
 	setWriteTimeout chan context.Context
@@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
 	c.readerFrameEOF = false
 	c.readerMaskPos = 0
 	c.readMsgLeft = c.msgReadLimit.Load()
-	c.readerShouldLock = lock
 
 	r := &messageReader{
 		c: c,
@@ -465,7 +463,11 @@ func (r *messageReader) eof() bool {
 
 // Read reads as many bytes as possible into p.
 func (r *messageReader) Read(p []byte) (int, error) {
-	n, err := r.read(p)
+	return r.exportedRead(p, true)
+}
+
+func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
+	n, err := r.read(p, lock)
 	if err != nil {
 		// Have to return io.EOF directly for now, we cannot wrap as errors.Is
 		// isn't used widely yet.
@@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
 	return n, nil
 }
 
-func (r *messageReader) read(p []byte) (int, error) {
-	if r.c.readerShouldLock {
-		err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
-		if err != nil {
-			return 0, err
+func (r *messageReader) readUnlocked(p []byte) (int, error) {
+	return r.exportedRead(p, false)
+}
+
+func (r *messageReader) read(p []byte, lock bool) (int, error) {
+	if lock {
+		// If we cannot acquire the read lock, then
+		// there is either a concurrent read or the close handshake
+		// is proceeding.
+		select {
+		case r.c.readLock <- struct{}{}:
+			defer r.c.releaseLock(r.c.readLock)
+		default:
+			if r.c.closing.Load() == 1 {
+				<-r.c.closed
+				return 0, r.c.closeErr
+			}
+			return 0, errors.New("concurrent read detected")
 		}
-		defer r.c.releaseLock(r.c.readLock)
 	}
 
 	if r.eof() {
-		return 0, fmt.Errorf("cannot use EOFed reader")
+		return 0, errors.New("cannot use EOFed reader")
 	}
 
 	if r.c.readMsgLeft <= 0 {
@@ -950,8 +964,6 @@ func (c *Conn) waitClose() error {
 		return c.closeReceived
 	}
 
-	c.readerShouldLock = false
-
 	b := bpool.Get()
 	buf := b.Bytes()
 	buf = buf[:cap(buf)]
@@ -965,7 +977,8 @@ func (c *Conn) waitClose() error {
 			}
 		}
 
-		_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
+		r := readerFunc(c.activeReader.readUnlocked)
+		_, err = io.CopyBuffer(ioutil.Discard, r, buf)
 		if err != nil {
 			return err
 		}
@@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
 	}
 }
 
+type readerFunc func(p []byte) (int, error)
+
+func (f readerFunc) Read(p []byte) (int, error) {
+	return f(p)
+}
+
 type writerFunc func(p []byte) (int, error)
 
 func (f writerFunc) Write(p []byte) (int, error) {
-- 
GitLab