good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
//go:build amd64 || arm64
package websocket
func mask(b []byte, key uint32) uint32 {
// TODO: Will enable in v1.9.0.
return maskGo(b, key)
/*
if len(b) > 0 {
return maskAsm(&b[0], len(b), key)
}
return key
*/
}
// @nhooyr: I am not confident that the amd64 or the arm64 implementations of this
// function are perfect. There are almost certainly missing optimizations or
// opportunities for simplification. I'm confident there are no bugs though.
// For example, the arm64 implementation doesn't align memory like the amd64.
// Or the amd64 implementation could use AVX512 instead of just AVX2.
// The AVX2 code I had to disable anyway as it wasn't performing as expected.
// See https://github.com/nhooyr/websocket/pull/326#issuecomment-1771138049
//
//go:noescape
//lint:ignore U1000 disabled till v1.9.0
func maskAsm(b *byte, len int, key uint32) uint32
//go:build amd64 || arm64
package websocket
import "testing"
func TestMaskASM(t *testing.T) {
t.Parallel()
testMask(t, "maskASM", mask)
}
//go:build !amd64 && !arm64 && !js
package websocket
func mask(b []byte, key uint32) uint32 {
return maskGo(b, key)
}
package websocket
import (
"bytes"
"crypto/rand"
"encoding/binary"
"math/big"
"math/bits"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func basicMask(b []byte, key uint32) uint32 {
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
func basicMask2(b []byte, key uint32) uint32 {
keyb := binary.LittleEndian.AppendUint32(nil, key)
pos := 0
for i := range b {
b[i] ^= keyb[pos&3]
pos++
}
return bits.RotateLeft32(key, (pos&3)*-8)
}
func TestMask(t *testing.T) {
t.Parallel()
testMask(t, "basicMask", basicMask)
testMask(t, "maskGo", maskGo)
testMask(t, "basicMask2", basicMask2)
}
func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) {
t.Run(name, func(t *testing.T) {
t.Parallel()
for i := 0; i < 9999; i++ {
keyb := make([]byte, 4)
_, err := rand.Read(keyb)
assert.Success(t, err)
key := binary.LittleEndian.Uint32(keyb)
n, err := rand.Int(rand.Reader, big.NewInt(1<<16))
assert.Success(t, err)
b := make([]byte, 1+n.Int64())
_, err = rand.Read(b)
assert.Success(t, err)
b2 := make([]byte, len(b))
copy(b2, b)
b3 := make([]byte, len(b))
copy(b3, b)
key2 := basicMask(b2, key)
key3 := fn(b3, key)
if key2 != key3 {
t.Errorf("expected key %X but got %X", key2, key3)
}
if !bytes.Equal(b2, b3) {
t.Error("bad bytes")
return
}
}
})
}
// This file contains *Conn symbols relevant to both
// Wasm and non Wasm builds.
package websocket
import (
......@@ -9,7 +6,6 @@ import (
"io"
"math"
"net"
"sync"
"sync/atomic"
"time"
)
......@@ -32,30 +28,64 @@ import (
//
// Close will close the *websocket.Conn with StatusNormalClosure.
//
// When a deadline is hit, the connection will be closed. This is
// different from most net.Conn implementations where only the
// reading/writing goroutines are interrupted but the connection is kept alive.
// When a deadline is hit and there is an active read or write goroutine, the
// connection will be closed. This is different from most net.Conn implementations
// where only the reading/writing goroutines are interrupted but the connection
// is kept alive.
//
// The Addr methods will return the real addresses for connections obtained
// from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
// will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
// String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
// full net.Conn to us.
//
// The Addr methods will return a mock net.Addr that returns "websocket" for Network
// and "websocket/unknown-addr" for String.
// When running as WASM, the Addr methods will always return the mock address described above.
//
// A received StatusNormalClosure or StatusGoingAway close frame will be translated to
// io.EOF when reading.
//
// Furthermore, the ReadLimit is set to -1 to disable it.
func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
c.SetReadLimit(-1)
nc := &netConn{
c: c,
msgType: msgType,
readMu: newMu(c),
writeMu: newMu(c),
}
var cancel context.CancelFunc
nc.writeContext, cancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
nc.readCtx, nc.readCancel = context.WithCancel(ctx)
nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.writeMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active write goroutine and so we should cancel the context.
nc.writeCancel()
return
}
defer nc.writeMu.unlock()
// Prevents future writes from writing until the deadline is reset.
nc.writeExpired.Store(1)
})
if !nc.writeTimer.Stop() {
<-nc.writeTimer.C
}
nc.readContext, cancel = context.WithCancel(ctx)
nc.readTimer = time.AfterFunc(math.MaxInt64, cancel)
nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
if !nc.readMu.tryLock() {
// If the lock cannot be acquired, then there is an
// active read goroutine and so we should cancel the context.
nc.readCancel()
return
}
defer nc.readMu.unlock()
// Prevents future reads from reading until the deadline is reset.
nc.readExpired.Store(1)
})
if !nc.readTimer.Stop() {
<-nc.readTimer.C
}
......@@ -68,59 +98,91 @@ type netConn struct {
msgType MessageType
writeTimer *time.Timer
writeContext context.Context
writeMu *mu
writeExpired atomic.Int64
writeCtx context.Context
writeCancel context.CancelFunc
readTimer *time.Timer
readContext context.Context
readMu sync.Mutex
eofed bool
reader io.Reader
readMu *mu
readExpired atomic.Int64
readCtx context.Context
readCancel context.CancelFunc
readEOFed bool
reader io.Reader
}
var _ net.Conn = &netConn{}
func (c *netConn) Close() error {
return c.c.Close(StatusNormalClosure, "")
func (nc *netConn) Close() error {
nc.writeTimer.Stop()
nc.writeCancel()
nc.readTimer.Stop()
nc.readCancel()
return nc.c.Close(StatusNormalClosure, "")
}
func (c *netConn) Write(p []byte) (int, error) {
err := c.c.Write(c.writeContext, c.msgType, p)
func (nc *netConn) Write(p []byte) (int, error) {
nc.writeMu.forceLock()
defer nc.writeMu.unlock()
if nc.writeExpired.Load() == 1 {
return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
}
err := nc.c.Write(nc.writeCtx, nc.msgType, p)
if err != nil {
return 0, err
}
return len(p), nil
}
func (c *netConn) Read(p []byte) (int, error) {
c.readMu.Lock()
defer c.readMu.Unlock()
func (nc *netConn) Read(p []byte) (int, error) {
nc.readMu.forceLock()
defer nc.readMu.unlock()
for {
n, err := nc.read(p)
if err != nil {
return n, err
}
if n == 0 {
continue
}
return n, nil
}
}
if c.eofed {
func (nc *netConn) read(p []byte) (int, error) {
if nc.readExpired.Load() == 1 {
return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
}
if nc.readEOFed {
return 0, io.EOF
}
if c.reader == nil {
typ, r, err := c.c.Reader(c.readContext)
if nc.reader == nil {
typ, r, err := nc.c.Reader(nc.readCtx)
if err != nil {
switch CloseStatus(err) {
case StatusNormalClosure, StatusGoingAway:
c.eofed = true
nc.readEOFed = true
return 0, io.EOF
}
return 0, err
}
if typ != c.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ)
c.c.Close(StatusUnsupportedData, err.Error())
if typ != nc.msgType {
err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
nc.c.Close(StatusUnsupportedData, err.Error())
return 0, err
}
c.reader = r
nc.reader = r
}
n, err := c.reader.Read(p)
n, err := nc.reader.Read(p)
if err == io.EOF {
c.reader = nil
nc.reader = nil
err = nil
}
return n, err
......@@ -137,109 +199,36 @@ func (a websocketAddr) String() string {
return "websocket/unknown-addr"
}
func (c *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}
func (c *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}
func (c *netConn) SetDeadline(t time.Time) error {
c.SetWriteDeadline(t)
c.SetReadDeadline(t)
func (nc *netConn) SetDeadline(t time.Time) error {
nc.SetWriteDeadline(t)
nc.SetReadDeadline(t)
return nil
}
func (c *netConn) SetWriteDeadline(t time.Time) error {
func (nc *netConn) SetWriteDeadline(t time.Time) error {
nc.writeExpired.Store(0)
if t.IsZero() {
c.writeTimer.Stop()
nc.writeTimer.Stop()
} else {
c.writeTimer.Reset(t.Sub(time.Now()))
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.writeTimer.Reset(dur)
}
return nil
}
func (c *netConn) SetReadDeadline(t time.Time) error {
func (nc *netConn) SetReadDeadline(t time.Time) error {
nc.readExpired.Store(0)
if t.IsZero() {
c.readTimer.Stop()
nc.readTimer.Stop()
} else {
c.readTimer.Reset(t.Sub(time.Now()))
dur := time.Until(t)
if dur <= 0 {
dur = 1
}
nc.readTimer.Reset(dur)
}
return nil
}
// CloseRead will start a goroutine to read from the connection until it is closed or a data message
// is received. If a data message is received, the connection will be closed with StatusPolicyViolation.
// Since CloseRead reads from the connection, it will respond to ping, pong and close frames.
// After calling this method, you cannot read any data messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// Use this when you do not want to read data messages from the connection anymore but will
// want to write messages to it.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.isReadClosed.Store(1)
ctx, cancel := context.WithCancel(ctx)
go func() {
defer cancel()
// We use the unexported reader method so that we don't get the read closed error.
c.reader(ctx, true)
// Either the connection is already closed since there was a read error
// or the context was cancelled or a message was read and we should close
// the connection.
c.Close(StatusPolicyViolation, "unexpected data message")
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit.Store(n)
}
func (c *Conn) setCloseErr(err error) {
c.closeErrOnce.Do(func() {
c.closeErr = fmt.Errorf("websocket closed: %w", err)
})
}
// See https://github.com/nhooyr/websocket/issues/153
type atomicInt64 struct {
v int64
}
func (v *atomicInt64) Load() int64 {
return atomic.LoadInt64(&v.v)
}
func (v *atomicInt64) Store(i int64) {
atomic.StoreInt64(&v.v, i)
}
func (v *atomicInt64) String() string {
return fmt.Sprint(v.Load())
}
// Increment increments the value and returns the new value.
func (v *atomicInt64) Increment(delta int64) int64 {
return atomic.AddInt64(&v.v, delta)
}
func (v *atomicInt64) CAS(old, new int64) (swapped bool) {
return atomic.CompareAndSwapInt64(&v.v, old, new)
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
return websocketAddr{}
}
//go:build !js
// +build !js
package websocket
import "net"
func (nc *netConn) RemoteAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.RemoteAddr()
}
return websocketAddr{}
}
func (nc *netConn) LocalAddr() net.Addr {
if unc, ok := nc.c.rwc.(net.Conn); ok {
return unc.LocalAddr()
}
return websocketAddr{}
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"strings"
"sync/atomic"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Reader reads from the connection until there is a WebSocket
// data message to be read. It will handle ping, pong and close frames as appropriate.
//
// It returns the type of the message and an io.Reader to read it.
// The passed context will also bound the reader.
// Ensure you read to EOF otherwise the connection will hang.
//
// Call CloseRead if you do not expect any data messages from the peer.
//
// Only one Reader may be open at a time.
//
// If you need a separate timeout on the Reader call and the Read itself,
// use time.AfterFunc to cancel the context passed in.
// See https://github.com/nhooyr/websocket/issues/87#issue-451703332
// Most users should not need this.
func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return c.reader(ctx)
}
// Read is a convenience method around Reader to read a single message
// from the connection.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
typ, r, err := c.Reader(ctx)
if err != nil {
return 0, nil, err
}
b, err := io.ReadAll(r)
return typ, b, err
}
// CloseRead starts a goroutine to read from the connection until it is closed
// or a data message is received.
//
// Once CloseRead is called you cannot read any messages from the connection.
// The returned context will be cancelled when the connection is closed.
//
// If a data message is received, the connection will be closed with StatusPolicyViolation.
//
// Call CloseRead when you do not expect to read any more messages.
// Since it actively reads from the connection, it will ensure that ping, pong and close
// frames are responded to. This means c.Ping and c.Close will still work as expected.
//
// This function is idempotent.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadDone = make(chan struct{})
c.closeReadMu.Unlock()
go func() {
defer close(c.closeReadDone)
defer cancel()
defer c.close()
_, _, err := c.Reader(ctx)
if err == nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit sets the max number of bytes to read for a single message.
// It applies to the Reader and Read methods.
//
// By default, the connection has a message read limit of 32768 bytes.
//
// When the limit is hit, the connection will be closed with StatusMessageTooBig.
//
// Set to -1 to disable.
func (c *Conn) SetReadLimit(n int64) {
if n >= 0 {
// We read one more byte than the limit in case
// there is a fin frame that needs to be read.
n++
}
c.msgReader.limitReader.limit.Store(n)
}
const defaultReadLimit = 32768
func newMsgReader(c *Conn) *msgReader {
mr := &msgReader{
c: c,
fin: true,
}
mr.readFunc = mr.read
mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
return mr
}
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() {
if mr.dict == nil {
mr.dict = &slidingWindow{}
}
mr.dict.init(32768)
}
if mr.flateBufio == nil {
mr.flateBufio = getBufioReader(mr.readFunc)
}
if mr.flateContextTakeover() {
mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
} else {
mr.flateReader = getFlateReader(mr.flateBufio, nil)
}
mr.limitReader.r = mr.flateReader
mr.flateTail.Reset(deflateMessageTail)
}
func (mr *msgReader) putFlateReader() {
if mr.flateReader != nil {
putFlateReader(mr.flateReader)
mr.flateReader = nil
}
}
func (mr *msgReader) close() {
mr.c.readMu.forceLock()
mr.putFlateReader()
if mr.dict != nil {
mr.dict.close()
mr.dict = nil
}
if mr.flateBufio != nil {
putBufioReader(mr.flateBufio)
}
if mr.c.client {
putBufioReader(mr.c.br)
mr.c.br = nil
}
}
func (mr *msgReader) flateContextTakeover() bool {
if mr.c.client {
return !mr.c.copts.serverNoContextTakeover
}
return !mr.c.copts.clientNoContextTakeover
}
func (c *Conn) readRSV1Illegal(h header) bool {
// If compression is disabled, rsv1 is illegal.
if !c.flate() {
return true
}
// rsv1 is only allowed on data frames beginning messages.
if h.opcode != opText && h.opcode != opBinary {
return true
}
return false
}
func (c *Conn) readLoop(ctx context.Context) (header, error) {
for {
h, err := c.readFrameHeader(ctx)
if err != nil {
return header{}, err
}
if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
c.writeError(StatusProtocolError, err)
return header{}, err
}
if !c.client && !h.masked {
return header{}, errors.New("received unmasked frame from client")
}
switch h.opcode {
case opClose, opPing, opPong:
err = c.handleControl(ctx, h)
if err != nil {
// Pass through CloseErrors when receiving a close frame.
if h.opcode == opClose && CloseStatus(err) != -1 {
return header{}, err
}
return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
}
case opContinuation, opText, opBinary:
return h, nil
default:
err := fmt.Errorf("received unknown opcode %v", h.opcode)
c.writeError(StatusProtocolError, err)
return header{}, err
}
}
}
// prepareRead sets the readTimeout context and returns a done function
// to be called after the read is done. It also returns an error if the
// connection is closed. The reference to the error is used to assign
// an error depending on if the connection closed or the context timed
// out during use. Typically the referenced error is a named return
// variable of the function calling this method.
func (c *Conn) prepareRead(ctx context.Context, err *error) (func(), error) {
select {
case <-c.closed:
return nil, net.ErrClosed
case c.readTimeout <- ctx:
}
done := func() {
select {
case <-c.closed:
if *err != nil {
*err = net.ErrClosed
}
case c.readTimeout <- context.Background():
}
if *err != nil && ctx.Err() != nil {
*err = ctx.Err()
}
}
c.closeStateMu.Lock()
closeReceivedErr := c.closeReceivedErr
c.closeStateMu.Unlock()
if closeReceivedErr != nil {
defer done()
return nil, closeReceivedErr
}
return done, nil
}
func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return header{}, err
}
defer readDone()
h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
if err != nil {
return header{}, err
}
return h, nil
}
func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) {
readDone, err := c.prepareRead(ctx, &err)
if err != nil {
return 0, err
}
defer readDone()
n, err := io.ReadFull(c.br, p)
if err != nil {
return n, fmt.Errorf("failed to read frame payload: %w", err)
}
return n, err
}
func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
c.writeError(StatusProtocolError, err)
return err
}
if !h.fin {
err := errors.New("received fragmented control frame")
c.writeError(StatusProtocolError, err)
return err
}
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
b := c.readControlBuf[:h.payloadLength]
_, err = c.readFramePayload(ctx, b)
if err != nil {
return err
}
if h.masked {
mask(b, h.maskKey)
}
switch h.opcode {
case opPing:
if c.onPingReceived != nil {
if !c.onPingReceived(ctx, b) {
return nil
}
}
return c.writeControl(ctx, opPong, b)
case opPong:
if c.onPongReceived != nil {
c.onPongReceived(ctx, b)
}
c.activePingsMu.Lock()
pong, ok := c.activePings[string(b)]
c.activePingsMu.Unlock()
if ok {
select {
case pong <- struct{}{}:
default:
}
}
return nil
}
// opClose
ce, err := parseClosePayload(b)
if err != nil {
err = fmt.Errorf("received invalid close payload: %w", err)
c.writeError(StatusProtocolError, err)
return err
}
err = fmt.Errorf("received close frame: %w", ce)
c.closeStateMu.Lock()
c.closeReceivedErr = err
closeSent := c.closeSentErr != nil
c.closeStateMu.Unlock()
// Only unlock readMu if this connection is being closed becaue
// c.close will try to acquire the readMu lock. We unlock for
// writeClose as well because it may also call c.close.
if !closeSent {
c.readMu.unlock()
_ = c.writeClose(ce.Code, ce.Reason)
}
if !c.casClosing() {
c.readMu.unlock()
_ = c.close()
}
return err
}
func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
defer errd.Wrap(&err, "failed to get reader")
err = c.readMu.lock(ctx)
if err != nil {
return 0, nil, err
}
defer c.readMu.unlock()
if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
}
h, err := c.readLoop(ctx)
if err != nil {
return 0, nil, err
}
if h.opcode == opContinuation {
err := errors.New("received continuation frame without text or binary frame")
c.writeError(StatusProtocolError, err)
return 0, nil, err
}
c.msgReader.reset(ctx, h)
return MessageType(h.opcode), c.msgReader, nil
}
type msgReader struct {
c *Conn
ctx context.Context
flate bool
flateReader io.Reader
flateBufio *bufio.Reader
flateTail strings.Reader
limitReader *limitReader
dict *slidingWindow
fin bool
payloadLength int64
maskKey uint32
// util.ReaderFunc(mr.Read) to avoid continuous allocations.
readFunc util.ReaderFunc
}
func (mr *msgReader) reset(ctx context.Context, h header) {
mr.ctx = ctx
mr.flate = h.rsv1
mr.limitReader.reset(mr.readFunc)
if mr.flate {
mr.resetFlate()
}
mr.setFrame(h)
}
func (mr *msgReader) setFrame(h header) {
mr.fin = h.fin
mr.payloadLength = h.payloadLength
mr.maskKey = h.maskKey
}
func (mr *msgReader) Read(p []byte) (n int, err error) {
err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()
n, err = mr.limitReader.Read(p)
if mr.flate && mr.flateContextTakeover() {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
return n, fmt.Errorf("failed to read: %w", err)
}
return n, nil
}
func (mr *msgReader) read(p []byte) (int, error) {
for {
if mr.payloadLength == 0 {
if mr.fin {
if mr.flate {
return mr.flateTail.Read(p)
}
return 0, io.EOF
}
h, err := mr.c.readLoop(mr.ctx)
if err != nil {
return 0, err
}
if h.opcode != opContinuation {
err := errors.New("received new data message without finishing the previous message")
mr.c.writeError(StatusProtocolError, err)
return 0, err
}
mr.setFrame(h)
continue
}
if int64(len(p)) > mr.payloadLength {
p = p[:mr.payloadLength]
}
n, err := mr.c.readFramePayload(mr.ctx, p)
if err != nil {
return n, err
}
mr.payloadLength -= int64(n)
if !mr.c.client {
mr.maskKey = mask(p, mr.maskKey)
}
return n, nil
}
}
type limitReader struct {
c *Conn
r io.Reader
limit atomic.Int64
n int64
}
func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
lr := &limitReader{
c: c,
}
lr.limit.Store(limit)
lr.reset(r)
return lr
}
func (lr *limitReader) reset(r io.Reader) {
lr.n = lr.limit.Load()
lr.r = r
}
func (lr *limitReader) Read(p []byte) (int, error) {
if lr.n < 0 {
return lr.r.Read(p)
}
if lr.n == 0 {
err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
lr.c.writeError(StatusMessageTooBig, err)
return 0, err
}
if int64(len(p)) > lr.n {
p = p[:lr.n]
}
n, err := lr.r.Read(p)
lr.n -= int64(n)
if lr.n < 0 {
lr.n = 0
}
return n, err
}
// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT.
// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT.
package websocket
......
package websocket_test
import (
"context"
"net/http"
"os"
"testing"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/assert"
)
func TestConn(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
if err != nil {
t.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "")
err = assertSubprotocol(c, "echo")
if err != nil {
t.Fatal(err)
}
err = assert.Equalf(&http.Response{}, resp, "unexpected http response")
if err != nil {
t.Fatal(err)
}
err = assertJSONEcho(ctx, c, 1024)
if err != nil {
t.Fatal(err)
}
err = assertEcho(ctx, c, websocket.MessageBinary, 1024)
if err != nil {
t.Fatal(err)
}
err = c.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatal(err)
}
}
//go:build !js
// +build !js
package websocket
import (
"bufio"
"compress/flate"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"time"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Writer returns a writer bounded by the context that will write
// a WebSocket message of type dataType to the connection.
//
// You must close the writer once you have written the entire message.
//
// Only one writer can be open at a time, multiple calls will block until the previous writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
w, err := c.writer(ctx, typ)
if err != nil {
return nil, fmt.Errorf("failed to get writer: %w", err)
}
return w, nil
}
// Write writes a message to the connection.
//
// See the Writer method if you want to stream a message.
//
// If compression is disabled or the compression threshold is not met, then it
// will write the message in a single frame.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
_, err := c.write(ctx, typ, p)
if err != nil {
return fmt.Errorf("failed to write msg: %w", err)
}
return nil
}
type msgWriter struct {
c *Conn
mu *mu
writeMu *mu
closed bool
ctx context.Context
opcode opcode
flate bool
trimWriter *trimLastFourBytesWriter
flateWriter *flate.Writer
}
func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriter{
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write),
}
}
if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter)
}
mw.flate = true
}
func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client {
return !mw.c.copts.clientNoContextTakeover
}
return !mw.c.copts.serverNoContextTakeover
}
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriter.reset(ctx, typ)
if err != nil {
return nil, err
}
return c.msgWriter, nil
}
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
mw, err := c.writer(ctx, typ)
if err != nil {
return 0, err
}
if !c.flate() {
defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
}
n, err := mw.Write(p)
if err != nil {
return n, err
}
err = mw.Close()
return n, err
}
func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx)
if err != nil {
return err
}
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
mw.closed = false
mw.trimWriter.reset()
return nil
}
func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter)
mw.flateWriter = nil
}
}
// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriter) Write(p []byte) (_ int, err error) {
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
}
}()
if mw.c.flate() {
// Only enables flate if the length crosses the
// threshold on the first frame
if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold {
mw.ensureFlate()
}
}
if mw.flate {
return mw.flateWriter.Write(p)
}
return mw.write(p)
}
func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err)
}
mw.opcode = opContinuation
return n, nil
}
// Close flushes the frame to the connection.
func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
if err != nil {
return fmt.Errorf("failed to flush flate: %w", err)
}
}
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return fmt.Errorf("failed to write fin frame: %w", err)
}
if mw.flate && !mw.flateContextTakeover() {
mw.putFlateWriter()
}
mw.mu.unlock()
return nil
}
func (mw *msgWriter) close() {
if mw.c.client {
mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw)
}
mw.writeMu.forceLock()
mw.putFlateWriter()
}
func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error {
ctx, cancel := context.WithTimeout(ctx, time.Second*5)
defer cancel()
_, err := c.writeFrame(ctx, true, false, opcode, p)
if err != nil {
return fmt.Errorf("failed to write control frame %v: %w", opcode, err)
}
return nil
}
// writeFrame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
defer c.writeFrameMu.unlock()
defer func() {
if c.isClosed() && opcode == opClose {
err = nil
}
if err != nil {
if ctx.Err() != nil {
err = ctx.Err()
} else if c.isClosed() {
err = net.ErrClosed
}
err = fmt.Errorf("failed to write frame: %w", err)
}
}()
c.closeStateMu.Lock()
closeSentErr := c.closeSentErr
c.closeStateMu.Unlock()
if closeSentErr != nil {
return 0, net.ErrClosed
}
select {
case <-c.closed:
return 0, net.ErrClosed
case c.writeTimeout <- ctx:
}
defer func() {
select {
case <-c.closed:
case c.writeTimeout <- context.Background():
}
}()
c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
if c.client {
c.writeHeader.masked = true
_, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4])
if err != nil {
return 0, fmt.Errorf("failed to generate masking key: %w", err)
}
c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:])
}
c.writeHeader.rsv1 = false
if flate && (opcode == opText || opcode == opBinary) {
c.writeHeader.rsv1 = true
}
err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:])
if err != nil {
return 0, err
}
n, err := c.writeFramePayload(p)
if err != nil {
return n, err
}
if c.writeHeader.fin {
err = c.bw.Flush()
if err != nil {
return n, fmt.Errorf("failed to flush: %w", err)
}
}
if opcode == opClose {
c.closeStateMu.Lock()
c.closeSentErr = fmt.Errorf("sent close frame: %w", net.ErrClosed)
closeReceived := c.closeReceivedErr != nil
c.closeStateMu.Unlock()
if closeReceived && !c.casClosing() {
c.writeFrameMu.unlock()
_ = c.close()
}
}
return n, nil
}
func (c *Conn) writeFramePayload(p []byte) (n int, err error) {
defer errd.Wrap(&err, "failed to write frame payload")
if !c.writeHeader.masked {
return c.bw.Write(p)
}
maskKey := c.writeHeader.maskKey
for len(p) > 0 {
// If the buffer is full, we need to flush.
if c.bw.Available() == 0 {
err = c.bw.Flush()
if err != nil {
return n, err
}
}
// Start of next write in the buffer.
i := c.bw.Buffered()
j := len(p)
if j > c.bw.Available() {
j = c.bw.Available()
}
_, err := c.bw.Write(p[:j])
if err != nil {
return n, err
}
maskKey = mask(c.writeBuf[i:c.bw.Buffered()], maskKey)
p = p[j:]
n += j
}
return n, nil
}
// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer
// and returns it.
func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte {
var writeBuf []byte
bw.Reset(util.WriterFunc(func(p2 []byte) (int, error) {
writeBuf = p2[:cap(p2)]
return len(p2), nil
}))
bw.WriteByte(0)
bw.Flush()
bw.Reset(w)
return writeBuf
}
func (c *Conn) writeError(code StatusCode, err error) {
c.writeClose(code, err.Error())
}
package websocket // import "nhooyr.io/websocket"
package websocket // import "github.com/coder/websocket"
import (
"bytes"
......@@ -6,25 +6,51 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"syscall/js"
"nhooyr.io/websocket/internal/bpool"
"nhooyr.io/websocket/internal/wsjs"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/wsjs"
)
// opcode represents a WebSocket opcode.
type opcode int
// https://tools.ietf.org/html/rfc6455#section-11.8.
const (
opContinuation opcode = iota
opText
opBinary
// 3 - 7 are reserved for further non-control frames.
_
_
_
_
_
opClose
opPing
opPong
// 11-16 are reserved for further control frames.
)
// Conn provides a wrapper around the browser WebSocket API.
type Conn struct {
ws wsjs.WebSocket
noCopy noCopy
ws wsjs.WebSocket
// read limit for a message in bytes.
msgReadLimit *atomicInt64
msgReadLimit atomic.Int64
closeReadMu sync.Mutex
closeReadCtx context.Context
closingMu sync.Mutex
isReadClosed *atomicInt64
closeOnce sync.Once
closed chan struct{}
closeErrOnce sync.Once
......@@ -32,6 +58,7 @@ type Conn struct {
closeWasClean bool
releaseOnClose func()
releaseOnError func()
releaseOnMessage func()
readSignal chan struct{}
......@@ -56,22 +83,28 @@ func (c *Conn) init() {
c.closed = make(chan struct{})
c.readSignal = make(chan struct{}, 1)
c.msgReadLimit = &atomicInt64{}
c.msgReadLimit.Store(32768)
c.isReadClosed = &atomicInt64{}
c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
err := CloseError{
Code: StatusCode(e.Code),
Reason: e.Reason,
}
c.close(fmt.Errorf("received close: %w", err), e.WasClean)
// We do not know if we sent or received this close as
// its possible the browser triggered it without us
// explicitly sending it.
c.close(err, e.WasClean)
c.releaseOnClose()
c.releaseOnError()
c.releaseOnMessage()
})
c.releaseOnError = c.ws.OnError(func(v js.Value) {
c.setCloseErr(errors.New(v.Get("message").String()))
c.closeWithInternal()
})
c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
c.readBufMu.Lock()
defer c.readBufMu.Unlock()
......@@ -98,16 +131,20 @@ func (c *Conn) closeWithInternal() {
// Read attempts to read a message from the connection.
// The maximum time spent waiting is bounded by the context.
func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
if c.isReadClosed.Load() == 1 {
return 0, nil, fmt.Errorf("websocket connection read closed")
c.closeReadMu.Lock()
closedRead := c.closeReadCtx != nil
c.closeReadMu.Unlock()
if closedRead {
return 0, nil, errors.New("WebSocket connection read closed")
}
typ, p, err := c.read(ctx)
if err != nil {
return 0, nil, fmt.Errorf("failed to read: %w", err)
}
if int64(len(p)) > c.msgReadLimit.Load() {
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit)
readLimit := c.msgReadLimit.Load()
if readLimit >= 0 && int64(len(p)) > readLimit {
err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
c.Close(StatusMessageTooBig, err.Error())
return 0, nil, err
}
......@@ -121,7 +158,7 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
return 0, nil, ctx.Err()
case <-c.readSignal:
case <-c.closed:
return 0, nil, c.closeErr
return 0, nil, net.ErrClosed
}
c.readBufMu.Lock()
......@@ -151,6 +188,11 @@ func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
}
}
// Ping is mocked out for Wasm.
func (c *Conn) Ping(ctx context.Context) error {
return nil
}
// Write writes a message of the given type to the connection.
// Always non blocking.
func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
......@@ -170,7 +212,7 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
if c.isClosed() {
return c.closeErr
return net.ErrClosed
}
switch typ {
case MessageBinary:
......@@ -182,31 +224,40 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
}
}
// Close closes the websocket with the given code and reason.
// Close closes the WebSocket with the given code and reason.
// It will wait until the peer responds with a close frame
// or the connection is closed.
// It thus performs the full WebSocket close handshake.
func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason)
if err != nil {
return fmt.Errorf("failed to close websocket: %w", err)
return fmt.Errorf("failed to close WebSocket: %w", err)
}
return nil
}
// CloseNow closes the WebSocket connection without attempting a close handshake.
// Use when you do not want the overhead of the close handshake.
//
// note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
// a WebSocket without the close handshake.
func (c *Conn) CloseNow() error {
return c.Close(StatusGoingAway, "")
}
func (c *Conn) exportedClose(code StatusCode, reason string) error {
c.closingMu.Lock()
defer c.closingMu.Unlock()
if c.isClosed() {
return net.ErrClosed
}
ce := fmt.Errorf("sent close: %w", CloseError{
Code: code,
Reason: reason,
})
if c.isClosed() {
return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr)
}
c.setCloseErr(ce)
err := c.ws.Close(int(code), reason)
if err != nil {
......@@ -234,12 +285,12 @@ type DialOptions struct {
// Dial creates a new WebSocket connection to the given url with the given options.
// The passed context bounds the maximum time spent waiting for the connection to open.
// The returned *http.Response is always nil or the zero value. It's only in the signature
// The returned *http.Response is always nil or a mock. It's only in the signature
// to match the core API.
func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
c, resp, err := dial(ctx, url, opts)
if err != nil {
return nil, resp, fmt.Errorf("failed to websocket dial %q: %w", url, err)
return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
}
return c, resp, nil
}
......@@ -249,6 +300,9 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp
opts = &DialOptions{}
}
url = strings.Replace(url, "http://", "ws://", 1)
url = strings.Replace(url, "https://", "wss://", 1)
ws, err := wsjs.New(url, opts.Subprotocols)
if err != nil {
return nil, nil, err
......@@ -270,12 +324,12 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Resp
c.Close(StatusPolicyViolation, "dial timed out")
return nil, nil, ctx.Err()
case <-opench:
return c, &http.Response{
StatusCode: http.StatusSwitchingProtocols,
}, nil
case <-c.closed:
return c, nil, c.closeErr
return nil, nil, net.ErrClosed
}
// Have to return a non nil response as the normal API does that.
return c, &http.Response{}, nil
}
// Reader attempts to read a message from the connection.
......@@ -288,16 +342,11 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
return typ, bytes.NewReader(p), nil
}
// Only implemented for use by *Conn.CloseRead in conn_common.go
func (c *Conn) reader(ctx context.Context, _ bool) {
c.read(ctx)
}
// Writer returns a writer to write a WebSocket data message to the connection.
// It buffers the entire message in memory and then sends it when the writer
// is closed.
func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
return writer{
return &writer{
c: c,
ctx: ctx,
typ: typ,
......@@ -315,7 +364,7 @@ type writer struct {
b *bytes.Buffer
}
func (w writer) Write(p []byte) (int, error) {
func (w *writer) Write(p []byte) (int, error) {
if w.closed {
return 0, errors.New("cannot write to closed writer")
}
......@@ -326,7 +375,7 @@ func (w writer) Write(p []byte) (int, error) {
return n, nil
}
func (w writer) Close() error {
func (w *writer) Close() error {
if w.closed {
return errors.New("cannot close closed writer")
}
......@@ -339,3 +388,211 @@ func (w writer) Close() error {
}
return nil
}
// CloseRead implements *Conn.CloseRead for wasm.
func (c *Conn) CloseRead(ctx context.Context) context.Context {
c.closeReadMu.Lock()
ctx2 := c.closeReadCtx
if ctx2 != nil {
c.closeReadMu.Unlock()
return ctx2
}
ctx, cancel := context.WithCancel(ctx)
c.closeReadCtx = ctx
c.closeReadMu.Unlock()
go func() {
defer cancel()
defer c.CloseNow()
_, _, err := c.read(ctx)
if err != nil {
c.Close(StatusPolicyViolation, "unexpected data message")
}
}()
return ctx
}
// SetReadLimit implements *Conn.SetReadLimit for wasm.
func (c *Conn) SetReadLimit(n int64) {
c.msgReadLimit.Store(n)
}
func (c *Conn) setCloseErr(err error) {
c.closeErrOnce.Do(func() {
c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
})
}
func (c *Conn) isClosed() bool {
select {
case <-c.closed:
return true
default:
return false
}
}
// AcceptOptions represents Accept's options.
type AcceptOptions struct {
Subprotocols []string
InsecureSkipVerify bool
OriginPatterns []string
CompressionMode CompressionMode
CompressionThreshold int
}
// Accept is stubbed out for Wasm.
func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
return nil, errors.New("unimplemented")
}
// StatusCode represents a WebSocket status code.
// https://tools.ietf.org/html/rfc6455#section-7.4
type StatusCode int
// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
//
// These are only the status codes defined by the protocol.
//
// You can define custom codes in the 3000-4999 range.
// The 3000-3999 range is reserved for use by libraries, frameworks and applications.
// The 4000-4999 range is reserved for private use.
const (
StatusNormalClosure StatusCode = 1000
StatusGoingAway StatusCode = 1001
StatusProtocolError StatusCode = 1002
StatusUnsupportedData StatusCode = 1003
// 1004 is reserved and so unexported.
statusReserved StatusCode = 1004
// StatusNoStatusRcvd cannot be sent in a close message.
// It is reserved for when a close message is received without
// a status code.
StatusNoStatusRcvd StatusCode = 1005
// StatusAbnormalClosure is exported for use only with Wasm.
// In non Wasm Go, the returned error will indicate whether the
// connection was closed abnormally.
StatusAbnormalClosure StatusCode = 1006
StatusInvalidFramePayloadData StatusCode = 1007
StatusPolicyViolation StatusCode = 1008
StatusMessageTooBig StatusCode = 1009
StatusMandatoryExtension StatusCode = 1010
StatusInternalError StatusCode = 1011
StatusServiceRestart StatusCode = 1012
StatusTryAgainLater StatusCode = 1013
StatusBadGateway StatusCode = 1014
// StatusTLSHandshake is only exported for use with Wasm.
// In non Wasm Go, the returned error will indicate whether there was
// a TLS handshake failure.
StatusTLSHandshake StatusCode = 1015
)
// CloseError is returned when the connection is closed with a status and reason.
//
// Use Go 1.13's errors.As to check for this error.
// Also see the CloseStatus helper.
type CloseError struct {
Code StatusCode
Reason string
}
func (ce CloseError) Error() string {
return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
}
// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
// the status code from a CloseError.
//
// -1 will be returned if the passed error is nil or not a CloseError.
func CloseStatus(err error) StatusCode {
var ce CloseError
if errors.As(err, &ce) {
return ce.Code
}
return -1
}
// CompressionMode represents the modes available to the deflate extension.
// See https://tools.ietf.org/html/rfc7692
// Works in all browsers except Safari which does not implement the deflate extension.
type CompressionMode int
const (
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
// for every message. This applies to both server and client side.
//
// This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used.
//
// The message will only be compressed if greater than 512 bytes.
CompressionNoContextTakeover CompressionMode = iota
// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
// This enables reusing the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
//
// If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC.
CompressionContextTakeover
// CompressionDisabled disables the deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
CompressionDisabled
)
// MessageType represents the type of a WebSocket message.
// See https://tools.ietf.org/html/rfc6455#section-5.6
type MessageType int
// MessageType constants.
const (
// MessageText is for UTF-8 encoded text messages like JSON.
MessageText MessageType = iota + 1
// MessageBinary is for binary messages like protobufs.
MessageBinary
)
type mu struct {
c *Conn
ch chan struct{}
}
func newMu(c *Conn) *mu {
return &mu{
c: c,
ch: make(chan struct{}, 1),
}
}
func (m *mu) forceLock() {
m.ch <- struct{}{}
}
func (m *mu) tryLock() bool {
select {
case m.ch <- struct{}{}:
return true
default:
return false
}
}
func (m *mu) unlock() {
select {
case <-m.ch:
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
package websocket_test
import (
"context"
"net/http"
"os"
"testing"
"time"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
)
func TestWasm(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
assert.Equal(t, "subprotocol", "echo", c.Subprotocol())
assert.Equal(t, "response code", http.StatusSwitchingProtocols, resp.StatusCode)
c.SetReadLimit(65536)
for i := 0; i < 10; i++ {
err = wstest.Echo(ctx, c, 65536)
assert.Success(t, err)
}
err = c.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
}
func TestWasmDialTimeout(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
beforeDial := time.Now()
_, _, err := websocket.Dial(ctx, "ws://example.com:9893", &websocket.DialOptions{
Subprotocols: []string{"echo"},
})
assert.Error(t, err)
if time.Since(beforeDial) >= time.Second {
t.Fatal("wasm context dial timeout is not working", time.Since(beforeDial))
}
}
// Package wsjson provides websocket helpers for JSON messages.
package wsjson // import "nhooyr.io/websocket/wsjson"
// Package wsjson provides helpers for reading and writing JSON messages.
package wsjson // import "github.com/coder/websocket/wsjson"
import (
"context"
"encoding/json"
"fmt"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/bpool"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/util"
)
// Read reads a json message from c into v.
// It will reuse buffers to avoid allocations.
// Read reads a JSON message from c into v.
// It will reuse buffers in between calls to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := read(ctx, c, v)
if err != nil {
return fmt.Errorf("failed to read json: %w", err)
}
return nil
return read(ctx, c, v)
}
func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
typ, r, err := c.Reader(ctx)
func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to read JSON message")
_, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageText {
c.Close(websocket.StatusUnsupportedData, "can only accept text messages")
return fmt.Errorf("unexpected frame type for json (expected %v): %v", websocket.MessageText, typ)
}
b := bpool.Get()
defer bpool.Put(b)
......@@ -42,39 +37,32 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error {
err = json.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON")
return fmt.Errorf("failed to unmarshal json: %w", err)
return fmt.Errorf("failed to unmarshal JSON: %w", err)
}
return nil
}
// Write writes the json message v to c.
// It will reuse buffers to avoid allocations.
// Write writes the JSON message v to c.
// It will reuse buffers in between calls to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
err := write(ctx, c, v)
if err != nil {
return fmt.Errorf("failed to write json: %w", err)
}
return nil
return write(ctx, c, v)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) error {
w, err := c.Writer(ctx, websocket.MessageText)
if err != nil {
return err
}
// We use Encode because it automatically enables buffer reuse without us
// needing to do anything. Though see https://github.com/golang/go/issues/27735
e := json.NewEncoder(w)
err = e.Encode(v)
if err != nil {
return fmt.Errorf("failed to encode json: %w", err)
}
func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
defer errd.Wrap(&err, "failed to write JSON message")
err = w.Close()
// json.Marshal cannot reuse buffers between calls as it has to return
// a copy of the byte slice but Encoder does as it directly writes to w.
err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) {
err := c.Write(ctx, websocket.MessageText, p)
if err != nil {
return 0, err
}
return len(p), nil
})).Encode(v)
if err != nil {
return err
return fmt.Errorf("failed to marshal JSON: %w", err)
}
return nil
}
package wsjson_test
import (
"encoding/json"
"io"
"strconv"
"testing"
"github.com/coder/websocket/internal/test/xrand"
)
func BenchmarkJSON(b *testing.B) {
sizes := []int{
8,
16,
32,
128,
256,
512,
1024,
2048,
4096,
8192,
16384,
}
b.Run("json.Encoder", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.NewEncoder(io.Discard).Encode(msg)
}
})
}
})
b.Run("json.Marshal", func(b *testing.B) {
for _, size := range sizes {
b.Run(strconv.Itoa(size), func(b *testing.B) {
msg := xrand.String(size)
b.SetBytes(int64(size))
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
json.Marshal(msg)
}
})
}
})
}
// Package wspb provides websocket helpers for protobuf messages.
package wspb // import "nhooyr.io/websocket/wspb"
import (
"bytes"
"context"
"fmt"
"github.com/golang/protobuf/proto"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/bpool"
)
// Read reads a protobuf message from c into v.
// It will reuse buffers to avoid allocations.
func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := read(ctx, c, v)
if err != nil {
return fmt.Errorf("failed to read protobuf: %w", err)
}
return nil
}
func read(ctx context.Context, c *websocket.Conn, v proto.Message) error {
typ, r, err := c.Reader(ctx)
if err != nil {
return err
}
if typ != websocket.MessageBinary {
c.Close(websocket.StatusUnsupportedData, "can only accept binary messages")
return fmt.Errorf("unexpected frame type for protobuf (expected %v): %v", websocket.MessageBinary, typ)
}
b := bpool.Get()
defer bpool.Put(b)
_, err = b.ReadFrom(r)
if err != nil {
return err
}
err = proto.Unmarshal(b.Bytes(), v)
if err != nil {
c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal protobuf")
return fmt.Errorf("failed to unmarshal protobuf: %w", err)
}
return nil
}
// Write writes the protobuf message v to c.
// It will reuse buffers to avoid allocations.
func Write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
err := write(ctx, c, v)
if err != nil {
return fmt.Errorf("failed to write protobuf: %w", err)
}
return nil
}
func write(ctx context.Context, c *websocket.Conn, v proto.Message) error {
b := bpool.Get()
pb := proto.NewBuffer(b.Bytes())
defer func() {
bpool.Put(bytes.NewBuffer(pb.Bytes()))
}()
err := pb.Marshal(v)
if err != nil {
return fmt.Errorf("failed to marshal protobuf: %w", err)
}
return c.Write(ctx, websocket.MessageBinary, pb.Bytes())
}