From f685c8d74181ad7f4c8023e736327c8bd55c5aa5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Wed, 17 Apr 2019 18:50:24 -0400 Subject: [PATCH] Improve speed and add a benchmark --- accept.go | 8 +- bench_test.go | 77 +++++++++++++ example_test.go | 7 +- json.go | 9 +- statuscode.go | 10 +- websocket.go | 268 ++++++++++++++++++++++++---------------------- websocket_test.go | 67 +++++++----- 7 files changed, 282 insertions(+), 164 deletions(-) create mode 100644 bench_test.go diff --git a/accept.go b/accept.go index 3120690..e0c31ef 100644 --- a/accept.go +++ b/accept.go @@ -53,19 +53,19 @@ func AcceptInsecureOrigin() AcceptOption { func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { if !headerValuesContainsToken(r.Header, "Connection", "Upgrade") { - err := xerrors.Errorf("websocket: protocol violation: Connection header does not contain Upgrade: %q", r.Header.Get("Connection")) + err := xerrors.Errorf("websocket: protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) http.Error(w, err.Error(), http.StatusBadRequest) return err } if !headerValuesContainsToken(r.Header, "Upgrade", "WebSocket") { - err := xerrors.Errorf("websocket: protocol violation: Upgrade header does not contain websocket: %q", r.Header.Get("Upgrade")) + err := xerrors.Errorf("websocket: protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) http.Error(w, err.Error(), http.StatusBadRequest) return err } if r.Method != "GET" { - err := xerrors.Errorf("websocket: protocol violation: handshake request method is not GET: %q", r.Method) + err := xerrors.Errorf("websocket: protocol violation: handshake request method %q is not GET", r.Method) http.Error(w, err.Error(), http.StatusBadRequest) return err } @@ -88,7 +88,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to WebSocket. // Accept will reject the handshake if the Origin is not the same as the Host unless -// InsecureAcceptOrigin is passed. +// the AcceptInsecureOrigin option is passed. // Accept uses w to write the handshake response so the timeouts on the http.Server apply. func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn, error) { var subprotocols []string diff --git a/bench_test.go b/bench_test.go new file mode 100644 index 0000000..f5b5b21 --- /dev/null +++ b/bench_test.go @@ -0,0 +1,77 @@ +package websocket_test + +import ( + "context" + "io" + "net/http" + "nhooyr.io/websocket" + "strings" + "testing" + "time" +) + +func BenchmarkConn(b *testing.B) { + b.StopTimer() + + s, closeFn := testServer(b, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, + websocket.AcceptSubprotocols("echo"), + ) + if err != nil { + b.Logf("server handshake failed: %+v", err) + return + } + echoLoop(r.Context(), c) + })) + defer closeFn() + + wsURL := strings.Replace(s.URL, "http", "ws", 1) + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute*5) + defer cancel() + + c, _, err := websocket.Dial(ctx, wsURL) + if err != nil { + b.Fatalf("failed to dial: %v", err) + } + defer c.Close(websocket.StatusInternalError, "") + + msg := strings.Repeat("2", 4096*16) + buf := make([]byte, len(msg)) + b.SetBytes(int64(len(msg))) + b.StartTimer() + for i := 0; i < b.N; i++ { + w, err := c.Write(ctx, websocket.MessageText) + if err != nil { + b.Fatal(err) + } + + _, err = io.WriteString(w, msg) + if err != nil { + b.Fatal(err) + } + + err = w.Close() + if err != nil { + b.Fatal(err) + } + + _, r, err := c.Read(ctx) + if err != nil { + b.Fatal(err, b.N) + } + + _, err = io.ReadFull(r, buf) + if err != nil { + b.Fatal(err) + } + + // TODO jank + _, err = r.Read(nil) + if err != io.EOF { + b.Fatalf("wtf %q", err) + } + } + b.StopTimer() + c.Close(websocket.StatusNormalClosure, "") +} diff --git a/example_test.go b/example_test.go index 702239b..85cd3aa 100644 --- a/example_test.go +++ b/example_test.go @@ -34,10 +34,13 @@ func ExampleAccept_echo() { if err != nil { return err } - r = io.LimitReader(r, 32768) - w := c.Write(ctx, typ) + w, err := c.Write(ctx, typ) + if err != nil { + return err + } + _, err = io.Copy(w, r) if err != nil { return err diff --git a/json.go b/json.go index 24e6f31..0d85a5d 100644 --- a/json.go +++ b/json.go @@ -22,7 +22,7 @@ func (jc JSONConn) Read(ctx context.Context, v interface{}) error { return nil } -func (jc *JSONConn) read(ctx context.Context, v interface{}) error { +func (jc JSONConn) read(ctx context.Context, v interface{}) error { typ, r, err := jc.Conn.Read(ctx) if err != nil { return err @@ -53,10 +53,13 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w := jc.Conn.Write(ctx, MessageText) + w, err := jc.Conn.Write(ctx, MessageText) + if err != nil { + return xerrors.Errorf("failed to get message writer: %w", err) + } e := json.NewEncoder(w) - err := e.Encode(v) + err = e.Encode(v) if err != nil { return xerrors.Errorf("failed to encode json: %w", err) } diff --git a/statuscode.go b/statuscode.go index 2f4f2c0..d742195 100644 --- a/statuscode.go +++ b/statuscode.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "math/bits" - "unicode/utf8" "golang.org/x/xerrors" ) @@ -54,6 +53,12 @@ func (ce CloseError) Error() string { } func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + if len(p) < 2 { return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code") } @@ -63,9 +68,6 @@ func parseClosePayload(p []byte) (CloseError, error) { Reason: string(p[2:]), } - if !utf8.ValidString(ce.Reason) { - return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason) - } if !validWireCloseCode(ce.Code) { return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code) } diff --git a/websocket.go b/websocket.go index 52b5d8d..52f42dc 100644 --- a/websocket.go +++ b/websocket.go @@ -5,8 +5,10 @@ import ( "context" "fmt" "io" + "log" "runtime" "sync" + "sync/atomic" "time" "golang.org/x/xerrors" @@ -34,6 +36,11 @@ type Conn struct { // Writers should send on write to begin sending // a message and then follow that up with some data // on writeBytes. + // Send on control to write a control message. + // writeDone will be sent back when the message is written + // Close writeBytes to flush the message and wait for a + // ping on writeDone. // TODO should I care about this allocation? + // writeDone will be closed if the data message write errors. write chan MessageType control chan control writeBytes chan []byte @@ -42,17 +49,17 @@ type Conn struct { // Readers should receive on read to begin reading a message. // Then send a byte slice to readBytes to read into it. // The n of bytes read will be sent on readDone once the read into a slice is complete. - // readDone will receive 0 when EOF is reached. - read chan opcode - readBytes chan []byte - readDone chan int - readerDone chan struct{} + // readDone will be closed if the read fails. + // readInProgress will be set to 0 on io.EOF. + activeReader int64 + inMsg bool + read chan opcode + readBytes chan []byte + readDone chan int } func (c *Conn) close(err error) { - if err != nil { - err = xerrors.Errorf("websocket: connection broken: %w", err) - } + err = xerrors.Errorf("websocket: connection broken: %w", err) c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) @@ -76,13 +83,14 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) + c.write = make(chan MessageType) c.control = make(chan control) c.writeDone = make(chan struct{}) + c.read = make(chan opcode) - c.readDone = make(chan int) c.readBytes = make(chan []byte) - c.readerDone = make(chan struct{}) + c.readDone = make(chan int) runtime.SetFinalizer(c, func(c *Conn) { c.Close(StatusInternalError, "websocket: connection ended up being garbage collected") @@ -116,6 +124,8 @@ func (c *Conn) writeFrame(h header, p []byte) { } func (c *Conn) writeLoop() { + defer close(c.writeDone) + messageLoop: for { c.writeBytes = make(chan []byte) @@ -173,6 +183,10 @@ messageLoop: } firstSent = true + if c.client { + log.Printf("client %#v", h) + } + c.writeFrame(h, b) if !ok { @@ -225,17 +239,15 @@ func (c *Conn) handleControl(h header) { c.writePong(b) case opPong: case opClose: - if len(b) > 0 { - ce, err := parseClosePayload(b) - if err != nil { - c.close(xerrors.Errorf("read invalid close payload: %w", err)) - return - } - c.Close(ce.Code, ce.Reason) + ce, err := parseClosePayload(b) + if err != nil { + c.close(xerrors.Errorf("read invalid close payload: %w", err)) + return + } + if ce.Code == StatusNoStatusRcvd { + c.writeClose(nil, ce) } else { - c.writeClose(nil, CloseError{ - Code: StatusNoStatusRcvd, - }) + c.Close(ce.Code, ce.Reason) } default: panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h)) @@ -243,7 +255,8 @@ func (c *Conn) handleControl(h header) { } func (c *Conn) readLoop() { - var indata bool + defer close(c.readDone) + for { h, err := readHeader(c.br) if err != nil { @@ -251,6 +264,10 @@ func (c *Conn) readLoop() { return } + if !c.client { + log.Printf("%#v", h) + } + if h.rsv1 || h.rsv2 || h.rsv3 { c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) return @@ -263,19 +280,19 @@ func (c *Conn) readLoop() { switch h.opcode { case opBinary, opText: - if !indata { - select { - case <-c.closed: - return - case c.read <- h.opcode: - } - indata = true - } else { - c.Close(StatusProtocolError, "cannot send data frame when previous frame is not finished") + if c.inMsg { + c.Close(StatusProtocolError, "cannot read data frame when previous frame is not finished") + return + } + + select { + case <-c.closed: return + case c.read <- h.opcode: + c.inMsg = true } case opContinuation: - if !indata { + if !c.inMsg { c.Close(StatusProtocolError, "continuation frame not after data or text frame") return } @@ -284,47 +301,55 @@ func (c *Conn) readLoop() { return } - maskPos := 0 - left := h.payloadLength - firstRead := false - for left > 0 || !firstRead { - select { - case <-c.closed: - return - case b := <-c.readBytes: - if int64(len(b)) > left { - b = b[:left] - } + err = c.dataReadLoop(h) + if err != nil { + c.close(xerrors.Errorf("failed to read from connection: %w", err)) + return + } + } +} - _, err = io.ReadFull(c.br, b) - if err != nil { - c.close(xerrors.Errorf("failed to read from connection: %w", err)) - return - } - left -= int64(len(b)) +func (c *Conn) dataReadLoop(h header) (err error) { + maskPos := 0 + left := h.payloadLength + firstReadDone := false + for left > 0 || !firstReadDone { + select { + case <-c.closed: + return c.closeErr + case b := <-c.readBytes: + if int64(len(b)) > left { + b = b[:left] + } - if h.masked { - maskPos = mask(h.maskKey, maskPos, b) - } + _, err := io.ReadFull(c.br, b) + if err != nil { + return xerrors.Errorf("failed to read from connection: %w", err) + } + left -= int64(len(b)) - select { - case <-c.closed: - return - case c.readDone <- len(b): - firstRead = true - } + if h.masked { + maskPos = mask(h.maskKey, maskPos, b) + } + + // Must set this before we signal the read is done. + // The reader will use this to return io.EOF and + // c.Read will use it to check if the reader has been completed. + if left == 0 && h.fin { + atomic.StoreInt64(&c.activeReader, 0) + c.inMsg = false } - } - if h.fin { - indata = false select { case <-c.closed: - return - case c.readerDone <- struct{}{}: + return c.closeErr + case c.readDone <- len(b): + firstReadDone = true } } } + + return nil } func (c *Conn) writePong(p []byte) error { @@ -404,76 +429,57 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error } // Write returns a writer bounded by the context that will write -// a WebSocket data frame of type dataType to the connection. -// Ensure you close the messageWriter once you have written to entire message. -// Concurrent calls to messageWriter are ok. -func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser { - // TODO acquire write here, move state into Conn and make messageWriter allocation free. - return &messageWriter{ - c: c, - ctx: ctx, - datatype: dataType, +// a WebSocket message of type dataType to the connection. +// Ensure you close the writer once you have written the entire message. +// Concurrent calls to Write are ok. +func (c *Conn) Write(ctx context.Context, dataType MessageType) (io.WriteCloser, error) { + select { + case <-c.closed: + return nil, c.closeErr + case <-ctx.Done(): + return nil, ctx.Err() + case c.write <- dataType: + return messageWriter{ + ctx: ctx, + c: c, + }, nil } } // messageWriter enables writing to a WebSocket connection. -// Ensure you close the messageWriter once you have written to entire message. type messageWriter struct { - datatype MessageType - ctx context.Context - c *Conn - acquiredLock bool + ctx context.Context + c *Conn } // Write writes the given bytes to the WebSocket connection. // The frame will automatically be fragmented as appropriate // with the buffers obtained from http.Hijacker. // Please ensure you call Close once you have written the full message. -func (w *messageWriter) Write(p []byte) (int, error) { - err := w.acquire() - if err != nil { - return 0, err - } - +func (w messageWriter) Write(p []byte) (int, error) { select { case <-w.c.closed: return 0, w.c.closeErr case w.c.writeBytes <- p: select { - case <-w.c.closed: - return 0, w.c.closeErr - case <-w.c.writeDone: - return len(p), nil case <-w.ctx.Done(): + w.c.close(xerrors.Errorf("write timed out: %w", w.ctx.Err())) + <-w.c.readDone return 0, w.ctx.Err() + case _, ok := <-w.c.writeDone: + if !ok { + return 0, w.c.closeErr + } + return len(p), nil } case <-w.ctx.Done(): return 0, w.ctx.Err() } } -func (w *messageWriter) acquire() error { - if !w.acquiredLock { - select { - case <-w.c.closed: - return w.c.closeErr - case w.c.write <- w.datatype: - w.acquiredLock = true - case <-w.ctx.Done(): - return w.ctx.Err() - } - } - return nil -} - // Close flushes the frame to the connection. // This must be called for every messageWriter. -func (w *messageWriter) Close() error { - err := w.acquire() - if err != nil { - return err - } - +func (w messageWriter) Close() error { close(w.c.writeBytes) select { case <-w.c.closed: @@ -485,26 +491,28 @@ func (w *messageWriter) Close() error { } } -// ReadMessage will wait until there is a WebSocket data frame to read from the connection. -// It returns the type of the data, a reader to read it and also an error. -// Please use SetContext on the reader to bound the read operation. +// ReadMessage will wait until there is a WebSocket data message to read from the connection. +// It returns the type of the message and a reader to read it. +// The passed context will also bound the reader. // Your application must keep reading messages for the Conn to automatically respond to ping -// and close frames. +// and close frames and not become stuck waiting for a data message to be read. +// Please ensure to read the full message from io.Reader. +// You can only read a single message at a time. func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) { - // TODO error if the reader is not done + if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { + return 0, nil, xerrors.New("websocket: previous message not fully read") + } + select { - case <-c.readerDone: - // The previous reader just hit a io.EOF, we handle it for users - return c.Read(ctx) case <-c.closed: - return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr) + return 0, nil, xerrors.Errorf("websocket: failed to read message: %w", c.closeErr) case opcode := <-c.read: return MessageType(opcode), messageReader{ ctx: ctx, c: c, }, nil case <-ctx.Done(): - return 0, nil, xerrors.Errorf("failed to read message: %w", ctx.Err()) + return 0, nil, xerrors.Errorf("websocket: failed to read message: %w", ctx.Err()) } } @@ -518,30 +526,38 @@ type messageReader struct { func (r messageReader) Read(p []byte) (int, error) { n, err := r.read(p) if err != nil { - // Have to return io.EOF directly for now. + // Have to return io.EOF directly for now, cannot wrap. if err == io.EOF { - return 0, io.EOF + return n, io.EOF } return n, xerrors.Errorf("failed to read: %w", err) } return n, nil } -func (r messageReader) read(p []byte) (int, error) { +func (r messageReader) read(p []byte) (_ int, err error) { + if atomic.LoadInt64(&r.c.activeReader) == 0 { + return 0, io.EOF + } + select { case <-r.c.closed: return 0, r.c.closeErr - case <-r.c.readerDone: - return 0, io.EOF case r.c.readBytes <- p: - // TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes. select { - case <-r.c.closed: - return 0, r.c.closeErr - case n := <-r.c.readDone: - return n, nil case <-r.ctx.Done(): + r.c.close(xerrors.Errorf("read timed out: %w", err)) + // Wait for readloop to complete so we know p is done. + <-r.c.readDone return 0, r.ctx.Err() + case n, ok := <-r.c.readDone: + if !ok { + return 0, r.c.closeErr + } + if atomic.LoadInt64(&r.c.activeReader) == 0 { + return n, io.EOF + } + return n, nil } case <-r.ctx.Done(): return 0, r.ctx.Err() diff --git a/websocket_test.go b/websocket_test.go index 868b69a..dba0182 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "io/ioutil" + "log" "net/http" "net/http/cookiejar" "net/http/httptest" @@ -292,29 +293,14 @@ func TestHandshake(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() - var conns int64 - s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - atomic.AddInt64(&conns, 1) - defer atomic.AddInt64(&conns, -1) - + s, closeFn := testServer(t, func(w http.ResponseWriter, r *http.Request) { err := tc.server(w, r) if err != nil { t.Errorf("server failed: %+v", err) return } - })) - defer func() { - s.Close() - - ctx, cancel := context.WithTimeout(context.Background(), time.Minute) - defer cancel() - - for atomic.LoadInt64(&conns) > 0 { - if ctx.Err() != nil { - t.Fatalf("waiting for server to come down timed out: %v", ctx.Err()) - } - } - }() + }) + defer closeFn() wsURL := strings.Replace(s.URL, "http", "ws", 1) @@ -329,6 +315,28 @@ func TestHandshake(t *testing.T) { } } +func testServer(tb testing.TB, fn http.HandlerFunc) (s *httptest.Server, closeFn func()) { + var conns int64 + s = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt64(&conns, 1) + defer atomic.AddInt64(&conns, -1) + + fn.ServeHTTP(w, r) + })) + return s, func() { + s.Close() + + ctx, cancel := context.WithTimeout(context.Background(), time.Minute) + defer cancel() + + for atomic.LoadInt64(&conns) > 0 { + if ctx.Err() != nil { + tb.Fatalf("waiting for server to come down timed out: %v", ctx.Err()) + } + } + } +} + // https://github.com/crossbario/autobahn-python/tree/master/wstest func TestAutobahnServer(t *testing.T) { t.Parallel() @@ -341,7 +349,7 @@ func TestAutobahnServer(t *testing.T) { t.Logf("server handshake failed: %+v", err) return } - echoLoop(r.Context(), c, t) + echoLoop(r.Context(), c) })) defer s.Close() @@ -354,7 +362,7 @@ func TestAutobahnServer(t *testing.T) { }, }, "cases": []string{"*"}, - "exclude-cases": []string{"6.*", "12.*", "13.*"}, + "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } specFile, err := ioutil.TempFile("", "websocket_fuzzingclient.json") if err != nil { @@ -388,11 +396,11 @@ func TestAutobahnServer(t *testing.T) { checkWSTestIndex(t, "./wstest_reports/server/index.json") } -func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { +func echoLoop(ctx context.Context, c *websocket.Conn) { defer c.Close(websocket.StatusInternalError, "") echo := func() error { - ctx, cancel := context.WithTimeout(ctx, time.Second*30) + ctx, cancel := context.WithTimeout(ctx, time.Minute) defer cancel() typ, r, err := c.Read(ctx) @@ -400,7 +408,13 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { return err } - w := c.Write(ctx, typ) + w, err := c.Write(ctx, typ) + if err != nil { + return err + } + + b1, _ := ioutil.ReadAll(r) + log.Printf("%q", b1) _, err = io.Copy(w, r) if err != nil { @@ -415,11 +429,14 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { return nil } + var i int for { err := echo() if err != nil { + log.Println("WTF", err, i) return } + i++ } } @@ -431,7 +448,7 @@ func TestAutobahnClient(t *testing.T) { "url": "ws://localhost:9001", "outdir": "wstest_reports/client", "cases": []string{"*"}, - "exclude-cases": []string{"6.*", "12.*", "13.*"}, + "exclude-cases": []string{"6.*", "7.5.1", "12.*", "13.*"}, } specFile, err := ioutil.TempFile("", "websocket_fuzzingserver.json") if err != nil { @@ -507,7 +524,7 @@ func TestAutobahnClient(t *testing.T) { if err != nil { t.Fatalf("failed to dial: %v", err) } - echoLoop(ctx, c, t) + echoLoop(ctx, c) }() } -- GitLab