diff --git a/netconn.go b/netconn.go index 0de2f1cb0f9823455143832599eead9cac45d096..184d5d6c5df1f25afec13f1409c51e2701c0d0ff 100644 --- a/netconn.go +++ b/netconn.go @@ -2,11 +2,12 @@ package websocket import ( "context" - "golang.org/x/xerrors" "io" "math" "net" "time" + + "golang.org/x/xerrors" ) // NetConn converts a *websocket.Conn into a net.Conn. diff --git a/websocket.go b/websocket.go index e7fb0dfafb2088d5e0d5887eb1fe13bf1c9cc71e..f875a14267f3809deade516d6bd15cff8380732f 100644 --- a/websocket.go +++ b/websocket.go @@ -12,6 +12,7 @@ import ( "runtime" "strconv" "sync" + "sync/atomic" "time" "golang.org/x/xerrors" @@ -64,6 +65,7 @@ type Conn struct { previousReader *messageReader // readFrameLock is acquired to read from bw. readFrameLock chan struct{} + readClosed int64 readHeaderBuf []byte controlPayloadBuf []byte @@ -329,6 +331,10 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { // See https://github.com/nhooyr/websocket/issues/87#issue-451703332 // Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + if atomic.LoadInt64(&c.readClosed) == 1 { + return 0, nil, xerrors.Errorf("websocket connection read closed") + } + typ, r, err := c.reader(ctx) if err != nil { return 0, nil, xerrors.Errorf("failed to get reader: %w", err) @@ -395,10 +401,13 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { // Use this when you do not want to read data messages from the connection anymore but will // want to write messages to it. func (c *Conn) CloseRead(ctx context.Context) context.Context { + atomic.StoreInt64(&c.readClosed, 1) + ctx, cancel := context.WithCancel(ctx) go func() { defer cancel() - c.Reader(ctx) + // We use the unexported reader so that we don't get the read closed error. + c.reader(ctx) c.Close(StatusPolicyViolation, "unexpected data message") }() return ctx