good morning!!!!

Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • github/nhooyr/websocket
  • open/websocket
2 results
Show changes
......@@ -6,9 +6,9 @@ package websocket
import (
"bufio"
"context"
"errors"
"fmt"
"io"
"net"
"runtime"
"strconv"
"sync"
......@@ -42,6 +42,8 @@ const (
// This applies to context expirations as well unfortunately.
// See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
type Conn struct {
noCopy noCopy
subprotocol string
rwc io.ReadWriteCloser
client bool
......@@ -50,31 +52,42 @@ type Conn struct {
br *bufio.Reader
bw *bufio.Writer
readTimeout chan context.Context
writeTimeout chan context.Context
readTimeout chan context.Context
writeTimeout chan context.Context
timeoutLoopDone chan struct{}
// Read state.
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
readCloseFrameErr error
readMu *mu
readHeaderBuf [8]byte
readControlBuf [maxControlPayload]byte
msgReader *msgReader
// Write state.
msgWriterState *msgWriterState
msgWriter *msgWriter
writeFrameMu *mu
writeBuf []byte
writeHeaderBuf [8]byte
writeHeader header
closed chan struct{}
closeMu sync.Mutex
closeErr error
wroteClose bool
pingCounter int32
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
// Close handshake state.
closeStateMu sync.RWMutex
closeReceivedErr error
closeSentErr error
// CloseRead state.
closeReadMu sync.Mutex
closeReadCtx context.Context
closeReadDone chan struct{}
closing atomic.Bool
closeMu sync.Mutex // Protects following.
closed chan struct{}
pingCounter atomic.Int64
activePingsMu sync.Mutex
activePings map[string]chan<- struct{}
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
}
type connConfig struct {
......@@ -83,6 +96,8 @@ type connConfig struct {
client bool
copts *compressionOptions
flateThreshold int
onPingReceived func(context.Context, []byte) bool
onPongReceived func(context.Context, []byte)
br *bufio.Reader
bw *bufio.Writer
......@@ -99,11 +114,14 @@ func newConn(cfg connConfig) *Conn {
br: cfg.br,
bw: cfg.bw,
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
readTimeout: make(chan context.Context),
writeTimeout: make(chan context.Context),
timeoutLoopDone: make(chan struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
onPingReceived: cfg.onPingReceived,
onPongReceived: cfg.onPongReceived,
}
c.readMu = newMu(c)
......@@ -111,20 +129,20 @@ func newConn(cfg connConfig) *Conn {
c.msgReader = newMsgReader(c)
c.msgWriterState = newMsgWriterState(c)
c.msgWriter = newMsgWriter(c)
if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
}
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() {
if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512
}
}
runtime.SetFinalizer(c, func(c *Conn) {
c.close(errors.New("connection garbage collected"))
c.close()
})
go c.timeoutLoop()
......@@ -138,30 +156,29 @@ func (c *Conn) Subprotocol() string {
return c.subprotocol
}
func (c *Conn) close(err error) {
func (c *Conn) close() error {
c.closeMu.Lock()
defer c.closeMu.Unlock()
if c.isClosed() {
return
return net.ErrClosed
}
c.setCloseErrLocked(err)
close(c.closed)
runtime.SetFinalizer(c, nil)
close(c.closed)
// Have to close after c.closed is closed to ensure any goroutine that wakes up
// from the connection being closed also sees that c.closed is closed and returns
// closeErr.
c.rwc.Close()
go func() {
c.msgWriterState.close()
c.msgReader.close()
}()
err := c.rwc.Close()
// With the close of rwc, these become safe to close.
c.msgWriter.close()
c.msgReader.close()
return err
}
func (c *Conn) timeoutLoop() {
defer close(c.timeoutLoopDone)
readCtx := context.Background()
writeCtx := context.Background()
......@@ -174,10 +191,10 @@ func (c *Conn) timeoutLoop() {
case readCtx = <-c.readTimeout:
case <-readCtx.Done():
c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err()))
go c.writeError(StatusPolicyViolation, errors.New("timed out"))
c.close()
return
case <-writeCtx.Done():
c.close(fmt.Errorf("write timed out: %w", writeCtx.Err()))
c.close()
return
}
}
......@@ -195,9 +212,9 @@ func (c *Conn) flate() bool {
//
// TCP Keepalives should suffice for most use cases.
func (c *Conn) Ping(ctx context.Context) error {
p := atomic.AddInt32(&c.pingCounter, 1)
p := c.pingCounter.Add(1)
err := c.ping(ctx, strconv.Itoa(int(p)))
err := c.ping(ctx, strconv.FormatInt(p, 10))
if err != nil {
return fmt.Errorf("failed to ping: %w", err)
}
......@@ -224,11 +241,9 @@ func (c *Conn) ping(ctx context.Context, p string) error {
select {
case <-c.closed:
return c.closeErr
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to wait for pong: %w", ctx.Err())
c.close(err)
return err
return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
case <-pong:
return nil
}
......@@ -262,11 +277,9 @@ func (m *mu) tryLock() bool {
func (m *mu) lock(ctx context.Context) error {
select {
case <-m.c.closed:
return m.c.closeErr
return net.ErrClosed
case <-ctx.Done():
err := fmt.Errorf("failed to acquire lock: %w", ctx.Err())
m.c.close(err)
return err
return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
case m.ch <- struct{}{}:
// To make sure the connection is certainly alive.
// As it's possible the send on m.ch was selected
......@@ -275,7 +288,7 @@ func (m *mu) lock(ctx context.Context) error {
case <-m.c.closed:
// Make sure to release.
m.unlock()
return m.c.closeErr
return net.ErrClosed
default:
}
return nil
......@@ -288,3 +301,7 @@ func (m *mu) unlock() {
default:
}
}
type noCopy struct{}
func (*noCopy) Lock() {}
//go:build !js
// +build !js
package websocket_test
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
......@@ -17,18 +17,13 @@ import (
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/duration"
"nhooyr.io/websocket"
"nhooyr.io/websocket/internal/errd"
"nhooyr.io/websocket/internal/test/assert"
"nhooyr.io/websocket/internal/test/wstest"
"nhooyr.io/websocket/internal/test/xrand"
"nhooyr.io/websocket/internal/xsync"
"nhooyr.io/websocket/wsjson"
"nhooyr.io/websocket/wspb"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/errd"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/wstest"
"github.com/coder/websocket/internal/test/xrand"
"github.com/coder/websocket/internal/xsync"
"github.com/coder/websocket/wsjson"
)
func TestConn(t *testing.T) {
......@@ -102,6 +97,85 @@ func TestConn(t *testing.T) {
assert.Contains(t, err, "failed to wait for pong")
})
t.Run("pingReceivedPongReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return true
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Success(t, err)
c1.CloseNow()
c2.CloseNow()
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "only one side receives the pong", false, pongReceived1 && pongReceived2)
assert.Equal(t, "ping and pong received", true, (pingReceived1 && pongReceived2) || (pingReceived2 && pongReceived1))
})
t.Run("pingReceivedPongNotReceived", func(t *testing.T) {
var pingReceived1, pongReceived1 bool
var pingReceived2, pongReceived2 bool
tt, c1, c2 := newConnTest(t,
&websocket.DialOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived1 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived1 = true
},
}, &websocket.AcceptOptions{
OnPingReceived: func(ctx context.Context, payload []byte) bool {
pingReceived2 = true
return false
},
OnPongReceived: func(ctx context.Context, payload []byte) {
pongReceived2 = true
},
},
)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")
c1.CloseNow()
c2.CloseNow()
assert.Equal(t, "only one side receives the ping", false, pingReceived1 && pingReceived2)
assert.Equal(t, "ping received and pong not received", true, (pingReceived1 && !pongReceived2) || (pingReceived2 && !pongReceived1))
})
t.Run("concurrentWrite", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
......@@ -144,7 +218,9 @@ func TestConn(t *testing.T) {
defer cancel()
err = c1.Write(ctx, websocket.MessageText, []byte("x"))
assert.Equal(t, "write error", context.DeadlineExceeded, err)
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected error: %#v", err)
}
})
t.Run("netConn", func(t *testing.T) {
......@@ -159,8 +235,8 @@ func TestConn(t *testing.T) {
n1.SetDeadline(time.Time{})
assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String())
assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network())
assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String())
assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network())
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
......@@ -170,7 +246,7 @@ func TestConn(t *testing.T) {
return n2.Close()
})
b, err := ioutil.ReadAll(n1)
b, err := io.ReadAll(n1)
assert.Success(t, err)
_, err = n1.Read(nil)
......@@ -198,7 +274,7 @@ func TestConn(t *testing.T) {
return err
})
_, err := ioutil.ReadAll(n1)
_, err := io.ReadAll(n1)
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
select {
......@@ -224,7 +300,7 @@ func TestConn(t *testing.T) {
return n2.Close()
})
b, err := ioutil.ReadAll(n1)
b, err := io.ReadAll(n1)
assert.Success(t, err)
_, err = n1.Read(nil)
......@@ -240,6 +316,18 @@ func TestConn(t *testing.T) {
assert.Equal(t, "read msg", s, string(b))
})
t.Run("netConn/pastDeadline", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
n1.SetDeadline(time.Now().Add(-time.Minute))
n2.SetDeadline(time.Now().Add(-time.Minute))
// No panic we're good.
})
t.Run("wsjson", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
......@@ -269,20 +357,67 @@ func TestConn(t *testing.T) {
assert.Success(t, err)
})
t.Run("wspb", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
t.Run("HTTPClient.Timeout", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
HTTPClient: &http.Client{Timeout: time.Second * 5},
}, nil)
tt.goEchoLoop(c2)
exp := ptypes.DurationProto(100)
err := wspb.Write(tt.ctx, c1, exp)
assert.Success(t, err)
c1.SetReadLimit(1 << 30)
exp := xrand.String(xrand.Int(131072))
act := &duration.Duration{}
err = wspb.Read(tt.ctx, c1, act)
werr := xsync.Go(func() error {
return wsjson.Write(tt.ctx, c1, exp)
})
var act interface{}
err := wsjson.Read(tt.ctx, c1, &act)
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
t.Run("CloseNow", func(t *testing.T) {
_, c1, c2 := newConnTest(t, nil, nil)
err1 := c1.CloseNow()
err2 := c2.CloseNow()
assert.Success(t, err1)
assert.Success(t, err2)
err1 = c1.CloseNow()
err2 = c2.CloseNow()
assert.ErrorIs(t, websocket.ErrClosed, err1)
assert.ErrorIs(t, websocket.ErrClosed, err2)
})
t.Run("MidReadClose", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(131072)
for i := 0; i < 5; i++ {
err := wstest.Echo(tt.ctx, c1, 131072)
assert.Success(t, err)
}
err := wsjson.Write(tt.ctx, c1, "four")
assert.Success(t, err)
_, _, err = c1.Reader(tt.ctx)
assert.Success(t, err)
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
......@@ -290,6 +425,9 @@ func TestConn(t *testing.T) {
func TestWasm(t *testing.T) {
t.Parallel()
if os.Getenv("CI") == "" {
t.SkipNow()
}
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, &websocket.AcceptOptions{
......@@ -305,8 +443,8 @@ func TestWasm(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".", "-v")
cmd.Env = append(cleanEnv(os.Environ()), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
b, err := cmd.CombinedOutput()
if err != nil {
......@@ -314,6 +452,18 @@ func TestWasm(t *testing.T) {
}
}
func cleanEnv(env []string) (out []string) {
for _, e := range env {
// Filter out GITHUB envs and anything with token in it,
// especially GITHUB_TOKEN in CI as it breaks TestWasm.
if strings.HasPrefix(e, "GITHUB") || strings.Contains(e, "TOKEN") {
continue
}
out = append(out, e)
}
return out
}
func assertCloseStatus(exp websocket.StatusCode, err error) error {
if websocket.CloseStatus(err) == -1 {
return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
......@@ -344,10 +494,8 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs
c1, c2 = c2, c1
}
t.Cleanup(func() {
// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
// blocking the test shutting down.
go c2.Close(websocket.StatusInternalError, "")
go c1.Close(websocket.StatusInternalError, "")
c2.CloseNow()
c1.CloseNow()
})
return tt, c1, c2
......@@ -392,7 +540,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
}
func BenchmarkConn(b *testing.B) {
var benchCases = []struct {
benchCases := []struct {
name string
mode websocket.CompressionMode
}{
......@@ -449,7 +597,7 @@ func BenchmarkConn(b *testing.B) {
typ, r, err := c1.Reader(bb.ctx)
if err != nil {
b.Fatal(err)
b.Fatal(i, err)
}
if websocket.MessageText != typ {
assert.Equal(b, "data type", websocket.MessageText, typ)
......@@ -505,36 +653,201 @@ func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOp
return assertCloseStatus(websocket.StatusNormalClosure, err)
}
func TestGin(t *testing.T) {
func assertEcho(tb testing.TB, ctx context.Context, c *websocket.Conn) {
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(ctx, c, exp)
})
var act interface{}
c.SetReadLimit(1 << 30)
err := wsjson.Read(ctx, c, &act)
assert.Success(tb, err)
assert.Equal(tb, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(tb, err)
case <-ctx.Done():
tb.Fatal(ctx.Err())
}
}
func assertClose(tb testing.TB, c *websocket.Conn) {
tb.Helper()
err := c.Close(websocket.StatusNormalClosure, "")
assert.Success(tb, err)
}
func TestConcurrentClosePing(t *testing.T) {
t.Parallel()
for i := 0; i < 64; i++ {
func() {
c1, c2 := wstest.Pipe(nil, nil)
defer c1.CloseNow()
defer c2.CloseNow()
c1.CloseRead(context.Background())
c2.CloseRead(context.Background())
errc := xsync.Go(func() error {
for range time.Tick(time.Millisecond) {
err := c1.Ping(context.Background())
if err != nil {
return err
}
}
panic("unreachable")
})
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.GET("/", func(ginCtx *gin.Context) {
err := echoServer(ginCtx.Writer, ginCtx.Request, nil)
if err != nil {
t.Error(err)
time.Sleep(10 * time.Millisecond)
assert.Success(t, c1.Close(websocket.StatusNormalClosure, ""))
<-errc
}()
}
}
func TestConnClosePropagation(t *testing.T) {
t.Parallel()
want := []byte("hello")
keepWriting := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
err := c.Write(context.Background(), websocket.MessageText, want)
if err != nil {
return err
}
}
})
}
keepReading := func(c *websocket.Conn) <-chan error {
return xsync.Go(func() error {
for {
_, got, err := c.Read(context.Background())
if err != nil {
return err
}
if !bytes.Equal(want, got) {
return fmt.Errorf("unexpected message: want %q, got %q", want, got)
}
}
})
}
checkReadErr := func(t *testing.T, err error) {
// Check read error (output depends on when read is called in relation to connection closure).
var ce websocket.CloseError
if errors.As(err, &ce) {
assert.Equal(t, "", websocket.StatusNormalClosure, ce.Code)
} else {
assert.ErrorIs(t, net.ErrClosed, err)
}
}
checkConnErrs := func(t *testing.T, conn ...*websocket.Conn) {
for _, c := range conn {
// Check write error.
err := c.Write(context.Background(), websocket.MessageText, want)
assert.ErrorIs(t, net.ErrClosed, err)
_, _, err = c.Read(context.Background())
checkReadErr(t, err)
}
}
t.Run("CloseOtherSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
_, got, err := other.Read(tt.ctx)
assert.Success(t, err)
assert.Equal(t, "msg", want, got)
err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringWrite", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
s := httptest.NewServer(r)
defer s.Close()
_ = this.CloseRead(tt.ctx)
thisWriteErr := keepWriting(this)
otherReadErr := keepReading(other)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
err := this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
c, _, err := websocket.Dial(ctx, s.URL, nil)
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
select {
case err := <-thisWriteErr:
assert.ErrorIs(t, net.ErrClosed, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err = wsjson.Write(ctx, c, "hello")
assert.Success(t, err)
select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
var v interface{}
err = wsjson.Read(ctx, c, &v)
assert.Success(t, err)
assert.Equal(t, "read msg", "hello", v)
checkConnErrs(t, this, other)
})
t.Run("CloseOtherSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
_ = other.CloseRead(tt.ctx)
errs := keepReading(this)
err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)
err = other.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
err = c.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-errs:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
t.Run("CloseThisSideDuringRead", func(t *testing.T) {
tt, this, other := newConnTest(t, nil, nil)
thisReadErr := keepReading(this)
otherReadErr := keepReading(other)
err := other.Write(tt.ctx, websocket.MessageText, want)
assert.Success(t, err)
err = this.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
select {
case err := <-thisReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
select {
case err := <-otherReadErr:
checkReadErr(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
checkConnErrs(t, this, other)
})
}
......@@ -11,14 +11,13 @@ import (
"encoding/base64"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"time"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
)
// DialOptions represents Dial's options.
......@@ -31,6 +30,10 @@ type DialOptions struct {
// HTTPHeader specifies the HTTP headers included in the handshake request.
HTTPHeader http.Header
// Host optionally overrides the Host HTTP header to send. If empty, the value
// of URL.Host will be used.
Host string
// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
Subprotocols []string
......@@ -45,6 +48,22 @@ type DialOptions struct {
// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
// for CompressionContextTakeover.
CompressionThreshold int
// OnPingReceived is an optional callback invoked synchronously when a ping frame is received.
//
// The payload contains the application data of the ping frame.
// If the callback returns false, the subsequent pong frame will not be sent.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
OnPingReceived func(ctx context.Context, payload []byte) bool
// OnPongReceived is an optional callback invoked synchronously when a pong frame is received.
//
// The payload contains the application data of the pong frame.
// To avoid blocking, any expensive processing should be performed asynchronously using a goroutine.
//
// Unlike OnPingReceived, this callback does not return a value because a pong frame
// is a response to a ping and does not trigger any further frame transmission.
OnPongReceived func(ctx context.Context, payload []byte)
}
func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
......@@ -56,16 +75,32 @@ func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context
}
if o.HTTPClient == nil {
o.HTTPClient = http.DefaultClient
} else if opts.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout)
}
if o.HTTPClient.Timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
newClient := *opts.HTTPClient
newClient := *o.HTTPClient
newClient.Timeout = 0
opts.HTTPClient = &newClient
o.HTTPClient = &newClient
}
if o.HTTPHeader == nil {
o.HTTPHeader = http.Header{}
}
newClient := *o.HTTPClient
oldCheckRedirect := o.HTTPClient.CheckRedirect
newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
switch req.URL.Scheme {
case "ws":
req.URL.Scheme = "http"
case "wss":
req.URL.Scheme = "https"
}
if oldCheckRedirect != nil {
return oldCheckRedirect(req, via)
}
return nil
}
o.HTTPClient = &newClient
return ctx, cancel, &o
}
......@@ -122,9 +157,9 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
})
defer timer.Stop()
b, _ := ioutil.ReadAll(r)
b, _ := io.ReadAll(r)
respBody.Close()
resp.Body = ioutil.NopCloser(bytes.NewReader(b))
resp.Body = io.NopCloser(bytes.NewReader(b))
}
}()
......@@ -144,6 +179,8 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (
client: true,
copts: copts,
flateThreshold: opts.CompressionThreshold,
onPingReceived: opts.OnPingReceived,
onPongReceived: opts.OnPongReceived,
br: getBufioReader(rwc),
bw: getBufioWriter(rwc),
}), resp, nil
......@@ -165,7 +202,13 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
}
req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create new http request: %w", err)
}
if len(opts.Host) > 0 {
req.Host = opts.Host
}
req.Header = opts.HTTPHeader.Clone()
req.Header.Set("Connection", "Upgrade")
req.Header.Set("Upgrade", "websocket")
......@@ -175,7 +218,7 @@ func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts
req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
}
if copts != nil {
copts.setHeader(req.Header)
req.Header.Set("Sec-WebSocket-Extensions", copts.String())
}
resp, err := opts.HTTPClient.Do(req)
......@@ -263,6 +306,10 @@ func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compress
copts.serverNoContextTakeover = true
continue
}
if strings.HasPrefix(p, "server_max_window_bits=") {
// We can't adjust the deflate window, but decoding with a larger window is acceptable.
continue
}
return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
}
......
//go:build !js
// +build !js
package websocket
package websocket_test
import (
"bytes"
"context"
"crypto/rand"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"nhooyr.io/websocket/internal/test/assert"
"github.com/coder/websocket"
"github.com/coder/websocket/internal/test/assert"
"github.com/coder/websocket/internal/util"
"github.com/coder/websocket/internal/xsync"
)
func TestBadDials(t *testing.T) {
......@@ -24,10 +28,11 @@ func TestBadDials(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
url string
opts *DialOptions
rand readerFunc
name string
url string
opts *websocket.DialOptions
rand util.ReaderFunc
nilCtx bool
}{
{
name: "badURL",
......@@ -47,6 +52,11 @@ func TestBadDials(t *testing.T) {
return 0, io.EOF
},
},
{
name: "nilContext",
url: "http://localhost",
nilCtx: true,
},
}
for _, tc := range testCases {
......@@ -54,14 +64,18 @@ func TestBadDials(t *testing.T) {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
var ctx context.Context
var cancel func()
if !tc.nilCtx {
ctx, cancel = context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
}
if tc.rand == nil {
tc.rand = rand.Reader.Read
}
_, _, err := dial(ctx, tc.url, tc.opts, tc.rand)
_, _, err := websocket.ExportedDial(ctx, tc.url, tc.opts, tc.rand)
assert.Error(t, err)
})
}
......@@ -73,10 +87,10 @@ func TestBadDials(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(*http.Request) (*http.Response, error) {
return &http.Response{
Body: ioutil.NopCloser(strings.NewReader("hi")),
Body: io.NopCloser(strings.NewReader("hi")),
}, nil
}),
})
......@@ -93,22 +107,82 @@ func TestBadDials(t *testing.T) {
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", secWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: ioutil.NopCloser(strings.NewReader("hi")),
Body: io.NopCloser(strings.NewReader("hi")),
}, nil
}
_, _, err := Dial(ctx, "ws://example.com", &DialOptions{
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
})
assert.Contains(t, err, "response body is not a io.ReadWriteCloser")
})
}
func Test_verifyHostOverride(t *testing.T) {
testCases := []struct {
name string
host string
exp string
}{
{
name: "noOverride",
host: "",
exp: "example.com",
},
{
name: "hostOverride",
host: "example.net",
exp: "example.net",
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
rt := func(r *http.Request) (*http.Response, error) {
assert.Equal(t, "Host", tc.exp, r.Host)
h := http.Header{}
h.Set("Connection", "Upgrade")
h.Set("Upgrade", "websocket")
h.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(r.Header.Get("Sec-WebSocket-Key")))
return &http.Response{
StatusCode: http.StatusSwitchingProtocols,
Header: h,
Body: mockBody{bytes.NewBufferString("hi")},
}, nil
}
c, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(rt),
Host: tc.host,
})
assert.Success(t, err)
c.CloseNow()
})
}
}
type mockBody struct {
*bytes.Buffer
}
func (mb mockBody) Close() error {
return nil
}
func Test_verifyServerHandshake(t *testing.T) {
t.Parallel()
......@@ -202,18 +276,18 @@ func Test_verifyServerHandshake(t *testing.T) {
resp := w.Result()
r := httptest.NewRequest("GET", "/", nil)
key, err := secWebSocketKey(rand.Reader)
key, err := websocket.SecWebSocketKey(rand.Reader)
assert.Success(t, err)
r.Header.Set("Sec-WebSocket-Key", key)
if resp.Header.Get("Sec-WebSocket-Accept") == "" {
resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
resp.Header.Set("Sec-WebSocket-Accept", websocket.SecWebSocketAccept(key))
}
opts := &DialOptions{
opts := &websocket.DialOptions{
Subprotocols: strings.Split(r.Header.Get("Sec-WebSocket-Protocol"), ","),
}
_, err = verifyServerResponse(opts, opts.CompressionMode.opts(), key, resp)
_, err = websocket.VerifyServerResponse(opts, websocket.CompressionModeOpts(opts.CompressionMode), key, resp)
if tc.success {
assert.Success(t, err)
} else {
......@@ -234,3 +308,113 @@ type roundTripperFunc func(*http.Request) (*http.Response, error)
func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) {
return f(r)
}
func TestDialRedirect(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := websocket.Dial(ctx, "ws://example.com", &websocket.DialOptions{
HTTPClient: mockHTTPClient(func(r *http.Request) (*http.Response, error) {
resp := &http.Response{
Header: http.Header{},
}
if r.URL.Scheme != "https" {
resp.Header.Set("Location", "wss://example.com")
resp.StatusCode = http.StatusFound
return resp, nil
}
resp.Header.Set("Connection", "Upgrade")
resp.Header.Set("Upgrade", "meow")
resp.StatusCode = http.StatusSwitchingProtocols
return resp, nil
}),
})
assert.Contains(t, err, "failed to WebSocket dial: WebSocket protocol violation: Upgrade header \"meow\" does not contain websocket")
}
type forwardProxy struct {
hc *http.Client
}
func newForwardProxy() *forwardProxy {
return &forwardProxy{
hc: &http.Client{},
}
}
func (fc *forwardProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
r = r.WithContext(ctx)
r.RequestURI = ""
resp, err := fc.hc.Do(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
defer resp.Body.Close()
for k, v := range resp.Header {
w.Header()[k] = v
}
w.Header().Set("PROXIED", "true")
w.WriteHeader(resp.StatusCode)
if resprw, ok := resp.Body.(io.ReadWriter); ok {
c, brw, err := w.(http.Hijacker).Hijack()
if err != nil {
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
brw.Flush()
errc1 := xsync.Go(func() error {
_, err := io.Copy(c, resprw)
return err
})
errc2 := xsync.Go(func() error {
_, err := io.Copy(resprw, c)
return err
})
select {
case <-errc1:
case <-errc2:
case <-r.Context().Done():
}
} else {
io.Copy(w, resp.Body)
}
}
func TestDialViaProxy(t *testing.T) {
t.Parallel()
ps := httptest.NewServer(newForwardProxy())
defer ps.Close()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
err := echoServer(w, r, nil)
assert.Success(t, err)
}))
defer s.Close()
psu, err := url.Parse(ps.URL)
assert.Success(t, err)
proxyTransport := http.DefaultTransport.(*http.Transport).Clone()
proxyTransport.Proxy = http.ProxyURL(psu)
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
c, resp, err := websocket.Dial(ctx, s.URL, &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: proxyTransport,
},
})
assert.Success(t, err)
assert.Equal(t, "", "true", resp.Header.Get("PROXIED"))
assertEcho(t, ctx, c)
assertClose(t, c)
}
......@@ -13,9 +13,9 @@
//
// The examples are the best way to understand how to correctly use the library.
//
// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages.
// The wsjson subpackage contain helpers for JSON and protobuf messages.
//
// More documentation at https://nhooyr.io/websocket.
// More documentation at https://github.com/coder/websocket.
//
// # Wasm
//
......@@ -28,6 +28,7 @@
//
// - Accept always errors out
// - Conn.Ping is no-op
// - Conn.CloseNow is Close(StatusGoingAway, "")
// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op
// - *http.Response from Dial is &http.Response{} with a 101 status code on success
package websocket // import "nhooyr.io/websocket"
package websocket // import "github.com/coder/websocket"
......@@ -6,8 +6,8 @@ import (
"net/http"
"time"
"nhooyr.io/websocket"
"nhooyr.io/websocket/wsjson"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
)
func ExampleAccept() {
......@@ -20,7 +20,7 @@ func ExampleAccept() {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Second*10)
defer cancel()
......@@ -50,7 +50,7 @@ func ExampleDial() {
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
err = wsjson.Write(ctx, c, "hi")
if err != nil {
......@@ -71,7 +71,7 @@ func ExampleCloseStatus() {
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
_, _, err = c.Reader(ctx)
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
......@@ -88,7 +88,7 @@ func Example_writeOnly() {
log.Println(err)
return
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
ctx, cancel := context.WithTimeout(r.Context(), time.Minute*10)
defer cancel()
......@@ -145,7 +145,7 @@ func ExampleConn_Ping() {
if err != nil {
log.Fatal(err)
}
defer c.Close(websocket.StatusInternalError, "the sky is falling")
defer c.CloseNow()
// Required to read the Pongs from the server.
ctx = c.CloseRead(ctx)
......@@ -162,10 +162,10 @@ func ExampleConn_Ping() {
// This example demonstrates full stack chat with an automated test.
func Example_fullStackChat() {
// https://github.com/nhooyr/websocket/tree/master/examples/chat
// https://github.com/nhooyr/websocket/tree/master/internal/examples/chat
}
// This example demonstrates a echo server.
func Example_echo() {
// https://github.com/nhooyr/websocket/tree/master/examples/echo
// https://github.com/nhooyr/websocket/tree/master/internal/examples/echo
}
......@@ -3,9 +3,15 @@
package websocket
import (
"net"
"github.com/coder/websocket/internal/util"
)
func (c *Conn) RecordBytesWritten() *int {
var bytesWritten int
c.bw.Reset(writerFunc(func(p []byte) (int, error) {
c.bw.Reset(util.WriterFunc(func(p []byte) (int, error) {
bytesWritten += len(p)
return c.rwc.Write(p)
}))
......@@ -14,10 +20,19 @@ func (c *Conn) RecordBytesWritten() *int {
func (c *Conn) RecordBytesRead() *int {
var bytesRead int
c.br.Reset(readerFunc(func(p []byte) (int, error) {
c.br.Reset(util.ReaderFunc(func(p []byte) (int, error) {
n, err := c.rwc.Read(p)
bytesRead += n
return n, err
}))
return &bytesRead
}
var ErrClosed = net.ErrClosed
var ExportedDial = dial
var SecWebSocketAccept = secWebSocketAccept
var SecWebSocketKey = secWebSocketKey
var VerifyServerResponse = verifyServerResponse
var CompressionModeOpts = CompressionMode.opts
//go:build !js
package websocket
import (
......@@ -6,9 +8,8 @@ import (
"fmt"
"io"
"math"
"math/bits"
"nhooyr.io/websocket/internal/errd"
"github.com/coder/websocket/internal/errd"
)
// opcode represents a WebSocket opcode.
......@@ -170,125 +171,3 @@ func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) {
return nil
}
// mask applies the WebSocket masking algorithm to p
// with the given key.
// See https://tools.ietf.org/html/rfc6455#section-5.3
//
// The returned value is the correctly rotated key to
// to continue to mask/unmask the message.
//
// It is optimized for LittleEndian and expects the key
// to be in little endian.
//
// See https://github.com/golang/go/issues/31586
func mask(key uint32, b []byte) uint32 {
if len(b) >= 8 {
key64 := uint64(key)<<32 | uint64(key)
// At some point in the future we can clean these unrolled loops up.
// See https://github.com/golang/go/issues/31586#issuecomment-487436401
// Then we xor until b is less than 128 bytes.
for len(b) >= 128 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
v = binary.LittleEndian.Uint64(b[64:72])
binary.LittleEndian.PutUint64(b[64:72], v^key64)
v = binary.LittleEndian.Uint64(b[72:80])
binary.LittleEndian.PutUint64(b[72:80], v^key64)
v = binary.LittleEndian.Uint64(b[80:88])
binary.LittleEndian.PutUint64(b[80:88], v^key64)
v = binary.LittleEndian.Uint64(b[88:96])
binary.LittleEndian.PutUint64(b[88:96], v^key64)
v = binary.LittleEndian.Uint64(b[96:104])
binary.LittleEndian.PutUint64(b[96:104], v^key64)
v = binary.LittleEndian.Uint64(b[104:112])
binary.LittleEndian.PutUint64(b[104:112], v^key64)
v = binary.LittleEndian.Uint64(b[112:120])
binary.LittleEndian.PutUint64(b[112:120], v^key64)
v = binary.LittleEndian.Uint64(b[120:128])
binary.LittleEndian.PutUint64(b[120:128], v^key64)
b = b[128:]
}
// Then we xor until b is less than 64 bytes.
for len(b) >= 64 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
v = binary.LittleEndian.Uint64(b[32:40])
binary.LittleEndian.PutUint64(b[32:40], v^key64)
v = binary.LittleEndian.Uint64(b[40:48])
binary.LittleEndian.PutUint64(b[40:48], v^key64)
v = binary.LittleEndian.Uint64(b[48:56])
binary.LittleEndian.PutUint64(b[48:56], v^key64)
v = binary.LittleEndian.Uint64(b[56:64])
binary.LittleEndian.PutUint64(b[56:64], v^key64)
b = b[64:]
}
// Then we xor until b is less than 32 bytes.
for len(b) >= 32 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
v = binary.LittleEndian.Uint64(b[16:24])
binary.LittleEndian.PutUint64(b[16:24], v^key64)
v = binary.LittleEndian.Uint64(b[24:32])
binary.LittleEndian.PutUint64(b[24:32], v^key64)
b = b[32:]
}
// Then we xor until b is less than 16 bytes.
for len(b) >= 16 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
v = binary.LittleEndian.Uint64(b[8:16])
binary.LittleEndian.PutUint64(b[8:16], v^key64)
b = b[16:]
}
// Then we xor until b is less than 8 bytes.
for len(b) >= 8 {
v := binary.LittleEndian.Uint64(b)
binary.LittleEndian.PutUint64(b, v^key64)
b = b[8:]
}
}
// Then we xor until b is less than 4 bytes.
for len(b) >= 4 {
v := binary.LittleEndian.Uint32(b)
binary.LittleEndian.PutUint32(b, v^key)
b = b[4:]
}
// xor remaining bytes.
for i := range b {
b[i] ^= byte(key)
key = bits.RotateLeft32(key, -8)
}
return key
}
......@@ -12,12 +12,8 @@ import (
"strconv"
"testing"
"time"
_ "unsafe"
"github.com/gobwas/ws"
_ "github.com/gorilla/websocket"
"nhooyr.io/websocket/internal/test/assert"
"github.com/coder/websocket/internal/test/assert"
)
func TestHeader(t *testing.T) {
......@@ -55,7 +51,7 @@ func TestHeader(t *testing.T) {
r := rand.New(rand.NewSource(time.Now().UnixNano()))
randBool := func() bool {
return r.Intn(1) == 0
return r.Intn(2) == 0
}
for i := 0; i < 10000; i++ {
......@@ -67,9 +63,11 @@ func TestHeader(t *testing.T) {
opcode: opcode(r.Intn(16)),
masked: randBool(),
maskKey: r.Uint32(),
payloadLength: r.Int63(),
}
if h.masked {
h.maskKey = r.Uint32()
}
testHeader(t, h)
}
......@@ -99,7 +97,7 @@ func Test_mask(t *testing.T) {
key := []byte{0xa, 0xb, 0xc, 0xff}
key32 := binary.LittleEndian.Uint32(key)
p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc}
gotKey32 := mask(key32, p)
gotKey32 := mask(p, key32)
expP := []byte{0, 0, 0, 0x0d, 0x6}
assert.Equal(t, "p", expP, p)
......@@ -107,87 +105,3 @@ func Test_mask(t *testing.T) {
expKey32 := bits.RotateLeft32(key32, -8)
assert.Equal(t, "key32", expKey32, gotKey32)
}
func basicMask(maskKey [4]byte, pos int, b []byte) int {
for i := range b {
b[i] ^= maskKey[pos&3]
pos++
}
return pos & 3
}
//go:linkname gorillaMaskBytes github.com/gorilla/websocket.maskBytes
func gorillaMaskBytes(key [4]byte, pos int, b []byte) int
func Benchmark_mask(b *testing.B) {
sizes := []int{
2,
3,
4,
8,
16,
32,
128,
512,
4096,
16384,
}
fns := []struct {
name string
fn func(b *testing.B, key [4]byte, p []byte)
}{
{
name: "basic",
fn: func(b *testing.B, key [4]byte, p []byte) {
for i := 0; i < b.N; i++ {
basicMask(key, 0, p)
}
},
},
{
name: "nhooyr",
fn: func(b *testing.B, key [4]byte, p []byte) {
key32 := binary.LittleEndian.Uint32(key[:])
b.ResetTimer()
for i := 0; i < b.N; i++ {
mask(key32, p)
}
},
},
{
name: "gorilla",
fn: func(b *testing.B, key [4]byte, p []byte) {
for i := 0; i < b.N; i++ {
gorillaMaskBytes(key, 0, p)
}
},
},
{
name: "gobwas",
fn: func(b *testing.B, key [4]byte, p []byte) {
for i := 0; i < b.N; i++ {
ws.Cipher(p, key, 0)
}
},
},
}
key := [4]byte{1, 2, 3, 4}
for _, size := range sizes {
p := make([]byte, size)
b.Run(strconv.Itoa(size), func(b *testing.B) {
for _, fn := range fns {
b.Run(fn.name, func(b *testing.B) {
b.SetBytes(int64(size))
fn.fn(b, key, p)
})
}
})
}
}
module nhooyr.io/websocket
module github.com/coder/websocket
go 1.19
require (
github.com/gin-gonic/gin v1.9.1
github.com/gobwas/ws v1.3.0
github.com/golang/protobuf v1.5.3
github.com/google/go-cmp v0.5.9
github.com/gorilla/websocket v1.5.0
golang.org/x/time v0.3.0
)
require (
github.com/bytedance/sonic v1.9.1 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.14.0 // indirect
github.com/gobwas/httphead v0.1.0 // indirect
github.com/gobwas/pool v0.2.1 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.9.0 // indirect
golang.org/x/net v0.10.0 // indirect
golang.org/x/sys v0.8.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
go 1.23
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams=
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg=
github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js=
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU=
github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM=
github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og=
github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.3.0 h1:sbeU3Y4Qzlb+MOzIe6mQGf7QR4Hkv6ZD0qhGkBFL2O0=
github.com/gobwas/ws v1.3.0/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q=
github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY=
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU=
github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.9.0 h1:LF6fAI+IutBocDJ2OT0Q1g8plpYljMZ4+lty+dsqw3g=
golang.org/x/crypto v0.9.0/go.mod h1:yrmDGqONDYtNj3tH8X9dzUun2m2lzPa9ngI6/RUPGR0=
golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.8.0 h1:EBmGv8NaZBZTWvrbjNoL6HVt+IVy3QDQpJs7VRIw3tU=
golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8=
golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4=
golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng=
google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
//go:build !js
package websocket
import (
"net/http"
)
type rwUnwrapper interface {
Unwrap() http.ResponseWriter
}
// hijacker returns the Hijacker interface of the http.ResponseWriter.
// It follows the Unwrap method of the http.ResponseWriter if available,
// matching the behavior of http.ResponseController. If the Hijacker
// interface is not found, it returns false.
//
// Since the http.ResponseController is not available in Go 1.19, and
// does not support checking the presence of the Hijacker interface,
// this function is used to provide a consistent way to check for the
// Hijacker interface across Go versions.
func hijacker(rw http.ResponseWriter) (http.Hijacker, bool) {
for {
switch t := rw.(type) {
case http.Hijacker:
return t, true
case rwUnwrapper:
rw = t.Unwrap()
default:
return nil, false
}
}
}
//go:build !js && go1.20
package websocket
import (
"bufio"
"errors"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/coder/websocket/internal/test/assert"
)
func Test_hijackerHTTPResponseControllerCompatibility(t *testing.T) {
t.Parallel()
rr := httptest.NewRecorder()
w := mockUnwrapper{
ResponseWriter: rr,
unwrap: func() http.ResponseWriter {
return mockHijacker{
ResponseWriter: rr,
hijack: func() (conn net.Conn, writer *bufio.ReadWriter, err error) {
return nil, nil, errors.New("haha")
},
}
},
}
_, _, err := http.NewResponseController(w).Hijack()
assert.Contains(t, err, "haha")
hj, ok := hijacker(w)
assert.Equal(t, "hijacker found", ok, true)
_, _, err = hj.Hijack()
assert.Contains(t, err, "haha")
}
......@@ -5,15 +5,16 @@ import (
"sync"
)
var bpool sync.Pool
var bpool = sync.Pool{
New: func() any {
return &bytes.Buffer{}
},
}
// Get returns a buffer from the pool or creates a new one if
// the pool is empty.
func Get() *bytes.Buffer {
b := bpool.Get()
if b == nil {
return &bytes.Buffer{}
}
return b.(*bytes.Buffer)
}
......
File moved
# Chat Example
This directory contains a full stack example of a simple chat webapp using nhooyr.io/websocket.
This directory contains a full stack example of a simple chat webapp using github.com/coder/websocket.
```bash
$ cd examples/chat
$ go run . localhost:0
listening on http://127.0.0.1:51055
listening on ws://127.0.0.1:51055
```
Visit the printed URL to submit and view broadcasted messages in a browser.
......
......@@ -3,15 +3,16 @@ package main
import (
"context"
"errors"
"io/ioutil"
"io"
"log"
"net"
"net/http"
"sync"
"time"
"golang.org/x/time/rate"
"nhooyr.io/websocket"
"github.com/coder/websocket"
)
// chatServer enables broadcasting to a set of subscribers.
......@@ -69,14 +70,7 @@ func (cs *chatServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// subscribeHandler accepts the WebSocket connection and then subscribes
// it to all future messages.
func (cs *chatServer) subscribeHandler(w http.ResponseWriter, r *http.Request) {
c, err := websocket.Accept(w, r, nil)
if err != nil {
cs.logf("%v", err)
return
}
defer c.Close(websocket.StatusInternalError, "")
err = cs.subscribe(r.Context(), c)
err := cs.subscribe(w, r)
if errors.Is(err, context.Canceled) {
return
}
......@@ -98,7 +92,7 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
return
}
body := http.MaxBytesReader(w, r.Body, 8192)
msg, err := ioutil.ReadAll(body)
msg, err := io.ReadAll(body)
if err != nil {
http.Error(w, http.StatusText(http.StatusRequestEntityTooLarge), http.StatusRequestEntityTooLarge)
return
......@@ -117,18 +111,39 @@ func (cs *chatServer) publishHandler(w http.ResponseWriter, r *http.Request) {
//
// It uses CloseRead to keep reading from the connection to process control
// messages and cancel the context if the connection drops.
func (cs *chatServer) subscribe(ctx context.Context, c *websocket.Conn) error {
ctx = c.CloseRead(ctx)
func (cs *chatServer) subscribe(w http.ResponseWriter, r *http.Request) error {
var mu sync.Mutex
var c *websocket.Conn
var closed bool
s := &subscriber{
msgs: make(chan []byte, cs.subscriberMessageBuffer),
closeSlow: func() {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
mu.Lock()
defer mu.Unlock()
closed = true
if c != nil {
c.Close(websocket.StatusPolicyViolation, "connection too slow to keep up with messages")
}
},
}
cs.addSubscriber(s)
defer cs.deleteSubscriber(s)
c2, err := websocket.Accept(w, r, nil)
if err != nil {
return err
}
mu.Lock()
if closed {
mu.Unlock()
return net.ErrClosed
}
c = c2
mu.Unlock()
defer c.CloseNow()
ctx := c.CloseRead(context.Background())
for {
select {
case msg := <-s.msgs:
......
......@@ -14,7 +14,7 @@ import (
"golang.org/x/time/rate"
"nhooyr.io/websocket"
"github.com/coder/websocket"
)
func Test_chatServer(t *testing.T) {
......@@ -52,7 +52,7 @@ func Test_chatServer(t *testing.T) {
// 10 clients are started that send 128 different
// messages of max 128 bytes concurrently.
//
// The test verifies that every message is seen by ever client
// The test verifies that every message is seen by every client
// and no errors occur anywhere.
t.Run("concurrency", func(t *testing.T) {
t.Parallel()
......
......@@ -2,7 +2,7 @@
<html lang="en-CA">
<head>
<meta charset="UTF-8" />
<title>nhooyr.io/websocket - Chat Example</title>
<title>github.com/coder/websocket - Chat Example</title>
<meta name="viewport" content="width=device-width" />
<link href="https://unpkg.com/sanitize.css" rel="stylesheet" />
......