From 250db1efbe15806649120e4f6748de43859b5d12 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sun, 7 Apr 2024 07:47:26 -0700
Subject: [PATCH] read: Fix CloseRead to have its own done channel

Context can be cancelled by parent. Doesn't indicate the CloseRead goroutine
has exited.
---
 close.go    |  6 +++---
 conn.go     | 13 +++++++------
 mask_asm.go |  2 +-
 read.go     |  2 ++
 4 files changed, 13 insertions(+), 10 deletions(-)

diff --git a/close.go b/close.go
index d151259..625ed12 100644
--- a/close.go
+++ b/close.go
@@ -239,11 +239,11 @@ func (c *Conn) waitGoroutines() error {
 	}
 
 	c.closeReadMu.Lock()
-	ctx := c.closeReadCtx
+	closeRead := c.closeReadCtx != nil
 	c.closeReadMu.Unlock()
-	if ctx != nil {
+	if closeRead {
 		select {
-		case <-ctx.Done():
+		case <-c.closeReadDone:
 		case <-t.C:
 			return errors.New("failed to wait for close read goroutine to exit")
 		}
diff --git a/conn.go b/conn.go
index f5da573..8690fb3 100644
--- a/conn.go
+++ b/conn.go
@@ -57,10 +57,10 @@ type Conn struct {
 	timeoutLoopDone chan struct{}
 
 	// Read state.
-	readMu            *mu
-	readHeaderBuf     [8]byte
-	readControlBuf    [maxControlPayload]byte
-	msgReader         *msgReader
+	readMu         *mu
+	readHeaderBuf  [8]byte
+	readControlBuf [maxControlPayload]byte
+	msgReader      *msgReader
 
 	// Write state.
 	msgWriter      *msgWriter
@@ -69,8 +69,9 @@ type Conn struct {
 	writeHeaderBuf [8]byte
 	writeHeader    header
 
-	closeReadMu  sync.Mutex
-	closeReadCtx context.Context
+	closeReadMu   sync.Mutex
+	closeReadCtx  context.Context
+	closeReadDone chan struct{}
 
 	closed  chan struct{}
 	closeMu sync.Mutex
diff --git a/mask_asm.go b/mask_asm.go
index 60c0290..f9484b5 100644
--- a/mask_asm.go
+++ b/mask_asm.go
@@ -3,7 +3,7 @@
 package websocket
 
 func mask(b []byte, key uint32) uint32 {
-    // TODO: Will enable in v1.9.0.
+	// TODO: Will enable in v1.9.0.
 	return maskGo(b, key)
 	/*
 		if len(b) > 0 {
diff --git a/read.go b/read.go
index 5df031c..a59e71d 100644
--- a/read.go
+++ b/read.go
@@ -71,9 +71,11 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
 	}
 	ctx, cancel := context.WithCancel(ctx)
 	c.closeReadCtx = ctx
+	c.closeReadDone = make(chan struct{})
 	c.closeReadMu.Unlock()
 
 	go func() {
+		defer close(c.closeReadDone)
 		defer cancel()
 		defer c.close()
 		_, _, err := c.Reader(ctx)
-- 
GitLab