From de8e29bdb753bc55c8f742c664adb44833afbc50 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 18 May 2020 04:25:52 -0400
Subject: [PATCH] Fix tests taking too long and switch to t.Cleanup

---
 autobahn_test.go |  7 ++++++-
 conn_test.go     | 47 +++++++++++++----------------------------------
 2 files changed, 19 insertions(+), 35 deletions(-)

diff --git a/autobahn_test.go b/autobahn_test.go
index d53159a..5bf0062 100644
--- a/autobahn_test.go
+++ b/autobahn_test.go
@@ -28,7 +28,6 @@ var excludedAutobahnCases = []string{
 
 	// We skip the tests related to requestMaxWindowBits as that is unimplemented due
 	// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
-	// Same with klauspost/compress which doesn't allow adjusting the sliding window size.
 	"13.3.*", "13.4.*", "13.5.*", "13.6.*",
 }
 
@@ -41,6 +40,12 @@ func TestAutobahn(t *testing.T) {
 		t.SkipNow()
 	}
 
+	if os.Getenv("AUTOBAHN_FAST") != "" {
+		excludedAutobahnCases = append(excludedAutobahnCases,
+			"9.*", "13.*", "12.*",
+		)
+	}
+
 	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
 	defer cancel()
 
diff --git a/conn_test.go b/conn_test.go
index 4bab5ad..9c85459 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -49,7 +49,6 @@ func TestConn(t *testing.T) {
 					CompressionMode:      compressionMode(),
 					CompressionThreshold: xrand.Int(9999),
 				})
-				defer tt.cleanup()
 
 				tt.goEchoLoop(c2)
 
@@ -67,8 +66,9 @@ func TestConn(t *testing.T) {
 	})
 
 	t.Run("badClose", func(t *testing.T) {
-		tt, c1, _ := newConnTest(t, nil, nil)
-		defer tt.cleanup()
+		tt, c1, c2 := newConnTest(t, nil, nil)
+
+		c2.CloseRead(tt.ctx)
 
 		err := c1.Close(-1, "")
 		assert.Contains(t, err, "failed to marshal close frame: status code StatusCode(-1) cannot be set")
@@ -76,7 +76,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("ping", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		c1.CloseRead(tt.ctx)
 		c2.CloseRead(tt.ctx)
@@ -92,7 +91,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("badPing", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		c2.CloseRead(tt.ctx)
 
@@ -105,7 +103,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("concurrentWrite", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		tt.goDiscardLoop(c2)
 
@@ -138,7 +135,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("concurrentWriteError", func(t *testing.T) {
 		tt, c1, _ := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		_, err := c1.Writer(tt.ctx, websocket.MessageText)
 		assert.Success(t, err)
@@ -152,7 +148,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("netConn", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
 		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageBinary)
@@ -192,17 +187,14 @@ func TestConn(t *testing.T) {
 
 	t.Run("netConn/BadMsg", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		n1 := websocket.NetConn(tt.ctx, c1, websocket.MessageBinary)
 		n2 := websocket.NetConn(tt.ctx, c2, websocket.MessageText)
 
+		c2.CloseRead(tt.ctx)
 		errs := xsync.Go(func() error {
 			_, err := n2.Write([]byte("hello"))
-			if err != nil {
-				return err
-			}
-			return nil
+			return err
 		})
 
 		_, err := ioutil.ReadAll(n1)
@@ -218,7 +210,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("wsjson", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		tt.goEchoLoop(c2)
 
@@ -248,7 +239,6 @@ func TestConn(t *testing.T) {
 
 	t.Run("wspb", func(t *testing.T) {
 		tt, c1, c2 := newConnTest(t, nil, nil)
-		defer tt.cleanup()
 
 		tt.goEchoLoop(c2)
 
@@ -305,8 +295,6 @@ func assertCloseStatus(exp websocket.StatusCode, err error) error {
 type connTest struct {
 	t   testing.TB
 	ctx context.Context
-
-	doneFuncs []func()
 }
 
 func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) (tt *connTest, c1, c2 *websocket.Conn) {
@@ -317,30 +305,22 @@ func newConnTest(t testing.TB, dialOpts *websocket.DialOptions, acceptOpts *webs
 
 	ctx, cancel := context.WithTimeout(context.Background(), time.Second*30)
 	tt = &connTest{t: t, ctx: ctx}
-	tt.appendDone(cancel)
+	t.Cleanup(cancel)
 
 	c1, c2 = wstest.Pipe(dialOpts, acceptOpts)
 	if xrand.Bool() {
 		c1, c2 = c2, c1
 	}
-	tt.appendDone(func() {
-		c2.Close(websocket.StatusInternalError, "")
-		c1.Close(websocket.StatusInternalError, "")
+	t.Cleanup(func() {
+		// We don't actually care whether this succeeds so we just run it in a separate goroutine to avoid
+		// blocking the test shutting down.
+		go c2.Close(websocket.StatusInternalError, "")
+		go c1.Close(websocket.StatusInternalError, "")
 	})
 
 	return tt, c1, c2
 }
 
-func (tt *connTest) appendDone(f func()) {
-	tt.doneFuncs = append(tt.doneFuncs, f)
-}
-
-func (tt *connTest) cleanup() {
-	for i := len(tt.doneFuncs) - 1; i >= 0; i-- {
-		tt.doneFuncs[i]()
-	}
-}
-
 func (tt *connTest) goEchoLoop(c *websocket.Conn) {
 	ctx, cancel := context.WithCancel(tt.ctx)
 
@@ -348,7 +328,7 @@ func (tt *connTest) goEchoLoop(c *websocket.Conn) {
 		err := wstest.EchoLoop(ctx, c)
 		return assertCloseStatus(websocket.StatusNormalClosure, err)
 	})
-	tt.appendDone(func() {
+	tt.t.Cleanup(func() {
 		cancel()
 		err := <-echoLoopErr
 		if err != nil {
@@ -370,7 +350,7 @@ func (tt *connTest) goDiscardLoop(c *websocket.Conn) {
 			}
 		}
 	})
-	tt.appendDone(func() {
+	tt.t.Cleanup(func() {
 		cancel()
 		err := <-discardLoopErr
 		if err != nil {
@@ -404,7 +384,6 @@ func BenchmarkConn(b *testing.B) {
 			}, &websocket.AcceptOptions{
 				CompressionMode: bc.mode,
 			})
-			defer bb.cleanup()
 
 			bb.goEchoLoop(c2)
 
-- 
GitLab