diff --git a/header.go b/header.go index b1aa2554950d36ef4cbff23668306acd69f5b08d..16ab6474e69c46b10bea83df98828c35482da032 100644 --- a/header.go +++ b/header.go @@ -31,10 +31,19 @@ type header struct { maskKey [4]byte } +func makeWriteHeaderBuf() []byte { + return make([]byte, maxHeaderSize) +} + // bytes returns the bytes of the header. // See https://tools.ietf.org/html/rfc6455#section-5.2 -func marshalHeader(h header) []byte { - b := make([]byte, 2, maxHeaderSize) +func writeHeader(b []byte, h header) []byte { + if b == nil { + b = makeWriteHeaderBuf() + } + + b = b[:2] + b[0] = 0 if h.fin { b[0] |= 1 << 7 @@ -75,7 +84,7 @@ func marshalHeader(h header) []byte { return b } -func makeHeaderBuf() []byte { +func makeReadHeaderBuf() []byte { return make([]byte, maxHeaderSize-2) } @@ -83,8 +92,9 @@ func makeHeaderBuf() []byte { // See https://tools.ietf.org/html/rfc6455#section-5.2 func readHeader(b []byte, r io.Reader) (header, error) { if b == nil { - b = makeHeaderBuf() + b = makeReadHeaderBuf() } + // We read the first two bytes first so that we know // exactly how long the header is. b = b[:2] diff --git a/header_test.go b/header_test.go index 78d618999c1b392ed3c1509cb92b33174b1321d1..b45854eaf4777753cdaa80af8d5b2763628b1874 100644 --- a/header_test.go +++ b/header_test.go @@ -24,7 +24,7 @@ func TestHeader(t *testing.T) { t.Run("readNegativeLength", func(t *testing.T) { t.Parallel() - b := marshalHeader(header{ + b := writeHeader(nil, header{ payloadLength: 1<<16 + 1, }) @@ -90,7 +90,7 @@ func TestHeader(t *testing.T) { } func testHeader(t *testing.T, h header) { - b := marshalHeader(h) + b := writeHeader(nil, h) r := bytes.NewReader(b) h2, err := readHeader(nil, r) if err != nil { diff --git a/websocket.go b/websocket.go index ebe1259774bef25885986582dac6b030181d3bc2..375685e75c9c865d434c9aae76d77c4c734bfb63 100644 --- a/websocket.go +++ b/websocket.go @@ -45,21 +45,21 @@ type Conn struct { // writeFrameLock is acquired to write a single frame. // Effectively meaning whoever holds it gets to write to bw. writeFrameLock chan struct{} + writeHeaderBuf []byte // Used to ensure the previous reader is read till EOF before allowing // a new one. previousReader *messageReader // readFrameLock is acquired to read from bw. - readFrameLock chan struct{} + readFrameLock chan struct{} + readHeaderBuf []byte + controlPayloadBuf []byte setReadTimeout chan context.Context setWriteTimeout chan context.Context activePingsMu sync.Mutex activePings map[string]chan<- struct{} - - headerBuf []byte - controlPayloadBuf []byte } func (c *Conn) init() { @@ -77,7 +77,8 @@ func (c *Conn) init() { c.activePings = make(map[string]chan<- struct{}) - c.headerBuf = makeHeaderBuf() + c.writeHeaderBuf = makeWriteHeaderBuf() + c.readHeaderBuf = makeReadHeaderBuf() c.controlPayloadBuf = make([]byte, maxControlFramePayload) runtime.SetFinalizer(c, func(c *Conn) { @@ -215,7 +216,7 @@ func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { case c.setReadTimeout <- ctx: } - h, err := readHeader(c.headerBuf, c.br) + h, err := readHeader(c.readHeaderBuf, c.br) if err != nil { select { case <-c.closed: @@ -628,7 +629,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte } } - b2 := marshalHeader(h) + headerBytes := writeHeader(c.writeHeaderBuf, h) err := c.acquireLock(ctx, c.writeFrameLock) if err != nil { @@ -651,7 +652,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte default: } - err = xerrors.Errorf("failed to write frame: %w", err) + err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err) // We need to release the lock first before closing the connection to ensure // the lock can be acquired inside close to ensure no one can access c.bw. c.releaseLock(c.writeFrameLock) @@ -660,7 +661,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte return err } - _, err = c.bw.Write(b2) + _, err = c.bw.Write(headerBytes) if err != nil { return 0, writeErr(err) }