diff --git a/websocket.go b/websocket.go index 7d33a36a6a0dc4adecef1d2ee4ec7229143ce408..6480aed66ff1b1b8c39da1cf034b4d29df8fa101 100644 --- a/websocket.go +++ b/websocket.go @@ -38,10 +38,10 @@ type Conn struct { writeDataLock chan struct{} writeFrameLock chan struct{} - readDataLock chan struct{} - readData chan header - readDone chan struct{} - readLoopDone chan struct{} + readMsgLock chan struct{} + readMsg chan header + readMsgDone chan struct{} + readFrameLock chan struct{} setReadTimeout chan context.Context setWriteTimeout chan context.Context @@ -90,17 +90,15 @@ func (c *Conn) close(err error) { close(c.closed) + // This ensures every goroutine that interacts + // with the conn closes before it can interact with the connection + c.readFrameLock <- struct{}{} + c.writeFrameLock <- struct{}{} + // See comment in dial.go if c.client { - go func() { - <-c.readLoopDone - // TODO this does not work if reader errors out so skip for now. - // c.readDataLock <- struct{}{} - // c.writeFrameLock <- struct{}{} - // - // returnBufioReader(c.br) - // returnBufioWriter(c.bw) - }() + returnBufioReader(c.br) + returnBufioWriter(c.bw) } }) } @@ -119,10 +117,10 @@ func (c *Conn) init() { c.writeDataLock = make(chan struct{}, 1) c.writeFrameLock = make(chan struct{}, 1) - c.readData = make(chan header) - c.readDone = make(chan struct{}) - c.readDataLock = make(chan struct{}, 1) - c.readLoopDone = make(chan struct{}) + c.readMsg = make(chan header) + c.readMsgDone = make(chan struct{}) + c.readMsgLock = make(chan struct{}, 1) + c.readFrameLock = make(chan struct{}, 1) c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) @@ -141,8 +139,8 @@ func (c *Conn) init() { // We never mask inside here because our mask key is always 0,0,0,0. // See comment on secWebSocketKey. -func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error { - err := c.acquireLock(ctx, c.writeFrameLock) +func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) (err error) { + err = c.acquireLock(ctx, c.writeFrameLock) if err != nil { return err } @@ -164,27 +162,33 @@ func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) error { } }() + defer func() { + if err != nil { + // We need to always release the lock first before closing the connection to ensure + // the lock can be acquired inside close. + c.releaseLock(c.writeFrameLock) + c.close(err) + } + }() + h.masked = c.client h.payloadLength = int64(len(p)) b2 := marshalHeader(h) _, err = c.bw.Write(b2) if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return c.closeErr + return xerrors.Errorf("failed to write to connection: %w", err) } _, err = c.bw.Write(p) if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return c.closeErr + return xerrors.Errorf("failed to write to connection: %w", err) } if h.fin { err := c.bw.Flush() if err != nil { - c.close(xerrors.Errorf("failed to write to connection: %w", err)) - return c.closeErr + return xerrors.Errorf("failed to write to connection: %w", err) } } @@ -279,9 +283,9 @@ func (c *Conn) handleControl(h header) { func (c *Conn) readTillData() (header, error) { for { - h, err := readHeader(c.br) + h, err := c.readHeader() if err != nil { - return header{}, xerrors.Errorf("failed to read header: %w", err) + return header{}, err } if h.rsv1 || h.rsv2 || h.rsv3 { @@ -312,9 +316,22 @@ func (c *Conn) readTillData() (header, error) { } } -func (c *Conn) readLoop() { - defer close(c.readLoopDone) +func (c *Conn) readHeader() (header, error) { + err := c.acquireLock(context.Background(), c.readFrameLock) + if err != nil { + return header{}, err + } + defer c.releaseLock(c.readFrameLock) + h, err := readHeader(c.br) + if err != nil { + return header{}, xerrors.Errorf("failed to read header: %w", err) + } + + return h, nil +} + +func (c *Conn) readLoop() { for { h, err := c.readTillData() if err != nil { @@ -325,13 +342,13 @@ func (c *Conn) readLoop() { select { case <-c.closed: return - case c.readData <- h: + case c.readMsg <- h: } select { case <-c.closed: return - case <-c.readDone: + case <-c.readMsgDone: } } } @@ -374,7 +391,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // Definitely worth seeing what popular browsers do later. p, err := ce.bytes() if err != nil { - fmt.Fprintf(os.Stderr, "failed to marshal close frame: %v\n", err) + fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) ce = CloseError{ Code: StatusInternalError, } @@ -415,7 +432,11 @@ func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { } func (c *Conn) releaseLock(lock chan struct{}) { - <-lock + // Allow multiple releases. + select { + case <-lock: + default: + } } func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error { @@ -572,7 +593,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { } func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { - err = c.acquireLock(ctx, c.readDataLock) + err = c.acquireLock(ctx, c.readMsgLock) if err != nil { return 0, nil, err } @@ -582,7 +603,7 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro return 0, nil, c.closeErr case <-ctx.Done(): return 0, nil, ctx.Err() - case h := <-c.readData: + case h := <-c.readMsg: if h.opcode == opContinuation { ce := CloseError{ Code: StatusProtocolError, @@ -631,7 +652,7 @@ func (r *messageReader) read(p []byte) (int, error) { select { case <-r.c.closed: return 0, r.c.closeErr - case h := <-r.c.readData: + case h := <-r.c.readMsg: if h.opcode != opContinuation { ce := CloseError{ Code: StatusProtocolError, @@ -654,7 +675,12 @@ func (r *messageReader) read(p []byte) (int, error) { case r.c.setReadTimeout <- r.ctx: } + err := r.c.acquireLock(r.ctx, r.c.readFrameLock) + if err != nil { + return 0, err + } n, err := io.ReadFull(r.c.br, p) + r.c.releaseLock(r.c.readFrameLock) select { case <-r.c.closed: @@ -676,11 +702,11 @@ func (r *messageReader) read(p []byte) (int, error) { select { case <-r.c.closed: return n, r.c.closeErr - case r.c.readDone <- struct{}{}: + case r.c.readMsgDone <- struct{}{}: } if r.h.fin { r.eofed = true - r.c.releaseLock(r.c.readDataLock) + r.c.releaseLock(r.c.readMsgLock) return n, io.EOF } r.maskPos = 0