good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 9c5bfabc authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Simplifications of conn_test.go

parent b33d48cb
Branches
Tags
No related merge requests found
......@@ -15,7 +15,6 @@ import (
"testing"
"time"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/duration"
"golang.org/x/xerrors"
......@@ -37,153 +36,86 @@ func TestConn(t *testing.T) {
for i := 0; i < 5; i++ {
t.Run("", func(t *testing.T) {
t.Parallel()
tt := newTest(t)
defer tt.done()
copts := &websocket.CompressionOptions{
dialCopts := &websocket.CompressionOptions{
Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)),
Threshold: xrand.Int(9999),
}
c1, c2, err := wstest.Pipe(&websocket.DialOptions{
CompressionOptions: copts,
}, &websocket.AcceptOptions{
CompressionOptions: copts,
})
if err != nil {
t.Fatal(err)
acceptCopts := &websocket.CompressionOptions{
Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)),
Threshold: xrand.Int(9999),
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c2)
return assertCloseStatus(websocket.StatusNormalClosure, err)
c1, c2 := tt.pipe(&websocket.DialOptions{
CompressionOptions: dialCopts,
}, &websocket.AcceptOptions{
CompressionOptions: acceptCopts,
})
defer func() {
err := <-echoLoopErr
if err != nil {
t.Errorf("echo loop error: %v", err)
}
}()
defer cancel()
tt.goEchoLoop(c2)
c1.SetReadLimit(131072)
for i := 0; i < 5; i++ {
err := wstest.Echo(ctx, c1, 131072)
if err != nil {
t.Fatal(err)
}
err := wstest.Echo(tt.ctx, c1, 131072)
tt.success(err)
}
err = c1.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
})
}
})
t.Run("badClose", func(t *testing.T) {
t.Parallel()
tt := newTest(t)
defer tt.done()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c1.Close(websocket.StatusInternalError, "")
defer c2.Close(websocket.StatusInternalError, "")
c1, _ := tt.pipe(nil, nil)
err = c1.Close(-1, "")
if !cmp.ErrorContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set") {
t.Fatalf("unexpected error: %v", err)
}
err := c1.Close(-1, "")
tt.errContains(err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
})
t.Run("ping", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c1.Close(websocket.StatusInternalError, "")
defer c2.Close(websocket.StatusInternalError, "")
tt := newTest(t)
defer tt.done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
c1, c2 := tt.pipe(nil, nil)
c2.CloseRead(ctx)
c1.CloseRead(ctx)
c1.CloseRead(tt.ctx)
c2.CloseRead(tt.ctx)
for i := 0; i < 10; i++ {
err = c1.Ping(ctx)
if err != nil {
t.Fatal(err)
}
err := c1.Ping(tt.ctx)
tt.success(err)
}
err = c1.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
})
t.Run("badPing", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c1.Close(websocket.StatusInternalError, "")
defer c2.Close(websocket.StatusInternalError, "")
tt := newTest(t)
defer tt.done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
c1, c2 := tt.pipe(nil, nil)
c2.CloseRead(ctx)
c2.CloseRead(tt.ctx)
err = c1.Ping(ctx)
if !cmp.ErrorContains(err, "failed to wait for pong") {
t.Fatalf("unexpected error: %v", err)
}
err := c1.Ping(tt.ctx)
tt.errContains(err, "failed to wait for pong")
})
t.Run("concurrentWrite", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
tt := newTest(t)
defer tt.done()
discardLoopErr := xsync.Go(func() error {
for {
_, _, err := c2.Read(ctx)
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
return nil
}
if err != nil {
return err
}
}
})
defer func() {
err := <-discardLoopErr
if err != nil {
t.Errorf("discard loop error: %v", err)
}
}()
defer cancel()
c1, c2 := tt.pipe(nil, nil)
tt.goDiscardLoop(c2)
msg := xrand.Bytes(xrand.Int(9999))
const count = 100
......@@ -191,74 +123,52 @@ func TestConn(t *testing.T) {
for i := 0; i < count; i++ {
go func() {
errs <- c1.Write(ctx, websocket.MessageBinary, msg)
errs <- c1.Write(tt.ctx, websocket.MessageBinary, msg)
}()
}
for i := 0; i < count; i++ {
err := <-errs
if err != nil {
t.Fatal(err)
}
tt.success(err)
}
err = c1.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
err := c1.Close(websocket.StatusNormalClosure, "")
tt.success(err)
})
t.Run("concurrentWriteError", func(t *testing.T) {
t.Parallel()
tt := newTest(t)
defer tt.done()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
c1, _ := tt.pipe(nil, nil)
_, err = c1.Writer(context.Background(), websocket.MessageText)
if err != nil {
t.Fatal(err)
}
_, err := c1.Writer(tt.ctx, websocket.MessageText)
tt.success(err)
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*100)
defer cancel()
err = c1.Write(ctx, websocket.MessageText, []byte("x"))
if !xerrors.Is(err, context.DeadlineExceeded) {
t.Fatal(err)
}
tt.eq(context.DeadlineExceeded, err)
})
t.Run("netConn", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
tt := newTest(t)
defer tt.done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
c1, c2 := tt.pipe(nil, nil)
n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(ctx, c2, websocket.MessageBinary)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
// Does not give any confidence but at least ensures no crashes.
d, _ := ctx.Deadline()
d, _ := tt.ctx.Deadline()
n1.SetDeadline(d)
n1.SetDeadline(time.Time{})
if n1.RemoteAddr() != n1.LocalAddr() {
t.Fatal()
}
if n1.RemoteAddr().String() != "websocket/unknown-addr" || n1.RemoteAddr().Network() != "websocket" {
t.Fatal(n1.RemoteAddr())
}
tt.eq(n1.RemoteAddr(), n1.LocalAddr())
tt.eq("websocket/unknown-addr", n1.RemoteAddr().String())
tt.eq("websocket", n1.RemoteAddr().Network())
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
......@@ -269,40 +179,25 @@ func TestConn(t *testing.T) {
})
b, err := ioutil.ReadAll(n1)
if err != nil {
t.Fatal(err)
}
tt.success(err)
_, err = n1.Read(nil)
if err != io.EOF {
t.Fatalf("expected EOF: %v", err)
}
tt.eq(err, io.EOF)
err = <-errs
if err != nil {
t.Fatal(err)
}
tt.success(err)
if !cmp.Equal([]byte("hello"), b) {
t.Fatalf("unexpected msg: %v", cmp.Diff([]byte("hello"), b))
}
tt.eq([]byte("hello"), b)
})
t.Run("netConn/BadMsg", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
tt := newTest(t)
defer tt.done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
c1, c2 := tt.pipe(nil, nil)
n1 := websocket.NetConn(ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(ctx, c2, websocket.MessageText)
n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
errs := xsync.Go(func() error {
_, err := n2.Write([]byte("hello"))
......@@ -312,114 +207,60 @@ func TestConn(t *testing.T) {
return nil
})
_, err = ioutil.ReadAll(n1)
if !cmp.ErrorContains(err, `unexpected frame type read (expected MessageBinary): MessageText`) {
t.Fatal(err)
}
_, err := ioutil.ReadAll(n1)
tt.errContains(err, `unexpected frame type read (expected MessageBinary): MessageText`)
err = <-errs
if err != nil {
t.Fatal(err)
}
tt.success(err)
})
t.Run("wsjson", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
tt := newTest(t)
defer tt.done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
c1, c2 := tt.pipe(nil, nil)
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c2)
return assertCloseStatus(websocket.StatusNormalClosure, err)
})
defer func() {
err := <-echoLoopErr
if err != nil {
t.Errorf("echo loop error: %v", err)
}
}()
defer cancel()
tt.goEchoLoop(c2)
c1.SetReadLimit(1 << 30)
exp := xrand.String(xrand.Int(131072))
werr := xsync.Go(func() error {
return wsjson.Write(ctx, c1, exp)
return wsjson.Write(tt.ctx, c1, exp)
})
var act interface{}
err = wsjson.Read(ctx, c1, &act)
if err != nil {
t.Fatal(err)
}
if exp != act {
t.Fatal(cmp.Diff(exp, act))
}
err := wsjson.Read(tt.ctx, c1, &act)
tt.success(err)
tt.eq(exp, act)
err = <-werr
if err != nil {
t.Fatal(err)
}
tt.success(err)
err = c1.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tt.success(err)
})
t.Run("wspb", func(t *testing.T) {
t.Parallel()
c1, c2, err := wstest.Pipe(nil, nil)
if err != nil {
t.Fatal(err)
}
defer c2.Close(websocket.StatusInternalError, "")
defer c1.Close(websocket.StatusInternalError, "")
tt := newTest(t)
defer tt.done()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*15)
defer cancel()
c1, c2 := tt.pipe(nil, nil)
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c2)
return assertCloseStatus(websocket.StatusNormalClosure, err)
})
defer func() {
err := <-echoLoopErr
if err != nil {
t.Errorf("echo loop error: %v", err)
}
}()
defer cancel()
tt.goEchoLoop(c2)
exp := ptypes.DurationProto(100)
err = wspb.Write(ctx, c1, exp)
if err != nil {
t.Fatal(err)
}
err := wspb.Write(tt.ctx, c1, exp)
tt.success(err)
act := &duration.Duration{}
err = wspb.Read(ctx, c1, act)
if err != nil {
t.Fatal(err)
}
if !proto.Equal(exp, act) {
t.Fatal(cmp.Diff(exp, act))
}
err = wspb.Read(tt.ctx, c1, act)
tt.success(err)
tt.eq(exp, act)
err = c1.Close(websocket.StatusNormalClosure, "")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
tt.success(err)
})
}
......@@ -443,7 +284,7 @@ func TestWasm(t *testing.T) {
err = wstest.EchoLoop(r.Context(), c)
if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
t.Errorf("echoLoop: %v", err)
t.Errorf("echoLoop failed: %v", err)
}
}))
defer wg.Wait()
......@@ -470,3 +311,103 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error {
}
return nil
}
type test struct {
t *testing.T
ctx context.Context
doneFuncs []func()
}
func newTest(t *testing.T) *test {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
tt := &test{t: t, ctx: ctx}
tt.appendDone(cancel)
return tt
}
func (tt *test) appendDone(f func()) {
tt.doneFuncs = append(tt.doneFuncs, f)
}
func (tt *test) done() {
for i := len(tt.doneFuncs) - 1; i >= 0; i-- {
tt.doneFuncs[i]()
}
}
func (tt *test) goEchoLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
echoLoopErr := xsync.Go(func() error {
err := wstest.EchoLoop(ctx, c)
return assertCloseStatus(websocket.StatusNormalClosure, err)
})
tt.appendDone(func() {
cancel()
err := <-echoLoopErr
if err != nil {
tt.t.Errorf("echo loop error: %v", err)
}
})
}
func (tt *test) goDiscardLoop(c *websocket.Conn) {
ctx, cancel := context.WithCancel(tt.ctx)
discardLoopErr := xsync.Go(func() error {
for {
_, _, err := c.Read(ctx)
if websocket.CloseStatus(err) == websocket.StatusNormalClosure {
return nil
}
if err != nil {
return err
}
}
})
tt.appendDone(func() {
cancel()
err := <-discardLoopErr
if err != nil {
tt.t.Errorf("discard loop error: %v", err)
}
})
}
func (tt *test) pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (c1, c2 *websocket.Conn) {
tt.t.Helper()
c1, c2, err := wstest.Pipe(dialOpts, acceptOpts)
if err != nil {
tt.t.Fatal(err)
}
tt.appendDone(func() {
c2.Close(websocket.StatusInternalError, "")
c1.Close(websocket.StatusInternalError, "")
})
return c1, c2
}
func (tt *test) success(err error) {
tt.t.Helper()
if err != nil {
tt.t.Fatal(err)
}
}
func (tt *test) errContains(err error, sub string) {
tt.t.Helper()
if !cmp.ErrorContains(err, sub) {
tt.t.Fatalf("error does not contain %q: %v", sub, err)
}
}
func (tt *test) eq(exp, act interface{}) {
tt.t.Helper()
if !cmp.Equal(exp, act) {
tt.t.Fatalf(cmp.Diff(exp, act))
}
}
......@@ -4,6 +4,7 @@ import (
"reflect"
"strings"
"github.com/golang/protobuf/proto"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
)
......@@ -12,7 +13,7 @@ import (
func Equal(v1, v2 interface{}) bool {
return cmp.Equal(v1, v2, cmpopts.EquateErrors(), cmp.Exporter(func(r reflect.Type) bool {
return true
}))
}), cmp.Comparer(proto.Equal))
}
// Diff returns a human readable diff between v1 and v2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment