From afe94af9aa98d974da157da151c15817bc9d85d0 Mon Sep 17 00:00:00 2001 From: Jacob Date: Tue, 9 Apr 2024 20:29:44 +0200 Subject: [PATCH 1/2] Use new atomic types from Go 1.19 This is a cleaner solution for the fix in #438 thanks to the fact that Go 1.19 now is the default and the atomic.Int64 types are automatically aligned correctly on 32 bit systems. Using this also means that xsync.Int64 can be removed. The new atomic.Int64 type solves the issue and should be quite a lot faster as it avoids the interface conversion. --- conn.go | 4 ++-- internal/xsync/int64.go | 23 ----------------------- netconn.go | 41 +++++++++++++++++++---------------------- read.go | 4 ++-- ws_js.go | 4 ++-- 5 files changed, 25 insertions(+), 51 deletions(-) delete mode 100644 internal/xsync/int64.go diff --git a/conn.go b/conn.go index 8690fb3b..48bc510a 100644 --- a/conn.go +++ b/conn.go @@ -77,7 +77,7 @@ type Conn struct { closeMu sync.Mutex closing bool - pingCounter int32 + pingCounter atomic.Int32 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } @@ -200,7 +200,7 @@ func (c *Conn) flate() bool { // // TCP Keepalives should suffice for most use cases. func (c *Conn) Ping(ctx context.Context) error { - p := atomic.AddInt32(&c.pingCounter, 1) + p := c.pingCounter.Add(1) err := c.ping(ctx, strconv.Itoa(int(p))) if err != nil { diff --git a/internal/xsync/int64.go b/internal/xsync/int64.go deleted file mode 100644 index a0c40204..00000000 --- a/internal/xsync/int64.go +++ /dev/null @@ -1,23 +0,0 @@ -package xsync - -import ( - "sync/atomic" -) - -// Int64 represents an atomic int64. -type Int64 struct { - // We do not use atomic.Load/StoreInt64 since it does not - // work on 32 bit computers but we need 64 bit integers. - i atomic.Value -} - -// Load loads the int64. -func (v *Int64) Load() int64 { - i, _ := v.i.Load().(int64) - return i -} - -// Store stores the int64. -func (v *Int64) Store(i int64) { - v.i.Store(i) -} diff --git a/netconn.go b/netconn.go index 86f7dadb..b118e4d3 100644 --- a/netconn.go +++ b/netconn.go @@ -68,7 +68,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.writeMu.unlock() // Prevents future writes from writing until the deadline is reset. - atomic.StoreInt64(&nc.writeExpired, 1) + nc.writeExpired.Store(1) }) if !nc.writeTimer.Stop() { <-nc.writeTimer.C @@ -84,7 +84,7 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { defer nc.readMu.unlock() // Prevents future reads from reading until the deadline is reset. - atomic.StoreInt64(&nc.readExpired, 1) + nc.readExpired.Store(1) }) if !nc.readTimer.Stop() { <-nc.readTimer.C @@ -94,25 +94,22 @@ func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { } type netConn struct { - // These must be first to be aligned on 32 bit platforms. - // https://github.com/nhooyr/websocket/pull/438 - readExpired int64 - writeExpired int64 - c *Conn msgType MessageType - writeTimer *time.Timer - writeMu *mu - writeCtx context.Context - writeCancel context.CancelFunc - - readTimer *time.Timer - readMu *mu - readCtx context.Context - readCancel context.CancelFunc - readEOFed bool - reader io.Reader + writeTimer *time.Timer + writeMu *mu + writeExpired atomic.Int64 + writeCtx context.Context + writeCancel context.CancelFunc + + readTimer *time.Timer + readMu *mu + readExpired atomic.Int64 + readCtx context.Context + readCancel context.CancelFunc + readEOFed bool + reader io.Reader } var _ net.Conn = &netConn{} @@ -129,7 +126,7 @@ func (nc *netConn) Write(p []byte) (int, error) { nc.writeMu.forceLock() defer nc.writeMu.unlock() - if atomic.LoadInt64(&nc.writeExpired) == 1 { + if nc.writeExpired.Load() == 1 { return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded) } @@ -157,7 +154,7 @@ func (nc *netConn) Read(p []byte) (int, error) { } func (nc *netConn) read(p []byte) (int, error) { - if atomic.LoadInt64(&nc.readExpired) == 1 { + if nc.readExpired.Load() == 1 { return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded) } @@ -209,7 +206,7 @@ func (nc *netConn) SetDeadline(t time.Time) error { } func (nc *netConn) SetWriteDeadline(t time.Time) error { - atomic.StoreInt64(&nc.writeExpired, 0) + nc.writeExpired.Store(0) if t.IsZero() { nc.writeTimer.Stop() } else { @@ -223,7 +220,7 @@ func (nc *netConn) SetWriteDeadline(t time.Time) error { } func (nc *netConn) SetReadDeadline(t time.Time) error { - atomic.StoreInt64(&nc.readExpired, 0) + nc.readExpired.Store(0) if t.IsZero() { nc.readTimer.Stop() } else { diff --git a/read.go b/read.go index a59e71d9..20ed9408 100644 --- a/read.go +++ b/read.go @@ -11,11 +11,11 @@ import ( "io" "net" "strings" + "sync/atomic" "time" "nhooyr.io/websocket/internal/errd" "nhooyr.io/websocket/internal/util" - "nhooyr.io/websocket/internal/xsync" ) // Reader reads from the connection until there is a WebSocket @@ -465,7 +465,7 @@ func (mr *msgReader) read(p []byte) (int, error) { type limitReader struct { c *Conn r io.Reader - limit xsync.Int64 + limit atomic.Int64 n int64 } diff --git a/ws_js.go b/ws_js.go index 02d61f28..6e58329e 100644 --- a/ws_js.go +++ b/ws_js.go @@ -12,11 +12,11 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall/js" "nhooyr.io/websocket/internal/bpool" "nhooyr.io/websocket/internal/wsjs" - "nhooyr.io/websocket/internal/xsync" ) // opcode represents a WebSocket opcode. @@ -45,7 +45,7 @@ type Conn struct { ws wsjs.WebSocket // read limit for a message in bytes. - msgReadLimit xsync.Int64 + msgReadLimit atomic.Int64 closeReadMu sync.Mutex closeReadCtx context.Context From cfde4a5ebfd40869983e926c9098e12f82761740 Mon Sep 17 00:00:00 2001 From: Jacob Date: Thu, 11 Apr 2024 10:35:07 +0200 Subject: [PATCH 2/2] Use Int64 instead of Int32 for counting pings --- conn.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conn.go b/conn.go index 48bc510a..d7434a9d 100644 --- a/conn.go +++ b/conn.go @@ -77,7 +77,7 @@ type Conn struct { closeMu sync.Mutex closing bool - pingCounter atomic.Int32 + pingCounter atomic.Int64 activePingsMu sync.Mutex activePings map[string]chan<- struct{} } @@ -202,7 +202,7 @@ func (c *Conn) flate() bool { func (c *Conn) Ping(ctx context.Context) error { p := c.pingCounter.Add(1) - err := c.ping(ctx, strconv.Itoa(int(p))) + err := c.ping(ctx, strconv.FormatInt(p, 10)) if err != nil { return fmt.Errorf("failed to ping: %w", err) }