diff --git a/conn.go b/conn.go index bc115e38f4d8b328c3c6848c831cf3bdf93ebbb6..e12e1443027071feca94ee2220686a2b5fa7dab3 100644 --- a/conn.go +++ b/conn.go @@ -406,27 +406,6 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return MessageType(h.opcode), r, nil } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. -// The returned context will be cancelled when the connection is closed. -// -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. -func (c *Conn) CloseRead(ctx context.Context) context.Context { - atomic.StoreInt64(&c.readClosed, 1) - - ctx, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - // We use the unexported reader so that we don't get the read closed error. - c.reader(ctx) - c.Close(StatusPolicyViolation, "unexpected data message") - }() - return ctx -} - // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { c *Conn diff --git a/doc.go b/doc.go index da6f32227207318d0ffa0aaa8cd3677d144a386f..2a5a0a1ab27f82884baf911515407a5f07338ad7 100644 --- a/doc.go +++ b/doc.go @@ -25,10 +25,10 @@ // // See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket // -// Thus the unsupported features when compiling to WASM are: +// Thus the unsupported features (not compiled in) for WASM are: // - Accept and AcceptOptions -// - Conn's Reader, Writer, SetReadLimit, Ping methods -// - HTTPClient and HTTPHeader dial options +// - Conn's Reader, Writer, SetReadLimit and Ping methods +// - HTTPClient and HTTPHeader fields in DialOptions // // The *http.Response returned by Dial will always either be nil or &http.Response{} as // we do not have access to the handshake response in the browser. diff --git a/netconn.go b/netconn.go index 8efdade22d9b172c29dfe6053cb7c7e292b5a32d..34f771c6d1a84429822aa7eb72569cc32a6ed120 100644 --- a/netconn.go +++ b/netconn.go @@ -7,6 +7,7 @@ import ( "io" "math" "net" + "sync/atomic" "time" ) @@ -159,3 +160,27 @@ func (c *netConn) SetReadDeadline(t time.Time) error { } return nil } + +// CloseRead will start a goroutine to read from the connection until it is closed or a data message +// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. +// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. +// After calling this method, you cannot read any data messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// Use this when you do not want to read data messages from the connection anymore but will +// want to write messages to it. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + atomic.StoreInt64(&c.readClosed, 1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + // We use the unexported reader method so that we don't get the read closed error. + c.reader(ctx) + // Either the connection is already closed since there was a read error + // or the context was cancelled or a message was read and we should close + // the connection. + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} diff --git a/websocket_js.go b/websocket_js.go index 14f198d15364d0d2046d360cabf5b63c7f478f6a..123bc8f4070f1459130369da32b0c7bd240eada4 100644 --- a/websocket_js.go +++ b/websocket_js.go @@ -10,6 +10,7 @@ import ( "reflect" "runtime" "sync" + "sync/atomic" "syscall/js" "nhooyr.io/websocket/internal/wsjs" @@ -19,9 +20,10 @@ import ( type Conn struct { ws wsjs.WebSocket - closeOnce sync.Once - closed chan struct{} - closeErr error + readClosed int64 + closeOnce sync.Once + closed chan struct{} + closeErr error releaseOnClose func() releaseOnMessage func() @@ -67,6 +69,10 @@ func (c *Conn) init() { // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + if atomic.LoadInt64(&c.readClosed) == 1 { + return 0, nil, fmt.Errorf("websocket connection read closed") + } + typ, p, err := c.read(ctx) if err != nil { return 0, nil, fmt.Errorf("failed to read: %w", err) @@ -78,6 +84,7 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { var me wsjs.MessageEvent select { case <-ctx.Done(): + c.Close(StatusPolicyViolation, "read timed out") return 0, nil, ctx.Err() case me = <-c.readch: case <-c.closed: @@ -198,6 +205,7 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp select { case <-ctx.Done(): + c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: case <-c.closed: @@ -215,3 +223,8 @@ func (c *netConn) netConnReader(ctx context.Context) (MessageType, io.Reader, er } return typ, bytes.NewReader(p), nil } + +// Only implemented for use by *Conn.CloseRead in netconn.go +func (c *Conn) reader(ctx context.Context) { + c.read(ctx) +}