From b53f306c00debd46e5ed5debd2f9594ee8889f5c Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sun, 9 Feb 2020 00:47:34 -0500 Subject: [PATCH] Get Wasm tests working --- accept.go | 7 +- accept_test.go | 55 +++++--- autobahn_test.go | 33 +++-- close.go | 193 ------------------------- close_notjs.go | 199 ++++++++++++++++++++++++++ compress.go | 155 -------------------- compress_notjs.go | 156 +++++++++++++++++++++ compress_test.go | 29 ++-- conn.go | 261 ---------------------------------- conn_notjs.go | 264 +++++++++++++++++++++++++++++++++++ conn_test.go | 131 +++++++---------- dial.go | 7 +- dial_test.go | 2 +- example_test.go | 3 +- frame.go | 2 - internal/test/cmp/cmp.go | 9 ++ internal/test/wstest/echo.go | 90 ++++++++++++ internal/test/wstest/pipe.go | 3 + internal/test/wstest/url.go | 11 ++ internal/xsync/go.go | 25 ++++ internal/xsync/go_test.go | 20 +++ internal/xsync/int64.go | 23 +++ read.go | 19 +-- ws_js.go | 30 ++-- ws_js_test.go | 40 ++++-- 25 files changed, 985 insertions(+), 782 deletions(-) create mode 100644 close_notjs.go create mode 100644 compress_notjs.go create mode 100644 conn_notjs.go create mode 100644 internal/test/wstest/echo.go create mode 100644 internal/test/wstest/url.go create mode 100644 internal/xsync/go.go create mode 100644 internal/xsync/go_test.go create mode 100644 internal/xsync/int64.go diff --git a/accept.go b/accept.go index 0394fa6..31f104b 100644 --- a/accept.go +++ b/accept.go @@ -39,7 +39,7 @@ type AcceptOptions struct { // CompressionOptions controls the compression options. // See docs on the CompressionOptions type. - CompressionOptions CompressionOptions + CompressionOptions *CompressionOptions } // Accept accepts a WebSocket handshake from a client and upgrades the @@ -59,6 +59,11 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Con if opts == nil { opts = &AcceptOptions{} } + opts = &*opts + + if opts.CompressionOptions == nil { + opts.CompressionOptions = &CompressionOptions{} + } err = verifyClientRequest(r) if err != nil { diff --git a/accept_test.go b/accept_test.go index 3e8b1f4..18302da 100644 --- a/accept_test.go +++ b/accept_test.go @@ -10,8 +10,9 @@ import ( "strings" "testing" - "cdr.dev/slog/sloggers/slogtest/assert" "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/test/cmp" ) func TestAccept(t *testing.T) { @@ -24,7 +25,9 @@ func TestAccept(t *testing.T) { r := httptest.NewRequest("GET", "/", nil) _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "protocol violation") + if !cmp.ErrorContains(err, "protocol violation") { + t.Fatal(err) + } }) t.Run("badOrigin", func(t *testing.T) { @@ -39,7 +42,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Origin", "harhar.com") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "request Origin \"harhar.com\" is not authorized for Host") + if !cmp.ErrorContains(err, `request Origin "harhar.com" is not authorized for Host`) { + t.Fatal(err) + } }) t.Run("badCompression", func(t *testing.T) { @@ -56,7 +61,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; harharhar") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "unsupported permessage-deflate parameter") + if !cmp.ErrorContains(err, `unsupported permessage-deflate parameter`) { + t.Fatal(err) + } }) t.Run("requireHttpHijacker", func(t *testing.T) { @@ -70,7 +77,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "http.ResponseWriter does not implement http.Hijacker") + if !cmp.ErrorContains(err, `http.ResponseWriter does not implement http.Hijacker`) { + t.Fatal(err) + } }) t.Run("badHijack", func(t *testing.T) { @@ -90,7 +99,9 @@ func TestAccept(t *testing.T) { r.Header.Set("Sec-WebSocket-Key", "meow123") _, err := Accept(w, r, nil) - assert.ErrorContains(t, "Accept", err, "failed to hijack connection") + if !cmp.ErrorContains(err, `failed to hijack connection`) { + t.Fatal(err) + } }) } @@ -182,10 +193,8 @@ func Test_verifyClientHandshake(t *testing.T) { } err := verifyClientRequest(r) - if tc.success { - assert.Success(t, "verifyClientRequest", err) - } else { - assert.Error(t, "verifyClientRequest", err) + if tc.success != (err == nil) { + t.Fatalf("unexpected error value: %v", err) } }) } @@ -235,7 +244,9 @@ func Test_selectSubprotocol(t *testing.T) { r.Header.Set("Sec-WebSocket-Protocol", strings.Join(tc.clientProtocols, ",")) negotiated := selectSubprotocol(r, tc.serverProtocols) - assert.Equal(t, "negotiated", tc.negotiated, negotiated) + if !cmp.Equal(tc.negotiated, negotiated) { + t.Fatalf("unexpected negotiated: %v", cmp.Diff(tc.negotiated, negotiated)) + } }) } } @@ -289,10 +300,8 @@ func Test_authenticateOrigin(t *testing.T) { r.Header.Set("Origin", tc.origin) err := authenticateOrigin(r) - if tc.success { - assert.Success(t, "authenticateOrigin", err) - } else { - assert.Error(t, "authenticateOrigin", err) + if tc.success != (err == nil) { + t.Fatalf("unexpected error value: %v", err) } }) } @@ -364,13 +373,21 @@ func Test_acceptCompression(t *testing.T) { w := httptest.NewRecorder() copts, err := acceptCompression(r, w, tc.mode) if tc.error { - assert.Error(t, "acceptCompression", err) + if err == nil { + t.Fatalf("expected error: %v", copts) + } return } - assert.Success(t, "acceptCompression", err) - assert.Equal(t, "compresssionOpts", tc.expCopts, copts) - assert.Equal(t, "respHeader", tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) + if err != nil { + t.Fatal(err) + } + if !cmp.Equal(tc.expCopts, copts) { + t.Fatalf("unexpected compression options: %v", cmp.Diff(tc.expCopts, copts)) + } + if !cmp.Equal(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions")) { + t.Fatalf("unexpected respHeader: %v", cmp.Diff(tc.respSecWebSocketExtensions, w.Header().Get("Sec-WebSocket-Extensions"))) + } }) } } diff --git a/autobahn_test.go b/autobahn_test.go index d730cf4..4d0bd1b 100644 --- a/autobahn_test.go +++ b/autobahn_test.go @@ -15,11 +15,11 @@ import ( "testing" "time" - "cdr.dev/slog/sloggers/slogtest/assert" "golang.org/x/xerrors" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/test/wstest" ) var excludedAutobahnCases = []string{ @@ -45,14 +45,20 @@ func TestAutobahn(t *testing.T) { defer cancel() wstestURL, closeFn, err := wstestClientServer(ctx) - assert.Success(t, "wstestClient", err) + if err != nil { + t.Fatal(err) + } defer closeFn() err = waitWS(ctx, wstestURL) - assert.Success(t, "waitWS", err) + if err != nil { + t.Fatal(err) + } cases, err := wstestCaseCount(ctx, wstestURL) - assert.Success(t, "wstestCaseCount", err) + if err != nil { + t.Fatal(err) + } t.Run("cases", func(t *testing.T) { for i := 1; i <= cases; i++ { @@ -62,16 +68,19 @@ func TestAutobahn(t *testing.T) { defer cancel() c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/runCase?case=%v&agent=main", i), nil) - assert.Success(t, "autobahn dial", err) - - err = echoLoop(ctx, c) + if err != nil { + t.Fatal(err) + } + err = wstest.EchoLoop(ctx, c) t.Logf("echoLoop: %v", err) }) } }) c, _, err := websocket.Dial(ctx, fmt.Sprintf(wstestURL+"/updateReports?agent=main"), nil) - assert.Success(t, "dial", err) + if err != nil { + t.Fatal(err) + } c.Close(websocket.StatusNormalClosure, "") checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") @@ -163,14 +172,18 @@ func wstestCaseCount(ctx context.Context, url string) (cases int, err error) { func checkWSTestIndex(t *testing.T, path string) { wstestOut, err := ioutil.ReadFile(path) - assert.Success(t, "ioutil.ReadFile", err) + if err != nil { + t.Fatal(err) + } var indexJSON map[string]map[string]struct { Behavior string `json:"behavior"` BehaviorClose string `json:"behaviorClose"` } err = json.Unmarshal(wstestOut, &indexJSON) - assert.Success(t, "json.Unmarshal", err) + if err != nil { + t.Fatal(err) + } for _, tests := range indexJSON { for test, result := range tests { diff --git a/close.go b/close.go index 931160e..2007323 100644 --- a/close.go +++ b/close.go @@ -1,17 +1,9 @@ -// +build !js - package websocket import ( - "context" - "encoding/binary" "fmt" - "log" - "time" "golang.org/x/xerrors" - - "nhooyr.io/websocket/internal/errd" ) // StatusCode represents a WebSocket status code. @@ -83,188 +75,3 @@ func CloseStatus(err error) StatusCode { } return -1 } - -// Close performs the WebSocket close handshake with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// All data messages received from the peer during the close handshake will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes. Avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - return c.closeHandshake(code, reason) -} - -func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { - defer errd.Wrap(&err, "failed to close WebSocket") - - err = c.writeClose(code, reason) - if err != nil { - return err - } - - err = c.waitCloseHandshake() - if CloseStatus(err) == -1 { - return err - } - return nil -} - -func (c *Conn) writeError(code StatusCode, err error) { - c.setCloseErr(err) - c.writeClose(code, err.Error()) - c.close(nil) -} - -func (c *Conn) writeClose(code StatusCode, reason string) error { - c.closeMu.Lock() - closing := c.wroteClose - c.wroteClose = true - c.closeMu.Unlock() - if closing { - return xerrors.New("already wrote close") - } - - ce := CloseError{ - Code: code, - Reason: reason, - } - - c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) - - var p []byte - if ce.Code != StatusNoStatusRcvd { - p = ce.bytes() - } - - return c.writeControl(context.Background(), opClose, p) -} - -func (c *Conn) waitCloseHandshake() error { - defer c.close(nil) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := c.readMu.Lock(ctx) - if err != nil { - return err - } - defer c.readMu.Unlock() - - if c.readCloseFrameErr != nil { - return c.readCloseFrameErr - } - - for { - h, err := c.readLoop(ctx) - if err != nil { - return err - } - - for i := int64(0); i < h.payloadLength; i++ { - _, err := c.br.ReadByte() - if err != nil { - return err - } - } - } -} - -func parseClosePayload(p []byte) (CloseError, error) { - if len(p) == 0 { - return CloseError{ - Code: StatusNoStatusRcvd, - }, nil - } - - if len(p) < 2 { - return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - ce := CloseError{ - Code: StatusCode(binary.BigEndian.Uint16(p)), - Reason: string(p[2:]), - } - - if !validWireCloseCode(ce.Code) { - return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) - } - - return ce, nil -} - -// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number -// and https://tools.ietf.org/html/rfc6455#section-7.4.1 -func validWireCloseCode(code StatusCode) bool { - switch code { - case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: - return false - } - - if code >= StatusNormalClosure && code <= StatusBadGateway { - return true - } - if code >= 3000 && code <= 4999 { - return true - } - - return false -} - -func (ce CloseError) bytes() []byte { - p, err := ce.bytesErr() - if err != nil { - log.Printf("websocket: failed to marshal close frame: %+v", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytesErr() - } - return p -} - -const maxCloseReason = maxControlPayload - 2 - -func (ce CloseError) bytesErr() ([]byte, error) { - if len(ce.Reason) > maxCloseReason { - return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) - } - - if !validWireCloseCode(ce.Code) { - return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) - } - - buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf, uint16(ce.Code)) - copy(buf[2:], ce.Reason) - return buf, nil -} - -func (c *Conn) setCloseErr(err error) { - c.closeMu.Lock() - c.setCloseErrLocked(err) - c.closeMu.Unlock() -} - -func (c *Conn) setCloseErrLocked(err error) { - if c.closeErr == nil { - c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) - } -} - -func (c *Conn) isClosed() bool { - select { - case <-c.closed: - return true - default: - return false - } -} diff --git a/close_notjs.go b/close_notjs.go new file mode 100644 index 0000000..dd1b0e0 --- /dev/null +++ b/close_notjs.go @@ -0,0 +1,199 @@ +// +build !js + +package websocket + +import ( + "context" + "encoding/binary" + "log" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket/internal/errd" +) + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + return c.closeHandshake(code, reason) +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + err = c.writeClose(code, reason) + if err != nil { + return err + } + + err = c.waitCloseHandshake() + if CloseStatus(err) == -1 { + return err + } + return nil +} + +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.close(nil) +} + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + closing := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if closing { + return xerrors.New("already wrote close") + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + c.setCloseErr(xerrors.Errorf("sent close frame: %w", ce)) + + var p []byte + if ce.Code != StatusNoStatusRcvd { + p = ce.bytes() + } + + return c.writeControl(context.Background(), opClose, p) +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + err := c.readMu.Lock(ctx) + if err != nil { + return err + } + defer c.readMu.Unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, xerrors.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() []byte { + p, err := ce.bytesErr() + if err != nil { + log.Printf("websocket: failed to marshal close frame: %v", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil { + c.closeErr = xerrors.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/compress.go b/compress.go index efd89b3..918b3b4 100644 --- a/compress.go +++ b/compress.go @@ -1,14 +1,5 @@ -// +build !js - package websocket -import ( - "compress/flate" - "io" - "net/http" - "sync" -) - // CompressionOptions represents the available deflate extension options. // See https://tools.ietf.org/html/rfc7692 type CompressionOptions struct { @@ -60,149 +51,3 @@ const ( // important than bandwidth. CompressionDisabled ) - -func (m CompressionMode) opts() *compressionOptions { - if m == CompressionDisabled { - return nil - } - return &compressionOptions{ - clientNoContextTakeover: m == CompressionNoContextTakeover, - serverNoContextTakeover: m == CompressionNoContextTakeover, - } -} - -type compressionOptions struct { - clientNoContextTakeover bool - serverNoContextTakeover bool -} - -func (copts *compressionOptions) setHeader(h http.Header) { - s := "permessage-deflate" - if copts.clientNoContextTakeover { - s += "; client_no_context_takeover" - } - if copts.serverNoContextTakeover { - s += "; server_no_context_takeover" - } - h.Set("Sec-WebSocket-Extensions", s) -} - -// These bytes are required to get flate.Reader to return. -// They are removed when sending to avoid the overhead as -// WebSocket framing tell's when the message has ended but then -// we need to add them back otherwise flate.Reader keeps -// trying to return more bytes. -const deflateMessageTail = "\x00\x00\xff\xff" - -func (c *Conn) writeNoContextTakeOver() bool { - return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover -} - -func (c *Conn) readNoContextTakeOver() bool { - return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover -} - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte -} - -func (tw *trimLastFourBytesWriter) reset() { - tw.tail = tw.tail[:0] -} - -func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { - extra := len(tw.tail) + len(p) - 4 - - if extra <= 0 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(tw.tail) { - extra = len(tw.tail) - } - if extra > 0 { - _, err := tw.w.Write(tw.tail[:extra]) - if err != nil { - return 0, err - } - tw.tail = tw.tail[extra:] - } - - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - tw.tail = append(tw.tail, p...) - return len(p), nil - } - - // Otherwise, only the last 4 bytes are. - tw.tail = append(tw.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := tw.w.Write(p) - return n + 4, err -} - -var flateReaderPool sync.Pool - -func getFlateReader(r io.Reader, dict []byte) io.Reader { - fr, ok := flateReaderPool.Get().(io.Reader) - if !ok { - return flate.NewReaderDict(r, dict) - } - fr.(flate.Resetter).Reset(r, dict) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -var flateWriterPool sync.Pool - -func getFlateWriter(w io.Writer) *flate.Writer { - fw, ok := flateWriterPool.Get().(*flate.Writer) - if !ok { - fw, _ = flate.NewWriter(w, flate.BestSpeed) - return fw - } - fw.Reset(w) - return fw -} - -func putFlateWriter(w *flate.Writer) { - flateWriterPool.Put(w) -} - -type slidingWindow struct { - r io.Reader - buf []byte -} - -func newSlidingWindow(n int) *slidingWindow { - return &slidingWindow{ - buf: make([]byte, 0, n), - } -} - -func (w *slidingWindow) write(p []byte) { - if len(p) >= cap(w.buf) { - w.buf = w.buf[:cap(w.buf)] - p = p[len(p)-cap(w.buf):] - copy(w.buf, p) - return - } - - left := cap(w.buf) - len(w.buf) - if left < len(p) { - // We need to shift spaceNeeded bytes from the end to make room for p at the end. - spaceNeeded := len(p) - left - copy(w.buf, w.buf[spaceNeeded:]) - w.buf = w.buf[:len(w.buf)-spaceNeeded] - } - - w.buf = append(w.buf, p...) -} diff --git a/compress_notjs.go b/compress_notjs.go new file mode 100644 index 0000000..8bc2f87 --- /dev/null +++ b/compress_notjs.go @@ -0,0 +1,156 @@ +// +build !js + +package websocket + +import ( + "compress/flate" + "io" + "net/http" + "sync" +) + +func (m CompressionMode) opts() *compressionOptions { + if m == CompressionDisabled { + return nil + } + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +func (c *Conn) writeNoContextTakeOver() bool { + return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover +} + +func (c *Conn) readNoContextTakeOver() bool { + return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover +} + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + tw.tail = tw.tail[:0] +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + tw.tail = tw.tail[extra:] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} + +type slidingWindow struct { + r io.Reader + buf []byte +} + +func newSlidingWindow(n int) *slidingWindow { + return &slidingWindow{ + buf: make([]byte, 0, n), + } +} + +func (w *slidingWindow) write(p []byte) { + if len(p) >= cap(w.buf) { + w.buf = w.buf[:cap(w.buf)] + p = p[len(p)-cap(w.buf):] + copy(w.buf, p) + return + } + + left := cap(w.buf) - len(w.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(w.buf, w.buf[spaceNeeded:]) + w.buf = w.buf[:len(w.buf)-spaceNeeded] + } + + w.buf = append(w.buf, p...) +} diff --git a/compress_test.go b/compress_test.go index 15d334d..51f658c 100644 --- a/compress_test.go +++ b/compress_test.go @@ -1,11 +1,11 @@ +// +build !js + package websocket import ( "strings" "testing" - "cdr.dev/slog/sloggers/slogtest/assert" - "nhooyr.io/websocket/internal/test/xrand" ) @@ -15,14 +15,21 @@ func Test_slidingWindow(t *testing.T) { const testCount = 99 const maxWindow = 99999 for i := 0; i < testCount; i++ { - input := xrand.String(maxWindow) - windowLength := xrand.Int(maxWindow) - r := newSlidingWindow(windowLength) - r.write([]byte(input)) - - if cap(r.buf) != windowLength { - t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) - } - assert.True(t, "hasSuffix", strings.HasSuffix(input, string(r.buf))) + t.Run("", func(t *testing.T) { + t.Parallel() + + input := xrand.String(maxWindow) + windowLength := xrand.Int(maxWindow) + r := newSlidingWindow(windowLength) + r.write([]byte(input)) + + if cap(r.buf) != windowLength { + t.Fatalf("sliding window length changed somehow: %q and windowLength %d", input, windowLength) + } + + if !strings.HasSuffix(input, string(r.buf)) { + t.Fatalf("r.buf is not a suffix of input: %q and %q", input, r.buf) + } + }) } } diff --git a/conn.go b/conn.go index 163802b..e58a874 100644 --- a/conn.go +++ b/conn.go @@ -2,18 +2,6 @@ package websocket -import ( - "bufio" - "context" - "io" - "runtime" - "strconv" - "sync" - "sync/atomic" - - "golang.org/x/xerrors" -) - // MessageType represents the type of a WebSocket message. // See https://tools.ietf.org/html/rfc6455#section-5.6 type MessageType int @@ -25,252 +13,3 @@ const ( // MessageBinary is for binary messages like protobufs. MessageBinary ) - -// Conn represents a WebSocket connection. -// All methods may be called concurrently except for Reader and Read. -// -// You must always read from the connection. Otherwise control -// frames will not be handled. See Reader and CloseRead. -// -// Be sure to call Close on the connection when you -// are finished with it to release associated resources. -// -// On any error from any method, the connection is closed -// with an appropriate reason. -type Conn struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - br *bufio.Reader - bw *bufio.Writer - - readTimeout chan context.Context - writeTimeout chan context.Context - - // Read state. - readMu *mu - readControlBuf [maxControlPayload]byte - msgReader *msgReader - readCloseFrameErr error - - // Write state. - msgWriter *msgWriter - writeFrameMu *mu - writeBuf []byte - writeHeader header - - closed chan struct{} - closeMu sync.Mutex - closeErr error - wroteClose bool - - pingCounter int32 - activePingsMu sync.Mutex - activePings map[string]chan<- struct{} -} - -type connConfig struct { - subprotocol string - rwc io.ReadWriteCloser - client bool - copts *compressionOptions - flateThreshold int - - br *bufio.Reader - bw *bufio.Writer -} - -func newConn(cfg connConfig) *Conn { - c := &Conn{ - subprotocol: cfg.subprotocol, - rwc: cfg.rwc, - client: cfg.client, - copts: cfg.copts, - flateThreshold: cfg.flateThreshold, - - br: cfg.br, - bw: cfg.bw, - - readTimeout: make(chan context.Context), - writeTimeout: make(chan context.Context), - - closed: make(chan struct{}), - activePings: make(map[string]chan<- struct{}), - } - if c.flate() && c.flateThreshold == 0 { - c.flateThreshold = 256 - if c.writeNoContextTakeOver() { - c.flateThreshold = 512 - } - } - - c.readMu = newMu(c) - c.writeFrameMu = newMu(c) - - c.msgReader = newMsgReader(c) - - c.msgWriter = newMsgWriter(c) - if c.client { - c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) - } - - runtime.SetFinalizer(c, func(c *Conn) { - c.close(xerrors.New("connection garbage collected")) - }) - - go c.timeoutLoop() - - return c -} - -// Subprotocol returns the negotiated subprotocol. -// An empty string means the default protocol. -func (c *Conn) Subprotocol() string { - return c.subprotocol -} - -func (c *Conn) close(err error) { - c.closeMu.Lock() - defer c.closeMu.Unlock() - - if c.isClosed() { - return - } - close(c.closed) - runtime.SetFinalizer(c, nil) - c.setCloseErrLocked(err) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.rwc.Close() - - go func() { - if c.client { - c.writeFrameMu.Lock(context.Background()) - putBufioWriter(c.bw) - } - c.msgWriter.close() - - c.msgReader.close() - if c.client { - putBufioReader(c.br) - } - }() -} - -func (c *Conn) timeoutLoop() { - readCtx := context.Background() - writeCtx := context.Background() - - for { - select { - case <-c.closed: - return - - case writeCtx = <-c.writeTimeout: - case readCtx = <-c.readTimeout: - - case <-readCtx.Done(): - c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) - go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) - case <-writeCtx.Done(): - c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) - return - } - } -} - -func (c *Conn) flate() bool { - return c.copts != nil -} - -// Ping sends a ping to the peer and waits for a pong. -// Use this to measure latency or ensure the peer is responsive. -// Ping must be called concurrently with Reader as it does -// not read from the connection but instead waits for a Reader call -// to read the pong. -// -// TCP Keepalives should suffice for most use cases. -func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) - - err := c.ping(ctx, strconv.Itoa(int(p))) - if err != nil { - return xerrors.Errorf("failed to ping: %w", err) - } - return nil -} - -func (c *Conn) ping(ctx context.Context, p string) error { - pong := make(chan struct{}) - - c.activePingsMu.Lock() - c.activePings[p] = pong - c.activePingsMu.Unlock() - - defer func() { - c.activePingsMu.Lock() - delete(c.activePings, p) - c.activePingsMu.Unlock() - }() - - err := c.writeControl(ctx, opPing, []byte(p)) - if err != nil { - return err - } - - select { - case <-c.closed: - return c.closeErr - case <-ctx.Done(): - err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) - c.close(err) - return err - case <-pong: - return nil - } -} - -type mu struct { - c *Conn - ch chan struct{} -} - -func newMu(c *Conn) *mu { - return &mu{ - c: c, - ch: make(chan struct{}, 1), - } -} - -func (m *mu) Lock(ctx context.Context) error { - select { - case <-m.c.closed: - return m.c.closeErr - case <-ctx.Done(): - err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) - m.c.close(err) - return err - case m.ch <- struct{}{}: - return nil - } -} - -func (m *mu) TryLock() bool { - select { - case m.ch <- struct{}{}: - return true - default: - return false - } -} - -func (m *mu) Unlock() { - select { - case <-m.ch: - default: - } -} diff --git a/conn_notjs.go b/conn_notjs.go new file mode 100644 index 0000000..d2fea4d --- /dev/null +++ b/conn_notjs.go @@ -0,0 +1,264 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "io" + "runtime" + "strconv" + "sync" + "sync/atomic" + + "golang.org/x/xerrors" +) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +type Conn struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriter *msgWriter + writeFrameMu *mu + writeBuf []byte + writeHeader header + + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 256 + if c.writeNoContextTakeOver() { + c.flateThreshold = 512 + } + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriter = newMsgWriter(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(xerrors.New("connection garbage collected")) + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + close(c.closed) + runtime.SetFinalizer(c, nil) + c.setCloseErrLocked(err) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + go func() { + if c.client { + c.writeFrameMu.Lock(context.Background()) + putBufioWriter(c.bw) + } + c.msgWriter.close() + + c.msgReader.close() + if c.client { + putBufioReader(c.br) + } + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(xerrors.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, xerrors.New("timed out")) + case <-writeCtx.Done(): + c.close(xerrors.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return xerrors.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err := xerrors.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) Lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := xerrors.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + return nil + } +} + +func (m *mu) TryLock() bool { + select { + case m.ch <- struct{}{}: + return true + default: + return false + } +} + +func (m *mu) Unlock() { + select { + case <-m.ch: + default: + } +} diff --git a/conn_test.go b/conn_test.go index 02606ef..5c817a2 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,33 +4,23 @@ package websocket_test import ( "context" - "io" + "fmt" + "net/http" + "net/http/httptest" + "os" + "os/exec" + "sync" "testing" "time" "golang.org/x/xerrors" "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/test/cmp" "nhooyr.io/websocket/internal/test/wstest" "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/internal/xsync" ) -func goFn(fn func() error) chan error { - errs := make(chan error) - go func() { - defer func() { - r := recover() - if r != nil { - errs <- xerrors.Errorf("panic in gofn: %v", r) - } - }() - errs <- fn() - }() - - return errs -} - func TestConn(t *testing.T) { t.Parallel() @@ -44,7 +34,7 @@ func TestConn(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - copts := websocket.CompressionOptions{ + copts := &websocket.CompressionOptions{ Mode: websocket.CompressionMode(xrand.Int(int(websocket.CompressionDisabled) + 1)), Threshold: xrand.Int(9999), } @@ -60,8 +50,8 @@ func TestConn(t *testing.T) { defer c1.Close(websocket.StatusInternalError, "") defer c2.Close(websocket.StatusInternalError, "") - echoLoopErr := goFn(func() error { - err := echoLoop(ctx, c1) + echoLoopErr := xsync.Go(func() error { + err := wstest.EchoLoop(ctx, c1) return assertCloseStatus(websocket.StatusNormalClosure, err) }) defer func() { @@ -72,39 +62,13 @@ func TestConn(t *testing.T) { }() defer cancel() - c2.SetReadLimit(1 << 30) + c2.SetReadLimit(131072) for i := 0; i < 5; i++ { - n := xrand.Int(131_072) - - msg := xrand.Bytes(n) - - expType := websocket.MessageBinary - if xrand.Bool() { - expType = websocket.MessageText - } - - writeErr := goFn(func() error { - return c2.Write(ctx, expType, msg) - }) - - actType, act, err := c2.Read(ctx) - if err != nil { - t.Fatal(err) - } - - err = <-writeErr + err := wstest.Echo(ctx, c2, 131072) if err != nil { t.Fatal(err) } - - if expType != actType { - t.Fatalf("unexpected message typ (%v): %v", expType, actType) - } - - if !cmp.Equal(msg, act) { - t.Fatalf("unexpected msg read: %v", cmp.Diff(msg, act)) - } } c2.Close(websocket.StatusNormalClosure, "") @@ -113,47 +77,50 @@ func TestConn(t *testing.T) { }) } -func assertCloseStatus(exp websocket.StatusCode, err error) error { - if websocket.CloseStatus(err) == -1 { - return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) - } - if websocket.CloseStatus(err) != exp { - return xerrors.Errorf("unexpected close status (%v):%v", exp, err) - } - return nil -} - -// echoLoop echos every msg received from c until an error -// occurs or the context expires. -// The read limit is set to 1 << 30. -func echoLoop(ctx context.Context, c *websocket.Conn) error { - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) +func TestWasm(t *testing.T) { + t.Parallel() - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() + var wg sync.WaitGroup + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wg.Add(1) + defer wg.Done() - b := make([]byte, 32<<10) - for { - typ, r, err := c.Reader(ctx) + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) if err != nil { - return err + t.Error(err) + return } + defer c.Close(websocket.StatusInternalError, "") - w, err := c.Writer(ctx, typ) - if err != nil { - return err + err = wstest.EchoLoop(r.Context(), c) + if websocket.CloseStatus(err) != websocket.StatusNormalClosure { + t.Errorf("echoLoop: %v", err) } + })) + defer wg.Wait() + defer s.Close() - _, err = io.CopyBuffer(w, r, b) - if err != nil { - return err - } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*20) + defer cancel() - err = w.Close() - if err != nil { - return err - } + cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") + cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wstest.URL(s))) + + b, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("wasm test binary failed: %v:\n%s", err, b) } } + +func assertCloseStatus(exp websocket.StatusCode, err error) error { + if websocket.CloseStatus(err) == -1 { + return xerrors.Errorf("expected websocket.CloseError: %T %v", err, err) + } + if websocket.CloseStatus(err) != exp { + return xerrors.Errorf("unexpected close status (%v):%v", exp, err) + } + return nil +} diff --git a/dial.go b/dial.go index a1509ab..3e2042e 100644 --- a/dial.go +++ b/dial.go @@ -35,8 +35,7 @@ type DialOptions struct { // CompressionOptions controls the compression options. // See docs on the CompressionOptions type. - // TODO make * - CompressionOptions CompressionOptions + CompressionOptions *CompressionOptions } // Dial performs a WebSocket handshake on url. @@ -60,6 +59,7 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( if opts == nil { opts = &DialOptions{} } + opts = &*opts if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient @@ -67,6 +67,9 @@ func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) ( if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } + if opts.CompressionOptions == nil { + opts.CompressionOptions = &CompressionOptions{} + } secWebSocketKey, err := secWebSocketKey(rand) if err != nil { diff --git a/dial_test.go b/dial_test.go index 3be5220..e38e8f1 100644 --- a/dial_test.go +++ b/dial_test.go @@ -223,7 +223,7 @@ func Test_verifyServerHandshake(t *testing.T) { } _, err = verifyServerResponse(opts, key, resp) if (err == nil) != tc.success { - t.Fatalf("unexpected error: %+v", err) + t.Fatalf("unexpected error: %v", err) } }) } diff --git a/example_test.go b/example_test.go index 1842b76..075107b 100644 --- a/example_test.go +++ b/example_test.go @@ -74,8 +74,7 @@ func ExampleCloseStatus() { _, _, err = c.Reader(ctx) if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %+v", err) - return + log.Fatalf("expected to be disconnected with StatusNormalClosure but got: %v", err) } } diff --git a/frame.go b/frame.go index 47ff40f..0257835 100644 --- a/frame.go +++ b/frame.go @@ -1,5 +1,3 @@ -// +build !js - package websocket import ( diff --git a/internal/test/cmp/cmp.go b/internal/test/cmp/cmp.go index d0eee6d..cdbadf7 100644 --- a/internal/test/cmp/cmp.go +++ b/internal/test/cmp/cmp.go @@ -2,6 +2,7 @@ package cmp import ( "reflect" + "strings" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -20,3 +21,11 @@ func Diff(v1, v2 interface{}) string { return true })) } + +// ErrorContains returns whether err.Error() contains sub. +func ErrorContains(err error, sub string) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), sub) +} diff --git a/internal/test/wstest/echo.go b/internal/test/wstest/echo.go new file mode 100644 index 0000000..70b2ba5 --- /dev/null +++ b/internal/test/wstest/echo.go @@ -0,0 +1,90 @@ +package wstest + +import ( + "context" + "io" + "time" + + "golang.org/x/xerrors" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/xrand" + "nhooyr.io/websocket/internal/xsync" +) + +// EchoLoop echos every msg received from c until an error +// occurs or the context expires. +// The read limit is set to 1 << 30. +func EchoLoop(ctx context.Context, c *websocket.Conn) error { + defer c.Close(websocket.StatusInternalError, "") + + c.SetReadLimit(1 << 30) + + ctx, cancel := context.WithTimeout(ctx, time.Minute) + defer cancel() + + b := make([]byte, 32<<10) + for { + typ, r, err := c.Reader(ctx) + if err != nil { + return err + } + + w, err := c.Writer(ctx, typ) + if err != nil { + return err + } + + _, err = io.CopyBuffer(w, r, b) + if err != nil { + return err + } + + err = w.Close() + if err != nil { + return err + } + } +} + +// Echo writes a message and ensures the same is sent back on c. +func Echo(ctx context.Context, c *websocket.Conn, max int) error { + expType := websocket.MessageBinary + if xrand.Bool() { + expType = websocket.MessageText + } + + msg := randMessage(expType, xrand.Int(max)) + + writeErr := xsync.Go(func() error { + return c.Write(ctx, expType, msg) + }) + + actType, act, err := c.Read(ctx) + if err != nil { + return err + } + + err = <-writeErr + if err != nil { + return err + } + + if expType != actType { + return xerrors.Errorf("unexpected message typ (%v): %v", expType, actType) + } + + if !cmp.Equal(msg, act) { + return xerrors.Errorf("unexpected msg read: %v", cmp.Diff(msg, act)) + } + + return nil +} + +func randMessage(typ websocket.MessageType, n int) []byte { + if typ == websocket.MessageBinary { + return xrand.Bytes(n) + } + return []byte(xrand.String(n)) +} diff --git a/internal/test/wstest/pipe.go b/internal/test/wstest/pipe.go index e958aea..81705a8 100644 --- a/internal/test/wstest/pipe.go +++ b/internal/test/wstest/pipe.go @@ -1,3 +1,5 @@ +// +build !js + package wstest import ( @@ -30,6 +32,7 @@ func Pipe(dialOpts *websocket.DialOptions, acceptOpts *websocket.AcceptOptions) if dialOpts == nil { dialOpts = &websocket.DialOptions{} } + dialOpts = &*dialOpts dialOpts.HTTPClient = &http.Client{ Transport: tt, } diff --git a/internal/test/wstest/url.go b/internal/test/wstest/url.go new file mode 100644 index 0000000..a11c61b --- /dev/null +++ b/internal/test/wstest/url.go @@ -0,0 +1,11 @@ +package wstest + +import ( + "net/http/httptest" + "strings" +) + +// URL returns the ws url for s. +func URL(s *httptest.Server) string { + return strings.Replace(s.URL, "http", "ws", 1) +} diff --git a/internal/xsync/go.go b/internal/xsync/go.go new file mode 100644 index 0000000..96cf810 --- /dev/null +++ b/internal/xsync/go.go @@ -0,0 +1,25 @@ +package xsync + +import ( + "golang.org/x/xerrors" +) + +// Go allows running a function in another goroutine +// and waiting for its error. +func Go(fn func() error) chan error { + errs := make(chan error, 1) + go func() { + defer func() { + r := recover() + if r != nil { + select { + case errs <- xerrors.Errorf("panic in go fn: %v", r): + default: + } + } + }() + errs <- fn() + }() + + return errs +} diff --git a/internal/xsync/go_test.go b/internal/xsync/go_test.go new file mode 100644 index 0000000..c0613e6 --- /dev/null +++ b/internal/xsync/go_test.go @@ -0,0 +1,20 @@ +package xsync + +import ( + "testing" + + "nhooyr.io/websocket/internal/test/cmp" +) + +func TestGoRecover(t *testing.T) { + t.Parallel() + + errs := Go(func() error { + panic("anmol") + }) + + err := <-errs + if !cmp.ErrorContains(err, "anmol") { + t.Fatalf("unexpected err: %v", err) + } +} diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go new file mode 100644 index 0000000..a0c4020 --- /dev/null +++ b/internal/xsync/int64.go @@ -0,0 +1,23 @@ +package xsync + +import ( + "sync/atomic" +) + +// Int64 represents an atomic int64. +type Int64 struct { + // We do not use atomic.Load/StoreInt64 since it does not + // work on 32 bit computers but we need 64 bit integers. + i atomic.Value +} + +// Load loads the int64. +func (v *Int64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +// Store stores the int64. +func (v *Int64) Store(i int64) { + v.i.Store(i) +} diff --git a/read.go b/read.go index b681a94..e723ef3 100644 --- a/read.go +++ b/read.go @@ -7,12 +7,12 @@ import ( "io" "io/ioutil" "strings" - "sync/atomic" "time" "golang.org/x/xerrors" "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until until there is a WebSocket @@ -415,7 +415,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit atomicInt64 + limit xsync.Int64 n int64 } @@ -448,21 +448,6 @@ func (lr *limitReader) Read(p []byte) (int, error) { return n, err } -type atomicInt64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -func (v *atomicInt64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -func (v *atomicInt64) Store(i int64) { - v.i.Store(i) -} - type readerFunc func(p []byte) (int, error) func (f readerFunc) Read(p []byte) (int, error) { diff --git a/ws_js.go b/ws_js.go index 3ce6f34..de76afa 100644 --- a/ws_js.go +++ b/ws_js.go @@ -4,6 +4,7 @@ import ( "bytes" "context" "io" + "net/http" "reflect" "runtime" "sync" @@ -13,6 +14,7 @@ import ( "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" + "nhooyr.io/websocket/internal/xsync" ) // MessageType represents the type of a WebSocket message. @@ -32,10 +34,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit atomicInt64 + msgReadLimit xsync.Int64 closingMu sync.Mutex - isReadClosed atomicInt64 + isReadClosed xsync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -67,11 +69,8 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &wssync.Int64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &wssync.Int64{} - c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ Code: StatusCode(e.Code), @@ -121,7 +120,7 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { return 0, nil, xerrors.Errorf("failed to read: %w", err) } if int64(len(p)) > c.msgReadLimit.Load() { - err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit) + err := xerrors.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) c.Close(StatusMessageTooBig, err.Error()) return 0, nil, err } @@ -248,17 +247,17 @@ type DialOptions struct { // Dial creates a new WebSocket connection to the given url with the given options. // The passed context bounds the maximum time spent waiting for the connection to open. -// The returned *http.Response is always nil or the zero value. It's only in the signature +// The returned *http.Response is always nil or a mock. It's only in the signature // to match the core API. -func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { - c, err := dial(ctx, url, opts) +func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + c, resp, err := dial(ctx, url, opts) if err != nil { - return nil, resp, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) + return nil, nil, xerrors.Errorf("failed to WebSocket dial %q: %w", url, err) } - return c, nil + return c, resp, nil } -func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { +func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { if opts == nil { opts = &DialOptions{} } @@ -284,11 +283,12 @@ func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, error) { c.Close(StatusPolicyViolation, "dial timed out") return nil, nil, ctx.Err() case <-opench: + return c, &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + }, nil case <-c.closed: - return c, nil, c.closeErr + return nil, nil, c.closeErr } - - return c, nil } // Reader attempts to read a message from the connection. diff --git a/ws_js_test.go b/ws_js_test.go index 65309bf..8d49af6 100644 --- a/ws_js_test.go +++ b/ws_js_test.go @@ -1,4 +1,4 @@ -package websocket +package websocket_test import ( "context" @@ -6,25 +6,43 @@ import ( "os" "testing" "time" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/test/cmp" + "nhooyr.io/websocket/internal/test/wstest" ) -func TestEcho(t *testing.T) { +func TestWasm(t *testing.T) { t.Parallel() ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() - c, resp, err := Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &DialOptions{ + c, resp, err := websocket.Dial(ctx, os.Getenv("WS_ECHO_SERVER_URL"), &websocket.DialOptions{ Subprotocols: []string{"echo"}, }) - assert.Success(t, err) - defer c.Close(StatusInternalError, "") + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + if !cmp.Equal("echo", c.Subprotocol()) { + t.Fatalf("unexpected subprotocol: %v", cmp.Diff("echo", c.Subprotocol())) + } + if !cmp.Equal(http.StatusSwitchingProtocols, resp.StatusCode) { + t.Fatalf("unexpected status code: %v", cmp.Diff(http.StatusSwitchingProtocols, resp.StatusCode)) + } - assertSubprotocol(t, c, "echo") - assert.Equalf(t, &http.Response{}, resp, "http.Response") - echoJSON(t, ctx, c, 1024) - assertEcho(t, ctx, c, MessageBinary, 1024) + c.SetReadLimit(65536) + for i := 0; i < 10; i++ { + err = wstest.Echo(ctx, c, 65536) + if err != nil { + t.Fatal(err) + } + } - err = c.Close(StatusNormalClosure, "") - assert.Success(t, err) + err = c.Close(websocket.StatusNormalClosure, "") + if err != nil { + t.Fatal(err) + } } -- GitLab