From ee1f3c601b22b62b9f82c2a2646c9611c1ea838e Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Fri, 7 Jun 2019 18:14:56 -0400 Subject: [PATCH] Reuse write and read header buffers Next is reusing the header structures. --- header.go | 18 ++++++++++++++---- header_test.go | 4 ++-- websocket.go | 19 ++++++++++--------- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/header.go b/header.go index b1aa255..16ab647 100644 --- a/header.go +++ b/header.go @@ -31,10 +31,19 @@ type header struct { maskKey [4]byte } +func makeWriteHeaderBuf() []byte { + return make([]byte, maxHeaderSize) +} + // bytes returns the bytes of the header. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func marshalHeader(h header) []byte { - b := make([]byte, 2, maxHeaderSize) +func writeHeader(b []byte, h header) []byte { + if b == nil { + b = makeWriteHeaderBuf() + } + + b = b[:2] + b[0] = 0 if h.fin { b[0] |= 1 << 7 @@ -75,7 +84,7 @@ func marshalHeader(h header) []byte { return b } -func makeHeaderBuf() []byte { +func makeReadHeaderBuf() []byte { return make([]byte, maxHeaderSize-2) } @@ -83,8 +92,9 @@ func makeHeaderBuf() []byte { // See https://tools.ietf.org/html/rfc6455#section-5.2 func readHeader(b []byte, r io.Reader) (header, error) { if b == nil { - b = makeHeaderBuf() + b = makeReadHeaderBuf() } + // We read the first two bytes first so that we know // exactly how long the header is. b = b[:2] diff --git a/header_test.go b/header_test.go index 78d6189..b45854e 100644 --- a/header_test.go +++ b/header_test.go @@ -24,7 +24,7 @@ func TestHeader(t *testing.T) { t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() - b := marshalHeader(header{ + b := writeHeader(nil, header{ payloadLength: 1<<16 + 1, }) @@ -90,7 +90,7 @@ func TestHeader(t *testing.T) { } func testHeader(t *testing.T, h header) { - b := marshalHeader(h) + b := writeHeader(nil, h) r := bytes.NewReader(b) h2, err := readHeader(nil, r) if err != nil { diff --git a/websocket.go b/websocket.go index ebe1259..375685e 100644 --- a/websocket.go +++ b/websocket.go @@ -45,21 +45,21 @@ type Conn struct { // writeFrameLock is acquired to write a single frame. // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} + writeHeaderBuf []byte // Used to ensure the previous reader is read till EOF before allowing // a new one. previousReader *messageReader // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} + readFrameLock chan struct{} + readHeaderBuf []byte + controlPayloadBuf []byte setReadTimeout chan context.Context setWriteTimeout chan context.Context activePingsMu sync.Mutex activePings map[string]chan<- struct{} - - headerBuf []byte - controlPayloadBuf []byte } func (c *Conn) init() { @@ -77,7 +77,8 @@ func (c *Conn) init() { c.activePings = make(map[string]chan<- struct{}) - c.headerBuf = makeHeaderBuf() + c.writeHeaderBuf = makeWriteHeaderBuf() + c.readHeaderBuf = makeReadHeaderBuf() c.controlPayloadBuf = make([]byte, maxControlFramePayload) runtime.SetFinalizer(c, func(c *Conn) { @@ -215,7 +216,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.setReadTimeout <- ctx: } - h, err := readHeader(c.headerBuf, c.br) + h, err := readHeader(c.readHeaderBuf, c.br) if err != nil { select { case <-c.closed: @@ -628,7 +629,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } } - b2 := marshalHeader(h) + headerBytes := writeHeader(c.writeHeaderBuf, h) err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { @@ -651,7 +652,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte default: } - err = xerrors.Errorf("failed to write frame: %w", err) + err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err) // We need to release the lock first before closing the connection to ensure // the lock can be acquired inside close to ensure no one can access c.bw. c.releaseLock(c.writeFrameLock) @@ -660,7 +661,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte return err } - _, err = c.bw.Write(b2) + _, err = c.bw.Write(headerBytes) if err != nil { return 0, writeErr(err) } -- GitLab