diff --git a/accept.go b/accept.go index bde56e907db68d460757b5c66cae2b54ed183e9d..2dabdae3ccf21d5bec42890161728dd9b65e8db0 100644 --- a/accept.go +++ b/accept.go @@ -112,7 +112,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn hj, ok := w.(http.Hijacker) if !ok { err = xerrors.New("websocket: response writer does not implement http.Hijacker") - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } @@ -131,7 +131,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn netConn, brw, err := hj.Hijack() if err != nil { err = xerrors.Errorf("websocket: failed to hijack connection: %w", err) - http.Error(w, err.Error(), http.StatusInternalServerError) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return nil, err } diff --git a/example_test.go b/example_test.go index 5c2d2b23444a9140dd67db2644e588ba47947bf7..bee8e9277bfd0e2bf0b9028e37ba14165a6c01f3 100644 --- a/example_test.go +++ b/example_test.go @@ -71,7 +71,7 @@ func ExampleAccept() { log.Printf("server handshake failed: %v", err) return } - defer c.Close(websocket.StatusInternalError, "") + defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error. jc := websocket.JSONConn{ Conn: c, diff --git a/header.go b/header.go index 8450e1428bcbde5ae2490adbca7101a290dd5704..276fa0c30b93f6120c18d064c96b1c5e05548f1d 100644 --- a/header.go +++ b/header.go @@ -30,8 +30,6 @@ type header struct { maskKey [4]byte } -// TODO bitwise helpers - // bytes returns the bytes of the header. // See https://tools.ietf.org/html/rfc6455#section-5.2 func marshalHeader(h header) []byte { diff --git a/header_test.go b/header_test.go index 65812997231684265b585e0dd854a3e97bd401fb..b4d0769fb2c33044fd841c04041f58a8d8328028 100644 --- a/header_test.go +++ b/header_test.go @@ -20,7 +20,7 @@ func randBool() bool { func TestHeader(t *testing.T) { t.Parallel() - t.Run("negative", func(t *testing.T) { + t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() b := marshalHeader(header{ diff --git a/statuscode.go b/statuscode.go index 596f78bcbfb546afc07b63d9b5b88754eae759ee..2f4f2c0c735c0550621c7317c9f7457fefe882a3 100644 --- a/statuscode.go +++ b/statuscode.go @@ -24,7 +24,10 @@ const ( StatusUnsupportedData _ // 1004 is reserved. StatusNoStatusRcvd - StatusAbnormalClosure + // statusAbnormalClosure is unexported because it isn't necessary, at least until WASM. + // The error returned will indicate whether the connection was closed or not or what happened. + // It only makes sense for browser clients. + statusAbnormalClosure StatusInvalidFramePayloadData StatusPolicyViolation StatusMessageTooBig @@ -33,7 +36,10 @@ const ( StatusServiceRestart StatusTryAgainLater StatusBadGateway - StatusTLSHandshake + // statusTLSHandshake is unexported because we just return + // handshake error in dial. We do not return a conn + // so there is nothing to use this on. At least until WASM. + statusTLSHandshake ) // CloseError represents an error from a WebSocket close frame. @@ -43,68 +49,63 @@ type CloseError struct { Reason string } -func (e CloseError) Error() string { - return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", e.Code, e.Reason) +func (ce CloseError) Error() string { + return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", ce.Code, ce.Reason) } -func parseClosePayload(p []byte) (code StatusCode, reason string, err error) { +func parseClosePayload(p []byte) (CloseError, error) { if len(p) < 2 { - return 0, "", fmt.Errorf("close payload too small, cannot even contain the 2 byte status code") + return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code") } - code = StatusCode(binary.BigEndian.Uint16(p)) - reason = string(p[2:]) + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } - if !utf8.ValidString(reason) { - return 0, "", xerrors.Errorf("invalid utf-8: %q", reason) + if !utf8.ValidString(ce.Reason) { + return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason) } - if !validCloseCode(code) { - return 0, "", xerrors.Errorf("invalid code %v", code) + if !validWireCloseCode(ce.Code) { + return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code) } - return code, reason, nil + return ce, nil } // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number // and https://tools.ietf.org/html/rfc6455#section-7.4.1 -var validReceivedCloseCodes = map[StatusCode]bool{ - StatusNormalClosure: true, - StatusGoingAway: true, - StatusProtocolError: true, - StatusUnsupportedData: true, - StatusNoStatusRcvd: false, - // TODO use - StatusAbnormalClosure: false, - StatusInvalidFramePayloadData: true, - StatusPolicyViolation: true, - StatusMessageTooBig: true, - StatusMandatoryExtension: true, - StatusInternalError: true, - StatusServiceRestart: true, - StatusTryAgainLater: true, - StatusTLSHandshake: false, -} +func validWireCloseCode(code StatusCode) bool { + if code >= StatusNormalClosure && code <= statusTLSHandshake { + switch code { + case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake: + return false + default: + return true + } + } + if code >= 3000 && code <= 4999 { + return true + } -func validCloseCode(code StatusCode) bool { - return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) + return false } const maxControlFramePayload = 125 -// TODO make method on CloseError -func closePayload(code StatusCode, reason string) ([]byte, error) { - if len(reason) > maxControlFramePayload-2 { - return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, reason, len(reason)) +func (ce CloseError) bytes() ([]byte, error) { + if len(ce.Reason) > maxControlFramePayload-2 { + return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason)) } - if bits.Len(uint(code)) > 16 { + if bits.Len(uint(ce.Code)) > 16 { return nil, errors.New("status code is larger than 2 bytes") } - if !validCloseCode(code) { - return nil, fmt.Errorf("status code %v cannot be set", code) + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) } - buf := make([]byte, 2+len(reason)) - binary.BigEndian.PutUint16(buf[:], uint16(code)) - copy(buf[2:], reason) + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf[:], uint16(ce.Code)) + copy(buf[2:], ce.Reason) return buf, nil } diff --git a/statuscode_string.go b/statuscode_string.go index fc8cea0d6faa717bb71e2c3ce1f4426d635ffa76..11725e4dcf4a3d63e171511730cfc90904333b56 100644 --- a/statuscode_string.go +++ b/statuscode_string.go @@ -13,7 +13,7 @@ func _() { _ = x[StatusProtocolError-1002] _ = x[StatusUnsupportedData-1003] _ = x[StatusNoStatusRcvd-1005] - _ = x[StatusAbnormalClosure-1006] + _ = x[statusAbnormalClosure-1006] _ = x[StatusInvalidFramePayloadData-1007] _ = x[StatusPolicyViolation-1008] _ = x[StatusMessageTooBig-1009] @@ -22,12 +22,12 @@ func _() { _ = x[StatusServiceRestart-1012] _ = x[StatusTryAgainLater-1013] _ = x[StatusBadGateway-1014] - _ = x[StatusTLSHandshake-1015] + _ = x[statusTLSHandshake-1015] } const ( _StatusCode_name_0 = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedData" - _StatusCode_name_1 = "StatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" + _StatusCode_name_1 = "StatusNoStatusRcvdstatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewaystatusTLSHandshake" ) var ( diff --git a/statuscode_test.go b/statuscode_test.go new file mode 100644 index 0000000000000000000000000000000000000000..38ee4c3fdd597a46d0bc7dc763df57badba11301 --- /dev/null +++ b/statuscode_test.go @@ -0,0 +1,57 @@ +package websocket + +import ( + "math" + "strings" + "testing" +) + +func TestCloseError(t *testing.T) { + t.Parallel() + + // Other parts of close error are tested by websocket_test.go right now + // with the autobahn tests. + + testCases := []struct { + name string + ce CloseError + success bool + }{ + { + name: "normal", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", maxControlFramePayload-2), + }, + success: true, + }, + { + name: "bigReason", + ce: CloseError{ + Code: StatusNormalClosure, + Reason: strings.Repeat("x", maxControlFramePayload-1), + }, + success: false, + }, + { + name: "bigCode", + ce: CloseError{ + Code: math.MaxUint16, + Reason: strings.Repeat("x", maxControlFramePayload-2), + }, + success: false, + }, + } + + for _, tc := range testCases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := tc.ce.bytes() + if (err == nil) != tc.success { + t.Fatalf("unexpected error value: %v", err) + } + }) + } +} diff --git a/websocket.go b/websocket.go index 09a94e7809ca6c8860093acb4ce46a34bef0e38a..717cc75550066688278e5cdaa05711bb029e44e3 100644 --- a/websocket.go +++ b/websocket.go @@ -23,7 +23,8 @@ type control struct { type Conn struct { subprotocol string br *bufio.Reader - // TODO Cannot use bufio writer because for compression we need to know how much is buffered and compress it if large. + // TODO switch to []byte for write buffering because for messages larger than buffers, there will always be 3 writes. One for the frame, one for the message, one for the fin. + // Also will help for compression. bw *bufio.Writer closer io.Closer client bool @@ -225,12 +226,12 @@ func (c *Conn) handleControl(h header) { case opPong: case opClose: if len(b) > 0 { - code, reason, err := parseClosePayload(b) + ce, err := parseClosePayload(b) if err != nil { c.close(xerrors.Errorf("read invalid close payload: %w", err)) return } - c.Close(code, reason) + c.Close(ce.Code, ce.Reason) } else { c.writeClose(nil, CloseError{ Code: StatusNoStatusRcvd, @@ -279,8 +280,7 @@ func (c *Conn) readLoop() { return } default: - // TODO send back protocol violation message or figure out what RFC wants. - c.close(xerrors.Errorf("unexpected opcode in header: %#v", h)) + c.Close(StatusProtocolError, fmt.Sprintf("unknown opcode %v", h.opcode)) return } @@ -338,18 +338,23 @@ func (c *Conn) writePong(p []byte) error { // Close closes the WebSocket connection with the given status code and reason. // It will write a WebSocket close frame with a timeout of 5 seconds. func (c *Conn) Close(code StatusCode, reason string) error { + ce := CloseError{ + Code: code, + Reason: reason, + } + // This function also will not wait for a close frame from the peer like the RFC // wants because that makes no sense and I don't think anyone actually follows that. // Definitely worth seeing what popular browsers do later. - p, err := closePayload(code, reason) + p, err := ce.bytes() if err != nil { - p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code)) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytes() } - cerr := c.writeClose(p, CloseError{ - Code: code, - Reason: reason, - }) + cerr := c.writeClose(p, ce) if err != nil { return err } diff --git a/websocket_test.go b/websocket_test.go index 14dcc8c51fc1a9eb74f9448d7a33ae546611ff80..2133482fb44403f733dca3834acb5d2071a78c9d 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -7,7 +7,9 @@ import ( "io" "io/ioutil" "net/http" + "net/http/cookiejar" "net/http/httptest" + "net/url" "os" "os/exec" "reflect" @@ -216,6 +218,52 @@ func TestHandshake(t *testing.T) { return nil }, }, + { + name: "cookies", + server: func(w http.ResponseWriter, r *http.Request) error { + cookie, err := r.Cookie("mycookie") + if err != nil { + return xerrors.Errorf("request is missing mycookie: %w", err) + } + if cookie.Value != "myvalue" { + return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value) + } + c, err := websocket.Accept(w, r) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + client: func(ctx context.Context, u string) error { + jar, err := cookiejar.New(nil) + if err != nil { + return xerrors.Errorf("failed to create cookie jar: %w", err) + } + parsedURL, err := url.Parse(u) + if err != nil { + return xerrors.Errorf("failed to parse url: %w", err) + } + parsedURL.Scheme = "http" + jar.SetCookies(parsedURL, []*http.Cookie{ + { + Name: "mycookie", + Value: "myvalue", + }, + }) + hc := &http.Client{ + Jar: jar, + } + c, _, err := websocket.Dial(ctx, u, + websocket.DialHTTPClient(hc), + ) + if err != nil { + return err + } + c.Close(websocket.StatusInternalError, "") + return nil + }, + }, } for _, tc := range testCases {