From b39ca873380498fd7ac2bb1d9fa221404cf90da8 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Wed, 29 May 2019 23:21:55 -0400 Subject: [PATCH] Fix bugs and improve docs --- README.md | 6 +-- accept.go | 9 +++- dial.go | 43 ++++++++++++++++-- example_echo_test.go | 5 ++- go.mod | 2 +- go.sum | 2 + statuscode.go | 2 +- websocket.go | 101 +++++++++++++++++++++++-------------------- websocket_test.go | 14 +++--- wsjson/wsjson.go | 14 +++++- 10 files changed, 128 insertions(+), 70 deletions(-) diff --git a/README.md b/README.md index 3f42742..1b9af61 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ go get nhooyr.io/websocket@v0.2.0 - Zero dependencies outside of the stdlib for the core library - JSON and ProtoBuf helpers in the wsjson and wspb subpackages - High performance -- Concurrent writes +- Concurrent reads and writes out of the box ## Roadmap @@ -122,8 +122,8 @@ also uses net/http's Client and ResponseWriter directly for WebSocket handshakes gorilla/websocket writes its handshakes to the underlying net.Conn which means it has to reinvent hooks for TLS and proxies and prevents support of HTTP/2. -Some more advantages of nhooyr/websocket are that it supports concurrent writes and makes it -very easy to close the connection with a status code and reason. +Some more advantages of nhooyr/websocket are that it supports concurrent reads, +writes and makes it very easy to close the connection with a status code and reason. In terms of performance, there is no significant difference between the two. Will update with benchmarks soon ([#75](https://github.com/nhooyr/websocket/issues/75)). diff --git a/accept.go b/accept.go index 207ecc7..17016d2 100644 --- a/accept.go +++ b/accept.go @@ -1,8 +1,10 @@ package websocket import ( + "bytes" "crypto/sha1" "encoding/base64" + "io" "net/http" "net/textproto" "net/url" @@ -78,6 +80,9 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // // Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. +// +// The returned connection will be bound by r.Context(). Use c.Context() to change +// the bounding context. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) if err != nil { @@ -126,6 +131,9 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, return nil, err } + b, _ := brw.Reader.Peek(brw.Reader.Buffered()) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) + c := &Conn{ subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), br: brw.Reader, @@ -133,7 +141,6 @@ func accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, closer: netConn, } c.init() - // TODO document. c.Context(r.Context()) return c, nil diff --git a/dial.go b/dial.go index 3c7e71d..f1ad725 100644 --- a/dial.go +++ b/dial.go @@ -5,13 +5,13 @@ import ( "bytes" "context" "encoding/base64" + "golang.org/x/xerrors" "io" "io/ioutil" "net/http" "net/url" "strings" - - "golang.org/x/xerrors" + "sync" ) // DialOptions represents the options available to pass to Dial. @@ -112,8 +112,8 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res c := &Conn{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), - br: bufio.NewReader(rwc), - bw: bufio.NewWriter(rwc), + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), closer: rwc, client: true, } @@ -140,3 +140,38 @@ func verifyServerResponse(resp *http.Response) error { return nil } + +// The below pools can only be used by the client because http.Hijacker will always +// have a bufio.Reader/Writer for us so it doesn't make sense to use a pool on top. + +var bufioReaderPool = sync.Pool{ + New: func() interface{} { + return bufio.NewReader(nil) + }, +} + +func getBufioReader(r io.Reader) *bufio.Reader { + br := bufioReaderPool.Get().(*bufio.Reader) + br.Reset(r) + return br +} + +func returnBufioReader(br *bufio.Reader) { + bufioReaderPool.Put(br) +} + +var bufioWriterPool = sync.Pool{ + New: func() interface{} { + return bufio.NewWriter(nil) + }, +} + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw := bufioWriterPool.Get().(*bufio.Writer) + bw.Reset(w) + return bw +} + +func returnBufioWriter(bw *bufio.Writer) { + bufioWriterPool.Put(bw) +} diff --git a/example_echo_test.go b/example_echo_test.go index ab0e8e7..a86d5b8 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -51,6 +51,7 @@ func Example_echo() { // Now we dial the server, send the messages and echo the responses. err = client("ws://" + l.Addr().String()) + time.Sleep(time.Second) if err != nil { log.Fatalf("client failed: %v", err) } @@ -66,6 +67,8 @@ func Example_echo() { // It ensures the client speaks the echo subprotocol and // only allows one message every 100ms with a 10 message burst. func echoServer(w http.ResponseWriter, r *http.Request) error { + log.Printf("serving %v", r.RemoteAddr) + c, err := websocket.Accept(w, r, websocket.AcceptOptions{ Subprotocols: []string{"echo"}, }) @@ -83,7 +86,7 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { for { err = echo(r.Context(), c, l) if err != nil { - return xerrors.Errorf("failed to echo: %w", err) + return xerrors.Errorf("failed to echo with %v: %w", r.RemoteAddr, err) } } } diff --git a/go.mod b/go.mod index f747eec..cc9a865 100644 --- a/go.mod +++ b/go.mod @@ -12,6 +12,6 @@ require ( golang.org/x/text v0.3.2 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 golang.org/x/tools v0.0.0-20190429184909-35c670923e21 - golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 + golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 mvdan.cc/sh v2.6.4+incompatible ) diff --git a/go.sum b/go.sum index 63aaa2a..90c9346 100644 --- a/go.sum +++ b/go.sum @@ -30,5 +30,7 @@ golang.org/x/tools v0.0.0-20190429184909-35c670923e21 h1:Kjcw+D2LTzLmxOHrMK9uvYP golang.org/x/tools v0.0.0-20190429184909-35c670923e21/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18 h1:1AGvnywFL1aB5KLRxyLseWJI6aSYPo3oF7HSpXdWQdU= golang.org/x/xerrors v0.0.0-20190315151331-d61658bd2e18/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522 h1:bhOzK9QyoD0ogCnFro1m2mz41+Ib0oOhfJnBp5MR4K4= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= mvdan.cc/sh v2.6.4+incompatible h1:eD6tDeh0pw+/TOTI1BBEryZ02rD2nMcFsgcvde7jffM= mvdan.cc/sh v2.6.4+incompatible/go.mod h1:IeeQbZq+x2SUGBensq/jge5lLQbS3XT2ktyp3wrt4x8= diff --git a/statuscode.go b/statuscode.go index d422374..c7b2036 100644 --- a/statuscode.go +++ b/statuscode.go @@ -49,7 +49,7 @@ type CloseError struct { } func (ce CloseError) Error() string { - return fmt.Sprintf("websocket closed with status = %v and reason = %q", ce.Code, ce.Reason) + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) } func parseClosePayload(p []byte) (CloseError, error) { diff --git a/websocket.go b/websocket.go index 6e35281..0b77966 100644 --- a/websocket.go +++ b/websocket.go @@ -14,7 +14,8 @@ import ( ) // Conn represents a WebSocket connection. -// All methods except Reader can be used concurrently. +// All methods may be called concurrently. +// // Please be sure to call Close on the connection when you // are finished with it to release resources. type Conn struct { @@ -31,8 +32,10 @@ type Conn struct { writeDataLock chan struct{} writeFrameLock chan struct{} - readData chan header - readDone chan struct{} + readDataLock chan struct{} + readData chan header + readDone chan struct{} + readLoopDone chan struct{} setReadTimeout chan context.Context setWriteTimeout chan context.Context @@ -44,7 +47,7 @@ type Conn struct { // when the connection is closed. // If the parent context is cancelled, the connection will be closed. // -// This is an experimental API meaning it may be remove in the future. +// This is an experimental API that may be remove in the future. // Please let me know how you feel about it. func (c *Conn) Context(parent context.Context) context.Context { select { @@ -77,6 +80,18 @@ func (c *Conn) close(err error) { c.closeErr = xerrors.Errorf("websocket closed: %w", cerr) close(c.closed) + + // See comment in dial.go + if c.client { + go func() { + <-c.readLoopDone + c.readDataLock <- struct{}{} + c.writeFrameLock <- struct{}{} + + returnBufioReader(c.br) + returnBufioWriter(c.bw) + }() + } }) } @@ -94,6 +109,8 @@ func (c *Conn) init() { c.readData = make(chan header) c.readDone = make(chan struct{}) + c.readDataLock = make(chan struct{}, 1) + c.readLoopDone = make(chan struct{}) c.setReadTimeout = make(chan context.Context) c.setWriteTimeout = make(chan context.Context) @@ -174,8 +191,8 @@ func (c *Conn) timeoutLoop() { select { case <-c.closed: return - case readCtx = <-c.setWriteTimeout: - case writeCtx = <-c.setReadTimeout: + case writeCtx = <-c.setWriteTimeout: + case readCtx = <-c.setReadTimeout: case <-readCtx.Done(): c.close(xerrors.Errorf("data read timed out: %w", readCtx.Err())) case <-writeCtx.Done(): @@ -276,6 +293,8 @@ func (c *Conn) readTillData() (header, error) { } func (c *Conn) readLoop() { + defer close(c.readLoopDone) + for { h, err := c.readTillData() if err != nil { @@ -487,8 +506,7 @@ func (w *messageWriter) close() error { // // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames and not become stuck waiting for a data message to be read. -// Please ensure to read the full message from io.Reader. If you do not read till -// io.EOF, the connection will break unless the next read would have yielded io.EOF. +// Please ensure to read the full message from io.Reader. // // You can only read a single message at a time so do not call this method // concurrently. @@ -500,30 +518,10 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return typ, r, nil } -func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - // if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { - // // If the next read yields io.EOF we are good to go. - // r := messageReader{ - // ctx: ctx, - // c: c, - // } - // _, err := r.Read(nil) - // if err == nil { - // return 0, nil, xerrors.New("previous message not fully read") - // } - // if !xerrors.Is(err, io.EOF) { - // return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err) - // } - // - // atomic.StoreInt64(&c.activeReader, 1) - // } - - select { - case <-c.closed: - return 0, nil, c.closeErr - case <-ctx.Done(): - return 0, nil, ctx.Err() - case c.setReadTimeout <- ctx: +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { + err = c.acquireLock(ctx, c.readDataLock) + if err != nil { + return 0, nil, err } select { @@ -533,25 +531,24 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, ctx.Err() case h := <-c.readData: if h.opcode == opContinuation { - if h.fin && h.payloadLength == 0 { - select { - case <-c.closed: - return 0, nil, c.closeErr - case c.readDone <- struct{}{}: - return c.reader(ctx) - } + ce := CloseError{ + Code: StatusProtocolError, + Reason: "continuation frame not after data or text frame", } - return 0, nil, xerrors.Errorf("previous reader was not read to EOF") + c.Close(ce.Code, ce.Reason) + return 0, nil, ce } return MessageType(h.opcode), &messageReader{ - h: &h, - c: c, + ctx: ctx, + h: &h, + c: c, }, nil } } // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { + ctx context.Context maskPos int h *header c *Conn @@ -598,8 +595,20 @@ func (r *messageReader) read(p []byte) (int, error) { p = p[:r.h.payloadLength] } + select { + case <-r.c.closed: + return 0, r.c.closeErr + case r.c.setReadTimeout <- r.ctx: + } + n, err := io.ReadFull(r.c.br, p) + select { + case <-r.c.closed: + return 0, r.c.closeErr + case r.c.setReadTimeout <- context.Background(): + } + r.h.payloadLength -= int64(n) if r.h.masked { r.maskPos = fastXOR(r.h.maskKey, r.maskPos, p) @@ -618,12 +627,8 @@ func (r *messageReader) read(p []byte) (int, error) { } if r.h.fin { r.eofed = true - select { - case <-r.c.closed: - return n, r.c.closeErr - case r.c.setReadTimeout <- context.Background(): - return n, io.EOF - } + r.c.releaseLock(r.c.readDataLock) + return n, io.EOF } r.maskPos = 0 r.h = nil diff --git a/websocket_test.go b/websocket_test.go index 8d18c73..0ac0557 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -293,10 +293,6 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - err = write() - if err != nil { - return err - } c.Close(websocket.StatusNormalClosure, "") return nil @@ -329,11 +325,6 @@ func TestHandshake(t *testing.T) { if err != nil { return err } - // Read twice to ensure the un EOFed previous reader works correctly. - err = read() - if err != nil { - return err - } c.Close(websocket.StatusNormalClosure, "") return nil @@ -766,6 +757,11 @@ func benchConn(b *testing.B, echo, stream bool, size int) { if err != nil { b.Fatal(err) } + + _, err = r.Read(nil) + if !xerrors.Is(err, io.EOF) { + b.Fatalf("more data in reader than needed") + } } } b.StopTimer() diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 9dd61bd..853369e 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -4,9 +4,8 @@ package wsjson import ( "context" "encoding/json" - "io" - "golang.org/x/xerrors" + "io" "nhooyr.io/websocket" ) @@ -41,6 +40,17 @@ func read(ctx context.Context, c *websocket.Conn, v interface{}) error { return xerrors.Errorf("failed to decode json: %w", err) } + // Have to ensure we read till EOF. + // Unfortunate but necessary evil for now. Can improve later. + // The code to do this automatically gets complicated fast because + // we support concurrent reading. + // So the Reader has to synchronize with Read somehow. + // Maybe its best to bring back the old readLoop? + _, err = r.Read(nil) + if !xerrors.Is(err, io.EOF) { + return xerrors.Errorf("more data than needed in reader") + } + return nil } -- GitLab