good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
Commits on Source (170)
github: nhooyr
version: 2
updates:
# Track in case we ever add dependencies.
- package-ecosystem: 'gomod'
directory: '/'
schedule:
interval: 'weekly'
commit-message:
prefix: 'chore'
# Keep example and test/benchmark deps up-to-date.
- package-ecosystem: 'gomod'
directories:
- '/internal/examples'
- '/internal/thirdparty'
schedule:
interval: 'monthly'
commit-message:
prefix: 'chore'
labels: []
groups:
internal-deps:
patterns:
- '*'
name: ci
on: [push, pull_request]
on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
......@@ -9,31 +15,45 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: ./ci/fmt.sh
- run: make fmt
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: go version
- uses: actions/setup-go@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: ./ci/lint.sh
- run: make lint
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: ./ci/test.sh
- uses: actions/upload-artifact@v2
if: always()
- run: make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make bench
name: daily
on:
workflow_dispatch:
schedule:
- cron: '42 0 * * *' # daily at 00:42
concurrency:
group: ${{ github.workflow }}
cancel-in-progress: true
jobs:
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench-dev:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test-dev:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage-dev.html
path: ./ci/out/coverage.html
name: static
on:
push:
branches: ['master']
workflow_dispatch:
# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages.
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: pages
cancel-in-progress: true
jobs:
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- name: Generate coverage and badge
run: |
make test
mkdir -p ./ci/out/static
cp ./ci/out/coverage.html ./ci/out/static/coverage.html
percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%')
wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success"
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./ci/out/static/
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
websocket.test
.PHONY: all
all: fmt lint test
.PHONY: fmt
fmt:
./ci/fmt.sh
.PHONY: lint
lint:
./ci/lint.sh
.PHONY: test
test:
./ci/test.sh
.PHONY: bench
bench:
./ci/bench.sh
\ No newline at end of file
# websocket
[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket)
[![coverage](https://img.shields.io/badge/coverage-86%25-success)](https://nhooyrio-websocket-coverage.netlify.app)
[![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket)
[![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html)
websocket is a minimal and idiomatic WebSocket library for Go.
> **note**: I haven't been responsive for questions/reports on the issue tracker but I do
> read through and there are no outstanding bugs. There are certainly some nice to haves
> that I should merge in/figure out but nothing critical. I haven't given up on adding new
> features and cleaning up the code further, just been busy. Should anything critical
> arise, I will fix it.
## Install
```bash
go get nhooyr.io/websocket
```sh
go get github.com/coder/websocket
```
> [!NOTE]
> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket).
> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from
> 2019 to 2024.
## Highlights
- Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports)
- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages
- [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports)
- JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- Zero alloc reads and writes
- Concurrent writes
- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API
- [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm)
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm)
## Roadmap
See GitHub issues for minor issues but the major future enhancements are:
- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217)
- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340)
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)
## Examples
For a production quality example that demonstrates the complete API, see the
[echo example](./examples/echo).
[echo example](./internal/examples/echo).
For a full stack example, see the [chat example](./examples/chat).
For a full stack example, see the [chat example](./internal/examples/chat).
### Server
......@@ -51,9 +61,11 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
if err != nil {
// ...
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
// Set the context as needed. Use of r.Context() is not recommended
// to avoid surprising behavior (see http.Hijacker).
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
var v interface{}
......@@ -78,7 +90,7 @@ c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil {
// ...
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi")
if err != nil {
......@@ -97,12 +109,14 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
- Mature and widely used
- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
- No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection.
- Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411)
Advantages of nhooyr.io/websocket:
Advantages of github.com/coder/websocket:
- Minimal and idiomatic API
- Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper
- Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support
- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
......@@ -110,28 +124,39 @@ Advantages of nhooyr.io/websocket:
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
- Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API
- Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
#### golang.org/x/net/websocket
[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning
to nhooyr.io/websocket.
The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning
to github.com/coder/websocket.
#### gobwas/ws
[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use.
However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
#### lesismal/nbio
[lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is
event driven for performance reasons.
However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
......@@ -5,6 +5,7 @@ package websocket
import (
"bytes"
"context"
"crypto/sha1"
"encoding/base64"
"errors"
......@@ -14,10 +15,10 @@ import (
"net/http"
"net/textproto"
"net/url"
"path/filepath"
"path"
"strings"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
)
// AcceptOptions represents Accept's options.
......@@ -41,8 +42,8 @@ type AcceptOptions struct {
// One would set this field to []string{"example.com"} to authorize example.com to connect.
//
// Each pattern is matched case insensitively against the request origin host
// with filepath.Match.
// See https://golang.org/pkg/path/filepath/#Match
// with path.Match.
// See https://golang.org/pkg/path/#Match
//
// Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks.
......@@ -62,6 +63,22 @@ type AcceptOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
......@@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
//
// Accept will write a response to w on all errors.
//
// Note that using the http.Request Context after Accept returns may lead to
// unexpected behavior (see http.Hijacker).
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts)
}
......@@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil {
if errors.Is(err, filepath.ErrBadPattern) {
if errors.Is(err, path.ErrBadPattern) {
log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden))
}
......@@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
}
}
hj, ok := w.(http.Hijacker)
hj, ok := hijacker(w)
if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
......@@ -123,9 +143,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}
copts, err := acceptCompression(r, w, opts.CompressionMode)
if err != nil {
return nil, err
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}
w.WriteHeader(http.StatusSwitchingProtocols)
......@@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: brw.Reader,
bw: brw.Writer,
......@@ -185,10 +207,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
if r.Header.Get("Sec-WebSocket-Key") == "" {
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
}
// The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}
return 0, nil
}
......@@ -210,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host)
if err != nil {
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
}
if matched {
return nil
......@@ -223,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
}
func match(pattern, s string) (bool, error) {
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
return path.Match(strings.ToLower(pattern), strings.ToLower(s))
}
func selectSubprotocol(r *http.Request, subprotocols []string) string {
......@@ -238,26 +271,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
return ""
}
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, nil
return nil, false
}
for _, ext := range websocketExtensions(r.Header) {
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
return acceptDeflate(w, ext, mode)
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
// case "x-webkit-deflate-frame":
// return acceptWebkitDeflate(w, ext, mode)
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, nil
return nil, false
}
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
......@@ -266,55 +299,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
copts.setHeader(w.Header())
return copts, nil
}
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
// The peer must explicitly request it.
copts.serverNoContextTakeover = false
for _, p := range ext.params {
if p == "no_context_takeover" {
copts.serverNoContextTakeover = true
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
// of ignoring it as the draft spec is unclear. It says the server can ignore it
// but the server has no way of signalling to the client it was ignored as the parameters
// are set one way.
// Thus us ignoring it would make the client think we understood it which would cause issues.
// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
//
// Either way, we're only implementing this for webkit which never sends the max_window_bits
// parameter so we don't need to worry about it.
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
s := "x-webkit-deflate-frame"
if copts.clientNoContextTakeover {
s += "; no_context_takeover"
return nil, false
}
w.Header().Set("Sec-WebSocket-Extensions", s)
return copts, nil
return copts, true
}
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
......
......@@ -10,9 +10,11 @@ import (
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"nhooyr.io/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
)
func TestAccept(t *testing.T) {
......@@ -36,7 +38,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "harhar.com")
_, err := Accept(w, r, nil)
......@@ -52,7 +54,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "https://harhar.com")
_, err := Accept(w, r, nil)
......@@ -62,20 +64,50 @@ func TestAccept(t *testing.T) {
t.Run("badCompression", func(t *testing.T) {
t.Parallel()
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
errHijack := errors.New("hijack error")
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errHijack
},
}
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionContextTakeover,
t.Run("withoutFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
})
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
})
t.Run("requireHttpHijacker", func(t *testing.T) {
......@@ -86,7 +118,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
......@@ -106,11 +138,74 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`)
})
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, "failed to hijack connection")
})
t.Run("closeRace", func(t *testing.T) {
t.Parallel()
server, _ := net.Pipe()
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return server, rw, nil
},
}
}
w := newResponseWriter()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
c, err := Accept(w, r, nil)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
c.Close(StatusInternalError, "the sky is falling")
wg.Done()
}()
go func() {
c.CloseNow()
wg.Done()
}()
wg.Wait()
assert.Success(t, err)
})
}
func Test_verifyClientHandshake(t *testing.T) {
......@@ -153,7 +248,15 @@ func Test_verifyClientHandshake(t *testing.T) {
},
},
{
name: "badWebSocketKey",
name: "missingWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
},
},
{
name: "emptyWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
......@@ -161,13 +264,43 @@ func Test_verifyClientHandshake(t *testing.T) {
"Sec-WebSocket-Key": "",
},
},
{
name: "shortWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(15),
},
},
{
name: "invalidWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "notbase64",
},
},
{
name: "extraWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
// Kinda cheeky, but http headers are case-insensitive.
// If 2 sec keys are present, this is a failure condition.
"Sec-WebSocket-Key": xrand.Base64(16),
"sec-webSocket-key": xrand.Base64(16),
},
},
{
name: "badHTTPVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123",
"Sec-WebSocket-Key": xrand.Base64(16),
},
http1: true,
},
......@@ -177,7 +310,17 @@ func Test_verifyClientHandshake(t *testing.T) {
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123",
"Sec-WebSocket-Key": xrand.Base64(16),
},
success: true,
},
{
name: "successSecKeyExtraSpace",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
},
success: true,
},
......@@ -197,7 +340,7 @@ func Test_verifyClientHandshake(t *testing.T) {
}
for k, v := range tc.h {
r.Header.Set(k, v)
r.Header.Add(k, v)
}
_, err := verifyClientRequest(httptest.NewRecorder(), r)
......@@ -344,59 +487,54 @@ func Test_authenticateOrigin(t *testing.T) {
}
}
func Test_acceptCompression(t *testing.T) {
func Test_selectDeflate(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
mode CompressionMode
reqSecWebSocketExtensions string
respSecWebSocketExtensions string
expCopts *compressionOptions
error bool
name string
mode CompressionMode
header string
expCopts *compressionOptions
expOK bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
expOK: false,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
expOK: false,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/error",
mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; meow",
error: true,
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
// {
// name: "x-webkit-deflate-frame",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// expCopts: &compressionOptions{
// clientNoContextTakeover: true,
// serverNoContextTakeover: true,
// },
// },
// {
// name: "x-webkit-deflate/error",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
// error: true,
// },
}
for _, tc := range testCases {
......@@ -404,19 +542,11 @@ func Test_acceptCompression(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions)
w := httptest.NewRecorder()
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
assert.Error(t, err)
return
}
assert.Success(t, err)
h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
})
}
}
......@@ -431,3 +561,14 @@ var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack()
}
type mockUnwrapper struct {
http.ResponseWriter
unwrap func() http.ResponseWriter
}
var _ rwUnwrapper = mockUnwrapper{}
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
return mu.unwrap()
}
......@@ -6,8 +6,9 @@ package websocket_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"io"
"net"
"os"
"os/exec"
......@@ -16,10 +17,11 @@ import (
"testing"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/util"
)
var excludedAutobahnCases = []string{
......@@ -37,8 +39,7 @@ var autobahnCases = []string{"*"}
// Used to run individual test cases. autobahnCases runs only those cases matched
// and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases
// is niled.
// TODO:
var forceAutobahnCases = []string{}
var onlyAutobahnCases = []string{}
func TestAutobahn(t *testing.T) {
t.Parallel()
......@@ -54,10 +55,15 @@ func TestAutobahn(t *testing.T) {
)
}
if len(onlyAutobahnCases) > 0 {
excludedAutobahnCases = []string{}
autobahnCases = onlyAutobahnCases
}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
wstestURL, closeFn, err := wstestServer(ctx)
wstestURL, closeFn, err := wstestServer(t, ctx)
assert.Success(t, err)
defer func() {
assert.Success(t, closeFn())
......@@ -86,11 +92,11 @@ func TestAutobahn(t *testing.T) {
}
})
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil)
c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil)
assert.Success(t, err)
c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
checkWSTestIndex(t, "./ci/out/autobahn-report/index.json")
}
func waitWS(ctx context.Context, url string) error {
......@@ -109,9 +115,9 @@ func waitWS(ctx context.Context, url string) error {
return ctx.Err()
}
// TODO: Let docker pick the port and use docker port to find it.
// Does mean we can't use -i but that's fine.
func wstestServer(ctx context.Context) (url string, closeFn func() error, err error) {
func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) {
defer errd.Wrap(&err, "failed to start autobahn wstest server")
serverAddr, err := unusedListenAddr()
if err != nil {
return "", nil, err
......@@ -122,7 +128,7 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er
}
url = "ws://" + serverAddr
const outDir = "ci/out/wstestClientReports"
const outDir = "ci/out/autobahn-report"
specFile, err := tempJSONFile(map[string]interface{}{
"url": url,
......@@ -141,6 +147,21 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er
}
}()
dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite")
dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(dockerPull)
err = dockerPull.Run()
if err != nil {
return "", nil, fmt.Errorf("failed to pull docker image: %w", err)
}
wd, err := os.Getwd()
if err != nil {
return "", nil, err
......@@ -158,24 +179,32 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er
// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
"--webport=0",
)
fmt.Println(strings.Join(args, " "))
// TODO: pull image in advance
wstest := exec.CommandContext(ctx, "docker", args...)
// TODO: log to *testing.T
wstest.Stdout = os.Stdout
wstest.Stderr = os.Stderr
wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(wstest)
err = wstest.Start()
if err != nil {
return "", nil, fmt.Errorf("failed to start wstest: %w", err)
}
// TODO: kill
return url, func() error {
err = wstest.Process.Kill()
if err != nil {
return fmt.Errorf("failed to kill wstest: %w", err)
}
return nil
err = wstest.Wait()
var ee *exec.ExitError
if errors.As(err, &ee) && ee.ExitCode() == -1 {
return nil
}
return err
}, nil
}
......@@ -192,7 +221,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
if err != nil {
return 0, err
}
b, err := ioutil.ReadAll(r)
b, err := io.ReadAll(r)
if err != nil {
return 0, err
}
......@@ -207,7 +236,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
}
func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := ioutil.ReadFile(path)
wstestOut, err := os.ReadFile(path)
assert.Success(t, err)
var indexJSON map[string]map[string]struct {
......@@ -252,7 +281,7 @@ func unusedListenAddr() (_ string, err error) {
}
func tempJSONFile(v interface{}) (string, error) {
f, err := ioutil.TempFile("", "temp.json")
f, err := os.CreateTemp("", "temp.json")
if err != nil {
return "", fmt.Errorf("temp file: %w", err)
}
......
#!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
go test --run=^$ --bench=. --benchmem "$@" ./...
# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
(
cd ./internal/thirdparty
go test --run=^$ --bench=. --benchmem "$@" .
GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
fi
)
......@@ -2,17 +2,24 @@
set -eu
cd -- "$(dirname "$0")/.."
X_TOOLS_VERSION=v0.31.0
go mod tidy
(cd ./internal/thirdparty && go mod tidy)
(cd ./internal/examples && go mod tidy)
gofmt -w -s .
go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" .
go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" .
npx prettier@3.0.3 \
--write \
git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \
--check \
--log-level=warn \
--print-width=90 \
--no-semi \
--single-quote \
--arrow-parens=avoid \
$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html")
--arrow-parens=avoid
go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go
go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go
if [ "${CI-}" ]; then
git diff --exit-code
fi
......@@ -2,13 +2,35 @@
set -eu
cd -- "$(dirname "$0")/.."
STATICCHECK_VERSION=v0.6.1
GOVULNCHECK_VERSION=v1.1.4
go vet ./...
GOOS=js GOARCH=wasm go vet ./...
go install golang.org/x/lint/golint@latest
golint -set_exit_status ./...
GOOS=js GOARCH=wasm golint -set_exit_status ./...
go install honnef.co/go/tools/cmd/staticcheck@latest
go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION}
staticcheck ./...
GOOS=js GOARCH=wasm staticcheck ./...
govulncheck() {
tmpf=$(mktemp)
if ! command govulncheck "$@" >"$tmpf" 2>&1; then
cat "$tmpf"
fi
}
go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION}
govulncheck ./...
GOOS=js GOARCH=wasm govulncheck ./...
(
cd ./internal/examples
go vet ./...
staticcheck ./...
govulncheck ./...
)
(
cd ./internal/thirdparty
go vet ./...
staticcheck ./...
govulncheck ./...
)
......@@ -2,8 +2,30 @@
set -eu
cd -- "$(dirname "$0")/.."
go install github.com/agnivade/wasmbrowsertest@latest
go test --race --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
(
cd ./internal/examples
go test "$@" ./...
)
(
cd ./internal/thirdparty
go test "$@" ./...
)
(
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)
go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i.bak '/examples/d' ci/out/coverage.prof
......
......@@ -8,10 +8,10 @@ import (
"encoding/binary"
"errors"
"fmt"
"log"
"net"
"time"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
)
// StatusCode represents a WebSocket status code.
......@@ -93,75 +93,110 @@ func CloseStatus(err error) StatusCode {
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid
// sending a dynamic reason.
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) error {
return c.closeHandshake(code, reason)
}
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
writeErr := c.writeClose(code, reason)
closeHandshakeErr := c.waitCloseHandshake()
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
if writeErr != nil {
return writeErr
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
if CloseStatus(closeHandshakeErr) == -1 {
return closeHandshakeErr
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return nil
return err
}
var errAlreadyWroteClose = errors.New("already wrote close")
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
wroteClose := c.wroteClose
c.wroteClose = true
c.closeMu.Unlock()
if wroteClose {
return errAlreadyWroteClose
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var marshalErr error
var err error
if ce.Code != StatusNoStatusRcvd {
p, marshalErr = ce.bytes()
if marshalErr != nil {
log.Printf("websocket: %v", marshalErr)
p, err = ce.bytes()
if err != nil {
return err
}
}
writeErr := c.writeControl(context.Background(), opClose, p)
if CloseStatus(writeErr) != -1 {
// Not a real error if it's due to a close frame being received.
writeErr = nil
}
// We do this after in case there was an error writing the close frame.
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
if marshalErr != nil {
return marshalErr
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return writeErr
return nil
}
func (c *Conn) waitCloseHandshake() error {
defer c.close(nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
......@@ -171,8 +206,11 @@ func (c *Conn) waitCloseHandshake() error {
}
defer c.readMu.unlock()
if c.readCloseFrameErr != nil {
return c.readCloseFrameErr
for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
for {
......@@ -190,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error {
}
}
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
......@@ -260,16 +328,8 @@ func (ce CloseError) bytesErr() ([]byte, error) {
return buf, nil
}
func (c *Conn) setCloseErr(err error) {
c.closeMu.Lock()
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}
func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
func (c *Conn) casClosing() bool {
return c.closing.Swap(true)
}
func (c *Conn) isClosed() bool {
......
......@@ -9,7 +9,7 @@ import (
"strings"
"testing"
"nhooyr.io/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/assert"
)
func TestCloseError(t *testing.T) {
......
......@@ -6,50 +6,47 @@ package websocket
import (
"compress/flate"
"io"
"net/http"
"sync"
)
// CompressionMode represents the modes available to the deflate extension.
// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// A compatibility layer is implemented for the older deflate-frame extension used
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06
// It will work the same in every way except that we cannot signal to the peer we
// want to use no context takeover on our side, we can only signal that they should.
// But it is currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int
const (
// CompressionDisabled disables the deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
// CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// This is the default.
// This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota
// CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection.
// It reusing the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover.
// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// previous messages. i.e compression context across messages is preserved.
//
// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
//
// Sometime in the future it will carry 65 kB overhead instead once https://github.com/golang/go/issues/36919
// is fixed.
// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
//
// If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC.
// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
CompressionContextTakeover
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
// for every message. This applies to both server and client side.
// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
//
// This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used.
// This means less efficient compression as the sliding window from previous messages will not be used but the
// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// Especially if the connections are long lived and seldom written to.
//
// The message will only be compressed if greater than 512 bytes.
// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover
)
......@@ -65,7 +62,7 @@ type compressionOptions struct {
serverNoContextTakeover bool
}
func (copts *compressionOptions) setHeader(h http.Header) {
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
......@@ -73,14 +70,14 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
h.Set("Sec-WebSocket-Extensions", s)
return s
}
// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to return more bytes.
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct {
......@@ -201,23 +198,19 @@ func (sw *slidingWindow) init(n int) {
}
p := slidingWindowPool(n)
buf, ok := p.Get().([]byte)
sw2, ok := p.Get().(*slidingWindow)
if ok {
sw.buf = buf[:0]
*sw = *sw2
} else {
sw.buf = make([]byte, 0, n)
}
}
func (sw *slidingWindow) close() {
if sw.buf == nil {
return
}
sw.buf = sw.buf[:0]
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw.buf)
swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock()
sw.buf = nil
}
func (sw *slidingWindow) write(p []byte) {
......
......@@ -4,11 +4,14 @@
package websocket
import (
"bytes"
"compress/flate"
"io"
"strings"
"testing"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
)
func Test_slidingWindow(t *testing.T) {
......@@ -33,3 +36,27 @@ func Test_slidingWindow(t *testing.T) {
})
}
}
func BenchmarkFlateWriter(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, _ := flate.NewWriter(io.Discard, flate.BestSpeed)
// We have to write a byte to get the writer to allocate to its full extent.
w.Write([]byte{'a'})
w.Flush()
}
}
func BenchmarkFlateReader(b *testing.B) {
b.ReportAllocs()
var buf bytes.Buffer
w, _ := flate.NewWriter(&buf, flate.BestSpeed)
w.Write([]byte{'a'})
w.Flush()
for i := 0; i < b.N; i++ {
r := flate.NewReader(bytes.NewReader(buf.Bytes()))
io.ReadAll(r)
}
}
......@@ -6,9 +6,9 @@ package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
......@@ -42,6 +42,8 @@ const (
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
client bool
......@@ -50,31 +52,42 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
readCloseFrameErr error
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriterState *msgWriterState
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error
// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}
type connConfig struct {
......@@ -83,6 +96,8 @@ type connConfig struct {
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
br *bufio.Reader
bw *bufio.Writer
......@@ -99,11 +114,14 @@ func newConn(cfg connConfig) *Conn {
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}
c.readMu = newMu(c)
......@@ -111,20 +129,20 @@ func newConn(cfg connConfig) *Conn {
c.msgReader = newMsgReader(c)
c.msgWriterState = newMsgWriterState(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() {
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
c.close()
})
go c.timeoutLoop()
......@@ -138,30 +156,29 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close(err error) {
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return
return net.ErrClosed
}
c.setCloseErrLocked(err)
close(c.closed)
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.rwc.Close()
go func() {
c.msgWriterState.close()
c.msgReader.close()
}()
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
......@@ -174,10 +191,10 @@ func (c *Conn) timeoutLoop() {
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
c.close()
return
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
c.close()
return
}
}
......@@ -195,9 +212,9 @@ func (c *Conn) flate() bool {
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1)
p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.Itoa(int(p)))
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
......@@ -224,11 +241,9 @@ func (c *Conn) ping(ctx context.Context, p string) error {
select {
case <-c.closed:
return c.closeErr
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
......@@ -262,11 +277,9 @@ func (m *mu) tryLock() bool {
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return m.c.closeErr
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
......@@ -275,7 +288,7 @@ func (m *mu) lock(ctx context.Context) error {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return m.c.closeErr
return net.ErrClosed
default:
}
return nil
......@@ -288,3 +301,7 @@ func (m *mu) unlock() {
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}