From 43cb01eaf9fad1e2052a18b69b777db62820aae7 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 29 Nov 2019 00:00:52 -0500 Subject: [PATCH] Refactor read.go/write.go --- README.md | 43 +++--- assert_test.go | 13 +- close.go | 64 +++++---- conn.go | 92 ++++++++++--- conn_test.go | 3 +- internal/assert/assert.go | 2 +- read.go | 266 ++++++++++++++++---------------------- write.go | 215 +++++++++++++----------------- wsjson/wsjson.go | 1 - 9 files changed, 345 insertions(+), 354 deletions(-) diff --git a/README.md b/README.md index efb4a59..f0babdf 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ go get nhooyr.io/websocket - Concurrent writes - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper -- [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping) +- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping) - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) @@ -88,26 +88,27 @@ c.Close(websocket.StatusNormalClosure, "") [gorilla/websocket](https://github.com/gorilla/websocket) is a widely used and mature library. Advantages of nhooyr.io/websocket: - - Minimal and idiomatic API - - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. - - [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper - - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - - Full [context.Context](https://blog.golang.org/context) support - - Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing - - Will enable easy HTTP/2 support in the future - - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client. - - Concurrent writes - - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - - Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API - - gorilla/websocket requires registering a pong callback and then sending a Ping - - Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - - Transparent buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - - Gorilla's implementation depends on unsafe and is slower - - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support + +- Minimal and idiomatic API + - Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side. +- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper +- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) +- Full [context.Context](https://blog.golang.org/context) support +- Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing + - Will enable easy HTTP/2 support in the future + - Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client. +- Concurrent writes +- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) +- Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API + - gorilla/websocket requires registering a pong callback and then sending a Ping +- Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) +- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages +- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go + - Gorilla's implementation depends on unsafe and is slower +- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Gorilla only supports no context takeover mode - - [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper - - Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) +- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper +- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) #### golang.org/x/net/websocket @@ -120,7 +121,7 @@ to nhooyr.io/websocket. #### gobwas/ws [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used -in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). +in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/assert_test.go b/assert_test.go index b6e50a4..e431993 100644 --- a/assert_test.go +++ b/assert_test.go @@ -4,12 +4,11 @@ import ( "context" "crypto/rand" "io" - "strings" - "testing" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/assert" "nhooyr.io/websocket/wsjson" + "strings" + "testing" ) func randBytes(t *testing.T, n int) []byte { @@ -21,12 +20,15 @@ func randBytes(t *testing.T, n int) []byte { func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { t.Helper() + defer c.Close(websocket.StatusInternalError, "") exp := randString(t, n) err := wsjson.Write(ctx, c, exp) assert.Success(t, err) assertJSONRead(t, ctx, c, exp) + + c.Close(websocket.StatusNormalClosure, "") } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { @@ -74,5 +76,10 @@ func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { t.Helper() + defer func() { + if t.Failed() { + t.Logf("error: %+v", err) + } + }() assert.Equal(t, exp, websocket.CloseStatus(err), "StatusCode") } diff --git a/close.go b/close.go index a02dc7d..4c474b7 100644 --- a/close.go +++ b/close.go @@ -7,9 +7,6 @@ import ( "fmt" "log" "nhooyr.io/websocket/internal/errd" - "time" - - "nhooyr.io/websocket/internal/bpool" ) // StatusCode represents a WebSocket status code. @@ -103,59 +100,58 @@ func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { defer errd.Wrap(&err, "failed to close WebSocket") - err = c.cw.sendClose(code, reason) + err = c.writeClose(code, reason) if err != nil { return err } - return c.cr.waitClose() + return c.waitClose() } -func (cw *connWriter) error(code StatusCode, err error) { - cw.c.setCloseErr(err) - cw.sendClose(code, err.Error()) - cw.c.closeWithErr(nil) +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.closeWithErr(nil) } -func (cw *connWriter) sendClose(code StatusCode, reason string) error { +func (c *Conn) writeClose(code StatusCode, reason string) error { ce := CloseError{ Code: code, Reason: reason, } - cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) var p []byte if ce.Code != StatusNoStatusRcvd { p = ce.bytes() } - return cw.control(context.Background(), opClose, p) + return c.writeControl(context.Background(), opClose, p) } -func (cr *connReader) waitClose() error { - defer cr.c.closeWithErr(nil) +func (c *Conn) waitClose() error { + defer c.closeWithErr(nil) return nil - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := cr.mu.Lock(ctx) - if err != nil { - return err - } - defer cr.mu.Unlock() - - b := bpool.Get() - buf := b.Bytes() - buf = buf[:cap(buf)] - defer bpool.Put(b) - - for { - // TODO - return nil - } + // ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + // defer cancel() + // + // err := cr.mu.Lock(ctx) + // if err != nil { + // return err + // } + // defer cr.mu.Unlock() + // + // b := bpool.Get() + // buf := b.Bytes() + // buf = buf[:cap(buf)] + // defer bpool.Put(b) + // + // for { + // return nil + // } } func parseClosePayload(p []byte) (CloseError, error) { @@ -230,11 +226,11 @@ func (ce CloseError) bytesErr() ([]byte, error) { func (c *Conn) setCloseErr(err error) { c.closeMu.Lock() - c.setCloseErrNoLock(err) + c.setCloseErrLocked(err) c.closeMu.Unlock() } -func (c *Conn) setCloseErrNoLock(err error) { +func (c *Conn) setCloseErrLocked(err error) { if c.closeErr == nil { c.closeErr = fmt.Errorf("WebSocket closed: %w", err) } diff --git a/conn.go b/conn.go index d900179..dc067d1 100644 --- a/conn.go +++ b/conn.go @@ -30,7 +30,7 @@ const ( // All methods may be called concurrently except for Reader and Read. // // You must always read from the connection. Otherwise control -// frames will not be handled. See the docs on Reader and CloseRead. +// frames will not be handled. See Reader and CloseRead. // // Be sure to call Close on the connection when you // are finished with it to release associated resources. @@ -42,9 +42,22 @@ type Conn struct { rwc io.ReadWriteCloser client bool copts *compressionOptions + br *bufio.Reader + bw *bufio.Writer - cr connReader - cw connWriter + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu mu + readControlBuf [maxControlPayload]byte + msgReader *msgReader + + // Write state. + msgWriter *msgWriter + writeFrameMu mu + writeBuf []byte + writeHeader header closed chan struct{} @@ -63,8 +76,8 @@ type connConfig struct { client bool copts *compressionOptions - bw *bufio.Writer br *bufio.Reader + bw *bufio.Writer } func newConn(cfg connConfig) *Conn { @@ -73,13 +86,23 @@ func newConn(cfg connConfig) *Conn { rwc: cfg.rwc, client: cfg.client, copts: cfg.copts, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), } - c.cr.init(c, cfg.br) - c.cw.init(c, cfg.bw) + c.msgReader = newMsgReader(c) - c.closed = make(chan struct{}) - c.activePings = make(map[string]chan<- struct{}) + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } runtime.SetFinalizer(c, func(c *Conn) { c.closeWithErr(errors.New("connection garbage collected")) @@ -90,6 +113,34 @@ func newConn(cfg connConfig) *Conn { return c } +func newMsgReader(c *Conn) *msgReader { + mr := &msgReader{ + c: c, + fin: true, + } + + mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768) + if c.deflateNegotiated() && mr.contextTakeover() { + mr.ensureFlateReader() + } + + return mr +} + +func newMsgWriter(c *Conn) *msgWriter { + mw := &msgWriter{ + c: c, + } + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + if c.deflateNegotiated() && mw.contextTakeover() { + mw.ensureFlateWriter() + } + + return mw +} + // Subprotocol returns the negotiated subprotocol. // An empty string means the default protocol. func (c *Conn) Subprotocol() string { @@ -105,7 +156,7 @@ func (c *Conn) closeWithErr(err error) { } close(c.closed) runtime.SetFinalizer(c, nil) - c.setCloseErrNoLock(err) + c.setCloseErrLocked(err) // Have to close after c.closed is closed to ensure any goroutine that wakes up // from the connection being closed also sees that c.closed is closed and returns @@ -113,8 +164,18 @@ func (c *Conn) closeWithErr(err error) { c.rwc.Close() go func() { - c.cr.close() - c.cw.close() + if c.client { + c.writeFrameMu.Lock(context.Background()) + putBufioWriter(c.bw) + } + c.msgWriter.close() + + if c.client { + c.readMu.Lock(context.Background()) + putBufioReader(c.br) + c.readMu.Unlock() + } + c.msgReader.close() }() } @@ -127,13 +188,12 @@ func (c *Conn) timeoutLoop() { case <-c.closed: return - case writeCtx = <-c.cw.timeout: - case readCtx = <-c.cr.timeout: + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - c.cw.error(StatusPolicyViolation, errors.New("timed out")) - return + go c.writeError(StatusPolicyViolation, errors.New("timed out")) case <-writeCtx.Done(): c.closeWithErr(fmt.Errorf("write timed out: %w", writeCtx.Err())) return @@ -175,7 +235,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { c.activePingsMu.Unlock() }() - err := c.cw.control(ctx, opPing, []byte(p)) + err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } diff --git a/conn_test.go b/conn_test.go index 6b8a778..cf2334f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -25,6 +25,7 @@ func TestConn(t *testing.T) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ Subprotocols: []string{"echo"}, InsecureSkipVerify: true, + // CompressionMode: websocket.CompressionDisabled, }) assert.Success(t, err) defer c.Close(websocket.StatusInternalError, "") @@ -41,12 +42,12 @@ func TestConn(t *testing.T) { opts := &websocket.DialOptions{ Subprotocols: []string{"echo"}, + // CompressionMode: websocket.CompressionDisabled, } opts.HTTPClient = s.Client() c, _, err := websocket.Dial(ctx, wsURL, opts) assert.Success(t, err) - assertJSONEcho(t, ctx, c, 2) }) } diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 4ebdb51..b448711 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -23,7 +23,7 @@ func NotEqual(t testing.TB, exp, act interface{}, name string) { func Success(t testing.TB, err error) { t.Helper() if err != nil { - t.Fatalf("unexpected error : %+v", err) + t.Fatalf("unexpected error: %+v", err) } } diff --git a/read.go b/read.go index 7dba832..d8691d6 100644 --- a/read.go +++ b/read.go @@ -1,7 +1,6 @@ package websocket import ( - "bufio" "context" "errors" "fmt" @@ -14,41 +13,22 @@ import ( "nhooyr.io/websocket/internal/errd" ) -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. +// Reader reads from the connection until until there is a WebSocket +// data message to be read. It will handle ping, pong and close frames as appropriate. +// +// It returns the type of the message and an io.Reader to read it. // The passed context will also bound the reader. // Ensure you read to EOF otherwise the connection will hang. // -// All returned errors will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. -// -// You must read from the connection for control frames to be handled. -// Thus if you expect messages to take a long time to be responded to, -// you should handle such messages async to reading from the connection -// to ensure control frames are promptly handled. -// -// If you do not expect any data messages from the peer, call CloseRead. +// Call CloseRead if you do not expect any data messages from the peer. // // Only one Reader may be open at a time. -// -// If you need a separate timeout on the Reader call and then the message -// Read, use time.AfterFunc to cancel the context passed in early. -// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 -// Most users should not need this. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - typ, r, err := c.cr.reader(ctx) - if err != nil { - return 0, nil, fmt.Errorf("failed to get reader: %w", err) - } - return typ, r, nil + return c.reader(ctx) } -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method to reuse buffers or for streaming. -// The docs on Reader apply to this method as well. +// Read is a convenience method around Reader to read a single message +// from the connection. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { typ, r, err := c.Reader(ctx) if err != nil { @@ -59,14 +39,17 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return typ, b, err } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. +// CloseRead starts a goroutine to read from the connection until it is closed +// or a data message is received. +// +// Once CloseRead is called you cannot read any messages from the connection. // The returned context will be cancelled when the connection is closed. // -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. +// If a data message is received, the connection will be closed with StatusPolicyViolation. +// +// Call CloseRead when you do not expect to read any more messages. +// Since it actively reads from the connection, it will ensure that ping, pong and close +// frames are responded to. func (c *Conn) CloseRead(ctx context.Context) context.Context { ctx, cancel := context.WithCancel(ctx) go func() { @@ -84,60 +67,32 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // // When the limit is hit, the connection will be closed with StatusMessageTooBig. func (c *Conn) SetReadLimit(n int64) { - c.cr.mr.lr.limit.Store(n) -} - -type connReader struct { - c *Conn - br *bufio.Reader - timeout chan context.Context - - mu mu - controlPayloadBuf [maxControlPayload]byte - mr *msgReader -} - -func (cr *connReader) init(c *Conn, br *bufio.Reader) { - cr.c = c - cr.br = br - cr.timeout = make(chan context.Context) - - cr.mr = &msgReader{ - cr: cr, - fin: true, - } - - cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768) - if c.deflateNegotiated() && cr.contextTakeover() { - cr.ensureFlateReader() - } + c.msgReader.limitReader.setLimit(n) } -func (cr *connReader) ensureFlateReader() { - cr.mr.fr = getFlateReader(readerFunc(cr.mr.read)) - cr.mr.lr.reset(cr.mr.fr) +func (mr *msgReader) ensureFlateReader() { + mr.flateReader = getFlateReader(readerFunc(mr.read)) + mr.limitReader.reset(mr.flateReader) } -func (cr *connReader) close() { - cr.mu.Lock(context.Background()) - if cr.c.client { - putBufioReader(cr.br) - } - if cr.c.deflateNegotiated() && cr.contextTakeover() { - putFlateReader(cr.mr.fr) +func (mr *msgReader) close() { + if mr.c.deflateNegotiated() && mr.contextTakeover() { + mr.c.readMu.Lock(context.Background()) + putFlateReader(mr.flateReader) + mr.c.readMu.Unlock() } } -func (cr *connReader) contextTakeover() bool { - if cr.c.client { - return cr.c.copts.serverNoContextTakeover +func (mr *msgReader) contextTakeover() bool { + if mr.c.client { + return mr.c.copts.serverNoContextTakeover } - return cr.c.copts.clientNoContextTakeover + return mr.c.copts.clientNoContextTakeover } -func (cr *connReader) rsv1Illegal(h header) bool { +func (c *Conn) readRSV1Illegal(h header) bool { // If compression is enabled, rsv1 is always illegal. - if !cr.c.deflateNegotiated() { + if !c.deflateNegotiated() { return true } // rsv1 is only allowed on data frames beginning messages. @@ -147,26 +102,26 @@ func (cr *connReader) rsv1Illegal(h header) bool { return false } -func (cr *connReader) loop(ctx context.Context) (header, error) { +func (c *Conn) readLoop(ctx context.Context) (header, error) { for { - h, err := cr.frameHeader(ctx) + h, err := c.readFrameHeader(ctx) if err != nil { return header{}, err } - if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 { + if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return header{}, err } - if !cr.c.client && !h.masked { + if !c.client && !h.masked { return header{}, errors.New("received unmasked frame from client") } switch h.opcode { case opClose, opPing, opPong: - err = cr.control(ctx, h) + err = c.handleControl(ctx, h) if err != nil { // Pass through CloseErrors when receiving a close frame. if h.opcode == opClose && CloseStatus(err) != -1 { @@ -178,95 +133,89 @@ func (cr *connReader) loop(ctx context.Context) (header, error) { return h, nil default: err := fmt.Errorf("received unknown opcode %v", h.opcode) - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return header{}, err } } } -func (cr *connReader) frameHeader(ctx context.Context) (header, error) { +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { select { - case <-cr.c.closed: - return header{}, cr.c.closeErr - case cr.timeout <- ctx: + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- ctx: } - h, err := readFrameHeader(cr.br) + h, err := readFrameHeader(c.br) if err != nil { select { - case <-cr.c.closed: - return header{}, cr.c.closeErr + case <-c.closed: + return header{}, c.closeErr case <-ctx.Done(): return header{}, ctx.Err() default: - cr.c.closeWithErr(err) + c.closeWithErr(err) return header{}, err } } select { - case <-cr.c.closed: - return header{}, cr.c.closeErr - case cr.timeout <- context.Background(): + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- context.Background(): } return h, nil } -func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) { +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { select { - case <-cr.c.closed: - return 0, cr.c.closeErr - case cr.timeout <- ctx: + case <-c.closed: + return 0, c.closeErr + case c.readTimeout <- ctx: } - n, err := io.ReadFull(cr.br, p) + n, err := io.ReadFull(c.br, p) if err != nil { select { - case <-cr.c.closed: - return n, cr.c.closeErr + case <-c.closed: + return n, c.closeErr case <-ctx.Done(): return n, ctx.Err() default: err = fmt.Errorf("failed to read frame payload: %w", err) - cr.c.closeWithErr(err) + c.closeWithErr(err) return n, err } } select { - case <-cr.c.closed: - return n, cr.c.closeErr - case cr.timeout <- context.Background(): + case <-c.closed: + return n, c.closeErr + case c.readTimeout <- context.Background(): } return n, err } -func (cr *connReader) control(ctx context.Context, h header) error { - if h.payloadLength < 0 { - err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength) - cr.c.cw.error(StatusProtocolError, err) - return err - } - - if h.payloadLength > maxControlPayload { - err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) - cr.c.cw.error(StatusProtocolError, err) +func (c *Conn) handleControl(ctx context.Context, h header) error { + if h.payloadLength < 0 || h.payloadLength > maxControlPayload { + err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + c.writeError(StatusProtocolError, err) return err } if !h.fin { err := errors.New("received fragmented control frame") - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return err } ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - b := cr.controlPayloadBuf[:h.payloadLength] - _, err := cr.framePayload(ctx, b) + b := c.readControlBuf[:h.payloadLength] + _, err := c.readFramePayload(ctx, b) if err != nil { return err } @@ -277,11 +226,11 @@ func (cr *connReader) control(ctx context.Context, h header) error { switch h.opcode { case opPing: - return cr.c.cw.control(ctx, opPong, b) + return c.writeControl(ctx, opPong, b) case opPong: - cr.c.activePingsMu.Lock() - pong, ok := cr.c.activePings[string(b)] - cr.c.activePingsMu.Unlock() + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() if ok { close(pong) } @@ -291,53 +240,56 @@ func (cr *connReader) control(ctx context.Context, h header) error { ce, err := parseClosePayload(b) if err != nil { err = fmt.Errorf("received invalid close payload: %w", err) - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return err } err = fmt.Errorf("received close frame: %w", ce) - cr.c.setCloseErr(err) - cr.c.cw.control(context.Background(), opClose, ce.bytes()) + c.setCloseErr(err) + c.writeControl(context.Background(), opClose, ce.bytes()) return err } -func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) { - err := cr.mu.Lock(ctx) +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { + defer errd.Wrap(&err, "failed to get reader") + + err = c.readMu.Lock(ctx) if err != nil { return 0, nil, err } - defer cr.mu.Unlock() + defer c.readMu.Unlock() - if !cr.mr.fin { + if !c.msgReader.fin { return 0, nil, errors.New("previous message not read to completion") } - h, err := cr.loop(ctx) + h, err := c.readLoop(ctx) if err != nil { return 0, nil, err } if h.opcode == opContinuation { err := errors.New("received continuation frame without text or binary frame") - cr.c.cw.error(StatusProtocolError, err) + c.writeError(StatusProtocolError, err) return 0, nil, err } - cr.mr.reset(ctx, h) + c.msgReader.reset(ctx, h) - return MessageType(h.opcode), cr.mr, nil + return MessageType(h.opcode), c.msgReader, nil } type msgReader struct { - cr *connReader - fr io.Reader - lr *limitReader + c *Conn ctx context.Context deflate bool + flateReader io.Reader deflateTail strings.Reader + limitReader *limitReader + payloadLength int64 maskKey uint32 fin bool @@ -348,8 +300,8 @@ func (mr *msgReader) reset(ctx context.Context, h header) { mr.deflate = h.rsv1 if mr.deflate { mr.deflateTail.Reset(deflateMessageTail) - if !mr.cr.contextTakeover() { - mr.cr.ensureFlateReader() + if !mr.contextTakeover() { + mr.ensureFlateReader() } } mr.setFrame(h) @@ -370,34 +322,42 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) { } }() - err = mr.cr.mu.Lock(mr.ctx) + err = mr.c.readMu.Lock(mr.ctx) if err != nil { return 0, err } - defer mr.cr.mu.Unlock() + defer mr.c.readMu.Unlock() if mr.payloadLength == 0 && mr.fin { - if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() { - if mr.fr != nil { - putFlateReader(mr.fr) - mr.fr = nil + if mr.c.deflateNegotiated() && !mr.contextTakeover() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil } } return 0, io.EOF } - return mr.lr.Read(p) + return mr.limitReader.Read(p) } func (mr *msgReader) read(p []byte) (int, error) { if mr.payloadLength == 0 { - h, err := mr.cr.loop(mr.ctx) + if mr.fin { + if mr.deflate { + n, _ := mr.deflateTail.Read(p[:4]) + return n, nil + } + return 0, io.EOF + } + + h, err := mr.c.readLoop(mr.ctx) if err != nil { return 0, err } if h.opcode != opContinuation { err := errors.New("received new data message without finishing the previous message") - mr.cr.c.cw.error(StatusProtocolError, err) + mr.c.writeError(StatusProtocolError, err) return 0, err } mr.setFrame(h) @@ -407,14 +367,14 @@ func (mr *msgReader) read(p []byte) (int, error) { p = p[:mr.payloadLength] } - n, err := mr.cr.framePayload(mr.ctx, p) + n, err := mr.c.readFramePayload(mr.ctx, p) if err != nil { return n, err } mr.payloadLength -= int64(n) - if !mr.cr.c.client { + if !mr.c.client { mr.maskKey = mask(mr.maskKey, p) } @@ -442,10 +402,14 @@ func (lr *limitReader) reset(r io.Reader) { lr.r = r } +func (lr *limitReader) setLimit(limit int64) { + lr.limit.Store(limit) +} + func (lr *limitReader) Read(p []byte) (int, error) { if lr.n <= 0 { err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) - lr.c.cw.error(StatusMessageTooBig, err) + lr.c.writeError(StatusMessageTooBig, err) return 0, err } diff --git a/write.go b/write.go index 9cafc5c..0ddf11e 100644 --- a/write.go +++ b/write.go @@ -24,7 +24,7 @@ import ( // // Never close the returned writer twice. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - w, err := c.cw.writer(ctx, typ) + w, err := c.writer(ctx, typ) if err != nil { return nil, fmt.Errorf("failed to get writer: %w", err) } @@ -38,111 +38,68 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err // If compression is disabled, then it is guaranteed to write the message // in a single frame. func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - _, err := c.cw.write(ctx, typ, p) + _, err := c.write(ctx, typ, p) if err != nil { return fmt.Errorf("failed to write msg: %w", err) } return nil } -type connWriter struct { - c *Conn - bw *bufio.Writer - - writeBuf []byte - - mw *messageWriter - frameMu mu - h header - - timeout chan context.Context +func (mw *msgWriter) ensureFlateWriter() { + mw.flateWriter = getFlateWriter(mw.trimWriter) } -func (cw *connWriter) init(c *Conn, bw *bufio.Writer) { - cw.c = c - cw.bw = bw - - if cw.c.client { - cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc) - } - - cw.timeout = make(chan context.Context) - - cw.mw = &messageWriter{ - cw: cw, +func (mw *msgWriter) contextTakeover() bool { + if mw.c.client { + return mw.c.copts.clientNoContextTakeover } - cw.mw.tw = &trimLastFourBytesWriter{ - w: writerFunc(cw.mw.write), - } - if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { - cw.mw.ensureFlateWriter() - } -} - -func (mw *messageWriter) ensureFlateWriter() { - mw.fw = getFlateWriter(mw.tw) + return mw.c.copts.serverNoContextTakeover } -func (cw *connWriter) close() { - if cw.c.client { - cw.frameMu.Lock(context.Background()) - putBufioWriter(cw.bw) - } - if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { - cw.mw.mu.Lock(context.Background()) - putFlateWriter(cw.mw.fw) - } -} - -func (mw *messageWriter) contextTakeover() bool { - if mw.cw.c.client { - return mw.cw.c.copts.clientNoContextTakeover - } - return mw.cw.c.copts.serverNoContextTakeover -} - -func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := cw.mw.reset(ctx, typ) +func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := c.msgWriter.reset(ctx, typ) if err != nil { return nil, err } - return cw.mw, nil + return c.msgWriter, nil } -func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) { - ww, err := cw.writer(ctx, typ) +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + mw, err := c.writer(ctx, typ) if err != nil { return 0, err } - if !cw.c.deflateNegotiated() { + if !c.deflateNegotiated() { // Fast single frame path. - defer cw.mw.mu.Unlock() - return cw.frame(ctx, true, cw.mw.opcode, p) + defer c.msgWriter.mu.Unlock() + return c.writeFrame(ctx, true, c.msgWriter.opcode, p) } - n, err := ww.Write(p) + n, err := mw.Write(p) if err != nil { return n, err } - err = ww.Close() + err = mw.Close() return n, err } -type messageWriter struct { - cw *connWriter +type msgWriter struct { + c *Conn - mu mu - compress bool - tw *trimLastFourBytesWriter - fw *flate.Writer - ctx context.Context - opcode opcode - closed bool + mu mu + + deflate bool + ctx context.Context + opcode opcode + closed bool + + trimWriter *trimLastFourBytesWriter + flateWriter *flate.Writer } -func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error { +func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { err := mw.mu.Lock(ctx) if err != nil { return err @@ -155,30 +112,30 @@ func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error { } // Write writes the given bytes to the WebSocket connection. -func (mw *messageWriter) Write(p []byte) (_ int, err error) { +func (mw *msgWriter) Write(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write") if mw.closed { return 0, errors.New("cannot use closed writer") } - if mw.cw.c.deflateNegotiated() { - if !mw.compress { + if mw.c.deflateNegotiated() { + if !mw.deflate { if !mw.contextTakeover() { mw.ensureFlateWriter() } - mw.tw.reset() - mw.compress = true + mw.trimWriter.reset() + mw.deflate = true } - return mw.fw.Write(p) + return mw.flateWriter.Write(p) } return mw.write(p) } -func (mw *messageWriter) write(p []byte) (int, error) { - n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p) +func (mw *msgWriter) write(p []byte) (int, error) { + n, err := mw.c.writeFrame(mw.ctx, false, mw.opcode, p) if err != nil { return n, fmt.Errorf("failed to write data frame: %w", err) } @@ -187,8 +144,7 @@ func (mw *messageWriter) write(p []byte) (int, error) { } // Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (mw *messageWriter) Close() (err error) { +func (mw *msgWriter) Close() (err error) { defer errd.Wrap(&err, "failed to close writer") if mw.closed { @@ -196,32 +152,39 @@ func (mw *messageWriter) Close() (err error) { } mw.closed = true - if mw.cw.c.deflateNegotiated() { - err = mw.fw.Flush() + if mw.c.deflateNegotiated() { + err = mw.flateWriter.Flush() if err != nil { return fmt.Errorf("failed to flush flate writer: %w", err) } } - _, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil) + _, err = mw.c.writeFrame(mw.ctx, true, mw.opcode, nil) if err != nil { return fmt.Errorf("failed to write fin frame: %w", err) } - if mw.compress && !mw.contextTakeover() { - putFlateWriter(mw.fw) - mw.compress = false + if mw.deflate && !mw.contextTakeover() { + putFlateWriter(mw.flateWriter) + mw.deflate = false } mw.mu.Unlock() return nil } -func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error { +func (cw *msgWriter) close() { + if cw.c.deflateNegotiated() && cw.contextTakeover() { + cw.mu.Lock(context.Background()) + putFlateWriter(cw.flateWriter) + } +} + +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { ctx, cancel := context.WithTimeout(ctx, time.Second*5) defer cancel() - _, err := cw.frame(ctx, true, opcode, p) + _, err := c.writeFrame(ctx, true, opcode, p) if err != nil { return fmt.Errorf("failed to write control frame %v: %w", opcode, err) } @@ -229,94 +192,94 @@ func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) erro } // frame handles all writes to the connection. -func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - err := cw.frameMu.Lock(ctx) +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + err := c.writeFrameMu.Lock(ctx) if err != nil { return 0, err } - defer cw.frameMu.Unlock() + defer c.writeFrameMu.Unlock() select { - case <-cw.c.closed: - return 0, cw.c.closeErr - case cw.timeout <- ctx: + case <-c.closed: + return 0, c.closeErr + case c.writeTimeout <- ctx: } - cw.h.fin = fin - cw.h.opcode = opcode - cw.h.masked = cw.c.client - cw.h.payloadLength = int64(len(p)) - - cw.h.rsv1 = false - if cw.mw.compress && (opcode == opText || opcode == opBinary) { - cw.h.rsv1 = true - } + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.payloadLength = int64(len(p)) - if cw.h.masked { - err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey) + if c.client { + c.writeHeader.masked = true + err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) if err != nil { return 0, fmt.Errorf("failed to generate masking key: %w", err) } } - err = writeFrameHeader(cw.h, cw.bw) + c.writeHeader.rsv1 = false + if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) { + c.writeHeader.rsv1 = true + } + + err = writeFrameHeader(c.writeHeader, c.bw) if err != nil { return 0, err } - n, err := cw.framePayload(p) + n, err := c.writeFramePayload(p) if err != nil { return n, err } - if cw.h.fin { - err = cw.bw.Flush() + if c.writeHeader.fin { + err = c.bw.Flush() if err != nil { return n, fmt.Errorf("failed to flush: %w", err) } } select { - case <-cw.c.closed: - return n, cw.c.closeErr - case cw.timeout <- context.Background(): + case <-c.closed: + return n, c.closeErr + case c.writeTimeout <- context.Background(): } return n, nil } -func (cw *connWriter) framePayload(p []byte) (_ int, err error) { +func (c *Conn) writeFramePayload(p []byte) (_ int, err error) { defer errd.Wrap(&err, "failed to write frame payload") - if !cw.h.masked { - return cw.bw.Write(p) + if !c.writeHeader.masked { + return c.bw.Write(p) } var n int - maskKey := cw.h.maskKey + maskKey := c.writeHeader.maskKey for len(p) > 0 { // If the buffer is full, we need to flush. - if cw.bw.Available() == 0 { - err = cw.bw.Flush() + if c.bw.Available() == 0 { + err = c.bw.Flush() if err != nil { return n, err } } // Start of next write in the buffer. - i := cw.bw.Buffered() + i := c.bw.Buffered() j := len(p) - if j > cw.bw.Available() { - j = cw.bw.Available() + if j > c.bw.Available() { + j = c.bw.Available() } - _, err := cw.bw.Write(p[:j]) + _, err := c.bw.Write(p[:j]) if err != nil { return n, err } - maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()]) + maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) p = p[j:] n += j diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 99996a6..36dd2df 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,7 +5,6 @@ import ( "context" "encoding/json" "fmt" - "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/errd" -- GitLab