diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 0000000000000000000000000000000000000000..1ff25dace8cf6347f2b6335ad7c3256d00d8d03a --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,7 @@ +package util + +type WriterFunc func(p []byte) (int, error) + +func (f WriterFunc) Write(p []byte) (int, error) { + return f(p) +} diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 2000a77af8dc5f0296142cbdeefd5af780f442b8..c6b29ee1ab65d9fbfba8f0cf7a9d1539266a0f34 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -8,6 +8,7 @@ import ( "nhooyr.io/websocket" "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/util" "nhooyr.io/websocket/internal/errd" ) @@ -51,17 +52,17 @@ func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { defer errd.Wrap(&err, "failed to write JSON message") - w, err := c.Writer(ctx, websocket.MessageText) - if err != nil { - return err - } - // json.Marshal cannot reuse buffers between calls as it has to return // a copy of the byte slice but Encoder does as it directly writes to w. - err = json.NewEncoder(w).Encode(v) + err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) { + err := c.Write(ctx, websocket.MessageText, p) + if err != nil { + return 0, err + } + return len(p), nil + })).Encode(v) if err != nil { return fmt.Errorf("failed to marshal JSON: %w", err) } - - return w.Close() + return nil }