good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
......@@ -9,42 +9,44 @@ import (
"sync"
)
// CompressionMode represents the modes available to the deflate extension.
// CompressionMode represents the modes available to the permessage-deflate extension.
// See https://tools.ietf.org/html/rfc7692
//
// Works in all browsers except Safari which does not implement the deflate extension.
// Works in all modern browsers except Safari which does not implement the permessage-deflate extension.
//
// Compression is only used if the peer supports the mode selected.
type CompressionMode int
const (
// CompressionDisabled disables the deflate extension.
//
// Use this if you are using a predominantly binary protocol with very
// little duplication in between messages or CPU and memory are more
// important than bandwidth.
// CompressionDisabled disables the negotiation of the permessage-deflate extension.
//
// This is the default.
// This is the default. Do not enable compression without benchmarking for your particular use case first.
CompressionDisabled CompressionMode = iota
// CompressionContextTakeover uses a 32 kB sliding window and flate.Writer per connection.
// It reuses the sliding window from previous messages.
// As most WebSocket protocols are repetitive, this can be very efficient.
// It carries an overhead of 32 kB + 1.2 MB for every connection compared to CompressionNoContextTakeover.
// CompressionContextTakeover compresses each message greater than 128 bytes reusing the 32 KB sliding window from
// previous messages. i.e compression context across messages is preserved.
//
// As most WebSocket protocols are text based and repetitive, this compression mode can be very efficient.
//
// Sometime in the future it will carry 65 kB overhead instead once https://github.com/golang/go/issues/36919
// is fixed.
// The memory overhead is a fixed 32 KB sliding window, a fixed 1.2 MB flate.Writer and a sync.Pool of 40 KB flate.Reader's
// that are used when reading and then returned.
//
// If the peer negotiates NoContextTakeover on the client or server side, it will be
// used instead as this is required by the RFC.
// Thus, it uses more memory than CompressionNoContextTakeover but compresses more efficiently.
//
// If the peer does not support CompressionContextTakeover then we will fall back to CompressionNoContextTakeover.
CompressionContextTakeover
// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
// for every message. This applies to both server and client side.
// CompressionNoContextTakeover compresses each message greater than 512 bytes. Each message is compressed with
// a new 1.2 MB flate.Writer pulled from a sync.Pool. Each message is read with a 40 KB flate.Reader pulled from
// a sync.Pool.
//
// This means less efficient compression as the sliding window from previous messages will not be used but the
// memory overhead will be lower as there will be no fixed cost for the flate.Writer nor the 32 KB sliding window.
// Especially if the connections are long lived and seldom written to.
//
// This means less efficient compression as the sliding window from previous messages
// will not be used but the memory overhead will be lower if the connections
// are long lived and seldom used.
// Thus, it uses less memory than CompressionContextTakeover but compresses less efficiently.
//
// The message will only be compressed if greater than 512 bytes.
// If the peer does not support CompressionNoContextTakeover then we will fall back to CompressionDisabled.
CompressionNoContextTakeover
)
......
......@@ -4,11 +4,14 @@
package websocket
import (
"bytes"
"compress/flate"
"io"
"strings"
"testing"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/xrand"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/xrand"
)
func Test_slidingWindow(t *testing.T) {
......@@ -33,3 +36,27 @@ func Test_slidingWindow(t *testing.T) {
})
}
}
func BenchmarkFlateWriter(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
w, _ := flate.NewWriter(io.Discard, flate.BestSpeed)
// We have to write a byte to get the writer to allocate to its full extent.
w.Write([]byte{'a'})
w.Flush()
}
}
func BenchmarkFlateReader(b *testing.B) {
b.ReportAllocs()
var buf bytes.Buffer
w, _ := flate.NewWriter(&buf, flate.BestSpeed)
w.Write([]byte{'a'})
w.Flush()
for i := 0; i < b.N; i++ {
r := flate.NewReader(bytes.NewReader(buf.Bytes()))
io.ReadAll(r)
}
}
......@@ -6,9 +6,9 @@ package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
......@@ -42,7 +42,7 @@ const (
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
......@@ -52,31 +52,42 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
readCloseFrameErr error
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriterState *msgWriterState
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error
// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}
type connConfig struct {
......@@ -85,6 +96,8 @@ type connConfig struct {
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
br *bufio.Reader
bw *bufio.Writer
......@@ -101,11 +114,14 @@ func newConn(cfg connConfig) *Conn {
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}
c.readMu = newMu(c)
......@@ -113,20 +129,20 @@ func newConn(cfg connConfig) *Conn {
c.msgReader = newMsgReader(c)
c.msgWriterState = newMsgWriterState(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() {
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
c.close()
})
go c.timeoutLoop()
......@@ -140,30 +156,29 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close(err error) {
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return
return net.ErrClosed
}
c.setCloseErrLocked(err)
close(c.closed)
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.rwc.Close()
go func() {
c.msgWriterState.close()
c.msgReader.close()
}()
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
......@@ -176,10 +191,10 @@ func (c *Conn) timeoutLoop() {
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
c.close()
return
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
c.close()
return
}
}
......@@ -197,9 +212,9 @@ func (c *Conn) flate() bool {
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1)
p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.Itoa(int(p)))
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
......@@ -226,11 +241,9 @@ func (c *Conn) ping(ctx context.Context, p string) error {
select {
case <-c.closed:
return c.closeErr
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
......@@ -264,11 +277,9 @@ func (m *mu) tryLock() bool {
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return m.c.closeErr
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
......@@ -277,7 +288,7 @@ func (m *mu) lock(ctx context.Context) error {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return m.c.closeErr
return net.ErrClosed
default:
}
return nil
......
......@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"net"
"net/http"
"net/http/httptest"
"os"
......@@ -16,13 +17,13 @@ import (
"testing"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
"nhooyr.io/websocket/internal/test/xrand"
"nhooyr.io/websocket/internal/xsync"
"nhooyr.io/websocket/wsjson"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/test/xrand"
"github.com/coder/websocket/internal/xsync"
"github.com/coder/websocket/wsjson"
)
func TestConn(t *testing.T) {
......@@ -96,6 +97,85 @@ func TestConn(t *testing.T) {
assert.Contains(t, err, "failed to wait for pong")
})
t.Run("pingReceivedPongReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Success(t, err)
c1.CloseNow()
c2.CloseNow()
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1))
})
t.Run("pingReceivedPongNotReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")
c1.CloseNow()
c2.CloseNow()
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1))
})
t.Run("concurrentWrite", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
......@@ -236,6 +316,18 @@ func TestConn(t *testing.T) {
assert.Equal(t, "read msg", s, string(b))
})
t.Run("netConn/pastDeadline", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
n1.SetDeadline(time.Now().Add(-time.Minute))
n2.SetDeadline(time.Now().Add(-time.Minute))
// No panic we're good.
})
t.Run("wsjson", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
......@@ -295,10 +387,47 @@ func TestConn(t *testing.T) {
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
t.Run("CloseNow", func(t *testing.T) {
_, c1, c2 := newConnTest(t, nil, nil)
err1 := c1.CloseNow()
err2 := c2.CloseNow()
assert.Success(t, err1)
assert.Success(t, err2)
err1 = c1.CloseNow()
err2 = c2.CloseNow()
assert.ErrorIs(t, websocket.ErrClosed, err1)
assert.ErrorIs(t, websocket.ErrClosed, err2)
})
t.Run("MidReadClose", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(131072)
for i := 0; i < 5; i++ {
err := wstest.Echo(tt.ctx, c1, 131072)
assert.Success(t, err)
}
err := wsjson.Write(tt.ctx, c1, "four")
assert.Success(t, err)
_, _, err = c1.Reader(tt.ctx)
assert.Success(t, err)
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
}
func TestWasm(t *testing.T) {
t.Parallel()
if os.Getenv("CI") == "" {
t.SkipNow()
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, &websocket.AcceptOptions{
......@@ -314,8 +443,8 @@ func TestWasm(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v")
cmd.Env = append(cleanEnv(os.Environ()), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
b, err := cmd.CombinedOutput()
if err != nil {
......@@ -323,6 +452,18 @@ func TestWasm(t *testing.T) {
}
}
func cleanEnv(env []string) (out []string) {
for _, e := range env {
// Filter out GITHUB envs and anything with token in it,
// especially GITHUB_TOKEN in CI as it breaks TestWasm.
if strings.HasPrefix(e, "GITHUB") || strings.Contains(e, "TOKEN") {
continue
}
out = append(out, e)
}
return out
}
func assertCloseStatus(exp websocket.StatusCode, err error) error {
if websocket.CloseStatus(err) == -1 {
return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
......@@ -353,10 +494,8 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs
c1, c2 = c2, c1
}
t.Cleanup(func() {
// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
// blocking the test shutting down.
go c2.Close(websocket.StatusInternalError, "")
go c1.Close(websocket.StatusInternalError, "")
c2.CloseNow()
c1.CloseNow()
})
return tt, c1, c2
......@@ -401,7 +540,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
}
func BenchmarkConn(b *testing.B) {
var benchCases = []struct {
benchCases := []struct {
name string
mode websocket.CompressionMode
}{
......@@ -513,3 +652,202 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp
err = wstest.EchoLoop(r.Context(), c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
}
func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(ctx, c, exp)
})
var act interface{}
c.SetReadLimit(1 << 30)
err := wsjson.Read(ctx, c, &act)
assert.Success(tb, err)
assert.Equal(tb, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(tb, err)
case <-ctx.Done():
tb.Fatal(ctx.Err())
}
}
func assertClose(tb testing.TB, c *websocket.Conn) {
tb.Helper()
err := c.Close(websocket.StatusNormalClosure, "")
assert.Success(tb, err)
}
func TestConcurrentClosePing(t *testing.T) {
t.Parallel()
for i := 0; i < 64; i++ {
func() {
c1, c2 := wstest.Pipe(nil, nil)
defer c1.CloseNow()
defer c2.CloseNow()
c1.CloseRead(context.Background())
c2.CloseRead(context.Background())
errc := xsync.Go(func() error {
for range time.Tick(time.Millisecond) {
err := c1.Ping(context.Background())
if err != nil {
return err
}
}
panic("unreachable")
})
time.Sleep(10 * time.Millisecond)
assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
<-errc
}()
}
}
func TestConnClosePropagation(t *testing.T) {
t.Parallel()
want := []byte("hello")
keepWriting := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
err := c.Write(context.Background(), websocket.MessageText, want)
if err != nil {
return err
}
}
})
}
keepReading := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
_, got, err := c.Read(context.Background())
if err != nil {
return err
}
if !bytes.Equal(want, got) {
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
}
}
})
}
checkReadErr := func(t *testing.T, err error) {
// Check read error (output depends on when read is called in relation to connection closure).
var ce websocket.CloseError
if errors.As(err, &ce) {
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
} else {
assert.ErrorIs(t, net.ErrClosed, err)
}
}
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
for _, c := range conn {
// Check write error.
err := c.Write(context.Background(), websocket.MessageText, want)
assert.ErrorIs(t, net.ErrClosed, err)
_, _, err = c.Read(context.Background())
checkReadErr(t, err)
}
}
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
_, got, err := other.Read(tt.ctx)
assert.Success(t, err)
assert.Equal(t, "msg", want, got)
err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
otherReadErr := keepReading(other)
err := this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = other.CloseRead(tt.ctx)
errs := keepReading(this)
err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)
err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-errs:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
thisReadErr := keepReading(this)
otherReadErr := keepReading(other)
err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)
err = this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
}
......@@ -17,7 +17,7 @@ import (
"sync"
"time"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
)
// DialOptions represents Dial's options.
......@@ -48,6 +48,22 @@ type DialOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
......@@ -70,6 +86,21 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
newClient := *o.HTTPClient
oldCheckRedirect := o.HTTPClient.CheckRedirect
newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect(req, via)
}
return nil
}
o.HTTPClient = &newClient
return ctx, cancel, &o
}
......@@ -148,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
......
//go:build !js
// +build !js
package websocket
package websocket_test
import (
"bytes"
......@@ -10,12 +10,15 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/util"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/util"
"github.com/coder/websocket/internal/xsync"
)
func TestBadDials(t *testing.T) {
......@@ -27,7 +30,7 @@ func TestBadDials(t *testing.T) {
testCases := []struct {
name string
url string
opts *DialOptions
opts *websocket.DialOptions
rand util.ReaderFunc
nilCtx bool
}{
......@@ -72,7 +75,7 @@ func TestBadDials(t *testing.T) {
tc.rand = rand.Reader.Read
}
_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
assert.Error(t, err)
})
}
......@@ -84,7 +87,7 @@ func TestBadDials(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
return &http.Response{
Body: io.NopCloser(strings.NewReader("hi")),
......@@ -104,7 +107,7 @@ func TestBadDials(t *testing.T) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
......@@ -113,7 +116,7 @@ func TestBadDials(t *testing.T) {
}, nil
}
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
})
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
......@@ -152,7 +155,7 @@ func Test_verifyHostOverride(t *testing.T) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
......@@ -161,11 +164,12 @@ func Test_verifyHostOverride(t *testing.T) {
}, nil
}
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
assert.Success(t, err)
c.CloseNow()
})
}
......@@ -272,18 +276,18 @@ func Test_verifyServerHandshake(t *testing.T) {
resp := w.Result()
r := httptest.NewRequest("GET", "/", nil)
key, err := secWebSocketKey(rand.Reader)
key, err := websocket.SecWebSocketKey(rand.Reader)
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
}
opts := &DialOptions{
opts := &websocket.DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
if tc.success {
assert.Success(t, err)
} else {
......@@ -304,3 +308,113 @@ type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func TestDialRedirect(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
resp := &http.Response{
Header: http.Header{},
}
if r.URL.Scheme != "https" {
resp.Header.Set("Location", "wss://example.com")
resp.StatusCode = http.StatusFound
return resp, nil
}
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "meow")
resp.StatusCode = http.StatusSwitchingProtocols
return resp, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
}
type forwardProxy struct {
hc *http.Client
}
func newForwardProxy() *forwardProxy {
return &forwardProxy{
hc: &http.Client{},
}
}
func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
r = r.WithContext(ctx)
r.RequestURI = ""
resp, err := fc.hc.Do(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer resp.Body.Close()
for k, v := range resp.Header {
w.Header()[k] = v
}
w.Header().Set("PROXIED", "true")
w.WriteHeader(resp.StatusCode)
if resprw, ok := resp.Body.(io.ReadWriter); ok {
c, brw, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
brw.Flush()
errc1 := xsync.Go(func() error {
_, err := io.Copy(c, resprw)
return err
})
errc2 := xsync.Go(func() error {
_, err := io.Copy(resprw, c)
return err
})
select {
case <-errc1:
case <-errc2:
case <-r.Context().Done():
}
} else {
io.Copy(w, resp.Body)
}
}
func TestDialViaProxy(t *testing.T) {
t.Parallel()
ps := httptest.NewServer(newForwardProxy())
defer ps.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, nil)
assert.Success(t, err)
}))
defer s.Close()
psu, err := url.Parse(ps.URL)
assert.Success(t, err)
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
proxyTransport.Proxy = http.ProxyURL(psu)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: proxyTransport,
},
})
assert.Success(t, err)
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
assertEcho(t, ctx, c)
assertClose(t, c)
}
......@@ -13,9 +13,9 @@
//
// The examples are the best way to understand how to correctly use the library.
//
// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages.
// The wsjson subpackage contain helpers for JSON and protobuf messages.
//
// More documentation at https://nhooyr.io/websocket.
// More documentation at https://github.com/coder/websocket.
//
// # Wasm
//
......@@ -28,6 +28,7 @@
//
// - Accept always errors out
// - Conn.Ping is no-op
// - Conn.CloseNow is Close(StatusGoingAway, "")
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
package websocket // import "nhooyr.io/websocket"
package websocket // import "github.com/coder/websocket"
......@@ -6,8 +6,8 @@ import (
"net/http"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
func ExampleAccept() {
......@@ -20,7 +20,7 @@ func ExampleAccept() {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
......@@ -50,7 +50,7 @@ func ExampleDial() {
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi")
if err != nil {
......@@ -71,7 +71,7 @@ func ExampleCloseStatus() {
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
_, _, err = c.Reader(ctx)
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
......@@ -88,7 +88,7 @@ func Example_writeOnly() {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10)
defer cancel()
......@@ -145,7 +145,7 @@ func ExampleConn_Ping() {
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
// Required to read the Pongs from the server.
ctx = c.CloseRead(ctx)
......@@ -162,10 +162,10 @@ func ExampleConn_Ping() {
// This example demonstrates full stack chat with an automated test.
func Example_fullStackChat() {
// https://github.com/nhooyr/websocket/tree/master/examples/chat
// https://github.com/nhooyr/websocket/tree/master/internal/examples/chat
}
// This example demonstrates a echo server.
func Example_echo() {
// https://github.com/nhooyr/websocket/tree/master/examples/echo
// https://github.com/nhooyr/websocket/tree/master/internal/examples/echo
}
......@@ -3,7 +3,11 @@
package websocket
import "nhooyr.io/websocket/internal/util"
import (
"net"
"github.com/coder/websocket/internal/util"
)
func (c *Conn) RecordBytesWritten() *int {
var bytesWritten int
......@@ -23,3 +27,12 @@ func (c *Conn) RecordBytesRead() *int {
}))
return &bytesRead
}
var ErrClosed = net.ErrClosed
var ExportedDial = dial
var SecWebSocketAccept = secWebSocketAccept
var SecWebSocketKey = secWebSocketKey
var VerifyServerResponse = verifyServerResponse
var CompressionModeOpts = CompressionMode.opts
......@@ -8,9 +8,8 @@ import (
"fmt"
"io"
"math"
"math/bits"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
)
// opcode represents a WebSocket opcode.
......@@ -172,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
return nil
}
// mask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}
// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}
// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
......@@ -13,7 +13,7 @@ import (
"testing"
"time"
"nhooyr.io/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/assert"
)
func TestHeader(t *testing.T) {
......@@ -97,7 +97,7 @@ func Test_mask(t *testing.T) {
key := []byte{0xa, 0xb, 0xc, 0xff}
key32 := binary.LittleEndian.Uint32(key)
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
gotKey32 := mask(key32, p)
gotKey32 := mask(p, key32)
expP := []byte{0, 0, 0, 0x0d, 0x6}
assert.Equal(t, "p", expP, p)
......
module nhooyr.io/websocket
module github.com/coder/websocket
go 1.19
go 1.23
//go:build !js
package websocket
import (
"net/http"
)
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
// hijacker returns the Hijacker interface of the http.ResponseWriter.
// It follows the Unwrap method of the http.ResponseWriter if available,
// matching the behavior of http.ResponseController. If the Hijacker
// interface is not found, it returns false.
//
// Since the http.ResponseController is not available in Go 1.19, and
// does not support checking the presence of the Hijacker interface,
// this function is used to provide a consistent way to check for the
// Hijacker interface across Go versions.
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t, true
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil, false
}
}
}
//go:build !js && go1.20
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
_, _, err := http.NewResponseController(w).Hijack()
assert.Contains(t, err, "haha")
hj, ok := hijacker(w)
assert.Equal(t, "hijacker found", ok, true)
_, _, err = hj.Hijack()
assert.Contains(t, err, "haha")
}
......@@ -5,15 +5,16 @@ import (
"sync"
)
var bpool sync.Pool
var bpool = sync.Pool{
New: func() any {
return &bytes.Buffer{}
},
}
// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
b := bpool.Get()
if b == nil {
return &bytes.Buffer{}
}
return b.(*bytes.Buffer)
}
......
# Chat Example
This directory contains a full stack example of a simple chat webapp using nhooyr.io/websocket.
This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket.
```bash
$ cd examples/chat
$ go run . localhost:0
listening on http://127.0.0.1:51055
listening on ws://127.0.0.1:51055
```
Visit the printed URL to submit and view broadcasted messages in a browser.
......
......@@ -5,13 +5,14 @@ import (
"errors"
"io"
"log"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"nhooyr.io/websocket"
"github.com/coder/websocket"
)
// chatServer enables broadcasting to a set of subscribers.
......@@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
cs.logf("%v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
err = cs.subscribe(r.Context(), c)
err := cs.subscribe(w, r)
if errors.Is(err, context.Canceled) {
return
}
......@@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
//
// It uses CloseRead to keep reading from the connection to process control
// messages and cancel the context if the connection drops.
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
ctx = c.CloseRead(ctx)
func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error {
var mu sync.Mutex
var c *websocket.Conn
var closed bool
s := &subscriber{
msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
mu.Lock()
defer mu.Unlock()
closed = true
if c != nil {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
}
},
}
cs.addSubscriber(s)
defer cs.deleteSubscriber(s)
c2, err := websocket.Accept(w, r, nil)
if err != nil {
return err
}
mu.Lock()
if closed {
mu.Unlock()
return net.ErrClosed
}
c = c2
mu.Unlock()
defer c.CloseNow()
ctx := c.CloseRead(context.Background())
for {
select {
case msg := <-s.msgs:
......
......@@ -14,7 +14,7 @@ import (
"golang.org/x/time/rate"
"nhooyr.io/websocket"
"github.com/coder/websocket"
)
func Test_chatServer(t *testing.T) {
......@@ -52,7 +52,7 @@ func Test_chatServer(t *testing.T) {
// 10 clients are started that send 128 different
// messages of max 128 bytes concurrently.
//
// The test verifies that every message is seen by ever client
// The test verifies that every message is seen by every client
// and no errors occur anywhere.
t.Run("concurrency", func(t *testing.T) {
t.Parallel()
......
......@@ -2,7 +2,7 @@
<html lang="en-CA">
<head>
<meta charset="UTF-8" />
<title>nhooyr.io/websocket - Chat Example</title>
<title>github.com/coder/websocket - Chat Example</title>
<meta name="viewport" content="width=device-width" />
<link href="https://unpkg.com/sanitize.css" rel="stylesheet" />
......
......@@ -31,7 +31,7 @@ func run() error {
if err != nil {
return err
}
log.Printf("listening on http://%v", l.Addr())
log.Printf("listening on ws://%v", l.Addr())
cs := newChatServer()
s := &http.Server{
......