good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found
Select Git revision
  • dependabot/go_modules/internal/examples/golang.org/x/time-0.6.0
  • dependabot/go_modules/internal/examples/golang.org/x/time-0.7.0
  • dependabot/go_modules/internal/examples/internal-deps-022ca1aea3
  • dependabot/go_modules/internal/examples/internal-deps-46eeb9c117
  • dependabot/go_modules/internal/examples/internal-deps-4cadc2be3d
  • dependabot/go_modules/internal/thirdparty/github.com/gin-gonic/gin-1.10.0
  • dependabot/go_modules/internal/thirdparty/github.com/gobwas/ws-1.4.0
  • dependabot/go_modules/internal/thirdparty/github.com/gorilla/websocket-1.5.3
  • dependabot/go_modules/internal/thirdparty/github.com/lesismal/nbio-1.5.10
  • dependabot/go_modules/internal/thirdparty/github.com/lesismal/nbio-1.5.11
  • dependabot/go_modules/internal/thirdparty/github.com/lesismal/nbio-1.5.12
  • dev
  • docs
  • ethan/close-order
  • fix-license-1
  • mafredri/build-add-makefile
  • mafredri/build-update-to-go1.22
  • mafredri/chore-remove-funding
  • mafredri/chore-update-dependabot
  • mafredri/fix-ci
  • mafredri/fix-ci-script-tool-version
  • mafredri/fix-coverage
  • mafredri/fix-coverage2
  • mafredri/fix-coverage3
  • mafredri/fix-coverage4
  • mafredri/fix-coverage5
  • mafredri/r-context-docs
  • mafredri/support-http-responsecontroller
  • master
  • matifali/enable-dependabot
  • matifali/optimize-CI
  • readme-update
  • v0.1.0
  • v0.2.0
  • v1.0.0
  • v1.1.0
  • v1.1.1
  • v1.2.0
  • v1.2.1
  • v1.3.0
  • v1.3.1
  • v1.3.2
  • v1.3.3
  • v1.4.0
  • v1.5.0
  • v1.5.1
  • v1.6.0
  • v1.6.1
  • v1.6.2
  • v1.6.3
  • v1.6.4
  • v1.6.5
  • v1.7.0
  • v1.7.1
  • v1.7.2
  • v1.7.3
  • v1.7.4
  • v1.8.0
  • v1.8.1
  • v1.8.10
  • v1.8.11
  • v1.8.12
  • v1.8.13
  • v1.8.2
  • v1.8.3
  • v1.8.4
  • v1.8.5
  • v1.8.6
  • v1.8.7
  • v1.8.8
  • v1.8.9
71 results

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Select Git revision
  • dev
  • docs
  • master
  • v0.1.0
  • v0.2.0
  • v1.0.0
  • v1.1.0
  • v1.1.1
  • v1.10.0
  • v1.2.0
  • v1.2.1
  • v1.3.0
  • v1.3.1
  • v1.3.2
  • v1.3.3
  • v1.4.0
  • v1.5.0
  • v1.5.1
  • v1.6.0
  • v1.6.1
  • v1.6.2
  • v1.6.3
  • v1.6.4
  • v1.6.5
  • v1.7.0
  • v1.7.1
  • v1.7.2
  • v1.7.3
  • v1.7.4
  • v1.8.0
  • v1.8.1
  • v1.8.2
  • v1.8.3
  • v1.8.4
  • v1.8.5
  • v1.8.6
  • v1.8.7
  • v1.9.0
  • v1.9.1
  • v1.9.2
40 results
Show changes

Commits on Source 88

github: nhooyr
version: 2
updates:
# Track in case we ever add dependencies.
- package-ecosystem: 'gomod'
directory: '/'
schedule:
interval: 'weekly'
commit-message:
prefix: 'chore'
# Keep example and test/benchmark deps up-to-date.
- package-ecosystem: 'gomod'
directories:
- '/internal/examples'
- '/internal/thirdparty'
schedule:
interval: 'monthly'
commit-message:
prefix: 'chore'
labels: []
groups:
internal-deps:
patterns:
- '*'
name: ci name: ci
on: [push, pull_request] on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true cancel-in-progress: true
...@@ -9,30 +15,36 @@ jobs: ...@@ -9,30 +15,36 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/fmt.sh - run: make fmt
lint: lint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- run: go version - run: go version
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/lint.sh - run: make lint
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/test.sh - run: make test
- uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v4
with: with:
name: coverage.html name: coverage.html
path: ./ci/out/coverage.html path: ./ci/out/coverage.html
...@@ -41,7 +53,7 @@ jobs: ...@@ -41,7 +53,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: ./ci/bench.sh - run: make bench
...@@ -12,19 +12,25 @@ jobs: ...@@ -12,19 +12,25 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: AUTOBAHN=1 ./ci/bench.sh - run: AUTOBAHN=1 make bench
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: AUTOBAHN=1 ./ci/test.sh - run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v4
with: with:
name: coverage.html name: coverage.html
path: ./ci/out/coverage.html path: ./ci/out/coverage.html
...@@ -34,21 +40,27 @@ jobs: ...@@ -34,21 +40,27 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: dev ref: dev
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: AUTOBAHN=1 ./ci/bench.sh - run: AUTOBAHN=1 make bench
test-dev: test-dev:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
ref: dev ref: dev
- uses: actions/setup-go@v4 - uses: actions/setup-go@v5
with: with:
go-version-file: ./go.mod go-version-file: ./go.mod
- run: AUTOBAHN=1 ./ci/test.sh - run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v3 - uses: actions/upload-artifact@v4
with: with:
name: coverage.html name: coverage-dev.html
path: ./ci/out/coverage.html path: ./ci/out/coverage.html
name: static
on:
push:
branches: ['master']
workflow_dispatch:
# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages.
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: pages
cancel-in-progress: true
jobs:
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- name: Generate coverage and badge
run: |
make test
mkdir -p ./ci/out/static
cp ./ci/out/coverage.html ./ci/out/static/coverage.html
percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%')
wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success"
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./ci/out/static/
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
Copyright (c) 2023 Anmol Sethi <hi@nhooyr.io> Copyright (c) 2025 Coder
Permission to use, copy, modify, and distribute this software for any Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above purpose with or without fee is hereby granted, provided that the above
......
.PHONY: all
all: fmt lint test
.PHONY: fmt
fmt:
./ci/fmt.sh
.PHONY: lint
lint:
./ci/lint.sh
.PHONY: test
test:
./ci/test.sh
.PHONY: bench
bench:
./ci/bench.sh
\ No newline at end of file
# websocket # websocket
[![Go Reference](https://pkg.go.dev/badge/nhooyr.io/websocket.svg)](https://pkg.go.dev/nhooyr.io/websocket) [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket)
[![Go Coverage](https://img.shields.io/badge/coverage-91%25-success)](https://nhooyr.io/websocket/coverage.html) [![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html)
websocket is a minimal and idiomatic WebSocket library for Go. websocket is a minimal and idiomatic WebSocket library for Go.
## Install ## Install
```sh ```sh
go get nhooyr.io/websocket go get github.com/coder/websocket
``` ```
> [!NOTE]
> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket).
> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from
> 2019 to 2024.
## Highlights ## Highlights
- Minimal and idiomatic API - Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support - First class [context.Context](https://blog.golang.org/context) support
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- [Zero dependencies](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) - [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports)
- JSON helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage - JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- Zero alloc reads and writes - Zero alloc reads and writes
- Concurrent writes - Concurrent writes
- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) - [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression - [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) - Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm)
## Roadmap ## Roadmap
...@@ -58,10 +63,12 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { ...@@ -58,10 +63,12 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
} }
defer c.CloseNow() defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) // Set the context as needed. Use of r.Context() is not recommended
// to avoid surprising behavior (see http.Hijacker).
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
var v interface{} var v any
err = wsjson.Read(ctx, c, &v) err = wsjson.Read(ctx, c, &v)
if err != nil { if err != nil {
// ... // ...
...@@ -102,14 +109,14 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): ...@@ -102,14 +109,14 @@ Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
- Mature and widely used - Mature and widely used
- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) - Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
- No extra goroutine per connection to support cancellation with context.Context. This costs nhooyr.io/websocket 2 KB of memory per connection. - No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection.
- Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411) - Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411)
Advantages of nhooyr.io/websocket: Advantages of github.com/coder/websocket:
- Minimal and idiomatic API - Minimal and idiomatic API
- Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) - Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support - Full [context.Context](https://blog.golang.org/context) support
- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
...@@ -117,24 +124,24 @@ Advantages of nhooyr.io/websocket: ...@@ -117,24 +124,24 @@ Advantages of nhooyr.io/websocket:
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
- Concurrent writes - Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) - Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API - Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- Gorilla requires registering a pong callback before sending a Ping - Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) - Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) subpackage - Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go - [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326) Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support - Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode - Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) - [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
#### golang.org/x/net/websocket #### golang.org/x/net/websocket
[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning
to nhooyr.io/websocket. to github.com/coder/websocket.
#### gobwas/ws #### gobwas/ws
...@@ -143,7 +150,7 @@ in an event driven style for performance. See the author's [blog post](https://m ...@@ -143,7 +150,7 @@ in an event driven style for performance. See the author's [blog post](https://m
However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws
When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
#### lesismal/nbio #### lesismal/nbio
...@@ -152,4 +159,4 @@ event driven for performance reasons. ...@@ -152,4 +159,4 @@ event driven for performance reasons.
However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio
When writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
//go:build !js //go:build !js
// +build !js
package websocket package websocket
import ( import (
"bytes" "bytes"
"context"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors" "errors"
...@@ -14,10 +14,10 @@ import ( ...@@ -14,10 +14,10 @@ import (
"net/http" "net/http"
"net/textproto" "net/textproto"
"net/url" "net/url"
"path/filepath" "path"
"strings" "strings"
"nhooyr.io/websocket/internal/errd" "github.com/coder/websocket/internal/errd"
) )
// AcceptOptions represents Accept's options. // AcceptOptions represents Accept's options.
...@@ -41,8 +41,8 @@ type AcceptOptions struct { ...@@ -41,8 +41,8 @@ type AcceptOptions struct {
// One would set this field to []string{"example.com"} to authorize example.com to connect. // One would set this field to []string{"example.com"} to authorize example.com to connect.
// //
// Each pattern is matched case insensitively against the request origin host // Each pattern is matched case insensitively against the request origin host
// with filepath.Match. // with path.Match.
// See https://golang.org/pkg/path/filepath/#Match // See https://golang.org/pkg/path/#Match
// //
// Please ensure you understand the ramifications of enabling this. // Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks. // If used incorrectly your WebSocket server will be open to CSRF attacks.
...@@ -62,6 +62,22 @@ type AcceptOptions struct { ...@@ -62,6 +62,22 @@ type AcceptOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover. // for CompressionContextTakeover.
CompressionThreshold int CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
} }
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
...@@ -79,6 +95,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions { ...@@ -79,6 +95,9 @@ func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
// //
// Accept will write a response to w on all errors. // Accept will write a response to w on all errors.
//
// Note that using the http.Request Context after Accept returns may lead to
// unexpected behavior (see http.Hijacker).
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts) return accept(w, r, opts)
} }
...@@ -96,7 +115,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -96,7 +115,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
if !opts.InsecureSkipVerify { if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns) err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil { if err != nil {
if errors.Is(err, filepath.ErrBadPattern) { if errors.Is(err, path.ErrBadPattern) {
log.Printf("websocket: %v", err) log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden)) err = errors.New(http.StatusText(http.StatusForbidden))
} }
...@@ -105,7 +124,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -105,7 +124,7 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
} }
} }
hj, ok := w.(http.Hijacker) hj, ok := hijacker(w)
if !ok { if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker") err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
...@@ -153,6 +172,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con ...@@ -153,6 +172,8 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con
client: false, client: false,
copts: copts, copts: copts,
flateThreshold: opts.CompressionThreshold, flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: brw.Reader, br: brw.Reader,
bw: brw.Writer, bw: brw.Writer,
...@@ -221,7 +242,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { ...@@ -221,7 +242,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
for _, hostPattern := range originHosts { for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host) matched, err := match(hostPattern, u.Host)
if err != nil { if err != nil {
return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
} }
if matched { if matched {
return nil return nil
...@@ -234,7 +255,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error { ...@@ -234,7 +255,7 @@ func authenticateOrigin(r *http.Request, originHosts []string) error {
} }
func match(pattern, s string) (bool, error) { func match(pattern, s string) (bool, error) {
return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) return path.Match(strings.ToLower(pattern), strings.ToLower(s))
} }
func selectSubprotocol(r *http.Request, subprotocols []string) string { func selectSubprotocol(r *http.Request, subprotocols []string) string {
......
//go:build !js //go:build !js
// +build !js
package websocket package websocket
...@@ -10,10 +9,11 @@ import ( ...@@ -10,10 +9,11 @@ import (
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand" "github.com/coder/websocket/internal/test/xrand"
) )
func TestAccept(t *testing.T) { func TestAccept(t *testing.T) {
...@@ -142,6 +142,69 @@ func TestAccept(t *testing.T) { ...@@ -142,6 +142,69 @@ func TestAccept(t *testing.T) {
_, err := Accept(w, r, nil) _, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`) assert.Contains(t, err, `failed to hijack connection`)
}) })
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, "failed to hijack connection")
})
t.Run("closeRace", func(t *testing.T) {
t.Parallel()
server, _ := net.Pipe()
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return server, rw, nil
},
}
}
w := newResponseWriter()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
c, err := Accept(w, r, nil)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
c.Close(StatusInternalError, "the sky is falling")
wg.Done()
}()
go func() {
c.CloseNow()
wg.Done()
}()
wg.Wait()
assert.Success(t, err)
})
} }
func Test_verifyClientHandshake(t *testing.T) { func Test_verifyClientHandshake(t *testing.T) {
...@@ -497,3 +560,14 @@ var _ http.Hijacker = mockHijacker{} ...@@ -497,3 +560,14 @@ var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack() return mj.hijack()
} }
type mockUnwrapper struct {
http.ResponseWriter
unwrap func() http.ResponseWriter
}
var _ rwUnwrapper = mockUnwrapper{}
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
return mu.unwrap()
}
//go:build !js //go:build !js
// +build !js
package websocket_test package websocket_test
...@@ -17,11 +16,11 @@ import ( ...@@ -17,11 +16,11 @@ import (
"testing" "testing"
"time" "time"
"nhooyr.io/websocket" "github.com/coder/websocket"
"nhooyr.io/websocket/internal/errd" "github.com/coder/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest" "github.com/coder/websocket/internal/test/wstest"
"nhooyr.io/websocket/internal/util" "github.com/coder/websocket/internal/util"
) )
var excludedAutobahnCases = []string{ var excludedAutobahnCases = []string{
...@@ -92,7 +91,7 @@ func TestAutobahn(t *testing.T) { ...@@ -92,7 +91,7 @@ func TestAutobahn(t *testing.T) {
} }
}) })
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil)
assert.Success(t, err) assert.Success(t, err)
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
...@@ -130,7 +129,7 @@ func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func( ...@@ -130,7 +129,7 @@ func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func(
url = "ws://" + serverAddr url = "ws://" + serverAddr
const outDir = "ci/out/autobahn-report" const outDir = "ci/out/autobahn-report"
specFile, err := tempJSONFile(map[string]interface{}{ specFile, err := tempJSONFile(map[string]any{
"url": url, "url": url,
"outdir": outDir, "outdir": outDir,
"cases": autobahnCases, "cases": autobahnCases,
...@@ -280,7 +279,7 @@ func unusedListenAddr() (_ string, err error) { ...@@ -280,7 +279,7 @@ func unusedListenAddr() (_ string, err error) {
return l.Addr().String(), nil return l.Addr().String(), nil
} }
func tempJSONFile(v interface{}) (string, error) { func tempJSONFile(v any) (string, error) {
f, err := os.CreateTemp("", "temp.json") f, err := os.CreateTemp("", "temp.json")
if err != nil { if err != nil {
return "", fmt.Errorf("temp file: %w", err) return "", fmt.Errorf("temp file: %w", err)
......
...@@ -2,8 +2,19 @@ ...@@ -2,8 +2,19 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" . go test --run=^$ --bench=. --benchmem "$@" ./...
# For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
( (
cd ./internal/thirdparty cd ./internal/thirdparty
go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" . go test --run=^$ --bench=. --benchmem "$@" .
GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
fi
) )
...@@ -2,19 +2,24 @@ ...@@ -2,19 +2,24 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
X_TOOLS_VERSION=v0.31.0
go mod tidy go mod tidy
(cd ./internal/thirdparty && go mod tidy) (cd ./internal/thirdparty && go mod tidy)
(cd ./internal/examples && go mod tidy) (cd ./internal/examples && go mod tidy)
gofmt -w -s . gofmt -w -s .
go run golang.org/x/tools/cmd/goimports@latest -w "-local=$(go list -m)" . go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" .
npx prettier@3.0.3 \ git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \
--write \ --check \
--log-level=warn \ --log-level=warn \
--print-width=90 \ --print-width=90 \
--no-semi \ --no-semi \
--single-quote \ --single-quote \
--arrow-parens=avoid \ --arrow-parens=avoid
$(git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html")
go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go
go run golang.org/x/tools/cmd/stringer@latest -type=opcode,MessageType,StatusCode -output=stringer.go if [ "${CI-}" ]; then
git diff --exit-code
fi
...@@ -2,10 +2,13 @@ ...@@ -2,10 +2,13 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
STATICCHECK_VERSION=v0.6.1
GOVULNCHECK_VERSION=v1.1.4
go vet ./... go vet ./...
GOOS=js GOARCH=wasm go vet ./... GOOS=js GOARCH=wasm go vet ./...
go install honnef.co/go/tools/cmd/staticcheck@latest go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION}
staticcheck ./... staticcheck ./...
GOOS=js GOARCH=wasm staticcheck ./... GOOS=js GOARCH=wasm staticcheck ./...
...@@ -15,7 +18,7 @@ govulncheck() { ...@@ -15,7 +18,7 @@ govulncheck() {
cat "$tmpf" cat "$tmpf"
fi fi
} }
go install golang.org/x/vuln/cmd/govulncheck@latest go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION}
govulncheck ./... govulncheck ./...
GOOS=js GOARCH=wasm govulncheck ./... GOOS=js GOARCH=wasm govulncheck ./...
......
...@@ -11,7 +11,20 @@ cd -- "$(dirname "$0")/.." ...@@ -11,7 +11,20 @@ cd -- "$(dirname "$0")/.."
go test "$@" ./... go test "$@" ./...
) )
go install github.com/agnivade/wasmbrowsertest@latest (
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)
go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./... go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
......
//go:build !js //go:build !js
// +build !js
package websocket package websocket
...@@ -11,7 +10,7 @@ import ( ...@@ -11,7 +10,7 @@ import (
"net" "net"
"time" "time"
"nhooyr.io/websocket/internal/errd" "github.com/coder/websocket/internal/errd"
) )
// StatusCode represents a WebSocket status code. // StatusCode represents a WebSocket status code.
...@@ -93,85 +92,110 @@ func CloseStatus(err error) StatusCode { ...@@ -93,85 +92,110 @@ func CloseStatus(err error) StatusCode {
// The connection can only be closed once. Additional calls to Close // The connection can only be closed once. Additional calls to Close
// are no-ops. // are no-ops.
// //
// The maximum length of reason must be 125 bytes. Avoid // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
// sending a dynamic reason.
// //
// Close will unblock all goroutines interacting with the connection once // Close will unblock all goroutines interacting with the connection once
// complete. // complete.
func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer c.wg.Wait() defer errd.Wrap(&err, "failed to close WebSocket")
return c.closeHandshake(code, reason)
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
} }
// CloseNow closes the WebSocket connection without attempting a close handshake. // CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake. // Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) { func (c *Conn) CloseNow() (err error) {
defer c.wg.Wait() defer errd.Wrap(&err, "failed to immediately close WebSocket")
defer errd.Wrap(&err, "failed to close WebSocket")
if c.isClosed() { if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed return net.ErrClosed
} }
defer func() {
c.close(nil) if errors.Is(err, net.ErrClosed) {
return c.closeErr err = nil
} }
}()
func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { err = c.close()
defer errd.Wrap(&err, "failed to close WebSocket")
writeErr := c.writeClose(code, reason)
closeHandshakeErr := c.waitCloseHandshake()
if writeErr != nil { err2 := c.waitGoroutines()
return writeErr if err == nil && err2 != nil {
err = err2
}
return err
} }
if CloseStatus(closeHandshakeErr) == -1 && !errors.Is(net.ErrClosed, closeHandshakeErr) { func (c *Conn) closeHandshake(code StatusCode, reason string) error {
return closeHandshakeErr err := c.writeClose(code, reason)
if err != nil {
return err
} }
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil return nil
} }
func (c *Conn) writeClose(code StatusCode, reason string) error { func (c *Conn) writeClose(code StatusCode, reason string) error {
c.closeMu.Lock()
wroteClose := c.wroteClose
c.wroteClose = true
c.closeMu.Unlock()
if wroteClose {
return net.ErrClosed
}
ce := CloseError{ ce := CloseError{
Code: code, Code: code,
Reason: reason, Reason: reason,
} }
var p []byte var p []byte
var marshalErr error var err error
if ce.Code != StatusNoStatusRcvd { if ce.Code != StatusNoStatusRcvd {
p, marshalErr = ce.bytes() p, err = ce.bytes()
if err != nil {
return err
} }
writeErr := c.writeControl(context.Background(), opClose, p)
if CloseStatus(writeErr) != -1 {
// Not a real error if it's due to a close frame being received.
writeErr = nil
} }
// We do this after in case there was an error writing the close frame. ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) defer cancel()
if marshalErr != nil { err = c.writeControl(ctx, opClose, p)
return marshalErr // If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
} }
return writeErr return nil
} }
func (c *Conn) waitCloseHandshake() error { func (c *Conn) waitCloseHandshake() error {
defer c.close(nil)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel() defer cancel()
...@@ -181,10 +205,6 @@ func (c *Conn) waitCloseHandshake() error { ...@@ -181,10 +205,6 @@ func (c *Conn) waitCloseHandshake() error {
} }
defer c.readMu.unlock() defer c.readMu.unlock()
if c.readCloseFrameErr != nil {
return c.readCloseFrameErr
}
for i := int64(0); i < c.msgReader.payloadLength; i++ { for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte() _, err := c.br.ReadByte()
if err != nil { if err != nil {
...@@ -207,6 +227,36 @@ func (c *Conn) waitCloseHandshake() error { ...@@ -207,6 +227,36 @@ func (c *Conn) waitCloseHandshake() error {
} }
} }
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) { func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 { if len(p) == 0 {
return CloseError{ return CloseError{
...@@ -277,16 +327,8 @@ func (ce CloseError) bytesErr() ([]byte, error) { ...@@ -277,16 +327,8 @@ func (ce CloseError) bytesErr() ([]byte, error) {
return buf, nil return buf, nil
} }
func (c *Conn) setCloseErr(err error) { func (c *Conn) casClosing() bool {
c.closeMu.Lock() return c.closing.Swap(true)
c.setCloseErrLocked(err)
c.closeMu.Unlock()
}
func (c *Conn) setCloseErrLocked(err error) {
if c.closeErr == nil && err != nil {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
}
} }
func (c *Conn) isClosed() bool { func (c *Conn) isClosed() bool {
......
//go:build !js //go:build !js
// +build !js
package websocket package websocket
...@@ -9,7 +8,7 @@ import ( ...@@ -9,7 +8,7 @@ import (
"strings" "strings"
"testing" "testing"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
) )
func TestCloseError(t *testing.T) { func TestCloseError(t *testing.T) {
......
//go:build !js //go:build !js
// +build !js
package websocket package websocket
...@@ -168,8 +167,10 @@ type slidingWindow struct { ...@@ -168,8 +167,10 @@ type slidingWindow struct {
buf []byte buf []byte
} }
var swPoolMu sync.RWMutex var (
var swPool = map[int]*sync.Pool{} swPoolMu sync.RWMutex
swPool = map[int]*sync.Pool{}
)
func slidingWindowPool(n int) *sync.Pool { func slidingWindowPool(n int) *sync.Pool {
swPoolMu.RLock() swPoolMu.RLock()
......
//go:build !js //go:build !js
// +build !js
package websocket package websocket
...@@ -10,8 +9,8 @@ import ( ...@@ -10,8 +9,8 @@ import (
"strings" "strings"
"testing" "testing"
"nhooyr.io/websocket/internal/test/assert" "github.com/coder/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand" "github.com/coder/websocket/internal/test/xrand"
) )
func Test_slidingWindow(t *testing.T) { func Test_slidingWindow(t *testing.T) {
...@@ -19,7 +18,7 @@ func Test_slidingWindow(t *testing.T) { ...@@ -19,7 +18,7 @@ func Test_slidingWindow(t *testing.T) {
const testCount = 99 const testCount = 99
const maxWindow = 99999 const maxWindow = 99999
for i := 0; i < testCount; i++ { for range testCount {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
t.Parallel() t.Parallel()
......
//go:build !js //go:build !js
// +build !js
package websocket package websocket
import ( import (
"bufio" "bufio"
"context" "context"
"errors"
"fmt" "fmt"
"io" "io"
"net" "net"
...@@ -55,13 +53,13 @@ type Conn struct { ...@@ -55,13 +53,13 @@ type Conn struct {
readTimeout chan context.Context readTimeout chan context.Context
writeTimeout chan context.Context writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state. // Read state.
readMu *mu readMu *mu
readHeaderBuf [8]byte readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte readControlBuf [maxControlPayload]byte
msgReader *msgReader msgReader *msgReader
readCloseFrameErr error
// Write state. // Write state.
msgWriter *msgWriter msgWriter *msgWriter
...@@ -70,15 +68,25 @@ type Conn struct { ...@@ -70,15 +68,25 @@ type Conn struct {
writeHeaderBuf [8]byte writeHeaderBuf [8]byte
writeHeader header writeHeader header
wg sync.WaitGroup // Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error
// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{} closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
pingCounter int32 pingCounter atomic.Int64
activePingsMu sync.Mutex activePingsMu sync.Mutex
activePings map[string]chan<- struct{} activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
} }
type connConfig struct { type connConfig struct {
...@@ -87,6 +95,8 @@ type connConfig struct { ...@@ -87,6 +95,8 @@ type connConfig struct {
client bool client bool
copts *compressionOptions copts *compressionOptions
flateThreshold int flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
br *bufio.Reader br *bufio.Reader
bw *bufio.Writer bw *bufio.Writer
...@@ -105,9 +115,12 @@ func newConn(cfg connConfig) *Conn { ...@@ -105,9 +115,12 @@ func newConn(cfg connConfig) *Conn {
readTimeout: make(chan context.Context), readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context), writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}), closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}), activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
} }
c.readMu = newMu(c) c.readMu = newMu(c)
...@@ -128,14 +141,10 @@ func newConn(cfg connConfig) *Conn { ...@@ -128,14 +141,10 @@ func newConn(cfg connConfig) *Conn {
} }
runtime.SetFinalizer(c, func(c *Conn) { runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected")) c.close()
}) })
c.wg.Add(1) go c.timeoutLoop()
go func() {
defer c.wg.Done()
c.timeoutLoop()
}()
return c return c
} }
...@@ -146,35 +155,29 @@ func (c *Conn) Subprotocol() string { ...@@ -146,35 +155,29 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol return c.subprotocol
} }
func (c *Conn) close(err error) { func (c *Conn) close() error {
c.closeMu.Lock() c.closeMu.Lock()
defer c.closeMu.Unlock() defer c.closeMu.Unlock()
if c.isClosed() { if c.isClosed() {
return return net.ErrClosed
}
if err == nil {
err = c.rwc.Close()
} }
c.setCloseErrLocked(err)
close(c.closed)
runtime.SetFinalizer(c, nil) runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up // Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns // from the connection being closed also sees that c.closed is closed and returns
// closeErr. // closeErr.
c.rwc.Close() err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.wg.Add(1)
go func() {
defer c.wg.Done()
c.msgWriter.close() c.msgWriter.close()
c.msgReader.close() c.msgReader.close()
}() return err
} }
func (c *Conn) timeoutLoop() { func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background() readCtx := context.Background()
writeCtx := context.Background() writeCtx := context.Background()
...@@ -187,14 +190,10 @@ func (c *Conn) timeoutLoop() { ...@@ -187,14 +190,10 @@ func (c *Conn) timeoutLoop() {
case readCtx = <-c.readTimeout: case readCtx = <-c.readTimeout:
case <-readCtx.Done(): case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) c.close()
c.wg.Add(1) return
go func() {
defer c.wg.Done()
c.writeError(StatusPolicyViolation, errors.New("read timed out"))
}()
case <-writeCtx.Done(): case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) c.close()
return return
} }
} }
...@@ -212,9 +211,9 @@ func (c *Conn) flate() bool { ...@@ -212,9 +211,9 @@ func (c *Conn) flate() bool {
// //
// TCP Keepalives should suffice for most use cases. // TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error { func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1) p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.Itoa(int(p))) err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil { if err != nil {
return fmt.Errorf("failed to ping: %w", err) return fmt.Errorf("failed to ping: %w", err)
} }
...@@ -243,9 +242,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { ...@@ -243,9 +242,7 @@ func (c *Conn) ping(ctx context.Context, p string) error {
case <-c.closed: case <-c.closed:
return net.ErrClosed return net.ErrClosed
case <-ctx.Done(): case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
case <-pong: case <-pong:
return nil return nil
} }
...@@ -281,9 +278,7 @@ func (m *mu) lock(ctx context.Context) error { ...@@ -281,9 +278,7 @@ func (m *mu) lock(ctx context.Context) error {
case <-m.c.closed: case <-m.c.closed:
return net.ErrClosed return net.ErrClosed
case <-ctx.Done(): case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
case m.ch <- struct{}{}: case m.ch <- struct{}{}:
// To make sure the connection is certainly alive. // To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected // As it's possible the send on m.ch was selected
......