From 26f34be85aa8e4a499ed14d9c8052399d2d25a2d Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 26 Apr 2019 19:49:18 -0400 Subject: [PATCH] Improve docs and fix error wrapping Closes #69 Closes #45 Closes #47 --- README.md | 4 +- accept.go | 8 +-- dial.go | 19 +++---- docs/contributing.md | 5 +- example_echo_test.go | 22 ++++---- example_test.go | 8 ++- statuscode.go | 15 ++---- websocket.go | 124 +++++++++++++++++++++++++++---------------- wsjson/wsjson.go | 3 +- wspb/wspb.go | 3 +- 10 files changed, 120 insertions(+), 91 deletions(-) diff --git a/README.md b/README.md index b5adb80..550304e 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { log.Printf("received: %v", v) - c.Close(websocket.StatusNormalClosure, "success") + c.Close(websocket.StatusNormalClosure, "") }) ``` @@ -77,7 +77,7 @@ if err != nil { // ... } -c.Close(websocket.StatusNormalClosure, "done") +c.Close(websocket.StatusNormalClosure, "") ``` ## Design considerations diff --git a/accept.go b/accept.go index 9cf546f..2cf1dc0 100644 --- a/accept.go +++ b/accept.go @@ -23,7 +23,7 @@ type AcceptOptions struct { // behaviour. By default Accept only allows the handshake to // succeed if the javascript that is initiating the handshake // is on the same domain as the server. This is to prevent CSRF - // when secure data is stored in a cookie as there is no same + // attacks when secure data is stored in a cookie as there is no same // origin policy for WebSockets. In other words, javascript from // any domain can perform a WebSocket dial on an arbitrary server. // This dial will include cookies which means the arbitrary javascript @@ -53,13 +53,13 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { } if r.Method != "GET" { - err := xerrors.Errorf("websocket protocol violation: handshake request method %q is not GET", r.Method) + err := xerrors.Errorf("websocket protocol violation: handshake request method is not GET but %q", r.Method) http.Error(w, err.Error(), http.StatusBadRequest) return err } if r.Header.Get("Sec-WebSocket-Version") != "13" { - err := xerrors.Errorf("unsupported websocket protocol version: %q", r.Header.Get("Sec-WebSocket-Version")) + err := xerrors.Errorf("unsupported websocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) http.Error(w, err.Error(), http.StatusBadRequest) return err } @@ -75,7 +75,7 @@ func verifyClientRequest(w http.ResponseWriter, r *http.Request) error { // Accept accepts a WebSocket handshake from a client and upgrades the // the connection to WebSocket. -// Accept will reject the handshake if the Origin is not the same as the Host unless +// Accept will reject the handshake if the Origin domain is not the same as the Host unless // the InsecureSkipVerify option is set. func Accept(w http.ResponseWriter, r *http.Request, opts AcceptOptions) (*Conn, error) { c, err := accept(w, r, opts) diff --git a/dial.go b/dial.go index eee40dd..36c12c4 100644 --- a/dial.go +++ b/dial.go @@ -19,7 +19,7 @@ type DialOptions struct { // HTTPClient is the http client used for the handshake. // Its Transport must use HTTP/1.1 and must return writable bodies // for WebSocket handshakes. This was introduced in Go 1.12. - // http.Transport does this correctly. + // http.Transport does this all correctly. HTTPClient *http.Client // HTTPHeader specifies the HTTP headers included in the handshake request. @@ -35,6 +35,9 @@ type DialOptions struct { var secWebSocketKey = base64.StdEncoding.EncodeToString(make([]byte, 16)) // Dial performs a WebSocket handshake on the given url with the given options. +// The response is the WebSocket handshake response from the server. +// If an error occurs, the returned response may be non nil. However, you can only +// read the first 1024 bytes of its body. func Dial(ctx context.Context, u string, opts DialOptions) (*Conn, *http.Response, error) { c, r, err := dial(ctx, u, opts) if err != nil { @@ -48,7 +51,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res opts.HTTPClient = http.DefaultClient } if opts.HTTPClient.Timeout > 0 { - return nil, nil, xerrors.Errorf("please use context for cancellation instead of http.Client.Timeout; see issue nhooyr.io/websocket#67") + return nil, nil, xerrors.Errorf("please use context for cancellation instead of http.Client.Timeout; see https://github.com/nhooyr/websocket/issues/67") } if opts.HTTPHeader == nil { opts.HTTPHeader = http.Header{} @@ -65,7 +68,7 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res case "wss": parsedURL.Scheme = "https" default: - return nil, nil, xerrors.Errorf("unexpected url scheme scheme: %q", parsedURL.Scheme) + return nil, nil, xerrors.Errorf("unexpected url scheme: %q", parsedURL.Scheme) } req, _ := http.NewRequest("GET", parsedURL.String(), nil) @@ -84,13 +87,12 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res return nil, nil, xerrors.Errorf("failed to send handshake request: %w", err) } defer func() { - respBody := resp.Body if err != nil { - // We read a bit of the body for better debugging. + // We read a bit of the body for easier debugging. r := io.LimitReader(resp.Body, 1024) b, _ := ioutil.ReadAll(r) + resp.Body.Close() resp.Body = ioutil.NopCloser(bytes.NewReader(b)) - respBody.Close() } }() @@ -104,7 +106,6 @@ func dial(ctx context.Context, u string, opts DialOptions) (_ *Conn, _ *http.Res return nil, resp, xerrors.Errorf("response body is not a read write closer: %T", rwc) } - // TODO pool bufio c := &Conn{ subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), br: bufio.NewReader(rwc), @@ -123,11 +124,11 @@ func verifyServerResponse(resp *http.Response) error { } if !headerValuesContainsToken(resp.Header, "Connection", "Upgrade") { - return xerrors.Errorf("websocket protocol violation: Connection header does not contain Upgrade: %q", resp.Header.Get("Connection")) + return xerrors.Errorf("websocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) } if !headerValuesContainsToken(resp.Header, "Upgrade", "WebSocket") { - return xerrors.Errorf("websocket protocol violation: Upgrade header does not contain websocket: %q", resp.Header.Get("Upgrade")) + return xerrors.Errorf("websocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) } // We do not care about Sec-WebSocket-Accept because it does not matter. diff --git a/docs/contributing.md b/docs/contributing.md index 3f267ef..7214915 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -8,10 +8,9 @@ Please be as descriptive as possible with your description. Please split up changes into several small descriptive commits. -Please capitalize the first word in the commit message and ensure it is -descriptive. +Please capitalize the first word in the commit message title. -The commit message should use the verb tense + phrase that completes the blank in +The commit message title should use the verb tense + phrase that completes the blank in > This change modifies websocket to ___________ diff --git a/example_echo_test.go b/example_echo_test.go index 358f5a2..a90257b 100644 --- a/example_echo_test.go +++ b/example_echo_test.go @@ -16,7 +16,7 @@ import ( "nhooyr.io/websocket/wsjson" ) -// main starts a WebSocket echo server and +// This example starts a WebSocket echo server and // then dials the server and sends 5 different messages // and prints out the server's responses. func Example_echo() { @@ -26,7 +26,6 @@ func Example_echo() { l, err := net.Listen("tcp", "localhost:0") if err != nil { log.Fatalf("failed to listen: %v", err) - return } defer l.Close() @@ -55,7 +54,6 @@ func Example_echo() { if err != nil { log.Fatalf("client failed: %v", err) } - // Output: // received: map[i:0] // received: map[i:1] @@ -76,16 +74,14 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { } defer c.Close(websocket.StatusInternalError, "the sky is falling") - if c.Subprotocol() == "" { - c.Close(websocket.StatusPolicyViolation, "cannot communicate with the default protocol") + if c.Subprotocol() != "echo" { + c.Close(websocket.StatusPolicyViolation, "client must speak the echo subprotocol") return xerrors.Errorf("client does not speak echo sub protocol") } - ctx := r.Context() l := rate.NewLimiter(rate.Every(time.Millisecond*100), 10) - for { - err = echo(ctx, c, l) + err = echo(r.Context(), c, l) if err != nil { return xerrors.Errorf("failed to echo: %w", err) } @@ -94,10 +90,10 @@ func echoServer(w http.ResponseWriter, r *http.Request) error { // echo reads from the websocket connection and then writes // the received message back to it. -// It only waits 1 minute to read and write the message and -// limits the received message to 32768 bytes. +// The entire function has 10s to complete. +// The received message is limited to 32768 bytes. func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { - ctx, cancel := context.WithTimeout(ctx, time.Minute) + ctx, cancel := context.WithTimeout(ctx, time.Second*10) defer cancel() err := l.Wait(ctx) @@ -118,7 +114,7 @@ func echo(ctx context.Context, c *websocket.Conn, l *rate.Limiter) error { _, err = io.Copy(w, r) if err != nil { - return err + return xerrors.Errorf("failed to io.Copy: %w", err) } err = w.Close() @@ -157,6 +153,6 @@ func client(url string) error { fmt.Printf("received: %v\n", v) } - c.Close(websocket.StatusNormalClosure, "done") + c.Close(websocket.StatusNormalClosure, "") return nil } diff --git a/example_test.go b/example_test.go index f5c92bb..7a0528c 100644 --- a/example_test.go +++ b/example_test.go @@ -10,6 +10,8 @@ import ( "nhooyr.io/websocket/wsjson" ) +// This example accepts a WebSocket connection, reads a single JSON +// message from the client and then closes the connection. func ExampleAccept() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { c, err := websocket.Accept(w, r, websocket.AcceptOptions{}) @@ -31,12 +33,14 @@ func ExampleAccept() { log.Printf("received: %v", v) - c.Close(websocket.StatusNormalClosure, "success") + c.Close(websocket.StatusNormalClosure, "") }) http.ListenAndServe("localhost:8080", fn) } +// This example dials a server, writes a single JSON message and then +// closes the connection. func ExampleDial() { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() @@ -54,5 +58,5 @@ func ExampleDial() { return } - c.Close(websocket.StatusNormalClosure, "done") + c.Close(websocket.StatusNormalClosure, "") } diff --git a/statuscode.go b/statuscode.go index 7ac424e..69b015c 100644 --- a/statuscode.go +++ b/statuscode.go @@ -2,9 +2,7 @@ package websocket import ( "encoding/binary" - "errors" "fmt" - "math/bits" "golang.org/x/xerrors" ) @@ -50,7 +48,7 @@ type CloseError struct { } func (ce CloseError) Error() string { - return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", ce.Code, ce.Reason) + return fmt.Sprintf("websocket closed with status = %v and reason = %q", ce.Code, ce.Reason) } func parseClosePayload(p []byte) (CloseError, error) { @@ -61,7 +59,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if len(p) < 2 { - return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code") + return CloseError{}, xerrors.Errorf("close payload too small, cannot even contain the 2 byte status code") } ce := CloseError{ @@ -70,7 +68,7 @@ func parseClosePayload(p []byte) (CloseError, error) { } if !validWireCloseCode(ce.Code) { - return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code) + return CloseError{}, xerrors.Errorf("invalid status code %v", ce.Code) } return ce, nil @@ -100,15 +98,12 @@ func (ce CloseError) bytes() ([]byte, error) { if len(ce.Reason) > maxControlFramePayload-2 { return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) } - if bits.Len(uint(ce.Code)) > 16 { - return nil, errors.New("status code is larger than 2 bytes") - } if !validWireCloseCode(ce.Code) { - return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + return nil, xerrors.Errorf("status code %v cannot be set", ce.Code) } buf := make([]byte, 2+len(ce.Reason)) - binary.BigEndian.PutUint16(buf[:], uint16(ce.Code)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) copy(buf[2:], ce.Reason) return buf, nil } diff --git a/websocket.go b/websocket.go index 2f324d3..c5e3bf5 100644 --- a/websocket.go +++ b/websocket.go @@ -60,18 +60,18 @@ type Conn struct { } func (c *Conn) close(err error) { - err = xerrors.Errorf("connection broken: %w", err) + err = xerrors.Errorf("websocket closed: %w", err) c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) - c.closeErr = err - cerr := c.closer.Close() - if c.closeErr == nil { - c.closeErr = cerr + if err != nil { + cerr = err } + c.closeErr = cerr + close(c.closed) }) } @@ -126,6 +126,20 @@ func (c *Conn) writeFrame(h header, p []byte) { } } +func (c *Conn) writeLoopControl(control control) { + h := header{ + fin: true, + opcode: control.opcode, + payloadLength: int64(len(control.payload)), + masked: c.client, + } + c.writeFrame(h, control.payload) + select { + case <-c.closed: + case c.writeDone <- struct{}{}: + } +} + func (c *Conn) writeLoop() { defer close(c.writeDone) @@ -136,19 +150,7 @@ messageLoop: case <-c.closed: return case control := <-c.control: - h := header{ - fin: true, - opcode: control.opcode, - payloadLength: int64(len(control.payload)), - masked: c.client, - } - c.writeFrame(h, control.payload) - select { - case <-c.closed: - return - case c.writeDone <- struct{}{}: - continue - } + c.writeLoopControl(control) case dataType = <-c.write: } @@ -158,19 +160,7 @@ messageLoop: case <-c.closed: return case control := <-c.control: - h := header{ - fin: true, - opcode: control.opcode, - payloadLength: int64(len(control.payload)), - masked: c.client, - } - c.writeFrame(h, control.payload) - select { - case <-c.closed: - return - case c.writeDone <- struct{}{}: - continue - } + c.writeLoopControl(control) case b := <-c.writeBytes: h := header{ fin: false, @@ -190,7 +180,6 @@ messageLoop: case <-c.closed: return case c.writeDone <- struct{}{}: - continue } case <-c.writeFlush: h := header{ @@ -265,7 +254,7 @@ func (c *Conn) readLoop() { } if h.rsv1 || h.rsv2 || h.rsv3 { - c.Close(StatusProtocolError, fmt.Sprintf("read header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) + c.Close(StatusProtocolError, fmt.Sprintf("received header with rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)) return } @@ -277,7 +266,7 @@ func (c *Conn) readLoop() { switch h.opcode { case opBinary, opText: if c.inMsg { - c.Close(StatusProtocolError, "cannot read data frame when previous frame is not finished") + c.Close(StatusProtocolError, "cannot read new data frame when previous frame is not finished") return } @@ -360,6 +349,14 @@ func (c *Conn) writePong(p []byte) error { // It will write a WebSocket close frame with a timeout of 5 seconds. // Concurrent calls to Close are ok. func (c *Conn) Close(code StatusCode, reason string) error { + err := c.exportedClose(code, reason) + if err != nil { + return xerrors.Errorf("failed to close connection: %w", err) + } + return nil +} + +func (c *Conn) exportedClose(code StatusCode, reason string) error { ce := CloseError{ Code: code, Reason: reason, @@ -411,8 +408,9 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error payload: p, }: case <-ctx.Done(): - c.close(xerrors.New("force closed: close frame write timed out")) - return c.closeErr + err := xerrors.Errorf("control frame write timed out: %w", ctx.Err()) + c.close(err) + return err } select { @@ -427,11 +425,20 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // Writer returns a writer bounded by the context that will write // a WebSocket message of type dataType to the connection. +// // Ensure you close the writer once you have written the entire message. // Concurrent calls to Writer are ok. -// Writer will block if there is another goroutine with an open writer -// until writer is closed. +// Only one writer can be open at a time so Writer will block if there is +// another goroutine with an open writer until that writer is closed. func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + wc, err := c.writer(ctx, typ) + if err != nil { + return nil, xerrors.Errorf("failed to get writer: %w", err) + } + return wc, nil +} + +func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { select { case <-c.closed: return nil, c.closeErr @@ -453,14 +460,23 @@ type messageWriter struct { // Write writes the given bytes to the WebSocket connection. func (w messageWriter) Write(p []byte) (int, error) { + n, err := w.Write(p) + if err != nil { + return n, xerrors.Errorf("failed to write: %w", err) + } + return n, nil +} + +func (w messageWriter) write(p []byte) (int, error) { select { case <-w.c.closed: return 0, w.c.closeErr case w.c.writeBytes <- p: select { case <-w.ctx.Done(): - w.c.close(xerrors.Errorf("write timed out: %w", w.ctx.Err())) - <-w.c.readDone + w.c.close(xerrors.Errorf("data write timed out: %w", w.ctx.Err())) + // Wait for writeLoop to complete so we know p is done. + <-w.c.writeDone return 0, w.ctx.Err() case _, ok := <-w.c.writeDone: if !ok { @@ -476,6 +492,14 @@ func (w messageWriter) Write(p []byte) (int, error) { // Close flushes the frame to the connection. // This must be called for every messageWriter. func (w messageWriter) Close() error { + err := w.close() + if err != nil { + return xerrors.Errorf("failed to close writer: %w", err) + } + return nil +} + +func (w messageWriter) close() error { select { case <-w.c.closed: return w.c.closeErr @@ -498,6 +522,14 @@ func (w messageWriter) Close() error { // You can only read a single message at a time so do not call this method // concurrently. func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, xerrors.Errorf("failed to get reader: %w", err) + } + return typ, r, nil +} + +func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) { select { case <-c.closed: @@ -511,7 +543,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, c.closeErr } if atomic.LoadInt64(&c.activeReader) == 1 { - return 0, nil, xerrors.New("websocket: previous message not fully read") + return 0, nil, xerrors.New("previous message not fully read") } } case <-ctx.Done(): @@ -521,14 +553,14 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { select { case <-c.closed: - return 0, nil, xerrors.Errorf("websocket: failed to read message: %w", c.closeErr) + return 0, nil, c.closeErr case opcode := <-c.read: return MessageType(opcode), messageReader{ ctx: ctx, c: c, }, nil case <-ctx.Done(): - return 0, nil, xerrors.Errorf("websocket: failed to read message: %w", ctx.Err()) + return 0, nil, ctx.Err() } } @@ -546,7 +578,7 @@ func (r messageReader) Read(p []byte) (int, error) { if err == io.EOF { return n, io.EOF } - return n, xerrors.Errorf("websocket: failed to read: %w", err) + return n, xerrors.Errorf("failed to read: %w", err) } return n, nil } @@ -562,8 +594,8 @@ func (r messageReader) read(p []byte) (_ int, err error) { case r.c.readBytes <- p: select { case <-r.ctx.Done(): - r.c.close(xerrors.Errorf("read timed out: %w", r.ctx.Err())) - // Wait for readloop to complete so we know p is done. + r.c.close(xerrors.Errorf("data read timed out: %w", r.ctx.Err())) + // Wait for readLoop to complete so we know p is done. <-r.c.readDone return 0, r.ctx.Err() case n, ok := <-r.c.readDone: diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index df67cf9..4d315d1 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -12,7 +12,8 @@ import ( ) // Read reads a json message from c into v. -// It will read a message up to 32768 bytes in length. +// For security reasons, it will not read messages +// larger than 32768 bytes. func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { err := read(ctx, c, v) if err != nil { diff --git a/wspb/wspb.go b/wspb/wspb.go index 159e92d..953128e 100644 --- a/wspb/wspb.go +++ b/wspb/wspb.go @@ -13,7 +13,8 @@ import ( ) // Read reads a protobuf message from c into v. -// It will read a message up to 32768 bytes in length. +// For security reasons, it will not read messages +// larger than 32768 bytes. func Read(ctx context.Context, c *websocket.Conn, v proto.Message) error { err := read(ctx, c, v) if err != nil { -- GitLab