From 4f014d23d63b60a6ed127590482938c65181405d Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Wed, 23 Oct 2019 09:53:41 -0400 Subject: [PATCH] Fix concurrent read with close Closes #164 --- conn.go | 19 ++++++++++++++----- conn_test.go | 23 +++++++++++++++++++++++ 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/conn.go b/conn.go index 861b239..cbb7fa5 100644 --- a/conn.go +++ b/conn.go @@ -42,11 +42,12 @@ type Conn struct { closer io.Closer client bool - closeOnce sync.Once - closeErrOnce sync.Once - closeErr error - closed chan struct{} - closing *atomicInt64 + closeOnce sync.Once + closeErrOnce sync.Once + closeErr error + closed chan struct{} + closing *atomicInt64 + closeReceived error // messageWriter state. // writeMsgLock is acquired to write a data message. @@ -339,10 +340,12 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) c.exportedClose(StatusProtocolError, err.Error(), false) + c.closeReceived = err return err } err = fmt.Errorf("received close: %w", ce) + c.closeReceived = err c.writeClose(b, err, false) if ctx.Err() != nil { @@ -941,6 +944,12 @@ func (c *Conn) waitClose() error { return err } defer c.releaseLock(c.readLock) + + if c.closeReceived != nil { + // goroutine reading just received the close. + return c.closeReceived + } + c.readerShouldLock = false b := bpool.Get() diff --git a/conn_test.go b/conn_test.go index d924fd0..83f09db 100644 --- a/conn_test.go +++ b/conn_test.go @@ -868,6 +868,29 @@ func TestConn(t *testing.T) { return c.Close(websocket.StatusNormalClosure, "") }, }, + { + // Issue #164 + name: "closeHandshake_concurrentRead", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + return assertCloseStatus(err, websocket.StatusNormalClosure) + }, + client: func(ctx context.Context, c *websocket.Conn) error { + errc := make(chan error, 1) + go func() { + _, _, err := c.Read(ctx) + errc <- err + }() + + err := c.Close(websocket.StatusNormalClosure, "") + if err != nil { + return err + } + + err = <-errc + return assertCloseStatus(err, websocket.StatusNormalClosure) + }, + }, } for _, tc := range testCases { tc := tc -- GitLab