From d09268649e33ce5b3afde49006d39508a28cbe12 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sat, 8 Feb 2020 15:29:08 -0500 Subject: [PATCH] Autobahn tests fully pass :) --- assert_test.go | 15 ---- autobahn_test.go | 76 ++------------------ conn.go | 2 +- conn_test.go | 177 ++++++++++++++++++++++++++--------------------- read.go | 6 +- write.go | 31 +++++---- 6 files changed, 127 insertions(+), 180 deletions(-) diff --git a/assert_test.go b/assert_test.go index 3727d99..22814e3 100644 --- a/assert_test.go +++ b/assert_test.go @@ -5,7 +5,6 @@ import ( "crypto/rand" "fmt" "net/http" - "net/http/httptest" "strings" "testing" "time" @@ -108,20 +107,6 @@ func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts return c } -func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - if opts == nil { - opts = &websocket.DialOptions{} - } - opts.HTTPClient = s.Client() - - c, resp, err := websocket.Dial(ctx, wsURL(s), opts) - assert.Success(t, "websocket.Dial", err) - return c, resp -} - func slogType(v interface{}) slog.Field { return slog.F("type", fmt.Sprintf("%T", v)) } diff --git a/autobahn_test.go b/autobahn_test.go index dd9887f..71d22be 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -8,9 +8,6 @@ import ( "fmt" "io/ioutil" "net" - "net/http" - "net/http/httptest" - "os" "os/exec" "strconv" "strings" @@ -32,69 +29,14 @@ var excludedAutobahnCases = []string{ // We skip the tests related to requestMaxWindowBits as that is unimplemented due // to limitations in compress/flate. See https://github.com/golang/go/issues/3155 "13.3.*", "13.4.*", "13.5.*", "13.6.*", - - "12.*", - "13.*", } var autobahnCases = []string{"*"} -// https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN") == "" { - t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite") - } - - t.Run("server", testServerAutobahn) - t.Run("client", testClientAutobahn) -} - -func testServerAutobahn(t *testing.T) { - t.Parallel() - - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - err := echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - })) - closeFn := wsgrace(s.Config) - defer func() { - err := closeFn() - assert.Success(t, "closeFn", err) - }() - - specFile, err := tempJSONFile(map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": autobahnCases, - "exclude-cases": excludedAutobahnCases, - }) - assert.Success(t, "tempJSONFile", err) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingclient", "--spec", specFile} - wstest := exec.CommandContext(ctx, "wstest", args...) - _, err = wstest.CombinedOutput() - assert.Success(t, "wstest", err) - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") -} - -func testClientAutobahn(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) @@ -108,27 +50,17 @@ func testClientAutobahn(t *testing.T) { assert.Success(t, "wstestCaseCount", err) t.Run("cases", func(t *testing.T) { - // Max 8 cases running at a time. - mu := make(chan struct{}, 8) - for i := 1; i <= cases; i++ { i := i t.Run("", func(t *testing.T) { - t.Parallel() - - mu <- struct{}{} - defer func() { - <-mu - }() - - ctx, cancel := context.WithTimeout(ctx, time.Second*45) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) assert.Success(t, "autobahn dial", err) err = echoLoop(ctx, c) - t.Logf("echoLoop: %+v", err) + t.Logf("echoLoop: %v", err) }) } }) @@ -174,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er return "", nil, xerrors.Errorf("failed to write spec: %w", err) } - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15) defer func() { if err != nil { cancel() diff --git a/conn.go b/conn.go index 2d36123..163802b 100644 --- a/conn.go +++ b/conn.go @@ -99,7 +99,7 @@ func newConn(cfg connConfig) *Conn { closed: make(chan struct{}), activePings: make(map[string]chan<- struct{}), } - if c.flateThreshold == 0 { + if c.flate() && c.flateThreshold == 0 { c.flateThreshold = 256 if c.writeNoContextTakeOver() { c.flateThreshold = 512 diff --git a/conn_test.go b/conn_test.go index 6f6b8d5..aceac3f 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,99 +3,70 @@ package websocket_test import ( + "bufio" "context" "crypto/rand" "io" "math/big" + "net" "net/http" "net/http/httptest" - "strings" - "sync/atomic" "testing" "time" "cdr.dev/slog/sloggers/slogtest/assert" - "golang.org/x/xerrors" "nhooyr.io/websocket" ) +func goFn(fn func()) func() { + done := make(chan struct{}) + go func() { + defer close(done) + fn() + }() + + return func() { + <-done + } +} + func TestConn(t *testing.T) { t.Parallel() t.Run("json", func(t *testing.T) { t.Parallel() - s, closeFn := testEchoLoop(t) - defer closeFn() + for i := 0; i < 1; i++ { + t.Run("", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() - c, _ := dialWebSocket(t, s, nil) - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) - - for i := 0; i < 10; i++ { - n := randInt(t, 1_048_576) - echoJSON(t, c, n) - } + c1, c2 := websocketPipe(t) - c.Close(websocket.StatusNormalClosure, "") - }) -} - -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) { - h := http.HandlerFunc(fn) - if randInt(tb, 2) == 1 { - s = httptest.NewTLSServer(h) - } else { - s = httptest.NewServer(h) - } - closeFn2 := wsgrace(s.Config) - return s, func() { - err := closeFn2() - assert.Success(tb, "closeFn", err) - } -} + wait := goFn(func() { + err := echoLoop(ctx, c1) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }) + defer wait() -// grace wraps s.Handler to gracefully shutdown WebSocket connections. -// The returned function must be used to close the server instead of s.Close. -func wsgrace(s *http.Server) (closeFn func() error) { - h := s.Handler - var conns int64 - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) + c2.SetReadLimit(1 << 30) - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - - r = r.WithContext(ctx) + for i := 0; i < 10; i++ { + n := randInt(t, 131_072) + echoJSON(t, c2, n) + } - h.ServeHTTP(w, r) + c2.Close(websocket.StatusNormalClosure, "") + }) + } }) +} - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - err := s.Shutdown(ctx) - if err != nil { - return xerrors.Errorf("server shutdown failed: %v", err) - } +type writerFunc func(p []byte) (int, error) - t := time.NewTicker(time.Millisecond * 10) - defer t.Stop() - for { - select { - case <-t.C: - if atomic.LoadInt64(&conns) == 0 { - return nil - } - case <-ctx.Done(): - return xerrors.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) - } - } - } +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) } // echoLoop echos every msg received from c until an error @@ -133,18 +104,8 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } -func wsURL(s *httptest.Server) string { - return strings.Replace(s.URL, "http", "ws", 1) -} - -func testEchoLoop(t testing.TB) (*httptest.Server, func()) { - return testServer(t, func(w http.ResponseWriter, r *http.Request) { - c := acceptWebSocket(t, r, w, nil) - defer c.Close(websocket.StatusInternalError, "") - - err := echoLoop(r.Context(), c) - assertCloseStatus(t, websocket.StatusNormalClosure, err) - }) +func randBool(t testing.TB) bool { + return randInt(t, 2) == 1 } func randInt(t testing.TB, max int) int { @@ -152,3 +113,65 @@ func randInt(t testing.TB, max int) int { assert.Success(t, "rand.Int", err) return int(x.Int64()) } + +type testHijacker struct { + *httptest.ResponseRecorder + serverConn net.Conn + hijacked chan struct{} +} + +var _ http.Hijacker = testHijacker{} + +func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + close(hj.hijacked) + return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil +} + +func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) { + var serverConn *websocket.Conn + tt := testTransport{ + h: func(w http.ResponseWriter, r *http.Request) { + serverConn = acceptWebSocket(t, r, w, nil) + }, + } + + dialOpts := &websocket.DialOptions{ + HTTPClient: &http.Client{ + Transport: tt, + }, + } + + clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts) + assert.Success(t, "websocket.Dial", err) + + if randBool(t) { + return serverConn, clientConn + } + return clientConn, serverConn +} + +type testTransport struct { + h http.HandlerFunc +} + +func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) { + clientConn, serverConn := net.Pipe() + + hj := testHijacker{ + ResponseRecorder: httptest.NewRecorder(), + serverConn: serverConn, + hijacked: make(chan struct{}), + } + + done := make(chan struct{}) + t.h.ServeHTTP(hj, r) + + select { + case <-hj.hijacked: + resp := hj.ResponseRecorder.Result() + resp.Body = clientConn + return resp, nil + case <-done: + return hj.ResponseRecorder.Result(), nil + } +} diff --git a/read.go b/read.go index 7e74894..b681a94 100644 --- a/read.go +++ b/read.go @@ -84,7 +84,7 @@ func newMsgReader(c *Conn) *msgReader { return mr } -func (mr *msgReader) ensureFlate() { +func (mr *msgReader) resetFlate() { if mr.flateContextTakeover() && mr.dict == nil { mr.dict = newSlidingWindow(32768) } @@ -332,7 +332,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) { mr.limitReader.reset(readerFunc(mr.read)) if mr.flate { - mr.ensureFlate() + mr.resetFlate() } mr.setFrame(h) @@ -362,7 +362,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) { defer mr.c.readMu.Unlock() n, err = mr.limitReader.Read(p) - if mr.flateContextTakeover() { + if mr.flate && mr.flateContextTakeover() { p = p[:n] mr.dict.write(p) } diff --git a/write.go b/write.go index 3454348..70656b9 100644 --- a/write.go +++ b/write.go @@ -70,17 +70,17 @@ func newMsgWriter(c *Conn) *msgWriter { } func (mw *msgWriter) ensureFlate() { - if mw.flateWriter == nil { - if mw.trimWriter == nil { - mw.trimWriter = &trimLastFourBytesWriter{ - w: writerFunc(mw.write), - } + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), } - mw.trimWriter.reset() + } + if mw.flateWriter == nil { mw.flateWriter = getFlateWriter(mw.trimWriter) - mw.flate = true } + + mw.flate = true } func (mw *msgWriter) flateContextTakeover() bool { @@ -128,6 +128,11 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error { mw.ctx = ctx mw.opcode = opcode(typ) mw.flate = false + + if mw.trimWriter != nil { + mw.trimWriter.reset() + } + return nil } @@ -146,9 +151,8 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - // TODO can make threshold detection robust across writes by writing to bufio writer - if mw.flate || - mw.c.flate() && len(p) >= mw.c.flateThreshold { + // TODO Write to buffer to detect whether to enable flate or not for this message. + if mw.c.flate() { mw.ensureFlate() return mw.flateWriter.Write(p) } @@ -172,7 +176,6 @@ func (mw *msgWriter) Close() (err error) { if mw.closed { return xerrors.New("cannot use closed writer") } - mw.closed = true if mw.flate { err = mw.flateWriter.Flush() @@ -181,12 +184,16 @@ func (mw *msgWriter) Close() (err error) { } } + // We set closed after flushing the flate writer to ensure Write + // can succeed. + mw.closed = true + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } - if mw.c.flate() && !mw.flateContextTakeover() { + if mw.flate && !mw.flateContextTakeover() { mw.returnFlateWriter() } mw.mu.Unlock() -- GitLab