diff --git a/lib/fed/conn.go b/lib/fed/conn.go
index e9b640bc63f31b0f620f0dd905d98f7b7b5893ce..713f47a82cf86ddeecf5cae9fbba643e06ce7709 100644
--- a/lib/fed/conn.go
+++ b/lib/fed/conn.go
@@ -1,8 +1,10 @@
package fed
import (
+ "bufio"
"crypto/tls"
"encoding/binary"
+ "errors"
"io"
"net"
)
@@ -13,112 +15,69 @@ type Conn interface {
Close() error
}
-const pktBufSize = 4096
-
type netConn struct {
- conn net.Conn
- w io.Writer
-
- writeBuf net.Buffers
-
- pktBuf [pktBufSize]byte
- readBuf []byte
+ conn net.Conn
+ writer bufio.Writer
+ reader bufio.Reader
headerBuf [5]byte
}
func WrapNetConn(conn net.Conn) Conn {
- return &netConn{
+ c := &netConn{
conn: conn,
- w: conn,
}
+ c.writer.Reset(conn)
+ c.reader.Reset(conn)
+ return c
}
func (T *netConn) EnableSSLClient(config *tls.Config) error {
- if err := T.flush(); err != nil {
+ if err := T.writer.Flush(); err != nil {
return err
}
+ if T.reader.Buffered() > 0 {
+ return errors.New("expected empty read buffer")
+ }
sslConn := tls.Client(T.conn, config)
+ T.writer.Reset(sslConn)
+ T.reader.Reset(sslConn)
T.conn = sslConn
- T.w = sslConn
return sslConn.Handshake()
}
func (T *netConn) EnableSSLServer(config *tls.Config) error {
- if err := T.flush(); err != nil {
+ if err := T.writer.Flush(); err != nil {
return err
}
+ if T.reader.Buffered() > 0 {
+ return errors.New("expected empty read buffer")
+ }
sslConn := tls.Server(T.conn, config)
+ T.writer.Reset(sslConn)
+ T.reader.Reset(sslConn)
T.conn = sslConn
- T.w = sslConn
return sslConn.Handshake()
}
-func (T *netConn) flush() error {
- if len(T.writeBuf) == 0 {
- return nil
- }
-
- _, err := T.writeBuf.WriteTo(T.w)
- T.writeBuf = T.writeBuf[0:]
- return err
-}
-
-func (T *netConn) read(buf []byte) (n int, err error) {
- for {
- if len(T.readBuf) > 0 {
- cn := copy(buf, T.readBuf)
- buf = buf[cn:]
- T.readBuf = T.readBuf[cn:]
- n += cn
- }
-
- if len(buf) == 0 {
- return
- }
-
- if len(buf) > len(T.pktBuf) {
- var rn int
- rn, err = T.conn.Read(buf)
- n += rn
- if err != nil {
- return
- }
- buf = buf[rn:]
- } else {
- var rn int
- rn, err = T.conn.Read(T.pktBuf[:])
- if err != nil {
- return
- }
- T.readBuf = T.pktBuf[:rn]
- }
- }
-}
-
func (T *netConn) ReadByte() (byte, error) {
- if err := T.flush(); err != nil {
+ if err := T.writer.Flush(); err != nil {
return 0, err
}
- var b [1]byte
- _, err := T.read(b[:])
- if err != nil {
- return 0, err
- }
- return b[0], nil
+ return T.reader.ReadByte()
}
func (T *netConn) ReadPacket(typed bool) (Packet, error) {
- if err := T.flush(); err != nil {
+ if err := T.writer.Flush(); err != nil {
return nil, err
}
if typed {
- _, err := T.read(T.headerBuf[:])
+ _, err := T.reader.Read(T.headerBuf[:])
if err != nil {
return nil, err
}
} else {
- _, err := T.read(T.headerBuf[1:])
+ _, err := T.reader.Read(T.headerBuf[1:])
if err != nil {
return nil, err
}
@@ -130,7 +89,7 @@ func (T *netConn) ReadPacket(typed bool) (Packet, error) {
copy(p, T.headerBuf[:])
packet := Packet(p)
- _, err := T.read(packet.Payload())
+ _, err := T.reader.Read(packet.Payload())
if err != nil {
return nil, err
}
@@ -138,17 +97,16 @@ func (T *netConn) ReadPacket(typed bool) (Packet, error) {
}
func (T *netConn) WriteByte(b byte) error {
- T.writeBuf = append(T.writeBuf, []byte{b})
- return nil
+ return T.writer.WriteByte(b)
}
func (T *netConn) WritePacket(packet Packet) error {
- T.writeBuf = append(T.writeBuf, packet.Bytes())
- return nil
+ _, err := T.writer.Write(packet.Bytes())
+ return err
}
func (T *netConn) Close() error {
- if err := T.flush(); err != nil {
+ if err := T.writer.Flush(); err != nil {
return err
}
return T.conn.Close()
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index de8fd0bae5aa458776f0e989db037d42765d19c6..4f53b811b8a741d212f92e47984870ea93373cf0 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -219,7 +219,7 @@ func (T *Pool) removeServerL1(server *Server) {
T.pooler.DeleteServer(server.GetID())
_ = server.GetConn().Close()
if T.serversByRecipe != nil {
- T.serversByRecipe[server.GetRecipe()] = slices.Remove(T.serversByRecipe[server.GetRecipe()], server)
+ T.serversByRecipe[server.GetRecipe()] = slices.Delete(T.serversByRecipe[server.GetRecipe()], server)
}
}
diff --git a/lib/util/rbtree/rbtree.go b/lib/util/rbtree/rbtree.go
index 873d23d2313535c07194a5e19ce59e720806bdd8..a2f4a7e2f9b56aaa1a9117c5007708897645952f 100644
--- a/lib/util/rbtree/rbtree.go
+++ b/lib/util/rbtree/rbtree.go
@@ -14,6 +14,7 @@ type RBTree[K order, V any] struct {
}
func (T *RBTree[K, V]) free(n *node[K, V]) {
+ *n = node[K, V]{}
T.pool = append(T.pool, n)
}
@@ -21,7 +22,6 @@ func (T *RBTree[K, V]) alloc() *node[K, V] {
if len(T.pool) > 0 {
v := T.pool[len(T.pool)-1]
T.pool = T.pool[:len(T.pool)-1]
- *v = node[K, V]{}
return v
}
return new(node[K, V])
diff --git a/lib/util/ring/ring.go b/lib/util/ring/ring.go
index 1753b3d999d8d1d7e392d705771ca7459d4d9042..51aeb9e22da4433b20a20b3e098fbb6f1d4509b3 100644
--- a/lib/util/ring/ring.go
+++ b/lib/util/ring/ring.go
@@ -104,6 +104,7 @@ func (r *Ring[T]) PopFront() (T, bool) {
}
front := r.buf[r.head]
+ r.buf[r.head] = *new(T)
r.incHead()
return front, true
}
@@ -114,7 +115,9 @@ func (r *Ring[T]) PopBack() (T, bool) {
}
r.decTail()
- return r.buf[r.tail], true
+ back := r.buf[r.tail]
+ r.buf[r.tail] = *new(T)
+ return back, true
}
func (r *Ring[T]) Clear() {
diff --git a/lib/util/slices/remove.go b/lib/util/slices/remove.go
index 7e38812ac877ffc90d67bc98e6de764f2c8c9309..e0c239a653295225d72f58619285524a12bd6131 100644
--- a/lib/util/slices/remove.go
+++ b/lib/util/slices/remove.go
@@ -14,3 +14,16 @@ func Remove[T comparable](slice []T, item T) []T {
return slice
}
+
+// Delete is similar to Remove but leaves a *new(T) in the old slice, allowing the value to be GC'd
+func Delete[T comparable](slice []T, item T) []T {
+ for i, s := range slice {
+ if s == item {
+ copy(slice[i:], slice[i+1:])
+ slice[len(slice)-1] = *new(T)
+ return slice[:len(slice)-1]
+ }
+ }
+
+ return slice
+}