good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
Commits on Source (170)
github: nhooyr
version: 2
updates:
# Track in case we ever add dependencies.
- package-ecosystem: 'gomod'
directory: '/'
schedule:
interval: 'weekly'
commit-message:
prefix: 'chore'
# Keep example and test/benchmark deps up-to-date.
- package-ecosystem: 'gomod'
directories:
- '/internal/examples'
- '/internal/thirdparty'
schedule:
interval: 'monthly'
commit-message:
prefix: 'chore'
labels: []
groups:
internal-deps:
patterns:
- '*'
name: ci name: ci
on: [push, pull_request] on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true cancel-in-progress: true
...@@ -9,30 +15,45 @@ jobs: ...@@ -9,30 +15,45 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/fmt.sh - run: make fmt
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- run: go version - run: go version
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/lint.sh - run: make lint
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/test.sh - run: make test
- uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v4
with: with:
name: coverage.html name: coverage.html
path: ./ci/out/coverage.html path: ./ci/out/coverage.html
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make bench
name: daily
on:
workflow_dispatch:
schedule:
- cron: '42 0 * * *' # daily at 00:42
concurrency:
group: ${{ github.workflow }}
cancel-in-progress: true
jobs:
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench-dev:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test-dev:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage-dev.html
path: ./ci/out/coverage.html
name: static
on:
push:
branches: ['master']
workflow_dispatch:
# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages.
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: pages
cancel-in-progress: true
jobs:
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- name: Generate coverage and badge
run: |
make test
mkdir -p ./ci/out/static
cp ./ci/out/coverage.html ./ci/out/static/coverage.html
percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%')
wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success"
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./ci/out/static/
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
websocket.test
Copyright (c) 2023 Anmol Sethi <hi@nhooyr.io> Copyright (c) 2025 Coder
Permission to use, copy, modify, and distribute this software for any Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above purpose with or without fee is hereby granted, provided that the above
......
.PHONY: all
all: fmt lint test
.PHONY: fmt
fmt:
./ci/fmt.sh
.PHONY: lint
lint:
./ci/lint.sh
.PHONY: test
test:
./ci/test.sh
.PHONY: bench
bench:
./ci/bench.sh
\ No newline at end of file
# websocket # websocket
[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket)
[![coverage](https://img.shields.io/badge/coverage-86%25-success)](https://nhooyrio-websocket-coverage.netlify.app) [![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html)
websocket is a minimal and idiomatic WebSocket library for Go. websocket is a minimal and idiomatic WebSocket library for Go.
> **note**: I haven't been responsive for questions/reports on the issue tracker but I do
> read through and there are no outstanding bugs. There are certainly some nice to haves
> that I should merge in/figure out but nothing critical. I haven't given up on adding new
> features and cleaning up the code further, just been busy. Should anything critical
> arise, I will fix it.
## Install ## Install
```bash ```sh
go get nhooyr.io/websocket go get github.com/coder/websocket
``` ```
> [!NOTE]
> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket).
> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from
> 2019 to 2024.
## Highlights ## Highlights
- Minimal and idiomatic API - Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support - First class [context.Context](https://blog.golang.org/context) support
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) - [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports)
- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- Zero alloc reads and writes - Zero alloc reads and writes
- Concurrent writes - Concurrent writes
- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) - [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm)
## Roadmap ## Roadmap
See GitHub issues for minor issues but the major future enhancements are:
- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217)
- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340)
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) - [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)
## Examples ## Examples
For a production quality example that demonstrates the complete API, see the For a production quality example that demonstrates the complete API, see the
[echo example](./examples/echo). [echo example](./internal/examples/echo).
For a full stack example, see the [chat example](./examples/chat). For a full stack example, see the [chat example](./internal/examples/chat).
### Server ### Server
...@@ -51,9 +61,11 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { ...@@ -51,9 +61,11 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
if err != nil { if err != nil {
// ... // ...
} }
defer c.Close(websocket.StatusInternalError, "the sky is falling") defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) // Set the context as needed. Use of r.Context() is not recommended
// to avoid surprising behavior (see http.Hijacker).
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
var v interface{} var v interface{}
...@@ -78,7 +90,7 @@ c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) ...@@ -78,7 +90,7 @@ c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil { if err != nil {
// ... // ...
} }
defer c.Close(websocket.StatusInternalError, "the sky is falling") defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi") err = wsjson.Write(ctx, c, "hi")
if err != nil { if err != nil {
...@@ -97,12 +109,14 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): ...@@ -97,12 +109,14 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
- Mature and widely used - Mature and widely used
- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
- No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection.
- Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411)
Advantages of nhooyr.io/websocket: Advantages of github.com/coder/websocket:
- Minimal and idiomatic API - Minimal and idiomatic API
- Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support - Full [context.Context](https://blog.golang.org/context) support
- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
...@@ -110,28 +124,39 @@ Advantages of nhooyr.io/websocket: ...@@ -110,28 +124,39 @@ Advantages of nhooyr.io/websocket:
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
- Concurrent writes - Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- Gorilla requires registering a pong callback before sending a Ping - Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode - Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370))
#### golang.org/x/net/websocket #### golang.org/x/net/websocket
[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning
to nhooyr.io/websocket. to github.com/coder/websocket.
#### gobwas/ws #### gobwas/ws
[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
#### lesismal/nbio
[lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is
event driven for performance reasons.
However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
...@@ -5,6 +5,7 @@ package websocket ...@@ -5,6 +5,7 @@ package websocket
import ( import (
"bytes" "bytes"
"context"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
...@@ -14,10 +15,10 @@ import ( ...@@ -14,10 +15,10 @@ import (
"net/http" "net/http"
"net/textproto" "net/textproto"
"net/url" "net/url"
"path/filepath" "path"
"strings" "strings"
"nhooyr.io/websocket/internal/errd" "github.com/coder/websocket/internal/errd"
) )
// AcceptOptions represents Accept's options. // AcceptOptions represents Accept's options.
...@@ -41,8 +42,8 @@ type AcceptOptions struct { ...@@ -41,8 +42,8 @@ type AcceptOptions struct {
// One would set this field to []string{"example.com"} to authorize example.com to connect. // One would set this field to []string{"example.com"} to authorize example.com to connect.
// //
// Each pattern is matched case insensitively against the request origin host // Each pattern is matched case insensitively against the request origin host
// with filepath.Match. // with path.Match.
// See https://golang.org/pkg/path/filepath/#Match // See https://golang.org/pkg/path/#Match
// //
// Please ensure you understand the ramifications of enabling this. // Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks. // If used incorrectly your WebSocket server will be open to CSRF attacks.
...@@ -62,6 +63,22 @@ type AcceptOptions struct { ...@@ -62,6 +63,22 @@ type AcceptOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover. // for CompressionContextTakeover.
CompressionThreshold int CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
} }
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
...@@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { ...@@ -79,6 +96,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
// //
// Accept will write a response to w on all errors. // Accept will write a response to w on all errors.
//
// Note that using the http.Request Context after Accept returns may lead to
// unexpected behavior (see http.Hijacker).
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts) return accept(w, r, opts)
} }
...@@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -96,7 +116,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
if !opts.InsecureSkipVerify { if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns) err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil { if err != nil {
if errors.Is(err, filepath.ErrBadPattern) { if errors.Is(err, path.ErrBadPattern) {
log.Printf("websocket: %v", err) log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden)) err = errors.New(http.StatusText(http.StatusForbidden))
} }
...@@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -105,7 +125,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
} }
} }
hj, ok := w.(http.Hijacker) hj, ok := hijacker(w)
if !ok { if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker") err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
...@@ -123,9 +143,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -123,9 +143,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
w.Header().Set("Sec-WebSocket-Protocol", subproto) w.Header().Set("Sec-WebSocket-Protocol", subproto)
} }
copts, err := acceptCompression(r, w, opts.CompressionMode) copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if err != nil { if ok {
return nil, err w.Header().Set("Sec-WebSocket-Extensions", copts.String())
} }
w.WriteHeader(http.StatusSwitchingProtocols) w.WriteHeader(http.StatusSwitchingProtocols)
...@@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -153,6 +173,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
client: false, client: false,
copts: copts, copts: copts,
flateThreshold: opts.CompressionThreshold, flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: brw.Reader, br: brw.Reader,
bw: brw.Writer, bw: brw.Writer,
...@@ -185,10 +207,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ ...@@ -185,10 +207,21 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
} }
if r.Header.Get("Sec-WebSocket-Key") == "" { websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
} }
if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
}
// The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}
return 0, nil return 0, nil
} }
...@@ -210,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { ...@@ -210,7 +243,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
for _, hostPattern := range originHosts { for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host) matched, err := match(hostPattern, u.Host)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
} }
if matched { if matched {
return nil return nil
...@@ -223,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { ...@@ -223,7 +256,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
} }
func match(pattern, s string) (bool, error) { func match(pattern, s string) (bool, error) {
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) return path.Match(strings.ToLower(pattern), strings.ToLower(s))
} }
func selectSubprotocol(r *http.Request, subprotocols []string) string { func selectSubprotocol(r *http.Request, subprotocols []string) string {
...@@ -238,26 +271,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string { ...@@ -238,26 +271,26 @@ func selectSubprotocol(r *http.Request, subprotocols []string) string {
return "" return ""
} }
func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled { if mode == CompressionDisabled {
return nil, nil return nil, false
} }
for _, ext := range extensions {
for _, ext := range websocketExtensions(r.Header) {
switch ext.name { switch ext.name {
// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate": case "permessage-deflate":
return acceptDeflate(w, ext, mode) copts, ok := acceptDeflate(ext, mode)
// Disabled for now, see https://github.com/nhooyr/websocket/issues/218 if ok {
// case "x-webkit-deflate-frame": return copts, true
// return acceptWebkitDeflate(w, ext, mode) }
} }
} }
return nil, nil return nil, false
} }
func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts() copts := mode.opts()
for _, p := range ext.params { for _, p := range ext.params {
switch p { switch p {
case "client_no_context_takeover": case "client_no_context_takeover":
...@@ -266,55 +299,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi ...@@ -266,55 +299,18 @@ func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode Compressi
case "server_no_context_takeover": case "server_no_context_takeover":
copts.serverNoContextTakeover = true copts.serverNoContextTakeover = true
continue continue
} case "client_max_window_bits",
"server_max_window_bits=15":
if strings.HasPrefix(p, "client_max_window_bits") {
// We cannot adjust the read sliding window so cannot make use of this.
continue continue
} }
err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) if strings.HasPrefix(p, "client_max_window_bits=") {
http.Error(w, err.Error(), http.StatusBadRequest) // We can't adjust the deflate window, but decoding with a larger window is acceptable.
return nil, err
}
copts.setHeader(w.Header())
return copts, nil
}
func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
copts := mode.opts()
// The peer must explicitly request it.
copts.serverNoContextTakeover = false
for _, p := range ext.params {
if p == "no_context_takeover" {
copts.serverNoContextTakeover = true
continue continue
} }
return nil, false
// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
// of ignoring it as the draft spec is unclear. It says the server can ignore it
// but the server has no way of signalling to the client it was ignored as the parameters
// are set one way.
// Thus us ignoring it would make the client think we understood it which would cause issues.
// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
//
// Either way, we're only implementing this for webkit which never sends the max_window_bits
// parameter so we don't need to worry about it.
err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
http.Error(w, err.Error(), http.StatusBadRequest)
return nil, err
}
s := "x-webkit-deflate-frame"
if copts.clientNoContextTakeover {
s += "; no_context_takeover"
} }
w.Header().Set("Sec-WebSocket-Extensions", s) return copts, true
return copts, nil
} }
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
......
...@@ -10,9 +10,11 @@ import ( ...@@ -10,9 +10,11 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
) )
func TestAccept(t *testing.T) { func TestAccept(t *testing.T) {
...@@ -36,7 +38,7 @@ func TestAccept(t *testing.T) { ...@@ -36,7 +38,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade") r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket") r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "harhar.com") r.Header.Set("Origin", "harhar.com")
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
...@@ -52,7 +54,7 @@ func TestAccept(t *testing.T) { ...@@ -52,7 +54,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade") r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket") r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "https://harhar.com") r.Header.Set("Origin", "https://harhar.com")
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
...@@ -62,20 +64,50 @@ func TestAccept(t *testing.T) { ...@@ -62,20 +64,50 @@ func TestAccept(t *testing.T) {
t.Run("badCompression", func(t *testing.T) { t.Run("badCompression", func(t *testing.T) {
t.Parallel() t.Parallel()
w := mockHijacker{ newRequest := func(extensions string) *http.Request {
ResponseWriter: httptest.NewRecorder(), r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
errHijack := errors.New("hijack error")
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errHijack
},
}
} }
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{ t.Run("withoutFallback", func(t *testing.T) {
CompressionMode: CompressionContextTakeover, t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
}) })
assert.Contains(t, err, `unsupported permessage-deflate parameter`)
}) })
t.Run("requireHttpHijacker", func(t *testing.T) { t.Run("requireHttpHijacker", func(t *testing.T) {
...@@ -86,7 +118,7 @@ func TestAccept(t *testing.T) { ...@@ -86,7 +118,7 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade") r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket") r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`) assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
...@@ -106,11 +138,74 @@ func TestAccept(t *testing.T) { ...@@ -106,11 +138,74 @@ func TestAccept(t *testing.T) {
r.Header.Set("Connection", "Upgrade") r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket") r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13") r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123") r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`) assert.Contains(t, err, `failed to hijack connection`)
}) })
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, "failed to hijack connection")
})
t.Run("closeRace", func(t *testing.T) {
t.Parallel()
server, _ := net.Pipe()
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return server, rw, nil
},
}
}
w := newResponseWriter()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
c, err := Accept(w, r, nil)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
c.Close(StatusInternalError, "the sky is falling")
wg.Done()
}()
go func() {
c.CloseNow()
wg.Done()
}()
wg.Wait()
assert.Success(t, err)
})
} }
func Test_verifyClientHandshake(t *testing.T) { func Test_verifyClientHandshake(t *testing.T) {
...@@ -153,7 +248,15 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -153,7 +248,15 @@ func Test_verifyClientHandshake(t *testing.T) {
}, },
}, },
{ {
name: "badWebSocketKey", name: "missingWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
},
},
{
name: "emptyWebSocketKey",
h: map[string]string{ h: map[string]string{
"Connection": "Upgrade", "Connection": "Upgrade",
"Upgrade": "websocket", "Upgrade": "websocket",
...@@ -161,13 +264,43 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -161,13 +264,43 @@ func Test_verifyClientHandshake(t *testing.T) {
"Sec-WebSocket-Key": "", "Sec-WebSocket-Key": "",
}, },
}, },
{
name: "shortWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(15),
},
},
{
name: "invalidWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "notbase64",
},
},
{
name: "extraWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
// Kinda cheeky, but http headers are case-insensitive.
// If 2 sec keys are present, this is a failure condition.
"Sec-WebSocket-Key": xrand.Base64(16),
"sec-webSocket-key": xrand.Base64(16),
},
},
{ {
name: "badHTTPVersion", name: "badHTTPVersion",
h: map[string]string{ h: map[string]string{
"Connection": "Upgrade", "Connection": "Upgrade",
"Upgrade": "websocket", "Upgrade": "websocket",
"Sec-WebSocket-Version": "13", "Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123", "Sec-WebSocket-Key": xrand.Base64(16),
}, },
http1: true, http1: true,
}, },
...@@ -177,7 +310,17 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -177,7 +310,17 @@ func Test_verifyClientHandshake(t *testing.T) {
"Connection": "keep-alive, Upgrade", "Connection": "keep-alive, Upgrade",
"Upgrade": "websocket", "Upgrade": "websocket",
"Sec-WebSocket-Version": "13", "Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123", "Sec-WebSocket-Key": xrand.Base64(16),
},
success: true,
},
{
name: "successSecKeyExtraSpace",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
}, },
success: true, success: true,
}, },
...@@ -197,7 +340,7 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -197,7 +340,7 @@ func Test_verifyClientHandshake(t *testing.T) {
} }
for k, v := range tc.h { for k, v := range tc.h {
r.Header.Set(k, v) r.Header.Add(k, v)
} }
_, err := verifyClientRequest(httptest.NewRecorder(), r) _, err := verifyClientRequest(httptest.NewRecorder(), r)
...@@ -344,59 +487,54 @@ func Test_authenticateOrigin(t *testing.T) { ...@@ -344,59 +487,54 @@ func Test_authenticateOrigin(t *testing.T) {
} }
} }
func Test_acceptCompression(t *testing.T) { func Test_selectDeflate(t *testing.T) {
t.Parallel() t.Parallel()
testCases := []struct { testCases := []struct {
name string name string
mode CompressionMode mode CompressionMode
reqSecWebSocketExtensions string header string
respSecWebSocketExtensions string expCopts *compressionOptions
expCopts *compressionOptions expOK bool
error bool
}{ }{
{ {
name: "disabled", name: "disabled",
mode: CompressionDisabled, mode: CompressionDisabled,
expCopts: nil, expCopts: nil,
expOK: false,
}, },
{ {
name: "noClientSupport", name: "noClientSupport",
mode: CompressionNoContextTakeover, mode: CompressionNoContextTakeover,
expCopts: nil, expCopts: nil,
expOK: false,
}, },
{ {
name: "permessage-deflate", name: "permessage-deflate",
mode: CompressionNoContextTakeover, mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; client_max_window_bits", header: "permessage-deflate; client_max_window_bits",
respSecWebSocketExtensions: "permessage-deflate; client_no_context_takeover; server_no_context_takeover",
expCopts: &compressionOptions{ expCopts: &compressionOptions{
clientNoContextTakeover: true, clientNoContextTakeover: true,
serverNoContextTakeover: true, serverNoContextTakeover: true,
}, },
expOK: true,
}, },
{ {
name: "permessage-deflate/error", name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover, mode: CompressionNoContextTakeover,
reqSecWebSocketExtensions: "permessage-deflate; meow", header: "permessage-deflate; meow",
error: true, expOK: false,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
}, },
// {
// name: "x-webkit-deflate-frame",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// respSecWebSocketExtensions: "x-webkit-deflate-frame; no_context_takeover",
// expCopts: &compressionOptions{
// clientNoContextTakeover: true,
// serverNoContextTakeover: true,
// },
// },
// {
// name: "x-webkit-deflate/error",
// mode: CompressionNoContextTakeover,
// reqSecWebSocketExtensions: "x-webkit-deflate-frame; max_window_bits",
// error: true,
// },
} }
for _, tc := range testCases { for _, tc := range testCases {
...@@ -404,19 +542,11 @@ func Test_acceptCompression(t *testing.T) { ...@@ -404,19 +542,11 @@ func Test_acceptCompression(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
r := httptest.NewRequest(http.MethodGet, "/", nil) h := http.Header{}
r.Header.Set("Sec-WebSocket-Extensions", tc.reqSecWebSocketExtensions) h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
w := httptest.NewRecorder() assert.Equal(t, "selected options", tc.expOK, ok)
copts, err := acceptCompression(r, w, tc.mode)
if tc.error {
assert.Error(t, err)
return
}
assert.Success(t, err)
assert.Equal(t, "compression options", tc.expCopts, copts) assert.Equal(t, "compression options", tc.expCopts, copts)
assert.Equal(t, "Sec-WebSocket-Extensions", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))
}) })
} }
} }
...@@ -431,3 +561,14 @@ var _ http.Hijacker = mockHijacker{} ...@@ -431,3 +561,14 @@ var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack() return mj.hijack()
} }
type mockUnwrapper struct {
http.ResponseWriter
unwrap func() http.ResponseWriter
}
var _ rwUnwrapper = mockUnwrapper{}
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
return mu.unwrap()
}
...@@ -6,8 +6,9 @@ package websocket_test ...@@ -6,8 +6,9 @@ package websocket_test
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"io/ioutil" "io"
"net" "net"
"os" "os"
"os/exec" "os/exec"
...@@ -16,10 +17,11 @@ import ( ...@@ -16,10 +17,11 @@ import (
"testing" "testing"
"time" "time"
"nhooyr.io/websocket" "github.com/coder/websocket"
"nhooyr.io/websocket/internal/errd" "github.com/coder/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest" "github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/util"
) )
var excludedAutobahnCases = []string{ var excludedAutobahnCases = []string{
...@@ -37,8 +39,7 @@ var autobahnCases = []string{"*"} ...@@ -37,8 +39,7 @@ var autobahnCases = []string{"*"}
// Used to run individual test cases. autobahnCases runs only those cases matched // Used to run individual test cases. autobahnCases runs only those cases matched
// and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases // and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases
// is niled. // is niled.
// TODO: var onlyAutobahnCases = []string{}
var forceAutobahnCases = []string{}
func TestAutobahn(t *testing.T) { func TestAutobahn(t *testing.T) {
t.Parallel() t.Parallel()
...@@ -54,10 +55,15 @@ func TestAutobahn(t *testing.T) { ...@@ -54,10 +55,15 @@ func TestAutobahn(t *testing.T) {
) )
} }
if len(onlyAutobahnCases) > 0 {
excludedAutobahnCases = []string{}
autobahnCases = onlyAutobahnCases
}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour) ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel() defer cancel()
wstestURL, closeFn, err := wstestServer(ctx) wstestURL, closeFn, err := wstestServer(t, ctx)
assert.Success(t, err) assert.Success(t, err)
defer func() { defer func() {
assert.Success(t, closeFn()) assert.Success(t, closeFn())
...@@ -86,11 +92,11 @@ func TestAutobahn(t *testing.T) { ...@@ -86,11 +92,11 @@ func TestAutobahn(t *testing.T) {
} }
}) })
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil)
assert.Success(t, err) assert.Success(t, err)
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") checkWSTestIndex(t, "./ci/out/autobahn-report/index.json")
} }
func waitWS(ctx context.Context, url string) error { func waitWS(ctx context.Context, url string) error {
...@@ -109,9 +115,9 @@ func waitWS(ctx context.Context, url string) error { ...@@ -109,9 +115,9 @@ func waitWS(ctx context.Context, url string) error {
return ctx.Err() return ctx.Err()
} }
// TODO: Let docker pick the port and use docker port to find it. func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) {
// Does mean we can't use -i but that's fine. defer errd.Wrap(&err, "failed to start autobahn wstest server")
func wstestServer(ctx context.Context) (url string, closeFn func() error, err error) {
serverAddr, err := unusedListenAddr() serverAddr, err := unusedListenAddr()
if err != nil { if err != nil {
return "", nil, err return "", nil, err
...@@ -122,7 +128,7 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er ...@@ -122,7 +128,7 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er
} }
url = "ws://" + serverAddr url = "ws://" + serverAddr
const outDir = "ci/out/wstestClientReports" const outDir = "ci/out/autobahn-report"
specFile, err := tempJSONFile(map[string]interface{}{ specFile, err := tempJSONFile(map[string]interface{}{
"url": url, "url": url,
...@@ -141,6 +147,21 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er ...@@ -141,6 +147,21 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er
} }
}() }()
dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite")
dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(dockerPull)
err = dockerPull.Run()
if err != nil {
return "", nil, fmt.Errorf("failed to pull docker image: %w", err)
}
wd, err := os.Getwd() wd, err := os.Getwd()
if err != nil { if err != nil {
return "", nil, err return "", nil, err
...@@ -158,24 +179,32 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er ...@@ -158,24 +179,32 @@ func wstestServer(ctx context.Context) (url string, closeFn func() error, err er
// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
"--webport=0", "--webport=0",
) )
fmt.Println(strings.Join(args, " "))
// TODO: pull image in advance
wstest := exec.CommandContext(ctx, "docker", args...) wstest := exec.CommandContext(ctx, "docker", args...)
// TODO: log to *testing.T wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) {
wstest.Stdout = os.Stdout tb.Log(string(p))
wstest.Stderr = os.Stderr return len(p), nil
})
wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(wstest)
err = wstest.Start() err = wstest.Start()
if err != nil { if err != nil {
return "", nil, fmt.Errorf("failed to start wstest: %w", err) return "", nil, fmt.Errorf("failed to start wstest: %w", err)
} }
// TODO: kill
return url, func() error { return url, func() error {
err = wstest.Process.Kill() err = wstest.Process.Kill()
if err != nil { if err != nil {
return fmt.Errorf("failed to kill wstest: %w", err) return fmt.Errorf("failed to kill wstest: %w", err)
} }
return nil err = wstest.Wait()
var ee *exec.ExitError
if errors.As(err, &ee) && ee.ExitCode() == -1 {
return nil
}
return err
}, nil }, nil
} }
...@@ -192,7 +221,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { ...@@ -192,7 +221,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
if err != nil { if err != nil {
return 0, err return 0, err
} }
b, err := ioutil.ReadAll(r) b, err := io.ReadAll(r)
if err != nil { if err != nil {
return 0, err return 0, err
} }
...@@ -207,7 +236,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { ...@@ -207,7 +236,7 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
} }
func checkWSTestIndex(t *testing.T, path string) { func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := ioutil.ReadFile(path) wstestOut, err := os.ReadFile(path)
assert.Success(t, err) assert.Success(t, err)
var indexJSON map[string]map[string]struct { var indexJSON map[string]map[string]struct {
...@@ -252,7 +281,7 @@ func unusedListenAddr() (_ string, err error) { ...@@ -252,7 +281,7 @@ func unusedListenAddr() (_ string, err error) {
} }
func tempJSONFile(v interface{}) (string, error) { func tempJSONFile(v interface{}) (string, error) {
f, err := ioutil.TempFile("", "temp.json") f, err := os.CreateTemp("", "temp.json")
if err != nil { if err != nil {
return "", fmt.Errorf("temp file: %w", err) return "", fmt.Errorf("temp file: %w", err)
} }
......
#!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
go test --run=^$ --bench=. --benchmem "$@" ./...
# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
(
cd ./internal/thirdparty
go test --run=^$ --bench=. --benchmem "$@" .
GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
fi
)
...@@ -2,17 +2,24 @@ ...@@ -2,17 +2,24 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
X_TOOLS_VERSION=v0.31.0
go mod tidy go mod tidy
(cd ./internal/thirdparty && go mod tidy)
(cd ./internal/examples && go mod tidy)
gofmt -w -s . gofmt -w -s .
go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" .
npx prettier@3.0.3 \ git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \
--write \ --check \
--log-level=warn \ --log-level=warn \
--print-width=90 \ --print-width=90 \
--no-semi \ --no-semi \
--single-quote \ --single-quote \
--arrow-parens=avoid \ --arrow-parens=avoid
$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html")
go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go
go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go if [ "${CI-}" ]; then
git diff --exit-code
fi
...@@ -2,13 +2,35 @@ ...@@ -2,13 +2,35 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
STATICCHECK_VERSION=v0.6.1
GOVULNCHECK_VERSION=v1.1.4
go vet ./... go vet ./...
GOOS=js GOARCH=wasm go vet ./... GOOS=js GOARCH=wasm go vet ./...
go install golang.org/x/lint/golint@latest go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION}
golint -set_exit_status ./...
GOOS=js GOARCH=wasm golint -set_exit_status ./...
go install honnef.co/go/tools/cmd/staticcheck@latest
staticcheck ./... staticcheck ./...
GOOS=js GOARCH=wasm staticcheck ./... GOOS=js GOARCH=wasm staticcheck ./...
govulncheck() {
tmpf=$(mktemp)
if ! command govulncheck "$@" >"$tmpf" 2>&1; then
cat "$tmpf"
fi
}
go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION}
govulncheck ./...
GOOS=js GOARCH=wasm govulncheck ./...
(
cd ./internal/examples
go vet ./...
staticcheck ./...
govulncheck ./...
)
(
cd ./internal/thirdparty
go vet ./...
staticcheck ./...
govulncheck ./...
)
...@@ -2,8 +2,30 @@ ...@@ -2,8 +2,30 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
go install github.com/agnivade/wasmbrowsertest@latest (
go test --race --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... cd ./internal/examples
go test "$@" ./...
)
(
cd ./internal/thirdparty
go test "$@" ./...
)
(
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)
go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i.bak '/examples/d' ci/out/coverage.prof sed -i.bak '/examples/d' ci/out/coverage.prof
......
...@@ -8,10 +8,10 @@ import ( ...@@ -8,10 +8,10 @@ import (
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
"log" "net"
"time" "time"
"nhooyr.io/websocket/internal/errd" "github.com/coder/websocket/internal/errd"
) )
// StatusCode represents a WebSocket status code. // StatusCode represents a WebSocket status code.
...@@ -93,75 +93,110 @@ func CloseStatus(err error) StatusCode { ...@@ -93,75 +93,110 @@ func CloseStatus(err error) StatusCode {
// The connection can only be closed once. Additional calls to Close // The connection can only be closed once. Additional calls to Close
// are no-ops. // are no-ops.
// //
// The maximum length of reason must be 125 bytes. Avoid // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
// sending a dynamic reason.
// //
// Close will unblock all goroutines interacting with the connection once // Close will unblock all goroutines interacting with the connection once
// complete. // complete.
func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) Close(code StatusCode, reason string) (err error) {
return c.closeHandshake(code, reason)
}
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket") defer errd.Wrap(&err, "failed to close WebSocket")
writeErr := c.writeClose(code, reason) if c.casClosing() {
closeHandshakeErr := c.waitCloseHandshake() err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
if writeErr != nil { err2 := c.close()
return writeErr if err == nil && err2 != nil {
err = err2
} }
if CloseStatus(closeHandshakeErr) == -1 { err2 = c.waitGoroutines()
return closeHandshakeErr if err == nil && err2 != nil {
err = err2
} }
return nil return err
} }
var errAlreadyWroteClose = errors.New("already wrote close") // CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
func (c *Conn) writeClose(code StatusCode, reason string) error { if c.casClosing() {
c.closeMu.Lock() err = c.waitGoroutines()
wroteClose := c.wroteClose if err != nil {
c.wroteClose = true return err
c.closeMu.Unlock() }
if wroteClose { return net.ErrClosed
return errAlreadyWroteClose
} }
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{ ce := CloseError{
Code: code, Code: code,
Reason: reason, Reason: reason,
} }
var p []byte var p []byte
var marshalErr error var err error
if ce.Code != StatusNoStatusRcvd { if ce.Code != StatusNoStatusRcvd {
p, marshalErr = ce.bytes() p, err = ce.bytes()
if marshalErr != nil { if err != nil {
log.Printf("websocket: %v", marshalErr) return err
} }
} }
writeErr := c.writeControl(context.Background(), opClose, p) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
if CloseStatus(writeErr) != -1 { defer cancel()
// Not a real error if it's due to a close frame being received.
writeErr = nil
}
// We do this after in case there was an error writing the close frame.
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce))
if marshalErr != nil { err = c.writeControl(ctx, opClose, p)
return marshalErr // If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
} }
return writeErr return nil
} }
func (c *Conn) waitCloseHandshake() error { func (c *Conn) waitCloseHandshake() error {
defer c.close(nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
...@@ -171,8 +206,11 @@ func (c *Conn) waitCloseHandshake() error { ...@@ -171,8 +206,11 @@ func (c *Conn) waitCloseHandshake() error {
} }
defer c.readMu.unlock() defer c.readMu.unlock()
if c.readCloseFrameErr != nil { for i := int64(0); i < c.msgReader.payloadLength; i++ {
return c.readCloseFrameErr _, err := c.br.ReadByte()
if err != nil {
return err
}
} }
for { for {
...@@ -190,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error { ...@@ -190,6 +228,36 @@ func (c *Conn) waitCloseHandshake() error {
} }
} }
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) { func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 { if len(p) == 0 {
return CloseError{ return CloseError{
...@@ -260,16 +328,8 @@ func (ce CloseError) bytesErr() ([]byte, error) { ...@@ -260,16 +328,8 @@ func (ce CloseError) bytesErr() ([]byte, error) {
return buf, nil return buf, nil
} }
func (c *Conn) setCloseErr(err error) { func (c *Conn) casClosing() bool {
c.closeMu.Lock() return c.closing.Swap(true)
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}
func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
} }
func (c *Conn) isClosed() bool { func (c *Conn) isClosed() bool {
......
...@@ -9,7 +9,7 @@ import ( ...@@ -9,7 +9,7 @@ import (
"strings" "strings"
"testing" "testing"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
) )
func TestCloseError(t *testing.T) { func TestCloseError(t *testing.T) {
......
...@@ -6,50 +6,47 @@ package websocket ...@@ -6,50 +6,47 @@ package websocket
import ( import (
"compress/flate" "compress/flate"
"io" "io"
"net/http"
"sync" "sync"
) )
// CompressionMode represents the modes available to the deflate extension. // CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692 // See https://tools.ietf.org/html/rfc7692
// //
// A compatibility layer is implemented for the older deflate-frame extension used // Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 //
// It will work the same in every way except that we cannot signal to the peer we // Compression is only used if the peer supports the mode selected.
// want to use no context takeover on our side, we can only signal that they should.
// But it is currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218
type CompressionMode int type CompressionMode int
const ( const (
// CompressionDisabled disables the deflate extension. // CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
// //
// This is the default. // This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota CompressionDisabled CompressionMode = iota
// CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection. // CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// It reusing the sliding window from previous messages. // previous messages. i.e compression context across messages is preserved.
// As most WebSocket protocols are repetitive, this can be very efficient. //
// It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover. // As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
// //
// Sometime in the future it will carry 65 kB overhead instead once https://github.com/golang/go/issues/36919 // Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
// is fixed.
// //
// If the peer negotiates NoContextTakeover on the client or server side, it will be // If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
// used instead as this is required by the RFC.
CompressionContextTakeover CompressionContextTakeover
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed // CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// for every message. This applies to both server and client side. // a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
// //
// This means less efficient compression as the sliding window from previous messages // This means less efficient compression as the sliding window from previous messages will not be used but the
// will not be used but the memory overhead will be lower if the connections // memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// are long lived and seldom used. // Especially if the connections are long lived and seldom written to.
// //
// The message will only be compressed if greater than 512 bytes. // Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover CompressionNoContextTakeover
) )
...@@ -65,7 +62,7 @@ type compressionOptions struct { ...@@ -65,7 +62,7 @@ type compressionOptions struct {
serverNoContextTakeover bool serverNoContextTakeover bool
} }
func (copts *compressionOptions) setHeader(h http.Header) { func (copts *compressionOptions) String() string {
s := "permessage-deflate" s := "permessage-deflate"
if copts.clientNoContextTakeover { if copts.clientNoContextTakeover {
s += "; client_no_context_takeover" s += "; client_no_context_takeover"
...@@ -73,14 +70,14 @@ func (copts *compressionOptions) setHeader(h http.Header) { ...@@ -73,14 +70,14 @@ func (copts *compressionOptions) setHeader(h http.Header) {
if copts.serverNoContextTakeover { if copts.serverNoContextTakeover {
s += "; server_no_context_takeover" s += "; server_no_context_takeover"
} }
h.Set("Sec-WebSocket-Extensions", s) return s
} }
// These bytes are required to get flate.Reader to return. // These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as // They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then // WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps // we need to add them back otherwise flate.Reader keeps
// trying to return more bytes. // trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff" const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct { type trimLastFourBytesWriter struct {
...@@ -201,23 +198,19 @@ func (sw *slidingWindow) init(n int) { ...@@ -201,23 +198,19 @@ func (sw *slidingWindow) init(n int) {
} }
p := slidingWindowPool(n) p := slidingWindowPool(n)
buf, ok := p.Get().([]byte) sw2, ok := p.Get().(*slidingWindow)
if ok { if ok {
sw.buf = buf[:0] *sw = *sw2
} else { } else {
sw.buf = make([]byte, 0, n) sw.buf = make([]byte, 0, n)
} }
} }
func (sw *slidingWindow) close() { func (sw *slidingWindow) close() {
if sw.buf == nil { sw.buf = sw.buf[:0]
return
}
swPoolMu.Lock() swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw.buf) swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock() swPoolMu.Unlock()
sw.buf = nil
} }
func (sw *slidingWindow) write(p []byte) { func (sw *slidingWindow) write(p []byte) {
......
...@@ -4,11 +4,14 @@ ...@@ -4,11 +4,14 @@
package websocket package websocket
import ( import (
"bytes"
"compress/flate"
"io"
"strings" "strings"
"testing" "testing"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand" "github.com/coder/websocket/internal/test/xrand"
) )
func Test_slidingWindow(t *testing.T) { func Test_slidingWindow(t *testing.T) {
...@@ -33,3 +36,27 @@ func Test_slidingWindow(t *testing.T) { ...@@ -33,3 +36,27 @@ func Test_slidingWindow(t *testing.T) {
}) })
} }
} }
func BenchmarkFlateWriter(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, _ := flate.NewWriter(io.Discard, flate.BestSpeed)
// We have to write a byte to get the writer to allocate to its full extent.
w.Write([]byte{'a'})
w.Flush()
}
}
func BenchmarkFlateReader(b *testing.B) {
b.ReportAllocs()
var buf bytes.Buffer
w, _ := flate.NewWriter(&buf, flate.BestSpeed)
w.Write([]byte{'a'})
w.Flush()
for i := 0; i < b.N; i++ {
r := flate.NewReader(bytes.NewReader(buf.Bytes()))
io.ReadAll(r)
}
}