From 224ef23799cb71fd8fabc27a20503e978a18e048 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 7 Oct 2019 01:20:55 -0400 Subject: [PATCH] Cleanup close handshake implementation --- conn.go | 157 ++++++++++++++++++++------------------------ conn_export_test.go | 4 +- go.mod | 1 - websocket_js.go | 54 +++++---------- 4 files changed, 92 insertions(+), 124 deletions(-) diff --git a/conn.go b/conn.go index 73d6490..b162a42 100644 --- a/conn.go +++ b/conn.go @@ -17,8 +17,6 @@ import ( "sync/atomic" "time" - "golang.org/x/xerrors" - "nhooyr.io/websocket/internal/bpool" ) @@ -66,7 +64,6 @@ type Conn struct { writeMsgOpcode opcode writeMsgCtx context.Context readMsgLeft int64 - readCloseFrame CloseError // Used to ensure the previous reader is read till EOF before allowing // a new one. @@ -74,7 +71,6 @@ type Conn struct { // readFrameLock is acquired to read from bw. readFrameLock chan struct{} isReadClosed *atomicInt64 - isCloseHandshake *atomicInt64 readHeaderBuf []byte controlPayloadBuf []byte @@ -102,7 +98,6 @@ func (c *Conn) init() { c.writeFrameLock = make(chan struct{}, 1) c.readFrameLock = make(chan struct{}, 1) - c.isCloseHandshake = &atomicInt64{} c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) @@ -206,20 +201,20 @@ func (c *Conn) releaseLock(lock chan struct{}) { } } -func (c *Conn) readTillMsg(ctx context.Context) (header, error) { +func (c *Conn) readTillMsg(ctx context.Context, lock bool) (header, error) { for { - h, err := c.readFrameHeader(ctx) + h, err := c.readFrameHeader(ctx, lock) if err != nil { return header{}, err } if h.rsv1 || h.rsv2 || h.rsv3 { - c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) + c.writeClose(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3), false) return header{}, c.closeErr } if h.opcode.controlOp() { - err = c.handleControl(ctx, h) + err = c.handleControl(ctx, h, lock) if err != nil { return header{}, fmt.Errorf("failed to handle control frame: %w", err) } @@ -230,18 +225,20 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { case opBinary, opText, opContinuation: return h, nil default: - c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode)) + c.writeClose(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode), false) return header{}, c.closeErr } } } -func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { - err := c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return header{}, err +func (c *Conn) readFrameHeader(ctx context.Context, lock bool) (header, error) { + if lock { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return header{}, err + } + defer c.releaseLock(c.readFrameLock) } - defer c.releaseLock(c.readFrameLock) select { case <-c.closed: @@ -273,14 +270,14 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { return h, nil } -func (c *Conn) handleControl(ctx context.Context, h header) error { +func (c *Conn) handleControl(ctx context.Context, h header, lock bool) error { if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength)) + c.writeClose(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength), false) return c.closeErr } if !h.fin { - c.Close(StatusProtocolError, "received fragmented control frame") + c.writeClose(StatusProtocolError, "received fragmented control frame", false) return c.closeErr } @@ -288,7 +285,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { defer cancel() b := c.controlPayloadBuf[:h.payloadLength] - _, err := c.readFramePayload(ctx, b) + _, err := c.readFramePayload(ctx, b, lock) if err != nil { return err } @@ -312,16 +309,14 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) - c.Close(StatusProtocolError, err.Error()) + c.writeClose(StatusProtocolError, err.Error(), false) return c.closeErr } // This ensures the closeErr of the Conn is always the received CloseError // in case the echo close frame write fails. // See https://github.com/nhooyr/websocket/issues/109 - c.setCloseErr(fmt.Errorf("received close frame: %w", ce)) - - c.readCloseFrame = ce + c.setCloseErr(ce) func() { ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) @@ -329,6 +324,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { c.writeControl(ctx, opClose, b) }() + if !lock { + c.releaseLock(c.readFrameLock) + } // We close with nil since the error is already set above. c.close(nil) return c.closeErr @@ -362,16 +360,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { // Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, fmt.Errorf("websocket connection read closed") - } - - if c.isCloseHandshake.Load() == 1 { - select { - case <-ctx.Done(): - return 0, nil, fmt.Errorf("failed to get reader: %w", ctx.Err()) - case <-c.closed: - return 0, nil, fmt.Errorf("failed to get reader: %w", c.closeErr) - } + return 0, nil, errors.New("websocket connection read closed") } typ, r, err := c.reader(ctx) @@ -381,23 +370,23 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return typ, r, nil } -func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { if c.activeReader != nil && !c.readerFrameEOF { // The only way we know for sure the previous reader is not yet complete is // if there is an active frame not yet fully read. // Otherwise, a user may have read the last byte but not the EOF if the EOF // is in the next frame so we check for that below. - return 0, nil, fmt.Errorf("previous message not read to completion") + return 0, nil, errors.New("previous message not read to completion") } - h, err := c.readTillMsg(ctx) + h, err := c.readTillMsg(ctx, true) if err != nil { return 0, nil, err } if c.activeReader != nil && !c.activeReader.eof() { if h.opcode != opContinuation { - c.Close(StatusProtocolError, "received new data message without finishing the previous message") + c.writeClose(StatusProtocolError, "received new data message without finishing the previous message", false) return 0, nil, c.closeErr } @@ -407,12 +396,12 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { c.activeReader = nil - h, err = c.readTillMsg(ctx) + h, err = c.readTillMsg(ctx, true) if err != nil { return 0, nil, err } } else if h.opcode == opContinuation { - c.Close(StatusProtocolError, "received continuation frame not after data or text frame") + c.writeClose(StatusProtocolError, "received continuation frame not after data or text frame", false) return 0, nil, c.closeErr } @@ -458,7 +447,7 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.c.readMsgLeft <= 0 { - r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit)) + r.c.writeClose(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit), false) return 0, r.c.closeErr } @@ -467,13 +456,13 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.c.readerFrameEOF { - h, err := r.c.readTillMsg(r.c.readerMsgCtx) + h, err := r.c.readTillMsg(r.c.readerMsgCtx, true) if err != nil { return 0, err } if h.opcode != opContinuation { - r.c.Close(StatusProtocolError, "received new data message without finishing the previous message") + r.c.writeClose(StatusProtocolError, "received new data message without finishing the previous message", false) return 0, r.c.closeErr } @@ -487,7 +476,7 @@ func (r *messageReader) read(p []byte) (int, error) { p = p[:h.payloadLength] } - n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) + n, err := r.c.readFramePayload(r.c.readerMsgCtx, p, true) h.payloadLength -= int64(n) r.c.readMsgLeft -= int64(n) @@ -512,12 +501,14 @@ func (r *messageReader) read(p []byte) (int, error) { return n, nil } -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { - err := c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return 0, err +func (c *Conn) readFramePayload(ctx context.Context, p []byte, lock bool) (int, error) { + if lock { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return 0, err + } + defer c.releaseLock(c.readFrameLock) } - defer c.releaseLock(c.readFrameLock) select { case <-c.closed: @@ -813,14 +804,14 @@ func (c *Conn) writePong(p []byte) error { // Close will unblock all goroutines interacting with the connection once // complete. func (c *Conn) Close(code StatusCode, reason string) error { - err := c.closeHandshake(code, reason) + err := c.writeClose(code, reason, true) if err != nil { return fmt.Errorf("failed to close websocket connection: %w", err) } return nil } -func (c *Conn) closeHandshake(code StatusCode, reason string) error { +func (c *Conn) writeClose(code StatusCode, reason string, handshake bool) error { ce := CloseError{ Code: code, Reason: reason, @@ -838,60 +829,58 @@ func (c *Conn) closeHandshake(code StatusCode, reason string) error { p, _ = ce.bytes() } + // Give the handshake 10 seconds. ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - // Ensures the connection is closed if everything below succeeds. - // Up here because we must release the read lock first. - // nil because of the setCloseErr call below. - defer c.close(nil) - - // CloseErrors sent are made opaque to prevent applications from thinking - // they received a given status. - sentErr := fmt.Errorf("sent close frame: %v", ce) - // Other connections should only see this error. - c.setCloseErr(sentErr) - err = c.writeControl(ctx, opClose, p) if err != nil { return err } + c.setCloseErr(ce) + defer c.close(nil) - // Wait for close frame from peer. - err = c.waitClose(ctx) - // We didn't read a close frame. - if c.readCloseFrame == (CloseError{}) { - if ctx.Err() != nil { - return xerrors.Errorf("failed to wait for peer close frame: %w", ctx.Err()) - } - // We need to make the err returned from c.waitClose accurate. - return xerrors.Errorf("failed to read peer close frame for unknown reason") + if handshake { + // Try to wait for close frame peer but don't complain + // if one is not received since we already decided the + // close status of the connection above. + c.waitClose(ctx) } + return nil } func (c *Conn) waitClose(ctx context.Context) error { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return err + } + defer c.releaseLock(c.readFrameLock) + b := bpool.Get() - buf := b.Bytes() - buf = buf[:cap(buf)] defer bpool.Put(b) - // Prevent reads from user code as we are going to be - // discarding all messages so they cannot rely on any ordering. - c.isCloseHandshake.Store(1) - - // From this point forward, any reader we receive means we are - // now the sole readers of the connection and so it is safe - // to discard all payloads. + var h header + if c.activeReader != nil && !c.readerFrameEOF { + h = c.readerMsgHeader + } for { - _, r, err := c.reader(ctx) - if err != nil { - return err + for h.payloadLength > 0 { + buf := b.Bytes() + if int64(cap(buf)) > h.payloadLength { + buf = buf[:h.payloadLength] + } else { + buf = buf[:cap(buf)] + } + n, err := c.readFramePayload(ctx, buf, false) + if err != nil { + return err + } + h.payloadLength -= int64(n) } - // Discard all payloads. - _, err = io.CopyBuffer(ioutil.Discard, r, buf) + h, err = c.readTillMsg(ctx, false) if err != nil { return err } diff --git a/conn_export_test.go b/conn_export_test.go index 32340b5..0fa3272 100644 --- a/conn_export_test.go +++ b/conn_export_test.go @@ -23,12 +23,12 @@ const ( ) func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) { - h, err := c.readFrameHeader(ctx) + h, err := c.readFrameHeader(ctx, true) if err != nil { return 0, nil, err } b := make([]byte, h.payloadLength) - _, err = c.readFramePayload(ctx, b) + _, err = c.readFramePayload(ctx, b, true) if err != nil { return 0, nil, err } diff --git a/go.mod b/go.mod index 2a5bbae..0e39836 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,5 @@ require ( golang.org/x/sys v0.0.0-20190927073244-c990c680b611 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72 - golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect ) diff --git a/websocket_js.go b/websocket_js.go index 27b8371..d11266d 100644 --- a/websocket_js.go +++ b/websocket_js.go @@ -23,11 +23,12 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit *atomicInt64 - isReadClosed *atomicInt64 - closeOnce sync.Once - closed chan struct{} - closeErrOnce sync.Once - closeErr error + isReadClosed *atomicInt64 + closeOnce sync.Once + closed chan struct{} + closeErrOnce sync.Once + closeErr error + closeWasClean bool releaseOnClose func() releaseOnMessage func() @@ -35,15 +36,14 @@ type Conn struct { readSignal chan struct{} readBufMu sync.Mutex readBuf []wsjs.MessageEvent - - closeEventCh chan wsjs.CloseEvent } -func (c *Conn) close(err error) { +func (c *Conn) close(err error, wasClean bool) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) c.setCloseErr(err) + c.closeWasClean = wasClean close(c.closed) }) } @@ -57,18 +57,15 @@ func (c *Conn) init() { c.isReadClosed = &atomicInt64{} - c.closeEventCh = make(chan wsjs.CloseEvent, 1) - c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { - c.closeEventCh <- e - close(c.closeEventCh) - - cerr := CloseError{ + var err error = CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } - - c.close(fmt.Errorf("received close frame: %w", cerr)) + if !e.WasClean { + err = fmt.Errorf("connection close was not clean: %w", err) + } + c.close(err, e.WasClean) c.releaseOnClose() c.releaseOnMessage() @@ -209,32 +206,15 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { return fmt.Errorf("already closed: %w", c.closeErr) } - cerr := CloseError{ - Code: code, - Reason: reason, - } - closeErr := fmt.Errorf("sent close frame: %v", cerr) - c.close(closeErr) - if !errors.Is(c.closeErr, closeErr) { - return c.closeErr - } - - // We're the only goroutine allowed to get this far. // The only possible error from closing the connection here // is that the connection is already closed in which case, - // we do not really care. + // we do not really care since c.closed will immediately return. c.ws.Close(int(code), reason) - // Guaranteed for this channel receive to succeed since the above - // if statement means we are the goroutine that closed this connection. - ev := <-c.closeEventCh - if !ev.WasClean { - return fmt.Errorf("unclean connection close: %v", CloseError{ - Code: StatusCode(ev.Code), - Reason: ev.Reason, - }) + <-c.closed + if !c.closeWasClean { + return c.closeErr } - return nil } -- GitLab