diff --git a/.gitignore b/.gitignore deleted file mode 100644 index 7fffaa26ba2cb59a5768fb1bbbf7d04b45677f93..0000000000000000000000000000000000000000 --- a/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -wstest_reports -websocket.test -profs diff --git a/README.md b/README.md index 9c44e4e1445f99cf27aee868cf2b16af6fd69641..4199423cbcd86b120daaaf93a17e2b123e1ed30f 100644 --- a/README.md +++ b/README.md @@ -5,24 +5,22 @@ websocket is a minimal and idiomatic WebSocket library for Go. -This library is not final and the API is subject to change. - ## Install ```bash -go get nhooyr.io/websocket@v0.2.0 +go get nhooyr.io/websocket@v1.0.0 ``` ## Features - Minimal and idiomatic API -- Tiny codebase at 1400 lines +- Tiny codebase at 1700 lines - First class context.Context support - Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Zero dependencies outside of the stdlib for the core library - JSON and ProtoBuf helpers in the wsjson and wspb subpackages -- High performance -- Concurrent reads and writes out of the box +- Highly optimized by default +- Concurrent writes out of the box ## Roadmap @@ -88,8 +86,9 @@ c.Close(websocket.StatusNormalClosure, "") - net.Conn is never exposed as WebSocket over HTTP/2 will not have a net.Conn. - Using net/http's Client for dialing means we do not have to reinvent dialing hooks and configurations like other WebSocket libraries -- We do not support the compression extension because Go's compress/flate library is very memory intensive - and browsers do not handle WebSocket compression intelligently. See [#5](https://github.com/nhooyr/websocket/issues/5) +- We do not support the deflate compression extension because Go's compress/flate library + is very memory intensive and browsers do not handle WebSocket compression intelligently. + See [#5](https://github.com/nhooyr/websocket/issues/5) ## Comparison @@ -111,7 +110,7 @@ Just compare the godoc of The API for nhooyr/websocket has been designed such that there is only one way to do things which makes it easy to use correctly. Not only is the API simpler, the implementation is -only 1400 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain, +only 1700 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain, more code to test, more code to document and more surface area for bugs. The future of gorilla/websocket is also uncertain. See [gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370). @@ -121,11 +120,23 @@ also uses net/http's Client and ResponseWriter directly for WebSocket handshakes gorilla/websocket writes its handshakes to the underlying net.Conn which means it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. -Some more advantages of nhooyr/websocket are that it supports concurrent reads, -writes and makes it very easy to close the connection with a status code and reason. +Some more advantages of nhooyr/websocket are that it supports concurrent writes and +makes it very easy to close the connection with a status code and reason. + +nhooyr/websocket also responds to pings, pongs and close frames in a separate goroutine so that +your application doesn't always need to read from the connection unless it expects a data message. +gorilla/websocket requires you to constantly read from the connection to respond to control frames +even if you don't expect the peer to send any messages. + +In terms of performance, the differences depend on your application code. nhooyr/websocket +reuses buffers efficiently out of the box if you use the wsjson and wspb subpackages whereas +gorilla/websocket does not. As mentioned above, nhooyr/websocket also supports concurrent +writers out of the box. -In terms of performance, the only difference is nhooyr/websocket is forced to use one extra -goroutine for context.Context support. Otherwise, they perform identically. +The only performance con to nhooyr/websocket is that uses two extra goroutines. One for +reading pings, pongs and close frames async to application code and another to support +context.Context cancellation. This costs 4 KB of memory which is cheap compared +to the benefits. ### x/net/websocket diff --git a/accept.go b/accept.go index a80f70aa97aeaf0815088cb72f1a4b3246c160a9..6e2141110a04526b99342e03567983d62bacdb1b 100644 --- a/accept.go +++ b/accept.go @@ -76,12 +76,12 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { } // Accept accepts a WebSocket handshake from a client and upgrades the -// the connection to WebSocket. +// the connection to a WebSocket. // // Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. // -// The returned connection will be bound by r.Context(). Use c.Context() to change +// The returned connection will be bound by r.Context(). Use conn.Context() to change // the bounding context. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) @@ -107,7 +107,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, hj, ok := w.(http.Hijacker) if !ok { - err = xerrors.New("response writer must implement http.Hijacker") + err = xerrors.New("passed ResponseWriter does not implement http.Hijacker") http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } @@ -115,7 +115,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, w.Header().Set("Upgrade", "websocket") w.Header().Set("Connection", "Upgrade") - handleKey(w, r) + handleSecWebSocketKey(w, r) subproto := selectSubprotocol(r, opts.Subprotocols) if subproto != "" { @@ -163,7 +163,7 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") -func handleKey(w http.ResponseWriter, r *http.Request) { +func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { key := r.Header.Get("Sec-WebSocket-Key") h := sha1.New() h.Write([]byte(key)) @@ -185,5 +185,5 @@ func authenticateOrigin(r *http.Request) error { if strings.EqualFold(u.Host, r.Host) { return nil } - return xerrors.Errorf("request origin %q is not authorized for host %q", origin, r.Host) + return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) } diff --git a/ci/bench/entrypoint.sh b/ci/bench/entrypoint.sh index 5f7dcf73c1dcc136cbf40b304d0b97c1e8872774..a8350c9d25c2677cf493bb41d587e86f0520f8f3 100755 --- a/ci/bench/entrypoint.sh +++ b/ci/bench/entrypoint.sh @@ -2,16 +2,14 @@ source ci/lib.sh || exit 1 -mkdir -p profs - -go test --vet=off --run=^$ -bench=. \ - -cpuprofile=profs/cpu \ - -memprofile=profs/mem \ - -blockprofile=profs/block \ - -mutexprofile=profs/mutex \ +go test --vet=off --run=^$ -bench=. -o=ci/out/websocket.test \ + -cpuprofile=ci/out/cpu.prof \ + -memprofile=ci/out/mem.prof \ + -blockprofile=ci/out/block.prof \ + -mutexprofile=ci/out/mutex.prof \ . set +x echo -echo "profiles are in ./profs +echo "profiles are in ./ci/out/*.prof keep in mind that every profiler Go provides is enabled so that may skew the benchmarks" diff --git a/ci/lint/entrypoint.sh b/ci/lint/entrypoint.sh index 09c3168322beecaa4a665a3f3e0d7180bea42432..62f7402245220d5c7073b2826e03d1f8485ecec1 100755 --- a/ci/lint/entrypoint.sh +++ b/ci/lint/entrypoint.sh @@ -7,5 +7,5 @@ source ci/lib.sh || exit 1 shellcheck ./**/*.sh ) -go vet -composites=false -lostcancel=false ./... +go vet ./... go run golang.org/x/lint/golint -set_exit_status ./... diff --git a/ci/out/.gitignore b/ci/out/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..72e8ffc0db8aad71a934dd11e5968bd5109e54b4 --- /dev/null +++ b/ci/out/.gitignore @@ -0,0 +1 @@ +* diff --git a/ci/test/entrypoint.sh b/ci/test/entrypoint.sh index 2a39593fa426ac6c18cb859a0236764dc9e0aec6..c9a0e80a0a572c38b5959ccd321d6b6202d6c406 100755 --- a/ci/test/entrypoint.sh +++ b/ci/test/entrypoint.sh @@ -2,8 +2,6 @@ source ci/lib.sh || exit 1 -mkdir -p profs - set +x echo echo "this step includes benchmarks for race detection and coverage purposes @@ -12,15 +10,15 @@ accurate numbers" echo set -x -go test -race -coverprofile=profs/coverage --vet=off -bench=. ./... -go tool cover -func=profs/coverage +go test -race -coverprofile=ci/out/coverage.prof --vet=off -bench=. ./... +go tool cover -func=ci/out/coverage.prof if [[ $CI ]]; then - bash <(curl -s https://codecov.io/bash) -f profs/coverage + bash <(curl -s https://codecov.io/bash) -f ci/out/coverage.prof else - go tool cover -html=profs/coverage -o=profs/coverage.html + go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html set +x echo - echo "please open profs/coverage.html to see detailed test coverage stats" + echo "please open ci/out/coverage.html to see detailed test coverage stats" fi diff --git a/dial.go b/dial.go index 53acd32ce6207c4d19d1d9c2e53087dd754c2d17..64d2820d505225857cdb8317b1997b29b3d712a9 100644 --- a/dial.go +++ b/dial.go @@ -18,9 +18,9 @@ import ( // DialOptions represents the options available to pass to Dial. type DialOptions struct { // HTTPClient is the http client used for the handshake. - // Its Transport must use HTTP/1.1 and return writable bodies - // for WebSocket handshakes. This was introduced in Go 1.12. - // http.Transport does this all correctly. + // Its Transport must return writable bodies + // for WebSocket handshakes. + // http.Transport does this correctly beginning with Go 1.12. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. @@ -30,7 +30,7 @@ type DialOptions struct { Subprotocols []string } -// We use this key for all client requests as the Sec-WebSocket-Key header is useless. +// We use this key for all client requests as the Sec-WebSocket-Key header doesn't do anything. // See https://stackoverflow.com/a/37074398/4283659. // We also use the same mask key for every message as it too does not make a difference. var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16)) @@ -108,7 +108,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res rwc, ok := resp.Body.(io.ReadWriteCloser) if !ok { - return nil, resp, xerrors.Errorf("response body is not a read write closer: %T", rwc) + return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) } c := &Conn{ diff --git a/example_echo_test.go b/example_echo_test.go index 405c7a4167f66b1e1c1094335d137f6dc391f394..6923bc0490bc4380263a968737bb5d4b8df649dc 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -20,7 +20,7 @@ import ( // dials the server and then sends 5 different messages // and prints out the server's responses. func Example_echo() { - // First we listen on port 0, that means the OS will + // First we listen on port 0 which means the OS will // assign us a random free port. This is the listener // the server will serve on and the client will connect to. l, err := net.Listen("tcp", "localhost:0") @@ -51,7 +51,6 @@ func Example_echo() { // Now we dial the server, send the messages and echo the responses. err = client("ws://" + l.Addr().String()) - time.Sleep(time.Second) if err != nil { log.Fatalf("client failed: %v", err) } diff --git a/limitedreader.go b/limitedreader.go new file mode 100644 index 0000000000000000000000000000000000000000..63bf40c45efc378cb2f8fe6b59d7854a3b30361f --- /dev/null +++ b/limitedreader.go @@ -0,0 +1,34 @@ +package websocket + +import ( + "fmt" + "io" + + "golang.org/x/xerrors" +) + +type limitedReader struct { + c *Conn + r io.Reader + left int64 + limit int64 +} + +func (lr *limitedReader) Read(p []byte) (int, error) { + if lr.limit == 0 { + lr.limit = lr.left + } + + if lr.left <= 0 { + msg := fmt.Sprintf("read limited at %v bytes", lr.limit) + lr.c.Close(StatusPolicyViolation, msg) + return 0, xerrors.Errorf(msg) + } + + if int64(len(p)) > lr.left { + p = p[:lr.left] + } + n, err := lr.r.Read(p) + lr.left -= int64(n) + return n, err +} diff --git a/messagetype.go b/messagetype.go index 1fd9cd6e32c58f5267748b9d2946349e7a49b49b..6a1205ee3b9e3d6047383ffd8e3e6d70029a5787 100644 --- a/messagetype.go +++ b/messagetype.go @@ -13,3 +13,5 @@ const ( // MessageBinary is for binary messages like Protobufs. MessageBinary MessageType = MessageType(opBinary) ) + +// Above I've explicitly included the types of the constants for stringer. diff --git a/statuscode.go b/statuscode.go index c7b20367046bebea6d757ff0169cbdb833d7dd94..661c6693e9eb562ca65f0084f5aba0a6a16904a5 100644 --- a/statuscode.go +++ b/statuscode.go @@ -60,7 +60,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if len(p) < 2 { - return CloseError{}, xerrors.Errorf("close payload too small, cannot even contain the 2 byte status code") + return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ @@ -78,13 +78,13 @@ func parseClosePayload(p []byte) (CloseError, error) { // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 func validWireCloseCode(code StatusCode) bool { - if code >= StatusNormalClosure && code <= statusTLSHandshake { - switch code { - case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake: - return false - default: - return true - } + switch code { + case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true } if code >= 3000 && code <= 4999 { return true diff --git a/websocket.go b/websocket.go index db2e82e70c78f9347f64e33ec967210727d3d800..d59812b82dce937a46dec912f241035ac8b1c3f0 100644 --- a/websocket.go +++ b/websocket.go @@ -11,17 +11,17 @@ import ( "runtime" "strconv" "sync" - "sync/atomic" "time" "golang.org/x/xerrors" ) // Conn represents a WebSocket connection. -// All methods may be called concurrently. +// All methods may be called concurrently except for Reader, Read +// and SetReadLimit. // // Please be sure to call Close on the connection when you -// are finished with it to release resources. +// are finished with it to release the associated resources. type Conn struct { subprotocol string br *bufio.Reader @@ -29,84 +29,38 @@ type Conn struct { closer io.Closer client bool + // read limit for a message in bytes. msgReadLimit int64 closeOnce sync.Once closeErr error closed chan struct{} - writeDataLock chan struct{} + // writeMsgLock is acquired to write a data message. + writeMsgLock chan struct{} + // writeFrameLock is acquired to write a single frame. + // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} - readMsgLock chan struct{} - readMsg chan header - readMsgDone chan struct{} + // Used to ensure the previous reader is read till EOF before allowing + // a new one. + previousReader *messageReader + // readFrameLock is acquired to read from bw. readFrameLock chan struct{} + // readMsg is used by messageReader to receive frames from + // readLoop. + readMsg chan header + // readMsgDone is used to tell the readLoop to continue after + // messageReader has read a frame. + readMsgDone chan struct{} setReadTimeout chan context.Context setWriteTimeout chan context.Context setConnContext chan context.Context getConnContext chan context.Context - pingListenerMu sync.Mutex - pingListener map[string]chan<- struct{} -} - -// Context returns a context derived from parent that will be cancelled -// when the connection is closed or broken. -// If the parent context is cancelled, the connection will be closed. -// -// This is an experimental API that may be removed in the future. -// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 -func (c *Conn) Context(parent context.Context) context.Context { - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case c.setConnContext <- parent: - } - - select { - case <-c.closed: - ctx, cancel := context.WithCancel(parent) - cancel() - return ctx - case ctx := <-c.getConnContext: - return ctx - } -} - -func (c *Conn) close(err error) { - c.closeOnce.Do(func() { - runtime.SetFinalizer(c, nil) - - cerr := c.closer.Close() - if err != nil { - cerr = err - } - - c.closeErr = xerrors.Errorf("websocket closed: %w", cerr) - - close(c.closed) - - // This ensures every goroutine that interacts - // with the conn closes before it can interact with the connection - c.readFrameLock <- struct{}{} - c.writeFrameLock <- struct{}{} - - // See comment in dial.go - if c.client { - returnBufioReader(c.br) - returnBufioWriter(c.bw) - } - }) -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} } func (c *Conn) init() { @@ -114,20 +68,19 @@ func (c *Conn) init() { c.msgReadLimit = 32768 - c.writeDataLock = make(chan struct{}, 1) + c.writeMsgLock = make(chan struct{}, 1) c.writeFrameLock = make(chan struct{}, 1) + c.readFrameLock = make(chan struct{}, 1) c.readMsg = make(chan header) c.readMsgDone = make(chan struct{}) - c.readMsgLock = make(chan struct{}, 1) - c.readFrameLock = make(chan struct{}, 1) c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) c.setConnContext = make(chan context.Context) c.getConnContext = make(chan context.Context) - c.pingListener = make(map[string]chan<- struct{}) + c.activePings = make(map[string]chan<- struct{}) runtime.SetFinalizer(c, func(c *Conn) { c.close(xerrors.New("connection garbage collected")) @@ -137,73 +90,44 @@ func (c *Conn) init() { go c.readLoop() } -// We never mask inside here because our mask key is always 0,0,0,0. -// See comment on secWebSocketKey. -func (c *Conn) writeFrame(ctx context.Context, h header, p []byte) (err error) { - err = c.acquireLock(ctx, c.writeFrameLock) - if err != nil { - return err - } - defer c.releaseLock(c.writeFrameLock) - - select { - case <-ctx.Done(): - return ctx.Err() - case <-c.closed: - return c.closeErr - case c.setWriteTimeout <- ctx: - } - defer func() { - // We have to remove the write timeout, even if ctx is cancelled. - select { - case <-c.closed: - return - case c.setWriteTimeout <- context.Background(): - } - }() +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} - defer func() { - if err != nil { - // We need to always release the lock first before closing the connection to ensure - // the lock can be acquired inside close. - c.releaseLock(c.writeFrameLock) - c.close(err) - } - }() +func (c *Conn) close(err error) { + c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) - h.masked = c.client - h.payloadLength = int64(len(p)) + c.closeErr = xerrors.Errorf("websocket closed: %w", err) + close(c.closed) - b2 := marshalHeader(h) - _, err = c.bw.Write(b2) - if err != nil { - return xerrors.Errorf("failed to write to connection: %w", err) - } - _, err = c.bw.Write(p) - if err != nil { - return xerrors.Errorf("failed to write to connection: %w", err) + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.closer.Close() - } + // See comment in dial.go + if c.client { + // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer + // and we can safely return them. + // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent + // a deadlock. + // As of now, this is in writeFrame, readFramePayload and readHeader. + c.readFrameLock <- struct{}{} + returnBufioReader(c.br) - if h.fin { - err := c.bw.Flush() - if err != nil { - return xerrors.Errorf("failed to write to connection: %w", err) + c.writeFrameLock <- struct{}{} + returnBufioWriter(c.bw) } - } - - return nil + }) } func (c *Conn) timeoutLoop() { readCtx := context.Background() writeCtx := context.Background() parentCtx := context.Background() - cancelCtx := func() {} - defer func() { - // We do not defer cancelCtx directly because its value may change. - cancelCtx() - }() for { select { @@ -219,8 +143,9 @@ func (c *Conn) timeoutLoop() { c.close(xerrors.Errorf("parent context cancelled: %w", parentCtx.Err())) return case parentCtx = <-c.setConnContext: - var ctx context.Context - ctx, cancelCtx = context.WithCancel(parentCtx) + ctx, cancelCtx := context.WithCancel(parentCtx) + defer cancelCtx() + select { case <-c.closed: return @@ -230,68 +155,92 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) handleControl(h header) { - if h.payloadLength > maxControlFramePayload { - c.Close(StatusProtocolError, "control frame too large") - return +// Context returns a context derived from parent that will be cancelled +// when the connection is closed or broken. +// If the parent context is cancelled, the connection will be closed. +// +// This is an experimental API. +// Please let me know how you feel about it in https://github.com/nhooyr/websocket/issues/79 +func (c *Conn) Context(parent context.Context) context.Context { + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case c.setConnContext <- parent: } - if !h.fin { - c.Close(StatusProtocolError, "control frame cannot be fragmented") - return + select { + case <-c.closed: + ctx, cancel := context.WithCancel(parent) + cancel() + return ctx + case ctx := <-c.getConnContext: + return ctx } +} - b := make([]byte, h.payloadLength) - _, err := io.ReadFull(c.br, b) - if err != nil { - c.close(xerrors.Errorf("failed to read control frame payload: %w", err)) - return +func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { + select { + case <-ctx.Done(): + var err error + switch lock { + case c.writeFrameLock, c.writeMsgLock: + err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err()) + case c.readFrameLock: + err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err()) + default: + panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) + } + c.close(err) + return ctx.Err() + case <-c.closed: + return c.closeErr + case lock <- struct{}{}: + return nil } +} - if h.masked { - fastXOR(h.maskKey, 0, b) +func (c *Conn) releaseLock(lock chan struct{}) { + // Allow multiple releases. + select { + case <-lock: + default: } +} - switch h.opcode { - case opPing: - c.writePong(b) - case opPong: - c.pingListenerMu.Lock() - listener, ok := c.pingListener[string(b)] - c.pingListenerMu.Unlock() - if ok { - close(listener) - } - case opClose: - ce, err := parseClosePayload(b) +func (c *Conn) readLoop() { + for { + h, err := c.readTillMsg() if err != nil { - c.close(xerrors.Errorf("received invalid close payload: %w", err)) return } - if ce.Code == StatusNoStatusRcvd { - c.writeClose(nil, ce) - } else { - c.Close(ce.Code, ce.Reason) + + select { + case <-c.closed: + return + case c.readMsg <- h: + } + + select { + case <-c.closed: + return + case <-c.readMsgDone: } - default: - panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) } } -func (c *Conn) readTillData() (header, error) { +func (c *Conn) readTillMsg() (header, error) { for { - h, err := c.readHeader() + h, err := c.readFrameHeader() if err != nil { return header{}, err } if h.rsv1 || h.rsv2 || h.rsv3 { - ce := CloseError{ - Code: StatusProtocolError, - Reason: fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3), - } - c.Close(ce.Code, ce.Reason) - return header{}, ce + err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + c.Close(StatusProtocolError, err.Error()) + return header{}, err } if h.opcode.controlOp() { @@ -303,17 +252,14 @@ func (c *Conn) readTillData() (header, error) { case opBinary, opText, opContinuation: return h, nil default: - ce := CloseError{ - Code: StatusProtocolError, - Reason: fmt.Sprintf("unknown opcode %v", h.opcode), - } - c.Close(ce.Code, ce.Reason) - return header{}, ce + err := xerrors.Errorf("received unknown opcode %v", h.opcode) + c.Close(StatusProtocolError, err.Error()) + return header{}, err } } } -func (c *Conn) readHeader() (header, error) { +func (c *Conn) readFrameHeader() (header, error) { err := c.acquireLock(context.Background(), c.readFrameLock) if err != nil { return header{}, err @@ -322,146 +268,300 @@ func (c *Conn) readHeader() (header, error) { h, err := readHeader(c.br) if err != nil { - return header{}, xerrors.Errorf("failed to read header: %w", err) + err := xerrors.Errorf("failed to read header: %w", err) + c.releaseLock(c.readFrameLock) + c.close(err) + return header{}, err } return h, nil } -func (c *Conn) readLoop() { - for { - h, err := c.readTillData() +func (c *Conn) handleControl(h header) { + if h.payloadLength > maxControlFramePayload { + c.Close(StatusProtocolError, "control frame too large") + return + } + + if !h.fin { + c.Close(StatusProtocolError, "control frame cannot be fragmented") + return + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + b := make([]byte, h.payloadLength) + + _, err := c.readFramePayload(ctx, b) + if err != nil { + return + } + + if h.masked { + fastXOR(h.maskKey, 0, b) + } + + switch h.opcode { + case opPing: + c.writePong(b) + case opPong: + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() + if ok { + close(pong) + } + case opClose: + ce, err := parseClosePayload(b) if err != nil { - c.close(err) + c.close(xerrors.Errorf("received invalid close payload: %w", err)) return } + if ce.Code == StatusNoStatusRcvd { + c.writeClose(nil, ce) + } else { + c.Close(ce.Code, ce.Reason) + } + default: + panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) + } +} - select { - case <-c.closed: - return - case c.readMsg <- h: +// Reader waits until there is a WebSocket data message to read +// from the connection. +// It returns the type of the message and a reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. +// +// Control (ping, pong, close) frames will be handled automatically +// in a separate goroutine so if you do not expect any data messages, +// you do not need to read from the connection. However, if the peer +// sends a data message, further pings, pongs and close frames will not +// be read if you do not read the message from the connection. +// +// Only one Reader may be open at a time. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, r, err := c.reader(ctx) + if err != nil { + return 0, nil, xerrors.Errorf("failed to get reader: %w", err) + } + return typ, &limitedReader{ + c: c, + r: r, + left: c.msgReadLimit, + }, nil +} + +func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { + if c.previousReader != nil && c.previousReader.h != nil { + // The only way we know for sure the previous reader is not yet complete is + // if there is an active frame not yet fully read. + // Otherwise, a user may have read the last byte but not the EOF if the EOF + // is in the next frame so we check for that below. + return 0, nil, xerrors.Errorf("previous message not read to completion") + } + + select { + case <-c.closed: + return 0, nil, c.closeErr + case <-ctx.Done(): + return 0, nil, ctx.Err() + case h := <-c.readMsg: + if c.previousReader != nil && !c.previousReader.done { + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data message without finishing the previous message") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err + } + + if !h.fin || h.payloadLength > 0 { + return 0, nil, xerrors.Errorf("previous message not read to completion") + } + + c.previousReader.done = true + + select { + case <-c.closed: + return 0, nil, c.closeErr + case c.readMsgDone <- struct{}{}: + } + + return c.reader(ctx) + } else if h.opcode == opContinuation { + err := xerrors.Errorf("received continuation frame not after data or text frame") + c.Close(StatusProtocolError, err.Error()) + return 0, nil, err } - select { - case <-c.closed: - return - case <-c.readMsgDone: + r := &messageReader{ + ctx: ctx, + c: c, + + h: &h, } + c.previousReader = r + return MessageType(h.opcode), r, nil } } -func (c *Conn) writePong(p []byte) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() +// messageReader enables reading a data frame from the WebSocket connection. +type messageReader struct { + ctx context.Context + c *Conn - err := c.writeMessage(ctx, opPong, p) - return err + h *header + maskPos int + done bool } -// Close closes the WebSocket connection with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5 seconds. -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection. -func (c *Conn) Close(code StatusCode, reason string) error { - err := c.exportedClose(code, reason) +// Read reads as many bytes as possible into p. +func (r *messageReader) Read(p []byte) (int, error) { + n, err := r.read(p) if err != nil { - return xerrors.Errorf("failed to close connection: %w", err) + // Have to return io.EOF directly for now, we cannot wrap as xerrors + // isn't used in stdlib. + if xerrors.Is(err, io.EOF) { + return n, io.EOF + } + return n, xerrors.Errorf("failed to read: %w", err) } - return nil + return n, nil } -func (c *Conn) exportedClose(code StatusCode, reason string) error { - ce := CloseError{ - Code: code, - Reason: reason, +func (r *messageReader) read(p []byte) (int, error) { + if r.done { + return 0, xerrors.Errorf("cannot use EOFed reader") } - // This function also will not wait for a close frame from the peer like the RFC - // wants because that makes no sense and I don't think anyone actually follows that. - // Definitely worth seeing what popular browsers do later. - p, err := ce.bytes() - if err != nil { - fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) - ce = CloseError{ - Code: StatusInternalError, + if r.h == nil { + select { + case <-r.c.closed: + return 0, r.c.closeErr + case <-r.ctx.Done(): + r.c.close(xerrors.Errorf("failed to read: %w", r.ctx.Err())) + return 0, r.ctx.Err() + case h := <-r.c.readMsg: + if h.opcode != opContinuation { + err := xerrors.Errorf("received new data frame without finishing the previous frame") + r.c.Close(StatusProtocolError, err.Error()) + return 0, err + } + r.h = &h } - p, _ = ce.bytes() } - return c.writeClose(p, ce) -} - -func (c *Conn) writeClose(p []byte, cerr CloseError) error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() + if int64(len(p)) > r.h.payloadLength { + p = p[:r.h.payloadLength] + } - err := c.writeMessage(ctx, opClose, p) + n, err := r.c.readFramePayload(r.ctx, p) - c.close(cerr) + r.h.payloadLength -= int64(n) + if r.h.masked { + r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + } if err != nil { - return err + return n, err } - if !xerrors.Is(c.closeErr, cerr) { - return c.closeErr + if r.h.payloadLength == 0 { + select { + case <-r.c.closed: + return n, r.c.closeErr + case r.c.readMsgDone <- struct{}{}: + } + + fin := r.h.fin + + // Need to nil this as Reader uses it to check + // whether there is active data on the previous reader and + // now there isn't. + r.h = nil + + if fin { + r.done = true + return n, io.EOF + } + + r.maskPos = 0 } - return nil + return n, nil } -func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { + err := c.acquireLock(ctx, c.readFrameLock) + if err != nil { + return 0, err + } + defer c.releaseLock(c.readFrameLock) + select { - case <-ctx.Done(): - return ctx.Err() case <-c.closed: - return c.closeErr - case lock <- struct{}{}: - return nil + return 0, c.closeErr + case c.setReadTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: + } + err = xerrors.Errorf("failed to read from connection: %w", err) + c.releaseLock(c.readFrameLock) + c.close(err) + return n, err } -} -func (c *Conn) releaseLock(lock chan struct{}) { - // Allow multiple releases. select { - case <-lock: - default: + case <-c.closed: + return n, c.closeErr + case c.setReadTimeout <- context.Background(): } + + return n, err } -func (c *Conn) writeMessage(ctx context.Context, opcode opcode, p []byte) error { - if !opcode.controlOp() { - err := c.acquireLock(ctx, c.writeDataLock) - if err != nil { - return err - } - defer c.releaseLock(c.writeDataLock) - } +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusPolicyViolation. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit = n +} - err := c.writeFrame(ctx, header{ - fin: true, - opcode: opcode, - }, p) +// Read is a convenience method to read a single message from the connection. +// +// See the Reader method if you want to be able to reuse buffers or want to stream a message. +// The docs on Reader apply to this method as well. +// +// This is an experimental API, please let me know how you feel about it in +// https://github.com/nhooyr/websocket/issues/62 +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) if err != nil { - return xerrors.Errorf("failed to write frame: %v", err) + return 0, nil, err } - return nil + + b, err := ioutil.ReadAll(r) + return typ, b, err } // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. // -// Ensure you close the writer once you have written the entire message. -// Concurrent calls to Writer are ok. -// Only one writer can be open at a time so Writer will block if there is -// another goroutine with an open writer until that writer is closed. +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { wc, err := c.writer(ctx, typ) if err != nil { @@ -471,7 +571,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err } func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.acquireLock(ctx, c.writeDataLock) + err := c.acquireLock(ctx, c.writeMsgLock) if err != nil { return nil, err } @@ -482,34 +582,30 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err }, nil } -// Read is a convenience method to read a single message from the connection. +// Write is a convenience method to write a message to the connection. // -// See the Reader method if you want to be able to reuse buffers or want to stream a message. +// See the Writer method if you want to stream a message. The docs on Writer +// regarding concurrency also apply to this method. // // This is an experimental API, please let me know how you feel about it in // https://github.com/nhooyr/websocket/issues/62 -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - typ, r, err := c.Reader(ctx) +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + err := c.write(ctx, typ, p) if err != nil { - return 0, nil, err + return xerrors.Errorf("failed to write msg: %w", err) } + return nil +} - b, err := ioutil.ReadAll(r) +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { + err := c.acquireLock(ctx, c.writeMsgLock) if err != nil { - return typ, b, err + return err } + defer c.releaseLock(c.writeMsgLock) - return typ, b, nil -} - -// Write is a convenience method to write a message to the connection. -// -// See the Writer method if you want to stream a message. -// -// This is an experimental API, please let me know how you feel about it in -// https://github.com/nhooyr/websocket/issues/62 -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - return c.writeMessage(ctx, opcode(typ), p) + err = c.writeFrame(ctx, true, opcode(typ), p) + return err } // messageWriter enables writing to a WebSocket connection. @@ -533,11 +629,9 @@ func (w *messageWriter) write(p []byte) (int, error) { if w.closed { return 0, xerrors.Errorf("cannot use closed writer") } - err := w.c.writeFrame(w.ctx, header{ - opcode: w.opcode, - }, p) + err := w.c.writeFrame(w.ctx, false, w.opcode, p) if err != nil { - return 0, err + return 0, xerrors.Errorf("failed to write data frame: %w", err) } w.opcode = opContinuation return len(p), nil @@ -559,166 +653,155 @@ func (w *messageWriter) close() error { } w.closed = true - err := w.c.writeFrame(w.ctx, header{ - fin: true, - opcode: w.opcode, - }, nil) + err := w.c.writeFrame(w.ctx, true, w.opcode, nil) if err != nil { - return err + return xerrors.Errorf("failed to write fin frame: %w", err) } - w.c.releaseLock(w.c.writeDataLock) + w.c.releaseLock(w.c.writeMsgLock) return nil } -// Reader will wait until there is a WebSocket data message to read from the connection. -// It returns the type of the message and a reader to read it. -// The passed context will also bound the reader. -// -// Your application must keep reading messages for the Conn to automatically respond to ping -// and close frames and not become stuck waiting for a data message to be read. -// Please ensure to read the full message from io.Reader. -// -// You can only read a single message at a time so do not call this method -// concurrently. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - typ, r, err := c.reader(ctx) +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + err := c.writeFrame(ctx, true, opcode, p) if err != nil { - return 0, nil, xerrors.Errorf("failed to get reader: %w", err) + return xerrors.Errorf("failed to write control frame: %w", err) } - return typ, io.LimitReader(r, c.msgReadLimit), nil + return nil } -func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { - err = c.acquireLock(ctx, c.readMsgLock) +// writeFrame handles all writes to the connection. +// We never mask inside here because our mask key is always 0,0,0,0. +// See comment on secWebSocketKey for why. +func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) error { + h := header{ + fin: fin, + opcode: opcode, + masked: c.client, + payloadLength: int64(len(p)), + } + b2 := marshalHeader(h) + + err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { - return 0, nil, err + return err } + defer c.releaseLock(c.writeFrameLock) select { case <-c.closed: - return 0, nil, c.closeErr - case <-ctx.Done(): - return 0, nil, ctx.Err() - case h := <-c.readMsg: - if h.opcode == opContinuation { - ce := CloseError{ - Code: StatusProtocolError, - Reason: "continuation frame not after data or text frame", - } - c.Close(ce.Code, ce.Reason) - return 0, nil, ce - } - return MessageType(h.opcode), &messageReader{ - ctx: ctx, - h: &h, - c: c, - }, nil + return c.closeErr + case c.setWriteTimeout <- ctx: } -} - -// messageReader enables reading a data frame from the WebSocket connection. -type messageReader struct { - ctx context.Context - maskPos int - h *header - c *Conn - eofed bool -} -// Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (int, error) { - n, err := r.read(p) - if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as xerrors - // isn't used in stdlib. - if xerrors.Is(err, io.EOF) { - return n, io.EOF + writeErr := func(err error) error { + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err = ctx.Err() + default: } - return n, xerrors.Errorf("failed to read: %w", err) - } - return n, nil -} -func (r *messageReader) read(p []byte) (int, error) { - if r.eofed { - return 0, xerrors.Errorf("cannot use EOFed reader") + err = xerrors.Errorf("failed to write to connection: %w", err) + // We need to release the lock first before closing the connection to ensure + // the lock can be acquired inside close to ensure no one can access c.bw. + c.releaseLock(c.writeFrameLock) + c.close(err) + + return err } - if r.h == nil { - select { - case <-r.c.closed: - return 0, r.c.closeErr - case h := <-r.c.readMsg: - if h.opcode != opContinuation { - ce := CloseError{ - Code: StatusProtocolError, - Reason: "cannot read new data frame when previous frame is not finished", - } - r.c.Close(ce.Code, ce.Reason) - return 0, ce - } - r.h = &h - } + _, err = c.bw.Write(b2) + if err != nil { + return writeErr(err) + } + _, err = c.bw.Write(p) + if err != nil { + return writeErr(err) } - if int64(len(p)) > r.h.payloadLength { - p = p[:r.h.payloadLength] + if fin { + err = c.bw.Flush() + if err != nil { + return writeErr(err) + } } + // We already finished writing, no need to potentially brick the connection if + // the context expires. select { - case <-r.c.closed: - return 0, r.c.closeErr - case r.c.setReadTimeout <- r.ctx: + case <-c.closed: + return c.closeErr + case c.setWriteTimeout <- context.Background(): } - err := r.c.acquireLock(r.ctx, r.c.readFrameLock) + return nil +} + +func (c *Conn) writePong(p []byte) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.writeControl(ctx, opPong, p) + return err +} + +// Close closes the WebSocket connection with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5 seconds. +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes otherwise an internal +// error will be sent to the peer. For this reason, you should avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.exportedClose(code, reason) if err != nil { - return 0, err + return xerrors.Errorf("failed to close connection: %w", err) } - n, err := io.ReadFull(r.c.br, p) - r.c.releaseLock(r.c.readFrameLock) + return nil +} - select { - case <-r.c.closed: - return 0, r.c.closeErr - case r.c.setReadTimeout <- context.Background(): +func (c *Conn) exportedClose(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, } - r.h.payloadLength -= int64(n) - if r.h.masked { - r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) + // This function also will not wait for a close frame from the peer like the RFC + // wants because that makes no sense and I don't think anyone actually follows that. + // Definitely worth seeing what popular browsers do later. + p, err := ce.bytes() + if err != nil { + fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytes() } + return c.writeClose(p, ce) +} + +func (c *Conn) writeClose(p []byte, cerr CloseError) error { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.writeControl(ctx, opClose, p) if err != nil { - r.c.close(xerrors.Errorf("failed to read control frame payload: %w", err)) - return n, r.c.closeErr + return err } - if r.h.payloadLength == 0 { - select { - case <-r.c.closed: - return n, r.c.closeErr - case r.c.readMsgDone <- struct{}{}: - } - if r.h.fin { - r.eofed = true - r.c.releaseLock(r.c.readMsgLock) - return n, io.EOF - } - r.maskPos = 0 - r.h = nil + c.close(cerr) + if !xerrors.Is(c.closeErr, cerr) { + return c.closeErr } - return n, nil -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -func (c *Conn) SetReadLimit(n int64) { - atomic.StoreInt64(&c.msgReadLimit, n) + return nil } func init() { @@ -728,7 +811,7 @@ func init() { // Ping sends a ping to the peer and waits for a pong. // Use this to measure latency or ensure the peer is responsive. // -// This API is experimental and subject to change. +// This API is experimental. // Please provide feedback in https://github.com/nhooyr/websocket/issues/1. func (c *Conn) Ping(ctx context.Context) error { err := c.ping(ctx) @@ -744,23 +827,26 @@ func (c *Conn) ping(ctx context.Context) error { pong := make(chan struct{}) - c.pingListenerMu.Lock() - c.pingListener[p] = pong - c.pingListenerMu.Unlock() + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() defer func() { - c.pingListenerMu.Lock() - delete(c.pingListener, p) - c.pingListenerMu.Unlock() + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() }() - err := c.writeMessage(ctx, opPing, []byte(p)) + err := c.writeControl(ctx, opPing, []byte(p)) if err != nil { return err } select { + case <-c.closed: + return c.closeErr case <-ctx.Done(): + c.close(xerrors.Errorf("failed to ping: %w", ctx.Err())) return ctx.Err() case <-pong: return nil diff --git a/websocket_test.go b/websocket_test.go index f1905c30fbb93ecdd813deff9a592ef3cc1642e2..b1c5b9d485d74243c89cd31c1bcf1138ff48325e 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -390,6 +390,11 @@ func TestHandshake(t *testing.T) { return err } + err = c.Write(r.Context(), websocket.MessageText, []byte("hi")) + if err != nil { + return err + } + c.Close(websocket.StatusNormalClosure, "") return nil }, @@ -405,10 +410,52 @@ func TestHandshake(t *testing.T) { return err } + _, _, err = c.Read(ctx) + if err != nil { + return err + } + c.Close(websocket.StatusNormalClosure, "") return nil }, }, + { + name: "readLimit", + server: func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + _, _, err = c.Read(r.Context()) + if err == nil { + return xerrors.Errorf("expected error but got nil") + } + return nil + }, + client: func(ctx context.Context, u string) error { + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) + if err != nil { + return err + } + + err = c.Ping(ctx) + + var ce websocket.CloseError + if !xerrors.As(err, &ce) || ce.Code != websocket.StatusPolicyViolation { + return xerrors.Errorf("unexpected error: %w", err) + } + + return nil + }, + }, } for _, tc := range testCases { @@ -477,17 +524,20 @@ func TestAutobahnServer(t *testing.T) { defer s.Close() spec := map[string]interface{}{ - "outdir": "wstest_reports/server", + "outdir": "ci/out/wstestServerReports", "servers": []interface{}{ map[string]interface{}{ "agent": "main", "url": strings.Replace(s.URL, "http", "ws", 1), }, }, - "cases": []string{"*"}, + "cases": []string{"*"}, + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. 7.5.1 is the same. + // 12.* and 13.* as we do not support compression. "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } - specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json") + specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") if err != nil { t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) } @@ -516,7 +566,7 @@ func TestAutobahnServer(t *testing.T) { t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) } - checkWSTestIndex(t, "./wstest_reports/server/index.json") + checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } func echoLoop(ctx context.Context, c *websocket.Conn) { @@ -593,12 +643,13 @@ func TestAutobahnClient(t *testing.T) { t.Parallel() spec := map[string]interface{}{ - "url": "ws://localhost:9001", - "outdir": "wstest_reports/client", - "cases": []string{"*"}, + "url": "ws://localhost:9001", + "outdir": "ci/out/wstestClientReports", + "cases": []string{"*"}, + // See TestAutobahnServer for the reasons why we exclude these. "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } - specFile, err := ioutil.TempFile("", "websocket_fuzzingserver.json") + specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") if err != nil { t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) } @@ -682,7 +733,7 @@ func TestAutobahnClient(t *testing.T) { } c.Close(websocket.StatusNormalClosure, "") - checkWSTestIndex(t, "./wstest_reports/client/index.json") + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") } func checkWSTestIndex(t *testing.T, path string) { diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index d85700bc55c458a7f4ba44bf10806201800e4f4f..994ffad194cd0a4b570ecd89559a04edbb23f302 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -1,10 +1,9 @@ -// Package wsjson provides helpers for JSON messages. +// Package wsjson provides websocket helpers for JSON messages. package wsjson import ( "context" "encoding/json" - "io" "golang.org/x/xerrors" @@ -12,6 +11,8 @@ import ( ) // Read reads a json message from c into v. +// If the message is larger than 128 bytes, it will use a buffer +// from a pool instead of performing an allocation. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { err := read(ctx, c, v) if err != nil { @@ -21,7 +22,7 @@ func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { } func read(ctx context.Context, c *websocket.Conn, v interface{}) error { - typ, r, err := c.Reader(ctx) + typ, b, err := c.Read(ctx) if err != nil { return err } @@ -31,27 +32,16 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) } - d := json.NewDecoder(r) - err = d.Decode(v) + err = json.Unmarshal(b, v) if err != nil { - return xerrors.Errorf("failed to decode json: %w", err) - } - - // Have to ensure we read till EOF. - // Unfortunate but necessary evil for now. Can improve later. - // The code to do this automatically gets complicated fast because - // we support concurrent reading. - // So the Reader has to synchronize with Read somehow. - // Maybe its best to bring back the old readLoop? - _, err = r.Read([]byte{0}) - if !xerrors.Is(err, io.EOF) { - return xerrors.Errorf("more data than needed in reader") + return xerrors.Errorf("failed to unmarshal json: %w", err) } return nil } // Write writes the json message v to c. +// It uses json.Encoder which automatically reuses buffers. func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { err := write(ctx, c, v) if err != nil { @@ -66,6 +56,8 @@ func write(ctx context.Context, c *websocket.Conn, v interface{}) error { return err } + // We use Encode because it automatically enables buffer reuse without us + // needing to do anything. Though see https://github.com/golang/go/issues/27735 e := json.NewEncoder(w) err = e.Encode(v) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index edffede18b98ea905e6ac23c9c287c46fcd421e5..e6c9169309bb9d18674cfa619983f9a3ab59ce0d 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -1,9 +1,8 @@ -// Package wspb provides helpers for protobuf messages. +// Package wspb provides websocket helpers for protobuf messages. package wspb import ( "context" - "io/ioutil" "github.com/golang/protobuf/proto" "golang.org/x/xerrors" @@ -12,6 +11,7 @@ import ( ) // Read reads a protobuf message from c into v. +// It will reuse buffers to avoid allocations. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err := read(ctx, c, v) if err != nil { @@ -21,7 +21,7 @@ func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { } func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { - typ, r, err := c.Reader(ctx) + typ, b, err := c.Read(ctx) if err != nil { return err } @@ -31,11 +31,6 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) } - b, err := ioutil.ReadAll(r) - if err != nil { - return xerrors.Errorf("failed to read message: %w", err) - } - err = proto.Unmarshal(b, v) if err != nil { return xerrors.Errorf("failed to unmarshal protobuf: %w", err) @@ -45,6 +40,7 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { } // Write writes the protobuf message v to c. +// It will reuse buffers to avoid allocations. func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { err := write(ctx, c, v) if err != nil {