From 07465830d901a15fe81a157849f6b9e259434e07 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 15 Apr 2019 13:35:14 -0500 Subject: [PATCH] Rename DataType to MessageType - Make messageReader a value type to avoid allocation - Add a bunch of important TODOs --- accept.go | 1 + datatype.go | 12 ----------- datatype_string.go | 25 ---------------------- example_test.go | 7 ++++++- json.go | 5 +++-- messagetype.go | 12 +++++++++++ messagetype_string.go | 25 ++++++++++++++++++++++ websocket.go | 49 ++++++++++++++++++++++++++++++------------- 8 files changed, 81 insertions(+), 55 deletions(-) delete mode 100644 datatype.go delete mode 100644 datatype_string.go create mode 100644 messagetype.go create mode 100644 messagetype_string.go diff --git a/accept.go b/accept.go index 2dabdae..f505c03 100644 --- a/accept.go +++ b/accept.go @@ -41,6 +41,7 @@ func (o acceptOrigins) acceptOption() {} // Use this option with caution to avoid exposing your WebSocket // server to a CSRF attack. // See https://stackoverflow.com/a/37837709/4283659 +// TODO remove in favour of AcceptInsecureOrigin func AcceptOrigins(origins ...string) AcceptOption { return acceptOrigins(origins) } diff --git a/datatype.go b/datatype.go deleted file mode 100644 index a1d8d57..0000000 --- a/datatype.go +++ /dev/null @@ -1,12 +0,0 @@ -package websocket - -// DataType represents the Opcode of a WebSocket data frame. -type DataType int - -//go:generate go run golang.org/x/tools/cmd/stringer -type=DataType - -// DataType constants. -const ( - DataText DataType = DataType(opText) - DataBinary DataType = DataType(opBinary) -) diff --git a/datatype_string.go b/datatype_string.go deleted file mode 100644 index 1b4aaba..0000000 --- a/datatype_string.go +++ /dev/null @@ -1,25 +0,0 @@ -// Code generated by "stringer -type=DataType"; DO NOT EDIT. - -package websocket - -import "strconv" - -func _() { - // An "invalid array index" compiler error signifies that the constant values have changed. - // Re-run the stringer command to generate them again. - var x [1]struct{} - _ = x[DataText-1] - _ = x[DataBinary-2] -} - -const _DataType_name = "DataTextDataBinary" - -var _DataType_index = [...]uint8{0, 8, 18} - -func (i DataType) String() string { - i -= 1 - if i < 0 || i >= DataType(len(_DataType_index)-1) { - return "DataType(" + strconv.FormatInt(int64(i+1), 10) + ")" - } - return _DataType_name[_DataType_index[i]:_DataType_index[i+1]] -} diff --git a/example_test.go b/example_test.go index bee8e92..702239b 100644 --- a/example_test.go +++ b/example_test.go @@ -14,13 +14,18 @@ import ( func ExampleAccept_echo() { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - c, err := websocket.Accept(w, r) + c, err := websocket.Accept(w, r, websocket.AcceptSubprotocols("echo")) if err != nil { log.Printf("server handshake failed: %v", err) return } defer c.Close(websocket.StatusInternalError, "") + if c.Subprotocol() == "" { + c.Close(websocket.StatusPolicyViolation, "cannot communicate with the default protocol") + return + } + echo := func() error { ctx, cancel := context.WithTimeout(r.Context(), time.Minute) defer cancel() diff --git a/json.go b/json.go index 53869b5..24e6f31 100644 --- a/json.go +++ b/json.go @@ -28,7 +28,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error { return err } - if typ != DataText { + if typ != MessageText { return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ) } @@ -39,6 +39,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error { if err != nil { return xerrors.Errorf("failed to decode json: %w", err) } + return nil } @@ -52,7 +53,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w := jc.Conn.Write(ctx, DataText) + w := jc.Conn.Write(ctx, MessageText) e := json.NewEncoder(w) err := e.Encode(v) diff --git a/messagetype.go b/messagetype.go new file mode 100644 index 0000000..54276b3 --- /dev/null +++ b/messagetype.go @@ -0,0 +1,12 @@ +package websocket + +// MessageType represents the Opcode of a WebSocket data frame. +type MessageType int + +//go:generate go run golang.org/x/tools/cmd/stringer -type=MessageType + +// MessageType constants. +const ( + MessageText MessageType = MessageType(opText) + MessageBinary MessageType = MessageType(opBinary) +) diff --git a/messagetype_string.go b/messagetype_string.go new file mode 100644 index 0000000..bc62db9 --- /dev/null +++ b/messagetype_string.go @@ -0,0 +1,25 @@ +// Code generated by "stringer -type=MessageType"; DO NOT EDIT. + +package websocket + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[MessageText-1] + _ = x[MessageBinary-2] +} + +const _MessageType_name = "MessageTextMessageBinary" + +var _MessageType_index = [...]uint8{0, 11, 24} + +func (i MessageType) String() string { + i -= 1 + if i < 0 || i >= MessageType(len(_MessageType_index)-1) { + return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] +} diff --git a/websocket.go b/websocket.go index 717cc75..18efb18 100644 --- a/websocket.go +++ b/websocket.go @@ -23,8 +23,7 @@ type control struct { type Conn struct { subprotocol string br *bufio.Reader - // TODO switch to []byte for write buffering because for messages larger than buffers, there will always be 3 writes. One for the frame, one for the message, one for the fin. - // Also will help for compression. + // TODO switch to []byte for write buffering for predicting compression in memory maybe bw *bufio.Writer closer io.Closer client bool @@ -36,7 +35,7 @@ type Conn struct { // Writers should send on write to begin sending // a message and then follow that up with some data // on writeBytes. - write chan DataType + write chan MessageType control chan control writeBytes chan []byte writeDone chan struct{} @@ -45,9 +44,10 @@ type Conn struct { // Then send a byte slice to readBytes to read into it. // The n of bytes read will be sent on readDone once the read into a slice is complete. // readDone will receive 0 when EOF is reached. - read chan opcode - readBytes chan []byte - readDone chan int + read chan opcode + readBytes chan []byte + readDone chan int + readerDone chan struct{} } func (c *Conn) close(err error) { @@ -77,12 +77,13 @@ func (c *Conn) Subprotocol() string { func (c *Conn) init() { c.closed = make(chan struct{}) - c.write = make(chan DataType) + c.write = make(chan MessageType) c.control = make(chan control) c.writeDone = make(chan struct{}) c.read = make(chan opcode) c.readDone = make(chan int) c.readBytes = make(chan []byte) + c.readerDone = make(chan struct{}) runtime.SetFinalizer(c, func(c *Conn) { c.Close(StatusInternalError, "websocket: connection ended up being garbage collected") @@ -120,7 +121,7 @@ messageLoop: for { c.writeBytes = make(chan []byte) - var dataType DataType + var dataType MessageType select { case <-c.closed: return @@ -321,7 +322,7 @@ func (c *Conn) readLoop() { select { case <-c.closed: return - case c.readDone <- 0: + case c.readerDone <- struct{}{}: } } } @@ -407,7 +408,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // a WebSocket data frame of type dataType to the connection. // Ensure you close the messageWriter once you have written to entire message. // Concurrent calls to messageWriter are ok. -func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser { +func (c *Conn) Write(ctx context.Context, dataType MessageType) io.WriteCloser { + // TODO acquire write here, move state into Conn and make messageWriter allocation free. return &messageWriter{ c: c, ctx: ctx, @@ -418,7 +420,7 @@ func (c *Conn) Write(ctx context.Context, dataType DataType) io.WriteCloser { // messageWriter enables writing to a WebSocket connection. // Ensure you close the messageWriter once you have written to entire message. type messageWriter struct { - datatype DataType + datatype MessageType ctx context.Context c *Conn acquiredLock bool @@ -489,12 +491,16 @@ func (w *messageWriter) Close() error { // Please use SetContext on the reader to bound the read operation. // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames. -func (c *Conn) Read(ctx context.Context) (DataType, io.Reader, error) { +func (c *Conn) Read(ctx context.Context) (MessageType, io.Reader, error) { + // TODO error if the reader is not done select { + case <-c.readerDone: + // The previous reader just hit a io.EOF, we handle it for users + return c.Read(ctx) case <-c.closed: return 0, nil, xerrors.Errorf("failed to read message: %w", c.closeErr) case opcode := <-c.read: - return DataType(opcode), &messageReader{ + return MessageType(opcode), messageReader{ ctx: ctx, c: c, }, nil @@ -510,13 +516,26 @@ type messageReader struct { } // Read reads as many bytes as possible into p. -func (r *messageReader) Read(p []byte) (n int, err error) { +func (r messageReader) Read(p []byte) (int, error) { + n, err := r.read(p) + if err != nil { + // Have to return io.EOF directly for now. + if err == io.EOF { + return 0, io.EOF + } + return n, xerrors.Errorf("failed to read: %w", err) + } + return n, nil +} + +func (r messageReader) read(p []byte) (int, error) { select { case <-r.c.closed: return 0, r.c.closeErr - case <-r.c.readDone: + case <-r.c.readerDone: return 0, io.EOF case r.c.readBytes <- p: + // TODO this is potentially racey as if we return if the context is cancelled, or the conn is closed we don't know if the p is ok to use. we must close the connection and also ensure the readLoop is done before returning, likewise with writes. select { case <-r.c.closed: return 0, r.c.closeErr -- GitLab