diff --git a/websocket.go b/websocket.go
index 7d33a36a6a0dc4adecef1d2ee4ec7229143ce408..6480aed66ff1b1b8c39da1cf034b4d29df8fa101 100644
--- a/websocket.go
+++ b/websocket.go
@@ -38,10 +38,10 @@ type Conn struct {
 	writeDataLock  chan struct{}
 	writeFrameLock chan struct{}
 
-	readDataLock chan struct{}
-	readData     chan header
-	readDone     chan struct{}
-	readLoopDone chan struct{}
+	readMsgLock   chan struct{}
+	readMsg       chan header
+	readMsgDone   chan struct{}
+	readFrameLock chan struct{}
 
 	setReadTimeout  chan context.Context
 	setWriteTimeout chan context.Context
@@ -90,17 +90,15 @@ func (c *Conn) close(err error) {
 
 		close(c.closed)
 
+		// This ensures every goroutine that interacts
+		// with the conn closes before it can interact with the connection
+		c.readFrameLock <- struct{}{}
+		c.writeFrameLock <- struct{}{}
+
 		// See comment in dial.go
 		if c.client {
-			go func() {
-				<-c.readLoopDone
-				// TODO this does not work if reader errors out so skip for now.
-				// c.readDataLock <- struct{}{}
-				// c.writeFrameLock <- struct{}{}
-				//
-				// returnBufioReader(c.br)
-				// returnBufioWriter(c.bw)
-			}()
+			returnBufioReader(c.br)
+			returnBufioWriter(c.bw)
 		}
 	})
 }
@@ -119,10 +117,10 @@ func (c *Conn) init() {
 	c.writeDataLock = make(chan struct{}, 1)
 	c.writeFrameLock = make(chan struct{}, 1)
 
-	c.readData = make(chan header)
-	c.readDone = make(chan struct{})
-	c.readDataLock = make(chan struct{}, 1)
-	c.readLoopDone = make(chan struct{})
+	c.readMsg = make(chan header)
+	c.readMsgDone = make(chan struct{})
+	c.readMsgLock = make(chan struct{}, 1)
+	c.readFrameLock = make(chan struct{}, 1)
 
 	c.setReadTimeout = make(chan context.Context)
 	c.setWriteTimeout = make(chan context.Context)
@@ -141,8 +139,8 @@ func (c *Conn) init() {
 
 // We never mask inside here because our mask key is always 0,0,0,0.
 // See comment on secWebSocketKey.
-func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error {
-	err := c.acquireLock(ctx, c.writeFrameLock)
+func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) (err error) {
+	err = c.acquireLock(ctx, c.writeFrameLock)
 	if err != nil {
 		return err
 	}
@@ -164,27 +162,33 @@ func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error {
 		}
 	}()
 
+	defer func() {
+		if err != nil {
+			// We need to always release the lock first before closing the connection to ensure
+			// the lock can be acquired inside close.
+			c.releaseLock(c.writeFrameLock)
+			c.close(err)
+		}
+	}()
+
 	h.masked = c.client
 	h.payloadLength = int64(len(p))
 
 	b2 := marshalHeader(h)
 	_, err = c.bw.Write(b2)
 	if err != nil {
-		c.close(xerrors.Errorf("failed to write to connection: %w", err))
-		return c.closeErr
+		return xerrors.Errorf("failed to write to connection: %w", err)
 	}
 	_, err = c.bw.Write(p)
 	if err != nil {
-		c.close(xerrors.Errorf("failed to write to connection: %w", err))
-		return c.closeErr
+		return xerrors.Errorf("failed to write to connection: %w", err)
 
 	}
 
 	if h.fin {
 		err := c.bw.Flush()
 		if err != nil {
-			c.close(xerrors.Errorf("failed to write to connection: %w", err))
-			return c.closeErr
+			return xerrors.Errorf("failed to write to connection: %w", err)
 		}
 	}
 
@@ -279,9 +283,9 @@ func (c *Conn) handleControl(h header) {
 
 func (c *Conn) readTillData() (header, error) {
 	for {
-		h, err := readHeader(c.br)
+		h, err := c.readHeader()
 		if err != nil {
-			return header{}, xerrors.Errorf("failed to read header: %w", err)
+			return header{}, err
 		}
 
 		if h.rsv1 || h.rsv2 || h.rsv3 {
@@ -312,9 +316,22 @@ func (c *Conn) readTillData() (header, error) {
 	}
 }
 
-func (c *Conn) readLoop() {
-	defer close(c.readLoopDone)
+func (c *Conn) readHeader() (header, error) {
+	err := c.acquireLock(context.Background(), c.readFrameLock)
+	if err != nil {
+		return header{}, err
+	}
+	defer c.releaseLock(c.readFrameLock)
 
+	h, err := readHeader(c.br)
+	if err != nil {
+		return header{}, xerrors.Errorf("failed to read header: %w", err)
+	}
+
+	return h, nil
+}
+
+func (c *Conn) readLoop() {
 	for {
 		h, err := c.readTillData()
 		if err != nil {
@@ -325,13 +342,13 @@ func (c *Conn) readLoop() {
 		select {
 		case <-c.closed:
 			return
-		case c.readData <- h:
+		case c.readMsg <- h:
 		}
 
 		select {
 		case <-c.closed:
 			return
-		case <-c.readDone:
+		case <-c.readMsgDone:
 		}
 	}
 }
@@ -374,7 +391,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
 	// Definitely worth seeing what popular browsers do later.
 	p, err := ce.bytes()
 	if err != nil {
-		fmt.Fprintf(os.Stderr, "failed to marshal close frame: %v\n", err)
+		fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err)
 		ce = CloseError{
 			Code: StatusInternalError,
 		}
@@ -415,7 +432,11 @@ func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
 }
 
 func (c *Conn) releaseLock(lock chan struct{}) {
-	<-lock
+	// Allow multiple releases.
+	select {
+	case <-lock:
+	default:
+	}
 }
 
 func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error {
@@ -572,7 +593,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 }
 
 func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
-	err = c.acquireLock(ctx, c.readDataLock)
+	err = c.acquireLock(ctx, c.readMsgLock)
 	if err != nil {
 		return 0, nil, err
 	}
@@ -582,7 +603,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
 		return 0, nil, c.closeErr
 	case <-ctx.Done():
 		return 0, nil, ctx.Err()
-	case h := <-c.readData:
+	case h := <-c.readMsg:
 		if h.opcode == opContinuation {
 			ce := CloseError{
 				Code:   StatusProtocolError,
@@ -631,7 +652,7 @@ func (r *messageReader) read(p []byte) (int, error) {
 		select {
 		case <-r.c.closed:
 			return 0, r.c.closeErr
-		case h := <-r.c.readData:
+		case h := <-r.c.readMsg:
 			if h.opcode != opContinuation {
 				ce := CloseError{
 					Code:   StatusProtocolError,
@@ -654,7 +675,12 @@ func (r *messageReader) read(p []byte) (int, error) {
 	case r.c.setReadTimeout <- r.ctx:
 	}
 
+	err := r.c.acquireLock(r.ctx, r.c.readFrameLock)
+	if err != nil {
+		return 0, err
+	}
 	n, err := io.ReadFull(r.c.br, p)
+	r.c.releaseLock(r.c.readFrameLock)
 
 	select {
 	case <-r.c.closed:
@@ -676,11 +702,11 @@ func (r *messageReader) read(p []byte) (int, error) {
 		select {
 		case <-r.c.closed:
 			return n, r.c.closeErr
-		case r.c.readDone <- struct{}{}:
+		case r.c.readMsgDone <- struct{}{}:
 		}
 		if r.h.fin {
 			r.eofed = true
-			r.c.releaseLock(r.c.readDataLock)
+			r.c.releaseLock(r.c.readMsgLock)
 			return n, io.EOF
 		}
 		r.maskPos = 0