From 9213cc7bf60ddd78d0e6da16c2641eb83e0cf0f5 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Sat, 30 Mar 2019 23:04:10 -0500 Subject: [PATCH] Significantly simplify core API and the godoc --- example_test.go | 9 ++----- json.go | 7 +++-- websocket.go | 65 +++++++++++++++-------------------------------- websocket_test.go | 5 +--- 4 files changed, 26 insertions(+), 60 deletions(-) diff --git a/example_test.go b/example_test.go index 5e6d072..b3ed2a5 100644 --- a/example_test.go +++ b/example_test.go @@ -34,14 +34,9 @@ func ExampleAccept_echo() { return err } - ctx, cancel = context.WithTimeout(ctx, time.Second*10) - defer cancel() - - r.SetContext(ctx) - r.Limit(32768) + r = io.LimitReader(r, 32768) - w := c.MessageWriter(typ) - w.SetContext(ctx) + w := c.MessageWriter(ctx, typ) _, err = io.Copy(w, r) if err != nil { return err diff --git a/json.go b/json.go index ebe0dfd..514be05 100644 --- a/json.go +++ b/json.go @@ -3,6 +3,7 @@ package websocket import ( "context" "encoding/json" + "io" "golang.org/x/xerrors" ) @@ -31,8 +32,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error { return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ) } - r.Limit(131072) - r.SetContext(ctx) + r = io.LimitReader(r, 131072) d := json.NewDecoder(r) err = d.Decode(v) @@ -52,8 +52,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error { } func (jc JSONConn) write(ctx context.Context, v interface{}) error { - w := jc.Conn.MessageWriter(DataText) - w.SetContext(ctx) + w := jc.Conn.MessageWriter(ctx, DataText) e := json.NewEncoder(w) err := e.Encode(v) diff --git a/websocket.go b/websocket.go index 2ed4b5d..99d2855 100644 --- a/websocket.go +++ b/websocket.go @@ -403,19 +403,19 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error // MessageWriter returns a writer bounded by the context that will write // a WebSocket data frame of type dataType to the connection. -// Ensure you close the MessageWriter once you have written to entire message. -// Concurrent calls to MessageWriter are ok. -func (c *Conn) MessageWriter(dataType DataType) *MessageWriter { - return &MessageWriter{ +// Ensure you close the messageWriter once you have written to entire message. +// Concurrent calls to messageWriter are ok. +func (c *Conn) MessageWriter(ctx context.Context, dataType DataType) io.WriteCloser { + return &messageWriter{ c: c, - ctx: context.Background(), + ctx: ctx, datatype: dataType, } } -// MessageWriter enables writing to a WebSocket connection. -// Ensure you close the MessageWriter once you have written to entire message. -type MessageWriter struct { +// messageWriter enables writing to a WebSocket connection. +// Ensure you close the messageWriter once you have written to entire message. +type messageWriter struct { datatype DataType ctx context.Context c *Conn @@ -429,7 +429,7 @@ type MessageWriter struct { // The frame will automatically be fragmented as appropriate // with the buffers obtained from http.Hijacker. // Please ensure you call Close once you have written the full message. -func (w *MessageWriter) Write(p []byte) (int, error) { +func (w *messageWriter) Write(p []byte) (int, error) { if !w.acquiredLock { select { case <-w.c.closed: @@ -458,14 +458,9 @@ func (w *MessageWriter) Write(p []byte) (int, error) { } } -// SetContext bounds the writer to the context. -func (w *MessageWriter) SetContext(ctx context.Context) { - w.ctx = ctx -} - // Close flushes the frame to the connection. -// This must be called for every MessageWriter. -func (w *MessageWriter) Close() error { +// This must be called for every messageWriter. +func (w *messageWriter) Close() error { if !w.acquiredLock { select { case <-w.c.closed: @@ -492,13 +487,13 @@ func (w *MessageWriter) Close() error { // Please use SetContext on the reader to bound the read operation. // Your application must keep reading messages for the Conn to automatically respond to ping // and close frames. -func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) { +func (c *Conn) ReadMessage(ctx context.Context) (DataType, io.Reader, error) { select { case <-c.closed: return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr()) case opcode := <-c.read: - return DataType(opcode), &MessageReader{ - ctx: context.Background(), + return DataType(opcode), &messageReader{ + ctx: ctx, c: c, }, nil case <-ctx.Done(): @@ -506,36 +501,21 @@ func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error } } -// MessageReader enables reading a data frame from the WebSocket connection. -type MessageReader struct { - n int - limit int - c *Conn - ctx context.Context +// messageReader enables reading a data frame from the WebSocket connection. +type messageReader struct { + ctx context.Context + c *Conn } // SetContext bounds the read operation to the ctx. // By default, the context is the one passed to conn.ReadMessage. // You still almost always want a separate context for reading the message though. -func (r *MessageReader) SetContext(ctx context.Context) { +func (r *messageReader) SetContext(ctx context.Context) { r.ctx = ctx } -// Limit limits the number of bytes read by the reader. -// -// Why not use io.LimitReader? io.LimitReader returns a io.EOF -// after the limit bytes which means its not possible to tell -// whether the message has been read or a limit has been hit. -// This results in unclear error and log messages. -// This function will cause the connection to be closed if the limit is hit -// with a close reason explaining the error and also an error -// indicating the limit was hit. -func (r *MessageReader) Limit(bytes int) { - r.limit = bytes -} - // Read reads as many bytes as possible into p. -func (r *MessageReader) Read(p []byte) (n int, err error) { +func (r *messageReader) Read(p []byte) (n int, err error) { select { case <-r.c.closed: return 0, r.c.getCloseErr() @@ -546,11 +526,6 @@ func (r *MessageReader) Read(p []byte) (n int, err error) { case <-r.c.closed: return 0, r.c.getCloseErr() case n := <-r.c.readDone: - r.n += n - // TODO make this better later and inside readLoop to prevent the read from actually occuring if over limit. - if r.limit > 0 && r.n > r.limit { - return 0, xerrors.New("message too big") - } return n, nil case <-r.ctx.Done(): return 0, r.ctx.Err() diff --git a/websocket_test.go b/websocket_test.go index e91e5b2..e585082 100644 --- a/websocket_test.go +++ b/websocket_test.go @@ -365,10 +365,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) { return err } - r.SetContext(ctx) - - w := c.MessageWriter(typ) - w.SetContext(ctx) + w := c.MessageWriter(ctx, typ) _, err = io.Copy(w, r) if err != nil { -- GitLab