diff --git a/close.go b/close.go
index ff2e878a93b5ce989c7566d95cca5eb9085602e3..f94951dcc47c29ff364b4b9eecc649da70977cda 100644
--- a/close.go
+++ b/close.go
@@ -100,7 +100,7 @@ func CloseStatus(err error) StatusCode {
 func (c *Conn) Close(code StatusCode, reason string) (err error) {
 	defer errd.Wrap(&err, "failed to close WebSocket")
 
-	if !c.casClosing() {
+	if c.casClosing() {
 		err = c.waitGoroutines()
 		if err != nil {
 			return err
@@ -133,7 +133,7 @@ func (c *Conn) Close(code StatusCode, reason string) (err error) {
 func (c *Conn) CloseNow() (err error) {
 	defer errd.Wrap(&err, "failed to immediately close WebSocket")
 
-	if !c.casClosing() {
+	if c.casClosing() {
 		err = c.waitGoroutines()
 		if err != nil {
 			return err
@@ -329,13 +329,7 @@ func (ce CloseError) bytesErr() ([]byte, error) {
 }
 
 func (c *Conn) casClosing() bool {
-	c.closeMu.Lock()
-	defer c.closeMu.Unlock()
-	if !c.closing {
-		c.closing = true
-		return true
-	}
-	return false
+	return c.closing.Swap(true)
 }
 
 func (c *Conn) isClosed() bool {
diff --git a/conn.go b/conn.go
index d7434a9d55cf94008bc87dfbe00f66720230c326..76b057dd29d8ca8b21ac9bf6c9f041d01f81d013 100644
--- a/conn.go
+++ b/conn.go
@@ -69,13 +69,19 @@ type Conn struct {
 	writeHeaderBuf [8]byte
 	writeHeader    header
 
+	// Close handshake state.
+	closeStateMu     sync.RWMutex
+	closeReceivedErr error
+	closeSentErr     error
+
+	// CloseRead state.
 	closeReadMu   sync.Mutex
 	closeReadCtx  context.Context
 	closeReadDone chan struct{}
 
+	closing atomic.Bool
+	closeMu sync.Mutex // Protects following.
 	closed  chan struct{}
-	closeMu sync.Mutex
-	closing bool
 
 	pingCounter   atomic.Int64
 	activePingsMu sync.Mutex
diff --git a/conn_test.go b/conn_test.go
index b4d57f2108cbcaadd84778346cdd3b973ecef295..9ed8c7ea256c7125e3f47fad4b1c87fb8588ec29 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -8,6 +8,7 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"net"
 	"net/http"
 	"net/http/httptest"
 	"os"
@@ -460,7 +461,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
 }
 
 func BenchmarkConn(b *testing.B) {
-	var benchCases = []struct {
+	benchCases := []struct {
 		name string
 		mode websocket.CompressionMode
 	}{
@@ -625,3 +626,149 @@ func TestConcurrentClosePing(t *testing.T) {
 		}()
 	}
 }
+
+func TestConnClosePropagation(t *testing.T) {
+	t.Parallel()
+
+	want := []byte("hello")
+	keepWriting := func(c *websocket.Conn) <-chan error {
+		return xsync.Go(func() error {
+			for {
+				err := c.Write(context.Background(), websocket.MessageText, want)
+				if err != nil {
+					return err
+				}
+			}
+		})
+	}
+	keepReading := func(c *websocket.Conn) <-chan error {
+		return xsync.Go(func() error {
+			for {
+				_, got, err := c.Read(context.Background())
+				if err != nil {
+					return err
+				}
+				if !bytes.Equal(want, got) {
+					return fmt.Errorf("unexpected message: want %q, got %q", want, got)
+				}
+			}
+		})
+	}
+	checkReadErr := func(t *testing.T, err error) {
+		// Check read error (output depends on when read is called in relation to connection closure).
+		var ce websocket.CloseError
+		if errors.As(err, &ce) {
+			assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
+		} else {
+			assert.ErrorIs(t, net.ErrClosed, err)
+		}
+	}
+	checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
+		for _, c := range conn {
+			// Check write error.
+			err := c.Write(context.Background(), websocket.MessageText, want)
+			assert.ErrorIs(t, net.ErrClosed, err)
+
+			_, _, err = c.Read(context.Background())
+			checkReadErr(t, err)
+		}
+	}
+
+	t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
+		tt, this, other := newConnTest(t, nil, nil)
+
+		_ = this.CloseRead(tt.ctx)
+		thisWriteErr := keepWriting(this)
+
+		_, got, err := other.Read(tt.ctx)
+		assert.Success(t, err)
+		assert.Equal(t, "msg", want, got)
+
+		err = other.Close(websocket.StatusNormalClosure, "")
+		assert.Success(t, err)
+
+		select {
+		case err := <-thisWriteErr:
+			assert.ErrorIs(t, net.ErrClosed, err)
+		case <-tt.ctx.Done():
+			t.Fatal(tt.ctx.Err())
+		}
+
+		checkConnErrs(t, this, other)
+	})
+	t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
+		tt, this, other := newConnTest(t, nil, nil)
+
+		_ = this.CloseRead(tt.ctx)
+		thisWriteErr := keepWriting(this)
+		otherReadErr := keepReading(other)
+
+		err := this.Close(websocket.StatusNormalClosure, "")
+		assert.Success(t, err)
+
+		select {
+		case err := <-thisWriteErr:
+			assert.ErrorIs(t, net.ErrClosed, err)
+		case <-tt.ctx.Done():
+			t.Fatal(tt.ctx.Err())
+		}
+
+		select {
+		case err := <-otherReadErr:
+			checkReadErr(t, err)
+		case <-tt.ctx.Done():
+			t.Fatal(tt.ctx.Err())
+		}
+
+		checkConnErrs(t, this, other)
+	})
+	t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
+		tt, this, other := newConnTest(t, nil, nil)
+
+		_ = other.CloseRead(tt.ctx)
+		errs := keepReading(this)
+
+		err := other.Write(tt.ctx, websocket.MessageText, want)
+		assert.Success(t, err)
+
+		err = other.Close(websocket.StatusNormalClosure, "")
+		assert.Success(t, err)
+
+		select {
+		case err := <-errs:
+			checkReadErr(t, err)
+		case <-tt.ctx.Done():
+			t.Fatal(tt.ctx.Err())
+		}
+
+		checkConnErrs(t, this, other)
+	})
+	t.Run("CloseThisSideDuringRead", func(t *testing.T) {
+		tt, this, other := newConnTest(t, nil, nil)
+
+		thisReadErr := keepReading(this)
+		otherReadErr := keepReading(other)
+
+		err := other.Write(tt.ctx, websocket.MessageText, want)
+		assert.Success(t, err)
+
+		err = this.Close(websocket.StatusNormalClosure, "")
+		assert.Success(t, err)
+
+		select {
+		case err := <-thisReadErr:
+			checkReadErr(t, err)
+		case <-tt.ctx.Done():
+			t.Fatal(tt.ctx.Err())
+		}
+
+		select {
+		case err := <-otherReadErr:
+			checkReadErr(t, err)
+		case <-tt.ctx.Done():
+			t.Fatal(tt.ctx.Err())
+		}
+
+		checkConnErrs(t, this, other)
+	})
+}
diff --git a/read.go b/read.go
index e2699da55f511ac68383dee425f9edb1883dc0a8..1267b5b91bb56d3cbaec1ab8cabf7e3f5f3834af 100644
--- a/read.go
+++ b/read.go
@@ -217,57 +217,68 @@ func (c *Conn) readLoop(ctx context.Context) (header, error) {
 	}
 }
 
-func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
+// prepareRead sets the readTimeout context and returns a done function
+// to be called after the read is done. It also returns an error if the
+// connection is closed. The reference to the error is used to assign
+// an error depending on if the connection closed or the context timed
+// out during use. Typically the referenced error is a named return
+// variable of the function calling this method.
+func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
 	select {
 	case <-c.closed:
-		return header{}, net.ErrClosed
+		return nil, net.ErrClosed
 	case c.readTimeout <- ctx:
 	}
 
-	h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
-	if err != nil {
+	done := func() {
 		select {
 		case <-c.closed:
-			return header{}, net.ErrClosed
-		case <-ctx.Done():
-			return header{}, ctx.Err()
-		default:
-			return header{}, err
+			if *err != nil {
+				*err = net.ErrClosed
+			}
+		case c.readTimeout <- context.Background():
+		}
+		if *err != nil && ctx.Err() != nil {
+			*err = ctx.Err()
 		}
 	}
 
-	select {
-	case <-c.closed:
-		return header{}, net.ErrClosed
-	case c.readTimeout <- context.Background():
+	c.closeStateMu.Lock()
+	closeReceivedErr := c.closeReceivedErr
+	c.closeStateMu.Unlock()
+	if closeReceivedErr != nil {
+		defer done()
+		return nil, closeReceivedErr
 	}
 
-	return h, nil
+	return done, nil
 }
 
-func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
-	select {
-	case <-c.closed:
-		return 0, net.ErrClosed
-	case c.readTimeout <- ctx:
+func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
+	readDone, err := c.prepareRead(ctx, &err)
+	if err != nil {
+		return header{}, err
 	}
+	defer readDone()
 
-	n, err := io.ReadFull(c.br, p)
+	h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
 	if err != nil {
-		select {
-		case <-c.closed:
-			return n, net.ErrClosed
-		case <-ctx.Done():
-			return n, ctx.Err()
-		default:
-			return n, fmt.Errorf("failed to read frame payload: %w", err)
-		}
+		return header{}, err
 	}
 
-	select {
-	case <-c.closed:
-		return n, net.ErrClosed
-	case c.readTimeout <- context.Background():
+	return h, nil
+}
+
+func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
+	readDone, err := c.prepareRead(ctx, &err)
+	if err != nil {
+		return 0, err
+	}
+	defer readDone()
+
+	n, err := io.ReadFull(c.br, p)
+	if err != nil {
+		return n, fmt.Errorf("failed to read frame payload: %w", err)
 	}
 
 	return n, err
@@ -325,9 +336,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
 	}
 
 	err = fmt.Errorf("received close frame: %w", ce)
-	c.writeClose(ce.Code, ce.Reason)
-	c.readMu.unlock()
-	c.close()
+	c.closeStateMu.Lock()
+	c.closeReceivedErr = err
+	closeSent := c.closeSentErr != nil
+	c.closeStateMu.Unlock()
+
+	// Only unlock readMu if this connection is being closed becaue
+	// c.close will try to acquire the readMu lock. We unlock for
+	// writeClose as well because it may also call c.close.
+	if !closeSent {
+		c.readMu.unlock()
+		_ = c.writeClose(ce.Code, ce.Reason)
+	}
+	if !c.casClosing() {
+		c.readMu.unlock()
+		_ = c.close()
+	}
 	return err
 }
 
diff --git a/write.go b/write.go
index e294a680e534703a21a6bddc063422f41c23ce46..7324de7427ede017a180f1848d53477ca9bf7108 100644
--- a/write.go
+++ b/write.go
@@ -5,6 +5,7 @@ package websocket
 
 import (
 	"bufio"
+	"compress/flate"
 	"context"
 	"crypto/rand"
 	"encoding/binary"
@@ -14,8 +15,6 @@ import (
 	"net"
 	"time"
 
-	"compress/flate"
-
 	"github.com/coder/websocket/internal/errd"
 	"github.com/coder/websocket/internal/util"
 )
@@ -249,22 +248,36 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
 	}
 	defer c.writeFrameMu.unlock()
 
+	defer func() {
+		if c.isClosed() && opcode == opClose {
+			err = nil
+		}
+		if err != nil {
+			if ctx.Err() != nil {
+				err = ctx.Err()
+			} else if c.isClosed() {
+				err = net.ErrClosed
+			}
+			err = fmt.Errorf("failed to write frame: %w", err)
+		}
+	}()
+
+	c.closeStateMu.Lock()
+	closeSentErr := c.closeSentErr
+	c.closeStateMu.Unlock()
+	if closeSentErr != nil {
+		return 0, net.ErrClosed
+	}
+
 	select {
 	case <-c.closed:
 		return 0, net.ErrClosed
 	case c.writeTimeout <- ctx:
 	}
-
 	defer func() {
-		if err != nil {
-			select {
-			case <-c.closed:
-				err = net.ErrClosed
-			case <-ctx.Done():
-				err = ctx.Err()
-			default:
-			}
-			err = fmt.Errorf("failed to write frame: %w", err)
+		select {
+		case <-c.closed:
+		case c.writeTimeout <- context.Background():
 		}
 	}()
 
@@ -303,13 +316,16 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
 		}
 	}
 
-	select {
-	case <-c.closed:
-		if opcode == opClose {
-			return n, nil
+	if opcode == opClose {
+		c.closeStateMu.Lock()
+		c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
+		closeReceived := c.closeReceivedErr != nil
+		c.closeStateMu.Unlock()
+
+		if closeReceived && !c.casClosing() {
+			c.writeFrameMu.unlock()
+			_ = c.close()
 		}
-		return n, net.ErrClosed
-	case c.writeTimeout <- context.Background():
 	}
 
 	return n, nil