good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 43cb01ea authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Refactor read.go/write.go

parent 746140b8
No related branches found
No related tags found
No related merge requests found
......@@ -24,7 +24,7 @@ go get nhooyr.io/websocket
- Concurrent writes
- [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close)
- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
- [Pings](https://godoc.org/nhooyr.io/websocket#Conn.Ping)
- [Ping pong](https://godoc.org/nhooyr.io/websocket#Conn.Ping)
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- Compile to [Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm)
......@@ -88,26 +88,27 @@ c.Close(websocket.StatusNormalClosure, "")
[gorilla/websocket](https://github.com/gorilla/websocket) is a widely used and mature library.
Advantages of nhooyr.io/websocket:
- Minimal and idiomatic API
- Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side.
- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support
- Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing
- Will enable easy HTTP/2 support in the future
- Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client.
- Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API
- gorilla/websocket requires registering a pong callback and then sending a Ping
- Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation depends on unsafe and is slower
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Minimal and idiomatic API
- Compare godoc of [nhooyr.io/websocket](https://godoc.org/nhooyr.io/websocket) with [gorilla/websocket](https://godoc.org/github.com/gorilla/websocket) side by side.
- [net.Conn](https://godoc.org/nhooyr.io/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support
- Uses [net/http.Client](https://golang.org/pkg/net/http/#Client) for dialing
- Will enable easy HTTP/2 support in the future
- Gorilla writes directly to a net.Conn and so duplicates features from net/http.Client.
- Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping](https://godoc.org/nhooyr.io/websocket#Conn.Ping) API
- gorilla/websocket requires registering a pong callback and then sending a Ping
- Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation depends on unsafe and is slower
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
- [CloseRead](https://godoc.org/nhooyr.io/websocket#Conn.CloseRead) helper
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
#### golang.org/x/net/websocket
......@@ -120,7 +121,7 @@ to nhooyr.io/websocket.
#### gobwas/ws
[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use.
......
......@@ -4,12 +4,11 @@ import (
"context"
"crypto/rand"
"io"
"strings"
"testing"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/assert"
"nhooyr.io/websocket/wsjson"
"strings"
"testing"
)
func randBytes(t *testing.T, n int) []byte {
......@@ -21,12 +20,15 @@ func randBytes(t *testing.T, n int) []byte {
func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) {
t.Helper()
defer c.Close(websocket.StatusInternalError, "")
exp := randString(t, n)
err := wsjson.Write(ctx, c, exp)
assert.Success(t, err)
assertJSONRead(t, ctx, c, exp)
c.Close(websocket.StatusNormalClosure, "")
}
func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) {
......@@ -74,5 +76,10 @@ func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) {
func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) {
t.Helper()
defer func() {
if t.Failed() {
t.Logf("error: %+v", err)
}
}()
assert.Equal(t, exp, websocket.CloseStatus(err), "StatusCode")
}
......@@ -7,9 +7,6 @@ import (
"fmt"
"log"
"nhooyr.io/websocket/internal/errd"
"time"
"nhooyr.io/websocket/internal/bpool"
)
// StatusCode represents a WebSocket status code.
......@@ -103,59 +100,58 @@ func (c *Conn) Close(code StatusCode, reason string) error {
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
err = c.cw.sendClose(code, reason)
err = c.writeClose(code, reason)
if err != nil {
return err
}
return c.cr.waitClose()
return c.waitClose()
}
func (cw *connWriter) error(code StatusCode, err error) {
cw.c.setCloseErr(err)
cw.sendClose(code, err.Error())
cw.c.closeWithErr(nil)
func (c *Conn) writeError(code StatusCode, err error) {
c.setCloseErr(err)
c.writeClose(code, err.Error())
c.closeWithErr(nil)
}
func (cw *connWriter) sendClose(code StatusCode, reason string) error {
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
var p []byte
if ce.Code != StatusNoStatusRcvd {
p = ce.bytes()
}
return cw.control(context.Background(), opClose, p)
return c.writeControl(context.Background(), opClose, p)
}
func (cr *connReader) waitClose() error {
defer cr.c.closeWithErr(nil)
func (c *Conn) waitClose() error {
defer c.closeWithErr(nil)
return nil
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := cr.mu.Lock(ctx)
if err != nil {
return err
}
defer cr.mu.Unlock()
b := bpool.Get()
buf := b.Bytes()
buf = buf[:cap(buf)]
defer bpool.Put(b)
for {
// TODO
return nil
}
// ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
// defer cancel()
//
// err := cr.mu.Lock(ctx)
// if err != nil {
// return err
// }
// defer cr.mu.Unlock()
//
// b := bpool.Get()
// buf := b.Bytes()
// buf = buf[:cap(buf)]
// defer bpool.Put(b)
//
// for {
// return nil
// }
}
func parseClosePayload(p []byte) (CloseError, error) {
......@@ -230,11 +226,11 @@ func (ce CloseError) bytesErr() ([]byte, error) {
func (c *Conn) setCloseErr(err error) {
c.closeMu.Lock()
c.setCloseErrNoLock(err)
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}
func (c *Conn) setCloseErrNoLock(err error) {
func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
......
......@@ -30,7 +30,7 @@ const (
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See the docs on Reader and CloseRead.
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
......@@ -42,9 +42,22 @@ type Conn struct {
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
br *bufio.Reader
bw *bufio.Writer
cr connReader
cw connWriter
readTimeout chan context.Context
writeTimeout chan context.Context
// Read state.
readMu mu
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriter *msgWriter
writeFrameMu mu
writeBuf []byte
writeHeader header
closed chan struct{}
......@@ -63,8 +76,8 @@ type connConfig struct {
client bool
copts *compressionOptions
bw *bufio.Writer
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
......@@ -73,13 +86,23 @@ func newConn(cfg connConfig) *Conn {
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
}
c.cr.init(c, cfg.br)
c.cw.init(c, cfg.bw)
c.msgReader = newMsgReader(c)
c.closed = make(chan struct{})
c.activePings = make(map[string]chan<- struct{})
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
runtime.SetFinalizer(c, func(c *Conn) {
c.closeWithErr(errors.New("connection garbage collected"))
......@@ -90,6 +113,34 @@ func newConn(cfg connConfig) *Conn {
return c
}
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.limitReader = newLimitReader(c, readerFunc(mr.read), 32768)
if c.deflateNegotiated() && mr.contextTakeover() {
mr.ensureFlateReader()
}
return mr
}
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
}
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
}
if c.deflateNegotiated() && mw.contextTakeover() {
mw.ensureFlateWriter()
}
return mw
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
......@@ -105,7 +156,7 @@ func (c *Conn) closeWithErr(err error) {
}
close(c.closed)
runtime.SetFinalizer(c, nil)
c.setCloseErrNoLock(err)
c.setCloseErrLocked(err)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
......@@ -113,8 +164,18 @@ func (c *Conn) closeWithErr(err error) {
c.rwc.Close()
go func() {
c.cr.close()
c.cw.close()
if c.client {
c.writeFrameMu.Lock(context.Background())
putBufioWriter(c.bw)
}
c.msgWriter.close()
if c.client {
c.readMu.Lock(context.Background())
putBufioReader(c.br)
c.readMu.Unlock()
}
c.msgReader.close()
}()
}
......@@ -127,13 +188,12 @@ func (c *Conn) timeoutLoop() {
case <-c.closed:
return
case writeCtx = <-c.cw.timeout:
case readCtx = <-c.cr.timeout:
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
c.cw.error(StatusPolicyViolation, errors.New("timed out"))
return
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
case <-writeCtx.Done():
c.closeWithErr(fmt.Errorf("write timed out: %w", writeCtx.Err()))
return
......@@ -175,7 +235,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
c.activePingsMu.Unlock()
}()
err := c.cw.control(ctx, opPing, []byte(p))
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
......
......@@ -25,6 +25,7 @@ func TestConn(t *testing.T) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"echo"},
InsecureSkipVerify: true,
// CompressionMode: websocket.CompressionDisabled,
})
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
......@@ -41,12 +42,12 @@ func TestConn(t *testing.T) {
opts := &websocket.DialOptions{
Subprotocols: []string{"echo"},
// CompressionMode: websocket.CompressionDisabled,
}
opts.HTTPClient = s.Client()
c, _, err := websocket.Dial(ctx, wsURL, opts)
assert.Success(t, err)
assertJSONEcho(t, ctx, c, 2)
})
}
......
......@@ -23,7 +23,7 @@ func NotEqual(t testing.TB, exp, act interface{}, name string) {
func Success(t testing.TB, err error) {
t.Helper()
if err != nil {
t.Fatalf("unexpected error : %+v", err)
t.Fatalf("unexpected error: %+v", err)
}
}
......
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
......@@ -14,41 +13,22 @@ import (
"nhooyr.io/websocket/internal/errd"
)
// Reader waits until there is a WebSocket data message to read
// from the connection.
// It returns the type of the message and a reader to read it.
// Reader reads from the connection until until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// All returned errors will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
//
// You must read from the connection for control frames to be handled.
// Thus if you expect messages to take a long time to be responded to,
// you should handle such messages async to reading from the connection
// to ensure control frames are promptly handled.
//
// If you do not expect any data messages from the peer, call CloseRead.
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and then the message
// Read, use time.AfterFunc to cancel the context passed in early.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
typ, r, err := c.cr.reader(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to get reader: %w", err)
}
return typ, r, nil
return c.reader(ctx)
}
// Read is a convenience method to read a single message from the connection.
//
// See the Reader method to reuse buffers or for streaming.
// The docs on Reader apply to this method as well.
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
......@@ -59,14 +39,17 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
return typ, b, err
}
// CloseRead will start a goroutine to read from the connection until it is closed or a data message
// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
// After calling this method, you cannot read any data messages from the connection.
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// Use this when you do not want to read data messages from the connection anymore but will
// want to write messages to it.
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
ctx, cancel := context.WithCancel(ctx)
go func() {
......@@ -84,60 +67,32 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context {
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
c.cr.mr.lr.limit.Store(n)
}
type connReader struct {
c *Conn
br *bufio.Reader
timeout chan context.Context
mu mu
controlPayloadBuf [maxControlPayload]byte
mr *msgReader
}
func (cr *connReader) init(c *Conn, br *bufio.Reader) {
cr.c = c
cr.br = br
cr.timeout = make(chan context.Context)
cr.mr = &msgReader{
cr: cr,
fin: true,
}
cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768)
if c.deflateNegotiated() && cr.contextTakeover() {
cr.ensureFlateReader()
}
c.msgReader.limitReader.setLimit(n)
}
func (cr *connReader) ensureFlateReader() {
cr.mr.fr = getFlateReader(readerFunc(cr.mr.read))
cr.mr.lr.reset(cr.mr.fr)
func (mr *msgReader) ensureFlateReader() {
mr.flateReader = getFlateReader(readerFunc(mr.read))
mr.limitReader.reset(mr.flateReader)
}
func (cr *connReader) close() {
cr.mu.Lock(context.Background())
if cr.c.client {
putBufioReader(cr.br)
}
if cr.c.deflateNegotiated() && cr.contextTakeover() {
putFlateReader(cr.mr.fr)
func (mr *msgReader) close() {
if mr.c.deflateNegotiated() && mr.contextTakeover() {
mr.c.readMu.Lock(context.Background())
putFlateReader(mr.flateReader)
mr.c.readMu.Unlock()
}
}
func (cr *connReader) contextTakeover() bool {
if cr.c.client {
return cr.c.copts.serverNoContextTakeover
func (mr *msgReader) contextTakeover() bool {
if mr.c.client {
return mr.c.copts.serverNoContextTakeover
}
return cr.c.copts.clientNoContextTakeover
return mr.c.copts.clientNoContextTakeover
}
func (cr *connReader) rsv1Illegal(h header) bool {
func (c *Conn) readRSV1Illegal(h header) bool {
// If compression is enabled, rsv1 is always illegal.
if !cr.c.deflateNegotiated() {
if !c.deflateNegotiated() {
return true
}
// rsv1 is only allowed on data frames beginning messages.
......@@ -147,26 +102,26 @@ func (cr *connReader) rsv1Illegal(h header) bool {
return false
}
func (cr *connReader) loop(ctx context.Context) (header, error) {
func (c *Conn) readLoop(ctx context.Context) (header, error) {
for {
h, err := cr.frameHeader(ctx)
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 {
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
cr.c.cw.error(StatusProtocolError, err)
c.writeError(StatusProtocolError, err)
return header{}, err
}
if !cr.c.client && !h.masked {
if !c.client && !h.masked {
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
err = cr.control(ctx, h)
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
......@@ -178,95 +133,89 @@ func (cr *connReader) loop(ctx context.Context) (header, error) {
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
cr.c.cw.error(StatusProtocolError, err)
c.writeError(StatusProtocolError, err)
return header{}, err
}
}
}
func (cr *connReader) frameHeader(ctx context.Context) (header, error) {
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
select {
case <-cr.c.closed:
return header{}, cr.c.closeErr
case cr.timeout <- ctx:
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- ctx:
}
h, err := readFrameHeader(cr.br)
h, err := readFrameHeader(c.br)
if err != nil {
select {
case <-cr.c.closed:
return header{}, cr.c.closeErr
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
return header{}, ctx.Err()
default:
cr.c.closeWithErr(err)
c.closeWithErr(err)
return header{}, err
}
}
select {
case <-cr.c.closed:
return header{}, cr.c.closeErr
case cr.timeout <- context.Background():
case <-c.closed:
return header{}, c.closeErr
case c.readTimeout <- context.Background():
}
return h, nil
}
func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) {
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
select {
case <-cr.c.closed:
return 0, cr.c.closeErr
case cr.timeout <- ctx:
case <-c.closed:
return 0, c.closeErr
case c.readTimeout <- ctx:
}
n, err := io.ReadFull(cr.br, p)
n, err := io.ReadFull(c.br, p)
if err != nil {
select {
case <-cr.c.closed:
return n, cr.c.closeErr
case <-c.closed:
return n, c.closeErr
case <-ctx.Done():
return n, ctx.Err()
default:
err = fmt.Errorf("failed to read frame payload: %w", err)
cr.c.closeWithErr(err)
c.closeWithErr(err)
return n, err
}
}
select {
case <-cr.c.closed:
return n, cr.c.closeErr
case cr.timeout <- context.Background():
case <-c.closed:
return n, c.closeErr
case c.readTimeout <- context.Background():
}
return n, err
}
func (cr *connReader) control(ctx context.Context, h header) error {
if h.payloadLength < 0 {
err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength)
cr.c.cw.error(StatusProtocolError, err)
return err
}
if h.payloadLength > maxControlPayload {
err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength)
cr.c.cw.error(StatusProtocolError, err)
func (c *Conn) handleControl(ctx context.Context, h header) error {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
cr.c.cw.error(StatusProtocolError, err)
c.writeError(StatusProtocolError, err)
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := cr.controlPayloadBuf[:h.payloadLength]
_, err := cr.framePayload(ctx, b)
b := c.readControlBuf[:h.payloadLength]
_, err := c.readFramePayload(ctx, b)
if err != nil {
return err
}
......@@ -277,11 +226,11 @@ func (cr *connReader) control(ctx context.Context, h header) error {
switch h.opcode {
case opPing:
return cr.c.cw.control(ctx, opPong, b)
return c.writeControl(ctx, opPong, b)
case opPong:
cr.c.activePingsMu.Lock()
pong, ok := cr.c.activePings[string(b)]
cr.c.activePingsMu.Unlock()
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
close(pong)
}
......@@ -291,53 +240,56 @@ func (cr *connReader) control(ctx context.Context, h header) error {
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
cr.c.cw.error(StatusProtocolError, err)
c.writeError(StatusProtocolError, err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
cr.c.setCloseErr(err)
cr.c.cw.control(context.Background(), opClose, ce.bytes())
c.setCloseErr(err)
c.writeControl(context.Background(), opClose, ce.bytes())
return err
}
func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) {
err := cr.mu.Lock(ctx)
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.Lock(ctx)
if err != nil {
return 0, nil, err
}
defer cr.mu.Unlock()
defer c.readMu.Unlock()
if !cr.mr.fin {
if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
}
h, err := cr.loop(ctx)
h, err := c.readLoop(ctx)
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
cr.c.cw.error(StatusProtocolError, err)
c.writeError(StatusProtocolError, err)
return 0, nil, err
}
cr.mr.reset(ctx, h)
c.msgReader.reset(ctx, h)
return MessageType(h.opcode), cr.mr, nil
return MessageType(h.opcode), c.msgReader, nil
}
type msgReader struct {
cr *connReader
fr io.Reader
lr *limitReader
c *Conn
ctx context.Context
deflate bool
flateReader io.Reader
deflateTail strings.Reader
limitReader *limitReader
payloadLength int64
maskKey uint32
fin bool
......@@ -348,8 +300,8 @@ func (mr *msgReader) reset(ctx context.Context, h header) {
mr.deflate = h.rsv1
if mr.deflate {
mr.deflateTail.Reset(deflateMessageTail)
if !mr.cr.contextTakeover() {
mr.cr.ensureFlateReader()
if !mr.contextTakeover() {
mr.ensureFlateReader()
}
}
mr.setFrame(h)
......@@ -370,34 +322,42 @@ func (mr *msgReader) Read(p []byte) (_ int, err error) {
}
}()
err = mr.cr.mu.Lock(mr.ctx)
err = mr.c.readMu.Lock(mr.ctx)
if err != nil {
return 0, err
}
defer mr.cr.mu.Unlock()
defer mr.c.readMu.Unlock()
if mr.payloadLength == 0 && mr.fin {
if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() {
if mr.fr != nil {
putFlateReader(mr.fr)
mr.fr = nil
if mr.c.deflateNegotiated() && !mr.contextTakeover() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
return 0, io.EOF
}
return mr.lr.Read(p)
return mr.limitReader.Read(p)
}
func (mr *msgReader) read(p []byte) (int, error) {
if mr.payloadLength == 0 {
h, err := mr.cr.loop(mr.ctx)
if mr.fin {
if mr.deflate {
n, _ := mr.deflateTail.Read(p[:4])
return n, nil
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
mr.cr.c.cw.error(StatusProtocolError, err)
mr.c.writeError(StatusProtocolError, err)
return 0, err
}
mr.setFrame(h)
......@@ -407,14 +367,14 @@ func (mr *msgReader) read(p []byte) (int, error) {
p = p[:mr.payloadLength]
}
n, err := mr.cr.framePayload(mr.ctx, p)
n, err := mr.c.readFramePayload(mr.ctx, p)
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
if !mr.cr.c.client {
if !mr.c.client {
mr.maskKey = mask(mr.maskKey, p)
}
......@@ -442,10 +402,14 @@ func (lr *limitReader) reset(r io.Reader) {
lr.r = r
}
func (lr *limitReader) setLimit(limit int64) {
lr.limit.Store(limit)
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n <= 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.cw.error(StatusMessageTooBig, err)
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
}
......
......@@ -24,7 +24,7 @@ import (
//
// Never close the returned writer twice.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.cw.writer(ctx, typ)
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
......@@ -38,111 +38,68 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err
// If compression is disabled, then it is guaranteed to write the message
// in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.cw.write(ctx, typ, p)
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type connWriter struct {
c *Conn
bw *bufio.Writer
writeBuf []byte
mw *messageWriter
frameMu mu
h header
timeout chan context.Context
func (mw *msgWriter) ensureFlateWriter() {
mw.flateWriter = getFlateWriter(mw.trimWriter)
}
func (cw *connWriter) init(c *Conn, bw *bufio.Writer) {
cw.c = c
cw.bw = bw
if cw.c.client {
cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc)
}
cw.timeout = make(chan context.Context)
cw.mw = &messageWriter{
cw: cw,
func (mw *msgWriter) contextTakeover() bool {
if mw.c.client {
return mw.c.copts.clientNoContextTakeover
}
cw.mw.tw = &trimLastFourBytesWriter{
w: writerFunc(cw.mw.write),
}
if cw.c.deflateNegotiated() && cw.mw.contextTakeover() {
cw.mw.ensureFlateWriter()
}
}
func (mw *messageWriter) ensureFlateWriter() {
mw.fw = getFlateWriter(mw.tw)
return mw.c.copts.serverNoContextTakeover
}
func (cw *connWriter) close() {
if cw.c.client {
cw.frameMu.Lock(context.Background())
putBufioWriter(cw.bw)
}
if cw.c.deflateNegotiated() && cw.mw.contextTakeover() {
cw.mw.mu.Lock(context.Background())
putFlateWriter(cw.mw.fw)
}
}
func (mw *messageWriter) contextTakeover() bool {
if mw.cw.c.client {
return mw.cw.c.copts.clientNoContextTakeover
}
return mw.cw.c.copts.serverNoContextTakeover
}
func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := cw.mw.reset(ctx, typ)
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return cw.mw, nil
return c.msgWriter, nil
}
func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
ww, err := cw.writer(ctx, typ)
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !cw.c.deflateNegotiated() {
if !c.deflateNegotiated() {
// Fast single frame path.
defer cw.mw.mu.Unlock()
return cw.frame(ctx, true, cw.mw.opcode, p)
defer c.msgWriter.mu.Unlock()
return c.writeFrame(ctx, true, c.msgWriter.opcode, p)
}
n, err := ww.Write(p)
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = ww.Close()
err = mw.Close()
return n, err
}
type messageWriter struct {
cw *connWriter
type msgWriter struct {
c *Conn
mu mu
compress bool
tw *trimLastFourBytesWriter
fw *flate.Writer
ctx context.Context
opcode opcode
closed bool
mu mu
deflate bool
ctx context.Context
opcode opcode
closed bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate.Writer
}
func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error {
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.Lock(ctx)
if err != nil {
return err
......@@ -155,30 +112,30 @@ func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error {
}
// Write writes the given bytes to the WebSocket connection.
func (mw *messageWriter) Write(p []byte) (_ int, err error) {
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write")
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
if mw.cw.c.deflateNegotiated() {
if !mw.compress {
if mw.c.deflateNegotiated() {
if !mw.deflate {
if !mw.contextTakeover() {
mw.ensureFlateWriter()
}
mw.tw.reset()
mw.compress = true
mw.trimWriter.reset()
mw.deflate = true
}
return mw.fw.Write(p)
return mw.flateWriter.Write(p)
}
return mw.write(p)
}
func (mw *messageWriter) write(p []byte) (int, error) {
n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p)
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
......@@ -187,8 +144,7 @@ func (mw *messageWriter) write(p []byte) (int, error) {
}
// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (mw *messageWriter) Close() (err error) {
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
if mw.closed {
......@@ -196,32 +152,39 @@ func (mw *messageWriter) Close() (err error) {
}
mw.closed = true
if mw.cw.c.deflateNegotiated() {
err = mw.fw.Flush()
if mw.c.deflateNegotiated() {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate writer: %w", err)
}
}
_, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil)
_, err = mw.c.writeFrame(mw.ctx, true, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.compress && !mw.contextTakeover() {
putFlateWriter(mw.fw)
mw.compress = false
if mw.deflate && !mw.contextTakeover() {
putFlateWriter(mw.flateWriter)
mw.deflate = false
}
mw.mu.Unlock()
return nil
}
func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error {
func (cw *msgWriter) close() {
if cw.c.deflateNegotiated() && cw.contextTakeover() {
cw.mu.Lock(context.Background())
putFlateWriter(cw.flateWriter)
}
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := cw.frame(ctx, true, opcode, p)
_, err := c.writeFrame(ctx, true, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
......@@ -229,94 +192,94 @@ func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) erro
}
// frame handles all writes to the connection.
func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
err := cw.frameMu.Lock(ctx)
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
err := c.writeFrameMu.Lock(ctx)
if err != nil {
return 0, err
}
defer cw.frameMu.Unlock()
defer c.writeFrameMu.Unlock()
select {
case <-cw.c.closed:
return 0, cw.c.closeErr
case cw.timeout <- ctx:
case <-c.closed:
return 0, c.closeErr
case c.writeTimeout <- ctx:
}
cw.h.fin = fin
cw.h.opcode = opcode
cw.h.masked = cw.c.client
cw.h.payloadLength = int64(len(p))
cw.h.rsv1 = false
if cw.mw.compress && (opcode == opText || opcode == opBinary) {
cw.h.rsv1 = true
}
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if cw.h.masked {
err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey)
if c.client {
c.writeHeader.masked = true
err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey)
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
}
err = writeFrameHeader(cw.h, cw.bw)
c.writeHeader.rsv1 = false
if c.msgWriter.deflate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw)
if err != nil {
return 0, err
}
n, err := cw.framePayload(p)
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if cw.h.fin {
err = cw.bw.Flush()
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
select {
case <-cw.c.closed:
return n, cw.c.closeErr
case cw.timeout <- context.Background():
case <-c.closed:
return n, c.closeErr
case c.writeTimeout <- context.Background():
}
return n, nil
}
func (cw *connWriter) framePayload(p []byte) (_ int, err error) {
func (c *Conn) writeFramePayload(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !cw.h.masked {
return cw.bw.Write(p)
if !c.writeHeader.masked {
return c.bw.Write(p)
}
var n int
maskKey := cw.h.maskKey
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if cw.bw.Available() == 0 {
err = cw.bw.Flush()
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := cw.bw.Buffered()
i := c.bw.Buffered()
j := len(p)
if j > cw.bw.Available() {
j = cw.bw.Available()
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := cw.bw.Write(p[:j])
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()])
maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()])
p = p[j:]
n += j
......
......@@ -5,7 +5,6 @@ import (
"context"
"encoding/json"
"fmt"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"nhooyr.io/websocket/internal/errd"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment