From 75e47c596512d8081f94731f9da804258d42c082 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 29 Aug 2023 19:58:07 -0500
Subject: [PATCH] idle timeout

---
 lib/gat/pool/pool.go | 97 ++++++++++++++++++++++++++++++++++++--------
 pgbouncer.ini        |  1 +
 2 files changed, 82 insertions(+), 16 deletions(-)

diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index a7a2313b..6a673039 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -2,6 +2,8 @@ package pool
 
 import (
 	"sync"
+	"sync/atomic"
+	"time"
 
 	"github.com/google/uuid"
 	"tuxpa.in/a/zlog/log"
@@ -29,18 +31,24 @@ type poolServer struct {
 	// middlewares
 	psServer  *ps.Server
 	eqpServer *eqp.Server
+
+	// peer is uuid.Nil if idle, and the client id otherwise
+	peer uuid.UUID
+	// since is when the current state started
+	since time.Time
+	mu    sync.Mutex
 }
 
 type poolRecipe struct {
 	recipe Recipe
-	count  int
+	count  atomic.Int64
 }
 
 type Pool struct {
 	options Options
 
 	recipes map[string]*poolRecipe
-	servers map[uuid.UUID]poolServer
+	servers map[uuid.UUID]*poolServer
 	clients map[uuid.UUID]zap.Conn
 	mu      sync.Mutex
 }
@@ -51,14 +59,53 @@ func NewPool(options Options) *Pool {
 	}
 
 	if options.ServerIdleTimeout != 0 {
-		go func() {
-			// TODO(garet) check pool for idle servers
-		}()
+		go p.idleTimeoutLoop()
 	}
 
 	return p
 }
 
+func (T *Pool) idlest() (idlest uuid.UUID, idle time.Time) {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	for serverID, server := range T.servers {
+		if server.peer != uuid.Nil {
+			continue
+		}
+
+		if idle != (time.Time{}) && server.since.After(idle) {
+			continue
+		}
+
+		idlest = serverID
+		idle = server.since
+	}
+
+	return
+}
+
+func (T *Pool) idleTimeoutLoop() {
+	for {
+		var wait time.Duration
+
+		now := time.Now()
+		var idlest uuid.UUID
+		var idle time.Time
+		for idlest, idle = T.idlest(); idlest != uuid.Nil && now.Sub(idle) > T.options.ServerIdleTimeout; idlest, idle = T.idlest() {
+			T.removeServer(idlest)
+		}
+
+		if idlest == uuid.Nil {
+			wait = T.options.ServerIdleTimeout
+		} else {
+			wait = idle.Add(T.options.ServerIdleTimeout).Sub(now)
+		}
+
+		time.Sleep(wait)
+	}
+}
+
 func (T *Pool) GetCredentials() auth.Credentials {
 	return T.options.Credentials
 }
@@ -74,7 +121,7 @@ func (T *Pool) _scaleUpRecipe(name string) {
 
 	serverID := uuid.New()
 	if T.servers == nil {
-		T.servers = make(map[uuid.UUID]poolServer)
+		T.servers = make(map[uuid.UUID]*poolServer)
 	}
 
 	var middlewares []middleware.Middleware
@@ -93,13 +140,15 @@ func (T *Pool) _scaleUpRecipe(name string) {
 		middlewares = append(middlewares, eqpServer)
 	}
 
-	T.servers[serverID] = poolServer{
+	T.servers[serverID] = &poolServer{
 		conn:   server,
 		accept: params,
 		recipe: name,
 
 		psServer:  psServer,
 		eqpServer: eqpServer,
+
+		since: time.Now(),
 	}
 	T.options.Pooler.AddServer(serverID)
 }
@@ -113,7 +162,6 @@ func (T *Pool) AddRecipe(name string, recipe Recipe) {
 	}
 	T.recipes[name] = &poolRecipe{
 		recipe: recipe,
-		count:  0,
 	}
 
 	for i := 0; i < recipe.MinConnections; i++ {
@@ -142,7 +190,7 @@ func (T *Pool) scaleUp() {
 	defer T.mu.Unlock()
 
 	for name, r := range T.recipes {
-		if r.recipe.MaxConnections == 0 || r.count < r.recipe.MaxConnections {
+		if r.recipe.MaxConnections == 0 || int(r.count.Load()) < r.recipe.MaxConnections {
 			T._scaleUpRecipe(name)
 			return
 		}
@@ -242,7 +290,7 @@ func (T *Pool) Serve(
 	clientID := T.addClient(client)
 
 	var serverID uuid.UUID
-	var server poolServer
+	var server *poolServer
 
 	defer func() {
 		if serverID != uuid.Nil {
@@ -277,13 +325,13 @@ func (T *Pool) Serve(
 		if serverErr != nil {
 			T.removeServer(serverID)
 			serverID = uuid.Nil
-			server = poolServer{}
+			server = nil
 			return serverErr
 		} else {
 			if T.options.Pooler.ReleaseAfterTransaction() {
 				T.releaseServer(serverID)
 				serverID = uuid.Nil
-				server = poolServer{}
+				server = nil
 			}
 		}
 
@@ -307,7 +355,7 @@ func (T *Pool) addClient(client zap.Conn) uuid.UUID {
 	return clientID
 }
 
-func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server poolServer) {
+func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server *poolServer) {
 	serverID = T.options.Pooler.AcquireConcurrent(clientID)
 	if serverID == uuid.Nil {
 		go T.scaleUp()
@@ -317,6 +365,12 @@ func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server poo
 	T.mu.Lock()
 	defer T.mu.Unlock()
 	server = T.servers[serverID]
+	if server != nil {
+		server.mu.Lock()
+		defer server.mu.Unlock()
+		server.peer = clientID
+		server.since = time.Now()
+	}
 	return
 }
 
@@ -324,9 +378,20 @@ func (T *Pool) releaseServer(serverID uuid.UUID) {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
+	server := T.servers[serverID]
+	if server == nil {
+		return
+	}
+
+	func() {
+		server.mu.Lock()
+		defer server.mu.Unlock()
+		server.peer = uuid.Nil
+		server.since = time.Now()
+	}()
+
 	if T.options.ServerResetQuery != "" {
-		server := T.servers[serverID].conn
-		err := backends.QueryString(new(backends.Context), server, T.options.ServerResetQuery)
+		err := backends.QueryString(new(backends.Context), server.conn, T.options.ServerResetQuery)
 		if err != nil {
 			T._removeServer(serverID)
 			return
@@ -342,7 +407,7 @@ func (T *Pool) _removeServer(serverID uuid.UUID) {
 		T.options.Pooler.RemoveServer(serverID)
 		r := T.recipes[server.recipe]
 		if r != nil {
-			r.count--
+			r.count.Add(-1)
 		}
 	}
 }
diff --git a/pgbouncer.ini b/pgbouncer.ini
index 6c175d6f..f9ea6434 100644
--- a/pgbouncer.ini
+++ b/pgbouncer.ini
@@ -3,6 +3,7 @@ pool_mode = transaction
 auth_file = userlist.txt
 listen_addr = *
 track_extra_parameters = IntervalStyle, session_authorization, default_transaction_read_only, search_path
+server_idle_timeout = 10
 
 [users]
 postgres =
-- 
GitLab