From d0a80496108cf7cdd4e20c24e4689cd5934b5b89 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 18 Nov 2019 22:52:18 -0500 Subject: [PATCH] Rewrite core Too many improvements and changes to list. Will include a detailed changelog for release. --- accept.go | 63 +- assert_test.go | 14 + autobahn_test.go | 252 ++ close.go | 158 +- close_test.go | 9 +- compress.go | 86 + conn.go | 1133 +------- conn_export_test.go | 129 - conn_test.go | 2382 +---------------- dial.go | 78 +- dial_test.go | 2 +- example_echo_test.go | 3 +- internal/wsframe/mask.go => frame.go | 162 +- .../wsframe/mask_test.go => frame_test.go | 108 +- internal/assert/assert.go | 40 +- internal/atomicint/atomicint.go | 32 - internal/bufpool/buf.go | 6 +- internal/bufpool/bufio.go | 40 - internal/errd/errd.go | 11 + internal/wsecho/wsecho.go | 55 - internal/wsframe/frame.go | 194 -- internal/wsframe/frame_stringer.go | 91 - internal/wsframe/frame_test.go | 157 -- internal/wsgrace/wsgrace.go | 50 - js_test.go | 50 - read.go | 479 ++++ reader.go | 31 - write.go | 348 +++ writer.go | 5 - ws_js.go | 12 +- wsjson/wsjson.go | 2 + 31 files changed, 1844 insertions(+), 4338 deletions(-) create mode 100644 autobahn_test.go delete mode 100644 conn_export_test.go rename internal/wsframe/mask.go => frame.go (57%) rename internal/wsframe/mask_test.go => frame_test.go (51%) delete mode 100644 internal/atomicint/atomicint.go delete mode 100644 internal/bufpool/bufio.go create mode 100644 internal/errd/errd.go delete mode 100644 internal/wsecho/wsecho.go delete mode 100644 internal/wsframe/frame.go delete mode 100644 internal/wsframe/frame_stringer.go delete mode 100644 internal/wsframe/frame_test.go delete mode 100644 internal/wsgrace/wsgrace.go delete mode 100644 js_test.go create mode 100644 read.go delete mode 100644 reader.go create mode 100644 write.go delete mode 100644 writer.go diff --git a/accept.go b/accept.go index 5ff2ea4..2028d4b 100644 --- a/accept.go +++ b/accept.go @@ -60,10 +60,15 @@ func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, return c, nil } -func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { +func (opts *AcceptOptions) ensure() *AcceptOptions { if opts == nil { - opts = &AcceptOptions{} + return &AcceptOptions{} } + return opts +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + opts = opts.ensure() err := verifyClientRequest(w, r) if err != nil { @@ -114,31 +119,14 @@ func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, b, _ := brw.Reader.Peek(brw.Reader.Buffered()) brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) - c := &Conn{ + return newConn(connConfig{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, br: brw.Reader, bw: brw.Writer, - closer: netConn, - copts: copts, - } - c.init() - - return c, nil -} - -func authenticateOrigin(r *http.Request) error { - origin := r.Header.Get("Origin") - if origin == "" { - return nil - } - u, err := url.Parse(origin) - if err != nil { - return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) - } - if !strings.EqualFold(u.Host, r.Host) { - return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) - } - return nil + }), nil } func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { @@ -181,15 +169,37 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { return nil } +func authenticateOrigin(r *http.Request) error { + origin := r.Header.Get("Origin") + if origin == "" { + return nil + } + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + if !strings.EqualFold(u.Host, r.Host) { + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) + } + return nil +} + func handleSecWebSocketKey(w http.ResponseWriter, r *http.Request) { key := r.Header.Get("Sec-WebSocket-Key") w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } func selectSubprotocol(r *http.Request, subprotocols []string) string { + cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") + if len(cps) == 0 { + return "" + } + for _, sp := range subprotocols { - if headerContainsToken(r.Header, "Sec-WebSocket-Protocol", sp) { - return sp + for _, cp := range cps { + if strings.EqualFold(sp, cp) { + return cp + } } } return "" @@ -266,7 +276,6 @@ func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode Com return copts, nil } - func headerContainsToken(h http.Header, key, token string) bool { token = strings.ToLower(token) diff --git a/assert_test.go b/assert_test.go index af30099..0cc9dfe 100644 --- a/assert_test.go +++ b/assert_test.go @@ -23,6 +23,8 @@ func randBytes(n int) []byte { } func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) { + t.Helper() + exp := randString(n) err := wsjson.Write(ctx, c, exp) assert.Success(t, err) @@ -35,6 +37,8 @@ func assertJSONEcho(t *testing.T, ctx context.Context, c *websocket.Conn, n int) } func assertJSONRead(t *testing.T, ctx context.Context, c *websocket.Conn, exp interface{}) { + t.Helper() + var act interface{} err := wsjson.Read(ctx, c, &act) assert.Success(t, err) @@ -56,6 +60,8 @@ func randString(n int) string { } func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websocket.MessageType, n int) { + t.Helper() + p := randBytes(n) err := c.Write(ctx, typ, p) assert.Success(t, err) @@ -68,5 +74,13 @@ func assertEcho(t *testing.T, ctx context.Context, c *websocket.Conn, typ websoc } func assertSubprotocol(t *testing.T, c *websocket.Conn, exp string) { + t.Helper() + assert.Equalf(t, exp, c.Subprotocol(), "unexpected subprotocol") } + +func assertCloseStatus(t *testing.T, exp websocket.StatusCode, err error) { + t.Helper() + + assert.Equalf(t, exp, websocket.CloseStatus(err), "unexpected status code") +} diff --git a/autobahn_test.go b/autobahn_test.go new file mode 100644 index 0000000..27f8a1b --- /dev/null +++ b/autobahn_test.go @@ -0,0 +1,252 @@ +package websocket_test + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net" + "net/http" + "net/http/httptest" + "nhooyr.io/websocket" + "os" + "os/exec" + "strconv" + "strings" + "testing" + "time" +) + +func TestAutobahn(t *testing.T) { + // This test contains the old autobahn test suite tests that use the + // python binary. The approach is clunky and slow so new tests + // have been written in pure Go in websocket_test.go. + // These have been kept for correctness purposes and are occasionally ran. + if os.Getenv("AUTOBAHN") == "" { + t.Skip("Set $AUTOBAHN to run tests against the autobahn test suite") + } + + t.Run("server", testServerAutobahnPython) + t.Run("client", testClientAutobahnPython) +} + +// https://github.com/crossbario/autobahn-python/tree/master/wstest +func testServerAutobahnPython(t *testing.T) { + t.Parallel() + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + }) + if err != nil { + t.Logf("server handshake failed: %+v", err) + return + } + echoLoop(r.Context(), c) + })) + defer s.Close() + + spec := map[string]interface{}{ + "outdir": "ci/out/wstestServerReports", + "servers": []interface{}{ + map[string]interface{}{ + "agent": "main", + "url": strings.Replace(s.URL, "http", "ws", 1), + }, + }, + "cases": []string{"*"}, + // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just + // more performance overhead. 7.5.1 is the same. + "exclude-cases": []string{"6.*", "7.5.1"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} + wstest := exec.CommandContext(ctx, "wstest", args...) + out, err := wstest.CombinedOutput() + if err != nil { + t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) + } + + checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") +} + +func unusedListenAddr() (string, error) { + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + return "", err + } + l.Close() + return l.Addr().String(), nil +} + +// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py +func testClientAutobahnPython(t *testing.T) { + t.Parallel() + + if os.Getenv("AUTOBAHN_PYTHON") == "" { + t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") + } + + serverAddr, err := unusedListenAddr() + if err != nil { + t.Fatalf("failed to get unused listen addr for wstest: %v", err) + } + + wsServerURL := "ws://" + serverAddr + + spec := map[string]interface{}{ + "url": wsServerURL, + "outdir": "ci/out/wstestClientReports", + "cases": []string{"*"}, + // See TestAutobahnServer for the reasons why we exclude these. + "exclude-cases": []string{"6.*", "7.5.1"}, + } + specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") + if err != nil { + t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) + } + defer specFile.Close() + + e := json.NewEncoder(specFile) + e.SetIndent("", "\t") + err = e.Encode(spec) + if err != nil { + t.Fatalf("failed to write spec: %v", err) + } + + err = specFile.Close() + if err != nil { + t.Fatalf("failed to close file: %v", err) + } + + ctx := context.Background() + ctx, cancel := context.WithTimeout(ctx, time.Minute*10) + defer cancel() + + args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), + // Disables some server that runs as part of fuzzingserver mode. + // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 + "--webport=0", + } + wstest := exec.CommandContext(ctx, "wstest", args...) + err = wstest.Start() + if err != nil { + t.Fatal(err) + } + defer func() { + err := wstest.Process.Kill() + if err != nil { + t.Error(err) + } + }() + + // Let it come up. + time.Sleep(time.Second * 5) + + var cases int + func() { + c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) + if err != nil { + t.Fatal(err) + } + defer c.Close(websocket.StatusInternalError, "") + + _, r, err := c.Reader(ctx) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(r) + if err != nil { + t.Fatal(err) + } + cases, err = strconv.Atoi(string(b)) + if err != nil { + t.Fatal(err) + } + + c.Close(websocket.StatusNormalClosure, "") + }() + + for i := 1; i <= cases; i++ { + func() { + ctx, cancel := context.WithTimeout(ctx, time.Second*45) + defer cancel() + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) + if err != nil { + t.Fatal(err) + } + echoLoop(ctx, c) + }() + } + + c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) + if err != nil { + t.Fatal(err) + } + c.Close(websocket.StatusNormalClosure, "") + + checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") +} + +func checkWSTestIndex(t *testing.T, path string) { + wstestOut, err := ioutil.ReadFile(path) + if err != nil { + t.Fatalf("failed to read index.json: %v", err) + } + + var indexJSON map[string]map[string]struct { + Behavior string `json:"behavior"` + BehaviorClose string `json:"behaviorClose"` + } + err = json.Unmarshal(wstestOut, &indexJSON) + if err != nil { + t.Fatalf("failed to unmarshal index.json: %v", err) + } + + var failed bool + for _, tests := range indexJSON { + for test, result := range tests { + switch result.Behavior { + case "OK", "NON-STRICT", "INFORMATIONAL": + default: + failed = true + t.Errorf("test %v failed", test) + } + switch result.BehaviorClose { + case "OK", "INFORMATIONAL": + default: + failed = true + t.Errorf("bad close behaviour for test %v", test) + } + } + } + + if failed { + path = strings.Replace(path, ".json", ".html", 1) + if os.Getenv("CI") == "" { + t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) + } + } +} diff --git a/close.go b/close.go index 4f48f1b..b1bc50e 100644 --- a/close.go +++ b/close.go @@ -5,7 +5,9 @@ import ( "encoding/binary" "errors" "fmt" - "nhooyr.io/websocket/internal/wsframe" + "log" + "nhooyr.io/websocket/internal/bufpool" + "time" ) // StatusCode represents a WebSocket status code. @@ -74,6 +76,87 @@ func CloseStatus(err error) StatusCode { return -1 } +// Close closes the WebSocket connection with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// Thus, it implements the full WebSocket close handshake. +// All data messages received from the peer during the close handshake +// will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes otherwise an internal +// error will be sent to the peer. For this reason, you should avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.closeHandshake(code, reason) + if err != nil { + return fmt.Errorf("failed to close websocket: %w", err) + } + return nil +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) error { + err := c.cw.sendClose(code, reason) + if err != nil { + return err + } + + return c.cr.waitClose() +} + +func (cw *connWriter) error(code StatusCode, err error) { + cw.c.setCloseErr(err) + cw.sendClose(code, err.Error()) + cw.c.close(nil) +} + +func (cw *connWriter) sendClose(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, + } + + cw.c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + var p []byte + if ce.Code != StatusNoStatusRcvd { + p = ce.bytes() + } + + return cw.control(context.Background(), opClose, p) +} + +func (cr *connReader) waitClose() error { + defer cr.c.close(nil) + + return nil + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := cr.mu.Lock(ctx) + if err != nil { + return err + } + defer cr.mu.Unlock() + + b := bufpool.Get() + buf := b.Bytes() + buf = buf[:cap(buf)] + defer bufpool.Put(b) + + for { + // TODO + return nil + } +} + func parseClosePayload(p []byte) (CloseError, error) { if len(p) == 0 { return CloseError{ @@ -81,14 +164,13 @@ func parseClosePayload(p []byte) (CloseError, error) { }, nil } - code, reason, err := wsframe.ParseClosePayload(p) - if err != nil { - return CloseError{}, err + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) } ce := CloseError{ - Code: StatusCode(code), - Reason: reason, + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), } if !validWireCloseCode(ce.Code) { @@ -116,11 +198,25 @@ func validWireCloseCode(code StatusCode) bool { return false } -func (ce CloseError) bytes() ([]byte, error) { - // TODO move check into frame write - if len(ce.Reason) > wsframe.MaxControlFramePayload-2 { - return nil, fmt.Errorf("reason string max is %v but got %q with length %v", wsframe.MaxControlFramePayload-2, ce.Reason, len(ce.Reason)) +func (ce CloseError) bytes() []byte { + p, err := ce.bytesErr() + if err != nil { + log.Printf("websocket: failed to marshal close frame: %+v", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() } + return p +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + if !validWireCloseCode(ce.Code) { return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } @@ -131,44 +227,16 @@ func (ce CloseError) bytes() ([]byte, error) { return buf, nil } -// CloseRead will start a goroutine to read from the connection until it is closed or a data message -// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. -// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. -// After calling this method, you cannot read any data messages from the connection. -// The returned context will be cancelled when the connection is closed. -// -// Use this when you do not want to read data messages from the connection anymore but will -// want to write messages to it. -func (c *Conn) CloseRead(ctx context.Context) context.Context { - c.isReadClosed.Store(1) - - ctx, cancel := context.WithCancel(ctx) - go func() { - defer cancel() - // We use the unexported reader method so that we don't get the read closed error. - c.reader(ctx, true) - // Either the connection is already closed since there was a read error - // or the context was cancelled or a message was read and we should close - // the connection. - c.Close(StatusPolicyViolation, "unexpected data message") - }() - return ctx -} - -// SetReadLimit sets the max number of bytes to read for a single message. -// It applies to the Reader and Read methods. -// -// By default, the connection has a message read limit of 32768 bytes. -// -// When the limit is hit, the connection will be closed with StatusMessageTooBig. -func (c *Conn) SetReadLimit(n int64) { - c.msgReadLimit.Store(n) +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrNoLock(err) + c.closeMu.Unlock() } -func (c *Conn) setCloseErr(err error) { - c.closeErrOnce.Do(func() { +func (c *Conn) setCloseErrNoLock(err error) { + if c.closeErr == nil { c.closeErr = fmt.Errorf("websocket closed: %w", err) - }) + } } func (c *Conn) isClosed() bool { diff --git a/close_test.go b/close_test.go index 78096d7..ee10cd3 100644 --- a/close_test.go +++ b/close_test.go @@ -5,7 +5,6 @@ import ( "io" "math" "nhooyr.io/websocket/internal/assert" - "nhooyr.io/websocket/internal/wsframe" "strings" "testing" ) @@ -22,7 +21,7 @@ func TestCloseError(t *testing.T) { name: "normal", ce: CloseError{ Code: StatusNormalClosure, - Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + Reason: strings.Repeat("x", maxCloseReason), }, success: true, }, @@ -30,7 +29,7 @@ func TestCloseError(t *testing.T) { name: "bigReason", ce: CloseError{ Code: StatusNormalClosure, - Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-1), + Reason: strings.Repeat("x", maxCloseReason+1), }, success: false, }, @@ -38,7 +37,7 @@ func TestCloseError(t *testing.T) { name: "bigCode", ce: CloseError{ Code: math.MaxUint16, - Reason: strings.Repeat("x", wsframe.MaxControlFramePayload-2), + Reason: strings.Repeat("x", maxCloseReason), }, success: false, }, @@ -49,7 +48,7 @@ func TestCloseError(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - _, err := tc.ce.bytes() + _, err := tc.ce.bytesErr() if (err == nil) != tc.success { t.Fatalf("unexpected error value: %+v", err) } diff --git a/compress.go b/compress.go index 5b5fdce..9e07543 100644 --- a/compress.go +++ b/compress.go @@ -3,7 +3,10 @@ package websocket import ( + "compress/flate" + "io" "net/http" + "sync" ) // CompressionMode controls the modes available RFC 7692's deflate extension. @@ -76,3 +79,86 @@ func (copts *compressionOptions) setHeader(h http.Header) { // we need to add them back otherwise flate.Reader keeps // trying to return more bytes. const deflateMessageTail = "\x00\x00\xff\xff" + +func (c *Conn) writeNoContextTakeOver() bool { + return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover +} + +func (c *Conn) readNoContextTakeOver() bool { + return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover +} + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + tw.tail = tw.tail[:0] +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + tw.tail = tw.tail[extra:] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReader(r) + } + fr.(flate.Resetter).Reset(r, nil) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +var flateWriterPool sync.Pool + +func getFlateWriter(w io.Writer) *flate.Writer { + fw, ok := flateWriterPool.Get().(*flate.Writer) + if !ok { + fw, _ = flate.NewWriter(w, flate.BestSpeed) + return fw + } + fw.Reset(w) + return fw +} + +func putFlateWriter(w *flate.Writer) { + flateWriterPool.Put(w) +} diff --git a/conn.go b/conn.go index 791d9b4..e3f2417 100644 --- a/conn.go +++ b/conn.go @@ -4,25 +4,14 @@ package websocket import ( "bufio" - "compress/flate" "context" - "crypto/rand" - "encoding/binary" "errors" "fmt" "io" - "io/ioutil" - "log" - "nhooyr.io/websocket/internal/atomicint" - "nhooyr.io/websocket/internal/wsframe" "runtime" "strconv" - "strings" "sync" "sync/atomic" - "time" - - "nhooyr.io/websocket/internal/bufpool" ) // MessageType represents the type of a WebSocket message. @@ -51,91 +40,54 @@ const ( // This applies to the Read methods in the wsjson/wspb subpackages as well. type Conn struct { subprotocol string - fw *flate.Writer - bw *bufio.Writer - // writeBuf is used for masking, its the buffer in bufio.Writer. - // Only used by the client for masking the bytes in the buffer. - writeBuf []byte - closer io.Closer - client bool - copts *compressionOptions - - closeOnce sync.Once - closeErrOnce sync.Once - closeErr error - closed chan struct{} - closing *atomicint.Int64 - closeReceived error + rwc io.ReadWriteCloser + client bool + copts *compressionOptions - // messageWriter state. - // writeMsgLock is acquired to write a data message. - writeMsgLock chan struct{} - // writeFrameLock is acquired to write a single frame. - // Effectively meaning whoever holds it gets to write to bw. - writeFrameLock chan struct{} - writeHeaderBuf []byte - writeHeader *header - // read limit for a message in bytes. - msgReadLimit *atomicint.Int64 + cr connReader + cw connWriter - // Used to ensure a previous writer is not used after being closed. - activeWriter atomic.Value - // messageWriter state. - writeMsgOpcode opcode - writeMsgCtx context.Context + closed chan struct{} - setReadTimeout chan context.Context - setWriteTimeout chan context.Context + closeMu sync.Mutex + closeErr error + closeHandshakeErr error - pingCounter *atomicint.Int64 + pingCounter int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} - - logf func(format string, v ...interface{}) } -func (c *Conn) init() { - c.closed = make(chan struct{}) - c.closing = &atomicint.Int64{} - - c.msgReadLimit = &atomicint.Int64{} - c.msgReadLimit.Store(32768) +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions - c.writeMsgLock = make(chan struct{}, 1) - c.writeFrameLock = make(chan struct{}, 1) + bw *bufio.Writer + br *bufio.Reader +} - c.readFrameLock = make(chan struct{}, 1) - c.readLock = make(chan struct{}, 1) - c.payloadReader = framePayloadReader{c} +func newConn(cfg connConfig) *Conn { + c := &Conn{} + c.subprotocol = cfg.subprotocol + c.rwc = cfg.rwc + c.client = cfg.client + c.copts = cfg.copts - c.setReadTimeout = make(chan context.Context) - c.setWriteTimeout = make(chan context.Context) + c.cr.init(c, cfg.br) + c.cw.init(c, cfg.bw) - c.pingCounter = &atomicint.Int64{} + c.closed = make(chan struct{}) c.activePings = make(map[string]chan<- struct{}) - c.writeHeaderBuf = makeWriteHeaderBuf() - c.writeHeader = &header{} - c.readHeaderBuf = makeReadHeaderBuf() - c.isReadClosed = &atomicint.Int64{} - c.controlPayloadBuf = make([]byte, maxControlFramePayload) - runtime.SetFinalizer(c, func(c *Conn) { c.close(errors.New("connection garbage collected")) }) - c.logf = log.Printf - - if c.copts != nil { - if !c.readNoContextTakeOver() { - c.fr = getFlateReader(c.payloadReader) - } - if !c.writeNoContextTakeOver() { - c.fw = getFlateWriter(c.bw) - } - } - go c.timeoutLoop() + + return c } // Subprotocol returns the negotiated subprotocol. @@ -145,38 +97,25 @@ func (c *Conn) Subprotocol() string { } func (c *Conn) close(err error) { - c.closeOnce.Do(func() { - runtime.SetFinalizer(c, nil) + c.closeMu.Lock() + defer c.closeMu.Unlock() - c.setCloseErr(err) - close(c.closed) - - // Have to close after c.closed is closed to ensure any goroutine that wakes up - // from the connection being closed also sees that c.closed is closed and returns - // closeErr. - c.closer.Close() + if c.isClosed() { + return + } + close(c.closed) + runtime.SetFinalizer(c, nil) + c.setCloseErrNoLock(err) - // By acquiring the locks, we ensure no goroutine will touch the bufio reader or writer - // and we can safely return them. - // Whenever a caller holds this lock and calls close, it ensures to release the lock to prevent - // a deadlock. - // As of now, this is in writeFrame, readFramePayload and readHeader. - c.readFrameLock <- struct{}{} - if c.client { - returnBufioReader(c.br) - } - if c.fr != nil { - putFlateReader(c.fr) - } + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() - c.writeFrameLock <- struct{}{} - if c.client { - returnBufioWriter(c.bw) - } - if c.fw != nil { - putFlateWriter(c.fw) - } - }) + go func() { + c.cr.close() + c.cw.close() + }() } func (c *Conn) timeoutLoop() { @@ -188,20 +127,13 @@ func (c *Conn) timeoutLoop() { case <-c.closed: return - case writeCtx = <-c.setWriteTimeout: - case readCtx = <-c.setReadTimeout: + case writeCtx = <-c.cw.timeout: + case readCtx = <-c.cr.timeout: case <-readCtx.Done(): c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) - // Guaranteed to eventually close the connection since we can only ever send - // one close frame. - go func() { - c.exportedClose(StatusPolicyViolation, "read timed out", true) - // Ensure the connection closes, i.e if we already sent a close frame and timed out - // to read the peer's close frame. - c.close(nil) - }() - readCtx = context.Background() + c.cw.error(StatusPolicyViolation, errors.New("timed out")) + return case <-writeCtx.Done(): c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) return @@ -209,843 +141,8 @@ func (c *Conn) timeoutLoop() { } } -func (c *Conn) acquireLock(ctx context.Context, lock chan struct{}) error { - select { - case <-ctx.Done(): - var err error - switch lock { - case c.writeFrameLock, c.writeMsgLock: - err = fmt.Errorf("could not acquire write lock: %v", ctx.Err()) - case c.readFrameLock, c.readLock: - err = fmt.Errorf("could not acquire read lock: %v", ctx.Err()) - default: - panic(fmt.Sprintf("websocket: failed to acquire unknown lock: %v", ctx.Err())) - } - c.close(err) - return ctx.Err() - case <-c.closed: - return c.closeErr - case lock <- struct{}{}: - return nil - } -} - -func (c *Conn) releaseLock(lock chan struct{}) { - // Allow multiple releases. - select { - case <-lock: - default: - } -} - -func (c *Conn) readTillMsg(ctx context.Context) (header, error) { - for { - h, err := c.readFrameHeader(ctx) - if err != nil { - return header{}, err - } - - if (h.rsv1 && (c.copts == nil || h.opcode.controlOp() || h.opcode == opContinuation)) || h.rsv2 || h.rsv3 { - err := fmt.Errorf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) - c.exportedClose(StatusProtocolError, err.Error(), false) - return header{}, err - } - - if h.opcode.controlOp() { - err = c.handleControl(ctx, h) - if err != nil { - // Pass through CloseErrors when receiving a close frame. - if h.opcode == opClose && CloseStatus(err) != -1 { - return header{}, err - } - return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) - } - continue - } - - switch h.opcode { - case opBinary, opText, opContinuation: - return h, nil - default: - err := fmt.Errorf("received unknown opcode %v", h.opcode) - c.exportedClose(StatusProtocolError, err.Error(), false) - return header{}, err - } - } -} - -func (c *Conn) readFrameHeader(ctx context.Context) (_ header, err error) { - wrap := func(err error) error { - return fmt.Errorf("failed to read frame header: %w", err) - } - defer func() { - if err != nil { - err = wrap(err) - } - }() - - err = c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return header{}, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return header{}, c.closeErr - case c.setReadTimeout <- ctx: - } - - h, err := readHeader(c.readHeaderBuf, c.br) - if err != nil { - select { - case <-c.closed: - return header{}, c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.releaseLock(c.readFrameLock) - c.close(wrap(err)) - return header{}, err - } - - select { - case <-c.closed: - return header{}, c.closeErr - case c.setReadTimeout <- context.Background(): - } - - return h, nil -} - -func (c *Conn) handleControl(ctx context.Context, h header) error { - if h.payloadLength > maxControlFramePayload { - err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) - c.exportedClose(StatusProtocolError, err.Error(), false) - return err - } - - if !h.fin { - err := errors.New("received fragmented control frame") - c.exportedClose(StatusProtocolError, err.Error(), false) - return err - } - - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - b := c.controlPayloadBuf[:h.payloadLength] - _, err := c.readFramePayload(ctx, b) - if err != nil { - return err - } - - if h.masked { - mask(h.maskKey, b) - } - - switch h.opcode { - case opPing: - return c.writeControl(ctx, opPong, b) - case opPong: - c.activePingsMu.Lock() - pong, ok := c.activePings[string(b)] - c.activePingsMu.Unlock() - if ok { - close(pong) - } - return nil - case opClose: - ce, err := parseClosePayload(b) - if err != nil { - err = fmt.Errorf("received invalid close payload: %w", err) - c.exportedClose(StatusProtocolError, err.Error(), false) - c.closeReceived = err - return err - } - - err = fmt.Errorf("received close: %w", ce) - c.closeReceived = err - c.writeClose(b, err, false) - - if ctx.Err() != nil { - // The above close probably has been returned by the peer in response - // to our read timing out so we have to return the read timed out error instead. - return fmt.Errorf("read timed out: %w", ctx.Err()) - } - - return err - default: - panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) - } -} - -// Reader waits until there is a WebSocket data message to read -// from the connection. -// It returns the type of the message and a reader to read it. -// The passed context will also bound the reader. -// Ensure you read to EOF otherwise the connection will hang. -// -// All returned errors will cause the connection -// to be closed so you do not need to write your own error message. -// This applies to the Read methods in the wsjson/wspb subpackages as well. -// -// You must read from the connection for control frames to be handled. -// Thus if you expect messages to take a long time to be responded to, -// you should handle such messages async to reading from the connection -// to ensure control frames are promptly handled. -// -// If you do not expect any data messages from the peer, call CloseRead. -// -// Only one Reader may be open at a time. -// -// If you need a separate timeout on the Reader call and then the message -// Read, use time.AfterFunc to cancel the context passed in early. -// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 -// Most users should not need this. -func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.isReadClosed.Load() == 1 { - return 0, nil, errors.New("websocket connection read closed") - } - - typ, r, err := c.reader(ctx, true) - if err != nil { - return 0, nil, fmt.Errorf("failed to get reader: %w", err) - } - return typ, r, nil -} - -func (c *Conn) reader(ctx context.Context, lock bool) (MessageType, io.Reader, error) { - if lock { - err := c.acquireLock(ctx, c.readLock) - if err != nil { - return 0, nil, err - } - defer c.releaseLock(c.readLock) - } - - if c.activeReader != nil && !c.readerFrameEOF { - // The only way we know for sure the previous reader is not yet complete is - // if there is an active frame not yet fully read. - // Otherwise, a user may have read the last byte but not the EOF if the EOF - // is in the next frame so we check for that below. - return 0, nil, errors.New("previous message not read to completion") - } - - h, err := c.readTillMsg(ctx) - if err != nil { - return 0, nil, err - } - - if c.activeReader != nil && !c.activeReader.eof() { - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, nil, err - } - - if !h.fin || h.payloadLength > 0 { - return 0, nil, fmt.Errorf("previous message not read to completion") - } - - c.activeReader = nil - - h, err = c.readTillMsg(ctx) - if err != nil { - return 0, nil, err - } - } else if h.opcode == opContinuation { - err := errors.New("received continuation frame not after data or text frame") - c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, nil, err - } - - c.readerMsgCtx = ctx - c.readerMsgHeader = h - - c.readerPayloadCompressed = h.rsv1 - - if c.readerPayloadCompressed { - c.readerCompressTail.Reset(deflateMessageTail) - } - - c.readerFrameEOF = false - c.readerMaskKey = h.maskKey - c.readMsgLeft = c.msgReadLimit.Load() - - r := &messageReader{ - c: c, - } - c.activeReader = r - if c.readerPayloadCompressed && c.readNoContextTakeOver() { - c.fr = getFlateReader(c.payloadReader) - } - return MessageType(h.opcode), r, nil -} - -type framePayloadReader struct { - c *Conn -} - -func (r framePayloadReader) Read(p []byte) (int, error) { - if r.c.readerFrameEOF { - if r.c.readerPayloadCompressed && r.c.readerMsgHeader.fin { - n, _ := r.c.readerCompressTail.Read(p) - return n, nil - } - - h, err := r.c.readTillMsg(r.c.readerMsgCtx) - if err != nil { - return 0, err - } - - if h.opcode != opContinuation { - err := errors.New("received new data message without finishing the previous message") - r.c.exportedClose(StatusProtocolError, err.Error(), false) - return 0, err - } - - r.c.readerMsgHeader = h - r.c.readerFrameEOF = false - r.c.readerMaskKey = h.maskKey - } - - h := r.c.readerMsgHeader - if int64(len(p)) > h.payloadLength { - p = p[:h.payloadLength] - } - - n, err := r.c.readFramePayload(r.c.readerMsgCtx, p) - - h.payloadLength -= int64(n) - if h.masked { - r.c.readerMaskKey = mask(r.c.readerMaskKey, p) - } - r.c.readerMsgHeader = h - - if err != nil { - return n, err - } - - if h.payloadLength == 0 { - r.c.readerFrameEOF = true - - if h.fin && !r.c.readerPayloadCompressed { - return n, io.EOF - } - } - - return n, nil -} - -// messageReader enables reading a data frame from the WebSocket connection. -type messageReader struct { - c *Conn -} - -func (r *messageReader) eof() bool { - return r.c.activeReader != r -} - -// Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (int, error) { - return r.exportedRead(p, true) -} - -func (r *messageReader) exportedRead(p []byte, lock bool) (int, error) { - n, err := r.read(p, lock) - if err != nil { - // Have to return io.EOF directly for now, we cannot wrap as errors.Is - // isn't used widely yet. - if errors.Is(err, io.EOF) { - return n, io.EOF - } - return n, fmt.Errorf("failed to read: %w", err) - } - return n, nil -} - -func (r *messageReader) readUnlocked(p []byte) (int, error) { - return r.exportedRead(p, false) -} - -func (r *messageReader) read(p []byte, lock bool) (int, error) { - if lock { - // If we cannot acquire the read lock, then - // there is either a concurrent read or the close handshake - // is proceeding. - select { - case r.c.readLock <- struct{}{}: - defer r.c.releaseLock(r.c.readLock) - default: - if r.c.closing.Load() == 1 { - <-r.c.closed - return 0, r.c.closeErr - } - return 0, errors.New("concurrent read detected") - } - } - - if r.eof() { - return 0, errors.New("cannot use EOFed reader") - } - - if r.c.readMsgLeft <= 0 { - err := fmt.Errorf("read limited at %v bytes", r.c.msgReadLimit) - r.c.exportedClose(StatusMessageTooBig, err.Error(), false) - return 0, err - } - - if int64(len(p)) > r.c.readMsgLeft { - p = p[:r.c.readMsgLeft] - } - - pr := io.Reader(r.c.payloadReader) - if r.c.readerPayloadCompressed { - pr = r.c.fr - } - - n, err := pr.Read(p) - - r.c.readMsgLeft -= int64(n) - - if r.c.readerFrameEOF && r.c.readerMsgHeader.fin { - if r.c.readerPayloadCompressed && r.c.readNoContextTakeOver() { - putFlateReader(r.c.fr) - r.c.fr = nil - } - r.c.activeReader = nil - if err == nil { - err = io.EOF - } - } - - return n, err -} - -func (c *Conn) readFramePayload(ctx context.Context, p []byte) (_ int, err error) { - wrap := func(err error) error { - return fmt.Errorf("failed to read frame payload: %w", err) - } - defer func() { - if err != nil { - err = wrap(err) - } - }() - - err = c.acquireLock(ctx, c.readFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.readFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setReadTimeout <- ctx: - } - - n, err := io.ReadFull(c.br, p) - if err != nil { - select { - case <-c.closed: - return n, c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - c.releaseLock(c.readFrameLock) - c.close(wrap(err)) - return n, err - } - - select { - case <-c.closed: - return n, c.closeErr - case c.setReadTimeout <- context.Background(): - } - - return n, err -} - -// Read is a convenience method to read a single message from the connection. -// -// See the Reader method if you want to be able to reuse buffers or want to stream a message. -// The docs on Reader apply to this method as well. -func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { - typ, r, err := c.Reader(ctx) - if err != nil { - return 0, nil, err - } - - b, err := ioutil.ReadAll(r) - return typ, b, err -} - -// Writer returns a writer bounded by the context that will write -// a WebSocket message of type dataType to the connection. -// -// You must close the writer once you have written the entire message. -// -// Only one writer can be open at a time, multiple calls will block until the previous writer -// is closed. -func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - wc, err := c.writer(ctx, typ) - if err != nil { - return nil, fmt.Errorf("failed to get writer: %w", err) - } - return wc, nil -} - -func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return nil, err - } - c.writeMsgCtx = ctx - c.writeMsgOpcode = opcode(typ) - w := &messageWriter{ - c: c, - } - c.activeWriter.Store(w) - return w, nil -} - -// Write is a convenience method to write a message to the connection. -// -// See the Writer method if you want to stream a message. -func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { - _, err := c.write(ctx, typ, p) - if err != nil { - return fmt.Errorf("failed to write msg: %w", err) - } - return nil -} - -func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { - err := c.acquireLock(ctx, c.writeMsgLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.writeMsgLock) - - n, err := c.writeFrame(ctx, true, opcode(typ), p) - return n, err -} - -// messageWriter enables writing to a WebSocket connection. -type messageWriter struct { - c *Conn -} - -func (w *messageWriter) closed() bool { - return w != w.c.activeWriter.Load() -} - -// Write writes the given bytes to the WebSocket connection. -func (w *messageWriter) Write(p []byte) (int, error) { - n, err := w.write(p) - if err != nil { - return n, fmt.Errorf("failed to write: %w", err) - } - return n, nil -} - -func (w *messageWriter) write(p []byte) (int, error) { - if w.closed() { - return 0, fmt.Errorf("cannot use closed writer") - } - n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p) - if err != nil { - return n, fmt.Errorf("failed to write data frame: %w", err) - } - w.c.writeMsgOpcode = opContinuation - return n, nil -} - -// Close flushes the frame to the connection. -// This must be called for every messageWriter. -func (w *messageWriter) Close() error { - err := w.close() - if err != nil { - return fmt.Errorf("failed to close writer: %w", err) - } - return nil -} - -func (w *messageWriter) close() error { - if w.closed() { - return fmt.Errorf("cannot use closed writer") - } - w.c.activeWriter.Store((*messageWriter)(nil)) - - _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) - if err != nil { - return fmt.Errorf("failed to write fin frame: %w", err) - } - - w.c.releaseLock(w.c.writeMsgLock) - return nil -} - -func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { - ctx, cancel := context.WithTimeout(ctx, time.Second*5) - defer cancel() - - _, err := c.writeFrame(ctx, true, opcode, p) - if err != nil { - return fmt.Errorf("failed to write control frame %v: %w", opcode, err) - } - return nil -} - -// writeFrame handles all writes to the connection. -func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { - err := c.acquireLock(ctx, c.writeFrameLock) - if err != nil { - return 0, err - } - defer c.releaseLock(c.writeFrameLock) - - select { - case <-c.closed: - return 0, c.closeErr - case c.setWriteTimeout <- ctx: - } - - c.writeHeader.fin = fin - c.writeHeader.opcode = opcode - c.writeHeader.masked = c.client - c.writeHeader.payloadLength = int64(len(p)) - - if c.client { - err = binary.Read(rand.Reader, binary.LittleEndian, &c.writeHeader.maskKey) - if err != nil { - return 0, fmt.Errorf("failed to generate masking key: %w", err) - } - } - - n, err := c.realWriteFrame(ctx, *c.writeHeader, p) - if err != nil { - return n, err - } - - // We already finished writing, no need to potentially brick the connection if - // the context expires. - select { - case <-c.closed: - return n, c.closeErr - case c.setWriteTimeout <- context.Background(): - } - - return n, nil -} - -func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error) { - defer func() { - if err != nil { - select { - case <-c.closed: - err = c.closeErr - case <-ctx.Done(): - err = ctx.Err() - default: - } - - err = fmt.Errorf("failed to write %v frame: %w", h.opcode, err) - // We need to release the lock first before closing the connection to ensure - // the lock can be acquired inside close to ensure no one can access c.bw. - c.releaseLock(c.writeFrameLock) - c.close(err) - } - }() - - headerBytes := writeHeader(c.writeHeaderBuf, h) - _, err = c.bw.Write(headerBytes) - if err != nil { - return 0, err - } - - if c.client { - maskKey := h.maskKey - for len(p) > 0 { - if c.bw.Available() == 0 { - err = c.bw.Flush() - if err != nil { - return n, err - } - } - - // Start of next write in the buffer. - i := c.bw.Buffered() - - p2 := p - if len(p) > c.bw.Available() { - p2 = p[:c.bw.Available()] - } - - n2, err := c.bw.Write(p2) - if err != nil { - return n, err - } - - maskKey = mask(maskKey, c.writeBuf[i:i+n2]) - - p = p[n2:] - n += n2 - } - } else { - n, err = c.bw.Write(p) - if err != nil { - return n, err - } - } - - if h.fin { - err = c.bw.Flush() - if err != nil { - return n, err - } - } - - return n, nil -} - -// Close closes the WebSocket connection with the given status code and reason. -// -// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for -// the peer to send a close frame. -// Thus, it implements the full WebSocket close handshake. -// All data messages received from the peer during the close handshake -// will be discarded. -// -// The connection can only be closed once. Additional calls to Close -// are no-ops. -// -// The maximum length of reason must be 125 bytes otherwise an internal -// error will be sent to the peer. For this reason, you should avoid -// sending a dynamic reason. -// -// Close will unblock all goroutines interacting with the connection once -// complete. -func (c *Conn) Close(code StatusCode, reason string) error { - err := c.exportedClose(code, reason, true) - var ec errClosing - if errors.As(err, &ec) { - <-c.closed - // We wait until the connection closes. - // We use writeClose and not exportedClose to avoid a second failed to marshal close frame error. - err = c.writeClose(nil, ec.ce, true) - } - if err != nil { - return fmt.Errorf("failed to close websocket connection: %w", err) - } - return nil -} - -func (c *Conn) exportedClose(code StatusCode, reason string, handshake bool) error { - ce := CloseError{ - Code: code, - Reason: reason, - } - - // This function also will not wait for a close frame from the peer like the RFC - // wants because that makes no sense and I don't think anyone actually follows that. - // Definitely worth seeing what popular browsers do later. - p, err := ce.bytes() - if err != nil { - c.logf("websocket: failed to marshal close frame: %+v", err) - ce = CloseError{ - Code: StatusInternalError, - } - p, _ = ce.bytes() - } - - return c.writeClose(p, fmt.Errorf("sent close: %w", ce), handshake) -} - -type errClosing struct { - ce error -} - -func (e errClosing) Error() string { - return "already closing connection" -} - -func (c *Conn) writeClose(p []byte, ce error, handshake bool) error { - if c.isClosed() { - return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) - } - - if !c.closing.CAS(0, 1) { - // Normally, we would want to wait until the connection is closed, - // at least for when a user calls into Close, so we handle that case in - // the exported Close function. - // - // But for internal library usage, we always want to return early, e.g. - // if we are performing a close handshake and the peer sends their close frame, - // we do not want to block here waiting for c.closed to close because it won't, - // at least not until we return since the gorouine that will close it is this one. - return errClosing{ - ce: ce, - } - } - - // No matter what happens next, close error should be set. - c.setCloseErr(ce) - defer c.close(nil) - - err := c.writeControl(context.Background(), opClose, p) - if err != nil { - return err - } - - if handshake { - err = c.waitClose() - if CloseStatus(err) == -1 { - // waitClose exited not due to receiving a close frame. - return fmt.Errorf("failed to wait for peer close frame: %w", err) - } - } - - return nil -} - -func (c *Conn) waitClose() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) - defer cancel() - - err := c.acquireLock(ctx, c.readLock) - if err != nil { - return err - } - defer c.releaseLock(c.readLock) - - if c.closeReceived != nil { - // goroutine reading just received the close. - return c.closeReceived - } - - b := bufpool.Get() - buf := b.Bytes() - buf = buf[:cap(buf)] - defer bufpool.Put(b) - - for { - if c.activeReader == nil || c.readerFrameEOF { - _, _, err := c.reader(ctx, false) - if err != nil { - return fmt.Errorf("failed to get reader: %w", err) - } - } - - r := readerFunc(c.activeReader.readUnlocked) - _, err = io.CopyBuffer(ioutil.Discard, r, buf) - if err != nil { - return err - } - } +func (c *Conn) deflateNegotiated() bool { + return c.copts != nil } // Ping sends a ping to the peer and waits for a pong. @@ -1056,9 +153,9 @@ func (c *Conn) waitClose() error { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := c.pingCounter.Increment(1) + p := atomic.AddInt32(&c.pingCounter, 1) - err := c.ping(ctx, strconv.FormatInt(p, 10)) + err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { return fmt.Errorf("failed to ping: %w", err) } @@ -1078,7 +175,7 @@ func (c *Conn) ping(ctx context.Context, p string) error { c.activePingsMu.Unlock() }() - err := c.writeControl(ctx, opPing, []byte(p)) + err := c.cw.control(ctx, opPing, []byte(p)) if err != nil { return err } @@ -1095,109 +192,37 @@ func (c *Conn) ping(ctx context.Context, p string) error { } } -type readerFunc func(p []byte) (int, error) - -func (f readerFunc) Read(p []byte) (int, error) { - return f(p) -} - -type writerFunc func(p []byte) (int, error) - -func (f writerFunc) Write(p []byte) (int, error) { - return f(p) -} - -// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer -// and stores it in c.writeBuf. -func (c *Conn) extractBufioWriterBuf(w io.Writer) { - c.bw.Reset(writerFunc(func(p2 []byte) (int, error) { - c.writeBuf = p2[:cap(p2)] - return len(p2), nil - })) - - c.bw.WriteByte(0) - c.bw.Flush() - - c.bw.Reset(w) -} - -var flateWriterPool = &sync.Pool{ - New: func() interface{} { - w, _ := flate.NewWriter(nil, flate.BestSpeed) - return w - }, -} - -func getFlateWriter(w io.Writer) *flate.Writer { - fw := flateWriterPool.Get().(*flate.Writer) - fw.Reset(w) - return fw -} - -func putFlateWriter(w *flate.Writer) { - flateWriterPool.Put(w) +type mu struct { + once sync.Once + ch chan struct{} } -var flateReaderPool = &sync.Pool{ - New: func() interface{} { - return flate.NewReader(nil) - }, -} - -func getFlateReader(r io.Reader) io.Reader { - fr := flateReaderPool.Get().(io.Reader) - fr.(flate.Resetter).Reset(r, nil) - return fr -} - -func putFlateReader(fr io.Reader) { - flateReaderPool.Put(fr) -} - -func (c *Conn) writeNoContextTakeOver() bool { - return c.client && c.copts.clientNoContextTakeover || !c.client && c.copts.serverNoContextTakeover -} - -func (c *Conn) readNoContextTakeOver() bool { - return !c.client && c.copts.clientNoContextTakeover || c.client && c.copts.serverNoContextTakeover -} - -type trimLastFourBytesWriter struct { - w io.Writer - tail []byte +func (m *mu) init() { + m.once.Do(func() { + m.ch = make(chan struct{}, 1) + }) } -func (w *trimLastFourBytesWriter) Write(p []byte) (int, error) { - extra := len(w.tail) + len(p) - 4 - - if extra <= 0 { - w.tail = append(w.tail, p...) - return len(p), nil - } - - // Now we need to write as many extra bytes as we can from the previous tail. - if extra > len(w.tail) { - extra = len(w.tail) - } - if extra > 0 { - _, err := w.Write(w.tail[:extra]) - if err != nil { - return 0, err - } - w.tail = w.tail[extra:] +func (m *mu) Lock(ctx context.Context) error { + m.init() + select { + case <-ctx.Done(): + return ctx.Err() + case m.ch <- struct{}{}: + return nil } +} - // If p is less than or equal to 4 bytes, - // all of it is is part of the tail. - if len(p) <= 4 { - w.tail = append(w.tail, p...) - return len(p), nil +func (m *mu) TryLock() bool { + m.init() + select { + case m.ch <- struct{}{}: + return true + default: + return false } +} - // Otherwise, only the last 4 bytes are. - w.tail = append(w.tail, p[len(p)-4:]...) - - p = p[:len(p)-4] - n, err := w.w.Write(p) - return n + 4, err +func (m *mu) Unlock() { + <-m.ch } diff --git a/conn_export_test.go b/conn_export_test.go deleted file mode 100644 index d5f5aa2..0000000 --- a/conn_export_test.go +++ /dev/null @@ -1,129 +0,0 @@ -// +build !js - -package websocket - -import ( - "bufio" - "context" - "fmt" -) - -type ( - Addr = websocketAddr - OpCode int -) - -const ( - OpClose = OpCode(opClose) - OpBinary = OpCode(opBinary) - OpText = OpCode(opText) - OpPing = OpCode(opPing) - OpPong = OpCode(opPong) - OpContinuation = OpCode(opContinuation) -) - -func (c *Conn) SetLogf(fn func(format string, v ...interface{})) { - c.logf = fn -} - -func (c *Conn) ReadFrame(ctx context.Context) (OpCode, []byte, error) { - h, err := c.readFrameHeader(ctx) - if err != nil { - return 0, nil, err - } - b := make([]byte, h.payloadLength) - _, err = c.readFramePayload(ctx, b) - if err != nil { - return 0, nil, err - } - if h.masked { - mask(h.maskKey, b) - } - return OpCode(h.opcode), b, nil -} - -func (c *Conn) WriteFrame(ctx context.Context, fin bool, opc OpCode, p []byte) (int, error) { - return c.writeFrame(ctx, fin, opcode(opc), p) -} - -// header represents a WebSocket frame header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type Header struct { - Fin bool - Rsv1 bool - Rsv2 bool - Rsv3 bool - OpCode OpCode - - PayloadLength int64 -} - -func (c *Conn) WriteHeader(ctx context.Context, h Header) error { - headerBytes := writeHeader(c.writeHeaderBuf, header{ - fin: h.Fin, - rsv1: h.Rsv1, - rsv2: h.Rsv2, - rsv3: h.Rsv3, - opcode: opcode(h.OpCode), - payloadLength: h.PayloadLength, - masked: c.client, - }) - _, err := c.bw.Write(headerBytes) - if err != nil { - return fmt.Errorf("failed to write header: %w", err) - } - if h.Fin { - err = c.Flush() - if err != nil { - return err - } - } - return nil -} - -func (c *Conn) PingWithPayload(ctx context.Context, p string) error { - return c.ping(ctx, p) -} - -func (c *Conn) WriteHalfFrame(ctx context.Context) (int, error) { - return c.realWriteFrame(ctx, header{ - fin: true, - opcode: opBinary, - payloadLength: 10, - }, make([]byte, 5)) -} - -func (c *Conn) CloseUnderlyingConn() { - c.closer.Close() -} - -func (c *Conn) Flush() error { - return c.bw.Flush() -} - -func (c CloseError) Bytes() ([]byte, error) { - return c.bytes() -} - -func (c *Conn) BW() *bufio.Writer { - return c.bw -} - -func (c *Conn) WriteClose(ctx context.Context, code StatusCode, reason string) ([]byte, error) { - b, err := CloseError{ - Code: code, - Reason: reason, - }.Bytes() - if err != nil { - return nil, err - } - _, err = c.WriteFrame(ctx, true, OpClose, b) - if err != nil { - return nil, err - } - return b, nil -} - -func ParseClosePayload(p []byte) (CloseError, error) { - return parseClosePayload(p) -} diff --git a/conn_test.go b/conn_test.go index d03a721..992c886 100644 --- a/conn_test.go +++ b/conn_test.go @@ -3,969 +3,28 @@ package websocket_test import ( - "bytes" "context" - "encoding/binary" - "encoding/json" - "errors" "fmt" "io" - "io/ioutil" - "math/rand" - "net" "net/http" - "net/http/cookiejar" "net/http/httptest" - "net/url" - "os" - "os/exec" - "reflect" - "strconv" + "nhooyr.io/websocket/internal/assert" "strings" + "sync/atomic" "testing" "time" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/timestamp" - "go.uber.org/multierr" - "nhooyr.io/websocket" - "nhooyr.io/websocket/internal/assert" - "nhooyr.io/websocket/internal/wsecho" - "nhooyr.io/websocket/internal/wsgrace" - "nhooyr.io/websocket/wsjson" - "nhooyr.io/websocket/wspb" ) -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func TestHandshake(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - client func(ctx context.Context, url string) error - server func(w http.ResponseWriter, r *http.Request) error - }{ - { - name: "badOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err == nil { - c.Close(websocket.StatusInternalError, "") - return errors.New("expected error regarding bad origin") - } - return assertErrorContains(err, "not authorized") - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "http://unauthorized.com") - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err == nil { - c.Close(websocket.StatusInternalError, "") - return errors.New("expected handshake failure") - } - return assertErrorContains(err, "403") - }, - }, - { - name: "acceptSecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", u) - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "acceptInsecureOrigin", - server: func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - h := http.Header{} - h.Set("Origin", "https://example.com") - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPHeader: h, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - { - name: "cookies", - server: func(w http.ResponseWriter, r *http.Request) error { - cookie, err := r.Cookie("mycookie") - if err != nil { - return fmt.Errorf("request is missing mycookie: %w", err) - } - err = assert.Equalf("myvalue", cookie.Value, "unexpected cookie value") - if err != nil { - return err - } - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - client: func(ctx context.Context, u string) error { - jar, err := cookiejar.New(nil) - if err != nil { - return fmt.Errorf("failed to create cookie jar: %w", err) - } - parsedURL, err := url.Parse(u) - if err != nil { - return fmt.Errorf("failed to parse url: %w", err) - } - parsedURL.Scheme = "http" - jar.SetCookies(parsedURL, []*http.Cookie{ - { - Name: "mycookie", - Value: "myvalue", - }, - }) - hc := &http.Client{ - Jar: jar, - } - c, _, err := websocket.Dial(ctx, u, &websocket.DialOptions{ - HTTPClient: hc, - }) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, - }, - } - - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, tc.server, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - err := tc.client(ctx, wsURL) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - }) - } -} - -func TestConn(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - - acceptOpts *websocket.AcceptOptions - server func(ctx context.Context, c *websocket.Conn) error - - dialOpts *websocket.DialOptions - response func(resp *http.Response) error - client func(ctx context.Context, c *websocket.Conn) error - }{ - { - name: "handshake", - acceptOpts: &websocket.AcceptOptions{ - Subprotocols: []string{"myproto"}, - }, - dialOpts: &websocket.DialOptions{ - Subprotocols: []string{"myproto"}, - }, - response: func(resp *http.Response) error { - headers := map[string]string{ - "Connection": "Upgrade", - "Upgrade": "websocket", - "Sec-WebSocket-Protocol": "myproto", - } - for h, exp := range headers { - value := resp.Header.Get(h) - err := assert.Equalf(exp, value, "unexpected value for header %v", h) - if err != nil { - return err - } - } - return nil - }, - }, - { - name: "handshake/defaultSubprotocol", - server: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "") - }, - }, - { - name: "handshake/subprotocolPriority", - acceptOpts: &websocket.AcceptOptions{ - Subprotocols: []string{"echo", "lar"}, - }, - server: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "echo") - }, - dialOpts: &websocket.DialOptions{ - Subprotocols: []string{"poof", "echo"}, - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertSubprotocol(c, "echo") - }, - }, - { - name: "closeError", - server: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "hello") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := assertJSONRead(ctx, c, "hello") - if err != nil { - return err - } - - _, _, err = c.Reader(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "netConn", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - defer nc.Close() - - nc.SetWriteDeadline(time.Time{}) - time.Sleep(1) - nc.SetWriteDeadline(time.Now().Add(time.Second * 15)) - - err := assert.Equalf(websocket.Addr{}, nc.LocalAddr(), "net conn local address is not equal to websocket.Addr") - if err != nil { - return err - } - err = assert.Equalf(websocket.Addr{}, nc.RemoteAddr(), "net conn remote address is not equal to websocket.Addr") - if err != nil { - return err - } - - for i := 0; i < 3; i++ { - _, err := nc.Write([]byte("hello")) - if err != nil { - return err - } - } - - return nil - }, - client: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - - nc.SetReadDeadline(time.Time{}) - time.Sleep(1) - nc.SetReadDeadline(time.Now().Add(time.Second * 15)) - - for i := 0; i < 3; i++ { - err := assertNetConnRead(nc, "hello") - if err != nil { - return err - } - } - - // Ensure the close frame is converted to an EOF and multiple read's after all return EOF. - err2 := assertNetConnRead(nc, "hello") - err := assert.Equalf(io.EOF, err2, "unexpected error") - if err != nil { - return err - } - - err2 = assertNetConnRead(nc, "hello") - return assert.Equalf(io.EOF, err2, "unexpected error") - }, - }, - { - name: "netConn/badReadMsgType", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - - nc.SetDeadline(time.Now().Add(time.Second * 15)) - - _, err := nc.Read(make([]byte, 1)) - return assertErrorContains(err, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, "meow") - if err != nil { - return err - } - - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusUnsupportedData) - }, - }, - { - name: "netConn/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - nc := websocket.NetConn(ctx, c, websocket.MessageBinary) - defer nc.Close() - - nc.SetDeadline(time.Now().Add(time.Second * 15)) - - _, err2 := nc.Read(make([]byte, 1)) - err := assertCloseStatus(err2, websocket.StatusBadGateway) - if err != nil { - return err - } - - _, err2 = nc.Write([]byte{0xff}) - return assertErrorContains(err2, "websocket closed") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusBadGateway, "") - }, - }, - { - name: "wsjson/echo", - server: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "meow") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertJSONRead(ctx, c, "meow") - }, - }, - { - name: "protobuf/echo", - server: func(ctx context.Context, c *websocket.Conn) error { - return wspb.Write(ctx, c, ptypes.DurationProto(100)) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return assertProtobufRead(ctx, c, ptypes.DurationProto(100)) - }, - }, - { - name: "ping", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - err := c.Ping(ctx) - if err != nil { - return err - } - - err = wsjson.Write(ctx, c, "hi") - if err != nil { - return err - } - - <-ctx.Done() - err = c.Ping(context.Background()) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - // We read a message from the connection and then keep reading until - // the Ping completes. - pingErrc := make(chan error, 1) - go func() { - pingErrc <- c.Ping(ctx) - }() - - // Once this completes successfully, that means they sent their ping and we responded to it. - err := assertJSONRead(ctx, c, "hi") - if err != nil { - return err - } - - // Now we need to ensure we're reading for their pong from our ping. - // Need new var to not race with above goroutine. - ctx2 := c.CloseRead(ctx) - - // Now we wait for our pong. - select { - case err = <-pingErrc: - return err - case <-ctx2.Done(): - return fmt.Errorf("failed to wait for pong: %w", ctx2.Err()) - } - }, - }, - { - name: "readLimit", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err2 := c.Read(ctx) - return assertErrorContains(err2, "read limited at 32768 bytes") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 32769))) - if err != nil { - return err - } - - _, _, err2 := c.Read(ctx) - return assertCloseStatus(err2, websocket.StatusMessageTooBig) - }, - }, - { - name: "wsjson/binary", - server: func(ctx context.Context, c *websocket.Conn) error { - var v interface{} - err2 := wsjson.Read(ctx, c, &v) - return assertErrorContains(err2, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return wspb.Write(ctx, c, ptypes.DurationProto(100)) - }, - }, - { - name: "wsjson/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - var v interface{} - err2 := wsjson.Read(ctx, c, &v) - return assertErrorContains(err2, "failed to unmarshal json") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageText, []byte("notjson")) - }, - }, - { - name: "wsjson/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err2 := c.Read(ctx) - return assertCloseStatus(err2, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, fmt.Println) - return assertErrorContains(err, "failed to encode json") - }, - }, - { - name: "wspb/text", - server: func(ctx context.Context, c *websocket.Conn) error { - var v proto.Message - err := wspb.Read(ctx, c, v) - return assertErrorContains(err, "unexpected frame type") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return wsjson.Write(ctx, c, "hi") - }, - }, - { - name: "wspb/badRead", - server: func(ctx context.Context, c *websocket.Conn) error { - var v timestamp.Timestamp - err := wspb.Read(ctx, c, &v) - return assertErrorContains(err, "failed to unmarshal protobuf") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageBinary, []byte("notpb")) - }, - }, - { - name: "wspb/badWrite", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wspb.Write(ctx, c, nil) - return assertErrorIs(proto.ErrNil, err) - }, - }, - { - name: "badClose", - server: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(9999, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "pingTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - err := c.Ping(ctx) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - err1 := assertErrorContains(err, "connection reset") - err2 := assertErrorIs(io.EOF, err) - if err1 != nil || err2 != nil { - return nil - } - return multierr.Combine(err1, err2) - }, - }, - { - name: "writeTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - c.Writer(ctx, websocket.MessageBinary) - - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - err := c.Write(ctx, websocket.MessageBinary, []byte("meow")) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorIs(io.EOF, err) - }, - }, - { - name: "readTimeout", - server: func(ctx context.Context, c *websocket.Conn) error { - ctx, cancel := context.WithTimeout(ctx, time.Second) - defer cancel() - _, _, err := c.Read(ctx) - return assertErrorIs(context.DeadlineExceeded, err) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorIs(websocket.CloseError{ - Code: websocket.StatusPolicyViolation, - Reason: "read timed out", - }, err) - }, - }, - { - name: "badOpCode", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, 13, []byte("meow")) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertErrorContains(err, "unknown opcode") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "unknown opcode") - }, - }, - { - name: "noRsv", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, 99, []byte("meow")) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - if err == nil || !strings.Contains(err.Error(), "rsv") { - return fmt.Errorf("expected error that contains rsv: %+v", err) - } - return nil - }, - }, - { - name: "largeControlFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - err := c.WriteHeader(ctx, websocket.Header{ - Fin: true, - OpCode: websocket.OpClose, - PayloadLength: 4096, - }) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "too big") - }, - }, - { - name: "fragmentedControlFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OpPing, []byte(strings.Repeat("x", 32))) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "fragmented") - }, - }, - { - name: "invalidClosePayload", - server: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{0x17, 0x70}) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "invalid status code") - }, - }, - { - name: "doubleReader", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "previous message not read to completion") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := c.Write(ctx, websocket.MessageBinary, []byte(strings.Repeat("x", 11))) - if err != nil { - return err - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "doubleFragmentedReader", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "previous message not read to completion") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusInternalError) - }, - }, - { - name: "newMessageInFragmentedMessage", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 10) - _, err = io.ReadFull(r, p) - if err != nil { - return err - } - _, _, err = c.Reader(ctx) - return assertErrorContains(err, "received new data message without finishing") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - _, _, err = c.Read(ctx) - return assertErrorContains(err, "received new data message without finishing") - }, - }, - { - name: "continuationFrameWithoutDataFrame", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Reader(ctx) - return assertErrorContains(err, "received continuation frame not after data") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, false, websocket.OpContinuation, []byte(strings.Repeat("x", 10))) - return err - }, - }, - { - name: "readBeforeEOF", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - var v interface{} - d := json.NewDecoder(r) - err = d.Decode(&v) - if err != nil { - return err - } - err = assert.Equalf("hi", v, "unexpected JSON") - if err != nil { - return err - } - _, b, err := c.Read(ctx) - if err != nil { - return err - } - return assert.Equalf("hi", string(b), "unexpected JSON") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - err := wsjson.Write(ctx, c, "hi") - if err != nil { - return err - } - return c.Write(ctx, websocket.MessageText, []byte("hi")) - }, - }, - { - name: "newMessageInFragmentedMessage2", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - p := make([]byte, 11) - _, err = io.ReadFull(r, p) - return assertErrorContains(err, "received new data message without finishing") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageBinary) - if err != nil { - return err - } - _, err = w.Write([]byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - err = c.Flush() - if err != nil { - return fmt.Errorf("failed to flush: %w", err) - } - _, err = c.WriteFrame(ctx, true, websocket.OpBinary, []byte(strings.Repeat("x", 10))) - if err != nil { - return fmt.Errorf("expected non nil error") - } - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - }, - }, - { - name: "doubleRead", - server: func(ctx context.Context, c *websocket.Conn) error { - _, r, err := c.Reader(ctx) - if err != nil { - return err - } - _, err = ioutil.ReadAll(r) - if err != nil { - return err - } - _, err = r.Read(make([]byte, 1)) - return assertErrorContains(err, "cannot use EOFed reader") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Write(ctx, websocket.MessageBinary, []byte("hi")) - }, - }, - { - name: "eofInPayload", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertErrorContains(err, "failed to read frame payload") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteHalfFrame(ctx) - if err != nil { - return err - } - c.CloseUnderlyingConn() - return nil - }, - }, - { - name: "closeHandshake", - server: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }, - client: func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }, - }, - { - // Issue #164 - name: "closeHandshake_concurrentRead", - server: func(ctx context.Context, c *websocket.Conn) error { - _, _, err := c.Read(ctx) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - client: func(ctx context.Context, c *websocket.Conn) error { - errc := make(chan error, 1) - go func() { - _, _, err := c.Read(ctx) - errc <- err - }() - - err := c.Close(websocket.StatusNormalClosure, "") - if err != nil { - return err - } - - err = <-errc - return assertCloseStatus(err, websocket.StatusNormalClosure) - }, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Run random tests over TLS. - tls := rand.Intn(2) == 1 - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, tc.acceptOpts) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - c.SetLogf(t.Logf) - if tc.server == nil { - return nil - } - return tc.server(r.Context(), c) - }, tls) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) - defer cancel() - - opts := tc.dialOpts - if tls { - if opts == nil { - opts = &websocket.DialOptions{} - } - opts.HTTPClient = s.Client() - } - - c, resp, err := websocket.Dial(ctx, wsURL, opts) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - c.SetLogf(t.Logf) - - if tc.response != nil { - err = tc.response(resp) - if err != nil { - t.Fatalf("response asserter failed: %+v", err) - } - } - - if tc.client != nil { - err = tc.client(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - } - - c.Close(websocket.StatusNormalClosure, "") - }) - } -} - -func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) error, tls bool) (s *httptest.Server, closeFn func()) { - h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - err := fn(w, r) - if err != nil { - tb.Errorf("server failed: %+v", err) - } - }) +func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request), tls bool) (s *httptest.Server, closeFn func()) { + h := http.HandlerFunc(fn) if tls { s = httptest.NewTLSServer(h) } else { s = httptest.NewServer(h) } - closeFn2 := wsgrace.Grace(s.Config) + closeFn2 := wsgrace(s.Config) return s, func() { err := closeFn2() if err != nil { @@ -974,1417 +33,112 @@ func testServer(tb testing.TB, fn func(w http.ResponseWriter, r *http.Request) e } } -func TestAutobahn(t *testing.T) { - t.Parallel() - - run := func(t *testing.T, name string, fn func(ctx context.Context, c *websocket.Conn) error) { - run2 := func(t *testing.T, testingClient bool) { - // Run random tests over TLS. - tls := rand.Intn(2) == 1 - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - ctx := r.Context() - if testingClient { - err = wsecho.Loop(ctx, c) - if err != nil { - t.Logf("failed to wsecho: %+v", err) - } - return nil - } - - c.SetReadLimit(1 << 30) - err = fn(ctx, c) - if err != nil { - return err - } - c.Close(websocket.StatusNormalClosure, "") - return nil - }, tls) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - opts := &websocket.DialOptions{ - Subprotocols: []string{"echo"}, - } - if tls { - opts.HTTPClient = s.Client() - } - - c, _, err := websocket.Dial(ctx, wsURL, opts) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - if testingClient { - c.SetReadLimit(1 << 30) - err = fn(ctx, c) - if err != nil { - t.Fatalf("client failed: %+v", err) - } - c.Close(websocket.StatusNormalClosure, "") - return - } - - err = wsecho.Loop(ctx, c) - if err != nil { - t.Logf("failed to wsecho: %+v", err) - } - } - t.Run(name, func(t *testing.T) { - t.Parallel() +// grace wraps s.Handler to gracefully shutdown WebSocket connections. +// The returned function must be used to close the server instead of s.Close. +func wsgrace(s *http.Server) (closeFn func() error) { + h := s.Handler + var conns int64 + s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&conns, 1) + defer atomic.AddInt64(&conns, -1) - run2(t, true) - }) - } - - // Section 1. - t.Run("echo", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 0, - 125, - 126, - 127, - 128, - 65535, - 65536, - 65536, - } - run := func(typ websocket.MessageType) { - for i, l := range lengths { - l := l - run(t, fmt.Sprintf("%v/%v", typ, l), func(ctx context.Context, c *websocket.Conn) error { - p := randBytes(l) - if i == len(lengths)-1 { - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - for i := 0; i < l; { - j := i + 997 - if j > l { - j = l - } - _, err = w.Write(p[i:j]) - if err != nil { - return err - } + ctx, cancel := context.WithTimeout(r.Context(), time.Second*5) + defer cancel() - i = j - } + r = r.WithContext(ctx) - err = w.Close() - if err != nil { - return err - } - } else { - err := c.Write(ctx, typ, p) - if err != nil { - return err - } - } - actTyp, p2, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(typ, actTyp, "unexpected message type") - if err != nil { - return err - } - return assert.Equalf(p, p2, "unexpected message") - }) - } - } - - run(websocket.MessageText) - run(websocket.MessageBinary) + h.ServeHTTP(w, r) }) - // Section 2. - t.Run("pingPong", func(t *testing.T) { - t.Parallel() - - run(t, "emptyPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - return c.PingWithPayload(ctx, "") - }) - run(t, "smallTextPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - return c.PingWithPayload(ctx, "hi") - }) - run(t, "smallBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 16) - return c.PingWithPayload(ctx, string(p)) - }) - run(t, "largeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 125) - return c.PingWithPayload(ctx, string(p)) - }) - run(t, "tooLargeBinaryPayload", func(ctx context.Context, c *websocket.Conn) error { - c.CloseRead(ctx) - p := bytes.Repeat([]byte{0xFE}, 126) - err := c.PingWithPayload(ctx, string(p)) - return assertCloseStatus(err, websocket.StatusProtocolError) - }) - run(t, "streamPingPayload", func(ctx context.Context, c *websocket.Conn) error { - err := assertStreamPing(ctx, c, 125) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - t.Run("unsolicitedPong", func(t *testing.T) { - t.Parallel() - - var testCases = []struct { - name string - pongPayload string - ping bool - }{ - { - name: "noPayload", - pongPayload: "", - }, - { - name: "payload", - pongPayload: "hi", - }, - { - name: "pongThenPing", - pongPayload: "hi", - ping: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpPong, []byte(tc.pongPayload)) - if err != nil { - return err - } - if tc.ping { - _, err := c.WriteFrame(ctx, true, websocket.OpPing, []byte("meow")) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, []byte("meow")) - if err != nil { - return err - } - } - return c.Close(websocket.StatusNormalClosure, "") - }) - } - }) - run(t, "tenPings", func(ctx context.Context, c *websocket.Conn) error { - ctx = c.CloseRead(ctx) - - for i := 0; i < 10; i++ { - err := c.Ping(ctx) - if err != nil { - return err - } - } + return func() error { + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() - _, err := c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - <-ctx.Done() - - err = c.Ping(context.Background()) - return assertCloseStatus(err, websocket.StatusNormalClosure) - }) - - run(t, "tenStreamedPings", func(ctx context.Context, c *websocket.Conn) error { - for i := 0; i < 10; i++ { - err := assertStreamPing(ctx, c, 125) - if err != nil { - return err - } - } - - return c.Close(websocket.StatusNormalClosure, "") - }) - }) - - // Section 3. - // We skip the per octet sending as it will add too much complexity. - t.Run("reserved", func(t *testing.T) { - t.Parallel() - - var testCases = []struct { - name string - header websocket.Header - }{ - { - name: "rsv1", - header: websocket.Header{ - Fin: true, - Rsv1: true, - OpCode: websocket.OpClose, - PayloadLength: 0, - }, - }, - { - name: "rsv2", - header: websocket.Header{ - Fin: true, - Rsv2: true, - OpCode: websocket.OpPong, - PayloadLength: 0, - }, - }, - { - name: "rsv3", - header: websocket.Header{ - Fin: true, - Rsv3: true, - OpCode: websocket.OpBinary, - PayloadLength: 0, - }, - }, - { - name: "rsvAll", - header: websocket.Header{ - Fin: true, - Rsv1: true, - Rsv2: true, - Rsv3: true, - OpCode: websocket.OpText, - PayloadLength: 0, - }, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - err := assertEcho(ctx, c, websocket.MessageText, 4096) - if err != nil { - return err - } - err = c.WriteHeader(ctx, tc.header) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - - // Section 4. - t.Run("opcodes", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - opcode websocket.OpCode - payload bool - echo bool - ping bool - }{ - // Section 1. - { - name: "3", - opcode: 3, - }, - { - name: "4", - opcode: 4, - payload: true, - }, - { - name: "5", - opcode: 5, - echo: true, - ping: true, - }, - { - name: "6", - opcode: 6, - payload: true, - echo: true, - ping: true, - }, - { - name: "7", - opcode: 7, - payload: true, - echo: true, - ping: true, - }, - - // Section 2. - { - name: "11", - opcode: 11, - }, - { - name: "12", - opcode: 12, - payload: true, - }, - { - name: "13", - opcode: 13, - payload: true, - echo: true, - ping: true, - }, - { - name: "14", - opcode: 14, - payload: true, - echo: true, - ping: true, - }, - { - name: "15", - opcode: 15, - payload: true, - echo: true, - ping: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - if tc.echo { - err := assertEcho(ctx, c, websocket.MessageText, 4096) - if err != nil { - return err - } - } - - p := []byte(nil) - if tc.payload { - p = randBytes(rand.Intn(4096) + 1) - } - _, err := c.WriteFrame(ctx, true, tc.opcode, p) - if err != nil { - return err - } - if tc.ping { - _, err = c.WriteFrame(ctx, true, websocket.OpPing, []byte("wtf")) - if err != nil { - return err - } - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - - // Section 5. - t.Run("fragmentation", func(t *testing.T) { - t.Parallel() - - // 5.1 to 5.8 - testCases := []struct { - name string - opcode websocket.OpCode - success bool - pingInBetween bool - }{ - { - name: "ping", - opcode: websocket.OpPing, - success: false, - }, - { - name: "pong", - opcode: websocket.OpPong, - success: false, - }, - { - name: "text", - opcode: websocket.OpText, - success: true, - }, - { - name: "textPing", - opcode: websocket.OpText, - success: true, - pingInBetween: true, - }, - } - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(16) - _, err := c.WriteFrame(ctx, false, tc.opcode, p1) - if err != nil { - return err - } - err = c.BW().Flush() - if err != nil { - return err - } - if !tc.success { - _, _, err = c.Read(ctx) - return assertCloseStatus(err, websocket.StatusProtocolError) - } - - if tc.pingInBetween { - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - } - - p2 := randBytes(16) - _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p2) - if err != nil { - return err - } - - err = assertReadFrame(ctx, c, tc.opcode, p1) - if err != nil { - return err - } - - if tc.pingInBetween { - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - } - - return assertReadFrame(ctx, c, websocket.OpContinuation, p2) - }) + err := s.Shutdown(ctx) + if err != nil { + return fmt.Errorf("server shutdown failed: %v", err) } - t.Run("unexpectedContinuation", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - fin bool - textFirst bool - }{ - { - name: "fin", - fin: true, - }, - { - name: "noFin", - fin: false, - }, - { - name: "echoFirst", - fin: false, - textFirst: true, - }, - // The rest of the tests in this section get complicated and do not inspire much confidence. - } - - for _, tc := range testCases { - tc := tc - run(t, tc.name, func(ctx context.Context, c *websocket.Conn) error { - if tc.textFirst { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - p1 := randBytes(32) - _, err = w.Write(p1) - if err != nil { - return err - } - p2 := randBytes(32) - _, err = w.Write(p2) - if err != nil { - return err - } - err = w.Close() - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) - if err != nil { - return err - } - } - - _, err := c.WriteFrame(ctx, tc.fin, websocket.OpContinuation, randBytes(32)) - if err != nil { - return err - } - err = c.BW().Flush() - if err != nil { - return err - } - - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - - run(t, "doubleText", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, true, websocket.OpText, randBytes(32)) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - - run(t, "5.19", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - p2 := randBytes(32) - p3 := randBytes(32) - p4 := randBytes(32) - p5 := randBytes(32) - - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p2) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - - time.Sleep(time.Second) - - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p3) - if err != nil { - return err - } - _, err = c.WriteFrame(ctx, false, websocket.OpContinuation, p4) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpPing, p1) - if err != nil { - return err - } - - _, err = c.WriteFrame(ctx, true, websocket.OpContinuation, p5) - if err != nil { - return err - } - - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p2) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p3) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p4) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpPong, p1) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, p5) - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpContinuation, []byte{}) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - }) - }) - - // Section 7 - t.Run("closeHandling", func(t *testing.T) { - t.Parallel() - - // 1.1 - 1.4 is useless. - run(t, "1.5", func(ctx context.Context, c *websocket.Conn) error { - p1 := randBytes(32) - _, err := c.WriteFrame(ctx, false, websocket.OpText, p1) - if err != nil { - return err - } - err = c.Flush() - if err != nil { - return err - } - _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - err = assertReadFrame(ctx, c, websocket.OpText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) - }) - - run(t, "1.6", func(ctx context.Context, c *websocket.Conn) error { - // 262144 bytes. - p1 := randBytes(1 << 18) - err := c.Write(ctx, websocket.MessageText, p1) - if err != nil { - return err - } - _, err = c.WriteClose(ctx, websocket.StatusNormalClosure, "") - if err != nil { - return err - } - err = assertReadMessage(ctx, c, websocket.MessageText, p1) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusNormalClosure) - }) - - run(t, "emptyClose", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, nil) - if err != nil { - return err - } - return assertReadFrame(ctx, c, websocket.OpClose, []byte{}) - }) - - run(t, "badClose", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, []byte{1}) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - - run(t, "noReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, "") - }) - - run(t, "simpleReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, randString(16)) - }) - - run(t, "maxReason", func(ctx context.Context, c *websocket.Conn) error { - return c.Close(websocket.StatusNormalClosure, randString(123)) - }) - - run(t, "tooBigReason", func(ctx context.Context, c *websocket.Conn) error { - _, err := c.WriteFrame(ctx, true, websocket.OpClose, - append([]byte{0x03, 0xE8}, randString(124)...), - ) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - - t.Run("validCloses", func(t *testing.T) { - t.Parallel() - - codes := [...]websocket.StatusCode{ - 1000, - 1001, - 1002, - 1003, - 1007, - 1008, - 1009, - 1010, - 1011, - 3000, - 3999, - 4000, - 4999, - } - for _, code := range codes { - run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { - return c.Close(code, randString(32)) - }) - } - }) - - t.Run("invalidCloseCodes", func(t *testing.T) { - t.Parallel() - - codes := []websocket.StatusCode{ - 0, - 999, - 1004, - 1005, - 1006, - 1016, - 1100, - 2000, - 2999, - 5000, - 65535, - } - for _, code := range codes { - run(t, strconv.Itoa(int(code)), func(ctx context.Context, c *websocket.Conn) error { - p := make([]byte, 2) - binary.BigEndian.PutUint16(p, uint16(code)) - p = append(p, randBytes(32)...) - _, err := c.WriteFrame(ctx, true, websocket.OpClose, p) - if err != nil { - return err - } - return assertReadCloseFrame(ctx, c, websocket.StatusProtocolError) - }) - } - }) - }) - - // Section 9. - t.Run("limits", func(t *testing.T) { - t.Parallel() - - t.Run("unfragmentedEcho", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 1 << 16, - 1 << 18, - // Anything higher is completely unnecessary. - } - - for _, l := range lengths { - l := l - run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { - return assertEcho(ctx, c, websocket.MessageBinary, l) - }) - } - }) - - t.Run("fragmentedEcho", func(t *testing.T) { - t.Parallel() - - fragments := []int{ - 64, - 256, - 1 << 10, - 1 << 12, - 1 << 14, - 1 << 16, - } - - for _, l := range fragments { - fragmentLength := l - run(t, strconv.Itoa(fragmentLength), func(ctx context.Context, c *websocket.Conn) error { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - b := randBytes(1 << 16) - for i := 0; i < len(b); { - j := i + fragmentLength - if j > len(b) { - j = len(b) - } - - _, err = w.Write(b[i:j]) - if err != nil { - return err - } - - i = j - } - err = w.Close() - if err != nil { - return err - } - - err = assertReadMessage(ctx, c, websocket.MessageText, b) - if err != nil { - return err - } - return c.Close(websocket.StatusNormalClosure, "") - }) - } - }) - - t.Run("latencyEcho", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 0, - 16, - } - - for _, l := range lengths { - l := l - run(t, strconv.Itoa(l), func(ctx context.Context, c *websocket.Conn) error { - for i := 0; i < 1000; i++ { - err := assertEcho(ctx, c, websocket.MessageBinary, l) - if err != nil { - return err - } - } + t := time.NewTicker(time.Millisecond * 10) + defer t.Stop() + for { + select { + case <-t.C: + if atomic.LoadInt64(&conns) == 0 { return nil - }) - } - }) - }) -} - -func assertCloseStatus(err error, code websocket.StatusCode) error { - var cerr websocket.CloseError - if !errors.As(err, &cerr) { - return fmt.Errorf("no websocket close error in error chain: %+v", err) - } - return assert.Equalf(code, cerr.Code, "unexpected status code") -} - -func assertProtobufRead(ctx context.Context, c *websocket.Conn, exp interface{}) error { - expType := reflect.TypeOf(exp) - actv := reflect.New(expType.Elem()) - act := actv.Interface().(proto.Message) - err := wspb.Read(ctx, c, act) - if err != nil { - return err - } - - return assert.Equalf(exp, act, "unexpected protobuf") -} - -func assertNetConnRead(r io.Reader, exp string) error { - act := make([]byte, len(exp)) - _, err := r.Read(act) - if err != nil { - return err - } - return assert.Equalf(exp, string(act), "unexpected net conn read") -} - -func assertErrorContains(err error, exp string) error { - if err == nil || !strings.Contains(err.Error(), exp) { - return fmt.Errorf("expected error that contains %q but got: %+v", exp, err) - } - return nil -} - -func assertErrorIs(exp, act error) error { - if !errors.Is(act, exp) { - return fmt.Errorf("expected error %+v to be in %+v", exp, act) - } - return nil -} - -func assertReadFrame(ctx context.Context, c *websocket.Conn, opcode websocket.OpCode, p []byte) error { - actOpcode, actP, err := c.ReadFrame(ctx) - if err != nil { - return err - } - err = assert.Equalf(opcode, actOpcode, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - return assert.Equalf(p, actP, "unexpected frame %v payload", opcode) -} - -func assertReadCloseFrame(ctx context.Context, c *websocket.Conn, code websocket.StatusCode) error { - actOpcode, actP, err := c.ReadFrame(ctx) - if err != nil { - return err - } - err = assert.Equalf(websocket.OpClose, actOpcode, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - ce, err := websocket.ParseClosePayload(actP) - if err != nil { - return fmt.Errorf("failed to parse close frame payload: %w", err) - } - return assert.Equalf(ce.Code, code, "unexpected frame close frame code with payload %q", actP) -} - -func assertStreamPing(ctx context.Context, c *websocket.Conn, l int) error { - err := c.WriteHeader(ctx, websocket.Header{ - Fin: true, - OpCode: websocket.OpPing, - PayloadLength: int64(l), - }) - if err != nil { - return err - } - for i := 0; i < l; i++ { - err = c.BW().WriteByte(0xFE) - if err != nil { - return fmt.Errorf("failed to write byte %d: %w", i, err) - } - if i%32 == 0 { - err = c.BW().Flush() - if err != nil { - return fmt.Errorf("failed to flush at byte %d: %w", i, err) + } + case <-ctx.Done(): + return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) } } } - err = c.BW().Flush() - if err != nil { - return fmt.Errorf("failed to flush: %v", err) - } - return assertReadFrame(ctx, c, websocket.OpPong, bytes.Repeat([]byte{0xFE}, l)) -} - -func assertReadMessage(ctx context.Context, c *websocket.Conn, typ websocket.MessageType, p []byte) error { - actTyp, actP, err := c.Read(ctx) - if err != nil { - return err - } - err = assert.Equalf(websocket.MessageText, actTyp, "unexpected frame opcode with payload %q", actP) - if err != nil { - return err - } - return assert.Equalf(p, actP, "unexpected frame %v payload", actTyp) -} - -func BenchmarkConn(b *testing.B) { - sizes := []int{ - 2, - 16, - 32, - 512, - 4096, - 16384, - } - - b.Run("write", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - b.Run("stream", func(b *testing.B) { - benchConn(b, false, true, size) - }) - b.Run("buffer", func(b *testing.B) { - benchConn(b, false, false, size) - }) - }) - } - }) - - b.Run("echo", func(b *testing.B) { - for _, size := range sizes { - b.Run(strconv.Itoa(size), func(b *testing.B) { - benchConn(b, true, true, size) - }) - } - }) } -func benchConn(b *testing.B, echo, stream bool, size int) { - s, closeFn := testServer(b, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, nil) - if err != nil { - return err - } - if echo { - wsecho.Loop(r.Context(), c) - } else { - discardLoop(r.Context(), c) - } - return nil - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) - defer cancel() - - c, _, err := websocket.Dial(ctx, wsURL, nil) - if err != nil { - b.Fatal(err) - } +// echoLoop echos every msg received from c until an error +// occurs or the context expires. +// The read limit is set to 1 << 30. +func echoLoop(ctx context.Context, c *websocket.Conn) error { defer c.Close(websocket.StatusInternalError, "") - msg := []byte(strings.Repeat("2", size)) - readBuf := make([]byte, len(msg)) - b.SetBytes(int64(len(msg))) - b.ReportAllocs() - b.ResetTimer() - for i := 0; i < b.N; i++ { - if stream { - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - b.Fatal(err) - } - - _, err = w.Write(msg) - if err != nil { - b.Fatal(err) - } - - err = w.Close() - if err != nil { - b.Fatal(err) - } - } else { - err = c.Write(ctx, websocket.MessageText, msg) - if err != nil { - b.Fatal(err) - } - } - - if echo { - _, r, err := c.Reader(ctx) - if err != nil { - b.Fatal(err) - } - - _, err = io.ReadFull(r, readBuf) - if err != nil { - b.Fatal(err) - } - } - } - b.StopTimer() - - c.Close(websocket.StatusNormalClosure, "") -} - -func discardLoop(ctx context.Context, c *websocket.Conn) { - defer c.Close(websocket.StatusInternalError, "") + c.SetReadLimit(1 << 30) ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() - b := make([]byte, 32768) - echo := func() error { - _, r, err := c.Reader(ctx) + b := make([]byte, 32<<10) + for { + typ, r, err := c.Reader(ctx) if err != nil { return err } - _, err = io.CopyBuffer(ioutil.Discard, r, b) + w, err := c.Writer(ctx, typ) if err != nil { return err } - return nil - } - for { - err := echo() + _, err = io.CopyBuffer(w, r, b) if err != nil { - return + return err } - } -} - -func TestAutobahnPython(t *testing.T) { - // This test contains the old autobahn test suite tests that use the - // python binary. The approach is clunky and slow so new tests - // have been written in pure Go in websocket_test.go. - // These have been kept for correctness purposes and are occasionally ran. - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to run tests against the python autobahn test suite") - } - - t.Run("server", testServerAutobahnPython) - t.Run("client", testClientAutobahnPython) -} - -// https://github.com/crossbario/autobahn-python/tree/master/wstest -func testServerAutobahnPython(t *testing.T) { - t.Parallel() - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - }) + err = w.Close() if err != nil { - t.Logf("server handshake failed: %+v", err) - return + return err } - wsecho.Loop(r.Context(), c) - })) - defer s.Close() - - spec := map[string]interface{}{ - "outdir": "ci/out/wstestServerReports", - "servers": []interface{}{ - map[string]interface{}{ - "agent": "main", - "url": strings.Replace(s.URL, "http", "ws", 1), - }, - }, - "cases": []string{"*"}, - // We skip the UTF-8 handling tests as there isn't any reason to reject invalid UTF-8, just - // more performance overhead. 7.5.1 is the same. - // 12.* and 13.* as we do not support compression. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingClient.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingclient.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingclient", "--spec", specFile.Name()} - wstest := exec.CommandContext(ctx, "wstest", args...) - out, err := wstest.CombinedOutput() - if err != nil { - t.Fatalf("failed to run wstest: %v\nout:\n%s", err, out) } - - checkWSTestIndex(t, "./ci/out/wstestServerReports/index.json") } -func unusedListenAddr() (string, error) { - l, err := net.Listen("tcp", "localhost:0") - if err != nil { - return "", err - } - l.Close() - return l.Addr().String(), nil -} - -// https://github.com/crossbario/autobahn-python/blob/master/wstest/testee_client_aio.py -func testClientAutobahnPython(t *testing.T) { +func TestConn(t *testing.T) { t.Parallel() - if os.Getenv("AUTOBAHN_PYTHON") == "" { - t.Skip("Set $AUTOBAHN_PYTHON to test against the python autobahn test suite") - } - - serverAddr, err := unusedListenAddr() - if err != nil { - t.Fatalf("failed to get unused listen addr for wstest: %v", err) - } - - wsServerURL := "ws://" + serverAddr - - spec := map[string]interface{}{ - "url": wsServerURL, - "outdir": "ci/out/wstestClientReports", - "cases": []string{"*"}, - // See TestAutobahnServer for the reasons why we exclude these. - "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, - } - specFile, err := ioutil.TempFile("", "websocketFuzzingServer.json") - if err != nil { - t.Fatalf("failed to create temp file for fuzzingserver.json: %v", err) - } - defer specFile.Close() - - e := json.NewEncoder(specFile) - e.SetIndent("", "\t") - err = e.Encode(spec) - if err != nil { - t.Fatalf("failed to write spec: %v", err) - } - - err = specFile.Close() - if err != nil { - t.Fatalf("failed to close file: %v", err) - } - - ctx := context.Background() - ctx, cancel := context.WithTimeout(ctx, time.Minute*10) - defer cancel() - - args := []string{"--mode", "fuzzingserver", "--spec", specFile.Name(), - // Disables some server that runs as part of fuzzingserver mode. - // See https://github.com/crossbario/autobahn-testsuite/blob/058db3a36b7c3a1edf68c282307c6b899ca4857f/autobahntestsuite/autobahntestsuite/wstest.py#L124 - "--webport=0", - } - wstest := exec.CommandContext(ctx, "wstest", args...) - err = wstest.Start() - if err != nil { - t.Fatal(err) - } - defer func() { - err := wstest.Process.Kill() - if err != nil { - t.Error(err) - } - }() - - // Let it come up. - time.Sleep(time.Second * 5) - - var cases int - func() { - c, _, err := websocket.Dial(ctx, wsServerURL+"/getCaseCount", nil) - if err != nil { - t.Fatal(err) - } - defer c.Close(websocket.StatusInternalError, "") - - _, r, err := c.Reader(ctx) - if err != nil { - t.Fatal(err) - } - b, err := ioutil.ReadAll(r) - if err != nil { - t.Fatal(err) - } - cases, err = strconv.Atoi(string(b)) - if err != nil { - t.Fatal(err) - } - - c.Close(websocket.StatusNormalClosure, "") - }() - - for i := 1; i <= cases; i++ { - func() { - ctx, cancel := context.WithTimeout(ctx, time.Second*45) - defer cancel() - - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/runCase?case=%v&agent=main", i), nil) - if err != nil { - t.Fatal(err) - } - wsecho.Loop(ctx, c) - }() - } - - c, _, err := websocket.Dial(ctx, fmt.Sprintf(wsServerURL+"/updateReports?agent=main"), nil) - if err != nil { - t.Fatal(err) - } - c.Close(websocket.StatusNormalClosure, "") + t.Run("json", func(t *testing.T) { + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ + Subprotocols: []string{"echo"}, + InsecureSkipVerify: true, + }) + assert.Success(t, err) + defer c.Close(websocket.StatusInternalError, "") - checkWSTestIndex(t, "./ci/out/wstestClientReports/index.json") -} + err = echoLoop(r.Context(), c) + assertCloseStatus(t, websocket.StatusNormalClosure, err) + }, false) + defer closeFn() -func checkWSTestIndex(t *testing.T, path string) { - wstestOut, err := ioutil.ReadFile(path) - if err != nil { - t.Fatalf("failed to read index.json: %v", err) - } + wsURL := strings.Replace(s.URL, "http", "ws", 1) - var indexJSON map[string]map[string]struct { - Behavior string `json:"behavior"` - BehaviorClose string `json:"behaviorClose"` - } - err = json.Unmarshal(wstestOut, &indexJSON) - if err != nil { - t.Fatalf("failed to unmarshal index.json: %v", err) - } + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() - var failed bool - for _, tests := range indexJSON { - for test, result := range tests { - switch result.Behavior { - case "OK", "NON-STRICT", "INFORMATIONAL": - default: - failed = true - t.Errorf("test %v failed", test) - } - switch result.BehaviorClose { - case "OK", "INFORMATIONAL": - default: - failed = true - t.Errorf("bad close behaviour for test %v", test) - } - } - } - - if failed { - path = strings.Replace(path, ".json", ".html", 1) - if os.Getenv("CI") == "" { - t.Errorf("wstest found failure, see %q (output as an artifact in CI)", path) - } - } -} - -func TestWASM(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - err = wsecho.Loop(r.Context(), c) - if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - return err + opts := &websocket.DialOptions{ + Subprotocols: []string{"echo"}, } - return nil - }, false) - defer closeFn() + opts.HTTPClient = s.Client() - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() + c, _, err := websocket.Dial(ctx, wsURL, opts) + assert.Success(t, err) - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) - - b, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("wasm test binary failed: %v:\n%s", err, b) - } + assertJSONEcho(t, ctx, c, 2) + }) } diff --git a/dial.go b/dial.go index 1008868..8fa0f7a 100644 --- a/dial.go +++ b/dial.go @@ -1,17 +1,19 @@ package websocket import ( + "bufio" "bytes" "context" "crypto/rand" "encoding/base64" + "errors" "fmt" "io" "io/ioutil" "net/http" "net/url" - "nhooyr.io/websocket/internal/bufpool" "strings" + "sync" ) // DialOptions represents the options available to pass to Dial. @@ -50,7 +52,7 @@ func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Respon return c, r, nil } -func (opts *DialOptions) fill() (*DialOptions, error) { +func (opts *DialOptions) ensure() *DialOptions { if opts == nil { opts = &DialOptions{} } else { @@ -60,20 +62,18 @@ func (opts *DialOptions) fill() (*DialOptions, error) { if opts.HTTPClient == nil { opts.HTTPClient = http.DefaultClient } - if opts.HTTPClient.Timeout > 0 { - return nil, fmt.Errorf("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") - } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} } - return opts, nil + return opts } func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Response, err error) { - opts, err = opts.fill() - if err != nil { - return nil, nil, err + opts = opts.ensure() + + if opts.HTTPClient.Timeout > 0 { + return nil, nil, errors.New("use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } parsedURL, err := url.Parse(u) @@ -104,8 +104,10 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re if len(opts.Subprotocols) > 0 { req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) } - copts := opts.CompressionMode.opts() - copts.setHeader(req.Header) + if opts.CompressionMode != CompressionDisabled { + copts := opts.CompressionMode.opts() + copts.setHeader(req.Header) + } resp, err := opts.HTTPClient.Do(req) if err != nil { @@ -121,7 +123,7 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re } }() - copts, err = verifyServerResponse(req, resp, opts) + copts, err := verifyServerResponse(req, resp) if err != nil { return nil, resp, err } @@ -131,18 +133,14 @@ func dial(ctx context.Context, u string, opts *DialOptions) (_ *Conn, _ *http.Re return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", rwc) } - c := &Conn{ + return newConn(connConfig{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: bufpool.GetReader(rwc), - bw: bufpool.GetWriter(rwc), - closer: rwc, + rwc: rwc, client: true, copts: copts, - } - c.extractBufioWriterBuf(rwc) - c.init() - - return c, resp, nil + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), + }), resp, nil } func secWebSocketKey() (string, error) { @@ -154,7 +152,7 @@ func secWebSocketKey() (string, error) { return base64.StdEncoding.EncodeToString(b), nil } -func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOptions) (*compressionOptions, error) { +func verifyServerResponse(r *http.Request, resp *http.Response) (*compressionOptions, error) { if resp.StatusCode != http.StatusSwitchingProtocols { return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) } @@ -178,7 +176,7 @@ func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOption return nil, fmt.Errorf("websocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) } - copts, err := verifyServerExtensions(resp.Header, opts.CompressionMode) + copts, err := verifyServerExtensions(resp.Header) if err != nil { return nil, err } @@ -186,7 +184,7 @@ func verifyServerResponse(r *http.Request, resp *http.Response, opts *DialOption return copts, nil } -func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOptions, error) { +func verifyServerExtensions(h http.Header) (*compressionOptions, error) { exts := websocketExtensions(h) if len(exts) == 0 { return nil, nil @@ -201,7 +199,7 @@ func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOp return nil, fmt.Errorf("unexpected extra extensions from server: %+v", exts[1:]) } - copts := mode.opts() + copts := &compressionOptions{} for _, p := range ext.params { switch p { case "client_no_context_takeover": @@ -217,3 +215,33 @@ func verifyServerExtensions(h http.Header, mode CompressionMode) (*compressionOp return copts, nil } + +var readerPool sync.Pool + +func getBufioReader(r io.Reader) *bufio.Reader { + br, ok := readerPool.Get().(*bufio.Reader) + if !ok { + return bufio.NewReader(r) + } + br.Reset(r) + return br +} + +func putBufioReader(br *bufio.Reader) { + readerPool.Put(br) +} + +var writerPool sync.Pool + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw, ok := writerPool.Get().(*bufio.Writer) + if !ok { + return bufio.NewWriter(w) + } + bw.Reset(w) + return bw +} + +func putBufioWriter(bw *bufio.Writer) { + writerPool.Put(bw) +} diff --git a/dial_test.go b/dial_test.go index 391aa1c..5eeb904 100644 --- a/dial_test.go +++ b/dial_test.go @@ -140,7 +140,7 @@ func Test_verifyServerHandshake(t *testing.T) { resp.Header.Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) } - _, err = verifyServerResponse(r, resp, &DialOptions{}) + _, err = verifyServerResponse(r, resp) if (err == nil) != tc.success { t.Fatalf("unexpected error: %+v", err) } diff --git a/example_echo_test.go b/example_echo_test.go index ecc9b97..16d003d 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -4,6 +4,7 @@ package websocket_test import ( "context" + "errors" "fmt" "io" "log" @@ -77,7 +78,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { if c.Subprotocol() != "echo" { c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") - return fmt.Errorf("client does not speak echo sub protocol") + return errors.New("client does not speak echo sub protocol") } l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) diff --git a/internal/wsframe/mask.go b/frame.go similarity index 57% rename from internal/wsframe/mask.go rename to frame.go index 2da4c11..0f10d55 100644 --- a/internal/wsframe/mask.go +++ b/frame.go @@ -1,11 +1,167 @@ -package wsframe +package websocket import ( + "bufio" "encoding/binary" + "math" "math/bits" + "nhooyr.io/websocket/internal/errd" ) -// Mask applies the WebSocket masking algorithm to p +// opcode represents a WebSocket opcode. +type opcode int + +// List at https://tools.ietf.org/html/rfc6455#section-11.8. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +type header struct { + fin bool + rsv1 bool + rsv2 bool + rsv3 bool + opcode opcode + + payloadLength int64 + + masked bool + maskKey uint32 +} + +// readFrameHeader reads a header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +func readFrameHeader(r *bufio.Reader) (_ header, err error) { + defer errd.Wrap(&err, "failed to read frame header") + + b, err := r.ReadByte() + if err != nil { + return header{}, err + } + + var h header + h.fin = b&(1<<7) != 0 + h.rsv1 = b&(1<<6) != 0 + h.rsv2 = b&(1<<5) != 0 + h.rsv3 = b&(1<<4) != 0 + + h.opcode = opcode(b & 0xf) + + b, err = r.ReadByte() + if err != nil { + return header{}, err + } + + h.masked = b&(1<<7) != 0 + + payloadLength := b &^ (1 << 7) + switch { + case payloadLength < 126: + h.payloadLength = int64(payloadLength) + case payloadLength == 126: + var pl uint16 + err = binary.Read(r, binary.BigEndian, &pl) + h.payloadLength = int64(pl) + case payloadLength == 127: + err = binary.Read(r, binary.BigEndian, &h.payloadLength) + } + if err != nil { + return header{}, err + } + + if h.masked { + err = binary.Read(r, binary.LittleEndian, &h.maskKey) + if err != nil { + return header{}, err + } + } + + return h, nil +} + +// maxControlPayload is the maximum length of a control frame payload. +// See https://tools.ietf.org/html/rfc6455#section-5.5. +const maxControlPayload = 125 + +// writeFrameHeader writes the bytes of the header to w. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func writeFrameHeader(h header, w *bufio.Writer) (err error) { + defer errd.Wrap(&err, "failed to write frame header") + + var b byte + if h.fin { + b |= 1 << 7 + } + if h.rsv1 { + b |= 1 << 6 + } + if h.rsv2 { + b |= 1 << 5 + } + if h.rsv3 { + b |= 1 << 4 + } + + b |= byte(h.opcode) + + err = w.WriteByte(b) + if err != nil { + return err + } + + lengthByte := byte(0) + if h.masked { + lengthByte |= 1 << 7 + } + + switch { + case h.payloadLength > math.MaxUint16: + lengthByte |= 127 + case h.payloadLength > 125: + lengthByte |= 126 + case h.payloadLength >= 0: + lengthByte |= byte(h.payloadLength) + } + err = w.WriteByte(lengthByte) + if err != nil { + return err + } + + switch { + case h.payloadLength > math.MaxUint16: + err = binary.Write(w, binary.BigEndian, h.payloadLength) + case h.payloadLength > 125: + err = binary.Write(w, binary.BigEndian, uint16(h.payloadLength)) + } + if err != nil { + return err + } + + if h.masked { + err = binary.Write(w, binary.LittleEndian, h.maskKey) + if err != nil { + return err + } + } + + return nil +} + +// mask applies the WebSocket masking algorithm to p // with the given key. // See https://tools.ietf.org/html/rfc6455#section-5.3 // @@ -16,7 +172,7 @@ import ( // to be in little endian. // // See https://github.com/golang/go/issues/31586 -func Mask(key uint32, b []byte) uint32 { +func mask(key uint32, b []byte) uint32 { if len(b) >= 8 { key64 := uint64(key)<<32 | uint64(key) diff --git a/internal/wsframe/mask_test.go b/frame_test.go similarity index 51% rename from internal/wsframe/mask_test.go rename to frame_test.go index fbd2989..0ed14ae 100644 --- a/internal/wsframe/mask_test.go +++ b/frame_test.go @@ -1,32 +1,108 @@ -package wsframe_test +// +build !js + +package websocket import ( - "crypto/rand" + "bufio" + "bytes" "encoding/binary" - "github.com/gobwas/ws" - "github.com/google/go-cmp/cmp" "math/bits" - "nhooyr.io/websocket/internal/wsframe" + "nhooyr.io/websocket/internal/assert" "strconv" "testing" + "time" _ "unsafe" + + "github.com/gobwas/ws" + _ "github.com/gorilla/websocket" + "math/rand" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func TestHeader(t *testing.T) { + t.Parallel() + + t.Run("lengths", func(t *testing.T) { + t.Parallel() + + lengths := []int{ + 124, + 125, + 126, + 127, + + 65534, + 65535, + 65536, + 65537, + } + + for _, n := range lengths { + n := n + t.Run(strconv.Itoa(n), func(t *testing.T) { + t.Parallel() + + testHeader(t, header{ + payloadLength: int64(n), + }) + }) + } + }) + + t.Run("fuzz", func(t *testing.T) { + t.Parallel() + + randBool := func() bool { + return rand.Intn(1) == 0 + } + + for i := 0; i < 10000; i++ { + h := header{ + fin: randBool(), + rsv1: randBool(), + rsv2: randBool(), + rsv3: randBool(), + opcode: opcode(rand.Intn(16)), + + masked: randBool(), + maskKey: rand.Uint32(), + payloadLength: rand.Int63(), + } + + testHeader(t, h) + } + }) +} + +func testHeader(t *testing.T, h header) { + b := &bytes.Buffer{} + w := bufio.NewWriter(b) + r := bufio.NewReader(b) + + err := writeFrameHeader(h, w) + assert.Success(t, err) + err = w.Flush() + assert.Success(t, err) + + h2, err := readFrameHeader(r) + assert.Success(t, err) + + assert.Equalf(t, h, h2, "written and read headers differ") +} + func Test_mask(t *testing.T) { t.Parallel() key := []byte{0xa, 0xb, 0xc, 0xff} key32 := binary.LittleEndian.Uint32(key) p := []byte{0xa, 0xb, 0xc, 0xf2, 0xc} - gotKey32 := wsframe.Mask(key32, p) + gotKey32 := mask(key32, p) - if exp := []byte{0, 0, 0, 0x0d, 0x6}; !cmp.Equal(exp, p) { - t.Fatalf("unexpected mask: %v", cmp.Diff(exp, p)) - } - - if exp := bits.RotateLeft32(key32, -8); !cmp.Equal(exp, gotKey32) { - t.Fatalf("unexpected mask key: %v", cmp.Diff(exp, gotKey32)) - } + assert.Equalf(t, []byte{0, 0, 0, 0x0d, 0x6}, p, "unexpected mask") + assert.Equalf(t, bits.RotateLeft32(key32, -8), gotKey32, "unexpected mask key") } func basicMask(maskKey [4]byte, pos int, b []byte) int { @@ -74,7 +150,7 @@ func Benchmark_mask(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - wsframe.Mask(key32, p) + mask(key32, p) } }, }, @@ -98,9 +174,7 @@ func Benchmark_mask(b *testing.B) { var key [4]byte _, err := rand.Read(key[:]) - if err != nil { - b.Fatalf("failed to populate mask key: %v", err) - } + assert.Success(b, err) for _, size := range sizes { p := make([]byte, size) diff --git a/internal/assert/assert.go b/internal/assert/assert.go index 372d546..1d9aece 100644 --- a/internal/assert/assert.go +++ b/internal/assert/assert.go @@ -2,6 +2,7 @@ package assert import ( "reflect" + "strings" "testing" "github.com/google/go-cmp/cmp" @@ -53,7 +54,7 @@ func structTypes(v reflect.Value, m map[reflect.Type]struct{}) { } } -func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) { +func Equalf(t testing.TB, exp, act interface{}, f string, v ...interface{}) { t.Helper() diff := cmpDiff(exp, act) if diff != "" { @@ -61,7 +62,40 @@ func Equalf(t *testing.T, exp, act interface{}, f string, v ...interface{}) { } } -func Success(t *testing.T, err error) { +func NotEqualf(t testing.TB, exp, act interface{}, f string, v ...interface{}) { t.Helper() - Equalf(t, error(nil), err, "unexpected failure") + diff := cmpDiff(exp, act) + if diff == "" { + t.Fatalf(f+": %v", append(v, diff)...) + } +} + +func Success(t testing.TB, err error) { + t.Helper() + if err != nil { + t.Fatalf("unexpected error: %+v", err) + } +} + +func Error(t testing.TB, err error) { + t.Helper() + if err == nil { + t.Fatal("expected error") + } +} + +func ErrorContains(t testing.TB, err error, sub string) { + t.Helper() + Error(t, err) + errs := err.Error() + if !strings.Contains(errs, sub) { + t.Fatalf("error string %q does not contain %q", errs, sub) + } +} + +func Panicf(t testing.TB, f string, v ...interface{}) { + r := recover() + if r == nil { + t.Fatalf(f, v...) + } } diff --git a/internal/atomicint/atomicint.go b/internal/atomicint/atomicint.go deleted file mode 100644 index 668b3b4..0000000 --- a/internal/atomicint/atomicint.go +++ /dev/null @@ -1,32 +0,0 @@ -package atomicint - -import ( - "fmt" - "sync/atomic" -) - -// See https://github.com/nhooyr/websocket/issues/153 -type Int64 struct { - v int64 -} - -func (v *Int64) Load() int64 { - return atomic.LoadInt64(&v.v) -} - -func (v *Int64) Store(i int64) { - atomic.StoreInt64(&v.v, i) -} - -func (v *Int64) String() string { - return fmt.Sprint(v.Load()) -} - -// Increment increments the value and returns the new value. -func (v *Int64) Increment(delta int64) int64 { - return atomic.AddInt64(&v.v, delta) -} - -func (v *Int64) CAS(old, new int64) (swapped bool) { - return atomic.CompareAndSwapInt64(&v.v, old, new) -} diff --git a/internal/bufpool/buf.go b/internal/bufpool/buf.go index 324a17e..0f7d976 100644 --- a/internal/bufpool/buf.go +++ b/internal/bufpool/buf.go @@ -5,12 +5,12 @@ import ( "sync" ) -var bpool sync.Pool +var pool sync.Pool // Get returns a buffer from the pool or creates a new one if // the pool is empty. func Get() *bytes.Buffer { - b, ok := bpool.Get().(*bytes.Buffer) + b, ok := pool.Get().(*bytes.Buffer) if !ok { b = &bytes.Buffer{} } @@ -20,5 +20,5 @@ func Get() *bytes.Buffer { // Put returns a buffer into the pool. func Put(b *bytes.Buffer) { b.Reset() - bpool.Put(b) + pool.Put(b) } diff --git a/internal/bufpool/bufio.go b/internal/bufpool/bufio.go deleted file mode 100644 index 875bbf4..0000000 --- a/internal/bufpool/bufio.go +++ /dev/null @@ -1,40 +0,0 @@ -package bufpool - -import ( - "bufio" - "io" - "sync" -) - -var readerPool = sync.Pool{ - New: func() interface{} { - return bufio.NewReader(nil) - }, -} - -func GetReader(r io.Reader) *bufio.Reader { - br := readerPool.Get().(*bufio.Reader) - br.Reset(r) - return br -} - -func PutReader(br *bufio.Reader) { - readerPool.Put(br) -} - -var writerPool = sync.Pool{ - New: func() interface{} { - return bufio.NewWriter(nil) - }, -} - -func GetWriter(w io.Writer) *bufio.Writer { - bw := writerPool.Get().(*bufio.Writer) - bw.Reset(w) - return bw -} - -func PutWriter(bw *bufio.Writer) { - writerPool.Put(bw) -} - diff --git a/internal/errd/errd.go b/internal/errd/errd.go new file mode 100644 index 0000000..51b7b4f --- /dev/null +++ b/internal/errd/errd.go @@ -0,0 +1,11 @@ +package errd + +import ( + "fmt" +) + +func Wrap(err *error, f string, v ...interface{}) { + if *err != nil { + *err = fmt.Errorf(f+ ": %w", append(v, *err)...) + } +} diff --git a/internal/wsecho/wsecho.go b/internal/wsecho/wsecho.go deleted file mode 100644 index c408f07..0000000 --- a/internal/wsecho/wsecho.go +++ /dev/null @@ -1,55 +0,0 @@ -// +build !js - -package wsecho - -import ( - "context" - "io" - "time" - - "nhooyr.io/websocket" -) - -// Loop echos every msg received from c until an error -// occurs or the context expires. -// The read limit is set to 1 << 30. -func Loop(ctx context.Context, c *websocket.Conn) error { - defer c.Close(websocket.StatusInternalError, "") - - c.SetReadLimit(1 << 30) - - ctx, cancel := context.WithTimeout(ctx, time.Minute) - defer cancel() - - b := make([]byte, 32<<10) - echo := func() error { - typ, r, err := c.Reader(ctx) - if err != nil { - return err - } - - w, err := c.Writer(ctx, typ) - if err != nil { - return err - } - - _, err = io.CopyBuffer(w, r, b) - if err != nil { - return err - } - - err = w.Close() - if err != nil { - return err - } - - return nil - } - - for { - err := echo() - if err != nil { - return err - } - } -} diff --git a/internal/wsframe/frame.go b/internal/wsframe/frame.go deleted file mode 100644 index 50ff8c1..0000000 --- a/internal/wsframe/frame.go +++ /dev/null @@ -1,194 +0,0 @@ -package wsframe - -import ( - "encoding/binary" - "fmt" - "io" - "math" -) - -// Opcode represents a WebSocket Opcode. -type Opcode int - -// Opcode constants. -const ( - OpContinuation Opcode = iota - OpText - OpBinary - // 3 - 7 are reserved for further non-control frames. - _ - _ - _ - _ - _ - OpClose - OpPing - OpPong - // 11-16 are reserved for further control frames. -) - -func (o Opcode) Control() bool { - switch o { - case OpClose, OpPing, OpPong: - return true - } - return false -} - -func (o Opcode) Data() bool { - switch o { - case OpText, OpBinary: - return true - } - return false -} - -// First byte contains fin, rsv1, rsv2, rsv3. -// Second byte contains mask flag and payload len. -// Next 8 bytes are the maximum extended payload length. -// Last 4 bytes are the mask key. -// https://tools.ietf.org/html/rfc6455#section-5.2 -const maxHeaderSize = 1 + 1 + 8 + 4 - -// Header represents a WebSocket frame Header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -type Header struct { - Fin bool - RSV1 bool - RSV2 bool - RSV3 bool - Opcode Opcode - - PayloadLength int64 - - Masked bool - MaskKey uint32 -} - -// bytes returns the bytes of the Header. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func (h Header) Bytes(b []byte) []byte { - if b == nil { - b = make([]byte, maxHeaderSize) - } - - b = b[:2] - b[0] = 0 - - if h.Fin { - b[0] |= 1 << 7 - } - if h.RSV1 { - b[0] |= 1 << 6 - } - if h.RSV2 { - b[0] |= 1 << 5 - } - if h.RSV3 { - b[0] |= 1 << 4 - } - - b[0] |= byte(h.Opcode) - - switch { - case h.PayloadLength < 0: - panic(fmt.Sprintf("websocket: invalid Header: negative length: %v", h.PayloadLength)) - case h.PayloadLength <= 125: - b[1] = byte(h.PayloadLength) - case h.PayloadLength <= math.MaxUint16: - b[1] = 126 - b = b[:len(b)+2] - binary.BigEndian.PutUint16(b[len(b)-2:], uint16(h.PayloadLength)) - default: - b[1] = 127 - b = b[:len(b)+8] - binary.BigEndian.PutUint64(b[len(b)-8:], uint64(h.PayloadLength)) - } - - if h.Masked { - b[1] |= 1 << 7 - b = b[:len(b)+4] - binary.LittleEndian.PutUint32(b[len(b)-4:], h.MaskKey) - } - - return b -} - -func MakeReadHeaderBuf() []byte { - return make([]byte, maxHeaderSize-2) -} - -// ReadHeader reads a Header from the reader. -// See https://tools.ietf.org/html/rfc6455#section-5.2 -func ReadHeader(r io.Reader, b []byte) (Header, error) { - // We read the first two bytes first so that we know - // exactly how long the Header is. - b = b[:2] - _, err := io.ReadFull(r, b) - if err != nil { - return Header{}, err - } - - var h Header - h.Fin = b[0]&(1<<7) != 0 - h.RSV1 = b[0]&(1<<6) != 0 - h.RSV2 = b[0]&(1<<5) != 0 - h.RSV3 = b[0]&(1<<4) != 0 - - h.Opcode = Opcode(b[0] & 0xf) - - var extra int - - h.Masked = b[1]&(1<<7) != 0 - if h.Masked { - extra += 4 - } - - payloadLength := b[1] &^ (1 << 7) - switch { - case payloadLength < 126: - h.PayloadLength = int64(payloadLength) - case payloadLength == 126: - extra += 2 - case payloadLength == 127: - extra += 8 - } - - if extra == 0 { - return h, nil - } - - b = b[:extra] - _, err = io.ReadFull(r, b) - if err != nil { - return Header{}, err - } - - switch { - case payloadLength == 126: - h.PayloadLength = int64(binary.BigEndian.Uint16(b)) - b = b[2:] - case payloadLength == 127: - h.PayloadLength = int64(binary.BigEndian.Uint64(b)) - if h.PayloadLength < 0 { - return Header{}, fmt.Errorf("Header with negative payload length: %v", h.PayloadLength) - } - b = b[8:] - } - - if h.Masked { - h.MaskKey = binary.LittleEndian.Uint32(b) - } - - return h, nil -} - -const MaxControlFramePayload = 125 - -func ParseClosePayload(p []byte) (uint16, string, error) { - if len(p) < 2 { - return 0, "", fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) - } - - return binary.BigEndian.Uint16(p), string(p[2:]), nil -} diff --git a/internal/wsframe/frame_stringer.go b/internal/wsframe/frame_stringer.go deleted file mode 100644 index b2e7f42..0000000 --- a/internal/wsframe/frame_stringer.go +++ /dev/null @@ -1,91 +0,0 @@ -// Code generated by "stringer -type=Opcode,MessageType,StatusCode -output=frame_stringer.go"; DO NOT EDIT. - -package wsframe - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[OpContinuation-0] - _ = x[OpText-1] - _ = x[OpBinary-2] - _ = x[OpClose-8] - _ = x[OpPing-9] - _ = x[OpPong-10] -} - -const ( - _opcode_name_0 = "opContinuationopTextopBinary" - _opcode_name_1 = "opCloseopPingopPong" -) - -var ( - _opcode_index_0 = [...]uint8{0, 14, 20, 28} - _opcode_index_1 = [...]uint8{0, 7, 13, 19} -) - -func (i Opcode) String() string { - switch { - case 0 <= i && i <= 2: - return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] - case 8 <= i && i <= 10: - i -= 8 - return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] - default: - return "Opcode(" + strconv.FormatInt(int64(i), 10) + ")" - } -} -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[MessageText-1] - _ = x[MessageBinary-2] -} - -const _MessageType_name = "MessageTextMessageBinary" - -var _MessageType_index = [...]uint8{0, 11, 24} - -func (i MessageType) String() string { - i -= 1 - if i < 0 || i >= MessageType(len(_MessageType_index)-1) { - return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] -} -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[StatusNormalClosure-1000] - _ = x[StatusGoingAway-1001] - _ = x[StatusProtocolError-1002] - _ = x[StatusUnsupportedData-1003] - _ = x[statusReserved-1004] - _ = x[StatusNoStatusRcvd-1005] - _ = x[StatusAbnormalClosure-1006] - _ = x[StatusInvalidFramePayloadData-1007] - _ = x[StatusPolicyViolation-1008] - _ = x[StatusMessageTooBig-1009] - _ = x[StatusMandatoryExtension-1010] - _ = x[StatusInternalError-1011] - _ = x[StatusServiceRestart-1012] - _ = x[StatusTryAgainLater-1013] - _ = x[StatusBadGateway-1014] - _ = x[StatusTLSHandshake-1015] -} - -const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" - -var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} - -func (i StatusCode) String() string { - i -= 1000 - if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { - return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" - } - return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] -} diff --git a/internal/wsframe/frame_test.go b/internal/wsframe/frame_test.go deleted file mode 100644 index d6b66e7..0000000 --- a/internal/wsframe/frame_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// +build !js - -package wsframe - -import ( - "bytes" - "io" - "math/rand" - "strconv" - "testing" - "time" - _ "unsafe" - - "github.com/google/go-cmp/cmp" - _ "github.com/gorilla/websocket" -) - -func init() { - rand.Seed(time.Now().UnixNano()) -} - -func randBool() bool { - return rand.Intn(1) == 0 -} - -func TestHeader(t *testing.T) { - t.Parallel() - - t.Run("eof", func(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - bytes []byte - }{ - { - "start", - []byte{0xff}, - }, - { - "middle", - []byte{0xff, 0xff, 0xff}, - }, - } - for _, tc := range testCases { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - b := bytes.NewBuffer(tc.bytes) - _, err := ReadHeader(nil, b) - if io.ErrUnexpectedEOF != err { - t.Fatalf("expected %v but got: %v", io.ErrUnexpectedEOF, err) - } - }) - } - }) - - t.Run("writeNegativeLength", func(t *testing.T) { - t.Parallel() - - defer func() { - r := recover() - if r == nil { - t.Fatal("failed to induce panic in writeHeader with negative payload length") - } - }() - - Header{ - PayloadLength: -1, - }.Bytes(nil) - }) - - t.Run("readNegativeLength", func(t *testing.T) { - t.Parallel() - - b := Header{ - PayloadLength: 1<<16 + 1, - }.Bytes(nil) - - // Make length negative - b[2] |= 1 << 7 - - r := bytes.NewReader(b) - _, err := ReadHeader(nil, r) - if err == nil { - t.Fatalf("unexpected error value: %+v", err) - } - }) - - t.Run("lengths", func(t *testing.T) { - t.Parallel() - - lengths := []int{ - 124, - 125, - 126, - 4096, - 16384, - 65535, - 65536, - 65537, - 131072, - } - - for _, n := range lengths { - n := n - t.Run(strconv.Itoa(n), func(t *testing.T) { - t.Parallel() - - testHeader(t, Header{ - PayloadLength: int64(n), - }) - }) - } - }) - - t.Run("fuzz", func(t *testing.T) { - t.Parallel() - - for i := 0; i < 10000; i++ { - h := Header{ - Fin: randBool(), - RSV1: randBool(), - RSV2: randBool(), - RSV3: randBool(), - Opcode: Opcode(rand.Intn(1 << 4)), - - Masked: randBool(), - PayloadLength: rand.Int63(), - } - - if h.Masked { - h.MaskKey = rand.Uint32() - } - - testHeader(t, h) - } - }) -} - -func testHeader(t *testing.T, h Header) { - b := h.Bytes(nil) - r := bytes.NewReader(b) - h2, err := ReadHeader(r, nil) - if err != nil { - t.Logf("Header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("failed to read Header: %v", err) - } - - if !cmp.Equal(h, h2, cmp.AllowUnexported(Header{})) { - t.Logf("Header: %#v", h) - t.Logf("bytes: %b", b) - t.Fatalf("parsed and read Header differ: %v", cmp.Diff(h, h2, cmp.AllowUnexported(Header{}))) - } -} diff --git a/internal/wsgrace/wsgrace.go b/internal/wsgrace/wsgrace.go deleted file mode 100644 index 513af1f..0000000 --- a/internal/wsgrace/wsgrace.go +++ /dev/null @@ -1,50 +0,0 @@ -package wsgrace - -import ( - "context" - "fmt" - "net/http" - "sync/atomic" - "time" -) - -// Grace wraps s.Handler to gracefully shutdown WebSocket connections. -// The returned function must be used to close the server instead of s.Close. -func Grace(s *http.Server) (closeFn func() error) { - h := s.Handler - var conns int64 - s.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) - - ctx, cancel := context.WithTimeout(r.Context(), time.Minute) - defer cancel() - - r = r.WithContext(ctx) - - h.ServeHTTP(w, r) - }) - - return func() error { - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - err := s.Shutdown(ctx) - if err != nil { - return fmt.Errorf("server shutdown failed: %v", err) - } - - t := time.NewTicker(time.Millisecond * 10) - defer t.Stop() - for { - select { - case <-t.C: - if atomic.LoadInt64(&conns) == 0 { - return nil - } - case <-ctx.Done(): - return fmt.Errorf("failed to wait for WebSocket connections: %v", ctx.Err()) - } - } - } -} diff --git a/js_test.go b/js_test.go deleted file mode 100644 index 80af789..0000000 --- a/js_test.go +++ /dev/null @@ -1,50 +0,0 @@ -package websocket_test - -import ( - "context" - "fmt" - "net/http" - "nhooyr.io/websocket/internal/wsecho" - "os" - "os/exec" - "strings" - "testing" - "time" - - "nhooyr.io/websocket" -) - -func TestJS(t *testing.T) { - t.Parallel() - - s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) error { - c, err := websocket.Accept(w, r, &websocket.AcceptOptions{ - Subprotocols: []string{"echo"}, - InsecureSkipVerify: true, - }) - if err != nil { - return err - } - defer c.Close(websocket.StatusInternalError, "") - - err = wsecho.Loop(r.Context(), c) - if websocket.CloseStatus(err) != websocket.StatusNormalClosure { - return err - } - return nil - }, false) - defer closeFn() - - wsURL := strings.Replace(s.URL, "http", "ws", 1) - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - cmd := exec.CommandContext(ctx, "go", "test", "-exec=wasmbrowsertest", "./...") - cmd.Env = append(os.Environ(), "GOOS=js", "GOARCH=wasm", fmt.Sprintf("WS_ECHO_SERVER_URL=%v", wsURL)) - - b, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("wasm test binary failed: %v:\n%s", err, b) - } -} diff --git a/read.go b/read.go new file mode 100644 index 0000000..97096f7 --- /dev/null +++ b/read.go @@ -0,0 +1,479 @@ +package websocket + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "log" + "nhooyr.io/websocket/internal/errd" + "strings" + "sync/atomic" + "time" +) + +// Reader waits until there is a WebSocket data message to read +// from the connection. +// It returns the type of the message and a reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. +// +// All returned errors will cause the connection +// to be closed so you do not need to write your own error message. +// This applies to the Read methods in the wsjson/wspb subpackages as well. +// +// You must read from the connection for control frames to be handled. +// Thus if you expect messages to take a long time to be responded to, +// you should handle such messages async to reading from the connection +// to ensure control frames are promptly handled. +// +// If you do not expect any data messages from the peer, call CloseRead. +// +// Only one Reader may be open at a time. +// +// If you need a separate timeout on the Reader call and then the message +// Read, use time.AfterFunc to cancel the context passed in early. +// See https://github.com/nhooyr/websocket/issues/87#issue-451703332 +// Most users should not need this. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, r, err := c.cr.reader(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to get reader: %w", err) + } + return typ, r, nil +} + +// Read is a convenience method to read a single message from the connection. +// +// See the Reader method to reuse buffers or for streaming. +// The docs on Reader apply to this method as well. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err + } + + b, err := ioutil.ReadAll(r) + return typ, b, err +} + +// CloseRead will start a goroutine to read from the connection until it is closed or a data message +// is received. If a data message is received, the connection will be closed with StatusPolicyViolation. +// Since CloseRead reads from the connection, it will respond to ping, pong and close frames. +// After calling this method, you cannot read any data messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// Use this when you do not want to read data messages from the connection anymore but will +// want to write messages to it. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.Reader(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + c.cr.mr.lr.limit.Store(n) +} + +type connReader struct { + c *Conn + br *bufio.Reader + timeout chan context.Context + + mu mu + controlPayloadBuf [maxControlPayload]byte + mr *msgReader +} + +func (cr *connReader) init(c *Conn, br *bufio.Reader) { + cr.c = c + cr.br = br + cr.timeout = make(chan context.Context) + + cr.mr = &msgReader{ + cr: cr, + fin: true, + } + + cr.mr.lr = newLimitReader(c, readerFunc(cr.mr.read), 32768) + if c.deflateNegotiated() && cr.contextTakeover() { + cr.ensureFlateReader() + } +} + +func (cr *connReader) ensureFlateReader() { + cr.mr.fr = getFlateReader(readerFunc(cr.mr.read)) + cr.mr.lr.reset(cr.mr.fr) +} + +func (cr *connReader) close() { + cr.mu.Lock(context.Background()) + if cr.c.client { + putBufioReader(cr.br) + } + if cr.c.deflateNegotiated() && cr.contextTakeover() { + putFlateReader(cr.mr.fr) + } +} + +func (cr *connReader) contextTakeover() bool { + if cr.c.client { + return cr.c.copts.serverNoContextTakeover + } + return cr.c.copts.clientNoContextTakeover +} + +func (cr *connReader) rsv1Illegal(h header) bool { + // If compression is enabled, rsv1 is always illegal. + if !cr.c.deflateNegotiated() { + return true + } + // rsv1 is only allowed on data frames beginning messages. + if h.opcode != opText && h.opcode != opBinary { + return true + } + return false +} + +func (cr *connReader) loop(ctx context.Context) (header, error) { + for { + h, err := cr.frameHeader(ctx) + if err != nil { + return header{}, err + } + + if h.rsv1 && cr.rsv1Illegal(h) || h.rsv2 || h.rsv3 { + err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + cr.c.cw.error(StatusProtocolError, err) + return header{}, err + } + + if !cr.c.client && !h.masked { + return header{}, errors.New("received unmasked frame from client") + } + + switch h.opcode { + case opClose, opPing, opPong: + err = cr.control(ctx, h) + if err != nil { + // Pass through CloseErrors when receiving a close frame. + if h.opcode == opClose && CloseStatus(err) != -1 { + return header{}, err + } + return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) + } + case opContinuation, opText, opBinary: + return h, nil + default: + err := fmt.Errorf("received unknown opcode %v", h.opcode) + cr.c.cw.error(StatusProtocolError, err) + return header{}, err + } + } +} + +func (cr *connReader) frameHeader(ctx context.Context) (header, error) { + select { + case <-cr.c.closed: + return header{}, cr.c.closeErr + case cr.timeout <- ctx: + } + + h, err := readFrameHeader(cr.br) + if err != nil { + select { + case <-cr.c.closed: + return header{}, cr.c.closeErr + case <-ctx.Done(): + return header{}, ctx.Err() + default: + cr.c.close(err) + return header{}, err + } + } + + select { + case <-cr.c.closed: + return header{}, cr.c.closeErr + case cr.timeout <- context.Background(): + } + + return h, nil +} + +func (cr *connReader) framePayload(ctx context.Context, p []byte) (int, error) { + select { + case <-cr.c.closed: + return 0, cr.c.closeErr + case cr.timeout <- ctx: + } + + n, err := io.ReadFull(cr.br, p) + if err != nil { + select { + case <-cr.c.closed: + return n, cr.c.closeErr + case <-ctx.Done(): + return n, ctx.Err() + default: + err = fmt.Errorf("failed to read frame payload: %w", err) + cr.c.close(err) + return n, err + } + } + + select { + case <-cr.c.closed: + return n, cr.c.closeErr + case cr.timeout <- context.Background(): + } + + return n, err +} + +func (cr *connReader) control(ctx context.Context, h header) error { + if h.payloadLength < 0 { + err := fmt.Errorf("received header with negative payload length: %v", h.payloadLength) + cr.c.cw.error(StatusProtocolError, err) + return err + } + + if h.payloadLength > maxControlPayload { + err := fmt.Errorf("received too big control frame at %v bytes", h.payloadLength) + cr.c.cw.error(StatusProtocolError, err) + return err + } + + if !h.fin { + err := errors.New("received fragmented control frame") + cr.c.cw.error(StatusProtocolError, err) + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + b := cr.controlPayloadBuf[:h.payloadLength] + _, err := cr.framePayload(ctx, b) + if err != nil { + return err + } + + if h.masked { + mask(h.maskKey, b) + } + + switch h.opcode { + case opPing: + return cr.c.cw.control(ctx, opPong, b) + case opPong: + cr.c.activePingsMu.Lock() + pong, ok := cr.c.activePings[string(b)] + cr.c.activePingsMu.Unlock() + if ok { + close(pong) + } + return nil + } + + ce, err := parseClosePayload(b) + if err != nil { + err = fmt.Errorf("received invalid close payload: %w", err) + cr.c.cw.error(StatusProtocolError, err) + return err + } + + err = fmt.Errorf("received close frame: %w", ce) + cr.c.setCloseErr(err) + cr.c.cw.control(context.Background(), opClose, ce.bytes()) + return err +} + +func (cr *connReader) reader(ctx context.Context) (MessageType, io.Reader, error) { + err := cr.mu.Lock(ctx) + if err != nil { + return 0, nil, err + } + defer cr.mu.Unlock() + + if !cr.mr.fin { + return 0, nil, errors.New("previous message not read to completion") + } + + h, err := cr.loop(ctx) + if err != nil { + return 0, nil, err + } + + if h.opcode == opContinuation { + err := errors.New("received continuation frame without text or binary frame") + cr.c.cw.error(StatusProtocolError, err) + return 0, nil, err + } + + cr.mr.reset(ctx, h) + + return MessageType(h.opcode), cr.mr, nil +} + +type msgReader struct { + cr *connReader + fr io.Reader + lr *limitReader + + ctx context.Context + + deflate bool + deflateTail strings.Reader + + payloadLength int64 + maskKey uint32 + fin bool +} + +func (mr *msgReader) reset(ctx context.Context, h header) { + mr.ctx = ctx + mr.deflate = h.rsv1 + if mr.deflate { + mr.deflateTail.Reset(deflateMessageTail) + if !mr.cr.contextTakeover() { + mr.cr.ensureFlateReader() + } + } + mr.setFrame(h) + mr.fin = false +} + +func (mr *msgReader) setFrame(h header) { + mr.payloadLength = h.payloadLength + mr.maskKey = h.maskKey + mr.fin = h.fin +} + +func (mr *msgReader) Read(p []byte) (_ int, err error) { + defer func() { + errd.Wrap(&err, "failed to read") + if errors.Is(err, io.EOF) { + err = io.EOF + } + }() + + err = mr.cr.mu.Lock(mr.ctx) + if err != nil { + return 0, err + } + defer mr.cr.mu.Unlock() + + if mr.payloadLength == 0 && mr.fin { + if mr.cr.c.deflateNegotiated() && !mr.cr.contextTakeover() { + if mr.fr != nil { + putFlateReader(mr.fr) + mr.fr = nil + } + } + return 0, io.EOF + } + + return mr.lr.Read(p) +} + +func (mr *msgReader) read(p []byte) (int, error) { + log.Println("compress", mr.deflate) + + if mr.payloadLength == 0 { + h, err := mr.cr.loop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := errors.New("received new data message without finishing the previous message") + mr.cr.c.cw.error(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + } + + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } + + n, err := mr.cr.framePayload(mr.ctx, p) + if err != nil { + return n, err + } + + mr.payloadLength -= int64(n) + + if !mr.cr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } + + return n, nil +} + +type limitReader struct { + c *Conn + r io.Reader + limit atomicInt64 + n int64 +} + +func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { + lr := &limitReader{ + c: c, + } + lr.limit.Store(limit) + lr.reset(r) + return lr +} + +func (lr *limitReader) reset(r io.Reader) { + lr.n = lr.limit.Load() + lr.r = r +} + +func (lr *limitReader) Read(p []byte) (int, error) { + if lr.n <= 0 { + err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) + lr.c.cw.error(StatusMessageTooBig, err) + return 0, err + } + + if int64(len(p)) > lr.n { + p = p[:lr.n] + } + n, err := lr.r.Read(p) + lr.n -= int64(n) + return n, err +} + +type atomicInt64 struct { + i atomic.Value +} + +func (v *atomicInt64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +func (v *atomicInt64) Store(i int64) { + v.i.Store(i) +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/reader.go b/reader.go deleted file mode 100644 index fe71656..0000000 --- a/reader.go +++ /dev/null @@ -1,31 +0,0 @@ -package websocket - -import ( - "bufio" - "context" - "io" - "nhooyr.io/websocket/internal/atomicint" - "nhooyr.io/websocket/internal/wsframe" - "strings" -) - -type reader struct { - // Acquired before performing any sort of read operation. - readLock chan struct{} - - c *Conn - - deflateReader io.Reader - br *bufio.Reader - - readClosed *atomicint.Int64 - readHeaderBuf []byte - controlPayloadBuf []byte - - msgCtx context.Context - msgCompressed bool - frameHeader wsframe.Header - frameMaskKey uint32 - frameEOF bool - deflateTail strings.Reader -} diff --git a/write.go b/write.go new file mode 100644 index 0000000..5bb489b --- /dev/null +++ b/write.go @@ -0,0 +1,348 @@ +package websocket + +import ( + "bufio" + "compress/flate" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "nhooyr.io/websocket/internal/errd" + "time" +) + +// Writer returns a writer bounded by the context that will write +// a WebSocket message of type dataType to the connection. +// +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. +// +// Never close the returned writer twice. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + w, err := c.cw.writer(ctx, typ) + if err != nil { + return nil, fmt.Errorf("failed to get writer: %w", err) + } + return w, nil +} + +// Write writes a message to the connection. +// +// See the Writer method if you want to stream a message. +// +// If compression is disabled, then it is guaranteed to write the message +// in a single frame. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + _, err := c.cw.write(ctx, typ, p) + if err != nil { + return fmt.Errorf("failed to write msg: %w", err) + } + return nil +} + +type connWriter struct { + c *Conn + bw *bufio.Writer + + writeBuf []byte + + mw *messageWriter + frameMu mu + h header + + timeout chan context.Context +} + +func (cw *connWriter) init(c *Conn, bw *bufio.Writer) { + cw.c = c + cw.bw = bw + + if cw.c.client { + cw.writeBuf = extractBufioWriterBuf(cw.bw, c.rwc) + } + + cw.timeout = make(chan context.Context) + + cw.mw = &messageWriter{ + cw: cw, + } + cw.mw.tw = &trimLastFourBytesWriter{ + w: writerFunc(cw.mw.write), + } + if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { + cw.mw.ensureFlateWriter() + } +} + +func (mw *messageWriter) ensureFlateWriter() { + mw.fw = getFlateWriter(mw.tw) +} + +func (cw *connWriter) close() { + if cw.c.client { + cw.frameMu.Lock(context.Background()) + putBufioWriter(cw.bw) + } + if cw.c.deflateNegotiated() && cw.mw.contextTakeover() { + cw.mw.mu.Lock(context.Background()) + putFlateWriter(cw.mw.fw) + } +} + +func (mw *messageWriter) contextTakeover() bool { + if mw.cw.c.client { + return mw.cw.c.copts.clientNoContextTakeover + } + return mw.cw.c.copts.serverNoContextTakeover +} + +func (cw *connWriter) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := cw.mw.reset(ctx, typ) + if err != nil { + return nil, err + } + return cw.mw, nil +} + +func (cw *connWriter) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + ww, err := cw.writer(ctx, typ) + if err != nil { + return 0, err + } + + if !cw.c.deflateNegotiated() { + // Fast single frame path. + defer cw.mw.mu.Unlock() + return cw.frame(ctx, true, cw.mw.opcode, p) + } + + n, err := ww.Write(p) + if err != nil { + return n, err + } + + err = ww.Close() + return n, err +} + +type messageWriter struct { + cw *connWriter + + mu mu + compress bool + tw *trimLastFourBytesWriter + fw *flate.Writer + ctx context.Context + opcode opcode + closed bool +} + +func (mw *messageWriter) reset(ctx context.Context, typ MessageType) error { + err := mw.mu.Lock(ctx) + if err != nil { + return err + } + + mw.closed = false + mw.ctx = ctx + mw.opcode = opcode(typ) + return nil +} + +// Write writes the given bytes to the WebSocket connection. +func (mw *messageWriter) Write(p []byte) (_ int, err error) { + defer errd.Wrap(&err, "failed to write") + + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + + if mw.cw.c.deflateNegotiated() { + if !mw.compress { + if !mw.contextTakeover() { + mw.ensureFlateWriter() + } + mw.tw.reset() + mw.compress = true + } + + return mw.fw.Write(p) + } + + return mw.write(p) +} + +func (mw *messageWriter) write(p []byte) (int, error) { + n, err := mw.cw.frame(mw.ctx, false, mw.opcode, p) + if err != nil { + return n, fmt.Errorf("failed to write data frame: %w", err) + } + mw.opcode = opContinuation + return n, nil +} + +// Close flushes the frame to the connection. +// This must be called for every messageWriter. +func (mw *messageWriter) Close() (err error) { + defer errd.Wrap(&err, "failed to close writer") + + if mw.closed { + return errors.New("cannot use closed writer") + } + mw.closed = true + + if mw.cw.c.deflateNegotiated() { + err = mw.fw.Flush() + if err != nil { + return fmt.Errorf("failed to flush flate writer: %w", err) + } + } + + _, err = mw.cw.frame(mw.ctx, true, mw.opcode, nil) + if err != nil { + return fmt.Errorf("failed to write fin frame: %w", err) + } + + if mw.compress && !mw.contextTakeover() { + putFlateWriter(mw.fw) + mw.compress = false + } + + mw.mu.Unlock() + return nil +} + +func (cw *connWriter) control(ctx context.Context, opcode opcode, p []byte) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + _, err := cw.frame(ctx, true, opcode, p) + if err != nil { + return fmt.Errorf("failed to write control frame %v: %w", opcode, err) + } + return nil +} + +// frame handles all writes to the connection. +func (cw *connWriter) frame(ctx context.Context, fin bool, opcode opcode, p []byte) (int, error) { + err := cw.frameMu.Lock(ctx) + if err != nil { + return 0, err + } + defer cw.frameMu.Unlock() + + select { + case <-cw.c.closed: + return 0, cw.c.closeErr + case cw.timeout <- ctx: + } + + cw.h.fin = fin + cw.h.opcode = opcode + cw.h.masked = cw.c.client + cw.h.payloadLength = int64(len(p)) + + cw.h.rsv1 = false + if cw.mw.compress && (opcode == opText || opcode == opBinary) { + cw.h.rsv1 = true + } + + if cw.h.masked { + err = binary.Read(rand.Reader, binary.LittleEndian, &cw.h.maskKey) + if err != nil { + return 0, fmt.Errorf("failed to generate masking key: %w", err) + } + } + + err = writeFrameHeader(cw.h, cw.bw) + if err != nil { + return 0, err + } + + n, err := cw.framePayload(p) + if err != nil { + return n, err + } + + if cw.h.fin { + err = cw.bw.Flush() + if err != nil { + return n, fmt.Errorf("failed to flush: %w", err) + } + } + + select { + case <-cw.c.closed: + return n, cw.c.closeErr + case cw.timeout <- context.Background(): + } + + return n, nil +} + +func (cw *connWriter) framePayload(p []byte) (_ int, err error) { + defer errd.Wrap(&err, "failed to write frame payload") + + if !cw.h.masked { + return cw.bw.Write(p) + } + + var n int + maskKey := cw.h.maskKey + for len(p) > 0 { + // If the buffer is full, we need to flush. + if cw.bw.Available() == 0 { + err = cw.bw.Flush() + if err != nil { + return n, err + } + } + + // Start of next write in the buffer. + i := cw.bw.Buffered() + + j := len(p) + if j > cw.bw.Available() { + j = cw.bw.Available() + } + + _, err := cw.bw.Write(p[:j]) + if err != nil { + return n, err + } + + maskKey = mask(maskKey, cw.writeBuf[i:cw.bw.Buffered()]) + + p = p[j:] + n += j + } + + return n, nil +} + +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer +// and returns it. +func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { + var writeBuf []byte + bw.Reset(writerFunc(func(p2 []byte) (int, error) { + writeBuf = p2[:cap(p2)] + return len(p2), nil + })) + + bw.WriteByte(0) + bw.Flush() + + bw.Reset(w) + + return writeBuf +} diff --git a/writer.go b/writer.go deleted file mode 100644 index b31d57a..0000000 --- a/writer.go +++ /dev/null @@ -1,5 +0,0 @@ -package websocket - -type writer struct { - -} diff --git a/ws_js.go b/ws_js.go index 4c06743..10ce0da 100644 --- a/ws_js.go +++ b/ws_js.go @@ -9,7 +9,7 @@ import ( "fmt" "io" "net/http" - "nhooyr.io/websocket/internal/atomicint" + "nhooyr.io/websocket/internal/wssync" "reflect" "runtime" "sync" @@ -24,10 +24,10 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit *atomicint.Int64 + msgReadLimit *wssync.Int64 closingMu sync.Mutex - isReadClosed *atomicint.Int64 + isReadClosed *wssync.Int64 closeOnce sync.Once closed chan struct{} closeErrOnce sync.Once @@ -59,10 +59,10 @@ func (c *Conn) init() { c.closed = make(chan struct{}) c.readSignal = make(chan struct{}, 1) - c.msgReadLimit = &atomicint.Int64{} + c.msgReadLimit = &wssync.Int64{} c.msgReadLimit.Store(32768) - c.isReadClosed = &atomicint.Int64{} + c.isReadClosed = &wssync.Int64{} c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { err := CloseError{ @@ -105,7 +105,7 @@ func (c *Conn) closeWithInternal() { // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { if c.isReadClosed.Load() == 1 { - return 0, nil, fmt.Errorf("websocket connection read closed") + return 0, nil, errors.New("websocket connection read closed") } typ, p, err := c.read(ctx) diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 9fa8b54..e818805 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "fmt" + "log" "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bufpool" ) @@ -41,6 +42,7 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { err = json.Unmarshal(b.Bytes(), v) if err != nil { c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") + log.Printf("%X", b.Bytes()) return fmt.Errorf("failed to unmarshal json: %w", err) } -- GitLab