good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit f178ccfa authored by Anmol Sethi's avatar Anmol Sethi Committed by GitHub
Browse files

Merge pull request #169 from nhooyr/race

Fix race with c.readerShouldLock
parents e36318f9 8b47056a
Branches
Tags v1.7.3
No related merge requests found
......@@ -82,7 +82,6 @@ type Conn struct {
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
readerShouldLock bool
setReadTimeout chan context.Context
setWriteTimeout chan context.Context
......@@ -237,6 +236,10 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
if h.opcode.controlOp() {
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
continue
......@@ -445,7 +448,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
c.readerFrameEOF = false
c.readerMaskPos = 0
c.readMsgLeft = c.msgReadLimit.Load()
c.readerShouldLock = lock
r := &messageReader{
c: c,
......@@ -465,7 +467,11 @@ func (r *messageReader) eof() bool {
// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
return r.exportedRead(p, true)
}
func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
n, err := r.read(p, lock)
if err != nil {
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
// isn't used widely yet.
......@@ -477,17 +483,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
return n, nil
}
func (r *messageReader) read(p []byte) (int, error) {
if r.c.readerShouldLock {
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
if err != nil {
return 0, err
func (r *messageReader) readUnlocked(p []byte) (int, error) {
return r.exportedRead(p, false)
}
func (r *messageReader) read(p []byte, lock bool) (int, error) {
if lock {
// If we cannot acquire the read lock, then
// there is either a concurrent read or the close handshake
// is proceeding.
select {
case r.c.readLock <- struct{}{}:
defer r.c.releaseLock(r.c.readLock)
default:
if r.c.closing.Load() == 1 {
<-r.c.closed
return 0, r.c.closeErr
}
return 0, errors.New("concurrent read detected")
}
}
if r.eof() {
return 0, fmt.Errorf("cannot use EOFed reader")
return 0, errors.New("cannot use EOFed reader")
}
if r.c.readMsgLeft <= 0 {
......@@ -950,8 +968,6 @@ func (c *Conn) waitClose() error {
return c.closeReceived
}
c.readerShouldLock = false
b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
......@@ -965,7 +981,8 @@ func (c *Conn) waitClose() error {
}
}
_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
r := readerFunc(c.activeReader.readUnlocked)
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
if err != nil {
return err
}
......@@ -1019,6 +1036,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
}
}
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment