diff --git a/websocket.go b/websocket.go index 275af9da72d3be94438701401bb17e189148bbd0..912508d5635321679d8c2459a66035fac11ab9fb 100644 --- a/websocket.go +++ b/websocket.go @@ -61,8 +61,6 @@ type Conn struct { } func (c *Conn) close(err error) { - err = xerrors.Errorf("websocket closed: %w", err) - c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) @@ -71,7 +69,7 @@ func (c *Conn) close(err error) { cerr = err } - c.closeErr = cerr + c.closeErr = xerrors.Errorf("websocket closed: %w", cerr) close(c.closed) }) @@ -98,7 +96,7 @@ func (c *Conn) init() { c.readDone = make(chan int) runtime.SetFinalizer(c, func(c *Conn) { - c.Close(StatusInternalError, "connection garbage collected") + c.close(xerrors.New("connection garbage collected")) }) go c.writeLoop() @@ -238,7 +236,7 @@ func (c *Conn) handleControl(h header) { case opClose: ce, err := parseClosePayload(b) if err != nil { - c.close(xerrors.Errorf("read invalid close payload: %w", err)) + c.close(xerrors.Errorf("received invalid close payload: %w", err)) return } if ce.Code == StatusNoStatusRcvd { @@ -302,7 +300,7 @@ func (c *Conn) readLoop() { } } -func (c *Conn) dataReadLoop(h header) (err error) { +func (c *Conn) dataReadLoop(h header) error { maskPos := 0 left := h.payloadLength firstReadDone := false @@ -355,7 +353,6 @@ func (c *Conn) writePong(p []byte) error { // Close closes the WebSocket connection with the given status code and reason. // It will write a WebSocket close frame with a timeout of 5 seconds. -// Concurrent calls to Close are ok. func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { @@ -400,7 +397,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { return err } - if cerr != c.closeErr { + if !xerrors.Is(c.closeErr, cerr) { return c.closeErr } @@ -420,9 +417,8 @@ func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) er payload: p, }: case <-ctx.Done(): - err := xerrors.Errorf("control frame write timed out: %w", ctx.Err()) - c.close(err) - return err + c.close(xerrors.Errorf("control frame write timed out: %w", ctx.Err())) + return ctx.Err() } select { @@ -487,7 +483,7 @@ func (w messageWriter) write(p []byte) (int, error) { select { case <-w.ctx.Done(): w.c.close(xerrors.Errorf("data write timed out: %w", w.ctx.Err())) - // Wait for writeLoop to complete so we know p is done. + // Wait for writeLoop to complete so we know p is done with. <-w.c.writeDone return 0, w.ctx.Err() case _, ok := <-w.c.writeDone: @@ -542,25 +538,21 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { } func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { - select { - case <-c.closed: - return 0, nil, c.closeErr - case c.readBytes <- nil: - select { - case <-ctx.Done(): - return 0, nil, ctx.Err() - case _, ok := <-c.readDone: - if !ok { - return 0, nil, c.closeErr - } - if atomic.LoadInt64(&c.activeReader) == 1 { - return 0, nil, xerrors.New("previous message not fully read") - } - } - case <-ctx.Done(): - return 0, nil, ctx.Err() + if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { + // If the next read yields io.EOF we are good to go. + r := messageReader{ + ctx: ctx, + c: c, } + _, err := r.Read(nil) + if err == nil { + return 0, nil, xerrors.New("previous message not fully read") + } + if !xerrors.Is(err, io.EOF) { + return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err) + } + + atomic.StoreInt64(&c.activeReader, 1) } select { @@ -586,7 +578,8 @@ type messageReader struct { func (r messageReader) Read(p []byte) (int, error) { n, err := r.read(p) if err != nil { - // Have to return io.EOF directly for now, cannot wrap. + // Have to return io.EOF directly for now, we cannot wrap as xerrors + // isn't used in stdlib. if err == io.EOF { return n, io.EOF }