diff --git a/conn.go b/conn.go index 8690fb3b46804bcab5c82f868aec9c5d3dccf421..d7434a9d55cf94008bc87dfbe00f66720230c326 100644 --- a/conn.go +++ b/conn.go @@ -77,7 +77,7 @@ type Conn struct { closeMu sync.Mutex closing bool - pingCounter int32 + pingCounter atomic.Int64 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } @@ -200,9 +200,9 @@ func (c *Conn) flate() bool { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) + p := c.pingCounter.Add(1) - err := c.ping(ctx, strconv.Itoa(int(p))) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) } diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go deleted file mode 100644 index a0c402041568e8b8119809d5306e456e2beedda5..0000000000000000000000000000000000000000 --- a/internal/xsync/int64.go +++ /dev/null @@ -1,23 +0,0 @@ -package xsync - -import ( - "sync/atomic" -) - -// Int64 represents an atomic int64. -type Int64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -// Load loads the int64. -func (v *Int64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -// Store stores the int64. -func (v *Int64) Store(i int64) { - v.i.Store(i) -} diff --git a/netconn.go b/netconn.go index 86f7dadb58e0502e0e8b0b415c6a1c2a3f115522..b118e4d3e653ac2d142fc466cc516e33f850b3a5 100644 --- a/netconn.go +++ b/netconn.go @@ -68,7 +68,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. - atomic.StoreInt64(&nc.writeExpired, 1) + nc.writeExpired.Store(1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C @@ -84,7 +84,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. - atomic.StoreInt64(&nc.readExpired, 1) + nc.readExpired.Store(1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C @@ -94,25 +94,22 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { - // These must be first to be aligned on 32 bit platforms. - // https://github.com/nhooyr/websocket/pull/438 - readExpired int64 - writeExpired int64 - c *Conn msgType MessageType - writeTimer *time.Timer - writeMu *mu - writeCtx context.Context - writeCancel context.CancelFunc - - readTimer *time.Timer - readMu *mu - readCtx context.Context - readCancel context.CancelFunc - readEOFed bool - reader io.Reader + writeTimer *time.Timer + writeMu *mu + writeExpired atomic.Int64 + writeCtx context.Context + writeCancel context.CancelFunc + + readTimer *time.Timer + readMu *mu + readExpired atomic.Int64 + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} @@ -129,7 +126,7 @@ func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() - if atomic.LoadInt64(&nc.writeExpired) == 1 { + if nc.writeExpired.Load() == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } @@ -157,7 +154,7 @@ func (nc *netConn) Read(p []byte) (int, error) { } func (nc *netConn) read(p []byte) (int, error) { - if atomic.LoadInt64(&nc.readExpired) == 1 { + if nc.readExpired.Load() == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } @@ -209,7 +206,7 @@ func (nc *netConn) SetDeadline(t time.Time) error { } func (nc *netConn) SetWriteDeadline(t time.Time) error { - atomic.StoreInt64(&nc.writeExpired, 0) + nc.writeExpired.Store(0) if t.IsZero() { nc.writeTimer.Stop() } else { @@ -223,7 +220,7 @@ func (nc *netConn) SetWriteDeadline(t time.Time) error { } func (nc *netConn) SetReadDeadline(t time.Time) error { - atomic.StoreInt64(&nc.readExpired, 0) + nc.readExpired.Store(0) if t.IsZero() { nc.readTimer.Stop() } else { diff --git a/read.go b/read.go index a59e71d9b9fde0d341b9dde9d2307030f8a1684e..20ed940881e0fcfbf0dc687371158f0bd55ed968 100644 --- a/read.go +++ b/read.go @@ -11,11 +11,11 @@ import ( "io" "net" "strings" + "sync/atomic" "time" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/util" - "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket @@ -465,7 +465,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit xsync.Int64 + limit atomic.Int64 n int64 } diff --git a/ws_js.go b/ws_js.go index 02d61f28c13e6ddacf36126227ad2a3d48209ed3..6e58329e0197641b3745204057985be9a4d1ed1f 100644 --- a/ws_js.go +++ b/ws_js.go @@ -12,11 +12,11 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall/js" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" - "nhooyr.io/websocket/internal/xsync" ) // opcode represents a WebSocket opcode. @@ -45,7 +45,7 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit xsync.Int64 + msgReadLimit atomic.Int64 closeReadMu sync.Mutex closeReadCtx context.Context