From 173d12eedd6a33f4683a6dd0d9c45b210b27d658 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Fri, 15 Sep 2023 18:39:35 -0500
Subject: [PATCH] why

---
 lib/gat/pool/pool.go | 52 ++++++++++++++++++++++++++------------------
 1 file changed, 31 insertions(+), 21 deletions(-)

diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index 847567c0..3f8e4b4d 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -267,16 +267,22 @@ func (T *Pool) removeServerL1(server *Server) {
 func (T *Pool) acquireServer(client *Client) *Server {
 	client.SetState(metrics.ConnStateAwaitingServer, uuid.Nil)
 
-	serverID := T.options.Pooler.Acquire(client.GetID(), SyncModeNonBlocking)
-	if serverID == uuid.Nil {
-		// TODO(garet) can this be run on same thread and only create a goroutine if scaling is possible?
-		go T.scaleUp()
-		serverID = T.options.Pooler.Acquire(client.GetID(), SyncModeBlocking)
-	}
+	for {
+		serverID := T.options.Pooler.Acquire(client.GetID(), SyncModeNonBlocking)
+		if serverID == uuid.Nil {
+			// TODO(garet) can this be run on same thread and only create a goroutine if scaling is possible?
+			go T.scaleUp()
+			serverID = T.options.Pooler.Acquire(client.GetID(), SyncModeBlocking)
+		}
 
-	T.mu.RLock()
-	defer T.mu.RUnlock()
-	return T.servers[serverID]
+		T.mu.RLock()
+		server, ok := T.servers[serverID]
+		T.mu.RUnlock()
+		if !ok {
+			continue
+		}
+		return server
+	}
 }
 
 func (T *Pool) releaseServer(server *Server) {
@@ -337,24 +343,34 @@ func (T *Pool) serve(client *Client, initialize bool) error {
 	T.addClient(client)
 	defer T.removeClient(client)
 
+	var err error
+	var serverErr error
+
 	var server *Server
+	defer func() {
+		if server != nil {
+			if serverErr != nil {
+				T.removeServer(server)
+			} else {
+				T.releaseServer(server)
+			}
+		}
+	}()
+
 	if !initialize {
 		server = T.acquireServer(client)
 
-		err, serverErr := Pair(T.options, client, server)
+		err, serverErr = Pair(T.options, client, server)
 		if serverErr != nil {
-			T.removeServer(server)
 			return serverErr
 		}
 		if err != nil {
-			T.releaseServer(server)
 			return err
 		}
 
 		p := packets.ReadyForQuery('I')
 		err = client.GetConn().WritePacket(p.IntoPacket())
 		if err != nil {
-			T.releaseServer(server)
 			return err
 		}
 	}
@@ -366,15 +382,12 @@ func (T *Pool) serve(client *Client, initialize bool) error {
 			server = nil
 		}
 
-		packet, err := client.GetConn().ReadPacket(true)
+		var packet fed.Packet
+		packet, err = client.GetConn().ReadPacket(true)
 		if err != nil {
-			if server != nil {
-				T.releaseServer(server)
-			}
 			return err
 		}
 
-		var serverErr error
 		if server == nil {
 			server = T.acquireServer(client)
 
@@ -384,15 +397,12 @@ func (T *Pool) serve(client *Client, initialize bool) error {
 			err, serverErr = bouncers.Bounce(client.GetReadWriter(), server.GetReadWriter(), packet)
 		}
 		if serverErr != nil {
-			T.removeServer(server)
 			return serverErr
 		} else {
 			TransactionComplete(client, server)
-
 		}
 
 		if err != nil {
-			T.releaseServer(server)
 			return err
 		}
 	}
-- 
GitLab