good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
Commits on Source (758)
Showing with 1584 additions and 277 deletions
version: 2
updates:
# Track in case we ever add dependencies.
- package-ecosystem: 'gomod'
directory: '/'
schedule:
interval: 'weekly'
commit-message:
prefix: 'chore'
# Keep example and test/benchmark deps up-to-date.
- package-ecosystem: 'gomod'
directories:
- '/internal/examples'
- '/internal/thirdparty'
schedule:
interval: 'monthly'
commit-message:
prefix: 'chore'
labels: []
groups:
internal-deps:
patterns:
- '*'
FROM golang:1.12
LABEL "com.github.actions.name"="fmt"
LABEL "com.github.actions.description"="fmt"
LABEL "com.github.actions.icon"="code"
LABEL "com.github.actions.color"="purple"
COPY entrypoint.sh /entrypoint.sh
CMD ["/entrypoint.sh"]
#!/usr/bin/env bash
source .github/lib.sh || exit 1
gen() {
# Unfortunately, this is the only way to ensure go.mod and go.sum are correct.
# See https://github.com/golang/go/issues/27005
go list ./... > /dev/null
go mod tidy
go install golang.org/x/tools/cmd/stringer
go generate ./...
}
fmt() {
gofmt -w -s .
go run go.coder.com/go-tools/cmd/goimports -w "-local=$(go list -m)" .
go run mvdan.cc/sh/cmd/shfmt -w -s -sr .
}
gen
fmt
if [[ $CI && $(unstaged_files) != "" ]]; then
set +x
echo
echo "files either need generation or are formatted incorrectly"
echo "please run:"
echo "./test.sh"
echo
git status
exit 1
fi
#!/usr/bin/env bash
set -euxo pipefail
export GO111MODULE=on
export PAGER=cat
# shellcheck disable=SC2034
# CI is used by the scripts that source this file.
export CI=${GITHUB_ACTION-}
if [[ $CI ]]; then
export GOFLAGS=-mod=readonly
fi
unstaged_files() {
git ls-files --other --modified --exclude-standard
}
FROM codercom/playcicache
LABEL "com.github.actions.name"="lint"
LABEL "com.github.actions.description"="lint"
LABEL "com.github.actions.icon"="code"
LABEL "com.github.actions.color"="purple"
COPY entrypoint.sh /entrypoint.sh
CMD ["/entrypoint.sh"]
#!/usr/bin/env bash
source .github/lib.sh || exit 1
(
shopt -s globstar nullglob dotglob
shellcheck ./**/*.sh
)
go vet -composites=false ./...
go run golang.org/x/lint/golint -set_exit_status ./...
workflow "main" {
on = "push"
resolves = ["fmt", "lint", "test"]
}
action "lint" {
uses = "./.github/lint"
}
action "fmt" {
uses = "./.github/fmt"
}
action "test" {
uses = "./.github/test"
secrets = ["CODECOV_TOKEN"]
}
FROM golang:1.12
LABEL "com.github.actions.name"="test"
LABEL "com.github.actions.description"="test"
LABEL "com.github.actions.icon"="code"
LABEL "com.github.actions.color"="purple"
RUN apt update && apt install -y shellcheck
COPY entrypoint.sh /entrypoint.sh
CMD ["/entrypoint.sh"]
#!/usr/bin/env bash
source .github/lib.sh || exit 1
COVERAGE_PROFILE=$(mktemp)
go test -race -v "-coverprofile=${COVERAGE_PROFILE}" -vet=off ./...
go tool cover "-func=${COVERAGE_PROFILE}"
if [[ $CI ]]; then
bash <(curl -s https://codecov.io/bash) -f "$COVERAGE_PROFILE"
else
go tool cover "-html=${COVERAGE_PROFILE}" -o=coverage.html
set +x
echo
echo "please open coverage.html to see detailed test coverage stats"
fi
name: ci
on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
jobs:
fmt:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make fmt
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: go version
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make lint
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make bench
name: daily
on:
workflow_dispatch:
schedule:
- cron: '42 0 * * *' # daily at 00:42
concurrency:
group: ${{ github.workflow }}
cancel-in-progress: true
jobs:
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench-dev:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test-dev:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage-dev.html
path: ./ci/out/coverage.html
name: static
on:
push:
branches: ['master']
workflow_dispatch:
# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages.
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: pages
cancel-in-progress: true
jobs:
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- name: Generate coverage and badge
run: |
make test
mkdir -p ./ci/out/static
cp ./ci/out/coverage.html ./ci/out/static/coverage.html
percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%')
wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success"
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./ci/out/static/
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
coverage.html
MIT License
Copyright (c) 2018 Anmol Sethi
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Copyright (c) 2025 Coder
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
.PHONY: all
all: fmt lint test
.PHONY: fmt
fmt:
./ci/fmt.sh
.PHONY: lint
lint:
./ci/lint.sh
.PHONY: test
test:
./ci/test.sh
.PHONY: bench
bench:
./ci/bench.sh
\ No newline at end of file
# websocket
[![GoDoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket)
[![Codecov](https://img.shields.io/codecov/c/github/nhooyr/websocket.svg)](https://codecov.io/gh/nhooyr/websocket)
[![GitHub release](https://img.shields.io/github/release/nhooyr/websocket.svg)](https://github.com/nhooyr/websocket/releases)
[![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket)
[![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html)
websocket is a minimal and idiomatic WebSocket library for Go.
This library is in heavy development.
## Install
```bash
go get nhooyr.io/websocket@master
```sh
go get github.com/coder/websocket
```
## Features
> [!NOTE]
> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket).
> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from
> 2019 to 2024.
- Full support of the WebSocket protocol
- Simple to use because of the minimal API
- Uses the context package for cancellation
- Uses net/http's Client to do WebSocket dials
- Compression of text frames larger than 1024 bytes by default
- Highly optimized
- API will transparently work with WebSockets over HTTP/2
- WASM support
## Highlights
## Example
- Minimal and idiomatic API
- First class [context.Context](https://blog.golang.org/context) support
- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports)
- JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- Zero alloc reads and writes
- Concurrent writes
- [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm)
### Server
## Roadmap
```go
func main() {
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r,
websocket.AcceptSubprotocols("test"),
)
if err != nil {
log.Printf("server handshake failed: %v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
v := map[string]interface{}{
"my_field": "foo",
}
err = websocket.WriteJSON(ctx, c, v)
if err != nil {
log.Printf("failed to write json: %v", err)
return
}
log.Printf("wrote %v", v)
c.Close(websocket.StatusNormalClosure, "")
})
err := http.ListenAndServe("localhost:8080", fn)
if err != nil {
log.Fatalf("failed to listen and serve: %v", err)
}
}
```
See GitHub issues for minor issues but the major future enhancements are:
For a production quality example that shows off the low level API, see the [echo example](https://godoc.org/nhooyr.io/websocket#ex-Accept--Echo).
- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217)
- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340)
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)
### Client
## Examples
```go
func main() {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
For a production quality example that demonstrates the complete API, see the
[echo example](./internal/examples/echo).
For a full stack example, see the [chat example](./internal/examples/chat).
c, _, err := websocket.Dial(ctx, "ws://localhost:8080",
websocket.DialSubprotocols("test"),
)
### Server
```go
http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
log.Fatalf("failed to ws dial: %v", err)
// ...
}
defer c.Close(websocket.StatusInternalError, "")
defer c.CloseNow()
// Set the context as needed. Use of r.Context() is not recommended
// to avoid surprising behavior (see http.Hijacker).
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
var v interface{}
err = websocket.ReadJSON(ctx, c, v)
err = wsjson.Read(ctx, c, &v)
if err != nil {
log.Fatalf("failed to read json: %v", err)
// ...
}
log.Printf("received %v", v)
log.Printf("received: %v", v)
c.Close(websocket.StatusNormalClosure, "")
}
})
```
See [example_test.go](example_test.go) for more examples.
### Client
## Design considerations
```go
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
- Minimal API is easier to maintain and for others to learn
- Context based cancellation is more ergonomic and robust than setting deadlines
- No pings or pongs because TCP keep alives work fine for HTTP/1.1 and they do not make
sense with HTTP/2
- net.Conn is never exposed as WebSocket's over HTTP/2 will not have a net.Conn.
- Functional options make the API very clean and easy to extend
- Compression is very useful for JSON payloads
- Protobuf and JSON helpers make code terse
- Using net/http's Client for dialing means we do not have to reinvent dialing hooks
and configurations. Just pass in a custom net/http client if you want custom dialing.
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil {
// ...
}
defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi")
if err != nil {
// ...
}
c.Close(websocket.StatusNormalClosure, "")
```
## Comparison
### gorilla/websocket
https://github.com/gorilla/websocket
Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
- Mature and widely used
- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
- No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection.
- Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411)
Advantages of github.com/coder/websocket:
- Minimal and idiomatic API
- Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
- Full [context.Context](https://blog.golang.org/context) support
- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
- Will enable easy HTTP/2 support in the future
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
- Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
- Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
This package is the community standard but it is very old and over timennn
has accumulated cruft. There are many ways to do the same thing and the API
overall is just not very clear. Just compare the godoc of
[nhooyr/websocket](godoc.org/github.com/nhooyr/websocket) side by side with
[gorilla/websocket](godoc.org/github.com/gorilla/websocket).
#### golang.org/x/net/websocket
The API for nhooyr/websocket has been designed such that there is only one way to do things
and with HTTP/2 in mind which makes using it correctly and safely much easier.
[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
### x/net/websocket
The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning
to github.com/coder/websocket.
https://godoc.org/golang.org/x/net/websocket
#### gobwas/ws
Unmaintained and the API does not reflect WebSocket semantics. Should never be used.
[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
See https://github.com/golang/go/issues/18152
However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws
### gobwas/ws
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
https://github.com/gobwas/ws
#### lesismal/nbio
This library has an extremely flexible API but that comes at the cost of usability
and clarity. Its just not clear how to do things in a safe manner.
[lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is
event driven for performance reasons.
This library is fantastic in terms of performance though. The author put in significant
effort to ensure its speed and I have tried to apply as many of its teachings as
I could into nhooyr/websocket.
However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio
If you want a library that gives you absolute control over everything, this is the library,
but for most users, the API provided by nhooyr/websocket will definitely fit better as it will
be just as performant but much easier to use.
When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
//go:build !js
// +build !js
package websocket
import (
"bytes"
"context"
"crypto/sha1"
"encoding/base64"
"errors"
"fmt"
"io"
"log"
"net/http"
"net/textproto"
"net/url"
"path"
"strings"
"github.com/coder/websocket/internal/errd"
)
// AcceptOption is an option that can be passed to Accept.
type AcceptOption interface {
acceptOption()
fmt.Stringer
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection when c.Subprotocol() == "".
Subprotocols []string
// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
//
// You probably want to use OriginPatterns instead.
InsecureSkipVerify bool
// OriginPatterns lists the host patterns for authorized origins.
// The request host is always authorized.
// Use this to enable cross origin WebSockets.
//
// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
// In such a case, example.com is the origin and chat.example.com is the request host.
// One would set this field to []string{"example.com"} to authorize example.com to connect.
//
// Each pattern is matched case insensitively against the request origin host
// with path.Match.
// See https://golang.org/pkg/path/#Match
//
// Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks.
//
// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
// to bring attention to the danger of such a setting.
OriginPatterns []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
// AcceptSubprotocols list the subprotocols that Accept will negotiate with a client.
// The first protocol that a client supports will be negotiated.
// Pass "" as a subprotocol if you would like to allow the default protocol.
func AcceptSubprotocols(subprotocols ...string) AcceptOption {
panic("TODO")
func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
var o AcceptOptions
if opts != nil {
o = *opts
}
return &o
}
// AcceptOrigins lists the origins that Accept will accept.
// Accept will always accept r.Host as the origin so you do not need to
// specify that with this option.
// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to a WebSocket.
//
// Accept will not allow cross origin requests by default.
// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
//
// Use this option with caution to avoid exposing your WebSocket
// server to a CSRF attack.
// See https://stackoverflow.com/a/37837709/4283659
// You can use a * to specify wildcards.
func AcceptOrigins(origins ...string) AcceptOption {
panic("TODO")
// Accept will write a response to w on all errors.
//
// Note that using the http.Request Context after Accept returns may lead to
// unexpected behavior (see http.Hijacker).
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return accept(w, r, opts)
}
// Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to WebSocket.
// Accept will reject the handshake if the Origin is not the same as the Host unless
// InsecureAcceptOrigin is passed.
// Accept uses w to write the handshake response so the timeouts on the http.Server apply.
func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) {
panic("TODO")
func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
defer errd.Wrap(&err, "failed to accept WebSocket connection")
errCode, err := verifyClientRequest(w, r)
if err != nil {
http.Error(w, err.Error(), errCode)
return nil, err
}
opts = opts.cloneWithDefaults()
if !opts.InsecureSkipVerify {
err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil {
if errors.Is(err, path.ErrBadPattern) {
log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden))
}
http.Error(w, err.Error(), http.StatusForbidden)
return nil, err
}
}
hj, ok := hijacker(w)
if !ok {
err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
return nil, err
}
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade")
key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
subproto := selectSubprotocol(r, opts.Subprotocols)
if subproto != "" {
w.Header().Set("Sec-WebSocket-Protocol", subproto)
}
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}
w.WriteHeader(http.StatusSwitchingProtocols)
// See https://github.com/nhooyr/websocket/issues/166
if ginWriter, ok := w.(interface {
WriteHeaderNow()
}); ok {
ginWriter.WriteHeaderNow()
}
netConn, brw, err := hj.Hijack()
if err != nil {
err = fmt.Errorf("failed to hijack connection: %w", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil, err
}
// https://github.com/golang/go/issues/32314
b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
return newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
rwc: netConn,
client: false,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: brw.Reader,
bw: brw.Writer,
}), nil
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
if !r.ProtoAtLeast(1, 1) {
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
}
if r.Method != "GET" {
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
}
// The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}
return 0, nil
}
func authenticateOrigin(r *http.Request, originHosts []string) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(r.Host, u.Host) {
return nil
}
for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host)
if err != nil {
return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
}
if matched {
return nil
}
}
if u.Host == "" {
return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
}
func match(pattern, s string) (bool, error) {
return path.Match(strings.ToLower(pattern), strings.ToLower(s))
}
func selectSubprotocol(r *http.Request, subprotocols []string) string {
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
for _, sp := range subprotocols {
for _, cp := range cps {
if strings.EqualFold(sp, cp) {
return cp
}
}
}
return ""
}
func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, false
}
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, false
}
func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
copts := mode.opts()
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, false
}
return copts, true
}
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
for _, t := range headerTokens(h, key) {
if strings.EqualFold(t, token) {
return true
}
}
return false
}
type websocketExtension struct {
name string
params []string
}
func websocketExtensions(h http.Header) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
for _, extStr := range extStrs {
if extStr == "" {
continue
}
vals := strings.Split(extStr, ";")
for i := range vals {
vals[i] = strings.TrimSpace(vals[i])
}
e := websocketExtension{
name: vals[0],
params: vals[1:],
}
exts = append(exts, e)
}
return exts
}
func headerTokens(h http.Header, key string) []string {
key = textproto.CanonicalMIMEHeaderKey(key)
var tokens []string
for _, v := range h[key] {
v = strings.TrimSpace(v)
for _, t := range strings.Split(v, ",") {
t = strings.TrimSpace(t)
tokens = append(tokens, t)
}
}
return tokens
}
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New()
h.Write([]byte(secWebSocketKey))
h.Write(keyGUID)
return base64.StdEncoding.EncodeToString(h.Sum(nil))
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
)
func TestAccept(t *testing.T) {
t.Parallel()
t.Run("badClientHandshake", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
_, err := Accept(w, r, nil)
assert.Contains(t, err, "protocol violation")
})
t.Run("badOrigin", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "harhar.com")
_, err := Accept(w, r, nil)
assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`)
})
// #247
t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "https://harhar.com")
_, err := Accept(w, r, nil)
assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`)
})
t.Run("badCompression", func(t *testing.T) {
t.Parallel()
newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
errHijack := errors.New("hijack error")
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errHijack
},
}
}
t.Run("withoutFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
})
})
t.Run("requireHttpHijacker", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
})
t.Run("badHijack", func(t *testing.T) {
t.Parallel()
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`)
})
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, "failed to hijack connection")
})
t.Run("closeRace", func(t *testing.T) {
t.Parallel()
server, _ := net.Pipe()
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return server, rw, nil
},
}
}
w := newResponseWriter()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
c, err := Accept(w, r, nil)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
c.Close(StatusInternalError, "the sky is falling")
wg.Done()
}()
go func() {
c.CloseNow()
wg.Done()
}()
wg.Wait()
assert.Success(t, err)
})
}
func Test_verifyClientHandshake(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
method string
http1 bool
h map[string]string
success bool
}{
{
name: "badConnection",
h: map[string]string{
"Connection": "notUpgrade",
},
},
{
name: "badUpgrade",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "notWebSocket",
},
},
{
name: "badMethod",
method: "POST",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
},
},
{
name: "badWebSocketVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "14",
},
},
{
name: "missingWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
},
},
{
name: "emptyWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "",
},
},
{
name: "shortWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(15),
},
},
{
name: "invalidWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "notbase64",
},
},
{
name: "extraWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
// Kinda cheeky, but http headers are case-insensitive.
// If 2 sec keys are present, this is a failure condition.
"Sec-WebSocket-Key": xrand.Base64(16),
"sec-webSocket-key": xrand.Base64(16),
},
},
{
name: "badHTTPVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(16),
},
http1: true,
},
{
name: "success",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(16),
},
success: true,
},
{
name: "successSecKeyExtraSpace",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
},
success: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest(tc.method, "/", nil)
r.ProtoMajor = 1
r.ProtoMinor = 1
if tc.http1 {
r.ProtoMinor = 0
}
for k, v := range tc.h {
r.Header.Add(k, v)
}
_, err := verifyClientRequest(httptest.NewRecorder(), r)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_selectSubprotocol(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
clientProtocols []string
serverProtocols []string
negotiated string
}{
{
name: "empty",
clientProtocols: nil,
serverProtocols: nil,
negotiated: "",
},
{
name: "basic",
clientProtocols: []string{"echo", "echo2"},
serverProtocols: []string{"echo2", "echo"},
negotiated: "echo2",
},
{
name: "none",
clientProtocols: []string{"echo", "echo3"},
serverProtocols: []string{"echo2", "echo4"},
negotiated: "",
},
{
name: "fallback",
clientProtocols: []string{"echo", "echo3"},
serverProtocols: []string{"echo2", "echo3"},
negotiated: "echo3",
},
{
name: "clientCasePresered",
clientProtocols: []string{"Echo1"},
serverProtocols: []string{"echo1"},
negotiated: "Echo1",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))
negotiated := selectSubprotocol(r, tc.serverProtocols)
assert.Equal(t, "negotiated", tc.negotiated, negotiated)
})
}
}
func Test_authenticateOrigin(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
origin string
host string
originPatterns []string
success bool
}{
{
name: "none",
success: true,
host: "example.com",
},
{
name: "invalid",
origin: "$#)(*)$#@*$(#@*$)#@*%)#(@*%)#(@%#@$#@$#$#@$#@}{}{}",
host: "example.com",
success: false,
},
{
name: "unauthorized",
origin: "https://example.com",
host: "example1.com",
success: false,
},
{
name: "authorized",
origin: "https://example.com",
host: "example.com",
success: true,
},
{
name: "authorizedCaseInsensitive",
origin: "https://examplE.com",
host: "example.com",
success: true,
},
{
name: "originPatterns",
origin: "https://two.examplE.com",
host: "example.com",
originPatterns: []string{
"*.example.com",
"bar.com",
},
success: true,
},
{
name: "originPatternsUnauthorized",
origin: "https://two.examplE.com",
host: "example.com",
originPatterns: []string{
"exam3.com",
"bar.com",
},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
r.Header.Set("Origin", tc.origin)
err := authenticateOrigin(r, tc.originPatterns)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func Test_selectDeflate(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
mode CompressionMode
header string
expCopts *compressionOptions
expOK bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
expOK: false,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
expOK: false,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
})
}
}
type mockHijacker struct {
http.ResponseWriter
hijack func() (net.Conn, *bufio.ReadWriter, error)
}
var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack()
}
type mockUnwrapper struct {
http.ResponseWriter
unwrap func() http.ResponseWriter
}
var _ rwUnwrapper = mockUnwrapper{}
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
return mu.unwrap()
}
//go:build !js
// +build !js
package websocket_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"strconv"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/util"
)
var excludedAutobahnCases = []string{
// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
// more performance overhead.
"6.*", "7.5.1",
// We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
}
var autobahnCases = []string{"*"}
// Used to run individual test cases. autobahnCases runs only those cases matched
// and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases
// is niled.
var onlyAutobahnCases = []string{}
func TestAutobahn(t *testing.T) {
t.Parallel()
if os.Getenv("AUTOBAHN") == "" {
t.SkipNow()
}
if os.Getenv("AUTOBAHN") == "fast" {
// These are the slow tests.
excludedAutobahnCases = append(excludedAutobahnCases,
"9.*", "12.*", "13.*",
)
}
if len(onlyAutobahnCases) > 0 {
excludedAutobahnCases = []string{}
autobahnCases = onlyAutobahnCases
}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
wstestURL, closeFn, err := wstestServer(t, ctx)
assert.Success(t, err)
defer func() {
assert.Success(t, closeFn())
}()
err = waitWS(ctx, wstestURL)
assert.Success(t, err)
cases, err := wstestCaseCount(ctx, wstestURL)
assert.Success(t, err)
t.Run("cases", func(t *testing.T) {
for i := 1; i <= cases; i++ {
i := i
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{
CompressionMode: websocket.CompressionContextTakeover,
})
assert.Success(t, err)
err = wstest.EchoLoop(ctx, c)
t.Logf("echoLoop: %v", err)
})
}
})
c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil)
assert.Success(t, err)
c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/autobahn-report/index.json")
}
func waitWS(ctx context.Context, url string) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
for ctx.Err() == nil {
c, _, err := websocket.Dial(ctx, url, nil)
if err != nil {
continue
}
c.Close(websocket.StatusNormalClosure, "")
return nil
}
return ctx.Err()
}
func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) {
defer errd.Wrap(&err, "failed to start autobahn wstest server")
serverAddr, err := unusedListenAddr()
if err != nil {
return "", nil, err
}
_, serverPort, err := net.SplitHostPort(serverAddr)
if err != nil {
return "", nil, err
}
url = "ws://" + serverAddr
const outDir = "ci/out/autobahn-report"
specFile, err := tempJSONFile(map[string]interface{}{
"url": url,
"outdir": outDir,
"cases": autobahnCases,
"exclude-cases": excludedAutobahnCases,
})
if err != nil {
return "", nil, fmt.Errorf("failed to write spec: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, time.Hour)
defer func() {
if err != nil {
cancel()
}
}()
dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite")
dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(dockerPull)
err = dockerPull.Run()
if err != nil {
return "", nil, fmt.Errorf("failed to pull docker image: %w", err)
}
wd, err := os.Getwd()
if err != nil {
return "", nil, err
}
var args []string
args = append(args, "run", "-i", "--rm",
"-v", fmt.Sprintf("%s:%[1]s", specFile),
"-v", fmt.Sprintf("%s/ci:/ci", wd),
fmt.Sprintf("-p=%s:%s", serverAddr, serverPort),
"crossbario/autobahn-testsuite",
)
args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile,
// Disables some server that runs as part of fuzzingserver mode.
// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
"--webport=0",
)
wstest := exec.CommandContext(ctx, "docker", args...)
wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(wstest)
err = wstest.Start()
if err != nil {
return "", nil, fmt.Errorf("failed to start wstest: %w", err)
}
return url, func() error {
err = wstest.Process.Kill()
if err != nil {
return fmt.Errorf("failed to kill wstest: %w", err)
}
err = wstest.Wait()
var ee *exec.ExitError
if errors.As(err, &ee) && ee.ExitCode() == -1 {
return nil
}
return err
}, nil
}
func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
defer errd.Wrap(&err, "failed to get case count")
c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil)
if err != nil {
return 0, err
}
defer c.Close(websocket.StatusInternalError, "")
_, r, err := c.Reader(ctx)
if err != nil {
return 0, err
}
b, err := io.ReadAll(r)
if err != nil {
return 0, err
}
cases, err = strconv.Atoi(string(b))
if err != nil {
return 0, err
}
c.Close(websocket.StatusNormalClosure, "")
return cases, nil
}
func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := os.ReadFile(path)
assert.Success(t, err)
var indexJSON map[string]map[string]struct {
Behavior string `json:"behavior"`
BehaviorClose string `json:"behaviorClose"`
}
err = json.Unmarshal(wstestOut, &indexJSON)
assert.Success(t, err)
for _, tests := range indexJSON {
for test, result := range tests {
t.Run(test, func(t *testing.T) {
switch result.BehaviorClose {
case "OK", "INFORMATIONAL":
default:
t.Errorf("bad close behaviour")
}
switch result.Behavior {
case "OK", "NON-STRICT", "INFORMATIONAL":
default:
t.Errorf("failed")
}
})
}
}
if t.Failed() {
htmlPath := strings.Replace(path, ".json", ".html", 1)
t.Errorf("detected autobahn violation, see %q", htmlPath)
}
}
func unusedListenAddr() (_ string, err error) {
defer errd.Wrap(&err, "failed to get unused listen address")
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
l.Close()
return l.Addr().String(), nil
}
func tempJSONFile(v interface{}) (string, error) {
f, err := os.CreateTemp("", "temp.json")
if err != nil {
return "", fmt.Errorf("temp file: %w", err)
}
defer f.Close()
e := json.NewEncoder(f)
e.SetIndent("", "\t")
err = e.Encode(v)
if err != nil {
return "", fmt.Errorf("json encode: %w", err)
}
err = f.Close()
if err != nil {
return "", fmt.Errorf("close temp file: %w", err)
}
return f.Name(), nil
}