good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 780bda41 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Fix race with c.readerShouldLock

Closes #168
parent e36318f9
No related branches found
No related tags found
No related merge requests found
......@@ -78,11 +78,10 @@ type Conn struct {
readLock chan struct{}
// messageReader state.
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
readerShouldLock bool
readerMsgCtx context.Context
readerMsgHeader header
readerFrameEOF bool
readerMaskPos int
setReadTimeout chan context.Context
setWriteTimeout chan context.Context
......@@ -445,7 +444,6 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e
c.readerFrameEOF = false
c.readerMaskPos = 0
c.readMsgLeft = c.msgReadLimit.Load()
c.readerShouldLock = lock
r := &messageReader{
c: c,
......@@ -465,7 +463,11 @@ func (r *messageReader) eof() bool {
// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
return r.exportedRead(p, true)
}
func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) {
n, err := r.read(p, lock)
if err != nil {
// Have to return io.EOF directly for now, we cannot wrap as errors.Is
// isn't used widely yet.
......@@ -477,17 +479,29 @@ func (r *messageReader) Read(p []byte) (int, error) {
return n, nil
}
func (r *messageReader) read(p []byte) (int, error) {
if r.c.readerShouldLock {
err := r.c.acquireLock(r.c.readerMsgCtx, r.c.readLock)
if err != nil {
return 0, err
func (r *messageReader) readUnlocked(p []byte) (int, error) {
return r.exportedRead(p, false)
}
func (r *messageReader) read(p []byte, lock bool) (int, error) {
if lock {
// If we cannot acquire the read lock, then
// there is either a concurrent read or the close handshake
// is proceeding.
select {
case r.c.readLock <- struct{}{}:
defer r.c.releaseLock(r.c.readLock)
default:
if r.c.closing.Load() == 1 {
<-r.c.closed
return 0, r.c.closeErr
}
return 0, errors.New("concurrent read detected")
}
defer r.c.releaseLock(r.c.readLock)
}
if r.eof() {
return 0, fmt.Errorf("cannot use EOFed reader")
return 0, errors.New("cannot use EOFed reader")
}
if r.c.readMsgLeft <= 0 {
......@@ -950,8 +964,6 @@ func (c *Conn) waitClose() error {
return c.closeReceived
}
c.readerShouldLock = false
b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
......@@ -965,7 +977,8 @@ func (c *Conn) waitClose() error {
}
}
_, err = io.CopyBuffer(ioutil.Discard, c.activeReader, buf)
r := readerFunc(c.activeReader.readUnlocked)
_, err = io.CopyBuffer(ioutil.Discard, r, buf)
if err != nil {
return err
}
......@@ -1019,6 +1032,12 @@ func (c *Conn) ping(ctx context.Context, p string) error {
}
}
type readerFunc func(p []byte) (int, error)
func (f readerFunc) Read(p []byte) (int, error) {
return f(p)
}
type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment