good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 679ddb82 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

Drastically improve non autobahn test coverage

Also simplified and refactored the Conn tests.

More changes soon.
parent a3a891bf
Branches
Tags
No related merge requests found
...@@ -6,6 +6,39 @@ import ( ...@@ -6,6 +6,39 @@ import (
"testing" "testing"
) )
func TestAccept(t *testing.T) {
t.Parallel()
t.Run("badClientHandshake", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
_, err := Accept(w, r, AcceptOptions{})
if err == nil {
t.Fatalf("unexpected error value: %v", err)
}
})
t.Run("requireHttpHijacker", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest("GET", "/", nil)
r.Header.Set("Connection", "Upgrade")
r.Header.Set("Upgrade", "websocket")
r.Header.Set("Sec-WebSocket-Version", "13")
r.Header.Set("Sec-WebSocket-Key", "meow123")
_, err := Accept(w, r, AcceptOptions{})
if err == nil || !strings.Contains(err.Error(), "http.Hijacker") {
t.Fatalf("unexpected error value: %v", err)
}
})
}
func Test_verifyClientHandshake(t *testing.T) { func Test_verifyClientHandshake(t *testing.T) {
t.Parallel() t.Parallel()
......
...@@ -4,19 +4,34 @@ set -euo pipefail ...@@ -4,19 +4,34 @@ set -euo pipefail
cd "$(dirname "${0}")" cd "$(dirname "${0}")"
cd "$(git rev-parse --show-toplevel)" cd "$(git rev-parse --show-toplevel)"
mkdir -p ci/out/websocket argv=(
testFlags=( go run gotest.tools/gotestsum
# https://circleci.com/docs/2.0/collect-test-data/
"--junitfile=ci/out/websocket/testReport.xml"
"--format=short-verbose"
--
-race -race
"-vet=off" "-vet=off"
# "-bench=." "-bench=."
)
# Interactive usage probably does not want to enable benchmarks, race detection
# turn off vet or use gotestsum by default.
if [[ $# -gt 0 ]]; then
argv=(go test "$@")
fi
# We always want coverage.
argv+=(
"-coverprofile=ci/out/coverage.prof" "-coverprofile=ci/out/coverage.prof"
"-coverpkg=./..." "-coverpkg=./..."
) )
# https://circleci.com/docs/2.0/collect-test-data/
go run gotest.tools/gotestsum \ mkdir -p ci/out/websocket
--junitfile ci/out/websocket/testReport.xml \ "${argv[@]}"
--format=short-verbose \
-- "${testFlags[@]}" # Removes coverage of generated files.
grep -v _string.go < ci/out/coverage.prof > ci/out/coverage2.prof
mv ci/out/coverage2.prof ci/out/coverage.prof
go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html go tool cover -html=ci/out/coverage.prof -o=ci/out/coverage.html
if [[ ${CI:-} ]]; then if [[ ${CI:-} ]]; then
......
...@@ -33,6 +33,10 @@ func TestBadDials(t *testing.T) { ...@@ -33,6 +33,10 @@ func TestBadDials(t *testing.T) {
}, },
}, },
}, },
{
name: "badTLS",
url: "wss://totallyfake.nhooyr.io",
},
} }
for _, tc := range testCases { for _, tc := range testCases {
...@@ -40,7 +44,10 @@ func TestBadDials(t *testing.T) { ...@@ -40,7 +44,10 @@ func TestBadDials(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
_, _, err := Dial(context.Background(), tc.url, tc.opts) ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_, _, err := Dial(ctx, tc.url, tc.opts)
if err == nil { if err == nil {
t.Fatalf("expected non nil error: %+v", err) t.Fatalf("expected non nil error: %+v", err)
} }
......
package websocket package websocket
var Compute = handleSecWebSocketKey import (
"context"
)
type Addr = websocketAddr
type Header = header
func (c *Conn) WriteFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) {
return c.writeFrame(ctx, fin, opcode, p)
}
...@@ -2,6 +2,7 @@ package websocket ...@@ -2,6 +2,7 @@ package websocket
import ( import (
"bytes" "bytes"
"io"
"math/rand" "math/rand"
"strconv" "strconv"
"testing" "testing"
...@@ -21,6 +22,36 @@ func randBool() bool { ...@@ -21,6 +22,36 @@ func randBool() bool {
func TestHeader(t *testing.T) { func TestHeader(t *testing.T) {
t.Parallel() t.Parallel()
t.Run("eof", func(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
bytes []byte
}{
{
"start",
[]byte{0xff},
},
{
"middle",
[]byte{0xff, 0xff, 0xff},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
b := bytes.NewBuffer(tc.bytes)
_, err := readHeader(nil, b)
if io.ErrUnexpectedEOF != err {
t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err)
}
})
}
})
t.Run("writeNegativeLength", func(t *testing.T) { t.Run("writeNegativeLength", func(t *testing.T) {
t.Parallel() t.Parallel()
......
...@@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) { ...@@ -101,8 +101,8 @@ func (c *netConn) Read(p []byte) (int, error) {
return 0, err return 0, err
} }
if typ != c.msgType { if typ != c.msgType {
c.c.Close(StatusUnsupportedData, fmt.Sprintf("can only accept %v messages", c.msgType)) c.c.Close(StatusUnsupportedData, fmt.Sprintf("unexpected frame type read (expected %v): %v", c.msgType, typ))
return 0, xerrors.Errorf("unexpected frame type read for net conn adapter (expected %v): %v", c.msgType, typ) return 0, c.c.closeErr
} }
c.reader = r c.reader = r
} }
......
...@@ -35,7 +35,7 @@ const ( ...@@ -35,7 +35,7 @@ const (
StatusTryAgainLater StatusTryAgainLater
StatusBadGateway StatusBadGateway
// statusTLSHandshake is unexported because we just return // statusTLSHandshake is unexported because we just return
// handshake error in dial. We do not return a conn // the handshake error in dial. We do not return a conn
// so there is nothing to use this on. At least until WASM. // so there is nothing to use this on. At least until WASM.
statusTLSHandshake statusTLSHandshake
) )
......
...@@ -4,14 +4,13 @@ import ( ...@@ -4,14 +4,13 @@ import (
"math" "math"
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
) )
func TestCloseError(t *testing.T) { func TestCloseError(t *testing.T) {
t.Parallel() t.Parallel()
// Other parts of close error are tested by websocket_test.go right now
// with the autobahn tests.
testCases := []struct { testCases := []struct {
name string name string
ce CloseError ce CloseError
...@@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) { ...@@ -50,7 +49,108 @@ func TestCloseError(t *testing.T) {
_, err := tc.ce.bytes() _, err := tc.ce.bytes()
if (err == nil) != tc.success { if (err == nil) != tc.success {
t.Fatalf("unexpected error value: %v", err) t.Fatalf("unexpected error value: %+v", err)
}
})
}
}
func Test_parseClosePayload(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
p []byte
success bool
ce CloseError
}{
{
name: "normal",
p: append([]byte{0x3, 0xE8}, []byte("hello")...),
success: true,
ce: CloseError{
Code: StatusNormalClosure,
Reason: "hello",
},
},
{
name: "nothing",
success: true,
ce: CloseError{
Code: StatusNoStatusRcvd,
},
},
{
name: "oneByte",
p: []byte{0},
success: false,
},
{
name: "badStatusCode",
p: []byte{0x17, 0x70},
success: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ce, err := parseClosePayload(tc.p)
if (err == nil) != tc.success {
t.Fatalf("unexpected expected error value: %+v", err)
}
if tc.success && tc.ce != ce {
t.Fatalf("unexpected close error: %v", cmp.Diff(tc.ce, ce))
}
})
}
}
func Test_validWireCloseCode(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
code StatusCode
valid bool
}{
{
name: "normal",
code: StatusNormalClosure,
valid: true,
},
{
name: "noStatus",
code: StatusNoStatusRcvd,
valid: false,
},
{
name: "3000",
code: 3000,
valid: true,
},
{
name: "4999",
code: 4999,
valid: true,
},
{
name: "unknown",
code: 5000,
valid: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
if valid := validWireCloseCode(tc.code); tc.valid != valid {
t.Fatalf("expected %v for %v but got %v", tc.valid, tc.code, valid)
} }
}) })
} }
......
...@@ -7,8 +7,8 @@ import ( ...@@ -7,8 +7,8 @@ import (
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
"log"
"math/rand" "math/rand"
"os"
"runtime" "runtime"
"strconv" "strconv"
"sync" "sync"
...@@ -210,9 +210,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { ...@@ -210,9 +210,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
} }
if h.rsv1 || h.rsv2 || h.rsv3 { if h.rsv1 || h.rsv2 || h.rsv3 {
err := xerrors.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3))
c.Close(StatusProtocolError, err.Error()) return header{}, c.closeErr
return header{}, err
} }
if h.opcode.controlOp() { if h.opcode.controlOp() {
...@@ -227,9 +226,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) { ...@@ -227,9 +226,8 @@ func (c *Conn) readTillMsg(ctx context.Context) (header, error) {
case opBinary, opText, opContinuation: case opBinary, opText, opContinuation:
return h, nil return h, nil
default: default:
err := xerrors.Errorf("received unknown opcode %v", h.opcode) c.Close(StatusProtocolError, fmt.Sprintf("received unknown opcode %v", h.opcode))
c.Close(StatusProtocolError, err.Error()) return header{}, c.closeErr
return header{}, err
} }
} }
} }
...@@ -273,15 +271,13 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { ...@@ -273,15 +271,13 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
func (c *Conn) handleControl(ctx context.Context, h header) error { func (c *Conn) handleControl(ctx context.Context, h header) error {
if h.payloadLength > maxControlFramePayload { if h.payloadLength > maxControlFramePayload {
err := xerrors.Errorf("control frame too large at %v bytes", h.payloadLength) c.Close(StatusProtocolError, fmt.Sprintf("control frame too large at %v bytes", h.payloadLength))
c.Close(StatusProtocolError, err.Error()) return c.closeErr
return err
} }
if !h.fin { if !h.fin {
err := xerrors.Errorf("received fragmented control frame") c.Close(StatusProtocolError, "received fragmented control frame")
c.Close(StatusProtocolError, err.Error()) return c.closeErr
return err
} }
ctx, cancel := context.WithTimeout(ctx, time.Second*5) ctx, cancel := context.WithTimeout(ctx, time.Second*5)
...@@ -311,8 +307,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error { ...@@ -311,8 +307,9 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
case opClose: case opClose:
ce, err := parseClosePayload(b) ce, err := parseClosePayload(b)
if err != nil { if err != nil {
c.Close(StatusProtocolError, "received invalid close payload") err = xerrors.Errorf("received invalid close payload: %w", err)
return xerrors.Errorf("received invalid close payload: %w", err) c.Close(StatusProtocolError, err.Error())
return c.closeErr
} }
// This ensures the closeErr of the Conn is always the received CloseError // This ensures the closeErr of the Conn is always the received CloseError
// in case the echo close frame write fails. // in case the echo close frame write fails.
...@@ -376,9 +373,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { ...@@ -376,9 +373,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
if c.activeReader != nil && !c.activeReader.eof() { if c.activeReader != nil && !c.activeReader.eof() {
if h.opcode != opContinuation { if h.opcode != opContinuation {
err := xerrors.Errorf("received new data message without finishing the previous message") c.Close(StatusProtocolError, "received new data message without finishing the previous message")
c.Close(StatusProtocolError, err.Error()) return 0, nil, c.closeErr
return 0, nil, err
} }
if !h.fin || h.payloadLength > 0 { if !h.fin || h.payloadLength > 0 {
...@@ -392,9 +388,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { ...@@ -392,9 +388,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
return 0, nil, err return 0, nil, err
} }
} else if h.opcode == opContinuation { } else if h.opcode == opContinuation {
err := xerrors.Errorf("received continuation frame not after data or text frame") c.Close(StatusProtocolError, "received continuation frame not after data or text frame")
c.Close(StatusProtocolError, err.Error()) return 0, nil, c.closeErr
return 0, nil, err
} }
c.readerMsgCtx = ctx c.readerMsgCtx = ctx
...@@ -460,9 +455,8 @@ func (r *messageReader) read(p []byte) (int, error) { ...@@ -460,9 +455,8 @@ func (r *messageReader) read(p []byte) (int, error) {
} }
if r.c.readMsgLeft <= 0 { if r.c.readMsgLeft <= 0 {
err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) r.c.Close(StatusMessageTooBig, fmt.Sprintf("read limited at %v bytes", r.c.msgReadLimit))
r.c.Close(StatusMessageTooBig, err.Error()) return 0, r.c.closeErr
return 0, err
} }
if int64(len(p)) > r.c.readMsgLeft { if int64(len(p)) > r.c.readMsgLeft {
...@@ -476,9 +470,8 @@ func (r *messageReader) read(p []byte) (int, error) { ...@@ -476,9 +470,8 @@ func (r *messageReader) read(p []byte) (int, error) {
} }
if h.opcode != opContinuation { if h.opcode != opContinuation {
err := xerrors.Errorf("received new data message without finishing the previous message") r.c.Close(StatusProtocolError, "received new data message without finishing the previous message")
r.c.Close(StatusProtocolError, err.Error()) return 0, r.c.closeErr
return 0, err
} }
r.c.readerMsgHeader = h r.c.readerMsgHeader = h
...@@ -828,7 +821,7 @@ func (c *Conn) writePong(p []byte) error { ...@@ -828,7 +821,7 @@ func (c *Conn) writePong(p []byte) error {
func (c *Conn) Close(code StatusCode, reason string) error { func (c *Conn) Close(code StatusCode, reason string) error {
err := c.exportedClose(code, reason) err := c.exportedClose(code, reason)
if err != nil { if err != nil {
return xerrors.Errorf("failed to close connection: %w", err) return xerrors.Errorf("failed to close websocket connection: %w", err)
} }
return nil return nil
} }
...@@ -844,7 +837,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { ...@@ -844,7 +837,7 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
// Definitely worth seeing what popular browsers do later. // Definitely worth seeing what popular browsers do later.
p, err := ce.bytes() p, err := ce.bytes()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "websocket: failed to marshal close frame: %v\n", err) log.Printf("websocket: failed to marshal close frame: %+v", err)
ce = CloseError{ ce = CloseError{
Code: StatusInternalError, Code: StatusInternalError,
} }
...@@ -853,12 +846,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error { ...@@ -853,12 +846,13 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
// CloseErrors sent are made opaque to prevent applications from thinking // CloseErrors sent are made opaque to prevent applications from thinking
// they received a given status. // they received a given status.
err = c.writeClose(p, xerrors.Errorf("sent close frame: %v", ce)) sentErr := xerrors.Errorf("sent close frame: %v", ce)
err = c.writeClose(p, sentErr)
if err != nil { if err != nil {
return err return err
} }
if !xerrors.Is(c.closeErr, ce) { if !xerrors.Is(c.closeErr, sentErr) {
return c.closeErr return c.closeErr
} }
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment