good morning!!!!

Skip to content
Snippets Groups Projects
Unverified Commit 1dbc1412 authored by Anmol Sethi's avatar Anmol Sethi
Browse files

write: Zero alloc writes with Writer

Closes #354
parent a975390c
Branches
Tags
No related merge requests found
websocket.test
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
set -eu set -eu
cd -- "$(dirname "$0")/.." cd -- "$(dirname "$0")/.."
go test --run=^$ --bench=. "$@" ./... go test --run=^$ --bench=. --benchmem --memprofile ci/out/prof.mem --cpuprofile ci/out/prof.cpu -o ci/out/websocket.test "$@" .
( (
cd ./internal/thirdparty cd ./internal/thirdparty
go test --run=^$ --bench=. "$@" ./... go test --run=^$ --bench=. --benchmem --memprofile ../../ci/out/prof-thirdparty.mem --cpuprofile ../../ci/out/prof-thirdparty.cpu -o ../../ci/out/thirdparty.test "$@" .
) )
...@@ -63,7 +63,7 @@ type Conn struct { ...@@ -63,7 +63,7 @@ type Conn struct {
readCloseFrameErr error readCloseFrameErr error
// Write state. // Write state.
msgWriterState *msgWriterState msgWriter *msgWriter
writeFrameMu *mu writeFrameMu *mu
writeBuf []byte writeBuf []byte
writeHeaderBuf [8]byte writeHeaderBuf [8]byte
...@@ -113,14 +113,14 @@ func newConn(cfg connConfig) *Conn { ...@@ -113,14 +113,14 @@ func newConn(cfg connConfig) *Conn {
c.msgReader = newMsgReader(c) c.msgReader = newMsgReader(c)
c.msgWriterState = newMsgWriterState(c) c.msgWriter = newMsgWriter(c)
if c.client { if c.client {
c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
} }
if c.flate() && c.flateThreshold == 0 { if c.flate() && c.flateThreshold == 0 {
c.flateThreshold = 128 c.flateThreshold = 128
if !c.msgWriterState.flateContextTakeover() { if !c.msgWriter.flateContextTakeover() {
c.flateThreshold = 512 c.flateThreshold = 512
} }
} }
...@@ -157,8 +157,7 @@ func (c *Conn) close(err error) { ...@@ -157,8 +157,7 @@ func (c *Conn) close(err error) {
c.rwc.Close() c.rwc.Close()
go func() { go func() {
c.msgWriterState.close() c.msgWriter.close()
c.msgReader.close() c.msgReader.close()
}() }()
} }
......
...@@ -49,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { ...@@ -49,30 +49,11 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
} }
type msgWriter struct { type msgWriter struct {
mw *msgWriterState
closed bool
}
func (mw *msgWriter) Write(p []byte) (int, error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
return mw.mw.Write(p)
}
func (mw *msgWriter) Close() error {
if mw.closed {
return errors.New("cannot use closed writer")
}
mw.closed = true
return mw.mw.Close()
}
type msgWriterState struct {
c *Conn c *Conn
mu *mu mu *mu
writeMu *mu writeMu *mu
closed bool
ctx context.Context ctx context.Context
opcode opcode opcode opcode
...@@ -82,8 +63,8 @@ type msgWriterState struct { ...@@ -82,8 +63,8 @@ type msgWriterState struct {
flateWriter *flate.Writer flateWriter *flate.Writer
} }
func newMsgWriterState(c *Conn) *msgWriterState { func newMsgWriter(c *Conn) *msgWriter {
mw := &msgWriterState{ mw := &msgWriter{
c: c, c: c,
mu: newMu(c), mu: newMu(c),
writeMu: newMu(c), writeMu: newMu(c),
...@@ -91,7 +72,7 @@ func newMsgWriterState(c *Conn) *msgWriterState { ...@@ -91,7 +72,7 @@ func newMsgWriterState(c *Conn) *msgWriterState {
return mw return mw
} }
func (mw *msgWriterState) ensureFlate() { func (mw *msgWriter) ensureFlate() {
if mw.trimWriter == nil { if mw.trimWriter == nil {
mw.trimWriter = &trimLastFourBytesWriter{ mw.trimWriter = &trimLastFourBytesWriter{
w: util.WriterFunc(mw.write), w: util.WriterFunc(mw.write),
...@@ -104,7 +85,7 @@ func (mw *msgWriterState) ensureFlate() { ...@@ -104,7 +85,7 @@ func (mw *msgWriterState) ensureFlate() {
mw.flate = true mw.flate = true
} }
func (mw *msgWriterState) flateContextTakeover() bool { func (mw *msgWriter) flateContextTakeover() bool {
if mw.c.client { if mw.c.client {
return !mw.c.copts.clientNoContextTakeover return !mw.c.copts.clientNoContextTakeover
} }
...@@ -112,14 +93,11 @@ func (mw *msgWriterState) flateContextTakeover() bool { ...@@ -112,14 +93,11 @@ func (mw *msgWriterState) flateContextTakeover() bool {
} }
func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
err := c.msgWriterState.reset(ctx, typ) err := c.msgWriter.reset(ctx, typ)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &msgWriter{ return c.msgWriter, nil
mw: c.msgWriterState,
closed: false,
}, nil
} }
func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) {
...@@ -129,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error ...@@ -129,8 +107,8 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
} }
if !c.flate() { if !c.flate() {
defer c.msgWriterState.mu.unlock() defer c.msgWriter.mu.unlock()
return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) return c.writeFrame(ctx, true, false, c.msgWriter.opcode, p)
} }
n, err := mw.Write(p) n, err := mw.Write(p)
...@@ -142,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error ...@@ -142,7 +120,7 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error
return n, err return n, err
} }
func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { func (mw *msgWriter) reset(ctx context.Context, typ MessageType) error {
err := mw.mu.lock(ctx) err := mw.mu.lock(ctx)
if err != nil { if err != nil {
return err return err
...@@ -151,13 +129,14 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { ...@@ -151,13 +129,14 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {
mw.ctx = ctx mw.ctx = ctx
mw.opcode = opcode(typ) mw.opcode = opcode(typ)
mw.flate = false mw.flate = false
mw.closed = false
mw.trimWriter.reset() mw.trimWriter.reset()
return nil return nil
} }
func (mw *msgWriterState) putFlateWriter() { func (mw *msgWriter) putFlateWriter() {
if mw.flateWriter != nil { if mw.flateWriter != nil {
putFlateWriter(mw.flateWriter) putFlateWriter(mw.flateWriter)
mw.flateWriter = nil mw.flateWriter = nil
...@@ -165,7 +144,11 @@ func (mw *msgWriterState) putFlateWriter() { ...@@ -165,7 +144,11 @@ func (mw *msgWriterState) putFlateWriter() {
} }
// Write writes the given bytes to the WebSocket connection. // Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) { func (mw *msgWriter) Write(p []byte) (_ int, err error) {
if mw.closed {
return 0, errors.New("cannot use closed writer")
}
err = mw.writeMu.lock(mw.ctx) err = mw.writeMu.lock(mw.ctx)
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to write: %w", err) return 0, fmt.Errorf("failed to write: %w", err)
...@@ -194,7 +177,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) { ...@@ -194,7 +177,7 @@ func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
return mw.write(p) return mw.write(p)
} }
func (mw *msgWriterState) write(p []byte) (int, error) { func (mw *msgWriter) write(p []byte) (int, error) {
n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p)
if err != nil { if err != nil {
return n, fmt.Errorf("failed to write data frame: %w", err) return n, fmt.Errorf("failed to write data frame: %w", err)
...@@ -204,9 +187,14 @@ func (mw *msgWriterState) write(p []byte) (int, error) { ...@@ -204,9 +187,14 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
} }
// Close flushes the frame to the connection. // Close flushes the frame to the connection.
func (mw *msgWriterState) Close() (err error) { func (mw *msgWriter) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer") defer errd.Wrap(&err, "failed to close writer")
if mw.closed {
return errors.New("writer already closed")
}
mw.closed = true
err = mw.writeMu.lock(mw.ctx) err = mw.writeMu.lock(mw.ctx)
if err != nil { if err != nil {
return err return err
...@@ -232,7 +220,7 @@ func (mw *msgWriterState) Close() (err error) { ...@@ -232,7 +220,7 @@ func (mw *msgWriterState) Close() (err error) {
return nil return nil
} }
func (mw *msgWriterState) close() { func (mw *msgWriter) close() {
if mw.c.client { if mw.c.client {
mw.c.writeFrameMu.forceLock() mw.c.writeFrameMu.forceLock()
putBufioWriter(mw.c.bw) putBufioWriter(mw.c.bw)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment