From 69d6ba4d4cbca2d622339e199fc18b20196efd28 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 5 Sep 2023 00:34:03 -0500
Subject: [PATCH] a

---
 lib/gat/pool/client.go |  67 ++++++++
 lib/gat/pool/pool.go   | 338 +++++++++++++++++------------------------
 lib/gat/pool/server.go | 101 ++++++++++++
 lib/util/maps/clear.go |   7 +
 4 files changed, 313 insertions(+), 200 deletions(-)
 create mode 100644 lib/gat/pool/client.go
 create mode 100644 lib/gat/pool/server.go
 create mode 100644 lib/util/maps/clear.go

diff --git a/lib/gat/pool/client.go b/lib/gat/pool/client.go
new file mode 100644
index 00000000..0ad34b6c
--- /dev/null
+++ b/lib/gat/pool/client.go
@@ -0,0 +1,67 @@
+package pool
+
+import (
+	"github.com/google/uuid"
+	"pggat2/lib/fed"
+	"sync"
+	"time"
+)
+
+type Client struct {
+	conn       fed.Conn
+	backendKey [8]byte
+
+	metrics ClientMetrics
+	mu      sync.RWMutex
+}
+
+func NewClient(
+	conn fed.Conn,
+	backendKey [8]byte,
+) *Client {
+	return &Client{
+		conn:       conn,
+		backendKey: backendKey,
+
+		metrics: MakeClientMetrics(),
+	}
+}
+
+func (T *Client) GetConn() fed.Conn {
+	return T.conn
+}
+
+func (T *Client) GetBackendKey() [8]byte {
+	return T.backendKey
+}
+
+// SetPeer replaces the peer. Returns the old peer
+func (T *Client) SetPeer(peer uuid.UUID) uuid.UUID {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	old := T.metrics.Peer
+	T.metrics.SetPeer(peer)
+	return old
+}
+
+func (T *Client) GetPeer() uuid.UUID {
+	T.mu.RLock()
+	defer T.mu.RUnlock()
+
+	return T.metrics.Peer
+}
+
+func (T *Client) GetConnection() (uuid.UUID, time.Time) {
+	T.mu.RLock()
+	defer T.mu.RUnlock()
+
+	return T.metrics.Peer, T.metrics.Since
+}
+
+func (T *Client) ReadMetrics(metrics *ClientMetrics) {
+	T.mu.RLock()
+	defer T.mu.RUnlock()
+
+	panic("TODO(garet)")
+}
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index d0ed7f1f..430476c8 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -1,7 +1,7 @@
 package pool
 
 import (
-	"sync"
+	"pggat2/lib/util/maps"
 	"sync/atomic"
 	"time"
 
@@ -23,39 +23,17 @@ import (
 	"pggat2/lib/util/strutil"
 )
 
-type poolServer struct {
-	conn   fed.Conn
-	accept backends.AcceptParams
-	recipe string
-
-	// middlewares
-	psServer  *ps.Server
-	eqpServer *eqp.Server
-
-	metrics ServerMetrics
-	mu      sync.Mutex
-}
-
 type poolRecipe struct {
 	recipe Recipe
 	count  atomic.Int64
 }
 
-type poolClient struct {
-	conn fed.Conn
-	key  [8]byte
-
-	metrics ClientMetrics
-	mu      sync.Mutex
-}
-
 type Pool struct {
 	options Options
 
-	recipes map[string]*poolRecipe
-	servers map[uuid.UUID]*poolServer
-	clients map[uuid.UUID]*poolClient
-	mu      sync.Mutex
+	recipes maps.RWLocked[string, *poolRecipe]
+	servers maps.RWLocked[uuid.UUID, *Server]
+	clients maps.RWLocked[uuid.UUID, *Client]
 }
 
 func NewPool(options Options) *Pool {
@@ -71,26 +49,20 @@ func NewPool(options Options) *Pool {
 }
 
 func (T *Pool) idlest() (idlest uuid.UUID, idle time.Time) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	for serverID, server := range T.servers {
-		func() {
-			server.mu.Lock()
-			defer server.mu.Unlock()
-
-			if server.metrics.Peer != uuid.Nil {
-				return
-			}
+	T.servers.Range(func(serverID uuid.UUID, server *Server) bool {
+		peer, since := server.GetConnection()
+		if peer != uuid.Nil {
+			return true
+		}
 
-			if idle != (time.Time{}) && server.metrics.Since.After(idle) {
-				return
-			}
+		if idle != (time.Time{}) && since.After(idle) {
+			return true
+		}
 
-			idlest = serverID
-			idle = server.metrics.Since
-		}()
-	}
+		idlest = serverID
+		idle = since
+		return true
+	})
 
 	return
 }
@@ -121,7 +93,10 @@ func (T *Pool) GetCredentials() auth.Credentials {
 }
 
 func (T *Pool) _scaleUpRecipe(name string) {
-	r := T.recipes[name]
+	r, ok := T.recipes.Load(name)
+	if !ok {
+		return
+	}
 
 	server, params, err := r.recipe.Dialer.Dial()
 	if err != nil {
@@ -129,11 +104,6 @@ func (T *Pool) _scaleUpRecipe(name string) {
 		return
 	}
 
-	serverID := uuid.New()
-	if T.servers == nil {
-		T.servers = make(map[uuid.UUID]*poolServer)
-	}
-
 	var middlewares []middleware.Middleware
 
 	var psServer *ps.Server
@@ -157,28 +127,32 @@ func (T *Pool) _scaleUpRecipe(name string) {
 		)
 	}
 
-	T.servers[serverID] = &poolServer{
-		conn:   server,
-		accept: params,
-		recipe: name,
-
-		psServer:  psServer,
-		eqpServer: eqpServer,
-
-		metrics: MakeServerMetrics(),
-	}
+	r.count.Add(1)
+	serverID := uuid.New()
+	T.servers.Store(serverID, NewServer(
+		server,
+		params.BackendKey,
+		params.InitialParameters,
+		name,
+		psServer,
+		eqpServer,
+	))
 	T.options.Pooler.AddServer(serverID)
 }
 
 func (T *Pool) AddRecipe(name string, recipe Recipe) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	if T.recipes == nil {
-		T.recipes = make(map[string]*poolRecipe)
-	}
-	T.recipes[name] = &poolRecipe{
+	_, hasOld := T.recipes.Swap(name, &poolRecipe{
 		recipe: recipe,
+	})
+	if hasOld {
+		T.servers.Range(func(serverID uuid.UUID, server *Server) bool {
+			if server.GetRecipe() == name {
+				_ = server.GetConn().Close()
+				T.options.Pooler.RemoveServer(serverID)
+				T.servers.Delete(serverID)
+			}
+			return true
+		})
 	}
 
 	for i := 0; i < recipe.MinConnections; i++ {
@@ -187,43 +161,40 @@ func (T *Pool) AddRecipe(name string, recipe Recipe) {
 }
 
 func (T *Pool) RemoveRecipe(name string) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	delete(T.recipes, name)
+	T.recipes.Delete(name)
 
 	// close all servers with this recipe
-	for id, server := range T.servers {
-		if server.recipe == name {
-			_ = server.conn.Close()
-			T.options.Pooler.RemoveServer(id)
-			delete(T.servers, id)
+
+	T.servers.Range(func(serverID uuid.UUID, server *Server) bool {
+		if server.GetRecipe() == name {
+			_ = server.GetConn().Close()
+			T.options.Pooler.RemoveServer(serverID)
+			T.servers.Delete(serverID)
 		}
-	}
+		return true
+	})
 }
 
-func (T *Pool) scaleUp() {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	for name, r := range T.recipes {
+func (T *Pool) ScaleUp() {
+	T.recipes.Range(func(name string, r *poolRecipe) bool {
 		if r.recipe.MaxConnections == 0 || int(r.count.Load()) < r.recipe.MaxConnections {
 			T._scaleUpRecipe(name)
-			return
+			return false
 		}
-	}
 
-	log.Println("warning: tried to scale up pool but no space was available")
+		return true
+	})
 }
 
-func (T *Pool) syncInitialParameters(
+func syncInitialParameters(
+	trackedParameters []strutil.CIString,
 	client fed.Conn,
 	clientParams map[strutil.CIString]string,
 	server fed.Conn,
 	serverParams map[strutil.CIString]string,
 ) (clientErr, serverErr error) {
 	for key, value := range clientParams {
-		setServer := slices.Contains(T.options.TrackedParameters, key)
+		setServer := slices.Contains(trackedParameters, key)
 
 		// skip already set params
 		if serverParams[key] == value {
@@ -308,7 +279,7 @@ func (T *Pool) Serve(
 	defer T.removeClient(clientID)
 
 	var serverID uuid.UUID
-	var server *poolServer
+	var server *Server
 
 	defer func() {
 		if serverID != uuid.Nil {
@@ -328,17 +299,17 @@ func (T *Pool) Serve(
 
 			switch T.options.ParameterStatusSync {
 			case ParameterStatusSyncDynamic:
-				clientErr, serverErr = ps.Sync(T.options.TrackedParameters, client, psClient, server.conn, server.psServer)
+				clientErr, serverErr = ps.Sync(T.options.TrackedParameters, client, psClient, server.GetConn(), server.GetPSServer())
 			case ParameterStatusSyncInitial:
-				clientErr, serverErr = T.syncInitialParameters(client, accept.InitialParameters, server.conn, server.accept.InitialParameters)
+				clientErr, serverErr = syncInitialParameters(T.options.TrackedParameters, client, accept.InitialParameters, server.GetConn(), server.GetInitialParameters())
 			}
 
 			if T.options.ExtendedQuerySync {
-				server.eqpServer.SetClient(eqpClient)
+				server.GetEQPServer().SetClient(eqpClient)
 			}
 		}
 		if clientErr == nil && serverErr == nil {
-			clientErr, serverErr = bouncers.Bounce(client, server.conn, packet)
+			clientErr, serverErr = bouncers.Bounce(client, server.GetConn(), packet)
 		}
 		if serverErr != nil {
 			T.removeServer(serverID)
@@ -350,6 +321,8 @@ func (T *Pool) Serve(
 				T.releaseServer(serverID)
 				serverID = uuid.Nil
 				server = nil
+			} else {
+				T.transactionComplete(serverID)
 			}
 		}
 
@@ -360,167 +333,132 @@ func (T *Pool) Serve(
 }
 
 func (T *Pool) addClient(client fed.Conn, key [8]byte) uuid.UUID {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
 	clientID := uuid.New()
 
-	if T.clients == nil {
-		T.clients = make(map[uuid.UUID]*poolClient)
-	}
-	T.clients[clientID] = &poolClient{
-		conn: client,
-		key:  key,
-
-		metrics: MakeClientMetrics(),
-	}
+	T.clients.Store(clientID, NewClient(
+		client,
+		key,
+	))
 	T.options.Pooler.AddClient(clientID)
 	return clientID
 }
 
 func (T *Pool) removeClient(clientID uuid.UUID) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	delete(T.clients, clientID)
+	T.clients.Delete(clientID)
 	T.options.Pooler.RemoveClient(clientID)
 }
 
-func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server *poolServer) {
+func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server *Server) {
 	serverID = T.options.Pooler.Acquire(clientID, SyncModeNonBlocking)
 	if serverID == uuid.Nil {
-		go T.scaleUp()
+		go T.ScaleUp()
 		serverID = T.options.Pooler.Acquire(clientID, SyncModeBlocking)
 	}
 
-	T.mu.Lock()
-	defer T.mu.Unlock()
-	server = T.servers[serverID]
-	client := T.clients[clientID]
+	server, _ = T.servers.Load(serverID)
+	client, _ := T.clients.Load(clientID)
 	if server != nil {
-		server.mu.Lock()
-		defer server.mu.Unlock()
-		server.metrics.SetPeer(clientID)
+		server.SetPeer(clientID)
 	}
 	if client != nil {
-		client.mu.Lock()
-		defer client.mu.Unlock()
-		client.metrics.SetPeer(serverID)
+		client.SetPeer(serverID)
 	}
 	return
 }
 
 func (T *Pool) releaseServer(serverID uuid.UUID) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	server := T.servers[serverID]
+	server, _ := T.servers.Load(serverID)
 	if server == nil {
 		return
 	}
 
-	var clientID uuid.UUID
-
-	func() {
-		server.mu.Lock()
-		defer server.mu.Unlock()
-		clientID = server.metrics.Peer
-		server.metrics.SetPeer(uuid.Nil)
-	}()
+	clientID := server.SetPeer(uuid.Nil)
 
 	if clientID != uuid.Nil {
-		client := T.clients[clientID]
+		client, _ := T.clients.Load(clientID)
 		if client != nil {
-			func() {
-				client.mu.Lock()
-				defer client.mu.Unlock()
-				client.metrics.SetPeer(uuid.Nil)
-			}()
+			client.SetPeer(uuid.Nil)
 		}
 	}
 
 	if T.options.ServerResetQuery != "" {
-		err := backends.QueryString(new(backends.Context), server.conn, T.options.ServerResetQuery)
+		err := backends.QueryString(new(backends.Context), server.GetConn(), T.options.ServerResetQuery)
 		if err != nil {
-			T._removeServer(serverID)
+			T.removeServer(serverID)
 			return
 		}
 	}
 	T.options.Pooler.Release(serverID)
 }
 
-func (T *Pool) _removeServer(serverID uuid.UUID) {
-	if server, ok := T.servers[serverID]; ok {
-		_ = server.conn.Close()
-		delete(T.servers, serverID)
-		T.options.Pooler.RemoveServer(serverID)
-		r := T.recipes[server.recipe]
-		if r != nil {
-			r.count.Add(-1)
-		}
-	}
+func (T *Pool) transactionComplete(serverID uuid.UUID) {
+
 }
 
 func (T *Pool) removeServer(serverID uuid.UUID) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
-	T._removeServer(serverID)
+	server, _ := T.servers.LoadAndDelete(serverID)
+	if server == nil {
+		return
+	}
+	_ = server.GetConn().Close()
+	T.options.Pooler.RemoveServer(serverID)
+	r, _ := T.recipes.Load(server.GetRecipe())
+	if r != nil {
+		r.count.Add(-1)
+	}
 }
 
 func (T *Pool) Cancel(key [8]byte) error {
-	dialer, backendKey := func() (Dialer, [8]byte) {
-		T.mu.Lock()
-		defer T.mu.Unlock()
-
-		var clientID uuid.UUID
-		for id, client := range T.clients {
-			if client.key == key {
-				clientID = id
-				break
-			}
-		}
-
-		if clientID == uuid.Nil {
-			return nil, [8]byte{}
-		}
-
-		// get peer
-		var recipe string
-		var serverKey [8]byte
-		var ok bool
-		for _, server := range T.servers {
-			func() {
-				server.mu.Lock()
-				defer server.mu.Unlock()
-
-				if server.metrics.Peer == clientID {
-					recipe = server.recipe
-					serverKey = server.accept.BackendKey
-					ok = true
-					return
-				}
-			}()
-			if ok {
-				break
-			}
+	var clientID uuid.UUID
+	T.clients.Range(func(id uuid.UUID, client *Client) bool {
+		if client.GetBackendKey() == key {
+			clientID = id
+			return false
 		}
+		return true
+	})
 
-		if !ok {
-			return nil, [8]byte{}
-		}
+	if clientID == uuid.Nil {
+		return nil
+	}
 
-		r, ok := T.recipes[recipe]
-		if !ok {
-			return nil, [8]byte{}
+	// get peer
+	var recipe string
+	var serverKey [8]byte
+	if T.servers.Range(func(_ uuid.UUID, server *Server) bool {
+		if server.GetPeer() == clientID {
+			recipe = server.GetRecipe()
+			serverKey = server.GetBackendKey()
+			return false
 		}
+		return true
+	}) {
+		return nil
+	}
 
-		return r.recipe.Dialer, serverKey
-	}()
-
-	if dialer == nil {
+	r, _ := T.recipes.Load(recipe)
+	if r == nil {
 		return nil
 	}
 
-	return dialer.Cancel(backendKey)
+	return r.recipe.Dialer.Cancel(serverKey)
+}
+
+func (T *Pool) ReadMetrics(metrics *Metrics) {
+	maps.Clear(metrics.Servers)
+	maps.Clear(metrics.Clients)
+
+	T.servers.Range(func(serverID uuid.UUID, server *Server) bool {
+		var m ServerMetrics
+		server.ReadMetrics(&m)
+		metrics.Servers[serverID] = m
+		return true
+	})
+
+	T.clients.Range(func(clientID uuid.UUID, client *Client) bool {
+		var m ClientMetrics
+		client.ReadMetrics(&m)
+		metrics.Clients[clientID] = m
+		return true
+	})
 }
diff --git a/lib/gat/pool/server.go b/lib/gat/pool/server.go
new file mode 100644
index 00000000..64fcce88
--- /dev/null
+++ b/lib/gat/pool/server.go
@@ -0,0 +1,101 @@
+package pool
+
+import (
+	"github.com/google/uuid"
+	"pggat2/lib/fed"
+	"pggat2/lib/middleware/middlewares/eqp"
+	"pggat2/lib/middleware/middlewares/ps"
+	"pggat2/lib/util/strutil"
+	"sync"
+	"time"
+)
+
+type Server struct {
+	conn              fed.Conn
+	backendKey        [8]byte
+	initialParameters map[strutil.CIString]string
+	recipe            string
+
+	psServer  *ps.Server
+	eqpServer *eqp.Server
+
+	metrics ServerMetrics
+	mu      sync.RWMutex
+}
+
+func NewServer(
+	conn fed.Conn,
+	backendKey [8]byte,
+	initialParameters map[strutil.CIString]string,
+	recipe string,
+
+	psServer *ps.Server,
+	eqpServer *eqp.Server,
+) *Server {
+	return &Server{
+		conn:              conn,
+		backendKey:        backendKey,
+		initialParameters: initialParameters,
+		recipe:            recipe,
+
+		psServer:  psServer,
+		eqpServer: eqpServer,
+
+		metrics: MakeServerMetrics(),
+	}
+}
+
+func (T *Server) GetConn() fed.Conn {
+	return T.conn
+}
+
+func (T *Server) GetBackendKey() [8]byte {
+	return T.backendKey
+}
+
+func (T *Server) GetInitialParameters() map[strutil.CIString]string {
+	return T.initialParameters
+}
+
+func (T *Server) GetRecipe() string {
+	return T.recipe
+}
+
+func (T *Server) GetPSServer() *ps.Server {
+	return T.psServer
+}
+
+func (T *Server) GetEQPServer() *eqp.Server {
+	return T.eqpServer
+}
+
+// SetPeer replaces the peer. Returns the old peer
+func (T *Server) SetPeer(peer uuid.UUID) uuid.UUID {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	old := T.metrics.Peer
+	T.metrics.SetPeer(peer)
+	return old
+}
+
+func (T *Server) GetPeer() uuid.UUID {
+	T.mu.RLock()
+	defer T.mu.RUnlock()
+
+	return T.metrics.Peer
+}
+
+func (T *Server) GetConnection() (uuid.UUID, time.Time) {
+	T.mu.RLock()
+	defer T.mu.RUnlock()
+
+	return T.metrics.Peer, T.metrics.Since
+}
+
+func (T *Server) ReadMetrics(metrics *ServerMetrics) {
+	T.mu.RLock()
+	defer T.mu.RUnlock()
+
+	panic("TODO(garet)")
+}
diff --git a/lib/util/maps/clear.go b/lib/util/maps/clear.go
new file mode 100644
index 00000000..11d592ea
--- /dev/null
+++ b/lib/util/maps/clear.go
@@ -0,0 +1,7 @@
+package maps
+
+func Clear[K comparable, V any](m map[K]V) {
+	for k := range m {
+		delete(m, k)
+	}
+}
-- 
GitLab