From 63f27e246b5bad3c2d0cc674be91fa03bae40d36 Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Tue, 20 Aug 2019 17:46:00 -0400 Subject: [PATCH] Reduce Reader/Writer allocations Closes #116 --- websocket.go | 43 +++++++++++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/websocket.go b/websocket.go index ee61f54..5942b68 100644 --- a/websocket.go +++ b/websocket.go @@ -56,6 +56,8 @@ type Conn struct { // read limit for a message in bytes. msgReadLimit int64 + // Used to ensure a previous writer is not used after being closed. + activeWriter *messageWriter // messageWriter state. writeMsgOpcode opcode writeMsgCtx context.Context @@ -63,7 +65,7 @@ type Conn struct { // Used to ensure the previous reader is read till EOF before allowing // a new one. - previousReader *messageReader + activeReader *messageReader // readFrameLock is acquired to read from bw. readFrameLock chan struct{} readClosed int64 @@ -358,7 +360,7 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { } func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { - if c.previousReader != nil && !c.readFrameEOF { + if c.activeReader != nil && !c.readFrameEOF { // The only way we know for sure the previous reader is not yet complete is // if there is an active frame not yet fully read. // Otherwise, a user may have read the last byte but not the EOF if the EOF @@ -371,7 +373,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, err } - if c.previousReader != nil && !c.previousReader.eof { + if c.activeReader != nil && !c.activeReader.eof() { if h.opcode != opContinuation { err := xerrors.Errorf("received new data message without finishing the previous message") c.Close(StatusProtocolError, err.Error()) @@ -382,7 +384,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { return 0, nil, xerrors.Errorf("previous message not read to completion") } - c.previousReader.eof = true + c.activeReader = nil h, err = c.readTillMsg(ctx) if err != nil { @@ -403,7 +405,7 @@ func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) { r := &messageReader{ c: c, } - c.previousReader = r + c.activeReader = r return MessageType(h.opcode), r, nil } @@ -430,8 +432,11 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { // messageReader enables reading a data frame from the WebSocket connection. type messageReader struct { - c *Conn - eof bool + c *Conn +} + +func (r *messageReader) eof() bool { + return r.c.activeReader != r } // Read reads as many bytes as possible into p. @@ -449,7 +454,7 @@ func (r *messageReader) Read(p []byte) (int, error) { } func (r *messageReader) read(p []byte) (int, error) { - if r.eof { + if r.eof() { return 0, xerrors.Errorf("cannot use EOFed reader") } @@ -502,7 +507,7 @@ func (r *messageReader) read(p []byte) (int, error) { r.c.readFrameEOF = true if h.fin { - r.eof = true + r.c.activeReader = nil return n, io.EOF } } @@ -593,9 +598,11 @@ func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, err } c.writeMsgCtx = ctx c.writeMsgOpcode = opcode(typ) - return &messageWriter{ + w := &messageWriter{ c: c, - }, nil + } + c.activeWriter = w + return w, nil } // Write is a convenience method to write a message to the connection. @@ -622,8 +629,11 @@ func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error // messageWriter enables writing to a WebSocket connection. type messageWriter struct { - c *Conn - closed bool + c *Conn +} + +func (w *messageWriter) closed() bool { + return w != w.c.activeWriter } // Write writes the given bytes to the WebSocket connection. @@ -636,7 +646,7 @@ func (w *messageWriter) Write(p []byte) (int, error) { } func (w *messageWriter) write(p []byte) (int, error) { - if w.closed { + if w.closed() { return 0, xerrors.Errorf("cannot use closed writer") } n, err := w.c.writeFrame(w.c.writeMsgCtx, false, w.c.writeMsgOpcode, p) @@ -658,16 +668,17 @@ func (w *messageWriter) Close() error { } func (w *messageWriter) close() error { - if w.closed { + if w.closed() { return xerrors.Errorf("cannot use closed writer") } - w.closed = true + w.closed() _, err := w.c.writeFrame(w.c.writeMsgCtx, true, w.c.writeMsgOpcode, nil) if err != nil { return xerrors.Errorf("failed to write fin frame: %w", err) } + w.c.activeWriter = nil w.c.releaseLock(w.c.writeMsgLock) return nil } -- GitLab