diff --git a/internal/util/util.go b/internal/util/util.go
new file mode 100644
index 0000000000000000000000000000000000000000..1ff25dace8cf6347f2b6335ad7c3256d00d8d03a
--- /dev/null
+++ b/internal/util/util.go
@@ -0,0 +1,7 @@
+package util
+
+type WriterFunc func(p []byte) (int, error)
+
+func (f WriterFunc) Write(p []byte) (int, error) {
+	return f(p)
+}
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
index 2000a77af8dc5f0296142cbdeefd5af780f442b8..c6b29ee1ab65d9fbfba8f0cf7a9d1539266a0f34 100644
--- a/wsjson/wsjson.go
+++ b/wsjson/wsjson.go
@@ -8,6 +8,7 @@ import (
 
 	"nhooyr.io/websocket"
 	"nhooyr.io/websocket/internal/bpool"
+	"nhooyr.io/websocket/internal/util"
 	"nhooyr.io/websocket/internal/errd"
 )
 
@@ -51,17 +52,17 @@ func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
 func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
 	defer errd.Wrap(&err, "failed to write JSON message")
 
-	w, err := c.Writer(ctx, websocket.MessageText)
-	if err != nil {
-		return err
-	}
-
 	// json.Marshal cannot reuse buffers between calls as it has to return
 	// a copy of the byte slice but Encoder does as it directly writes to w.
-	err = json.NewEncoder(w).Encode(v)
+	err = json.NewEncoder(util.WriterFunc(func(p []byte) (int, error) {
+		err := c.Write(ctx, websocket.MessageText, p)
+		if err != nil {
+			return 0, err
+		}
+		return len(p), nil
+	})).Encode(v)
 	if err != nil {
 		return fmt.Errorf("failed to marshal JSON: %w", err)
 	}
-
-	return w.Close()
+	return nil
 }