Newer
Older
"net/http"
"net/http/httptest"
"os"
"os/exec"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/duration"
"nhooyr.io/websocket/internal/test/wstest"
"nhooyr.io/websocket/internal/test/xrand"
compressionMode := func() websocket.CompressionMode {
return websocket.CompressionMode(xrand.Int(int(websocket.CompressionContextTakeover) + 1))
tt, c1, c2 := newConnTest(t, &websocket.DialOptions{
CompressionMode: compressionMode(),
CompressionThreshold: xrand.Int(9999),
CompressionMode: compressionMode(),
CompressionThreshold: xrand.Int(9999),
err := c1.Close(websocket.StatusNormalClosure, "")
tt, c1, c2 := newConnTest(t, nil, nil)
c2.CloseRead(tt.ctx)
assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
err := c1.Close(websocket.StatusNormalClosure, "")
ctx, cancel := context.WithTimeout(tt.ctx, time.Millisecond*100)
defer cancel()
err := c1.Ping(ctx)
assert.Contains(t, err, "failed to wait for pong")
})
t.Run("concurrentWrite", func(t *testing.T) {
msg := xrand.Bytes(xrand.Int(9999))
const count = 100
errs := make(chan error, count)
for i := 0; i < count; i++ {
go func() {
select {
case errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg):
case <-tt.ctx.Done():
return
}
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
err := c1.Close(websocket.StatusNormalClosure, "")
t.Run("concurrentWriteError", func(t *testing.T) {
_, err := c1.Writer(tt.ctx, websocket.MessageText)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
err = c1.Write(ctx, websocket.MessageText, []byte("x"))
assert.Equal(t, "write error", context.DeadlineExceeded, err)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
// Does not give any confidence but at least ensures no crashes.
assert.Equal(t, "remote addr", n1.RemoteAddr(), n1.LocalAddr())
assert.Equal(t, "remote addr string", "websocket/unknown-addr", n1.RemoteAddr().String())
assert.Equal(t, "remote addr network", "websocket", n1.RemoteAddr().Network())
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
if err != nil {
return err
}
return n2.Close()
})
b, err := ioutil.ReadAll(n1)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
assert.Contains(t, err, `unexpected frame type read (expected MessageBinary): MessageText`)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
t.Run("netConn/readLimit", func(t *testing.T) {
tt, c1, c2 := newConnTest(t, nil, nil)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
s := strings.Repeat("papa", 1 << 20)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte(s))
if err != nil {
return err
}
return n2.Close()
})
b, err := ioutil.ReadAll(n1)
assert.Success(t, err)
_, err = n1.Read(nil)
assert.Equal(t, "read error", err, io.EOF)
select {
case err := <-errs:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
assert.Equal(t, "read msg", s, string(b))
})
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
select {
case err := <-werr:
assert.Success(t, err)
case <-tt.ctx.Done():
t.Fatal(tt.ctx.Err())
}
assert.Success(t, err)
assert.Equal(t, "read msg", exp, act)
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Subprotocols: []string{"echo"},
InsecureSkipVerify: true,
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", ".")
cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", s.URL))
b, err := cmd.CombinedOutput()
if err != nil {
t.Fatalf("wasm test binary failed: %v:\n%s", err, b)
func assertCloseStatus(exp websocket.StatusCode, err error) error {
if websocket.CloseStatus(err) == -1 {
return fmt.Errorf("expected websocket.CloseError: %T %v", err, err)
return fmt.Errorf("expected close status %v but got %v", exp, err)
func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
if t, ok := t.(*testing.T); ok {
t.Parallel()
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
if xrand.Bool() {
c1, c2 = c2, c1
}
t.Cleanup(func() {
// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
// blocking the test shutting down.
go c2.Close(websocket.StatusInternalError, "")
go c1.Close(websocket.StatusInternalError, "")
func (tt *connTest) goEchoLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
})
cancel()
err := <-echoLoopErr
if err != nil {
tt.t.Errorf("echo loop error: %v", err)
}
})
}
func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
discardLoopErr := xsync.Go(func() error {
defer c.Close(websocket.StatusInternalError, "")
for {
_, _, err := c.Read(ctx)
if err != nil {
return assertCloseStatus(websocket.StatusNormalClosure, err)
cancel()
err := <-discardLoopErr
if err != nil {
tt.t.Errorf("discard loop error: %v", err)
}
})
}
func BenchmarkConn(b *testing.B) {
var benchCases = []struct {
name string
mode websocket.CompressionMode
}{
{
mode: websocket.CompressionContextTakeover,
},
{
mode: websocket.CompressionNoContextTakeover,
},
}
for _, bc := range benchCases {
b.Run(bc.name, func(b *testing.B) {
bb, c1, c2 := newConnTest(b, &websocket.DialOptions{
bytesWritten := c1.RecordBytesWritten()
bytesRead := c1.RecordBytesRead()
msg := []byte(strings.Repeat("1234", 128))
readBuf := make([]byte, len(msg))
writes := make(chan struct{})
defer close(writes)
werrs := make(chan error)
go func() {
select {
case werrs <- c1.Write(bb.ctx, websocket.MessageText, msg):
case <-bb.ctx.Done():
return
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
select {
case writes <- struct{}{}:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
typ, r, err := c1.Reader(bb.ctx)
if err != nil {
b.Fatal(err)
}
if websocket.MessageText != typ {
assert.Equal(b, "data type", websocket.MessageText, typ)
}
_, err = io.ReadFull(r, readBuf)
if err != nil {
b.Fatal(err)
}
n2, err := r.Read(readBuf)
if err != io.EOF {
assert.Equal(b, "read err", io.EOF, err)
}
if n2 != 0 {
assert.Equal(b, "n2", 0, n2)
}
if !bytes.Equal(msg, readBuf) {
assert.Equal(b, "msg", msg, readBuf)
select {
case err = <-werrs:
case <-bb.ctx.Done():
b.Fatal(bb.ctx.Err())
}
if err != nil {
b.Fatal(err)
}
}
b.StopTimer()
b.ReportMetric(float64(*bytesWritten/b.N), "written/op")
b.ReportMetric(float64(*bytesRead/b.N), "read/op")
err := c1.Close(websocket.StatusNormalClosure, "")
assert.Success(b, err)
})
}
}
func echoServer(w http.ResponseWriter, r *http.Request, opts *websocket.AcceptOptions) (err error) {
defer errd.Wrap(&err, "echo server failed")
c, err := websocket.Accept(w, r, opts)
if err != nil {
return err
}
defer c.Close(websocket.StatusInternalError, "")
err = wstest.EchoLoop(r.Context(), c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
}
func TestGin(t *testing.T) {
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
gin.SetMode(gin.ReleaseMode)
r := gin.New()
r.GET("/", func(ginCtx *gin.Context) {
err := echoServer(ginCtx.Writer, ginCtx.Request, nil)
if err != nil {
t.Error(err)
}
})
s := httptest.NewServer(r)
defer s.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
defer cancel()
c, _, err := websocket.Dial(ctx, s.URL, nil)
assert.Success(t, err)
defer c.Close(websocket.StatusInternalError, "")
err = wsjson.Write(ctx, c, "hello")
assert.Success(t, err)
var v interface{}
err = wsjson.Read(ctx, c, &v)
assert.Success(t, err)
assert.Equal(t, "read msg", "hello", v)
err = c.Close(websocket.StatusNormalClosure, "")
assert.Success(t, err)