diff --git a/websocket.go b/websocket.go index 88520858324c450661d5fb4032c212d04f1f0a07..09a94e7809ca6c8860093acb4ce46a34bef0e38a 100644 --- a/websocket.go +++ b/websocket.go @@ -49,13 +49,6 @@ type Conn struct { readDone chan int } -func (c *Conn) getCloseErr() error { - if c.closeErr != nil { - return c.closeErr - } - return nil -} - func (c *Conn) close(err error) { if err != nil { err = xerrors.Errorf("websocket: connection broken: %w", err) @@ -160,8 +153,12 @@ messageLoop: masked: c.client, } c.writeFrame(h, control.payload) - c.writeDone <- struct{}{} - continue + select { + case <-c.closed: + return + case c.writeDone <- struct{}{}: + continue + } case b, ok := <-c.writeBytes: h := header{ fin: !ok, @@ -349,14 +346,14 @@ func (c *Conn) Close(code StatusCode, reason string) error { p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code)) } - err2 := c.writeClose(p, CloseError{ + cerr := c.writeClose(p, CloseError{ Code: code, Reason: reason, }) if err != nil { return err } - return err2 + return cerr } func (c *Conn) writeClose(p []byte, cerr CloseError) error { @@ -381,19 +378,19 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error { func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { select { case <-c.closed: - return c.getCloseErr() + return c.closeErr case c.control <- control{ opcode: opcode, payload: p, }: case <-ctx.Done(): c.close(xerrors.New("force closed: close frame write timed out")) - return c.getCloseErr() + return c.closeErr } select { case <-c.closed: - return c.getCloseErr() + return c.closeErr case <-c.writeDone: return nil case <-ctx.Done(): @@ -420,9 +417,6 @@ type messageWriter struct { ctx context.Context c *Conn acquiredLock bool - sentFirst bool - - done chan struct{} } // Write writes the given bytes to the WebSocket connection. @@ -430,24 +424,18 @@ type messageWriter struct { // with the buffers obtained from http.Hijacker. // Please ensure you call Close once you have written the full message. func (w *messageWriter) Write(p []byte) (int, error) { - if !w.acquiredLock { - select { - case <-w.c.closed: - return 0, w.c.getCloseErr() - case w.c.write <- w.datatype: - w.acquiredLock = true - case <-w.ctx.Done(): - return 0, w.ctx.Err() - } + err := w.acquire() + if err != nil { + return 0, err } select { case <-w.c.closed: - return 0, w.c.getCloseErr() + return 0, w.c.closeErr case w.c.writeBytes <- p: select { case <-w.c.closed: - return 0, w.c.getCloseErr() + return 0, w.c.closeErr case <-w.c.writeDone: return len(p), nil case <-w.ctx.Done(): @@ -458,23 +446,32 @@ func (w *messageWriter) Write(p []byte) (int, error) { } } -// Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (w *messageWriter) Close() error { +func (w *messageWriter) acquire() error { if !w.acquiredLock { select { case <-w.c.closed: - return w.c.getCloseErr() + return w.c.closeErr case w.c.write <- w.datatype: w.acquiredLock = true case <-w.ctx.Done(): return w.ctx.Err() } } + return nil +} + +// Close flushes the frame to the connection. +// This must be called for every messageWriter. +func (w *messageWriter) Close() error { + err := w.acquire() + if err != nil { + return err + } + close(w.c.writeBytes) select { case <-w.c.closed: - return w.c.getCloseErr() + return w.c.closeErr case <-w.ctx.Done(): return w.ctx.Err() case <-w.c.writeDone: @@ -490,7 +487,7 @@ func (w *messageWriter) Close() error { func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) { select { case <-c.closed: - return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) + return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr) case opcode := <-c.read: return DataType(opcode), &messageReader{ ctx: ctx, @@ -507,24 +504,17 @@ type messageReader struct { c *Conn } -// SetContext bounds the read operation to the ctx. -// By default, the context is the one passed to conn.ReadMessage. -// You still almost always want a separate context for reading the message though. -func (r *messageReader) SetContext(ctx context.Context) { - r.ctx = ctx -} - // Read reads as many bytes as possible into p. func (r *messageReader) Read(p []byte) (n int, err error) { select { case <-r.c.closed: - return 0, r.c.getCloseErr() + return 0, r.c.closeErr case <-r.c.readDone: return 0, io.EOF case r.c.readBytes <- p: select { case <-r.c.closed: - return 0, r.c.getCloseErr() + return 0, r.c.closeErr case n := <-r.c.readDone: return n, nil case <-r.ctx.Done():