From 531d4fab2b30955df6ca43aea0417eb7aa60d515 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Tue, 12 Nov 2019 11:15:17 -0500 Subject: [PATCH] Improve general compression API and write docs --- README.md | 46 +- accept.go | 330 +++++++++ handshake_test.go => accept_test.go | 143 ---- assert_test.go | 56 +- ci/fmt.mk | 2 +- close.go | 181 +++++ close_test.go | 196 ++++++ compress.go | 78 +++ conn.go | 297 +++++--- dial.go | 219 ++++++ dial_test.go | 149 ++++ doc.go | 2 +- frame.go | 445 ------------ frame_test.go | 457 ------------- handshake.go | 637 ------------------ internal/assert/assert.go | 18 +- internal/atomicint/atomicint.go | 32 + internal/{bpool/bpool.go => bufpool/buf.go} | 2 +- .../bpool_test.go => bufpool/buf_test.go} | 2 +- internal/bufpool/bufio.go | 40 ++ internal/wsframe/frame.go | 194 ++++++ .../wsframe/frame_stringer.go | 20 +- internal/wsframe/frame_test.go | 157 +++++ internal/wsframe/mask.go | 128 ++++ internal/wsframe/mask_test.go | 118 ++++ js_test.go | 50 ++ conn_common.go => netconn.go | 78 --- reader.go | 31 + websocket_js_test.go | 52 -- writer.go | 5 + websocket_js.go => ws_js.go | 58 +- ws_js_test.go | 22 + wsjson/wsjson.go | 7 +- wspb/wspb.go | 10 +- 34 files changed, 2243 insertions(+), 2019 deletions(-) create mode 100644 accept.go rename handshake_test.go => accept_test.go (62%) create mode 100644 close.go create mode 100644 close_test.go create mode 100644 compress.go create mode 100644 dial.go create mode 100644 dial_test.go delete mode 100644 frame.go delete mode 100644 frame_test.go delete mode 100644 handshake.go create mode 100644 internal/atomicint/atomicint.go rename internal/{bpool/bpool.go => bufpool/buf.go} (95%) rename internal/{bpool/bpool_test.go => bufpool/buf_test.go} (97%) create mode 100644 internal/bufpool/bufio.go create mode 100644 internal/wsframe/frame.go rename frame_stringer.go => internal/wsframe/frame_stringer.go (90%) create mode 100644 internal/wsframe/frame_test.go create mode 100644 internal/wsframe/mask.go create mode 100644 internal/wsframe/mask_test.go create mode 100644 js_test.go rename conn_common.go => netconn.go (60%) create mode 100644 reader.go delete mode 100644 websocket_js_test.go create mode 100644 writer.go rename websocket_js.go => ws_js.go (88%) create mode 100644 ws_js_test.go diff --git a/README.md b/README.md index b5adc59..17c7c83 100644 --- a/README.md +++ b/README.md @@ -22,13 +22,14 @@ go get nhooyr.io/websocket - [Zero dependencies](https://godoc.org/nhooyr.io/websocket?imports) - JSON and ProtoBuf helpers in the [wsjson](https://godoc.org/nhooyr.io/websocket/wsjson) and [wspb](https://godoc.org/nhooyr.io/websocket/wspb) subpackages - Highly optimized by default + - Zero alloc reads and writes - Concurrent writes out of the box - [Complete Wasm](https://godoc.org/nhooyr.io/websocket#hdr-Wasm) support - [Close handshake](https://godoc.org/nhooyr.io/websocket#Conn.Close) +- Full support of [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression extension ## Roadmap -- [ ] Compression Extensions [#163](https://github.com/nhooyr/websocket/pull/163) - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) ## Examples @@ -84,22 +85,12 @@ if err != nil { c.Close(websocket.StatusNormalClosure, "") ``` -## Design justifications - -- A minimal API is easier to maintain due to less docs, tests and bugs -- A minimal API is also easier to use and learn -- Context based cancellation is more ergonomic and robust than setting deadlines -- net.Conn is never exposed as WebSocket over HTTP/2 will not have a net.Conn. -- Using net/http's Client for dialing means we do not have to reinvent dialing hooks - and configurations like other WebSocket libraries - ## Comparison -Before the comparison, I want to point out that both gorilla/websocket and gobwas/ws were -extremely useful in implementing the WebSocket protocol correctly so _big thanks_ to the -authors of both. In particular, I made sure to go through the issue tracker of gorilla/websocket -to ensure I implemented details correctly and understood how people were using WebSockets in -production. +Before the comparison, I want to point out that gorilla/websocket was extremely useful in implementing the +WebSocket protocol correctly so _big thanks_ to its authors. In particular, I made sure to go through the +issue tracker of gorilla/websocket to ensure I implemented details correctly and understood how people were +using WebSockets in production. ### gorilla/websocket @@ -121,7 +112,7 @@ more code to test, more code to document and more surface area for bugs. Moreover, nhooyr.io/websocket supports newer Go idioms such as context.Context. It also uses net/http's Client and ResponseWriter directly for WebSocket handshakes. gorilla/websocket writes its handshakes to the underlying net.Conn. -Thus it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. +Thus it has to reinvent hooks for TLS and proxies and prevents easy support of HTTP/2. Some more advantages of nhooyr.io/websocket are that it supports concurrent writes and makes it very easy to close the connection with a status code and reason. In fact, @@ -138,10 +129,14 @@ In terms of performance, the differences mostly depend on your application code. reuses message buffers out of the box if you use the wsjson and wspb subpackages. As mentioned above, nhooyr.io/websocket also supports concurrent writers. -The WebSocket masking algorithm used by this package is also [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) -faster than gorilla/websocket or gobwas/ws while using only pure safe Go. +The WebSocket masking algorithm used by this package is [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) +faster than gorilla/websocket while using only pure safe Go. -The only performance con to nhooyr.io/websocket is that it uses one extra goroutine to support +The [permessage-deflate compression extension](https://tools.ietf.org/html/rfc7692) is fully supported by this library +whereas gorilla only supports no context takeover mode. See our godoc for the differences. This will make a big +difference on bandwidth used in most use cases. + +The only performance con to nhooyr.io/websocket is that it uses a goroutine to support cancellation with context.Context. This costs 2 KB of memory which is cheap compared to the benefits. @@ -160,14 +155,15 @@ https://github.com/gobwas/ws This library has an extremely flexible API but that comes at the cost of usability and clarity. -This library is fantastic in terms of performance. The author put in significant -effort to ensure its speed and I have applied as many of its optimizations as -I could into nhooyr.io/websocket. Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb) -about performant WebSocket servers. +Due to its flexibility, it can be used in a event driven style for performance. +Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb) about performant WebSocket servers. If you want a library that gives you absolute control over everything, this is the library. -But for 99.9% of use cases, nhooyr.io/websocket will fit better. It's nearly as performant -but much easier to use. +But for 99.9% of use cases, nhooyr.io/websocket will fit better as it is both easier and +faster for normal idiomatic Go. The masking implementation is [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) +faster, the compression extensions are fully supported and as much as possible is reused by default. + +See the gorilla/websocket comparison for more performance details. ## Contributing diff --git a/accept.go b/accept.go new file mode 100644 index 0000000..5ff2ea4 --- /dev/null +++ b/accept.go @@ -0,0 +1,330 @@ +package websocket + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "io" + "net/http" + "net/textproto" + "net/url" + "strings" +) + +// AcceptOptions represents the options available to pass to Accept. +type AcceptOptions struct { + // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. + // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to + // reject it, close the connection if c.Subprotocol() == "". + Subprotocols []string + + // InsecureSkipVerify disables Accept's origin verification + // behaviour. By default Accept only allows the handshake to + // succeed if the javascript that is initiating the handshake + // is on the same domain as the server. This is to prevent CSRF + // attacks when secure data is stored in a cookie as there is no same + // origin policy for WebSockets. In other words, javascript from + // any domain can perform a WebSocket dial on an arbitrary server. + // This dial will include cookies which means the arbitrary javascript + // can perform actions as the authenticated user. + // + // See https://stackoverflow.com/a/37837709/4283659 + // + // The only time you need this is if your javascript is running on a different domain + // than your WebSocket server. + // Think carefully about whether you really need this option before you use it. + // If you do, remember that if you store secure data in cookies, you wil need to verify the + // Origin header yourself otherwise you are exposing yourself to a CSRF attack. + InsecureSkipVerify bool + + // CompressionMode sets the compression mode. + // See docs on the CompressionMode type and defined constants. + CompressionMode CompressionMode +} + +// Accept accepts a WebSocket HTTP handshake from a client and upgrades the +// the connection to a WebSocket. +// +// Accept will reject the handshake if the Origin domain is not the same as the Host unless +// the InsecureSkipVerify option is set. In other words, by default it does not allow +// cross origin requests. +// +// If an error occurs, Accept will write a response with a safe error message to w. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + c, err := accept(w, r, opts) + if err != nil { + return nil, fmt.Errorf("failed to accept websocket connection: %w", err) + } + return c, nil +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + if opts == nil { + opts = &AcceptOptions{} + } + + err := verifyClientRequest(w, r) + if err != nil { + return nil, err + } + + if !opts.InsecureSkipVerify { + err = authenticateOrigin(r) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } + } + + hj, ok := w.(http.Hijacker) + if !ok { + err = errors.New("passed ResponseWriter does not implement http.Hijacker") + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return nil, err + } + + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + + handleSecWebSocketKey(w, r) + + subproto := selectSubprotocol(r, opts.Subprotocols) + if subproto != "" { + w.Header().Set("Sec-WebSocket-Protocol", subproto) + } + + copts, err := acceptCompression(r, w, opts.CompressionMode) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + w.WriteHeader(http.StatusSwitchingProtocols) + + netConn, brw, err := hj.Hijack() + if err != nil { + err = fmt.Errorf("failed to hijack connection: %w", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return nil, err + } + + // https://github.com/golang/go/issues/32314 + b, _ := brw.Reader.Peek(brw.Reader.Buffered()) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) + + c := &Conn{ + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + br: brw.Reader, + bw: brw.Writer, + closer: netConn, + copts: copts, + } + c.init() + + return c, nil +} + +func authenticateOrigin(r *http.Request) error { + origin := r.Header.Get("Origin") + if origin == "" { + return nil + } + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + if !strings.EqualFold(u.Host, r.Host) { + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + } + return nil +} + +func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { + if !r.ProtoAtLeast(1, 1) { + err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if !headerContainsToken(r.Header, "Connection", "Upgrade") { + err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { + err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Method != "GET" { + err := fmt.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Header.Get("Sec-WebSocket-Version") != "13" { + err := fmt.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + if r.Header.Get("Sec-WebSocket-Key") == "" { + err := errors.New("websocket protocol violation: missing Sec-WebSocket-Key") + http.Error(w, err.Error(), http.StatusBadRequest) + return err + } + + return nil +} + +func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { + key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) +} + +func selectSubprotocol(r *http.Request, subprotocols []string) string { + for _, sp := range subprotocols { + if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { + return sp + } + } + return "" +} + +func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { + if mode == CompressionDisabled { + return nil, nil + } + + for _, ext := range websocketExtensions(r.Header) { + switch ext.name { + case "permessage-deflate": + return acceptDeflate(w, ext, mode) + case "x-webkit-deflate-frame": + return acceptWebkitDeflate(w, ext, mode) + } + } + return nil, nil +} + +func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + case "client_max_window_bits", "server-max-window-bits": + continue + } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + } + + copts.setHeader(w.Header()) + + return copts, nil +} + +func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + // The peer must explicitly request it. + copts.serverNoContextTakeover = false + + for _, p := range ext.params { + if p == "no_context_takeover" { + copts.serverNoContextTakeover = true + continue + } + + // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead + // of ignoring it as the draft spec is unclear. It says the server can ignore it + // but the server has no way of signalling to the client it was ignored as the parameters + // are set one way. + // Thus us ignoring it would make the client think we understood it which would cause issues. + // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 + // + // Either way, we're only implementing this for webkit which never sends the max_window_bits + // parameter so we don't need to worry about it. + return nil, fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + } + + s := "x-webkit-deflate-frame" + if copts.clientNoContextTakeover { + s += "; no_context_takeover" + } + w.Header().Set("Sec-WebSocket-Extensions", s) + + return copts, nil +} + + +func headerContainsToken(h http.Header, key, token string) bool { + token = strings.ToLower(token) + + for _, t := range headerTokens(h, key) { + if t == token { + return true + } + } + return false +} + +type websocketExtension struct { + name string + params []string +} + +func websocketExtensions(h http.Header) []websocketExtension { + var exts []websocketExtension + extStrs := headerTokens(h, "Sec-WebSocket-Extensions") + for _, extStr := range extStrs { + if extStr == "" { + continue + } + + vals := strings.Split(extStr, ";") + for i := range vals { + vals[i] = strings.TrimSpace(vals[i]) + } + + e := websocketExtension{ + name: vals[0], + params: vals[1:], + } + + exts = append(exts, e) + } + return exts +} + +func headerTokens(h http.Header, key string) []string { + key = textproto.CanonicalMIMEHeaderKey(key) + var tokens []string + for _, v := range h[key] { + v = strings.TrimSpace(v) + for _, t := range strings.Split(v, ",") { + t = strings.ToLower(t) + tokens = append(tokens, t) + } + } + return tokens +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func secWebSocketAccept(secWebSocketKey string) string { + h := sha1.New() + h.Write([]byte(secWebSocketKey)) + h.Write(keyGUID) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/handshake_test.go b/accept_test.go similarity index 62% rename from handshake_test.go rename to accept_test.go index 82f958e..9598cd5 100644 --- a/handshake_test.go +++ b/accept_test.go @@ -1,14 +1,9 @@ -// +build !js - package websocket import ( - "context" - "net/http" "net/http/httptest" "strings" "testing" - "time" ) func TestAccept(t *testing.T) { @@ -246,141 +241,3 @@ func Test_authenticateOrigin(t *testing.T) { }) } } - -func TestBadDials(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - url string - opts *DialOptions - }{ - { - name: "badURL", - url: "://noscheme", - }, - { - name: "badURLScheme", - url: "ftp://nhooyr.io", - }, - { - name: "badHTTPClient", - url: "ws://nhooyr.io", - opts: &DialOptions{ - HTTPClient: &http.Client{ - Timeout: time.Minute, - }, - }, - }, - { - name: "badTLS", - url: "wss://totallyfake.nhooyr.io", - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - _, _, err := Dial(ctx, tc.url, tc.opts) - if err == nil { - t.Fatalf("expected non nil error: %+v", err) - } - }) - } -} - -func Test_verifyServerHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - response func(w http.ResponseWriter) - success bool - }{ - { - name: "badStatus", - response: func(w http.ResponseWriter) { - w.WriteHeader(http.StatusOK) - }, - success: false, - }, - { - name: "badConnection", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badUpgrade", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "???") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badSecWebSocketAccept", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Accept", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "badSecWebSocketProtocol", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Sec-WebSocket-Protocol", "xd") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: false, - }, - { - name: "success", - response: func(w http.ResponseWriter) { - w.Header().Set("Connection", "Upgrade") - w.Header().Set("Upgrade", "websocket") - w.WriteHeader(http.StatusSwitchingProtocols) - }, - success: true, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - w := httptest.NewRecorder() - tc.response(w) - resp := w.Result() - - r := httptest.NewRequest("GET", "/", nil) - key, err := makeSecWebSocketKey() - if err != nil { - t.Fatal(err) - } - r.Header.Set("Sec-WebSocket-Key", key) - - if resp.Header.Get("Sec-WebSocket-Accept") == "" { - resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) - } - - _, err = verifyServerResponse(r, resp, &DialOptions{}) - if (err == nil) != tc.success { - t.Fatalf("unexpected error: %+v", err) - } - }) - } -} diff --git a/assert_test.go b/assert_test.go index 26fd1d4..af30099 100644 --- a/assert_test.go +++ b/assert_test.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "strings" + "testing" "time" "nhooyr.io/websocket" @@ -15,36 +16,30 @@ func init() { rand.Seed(time.Now().UnixNano()) } -func assertJSONEcho(ctx context.Context, c *websocket.Conn, n int) error { +func randBytes(n int) []byte { + b := make([]byte, n) + rand.Read(b) + return b +} + +func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { exp := randString(n) err := wsjson.Write(ctx, c, exp) - if err != nil { - return err - } + assert.Success(t, err) var act interface{} err = wsjson.Read(ctx, c, &act) - if err != nil { - return err - } + assert.Success(t, err) - return assert.Equalf(exp, act, "unexpected JSON") + assert.Equalf(t, exp, act, "unexpected JSON") } -func assertJSONRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { +func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { var act interface{} err := wsjson.Read(ctx, c, &act) - if err != nil { - return err - } - - return assert.Equalf(exp, act, "unexpected JSON") -} + assert.Success(t, err) -func randBytes(n int) []byte { - b := make([]byte, n) - rand.Read(b) - return b + assert.Equalf(t, exp, act, "unexpected JSON") } func randString(n int) string { @@ -60,23 +55,18 @@ func randString(n int) string { return s } -func assertEcho(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) error { +func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { p := randBytes(n) err := c.Write(ctx, typ, p) - if err != nil { - return err - } + assert.Success(t, err) + typ2, p2, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(typ, typ2, "unexpected data type") - if err != nil { - return err - } - return assert.Equalf(p, p2, "unexpected payload") + assert.Success(t, err) + + assert.Equalf(t, typ, typ2, "unexpected data type") + assert.Equalf(t, p, p2, "unexpected payload") } -func assertSubprotocol(c *websocket.Conn, exp string) error { - return assert.Equalf(exp, c.Subprotocol(), "unexpected subprotocol") +func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { + assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol") } diff --git a/ci/fmt.mk b/ci/fmt.mk index 8e61bc2..3637c1a 100644 --- a/ci/fmt.mk +++ b/ci/fmt.mk @@ -22,4 +22,4 @@ prettier: prettier --write --print-width=120 --no-semi --trailing-comma=all --loglevel=warn $$(git ls-files "*.yml" "*.md") gen: - go generate ./... + stringer -type=Opcode,MessageType,StatusCode -output=websocket_stringer.go diff --git a/close.go b/close.go new file mode 100644 index 0000000..4f48f1b --- /dev/null +++ b/close.go @@ -0,0 +1,181 @@ +package websocket + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "nhooyr.io/websocket/internal/wsframe" +) + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// These codes were retrieved from: +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// The defined constants only represent the status codes registered with IANA. +// The 4000-4999 range of status codes is reserved for arbitrary use by applications. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so not exported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // an explicit status. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether the connection was closed or not or what happened. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError represents a WebSocket close frame. +// It is returned by Conn's methods when a WebSocket close frame is received from +// the peer. +// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, +// to check for this error. See the CloseError example. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around errors.As to grab +// the status code from a *CloseError. If the passed error is nil +// or not a *CloseError, the returned StatusCode will be -1. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + code, reason, err := wsframe.ParseClosePayload(p) + if err != nil { + return CloseError{}, err + } + + ce := CloseError{ + Code: StatusCode(code), + Reason: reason, + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + // TODO move check into frame write + if len(ce.Reason) > wsframe.MaxControlFramePayload-2 { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", wsframe.MaxControlFramePayload-2, ce.Reason, len(ce.Reason)) + } + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +// CloseRead will start a goroutine to read from the connection until it is closed or a data message +// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. +// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. +// After calling this method, you cannot read any data messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// Use this when you do not want to read data messages from the connection anymore but will +// want to write messages to it. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + // We use the unexported reader method so that we don't get the read closed error. + c.reader(ctx, true) + // Either the connection is already closed since there was a read error + // or the context was cancelled or a message was read and we should close + // the connection. + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("websocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/close_test.go b/close_test.go new file mode 100644 index 0000000..78096d7 --- /dev/null +++ b/close_test.go @@ -0,0 +1,196 @@ +package websocket + +import ( + "github.com/google/go-cmp/cmp" + "io" + "math" + "nhooyr.io/websocket/internal/assert" + "nhooyr.io/websocket/internal/wsframe" + "strings" + "testing" +) + +func TestCloseError(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + ce CloseError + success bool + }{ + { + name: "normal", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + }, + success: true, + }, + { + name: "bigReason", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-1), + }, + success: false, + }, + { + name: "bigCode", + ce: CloseError{ + Code: math.MaxUint16, + Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + }, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := tc.ce.bytes() + if (err == nil) != tc.success { + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if (err == nil) != tc.success { + t.Fatalf("unexpected expected error value: %+v", err) + } + + if tc.success && tc.ce != ce { + t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if valid := validWireCloseCode(tc.code); tc.valid != valid { + t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) + } + }) + } +} + +func TestCloseStatus(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + in error + exp StatusCode + }{ + { + name: "nil", + in: nil, + exp: -1, + }, + { + name: "io.EOF", + in: io.EOF, + exp: -1, + }, + { + name: "StatusInternalError", + in: CloseError{ + Code: StatusInternalError, + }, + exp: StatusInternalError, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + assert.Equalf(t, tc.exp, CloseStatus(tc.in), "unexpected close status") + }) + } +} diff --git a/compress.go b/compress.go new file mode 100644 index 0000000..5b5fdce --- /dev/null +++ b/compress.go @@ -0,0 +1,78 @@ +// +build !js + +package websocket + +import ( + "net/http" +) + +// CompressionMode controls the modes available RFC 7692's deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// +// A compatibility layer is implemented for the older deflate-frame extension used +// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 +// It will work the same in every way except that we cannot signal to the peer we +// want to use no context takeover on our side, we can only signal that they should. +type CompressionMode int + +const ( + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this is the default. + // + // The message will only be compressed if greater than or equal to 128 bytes. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover CompressionMode = iota + + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be much lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than or equal to 512 bytes. + CompressionNoContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) + +func (m CompressionMode) opts() *compressionOptions { + if m == CompressionDisabled { + return nil + } + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" diff --git a/conn.go b/conn.go index 32dfa81..791d9b4 100644 --- a/conn.go +++ b/conn.go @@ -13,13 +13,28 @@ import ( "io" "io/ioutil" "log" + "nhooyr.io/websocket/internal/atomicint" + "nhooyr.io/websocket/internal/wsframe" "runtime" "strconv" + "strings" "sync" "sync/atomic" "time" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" +) + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like Protobufs. + MessageBinary ) // Conn represents a WebSocket connection. @@ -36,20 +51,20 @@ import ( // This applies to the Read methods in the wsjson/wspb subpackages as well. type Conn struct { subprotocol string - br *bufio.Reader + fw *flate.Writer bw *bufio.Writer // writeBuf is used for masking, its the buffer in bufio.Writer. // Only used by the client for masking the bytes in the buffer. writeBuf []byte closer io.Closer client bool - copts *CompressionOptions + copts *compressionOptions closeOnce sync.Once closeErrOnce sync.Once closeErr error closed chan struct{} - closing *atomicInt64 + closing *atomicint.Int64 closeReceived error // messageWriter state. @@ -61,35 +76,18 @@ type Conn struct { writeHeaderBuf []byte writeHeader *header // read limit for a message in bytes. - msgReadLimit *atomicInt64 + msgReadLimit *atomicint.Int64 // Used to ensure a previous writer is not used after being closed. activeWriter atomic.Value // messageWriter state. writeMsgOpcode opcode writeMsgCtx context.Context - readMsgLeft int64 - - // Used to ensure the previous reader is read till EOF before allowing - // a new one. - activeReader *messageReader - // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} - isReadClosed *atomicInt64 - readHeaderBuf []byte - controlPayloadBuf []byte - readLock chan struct{} - - // messageReader state. - readerMsgCtx context.Context - readerMsgHeader header - readerFrameEOF bool - readerMaskKey uint32 setReadTimeout chan context.Context setWriteTimeout chan context.Context - pingCounter *atomicInt64 + pingCounter *atomicint.Int64 activePingsMu sync.Mutex activePings map[string]chan<- struct{} @@ -98,9 +96,9 @@ type Conn struct { func (c *Conn) init() { c.closed = make(chan struct{}) - c.closing = &atomicInt64{} + c.closing = &atomicint.Int64{} - c.msgReadLimit = &atomicInt64{} + c.msgReadLimit = &atomicint.Int64{} c.msgReadLimit.Store(32768) c.writeMsgLock = make(chan struct{}, 1) @@ -108,17 +106,18 @@ func (c *Conn) init() { c.readFrameLock = make(chan struct{}, 1) c.readLock = make(chan struct{}, 1) + c.payloadReader = framePayloadReader{c} c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) - c.pingCounter = &atomicInt64{} + c.pingCounter = &atomicint.Int64{} c.activePings = make(map[string]chan<- struct{}) c.writeHeaderBuf = makeWriteHeaderBuf() c.writeHeader = &header{} c.readHeaderBuf = makeReadHeaderBuf() - c.isReadClosed = &atomicInt64{} + c.isReadClosed = &atomicint.Int64{} c.controlPayloadBuf = make([]byte, maxControlFramePayload) runtime.SetFinalizer(c, func(c *Conn) { @@ -127,6 +126,15 @@ func (c *Conn) init() { c.logf = log.Printf + if c.copts != nil { + if !c.readNoContextTakeOver() { + c.fr = getFlateReader(c.payloadReader) + } + if !c.writeNoContextTakeOver() { + c.fw = getFlateWriter(c.bw) + } + } + go c.timeoutLoop() } @@ -148,19 +156,26 @@ func (c *Conn) close(err error) { // closeErr. c.closer.Close() - // See comment on bufioReaderPool in handshake.go + // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer + // and we can safely return them. + // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent + // a deadlock. + // As of now, this is in writeFrame, readFramePayload and readHeader. + c.readFrameLock <- struct{}{} if c.client { - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readFramePayload and readHeader. - c.readFrameLock <- struct{}{} returnBufioReader(c.br) + } + if c.fr != nil { + putFlateReader(c.fr) + } - c.writeFrameLock <- struct{}{} + c.writeFrameLock <- struct{}{} + if c.client { returnBufioWriter(c.bw) } + if c.fw != nil { + putFlateWriter(c.fw) + } }) } @@ -230,7 +245,7 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { return header{}, err } - if h.rsv1 || h.rsv2 || h.rsv3 { + if (h.rsv1 && (c.copts == nil || h.opcode.controlOp() || h.opcode == opContinuation)) || h.rsv2 || h.rsv3 { err := fmt.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.exportedClose(StatusProtocolError, err.Error(), false) return header{}, err @@ -448,6 +463,13 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e c.readerMsgCtx = ctx c.readerMsgHeader = h + + c.readerPayloadCompressed = h.rsv1 + + if c.readerPayloadCompressed { + c.readerCompressTail.Reset(deflateMessageTail) + } + c.readerFrameEOF = false c.readerMaskKey = h.maskKey c.readMsgLeft = c.msgReadLimit.Load() @@ -456,9 +478,67 @@ func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, e c: c, } c.activeReader = r + if c.readerPayloadCompressed && c.readNoContextTakeOver() { + c.fr = getFlateReader(c.payloadReader) + } return MessageType(h.opcode), r, nil } +type framePayloadReader struct { + c *Conn +} + +func (r framePayloadReader) Read(p []byte) (int, error) { + if r.c.readerFrameEOF { + if r.c.readerPayloadCompressed && r.c.readerMsgHeader.fin { + n, _ := r.c.readerCompressTail.Read(p) + return n, nil + } + + h, err := r.c.readTillMsg(r.c.readerMsgCtx) + if err != nil { + return 0, err + } + + if h.opcode != opContinuation { + err := errors.New("received new data message without finishing the previous message") + r.c.exportedClose(StatusProtocolError, err.Error(), false) + return 0, err + } + + r.c.readerMsgHeader = h + r.c.readerFrameEOF = false + r.c.readerMaskKey = h.maskKey + } + + h := r.c.readerMsgHeader + if int64(len(p)) > h.payloadLength { + p = p[:h.payloadLength] + } + + n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) + + h.payloadLength -= int64(n) + if h.masked { + r.c.readerMaskKey = mask(r.c.readerMaskKey, p) + } + r.c.readerMsgHeader = h + + if err != nil { + return n, err + } + + if h.payloadLength == 0 { + r.c.readerFrameEOF = true + + if h.fin && !r.c.readerPayloadCompressed { + return n, io.EOF + } + } + + return n, nil +} + // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { c *Conn @@ -521,51 +601,27 @@ func (r *messageReader) read(p []byte, lock bool) (int, error) { p = p[:r.c.readMsgLeft] } - if r.c.readerFrameEOF { - h, err := r.c.readTillMsg(r.c.readerMsgCtx) - if err != nil { - return 0, err - } - - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - r.c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, err - } - - r.c.readerMsgHeader = h - r.c.readerFrameEOF = false - r.c.readerMaskKey = h.maskKey + pr := io.Reader(r.c.payloadReader) + if r.c.readerPayloadCompressed { + pr = r.c.fr } - h := r.c.readerMsgHeader - if int64(len(p)) > h.payloadLength { - p = p[:h.payloadLength] - } - - n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) + n, err := pr.Read(p) - h.payloadLength -= int64(n) r.c.readMsgLeft -= int64(n) - if h.masked { - r.c.readerMaskKey = mask(r.c.readerMaskKey, p) - } - r.c.readerMsgHeader = h - if err != nil { - return n, err - } - - if h.payloadLength == 0 { - r.c.readerFrameEOF = true - - if h.fin { - r.c.activeReader = nil - return n, io.EOF + if r.c.readerFrameEOF && r.c.readerMsgHeader.fin { + if r.c.readerPayloadCompressed && r.c.readNoContextTakeOver() { + putFlateReader(r.c.fr) + r.c.fr = nil + } + r.c.activeReader = nil + if err == nil { + err = io.EOF } } - return n, nil + return n, err } func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { @@ -971,10 +1027,10 @@ func (c *Conn) waitClose() error { return c.closeReceived } - b := bpool.Get() + b := bufpool.Get() buf := b.Bytes() buf = buf[:cap(buf)] - defer bpool.Put(b) + defer bufpool.Put(b) for { if c.activeReader == nil || c.readerFrameEOF { @@ -1065,40 +1121,21 @@ func (c *Conn) extractBufioWriterBuf(w io.Writer) { c.bw.Reset(w) } -var flateWriterPoolsMu sync.Mutex -var flateWriterPools = make(map[int]*sync.Pool) - -func getFlateWriterPool(level int) *sync.Pool { - flateWriterPoolsMu.Lock() - defer flateWriterPoolsMu.Unlock() - - p, ok := flateWriterPools[level] - if !ok { - p = &sync.Pool{ - New: func() interface{} { - w, err := flate.NewWriter(nil, level) - if err != nil { - panic("websocket: unexpected error from flate.NewWriter: " + err.Error()) - } - return w - }, - } - flateWriterPools[level] = p - } - - return p +var flateWriterPool = &sync.Pool{ + New: func() interface{} { + w, _ := flate.NewWriter(nil, flate.BestSpeed) + return w + }, } -func getFlateWriter(w io.Writer, level int) *flate.Writer { - p := getFlateWriterPool(level) - fw := p.Get().(*flate.Writer) +func getFlateWriter(w io.Writer) *flate.Writer { + fw := flateWriterPool.Get().(*flate.Writer) fw.Reset(w) return fw } -func putFlateWriter(w *flate.Writer, level int) { - p := getFlateWriterPool(level) - p.Put(w) +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) } var flateReaderPool = &sync.Pool{ @@ -1107,12 +1144,60 @@ var flateReaderPool = &sync.Pool{ }, } -func getFlateReader(r flate.Reader) io.ReadCloser { - fr := flateReaderPool.Get().(io.ReadCloser) +func getFlateReader(r io.Reader) io.Reader { + fr := flateReaderPool.Get().(io.Reader) fr.(flate.Resetter).Reset(r, nil) return fr } -func putFlateReader(fr io.ReadCloser) { +func putFlateReader(fr io.Reader) { flateReaderPool.Put(fr) } + +func (c *Conn) writeNoContextTakeOver() bool { + return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover +} + +func (c *Conn) readNoContextTakeOver() bool { + return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover +} + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (w *trimLastFourBytesWriter) Write(p []byte) (int, error) { + extra := len(w.tail) + len(p) - 4 + + if extra <= 0 { + w.tail = append(w.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(w.tail) { + extra = len(w.tail) + } + if extra > 0 { + _, err := w.Write(w.tail[:extra]) + if err != nil { + return 0, err + } + w.tail = w.tail[extra:] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + w.tail = append(w.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + w.tail = append(w.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := w.w.Write(p) + return n + 4, err +} diff --git a/dial.go b/dial.go new file mode 100644 index 0000000..1008868 --- /dev/null +++ b/dial.go @@ -0,0 +1,219 @@ +package websocket + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "nhooyr.io/websocket/internal/bufpool" + "strings" +) + +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // HTTPClient is the http client used for the handshake. + // Its Transport must return writable bodies + // for WebSocket handshakes. + // http.Transport does this correctly beginning with Go 1.12. + HTTPClient *http.Client + + // HTTPHeader specifies the HTTP headers included in the handshake request. + HTTPHeader http.Header + + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string + + // See docs on CompressionMode. + CompressionMode CompressionMode +} + +// Dial performs a WebSocket handshake on the given url with the given options. +// The response is the WebSocket handshake response from the server. +// If an error occurs, the returned response may be non nil. However, you can only +// read the first 1024 bytes of its body. +// +// You never need to close the resp.Body yourself. +// +// This function requires at least Go 1.12 to succeed as it uses a new feature +// in net/http to perform WebSocket handshakes and get a writable body +// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + c, r, err := dial(ctx, u, opts) + if err != nil { + return nil, r, fmt.Errorf("failed to websocket dial: %w", err) + } + return c, r, nil +} + +func (opts *DialOptions) fill() (*DialOptions, error) { + if opts == nil { + opts = &DialOptions{} + } else { + opts = &*opts + } + + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } + if opts.HTTPClient.Timeout > 0 { + return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") + } + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + + return opts, nil +} + +func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { + opts, err = opts.fill() + if err != nil { + return nil, nil, err + } + + parsedURL, err := url.Parse(u) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse url: %w", err) + } + + switch parsedURL.Scheme { + case "ws": + parsedURL.Scheme = "http" + case "wss": + parsedURL.Scheme = "https" + default: + return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) + } + + req, _ := http.NewRequest("GET", parsedURL.String(), nil) + req = req.WithContext(ctx) + req.Header = opts.HTTPHeader + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + secWebSocketKey, err := secWebSocketKey() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + copts := opts.CompressionMode.opts() + copts.setHeader(req.Header) + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) + } + defer func() { + if err != nil { + // We read a bit of the body for easier debugging. + r := io.LimitReader(resp.Body, 1024) + b, _ := ioutil.ReadAll(r) + resp.Body.Close() + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + }() + + copts, err = verifyServerResponse(req, resp, opts) + if err != nil { + return nil, resp, err + } + + rwc, ok := resp.Body.(io.ReadWriteCloser) + if !ok { + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) + } + + c := &Conn{ + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + br: bufpool.GetReader(rwc), + bw: bufpool.GetWriter(rwc), + closer: rwc, + client: true, + copts: copts, + } + c.extractBufioWriterBuf(rwc) + c.init() + + return c, resp, nil +} + +func secWebSocketKey() (string, error) { + b := make([]byte, 16) + _, err := io.ReadFull(rand.Reader, b) + if err != nil { + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*compressionOptions, error) { + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerContainsToken(resp.Header, "Connection", "Upgrade") { + return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + } + + if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { + return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + } + + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { + return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + r.Header.Get("Sec-WebSocket-Key"), + ) + } + + if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { + return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) + } + + copts, err := verifyServerExtensions(resp.Header, opts.CompressionMode) + if err != nil { + return nil, err + } + + return copts, nil +} + +func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOptions, error) { + exts := websocketExtensions(h) + if len(exts) == 0 { + return nil, nil + } + + ext := exts[0] + if ext.name != "permessage-deflate" { + return nil, fmt.Errorf("unexpected extension from server: %q", ext) + } + + if len(exts) > 1 { + return nil, fmt.Errorf("unexpected extra extensions from server: %+v", exts[1:]) + } + + copts := mode.opts() + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + } + + return copts, nil +} diff --git a/dial_test.go b/dial_test.go new file mode 100644 index 0000000..391aa1c --- /dev/null +++ b/dial_test.go @@ -0,0 +1,149 @@ +// +build !js + +package websocket + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestBadDials(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + url string + opts *DialOptions + }{ + { + name: "badURL", + url: "://noscheme", + }, + { + name: "badURLScheme", + url: "ftp://nhooyr.io", + }, + { + name: "badHTTPClient", + url: "ws://nhooyr.io", + opts: &DialOptions{ + HTTPClient: &http.Client{ + Timeout: time.Minute, + }, + }, + }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, tc.url, tc.opts) + if err == nil { + t.Fatalf("expected non nil error: %+v", err) + } + }) + } +} + +func Test_verifyServerHandshake(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + response func(w http.ResponseWriter) + success bool + }{ + { + name: "badStatus", + response: func(w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + }, + success: false, + }, + { + name: "badConnection", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badUpgrade", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "???") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketAccept", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Accept", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "badSecWebSocketProtocol", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Sec-WebSocket-Protocol", "xd") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: false, + }, + { + name: "success", + response: func(w http.ResponseWriter) { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + w.WriteHeader(http.StatusSwitchingProtocols) + }, + success: true, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + tc.response(w) + resp := w.Result() + + r := httptest.NewRequest("GET", "/", nil) + key, err := secWebSocketKey() + if err != nil { + t.Fatal(err) + } + r.Header.Set("Sec-WebSocket-Key", key) + + if resp.Header.Get("Sec-WebSocket-Accept") == "" { + resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + } + + _, err = verifyServerResponse(r, resp, &DialOptions{}) + if (err == nil) != tc.success { + t.Fatalf("unexpected error: %+v", err) + } + }) + } +} diff --git a/doc.go b/doc.go index 804665f..5285a78 100644 --- a/doc.go +++ b/doc.go @@ -1,6 +1,6 @@ // +build !js -// Package websocket is a minimal and idiomatic implementation of the WebSocket protocol. +// Package websocket implements the RFC 6455 WebSocket protocol. // // https://tools.ietf.org/html/rfc6455 // diff --git a/frame.go b/frame.go deleted file mode 100644 index e4bf931..0000000 --- a/frame.go +++ /dev/null @@ -1,445 +0,0 @@ -package websocket - -import ( - "encoding/binary" - "errors" - "fmt" - "io" - "math" - "math/bits" -) - -//go:generate stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go - -// opcode represents a WebSocket Opcode. -type opcode int - -// opcode constants. -const ( - opContinuation opcode = iota - opText - opBinary - // 3 - 7 are reserved for further non-control frames. - _ - _ - _ - _ - _ - opClose - opPing - opPong - // 11-16 are reserved for further control frames. -) - -func (o opcode) controlOp() bool { - switch o { - case opClose, opPing, opPong: - return true - } - return false -} - -// MessageType represents the type of a WebSocket message. -// See https://tools.ietf.org/html/rfc6455#section-5.6 -type MessageType int - -// MessageType constants. -const ( - // MessageText is for UTF-8 encoded text messages like JSON. - MessageText MessageType = iota + 1 - // MessageBinary is for binary messages like Protobufs. - MessageBinary -) - -// First byte contains fin, rsv1, rsv2, rsv3. -// Second byte contains mask flag and payload len. -// Next 8 bytes are the maximum extended payload length. -// Last 4 bytes are the mask key. -// https://tools.ietf.org/html/rfc6455#section-5.2 -const maxHeaderSize = 1 + 1 + 8 + 4 - -// header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type header struct { - fin bool - rsv1 bool - rsv2 bool - rsv3 bool - opcode opcode - - payloadLength int64 - - masked bool - maskKey uint32 -} - -func makeWriteHeaderBuf() []byte { - return make([]byte, maxHeaderSize) -} - -// bytes returns the bytes of the header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func writeHeader(b []byte, h header) []byte { - if b == nil { - b = makeWriteHeaderBuf() - } - - b = b[:2] - b[0] = 0 - - if h.fin { - b[0] |= 1 << 7 - } - if h.rsv1 { - b[0] |= 1 << 6 - } - if h.rsv2 { - b[0] |= 1 << 5 - } - if h.rsv3 { - b[0] |= 1 << 4 - } - - b[0] |= byte(h.opcode) - - switch { - case h.payloadLength < 0: - panic(fmt.Sprintf("websocket: invalid header: negative length: %v", h.payloadLength)) - case h.payloadLength <= 125: - b[1] = byte(h.payloadLength) - case h.payloadLength <= math.MaxUint16: - b[1] = 126 - b = b[:len(b)+2] - binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.payloadLength)) - default: - b[1] = 127 - b = b[:len(b)+8] - binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.payloadLength)) - } - - if h.masked { - b[1] |= 1 << 7 - b = b[:len(b)+4] - binary.LittleEndian.PutUint32(b[len(b)-4:], h.maskKey) - } - - return b -} - -func makeReadHeaderBuf() []byte { - return make([]byte, maxHeaderSize-2) -} - -// readHeader reads a header from the reader. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func readHeader(b []byte, r io.Reader) (header, error) { - if b == nil { - b = makeReadHeaderBuf() - } - - // We read the first two bytes first so that we know - // exactly how long the header is. - b = b[:2] - _, err := io.ReadFull(r, b) - if err != nil { - return header{}, err - } - - var h header - h.fin = b[0]&(1<<7) != 0 - h.rsv1 = b[0]&(1<<6) != 0 - h.rsv2 = b[0]&(1<<5) != 0 - h.rsv3 = b[0]&(1<<4) != 0 - - h.opcode = opcode(b[0] & 0xf) - - var extra int - - h.masked = b[1]&(1<<7) != 0 - if h.masked { - extra += 4 - } - - payloadLength := b[1] &^ (1 << 7) - switch { - case payloadLength < 126: - h.payloadLength = int64(payloadLength) - case payloadLength == 126: - extra += 2 - case payloadLength == 127: - extra += 8 - } - - if extra == 0 { - return h, nil - } - - b = b[:extra] - _, err = io.ReadFull(r, b) - if err != nil { - return header{}, err - } - - switch { - case payloadLength == 126: - h.payloadLength = int64(binary.BigEndian.Uint16(b)) - b = b[2:] - case payloadLength == 127: - h.payloadLength = int64(binary.BigEndian.Uint64(b)) - if h.payloadLength < 0 { - return header{}, fmt.Errorf("header with negative payload length: %v", h.payloadLength) - } - b = b[8:] - } - - if h.masked { - h.maskKey = binary.LittleEndian.Uint32(b) - } - - return h, nil -} - -// StatusCode represents a WebSocket status code. -// https://tools.ietf.org/html/rfc6455#section-7.4 -type StatusCode int - -// These codes were retrieved from: -// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// -// The defined constants only represent the status codes registered with IANA. -// The 4000-4999 range of status codes is reserved for arbitrary use by applications. -const ( - StatusNormalClosure StatusCode = 1000 - StatusGoingAway StatusCode = 1001 - StatusProtocolError StatusCode = 1002 - StatusUnsupportedData StatusCode = 1003 - - // 1004 is reserved and so not exported. - statusReserved StatusCode = 1004 - - // StatusNoStatusRcvd cannot be sent in a close message. - // It is reserved for when a close message is received without - // an explicit status. - StatusNoStatusRcvd StatusCode = 1005 - - // StatusAbnormalClosure is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether the connection was closed or not or what happened. - StatusAbnormalClosure StatusCode = 1006 - - StatusInvalidFramePayloadData StatusCode = 1007 - StatusPolicyViolation StatusCode = 1008 - StatusMessageTooBig StatusCode = 1009 - StatusMandatoryExtension StatusCode = 1010 - StatusInternalError StatusCode = 1011 - StatusServiceRestart StatusCode = 1012 - StatusTryAgainLater StatusCode = 1013 - StatusBadGateway StatusCode = 1014 - - // StatusTLSHandshake is only exported for use with Wasm. - // In non Wasm Go, the returned error will indicate whether there was a TLS handshake failure. - StatusTLSHandshake StatusCode = 1015 -) - -// CloseError represents a WebSocket close frame. -// It is returned by Conn's methods when a WebSocket close frame is received from -// the peer. -// You will need to use the https://golang.org/pkg/errors/#As function, new in Go 1.13, -// to check for this error. See the CloseError example. -type CloseError struct { - Code StatusCode - Reason string -} - -func (ce CloseError) Error() string { - return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) -} - -// CloseStatus is a convenience wrapper around errors.As to grab -// the status code from a *CloseError. If the passed error is nil -// or not a *CloseError, the returned StatusCode will be -1. -func CloseStatus(err error) StatusCode { - var ce CloseError - if errors.As(err, &ce) { - return ce.Code - } - return -1 -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -const maxControlFramePayload = 125 - -func (ce CloseError) bytes() ([]byte, error) { - if len(ce.Reason) > maxControlFramePayload-2 { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) - } - if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -// fastMask applies the WebSocket masking algorithm to p -// with the given key. -// See https://tools.ietf.org/html/rfc6455#section-5.3 -// -// The returned value is the correctly rotated key to -// to continue to mask/unmask the message. -// -// It is optimized for LittleEndian and expects the key -// to be in little endian. -// -// See https://github.com/golang/go/issues/31586 -func mask(key uint32, b []byte) uint32 { - if len(b) >= 8 { - key64 := uint64(key)<<32 | uint64(key) - - // At some point in the future we can clean these unrolled loops up. - // See https://github.com/golang/go/issues/31586#issuecomment-487436401 - - // Then we xor until b is less than 128 bytes. - for len(b) >= 128 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - v = binary.LittleEndian.Uint64(b[64:72]) - binary.LittleEndian.PutUint64(b[64:72], v^key64) - v = binary.LittleEndian.Uint64(b[72:80]) - binary.LittleEndian.PutUint64(b[72:80], v^key64) - v = binary.LittleEndian.Uint64(b[80:88]) - binary.LittleEndian.PutUint64(b[80:88], v^key64) - v = binary.LittleEndian.Uint64(b[88:96]) - binary.LittleEndian.PutUint64(b[88:96], v^key64) - v = binary.LittleEndian.Uint64(b[96:104]) - binary.LittleEndian.PutUint64(b[96:104], v^key64) - v = binary.LittleEndian.Uint64(b[104:112]) - binary.LittleEndian.PutUint64(b[104:112], v^key64) - v = binary.LittleEndian.Uint64(b[112:120]) - binary.LittleEndian.PutUint64(b[112:120], v^key64) - v = binary.LittleEndian.Uint64(b[120:128]) - binary.LittleEndian.PutUint64(b[120:128], v^key64) - b = b[128:] - } - - // Then we xor until b is less than 64 bytes. - for len(b) >= 64 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - v = binary.LittleEndian.Uint64(b[32:40]) - binary.LittleEndian.PutUint64(b[32:40], v^key64) - v = binary.LittleEndian.Uint64(b[40:48]) - binary.LittleEndian.PutUint64(b[40:48], v^key64) - v = binary.LittleEndian.Uint64(b[48:56]) - binary.LittleEndian.PutUint64(b[48:56], v^key64) - v = binary.LittleEndian.Uint64(b[56:64]) - binary.LittleEndian.PutUint64(b[56:64], v^key64) - b = b[64:] - } - - // Then we xor until b is less than 32 bytes. - for len(b) >= 32 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - v = binary.LittleEndian.Uint64(b[16:24]) - binary.LittleEndian.PutUint64(b[16:24], v^key64) - v = binary.LittleEndian.Uint64(b[24:32]) - binary.LittleEndian.PutUint64(b[24:32], v^key64) - b = b[32:] - } - - // Then we xor until b is less than 16 bytes. - for len(b) >= 16 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - v = binary.LittleEndian.Uint64(b[8:16]) - binary.LittleEndian.PutUint64(b[8:16], v^key64) - b = b[16:] - } - - // Then we xor until b is less than 8 bytes. - for len(b) >= 8 { - v := binary.LittleEndian.Uint64(b) - binary.LittleEndian.PutUint64(b, v^key64) - b = b[8:] - } - } - - // Then we xor until b is less than 4 bytes. - for len(b) >= 4 { - v := binary.LittleEndian.Uint32(b) - binary.LittleEndian.PutUint32(b, v^key) - b = b[4:] - } - - // xor remaining bytes. - for i := range b { - b[i] ^= byte(key) - key = bits.RotateLeft32(key, -8) - } - - return key -} diff --git a/frame_test.go b/frame_test.go deleted file mode 100644 index 571e68f..0000000 --- a/frame_test.go +++ /dev/null @@ -1,457 +0,0 @@ -// +build !js - -package websocket - -import ( - "bytes" - "encoding/binary" - "io" - "math" - "math/bits" - "math/rand" - "strconv" - "strings" - "testing" - "time" - _ "unsafe" - - "github.com/gobwas/ws" - "github.com/google/go-cmp/cmp" - _ "github.com/gorilla/websocket" - - "nhooyr.io/websocket/internal/assert" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBool() bool { - return rand.Intn(1) == 0 -} - -func TestHeader(t *testing.T) { - t.Parallel() - - t.Run("eof", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - bytes []byte - }{ - { - "start", - []byte{0xff}, - }, - { - "middle", - []byte{0xff, 0xff, 0xff}, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - b := bytes.NewBuffer(tc.bytes) - _, err := readHeader(nil, b) - if io.ErrUnexpectedEOF != err { - t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) - } - }) - } - }) - - t.Run("writeNegativeLength", func(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r == nil { - t.Fatal("failed to induce panic in writeHeader with negative payload length") - } - }() - - writeHeader(nil, header{ - payloadLength: -1, - }) - }) - - t.Run("readNegativeLength", func(t *testing.T) { - t.Parallel() - - b := writeHeader(nil, header{ - payloadLength: 1<<16 + 1, - }) - - // Make length negative - b[2] |= 1 << 7 - - r := bytes.NewReader(b) - _, err := readHeader(nil, r) - if err == nil { - t.Fatalf("unexpected error value: %+v", err) - } - }) - - t.Run("lengths", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 124, - 125, - 126, - 4096, - 16384, - 65535, - 65536, - 65537, - 131072, - } - - for _, n := range lengths { - n := n - t.Run(strconv.Itoa(n), func(t *testing.T) { - t.Parallel() - - testHeader(t, header{ - payloadLength: int64(n), - }) - }) - } - }) - - t.Run("fuzz", func(t *testing.T) { - t.Parallel() - - for i := 0; i < 10000; i++ { - h := header{ - fin: randBool(), - rsv1: randBool(), - rsv2: randBool(), - rsv3: randBool(), - opcode: opcode(rand.Intn(1 << 4)), - - masked: randBool(), - payloadLength: rand.Int63(), - } - - if h.masked { - h.maskKey = rand.Uint32() - } - - testHeader(t, h) - } - }) -} - -func testHeader(t *testing.T, h header) { - b := writeHeader(nil, h) - r := bytes.NewReader(b) - h2, err := readHeader(nil, r) - if err != nil { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read header: %v", err) - } - - if !cmp.Equal(h, h2, cmp.AllowUnexported(header{})) { - t.Logf("header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(header{}))) - } -} - -func TestCloseError(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - ce CloseError - success bool - }{ - { - name: "normal", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: true, - }, - { - name: "bigReason", - ce: CloseError{ - Code: StatusNormalClosure, - Reason: strings.Repeat("x", maxControlFramePayload-1), - }, - success: false, - }, - { - name: "bigCode", - ce: CloseError{ - Code: math.MaxUint16, - Reason: strings.Repeat("x", maxControlFramePayload-2), - }, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - _, err := tc.ce.bytes() - if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %+v", err) - } - }) - } -} - -func Test_parseClosePayload(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - p []byte - success bool - ce CloseError - }{ - { - name: "normal", - p: append([]byte{0x3, 0xE8}, []byte("hello")...), - success: true, - ce: CloseError{ - Code: StatusNormalClosure, - Reason: "hello", - }, - }, - { - name: "nothing", - success: true, - ce: CloseError{ - Code: StatusNoStatusRcvd, - }, - }, - { - name: "oneByte", - p: []byte{0}, - success: false, - }, - { - name: "badStatusCode", - p: []byte{0x17, 0x70}, - success: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - ce, err := parseClosePayload(tc.p) - if (err == nil) != tc.success { - t.Fatalf("unexpected expected error value: %+v", err) - } - - if tc.success && tc.ce != ce { - t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) - } - }) - } -} - -func Test_validWireCloseCode(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - code StatusCode - valid bool - }{ - { - name: "normal", - code: StatusNormalClosure, - valid: true, - }, - { - name: "noStatus", - code: StatusNoStatusRcvd, - valid: false, - }, - { - name: "3000", - code: 3000, - valid: true, - }, - { - name: "4999", - code: 4999, - valid: true, - }, - { - name: "unknown", - code: 5000, - valid: false, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - if valid := validWireCloseCode(tc.code); tc.valid != valid { - t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) - } - }) - } -} - -func Test_mask(t *testing.T) { - t.Parallel() - - key := []byte{0xa, 0xb, 0xc, 0xff} - key32 := binary.LittleEndian.Uint32(key) - p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := mask(key32, p) - - if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { - t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) - } - - if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { - t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) - } -} - -func basicMask(maskKey [4]byte, pos int, b []byte) int { - for i := range b { - b[i] ^= maskKey[pos&3] - pos++ - } - return pos & 3 -} - -//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes -func gorillaMaskBytes(key [4]byte, pos int, b []byte) int - -func Benchmark_mask(b *testing.B) { - sizes := []int{ - 2, - 3, - 4, - 8, - 16, - 32, - 128, - 512, - 4096, - 16384, - } - - fns := []struct { - name string - fn func(b *testing.B, key [4]byte, p []byte) - }{ - { - name: "basic", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - basicMask(key, 0, p) - } - }, - }, - - { - name: "nhooyr", - fn: func(b *testing.B, key [4]byte, p []byte) { - key32 := binary.LittleEndian.Uint32(key[:]) - b.ResetTimer() - - for i := 0; i < b.N; i++ { - mask(key32, p) - } - }, - }, - { - name: "gorilla", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - gorillaMaskBytes(key, 0, p) - } - }, - }, - { - name: "gobwas", - fn: func(b *testing.B, key [4]byte, p []byte) { - for i := 0; i < b.N; i++ { - ws.Cipher(p, key, 0) - } - }, - }, - } - - var key [4]byte - _, err := rand.Read(key[:]) - if err != nil { - b.Fatalf("failed to populate mask key: %v", err) - } - - for _, size := range sizes { - p := make([]byte, size) - - b.Run(strconv.Itoa(size), func(b *testing.B) { - for _, fn := range fns { - b.Run(fn.name, func(b *testing.B) { - b.SetBytes(int64(size)) - - fn.fn(b, key, p) - }) - } - }) - } -} - -func TestCloseStatus(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - in error - exp StatusCode - }{ - { - name: "nil", - in: nil, - exp: -1, - }, - { - name: "io.EOF", - in: io.EOF, - exp: -1, - }, - { - name: "StatusInternalError", - in: CloseError{ - Code: StatusInternalError, - }, - exp: StatusInternalError, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - err := assert.Equalf(tc.exp, CloseStatus(tc.in), "unexpected close status") - if err != nil { - t.Fatal(err) - } - }) - } -} diff --git a/handshake.go b/handshake.go deleted file mode 100644 index 0333103..0000000 --- a/handshake.go +++ /dev/null @@ -1,637 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "bytes" - "context" - "crypto/rand" - "crypto/sha1" - "encoding/base64" - "errors" - "fmt" - "io" - "io/ioutil" - "net/http" - "net/textproto" - "net/url" - "strings" - "sync" -) - -// AcceptOptions represents the options available to pass to Accept. -type AcceptOptions struct { - // Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. - // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to - // reject it, close the connection if c.Subprotocol() == "". - Subprotocols []string - - // InsecureSkipVerify disables Accept's origin verification - // behaviour. By default Accept only allows the handshake to - // succeed if the javascript that is initiating the handshake - // is on the same domain as the server. This is to prevent CSRF - // attacks when secure data is stored in a cookie as there is no same - // origin policy for WebSockets. In other words, javascript from - // any domain can perform a WebSocket dial on an arbitrary server. - // This dial will include cookies which means the arbitrary javascript - // can perform actions as the authenticated user. - // - // See https://stackoverflow.com/a/37837709/4283659 - // - // The only time you need this is if your javascript is running on a different domain - // than your WebSocket server. - // Think carefully about whether you really need this option before you use it. - // If you do, remember that if you store secure data in cookies, you wil need to verify the - // Origin header yourself otherwise you are exposing yourself to a CSRF attack. - InsecureSkipVerify bool - - // Compression sets the compression options. - // By default, compression is disabled. - // See docs on the CompressionOptions type. - Compression *CompressionOptions -} - -func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { - if !r.ProtoAtLeast(1, 1) { - err := fmt.Errorf("websocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if !headerContainsToken(r.Header, "Connection", "Upgrade") { - err := fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if !headerContainsToken(r.Header, "Upgrade", "WebSocket") { - err := fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Method != "GET" { - err := fmt.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Header.Get("Sec-WebSocket-Version") != "13" { - err := fmt.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - if r.Header.Get("Sec-WebSocket-Key") == "" { - err := errors.New("websocket protocol violation: missing Sec-WebSocket-Key") - http.Error(w, err.Error(), http.StatusBadRequest) - return err - } - - return nil -} - -// Accept accepts a WebSocket handshake from a client and upgrades the -// the connection to a WebSocket. -// -// Accept will reject the handshake if the Origin domain is not the same as the Host unless -// the InsecureSkipVerify option is set. In other words, by default it does not allow -// cross origin requests. -// -// If an error occurs, Accept will always write an appropriate response so you do not -// have to. -func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - c, err := accept(w, r, opts) - if err != nil { - return nil, fmt.Errorf("failed to accept websocket connection: %w", err) - } - return c, nil -} - -func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { - if opts == nil { - opts = &AcceptOptions{} - } - - err := verifyClientRequest(w, r) - if err != nil { - return nil, err - } - - if !opts.InsecureSkipVerify { - err = authenticateOrigin(r) - if err != nil { - http.Error(w, err.Error(), http.StatusForbidden) - return nil, err - } - } - - hj, ok := w.(http.Hijacker) - if !ok { - err = errors.New("passed ResponseWriter does not implement http.Hijacker") - http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) - return nil, err - } - - w.Header().Set("Upgrade", "websocket") - w.Header().Set("Connection", "Upgrade") - - handleSecWebSocketKey(w, r) - - subproto := selectSubprotocol(r, opts.Subprotocols) - if subproto != "" { - w.Header().Set("Sec-WebSocket-Protocol", subproto) - } - - var copts *CompressionOptions - if opts.Compression != nil { - copts, err = negotiateCompression(r.Header, opts.Compression) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return nil, err - } - if copts != nil { - copts.setHeader(w.Header(), false) - } - } - - w.WriteHeader(http.StatusSwitchingProtocols) - - netConn, brw, err := hj.Hijack() - if err != nil { - err = fmt.Errorf("failed to hijack connection: %w", err) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - return nil, err - } - - // https://github.com/golang/go/issues/32314 - b, _ := brw.Reader.Peek(brw.Reader.Buffered()) - brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - - c := &Conn{ - subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), - br: brw.Reader, - bw: brw.Writer, - closer: netConn, - copts: copts, - } - c.init() - - return c, nil -} - -func headerContainsToken(h http.Header, key, token string) bool { - key = textproto.CanonicalMIMEHeaderKey(key) - - token = strings.ToLower(token) - match := func(t string) bool { - return t == token - } - - for _, v := range h[key] { - if searchHeaderTokens(v, match) { - return true - } - } - - return false -} - -// readCompressionExtensionHeader extracts compression extension info from h. -// The standard says we should support multiple compression extension configurations -// from the client but we don't need to as there is only a single deflate extension -// and we support every configuration without error so we only need to check the first -// and thus preferred configuration. -func readCompressionExtensionHeader(h http.Header) (xWebkitDeflateFrame bool, params []string, ok bool) { - match := func(t string) bool { - vals := strings.Split(t, ";") - for i := range vals { - vals[i] = strings.TrimSpace(vals[i]) - } - params = vals[1:] - - if vals[0] == "permessage-deflate" { - return true - } - - // See https://bugs.webkit.org/show_bug.cgi?id=115504 - if vals[0] == "x-webkit-deflate-frame" { - xWebkitDeflateFrame = true - return true - } - - return false - } - - key := textproto.CanonicalMIMEHeaderKey("Sec-WebSocket-Extensions") - for _, v := range h[key] { - if searchHeaderTokens(v, match) { - return xWebkitDeflateFrame, params, true - } - } - - return false, nil, false -} - -func searchHeaderTokens(v string, match func(val string) bool) bool { - v = strings.ToLower(v) - v = strings.TrimSpace(v) - - for _, v2 := range strings.Split(v, ",") { - v2 = strings.TrimSpace(v2) - if match(v2) { - return true - } - } - - return false -} - -func selectSubprotocol(r *http.Request, subprotocols []string) string { - for _, sp := range subprotocols { - if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { - return sp - } - } - return "" -} - -var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") - -func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { - key := r.Header.Get("Sec-WebSocket-Key") - w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) -} - -func secWebSocketAccept(secWebSocketKey string) string { - h := sha1.New() - h.Write([]byte(secWebSocketKey)) - h.Write(keyGUID) - - return base64.StdEncoding.EncodeToString(h.Sum(nil)) -} - -func authenticateOrigin(r *http.Request) error { - origin := r.Header.Get("Origin") - if origin == "" { - return nil - } - u, err := url.Parse(origin) - if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) - } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) - } - return nil -} - -// DialOptions represents the options available to pass to Dial. -type DialOptions struct { - // HTTPClient is the http client used for the handshake. - // Its Transport must return writable bodies - // for WebSocket handshakes. - // http.Transport does this correctly beginning with Go 1.12. - HTTPClient *http.Client - - // HTTPHeader specifies the HTTP headers included in the handshake request. - HTTPHeader http.Header - - // Subprotocols lists the subprotocols to negotiate with the server. - Subprotocols []string - - // Compression sets the compression options. - // By default, compression is disabled. - // See docs on the CompressionOptions type. - Compression *CompressionOptions -} - -// CompressionOptions describes the available compression options. -// -// See https://tools.ietf.org/html/rfc7692 -// -// The NoContextTakeover variables control whether a flate.Writer or flate.Reader is allocated -// for every connection (context takeover) versus shared from a pool (no context takeover). -// -// The advantage to context takeover is more efficient compression as the sliding window from previous -// messages will be used instead of being reset between every message. -// -// The advantage to no context takeover is that the flate structures are allocated as needed -// and shared between connections instead of giving each connection a fixed flate.Writer and -// flate.Reader. -// -// See https://www.igvita.com/2013/11/27/configuring-and-optimizing-websocket-compression. -// -// Enabling compression will increase memory and CPU usage and should -// be profiled before enabling in production. -// See https://github.com/gorilla/websocket/issues/203 -// -// This API is experimental and subject to change. -type CompressionOptions struct { - // ClientNoContextTakeover controls whether the client should use context takeover. - // See docs on CompressionOptions for discussion regarding context takeover. - // - // If set by the server, will guarantee that the client does not use context takeover. - ClientNoContextTakeover bool - - // ServerNoContextTakeover controls whether the server should use context takeover. - // See docs on CompressionOptions for discussion regarding context takeover. - // - // If set by the client, will guarantee that the server does not use context takeover. - ServerNoContextTakeover bool - - // Level controls the compression level used. - // Defaults to flate.BestSpeed. - Level int - - // Threshold controls the minimum message size in bytes before compression is used. - // Must not be greater than 4096 as that is the write buffer's size. - // - // Defaults to 256. - Threshold int - - // This is used for supporting Safari as it still uses x-webkit-deflate-frame. - // See negotiateCompression. - xWebkitDeflateFrame bool -} - -// Dial performs a WebSocket handshake on the given url with the given options. -// The response is the WebSocket handshake response from the server. -// If an error occurs, the returned response may be non nil. However, you can only -// read the first 1024 bytes of its body. -// -// You never need to close the resp.Body yourself. -// -// This function requires at least Go 1.12 to succeed as it uses a new feature -// in net/http to perform WebSocket handshakes and get a writable body -// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861 -func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { - c, r, err := dial(ctx, u, opts) - if err != nil { - return nil, r, fmt.Errorf("failed to websocket dial: %w", err) - } - return c, r, nil -} - -func (opts *DialOptions) ensure() (*DialOptions, error) { - if opts == nil { - opts = &DialOptions{} - } else { - opts = &*opts - } - - if opts.HTTPClient == nil { - opts.HTTPClient = http.DefaultClient - } - if opts.HTTPClient.Timeout > 0 { - return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } - if opts.HTTPHeader == nil { - opts.HTTPHeader = http.Header{} - } - - return opts, nil -} - -func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - opts, err = opts.ensure() - if err != nil { - return nil, nil, err - } - - parsedURL, err := url.Parse(u) - if err != nil { - return nil, nil, fmt.Errorf("failed to parse url: %w", err) - } - - switch parsedURL.Scheme { - case "ws": - parsedURL.Scheme = "http" - case "wss": - parsedURL.Scheme = "https" - default: - return nil, nil, fmt.Errorf("unexpected url scheme: %q", parsedURL.Scheme) - } - - req, _ := http.NewRequest("GET", parsedURL.String(), nil) - req = req.WithContext(ctx) - req.Header = opts.HTTPHeader - req.Header.Set("Connection", "Upgrade") - req.Header.Set("Upgrade", "websocket") - req.Header.Set("Sec-WebSocket-Version", "13") - secWebSocketKey, err := makeSecWebSocketKey() - if err != nil { - return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) - } - req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) - if len(opts.Subprotocols) > 0 { - req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) - } - if opts.Compression != nil { - opts.Compression.setHeader(req.Header, true) - } - - resp, err := opts.HTTPClient.Do(req) - if err != nil { - return nil, nil, fmt.Errorf("failed to send handshake request: %w", err) - } - defer func() { - if err != nil { - // We read a bit of the body for easier debugging. - r := io.LimitReader(resp.Body, 1024) - b, _ := ioutil.ReadAll(r) - resp.Body.Close() - resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - } - }() - - copts, err := verifyServerResponse(req, resp, opts) - if err != nil { - return nil, resp, err - } - - rwc, ok := resp.Body.(io.ReadWriteCloser) - if !ok { - return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", resp.Body) - } - - c := &Conn{ - subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: getBufioReader(rwc), - bw: getBufioWriter(rwc), - closer: rwc, - client: true, - copts: copts, - } - c.extractBufioWriterBuf(rwc) - c.init() - - return c, resp, nil -} - -func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*CompressionOptions, error) { - if resp.StatusCode != http.StatusSwitchingProtocols { - return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) - } - - if !headerContainsToken(resp.Header, "Connection", "Upgrade") { - return nil, fmt.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) - } - - if !headerContainsToken(resp.Header, "Upgrade", "WebSocket") { - return nil, fmt.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) - } - - if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) { - return nil, fmt.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", - resp.Header.Get("Sec-WebSocket-Accept"), - r.Header.Get("Sec-WebSocket-Key"), - ) - } - - if proto := resp.Header.Get("Sec-WebSocket-Protocol"); proto != "" && !headerContainsToken(r.Header, "Sec-WebSocket-Protocol", proto) { - return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) - } - - var copts *CompressionOptions - if opts.Compression != nil { - var err error - copts, err = negotiateCompression(resp.Header, opts.Compression) - if err != nil { - return nil, err - } - } - - return copts, nil -} - -// The below pools can only be used by the client because http.Hijacker will always -// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top. - -var bufioReaderPool = sync.Pool{ - New: func() interface{} { - return bufio.NewReader(nil) - }, -} - -func getBufioReader(r io.Reader) *bufio.Reader { - br := bufioReaderPool.Get().(*bufio.Reader) - br.Reset(r) - return br -} - -func returnBufioReader(br *bufio.Reader) { - bufioReaderPool.Put(br) -} - -var bufioWriterPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriter(nil) - }, -} - -func getBufioWriter(w io.Writer) *bufio.Writer { - bw := bufioWriterPool.Get().(*bufio.Writer) - bw.Reset(w) - return bw -} - -func returnBufioWriter(bw *bufio.Writer) { - bufioWriterPool.Put(bw) -} - -func makeSecWebSocketKey() (string, error) { - b := make([]byte, 16) - _, err := io.ReadFull(rand.Reader, b) - if err != nil { - return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) - } - return base64.StdEncoding.EncodeToString(b), nil -} - -func negotiateCompression(h http.Header, copts *CompressionOptions) (*CompressionOptions, error) { - xWebkitDeflateFrame, params, ok := readCompressionExtensionHeader(h) - if !ok { - return nil, nil - } - - // Ensures our changes do not modify the real compression options. - copts = &*copts - copts.xWebkitDeflateFrame = xWebkitDeflateFrame - - // We are the client if the header contains the accept header, meaning its from the server. - client := h.Get("Sec-WebSocket-Accept") == "" - - if copts.xWebkitDeflateFrame { - // The other endpoint dictates whether or not we can - // use context takeover on our side. We cannot force it. - // Likewise, we tell the other side so we can force that. - if client { - copts.ClientNoContextTakeover = false - } else { - copts.ServerNoContextTakeover = false - } - } - - for _, p := range params { - switch p { - case "client_no_context_takeover": - copts.ClientNoContextTakeover = true - continue - case "server_no_context_takeover": - copts.ServerNoContextTakeover = true - continue - case "client_max_window_bits", "server-max-window-bits": - if !client { - // If we are the server, we are allowed to ignore these parameters. - // However, if we are the client, we must obey them but because of - // https://github.com/golang/go/issues/3155 we cannot. - continue - } - case "no_context_takeover": - if copts.xWebkitDeflateFrame { - if client { - copts.ClientNoContextTakeover = true - } else { - copts.ServerNoContextTakeover = true - } - continue - } - - // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead - // of ignoring it as the draft spec is unclear. It says the server can ignore it - // but the server has no way of signalling to the client it was ignored as parameters - // are set one way. - // Thus us ignoring it would make the client think we understood it which would cause issues. - // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 - // - // Either way, we're only implementing this for webkit which never sends the max_window_bits - // parameter so we don't need to worry about it. - } - - return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) - } - - return copts, nil -} - -func (copts *CompressionOptions) setHeader(h http.Header, client bool) { - var s string - if !copts.xWebkitDeflateFrame { - s := "permessage-deflate" - if copts.ClientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.ServerNoContextTakeover { - s += "; server_no_context_takeover" - } - } else { - s = "x-webkit-deflate-frame" - // We can only set no context takeover for the peer. - if client && copts.ServerNoContextTakeover || !client && copts.ClientNoContextTakeover { - s += "; no_context_takeover" - } - } - h.Set("Sec-WebSocket-Extensions", s) -} diff --git a/internal/assert/assert.go b/internal/assert/assert.go index e57abfd..372d546 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -1,8 +1,8 @@ package assert import ( - "fmt" "reflect" + "testing" "github.com/google/go-cmp/cmp" ) @@ -53,11 +53,15 @@ func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { } } -// Equalf compares exp to act and if they are not equal, returns -// an error describing an error. -func Equalf(exp, act interface{}, f string, v ...interface{}) error { - if diff := cmpDiff(exp, act); diff != "" { - return fmt.Errorf(f+": %v", append(v, diff)...) +func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) { + t.Helper() + diff := cmpDiff(exp, act) + if diff != "" { + t.Fatalf(f+": %v", append(v, diff)...) } - return nil +} + +func Success(t *testing.T, err error) { + t.Helper() + Equalf(t, error(nil), err, "unexpected failure") } diff --git a/internal/atomicint/atomicint.go b/internal/atomicint/atomicint.go new file mode 100644 index 0000000..668b3b4 --- /dev/null +++ b/internal/atomicint/atomicint.go @@ -0,0 +1,32 @@ +package atomicint + +import ( + "fmt" + "sync/atomic" +) + +// See https://github.com/nhooyr/websocket/issues/153 +type Int64 struct { + v int64 +} + +func (v *Int64) Load() int64 { + return atomic.LoadInt64(&v.v) +} + +func (v *Int64) Store(i int64) { + atomic.StoreInt64(&v.v, i) +} + +func (v *Int64) String() string { + return fmt.Sprint(v.Load()) +} + +// Increment increments the value and returns the new value. +func (v *Int64) Increment(delta int64) int64 { + return atomic.AddInt64(&v.v, delta) +} + +func (v *Int64) CAS(old, new int64) (swapped bool) { + return atomic.CompareAndSwapInt64(&v.v, old, new) +} diff --git a/internal/bpool/bpool.go b/internal/bufpool/buf.go similarity index 95% rename from internal/bpool/bpool.go rename to internal/bufpool/buf.go index 4266c23..324a17e 100644 --- a/internal/bpool/bpool.go +++ b/internal/bufpool/buf.go @@ -1,4 +1,4 @@ -package bpool +package bufpool import ( "bytes" diff --git a/internal/bpool/bpool_test.go b/internal/bufpool/buf_test.go similarity index 97% rename from internal/bpool/bpool_test.go rename to internal/bufpool/buf_test.go index 5dfe56e..42a2fea 100644 --- a/internal/bpool/bpool_test.go +++ b/internal/bufpool/buf_test.go @@ -1,4 +1,4 @@ -package bpool +package bufpool import ( "strconv" diff --git a/internal/bufpool/bufio.go b/internal/bufpool/bufio.go new file mode 100644 index 0000000..875bbf4 --- /dev/null +++ b/internal/bufpool/bufio.go @@ -0,0 +1,40 @@ +package bufpool + +import ( + "bufio" + "io" + "sync" +) + +var readerPool = sync.Pool{ + New: func() interface{} { + return bufio.NewReader(nil) + }, +} + +func GetReader(r io.Reader) *bufio.Reader { + br := readerPool.Get().(*bufio.Reader) + br.Reset(r) + return br +} + +func PutReader(br *bufio.Reader) { + readerPool.Put(br) +} + +var writerPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriter(nil) + }, +} + +func GetWriter(w io.Writer) *bufio.Writer { + bw := writerPool.Get().(*bufio.Writer) + bw.Reset(w) + return bw +} + +func PutWriter(bw *bufio.Writer) { + writerPool.Put(bw) +} + diff --git a/internal/wsframe/frame.go b/internal/wsframe/frame.go new file mode 100644 index 0000000..50ff8c1 --- /dev/null +++ b/internal/wsframe/frame.go @@ -0,0 +1,194 @@ +package wsframe + +import ( + "encoding/binary" + "fmt" + "io" + "math" +) + +// Opcode represents a WebSocket Opcode. +type Opcode int + +// Opcode constants. +const ( + OpContinuation Opcode = iota + OpText + OpBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + OpClose + OpPing + OpPong + // 11-16 are reserved for further control frames. +) + +func (o Opcode) Control() bool { + switch o { + case OpClose, OpPing, OpPong: + return true + } + return false +} + +func (o Opcode) Data() bool { + switch o { + case OpText, OpBinary: + return true + } + return false +} + +// First byte contains fin, rsv1, rsv2, rsv3. +// Second byte contains mask flag and payload len. +// Next 8 bytes are the maximum extended payload length. +// Last 4 bytes are the mask key. +// https://tools.ietf.org/html/rfc6455#section-5.2 +const maxHeaderSize = 1 + 1 + 8 + 4 + +// Header represents a WebSocket frame Header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +type Header struct { + Fin bool + RSV1 bool + RSV2 bool + RSV3 bool + Opcode Opcode + + PayloadLength int64 + + Masked bool + MaskKey uint32 +} + +// bytes returns the bytes of the Header. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func (h Header) Bytes(b []byte) []byte { + if b == nil { + b = make([]byte, maxHeaderSize) + } + + b = b[:2] + b[0] = 0 + + if h.Fin { + b[0] |= 1 << 7 + } + if h.RSV1 { + b[0] |= 1 << 6 + } + if h.RSV2 { + b[0] |= 1 << 5 + } + if h.RSV3 { + b[0] |= 1 << 4 + } + + b[0] |= byte(h.Opcode) + + switch { + case h.PayloadLength < 0: + panic(fmt.Sprintf("websocket: invalid Header: negative length: %v", h.PayloadLength)) + case h.PayloadLength <= 125: + b[1] = byte(h.PayloadLength) + case h.PayloadLength <= math.MaxUint16: + b[1] = 126 + b = b[:len(b)+2] + binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.PayloadLength)) + default: + b[1] = 127 + b = b[:len(b)+8] + binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.PayloadLength)) + } + + if h.Masked { + b[1] |= 1 << 7 + b = b[:len(b)+4] + binary.LittleEndian.PutUint32(b[len(b)-4:], h.MaskKey) + } + + return b +} + +func MakeReadHeaderBuf() []byte { + return make([]byte, maxHeaderSize-2) +} + +// ReadHeader reads a Header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func ReadHeader(r io.Reader, b []byte) (Header, error) { + // We read the first two bytes first so that we know + // exactly how long the Header is. + b = b[:2] + _, err := io.ReadFull(r, b) + if err != nil { + return Header{}, err + } + + var h Header + h.Fin = b[0]&(1<<7) != 0 + h.RSV1 = b[0]&(1<<6) != 0 + h.RSV2 = b[0]&(1<<5) != 0 + h.RSV3 = b[0]&(1<<4) != 0 + + h.Opcode = Opcode(b[0] & 0xf) + + var extra int + + h.Masked = b[1]&(1<<7) != 0 + if h.Masked { + extra += 4 + } + + payloadLength := b[1] &^ (1 << 7) + switch { + case payloadLength < 126: + h.PayloadLength = int64(payloadLength) + case payloadLength == 126: + extra += 2 + case payloadLength == 127: + extra += 8 + } + + if extra == 0 { + return h, nil + } + + b = b[:extra] + _, err = io.ReadFull(r, b) + if err != nil { + return Header{}, err + } + + switch { + case payloadLength == 126: + h.PayloadLength = int64(binary.BigEndian.Uint16(b)) + b = b[2:] + case payloadLength == 127: + h.PayloadLength = int64(binary.BigEndian.Uint64(b)) + if h.PayloadLength < 0 { + return Header{}, fmt.Errorf("Header with negative payload length: %v", h.PayloadLength) + } + b = b[8:] + } + + if h.Masked { + h.MaskKey = binary.LittleEndian.Uint32(b) + } + + return h, nil +} + +const MaxControlFramePayload = 125 + +func ParseClosePayload(p []byte) (uint16, string, error) { + if len(p) < 2 { + return 0, "", fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + return binary.BigEndian.Uint16(p), string(p[2:]), nil +} diff --git a/frame_stringer.go b/internal/wsframe/frame_stringer.go similarity index 90% rename from frame_stringer.go rename to internal/wsframe/frame_stringer.go index 72b865f..b2e7f42 100644 --- a/frame_stringer.go +++ b/internal/wsframe/frame_stringer.go @@ -1,6 +1,6 @@ -// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. +// Code generated by "stringer -type=Opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. -package websocket +package wsframe import "strconv" @@ -8,12 +8,12 @@ func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} - _ = x[opContinuation-0] - _ = x[opText-1] - _ = x[opBinary-2] - _ = x[opClose-8] - _ = x[opPing-9] - _ = x[opPong-10] + _ = x[OpContinuation-0] + _ = x[OpText-1] + _ = x[OpBinary-2] + _ = x[OpClose-8] + _ = x[OpPing-9] + _ = x[OpPong-10] } const ( @@ -26,7 +26,7 @@ var ( _opcode_index_1 = [...]uint8{0, 7, 13, 19} ) -func (i opcode) String() string { +func (i Opcode) String() string { switch { case 0 <= i && i <= 2: return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] @@ -34,7 +34,7 @@ func (i opcode) String() string { i -= 8 return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] default: - return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + return "Opcode(" + strconv.FormatInt(int64(i), 10) + ")" } } func _() { diff --git a/internal/wsframe/frame_test.go b/internal/wsframe/frame_test.go new file mode 100644 index 0000000..d6b66e7 --- /dev/null +++ b/internal/wsframe/frame_test.go @@ -0,0 +1,157 @@ +// +build !js + +package wsframe + +import ( + "bytes" + "io" + "math/rand" + "strconv" + "testing" + "time" + _ "unsafe" + + "github.com/google/go-cmp/cmp" + _ "github.com/gorilla/websocket" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func randBool() bool { + return rand.Intn(1) == 0 +} + +func TestHeader(t *testing.T) { + t.Parallel() + + t.Run("eof", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + bytes []byte + }{ + { + "start", + []byte{0xff}, + }, + { + "middle", + []byte{0xff, 0xff, 0xff}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := bytes.NewBuffer(tc.bytes) + _, err := ReadHeader(nil, b) + if io.ErrUnexpectedEOF != err { + t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) + } + }) + } + }) + + t.Run("writeNegativeLength", func(t *testing.T) { + t.Parallel() + + defer func() { + r := recover() + if r == nil { + t.Fatal("failed to induce panic in writeHeader with negative payload length") + } + }() + + Header{ + PayloadLength: -1, + }.Bytes(nil) + }) + + t.Run("readNegativeLength", func(t *testing.T) { + t.Parallel() + + b := Header{ + PayloadLength: 1<<16 + 1, + }.Bytes(nil) + + // Make length negative + b[2] |= 1 << 7 + + r := bytes.NewReader(b) + _, err := ReadHeader(nil, r) + if err == nil { + t.Fatalf("unexpected error value: %+v", err) + } + }) + + t.Run("lengths", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 124, + 125, + 126, + 4096, + 16384, + 65535, + 65536, + 65537, + 131072, + } + + for _, n := range lengths { + n := n + t.Run(strconv.Itoa(n), func(t *testing.T) { + t.Parallel() + + testHeader(t, Header{ + PayloadLength: int64(n), + }) + }) + } + }) + + t.Run("fuzz", func(t *testing.T) { + t.Parallel() + + for i := 0; i < 10000; i++ { + h := Header{ + Fin: randBool(), + RSV1: randBool(), + RSV2: randBool(), + RSV3: randBool(), + Opcode: Opcode(rand.Intn(1 << 4)), + + Masked: randBool(), + PayloadLength: rand.Int63(), + } + + if h.Masked { + h.MaskKey = rand.Uint32() + } + + testHeader(t, h) + } + }) +} + +func testHeader(t *testing.T, h Header) { + b := h.Bytes(nil) + r := bytes.NewReader(b) + h2, err := ReadHeader(r, nil) + if err != nil { + t.Logf("Header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("failed to read Header: %v", err) + } + + if !cmp.Equal(h, h2, cmp.AllowUnexported(Header{})) { + t.Logf("Header: %#v", h) + t.Logf("bytes: %b", b) + t.Fatalf("parsed and read Header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(Header{}))) + } +} diff --git a/internal/wsframe/mask.go b/internal/wsframe/mask.go new file mode 100644 index 0000000..2da4c11 --- /dev/null +++ b/internal/wsframe/mask.go @@ -0,0 +1,128 @@ +package wsframe + +import ( + "encoding/binary" + "math/bits" +) + +// Mask applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func Mask(key uint32, b []byte) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/internal/wsframe/mask_test.go b/internal/wsframe/mask_test.go new file mode 100644 index 0000000..fbd2989 --- /dev/null +++ b/internal/wsframe/mask_test.go @@ -0,0 +1,118 @@ +package wsframe_test + +import ( + "crypto/rand" + "encoding/binary" + "github.com/gobwas/ws" + "github.com/google/go-cmp/cmp" + "math/bits" + "nhooyr.io/websocket/internal/wsframe" + "strconv" + "testing" + _ "unsafe" +) + +func Test_mask(t *testing.T) { + t.Parallel() + + key := []byte{0xa, 0xb, 0xc, 0xff} + key32 := binary.LittleEndian.Uint32(key) + p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} + gotKey32 := wsframe.Mask(key32, p) + + if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { + t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) + } + + if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { + t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) + } +} + +func basicMask(maskKey [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= maskKey[pos&3] + pos++ + } + return pos & 3 +} + +//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes +func gorillaMaskBytes(key [4]byte, pos int, b []byte) int + +func Benchmark_mask(b *testing.B) { + sizes := []int{ + 2, + 3, + 4, + 8, + 16, + 32, + 128, + 512, + 4096, + 16384, + } + + fns := []struct { + name string + fn func(b *testing.B, key [4]byte, p []byte) + }{ + { + name: "basic", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + basicMask(key, 0, p) + } + }, + }, + + { + name: "nhooyr", + fn: func(b *testing.B, key [4]byte, p []byte) { + key32 := binary.LittleEndian.Uint32(key[:]) + b.ResetTimer() + + for i := 0; i < b.N; i++ { + wsframe.Mask(key32, p) + } + }, + }, + { + name: "gorilla", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + gorillaMaskBytes(key, 0, p) + } + }, + }, + { + name: "gobwas", + fn: func(b *testing.B, key [4]byte, p []byte) { + for i := 0; i < b.N; i++ { + ws.Cipher(p, key, 0) + } + }, + }, + } + + var key [4]byte + _, err := rand.Read(key[:]) + if err != nil { + b.Fatalf("failed to populate mask key: %v", err) + } + + for _, size := range sizes { + p := make([]byte, size) + + b.Run(strconv.Itoa(size), func(b *testing.B) { + for _, fn := range fns { + b.Run(fn.name, func(b *testing.B) { + b.SetBytes(int64(size)) + + fn.fn(b, key, p) + }) + } + }) + } +} diff --git a/js_test.go b/js_test.go new file mode 100644 index 0000000..80af789 --- /dev/null +++ b/js_test.go @@ -0,0 +1,50 @@ +package websocket_test + +import ( + "context" + "fmt" + "net/http" + "nhooyr.io/websocket/internal/wsecho" + "os" + "os/exec" + "strings" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestJS(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) + if err != nil { + return err + } + defer c.Close(websocket.StatusInternalError, "") + + err = wsecho.Loop(r.Context(), c) + if websocket.CloseStatus(err) != websocket.StatusNormalClosure { + return err + } + return nil + }, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) + + b, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("wasm test binary failed: %v:\n%s", err, b) + } +} diff --git a/conn_common.go b/netconn.go similarity index 60% rename from conn_common.go rename to netconn.go index 1247df6..74a2c7c 100644 --- a/conn_common.go +++ b/netconn.go @@ -1,6 +1,3 @@ -// This file contains *Conn symbols relevant to both -// Wasm and non Wasm builds. - package websocket import ( @@ -10,7 +7,6 @@ import ( "math" "net" "sync" - "sync/atomic" "time" ) @@ -169,77 +165,3 @@ func (c *netConn) SetReadDeadline(t time.Time) error { return nil } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. -// The returned context will be cancelled when the connection is closed. -// -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. -func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - - ctx, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - // We use the unexported reader method so that we don't get the read closed error. - c.reader(ctx, true) - // Either the connection is already closed since there was a read error - // or the context was cancelled or a message was read and we should close - // the connection. - c.Close(StatusPolicyViolation, "unexpected data message") - }() - return ctx -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusMessageTooBig. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit.Store(n) -} - -func (c *Conn) setCloseErr(err error) { - c.closeErrOnce.Do(func() { - c.closeErr = fmt.Errorf("websocket closed: %w", err) - }) -} - -// See https://github.com/nhooyr/websocket/issues/153 -type atomicInt64 struct { - v int64 -} - -func (v *atomicInt64) Load() int64 { - return atomic.LoadInt64(&v.v) -} - -func (v *atomicInt64) Store(i int64) { - atomic.StoreInt64(&v.v, i) -} - -func (v *atomicInt64) String() string { - return fmt.Sprint(v.Load()) -} - -// Increment increments the value and returns the new value. -func (v *atomicInt64) Increment(delta int64) int64 { - return atomic.AddInt64(&v.v, delta) -} - -func (v *atomicInt64) CAS(old, new int64) (swapped bool) { - return atomic.CompareAndSwapInt64(&v.v, old, new) -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/reader.go b/reader.go new file mode 100644 index 0000000..fe71656 --- /dev/null +++ b/reader.go @@ -0,0 +1,31 @@ +package websocket + +import ( + "bufio" + "context" + "io" + "nhooyr.io/websocket/internal/atomicint" + "nhooyr.io/websocket/internal/wsframe" + "strings" +) + +type reader struct { + // Acquired before performing any sort of read operation. + readLock chan struct{} + + c *Conn + + deflateReader io.Reader + br *bufio.Reader + + readClosed *atomicint.Int64 + readHeaderBuf []byte + controlPayloadBuf []byte + + msgCtx context.Context + msgCompressed bool + frameHeader wsframe.Header + frameMaskKey uint32 + frameEOF bool + deflateTail strings.Reader +} diff --git a/websocket_js_test.go b/websocket_js_test.go deleted file mode 100644 index 9b7bb81..0000000 --- a/websocket_js_test.go +++ /dev/null @@ -1,52 +0,0 @@ -package websocket_test - -import ( - "context" - "net/http" - "os" - "testing" - "time" - - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" -) - -func TestConn(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - err = assertSubprotocol(c, "echo") - if err != nil { - t.Fatal(err) - } - - err = assert.Equalf(&http.Response{}, resp, "unexpected http response") - if err != nil { - t.Fatal(err) - } - - err = assertJSONEcho(ctx, c, 1024) - if err != nil { - t.Fatal(err) - } - - err = assertEcho(ctx, c, websocket.MessageBinary, 1024) - if err != nil { - t.Fatal(err) - } - - err = c.Close(websocket.StatusNormalClosure, "") - if err != nil { - t.Fatal(err) - } -} diff --git a/writer.go b/writer.go new file mode 100644 index 0000000..b31d57a --- /dev/null +++ b/writer.go @@ -0,0 +1,5 @@ +package websocket + +type writer struct { + +} diff --git a/websocket_js.go b/ws_js.go similarity index 88% rename from websocket_js.go rename to ws_js.go index d27809c..4c06743 100644 --- a/websocket_js.go +++ b/ws_js.go @@ -1,3 +1,5 @@ +// +build js + package websocket // import "nhooyr.io/websocket" import ( @@ -7,12 +9,13 @@ import ( "fmt" "io" "net/http" + "nhooyr.io/websocket/internal/atomicint" "reflect" "runtime" "sync" "syscall/js" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" "nhooyr.io/websocket/internal/wsjs" ) @@ -21,10 +24,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit *atomicInt64 + msgReadLimit *atomicint.Int64 closingMu sync.Mutex - isReadClosed *atomicInt64 + isReadClosed *atomicint.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -56,17 +59,20 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &atomicInt64{} + c.msgReadLimit = &atomicint.Int64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &atomicInt64{} + c.isReadClosed = &atomicint.Int64{} c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), Reason: e.Reason, } - c.close(fmt.Errorf("received close: %w", err), e.WasClean) + // We do not know if we sent or received this close as + // its possible the browser triggered it without us + // explicitly sending it. + c.close(err, e.WasClean) c.releaseOnClose() c.releaseOnMessage() @@ -288,11 +294,6 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return typ, bytes.NewReader(p), nil } -// Only implemented for use by *Conn.CloseRead in conn_common.go -func (c *Conn) reader(ctx context.Context, _ bool) { - c.read(ctx) -} - // Writer returns a writer to write a WebSocket data message to the connection. // It buffers the entire message in memory and then sends it when the writer // is closed. @@ -301,7 +302,7 @@ func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, err c: c, ctx: ctx, typ: typ, - b: bpool.Get(), + b: bufpool.Get(), }, nil } @@ -331,7 +332,7 @@ func (w writer) Close() error { return errors.New("cannot close closed writer") } w.closed = true - defer bpool.Put(w.b) + defer bufpool.Put(w.b) err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) if err != nil { @@ -339,3 +340,34 @@ func (w writer) Close() error { } return nil } + +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.read(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("websocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/ws_js_test.go b/ws_js_test.go new file mode 100644 index 0000000..abd950c --- /dev/null +++ b/ws_js_test.go @@ -0,0 +1,22 @@ +package websocket + +func TestEcho(t *testing.T) { + t.Parallel() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ + Subprotocols: []string{"echo"}, + }) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") + + assertSubprotocol(t, c, "echo") + assert.Equalf(t, &http.Response{}, resp, "unexpected http response") + assertJSONEcho(t, ctx, c, 1024) + assertEcho(t, ctx, c, websocket.MessageBinary, 1024) + + err = c.Close(websocket.StatusNormalClosure, "") + assert.Success(t, err) +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index fe935fa..9fa8b54 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,9 +5,8 @@ import ( "context" "encoding/json" "fmt" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" ) // Read reads a json message from c into v. @@ -31,8 +30,8 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ) } - b := bpool.Get() - defer bpool.Put(b) + b := bufpool.Get() + defer bufpool.Put(b) _, err = b.ReadFrom(r) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index 3c9e0f7..52ddcd5 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -9,7 +9,7 @@ import ( "github.com/golang/protobuf/proto" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/bufpool" ) // Read reads a protobuf message from c into v. @@ -33,8 +33,8 @@ func read(ctx context.Context, c *websocket.Conn, v proto.Message) error { return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ) } - b := bpool.Get() - defer bpool.Put(b) + b := bufpool.Get() + defer bufpool.Put(b) _, err = b.ReadFrom(r) if err != nil { @@ -61,10 +61,10 @@ func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error { } func write(ctx context.Context, c *websocket.Conn, v proto.Message) error { - b := bpool.Get() + b := bufpool.Get() pb := proto.NewBuffer(b.Bytes()) defer func() { - bpool.Put(bytes.NewBuffer(pb.Bytes())) + bufpool.Put(bytes.NewBuffer(pb.Bytes())) }() err := pb.Marshal(v) -- GitLab