From d09268649e33ce5b3afde49006d39508a28cbe12 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 8 Feb 2020 15:29:08 -0500
Subject: [PATCH] Autobahn tests fully pass :)

---
 assert_test.go   |  15 ----
 autobahn_test.go |  76 ++------------------
 conn.go          |   2 +-
 conn_test.go     | 177 ++++++++++++++++++++++++++---------------------
 read.go          |   6 +-
 write.go         |  31 +++++----
 6 files changed, 127 insertions(+), 180 deletions(-)

diff --git a/assert_test.go b/assert_test.go
index 3727d99..22814e3 100644
--- a/assert_test.go
+++ b/assert_test.go
@@ -5,7 +5,6 @@ import (
 	"crypto/rand"
 	"fmt"
 	"net/http"
-	"net/http/httptest"
 	"strings"
 	"testing"
 	"time"
@@ -108,20 +107,6 @@ func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts
 	return c
 }
 
-func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) {
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
-	defer cancel()
-
-	if opts == nil {
-		opts = &websocket.DialOptions{}
-	}
-	opts.HTTPClient = s.Client()
-
-	c, resp, err := websocket.Dial(ctx, wsURL(s), opts)
-	assert.Success(t, "websocket.Dial", err)
-	return c, resp
-}
-
 func slogType(v interface{}) slog.Field {
 	return slog.F("type", fmt.Sprintf("%T", v))
 }
diff --git a/autobahn_test.go b/autobahn_test.go
index dd9887f..71d22be 100644
--- a/autobahn_test.go
+++ b/autobahn_test.go
@@ -8,9 +8,6 @@ import (
 	"fmt"
 	"io/ioutil"
 	"net"
-	"net/http"
-	"net/http/httptest"
-	"os"
 	"os/exec"
 	"strconv"
 	"strings"
@@ -32,69 +29,14 @@ var excludedAutobahnCases = []string{
 	// We skip the tests related to requestMaxWindowBits as that is unimplemented due
 	// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
 	"13.3.*", "13.4.*", "13.5.*", "13.6.*",
-
-	"12.*",
-	"13.*",
 }
 
 var autobahnCases = []string{"*"}
 
-// https://github.com/crossbario/autobahn-python/tree/master/wstest
 func TestAutobahn(t *testing.T) {
 	t.Parallel()
 
-	if os.Getenv("AUTOBAHN") == "" {
-		t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite")
-	}
-
-	t.Run("server", testServerAutobahn)
-	t.Run("client", testClientAutobahn)
-}
-
-func testServerAutobahn(t *testing.T) {
-	t.Parallel()
-
-	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{
-			Subprotocols: []string{"echo"},
-		})
-		err := echoLoop(r.Context(), c)
-		assertCloseStatus(t, websocket.StatusNormalClosure, err)
-	}))
-	closeFn := wsgrace(s.Config)
-	defer func() {
-		err := closeFn()
-		assert.Success(t, "closeFn", err)
-	}()
-
-	specFile, err := tempJSONFile(map[string]interface{}{
-		"outdir": "ci/out/wstestServerReports",
-		"servers": []interface{}{
-			map[string]interface{}{
-				"agent": "main",
-				"url":   strings.Replace(s.URL, "http", "ws", 1),
-			},
-		},
-		"cases":         autobahnCases,
-		"exclude-cases": excludedAutobahnCases,
-	})
-	assert.Success(t, "tempJSONFile", err)
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
-	defer cancel()
-
-	args := []string{"--mode", "fuzzingclient", "--spec", specFile}
-	wstest := exec.CommandContext(ctx, "wstest", args...)
-	_, err = wstest.CombinedOutput()
-	assert.Success(t, "wstest", err)
-
-	checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
-}
-
-func testClientAutobahn(t *testing.T) {
-	t.Parallel()
-
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
 	defer cancel()
 
 	wstestURL, closeFn, err := wstestClientServer(ctx)
@@ -108,27 +50,17 @@ func testClientAutobahn(t *testing.T) {
 	assert.Success(t, "wstestCaseCount", err)
 
 	t.Run("cases", func(t *testing.T) {
-		// Max 8 cases running at a time.
-		mu := make(chan struct{}, 8)
-
 		for i := 1; i <= cases; i++ {
 			i := i
 			t.Run("", func(t *testing.T) {
-				t.Parallel()
-
-				mu <- struct{}{}
-				defer func() {
-					<-mu
-				}()
-
-				ctx, cancel := context.WithTimeout(ctx, time.Second*45)
+				ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
 				defer cancel()
 
 				c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil)
 				assert.Success(t, "autobahn dial", err)
 
 				err = echoLoop(ctx, c)
-				t.Logf("echoLoop: %+v", err)
+				t.Logf("echoLoop: %v", err)
 			})
 		}
 	})
@@ -174,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er
 		return "", nil, xerrors.Errorf("failed to write spec: %w", err)
 	}
 
-	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
 	defer func() {
 		if err != nil {
 			cancel()
diff --git a/conn.go b/conn.go
index 2d36123..163802b 100644
--- a/conn.go
+++ b/conn.go
@@ -99,7 +99,7 @@ func newConn(cfg connConfig) *Conn {
 		closed:      make(chan struct{}),
 		activePings: make(map[string]chan<- struct{}),
 	}
-	if c.flateThreshold == 0 {
+	if c.flate() && c.flateThreshold == 0 {
 		c.flateThreshold = 256
 		if c.writeNoContextTakeOver() {
 			c.flateThreshold = 512
diff --git a/conn_test.go b/conn_test.go
index 6f6b8d5..aceac3f 100644
--- a/conn_test.go
+++ b/conn_test.go
@@ -3,99 +3,70 @@
 package websocket_test
 
 import (
+	"bufio"
 	"context"
 	"crypto/rand"
 	"io"
 	"math/big"
+	"net"
 	"net/http"
 	"net/http/httptest"
-	"strings"
-	"sync/atomic"
 	"testing"
 	"time"
 
 	"cdr.dev/slog/sloggers/slogtest/assert"
-	"golang.org/x/xerrors"
 
 	"nhooyr.io/websocket"
 )
 
+func goFn(fn func()) func() {
+	done := make(chan struct{})
+	go func() {
+		defer close(done)
+		fn()
+	}()
+
+	return func() {
+		<-done
+	}
+}
+
 func TestConn(t *testing.T) {
 	t.Parallel()
 
 	t.Run("json", func(t *testing.T) {
 		t.Parallel()
 
-		s, closeFn := testEchoLoop(t)
-		defer closeFn()
+		for i := 0; i < 1; i++ {
+			t.Run("", func(t *testing.T) {
+				ctx, cancel := context.WithTimeout(context.Background(), time.Second)
+				defer cancel()
 
-		c, _ := dialWebSocket(t, s, nil)
-		defer c.Close(websocket.StatusInternalError, "")
-
-		c.SetReadLimit(1 << 30)
-
-		for i := 0; i < 10; i++ {
-			n := randInt(t, 1_048_576)
-			echoJSON(t, c, n)
-		}
+				c1, c2 := websocketPipe(t)
 
-		c.Close(websocket.StatusNormalClosure, "")
-	})
-}
-
-func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) {
-	h := http.HandlerFunc(fn)
-	if randInt(tb, 2) == 1 {
-		s = httptest.NewTLSServer(h)
-	} else {
-		s = httptest.NewServer(h)
-	}
-	closeFn2 := wsgrace(s.Config)
-	return s, func() {
-		err := closeFn2()
-		assert.Success(tb, "closeFn", err)
-	}
-}
+				wait := goFn(func() {
+					err := echoLoop(ctx, c1)
+					assertCloseStatus(t, websocket.StatusNormalClosure, err)
+				})
+				defer wait()
 
-// grace wraps s.Handler to gracefully shutdown WebSocket connections.
-// The returned function must be used to close the server instead of s.Close.
-func wsgrace(s *http.Server) (closeFn func() error) {
-	h := s.Handler
-	var conns int64
-	s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		atomic.AddInt64(&conns, 1)
-		defer atomic.AddInt64(&conns, -1)
+				c2.SetReadLimit(1 << 30)
 
-		ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
-		defer cancel()
-
-		r = r.WithContext(ctx)
+				for i := 0; i < 10; i++ {
+					n := randInt(t, 131_072)
+					echoJSON(t, c2, n)
+				}
 
-		h.ServeHTTP(w, r)
+				c2.Close(websocket.StatusNormalClosure, "")
+			})
+		}
 	})
+}
 
-	return func() error {
-		ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
-		defer cancel()
-
-		err := s.Shutdown(ctx)
-		if err != nil {
-			return xerrors.Errorf("server shutdown failed: %v", err)
-		}
+type writerFunc func(p []byte) (int, error)
 
-		t := time.NewTicker(time.Millisecond * 10)
-		defer t.Stop()
-		for {
-			select {
-			case <-t.C:
-				if atomic.LoadInt64(&conns) == 0 {
-					return nil
-				}
-			case <-ctx.Done():
-				return xerrors.Errorf("failed to wait for WebSocket connections: %v", ctx.Err())
-			}
-		}
-	}
+func (f writerFunc) Write(p []byte) (int, error) {
+	return f(p)
 }
 
 // echoLoop echos every msg received from c until an error
@@ -133,18 +104,8 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error {
 	}
 }
 
-func wsURL(s *httptest.Server) string {
-	return strings.Replace(s.URL, "http", "ws", 1)
-}
-
-func testEchoLoop(t testing.TB) (*httptest.Server, func()) {
-	return testServer(t, func(w http.ResponseWriter, r *http.Request) {
-		c := acceptWebSocket(t, r, w, nil)
-		defer c.Close(websocket.StatusInternalError, "")
-
-		err := echoLoop(r.Context(), c)
-		assertCloseStatus(t, websocket.StatusNormalClosure, err)
-	})
+func randBool(t testing.TB) bool  {
+	return randInt(t, 2) == 1
 }
 
 func randInt(t testing.TB, max int) int {
@@ -152,3 +113,65 @@ func randInt(t testing.TB, max int) int {
 	assert.Success(t, "rand.Int", err)
 	return int(x.Int64())
 }
+
+type testHijacker struct {
+	*httptest.ResponseRecorder
+	serverConn net.Conn
+	hijacked chan struct{}
+}
+
+var _ http.Hijacker = testHijacker{}
+
+func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
+	close(hj.hijacked)
+	return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
+}
+
+func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) {
+	var serverConn *websocket.Conn
+	tt := testTransport{
+		h: func(w http.ResponseWriter, r *http.Request) {
+			serverConn = acceptWebSocket(t, r, w, nil)
+		},
+	}
+
+	dialOpts := &websocket.DialOptions{
+		HTTPClient: &http.Client{
+			Transport: tt,
+		},
+	}
+
+	clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts)
+	assert.Success(t, "websocket.Dial", err)
+
+	if randBool(t) {
+		return serverConn, clientConn
+	}
+	return clientConn, serverConn
+}
+
+type testTransport struct {
+	h http.HandlerFunc
+}
+
+func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error)  {
+	clientConn, serverConn := net.Pipe()
+
+	hj := testHijacker{
+		ResponseRecorder: httptest.NewRecorder(),
+		serverConn:       serverConn,
+		hijacked:         make(chan struct{}),
+	}
+
+	done := make(chan struct{})
+	t.h.ServeHTTP(hj, r)
+
+	select {
+	case <-hj.hijacked:
+		resp := hj.ResponseRecorder.Result()
+		resp.Body = clientConn
+		return resp, nil
+	case <-done:
+		return hj.ResponseRecorder.Result(), nil
+	}
+}
diff --git a/read.go b/read.go
index 7e74894..b681a94 100644
--- a/read.go
+++ b/read.go
@@ -84,7 +84,7 @@ func newMsgReader(c *Conn) *msgReader {
 	return mr
 }
 
-func (mr *msgReader) ensureFlate() {
+func (mr *msgReader) resetFlate() {
 	if mr.flateContextTakeover() && mr.dict == nil {
 		mr.dict = newSlidingWindow(32768)
 	}
@@ -332,7 +332,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) {
 	mr.limitReader.reset(readerFunc(mr.read))
 
 	if mr.flate {
-		mr.ensureFlate()
+		mr.resetFlate()
 	}
 
 	mr.setFrame(h)
@@ -362,7 +362,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
 	defer mr.c.readMu.Unlock()
 
 	n, err = mr.limitReader.Read(p)
-	if mr.flateContextTakeover() {
+	if mr.flate && mr.flateContextTakeover() {
 		p = p[:n]
 		mr.dict.write(p)
 	}
diff --git a/write.go b/write.go
index 3454348..70656b9 100644
--- a/write.go
+++ b/write.go
@@ -70,17 +70,17 @@ func newMsgWriter(c *Conn) *msgWriter {
 }
 
 func (mw *msgWriter) ensureFlate() {
-	if mw.flateWriter == nil {
-		if mw.trimWriter == nil {
-			mw.trimWriter = &trimLastFourBytesWriter{
-				w: writerFunc(mw.write),
-			}
+	if mw.trimWriter == nil {
+		mw.trimWriter = &trimLastFourBytesWriter{
+			w: writerFunc(mw.write),
 		}
-		mw.trimWriter.reset()
+	}
 
+	if mw.flateWriter == nil {
 		mw.flateWriter = getFlateWriter(mw.trimWriter)
-		mw.flate = true
 	}
+
+	mw.flate = true
 }
 
 func (mw *msgWriter) flateContextTakeover() bool {
@@ -128,6 +128,11 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
 	mw.ctx = ctx
 	mw.opcode = opcode(typ)
 	mw.flate = false
+
+	if mw.trimWriter != nil {
+		mw.trimWriter.reset()
+	}
+
 	return nil
 }
 
@@ -146,9 +151,8 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
 		return 0, xerrors.New("cannot use closed writer")
 	}
 
-	// TODO can make threshold detection robust across writes by writing to bufio writer
-	if mw.flate ||
-		mw.c.flate() && len(p) >= mw.c.flateThreshold {
+	// TODO Write to buffer to detect whether to enable flate or not for this message.
+	if mw.c.flate() {
 		mw.ensureFlate()
 		return mw.flateWriter.Write(p)
 	}
@@ -172,7 +176,6 @@ func (mw *msgWriter) Close() (err error) {
 	if mw.closed {
 		return xerrors.New("cannot use closed writer")
 	}
-	mw.closed = true
 
 	if mw.flate {
 		err = mw.flateWriter.Flush()
@@ -181,12 +184,16 @@ func (mw *msgWriter) Close() (err error) {
 		}
 	}
 
+	// We set closed after flushing the flate writer to ensure Write
+	// can succeed.
+	mw.closed = true
+
 	_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
 	if err != nil {
 		return xerrors.Errorf("failed to write fin frame: %w", err)
 	}
 
-	if mw.c.flate() && !mw.flateContextTakeover() {
+	if mw.flate && !mw.flateContextTakeover() {
 		mw.returnFlateWriter()
 	}
 	mw.mu.Unlock()
-- 
GitLab