good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
FROM golang:1.12
LABEL "com.github.actions.name"="lint"
LABEL "com.github.actions.description"=""
LABEL "com.github.actions.icon"="code"
LABEL "com.github.actions.color"="purple"
RUN apt update && apt install -y shellcheck
COPY entrypoint.sh /entrypoint.sh
CMD ["/entrypoint.sh"]
#!/usr/bin/env bash
source ci/lib.sh || exit 1
(
shopt -s globstar nullglob dotglob
shellcheck ./**/*.sh
)
go vet ./...
go run golang.org/x/lint/golint -set_exit_status ./...
#!/usr/bin/env bash
# This script is for local testing. See .github for CI.
cd "$(dirname "${0}")/.." || exit 1
source ci/lib.sh || exit 1
function docker_run() {
local DIR="$1"
local IMAGE
IMAGE="$(docker build -q "$DIR")"
docker run \
-it \
-v "${PWD}:/repo" \
-v "$(go env GOPATH):/go" \
-v "$(go env GOCACHE):/root/.cache/go-build" \
-w /repo \
"${IMAGE}"
}
function help() {
set +x
echo
echo "$0 [-h] <step>"
cat << EOF
If you do not pass in an explicit step, all steps will be ran in order.
Pass "analyze" as the step to be put into an interactive container to analyze
profiles.
EOF
exit 1
}
# Use this to analyze benchmark profiles.
if [[ ${1-} == "analyze" ]]; then
docker run \
-it \
-v "${PWD}:/repo" \
-v "$(go env GOPATH):/go" \
-v "$(go env GOCACHE):/root/.cache/go-build" \
-w /repo \
golang:1.12
fi
if [[ ${1-} == "-h" || ${1-} == "--help" || ${1-} == "help" ]]; then
help
fi
if [[ $# -gt 0 ]]; then
if [[ ! -d "ci/$*" ]]; then
help
fi
docker_run "ci/$*"
exit 0
fi
docker_run ci/fmt
docker_run ci/lint
docker_run ci/test
docker_run ci/bench
#!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
(
cd ./internal/examples
go test "$@" ./...
)
(
cd ./internal/thirdparty
go test "$@" ./...
)
(
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)
go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i.bak '/examples/d' ci/out/coverage.prof
# Last line is the total coverage.
go tool cover -func ci/out/coverage.prof | tail -n1
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
FROM golang:1.12
LABEL "com.github.actions.name"="test"
LABEL "com.github.actions.description"=""
LABEL "com.github.actions.icon"="code"
LABEL "com.github.actions.color"="green"
RUN apt update && \
apt install -y shellcheck python-pip && \
pip install autobahntestsuite
COPY entrypoint.sh /entrypoint.sh
CMD ["/entrypoint.sh"]
#!/usr/bin/env bash
source ci/lib.sh || exit 1
set +x
echo
echo "this step includes benchmarks for race detection and coverage purposes
but the numbers will be misleading. please see the bench step for more
accurate numbers"
echo
set -x
go test -race -coverprofile=ci/out/coverage.prof --vet=off -bench=. -coverpkg=./... ./...
go tool cover -func=ci/out/coverage.prof
if [[ $CI ]]; then
bash <(curl -s https://codecov.io/bash) -f ci/out/coverage.prof
else
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
set +x
echo
echo "please open ci/out/coverage.html to see detailed test coverage stats"
fi
//go:build !js
// +build !js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
"github.com/coder/websocket/internal/errd"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var err error
if ce.Code != StatusNoStatusRcvd {
p, err = ce.bytes()
if err != nil {
return err
}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
}
func (c *Conn) waitCloseHandshake() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()
for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}
for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}
const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
func (c *Conn) casClosing() bool {
return c.closing.Swap(true)
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
//go:build !js
// +build !js
package websocket
import (
"io"
"math"
"strings"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func TestCloseError(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
ce CloseError
success bool
}{
{
name: "normal",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxCloseReason),
},
success: true,
},
{
name: "bigReason",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxCloseReason+1),
},
success: false,
},
{
name: "bigCode",
ce: CloseError{
Code: math.MaxUint16,
Reason: strings.Repeat("x", maxCloseReason),
},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := tc.ce.bytesErr()
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
t.Run("Error", func(t *testing.T) {
exp := `status = StatusInternalError and reason = "meow"`
act := CloseError{
Code: StatusInternalError,
Reason: "meow",
}.Error()
assert.Equal(t, "CloseError.Error()", exp, act)
})
}
func Test_parseClosePayload(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
p []byte
success bool
ce CloseError
}{
{
name: "normal",
p: append([]byte{0x3, 0xE8}, []byte("hello")...),
success: true,
ce: CloseError{
Code: StatusNormalClosure,
Reason: "hello",
},
},
{
name: "nothing",
success: true,
ce: CloseError{
Code: StatusNoStatusRcvd,
},
},
{
name: "oneByte",
p: []byte{0},
success: false,
},
{
name: "badStatusCode",
p: []byte{0x17, 0x70},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ce, err := parseClosePayload(tc.p)
if tc.success {
assert.Success(t, err)
assert.Equal(t, "close payload", tc.ce, ce)
} else {
assert.Error(t, err)
}
})
}
}
func Test_validWireCloseCode(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
code StatusCode
valid bool
}{
{
name: "normal",
code: StatusNormalClosure,
valid: true,
},
{
name: "noStatus",
code: StatusNoStatusRcvd,
valid: false,
},
{
name: "3000",
code: 3000,
valid: true,
},
{
name: "4999",
code: 4999,
valid: true,
},
{
name: "unknown",
code: 5000,
valid: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := validWireCloseCode(tc.code)
assert.Equal(t, "wire close code", tc.valid, act)
})
}
}
func TestCloseStatus(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
in error
exp StatusCode
}{
{
name: "nil",
in: nil,
exp: -1,
},
{
name: "io.EOF",
in: io.EOF,
exp: -1,
},
{
name: "StatusInternalError",
in: CloseError{
Code: StatusInternalError,
},
exp: StatusInternalError,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := CloseStatus(tc.in)
assert.Equal(t, "close status", tc.exp, act)
})
}
}
//go:build !js
// +build !js
package websocket
import (
"compress/flate"
"io"
"sync"
)
// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int
const (
// CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota
// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// previous messages. i.e compression context across messages is preserved.
//
// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
//
// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
//
// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
CompressionContextTakeover
// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
//
// This means less efficient compression as the sliding window from previous messages will not be used but the
// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// Especially if the connections are long lived and seldom written to.
//
// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover
)
func (m CompressionMode) opts() *compressionOptions {
return &compressionOptions{
clientNoContextTakeover: m == CompressionNoContextTakeover,
serverNoContextTakeover: m == CompressionNoContextTakeover,
}
}
type compressionOptions struct {
clientNoContextTakeover bool
serverNoContextTakeover bool
}
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
return s
}
// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct {
w io.Writer
tail []byte
}
func (tw *trimLastFourBytesWriter) reset() {
if tw != nil && tw.tail != nil {
tw.tail = tw.tail[:0]
}
}
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
if tw.tail == nil {
tw.tail = make([]byte, 0, 4)
}
extra := len(tw.tail) + len(p) - 4
if extra <= 0 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Now we need to write as many extra bytes as we can from the previous tail.
if extra > len(tw.tail) {
extra = len(tw.tail)
}
if extra > 0 {
_, err := tw.w.Write(tw.tail[:extra])
if err != nil {
return 0, err
}
// Shift remaining bytes in tail over.
n := copy(tw.tail, tw.tail[extra:])
tw.tail = tw.tail[:n]
}
// If p is less than or equal to 4 bytes,
// all of it is is part of the tail.
if len(p) <= 4 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Otherwise, only the last 4 bytes are.
tw.tail = append(tw.tail, p[len(p)-4:]...)
p = p[:len(p)-4]
n, err := tw.w.Write(p)
return n + 4, err
}
var flateReaderPool sync.Pool
func getFlateReader(r io.Reader, dict []byte) io.Reader {
fr, ok := flateReaderPool.Get().(io.Reader)
if !ok {
return flate.NewReaderDict(r, dict)
}
fr.(flate.Resetter).Reset(r, dict)
return fr
}
func putFlateReader(fr io.Reader) {
flateReaderPool.Put(fr)
}
var flateWriterPool sync.Pool
func getFlateWriter(w io.Writer) *flate.Writer {
fw, ok := flateWriterPool.Get().(*flate.Writer)
if !ok {
fw, _ = flate.NewWriter(w, flate.BestSpeed)
return fw
}
fw.Reset(w)
return fw
}
func putFlateWriter(w *flate.Writer) {
flateWriterPool.Put(w)
}
type slidingWindow struct {
buf []byte
}
var swPoolMu sync.RWMutex
var swPool = map[int]*sync.Pool{}
func slidingWindowPool(n int) *sync.Pool {
swPoolMu.RLock()
p, ok := swPool[n]
swPoolMu.RUnlock()
if ok {
return p
}
p = &sync.Pool{}
swPoolMu.Lock()
swPool[n] = p
swPoolMu.Unlock()
return p
}
func (sw *slidingWindow) init(n int) {
if sw.buf != nil {
return
}
if n == 0 {
n = 32768
}
p := slidingWindowPool(n)
sw2, ok := p.Get().(*slidingWindow)
if ok {
*sw = *sw2
} else {
sw.buf = make([]byte, 0, n)
}
}
func (sw *slidingWindow) close() {
sw.buf = sw.buf[:0]
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock()
}
func (sw *slidingWindow) write(p []byte) {
if len(p) >= cap(sw.buf) {
sw.buf = sw.buf[:cap(sw.buf)]
p = p[len(p)-cap(sw.buf):]
copy(sw.buf, p)
return
}
left := cap(sw.buf) - len(sw.buf)
if left < len(p) {
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
spaceNeeded := len(p) - left
copy(sw.buf, sw.buf[spaceNeeded:])
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
}
sw.buf = append(sw.buf, p...)
}
//go:build !js
// +build !js
package websocket
import (
"bytes"
"compress/flate"
"io"
"strings"
"testing"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
)
func Test_slidingWindow(t *testing.T) {
t.Parallel()
const testCount = 99
const maxWindow = 99999
for i := 0; i < testCount; i++ {
t.Run("", func(t *testing.T) {
t.Parallel()
input := xrand.String(maxWindow)
windowLength := xrand.Int(maxWindow)
var sw slidingWindow
sw.init(windowLength)
sw.write([]byte(input))
assert.Equal(t, "window length", windowLength, cap(sw.buf))
if !strings.HasSuffix(input, string(sw.buf)) {
t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf)
}
})
}
}
func BenchmarkFlateWriter(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, _ := flate.NewWriter(io.Discard, flate.BestSpeed)
// We have to write a byte to get the writer to allocate to its full extent.
w.Write([]byte{'a'})
w.Flush()
}
}
func BenchmarkFlateReader(b *testing.B) {
b.ReportAllocs()
var buf bytes.Buffer
w, _ := flate.NewWriter(&buf, flate.BestSpeed)
w.Write([]byte{'a'})
w.Flush()
for i := 0; i < b.N; i++ {
r := flate.NewReader(bytes.NewReader(buf.Bytes()))
io.ReadAll(r)
}
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
//
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error
// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}
type connConfig struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
c := &Conn{
subprotocol: cfg.subprotocol,
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
flateThreshold: cfg.flateThreshold,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}
c.readMu = newMu(c)
c.writeFrameMu = newMu(c)
c.msgReader = newMsgReader(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close()
})
go c.timeoutLoop()
return c
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.close()
return
case <-writeCtx.Done():
c.close()
return
}
}
}
func (c *Conn) flate() bool {
return c.copts != nil
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context, p string) error {
pong := make(chan struct{}, 1)
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
}
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return net.ErrClosed
default:
}
return nil
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
//go:build !js
package websocket_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/test/xrand"
"github.com/coder/websocket/internal/xsync"
"github.com/coder/websocket/wsjson"
)
func TestConn(t *testing.T) {
t.Parallel()
t.Run("fuzzData", func(t *testing.T) {
t.Parallel()
compressionMode := func() websocket.CompressionMode {
return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1))
}
for i := 0; i < 5; i++ {
t.Run("", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
CompressionMode: compressionMode(),
CompressionThreshold: xrand.Int(9999),
}, &websocket.AcceptOptions{
CompressionMode: compressionMode(),
CompressionThreshold: xrand.Int(9999),
})
tt.goEchoLoop(c2)
c1.SetReadLimit(131072)
for i := 0; i < 5; i++ {
err := wstest.Echo(tt.ctx, c1, 131072)
assert.Success(t, err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
}
})
t.Run("badClose", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
c2.CloseRead(tt.ctx)
err := c1.Close(-1, "")
assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
})
t.Run("ping", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
for i := 0; i < 10; i++ {
err := c1.Ping(tt.ctx)
assert.Success(t, err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
t.Run("badPing", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")
})
t.Run("pingReceivedPongReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Success(t, err)
c1.CloseNow()
c2.CloseNow()
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1))
})
t.Run("pingReceivedPongNotReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")
c1.CloseNow()
c2.CloseNow()
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1))
})
t.Run("concurrentWrite", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
tt.goDiscardLoop(c2)
msg := xrand.Bytes(xrand.Int(9999))
const count = 100
errs := make(chan error, count)
for i := 0; i < count; i++ {
go func() {
select {
case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
case <-tt.ctx.Done():
return
}
}()
}
for i := 0; i < count; i++ {
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
}
err := c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
t.Run("concurrentWriteError", func(t *testing.T) {
tt, c1, _ := newConnTest(t, nil, nil)
_, err := c1.Writer(tt.ctx, websocket.MessageText)
assert.Success(t, err)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
err = c1.Write(ctx, websocket.MessageText, []byte("x"))
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected error: %#v", err)
}
})
t.Run("netConn", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
// Does not give any confidence but at least ensures no crashes.
d, _ := tt.ctx.Deadline()
n1.SetDeadline(d)
n1.SetDeadline(time.Time{})
assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String())
assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network())
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
if err != nil {
return err
}
return n2.Close()
})
b, err := io.ReadAll(n1)
assert.Success(t, err)
_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
assert.Equal(t, "read msg", []byte("hello"), b)
})
t.Run("netConn/BadMsg", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
c2.CloseRead(tt.ctx)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
return err
})
_, err := io.ReadAll(n1)
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
})
t.Run("netConn/readLimit", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
s := strings.Repeat("papa", 1<<20)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte(s))
if err != nil {
return err
}
return n2.Close()
})
b, err := io.ReadAll(n1)
assert.Success(t, err)
_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
assert.Equal(t, "read msg", s, string(b))
})
t.Run("netConn/pastDeadline", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
n1.SetDeadline(time.Now().Add(-time.Minute))
n2.SetDeadline(time.Now().Add(-time.Minute))
// No panic we're good.
})
t.Run("wsjson", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(1 << 30)
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(tt.ctx, c1, exp)
})
var act interface{}
err := wsjson.Read(tt.ctx, c1, &act)
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
t.Run("HTTPClient.Timeout", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
HTTPClient: &http.Client{Timeout: time.Second * 5},
}, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(1 << 30)
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(tt.ctx, c1, exp)
})
var act interface{}
err := wsjson.Read(tt.ctx, c1, &act)
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
t.Run("CloseNow", func(t *testing.T) {
_, c1, c2 := newConnTest(t, nil, nil)
err1 := c1.CloseNow()
err2 := c2.CloseNow()
assert.Success(t, err1)
assert.Success(t, err2)
err1 = c1.CloseNow()
err2 = c2.CloseNow()
assert.ErrorIs(t, websocket.ErrClosed, err1)
assert.ErrorIs(t, websocket.ErrClosed, err2)
})
t.Run("MidReadClose", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(131072)
for i := 0; i < 5; i++ {
err := wstest.Echo(tt.ctx, c1, 131072)
assert.Success(t, err)
}
err := wsjson.Write(tt.ctx, c1, "four")
assert.Success(t, err)
_, _, err = c1.Reader(tt.ctx)
assert.Success(t, err)
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
}
func TestWasm(t *testing.T) {
t.Parallel()
if os.Getenv("CI") == "" {
t.SkipNow()
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, &websocket.AcceptOptions{
Subprotocols: []string{"echo"},
InsecureSkipVerify: true,
})
if err != nil {
t.Error(err)
}
}))
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v")
cmd.Env = append(cleanEnv(os.Environ()), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
b, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
}
}
func cleanEnv(env []string) (out []string) {
for _, e := range env {
// Filter out GITHUB envs and anything with token in it,
// especially GITHUB_TOKEN in CI as it breaks TestWasm.
if strings.HasPrefix(e, "GITHUB") || strings.Contains(e, "TOKEN") {
continue
}
out = append(out, e)
}
return out
}
func assertCloseStatus(exp websocket.StatusCode, err error) error {
if websocket.CloseStatus(err) == -1 {
return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
}
if websocket.CloseStatus(err) != exp {
return fmt.Errorf("expected close status %v but got %v", exp, err)
}
return nil
}
type connTest struct {
t testing.TB
ctx context.Context
}
func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
if t, ok := t.(*testing.T); ok {
t.Parallel()
}
t.Helper()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
tt = &connTest{t: t, ctx: ctx}
t.Cleanup(cancel)
c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
if xrand.Bool() {
c1, c2 = c2, c1
}
t.Cleanup(func() {
c2.CloseNow()
c1.CloseNow()
})
return tt, c1, c2
}
func (tt *connTest) goEchoLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
})
tt.t.Cleanup(func() {
cancel()
err := <-echoLoopErr
if err != nil {
tt.t.Errorf("echo loop error: %v", err)
}
})
}
func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
discardLoopErr := xsync.Go(func() error {
defer c.Close(websocket.StatusInternalError, "")
for {
_, _, err := c.Read(ctx)
if err != nil {
return assertCloseStatus(websocket.StatusNormalClosure, err)
}
}
})
tt.t.Cleanup(func() {
cancel()
err := <-discardLoopErr
if err != nil {
tt.t.Errorf("discard loop error: %v", err)
}
})
}
func BenchmarkConn(b *testing.B) {
benchCases := []struct {
name string
mode websocket.CompressionMode
}{
{
name: "disabledCompress",
mode: websocket.CompressionDisabled,
},
{
name: "compressContextTakeover",
mode: websocket.CompressionContextTakeover,
},
{
name: "compressNoContext",
mode: websocket.CompressionNoContextTakeover,
},
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
CompressionMode: bc.mode,
}, &websocket.AcceptOptions{
CompressionMode: bc.mode,
})
bb.goEchoLoop(c2)
bytesWritten := c1.RecordBytesWritten()
bytesRead := c1.RecordBytesRead()
msg := []byte(strings.Repeat("1234", 128))
readBuf := make([]byte, len(msg))
writes := make(chan struct{})
defer close(writes)
werrs := make(chan error)
go func() {
for range writes {
select {
case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
case <-bb.ctx.Done():
return
}
}
}()
b.SetBytes(int64(len(msg)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case writes <- struct{}{}:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
typ, r, err := c1.Reader(bb.ctx)
if err != nil {
b.Fatal(i, err)
}
if websocket.MessageText != typ {
assert.Equal(b, "data type", websocket.MessageText, typ)
}
_, err = io.ReadFull(r, readBuf)
if err != nil {
b.Fatal(err)
}
n2, err := r.Read(readBuf)
if err != io.EOF {
assert.Equal(b, "read err", io.EOF, err)
}
if n2 != 0 {
assert.Equal(b, "n2", 0, n2)
}
if !bytes.Equal(msg, readBuf) {
assert.Equal(b, "msg", msg, readBuf)
}
select {
case err = <-werrs:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
if err != nil {
b.Fatal(err)
}
}
b.StopTimer()
b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
b.ReportMetric(float64(*bytesRead/b.N), "read/op")
err := c1.Close(websocket.StatusNormalClosure, "")
assert.Success(b, err)
})
}
}
func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) {
defer errd.Wrap(&err, "echo server failed")
c, err := websocket.Accept(w, r, opts)
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
err = wstest.EchoLoop(r.Context(), c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
}
func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(ctx, c, exp)
})
var act interface{}
c.SetReadLimit(1 << 30)
err := wsjson.Read(ctx, c, &act)
assert.Success(tb, err)
assert.Equal(tb, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(tb, err)
case <-ctx.Done():
tb.Fatal(ctx.Err())
}
}
func assertClose(tb testing.TB, c *websocket.Conn) {
tb.Helper()
err := c.Close(websocket.StatusNormalClosure, "")
assert.Success(tb, err)
}
func TestConcurrentClosePing(t *testing.T) {
t.Parallel()
for i := 0; i < 64; i++ {
func() {
c1, c2 := wstest.Pipe(nil, nil)
defer c1.CloseNow()
defer c2.CloseNow()
c1.CloseRead(context.Background())
c2.CloseRead(context.Background())
errc := xsync.Go(func() error {
for range time.Tick(time.Millisecond) {
err := c1.Ping(context.Background())
if err != nil {
return err
}
}
panic("unreachable")
})
time.Sleep(10 * time.Millisecond)
assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
<-errc
}()
}
}
func TestConnClosePropagation(t *testing.T) {
t.Parallel()
want := []byte("hello")
keepWriting := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
err := c.Write(context.Background(), websocket.MessageText, want)
if err != nil {
return err
}
}
})
}
keepReading := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
_, got, err := c.Read(context.Background())
if err != nil {
return err
}
if !bytes.Equal(want, got) {
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
}
}
})
}
checkReadErr := func(t *testing.T, err error) {
// Check read error (output depends on when read is called in relation to connection closure).
var ce websocket.CloseError
if errors.As(err, &ce) {
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
} else {
assert.ErrorIs(t, net.ErrClosed, err)
}
}
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
for _, c := range conn {
// Check write error.
err := c.Write(context.Background(), websocket.MessageText, want)
assert.ErrorIs(t, net.ErrClosed, err)
_, _, err = c.Read(context.Background())
checkReadErr(t, err)
}
}
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
_, got, err := other.Read(tt.ctx)
assert.Success(t, err)
assert.Equal(t, "msg", want, got)
err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
otherReadErr := keepReading(other)
err := this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = other.CloseRead(tt.ctx)
errs := keepReading(this)
err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)
err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-errs:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
thisReadErr := keepReading(this)
otherReadErr := keepReading(other)
err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)
err = this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/xerrors"
"github.com/coder/websocket/internal/errd"
)
// DialOptions represents the options available to pass to Dial.
// DialOptions represents Dial's options.
type DialOptions struct {
// HTTPClient is the http client used for the handshake.
// Its Transport must return writable bodies
// for WebSocket handshakes.
// http.Transport does this correctly beginning with Go 1.12.
// HTTPClient is used for the connection.
// Its Transport must return writable bodies for WebSocket handshakes.
// http.Transport does beginning with Go 1.12.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Subprotocols lists the subprotocols to negotiate with the server.
// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
// Dial performs a WebSocket handshake on the given url with the given options.
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
var cancel context.CancelFunc
var o DialOptions
if opts != nil {
o = *opts
}
if o.HTTPClient == nil {
o.HTTPClient = http.DefaultClient
}
if o.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
newClient := *o.HTTPClient
newClient.Timeout = 0
o.HTTPClient = &newClient
}
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
newClient := *o.HTTPClient
oldCheckRedirect := o.HTTPClient.CheckRedirect
newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect(req, via)
}
return nil
}
o.HTTPClient = &newClient
return ctx, cancel, &o
}
// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
// If an error occurs, the returned response may be non nil. However, you can only
// read the first 1024 bytes of its body.
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
//
// You never need to close the resp.Body yourself.
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
//
// This function requires at least Go 1.12 to succeed as it uses a new feature
// in net/http to perform WebSocket handshakes and get a writable body
// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861
func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) {
c, r, err := dial(ctx, u, opts)
// URLs with http/https schemes will work and are interpreted as ws/wss.
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
return dial(ctx, u, opts, nil)
}
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
defer errd.Wrap(&err, "failed to WebSocket dial")
var cancel context.CancelFunc
ctx, cancel, opts = opts.cloneWithDefaults(ctx)
if cancel != nil {
defer cancel()
}
secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
return nil, r, xerrors.Errorf("failed to websocket dial: %w", err)
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
return c, r, nil
}
func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) {
if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}
if opts.HTTPClient.Timeout > 0 {
return nil, nil, xerrors.Errorf("please use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
respBody := resp.Body
resp.Body = nil
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
})
defer timer.Stop()
b, _ := io.ReadAll(r)
respBody.Close()
resp.Body = io.NopCloser(bytes.NewReader(b))
}
}()
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
rwc, ok := respBody.(io.ReadWriteCloser)
if !ok {
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
}
parsedURL, err := url.Parse(u)
return newConn(connConfig{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
rwc: rwc,
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
}
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
u, err := url.Parse(urls)
if err != nil {
return nil, nil, xerrors.Errorf("failed to parse url: %w", err)
return nil, fmt.Errorf("failed to parse url: %w", err)
}
switch parsedURL.Scheme {
switch u.Scheme {
case "ws":
parsedURL.Scheme = "http"
u.Scheme = "http"
case "wss":
parsedURL.Scheme = "https"
u.Scheme = "https"
case "http", "https":
default:
return nil, nil, xerrors.Errorf("unexpected url scheme: %q", parsedURL.Scheme)
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, _ := http.NewRequest("GET", parsedURL.String(), nil)
req = req.WithContext(ctx)
req.Header = opts.HTTPHeader
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create new http request: %w", err)
}
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, nil, xerrors.Errorf("failed to send handshake request: %w", err)
return nil, fmt.Errorf("failed to send handshake request: %w", err)
}
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(resp.Body, 1024)
b, _ := ioutil.ReadAll(r)
resp.Body.Close()
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
}()
return resp, nil
}
err = verifyServerResponse(req, resp)
func secWebSocketKey(rr io.Reader) (string, error) {
if rr == nil {
rr = rand.Reader
}
b := make([]byte, 16)
_, err := io.ReadFull(rr, b)
if err != nil {
return nil, resp, err
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", rwc)
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
c := &Conn{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
closer: rwc,
client: true,
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
c.extractBufioWriterBuf(rwc)
c.init()
return c, resp, nil
}
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
func verifyServerResponse(r *http.Request, resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}
if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") {
return xerrors.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return nil, err
}
if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") {
return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
return verifyServerExtensions(copts, resp.Header)
}
func verifySubprotocol(subprotos []string, resp *http.Response) error {
proto := resp.Header.Get("Sec-WebSocket-Protocol")
if proto == "" {
return nil
}
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
return xerrors.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
r.Header.Get("Sec-WebSocket-Key"),
)
for _, sp2 := range subprotos {
if strings.EqualFold(sp2, proto) {
return nil
}
}
return nil
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}
// The below pools can only be used by the client because http.Hijacker will always
// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top.
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}
ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}
_copts := *copts
copts = &_copts
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
var bufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReader(nil)
},
return copts, nil
}
var bufioReaderPool sync.Pool
func getBufioReader(r io.Reader) *bufio.Reader {
br := bufioReaderPool.Get().(*bufio.Reader)
br, ok := bufioReaderPool.Get().(*bufio.Reader)
if !ok {
return bufio.NewReader(r)
}
br.Reset(r)
return br
}
func returnBufioReader(br *bufio.Reader) {
func putBufioReader(br *bufio.Reader) {
bufioReaderPool.Put(br)
}
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriter(nil)
},
}
var bufioWriterPool sync.Pool
func getBufioWriter(w io.Writer) *bufio.Writer {
bw := bufioWriterPool.Get().(*bufio.Writer)
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
if !ok {
return bufio.NewWriter(w)
}
bw.Reset(w)
return bw
}
func returnBufioWriter(bw *bufio.Writer) {
func putBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
func makeSecWebSocketKey() string {
b := make([]byte, 16)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
package websocket
//go:build !js
// +build !js
package websocket_test
import (
"bytes"
"context"
"crypto/rand"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/util"
"github.com/coder/websocket/internal/xsync"
)
func TestBadDials(t *testing.T) {
t.Parallel()
t.Run("badReq", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
url string
opts *websocket.DialOptions
rand util.ReaderFunc
nilCtx bool
}{
{
name: "badURL",
url: "://noscheme",
},
{
name: "badURLScheme",
url: "ftp://nhooyr.io",
},
{
name: "badTLS",
url: "wss://totallyfake.nhooyr.io",
},
{
name: "badReader",
rand: func(p []byte) (int, error) {
return 0, io.EOF
},
},
{
name: "nilContext",
url: "http://localhost",
nilCtx: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var ctx context.Context
var cancel func()
if !tc.nilCtx {
ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
}
if tc.rand == nil {
tc.rand = rand.Reader.Read
}
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
assert.Error(t, err)
})
}
})
t.Run("badResponse", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(strings.NewReader("hi")),
}, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
})
t.Run("badBody", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
rt := func(r *http.Request) (*http.Response, error) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: io.NopCloser(strings.NewReader("hi")),
}, nil
}
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
})
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
})
}
func Test_verifyHostOverride(t *testing.T) {
testCases := []struct {
name string
host string
exp string
}{
{
name: "noOverride",
host: "",
exp: "example.com",
},
{
name: "hostOverride",
host: "example.net",
exp: "example.net",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
rt := func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Host", tc.exp, r.Host)
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: mockBody{bytes.NewBufferString("hi")},
}, nil
}
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
assert.Success(t, err)
c.CloseNow()
})
}
}
type mockBody struct {
*bytes.Buffer
}
func (mb mockBody) Close() error {
return nil
}
func Test_verifyServerHandshake(t *testing.T) {
t.Parallel()
......@@ -48,6 +225,36 @@ func Test_verifyServerHandshake(t *testing.T) {
},
success: false,
},
{
name: "badSecWebSocketProtocol",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Protocol", "xd")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "unsupportedExtension",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Extensions", "meow")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "unsupportedDeflateParam",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "success",
response: func(w http.ResponseWriter) {
......@@ -69,17 +276,145 @@ func Test_verifyServerHandshake(t *testing.T) {
resp := w.Result()
r := httptest.NewRequest("GET", "/", nil)
key := makeSecWebSocketKey()
key, err := websocket.SecWebSocketKey(rand.Reader)
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
}
err := verifyServerResponse(r, resp)
if (err == nil) != tc.success {
t.Fatalf("unexpected error: %+v", err)
opts := &websocket.DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func mockHTTPClient(fn roundTripperFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func TestDialRedirect(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
resp := &http.Response{
Header: http.Header{},
}
if r.URL.Scheme != "https" {
resp.Header.Set("Location", "wss://example.com")
resp.StatusCode = http.StatusFound
return resp, nil
}
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "meow")
resp.StatusCode = http.StatusSwitchingProtocols
return resp, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
}
type forwardProxy struct {
hc *http.Client
}
func newForwardProxy() *forwardProxy {
return &forwardProxy{
hc: &http.Client{},
}
}
func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
r = r.WithContext(ctx)
r.RequestURI = ""
resp, err := fc.hc.Do(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer resp.Body.Close()
for k, v := range resp.Header {
w.Header()[k] = v
}
w.Header().Set("PROXIED", "true")
w.WriteHeader(resp.StatusCode)
if resprw, ok := resp.Body.(io.ReadWriter); ok {
c, brw, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
brw.Flush()
errc1 := xsync.Go(func() error {
_, err := io.Copy(c, resprw)
return err
})
errc2 := xsync.Go(func() error {
_, err := io.Copy(resprw, c)
return err
})
select {
case <-errc1:
case <-errc2:
case <-r.Context().Done():
}
} else {
io.Copy(w, resp.Body)
}
}
func TestDialViaProxy(t *testing.T) {
t.Parallel()
ps := httptest.NewServer(newForwardProxy())
defer ps.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, nil)
assert.Success(t, err)
}))
defer s.Close()
psu, err := url.Parse(ps.URL)
assert.Success(t, err)
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
proxyTransport.Proxy = http.ProxyURL(psu)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: proxyTransport,
},
})
assert.Success(t, err)
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
assertEcho(t, ctx, c)
assertClose(t, c)
}
// Package websocket is a minimal and idiomatic implementation of the WebSocket protocol.
//go:build !js
// +build !js
// Package websocket implements the RFC 6455 WebSocket protocol.
//
// See https://tools.ietf.org/html/rfc6455
// https://tools.ietf.org/html/rfc6455
//
// Conn, Dial, and Accept are the main entrypoints into this package. Use Dial to dial
// a WebSocket server, Accept to accept a WebSocket client dial and then Conn to interact
// with the resulting WebSocket connections.
// Use Dial to dial a WebSocket server.
//
// Use Accept to accept a WebSocket client.
//
// Conn represents the resulting WebSocket connection.
//
// The examples are the best way to understand how to correctly use the library.
//
// The wsjson and wspb subpackages contain helpers for JSON and ProtoBuf messages.
// The wsjson subpackage contain helpers for JSON and protobuf messages.
//
// More documentation at https://github.com/coder/websocket.
//
// # Wasm
//
// The client side supports compiling to Wasm.
// It wraps the WebSocket browser API.
//
// See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket
//
// Please see https://nhooyr.io/websocket for more overview docs and a
// comparison with existing implementations.
// Some important caveats to be aware of:
//
// Please be sure to use the https://golang.org/x/xerrors package when inspecting returned errors.
package websocket
// - Accept always errors out
// - Conn.Ping is no-op
// - Conn.CloseNow is Close(StatusGoingAway, "")
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
package websocket // import "github.com/coder/websocket"
* @nhooyr
# Contributing
## Issues
Please be as descriptive as possible with your description.
## Pull requests
Please split up changes into several small descriptive commits.
Please capitalize the first word in the commit message title.
The commit message title should use the verb tense + phrase that completes the blank in
> This change modifies websocket to ___________
Be sure to link to an existing issue if one exists. In general, try creating an issue
before making a PR to get some discussion going and to make sure you do not spend time
on a PR that may be rejected.
Run `ci/run.sh` to test your changes. You only need docker and bash to run the tests.
package websocket_test
import (
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"time"
"golang.org/x/time/rate"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
)
// This example starts a WebSocket echo server,
// dials the server and then sends 5 different messages
// and prints out the server's responses.
func Example_echo() {
// First we listen on port 0 which means the OS will
// assign us a random free port. This is the listener
// the server will serve on and the client will connect to.
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
log.Fatalf("failed to listen: %v", err)
}
defer l.Close()
s := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r)
if err != nil {
log.Printf("echo server: %v", err)
}
}),
ReadTimeout: time.Second * 15,
WriteTimeout: time.Second * 15,
}
defer s.Close()
// This starts the echo server on the listener.
go func() {
err := s.Serve(l)
if err != http.ErrServerClosed {
log.Fatalf("failed to listen and serve: %v", err)
}
}()
// Now we dial the server, send the messages and echo the responses.
err = client("ws://" + l.Addr().String())
if err != nil {
log.Fatalf("client failed: %v", err)
}
// Output:
// received: map[i:0]
// received: map[i:1]
// received: map[i:2]
// received: map[i:3]
// received: map[i:4]
}
// echoServer is the WebSocket echo server implementation.
// It ensures the client speaks the echo subprotocol and
// only allows one message every 100ms with a 10 message burst.
func echoServer(w http.ResponseWriter, r *http.Request) error {
log.Printf("serving %v", r.RemoteAddr)
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
if c.Subprotocol() != "echo" {
c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol")
return xerrors.Errorf("client does not speak echo sub protocol")
}
l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10)
for {
err = echo(r.Context(), c, l)
if err != nil {
return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err)
}
}
}
// echo reads from the websocket connection and then writes
// the received message back to it.
// The entire function has 10s to complete.
func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*10)
defer cancel()
err := l.Wait(ctx)
if err != nil {
return err
}
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
w, err := c.Writer(ctx, typ)
if err != nil {
return err
}
_, err = io.Copy(w, r)
if err != nil {
return xerrors.Errorf("failed to io.Copy: %w", err)
}
err = w.Close()
return err
}
// client dials the WebSocket echo server at the given url.
// It then sends it 5 different messages and echo's the server's
// response to each.
func client(url string) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, _, err := websocket.Dial(ctx, url, websocket.DialOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
for i := 0; i < 5; i++ {
err = wsjson.Write(ctx, c, map[string]int{
"i": i,
})
if err != nil {
return err
}
v := map[string]int{}
err = wsjson.Read(ctx, c, &v)
if err != nil {
return err
}
fmt.Printf("received: %v\n", v)
}
c.Close(websocket.StatusNormalClosure, "")
return nil
}
......@@ -6,20 +6,21 @@ import (
"net/http"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
// This example accepts a WebSocket connection, reads a single JSON
// message from the client and then closes the connection.
func ExampleAccept() {
// This handler accepts a WebSocket connection, reads a single JSON
// message from the client and then closes the connection.
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
c, err := websocket.Accept(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
......@@ -31,8 +32,6 @@ func ExampleAccept() {
return
}
log.Printf("received: %v", v)
c.Close(websocket.StatusNormalClosure, "")
})
......@@ -40,17 +39,18 @@ func ExampleAccept() {
log.Fatal(err)
}
// This example dials a server, writes a single JSON message and then
// closes the connection.
func ExampleDial() {
// Dials a server, writes a single JSON message and then
// closes the connection.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", websocket.DialOptions{})
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi")
if err != nil {
......@@ -60,16 +60,35 @@ func ExampleDial() {
c.Close(websocket.StatusNormalClosure, "")
}
// This example shows how to correctly handle a WebSocket connection
// on which you will only write and do not expect to read data messages.
func ExampleCloseStatus() {
// Dials a server and then expects to be disconnected with status code
// websocket.StatusNormalClosure.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil {
log.Fatal(err)
}
defer c.CloseNow()
_, _, err = c.Reader(ctx)
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err)
}
}
func Example_writeOnly() {
// This handler demonstrates how to correctly handle a write only WebSocket connection.
// i.e you only expect to write messages and do not expect to read any messages.
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
c, err := websocket.Accept(w, r, nil)
if err != nil {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10)
defer cancel()
......@@ -97,3 +116,56 @@ func Example_writeOnly() {
err := http.ListenAndServe("localhost:8080", fn)
log.Fatal(err)
}
func Example_crossOrigin() {
// This handler demonstrates how to safely accept cross origin WebSockets
// from the origin example.com.
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, &websocket.AcceptOptions{
OriginPatterns: []string{"example.com"},
})
if err != nil {
log.Println(err)
return
}
c.Close(websocket.StatusNormalClosure, "cross origin WebSocket accepted")
})
err := http.ListenAndServe("localhost:8080", fn)
log.Fatal(err)
}
func ExampleConn_Ping() {
// Dials a server and pings it 5 times.
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil)
if err != nil {
log.Fatal(err)
}
defer c.CloseNow()
// Required to read the Pongs from the server.
ctx = c.CloseRead(ctx)
for i := 0; i < 5; i++ {
err = c.Ping(ctx)
if err != nil {
log.Fatal(err)
}
}
c.Close(websocket.StatusNormalClosure, "")
}
// This example demonstrates full stack chat with an automated test.
func Example_fullStackChat() {
// https://github.com/nhooyr/websocket/tree/master/internal/examples/chat
}
// This example demonstrates a echo server.
func Example_echo() {
// https://github.com/nhooyr/websocket/tree/master/internal/examples/echo
}
//go:build !js
// +build !js
package websocket
var Compute = handleSecWebSocketKey
import (
"net"
"github.com/coder/websocket/internal/util"
)
func (c *Conn) RecordBytesWritten() *int {
var bytesWritten int
c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) {
bytesWritten += len(p)
return c.rwc.Write(p)
}))
return &bytesWritten
}
func (c *Conn) RecordBytesRead() *int {
var bytesRead int
c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) {
n, err := c.rwc.Read(p)
bytesRead += n
return n, err
}))
return &bytesRead
}
var ErrClosed = net.ErrClosed
var ExportedDial = dial
var SecWebSocketAccept = secWebSocketAccept
var SecWebSocketKey = secWebSocketKey
var VerifyServerResponse = verifyServerResponse
var CompressionModeOpts = CompressionMode.opts