diff --git a/conn.go b/conn.go index 43a94397a3caf85a6cf03e0db08af9e38b8e4a7e..861b2390a9de905278ca10a75f2bc22d54ffc2b4 100644 --- a/conn.go +++ b/conn.go @@ -851,6 +851,13 @@ func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, e // complete. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason, true) + var ec errClosing + if errors.As(err, &ec) { + <-c.closed + // We wait until the connection closes. + // We use writeClose and not exportedClose to avoid a second failed to marshal close frame error. + err = c.writeClose(nil, ec.ce, true) + } if err != nil { return fmt.Errorf("failed to close websocket connection: %w", err) } @@ -878,15 +885,31 @@ func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) err return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake) } +type errClosing struct { + ce error +} + +func (e errClosing) Error() string { + return "already closing connection" +} + func (c *Conn) writeClose(p []byte, ce error, handshake bool) error { - select { - case <-c.closed: - return fmt.Errorf("tried to close with %v but connection already closed: %w", ce, c.closeErr) - default: + if c.isClosed() { + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } if !c.closing.CAS(0, 1) { - return fmt.Errorf("another goroutine is closing") + // Normally, we would want to wait until the connection is closed, + // at least for when a user calls into Close, so we handle that case in + // the exported Close function. + // + // But for internal library usage, we always want to return early, e.g. + // if we are performing a close handshake and the peer sends their close frame, + // we do not want to block here waiting for c.closed to close because it won't, + // at least not until we return since the gorouine that will close it is this one. + return errClosing{ + ce: ce, + } } // No matter what happens next, close error should be set. diff --git a/conn_common.go b/conn_common.go index 9f0b045a1c54f94057ae2fe932c14c7ccd6dfdf6..1247df6e13517a1bd89a04e290bfa046019ea724 100644 --- a/conn_common.go +++ b/conn_common.go @@ -234,3 +234,12 @@ func (v *atomicInt64) Increment(delta int64) int64 { func (v *atomicInt64) CAS(old, new int64) (swapped bool) { return atomic.CompareAndSwapInt64(&v.v, old, new) } + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/conn_test.go b/conn_test.go index 1acdf5951ec38d8786447702a289a654aa7e44ec..8413c4c2afc2ae27df72612669d07452dbc6aff6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -602,7 +602,11 @@ func TestConn(t *testing.T) { { name: "largeControlFrame", server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte(strings.Repeat("x", 4096))) + err := c.WriteHeader(ctx, websocket.Header{ + Fin: true, + OpCode: websocket.OpClose, + PayloadLength: 4096, + }) if err != nil { return err } diff --git a/websocket_js.go b/websocket_js.go index f297f9d4f6147ca15d3bb6c789315532da57f774..d27809cf71772f8d5720df08e8153de00d979639 100644 --- a/websocket_js.go +++ b/websocket_js.go @@ -23,7 +23,7 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit *atomicInt64 - closeMu sync.Mutex + closingMu sync.Mutex isReadClosed *atomicInt64 closeOnce sync.Once closed chan struct{} @@ -43,6 +43,9 @@ func (c *Conn) close(err error, wasClean bool) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) + if !wasClean { + err = fmt.Errorf("unclean connection close: %w", err) + } c.setCloseErr(err) c.closeWasClean = wasClean close(c.closed) @@ -59,14 +62,11 @@ func (c *Conn) init() { c.isReadClosed = &atomicInt64{} c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { - var err error = CloseError{ + err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } - if !e.WasClean { - err = fmt.Errorf("connection close was not clean: %w", err) - } - c.close(err, e.WasClean) + c.close(fmt.Errorf("received close: %w", err), e.WasClean) c.releaseOnClose() c.releaseOnMessage() @@ -182,15 +182,6 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { } } -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} - // Close closes the websocket with the given code and reason. // It will wait until the peer responds with a close frame // or the connection is closed. @@ -204,13 +195,19 @@ func (c *Conn) Close(code StatusCode, reason string) error { } func (c *Conn) exportedClose(code StatusCode, reason string) error { - c.closeMu.Lock() - defer c.closeMu.Unlock() + c.closingMu.Lock() + defer c.closingMu.Unlock() + + ce := fmt.Errorf("sent close: %w", CloseError{ + Code: code, + Reason: reason, + }) if c.isClosed() { - return fmt.Errorf("already closed: %w", c.closeErr) + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) } + c.setCloseErr(ce) err := c.ws.Close(int(code), reason) if err != nil { return err