From 07465830d901a15fe81a157849f6b9e259434e07 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Mon, 15 Apr 2019 13:35:14 -0500
Subject: [PATCH] Rename DataType to MessageType

- Make messageReader a value type to avoid allocation
- Add a bunch of important TODOs
---
 accept.go             |  1 +
 datatype.go           | 12 -----------
 datatype_string.go    | 25 ----------------------
 example_test.go       |  7 ++++++-
 json.go               |  5 +++--
 messagetype.go        | 12 +++++++++++
 messagetype_string.go | 25 ++++++++++++++++++++++
 websocket.go          | 49 ++++++++++++++++++++++++++++++-------------
 8 files changed, 81 insertions(+), 55 deletions(-)
 delete mode 100644 datatype.go
 delete mode 100644 datatype_string.go
 create mode 100644 messagetype.go
 create mode 100644 messagetype_string.go

diff --git a/accept.go b/accept.go
index 2dabdae..f505c03 100644
--- a/accept.go
+++ b/accept.go
@@ -41,6 +41,7 @@ func (o acceptOrigins) acceptOption() {}
 // Use this option with caution to avoid exposing your WebSocket
 // server to a CSRF attack.
 // See https://stackoverflow.com/a/37837709/4283659
+// TODO remove in favour of AcceptInsecureOrigin
 func AcceptOrigins(origins ...string) AcceptOption {
 	return acceptOrigins(origins)
 }
diff --git a/datatype.go b/datatype.go
deleted file mode 100644
index a1d8d57..0000000
--- a/datatype.go
+++ /dev/null
@@ -1,12 +0,0 @@
-package websocket
-
-// DataType represents the Opcode of a WebSocket data frame.
-type DataType int
-
-//go:generate go run golang.org/x/tools/cmd/stringer -type=DataType
-
-// DataType constants.
-const (
-	DataText   DataType = DataType(opText)
-	DataBinary DataType = DataType(opBinary)
-)
diff --git a/datatype_string.go b/datatype_string.go
deleted file mode 100644
index 1b4aaba..0000000
--- a/datatype_string.go
+++ /dev/null
@@ -1,25 +0,0 @@
-// Code generated by "stringer -type=DataType"; DO NOT EDIT.
-
-package websocket
-
-import "strconv"
-
-func _() {
-	// An "invalid array index" compiler error signifies that the constant values have changed.
-	// Re-run the stringer command to generate them again.
-	var x [1]struct{}
-	_ = x[DataText-1]
-	_ = x[DataBinary-2]
-}
-
-const _DataType_name = "DataTextDataBinary"
-
-var _DataType_index = [...]uint8{0, 8, 18}
-
-func (i DataType) String() string {
-	i -= 1
-	if i < 0 || i >= DataType(len(_DataType_index)-1) {
-		return "DataType(" + strconv.FormatInt(int64(i+1), 10) + ")"
-	}
-	return _DataType_name[_DataType_index[i]:_DataType_index[i+1]]
-}
diff --git a/example_test.go b/example_test.go
index bee8e92..702239b 100644
--- a/example_test.go
+++ b/example_test.go
@@ -14,13 +14,18 @@ import (
 
 func ExampleAccept_echo() {
 	fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
-		c, err := websocket.Accept(w, r)
+		c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("echo"))
 		if err != nil {
 			log.Printf("server handshake failed: %v", err)
 			return
 		}
 		defer c.Close(websocket.StatusInternalError, "")
 
+		if c.Subprotocol() == "" {
+			c.Close(websocket.StatusPolicyViolation, "cannot communicate with the default protocol")
+			return
+		}
+
 		echo := func() error {
 			ctx, cancel := context.WithTimeout(r.Context(), time.Minute)
 			defer cancel()
diff --git a/json.go b/json.go
index 53869b5..24e6f31 100644
--- a/json.go
+++ b/json.go
@@ -28,7 +28,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
 		return err
 	}
 
-	if typ != DataText {
+	if typ != MessageText {
 		return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ)
 	}
 
@@ -39,6 +39,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
 	if err != nil {
 		return xerrors.Errorf("failed to decode json: %w", err)
 	}
+
 	return nil
 }
 
@@ -52,7 +53,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
 }
 
 func (jc JSONConn) write(ctx context.Context, v interface{}) error {
-	w := jc.Conn.Write(ctx, DataText)
+	w := jc.Conn.Write(ctx, MessageText)
 
 	e := json.NewEncoder(w)
 	err := e.Encode(v)
diff --git a/messagetype.go b/messagetype.go
new file mode 100644
index 0000000..54276b3
--- /dev/null
+++ b/messagetype.go
@@ -0,0 +1,12 @@
+package websocket
+
+// MessageType represents the Opcode of a WebSocket data frame.
+type MessageType int
+
+//go:generate go run golang.org/x/tools/cmd/stringer -type=MessageType
+
+// MessageType constants.
+const (
+	MessageText   MessageType = MessageType(opText)
+	MessageBinary MessageType = MessageType(opBinary)
+)
diff --git a/messagetype_string.go b/messagetype_string.go
new file mode 100644
index 0000000..bc62db9
--- /dev/null
+++ b/messagetype_string.go
@@ -0,0 +1,25 @@
+// Code generated by "stringer -type=MessageType"; DO NOT EDIT.
+
+package websocket
+
+import "strconv"
+
+func _() {
+	// An "invalid array index" compiler error signifies that the constant values have changed.
+	// Re-run the stringer command to generate them again.
+	var x [1]struct{}
+	_ = x[MessageText-1]
+	_ = x[MessageBinary-2]
+}
+
+const _MessageType_name = "MessageTextMessageBinary"
+
+var _MessageType_index = [...]uint8{0, 11, 24}
+
+func (i MessageType) String() string {
+	i -= 1
+	if i < 0 || i >= MessageType(len(_MessageType_index)-1) {
+		return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")"
+	}
+	return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]]
+}
diff --git a/websocket.go b/websocket.go
index 717cc75..18efb18 100644
--- a/websocket.go
+++ b/websocket.go
@@ -23,8 +23,7 @@ type control struct {
 type Conn struct {
 	subprotocol string
 	br          *bufio.Reader
-	// TODO switch to []byte for write buffering because for messages larger than buffers, there will always be 3 writes. One for the frame, one for the message, one for the fin.
-	// Also will help for compression.
+	// TODO switch to []byte for write buffering for predicting compression in memory maybe
 	bw     *bufio.Writer
 	closer io.Closer
 	client bool
@@ -36,7 +35,7 @@ type Conn struct {
 	// Writers should send on write to begin sending
 	// a message and then follow that up with some data
 	// on writeBytes.
-	write      chan DataType
+	write      chan MessageType
 	control    chan control
 	writeBytes chan []byte
 	writeDone  chan struct{}
@@ -45,9 +44,10 @@ type Conn struct {
 	// Then send a byte slice to readBytes to read into it.
 	// The n of bytes read will be sent on readDone once the read into a slice is complete.
 	// readDone will receive 0 when EOF is reached.
-	read      chan opcode
-	readBytes chan []byte
-	readDone  chan int
+	read       chan opcode
+	readBytes  chan []byte
+	readDone   chan int
+	readerDone chan struct{}
 }
 
 func (c *Conn) close(err error) {
@@ -77,12 +77,13 @@ func (c *Conn) Subprotocol() string {
 
 func (c *Conn) init() {
 	c.closed = make(chan struct{})
-	c.write = make(chan DataType)
+	c.write = make(chan MessageType)
 	c.control = make(chan control)
 	c.writeDone = make(chan struct{})
 	c.read = make(chan opcode)
 	c.readDone = make(chan int)
 	c.readBytes = make(chan []byte)
+	c.readerDone = make(chan struct{})
 
 	runtime.SetFinalizer(c, func(c *Conn) {
 		c.Close(StatusInternalError, "websocket: connection ended up being garbage collected")
@@ -120,7 +121,7 @@ messageLoop:
 	for {
 		c.writeBytes = make(chan []byte)
 
-		var dataType DataType
+		var dataType MessageType
 		select {
 		case <-c.closed:
 			return
@@ -321,7 +322,7 @@ func (c *Conn) readLoop() {
 			select {
 			case <-c.closed:
 				return
-			case c.readDone <- 0:
+			case c.readerDone <- struct{}{}:
 			}
 		}
 	}
@@ -407,7 +408,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
 // a WebSocket data frame of type dataType to the connection.
 // Ensure you close the messageWriter once you have written to entire message.
 // Concurrent calls to messageWriter are ok.
-func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser {
+func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser {
+	// TODO acquire write here, move state into Conn and make messageWriter allocation free.
 	return &messageWriter{
 		c:        c,
 		ctx:      ctx,
@@ -418,7 +420,7 @@ func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser {
 // messageWriter enables writing to a WebSocket connection.
 // Ensure you close the messageWriter once you have written to entire message.
 type messageWriter struct {
-	datatype     DataType
+	datatype     MessageType
 	ctx          context.Context
 	c            *Conn
 	acquiredLock bool
@@ -489,12 +491,16 @@ func (w *messageWriter) Close() error {
 // Please use SetContext on the reader to bound the read operation.
 // Your application must keep reading messages for the Conn to automatically respond to ping
 // and close frames.
-func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) {
+func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) {
+	// TODO error if the reader is not done
 	select {
+	case <-c.readerDone:
+		// The previous reader just hit a io.EOF, we handle it for users
+		return c.Read(ctx)
 	case <-c.closed:
 		return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr)
 	case opcode := <-c.read:
-		return DataType(opcode), &messageReader{
+		return MessageType(opcode), messageReader{
 			ctx: ctx,
 			c:   c,
 		}, nil
@@ -510,13 +516,26 @@ type messageReader struct {
 }
 
 // Read reads as many bytes as possible into p.
-func (r *messageReader) Read(p []byte) (n int, err error) {
+func (r messageReader) Read(p []byte) (int, error) {
+	n, err := r.read(p)
+	if err != nil {
+		// Have to return io.EOF directly for now.
+		if err == io.EOF {
+			return 0, io.EOF
+		}
+		return n, xerrors.Errorf("failed to read: %w", err)
+	}
+	return n, nil
+}
+
+func (r messageReader) read(p []byte) (int, error) {
 	select {
 	case <-r.c.closed:
 		return 0, r.c.closeErr
-	case <-r.c.readDone:
+	case <-r.c.readerDone:
 		return 0, io.EOF
 	case r.c.readBytes <- p:
+		// TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes.
 		select {
 		case <-r.c.closed:
 			return 0, r.c.closeErr
-- 
GitLab