From 780bda4159cd001ed4e1704327c1292a1d21336d Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 4 Nov 2019 18:50:29 -0500 Subject: [PATCH] Fix race with c.readerShouldLock Closes #168 --- conn.go | 53 ++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/conn.go b/conn.go index cbb7fa5..7d48b8a 100644 --- a/conn.go +++ b/conn.go @@ -78,11 +78,10 @@ type Conn struct { readLock chan struct{} // messageReader state. - readerMsgCtx context.Context - readerMsgHeader header - readerFrameEOF bool - readerMaskPos int - readerShouldLock bool + readerMsgCtx context.Context + readerMsgHeader header + readerFrameEOF bool + readerMaskPos int setReadTimeout chan context.Context setWriteTimeout chan context.Context @@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e c.readerFrameEOF = false c.readerMaskPos = 0 c.readMsgLeft = c.msgReadLimit.Load() - c.readerShouldLock = lock r := &messageReader{ c: c, @@ -465,7 +463,11 @@ func (r *messageReader) eof() bool { // Read reads as many bytes as possible into p. func (r *messageReader) Read(p []byte) (int, error) { - n, err := r.read(p) + return r.exportedRead(p, true) +} + +func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) { + n, err := r.read(p, lock) if err != nil { // Have to return io.EOF directly for now, we cannot wrap as errors.Is // isn't used widely yet. @@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) { return n, nil } -func (r *messageReader) read(p []byte) (int, error) { - if r.c.readerShouldLock { - err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock) - if err != nil { - return 0, err +func (r *messageReader) readUnlocked(p []byte) (int, error) { + return r.exportedRead(p, false) +} + +func (r *messageReader) read(p []byte, lock bool) (int, error) { + if lock { + // If we cannot acquire the read lock, then + // there is either a concurrent read or the close handshake + // is proceeding. + select { + case r.c.readLock <- struct{}{}: + defer r.c.releaseLock(r.c.readLock) + default: + if r.c.closing.Load() == 1 { + <-r.c.closed + return 0, r.c.closeErr + } + return 0, errors.New("concurrent read detected") } - defer r.c.releaseLock(r.c.readLock) } if r.eof() { - return 0, fmt.Errorf("cannot use EOFed reader") + return 0, errors.New("cannot use EOFed reader") } if r.c.readMsgLeft <= 0 { @@ -950,8 +964,6 @@ func (c *Conn) waitClose() error { return c.closeReceived } - c.readerShouldLock = false - b := bpool.Get() buf := b.Bytes() buf = buf[:cap(buf)] @@ -965,7 +977,8 @@ func (c *Conn) waitClose() error { } } - _, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf) + r := readerFunc(c.activeReader.readUnlocked) + _, err = io.CopyBuffer(ioutil.Discard, r, buf) if err != nil { return err } @@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error { } } +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} + type writerFunc func(p []byte) (int, error) func (f writerFunc) Write(p []byte) (int, error) { -- GitLab