good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
......@@ -2,12 +2,12 @@ package websocket
import (
"context"
"fmt"
"io"
"math"
"net"
"sync/atomic"
"time"
"golang.org/x/xerrors"
)
// NetConn converts a *websocket.Conn into a net.Conn.
......@@ -17,33 +17,75 @@ import (
// correctly and so provided in the library.
// See https://github.com/nhooyr/websocket/issues/100.
//
// Every Write to the net.Conn will correspond to a binary message
// write on *webscoket.Conn.
// Every Write to the net.Conn will correspond to a message write of
// the given type on *websocket.Conn.
//
// The passed ctx bounds the lifetime of the net.Conn. If cancelled,
// all reads and writes on the net.Conn will be cancelled.
//
// If a message is read that is not of the correct type, the connection
// will be closed with StatusUnsupportedData and an error will be returned.
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return the real addresses for connections obtained
// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
// full net.Conn to us.
//
// When running as WASM, the Addr methods will always return the mock address described above.
//
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
// and "websocket/unknown-addr" for String.
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// A received StatusNormalClosure close frame will be translated to EOF when reading.
func NetConn(c *Conn) net.Conn {
// Furthermore, the ReadLimit is set to -1 to disable it.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
c.SetReadLimit(-1)
nc := &netConn{
c: c,
c: c,
msgType: msgType,
readMu: newMu(c),
writeMu: newMu(c),
}
var cancel context.CancelFunc
nc.writeContext, cancel = context.WithCancel(context.Background())
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
nc.readCtx, nc.readCancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.writeMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active write goroutine and so we should cancel the context.
nc.writeCancel()
return
}
defer nc.writeMu.unlock()
// Prevents future writes from writing until the deadline is reset.
nc.writeExpired.Store(1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}
nc.readContext, cancel = context.WithCancel(context.Background())
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.readMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active read goroutine and so we should cancel the context.
nc.readCancel()
return
}
defer nc.readMu.unlock()
// Prevents future reads from reading until the deadline is reset.
nc.readExpired.Store(1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
......@@ -52,57 +94,95 @@ func NetConn(c *Conn) net.Conn {
}
type netConn struct {
c *Conn
c *Conn
msgType MessageType
writeTimer *time.Timer
writeContext context.Context
writeMu *mu
writeExpired atomic.Int64
writeCtx context.Context
writeCancel context.CancelFunc
readTimer *time.Timer
readContext context.Context
eofed bool
reader io.Reader
readMu *mu
readExpired atomic.Int64
readCtx context.Context
readCancel context.CancelFunc
readEOFed bool
reader io.Reader
}
var _ net.Conn = &netConn{}
func (c *netConn) Close() error {
return c.c.Close(StatusNormalClosure, "")
func (nc *netConn) Close() error {
nc.writeTimer.Stop()
nc.writeCancel()
nc.readTimer.Stop()
nc.readCancel()
return nc.c.Close(StatusNormalClosure, "")
}
func (c *netConn) Write(p []byte) (int, error) {
err := c.c.Write(c.writeContext, MessageBinary, p)
func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()
if nc.writeExpired.Load() == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}
err := nc.c.Write(nc.writeCtx, nc.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}
func (c *netConn) Read(p []byte) (int, error) {
if c.eofed {
func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()
for {
n, err := nc.read(p)
if err != nil {
return n, err
}
if n == 0 {
continue
}
return n, nil
}
}
func (nc *netConn) read(p []byte) (int, error) {
if nc.readExpired.Load() == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}
if nc.readEOFed {
return 0, io.EOF
}
if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if nc.reader == nil {
typ, r, err := nc.c.Reader(nc.readCtx)
if err != nil {
var ce CloseError
if xerrors.As(err, &ce) && (ce.Code == StatusNormalClosure) {
c.eofed = true
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
nc.readEOFed = true
return 0, io.EOF
}
return 0, err
}
if typ != MessageBinary {
c.c.Close(StatusUnsupportedData, "can only accept binary messages")
return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", MessageBinary, typ)
if typ != nc.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
nc.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
c.reader = r
nc.reader = r
}
n, err := c.reader.Read(p)
n, err := nc.reader.Read(p)
if err == io.EOF {
c.reader = nil
nc.reader = nil
err = nil
}
return n, err
......@@ -119,26 +199,36 @@ func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}
func (c *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}
func (c *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}
func (c *netConn) SetDeadline(t time.Time) error {
c.SetWriteDeadline(t)
c.SetReadDeadline(t)
func (nc *netConn) SetDeadline(t time.Time) error {
nc.SetWriteDeadline(t)
nc.SetReadDeadline(t)
return nil
}
func (c *netConn) SetWriteDeadline(t time.Time) error {
c.writeTimer.Reset(t.Sub(time.Now()))
func (nc *netConn) SetWriteDeadline(t time.Time) error {
nc.writeExpired.Store(0)
if t.IsZero() {
nc.writeTimer.Stop()
} else {
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.writeTimer.Reset(dur)
}
return nil
}
func (c *netConn) SetReadDeadline(t time.Time) error {
c.readTimer.Reset(t.Sub(time.Now()))
func (nc *netConn) SetReadDeadline(t time.Time) error {
nc.readExpired.Store(0)
if t.IsZero() {
nc.readTimer.Stop()
} else {
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.readTimer.Reset(dur)
}
return nil
}
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}
//go:build !js
// +build !js
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.RemoteAddr()
}
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.LocalAddr()
}
return websocketAddr{}
}
package websocket
// opcode represents a WebSocket Opcode.
type opcode int
//go:generate go run golang.org/x/tools/cmd/stringer -type=opcode
// opcode constants.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
func (o opcode) controlOp() bool {
switch o {
case opClose, opPing, opPong:
return true
}
return false
}
// Code generated by "stringer -type=opcode"; DO NOT EDIT.
package websocket
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[opContinuation-0]
_ = x[opText-1]
_ = x[opBinary-2]
_ = x[opClose-8]
_ = x[opPing-9]
_ = x[opPong-10]
}
const (
_opcode_name_0 = "opContinuationopTextopBinary"
_opcode_name_1 = "opCloseopPingopPong"
)
var (
_opcode_index_0 = [...]uint8{0, 14, 20, 28}
_opcode_index_1 = [...]uint8{0, 7, 13, 19}
)
func (i opcode) String() string {
switch {
case 0 <= i && i <= 2:
return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]]
case 8 <= i && i <= 10:
i -= 8
return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]]
default:
return "opcode(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync/atomic"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Reader reads from the connection until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and the Read itself,
// use time.AfterFunc to cancel the context passed in.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return c.reader(ctx)
}
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := io.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
//
// This function is idempotent.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadDone = make(chan struct{})
c.closeReadMu.Unlock()
go func() {
defer close(c.closeReadDone)
defer cancel()
defer c.close()
_, _, err := c.Reader(ctx)
if err == nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
//
// Set to -1 to disable.
func (c *Conn) SetReadLimit(n int64) {
if n >= 0 {
// We read one more byte than the limit in case
// there is a fin frame that needs to be read.
n++
}
c.msgReader.limitReader.limit.Store(n)
}
const defaultReadLimit = 32768
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.readFunc = mr.read
mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
return mr
}
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
func (mr *msgReader) putFlateReader() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
if mr.dict != nil {
mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
if mr.c.client {
putBufioReader(mr.c.br)
mr.c.br = nil
}
}
func (mr *msgReader) flateContextTakeover() bool {
if mr.c.client {
return !mr.c.copts.serverNoContextTakeover
}
return !mr.c.copts.clientNoContextTakeover
}
func (c *Conn) readRSV1Illegal(h header) bool {
// If compression is disabled, rsv1 is illegal.
if !c.flate() {
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
c.writeError(StatusProtocolError, err)
return header{}, err
}
if !c.client && !h.masked {
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
c.writeError(StatusProtocolError, err)
return header{}, err
}
}
}
// prepareRead sets the readTimeout context and returns a done function
// to be called after the read is done. It also returns an error if the
// connection is closed. The reference to the error is used to assign
// an error depending on if the connection closed or the context timed
// out during use. Typically the referenced error is a named return
// variable of the function calling this method.
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
select {
case <-c.closed:
return nil, net.ErrClosed
case c.readTimeout <- ctx:
}
done := func() {
select {
case <-c.closed:
if *err != nil {
*err = net.ErrClosed
}
case c.readTimeout <- context.Background():
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}
c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
defer done()
return nil, closeReceivedErr
}
return done, nil
}
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return header{}, err
}
defer readDone()
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
return header{}, err
}
return h, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return 0, err
}
defer readDone()
n, err := io.ReadFull(c.br, p)
if err != nil {
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
return n, err
}
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
c.writeError(StatusProtocolError, err)
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.readControlBuf[:h.payloadLength]
_, err = c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
mask(b, h.maskKey)
}
switch h.opcode {
case opPing:
if c.onPingReceived != nil {
if !c.onPingReceived(ctx, b) {
return nil
}
}
return c.writeControl(ctx, opPong, b)
case opPong:
if c.onPongReceived != nil {
c.onPongReceived(ctx, b)
}
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
select {
case pong <- struct{}{}:
default:
}
}
return nil
}
// opClose
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
c.writeError(StatusProtocolError, err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
c.closeStateMu.Lock()
c.closeReceivedErr = err
closeSent := c.closeSentErr != nil
c.closeStateMu.Unlock()
// Only unlock readMu if this connection is being closed becaue
// c.close will try to acquire the readMu lock. We unlock for
// writeClose as well because it may also call c.close.
if !closeSent {
c.readMu.unlock()
_ = c.writeClose(ce.Code, ce.Reason)
}
if !c.casClosing() {
c.readMu.unlock()
_ = c.close()
}
return err
}
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.lock(ctx)
if err != nil {
return 0, nil, err
}
defer c.readMu.unlock()
if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
}
h, err := c.readLoop(ctx)
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
c.writeError(StatusProtocolError, err)
return 0, nil, err
}
c.msgReader.reset(ctx, h)
return MessageType(h.opcode), c.msgReader, nil
}
type msgReader struct {
c *Conn
ctx context.Context
flate bool
flateReader io.Reader
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict *slidingWindow
fin bool
payloadLength int64
maskKey uint32
// util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc util.ReaderFunc
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.flate = h.rsv1
mr.limitReader.reset(mr.readFunc)
if mr.flate {
mr.resetFlate()
}
mr.setFrame(h)
}
func (mr *msgReader) setFrame(h header) {
mr.fin = h.fin
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()
n, err = mr.limitReader.Read(p)
if mr.flate && mr.flateContextTakeover() {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
return n, fmt.Errorf("failed to read: %w", err)
}
return n, nil
}
func (mr *msgReader) read(p []byte) (int, error) {
for {
if mr.payloadLength == 0 {
if mr.fin {
if mr.flate {
return mr.flateTail.Read(p)
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
mr.c.writeError(StatusProtocolError, err)
return 0, err
}
mr.setFrame(h)
continue
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
n, err := mr.c.readFramePayload(mr.ctx, p)
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
if !mr.c.client {
mr.maskKey = mask(p, mr.maskKey)
}
return n, nil
}
}
type limitReader struct {
c *Conn
r io.Reader
limit atomic.Int64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
lr.reset(r)
return lr
}
func (lr *limitReader) reset(r io.Reader) {
lr.n = lr.limit.Load()
lr.r = r
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n < 0 {
return lr.r.Read(p)
}
if lr.n == 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
if lr.n < 0 {
lr.n = 0
}
return n, err
}
package websocket
import (
"encoding/binary"
"fmt"
"golang.org/x/xerrors"
)
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
//go:generate go run golang.org/x/tools/cmd/stringer -type=StatusCode
// These codes were retrieved from:
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
const (
StatusNormalClosure StatusCode = 1000 + iota
StatusGoingAway
StatusProtocolError
StatusUnsupportedData
_ // 1004 is reserved.
StatusNoStatusRcvd
// statusAbnormalClosure is unexported because it isn't necessary, at least until WASM.
// The error returned will indicate whether the connection was closed or not or what happened.
// It only makes sense for browser clients.
statusAbnormalClosure
StatusInvalidFramePayloadData
StatusPolicyViolation
StatusMessageTooBig
StatusMandatoryExtension
StatusInternalError
StatusServiceRestart
StatusTryAgainLater
StatusBadGateway
// statusTLSHandshake is unexported because we just return
// handshake error in dial. We do not return a conn
// so there is nothing to use this on. At least until WASM.
statusTLSHandshake
)
// CloseError represents a WebSocket close frame.
// It is returned by Conn's methods when a WebSocket close frame is received from
// the peer.
// You will need to use https://golang.org/x/xerrors to check for this error.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
func parseClosePayload(p []byte) (CloseError, error) {
if len(p) == 0 {
return CloseError{
Code: StatusNoStatusRcvd,
}, nil
}
if len(p) < 2 {
return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
}
ce := CloseError{
Code: StatusCode(binary.BigEndian.Uint16(p)),
Reason: string(p[2:]),
}
if !validWireCloseCode(ce.Code) {
return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code)
}
return ce, nil
}
// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
// and https://tools.ietf.org/html/rfc6455#section-7.4.1
func validWireCloseCode(code StatusCode) bool {
switch code {
case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake:
return false
}
if code >= StatusNormalClosure && code <= StatusBadGateway {
return true
}
if code >= 3000 && code <= 4999 {
return true
}
return false
}
const maxControlFramePayload = 125
func (ce CloseError) bytes() ([]byte, error) {
if len(ce.Reason) > maxControlFramePayload-2 {
return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason))
}
if !validWireCloseCode(ce.Code) {
return nil, xerrors.Errorf("status code %v cannot be set", ce.Code)
}
buf := make([]byte, 2+len(ce.Reason))
binary.BigEndian.PutUint16(buf, uint16(ce.Code))
copy(buf[2:], ce.Reason)
return buf, nil
}
// Code generated by "stringer -type=StatusCode"; DO NOT EDIT.
package websocket
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[StatusNormalClosure-1000]
_ = x[StatusGoingAway-1001]
_ = x[StatusProtocolError-1002]
_ = x[StatusUnsupportedData-1003]
_ = x[StatusNoStatusRcvd-1005]
_ = x[statusAbnormalClosure-1006]
_ = x[StatusInvalidFramePayloadData-1007]
_ = x[StatusPolicyViolation-1008]
_ = x[StatusMessageTooBig-1009]
_ = x[StatusMandatoryExtension-1010]
_ = x[StatusInternalError-1011]
_ = x[StatusServiceRestart-1012]
_ = x[StatusTryAgainLater-1013]
_ = x[StatusBadGateway-1014]
_ = x[statusTLSHandshake-1015]
}
const (
_StatusCode_name_0 = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedData"
_StatusCode_name_1 = "StatusNoStatusRcvdstatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewaystatusTLSHandshake"
)
var (
_StatusCode_index_0 = [...]uint8{0, 19, 34, 53, 74}
_StatusCode_index_1 = [...]uint8{0, 18, 39, 68, 89, 108, 132, 151, 171, 190, 206, 224}
)
func (i StatusCode) String() string {
switch {
case 1000 <= i && i <= 1003:
i -= 1000
return _StatusCode_name_0[_StatusCode_index_0[i]:_StatusCode_index_0[i+1]]
case 1005 <= i && i <= 1015:
i -= 1005
return _StatusCode_name_1[_StatusCode_index_1[i]:_StatusCode_index_1[i+1]]
default:
return "StatusCode(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
package websocket
import (
"math"
"strings"
"testing"
)
func TestCloseError(t *testing.T) {
t.Parallel()
// Other parts of close error are tested by websocket_test.go right now
// with the autobahn tests.
testCases := []struct {
name string
ce CloseError
success bool
}{
{
name: "normal",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxControlFramePayload-2),
},
success: true,
},
{
name: "bigReason",
ce: CloseError{
Code: StatusNormalClosure,
Reason: strings.Repeat("x", maxControlFramePayload-1),
},
success: false,
},
{
name: "bigCode",
ce: CloseError{
Code: math.MaxUint16,
Reason: strings.Repeat("x", maxControlFramePayload-2),
},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := tc.ce.bytes()
if (err == nil) != tc.success {
t.Fatalf("unexpected error value: %v", err)
}
})
}
}
// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT.
package websocket
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[opContinuation-0]
_ = x[opText-1]
_ = x[opBinary-2]
_ = x[opClose-8]
_ = x[opPing-9]
_ = x[opPong-10]
}
const (
_opcode_name_0 = "opContinuationopTextopBinary"
_opcode_name_1 = "opCloseopPingopPong"
)
var (
_opcode_index_0 = [...]uint8{0, 14, 20, 28}
_opcode_index_1 = [...]uint8{0, 7, 13, 19}
)
func (i opcode) String() string {
switch {
case 0 <= i && i <= 2:
return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]]
case 8 <= i && i <= 10:
i -= 8
return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]]
default:
return "opcode(" + strconv.FormatInt(int64(i), 10) + ")"
}
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[MessageText-1]
_ = x[MessageBinary-2]
}
const _MessageType_name = "MessageTextMessageBinary"
var _MessageType_index = [...]uint8{0, 11, 24}
func (i MessageType) String() string {
i -= 1
if i < 0 || i >= MessageType(len(_MessageType_index)-1) {
return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")"
}
return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]]
}
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[StatusNormalClosure-1000]
_ = x[StatusGoingAway-1001]
_ = x[StatusProtocolError-1002]
_ = x[StatusUnsupportedData-1003]
_ = x[statusReserved-1004]
_ = x[StatusNoStatusRcvd-1005]
_ = x[StatusAbnormalClosure-1006]
_ = x[StatusInvalidFramePayloadData-1007]
_ = x[StatusPolicyViolation-1008]
_ = x[StatusMessageTooBig-1009]
_ = x[StatusMandatoryExtension-1010]
_ = x[StatusInternalError-1011]
_ = x[StatusServiceRestart-1012]
_ = x[StatusTryAgainLater-1013]
_ = x[StatusBadGateway-1014]
_ = x[StatusTLSHandshake-1015]
}
const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake"
var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312}
func (i StatusCode) String() string {
i -= 1000
if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) {
return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")"
}
return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]]
}
// +build tools
package tools
// See https://github.com/go-modules-by-example/index/blob/master/010_tools/README.md
import (
_ "go.coder.com/go-tools/cmd/goimports"
_ "golang.org/x/lint/golint"
_ "golang.org/x/tools/cmd/stringer"
_ "mvdan.cc/sh/cmd/shfmt"
)
package websocket
import (
"bufio"
"context"
cryptorand "crypto/rand"
"fmt"
"io"
"io/ioutil"
"math/rand"
"os"
"runtime"
"strconv"
"sync"
"sync/atomic"
"time"
"golang.org/x/xerrors"
)
// Conn represents a WebSocket connection.
// All methods may be called concurrently except for Reader, Read
// and SetReadLimit.
//
// You must always read from the connection. Otherwise control
// frames will not be handled. See the docs on Reader and CloseRead.
//
// Please be sure to call Close on the connection when you
// are finished with it to release the associated resources.
//
// Every error from Read or Reader will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
type Conn struct {
subprotocol string
br *bufio.Reader
bw *bufio.Writer
// writeBuf is used for masking, its the buffer in bufio.Writer.
// Only used by the client for masking the bytes in the buffer.
writeBuf []byte
closer io.Closer
client bool
closeOnce sync.Once
closeErr error
closed chan struct{}
// writeMsgLock is acquired to write a data message.
writeMsgLock chan struct{}
// writeFrameLock is acquired to write a single frame.
// Effectively meaning whoever holds it gets to write to bw.
writeFrameLock chan struct{}
writeHeaderBuf []byte
writeHeader *header
// read limit for a message in bytes.
msgReadLimit int64
// messageWriter state.
writeMsgOpcode opcode
writeMsgCtx context.Context
readMsgLeft int64
// Used to ensure the previous reader is read till EOF before allowing
// a new one.
previousReader *messageReader
// readFrameLock is acquired to read from bw.
readFrameLock chan struct{}
readClosed int64
readHeaderBuf []byte
controlPayloadBuf []byte
// messageReader state
readMsgCtx context.Context
readMsgHeader header
readFrameEOF bool
readMaskPos int
setReadTimeout chan context.Context
setWriteTimeout chan context.Context
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
}
func (c *Conn) init() {
c.closed = make(chan struct{})
c.msgReadLimit = 32768
c.writeMsgLock = make(chan struct{}, 1)
c.writeFrameLock = make(chan struct{}, 1)
c.readFrameLock = make(chan struct{}, 1)
c.setReadTimeout = make(chan context.Context)
c.setWriteTimeout = make(chan context.Context)
c.activePings = make(map[string]chan<- struct{})
c.writeHeaderBuf = makeWriteHeaderBuf()
c.writeHeader = &header{}
c.readHeaderBuf = makeReadHeaderBuf()
c.controlPayloadBuf = make([]byte, maxControlFramePayload)
runtime.SetFinalizer(c, func(c *Conn) {
c.close(xerrors.New("connection garbage collected"))
})
go c.timeoutLoop()
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close(err error) {
c.closeOnce.Do(func() {
runtime.SetFinalizer(c, nil)
c.closeErr = xerrors.Errorf("websocket closed: %w", err)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.closer.Close()
// See comment in dial.go
if c.client {
// By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer
// and we can safely return them.
// Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent
// a deadlock.
// As of now, this is in writeFrame, readFramePayload and readHeader.
c.readFrameLock <- struct{}{}
returnBufioReader(c.br)
c.writeFrameLock <- struct{}{}
returnBufioWriter(c.bw)
}
})
}
func (c *Conn) timeoutLoop() {
readCtx := context.Background()
writeCtx := context.Background()
for {
select {
case <-c.closed:
return
case writeCtx = <-c.setWriteTimeout:
case readCtx = <-c.setReadTimeout:
case <-readCtx.Done():
c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err()))
case <-writeCtx.Done():
c.close(xerrors.Errorf("data write timed out: %w", writeCtx.Err()))
}
}
}
func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error {
select {
case <-ctx.Done():
var err error
switch lock {
case c.writeFrameLock, c.writeMsgLock:
err = xerrors.Errorf("could not acquire write lock: %v", ctx.Err())
case c.readFrameLock:
err = xerrors.Errorf("could not acquire read lock: %v", ctx.Err())
default:
panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err()))
}
c.close(err)
return ctx.Err()
case <-c.closed:
return c.closeErr
case lock <- struct{}{}:
return nil
}
}
func (c *Conn) releaseLock(lock chan struct{}) {
// Allow multiple releases.
select {
case <-lock:
default:
}
}
func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 || h.rsv2 || h.rsv3 {
err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
c.Close(StatusProtocolError, err.Error())
return header{}, err
}
if h.opcode.controlOp() {
err = c.handleControl(ctx, h)
if err != nil {
return header{}, xerrors.Errorf("failed to handle control frame: %w", err)
}
continue
}
switch h.opcode {
case opBinary, opText, opContinuation:
return h, nil
default:
err := xerrors.Errorf("received unknown opcode %v", h.opcode)
c.Close(StatusProtocolError, err.Error())
return header{}, err
}
}
}
func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
err := c.acquireLock(context.Background(), c.readFrameLock)
if err != nil {
return header{}, err
}
defer c.releaseLock(c.readFrameLock)
select {
case <-c.closed:
return header{}, c.closeErr
case c.setReadTimeout <- ctx:
}
h, err := readHeader(c.readHeaderBuf, c.br)
if err != nil {
select {
case <-c.closed:
return header{}, c.closeErr
case <-ctx.Done():
err = ctx.Err()
default:
}
err := xerrors.Errorf("failed to read header: %w", err)
c.releaseLock(c.readFrameLock)
c.close(err)
return header{}, err
}
select {
case <-c.closed:
return header{}, c.closeErr
case c.setReadTimeout <- context.Background():
}
return h, nil
}
func (c *Conn) handleControl(ctx context.Context, h header) error {
if h.payloadLength > maxControlFramePayload {
err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength)
c.Close(StatusProtocolError, err.Error())
return err
}
if !h.fin {
err := xerrors.Errorf("received fragmented control frame")
c.Close(StatusProtocolError, err.Error())
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.controlPayloadBuf[:h.payloadLength]
_, err := c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
fastXOR(h.maskKey, 0, b)
}
switch h.opcode {
case opPing:
return c.writePong(b)
case opPong:
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
close(pong)
}
return nil
case opClose:
ce, err := parseClosePayload(b)
if err != nil {
c.Close(StatusProtocolError, "received invalid close payload")
return xerrors.Errorf("received invalid close payload: %w", err)
}
c.writeClose(b, xerrors.Errorf("received close frame: %w", ce))
return c.closeErr
default:
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
}
}
// Reader waits until there is a WebSocket data message to read
// from the connection.
// It returns the type of the message and a reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// All returned errors will cause the connection
// to be closed so you do not need to write your own error message.
// This applies to the Read methods in the wsjson/wspb subpackages as well.
//
// You must read from the connection for control frames to be handled.
// If you do not expect any data messages from the peer, call CloseRead.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and then the message
// Read, use time.AfterFunc to cancel the context passed in early.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
if atomic.LoadInt64(&c.readClosed) == 1 {
return 0, nil, xerrors.Errorf("websocket connection read closed")
}
typ, r, err := c.reader(ctx)
if err != nil {
return 0, nil, xerrors.Errorf("failed to get reader: %w", err)
}
return typ, r, nil
}
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
if c.previousReader != nil && !c.readFrameEOF {
// The only way we know for sure the previous reader is not yet complete is
// if there is an active frame not yet fully read.
// Otherwise, a user may have read the last byte but not the EOF if the EOF
// is in the next frame so we check for that below.
return 0, nil, xerrors.Errorf("previous message not read to completion")
}
h, err := c.readTillMsg(ctx)
if err != nil {
return 0, nil, err
}
if c.previousReader != nil && !c.previousReader.eof {
if h.opcode != opContinuation {
err := xerrors.Errorf("received new data message without finishing the previous message")
c.Close(StatusProtocolError, err.Error())
return 0, nil, err
}
if !h.fin || h.payloadLength > 0 {
return 0, nil, xerrors.Errorf("previous message not read to completion")
}
c.previousReader.eof = true
h, err = c.readTillMsg(ctx)
if err != nil {
return 0, nil, err
}
} else if h.opcode == opContinuation {
err := xerrors.Errorf("received continuation frame not after data or text frame")
c.Close(StatusProtocolError, err.Error())
return 0, nil, err
}
c.readMsgCtx = ctx
c.readMsgHeader = h
c.readFrameEOF = false
c.readMaskPos = 0
c.readMsgLeft = c.msgReadLimit
r := &messageReader{
c: c,
}
c.previousReader = r
return MessageType(h.opcode), r, nil
}
// CloseRead will start a goroutine to read from the connection until it is closed or a data message
// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
// After calling this method, you cannot read any data messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// Use this when you do not want to read data messages from the connection anymore but will
// want to write messages to it.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
atomic.StoreInt64(&c.readClosed, 1)
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
// We use the unexported reader so that we don't get the read closed error.
c.reader(ctx)
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
}
// messageReader enables reading a data frame from the WebSocket connection.
type messageReader struct {
c *Conn
eof bool
}
// Read reads as many bytes as possible into p.
func (r *messageReader) Read(p []byte) (int, error) {
n, err := r.read(p)
if err != nil {
// Have to return io.EOF directly for now, we cannot wrap as xerrors
// isn't used in stdlib.
if xerrors.Is(err, io.EOF) {
return n, io.EOF
}
return n, xerrors.Errorf("failed to read: %w", err)
}
return n, nil
}
func (r *messageReader) read(p []byte) (int, error) {
if r.eof {
return 0, xerrors.Errorf("cannot use EOFed reader")
}
if r.c.readMsgLeft <= 0 {
err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit)
r.c.Close(StatusMessageTooBig, err.Error())
return 0, err
}
if int64(len(p)) > r.c.readMsgLeft {
p = p[:r.c.readMsgLeft]
}
if r.c.readFrameEOF {
h, err := r.c.readTillMsg(r.c.readMsgCtx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := xerrors.Errorf("received new data message without finishing the previous message")
r.c.Close(StatusProtocolError, err.Error())
return 0, err
}
r.c.readMsgHeader = h
r.c.readFrameEOF = false
r.c.readMaskPos = 0
}
h := r.c.readMsgHeader
if int64(len(p)) > h.payloadLength {
p = p[:h.payloadLength]
}
n, err := r.c.readFramePayload(r.c.readMsgCtx, p)
h.payloadLength -= int64(n)
r.c.readMsgLeft -= int64(n)
if h.masked {
r.c.readMaskPos = fastXOR(h.maskKey, r.c.readMaskPos, p)
}
r.c.readMsgHeader = h
if err != nil {
return n, err
}
if h.payloadLength == 0 {
r.c.readFrameEOF = true
if h.fin {
r.eof = true
return n, io.EOF
}
}
return n, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
err := c.acquireLock(ctx, c.readFrameLock)
if err != nil {
return 0, err
}
defer c.releaseLock(c.readFrameLock)
select {
case <-c.closed:
return 0, c.closeErr
case c.setReadTimeout <- ctx:
}
n, err := io.ReadFull(c.br, p)
if err != nil {
select {
case <-c.closed:
return n, c.closeErr
case <-ctx.Done():
err = ctx.Err()
default:
}
err = xerrors.Errorf("failed to read frame payload: %w", err)
c.releaseLock(c.readFrameLock)
c.close(err)
return n, err
}
select {
case <-c.closed:
return n, c.closeErr
case c.setReadTimeout <- context.Background():
}
return n, err
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit = n
}
// Read is a convenience method to read a single message from the connection.
//
// See the Reader method if you want to be able to reuse buffers or want to stream a message.
// The docs on Reader apply to this method as well.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := ioutil.ReadAll(r)
return typ, b, err
}
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
wc, err := c.writer(ctx, typ)
if err != nil {
return nil, xerrors.Errorf("failed to get writer: %w", err)
}
return wc, nil
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.acquireLock(ctx, c.writeMsgLock)
if err != nil {
return nil, err
}
c.writeMsgCtx = ctx
c.writeMsgOpcode = opcode(typ)
return &messageWriter{
c: c,
}, nil
}
// Write is a convenience method to write a message to the connection.
//
// See the Writer method if you want to stream a message. The docs on Writer
// regarding concurrency also apply to this method.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return xerrors.Errorf("failed to write msg: %w", err)
}
return nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
err := c.acquireLock(ctx, c.writeMsgLock)
if err != nil {
return 0, err
}
defer c.releaseLock(c.writeMsgLock)
n, err := c.writeFrame(ctx, true, opcode(typ), p)
return n, err
}
// messageWriter enables writing to a WebSocket connection.
type messageWriter struct {
c *Conn
closed bool
}
// Write writes the given bytes to the WebSocket connection.
func (w *messageWriter) Write(p []byte) (int, error) {
n, err := w.write(p)
if err != nil {
return n, xerrors.Errorf("failed to write: %w", err)
}
return n, nil
}
func (w *messageWriter) write(p []byte) (int, error) {
if w.closed {
return 0, xerrors.Errorf("cannot use closed writer")
}
n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p)
if err != nil {
return n, xerrors.Errorf("failed to write data frame: %w", err)
}
w.c.writeMsgOpcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
// This must be called for every messageWriter.
func (w *messageWriter) Close() error {
err := w.close()
if err != nil {
return xerrors.Errorf("failed to close writer: %w", err)
}
return nil
}
func (w *messageWriter) close() error {
if w.closed {
return xerrors.Errorf("cannot use closed writer")
}
w.closed = true
_, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil)
if err != nil {
return xerrors.Errorf("failed to write fin frame: %w", err)
}
w.c.releaseLock(w.c.writeMsgLock)
return nil
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
_, err := c.writeFrame(ctx, true, opcode, p)
if err != nil {
return xerrors.Errorf("failed to write control frame: %w", err)
}
return nil
}
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
err := c.acquireLock(ctx, c.writeFrameLock)
if err != nil {
return 0, err
}
defer c.releaseLock(c.writeFrameLock)
select {
case <-c.closed:
return 0, c.closeErr
case c.setWriteTimeout <- ctx:
}
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.masked = c.client
c.writeHeader.payloadLength = int64(len(p))
if c.client {
_, err := io.ReadFull(cryptorand.Reader, c.writeHeader.maskKey[:])
if err != nil {
return 0, xerrors.Errorf("failed to generate masking key: %w", err)
}
}
n, err := c.realWriteFrame(ctx, *c.writeHeader, p)
if err != nil {
return n, err
}
// We already finished writing, no need to potentially brick the connection if
// the context expires.
select {
case <-c.closed:
return n, c.closeErr
case c.setWriteTimeout <- context.Background():
}
return n, nil
}
func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) {
defer func() {
if err != nil {
select {
case <-c.closed:
err = c.closeErr
case <-ctx.Done():
err = ctx.Err()
default:
}
err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err)
// We need to release the lock first before closing the connection to ensure
// the lock can be acquired inside close to ensure no one can access c.bw.
c.releaseLock(c.writeFrameLock)
c.close(err)
}
}()
headerBytes := writeHeader(c.writeHeaderBuf, h)
_, err = c.bw.Write(headerBytes)
if err != nil {
return 0, err
}
if c.client {
var keypos int
for len(p) > 0 {
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
p2 := p
if len(p) > c.bw.Available() {
p2 = p[:c.bw.Available()]
}
n2, err := c.bw.Write(p2)
if err != nil {
return n, err
}
keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])
p = p[n2:]
n += n2
}
} else {
n, err = c.bw.Write(p)
if err != nil {
return n, err
}
}
if h.fin {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
return n, nil
}
func (c *Conn) writePong(p []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
err := c.writeControl(ctx, opPong, p)
return err
}
// Close closes the WebSocket connection with the given status code and reason.
//
// It will write a WebSocket close frame with a timeout of 5 seconds.
// The connection can only be closed once. Additional calls to Close
// are no-ops.
//
// This does not perform a WebSocket close handshake.
// See https://github.com/nhooyr/websocket/issues/103 for details on why.
//
// The maximum length of reason must be 125 bytes otherwise an internal
// error will be sent to the peer. For this reason, you should avoid
// sending a dynamic reason.
//
// Close will unblock all goroutines interacting with the connection.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
if err != nil {
return xerrors.Errorf("failed to close connection: %w", err)
}
return nil
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
ce := CloseError{
Code: code,
Reason: reason,
}
// This function also will not wait for a close frame from the peer like the RFC
// wants because that makes no sense and I don't think anyone actually follows that.
// Definitely worth seeing what popular browsers do later.
p, err := ce.bytes()
if err != nil {
fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err)
ce = CloseError{
Code: StatusInternalError,
}
p, _ = ce.bytes()
}
// CloseErrors sent are made opaque to prevent applications from thinking
// they received a given status.
err = c.writeClose(p, xerrors.Errorf("sent close frame: %v", ce))
if err != nil {
return err
}
if !xerrors.Is(c.closeErr, ce) {
return c.closeErr
}
return nil
}
func (c *Conn) writeClose(p []byte, cerr error) error {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
// If this fails, the connection had to have died.
err := c.writeControl(ctx, opClose, p)
if err != nil {
return err
}
c.close(cerr)
return nil
}
func init() {
rand.Seed(time.Now().UnixNano())
}
// Ping sends a ping to the peer and waits for a pong.
// Use this to measure latency or ensure the peer is responsive.
// Ping must be called concurrently with Reader as otherwise it does
// not read from the connection and relies on Reader to unblock
// when the pong arrives.
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
err := c.ping(ctx)
if err != nil {
return xerrors.Errorf("failed to ping: %w", err)
}
return nil
}
func (c *Conn) ping(ctx context.Context) error {
id := rand.Uint64()
p := strconv.FormatUint(id, 10)
pong := make(chan struct{})
c.activePingsMu.Lock()
c.activePings[p] = pong
c.activePingsMu.Unlock()
defer func() {
c.activePingsMu.Lock()
delete(c.activePings, p)
c.activePingsMu.Unlock()
}()
err := c.writeControl(ctx, opPing, []byte(p))
if err != nil {
return err
}
select {
case <-c.closed:
return c.closeErr
case <-ctx.Done():
err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
case <-pong:
return nil
}
}
type writerFunc func(p []byte) (int, error)
func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and stores it in c.writeBuf.
func (c *Conn) extractBufioWriterBuf(w io.Writer) {
c.bw.Reset(writerFunc(func(p2 []byte) (int, error) {
c.writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
c.bw.WriteByte(0)
c.bw.Flush()
c.bw.Reset(w)
}
package websocket_test
import (
"context"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/url"
"os"
"os/exec"
"reflect"
"strconv"
"strings"
"sync/atomic"
"testing"
"time"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/duration"
"github.com/google/go-cmp/cmp"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"nhooyr.io/websocket/wspb"
)
func TestHandshake(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
client func(ctx context.Context, url string) error
server func(w http.ResponseWriter, r *http.Request) error
}{
{
name: "handshake",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"myproto"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
c, resp, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"myproto"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
checkHeader := func(h, exp string) {
t.Helper()
value := resp.Header.Get(h)
if exp != value {
t.Errorf("expected different value for header %v: %v", h, cmp.Diff(exp, value))
}
}
checkHeader("Connection", "Upgrade")
checkHeader("Upgrade", "websocket")
checkHeader("Sec-WebSocket-Protocol", "myproto")
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "closeError",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
err = wsjson.Write(r.Context(), c, "hello")
if err != nil {
return err
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
var m string
err = wsjson.Read(ctx, c, &m)
if err != nil {
return err
}
if m != "hello" {
return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m)
}
_, _, err = c.Reader(ctx)
var cerr websocket.CloseError
if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError {
return xerrors.Errorf("unexpected error: %+v", err)
}
return nil
},
},
{
name: "netConn",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
nc := websocket.NetConn(c)
defer nc.Close()
nc.SetWriteDeadline(time.Now().Add(time.Second * 15))
for i := 0; i < 3; i++ {
_, err = nc.Write([]byte("hello"))
if err != nil {
return err
}
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
nc := websocket.NetConn(c)
defer nc.Close()
nc.SetReadDeadline(time.Now().Add(time.Second * 15))
read := func() error {
p := make([]byte, len("hello"))
// We do not use io.ReadFull here as it masks EOFs.
// See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024
_, err = nc.Read(p)
if err != nil {
return err
}
if string(p) != "hello" {
return xerrors.Errorf("unexpected payload %q received", string(p))
}
return nil
}
for i := 0; i < 3; i++ {
err = read()
if err != nil {
return err
}
}
// Ensure the close frame is converted to an EOF and multiple read's after all return EOF.
err = read()
if err != io.EOF {
return err
}
err = read()
if err != io.EOF {
return err
}
return nil
},
},
{
name: "defaultSubprotocol",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "" {
return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol())
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"meow"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "" {
return xerrors.Errorf("unexpected subprotocol: %v", c.Subprotocol())
}
return nil
},
},
{
name: "subprotocol",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"echo", "lar"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "echo" {
return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol())
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
Subprotocols: []string{"poof", "echo"},
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
if c.Subprotocol() != "echo" {
return xerrors.Errorf("unexpected subprotocol: %q", c.Subprotocol())
}
return nil
},
},
{
name: "badOrigin",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err == nil {
c.Close(websocket.StatusInternalError, "")
return xerrors.New("expected error regarding bad origin")
}
return nil
},
client: func(ctx context.Context, u string) error {
h := http.Header{}
h.Set("Origin", "http://unauthorized.com")
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPHeader: h,
})
if err == nil {
c.Close(websocket.StatusInternalError, "")
return xerrors.New("expected handshake failure")
}
return nil
},
},
{
name: "acceptSecureOrigin",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
h := http.Header{}
h.Set("Origin", u)
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPHeader: h,
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
},
{
name: "acceptInsecureOrigin",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
InsecureSkipVerify: true,
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
h := http.Header{}
h.Set("Origin", "https://example.com")
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPHeader: h,
})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
return nil
},
},
{
name: "jsonEcho",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
defer cancel()
write := func() error {
v := map[string]interface{}{
"anmol": "wowow",
}
err := wsjson.Write(ctx, c, v)
return err
}
err = write()
if err != nil {
return err
}
err = write()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
read := func() error {
var v interface{}
err := wsjson.Read(ctx, c, &v)
if err != nil {
return err
}
exp := map[string]interface{}{
"anmol": "wowow",
}
if !reflect.DeepEqual(exp, v) {
return xerrors.Errorf("expected %v but got %v", exp, v)
}
return nil
}
err = read()
if err != nil {
return err
}
err = read()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "protobufEcho",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
defer cancel()
write := func() error {
err := wspb.Write(ctx, c, ptypes.DurationProto(100))
return err
}
err = write()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
read := func() error {
var v duration.Duration
err := wspb.Read(ctx, c, &v)
if err != nil {
return err
}
d, err := ptypes.Duration(&v)
if err != nil {
return xerrors.Errorf("failed to convert duration.Duration to time.Duration: %w", err)
}
const exp = time.Duration(100)
if !reflect.DeepEqual(exp, d) {
return xerrors.Errorf("expected %v but got %v", exp, d)
}
return nil
}
err = read()
if err != nil {
return err
}
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "cookies",
server: func(w http.ResponseWriter, r *http.Request) error {
cookie, err := r.Cookie("mycookie")
if err != nil {
return xerrors.Errorf("request is missing mycookie: %w", err)
}
if cookie.Value != "myvalue" {
return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
}
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
c.Close(websocket.StatusInternalError, "")
return nil
},
client: func(ctx context.Context, u string) error {
jar, err := cookiejar.New(nil)
if err != nil {
return xerrors.Errorf("failed to create cookie jar: %w", err)
}
parsedURL, err := url.Parse(u)
if err != nil {
return xerrors.Errorf("failed to parse url: %w", err)
}
parsedURL.Scheme = "http"
jar.SetCookies(parsedURL, []*http.Cookie{
{
Name: "mycookie",
Value: "myvalue",
},
})
hc := &http.Client{
Jar: jar,
}
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{
HTTPClient: hc,
})
if err != nil {
return err
}
c.Close(websocket.StatusInternalError, "")
return nil
},
},
{
name: "ping",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
errc := make(chan error, 1)
go func() {
_, _, err2 := c.Read(r.Context())
errc <- err2
}()
err = c.Ping(r.Context())
if err != nil {
return err
}
err = c.Write(r.Context(), websocket.MessageText, []byte("hi"))
if err != nil {
return err
}
err = <-errc
var ce websocket.CloseError
if xerrors.As(err, &ce) && ce.Code == websocket.StatusNormalClosure {
return nil
}
return xerrors.Errorf("unexpected error: %w", err)
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
// We read a message from the connection and then keep reading until
// the Ping completes.
done := make(chan struct{})
go func() {
_, _, err := c.Read(ctx)
if err != nil {
c.Close(websocket.StatusInternalError, err.Error())
return
}
close(done)
c.Read(ctx)
}()
err = c.Ping(ctx)
if err != nil {
return err
}
<-done
c.Close(websocket.StatusNormalClosure, "")
return nil
},
},
{
name: "readLimit",
server: func(w http.ResponseWriter, r *http.Request) error {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
_, _, err = c.Read(r.Context())
if err == nil {
return xerrors.Errorf("expected error but got nil")
}
return nil
},
client: func(ctx context.Context, u string) error {
c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{})
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
go c.Reader(ctx)
err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769)))
if err != nil {
return err
}
err = c.Ping(ctx)
var ce websocket.CloseError
if !xerrors.As(err, &ce) || ce.Code != websocket.StatusMessageTooBig {
return xerrors.Errorf("unexpected error: %w", err)
}
return nil
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) {
err := tc.server(w, r)
if err != nil {
t.Errorf("server failed: %+v", err)
return
}
})
defer closeFn()
wsURL := strings.Replace(s.URL, "http", "ws", 1)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
err := tc.client(ctx, wsURL)
if err != nil {
t.Fatalf("client failed: %+v", err)
}
})
}
}
func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) {
var conns int64
s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&conns, 1)
defer atomic.AddInt64(&conns, -1)
fn.ServeHTTP(w, r)
}))
return s, func() {
s.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
for atomic.LoadInt64(&conns) > 0 {
if ctx.Err() != nil {
tb.Fatalf("waiting for server to come down timed out: %v", ctx.Err())
}
}
}
}
// https://github.com/crossbario/autobahn-python/tree/master/wstest
func TestAutobahnServer(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
t.Logf("server handshake failed: %+v", err)
return
}
echoLoop(r.Context(), c)
}))
defer s.Close()
spec := map[string]interface{}{
"outdir": "ci/out/wstestServerReports",
"servers": []interface{}{
map[string]interface{}{
"agent": "main",
"url": strings.Replace(s.URL, "http", "ws", 1),
},
},
"cases": []string{"*"},
// We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just
// more performance overhead. 7.5.1 is the same.
// 12.* and 13.* as we do not support compression.
"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
}
specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json")
if err != nil {
t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err)
}
defer specFile.Close()
e := json.NewEncoder(specFile)
e.SetIndent("", "\t")
err = e.Encode(spec)
if err != nil {
t.Fatalf("failed to write spec: %v", err)
}
err = specFile.Close()
if err != nil {
t.Fatalf("failed to close file: %v", err)
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
defer cancel()
args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()}
wstest := exec.CommandContext(ctx, "wstest", args...)
out, err := wstest.CombinedOutput()
if err != nil {
t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out)
}
checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
}
func echoLoop(ctx context.Context, c *websocket.Conn) {
defer c.Close(websocket.StatusInternalError, "")
c.SetReadLimit(1 << 40)
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
b := make([]byte, 32768)
echo := func() error {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
w, err := c.Writer(ctx, typ)
if err != nil {
return err
}
_, err = io.CopyBuffer(w, r, b)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return nil
}
for {
err := echo()
if err != nil {
return
}
}
}
func discardLoop(ctx context.Context, c *websocket.Conn) {
defer c.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
b := make([]byte, 32768)
echo := func() error {
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
_, err = io.CopyBuffer(ioutil.Discard, r, b)
if err != nil {
return err
}
return nil
}
for {
err := echo()
if err != nil {
return
}
}
}
// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py
func TestAutobahnClient(t *testing.T) {
t.Parallel()
spec := map[string]interface{}{
"url": "ws://localhost:9001",
"outdir": "ci/out/wstestClientReports",
"cases": []string{"*"},
// See TestAutobahnServer for the reasons why we exclude these.
"exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"},
}
specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json")
if err != nil {
t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err)
}
defer specFile.Close()
e := json.NewEncoder(specFile)
e.SetIndent("", "\t")
err = e.Encode(spec)
if err != nil {
t.Fatalf("failed to write spec: %v", err)
}
err = specFile.Close()
if err != nil {
t.Fatalf("failed to close file: %v", err)
}
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, time.Minute*10)
defer cancel()
args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name()}
if os.Getenv("CI") == "" {
args = append([]string{"--debug"}, args...)
}
wstest := exec.CommandContext(ctx, "wstest", args...)
err = wstest.Start()
if err != nil {
t.Fatal(err)
}
defer func() {
err := wstest.Process.Kill()
if err != nil {
t.Error(err)
}
}()
// Let it come up.
time.Sleep(time.Second * 5)
var cases int
func() {
c, _, err := websocket.Dial(ctx, "ws://localhost:9001/getCaseCount", websocket.DialOptions{})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
defer c.Close(websocket.StatusInternalError, "")
_, r, err := c.Reader(ctx)
if err != nil {
t.Fatal(err)
}
b, err := ioutil.ReadAll(r)
if err != nil {
t.Fatal(err)
}
cases, err = strconv.Atoi(string(b))
if err != nil {
t.Fatal(err)
}
c.Close(websocket.StatusNormalClosure, "")
}()
for i := 1; i <= cases; i++ {
func() {
ctx, cancel := context.WithTimeout(ctx, time.Second*45)
defer cancel()
c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/runCase?case=%v&agent=main", i), websocket.DialOptions{})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
echoLoop(ctx, c)
}()
}
c, _, err := websocket.Dial(ctx, fmt.Sprintf("ws://localhost:9001/updateReports?agent=main"), websocket.DialOptions{})
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
c.Close(websocket.StatusNormalClosure, "")
checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json")
}
func checkWSTestIndex(t *testing.T, path string) {
wstestOut, err := ioutil.ReadFile(path)
if err != nil {
t.Fatalf("failed to read index.json: %v", err)
}
var indexJSON map[string]map[string]struct {
Behavior string `json:"behavior"`
BehaviorClose string `json:"behaviorClose"`
}
err = json.Unmarshal(wstestOut, &indexJSON)
if err != nil {
t.Fatalf("failed to unmarshal index.json: %v", err)
}
var failed bool
for _, tests := range indexJSON {
for test, result := range tests {
switch result.Behavior {
case "OK", "NON-STRICT", "INFORMATIONAL":
default:
failed = true
t.Errorf("test %v failed", test)
}
switch result.BehaviorClose {
case "OK", "INFORMATIONAL":
default:
failed = true
t.Errorf("bad close behaviour for test %v", test)
}
}
}
if failed {
path = strings.Replace(path, ".json", ".html", 1)
if os.Getenv("CI") == "" {
t.Errorf("wstest found failure, please see %q", path)
} else {
t.Errorf("wstest found failure, please run test.sh locally to see %q", path)
}
}
}
func benchConn(b *testing.B, echo, stream bool, size int) {
s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, websocket.AcceptOptions{})
if err != nil {
b.Logf("server handshake failed: %+v", err)
return
}
if echo {
echoLoop(r.Context(), c)
} else {
discardLoop(r.Context(), c)
}
}))
defer closeFn()
wsURL := strings.Replace(s.URL, "http", "ws", 1)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
c, _, err := websocket.Dial(ctx, wsURL, websocket.DialOptions{})
if err != nil {
b.Fatalf("failed to dial: %v", err)
}
defer c.Close(websocket.StatusInternalError, "")
msg := []byte(strings.Repeat("2", size))
readBuf := make([]byte, len(msg))
b.SetBytes(int64(len(msg)))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
if stream {
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
b.Fatal(err)
}
_, err = w.Write(msg)
if err != nil {
b.Fatal(err)
}
err = w.Close()
if err != nil {
b.Fatal(err)
}
} else {
err = c.Write(ctx, websocket.MessageText, msg)
if err != nil {
b.Fatal(err)
}
}
if echo {
_, r, err := c.Reader(ctx)
if err != nil {
b.Fatal(err)
}
_, err = io.ReadFull(r, readBuf)
if err != nil {
b.Fatal(err)
}
}
}
b.StopTimer()
c.Close(websocket.StatusNormalClosure, "")
}
func BenchmarkConn(b *testing.B) {
sizes := []int{
2,
16,
32,
512,
4096,
16384,
}
b.Run("write", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
b.Run("stream", func(b *testing.B) {
benchConn(b, false, true, size)
})
b.Run("buffer", func(b *testing.B) {
benchConn(b, false, false, size)
})
})
}
})
b.Run("echo", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
benchConn(b, false, false, size)
})
}
})
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"compress/flate"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
return w, nil
}
// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type msgWriter struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context.Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate.Writer
}
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write),
}
}
if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter)
}
mw.flate = true
}
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return c.msgWriter, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = mw.Close()
return n, err
}
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
}
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false
mw.trimWriter.reset()
return nil
}
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
}
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
mw.ensureFlate()
}
}
if mw.flate {
return mw.flateWriter.Write(p)
}
return mw.write(p)
}
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
mw.opcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate: %w", err)
}
}
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.flate && !mw.flateContextTakeover() {
mw.putFlateWriter()
}
mw.mu.unlock()
return nil
}
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
}
mw.writeMu.forceLock()
mw.putFlateWriter()
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
defer func() {
if c.isClosed() && opcode == opClose {
err = nil
}
if err != nil {
if ctx.Err() != nil {
err = ctx.Err()
} else if c.isClosed() {
err = net.ErrClosed
}
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.closeStateMu.Lock()
closeSentErr := c.closeSentErr
c.closeStateMu.Unlock()
if closeSentErr != nil {
return 0, net.ErrClosed
}
select {
case <-c.closed:
return 0, net.ErrClosed
case c.writeTimeout <- ctx:
}
defer func() {
select {
case <-c.closed:
case c.writeTimeout <- context.Background():
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if c.client {
c.writeHeader.masked = true
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
}
c.writeHeader.rsv1 = false
if flate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
if err != nil {
return 0, err
}
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
if opcode == opClose {
c.closeStateMu.Lock()
c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
closeReceived := c.closeReceivedErr != nil
c.closeStateMu.Unlock()
if closeReceived && !c.casClosing() {
c.writeFrameMu.unlock()
_ = c.close()
}
}
return n, nil
}
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !c.writeHeader.masked {
return c.bw.Write(p)
}
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
j := len(p)
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
p = p[j:]
n += j
}
return n, nil
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
bw.WriteByte(0)
bw.Flush()
bw.Reset(w)
return writeBuf
}
func (c *Conn) writeError(code StatusCode, err error) {
c.writeClose(code, err.Error())
}
package websocket // import "github.com/coder/websocket"
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall/js"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/wsjs"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// Conn provides a wrapper around the browser WebSocket API.
type Conn struct {
noCopy noCopy
ws wsjs.WebSocket
// read limit for a message in bytes.
msgReadLimit atomic.Int64
closeReadMu sync.Mutex
closeReadCtx context.Context
closingMu sync.Mutex
closeOnce sync.Once
closed chan struct{}
closeErrOnce sync.Once
closeErr error
closeWasClean bool
releaseOnClose func()
releaseOnError func()
releaseOnMessage func()
readSignal chan struct{}
readBufMu sync.Mutex
readBuf []wsjs.MessageEvent
}
func (c *Conn) close(err error, wasClean bool) {
c.closeOnce.Do(func() {
runtime.SetFinalizer(c, nil)
if !wasClean {
err = fmt.Errorf("unclean connection close: %w", err)
}
c.setCloseErr(err)
c.closeWasClean = wasClean
close(c.closed)
})
}
func (c *Conn) init() {
c.closed = make(chan struct{})
c.readSignal = make(chan struct{}, 1)
c.msgReadLimit.Store(32768)
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
err := CloseError{
Code: StatusCode(e.Code),
Reason: e.Reason,
}
// We do not know if we sent or received this close as
// its possible the browser triggered it without us
// explicitly sending it.
c.close(err, e.WasClean)
c.releaseOnClose()
c.releaseOnError()
c.releaseOnMessage()
})
c.releaseOnError = c.ws.OnError(func(v js.Value) {
c.setCloseErr(errors.New(v.Get("message").String()))
c.closeWithInternal()
})
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
c.readBuf = append(c.readBuf, e)
// Lets the read goroutine know there is definitely something in readBuf.
select {
case c.readSignal <- struct{}{}:
default:
}
})
runtime.SetFinalizer(c, func(c *Conn) {
c.setCloseErr(errors.New("connection garbage collected"))
c.closeWithInternal()
})
}
func (c *Conn) closeWithInternal() {
c.Close(StatusInternalError, "something went wrong")
}
// Read attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
c.closeReadMu.Lock()
closedRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closedRead {
return 0, nil, errors.New("WebSocket connection read closed")
}
typ, p, err := c.read(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to read: %w", err)
}
readLimit := c.msgReadLimit.Load()
if readLimit >= 0 && int64(len(p)) > readLimit {
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
c.Close(StatusMessageTooBig, err.Error())
return 0, nil, err
}
return typ, p, nil
}
func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "read timed out")
return 0, nil, ctx.Err()
case <-c.readSignal:
case <-c.closed:
return 0, nil, net.ErrClosed
}
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
me := c.readBuf[0]
// We copy the messages forward and decrease the size
// of the slice to avoid reallocating.
copy(c.readBuf, c.readBuf[1:])
c.readBuf = c.readBuf[:len(c.readBuf)-1]
if len(c.readBuf) > 0 {
// Next time we read, we'll grab the message.
select {
case c.readSignal <- struct{}{}:
default:
}
}
switch p := me.Data.(type) {
case string:
return MessageText, []byte(p), nil
case []byte:
return MessageBinary, p, nil
default:
panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
}
}
// Ping is mocked out for Wasm.
func (c *Conn) Ping(ctx context.Context) error {
return nil
}
// Write writes a message of the given type to the connection.
// Always non blocking.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
err := c.write(ctx, typ, p)
if err != nil {
// Have to ensure the WebSocket is closed after a write error
// to match the Go API. It can only error if the message type
// is unexpected or the passed bytes contain invalid UTF-8 for
// MessageText.
err := fmt.Errorf("failed to write: %w", err)
c.setCloseErr(err)
c.closeWithInternal()
return err
}
return nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
if c.isClosed() {
return net.ErrClosed
}
switch typ {
case MessageBinary:
return c.ws.SendBytes(p)
case MessageText:
return c.ws.SendText(string(p))
default:
return fmt.Errorf("unexpected message type: %v", typ)
}
}
// Close closes the WebSocket with the given code and reason.
// It will wait until the peer responds with a close frame
// or the connection is closed.
// It thus performs the full WebSocket close handshake.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
if err != nil {
return fmt.Errorf("failed to close WebSocket: %w", err)
}
return nil
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
//
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
// a WebSocket without the close handshake.
func (c *Conn) CloseNow() error {
return c.Close(StatusGoingAway, "")
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
c.closingMu.Lock()
defer c.closingMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
ce := fmt.Errorf("sent close: %w", CloseError{
Code: code,
Reason: reason,
})
c.setCloseErr(ce)
err := c.ws.Close(int(code), reason)
if err != nil {
return err
}
<-c.closed
if !c.closeWasClean {
return c.closeErr
}
return nil
}
// Subprotocol returns the negotiated subprotocol.
// An empty string means the default protocol.
func (c *Conn) Subprotocol() string {
return c.ws.Subprotocol()
}
// DialOptions represents the options available to pass to Dial.
type DialOptions struct {
// Subprotocols lists the subprotocols to negotiate with the server.
Subprotocols []string
}
// Dial creates a new WebSocket connection to the given url with the given options.
// The passed context bounds the maximum time spent waiting for the connection to open.
// The returned *http.Response is always nil or a mock. It's only in the signature
// to match the core API.
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
c, resp, err := dial(ctx, url, opts)
if err != nil {
return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
}
return c, resp, nil
}
func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
if opts == nil {
opts = &DialOptions{}
}
url = strings.Replace(url, "http://", "ws://", 1)
url = strings.Replace(url, "https://", "wss://", 1)
ws, err := wsjs.New(url, opts.Subprotocols)
if err != nil {
return nil, nil, err
}
c := &Conn{
ws: ws,
}
c.init()
opench := make(chan struct{})
releaseOpen := ws.OnOpen(func(e js.Value) {
close(opench)
})
defer releaseOpen()
select {
case <-ctx.Done():
c.Close(StatusPolicyViolation, "dial timed out")
return nil, nil, ctx.Err()
case <-opench:
return c, &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}, nil
case <-c.closed:
return nil, nil, net.ErrClosed
}
}
// Reader attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
typ, p, err := c.Read(ctx)
if err != nil {
return 0, nil, err
}
return typ, bytes.NewReader(p), nil
}
// Writer returns a writer to write a WebSocket data message to the connection.
// It buffers the entire message in memory and then sends it when the writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
return &writer{
c: c,
ctx: ctx,
typ: typ,
b: bpool.Get(),
}, nil
}
type writer struct {
closed bool
c *Conn
ctx context.Context
typ MessageType
b *bytes.Buffer
}
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, errors.New("cannot write to closed writer")
}
n, err := w.b.Write(p)
if err != nil {
return n, fmt.Errorf("failed to write message: %w", err)
}
return n, nil
}
func (w *writer) Close() error {
if w.closed {
return errors.New("cannot close closed writer")
}
w.closed = true
defer bpool.Put(w.b)
err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
if err != nil {
return fmt.Errorf("failed to close writer: %w", err)
}
return nil
}
// CloseRead implements *Conn.CloseRead for wasm.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadMu.Unlock()
go func() {
defer cancel()
defer c.CloseNow()
_, _, err := c.read(ctx)
if err != nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit implements *Conn.SetReadLimit for wasm.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit.Store(n)
}
func (c *Conn) setCloseErr(err error) {
c.closeErrOnce.Do(func() {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
})
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
Subprotocols []string
InsecureSkipVerify bool
OriginPatterns []string
CompressionMode CompressionMode
CompressionThreshold int
}
// Accept is stubbed out for Wasm.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, errors.New("unimplemented")
}
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int
const (
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
// for every message. This applies to both server and client side.
//
// This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used.
//
// The message will only be compressed if greater than 512 bytes.
CompressionNoContextTakeover CompressionMode = iota
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
// This enables reusing the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
//
// If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC.
CompressionContextTakeover
// CompressionDisabled disables the deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
CompressionDisabled
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
package websocket_test
import (
"context"
"net/http"
"os"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
)
func TestWasm(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
assert.Equal(t, "subprotocol", "echo", c.Subprotocol())
assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode)
c.SetReadLimit(65536)
for i := 0; i < 10; i++ {
err = wstest.Echo(ctx, c, 65536)
assert.Success(t, err)
}
err = c.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
}
func TestWasmDialTimeout(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
beforeDial := time.Now()
_, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
assert.Error(t, err)
if time.Since(beforeDial) >= time.Second {
t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial))
}
}
// Package wsjson provides websocket helpers for JSON messages.
package wsjson
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "github.com/coder/websocket/wsjson"
import (
"context"
"encoding/json"
"fmt"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Read reads a json message from c into v.
// It will reuse buffers to avoid allocations.
// Read reads a JSON message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := read(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to read json: %w", err)
}
return nil
return read(ctx, c, v)
}
func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
typ, r, err := c.Reader(ctx)
func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to read JSON message")
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageText {
c.Close(websocket.StatusUnsupportedData, "can only accept text messages")
return xerrors.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ)
}
b := bpool.Get()
defer func() {
bpool.Put(b)
}()
defer bpool.Put(b)
_, err = b.ReadFrom(r)
if err != nil {
......@@ -45,39 +37,32 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err = json.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
return xerrors.Errorf("failed to unmarshal json: %w", err)
return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
return nil
}
// Write writes the json message v to c.
// It will reuse buffers to avoid allocations.
// Write writes the JSON message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := write(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to write json: %w", err)
}
return nil
return write(ctx, c, v)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) error {
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
return err
}
// We use Encode because it automatically enables buffer reuse without us
// needing to do anything. Though see https://github.com/golang/go/issues/27735
e := json.NewEncoder(w)
err = e.Encode(v)
if err != nil {
return xerrors.Errorf("failed to encode json: %w", err)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to write JSON message")
err = w.Close()
// json.Marshal cannot reuse buffers between calls as it has to return
// a copy of the byte slice but Encoder does as it directly writes to w.
err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) {
err := c.Write(ctx, websocket.MessageText, p)
if err != nil {
return 0, err
}
return len(p), nil
})).Encode(v)
if err != nil {
return err
return fmt.Errorf("failed to marshal JSON: %w", err)
}
return nil
}
package wsjson_test
import (
"encoding/json"
"io"
"strconv"
"testing"
"github.com/coder/websocket/internal/test/xrand"
)
func BenchmarkJSON(b *testing.B) {
sizes := []int{
8,
16,
32,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
}
b.Run("json.Encoder", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.NewEncoder(io.Discard).Encode(msg)
}
})
}
})
b.Run("json.Marshal", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.Marshal(msg)
}
})
}
})
}
// Package wspb provides websocket helpers for protobuf messages.
package wspb
import (
"bytes"
"context"
"sync"
"github.com/golang/protobuf/proto"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
)
// Read reads a protobuf message from c into v.
// It will reuse buffers to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := read(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to read protobuf: %w", err)
}
return nil
}
func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageBinary {
c.Close(websocket.StatusUnsupportedData, "can only accept binary messages")
return xerrors.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ)
}
b := bpool.Get()
defer func() {
bpool.Put(b)
}()
_, err = b.ReadFrom(r)
if err != nil {
return err
}
err = proto.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf")
return xerrors.Errorf("failed to unmarshal protobuf: %w", err)
}
return nil
}
// Write writes the protobuf message v to c.
// It will reuse buffers to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := write(ctx, c, v)
if err != nil {
return xerrors.Errorf("failed to write protobuf: %w", err)
}
return nil
}
var writeBufPool sync.Pool
func write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
b := bpool.Get()
pb := proto.NewBuffer(b.Bytes())
defer func() {
bpool.Put(bytes.NewBuffer(pb.Bytes()))
}()
err := pb.Marshal(v)
if err != nil {
return xerrors.Errorf("failed to marshal protobuf: %w", err)
}
return c.Write(ctx, websocket.MessageBinary, pb.Bytes())
}
package websocket
import (
"encoding/binary"
)
// xor applies the WebSocket masking algorithm to p
// with the given key where the first 3 bits of pos
// are the starting position in the key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the position of the next byte
// to be used for masking in the key. This is so that
// unmasking can be performed without the entire frame.
func fastXOR(key [4]byte, keyPos int, b []byte) int {
// If the payload is greater than or equal to 16 bytes, then it's worth
// masking 8 bytes at a time.
// Optimization from https://github.com/golang/go/issues/31586#issuecomment-485530859
if len(b) >= 16 {
// We first create a key that is 8 bytes long
// and is aligned on the position correctly.
var alignedKey [8]byte
for i := range alignedKey {
alignedKey[i] = key[(i+keyPos)&3]
}
k := binary.LittleEndian.Uint64(alignedKey[:])
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
v = binary.LittleEndian.Uint64(b[32:])
binary.LittleEndian.PutUint64(b[32:], v^k)
v = binary.LittleEndian.Uint64(b[40:])
binary.LittleEndian.PutUint64(b[40:], v^k)
v = binary.LittleEndian.Uint64(b[48:])
binary.LittleEndian.PutUint64(b[48:], v^k)
v = binary.LittleEndian.Uint64(b[56:])
binary.LittleEndian.PutUint64(b[56:], v^k)
v = binary.LittleEndian.Uint64(b[64:])
binary.LittleEndian.PutUint64(b[64:], v^k)
v = binary.LittleEndian.Uint64(b[72:])
binary.LittleEndian.PutUint64(b[72:], v^k)
v = binary.LittleEndian.Uint64(b[80:])
binary.LittleEndian.PutUint64(b[80:], v^k)
v = binary.LittleEndian.Uint64(b[88:])
binary.LittleEndian.PutUint64(b[88:], v^k)
v = binary.LittleEndian.Uint64(b[96:])
binary.LittleEndian.PutUint64(b[96:], v^k)
v = binary.LittleEndian.Uint64(b[104:])
binary.LittleEndian.PutUint64(b[104:], v^k)
v = binary.LittleEndian.Uint64(b[112:])
binary.LittleEndian.PutUint64(b[112:], v^k)
v = binary.LittleEndian.Uint64(b[120:])
binary.LittleEndian.PutUint64(b[120:], v^k)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
v = binary.LittleEndian.Uint64(b[32:])
binary.LittleEndian.PutUint64(b[32:], v^k)
v = binary.LittleEndian.Uint64(b[40:])
binary.LittleEndian.PutUint64(b[40:], v^k)
v = binary.LittleEndian.Uint64(b[48:])
binary.LittleEndian.PutUint64(b[48:], v^k)
v = binary.LittleEndian.Uint64(b[56:])
binary.LittleEndian.PutUint64(b[56:], v^k)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
v = binary.LittleEndian.Uint64(b[16:])
binary.LittleEndian.PutUint64(b[16:], v^k)
v = binary.LittleEndian.Uint64(b[24:])
binary.LittleEndian.PutUint64(b[24:], v^k)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
v = binary.LittleEndian.Uint64(b[8:])
binary.LittleEndian.PutUint64(b[8:], v^k)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^k)
b = b[8:]
}
}
// xor remaining bytes.
for i := range b {
b[i] ^= key[keyPos&3]
keyPos++
}
return keyPos & 3
}