good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
Commits on Source (544)
version: 2
jobs:
fmt:
docker:
- image: golang:1
steps:
- checkout
- restore_cache:
keys:
- go-{{ checksum "go.sum" }}
# Fallback to using the latest cache if no exact match is found.
- go-
- run: ./ci/fmt.sh
- save_cache:
paths:
- /go
- /root/.cache/go-build
key: go-{{ checksum "go.sum" }}
lint:
docker:
- image: golang:1
steps:
- checkout
- restore_cache:
keys:
- go-{{ checksum "go.sum" }}
# Fallback to using the latest cache if no exact match is found.
- go-
- run: ./ci/lint.sh
- save_cache:
paths:
- /go
- /root/.cache/go-build
key: go-{{ checksum "go.sum" }}
test:
docker:
- image: golang:1
steps:
- checkout
- restore_cache:
keys:
- go-{{ checksum "go.sum" }}
# Fallback to using the latest cache if no exact match is found.
- go-
- run: ./ci/test.sh
- save_cache:
paths:
- /go
- /root/.cache/go-build
key: go-{{ checksum "go.sum" }}
workflows:
version: 2
fmt:
jobs:
- fmt
lint:
jobs:
- lint
test:
jobs:
- test
version: 2
updates:
# Track in case we ever add dependencies.
- package-ecosystem: 'gomod'
directory: '/'
schedule:
interval: 'weekly'
commit-message:
prefix: 'chore'
# Keep example and test/benchmark deps up-to-date.
- package-ecosystem: 'gomod'
directories:
- '/internal/examples'
- '/internal/thirdparty'
schedule:
interval: 'monthly'
commit-message:
prefix: 'chore'
labels: []
groups:
internal-deps:
patterns:
- '*'
name: ci
on:
push:
branches:
- master
pull_request:
branches:
- master
concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}
cancel-in-progress: true
jobs:
fmt:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make fmt
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: go version
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make lint
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: make bench
name: daily
on:
workflow_dispatch:
schedule:
- cron: '42 0 * * *' # daily at 00:42
concurrency:
group: ${{ github.workflow }}
cancel-in-progress: true
jobs:
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage.html
path: ./ci/out/coverage.html
bench-dev:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make bench
test-dev:
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- uses: actions/checkout@v4
with:
ref: dev
- uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- run: AUTOBAHN=1 make test
- uses: actions/upload-artifact@v4
with:
name: coverage-dev.html
path: ./ci/out/coverage.html
name: static
on:
push:
branches: ['master']
workflow_dispatch:
# Set permissions of the GITHUB_TOKEN to allow deployment to GitHub Pages.
permissions:
contents: read
pages: write
id-token: write
concurrency:
group: pages
cancel-in-progress: true
jobs:
deploy:
environment:
name: github-pages
url: ${{ steps.deployment.outputs.page_url }}
runs-on: ubuntu-latest
steps:
- name: Disable AppArmor
if: runner.os == 'Linux'
run: |
# Disable AppArmor for Ubuntu 23.10+.
# https://chromium.googlesource.com/chromium/src/+/main/docs/security/apparmor-userns-restrictions.md
echo 0 | sudo tee /proc/sys/kernel/apparmor_restrict_unprivileged_userns
- name: Checkout
uses: actions/checkout@v4
- name: Setup Pages
uses: actions/configure-pages@v5
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version-file: ./go.mod
- name: Generate coverage and badge
run: |
make test
mkdir -p ./ci/out/static
cp ./ci/out/coverage.html ./ci/out/static/coverage.html
percent=$(go tool cover -func ./ci/out/coverage.prof | tail -n1 | awk '{print $3}' | tr -d '%')
wget -O ./ci/out/static/coverage.svg "https://img.shields.io/badge/coverage-${percent}%25-success"
- name: Upload artifact
uses: actions/upload-pages-artifact@v3
with:
path: ./ci/out/static/
- name: Deploy to GitHub Pages
id: deployment
uses: actions/deploy-pages@v4
MIT License Copyright (c) 2025 Coder
Copyright (c) 2018 Anmol Sethi Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
Permission is hereby granted, free of charge, to any person obtaining a copy copyright notice and this permission notice appear in all copies.
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
copies of the Software, and to permit persons to whom the Software is MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
furnished to do so, subject to the following conditions: ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
The above copyright notice and this permission notice shall be included in all ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
copies or substantial portions of the Software. OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
.PHONY: all
all: fmt lint test
.PHONY: fmt
fmt:
./ci/fmt.sh
.PHONY: lint
lint:
./ci/lint.sh
.PHONY: test
test:
./ci/test.sh
.PHONY: bench
bench:
./ci/bench.sh
\ No newline at end of file
# websocket # websocket
[![GoDoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://godoc.org/nhooyr.io/websocket) [![Go Reference](https://pkg.go.dev/badge/github.com/coder/websocket.svg)](https://pkg.go.dev/github.com/coder/websocket)
[![Codecov](https://img.shields.io/codecov/c/github/nhooyr/websocket.svg?color=brightgreen)](https://codecov.io/gh/nhooyr/websocket) [![Go Coverage](https://coder.github.io/websocket/coverage.svg)](https://coder.github.io/websocket/coverage.html)
websocket is a minimal and idiomatic WebSocket library for Go. websocket is a minimal and idiomatic WebSocket library for Go.
## Install ## Install
```bash ```sh
go get nhooyr.io/websocket@v1.3.1 go get github.com/coder/websocket
``` ```
## Features > [!NOTE]
> Coder now maintains this project as explained in [this blog post](https://coder.com/blog/websocket).
> We're grateful to [nhooyr](https://github.com/nhooyr) for authoring and maintaining this project from
> 2019 to 2024.
## Highlights
- Minimal and idiomatic API - Minimal and idiomatic API
- Tiny codebase at 1700 lines - First class [context.Context](https://blog.golang.org/context) support
- First class context.Context support - Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite)
- Thorough tests, fully passes the [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) - [Zero dependencies](https://pkg.go.dev/github.com/coder/websocket?tab=imports)
- Zero dependencies outside of the stdlib for the core library - JSON helpers in the [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- JSON and ProtoBuf helpers in the wsjson and wspb subpackages - Zero alloc reads and writes
- Highly optimized by default - Concurrent writes
- Concurrent writes out of the box - [Close handshake](https://pkg.go.dev/github.com/coder/websocket#Conn.Close)
- [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- [Ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections
- Compile to [Wasm](https://pkg.go.dev/github.com/coder/websocket#hdr-Wasm)
## Roadmap ## Roadmap
- [ ] WebSockets over HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) See GitHub issues for minor issues but the major future enhancements are:
- [ ] Perfect examples [#217](https://github.com/nhooyr/websocket/issues/217)
- [ ] wstest.Pipe for in memory testing [#340](https://github.com/nhooyr/websocket/issues/340)
- [ ] Ping pong heartbeat helper [#267](https://github.com/nhooyr/websocket/issues/267)
- [ ] Ping pong instrumentation callbacks [#246](https://github.com/nhooyr/websocket/issues/246)
- [ ] Graceful shutdown helpers [#209](https://github.com/nhooyr/websocket/issues/209)
- [ ] Assembly for WebSocket masking [#16](https://github.com/nhooyr/websocket/issues/16)
- WIP at [#326](https://github.com/nhooyr/websocket/pull/326), about 3x faster
- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4)
- [ ] The holy grail [#402](https://github.com/nhooyr/websocket/issues/402)
## Examples ## Examples
For a production quality example that shows off the full API, see the [echo example on the godoc](https://godoc.org/nhooyr.io/websocket#example-package--Echo). On github, the example is at [example_echo_test.go](./example_echo_test.go). For a production quality example that demonstrates the complete API, see the
[echo example](./internal/examples/echo).
For a full stack example, see the [chat example](./internal/examples/chat).
### Server ### Server
```go ```go
http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) c, err := websocket.Accept(w, r, nil)
if err != nil { if err != nil {
// ... // ...
} }
defer c.Close(websocket.StatusInternalError, "the sky is falling") defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) // Set the context as needed. Use of r.Context() is not recommended
// to avoid surprising behavior (see http.Hijacker).
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel() defer cancel()
var v interface{} var v interface{}
err = wsjson.Read(ctx, c, &v) err = wsjson.Read(ctx, c, &v)
if err != nil { if err != nil {
// ... // ...
} }
log.Printf("received: %v", v) log.Printf("received: %v", v)
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
}) })
``` ```
### Client ### Client
The client side of this library requires at minimum Go 1.12 as it uses a [new feature
in net/http](https://github.com/golang/go/issues/26937#issuecomment-415855861) to perform WebSocket handshakes.
```go ```go
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{}) c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil { if err != nil {
// ... // ...
} }
defer c.Close(websocket.StatusInternalError, "the sky is falling") defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi") err = wsjson.Write(ctx, c, "hi")
if err != nil { if err != nil {
...@@ -78,89 +100,63 @@ if err != nil { ...@@ -78,89 +100,63 @@ if err != nil {
c.Close(websocket.StatusNormalClosure, "") c.Close(websocket.StatusNormalClosure, "")
``` ```
## Design justifications
- A minimal API is easier to maintain due to less docs, tests and bugs
- A minimal API is also easier to use and learn
- Context based cancellation is more ergonomic and robust than setting deadlines
- net.Conn is never exposed as WebSocket over HTTP/2 will not have a net.Conn.
- Using net/http's Client for dialing means we do not have to reinvent dialing hooks
and configurations like other WebSocket libraries
- We do not support the deflate compression extension because Go's compress/flate library
is very memory intensive and browsers do not handle WebSocket compression intelligently.
See [#5](https://github.com/nhooyr/websocket/issues/5)
## Comparison ## Comparison
Before the comparison, I want to point out that both gorilla/websocket and gobwas/ws were
extremely useful in implementing the WebSocket protocol correctly so *big thanks* to the
authors of both. In particular, I made sure to go through the issue tracker of gorilla/websocket
to ensure I implemented details correctly and understood how people were using WebSockets in
production.
### gorilla/websocket ### gorilla/websocket
https://github.com/gorilla/websocket Advantages of [gorilla/websocket](https://github.com/gorilla/websocket):
This package is the community standard but it is 6 years old and over time
has accumulated cruft. There are too many ways to do the same thing.
Just compare the godoc of
[nhooyr/websocket](https://godoc.org/github.com/nhooyr/websocket) side by side with
[gorilla/websocket](https://godoc.org/github.com/gorilla/websocket).
The API for nhooyr/websocket has been designed such that there is only one way to do things
which makes it easy to use correctly. Not only is the API simpler, the implementation is
only 1700 lines whereas gorilla/websocket is at 3500 lines. That's more code to maintain,
more code to test, more code to document and more surface area for bugs.
Moreover, nhooyr/websocket has support for newer Go idioms such as context.Context and
also uses net/http's Client and ResponseWriter directly for WebSocket handshakes.
gorilla/websocket writes its handshakes to the underlying net.Conn which means
it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2.
Some more advantages of nhooyr/websocket are that it supports concurrent writes and - Mature and widely used
makes it very easy to close the connection with a status code and reason. - [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage)
- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers)
- No extra goroutine per connection to support cancellation with context.Context. This costs github.com/coder/websocket 2 KB of memory per connection.
- Will be removed soon with [context.AfterFunc](https://github.com/golang/go/issues/57928). See [#411](https://github.com/nhooyr/websocket/issues/411)
The ping API is also nicer. gorilla/websocket requires registering a pong handler on the Conn Advantages of github.com/coder/websocket:
which results in awkward control flow. With nhooyr/websocket you use the Ping method on the Conn
that sends a ping and also waits for the pong.
In terms of performance, the differences mostly depend on your application code. nhooyr/websocket - Minimal and idiomatic API
reuses message buffers out of the box if you use the wsjson and wspb subpackages. - Compare godoc of [github.com/coder/websocket](https://pkg.go.dev/github.com/coder/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side.
As mentioned above, nhooyr/websocket also supports concurrent writers. - [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) wrapper
- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535))
The only performance con to nhooyr/websocket is that uses one extra goroutine to support - Full [context.Context](https://blog.golang.org/context) support
cancellation with context.Context. This costs 2 KB of memory which is cheap compared to - Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client)
simplicity benefits. - Will enable easy HTTP/2 support in the future
- Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client.
### x/net/websocket - Concurrent writes
- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448))
https://godoc.org/golang.org/x/net/websocket - Idiomatic [ping pong](https://pkg.go.dev/github.com/coder/websocket#Conn.Ping) API
- Gorilla requires registering a pong callback before sending a Ping
- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432))
- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/github.com/coder/websocket/wsjson) subpackage
- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go
- Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/).
Soon we'll have assembly and be 3x faster [#326](https://github.com/nhooyr/websocket/pull/326)
- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support
- Gorilla only supports no context takeover mode
- [CloseRead](https://pkg.go.dev/github.com/coder/websocket#Conn.CloseRead) helper for write only connections ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492))
Unmaintained and the API does not reflect WebSocket semantics. Should never be used. #### golang.org/x/net/websocket
See https://github.com/golang/go/issues/18152 [golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated.
See [golang/go/issues/18152](https://github.com/golang/go/issues/18152).
### gobwas/ws The [net.Conn](https://pkg.go.dev/github.com/coder/websocket#NetConn) can help in transitioning
to github.com/coder/websocket.
https://github.com/gobwas/ws #### gobwas/ws
This library has an extremely flexible API but that comes at the cost of usability [gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used
and clarity. in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb).
This library is fantastic in terms of performance. The author put in significant However it is quite bloated. See https://pkg.go.dev/github.com/gobwas/ws
effort to ensure its speed and I have applied as many of its optimizations as
I could into nhooyr/websocket. Definitely check out his fantastic [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb)
about performant WebSocket servers.
If you want a library that gives you absolute control over everything, this is the library, When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
but for most users, the API provided by nhooyr/websocket will fit better as it is nearly just
as performant but much easier to use correctly and idiomatic.
## Users #### lesismal/nbio
This is a list of companies or projects that use this library. [lesismal/nbio](https://github.com/lesismal/nbio) is similar to gobwas/ws in that the API is
event driven for performance reasons.
- [Coder](https://github.com/cdr) However it is quite bloated. See https://pkg.go.dev/github.com/lesismal/nbio
If your company or project is using this library, please feel free to open a PR to amend the list. When writing idiomatic Go, github.com/coder/websocket will be faster and easier to use.
//go:build !js
// +build !js
package websocket package websocket
import ( import (
"bytes" "bytes"
"context"
"crypto/sha1" "crypto/sha1"
"encoding/base64" "encoding/base64"
"errors"
"fmt"
"io" "io"
"log"
"net/http" "net/http"
"net/textproto" "net/textproto"
"net/url" "net/url"
"path"
"strings" "strings"
"golang.org/x/net/http/httpguts" "github.com/coder/websocket/internal/errd"
"golang.org/x/xerrors"
) )
// AcceptOptions represents the options available to pass to Accept. // AcceptOptions represents Accept's options.
type AcceptOptions struct { type AcceptOptions struct {
// Subprotocols lists the websocket subprotocols that Accept will negotiate with a client. // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
// reject it, close the connection if c.Subprotocol() == "". // reject it, close the connection when c.Subprotocol() == "".
Subprotocols []string Subprotocols []string
// InsecureSkipVerify disables Accept's origin verification // InsecureSkipVerify is used to disable Accept's origin verification behaviour.
// behaviour. By default Accept only allows the handshake to
// succeed if the javascript that is initiating the handshake
// is on the same domain as the server. This is to prevent CSRF
// attacks when secure data is stored in a cookie as there is no same
// origin policy for WebSockets. In other words, javascript from
// any domain can perform a WebSocket dial on an arbitrary server.
// This dial will include cookies which means the arbitrary javascript
// can perform actions as the authenticated user.
//
// See https://stackoverflow.com/a/37837709/4283659
// //
// The only time you need this is if your javascript is running on a different domain // You probably want to use OriginPatterns instead.
// than your WebSocket server.
// Please think carefully about whether you really need this option before you use it.
// If you do, remember that if you store secure data in cookies, you wil need to verify the
// Origin header yourself otherwise you are exposing yourself to a CSRF attack.
InsecureSkipVerify bool InsecureSkipVerify bool
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // OriginPatterns lists the host patterns for authorized origins.
if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { // The request host is always authorized.
err := xerrors.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) // Use this to enable cross origin WebSockets.
http.Error(w, err.Error(), http.StatusBadRequest) //
return err // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
} // In such a case, example.com is the origin and chat.example.com is the request host.
// One would set this field to []string{"example.com"} to authorize example.com to connect.
//
// Each pattern is matched case insensitively against the request origin host
// with path.Match.
// See https://golang.org/pkg/path/#Match
//
// Please ensure you understand the ramifications of enabling this.
// If used incorrectly your WebSocket server will be open to CSRF attacks.
//
// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
// to bring attention to the danger of such a setting.
OriginPatterns []string
if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { // CompressionMode controls the compression mode.
err := xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) // Defaults to CompressionDisabled.
http.Error(w, err.Error(), http.StatusBadRequest) //
return err // See docs on CompressionMode for details.
} CompressionMode CompressionMode
if r.Method != "GET" { // CompressionThreshold controls the minimum size of a message before compression is applied.
err := xerrors.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) //
http.Error(w, err.Error(), http.StatusBadRequest) // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
return err // for CompressionContextTakeover.
} CompressionThreshold int
if r.Header.Get("Sec-WebSocket-Version") != "13" { // OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
err := xerrors.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) //
http.Error(w, err.Error(), http.StatusBadRequest) // The payload contains the application data of the ping frame.
return err // If the callback returns false, the subsequent pong frame will not be sent.
} // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
if r.Header.Get("Sec-WebSocket-Key") == "" { // OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
err := xerrors.New("websocket protocol violation: missing Sec-WebSocket-Key") //
http.Error(w, err.Error(), http.StatusBadRequest) // The payload contains the application data of the pong frame.
return err // To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
} //
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
return nil func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
var o AcceptOptions
if opts != nil {
o = *opts
}
return &o
} }
// Accept accepts a WebSocket handshake from a client and upgrades the // Accept accepts a WebSocket handshake from a client and upgrades the
// the connection to a WebSocket. // the connection to a WebSocket.
// //
// Accept will reject the handshake if the Origin domain is not the same as the Host unless // Accept will not allow cross origin requests by default.
// the InsecureSkipVerify option is set. In other words, by default it does not allow // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
// cross origin requests.
// //
// If an error occurs, Accept will always write an appropriate response so you do not // Accept will write a response to w on all errors.
// have to. //
func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { // Note that using the http.Request Context after Accept returns may lead to
c, err := accept(w, r, opts) // unexpected behavior (see http.Hijacker).
if err != nil { func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, xerrors.Errorf("failed to accept websocket connection: %w", err) return accept(w, r, opts)
}
return c, nil
} }
func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
err := verifyClientRequest(w, r) defer errd.Wrap(&err, "failed to accept WebSocket connection")
errCode, err := verifyClientRequest(w, r)
if err != nil { if err != nil {
http.Error(w, err.Error(), errCode)
return nil, err return nil, err
} }
opts = opts.cloneWithDefaults()
if !opts.InsecureSkipVerify { if !opts.InsecureSkipVerify {
err = authenticateOrigin(r) err = authenticateOrigin(r, opts.OriginPatterns)
if err != nil { if err != nil {
if errors.Is(err, path.ErrBadPattern) {
log.Printf("websocket: %v", err)
err = errors.New(http.StatusText(http.StatusForbidden))
}
http.Error(w, err.Error(), http.StatusForbidden) http.Error(w, err.Error(), http.StatusForbidden)
return nil, err return nil, err
} }
} }
hj, ok := w.(http.Hijacker) hj, ok := hijacker(w)
if !ok { if !ok {
err = xerrors.New("passed ResponseWriter does not implement http.Hijacker") err = errors.New("http.ResponseWriter does not implement http.Hijacker")
http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
return nil, err return nil, err
} }
...@@ -116,18 +135,30 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, ...@@ -116,18 +135,30 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
w.Header().Set("Upgrade", "websocket") w.Header().Set("Upgrade", "websocket")
w.Header().Set("Connection", "Upgrade") w.Header().Set("Connection", "Upgrade")
handleSecWebSocketKey(w, r) key := r.Header.Get("Sec-WebSocket-Key")
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
subproto := selectSubprotocol(r, opts.Subprotocols) subproto := selectSubprotocol(r, opts.Subprotocols)
if subproto != "" { if subproto != "" {
w.Header().Set("Sec-WebSocket-Protocol", subproto) w.Header().Set("Sec-WebSocket-Protocol", subproto)
} }
copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
if ok {
w.Header().Set("Sec-WebSocket-Extensions", copts.String())
}
w.WriteHeader(http.StatusSwitchingProtocols) w.WriteHeader(http.StatusSwitchingProtocols)
// See https://github.com/nhooyr/websocket/issues/166
if ginWriter, ok := w.(interface {
WriteHeaderNow()
}); ok {
ginWriter.WriteHeaderNow()
}
netConn, brw, err := hj.Hijack() netConn, brw, err := hj.Hijack()
if err != nil { if err != nil {
err = xerrors.Errorf("failed to hijack connection: %w", err) err = fmt.Errorf("failed to hijack connection: %w", err)
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return nil, err return nil, err
} }
...@@ -136,38 +167,204 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, ...@@ -136,38 +167,204 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn,
b, _ := brw.Reader.Peek(brw.Reader.Buffered()) b, _ := brw.Reader.Peek(brw.Reader.Buffered())
brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
c := &Conn{ return newConn(connConfig{
subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), subprotocol: w.Header().Get("Sec-WebSocket-Protocol"),
br: brw.Reader, rwc: netConn,
bw: brw.Writer, client: false,
closer: netConn, copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: brw.Reader,
bw: brw.Writer,
}), nil
}
func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
if !r.ProtoAtLeast(1, 1) {
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
}
if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
}
if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
}
if r.Method != "GET" {
return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
}
if r.Header.Get("Sec-WebSocket-Version") != "13" {
w.Header().Set("Sec-WebSocket-Version", "13")
return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
}
websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
if len(websocketSecKeys) == 0 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
}
if len(websocketSecKeys) > 1 {
return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
} }
c.init()
return c, nil // The RFC states to remove any leading or trailing whitespace.
websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
}
return 0, nil
} }
func headerValuesContainsToken(h http.Header, key, val string) bool { func authenticateOrigin(r *http.Request, originHosts []string) error {
key = textproto.CanonicalMIMEHeaderKey(key) origin := r.Header.Get("Origin")
return httpguts.HeaderValuesContainsToken(h[key], val) if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(r.Host, u.Host) {
return nil
}
for _, hostPattern := range originHosts {
matched, err := match(hostPattern, u.Host)
if err != nil {
return fmt.Errorf("failed to parse path pattern %q: %w", hostPattern, err)
}
if matched {
return nil
}
}
if u.Host == "" {
return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
}
return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
}
func match(pattern, s string) (bool, error) {
return path.Match(strings.ToLower(pattern), strings.ToLower(s))
} }
func selectSubprotocol(r *http.Request, subprotocols []string) string { func selectSubprotocol(r *http.Request, subprotocols []string) string {
cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
for _, sp := range subprotocols { for _, sp := range subprotocols {
if headerValuesContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { for _, cp := range cps {
return sp if strings.EqualFold(sp, cp) {
return cp
}
} }
} }
return "" return ""
} }
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
if mode == CompressionDisabled {
return nil, false
}
for _, ext := range extensions {
switch ext.name {
// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
// See https://github.com/nhooyr/websocket/issues/218
case "permessage-deflate":
copts, ok := acceptDeflate(ext, mode)
if ok {
return copts, true
}
}
}
return nil, false
}
func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
key := r.Header.Get("Sec-WebSocket-Key") copts := mode.opts()
w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
case "client_max_window_bits",
"server_max_window_bits=15":
continue
}
if strings.HasPrefix(p, "client_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, false
}
return copts, true
}
func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
for _, t := range headerTokens(h, key) {
if strings.EqualFold(t, token) {
return true
}
}
return false
}
type websocketExtension struct {
name string
params []string
}
func websocketExtensions(h http.Header) []websocketExtension {
var exts []websocketExtension
extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
for _, extStr := range extStrs {
if extStr == "" {
continue
}
vals := strings.Split(extStr, ";")
for i := range vals {
vals[i] = strings.TrimSpace(vals[i])
}
e := websocketExtension{
name: vals[0],
params: vals[1:],
}
exts = append(exts, e)
}
return exts
}
func headerTokens(h http.Header, key string) []string {
key = textproto.CanonicalMIMEHeaderKey(key)
var tokens []string
for _, v := range h[key] {
v = strings.TrimSpace(v)
for _, t := range strings.Split(v, ",") {
t = strings.TrimSpace(t)
tokens = append(tokens, t)
}
}
return tokens
} }
var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
func secWebSocketAccept(secWebSocketKey string) string { func secWebSocketAccept(secWebSocketKey string) string {
h := sha1.New() h := sha1.New()
h.Write([]byte(secWebSocketKey)) h.Write([]byte(secWebSocketKey))
...@@ -175,18 +372,3 @@ func secWebSocketAccept(secWebSocketKey string) string { ...@@ -175,18 +372,3 @@ func secWebSocketAccept(secWebSocketKey string) string {
return base64.StdEncoding.EncodeToString(h.Sum(nil)) return base64.StdEncoding.EncodeToString(h.Sum(nil))
} }
func authenticateOrigin(r *http.Request) error {
origin := r.Header.Get("Origin")
if origin == "" {
return nil
}
u, err := url.Parse(origin)
if err != nil {
return xerrors.Errorf("failed to parse Origin header %q: %w", origin, err)
}
if strings.EqualFold(u.Host, r.Host) {
return nil
}
return xerrors.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
}
//go:build !js
// +build !js
package websocket package websocket
import ( import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"sync"
"testing" "testing"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
) )
func TestAccept(t *testing.T) {
t.Parallel()
t.Run("badClientHandshake", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
_, err := Accept(w, r, nil)
assert.Contains(t, err, "protocol violation")
})
t.Run("badOrigin", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "harhar.com")
_, err := Accept(w, r, nil)
assert.Contains(t, err, `request Origin "harhar.com" is not a valid URL with a host`)
})
// #247
t.Run("unauthorizedOriginErrorMessage", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Origin", "https://harhar.com")
_, err := Accept(w, r, nil)
assert.Contains(t, err, `request Origin "harhar.com" is not authorized for Host "example.com"`)
})
t.Run("badCompression", func(t *testing.T) {
t.Parallel()
newRequest := func(extensions string) *http.Request {
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
r.Header.Set("Sec-WebSocket-Extensions", extensions)
return r
}
errHijack := errors.New("hijack error")
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, errHijack
},
}
}
t.Run("withoutFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header", w.Header().Get("Sec-WebSocket-Extensions"), "")
})
t.Run("withFallback", func(t *testing.T) {
t.Parallel()
w := newResponseWriter()
r := newRequest("permessage-deflate; harharhar, permessage-deflate")
_, err := Accept(w, r, &AcceptOptions{
CompressionMode: CompressionNoContextTakeover,
})
assert.ErrorIs(t, errHijack, err)
assert.Equal(t, "extension header",
w.Header().Get("Sec-WebSocket-Extensions"),
CompressionNoContextTakeover.opts().String(),
)
})
})
t.Run("requireHttpHijacker", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, `http.ResponseWriter does not implement http.Hijacker`)
})
t.Run("badHijack", func(t *testing.T) {
t.Parallel()
w := mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, `failed to hijack connection`)
})
t.Run("wrapperHijackerIsUnwrapped", func(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
_, err := Accept(w, r, nil)
assert.Contains(t, err, "failed to hijack connection")
})
t.Run("closeRace", func(t *testing.T) {
t.Parallel()
server, _ := net.Pipe()
rw := bufio.NewReadWriter(bufio.NewReader(server), bufio.NewWriter(server))
newResponseWriter := func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: httptest.NewRecorder(),
hijack: func() (net.Conn, *bufio.ReadWriter, error) {
return server, rw, nil
},
}
}
w := newResponseWriter()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", xrand.Base64(16))
c, err := Accept(w, r, nil)
wg := &sync.WaitGroup{}
wg.Add(2)
go func() {
c.Close(StatusInternalError, "the sky is falling")
wg.Done()
}()
go func() {
c.CloseNow()
wg.Done()
}()
wg.Wait()
assert.Success(t, err)
})
}
func Test_verifyClientHandshake(t *testing.T) { func Test_verifyClientHandshake(t *testing.T) {
t.Parallel() t.Parallel()
testCases := []struct { testCases := []struct {
name string name string
method string method string
http1 bool
h map[string]string h map[string]string
success bool success bool
}{ }{
...@@ -45,7 +248,15 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -45,7 +248,15 @@ func Test_verifyClientHandshake(t *testing.T) {
}, },
}, },
{ {
name: "badWebSocketKey", name: "missingWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
},
},
{
name: "emptyWebSocketKey",
h: map[string]string{ h: map[string]string{
"Connection": "Upgrade", "Connection": "Upgrade",
"Upgrade": "websocket", "Upgrade": "websocket",
...@@ -54,12 +265,62 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -54,12 +265,62 @@ func Test_verifyClientHandshake(t *testing.T) {
}, },
}, },
{ {
name: "success", name: "shortWebSocketKey",
h: map[string]string{ h: map[string]string{
"Connection": "Upgrade", "Connection": "Upgrade",
"Upgrade": "websocket", "Upgrade": "websocket",
"Sec-WebSocket-Version": "13", "Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "meow123", "Sec-WebSocket-Key": xrand.Base64(15),
},
},
{
name: "invalidWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": "notbase64",
},
},
{
name: "extraWebSocketKey",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
// Kinda cheeky, but http headers are case-insensitive.
// If 2 sec keys are present, this is a failure condition.
"Sec-WebSocket-Key": xrand.Base64(16),
"sec-webSocket-key": xrand.Base64(16),
},
},
{
name: "badHTTPVersion",
h: map[string]string{
"Connection": "Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(16),
},
http1: true,
},
{
name: "success",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": xrand.Base64(16),
},
success: true,
},
{
name: "successSecKeyExtraSpace",
h: map[string]string{
"Connection": "keep-alive, Upgrade",
"Upgrade": "websocket",
"Sec-WebSocket-Version": "13",
"Sec-WebSocket-Key": " " + xrand.Base64(16) + " ",
}, },
success: true, success: true,
}, },
...@@ -70,16 +331,23 @@ func Test_verifyClientHandshake(t *testing.T) { ...@@ -70,16 +331,23 @@ func Test_verifyClientHandshake(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest(tc.method, "/", nil) r := httptest.NewRequest(tc.method, "/", nil)
r.ProtoMajor = 1
r.ProtoMinor = 1
if tc.http1 {
r.ProtoMinor = 0
}
for k, v := range tc.h { for k, v := range tc.h {
r.Header.Set(k, v) r.Header.Add(k, v)
} }
err := verifyClientRequest(w, r) _, err := verifyClientRequest(httptest.NewRecorder(), r)
if (err == nil) != tc.success { if tc.success {
t.Fatalf("unexpected error value: %+v", err) assert.Success(t, err)
} else {
assert.Error(t, err)
} }
}) })
} }
...@@ -118,6 +386,12 @@ func Test_selectSubprotocol(t *testing.T) { ...@@ -118,6 +386,12 @@ func Test_selectSubprotocol(t *testing.T) {
serverProtocols: []string{"echo2", "echo3"}, serverProtocols: []string{"echo2", "echo3"},
negotiated: "echo3", negotiated: "echo3",
}, },
{
name: "clientCasePresered",
clientProtocols: []string{"Echo1"},
serverProtocols: []string{"echo1"},
negotiated: "Echo1",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
...@@ -129,9 +403,7 @@ func Test_selectSubprotocol(t *testing.T) { ...@@ -129,9 +403,7 @@ func Test_selectSubprotocol(t *testing.T) {
r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ","))
negotiated := selectSubprotocol(r, tc.serverProtocols) negotiated := selectSubprotocol(r, tc.serverProtocols)
if tc.negotiated != negotiated { assert.Equal(t, "negotiated", tc.negotiated, negotiated)
t.Fatalf("expected %q but got %q", tc.negotiated, negotiated)
}
}) })
} }
} }
...@@ -140,10 +412,11 @@ func Test_authenticateOrigin(t *testing.T) { ...@@ -140,10 +412,11 @@ func Test_authenticateOrigin(t *testing.T) {
t.Parallel() t.Parallel()
testCases := []struct { testCases := []struct {
name string name string
origin string origin string
host string host string
success bool originPatterns []string
success bool
}{ }{
{ {
name: "none", name: "none",
...@@ -174,6 +447,26 @@ func Test_authenticateOrigin(t *testing.T) { ...@@ -174,6 +447,26 @@ func Test_authenticateOrigin(t *testing.T) {
host: "example.com", host: "example.com",
success: true, success: true,
}, },
{
name: "originPatterns",
origin: "https://two.examplE.com",
host: "example.com",
originPatterns: []string{
"*.example.com",
"bar.com",
},
success: true,
},
{
name: "originPatternsUnauthorized",
origin: "https://two.examplE.com",
host: "example.com",
originPatterns: []string{
"exam3.com",
"bar.com",
},
success: false,
},
} }
for _, tc := range testCases { for _, tc := range testCases {
...@@ -184,10 +477,98 @@ func Test_authenticateOrigin(t *testing.T) { ...@@ -184,10 +477,98 @@ func Test_authenticateOrigin(t *testing.T) {
r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil) r := httptest.NewRequest("GET", "http://"+tc.host+"/", nil)
r.Header.Set("Origin", tc.origin) r.Header.Set("Origin", tc.origin)
err := authenticateOrigin(r) err := authenticateOrigin(r, tc.originPatterns)
if (err == nil) != tc.success { if tc.success {
t.Fatalf("unexpected error value: %+v", err) assert.Success(t, err)
} else {
assert.Error(t, err)
} }
}) })
} }
} }
func Test_selectDeflate(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
mode CompressionMode
header string
expCopts *compressionOptions
expOK bool
}{
{
name: "disabled",
mode: CompressionDisabled,
expCopts: nil,
expOK: false,
},
{
name: "noClientSupport",
mode: CompressionNoContextTakeover,
expCopts: nil,
expOK: false,
},
{
name: "permessage-deflate",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow",
expOK: false,
},
{
name: "permessage-deflate/unknown-parameter",
mode: CompressionNoContextTakeover,
header: "permessage-deflate; meow, permessage-deflate; client_max_window_bits",
expCopts: &compressionOptions{
clientNoContextTakeover: true,
serverNoContextTakeover: true,
},
expOK: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
h := http.Header{}
h.Set("Sec-WebSocket-Extensions", tc.header)
copts, ok := selectDeflate(websocketExtensions(h), tc.mode)
assert.Equal(t, "selected options", tc.expOK, ok)
assert.Equal(t, "compression options", tc.expCopts, copts)
})
}
}
type mockHijacker struct {
http.ResponseWriter
hijack func() (net.Conn, *bufio.ReadWriter, error)
}
var _ http.Hijacker = mockHijacker{}
func (mj mockHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return mj.hijack()
}
type mockUnwrapper struct {
http.ResponseWriter
unwrap func() http.ResponseWriter
}
var _ rwUnwrapper = mockUnwrapper{}
func (mu mockUnwrapper) Unwrap() http.ResponseWriter {
return mu.unwrap()
}
//go:build !js
// +build !js
package websocket_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"os"
"os/exec"
"strconv"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/util"
)
var excludedAutobahnCases = []string{
// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
// more performance overhead.
"6.*", "7.5.1",
// We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
}
var autobahnCases = []string{"*"}
// Used to run individual test cases. autobahnCases runs only those cases matched
// and not excluded by excludedAutobahnCases. Adding cases here means excludedAutobahnCases
// is niled.
var onlyAutobahnCases = []string{}
func TestAutobahn(t *testing.T) {
t.Parallel()
if os.Getenv("AUTOBAHN") == "" {
t.SkipNow()
}
if os.Getenv("AUTOBAHN") == "fast" {
// These are the slow tests.
excludedAutobahnCases = append(excludedAutobahnCases,
"9.*", "12.*", "13.*",
)
}
if len(onlyAutobahnCases) > 0 {
excludedAutobahnCases = []string{}
autobahnCases = onlyAutobahnCases
}
ctx, cancel := context.WithTimeout(context.Background(), time.Hour)
defer cancel()
wstestURL, closeFn, err := wstestServer(t, ctx)
assert.Success(t, err)
defer func() {
assert.Success(t, closeFn())
}()
err = waitWS(ctx, wstestURL)
assert.Success(t, err)
cases, err := wstestCaseCount(ctx, wstestURL)
assert.Success(t, err)
t.Run("cases", func(t *testing.T) {
for i := 1; i <= cases; i++ {
i := i
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), &websocket.DialOptions{
CompressionMode: websocket.CompressionContextTakeover,
})
assert.Success(t, err)
err = wstest.EchoLoop(ctx, c)
t.Logf("echoLoop: %v", err)
})
}
})
c, _, err := websocket.Dial(ctx, wstestURL+"/updateReports?agent=main", nil)
assert.Success(t, err)
c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/autobahn-report/index.json")
}
func waitWS(ctx context.Context, url string) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
for ctx.Err() == nil {
c, _, err := websocket.Dial(ctx, url, nil)
if err != nil {
continue
}
c.Close(websocket.StatusNormalClosure, "")
return nil
}
return ctx.Err()
}
func wstestServer(tb testing.TB, ctx context.Context) (url string, closeFn func() error, err error) {
defer errd.Wrap(&err, "failed to start autobahn wstest server")
serverAddr, err := unusedListenAddr()
if err != nil {
return "", nil, err
}
_, serverPort, err := net.SplitHostPort(serverAddr)
if err != nil {
return "", nil, err
}
url = "ws://" + serverAddr
const outDir = "ci/out/autobahn-report"
specFile, err := tempJSONFile(map[string]interface{}{
"url": url,
"outdir": outDir,
"cases": autobahnCases,
"exclude-cases": excludedAutobahnCases,
})
if err != nil {
return "", nil, fmt.Errorf("failed to write spec: %w", err)
}
ctx, cancel := context.WithTimeout(ctx, time.Hour)
defer func() {
if err != nil {
cancel()
}
}()
dockerPull := exec.CommandContext(ctx, "docker", "pull", "crossbario/autobahn-testsuite")
dockerPull.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
dockerPull.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(dockerPull)
err = dockerPull.Run()
if err != nil {
return "", nil, fmt.Errorf("failed to pull docker image: %w", err)
}
wd, err := os.Getwd()
if err != nil {
return "", nil, err
}
var args []string
args = append(args, "run", "-i", "--rm",
"-v", fmt.Sprintf("%s:%[1]s", specFile),
"-v", fmt.Sprintf("%s/ci:/ci", wd),
fmt.Sprintf("-p=%s:%s", serverAddr, serverPort),
"crossbario/autobahn-testsuite",
)
args = append(args, "wstest", "--mode", "fuzzingserver", "--spec", specFile,
// Disables some server that runs as part of fuzzingserver mode.
// See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124
"--webport=0",
)
wstest := exec.CommandContext(ctx, "docker", args...)
wstest.Stdout = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
wstest.Stderr = util.WriterFunc(func(p []byte) (int, error) {
tb.Log(string(p))
return len(p), nil
})
tb.Log(wstest)
err = wstest.Start()
if err != nil {
return "", nil, fmt.Errorf("failed to start wstest: %w", err)
}
return url, func() error {
err = wstest.Process.Kill()
if err != nil {
return fmt.Errorf("failed to kill wstest: %w", err)
}
err = wstest.Wait()
var ee *exec.ExitError
if errors.As(err, &ee) && ee.ExitCode() == -1 {
return nil
}
return err
}, nil
}
func wstestCaseCount(ctx context.Context, url string) (cases int, err error) {
defer errd.Wrap(&err, "failed to get case count")
c, _, err := websocket.Dial(ctx, url+"/getCaseCount", nil)
if err != nil {
return 0, err
}
defer c.Close(websocket.StatusInternalError, "")
_, r, err := c.Reader(ctx)
if err != nil {
return 0, err
}
b, err := io.ReadAll(r)
if err != nil {
return 0, err
}
cases, err = strconv.Atoi(string(b))
if err != nil {
return 0, err
}
c.Close(websocket.StatusNormalClosure, "")
return cases, nil
}
func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := os.ReadFile(path)
assert.Success(t, err)
var indexJSON map[string]map[string]struct {
Behavior string `json:"behavior"`
BehaviorClose string `json:"behaviorClose"`
}
err = json.Unmarshal(wstestOut, &indexJSON)
assert.Success(t, err)
for _, tests := range indexJSON {
for test, result := range tests {
t.Run(test, func(t *testing.T) {
switch result.BehaviorClose {
case "OK", "INFORMATIONAL":
default:
t.Errorf("bad close behaviour")
}
switch result.Behavior {
case "OK", "NON-STRICT", "INFORMATIONAL":
default:
t.Errorf("failed")
}
})
}
}
if t.Failed() {
htmlPath := strings.Replace(path, ".json", ".html", 1)
t.Errorf("detected autobahn violation, see %q", htmlPath)
}
}
func unusedListenAddr() (_ string, err error) {
defer errd.Wrap(&err, "failed to get unused listen address")
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
return "", err
}
l.Close()
return l.Addr().String(), nil
}
func tempJSONFile(v interface{}) (string, error) {
f, err := os.CreateTemp("", "temp.json")
if err != nil {
return "", fmt.Errorf("temp file: %w", err)
}
defer f.Close()
e := json.NewEncoder(f)
e.SetIndent("", "\t")
err = e.Encode(v)
if err != nil {
return "", fmt.Errorf("json encode: %w", err)
}
err = f.Close()
if err != nil {
return "", fmt.Errorf("close temp file: %w", err)
}
return f.Name(), nil
}
comment: off
coverage:
status:
# Prevent small changes in coverage from failing CI.
project:
default:
threshold: 5
patch:
default:
threshold: 5
#!/usr/bin/env bash #!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
set -euo pipefail go test --run=^$ --bench=. --benchmem "$@" ./...
cd "$(dirname "${0}")" # For profiling add: --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test
source ./lib.sh (
cd ./internal/thirdparty
go test --run=^$ --bench=. --benchmem "$@" .
go test --vet=off --run=^$ -bench=. -o=ci/out/websocket.test \ GOARCH=arm64 go test -c -o ../../ci/out/thirdparty-arm64.test "$@" .
-cpuprofile=ci/out/cpu.prof \ if [ "$#" -eq 0 ]; then
-memprofile=ci/out/mem.prof \ if [ "${CI-}" ]; then
-blockprofile=ci/out/block.prof \ sudo apt-get update
-mutexprofile=ci/out/mutex.prof \ sudo apt-get install -y qemu-user-static
. ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
echo qemu-aarch64 ../../ci/out/thirdparty-arm64.test --test.run=^$ --test.bench=Benchmark_mask --test.benchmem
echo "Profiles are in ./ci/out/*.prof fi
Keep in mind that every profiler Go provides is enabled so that may skew the benchmarks." )
#!/usr/bin/env bash #!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
set -euo pipefail X_TOOLS_VERSION=v0.31.0
cd "$(dirname "${0}")"
source ./lib.sh
unstaged_files() { go mod tidy
git ls-files --other --modified --exclude-standard (cd ./internal/thirdparty && go mod tidy)
} (cd ./internal/examples && go mod tidy)
gofmt -w -s .
go run golang.org/x/tools/cmd/goimports@${X_TOOLS_VERSION} -w "-local=$(go list -m)" .
gen() { git ls-files "*.yml" "*.md" "*.js" "*.css" "*.html" | xargs npx prettier@3.3.3 \
# Unfortunately, this is the only way to ensure go.mod and go.sum are correct. --check \
# See https://github.com/golang/go/issues/27005 --log-level=warn \
go list ./... > /dev/null --print-width=90 \
go mod tidy --no-semi \
--single-quote \
--arrow-parens=avoid
go generate ./... go run golang.org/x/tools/cmd/stringer@${X_TOOLS_VERSION} -type=opcode,MessageType,StatusCode -output=stringer.go
}
fmt() { if [ "${CI-}" ]; then
gofmt -w -s . git diff --exit-code
go run go.coder.com/go-tools/cmd/goimports -w "-local=$(go list -m)" .
go run mvdan.cc/sh/cmd/shfmt -i 2 -w -s -sr .
}
gen
fmt
if [[ $CI && $(unstaged_files) != "" ]]; then
echo
echo "Files either need generation or are formatted incorrectly."
echo "Please run:"
echo "./ci/fmt.sh"
echo
git status
exit 1
fi fi
#!/usr/bin/env bash
set -euo pipefail
# Ensures $CI can be used if it's set or not.
export CI=${CI:-}
if [[ $CI ]]; then
export GOFLAGS=-mod=readonly
export DEBIAN_FRONTEND=noninteractive
fi
cd "$(git rev-parse --show-toplevel)"
#!/usr/bin/env bash #!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
set -euo pipefail STATICCHECK_VERSION=v0.6.1
cd "$(dirname "${0}")" GOVULNCHECK_VERSION=v1.1.4
source ./lib.sh
if [[ $CI ]]; then
apt-get update -qq
apt-get install -qq shellcheck > /dev/null
fi
# shellcheck disable=SC2046
shellcheck -e SC1091 -x $(git ls-files "*.sh")
go vet ./... go vet ./...
go run golang.org/x/lint/golint -set_exit_status ./... GOOS=js GOARCH=wasm go vet ./...
go install honnef.co/go/tools/cmd/staticcheck@${STATICCHECK_VERSION}
staticcheck ./...
GOOS=js GOARCH=wasm staticcheck ./...
govulncheck() {
tmpf=$(mktemp)
if ! command govulncheck "$@" >"$tmpf" 2>&1; then
cat "$tmpf"
fi
}
go install golang.org/x/vuln/cmd/govulncheck@${GOVULNCHECK_VERSION}
govulncheck ./...
GOOS=js GOARCH=wasm govulncheck ./...
(
cd ./internal/examples
go vet ./...
staticcheck ./...
govulncheck ./...
)
(
cd ./internal/thirdparty
go vet ./...
staticcheck ./...
govulncheck ./...
)
#!/usr/bin/env bash
# This script is for local testing. See .circleci for CI.
set -euo pipefail
cd "$(dirname "${0}")"
source ./lib.sh
./fmt.sh
./lint.sh
./test.sh
#!/usr/bin/env bash #!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
set -euo pipefail (
cd "$(dirname "${0}")" cd ./internal/examples
source ./lib.sh go test "$@" ./...
)
(
cd ./internal/thirdparty
go test "$@" ./...
)
echo "This step includes benchmarks for race detection and coverage purposes (
but the numbers will be misleading. please see the bench step or ./bench.sh for GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
more accurate numbers." if [ "$#" -eq 0 ]; then
echo if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)
if [[ $CI ]]; then
apt-get update -qq
apt-get install -qq python-pip > /dev/null
# Need to add pip install directory to $PATH.
export PATH="/home/circleci/.local/bin:$PATH"
pip install -qqq autobahntestsuite
fi
go test -race -coverprofile=ci/out/coverage.prof --vet=off -bench=. -coverpkg=./... ./... go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce
go tool cover -func=ci/out/coverage.prof go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i.bak '/examples/d' ci/out/coverage.prof
if [[ $CI ]]; then # Last line is the total coverage.
bash <(curl -s https://codecov.io/bash) -f ci/out/coverage.prof go tool cover -func ci/out/coverage.prof | tail -n1
else
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
echo go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
echo "Please open ci/out/coverage.html to see detailed test coverage stats."
fi
//go:build !js
// +build !js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
"github.com/coder/websocket/internal/errd"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var err error
if ce.Code != StatusNoStatusRcvd {
p, err = ce.bytes()
if err != nil {
return err
}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
}
func (c *Conn) waitCloseHandshake() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()
for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}
for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}
const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
func (c *Conn) casClosing() bool {
return c.closing.Swap(true)
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
//go:build !js
// +build !js
package websocket
import (
"io"
"math"
"strings"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func TestCloseError(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
ce CloseError
success bool
}{
{
name: "normal",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxCloseReason),
},
success: true,
},
{
name: "bigReason",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxCloseReason+1),
},
success: false,
},
{
name: "bigCode",
ce: CloseError{
Code: math.MaxUint16,
Reason: strings.Repeat("x", maxCloseReason),
},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := tc.ce.bytesErr()
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
t.Run("Error", func(t *testing.T) {
exp := `status = StatusInternalError and reason = "meow"`
act := CloseError{
Code: StatusInternalError,
Reason: "meow",
}.Error()
assert.Equal(t, "CloseError.Error()", exp, act)
})
}
func Test_parseClosePayload(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
p []byte
success bool
ce CloseError
}{
{
name: "normal",
p: append([]byte{0x3, 0xE8}, []byte("hello")...),
success: true,
ce: CloseError{
Code: StatusNormalClosure,
Reason: "hello",
},
},
{
name: "nothing",
success: true,
ce: CloseError{
Code: StatusNoStatusRcvd,
},
},
{
name: "oneByte",
p: []byte{0},
success: false,
},
{
name: "badStatusCode",
p: []byte{0x17, 0x70},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ce, err := parseClosePayload(tc.p)
if tc.success {
assert.Success(t, err)
assert.Equal(t, "close payload", tc.ce, ce)
} else {
assert.Error(t, err)
}
})
}
}
func Test_validWireCloseCode(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
code StatusCode
valid bool
}{
{
name: "normal",
code: StatusNormalClosure,
valid: true,
},
{
name: "noStatus",
code: StatusNoStatusRcvd,
valid: false,
},
{
name: "3000",
code: 3000,
valid: true,
},
{
name: "4999",
code: 4999,
valid: true,
},
{
name: "unknown",
code: 5000,
valid: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := validWireCloseCode(tc.code)
assert.Equal(t, "wire close code", tc.valid, act)
})
}
}
func TestCloseStatus(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
in error
exp StatusCode
}{
{
name: "nil",
in: nil,
exp: -1,
},
{
name: "io.EOF",
in: io.EOF,
exp: -1,
},
{
name: "StatusInternalError",
in: CloseError{
Code: StatusInternalError,
},
exp: StatusInternalError,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := CloseStatus(tc.in)
assert.Equal(t, "close status", tc.exp, act)
})
}
}