From 9fc9f7ab6742008fb936186272696c9933d9c51b Mon Sep 17 00:00:00 2001 From: Anmol Sethi <hi@nhooyr.io> Date: Mon, 23 Sep 2019 16:53:37 -0500 Subject: [PATCH] Ensure message order with a buffer --- conn.go | 6 ----- conn_common.go | 6 +++++ websocket_js.go | 62 ++++++++++++++++++++++++++++++++++++++----------- 3 files changed, 54 insertions(+), 20 deletions(-) diff --git a/conn.go b/conn.go index 20dbece..3d7d574 100644 --- a/conn.go +++ b/conn.go @@ -120,12 +120,6 @@ func (c *Conn) Subprotocol() string { return c.subprotocol } -func (c *Conn) setCloseErr(err error) { - c.closeErrOnce.Do(func() { - c.closeErr = fmt.Errorf("websocket closed: %w", err) - }) -} - func (c *Conn) close(err error) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) diff --git a/conn_common.go b/conn_common.go index 1429b47..ae0fe55 100644 --- a/conn_common.go +++ b/conn_common.go @@ -202,3 +202,9 @@ func (c *Conn) CloseRead(ctx context.Context) context.Context { func (c *Conn) SetReadLimit(n int64) { c.msgReadLimit = n } + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("websocket closed: %w", err) + }) +} diff --git a/websocket_js.go b/websocket_js.go index 4ed49d9..3822797 100644 --- a/websocket_js.go +++ b/websocket_js.go @@ -23,29 +23,32 @@ type Conn struct { msgReadLimit int64 - readClosed int64 - closeOnce sync.Once - closed chan struct{} - closeErr error + readClosed int64 + closeOnce sync.Once + closed chan struct{} + closeErrOnce sync.Once + closeErr error releaseOnClose func() releaseOnMessage func() - readch chan wsjs.MessageEvent + readSignal chan struct{} + readBufMu sync.Mutex + readBuf []wsjs.MessageEvent } func (c *Conn) close(err error) { c.closeOnce.Do(func() { runtime.SetFinalizer(c, nil) - c.closeErr = fmt.Errorf("websocket closed: %w", err) + c.setCloseErr(err) close(c.closed) }) } func (c *Conn) init() { c.closed = make(chan struct{}) - c.readch = make(chan wsjs.MessageEvent, 1) + c.readSignal = make(chan struct{}, 1) c.msgReadLimit = 32768 c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { @@ -61,15 +64,28 @@ func (c *Conn) init() { }) c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { - c.readch <- e + c.readBufMu.Lock() + defer c.readBufMu.Unlock() + + c.readBuf = append(c.readBuf, e) + + // Lets the read goroutine know there is definitely something in readBuf. + select { + case c.readSignal <- struct{}{}: + default: + } }) runtime.SetFinalizer(c, func(c *Conn) { - c.ws.Close(int(StatusInternalError), "") - c.close(errors.New("connection garbage collected")) + c.setCloseErr(errors.New("connection garbage collected")) + c.closeWithInternal() }) } +func (c *Conn) closeWithInternal() { + c.Close(StatusInternalError, "something went wrong") +} + // Read attempts to read a message from the connection. // The maximum time spent waiting is bounded by the context. func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { @@ -89,16 +105,32 @@ func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { } func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { - var me wsjs.MessageEvent select { case <-ctx.Done(): c.Close(StatusPolicyViolation, "read timed out") return 0, nil, ctx.Err() - case me = <-c.readch: + case <-c.readSignal: case <-c.closed: return 0, nil, c.closeErr } + c.readBufMu.Lock() + defer c.readBufMu.Unlock() + + me := c.readBuf[0] + // We copy the messages forward and decrease the size + // of the slice to avoid reallocating. + copy(c.readBuf, c.readBuf[1:]) + c.readBuf = c.readBuf[:len(c.readBuf)-1] + + if len(c.readBuf) > 0 { + // Next time we read, we'll grab the message. + select { + case c.readSignal <- struct{}{}: + default: + } + } + switch p := me.Data.(type) { case string: return MessageText, []byte(p), nil @@ -118,8 +150,10 @@ func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { // to match the Go API. It can only error if the message type // is unexpected or the passed bytes contain invalid UTF-8 for // MessageText. - c.Close(StatusInternalError, "something went wrong") - return fmt.Errorf("failed to write: %w", err) + err := fmt.Errorf("failed to write: %w", err) + c.setCloseErr(err) + c.closeWithInternal() + return err } return nil } -- GitLab