From 0fa48a57a1e1c5d8d82ec7eedc5cc164e6d80a35 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Wed, 12 Jun 2019 12:06:01 -0400 Subject: [PATCH] Reduce Reader allocation by 8 bytes Both Reader and Writer now will only ever allocate 16 bytes for their entire usage :) This is the minimum possible while still preventing misuse of a EOFed Reader or a closed Writer. --- websocket.go | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/websocket.go b/websocket.go index 2efc485..c5a8f68 100644 --- a/websocket.go +++ b/websocket.go @@ -36,9 +36,6 @@ type Conn struct { closer io.Closer client bool - // read limit for a message in bytes. - msgReadLimit int64 - closeOnce sync.Once closeErr error closed chan struct{} @@ -50,10 +47,13 @@ type Conn struct { writeFrameLock chan struct{} writeHeaderBuf []byte writeHeader *header + // read limit for a message in bytes. + msgReadLimit int64 // messageWriter state. writeMsgOpcode opcode writeMsgCtx context.Context + readMsgLeft int64 // Used to ensure the previous reader is read till EOF before allowing // a new one. @@ -371,10 +371,10 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { c.readMsgHeader = h c.readFrameEOF = false c.readMaskPos = 0 + c.readMsgLeft = c.msgReadLimit r := &messageReader{ - c: c, - left: c.msgReadLimit, + c: c, } c.previousReader = r return MessageType(h.opcode), r, nil @@ -382,9 +382,8 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { - c *Conn - left int64 - eof bool + c *Conn + eof bool } // Read reads as many bytes as possible into p. @@ -406,14 +405,14 @@ func (r *messageReader) read(p []byte) (int, error) { return 0, xerrors.Errorf("cannot use EOFed reader") } - if r.left <= 0 { + if r.c.readMsgLeft <= 0 { err := xerrors.Errorf("read limited at %v bytes", r.c.msgReadLimit) r.c.Close(StatusMessageTooBig, err.Error()) return 0, err } - if int64(len(p)) > r.left { - p = p[:r.left] + if int64(len(p)) > r.c.readMsgLeft { + p = p[:r.c.readMsgLeft] } if r.c.readFrameEOF { @@ -441,7 +440,7 @@ func (r *messageReader) read(p []byte) (int, error) { n, err := r.c.readFramePayload(r.c.readMsgCtx, p) h.payloadLength -= int64(n) - r.left -= int64(n) + r.c.readMsgLeft -= int64(n) if h.masked { r.c.readMaskPos = fastXOR(h.maskKey, r.c.readMaskPos, p) } -- GitLab