good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit f6287497 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Remove readLoop

Closes #93
parent d8e872c7
Branches
Tags
No related merge requests found
......@@ -123,24 +123,18 @@ it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2.
Some more advantages of nhooyr/websocket are that it supports concurrent writes and
makes it very easy to close the connection with a status code and reason.
nhooyr/websocket also responds to pings, pongs and close frames in a separate goroutine so that
your application doesn't always need to read from the connection unless it expects a data message.
gorilla/websocket requires you to constantly read from the connection to respond to control frames
even if you don't expect the peer to send any messages.
The ping API is also much nicer. gorilla/websocket requires registering a pong handler on the Conn
which results in awkward control flow. With nhooyr/websocket you use the Ping method on the Conn
that sends a ping and also waits for the pong.
In terms of performance, the differences depend on your application code. nhooyr/websocket
reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas
gorilla/websocket does not. As mentioned above, nhooyr/websocket also supports concurrent
gorilla/websocket does not at all. As mentioned above, nhooyr/websocket also supports concurrent
writers out of the box.
The only performance con to nhooyr/websocket is that uses two extra goroutines. One for
reading pings, pongs and close frames async to application code and another to support
context.Context cancellation. This costs 4 KB of memory which is cheap compared
to the benefits.
The only performance con to nhooyr/websocket is that uses one extra goroutine to support
cancellation with context.Context and the net/http client side body upgrade.
This costs 2 KB of memory which is cheap compared to simplicity benefits.
### x/net/websocket
......
......@@ -81,9 +81,6 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error {
// Accept will reject the handshake if the Origin domain is not the same as the Host unless
// the InsecureSkipVerify option is set. In other words, by default it does not allow
// cross origin requests.
//
// The returned connection will be bound by r.Context(). Use conn.Context() to change
// the bounding context.
func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) {
c, err := accept(w, r, opts)
if err != nil {
......@@ -143,7 +140,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
closer: netConn,
}
c.init()
c.Context(r.Context())
return c, nil
}
......
package websocket
import (
"fmt"
"io"
"golang.org/x/xerrors"
......@@ -20,9 +19,9 @@ func (lr *limitedReader) Read(p []byte) (int, error) {
}
if lr.left <= 0 {
msg := fmt.Sprintf("read limited at %v bytes", lr.limit)
lr.c.Close(StatusPolicyViolation, msg)
return 0, xerrors.Errorf(msg)
err := xerrors.Errorf("read limited at %v bytes", lr.limit)
lr.c.Close(StatusMessageTooBig, err.Error())
return 0, err
}
if int64(len(p)) > lr.left {
......
......@@ -28,7 +28,7 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer
// writeBuf is used for masking, its the buffer in bufio.Writer.
// Only used by the client.
// Only used by the client for masking the bytes in the buffer.
writeBuf []byte
closer io.Closer
client bool
......@@ -51,17 +51,9 @@ type Conn struct {
previousReader *messageReader
// readFrameLock is acquired to read from bw.
readFrameLock chan struct{}
// readMsg is used by messageReader to receive frames from
// readLoop.
readMsg chan header
// readMsgDone is used to tell the readLoop to continue after
// messageReader has read a frame.
readMsgDone chan struct{}
setReadTimeout chan context.Context
setWriteTimeout chan context.Context
setConnContext chan context.Context
getConnContext chan context.Context
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
......@@ -76,13 +68,9 @@ func (c *Conn) init() {
c.writeFrameLock = make(chan struct{}, 1)
c.readFrameLock = make(chan struct{}, 1)
c.readMsg = make(chan header)
c.readMsgDone = make(chan struct{})
c.setReadTimeout = make(chan context.Context)
c.setWriteTimeout = make(chan context.Context)
c.setConnContext = make(chan context.Context)
c.getConnContext = make(chan context.Context)
c.activePings = make(map[string]chan<- struct{})
......@@ -91,7 +79,6 @@ func (c *Conn) init() {
})
go c.timeoutLoop()
go c.readLoop()
}
// Subprotocol returns the negotiated subprotocol.
......@@ -131,53 +118,20 @@ func (c *Conn) close(err error) {
func (c *Conn) timeoutLoop() {
readCtx := context.Background()
writeCtx := context.Background()
parentCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.setWriteTimeout:
case readCtx = <-c.setReadTimeout:
case <-readCtx.Done():
c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err()))
case <-writeCtx.Done():
c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err()))
case <-parentCtx.Done():
c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err()))
return
case parentCtx = <-c.setConnContext:
ctx, cancelCtx := context.WithCancel(parentCtx)
defer cancelCtx()
select {
case <-c.closed:
return
case c.getConnContext <- ctx:
}
}
}
}
// Context returns a context derived from parent that will be cancelled
// when the connection is closed or broken.
// If the parent context is cancelled, the connection will be closed.
func (c *Conn) Context(parent context.Context) context.Context {
select {
case <-c.closed:
ctx, cancel := context.WithCancel(parent)
cancel()
return ctx
case c.setConnContext <- parent:
}
select {
case <-c.closed:
ctx, cancel := context.WithCancel(parent)
cancel()
return ctx
case ctx := <-c.getConnContext:
return ctx
}
}
......@@ -210,30 +164,9 @@ func (c *Conn) releaseLock(lock chan struct{}) {
}
}
func (c *Conn) readLoop() {
for {
h, err := c.readTillMsg()
if err != nil {
return
}
select {
case <-c.closed:
return
case c.readMsg <- h:
}
select {
case <-c.closed:
return
case <-c.readMsgDone:
}
}
}
func (c *Conn) readTillMsg() (header, error) {
func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader()
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
......@@ -245,7 +178,10 @@ func (c *Conn) readTillMsg() (header, error) {
}
if h.opcode.controlOp() {
c.handleControl(h)
err = c.handleControl(ctx, h)
if err != nil {
return header{}, err
}
continue
}
......@@ -260,43 +196,64 @@ func (c *Conn) readTillMsg() (header, error) {
}
}
func (c *Conn) readFrameHeader() (header, error) {
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
err := c.acquireLock(context.Background(), c.readFrameLock)
if err != nil {
return header{}, err
}
defer c.releaseLock(c.readFrameLock)
select {
case <-c.closed:
return header{}, c.closeErr
case c.setReadTimeout <- ctx:
}
h, err := readHeader(c.br)
if err != nil {
select {
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
err = ctx.Err()
default:
}
err := xerrors.Errorf("failed to read header: %w", err)
c.releaseLock(c.readFrameLock)
c.close(err)
return header{}, err
}
select {
case <-c.closed:
return header{}, c.closeErr
case c.setReadTimeout <- context.Background():
}
return h, nil
}
func (c *Conn) handleControl(h header) {
func (c *Conn) handleControl(ctx context.Context, h header) error {
if h.payloadLength > maxControlFramePayload {
c.Close(StatusProtocolError, "control frame too large")
return
err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength)
c.Close(StatusProtocolError, err.Error())
return err
}
if !h.fin {
c.Close(StatusProtocolError, "control frame cannot be fragmented")
return
err := xerrors.Errorf("received fragmented control frame")
c.Close(StatusProtocolError, err.Error())
return err
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := make([]byte, h.payloadLength)
_, err := c.readFramePayload(ctx, b)
if err != nil {
return
return err
}
if h.masked {
......@@ -305,7 +262,7 @@ func (c *Conn) handleControl(h header) {
switch h.opcode {
case opPing:
c.writePong(b)
return c.writePong(b)
case opPong:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
......@@ -313,17 +270,20 @@ func (c *Conn) handleControl(h header) {
if ok {
close(pong)
}
return nil
case opClose:
ce, err := parseClosePayload(b)
if err != nil {
c.close(xerrors.Errorf("received invalid close payload: %w", err))
return
err = xerrors.Errorf("received invalid close payload: %w", err)
c.close(err)
return err
}
if ce.Code == StatusNoStatusRcvd {
c.writeClose(nil, ce)
} else {
c.Close(ce.Code, ce.Reason)
}
return c.closeErr
default:
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
}
......@@ -335,11 +295,10 @@ func (c *Conn) handleControl(h header) {
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Control (ping, pong, close) frames will be handled automatically
// in a separate goroutine so if you do not expect any data messages,
// you do not need to read from the connection. However, if the peer
// sends a data message, further pings, pongs and close frames will not
// be read if you do not read the message from the connection.
// You must read from the connection for close frames to be read.
// If you do not expect any data messages from the peer, just call
// Reader in a separate goroutine and close the connection with StatusPolicyViolation
// when it returns. Example at // TODO
//
// Only one Reader may be open at a time.
//
......@@ -368,12 +327,11 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
return 0, nil, xerrors.Errorf("previous message not read to completion")
}
select {
case <-c.closed:
return 0, nil, c.closeErr
case <-ctx.Done():
return 0, nil, ctx.Err()
case h := <-c.readMsg:
h, err := c.readTillMsg(ctx)
if err != nil {
return 0, nil, err
}
if c.previousReader != nil && !c.previousReader.done {
if h.opcode != opContinuation {
err := xerrors.Errorf("received new data message without finishing the previous message")
......@@ -387,12 +345,6 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
c.previousReader.done = true
select {
case <-c.closed:
return 0, nil, c.closeErr
case c.readMsgDone <- struct{}{}:
}
return c.reader(ctx)
} else if h.opcode == opContinuation {
err := xerrors.Errorf("received continuation frame not after data or text frame")
......@@ -409,7 +361,6 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
c.previousReader = r
return MessageType(h.opcode), r, nil
}
}
// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
......@@ -441,13 +392,11 @@ func (r *messageReader) read(p []byte) (int, error) {
}
if r.h == nil {
select {
case <-r.c.closed:
return 0, r.c.closeErr
case <-r.ctx.Done():
r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err()))
return 0, r.ctx.Err()
case h := <-r.c.readMsg:
h, err := r.c.readTillMsg(r.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := xerrors.Errorf("received new data frame without finishing the previous frame")
r.c.Close(StatusProtocolError, err.Error())
......@@ -455,7 +404,6 @@ func (r *messageReader) read(p []byte) (int, error) {
}
r.h = &h
}
}
if int64(len(p)) > r.h.payloadLength {
p = p[:r.h.payloadLength]
......@@ -473,12 +421,6 @@ func (r *messageReader) read(p []byte) (int, error) {
}
if r.h.payloadLength == 0 {
select {
case <-r.c.closed:
return n, r.c.closeErr
case r.c.readMsgDone <- struct{}{}:
}
fin := r.h.fin
// Need to nil this as Reader uses it to check
......@@ -539,7 +481,7 @@ func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusPolicyViolation.
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit = n
}
......
......@@ -383,6 +383,8 @@ func TestHandshake(t *testing.T) {
}
defer c.Close(websocket.StatusInternalError, "")
go c.Reader(r.Context())
err = c.Ping(r.Context())
if err != nil {
return err
......@@ -403,10 +405,10 @@ func TestHandshake(t *testing.T) {
}
defer c.Close(websocket.StatusInternalError, "")
err = c.Ping(ctx)
if err != nil {
return err
}
errc := make(chan error, 1)
go func() {
errc <- c.Ping(ctx)
}()
_, _, err = c.Read(ctx)
if err != nil {
......@@ -414,7 +416,7 @@ func TestHandshake(t *testing.T) {
}
c.Close(websocket.StatusNormalClosure, "")
return nil
return <-errc
},
},
{
......@@ -439,6 +441,8 @@ func TestHandshake(t *testing.T) {
}
defer c.Close(websocket.StatusInternalError, "")
go c.Reader(ctx)
err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769)))
if err != nil {
return err
......@@ -454,46 +458,6 @@ func TestHandshake(t *testing.T) {
return nil
},
},
{
name: "context",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(r.Context(), time.Second)
defer cancel()
c.Context(ctx)
for r.Context().Err() == nil {
err = c.Ping(ctx)
if err != nil {
return nil
}
}
return xerrors.Errorf("all pings succeeded")
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
cctx := c.Context(ctx)
select {
case <-ctx.Done():
return xerrors.Errorf("child context never cancelled")
case <-cctx.Done():
return nil
}
},
},
}
for _, tc := range testCases {
......
......@@ -44,6 +44,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err = json.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
return xerrors.Errorf("failed to unmarshal json: %w", err)
}
......
......@@ -46,6 +46,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err = proto.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf")
return xerrors.Errorf("failed to unmarshal protobuf: %w", err)
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment