good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
#!/usr/bin/env bash
#!/bin/sh
set -eu
cd -- "$(dirname "$0")/.."
set -euo pipefail
cd "$(dirname "${0}")"
cd "$(git rev-parse --show-toplevel)"
(
cd ./internal/examples
go test "$@" ./...
)
(
cd ./internal/thirdparty
go test "$@" ./...
)
mkdir -p ci/out/websocket
testFlags=(
-race
"-vet=off"
"-bench=."
"-coverprofile=ci/out/coverage.prof"
"-coverpkg=./..."
(
GOARCH=arm64 go test -c -o ./ci/out/websocket-arm64.test "$@" .
if [ "$#" -eq 0 ]; then
if [ "${CI-}" ]; then
sudo apt-get update
sudo apt-get install -y qemu-user-static
ln -s /usr/bin/qemu-aarch64-static /usr/local/bin/qemu-aarch64
fi
qemu-aarch64 ./ci/out/websocket-arm64.test -test.run=TestMask
fi
)
# https://circleci.com/docs/2.0/collect-test-data/
go run gotest.tools/gotestsum \
--junitfile ci/out/websocket/testReport.xml \
--format=short-verbose \
-- "${testFlags[@]}"
go install github.com/agnivade/wasmbrowsertest@8be019f6c6dceae821467b4c589eb195c2b761ce
go test --race --bench=. --timeout=1h --covermode=atomic --coverprofile=ci/out/coverage.prof --coverpkg=./... "$@" ./...
sed -i.bak '/stringer\.go/d' ci/out/coverage.prof
sed -i.bak '/nhooyr.io\/websocket\/internal\/test/d' ci/out/coverage.prof
sed -i.bak '/examples/d' ci/out/coverage.prof
# Last line is the total coverage.
go tool cover -func ci/out/coverage.prof | tail -n1
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
if [[ ${CI:-} ]]; then
bash <(curl -s https://codecov.io/bash) -f ci/out/coverage.prof
fi
//go:build !js
// +build !js
package websocket
import (
"context"
"encoding/binary"
"errors"
"fmt"
"net"
"time"
"github.com/coder/websocket/internal/errd"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// Close performs the WebSocket close handshake with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
// the peer to send a close frame.
// All data messages received from the peer during the close handshake will be discarded.
//
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection once
// complete.
func (c *Conn) Close(code StatusCode, reason string) (err error) {
defer errd.Wrap(&err, "failed to close WebSocket")
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.closeHandshake(code, reason)
err2 := c.close()
if err == nil && err2 != nil {
err = err2
}
err2 = c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
func (c *Conn) CloseNow() (err error) {
defer errd.Wrap(&err, "failed to immediately close WebSocket")
if c.casClosing() {
err = c.waitGoroutines()
if err != nil {
return err
}
return net.ErrClosed
}
defer func() {
if errors.Is(err, net.ErrClosed) {
err = nil
}
}()
err = c.close()
err2 := c.waitGoroutines()
if err == nil && err2 != nil {
err = err2
}
return err
}
func (c *Conn) closeHandshake(code StatusCode, reason string) error {
err := c.writeClose(code, reason)
if err != nil {
return err
}
err = c.waitCloseHandshake()
if CloseStatus(err) != code {
return err
}
return nil
}
func (c *Conn) writeClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
var p []byte
var err error
if ce.Code != StatusNoStatusRcvd {
p, err = ce.bytes()
if err != nil {
return err
}
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err = c.writeControl(ctx, opClose, p)
// If the connection closed as we're writing we ignore the error as we might
// have written the close frame, the peer responded and then someone else read it
// and closed the connection.
if err != nil && !errors.Is(err, net.ErrClosed) {
return err
}
return nil
}
func (c *Conn) waitCloseHandshake() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.readMu.lock(ctx)
if err != nil {
return err
}
defer c.readMu.unlock()
for i := int64(0); i < c.msgReader.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
for {
h, err := c.readLoop(ctx)
if err != nil {
return err
}
for i := int64(0); i < h.payloadLength; i++ {
_, err := c.br.ReadByte()
if err != nil {
return err
}
}
}
}
func (c *Conn) waitGoroutines() error {
t := time.NewTimer(time.Second * 15)
defer t.Stop()
select {
case <-c.timeoutLoopDone:
case <-t.C:
return errors.New("failed to wait for timeoutLoop goroutine to exit")
}
c.closeReadMu.Lock()
closeRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closeRead {
select {
case <-c.closeReadDone:
case <-t.C:
return errors.New("failed to wait for close read goroutine to exit")
}
}
select {
case <-c.closed:
case <-t.C:
return errors.New("failed to wait for connection to be closed")
}
return nil
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
func (ce CloseError) bytes() ([]byte, error) {
p, err := ce.bytesErr()
if err != nil {
err = fmt.Errorf("failed to marshal close frame: %w", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytesErr()
}
return p, err
}
const maxCloseReason = maxControlPayload - 2
func (ce CloseError) bytesErr() ([]byte, error) {
if len(ce.Reason) > maxCloseReason {
return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
func (c *Conn) casClosing() bool {
return c.closing.Swap(true)
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
//go:build !js
// +build !js
package websocket
import (
"io"
"math"
"strings"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func TestCloseError(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
ce CloseError
success bool
}{
{
name: "normal",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxCloseReason),
},
success: true,
},
{
name: "bigReason",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxCloseReason+1),
},
success: false,
},
{
name: "bigCode",
ce: CloseError{
Code: math.MaxUint16,
Reason: strings.Repeat("x", maxCloseReason),
},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := tc.ce.bytesErr()
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
t.Run("Error", func(t *testing.T) {
exp := `status = StatusInternalError and reason = "meow"`
act := CloseError{
Code: StatusInternalError,
Reason: "meow",
}.Error()
assert.Equal(t, "CloseError.Error()", exp, act)
})
}
func Test_parseClosePayload(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
p []byte
success bool
ce CloseError
}{
{
name: "normal",
p: append([]byte{0x3, 0xE8}, []byte("hello")...),
success: true,
ce: CloseError{
Code: StatusNormalClosure,
Reason: "hello",
},
},
{
name: "nothing",
success: true,
ce: CloseError{
Code: StatusNoStatusRcvd,
},
},
{
name: "oneByte",
p: []byte{0},
success: false,
},
{
name: "badStatusCode",
p: []byte{0x17, 0x70},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ce, err := parseClosePayload(tc.p)
if tc.success {
assert.Success(t, err)
assert.Equal(t, "close payload", tc.ce, ce)
} else {
assert.Error(t, err)
}
})
}
}
func Test_validWireCloseCode(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
code StatusCode
valid bool
}{
{
name: "normal",
code: StatusNormalClosure,
valid: true,
},
{
name: "noStatus",
code: StatusNoStatusRcvd,
valid: false,
},
{
name: "3000",
code: 3000,
valid: true,
},
{
name: "4999",
code: 4999,
valid: true,
},
{
name: "unknown",
code: 5000,
valid: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := validWireCloseCode(tc.code)
assert.Equal(t, "wire close code", tc.valid, act)
})
}
}
func TestCloseStatus(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
in error
exp StatusCode
}{
{
name: "nil",
in: nil,
exp: -1,
},
{
name: "io.EOF",
in: io.EOF,
exp: -1,
},
{
name: "StatusInternalError",
in: CloseError{
Code: StatusInternalError,
},
exp: StatusInternalError,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
act := CloseStatus(tc.in)
assert.Equal(t, "close status", tc.exp, act)
})
}
}
//go:build !js
// +build !js
package websocket
import (
"compress/flate"
"io"
"sync"
)
// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int
const (
// CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota
// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// previous messages. i.e compression context across messages is preserved.
//
// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
//
// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
//
// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
CompressionContextTakeover
// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
//
// This means less efficient compression as the sliding window from previous messages will not be used but the
// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// Especially if the connections are long lived and seldom written to.
//
// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover
)
func (m CompressionMode) opts() *compressionOptions {
return &compressionOptions{
clientNoContextTakeover: m == CompressionNoContextTakeover,
serverNoContextTakeover: m == CompressionNoContextTakeover,
}
}
type compressionOptions struct {
clientNoContextTakeover bool
serverNoContextTakeover bool
}
func (copts *compressionOptions) String() string {
s := "permessage-deflate"
if copts.clientNoContextTakeover {
s += "; client_no_context_takeover"
}
if copts.serverNoContextTakeover {
s += "; server_no_context_takeover"
}
return s
}
// These bytes are required to get flate.Reader to return.
// They are removed when sending to avoid the overhead as
// WebSocket framing tell's when the message has ended but then
// we need to add them back otherwise flate.Reader keeps
// trying to read more bytes.
const deflateMessageTail = "\x00\x00\xff\xff"
type trimLastFourBytesWriter struct {
w io.Writer
tail []byte
}
func (tw *trimLastFourBytesWriter) reset() {
if tw != nil && tw.tail != nil {
tw.tail = tw.tail[:0]
}
}
func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) {
if tw.tail == nil {
tw.tail = make([]byte, 0, 4)
}
extra := len(tw.tail) + len(p) - 4
if extra <= 0 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Now we need to write as many extra bytes as we can from the previous tail.
if extra > len(tw.tail) {
extra = len(tw.tail)
}
if extra > 0 {
_, err := tw.w.Write(tw.tail[:extra])
if err != nil {
return 0, err
}
// Shift remaining bytes in tail over.
n := copy(tw.tail, tw.tail[extra:])
tw.tail = tw.tail[:n]
}
// If p is less than or equal to 4 bytes,
// all of it is is part of the tail.
if len(p) <= 4 {
tw.tail = append(tw.tail, p...)
return len(p), nil
}
// Otherwise, only the last 4 bytes are.
tw.tail = append(tw.tail, p[len(p)-4:]...)
p = p[:len(p)-4]
n, err := tw.w.Write(p)
return n + 4, err
}
var flateReaderPool sync.Pool
func getFlateReader(r io.Reader, dict []byte) io.Reader {
fr, ok := flateReaderPool.Get().(io.Reader)
if !ok {
return flate.NewReaderDict(r, dict)
}
fr.(flate.Resetter).Reset(r, dict)
return fr
}
func putFlateReader(fr io.Reader) {
flateReaderPool.Put(fr)
}
var flateWriterPool sync.Pool
func getFlateWriter(w io.Writer) *flate.Writer {
fw, ok := flateWriterPool.Get().(*flate.Writer)
if !ok {
fw, _ = flate.NewWriter(w, flate.BestSpeed)
return fw
}
fw.Reset(w)
return fw
}
func putFlateWriter(w *flate.Writer) {
flateWriterPool.Put(w)
}
type slidingWindow struct {
buf []byte
}
var swPoolMu sync.RWMutex
var swPool = map[int]*sync.Pool{}
func slidingWindowPool(n int) *sync.Pool {
swPoolMu.RLock()
p, ok := swPool[n]
swPoolMu.RUnlock()
if ok {
return p
}
p = &sync.Pool{}
swPoolMu.Lock()
swPool[n] = p
swPoolMu.Unlock()
return p
}
func (sw *slidingWindow) init(n int) {
if sw.buf != nil {
return
}
if n == 0 {
n = 32768
}
p := slidingWindowPool(n)
sw2, ok := p.Get().(*slidingWindow)
if ok {
*sw = *sw2
} else {
sw.buf = make([]byte, 0, n)
}
}
func (sw *slidingWindow) close() {
sw.buf = sw.buf[:0]
swPoolMu.Lock()
swPool[cap(sw.buf)].Put(sw)
swPoolMu.Unlock()
}
func (sw *slidingWindow) write(p []byte) {
if len(p) >= cap(sw.buf) {
sw.buf = sw.buf[:cap(sw.buf)]
p = p[len(p)-cap(sw.buf):]
copy(sw.buf, p)
return
}
left := cap(sw.buf) - len(sw.buf)
if left < len(p) {
// We need to shift spaceNeeded bytes from the end to make room for p at the end.
spaceNeeded := len(p) - left
copy(sw.buf, sw.buf[spaceNeeded:])
sw.buf = sw.buf[:len(sw.buf)-spaceNeeded]
}
sw.buf = append(sw.buf, p...)
}
//go:build !js
// +build !js
package websocket
import (
"bytes"
"compress/flate"
"io"
"strings"
"testing"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
)
func Test_slidingWindow(t *testing.T) {
t.Parallel()
const testCount = 99
const maxWindow = 99999
for i := 0; i < testCount; i++ {
t.Run("", func(t *testing.T) {
t.Parallel()
input := xrand.String(maxWindow)
windowLength := xrand.Int(maxWindow)
var sw slidingWindow
sw.init(windowLength)
sw.write([]byte(input))
assert.Equal(t, "window length", windowLength, cap(sw.buf))
if !strings.HasSuffix(input, string(sw.buf)) {
t.Fatalf("r.buf is not a suffix of input: %q and %q", input, sw.buf)
}
})
}
}
func BenchmarkFlateWriter(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, _ := flate.NewWriter(io.Discard, flate.BestSpeed)
// We have to write a byte to get the writer to allocate to its full extent.
w.Write([]byte{'a'})
w.Flush()
}
}
func BenchmarkFlateReader(b *testing.B) {
b.ReportAllocs()
var buf bytes.Buffer
w, _ := flate.NewWriter(&buf, flate.BestSpeed)
w.Write([]byte{'a'})
w.Flush()
for i := 0; i < b.N; i++ {
r := flate.NewReader(bytes.NewReader(buf.Bytes()))
io.ReadAll(r)
}
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
"sync/atomic"
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader and Read.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See Reader and CloseRead.
//
// Be sure to call Close on the connection when you
// are finished with it to release associated resources.
//
// On any error from any method, the connection is closed
// with an appropriate reason.
//
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error
// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}
type connConfig struct {
subprotocol string
rwc io.ReadWriteCloser
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
br *bufio.Reader
bw *bufio.Writer
}
func newConn(cfg connConfig) *Conn {
c := &Conn{
subprotocol: cfg.subprotocol,
rwc: cfg.rwc,
client: cfg.client,
copts: cfg.copts,
flateThreshold: cfg.flateThreshold,
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}
c.readMu = newMu(c)
c.writeFrameMu = newMu(c)
c.msgReader = newMsgReader(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close()
})
go c.timeoutLoop()
return c
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.writeTimeout:
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.close()
return
case <-writeCtx.Done():
c.close()
return
}
}
}
func (c *Conn) flate() bool {
return c.copts != nil
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as it does
// not read from the connection but instead waits for a Reader call
// to read the pong.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context, p string) error {
pong := make(chan struct{}, 1)
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
}
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return net.ErrClosed
case <-ctx.Done():
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
// over the receive on closed.
select {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return net.ErrClosed
default:
}
return nil
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
This diff is collapsed.
//go:build !js
// +build !js
package websocket
import (
"bufio"
"bytes"
"context"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/xerrors"
"github.com/coder/websocket/internal/errd"
)
// DialOptions represents the options available to pass to Dial.
// DialOptions represents Dial's options.
type DialOptions struct {
// HTTPClient is the http client used for the handshake.
// Its Transport must return writable bodies
// for WebSocket handshakes.
// http.Transport does this correctly beginning with Go 1.12.
// HTTPClient is used for the connection.
// Its Transport must return writable bodies for WebSocket handshakes.
// http.Transport does beginning with Go 1.12.
HTTPClient *http.Client
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Subprotocols lists the subprotocols to negotiate with the server.
// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string
// CompressionMode controls the compression mode.
// Defaults to CompressionDisabled.
//
// See docs on CompressionMode for details.
CompressionMode CompressionMode
// CompressionThreshold controls the minimum size of a message before compression is applied.
//
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
// Dial performs a WebSocket handshake on the given url with the given options.
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
var cancel context.CancelFunc
var o DialOptions
if opts != nil {
o = *opts
}
if o.HTTPClient == nil {
o.HTTPClient = http.DefaultClient
}
if o.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
newClient := *o.HTTPClient
newClient.Timeout = 0
o.HTTPClient = &newClient
}
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
newClient := *o.HTTPClient
oldCheckRedirect := o.HTTPClient.CheckRedirect
newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect(req, via)
}
return nil
}
o.HTTPClient = &newClient
return ctx, cancel, &o
}
// Dial performs a WebSocket handshake on url.
//
// The response is the WebSocket handshake response from the server.
// If an error occurs, the returned response may be non nil. However, you can only
// read the first 1024 bytes of its body.
// You never need to close resp.Body yourself.
//
// If an error occurs, the returned response may be non nil.
// However, you can only read the first 1024 bytes of the body.
//
// You never need to close the resp.Body yourself.
// This function requires at least Go 1.12 as it uses a new feature
// in net/http to perform WebSocket handshakes.
// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
//
// This function requires at least Go 1.12 to succeed as it uses a new feature
// in net/http to perform WebSocket handshakes and get a writable body
// from the transport. See https://github.com/golang/go/issues/26937#issuecomment-415855861
func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) {
c, r, err := dial(ctx, u, opts)
// URLs with http/https schemes will work and are interpreted as ws/wss.
func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
return dial(ctx, u, opts, nil)
}
func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
defer errd.Wrap(&err, "failed to WebSocket dial")
var cancel context.CancelFunc
ctx, cancel, opts = opts.cloneWithDefaults(ctx)
if cancel != nil {
defer cancel()
}
secWebSocketKey, err := secWebSocketKey(rand)
if err != nil {
return nil, r, xerrors.Errorf("failed to websocket dial: %w", err)
return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
}
return c, r, nil
}
func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Response, err error) {
if opts.HTTPClient == nil {
opts.HTTPClient = http.DefaultClient
var copts *compressionOptions
if opts.CompressionMode != CompressionDisabled {
copts = opts.CompressionMode.opts()
}
if opts.HTTPClient.Timeout > 0 {
return nil, nil, xerrors.Errorf("please use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67")
resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
if err != nil {
return nil, resp, err
}
respBody := resp.Body
resp.Body = nil
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(respBody, 1024)
timer := time.AfterFunc(time.Second*3, func() {
respBody.Close()
})
defer timer.Stop()
b, _ := io.ReadAll(r)
respBody.Close()
resp.Body = io.NopCloser(bytes.NewReader(b))
}
}()
copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
if err != nil {
return nil, resp, err
}
if opts.HTTPHeader == nil {
opts.HTTPHeader = http.Header{}
rwc, ok := respBody.(io.ReadWriteCloser)
if !ok {
return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
}
parsedURL, err := url.Parse(u)
return newConn(connConfig{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
rwc: rwc,
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
}
func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
u, err := url.Parse(urls)
if err != nil {
return nil, nil, xerrors.Errorf("failed to parse url: %w", err)
return nil, fmt.Errorf("failed to parse url: %w", err)
}
switch parsedURL.Scheme {
switch u.Scheme {
case "ws":
parsedURL.Scheme = "http"
u.Scheme = "http"
case "wss":
parsedURL.Scheme = "https"
u.Scheme = "https"
case "http", "https":
default:
return nil, nil, xerrors.Errorf("unexpected url scheme: %q", parsedURL.Scheme)
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, _ := http.NewRequest("GET", parsedURL.String(), nil)
req = req.WithContext(ctx)
req.Header = opts.HTTPHeader
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create new http request: %w", err)
}
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
req.Header.Set("Sec-WebSocket-Version", "13")
req.Header.Set("Sec-WebSocket-Key", makeSecWebSocketKey())
req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
if len(opts.Subprotocols) > 0 {
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}
resp, err := opts.HTTPClient.Do(req)
if err != nil {
return nil, nil, xerrors.Errorf("failed to send handshake request: %w", err)
return nil, fmt.Errorf("failed to send handshake request: %w", err)
}
defer func() {
if err != nil {
// We read a bit of the body for easier debugging.
r := io.LimitReader(resp.Body, 1024)
b, _ := ioutil.ReadAll(r)
resp.Body.Close()
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
}
}()
return resp, nil
}
err = verifyServerResponse(req, resp)
func secWebSocketKey(rr io.Reader) (string, error) {
if rr == nil {
rr = rand.Reader
}
b := make([]byte, 16)
_, err := io.ReadFull(rr, b)
if err != nil {
return nil, resp, err
return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
}
return base64.StdEncoding.EncodeToString(b), nil
}
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
return nil, resp, xerrors.Errorf("response body is not a io.ReadWriteCloser: %T", rwc)
func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
if resp.StatusCode != http.StatusSwitchingProtocols {
return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
}
c := &Conn{
subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
closer: rwc,
client: true,
if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
}
c.extractBufioWriterBuf(rwc)
c.init()
return c, resp, nil
}
if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
}
func verifyServerResponse(r *http.Request, resp *http.Response) error {
if resp.StatusCode != http.StatusSwitchingProtocols {
return xerrors.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
secWebSocketKey,
)
}
if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") {
return xerrors.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
err := verifySubprotocol(opts.Subprotocols, resp)
if err != nil {
return nil, err
}
if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") {
return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
return verifyServerExtensions(copts, resp.Header)
}
func verifySubprotocol(subprotos []string, resp *http.Response) error {
proto := resp.Header.Get("Sec-WebSocket-Protocol")
if proto == "" {
return nil
}
if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")) {
return xerrors.Errorf("websocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
resp.Header.Get("Sec-WebSocket-Accept"),
r.Header.Get("Sec-WebSocket-Key"),
)
for _, sp2 := range subprotos {
if strings.EqualFold(sp2, proto) {
return nil
}
}
return nil
return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
}
// The below pools can only be used by the client because http.Hijacker will always
// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top.
func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
exts := websocketExtensions(h)
if len(exts) == 0 {
return nil, nil
}
ext := exts[0]
if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
}
_copts := *copts
copts = &_copts
for _, p := range ext.params {
switch p {
case "client_no_context_takeover":
copts.clientNoContextTakeover = true
continue
case "server_no_context_takeover":
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
var bufioReaderPool = sync.Pool{
New: func() interface{} {
return bufio.NewReader(nil)
},
return copts, nil
}
var bufioReaderPool sync.Pool
func getBufioReader(r io.Reader) *bufio.Reader {
br := bufioReaderPool.Get().(*bufio.Reader)
br, ok := bufioReaderPool.Get().(*bufio.Reader)
if !ok {
return bufio.NewReader(r)
}
br.Reset(r)
return br
}
func returnBufioReader(br *bufio.Reader) {
func putBufioReader(br *bufio.Reader) {
bufioReaderPool.Put(br)
}
var bufioWriterPool = sync.Pool{
New: func() interface{} {
return bufio.NewWriter(nil)
},
}
var bufioWriterPool sync.Pool
func getBufioWriter(w io.Writer) *bufio.Writer {
bw := bufioWriterPool.Get().(*bufio.Writer)
bw, ok := bufioWriterPool.Get().(*bufio.Writer)
if !ok {
return bufio.NewWriter(w)
}
bw.Reset(w)
return bw
}
func returnBufioWriter(bw *bufio.Writer) {
func putBufioWriter(bw *bufio.Writer) {
bufioWriterPool.Put(bw)
}
func makeSecWebSocketKey() string {
b := make([]byte, 16)
rand.Read(b)
return base64.StdEncoding.EncodeToString(b)
}
package websocket
//go:build !js
// +build !js
package websocket_test
import (
"bytes"
"context"
"crypto/rand"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/util"
"github.com/coder/websocket/internal/xsync"
)
func TestBadDials(t *testing.T) {
t.Parallel()
t.Run("badReq", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
url string
opts *websocket.DialOptions
rand util.ReaderFunc
nilCtx bool
}{
{
name: "badURL",
url: "://noscheme",
},
{
name: "badURLScheme",
url: "ftp://nhooyr.io",
},
{
name: "badTLS",
url: "wss://totallyfake.nhooyr.io",
},
{
name: "badReader",
rand: func(p []byte) (int, error) {
return 0, io.EOF
},
},
{
name: "nilContext",
url: "http://localhost",
nilCtx: true,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var ctx context.Context
var cancel func()
if !tc.nilCtx {
ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
}
if tc.rand == nil {
tc.rand = rand.Reader.Read
}
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
assert.Error(t, err)
})
}
})
t.Run("badResponse", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(strings.NewReader("hi")),
}, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: expected handshake response status code 101 but got 0")
})
t.Run("badBody", func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
rt := func(r *http.Request) (*http.Response, error) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: io.NopCloser(strings.NewReader("hi")),
}, nil
}
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
})
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
})
}
func Test_verifyHostOverride(t *testing.T) {
testCases := []struct {
name string
host string
exp string
}{
{
name: "noOverride",
host: "",
exp: "example.com",
},
{
name: "hostOverride",
host: "example.net",
exp: "example.net",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
rt := func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Host", tc.exp, r.Host)
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: mockBody{bytes.NewBufferString("hi")},
}, nil
}
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
assert.Success(t, err)
c.CloseNow()
})
}
}
type mockBody struct {
*bytes.Buffer
}
func (mb mockBody) Close() error {
return nil
}
func Test_verifyServerHandshake(t *testing.T) {
t.Parallel()
......@@ -48,6 +225,36 @@ func Test_verifyServerHandshake(t *testing.T) {
},
success: false,
},
{
name: "badSecWebSocketProtocol",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Protocol", "xd")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "unsupportedExtension",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Extensions", "meow")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "unsupportedDeflateParam",
response: func(w http.ResponseWriter) {
w.Header().Set("Connection", "Upgrade")
w.Header().Set("Upgrade", "websocket")
w.Header().Set("Sec-WebSocket-Extensions", "permessage-deflate; meow")
w.WriteHeader(http.StatusSwitchingProtocols)
},
success: false,
},
{
name: "success",
response: func(w http.ResponseWriter) {
......@@ -69,17 +276,145 @@ func Test_verifyServerHandshake(t *testing.T) {
resp := w.Result()
r := httptest.NewRequest("GET", "/", nil)
key := makeSecWebSocketKey()
key, err := websocket.SecWebSocketKey(rand.Reader)
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
}
err := verifyServerResponse(r, resp)
if (err == nil) != tc.success {
t.Fatalf("unexpected error: %+v", err)
opts := &websocket.DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
if tc.success {
assert.Success(t, err)
} else {
assert.Error(t, err)
}
})
}
}
func mockHTTPClient(fn roundTripperFunc) *http.Client {
return &http.Client{
Transport: fn,
}
}
type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func TestDialRedirect(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
resp := &http.Response{
Header: http.Header{},
}
if r.URL.Scheme != "https" {
resp.Header.Set("Location", "wss://example.com")
resp.StatusCode = http.StatusFound
return resp, nil
}
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "meow")
resp.StatusCode = http.StatusSwitchingProtocols
return resp, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
}
type forwardProxy struct {
hc *http.Client
}
func newForwardProxy() *forwardProxy {
return &forwardProxy{
hc: &http.Client{},
}
}
func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
r = r.WithContext(ctx)
r.RequestURI = ""
resp, err := fc.hc.Do(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer resp.Body.Close()
for k, v := range resp.Header {
w.Header()[k] = v
}
w.Header().Set("PROXIED", "true")
w.WriteHeader(resp.StatusCode)
if resprw, ok := resp.Body.(io.ReadWriter); ok {
c, brw, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
brw.Flush()
errc1 := xsync.Go(func() error {
_, err := io.Copy(c, resprw)
return err
})
errc2 := xsync.Go(func() error {
_, err := io.Copy(resprw, c)
return err
})
select {
case <-errc1:
case <-errc2:
case <-r.Context().Done():
}
} else {
io.Copy(w, resp.Body)
}
}
func TestDialViaProxy(t *testing.T) {
t.Parallel()
ps := httptest.NewServer(newForwardProxy())
defer ps.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, nil)
assert.Success(t, err)
}))
defer s.Close()
psu, err := url.Parse(ps.URL)
assert.Success(t, err)
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
proxyTransport.Proxy = http.ProxyURL(psu)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: proxyTransport,
},
})
assert.Success(t, err)
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
assertEcho(t, ctx, c)
assertClose(t, c)
}
// Package websocket is a minimal and idiomatic implementation of the WebSocket protocol.
//go:build !js
// +build !js
// Package websocket implements the RFC 6455 WebSocket protocol.
//
// See https://tools.ietf.org/html/rfc6455
// https://tools.ietf.org/html/rfc6455
//
// Conn, Dial, and Accept are the main entrypoints into this package. Use Dial to dial
// a WebSocket server, Accept to accept a WebSocket client dial and then Conn to interact
// with the resulting WebSocket connections.
// Use Dial to dial a WebSocket server.
//
// Use Accept to accept a WebSocket client.
//
// Conn represents the resulting WebSocket connection.
//
// The examples are the best way to understand how to correctly use the library.
//
// The wsjson and wspb subpackages contain helpers for JSON and ProtoBuf messages.
// The wsjson subpackage contain helpers for JSON and protobuf messages.
//
// More documentation at https://github.com/coder/websocket.
//
// # Wasm
//
// The client side supports compiling to Wasm.
// It wraps the WebSocket browser API.
//
// See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket
//
// Please see https://nhooyr.io/websocket for more overview docs and a
// comparison with existing implementations.
// Some important caveats to be aware of:
//
// Please be sure to use the https://golang.org/x/xerrors package when inspecting returned errors.
package websocket
// - Accept always errors out
// - Conn.Ping is no-op
// - Conn.CloseNow is Close(StatusGoingAway, "")
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
package websocket // import "github.com/coder/websocket"
* @nhooyr
# Contributing
## Issues
Please be as descriptive as possible with your description.
Reproducible examples are key to fixing bugs and strongly encouraged.
## Pull requests
Good issues for first time contributors are marked as such. Please feel free to
reach out for clarification on what needs to be done.
Split up large changes into several small descriptive commits.
Capitalize the first word in the commit message title.
The commit message title should use the verb tense + phrase that completes the blank in
> This change modifies websocket to \_\_\_\_\_\_\_\_\_
Be sure to link to an existing issue if one exists. In general, create an issue
before a PR to get some discussion going and to make sure you do not spend time
on a PR that may be rejected.
You can run tests normally with `go test`.
You'll need the [Autobahn Test suite pip package](https://github.com/crossbario/autobahn-testsuite).
In the future this dependency will be removed. See [#117](https://github.com/nhooyr/websocket/issues/117).
Please ensure CI passes for your changes. If necessary, you may run CI locally.
The various steps are located in `ci/*.sh`.
- `ci/fmt.sh` requires node (specifically prettier).
- `ci/lint.sh` requires [shellcheck](https://github.com/koalaman/shellcheck#installing).
- `ci/test.sh` requires the [Autobahn Test suite pip package](https://github.com/crossbario/autobahn-testsuite).
- `ci/run.sh` runs everything in the above order and requires all of their dependencies.
See [../ci/image/Dockerfile](../ci/image/Dockerfile) for the installation of the CI dependencies on Ubuntu.
For CI coverage, you can use either [codecov](https://codecov.io/gh/nhooyr/websocket) or click the
`coverage.html` artifact on the test step in CI.
For coverage details locally, please see `ci/out/coverage.html` after running `ci/run.sh` or `ci/test.sh`.
<!-- Please read the contributing guidelines. -->
<!-- https://github.com/nhooyr/websocket/blob/master/docs/CONTRIBUTING.md#pull-requests -->
<!-- Please read the contributing guidelines. -->
<!-- https://github.com/nhooyr/websocket/blob/master/docs/CONTRIBUTING.md#issues -->
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
module nhooyr.io/websocket
module github.com/coder/websocket
go 1.12
require (
github.com/golang/protobuf v1.3.1
github.com/google/go-cmp v0.2.0
github.com/kr/pretty v0.1.0 // indirect
go.coder.com/go-tools v0.0.0-20190317003359-0c6a35b74a16
golang.org/x/lint v0.0.0-20190409202823-959b441ac422
golang.org/x/net v0.0.0-20190424112056-4829fb13d2c6
golang.org/x/text v0.3.2 // indirect
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
golang.org/x/tools v0.0.0-20190429184909-35c670923e21
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522
gotest.tools/gotestsum v0.3.5
mvdan.cc/sh v2.6.4+incompatible
)
replace gotest.tools/gotestsum => github.com/nhooyr/gotestsum v0.3.6-0.20190821172136-aaabbb33254b
go 1.23