diff --git a/websocket.go b/websocket.go
index 275af9da72d3be94438701401bb17e189148bbd0..912508d5635321679d8c2459a66035fac11ab9fb 100644
--- a/websocket.go
+++ b/websocket.go
@@ -61,8 +61,6 @@ type Conn struct {
 }
 
 func (c *Conn) close(err error) {
-	err = xerrors.Errorf("websocket closed: %w", err)
-
 	c.closeOnce.Do(func() {
 		runtime.SetFinalizer(c, nil)
 
@@ -71,7 +69,7 @@ func (c *Conn) close(err error) {
 			cerr = err
 		}
 
-		c.closeErr = cerr
+		c.closeErr = xerrors.Errorf("websocket closed: %w", cerr)
 
 		close(c.closed)
 	})
@@ -98,7 +96,7 @@ func (c *Conn) init() {
 	c.readDone = make(chan int)
 
 	runtime.SetFinalizer(c, func(c *Conn) {
-		c.Close(StatusInternalError, "connection garbage collected")
+		c.close(xerrors.New("connection garbage collected"))
 	})
 
 	go c.writeLoop()
@@ -238,7 +236,7 @@ func (c *Conn) handleControl(h header) {
 	case opClose:
 		ce, err := parseClosePayload(b)
 		if err != nil {
-			c.close(xerrors.Errorf("read invalid close payload: %w", err))
+			c.close(xerrors.Errorf("received invalid close payload: %w", err))
 			return
 		}
 		if ce.Code == StatusNoStatusRcvd {
@@ -302,7 +300,7 @@ func (c *Conn) readLoop() {
 	}
 }
 
-func (c *Conn) dataReadLoop(h header) (err error) {
+func (c *Conn) dataReadLoop(h header) error {
 	maskPos := 0
 	left := h.payloadLength
 	firstReadDone := false
@@ -355,7 +353,6 @@ func (c *Conn) writePong(p []byte) error {
 
 // Close closes the WebSocket connection with the given status code and reason.
 // It will write a WebSocket close frame with a timeout of 5 seconds.
-// Concurrent calls to Close are ok.
 func (c *Conn) Close(code StatusCode, reason string) error {
 	err := c.exportedClose(code, reason)
 	if err != nil {
@@ -400,7 +397,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
 		return err
 	}
 
-	if cerr != c.closeErr {
+	if !xerrors.Is(c.closeErr, cerr) {
 		return c.closeErr
 	}
 
@@ -420,9 +417,8 @@ func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) er
 		payload: p,
 	}:
 	case <-ctx.Done():
-		err := xerrors.Errorf("control frame write timed out: %w", ctx.Err())
-		c.close(err)
-		return err
+		c.close(xerrors.Errorf("control frame write timed out: %w", ctx.Err()))
+		return ctx.Err()
 	}
 
 	select {
@@ -487,7 +483,7 @@ func (w messageWriter) write(p []byte) (int, error) {
 		select {
 		case <-w.ctx.Done():
 			w.c.close(xerrors.Errorf("data write timed out: %w", w.ctx.Err()))
-			// Wait for writeLoop to complete so we know p is done.
+			// Wait for writeLoop to complete so we know p is done with.
 			<-w.c.writeDone
 			return 0, w.ctx.Err()
 		case _, ok := <-w.c.writeDone:
@@ -542,25 +538,21 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 }
 
 func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
-	for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
-		select {
-		case <-c.closed:
-			return 0, nil, c.closeErr
-		case c.readBytes <- nil:
-			select {
-			case <-ctx.Done():
-				return 0, nil, ctx.Err()
-			case _, ok := <-c.readDone:
-				if !ok {
-					return 0, nil, c.closeErr
-				}
-				if atomic.LoadInt64(&c.activeReader) == 1 {
-					return 0, nil, xerrors.New("previous message not fully read")
-				}
-			}
-		case <-ctx.Done():
-			return 0, nil, ctx.Err()
+	if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
+		// If the next read yields io.EOF we are good to go.
+		r := messageReader{
+			ctx: ctx,
+			c:   c,
 		}
+		_, err := r.Read(nil)
+		if err == nil {
+			return 0, nil, xerrors.New("previous message not fully read")
+		}
+		if !xerrors.Is(err, io.EOF) {
+			return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err)
+		}
+
+		atomic.StoreInt64(&c.activeReader, 1)
 	}
 
 	select {
@@ -586,7 +578,8 @@ type messageReader struct {
 func (r messageReader) Read(p []byte) (int, error) {
 	n, err := r.read(p)
 	if err != nil {
-		// Have to return io.EOF directly for now, cannot wrap.
+		// Have to return io.EOF directly for now, we cannot wrap as xerrors
+		// isn't used in stdlib.
 		if err == io.EOF {
 			return n, io.EOF
 		}