From 519e970aad9fa123333c3f2597e806a7b8ff7300 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 13 Apr 2019 17:12:03 -0500
Subject: [PATCH] Cookie and CloseError unit tests

---
 accept.go            |  4 +--
 example_test.go      |  2 +-
 header.go            |  2 --
 header_test.go       |  2 +-
 statuscode.go        | 85 ++++++++++++++++++++++----------------------
 statuscode_string.go |  6 ++--
 statuscode_test.go   | 57 +++++++++++++++++++++++++++++
 websocket.go         | 27 ++++++++------
 websocket_test.go    | 48 +++++++++++++++++++++++++
 9 files changed, 171 insertions(+), 62 deletions(-)
 create mode 100644 statuscode_test.go

diff --git a/accept.go b/accept.go
index bde56e9..2dabdae 100644
--- a/accept.go
+++ b/accept.go
@@ -112,7 +112,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
 	hj, ok := w.(http.Hijacker)
 	if !ok {
 		err = xerrors.New("websocket: response writer does not implement http.Hijacker")
-		http.Error(w, err.Error(), http.StatusInternalServerError)
+		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 		return nil, err
 	}
 
@@ -131,7 +131,7 @@ func Accept(w http.ResponseWriter, r *http.Request, opts ...AcceptOption) (*Conn
 	netConn, brw, err := hj.Hijack()
 	if err != nil {
 		err = xerrors.Errorf("websocket: failed to hijack connection: %w", err)
-		http.Error(w, err.Error(), http.StatusInternalServerError)
+		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 		return nil, err
 	}
 
diff --git a/example_test.go b/example_test.go
index 5c2d2b2..bee8e92 100644
--- a/example_test.go
+++ b/example_test.go
@@ -71,7 +71,7 @@ func ExampleAccept() {
 			log.Printf("server handshake failed: %v", err)
 			return
 		}
-		defer c.Close(websocket.StatusInternalError, "")
+		defer c.Close(websocket.StatusInternalError, "") // TODO returning internal is incorect if its a timeout error.
 
 		jc := websocket.JSONConn{
 			Conn: c,
diff --git a/header.go b/header.go
index 8450e14..276fa0c 100644
--- a/header.go
+++ b/header.go
@@ -30,8 +30,6 @@ type header struct {
 	maskKey [4]byte
 }
 
-// TODO bitwise helpers
-
 // bytes returns the bytes of the header.
 // See https://tools.ietf.org/html/rfc6455#section-5.2
 func marshalHeader(h header) []byte {
diff --git a/header_test.go b/header_test.go
index 6581299..b4d0769 100644
--- a/header_test.go
+++ b/header_test.go
@@ -20,7 +20,7 @@ func randBool() bool {
 func TestHeader(t *testing.T) {
 	t.Parallel()
 
-	t.Run("negative", func(t *testing.T) {
+	t.Run("readNegativeLength", func(t *testing.T) {
 		t.Parallel()
 
 		b := marshalHeader(header{
diff --git a/statuscode.go b/statuscode.go
index 596f78b..2f4f2c0 100644
--- a/statuscode.go
+++ b/statuscode.go
@@ -24,7 +24,10 @@ const (
 	StatusUnsupportedData
 	_ // 1004 is reserved.
 	StatusNoStatusRcvd
-	StatusAbnormalClosure
+	// statusAbnormalClosure is unexported because it isn't necessary, at least until WASM.
+	// The error returned will indicate whether the connection was closed or not or what happened.
+	// It only makes sense for browser clients.
+	statusAbnormalClosure
 	StatusInvalidFramePayloadData
 	StatusPolicyViolation
 	StatusMessageTooBig
@@ -33,7 +36,10 @@ const (
 	StatusServiceRestart
 	StatusTryAgainLater
 	StatusBadGateway
-	StatusTLSHandshake
+	// statusTLSHandshake is unexported because we just return
+	// handshake error in dial. We do not return a conn
+	// so there is nothing to use this on. At least until WASM.
+	statusTLSHandshake
 )
 
 // CloseError represents an error from a WebSocket close frame.
@@ -43,68 +49,63 @@ type CloseError struct {
 	Reason string
 }
 
-func (e CloseError) Error() string {
-	return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", e.Code, e.Reason)
+func (ce CloseError) Error() string {
+	return fmt.Sprintf("WebSocket closed with status = %v and reason = %q", ce.Code, ce.Reason)
 }
 
-func parseClosePayload(p []byte) (code StatusCode, reason string, err error) {
+func parseClosePayload(p []byte) (CloseError, error) {
 	if len(p) < 2 {
-		return 0, "", fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
+		return CloseError{}, fmt.Errorf("close payload too small, cannot even contain the 2 byte status code")
 	}
 
-	code = StatusCode(binary.BigEndian.Uint16(p))
-	reason = string(p[2:])
+	ce := CloseError{
+		Code:   StatusCode(binary.BigEndian.Uint16(p)),
+		Reason: string(p[2:]),
+	}
 
-	if !utf8.ValidString(reason) {
-		return 0, "", xerrors.Errorf("invalid utf-8: %q", reason)
+	if !utf8.ValidString(ce.Reason) {
+		return CloseError{}, xerrors.Errorf("invalid utf-8: %q", ce.Reason)
 	}
-	if !validCloseCode(code) {
-		return 0, "", xerrors.Errorf("invalid code %v", code)
+	if !validWireCloseCode(ce.Code) {
+		return CloseError{}, xerrors.Errorf("invalid code %v", ce.Code)
 	}
 
-	return code, reason, nil
+	return ce, nil
 }
 
 // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
 // and https://tools.ietf.org/html/rfc6455#section-7.4.1
-var validReceivedCloseCodes = map[StatusCode]bool{
-	StatusNormalClosure:   true,
-	StatusGoingAway:       true,
-	StatusProtocolError:   true,
-	StatusUnsupportedData: true,
-	StatusNoStatusRcvd:    false,
-	// TODO use
-	StatusAbnormalClosure:         false,
-	StatusInvalidFramePayloadData: true,
-	StatusPolicyViolation:         true,
-	StatusMessageTooBig:           true,
-	StatusMandatoryExtension:      true,
-	StatusInternalError:           true,
-	StatusServiceRestart:          true,
-	StatusTryAgainLater:           true,
-	StatusTLSHandshake:            false,
-}
+func validWireCloseCode(code StatusCode) bool {
+	if code >= StatusNormalClosure && code <= statusTLSHandshake {
+		switch code {
+		case 1004, StatusNoStatusRcvd, statusAbnormalClosure, statusTLSHandshake:
+			return false
+		default:
+			return true
+		}
+	}
+	if code >= 3000 && code <= 4999 {
+		return true
+	}
 
-func validCloseCode(code StatusCode) bool {
-	return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999)
+	return false
 }
 
 const maxControlFramePayload = 125
 
-// TODO make method on CloseError
-func closePayload(code StatusCode, reason string) ([]byte, error) {
-	if len(reason) > maxControlFramePayload-2 {
-		return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, reason, len(reason))
+func (ce CloseError) bytes() ([]byte, error) {
+	if len(ce.Reason) > maxControlFramePayload-2 {
+		return nil, xerrors.Errorf("reason string max is %v but got %q with length %v", maxControlFramePayload-2, ce.Reason, len(ce.Reason))
 	}
-	if bits.Len(uint(code)) > 16 {
+	if bits.Len(uint(ce.Code)) > 16 {
 		return nil, errors.New("status code is larger than 2 bytes")
 	}
-	if !validCloseCode(code) {
-		return nil, fmt.Errorf("status code %v cannot be set", code)
+	if !validWireCloseCode(ce.Code) {
+		return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
 	}
 
-	buf := make([]byte, 2+len(reason))
-	binary.BigEndian.PutUint16(buf[:], uint16(code))
-	copy(buf[2:], reason)
+	buf := make([]byte, 2+len(ce.Reason))
+	binary.BigEndian.PutUint16(buf[:], uint16(ce.Code))
+	copy(buf[2:], ce.Reason)
 	return buf, nil
 }
diff --git a/statuscode_string.go b/statuscode_string.go
index fc8cea0..11725e4 100644
--- a/statuscode_string.go
+++ b/statuscode_string.go
@@ -13,7 +13,7 @@ func _() {
 	_ = x[StatusProtocolError-1002]
 	_ = x[StatusUnsupportedData-1003]
 	_ = x[StatusNoStatusRcvd-1005]
-	_ = x[StatusAbnormalClosure-1006]
+	_ = x[statusAbnormalClosure-1006]
 	_ = x[StatusInvalidFramePayloadData-1007]
 	_ = x[StatusPolicyViolation-1008]
 	_ = x[StatusMessageTooBig-1009]
@@ -22,12 +22,12 @@ func _() {
 	_ = x[StatusServiceRestart-1012]
 	_ = x[StatusTryAgainLater-1013]
 	_ = x[StatusBadGateway-1014]
-	_ = x[StatusTLSHandshake-1015]
+	_ = x[statusTLSHandshake-1015]
 }
 
 const (
 	_StatusCode_name_0 = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedData"
-	_StatusCode_name_1 = "StatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake"
+	_StatusCode_name_1 = "StatusNoStatusRcvdstatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewaystatusTLSHandshake"
 )
 
 var (
diff --git a/statuscode_test.go b/statuscode_test.go
new file mode 100644
index 0000000..38ee4c3
--- /dev/null
+++ b/statuscode_test.go
@@ -0,0 +1,57 @@
+package websocket
+
+import (
+	"math"
+	"strings"
+	"testing"
+)
+
+func TestCloseError(t *testing.T) {
+	t.Parallel()
+
+	// Other parts of close error are tested by websocket_test.go right now
+	// with the autobahn tests.
+
+	testCases := []struct {
+		name    string
+		ce      CloseError
+		success bool
+	}{
+		{
+			name: "normal",
+			ce: CloseError{
+				Code:   StatusNormalClosure,
+				Reason: strings.Repeat("x", maxControlFramePayload-2),
+			},
+			success: true,
+		},
+		{
+			name: "bigReason",
+			ce: CloseError{
+				Code:   StatusNormalClosure,
+				Reason: strings.Repeat("x", maxControlFramePayload-1),
+			},
+			success: false,
+		},
+		{
+			name: "bigCode",
+			ce: CloseError{
+				Code:   math.MaxUint16,
+				Reason: strings.Repeat("x", maxControlFramePayload-2),
+			},
+			success: false,
+		},
+	}
+
+	for _, tc := range testCases {
+		tc := tc
+		t.Run(tc.name, func(t *testing.T) {
+			t.Parallel()
+
+			_, err := tc.ce.bytes()
+			if (err == nil) != tc.success {
+				t.Fatalf("unexpected error value: %v", err)
+			}
+		})
+	}
+}
diff --git a/websocket.go b/websocket.go
index 09a94e7..717cc75 100644
--- a/websocket.go
+++ b/websocket.go
@@ -23,7 +23,8 @@ type control struct {
 type Conn struct {
 	subprotocol string
 	br          *bufio.Reader
-	// TODO Cannot use bufio writer because for compression we need to know how much is buffered and compress it if large.
+	// TODO switch to []byte for write buffering because for messages larger than buffers, there will always be 3 writes. One for the frame, one for the message, one for the fin.
+	// Also will help for compression.
 	bw     *bufio.Writer
 	closer io.Closer
 	client bool
@@ -225,12 +226,12 @@ func (c *Conn) handleControl(h header) {
 	case opPong:
 	case opClose:
 		if len(b) > 0 {
-			code, reason, err := parseClosePayload(b)
+			ce, err := parseClosePayload(b)
 			if err != nil {
 				c.close(xerrors.Errorf("read invalid close payload: %w", err))
 				return
 			}
-			c.Close(code, reason)
+			c.Close(ce.Code, ce.Reason)
 		} else {
 			c.writeClose(nil, CloseError{
 				Code: StatusNoStatusRcvd,
@@ -279,8 +280,7 @@ func (c *Conn) readLoop() {
 				return
 			}
 		default:
-			// TODO send back protocol violation message or figure out what RFC wants.
-			c.close(xerrors.Errorf("unexpected opcode in header: %#v", h))
+			c.Close(StatusProtocolError, fmt.Sprintf("unknown opcode %v", h.opcode))
 			return
 		}
 
@@ -338,18 +338,23 @@ func (c *Conn) writePong(p []byte) error {
 // Close closes the WebSocket connection with the given status code and reason.
 // It will write a WebSocket close frame with a timeout of 5 seconds.
 func (c *Conn) Close(code StatusCode, reason string) error {
+	ce := CloseError{
+		Code:   code,
+		Reason: reason,
+	}
+
 	// This function also will not wait for a close frame from the peer like the RFC
 	// wants because that makes no sense and I don't think anyone actually follows that.
 	// Definitely worth seeing what popular browsers do later.
-	p, err := closePayload(code, reason)
+	p, err := ce.bytes()
 	if err != nil {
-		p, _ = closePayload(StatusInternalError, fmt.Sprintf("websocket: application tried to send code %v but code or reason was invalid", code))
+		ce = CloseError{
+			Code: StatusInternalError,
+		}
+		p, _ = ce.bytes()
 	}
 
-	cerr := c.writeClose(p, CloseError{
-		Code:   code,
-		Reason: reason,
-	})
+	cerr := c.writeClose(p, ce)
 	if err != nil {
 		return err
 	}
diff --git a/websocket_test.go b/websocket_test.go
index 14dcc8c..2133482 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -7,7 +7,9 @@ import (
 	"io"
 	"io/ioutil"
 	"net/http"
+	"net/http/cookiejar"
 	"net/http/httptest"
+	"net/url"
 	"os"
 	"os/exec"
 	"reflect"
@@ -216,6 +218,52 @@ func TestHandshake(t *testing.T) {
 				return nil
 			},
 		},
+		{
+			name: "cookies",
+			server: func(w http.ResponseWriter, r *http.Request) error {
+				cookie, err := r.Cookie("mycookie")
+				if err != nil {
+					return xerrors.Errorf("request is missing mycookie: %w", err)
+				}
+				if cookie.Value != "myvalue" {
+					return xerrors.Errorf("expected %q but got %q", "myvalue", cookie.Value)
+				}
+				c, err := websocket.Accept(w, r)
+				if err != nil {
+					return err
+				}
+				c.Close(websocket.StatusInternalError, "")
+				return nil
+			},
+			client: func(ctx context.Context, u string) error {
+				jar, err := cookiejar.New(nil)
+				if err != nil {
+					return xerrors.Errorf("failed to create cookie jar: %w", err)
+				}
+				parsedURL, err := url.Parse(u)
+				if err != nil {
+					return xerrors.Errorf("failed to parse url: %w", err)
+				}
+				parsedURL.Scheme = "http"
+				jar.SetCookies(parsedURL, []*http.Cookie{
+					{
+						Name:  "mycookie",
+						Value: "myvalue",
+					},
+				})
+				hc := &http.Client{
+					Jar: jar,
+				}
+				c, _, err := websocket.Dial(ctx, u,
+					websocket.DialHTTPClient(hc),
+				)
+				if err != nil {
+					return err
+				}
+				c.Close(websocket.StatusInternalError, "")
+				return nil
+			},
+		},
 	}
 
 	for _, tc := range testCases {
-- 
GitLab