From 679ddb825d5cd5ce4cc7136734fff5effe3a2910 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Thu, 29 Aug 2019 15:37:26 -0500 Subject: [PATCH] Drastically improve non autobahn test coverage Also simplified and refactored the Conn tests. More changes soon. --- accept_test.go | 33 ++ ci/test.sh | 31 +- dial_test.go | 9 +- export_test.go | 12 +- header_test.go | 31 ++ netconn.go | 4 +- statuscode.go | 2 +- statuscode_test.go | 108 ++++++- websocket.go | 56 ++-- websocket_test.go | 781 ++++++++++++++++++++++++++++++--------------- 10 files changed, 761 insertions(+), 306 deletions(-) diff --git a/accept_test.go b/accept_test.go index 6f5c3fb..8634066 100644 --- a/accept_test.go +++ b/accept_test.go @@ -6,6 +6,39 @@ import ( "testing" ) +func TestAccept(t *testing.T) { + t.Parallel() + + t.Run("badClientHandshake", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + + _, err := Accept(w, r, AcceptOptions{}) + if err == nil { + t.Fatalf("unexpected error value: %v", err) + } + + }) + + t.Run("requireHttpHijacker", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest("GET", "/", nil) + r.Header.Set("Connection", "Upgrade") + r.Header.Set("Upgrade", "websocket") + r.Header.Set("Sec-WebSocket-Version", "13") + r.Header.Set("Sec-WebSocket-Key", "meow123") + + _, err := Accept(w, r, AcceptOptions{}) + if err == nil || !strings.Contains(err.Error(), "http.Hijacker") { + t.Fatalf("unexpected error value: %v", err) + } + }) +} + func Test_verifyClientHandshake(t *testing.T) { t.Parallel() diff --git a/ci/test.sh b/ci/test.sh index 875216f..1d4a8b0 100755 --- a/ci/test.sh +++ b/ci/test.sh @@ -4,19 +4,34 @@ set -euo pipefail cd "$(dirname "${0}")" cd "$(git rev-parse --show-toplevel)" -mkdir -p ci/out/websocket -testFlags=( +argv=( + go run gotest.tools/gotestsum + # https://circleci.com/docs/2.0/collect-test-data/ + "--junitfile=ci/out/websocket/testReport.xml" + "--format=short-verbose" + -- -race "-vet=off" - # "-bench=." + "-bench=." +) +# Interactive usage probably does not want to enable benchmarks, race detection +# turn off vet or use gotestsum by default. +if [[ $# -gt 0 ]]; then + argv=(go test "$@") +fi + +# We always want coverage. +argv+=( "-coverprofile=ci/out/coverage.prof" "-coverpkg=./..." ) -# https://circleci.com/docs/2.0/collect-test-data/ -go run gotest.tools/gotestsum \ - --junitfile ci/out/websocket/testReport.xml \ - --format=short-verbose \ - -- "${testFlags[@]}" + +mkdir -p ci/out/websocket +"${argv[@]}" + +# Removes coverage of generated files. +grep -v _string.go < ci/out/coverage.prof > ci/out/coverage2.prof +mv ci/out/coverage2.prof ci/out/coverage.prof go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html if [[ ${CI:-} ]]; then diff --git a/dial_test.go b/dial_test.go index 6400c22..4607493 100644 --- a/dial_test.go +++ b/dial_test.go @@ -33,6 +33,10 @@ func TestBadDials(t *testing.T) { }, }, }, + { + name: "badTLS", + url: "wss://totallyfake.nhooyr.io", + }, } for _, tc := range testCases { @@ -40,7 +44,10 @@ func TestBadDials(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, _, err := Dial(context.Background(), tc.url, tc.opts) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + _, _, err := Dial(ctx, tc.url, tc.opts) if err == nil { t.Fatalf("expected non nil error: %+v", err) } diff --git a/export_test.go b/export_test.go index 22ad76f..ab766f1 100644 --- a/export_test.go +++ b/export_test.go @@ -1,3 +1,13 @@ package websocket -var Compute = handleSecWebSocketKey +import ( + "context" +) + +type Addr = websocketAddr + +type Header = header + +func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + return c.writeFrame(ctx, fin, opcode, p) +} diff --git a/header_test.go b/header_test.go index 4457c35..45d0535 100644 --- a/header_test.go +++ b/header_test.go @@ -2,6 +2,7 @@ package websocket import ( "bytes" + "io" "math/rand" "strconv" "testing" @@ -21,6 +22,36 @@ func randBool() bool { func TestHeader(t *testing.T) { t.Parallel() + t.Run("eof", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + bytes []byte + }{ + { + "start", + []byte{0xff}, + }, + { + "middle", + []byte{0xff, 0xff, 0xff}, + }, + } + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + b := bytes.NewBuffer(tc.bytes) + _, err := readHeader(nil, b) + if io.ErrUnexpectedEOF != err { + t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) + } + }) + } + }) + t.Run("writeNegativeLength", func(t *testing.T) { t.Parallel() diff --git a/netconn.go b/netconn.go index d28eeb8..a6f902d 100644 --- a/netconn.go +++ b/netconn.go @@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) { return 0, err } if typ != c.msgType { - c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType)) - return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ) + c.c.Close(StatusUnsupportedData, fmt.Sprintf("unexpected frame type read (expected %v): %v", c.msgType, typ)) + return 0, c.c.closeErr } c.reader = r } diff --git a/statuscode.go b/statuscode.go index 42ae40c..498437d 100644 --- a/statuscode.go +++ b/statuscode.go @@ -35,7 +35,7 @@ const ( StatusTryAgainLater StatusBadGateway // statusTLSHandshake is unexported because we just return - // handshake error in dial. We do not return a conn + // the handshake error in dial. We do not return a conn // so there is nothing to use this on. At least until WASM. statusTLSHandshake ) diff --git a/statuscode_test.go b/statuscode_test.go index 38ee4c3..b963786 100644 --- a/statuscode_test.go +++ b/statuscode_test.go @@ -4,14 +4,13 @@ import ( "math" "strings" "testing" + + "github.com/google/go-cmp/cmp" ) func TestCloseError(t *testing.T) { t.Parallel() - // Other parts of close error are tested by websocket_test.go right now - // with the autobahn tests. - testCases := []struct { name string ce CloseError @@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) { _, err := tc.ce.bytes() if (err == nil) != tc.success { - t.Fatalf("unexpected error value: %v", err) + t.Fatalf("unexpected error value: %+v", err) + } + }) + } +} + +func Test_parseClosePayload(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + p []byte + success bool + ce CloseError + }{ + { + name: "normal", + p: append([]byte{0x3, 0xE8}, []byte("hello")...), + success: true, + ce: CloseError{ + Code: StatusNormalClosure, + Reason: "hello", + }, + }, + { + name: "nothing", + success: true, + ce: CloseError{ + Code: StatusNoStatusRcvd, + }, + }, + { + name: "oneByte", + p: []byte{0}, + success: false, + }, + { + name: "badStatusCode", + p: []byte{0x17, 0x70}, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + ce, err := parseClosePayload(tc.p) + if (err == nil) != tc.success { + t.Fatalf("unexpected expected error value: %+v", err) + } + + if tc.success && tc.ce != ce { + t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce)) + } + }) + } +} + +func Test_validWireCloseCode(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code StatusCode + valid bool + }{ + { + name: "normal", + code: StatusNormalClosure, + valid: true, + }, + { + name: "noStatus", + code: StatusNoStatusRcvd, + valid: false, + }, + { + name: "3000", + code: 3000, + valid: true, + }, + { + name: "4999", + code: 4999, + valid: true, + }, + { + name: "unknown", + code: 5000, + valid: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + if valid := validWireCloseCode(tc.code); tc.valid != valid { + t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid) } }) } diff --git a/websocket.go b/websocket.go index 393ea54..833c120 100644 --- a/websocket.go +++ b/websocket.go @@ -7,8 +7,8 @@ import ( "fmt" "io" "io/ioutil" + "log" "math/rand" - "os" "runtime" "strconv" "sync" @@ -210,9 +210,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { } if h.rsv1 || h.rsv2 || h.rsv3 { - err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - c.Close(StatusProtocolError, err.Error()) - return header{}, err + c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) + return header{}, c.closeErr } if h.opcode.controlOp() { @@ -227,9 +226,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { case opBinary, opText, opContinuation: return h, nil default: - err := xerrors.Errorf("received unknown opcode %v", h.opcode) - c.Close(StatusProtocolError, err.Error()) - return header{}, err + c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode)) + return header{}, c.closeErr } } } @@ -273,15 +271,13 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { func (c *Conn) handleControl(ctx context.Context, h header) error { if h.payloadLength > maxControlFramePayload { - err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength) - c.Close(StatusProtocolError, err.Error()) - return err + c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength)) + return c.closeErr } if !h.fin { - err := xerrors.Errorf("received fragmented control frame") - c.Close(StatusProtocolError, err.Error()) - return err + c.Close(StatusProtocolError, "received fragmented control frame") + return c.closeErr } ctx, cancel := context.WithTimeout(ctx, time.Second*5) @@ -311,8 +307,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { case opClose: ce, err := parseClosePayload(b) if err != nil { - c.Close(StatusProtocolError, "received invalid close payload") - return xerrors.Errorf("received invalid close payload: %w", err) + err = xerrors.Errorf("received invalid close payload: %w", err) + c.Close(StatusProtocolError, err.Error()) + return c.closeErr } // This ensures the closeErr of the Conn is always the received CloseError // in case the echo close frame write fails. @@ -376,9 +373,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { if c.activeReader != nil && !c.activeReader.eof() { if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err + c.Close(StatusProtocolError, "received new data message without finishing the previous message") + return 0, nil, c.closeErr } if !h.fin || h.payloadLength > 0 { @@ -392,9 +388,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, err } } else if h.opcode == opContinuation { - err := xerrors.Errorf("received continuation frame not after data or text frame") - c.Close(StatusProtocolError, err.Error()) - return 0, nil, err + c.Close(StatusProtocolError, "received continuation frame not after data or text frame") + return 0, nil, c.closeErr } c.readerMsgCtx = ctx @@ -460,9 +455,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.c.readMsgLeft <= 0 { - err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) - r.c.Close(StatusMessageTooBig, err.Error()) - return 0, err + r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit)) + return 0, r.c.closeErr } if int64(len(p)) > r.c.readMsgLeft { @@ -476,9 +470,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if h.opcode != opContinuation { - err := xerrors.Errorf("received new data message without finishing the previous message") - r.c.Close(StatusProtocolError, err.Error()) - return 0, err + r.c.Close(StatusProtocolError, "received new data message without finishing the previous message") + return 0, r.c.closeErr } r.c.readerMsgHeader = h @@ -828,7 +821,7 @@ func (c *Conn) writePong(p []byte) error { func (c *Conn) Close(code StatusCode, reason string) error { err := c.exportedClose(code, reason) if err != nil { - return xerrors.Errorf("failed to close connection: %w", err) + return xerrors.Errorf("failed to close websocket connection: %w", err) } return nil } @@ -844,7 +837,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // Definitely worth seeing what popular browsers do later. p, err := ce.bytes() if err != nil { - fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) + log.Printf("websocket: failed to marshal close frame: %+v", err) ce = CloseError{ Code: StatusInternalError, } @@ -853,12 +846,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { // CloseErrors sent are made opaque to prevent applications from thinking // they received a given status. - err = c.writeClose(p, xerrors.Errorf("sent close frame: %v", ce)) + sentErr := xerrors.Errorf("sent close frame: %v", ce) + err = c.writeClose(p, sentErr) if err != nil { return err } - if !xerrors.Is(c.closeErr, ce) { + if !xerrors.Is(c.closeErr, sentErr) { return c.closeErr } diff --git a/websocket_test.go b/websocket_test.go index 2ef25cd..b45f024 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -4,8 +4,11 @@ import ( "context" "encoding/json" "fmt" + "github.com/golang/protobuf/proto" + "github.com/golang/protobuf/ptypes/timestamp" "io" "io/ioutil" + "math/rand" "net" "net/http" "net/http/cookiejar" @@ -75,127 +78,6 @@ func TestHandshake(t *testing.T) { return nil }, }, - { - name: "closeError", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - err = wsjson.Write(r.Context(), c, "hello") - if err != nil { - return err - } - - return nil - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - var m string - err = wsjson.Read(ctx, c, &m) - if err != nil { - return err - } - - if m != "hello" { - return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) - } - - _, _, err = c.Reader(ctx) - var cerr websocket.CloseError - if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { - return xerrors.Errorf("unexpected error: %+v", err) - } - - return nil - }, - }, - { - name: "netConn", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - nc := websocket.NetConn(c, websocket.MessageBinary) - defer nc.Close() - - nc.SetWriteDeadline(time.Time{}) - time.Sleep(1) - nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - for i := 0; i < 3; i++ { - _, err = nc.Write([]byte("hello")) - if err != nil { - return err - } - } - - return nil - }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - Subprotocols: []string{"meow"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - nc := websocket.NetConn(c, websocket.MessageBinary) - defer nc.Close() - - nc.SetReadDeadline(time.Time{}) - time.Sleep(1) - nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - - read := func() error { - p := make([]byte, len("hello")) - // We do not use io.ReadFull here as it masks EOFs. - // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 - _, err = nc.Read(p) - if err != nil { - return err - } - - if string(p) != "hello" { - return xerrors.Errorf("unexpected payload %q received", string(p)) - } - return nil - } - - for i := 0; i < 3; i++ { - err = read() - if err != nil { - return err - } - } - - // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err = read() - if err != io.EOF { - return err - } - - err = read() - if err != io.EOF { - return err - } - - return nil - }, - }, { name: "defaultSubprotocol", server: func(w http.ResponseWriter, r *http.Request) error { @@ -323,22 +205,240 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - defer c.Close(websocket.StatusInternalError, "") + defer c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + { + name: "cookies", + server: func(w http.ResponseWriter, r *http.Request) error { + cookie, err := r.Cookie("mycookie") + if err != nil { + return xerrors.Errorf("request is missing mycookie: %w", err) + } + if cookie.Value != "myvalue" { + return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) + } + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + jar, err := cookiejar.New(nil) + if err != nil { + return xerrors.Errorf("failed to create cookie jar: %w", err) + } + parsedURL, err := url.Parse(u) + if err != nil { + return xerrors.Errorf("failed to parse url: %w", err) + } + parsedURL.Scheme = "http" + jar.SetCookies(parsedURL, []*http.Cookie{ + { + Name: "mycookie", + Value: "myvalue", + }, + }) + hc := &http.Client{ + Jar: jar, + } + c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ + HTTPClient: hc, + }) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, tc.server, false) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + err := tc.client(ctx, wsURL) + if err != nil { + t.Fatalf("client failed: %+v", err) + } + }) + } +} + +func TestConn(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + client func(ctx context.Context, c *websocket.Conn) error + server func(ctx context.Context, c *websocket.Conn) error + }{ + { + name: "closeError", + server: func(ctx context.Context, c *websocket.Conn) error { + return wsjson.Write(ctx, c, "hello") + }, + client: func(ctx context.Context, c *websocket.Conn) error { + var m string + err := wsjson.Read(ctx, c, &m) + if err != nil { + return err + } + + if m != "hello" { + return xerrors.Errorf("recieved unexpected msg but expected hello: %+v", m) + } + + _, _, err = c.Reader(ctx) + var cerr websocket.CloseError + if !xerrors.As(err, &cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("unexpected error: %+v", err) + } + + return nil + }, + }, + { + name: "netConn", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetWriteDeadline(time.Time{}) + time.Sleep(1) + nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) + + if nc.LocalAddr() != (websocket.Addr{}) { + return xerrors.Errorf("net conn local address is not equal to websocket.Addr") + } + if nc.RemoteAddr() != (websocket.Addr{}) { + return xerrors.Errorf("net conn remote address is not equal to websocket.Addr") + } + + for i := 0; i < 3; i++ { + _, err := nc.Write([]byte("hello")) + if err != nil { + return err + } + } + + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetReadDeadline(time.Time{}) + time.Sleep(1) + nc.SetReadDeadline(time.Now().Add(time.Second * 15)) + + read := func() error { + p := make([]byte, len("hello")) + // We do not use io.ReadFull here as it masks EOFs. + // See https://github.com/nhooyr/websocket/issues/100#issuecomment-508148024 + _, err := nc.Read(p) + if err != nil { + return err + } + + if string(p) != "hello" { + return xerrors.Errorf("unexpected payload %q received", string(p)) + } + return nil + } + + for i := 0; i < 3; i++ { + err := read() + if err != nil { + return err + } + } + + // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. + err := read() + if err != io.EOF { + return err + } + + err = read() + if err != io.EOF { + return err + } + + return nil + }, + }, + { + name: "netConn/badReadMsgType", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetDeadline(time.Now().Add(time.Second * 15)) + + _, err := nc.Read(make([]byte, 1)) + if err == nil { + return xerrors.Errorf("expected error") + } + + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, "meow") + if err != nil { + return err + } + + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusUnsupportedData { + return xerrors.Errorf("expected close error with code StatusUnsupportedData: %+v", err) + } + + return nil + }, + }, + { + name: "netConn/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + nc := websocket.NetConn(c, websocket.MessageBinary) + defer nc.Close() + + nc.SetDeadline(time.Now().Add(time.Second * 15)) + + _, err := nc.Read(make([]byte, 1)) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusBadGateway { + return xerrors.Errorf("expected close error with code StatusBadGateway: %+v", err) + } + + _, err = nc.Write([]byte{0xff}) + if err == nil { + return xerrors.Errorf("expected writes to fail after reading a close frame: %v", err) + } + return nil }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Close(websocket.StatusBadGateway, "") + }, }, { name: "jsonEcho", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - + server: func(ctx context.Context, c *websocket.Conn) error { write := func() error { v := map[string]interface{}{ "anmol": "wowow", @@ -346,7 +446,7 @@ func TestHandshake(t *testing.T) { err := wsjson.Write(ctx, c, v) return err } - err = write() + err := write() if err != nil { return err } @@ -358,13 +458,7 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusNormalClosure, "") return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { read := func() error { var v interface{} err := wsjson.Read(ctx, c, &v) @@ -380,7 +474,7 @@ func TestHandshake(t *testing.T) { } return nil } - err = read() + err := read() if err != nil { return err } @@ -395,21 +489,12 @@ func TestHandshake(t *testing.T) { }, { name: "protobufEcho", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) - defer cancel() - + server: func(ctx context.Context, c *websocket.Conn) error { write := func() error { err := wspb.Write(ctx, c, ptypes.DurationProto(100)) return err } - err = write() + err := write() if err != nil { return err } @@ -417,13 +502,7 @@ func TestHandshake(t *testing.T) { c.Close(websocket.StatusNormalClosure, "") return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { read := func() error { var v duration.Duration err := wspb.Read(ctx, c, &v) @@ -441,7 +520,7 @@ func TestHandshake(t *testing.T) { } return nil } - err = read() + err := read() if err != nil { return err } @@ -450,73 +529,21 @@ func TestHandshake(t *testing.T) { return nil }, }, - { - name: "cookies", - server: func(w http.ResponseWriter, r *http.Request) error { - cookie, err := r.Cookie("mycookie") - if err != nil { - return xerrors.Errorf("request is missing mycookie: %w", err) - } - if cookie.Value != "myvalue" { - return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) - } - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - c.Close(websocket.StatusInternalError, "") - return nil - }, - client: func(ctx context.Context, u string) error { - jar, err := cookiejar.New(nil) - if err != nil { - return xerrors.Errorf("failed to create cookie jar: %w", err) - } - parsedURL, err := url.Parse(u) - if err != nil { - return xerrors.Errorf("failed to parse url: %w", err) - } - parsedURL.Scheme = "http" - jar.SetCookies(parsedURL, []*http.Cookie{ - { - Name: "mycookie", - Value: "myvalue", - }, - }) - hc := &http.Client{ - Jar: jar, - } - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{ - HTTPClient: hc, - }) - if err != nil { - return err - } - c.Close(websocket.StatusInternalError, "") - return nil - }, - }, { name: "ping", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + server: func(ctx context.Context, c *websocket.Conn) error { errc := make(chan error, 1) go func() { - _, _, err2 := c.Read(r.Context()) + _, _, err2 := c.Read(ctx) errc <- err2 }() - err = c.Ping(r.Context()) + err := c.Ping(ctx) if err != nil { return err } - err = c.Write(r.Context(), websocket.MessageText, []byte("hi")) + err = c.Write(ctx, websocket.MessageText, []byte("hi")) if err != nil { return err } @@ -528,13 +555,7 @@ func TestHandshake(t *testing.T) { } return xerrors.Errorf("unexpected error: %w", err) }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - + client: func(ctx context.Context, c *websocket.Conn) error { // We read a message from the connection and then keep reading until // the Ping completes. done := make(chan struct{}) @@ -550,7 +571,7 @@ func TestHandshake(t *testing.T) { c.Read(ctx) }() - err = c.Ping(ctx) + err := c.Ping(ctx) if err != nil { return err } @@ -563,29 +584,17 @@ func TestHandshake(t *testing.T) { }, { name: "readLimit", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - _, _, err = c.Read(r.Context()) + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) if err == nil { return xerrors.Errorf("expected error but got nil") } return nil }, - client: func(ctx context.Context, u string) error { - c, _, err := websocket.Dial(ctx, u, websocket.DialOptions{}) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") + client: func(ctx context.Context, c *websocket.Conn) error { + go c.CloseRead(ctx) - go c.Reader(ctx) - - err = c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) + err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) if err != nil { return err } @@ -600,20 +609,244 @@ func TestHandshake(t *testing.T) { return nil }, }, - } + { + name: "wsjson/binary", + server: func(ctx context.Context, c *websocket.Conn) error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return wspb.Write(ctx, c, ptypes.DurationProto(100)) + }, + }, + { + name: "wsjson/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + var v interface{} + err := wsjson.Read(ctx, c, &v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageText, []byte("notjson")) + }, + }, + { + name: "wsjson/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wsjson.Write(ctx, c, fmt.Println) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + }, + { + name: "wspb/text", + server: func(ctx context.Context, c *websocket.Conn) error { + var v proto.Message + err := wspb.Read(ctx, c, v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return wsjson.Write(ctx, c, "hi") + }, + }, + { + name: "wspb/badRead", + server: func(ctx context.Context, c *websocket.Conn) error { + var v timestamp.Timestamp + err := wspb.Read(ctx, c, &v) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) + }, + }, + { + name: "wspb/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wspb.Write(ctx, c, nil) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + }, + { + name: "wspb/badWrite", + server: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + err := wspb.Write(ctx, c, nil) + if err == nil { + return xerrors.Errorf("expected error: %v", err) + } + return nil + }, + }, + { + name: "badClose", + server: func(ctx context.Context, c *websocket.Conn) error { + return c.Close(9999, "") + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusInternalError { + return xerrors.Errorf("expected close error with StatusInternalError: %+v", err) + } + return nil + }, + }, + { + name: "pingTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := c.Ping(ctx) + if err == nil { + return xerrors.Errorf("expected nil error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "writeTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + c.Writer(ctx, websocket.MessageBinary) + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) + if !xerrors.Is(err, context.DeadlineExceeded) { + return xerrors.Errorf("expected deadline exceeded error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "readTimeout", + server: func(ctx context.Context, c *websocket.Conn) error { + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + <-ctx.Done() + _, err = r.Read(make([]byte, 1)) + if !xerrors.Is(err, context.DeadlineExceeded){ + return xerrors.Errorf("expected deadline exceeded error: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + time.Sleep(time.Second) + return nil + }, + }, + { + name: "badOpCode", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, 13, []byte("meow")) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || strings.Contains(err.Error(), "opcode") { + return xerrors.Errorf("expected error that contains opcode: %+v", err) + } + return nil + }, + }, + { + name: "noRsv", + server: func(ctx context.Context, c *websocket.Conn) error { + _, err := c.WriteFrame(ctx, true, 99, []byte("meow")) + if err != nil { + return err + } + _, _, err = c.Read(ctx) + cerr := &websocket.CloseError{} + if !xerrors.As(err, cerr) || cerr.Code != websocket.StatusProtocolError { + return xerrors.Errorf("expected close error with StatusProtocolError: %+v", err) + } + return nil + }, + client: func(ctx context.Context, c *websocket.Conn) error { + _, _, err := c.Read(ctx) + if err == nil || !strings.Contains(err.Error(), "rsv") { + return xerrors.Errorf("expected error that contains rsv: %+v", err) + } + return nil + }, + }, + } for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { - err := tc.server(w, r) + // Run random tests over TLS. + tls := rand.Intn(2) == 1 + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { + c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { - t.Errorf("server failed: %+v", err) - return + return err } - }) + defer c.Close(websocket.StatusInternalError, "") + tc.server(r.Context(), c) + return nil + }, tls) defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) @@ -621,7 +854,18 @@ func TestHandshake(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() - err := tc.client(ctx, wsURL) + opts := websocket.DialOptions{} + if tls { + opts.HTTPClient = s.Client() + } + + c, _, err := websocket.Dial(ctx, wsURL, opts) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + err = tc.client(ctx, c) if err != nil { t.Fatalf("client failed: %+v", err) } @@ -629,14 +873,31 @@ func TestHandshake(t *testing.T) { } } -func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) { +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { var conns int64 - s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { atomic.AddInt64(&conns, 1) defer atomic.AddInt64(&conns, -1) - fn.ServeHTTP(w, r) - })) + ctx, cancel := context.WithTimeout(r.Context(), time.Second*30) + defer cancel() + + r = r.WithContext(ctx) + + err := fn(w, r) + if err != nil { + tb.Errorf("server failed: %+v", err) + } + }) + if tls { + s = httptest.NewTLSServer(h) + } else { + s = httptest.NewServer(h) + } return s, func() { s.Close() @@ -654,7 +915,9 @@ func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahnServer(t *testing.T) { t.Parallel() - t.Skip() + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run the autobahn test suite.") + } s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, websocket.AcceptOptions{ @@ -795,7 +1058,9 @@ func unusedListenAddr() (string, error) { // https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py func TestAutobahnClient(t *testing.T) { t.Parallel() - t.Skip() + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run the autobahn test suite.") + } serverAddr, err := unusedListenAddr() if err != nil { @@ -941,18 +1206,18 @@ func checkWSTestIndex(t *testing.T, path string) { } func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) if err != nil { - b.Logf("server handshake failed: %+v", err) - return + return err } if echo { echoLoop(r.Context(), c) } else { discardLoop(r.Context(), c) } - })) + return nil + }, false) defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) -- GitLab