From 9213cc7bf60ddd78d0e6da16c2641eb83e0cf0f5 Mon Sep 17 00:00:00 2001
From: Anmol Sethi <hi@nhooyr.io>
Date: Sat, 30 Mar 2019 23:04:10 -0500
Subject: [PATCH] Significantly simplify core API and the godoc

---
 example_test.go   |  9 ++-----
 json.go           |  7 +++--
 websocket.go      | 65 +++++++++++++++--------------------------------
 websocket_test.go |  5 +---
 4 files changed, 26 insertions(+), 60 deletions(-)

diff --git a/example_test.go b/example_test.go
index 5e6d072..b3ed2a5 100644
--- a/example_test.go
+++ b/example_test.go
@@ -34,14 +34,9 @@ func ExampleAccept_echo() {
 				return err
 			}
 
-			ctx, cancel = context.WithTimeout(ctx, time.Second*10)
-			defer cancel()
-
-			r.SetContext(ctx)
-			r.Limit(32768)
+			r = io.LimitReader(r, 32768)
 
-			w := c.MessageWriter(typ)
-			w.SetContext(ctx)
+			w := c.MessageWriter(ctx, typ)
 			_, err = io.Copy(w, r)
 			if err != nil {
 				return err
diff --git a/json.go b/json.go
index ebe0dfd..514be05 100644
--- a/json.go
+++ b/json.go
@@ -3,6 +3,7 @@ package websocket
 import (
 	"context"
 	"encoding/json"
+	"io"
 
 	"golang.org/x/xerrors"
 )
@@ -31,8 +32,7 @@ func (jc *JSONConn) read(ctx context.Context, v interface{}) error {
 		return xerrors.Errorf("unexpected frame type for json (expected DataText): %v", typ)
 	}
 
-	r.Limit(131072)
-	r.SetContext(ctx)
+	r = io.LimitReader(r, 131072)
 
 	d := json.NewDecoder(r)
 	err = d.Decode(v)
@@ -52,8 +52,7 @@ func (jc JSONConn) Write(ctx context.Context, v interface{}) error {
 }
 
 func (jc JSONConn) write(ctx context.Context, v interface{}) error {
-	w := jc.Conn.MessageWriter(DataText)
-	w.SetContext(ctx)
+	w := jc.Conn.MessageWriter(ctx, DataText)
 
 	e := json.NewEncoder(w)
 	err := e.Encode(v)
diff --git a/websocket.go b/websocket.go
index 2ed4b5d..99d2855 100644
--- a/websocket.go
+++ b/websocket.go
@@ -403,19 +403,19 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
 
 // MessageWriter returns a writer bounded by the context that will write
 // a WebSocket data frame of type dataType to the connection.
-// Ensure you close the MessageWriter once you have written to entire message.
-// Concurrent calls to MessageWriter are ok.
-func (c *Conn) MessageWriter(dataType DataType) *MessageWriter {
-	return &MessageWriter{
+// Ensure you close the messageWriter once you have written to entire message.
+// Concurrent calls to messageWriter are ok.
+func (c *Conn) MessageWriter(ctx context.Context, dataType DataType) io.WriteCloser {
+	return &messageWriter{
 		c:        c,
-		ctx:      context.Background(),
+		ctx:      ctx,
 		datatype: dataType,
 	}
 }
 
-// MessageWriter enables writing to a WebSocket connection.
-// Ensure you close the MessageWriter once you have written to entire message.
-type MessageWriter struct {
+// messageWriter enables writing to a WebSocket connection.
+// Ensure you close the messageWriter once you have written to entire message.
+type messageWriter struct {
 	datatype     DataType
 	ctx          context.Context
 	c            *Conn
@@ -429,7 +429,7 @@ type MessageWriter struct {
 // The frame will automatically be fragmented as appropriate
 // with the buffers obtained from http.Hijacker.
 // Please ensure you call Close once you have written the full message.
-func (w *MessageWriter) Write(p []byte) (int, error) {
+func (w *messageWriter) Write(p []byte) (int, error) {
 	if !w.acquiredLock {
 		select {
 		case <-w.c.closed:
@@ -458,14 +458,9 @@ func (w *MessageWriter) Write(p []byte) (int, error) {
 	}
 }
 
-// SetContext bounds the writer to the context.
-func (w *MessageWriter) SetContext(ctx context.Context) {
-	w.ctx = ctx
-}
-
 // Close flushes the frame to the connection.
-// This must be called for every MessageWriter.
-func (w *MessageWriter) Close() error {
+// This must be called for every messageWriter.
+func (w *messageWriter) Close() error {
 	if !w.acquiredLock {
 		select {
 		case <-w.c.closed:
@@ -492,13 +487,13 @@ func (w *MessageWriter) Close() error {
 // Please use SetContext on the reader to bound the read operation.
 // Your application must keep reading messages for the Conn to automatically respond to ping
 // and close frames.
-func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error) {
+func (c *Conn) ReadMessage(ctx context.Context) (DataType, io.Reader, error) {
 	select {
 	case <-c.closed:
 		return 0, nil, xerrors.Errorf("failed to read message: %w", c.getCloseErr())
 	case opcode := <-c.read:
-		return DataType(opcode), &MessageReader{
-			ctx: context.Background(),
+		return DataType(opcode), &messageReader{
+			ctx: ctx,
 			c:   c,
 		}, nil
 	case <-ctx.Done():
@@ -506,36 +501,21 @@ func (c *Conn) ReadMessage(ctx context.Context) (DataType, *MessageReader, error
 	}
 }
 
-// MessageReader enables reading a data frame from the WebSocket connection.
-type MessageReader struct {
-	n     int
-	limit int
-	c     *Conn
-	ctx   context.Context
+// messageReader enables reading a data frame from the WebSocket connection.
+type messageReader struct {
+	ctx context.Context
+	c   *Conn
 }
 
 // SetContext bounds the read operation to the ctx.
 // By default, the context is the one passed to conn.ReadMessage.
 // You still almost always want a separate context for reading the message though.
-func (r *MessageReader) SetContext(ctx context.Context) {
+func (r *messageReader) SetContext(ctx context.Context) {
 	r.ctx = ctx
 }
 
-// Limit limits the number of bytes read by the reader.
-//
-// Why not use io.LimitReader? io.LimitReader returns a io.EOF
-// after the limit bytes which means its not possible to tell
-// whether the message has been read or a limit has been hit.
-// This results in unclear error and log messages.
-// This function will cause the connection to be closed if the limit is hit
-// with a close reason explaining the error and also an error
-// indicating the limit was hit.
-func (r *MessageReader) Limit(bytes int) {
-	r.limit = bytes
-}
-
 // Read reads as many bytes as possible into p.
-func (r *MessageReader) Read(p []byte) (n int, err error) {
+func (r *messageReader) Read(p []byte) (n int, err error) {
 	select {
 	case <-r.c.closed:
 		return 0, r.c.getCloseErr()
@@ -546,11 +526,6 @@ func (r *MessageReader) Read(p []byte) (n int, err error) {
 		case <-r.c.closed:
 			return 0, r.c.getCloseErr()
 		case n := <-r.c.readDone:
-			r.n += n
-			// TODO make this better later and inside readLoop to prevent the read from actually occuring if over limit.
-			if r.limit > 0 && r.n > r.limit {
-				return 0, xerrors.New("message too big")
-			}
 			return n, nil
 		case <-r.ctx.Done():
 			return 0, r.ctx.Err()
diff --git a/websocket_test.go b/websocket_test.go
index e91e5b2..e585082 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -365,10 +365,7 @@ func echoLoop(ctx context.Context, c *websocket.Conn, t *testing.T) {
 			return err
 		}
 
-		r.SetContext(ctx)
-
-		w := c.MessageWriter(typ)
-		w.SetContext(ctx)
+		w := c.MessageWriter(ctx, typ)
 
 		_, err = io.Copy(w, r)
 		if err != nil {
-- 
GitLab