From 224ef23799cb71fd8fabc27a20503e978a18e048 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 7 Oct 2019 01:20:55 -0400
Subject: [PATCH] Cleanup close handshake implementation

---
 conn.go             | 157 ++++++++++++++++++++------------------------
 conn_export_test.go |   4 +-
 go.mod              |   1 -
 websocket_js.go     |  54 +++++----------
 4 files changed, 92 insertions(+), 124 deletions(-)

diff --git a/conn.go b/conn.go
index 73d6490..b162a42 100644
--- a/conn.go
+++ b/conn.go
@@ -17,8 +17,6 @@ import (
 	"sync/atomic"
 	"time"
 
-	"golang.org/x/xerrors"
-
 	"nhooyr.io/websocket/internal/bpool"
 )
 
@@ -66,7 +64,6 @@ type Conn struct {
 	writeMsgOpcode opcode
 	writeMsgCtx    context.Context
 	readMsgLeft    int64
-	readCloseFrame CloseError
 
 	// Used to ensure the previous reader is read till EOF before allowing
 	// a new one.
@@ -74,7 +71,6 @@ type Conn struct {
 	// readFrameLock is acquired to read from bw.
 	readFrameLock     chan struct{}
 	isReadClosed      *atomicInt64
-	isCloseHandshake  *atomicInt64
 	readHeaderBuf     []byte
 	controlPayloadBuf []byte
 
@@ -102,7 +98,6 @@ func (c *Conn) init() {
 	c.writeFrameLock = make(chan struct{}, 1)
 
 	c.readFrameLock = make(chan struct{}, 1)
-	c.isCloseHandshake = &atomicInt64{}
 
 	c.setReadTimeout = make(chan context.Context)
 	c.setWriteTimeout = make(chan context.Context)
@@ -206,20 +201,20 @@ func (c *Conn) releaseLock(lock chan struct{}) {
 	}
 }
 
-func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
+func (c *Conn) readTillMsg(ctx context.Context, lock bool) (header, error) {
 	for {
-		h, err := c.readFrameHeader(ctx)
+		h, err := c.readFrameHeader(ctx, lock)
 		if err != nil {
 			return header{}, err
 		}
 
 		if h.rsv1 || h.rsv2 || h.rsv3 {
-			c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
+			c.writeClose(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3), false)
 			return header{}, c.closeErr
 		}
 
 		if h.opcode.controlOp() {
-			err = c.handleControl(ctx, h)
+			err = c.handleControl(ctx, h, lock)
 			if err != nil {
 				return header{}, fmt.Errorf("failed to handle control frame: %w", err)
 			}
@@ -230,18 +225,20 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
 		case opBinary, opText, opContinuation:
 			return h, nil
 		default:
-			c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode))
+			c.writeClose(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode), false)
 			return header{}, c.closeErr
 		}
 	}
 }
 
-func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
-	err := c.acquireLock(ctx, c.readFrameLock)
-	if err != nil {
-		return header{}, err
+func (c *Conn) readFrameHeader(ctx context.Context, lock bool) (header, error) {
+	if lock {
+		err := c.acquireLock(ctx, c.readFrameLock)
+		if err != nil {
+			return header{}, err
+		}
+		defer c.releaseLock(c.readFrameLock)
 	}
-	defer c.releaseLock(c.readFrameLock)
 
 	select {
 	case <-c.closed:
@@ -273,14 +270,14 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
 	return h, nil
 }
 
-func (c *Conn) handleControl(ctx context.Context, h header) error {
+func (c *Conn) handleControl(ctx context.Context, h header, lock bool) error {
 	if h.payloadLength > maxControlFramePayload {
-		c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength))
+		c.writeClose(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength), false)
 		return c.closeErr
 	}
 
 	if !h.fin {
-		c.Close(StatusProtocolError, "received fragmented control frame")
+		c.writeClose(StatusProtocolError, "received fragmented control frame", false)
 		return c.closeErr
 	}
 
@@ -288,7 +285,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 	defer cancel()
 
 	b := c.controlPayloadBuf[:h.payloadLength]
-	_, err := c.readFramePayload(ctx, b)
+	_, err := c.readFramePayload(ctx, b, lock)
 	if err != nil {
 		return err
 	}
@@ -312,16 +309,14 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 		ce, err := parseClosePayload(b)
 		if err != nil {
 			err = fmt.Errorf("received invalid close payload: %w", err)
-			c.Close(StatusProtocolError, err.Error())
+			c.writeClose(StatusProtocolError, err.Error(), false)
 			return c.closeErr
 		}
 
 		// This ensures the closeErr of the Conn is always the received CloseError
 		// in case the echo close frame write fails.
 		// See https://github.com/nhooyr/websocket/issues/109
-		c.setCloseErr(fmt.Errorf("received close frame: %w", ce))
-
-		c.readCloseFrame = ce
+		c.setCloseErr(ce)
 
 		func() {
 			ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
@@ -329,6 +324,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 			c.writeControl(ctx, opClose, b)
 		}()
 
+		if !lock {
+			c.releaseLock(c.readFrameLock)
+		}
 		// We close with nil since the error is already set above.
 		c.close(nil)
 		return c.closeErr
@@ -362,16 +360,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
 // Most users should not need this.
 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 	if c.isReadClosed.Load() == 1 {
-		return 0, nil, fmt.Errorf("websocket connection read closed")
-	}
-
-	if c.isCloseHandshake.Load() == 1 {
-		select {
-		case <-ctx.Done():
-			return 0, nil, fmt.Errorf("failed to get reader: %w", ctx.Err())
-		case <-c.closed:
-			return 0, nil, fmt.Errorf("failed to get reader: %w", c.closeErr)
-		}
+		return 0, nil, errors.New("websocket connection read closed")
 	}
 
 	typ, r, err := c.reader(ctx)
@@ -381,23 +370,23 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 	return typ, r, nil
 }
 
-func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
+func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
 	if c.activeReader != nil && !c.readerFrameEOF {
 		// The only way we know for sure the previous reader is not yet complete is
 		// if there is an active frame not yet fully read.
 		// Otherwise, a user may have read the last byte but not the EOF if the EOF
 		// is in the next frame so we check for that below.
-		return 0, nil, fmt.Errorf("previous message not read to completion")
+		return 0, nil, errors.New("previous message not read to completion")
 	}
 
-	h, err := c.readTillMsg(ctx)
+	h, err := c.readTillMsg(ctx, true)
 	if err != nil {
 		return 0, nil, err
 	}
 
 	if c.activeReader != nil && !c.activeReader.eof() {
 		if h.opcode != opContinuation {
-			c.Close(StatusProtocolError, "received new data message without finishing the previous message")
+			c.writeClose(StatusProtocolError, "received new data message without finishing the previous message", false)
 			return 0, nil, c.closeErr
 		}
 
@@ -407,12 +396,12 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
 
 		c.activeReader = nil
 
-		h, err = c.readTillMsg(ctx)
+		h, err = c.readTillMsg(ctx, true)
 		if err != nil {
 			return 0, nil, err
 		}
 	} else if h.opcode == opContinuation {
-		c.Close(StatusProtocolError, "received continuation frame not after data or text frame")
+		c.writeClose(StatusProtocolError, "received continuation frame not after data or text frame", false)
 		return 0, nil, c.closeErr
 	}
 
@@ -458,7 +447,7 @@ func (r *messageReader) read(p []byte) (int, error) {
 	}
 
 	if r.c.readMsgLeft <= 0 {
-		r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit))
+		r.c.writeClose(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit), false)
 		return 0, r.c.closeErr
 	}
 
@@ -467,13 +456,13 @@ func (r *messageReader) read(p []byte) (int, error) {
 	}
 
 	if r.c.readerFrameEOF {
-		h, err := r.c.readTillMsg(r.c.readerMsgCtx)
+		h, err := r.c.readTillMsg(r.c.readerMsgCtx, true)
 		if err != nil {
 			return 0, err
 		}
 
 		if h.opcode != opContinuation {
-			r.c.Close(StatusProtocolError, "received new data message without finishing the previous message")
+			r.c.writeClose(StatusProtocolError, "received new data message without finishing the previous message", false)
 			return 0, r.c.closeErr
 		}
 
@@ -487,7 +476,7 @@ func (r *messageReader) read(p []byte) (int, error) {
 		p = p[:h.payloadLength]
 	}
 
-	n, err := r.c.readFramePayload(r.c.readerMsgCtx, p)
+	n, err := r.c.readFramePayload(r.c.readerMsgCtx, p, true)
 
 	h.payloadLength -= int64(n)
 	r.c.readMsgLeft -= int64(n)
@@ -512,12 +501,14 @@ func (r *messageReader) read(p []byte) (int, error) {
 	return n, nil
 }
 
-func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
-	err := c.acquireLock(ctx, c.readFrameLock)
-	if err != nil {
-		return 0, err
+func (c *Conn) readFramePayload(ctx context.Context, p []byte, lock bool) (int, error) {
+	if lock {
+		err := c.acquireLock(ctx, c.readFrameLock)
+		if err != nil {
+			return 0, err
+		}
+		defer c.releaseLock(c.readFrameLock)
 	}
-	defer c.releaseLock(c.readFrameLock)
 
 	select {
 	case <-c.closed:
@@ -813,14 +804,14 @@ func (c *Conn) writePong(p []byte) error {
 // Close will unblock all goroutines interacting with the connection once
 // complete.
 func (c *Conn) Close(code StatusCode, reason string) error {
-	err := c.closeHandshake(code, reason)
+	err := c.writeClose(code, reason, true)
 	if err != nil {
 		return fmt.Errorf("failed to close websocket connection: %w", err)
 	}
 	return nil
 }
 
-func (c *Conn) closeHandshake(code StatusCode, reason string) error {
+func (c *Conn) writeClose(code StatusCode, reason string, handshake bool) error {
 	ce := CloseError{
 		Code:   code,
 		Reason: reason,
@@ -838,60 +829,58 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error {
 		p, _ = ce.bytes()
 	}
 
+	// Give the handshake 10 seconds.
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
 	defer cancel()
 
-	// Ensures the connection is closed if everything below succeeds.
-	// Up here because we must release the read lock first.
-	// nil because of the setCloseErr call below.
-	defer c.close(nil)
-
-	// CloseErrors sent are made opaque to prevent applications from thinking
-	// they received a given status.
-	sentErr := fmt.Errorf("sent close frame: %v", ce)
-	// Other connections should only see this error.
-	c.setCloseErr(sentErr)
-
 	err = c.writeControl(ctx, opClose, p)
 	if err != nil {
 		return err
 	}
+	c.setCloseErr(ce)
+	defer c.close(nil)
 
-	// Wait for close frame from peer.
-	err = c.waitClose(ctx)
-	// We didn't read a close frame.
-	if c.readCloseFrame == (CloseError{}) {
-		if ctx.Err() != nil {
-			return xerrors.Errorf("failed to wait for peer close frame: %w", ctx.Err())
-		}
-		// We need to make the err returned from c.waitClose accurate.
-		return xerrors.Errorf("failed to read peer close frame for unknown reason")
+	if handshake {
+		// Try to wait for close frame peer but don't complain
+		// if one is not received since we already decided the
+		// close status of the connection above.
+		c.waitClose(ctx)
 	}
+
 	return nil
 }
 
 func (c *Conn) waitClose(ctx context.Context) error {
+	err := c.acquireLock(ctx, c.readFrameLock)
+	if err != nil {
+		return err
+	}
+	defer c.releaseLock(c.readFrameLock)
+
 	b := bpool.Get()
-	buf := b.Bytes()
-	buf = buf[:cap(buf)]
 	defer bpool.Put(b)
 
-	// Prevent reads from user code as we are going to be
-	// discarding all messages so they cannot rely on any ordering.
-	c.isCloseHandshake.Store(1)
-
-	// From this point forward, any reader we receive means we are
-	// now the sole readers of the connection and so it is safe
-	// to discard all payloads.
+	var h header
+	if c.activeReader != nil && !c.readerFrameEOF {
+		h = c.readerMsgHeader
+	}
 
 	for {
-		_, r, err := c.reader(ctx)
-		if err != nil {
-			return err
+		for h.payloadLength > 0 {
+			buf := b.Bytes()
+			if int64(cap(buf)) > h.payloadLength {
+				buf = buf[:h.payloadLength]
+			} else {
+				buf = buf[:cap(buf)]
+			}
+			n, err := c.readFramePayload(ctx, buf, false)
+			if err != nil {
+				return err
+			}
+			h.payloadLength -= int64(n)
 		}
 
-		// Discard all payloads.
-		_, err = io.CopyBuffer(ioutil.Discard, r, buf)
+		h, err = c.readTillMsg(ctx, false)
 		if err != nil {
 			return err
 		}
diff --git a/conn_export_test.go b/conn_export_test.go
index 32340b5..0fa3272 100644
--- a/conn_export_test.go
+++ b/conn_export_test.go
@@ -23,12 +23,12 @@ const (
 )
 
 func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) {
-	h, err := c.readFrameHeader(ctx)
+	h, err := c.readFrameHeader(ctx, true)
 	if err != nil {
 		return 0, nil, err
 	}
 	b := make([]byte, h.payloadLength)
-	_, err = c.readFramePayload(ctx, b)
+	_, err = c.readFramePayload(ctx, b, true)
 	if err != nil {
 		return 0, nil, err
 	}
diff --git a/go.mod b/go.mod
index 2a5bbae..0e39836 100644
--- a/go.mod
+++ b/go.mod
@@ -21,6 +21,5 @@ require (
 	golang.org/x/sys v0.0.0-20190927073244-c990c680b611 // indirect
 	golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
 	golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72
-	golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
 	gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
 )
diff --git a/websocket_js.go b/websocket_js.go
index 27b8371..d11266d 100644
--- a/websocket_js.go
+++ b/websocket_js.go
@@ -23,11 +23,12 @@ type Conn struct {
 	// read limit for a message in bytes.
 	msgReadLimit *atomicInt64
 
-	isReadClosed *atomicInt64
-	closeOnce    sync.Once
-	closed       chan struct{}
-	closeErrOnce sync.Once
-	closeErr     error
+	isReadClosed  *atomicInt64
+	closeOnce     sync.Once
+	closed        chan struct{}
+	closeErrOnce  sync.Once
+	closeErr      error
+	closeWasClean bool
 
 	releaseOnClose   func()
 	releaseOnMessage func()
@@ -35,15 +36,14 @@ type Conn struct {
 	readSignal chan struct{}
 	readBufMu  sync.Mutex
 	readBuf    []wsjs.MessageEvent
-
-	closeEventCh chan wsjs.CloseEvent
 }
 
-func (c *Conn) close(err error) {
+func (c *Conn) close(err error, wasClean bool) {
 	c.closeOnce.Do(func() {
 		runtime.SetFinalizer(c, nil)
 
 		c.setCloseErr(err)
+		c.closeWasClean = wasClean
 		close(c.closed)
 	})
 }
@@ -57,18 +57,15 @@ func (c *Conn) init() {
 
 	c.isReadClosed = &atomicInt64{}
 
-	c.closeEventCh = make(chan wsjs.CloseEvent, 1)
-
 	c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
-		c.closeEventCh <- e
-		close(c.closeEventCh)
-
-		cerr := CloseError{
+		var err error = CloseError{
 			Code:   StatusCode(e.Code),
 			Reason: e.Reason,
 		}
-
-		c.close(fmt.Errorf("received close frame: %w", cerr))
+		if !e.WasClean {
+			err = fmt.Errorf("connection close was not clean: %w", err)
+		}
+		c.close(err, e.WasClean)
 
 		c.releaseOnClose()
 		c.releaseOnMessage()
@@ -209,32 +206,15 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
 		return fmt.Errorf("already closed: %w", c.closeErr)
 	}
 
-	cerr := CloseError{
-		Code:   code,
-		Reason: reason,
-	}
-	closeErr := fmt.Errorf("sent close frame: %v", cerr)
-	c.close(closeErr)
-	if !errors.Is(c.closeErr, closeErr) {
-		return c.closeErr
-	}
-
-	// We're the only goroutine allowed to get this far.
 	// The only possible error from closing the connection here
 	// is that the connection is already closed in which case,
-	// we do not really care.
+	// we do not really care since c.closed will immediately return.
 	c.ws.Close(int(code), reason)
 
-	// Guaranteed for this channel receive to succeed since the above
-	// if statement means we are the goroutine that closed this connection.
-	ev := <-c.closeEventCh
-	if !ev.WasClean {
-		return fmt.Errorf("unclean connection close: %v", CloseError{
-			Code:   StatusCode(ev.Code),
-			Reason: ev.Reason,
-		})
+	<-c.closed
+	if !c.closeWasClean {
+		return c.closeErr
 	}
-
 	return nil
 }
 
-- 
GitLab