Newer
Older
"net/http"
"net/http/httptest"
"os"
"os/exec"
"nhooyr.io/websocket/internal/test/wstest"
"nhooyr.io/websocket/internal/test/xrand"
compressionMode := func() websocket.CompressionMode {
return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1))
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
CompressionMode: compressionMode(),
CompressionThreshold: xrand.Int(9999),
CompressionMode: compressionMode(),
CompressionThreshold: xrand.Int(9999),
err := c1.Close(websocket.StatusNormalClosure, "")
tt, c1, c2 := newConnTest(t, nil, nil)
c2.CloseRead(tt.ctx)
assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
err := c1.Close(websocket.StatusNormalClosure, "")
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")
})
t.Run("concurrentWrite", func(t *testing.T) {
msg := xrand.Bytes(xrand.Int(9999))
const count = 100
errs := make(chan error, count)
for i := 0; i < count; i++ {
go func() {
select {
case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
case <-tt.ctx.Done():
return
}
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err := c1.Close(websocket.StatusNormalClosure, "")
t.Run("concurrentWriteError", func(t *testing.T) {
_, err := c1.Writer(tt.ctx, websocket.MessageText)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
err = c1.Write(ctx, websocket.MessageText, []byte("x"))
if !errors.Is(err, context.DeadlineExceeded) {
t.Fatalf("unexpected error: %#v", err)
}
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
// Does not give any confidence but at least ensures no crashes.
assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
assert.Equal(t, "remote addr string", "pipe", n1.RemoteAddr().String())
assert.Equal(t, "remote addr network", "pipe", n1.RemoteAddr().Network())
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
if err != nil {
return err
}
return n2.Close()
})
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
t.Run("netConn/readLimit", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte(s))
if err != nil {
return err
}
return n2.Close()
})
assert.Success(t, err)
_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
assert.Equal(t, "read msg", s, string(b))
})
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
t.Run("HTTPClient.Timeout", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
HTTPClient: &http.Client{Timeout: time.Second*5},
}, nil)
tt.goEchoLoop(c2)
c1.SetReadLimit(1 << 30)
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(tt.ctx, c1, exp)
})
var act interface{}
err := wsjson.Read(tt.ctx, c1, &act)
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err = c1.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)
})
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Subprotocols: []string{"echo"},
InsecureSkipVerify: true,
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
b, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
func assertCloseStatus(exp websocket.StatusCode, err error) error {
if websocket.CloseStatus(err) == -1 {
return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
return fmt.Errorf("expected close status %v but got %v", exp, err)
func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
if t, ok := t.(*testing.T); ok {
t.Parallel()
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
if xrand.Bool() {
c1, c2 = c2, c1
}
t.Cleanup(func() {
// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
// blocking the test shutting down.
go c2.Close(websocket.StatusInternalError, "")
go c1.Close(websocket.StatusInternalError, "")
func (tt *connTest) goEchoLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
})
cancel()
err := <-echoLoopErr
if err != nil {
tt.t.Errorf("echo loop error: %v", err)
}
})
}
func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
discardLoopErr := xsync.Go(func() error {
defer c.Close(websocket.StatusInternalError, "")
for {
_, _, err := c.Read(ctx)
if err != nil {
return assertCloseStatus(websocket.StatusNormalClosure, err)
cancel()
err := <-discardLoopErr
if err != nil {
tt.t.Errorf("discard loop error: %v", err)
}
})
}
func BenchmarkConn(b *testing.B) {
var benchCases = []struct {
name string
mode websocket.CompressionMode
}{
{
mode: websocket.CompressionContextTakeover,
},
{
mode: websocket.CompressionNoContextTakeover,
},
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
bytesWritten := c1.RecordBytesWritten()
bytesRead := c1.RecordBytesRead()
msg := []byte(strings.Repeat("1234", 128))
readBuf := make([]byte, len(msg))
writes := make(chan struct{})
defer close(writes)
werrs := make(chan error)
go func() {
select {
case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
case <-bb.ctx.Done():
return
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case writes <- struct{}{}:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
typ, r, err := c1.Reader(bb.ctx)
if err != nil {
b.Fatal(err)
}
if websocket.MessageText != typ {
assert.Equal(b, "data type", websocket.MessageText, typ)
}
_, err = io.ReadFull(r, readBuf)
if err != nil {
b.Fatal(err)
}
n2, err := r.Read(readBuf)
if err != io.EOF {
assert.Equal(b, "read err", io.EOF, err)
}
if n2 != 0 {
assert.Equal(b, "n2", 0, n2)
}
if !bytes.Equal(msg, readBuf) {
assert.Equal(b, "msg", msg, readBuf)
select {
case err = <-werrs:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
if err != nil {
b.Fatal(err)
}
}
b.StopTimer()
b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
b.ReportMetric(float64(*bytesRead/b.N), "read/op")
err := c1.Close(websocket.StatusNormalClosure, "")
assert.Success(b, err)
})
}
}
func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) {
defer errd.Wrap(&err, "echo server failed")
c, err := websocket.Accept(w, r, opts)
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
err = wstest.EchoLoop(r.Context(), c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
}