good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit d432e6ba authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Implement complete close handshake

I changed my mind after #103 as browsers include a wasClean event to indicate
whether the connection was closed cleanly. From my tests, if a server using
this library prior to this commit initiates the close handshake, wasClean
will be false for the browser as the connection was closed before it could
respond with a close frame. Thus, I believe it's necessary to fully implement
the close handshake.

@stephenyama You'll enjoy this.
parent e795e467
No related branches found
No related tags found
No related merge requests found
......@@ -16,6 +16,10 @@ import (
"sync"
"sync/atomic"
"time"
"golang.org/x/xerrors"
"nhooyr.io/websocket/internal/bpool"
)
// Conn represents a WebSocket connection.
......@@ -62,6 +66,7 @@ type Conn struct {
writeMsgOpcode opcode
writeMsgCtx context.Context
readMsgLeft int64
readCloseFrame CloseError
// Used to ensure the previous reader is read till EOF before allowing
// a new one.
......@@ -69,6 +74,7 @@ type Conn struct {
// readFrameLock is acquired to read from bw.
readFrameLock chan struct{}
isReadClosed *atomicInt64
isCloseHandshake *atomicInt64
readHeaderBuf []byte
controlPayloadBuf []byte
......@@ -96,6 +102,7 @@ func (c *Conn) init() {
c.writeFrameLock = make(chan struct{}, 1)
c.readFrameLock = make(chan struct{}, 1)
c.isCloseHandshake = &atomicInt64{}
c.setReadTimeout = make(chan context.Context)
c.setWriteTimeout = make(chan context.Context)
......@@ -230,7 +237,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
}
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
err := c.acquireLock(context.Background(), c.readFrameLock)
err := c.acquireLock(ctx, c.readFrameLock)
if err != nil {
return header{}, err
}
......@@ -308,11 +315,22 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
c.Close(StatusProtocolError, err.Error())
return c.closeErr
}
// This ensures the closeErr of the Conn is always the received CloseError
// in case the echo close frame write fails.
// See https://github.com/nhooyr/websocket/issues/109
c.setCloseErr(fmt.Errorf("received close frame: %w", ce))
c.writeClose(b, nil)
c.readCloseFrame = ce
func() {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
c.writeControl(ctx, opClose, b)
}()
// We close with nil since the error is already set above.
c.close(nil)
return c.closeErr
default:
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
......@@ -347,6 +365,15 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return 0, nil, fmt.Errorf("websocket connection read closed")
}
if c.isCloseHandshake.Load() == 1 {
select {
case <-ctx.Done():
return 0, nil, fmt.Errorf("failed to get reader: %w", ctx.Err())
case <-c.closed:
return 0, nil, fmt.Errorf("failed to get reader: %w", c.closeErr)
}
}
typ, r, err := c.reader(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to get reader: %w", err)
......@@ -772,27 +799,28 @@ func (c *Conn) writePong(p []byte) error {
// Close closes the WebSocket connection with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5 seconds.
// It will write a WebSocket close frame and then wait for the peer to respond
// with its own close frame. The entire process must complete within 10 seconds.
// Thus, it implements the full WebSocket close handshake.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// This does not perform a WebSocket close handshake.
// See https://github.com/nhooyr/websocket/issues/103 for details on why.
//
// The maximum length of reason must be 125 bytes otherwise an internal
// error will be sent to the peer. For this reason, you should avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection.
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
err := c.closeHandshake(code, reason)
if err != nil {
return fmt.Errorf("failed to close websocket connection: %w", err)
}
return nil
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
......@@ -810,34 +838,64 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
p, _ = ce.bytes()
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
// Ensures the connection is closed if everything below succeeds.
// Up here because we must release the read lock first.
// nil because of the setCloseErr call below.
defer c.close(nil)
// CloseErrors sent are made opaque to prevent applications from thinking
// they received a given status.
sentErr := fmt.Errorf("sent close frame: %v", ce)
err = c.writeClose(p, sentErr)
// Other connections should only see this error.
c.setCloseErr(sentErr)
err = c.writeControl(ctx, opClose, p)
if err != nil {
return err
}
if !errors.Is(c.closeErr, sentErr) {
return c.closeErr
// Wait for close frame from peer.
err = c.waitClose(ctx)
// We didn't read a close frame.
if c.readCloseFrame == (CloseError{}) {
if ctx.Err() != nil {
return xerrors.Errorf("failed to wait for peer close frame: %w", ctx.Err())
}
// We need to make the err returned from c.waitClose accurate.
return xerrors.Errorf("failed to read peer close frame for unknown reason")
}
return nil
}
func (c *Conn) writeClose(p []byte, cerr error) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
func (c *Conn) waitClose(ctx context.Context) error {
b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
defer bpool.Put(b)
// If this fails, the connection had to have died.
err := c.writeControl(ctx, opClose, p)
// Prevent reads from user code as we are going to be
// discarding all messages so they cannot rely on any ordering.
c.isCloseHandshake.Store(1)
// From this point forward, any reader we receive means we are
// now the sole readers of the connection and so it is safe
// to discard all payloads.
for {
_, r, err := c.reader(ctx)
if err != nil {
return err
}
c.close(cerr)
return nil
// Discard all payloads.
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
if err != nil {
return err
}
}
}
// Ping sends a ping to the peer and waits for a pong.
......
......@@ -230,3 +230,7 @@ func (v *atomicInt64) String() string {
func (v *atomicInt64) Increment(delta int64) int64 {
return atomic.AddInt64(&v.v, delta)
}
func (v *atomicInt64) CAS(old, new int64) (swapped bool) {
return atomic.CompareAndSwapInt64(&v.v, old, new)
}
......@@ -856,6 +856,15 @@ func TestConn(t *testing.T) {
return nil
},
},
{
name: "closeHandshake",
server: func(ctx context.Context, c *websocket.Conn) error {
return c.Close(websocket.StatusNormalClosure, "")
},
client: func(ctx context.Context, c *websocket.Conn) error {
return c.Close(websocket.StatusNormalClosure, "")
},
},
}
for _, tc := range testCases {
tc := tc
......
......@@ -21,5 +21,6 @@ require (
golang.org/x/sys v0.0.0-20190927073244-c990c680b611 // indirect
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
golang.org/x/tools v0.0.0-20190920225731-5eefd052ad72
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
)
......@@ -36,8 +36,7 @@ type Conn struct {
readBufMu sync.Mutex
readBuf []wsjs.MessageEvent
// Only used by tests
receivedCloseFrame chan struct{}
closeEventCh chan wsjs.CloseEvent
}
func (c *Conn) close(err error) {
......@@ -58,10 +57,11 @@ func (c *Conn) init() {
c.isReadClosed = &atomicInt64{}
c.receivedCloseFrame = make(chan struct{})
c.closeEventCh = make(chan wsjs.CloseEvent, 1)
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
close(c.receivedCloseFrame)
c.closeEventCh <- e
close(c.closeEventCh)
cerr := CloseError{
Code: StatusCode(e.Code),
......@@ -193,24 +193,46 @@ func (c *Conn) isClosed() bool {
}
// Close closes the websocket with the given code and reason.
// It will wait until the peer responds with a close frame
// or the connection is closed.
// It thus performs the full WebSocket close handshake.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
if err != nil {
return fmt.Errorf("failed to close websocket: %w", err)
}
return nil
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
if c.isClosed() {
return fmt.Errorf("already closed: %w", c.closeErr)
}
err := fmt.Errorf("sent close frame: %v", CloseError{
cerr := CloseError{
Code: code,
Reason: reason,
})
err2 := c.ws.Close(int(code), reason)
if err2 != nil {
err = err2
}
c.close(err)
closeErr := fmt.Errorf("sent close frame: %v", cerr)
c.close(closeErr)
if !errors.Is(c.closeErr, closeErr) {
return c.closeErr
}
if !errors.Is(c.closeErr, err) {
return fmt.Errorf("failed to close websocket: %w", err)
// We're the only goroutine allowed to get this far.
// The only possible error from closing the connection here
// is that the connection is already closed in which case,
// we do not really care.
c.ws.Close(int(code), reason)
// Guaranteed for this channel receive to succeed since the above
// if statement means we are the goroutine that closed this connection.
ev := <-c.closeEventCh
if !ev.WasClean {
return fmt.Errorf("unclean connection close: %v", CloseError{
Code: StatusCode(ev.Code),
Reason: ev.Reason,
})
}
return nil
......
// +build js
package websocket
import (
"context"
"fmt"
)
func (c *Conn) WaitCloseFrame(ctx context.Context) error {
select {
case <-c.receivedCloseFrame:
return nil
case <-ctx.Done():
return fmt.Errorf("failed to wait for close frame: %w", ctx.Err())
}
}
......@@ -49,9 +49,4 @@ func TestConn(t *testing.T) {
if err != nil {
t.Fatal(err)
}
err = c.WaitCloseFrame(ctx)
if err != nil {
t.Fatal(err)
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment