diff --git a/read.go b/read.go
index bd0ddf95a1e34014b121a65d447faf61465b60bf..8bd736950db10b53eb1bfd43dbfe06e19bafe374 100644
--- a/read.go
+++ b/read.go
@@ -64,9 +64,10 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
 // This function is idempotent.
 func (c *Conn) CloseRead(ctx context.Context) context.Context {
 	c.closeReadMu.Lock()
-	if c.closeReadCtx != nil {
+	ctx2 := c.closeReadCtx
+	if ctx2 != nil {
 		c.closeReadMu.Unlock()
-		return c.closeReadCtx
+		return ctx2
 	}
 	ctx, cancel := context.WithCancel(ctx)
 	c.closeReadCtx = ctx
diff --git a/ws_js.go b/ws_js.go
index 2b8e3b3db5c61900cbfff139c3443d5c4bc30e73..02d61f28c13e6ddacf36126227ad2a3d48209ed3 100644
--- a/ws_js.go
+++ b/ws_js.go
@@ -47,9 +47,10 @@ type Conn struct {
 	// read limit for a message in bytes.
 	msgReadLimit xsync.Int64
 
-	wg            sync.WaitGroup
+	closeReadMu  sync.Mutex
+	closeReadCtx context.Context
+
 	closingMu     sync.Mutex
-	isReadClosed  xsync.Int64
 	closeOnce     sync.Once
 	closed        chan struct{}
 	closeErrOnce  sync.Once
@@ -130,7 +131,10 @@ func (c *Conn) closeWithInternal() {
 // Read attempts to read a message from the connection.
 // The maximum time spent waiting is bounded by the context.
 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
-	if c.isReadClosed.Load() == 1 {
+	c.closeReadMu.Lock()
+	closedRead := c.closeReadCtx != nil
+	c.closeReadMu.Unlock()
+	if closedRead {
 		return 0, nil, errors.New("WebSocket connection read closed")
 	}
 
@@ -387,14 +391,19 @@ func (w *writer) Close() error {
 
 // CloseRead implements *Conn.CloseRead for wasm.
 func (c *Conn) CloseRead(ctx context.Context) context.Context {
-	c.isReadClosed.Store(1)
-
+	c.closeReadMu.Lock()
+	ctx2 := c.closeReadCtx
+	if ctx2 != nil {
+		c.closeReadMu.Unlock()
+		return ctx2
+	}
 	ctx, cancel := context.WithCancel(ctx)
-	c.wg.Add(1)
+	c.closeReadCtx = ctx
+	c.closeReadMu.Unlock()
+
 	go func() {
-		defer c.CloseNow()
-		defer c.wg.Done()
 		defer cancel()
+		defer c.CloseNow()
 		_, _, err := c.read(ctx)
 		if err != nil {
 			c.Close(StatusPolicyViolation, "unexpected data message")