diff --git a/assert_test.go b/assert_test.go index 5307ee8e4b44a71b4b7631ca5c4c03999f863c18..6cfd926432f1d9a55cd8361801ab6c6e9003c97e 100644 --- a/assert_test.go +++ b/assert_test.go @@ -45,6 +45,7 @@ func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp in func randString(t *testing.T, n int) string { s := strings.ToValidUTF8(string(randBytes(t, n)), "_") + s = strings.ReplaceAll(s, "\x00", "_") if len(s) > n { return s[:n] } diff --git a/ci/image/Dockerfile b/ci/image/Dockerfile index 070c50e66985a8e29b853d520ac06365127deccf..88c965028349d1ef3acc747729617964428ff137 100644 --- a/ci/image/Dockerfile +++ b/ci/image/Dockerfile @@ -1,7 +1,7 @@ FROM golang:1 RUN apt-get update -RUN apt-get install -y chromium +RUN apt-get install -y chromium npm ENV GOFLAGS="-mod=readonly" ENV PAGER=cat diff --git a/conn_test.go b/conn_test.go index 7186da8a7bd7ef6c425cec59c5c89f5284aebe42..4720cba929a43062ed812360b332fe373e8928df 100644 --- a/conn_test.go +++ b/conn_test.go @@ -18,45 +18,71 @@ import ( "nhooyr.io/websocket" ) +func TestFuzz(t *testing.T) { + t.Parallel() + + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionContextTakeover, + }, + }) + assert.Success(t, "accept", err) + defer c.Close(websocket.StatusInternalError, "") + + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + opts := &websocket.DialOptions{ + CompressionOptions: websocket.CompressionOptions{ + Mode: websocket.CompressionContextTakeover, + }, + } + opts.HTTPClient = s.Client() + + c, _, err := websocket.Dial(ctx, wsURL(s), opts) + assert.Success(t, "dial", err) + assertJSONEcho(t, ctx, c, 8393) +} + func TestConn(t *testing.T) { t.Parallel() t.Run("json", func(t *testing.T) { s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, + Subprotocols: []string{"echo"}, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - Threshold: 1, + Mode: websocket.CompressionContextTakeover, }, }) assert.Success(t, "accept", err) defer c.Close(websocket.StatusInternalError, "") err = echoLoop(r.Context(), c) - t.Logf("server: %v", err) assertCloseStatus(t, websocket.StatusNormalClosure, err) }, false) defer closeFn() - wsURL := strings.Replace(s.URL, "http", "ws", 1) - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() opts := &websocket.DialOptions{ Subprotocols: []string{"echo"}, CompressionOptions: websocket.CompressionOptions{ - Mode: websocket.CompressionContextTakeover, - Threshold: 1, + Mode: websocket.CompressionContextTakeover, }, } opts.HTTPClient = s.Client() - c, _, err := websocket.Dial(ctx, wsURL, opts) + c, _, err := websocket.Dial(ctx, wsURL(s), opts) assert.Success(t, "dial", err) - assertJSONEcho(t, ctx, c, 2) + assertJSONEcho(t, ctx, c, 8393) }) } @@ -149,3 +175,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn) error { } } } + +func wsURL(s *httptest.Server) string { + return strings.Replace(s.URL, "http", "ws", 1) +} diff --git a/internal/errd/wrap.go b/internal/errd/wrap.go index 20de77430ddac87d1187c38625e2be73daefc99d..ed0b775447d005dce617f8a5b961dc6a3684cf3c 100644 --- a/internal/errd/wrap.go +++ b/internal/errd/wrap.go @@ -1,12 +1,42 @@ package errd -import "golang.org/x/xerrors" +import ( + "fmt" + + "golang.org/x/xerrors" +) + +type wrapError struct { + msg string + err error + frame xerrors.Frame +} + +func (e *wrapError) Error() string { + return fmt.Sprint(e) +} + +func (e *wrapError) Format(s fmt.State, v rune) { xerrors.FormatError(e, s, v) } + +func (e *wrapError) FormatError(p xerrors.Printer) (next error) { + p.Print(e.msg) + e.frame.Format(p) + return e.err +} + +func (e *wrapError) Unwrap() error { + return e.err +} // Wrap wraps err with xerrors.Errorf if err is non nil. // Intended for use with defer and a named error return. // Inspired by https://github.com/golang/go/issues/32676. func Wrap(err *error, f string, v ...interface{}) { if *err != nil { - *err = xerrors.Errorf(f+": %w", append(v, *err)...) + *err = &wrapError{ + msg: fmt.Sprintf(f, v...), + err: *err, + frame: xerrors.Caller(1), + } } } diff --git a/write.go b/write.go index a7fa5f5a62bcf76aa141239289d167d446bb8ffe..4a756fa9dee2cb9108b7aa5903985076c3bba563 100644 --- a/write.go +++ b/write.go @@ -145,7 +145,7 @@ func (mw *msgWriter) Write(p []byte) (_ int, err error) { return 0, xerrors.New("cannot use closed writer") } - // TODO can make threshold detection robust across writes by writing to buffer + // TODO can make threshold detection robust across writes by writing to bufio writer if mw.flate || mw.c.flate() && len(p) >= mw.c.flateThreshold { mw.ensureFlate()