From df0200abc38721e1f90f4ff9c46370a36be51c85 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Sun, 17 Sep 2023 22:16:41 -0500
Subject: [PATCH] change api so prev was valid

---
 .../modes/digitalocean_discovery/config.go    | 17 ++++-------
 lib/gat/pool/options.go                       |  2 +-
 lib/gat/pool/pool.go                          | 29 +++++++++++++------
 lib/gat/pool/pools/session/apply.go           |  2 +-
 lib/gat/pool/pools/session/pooler.go          |  4 +++
 lib/gat/pool/pools/transaction/apply.go       |  2 +-
 lib/gat/pool/pools/transaction/pooler.go      |  4 +++
 lib/rob/schedulers/v2/scheduler_test.go       | 13 ++++++---
 8 files changed, 45 insertions(+), 28 deletions(-)

diff --git a/lib/gat/modes/digitalocean_discovery/config.go b/lib/gat/modes/digitalocean_discovery/config.go
index 6fd93685..f09a0472 100644
--- a/lib/gat/modes/digitalocean_discovery/config.go
+++ b/lib/gat/modes/digitalocean_discovery/config.go
@@ -95,7 +95,7 @@ func (T *Config) ListenAndServe() error {
 			}
 
 			for _, dbname := range cluster.DBNames {
-				baseOptions := pool.Options{
+				poolOptions := pool.Options{
 					Credentials:                creds,
 					ServerReconnectInitialTime: 5 * time.Second,
 					ServerReconnectMaxTime:     5 * time.Second,
@@ -108,12 +108,11 @@ func (T *Config) ListenAndServe() error {
 						strutil.MakeCIString("application_name"),
 					},
 				}
-				var poolOptions pool.Options
 				if T.PoolMode == "session" {
-					baseOptions.ServerResetQuery = "DISCARD ALL"
-					poolOptions = session.Apply(baseOptions)
+					poolOptions.ServerResetQuery = "DISCARD ALL"
+					poolOptions = session.Apply(poolOptions)
 				} else {
-					poolOptions = transaction.Apply(baseOptions)
+					poolOptions = transaction.Apply(poolOptions)
 				}
 
 				p := pool.NewPool(poolOptions)
@@ -150,15 +149,9 @@ func (T *Config) ListenAndServe() error {
 					// change pool credentials
 					creds2 := creds
 					creds2.Username = user.Name + "_ro"
-					poolOptions2 := baseOptions
+					poolOptions2 := poolOptions
 					poolOptions2.Credentials = creds2
 
-					if T.PoolMode == "session" {
-						poolOptions2 = session.Apply(baseOptions)
-					} else {
-						poolOptions2 = transaction.Apply(baseOptions)
-					}
-
 					p2 := pool.NewPool(poolOptions2)
 
 					for _, replica := range replicas {
diff --git a/lib/gat/pool/options.go b/lib/gat/pool/options.go
index 675ce84c..f3716f9c 100644
--- a/lib/gat/pool/options.go
+++ b/lib/gat/pool/options.go
@@ -23,7 +23,7 @@ const (
 type Options struct {
 	Credentials auth.Credentials
 
-	Pooler Pooler
+	NewPooler func() Pooler
 	// ReleaseAfterTransaction toggles whether servers should be released and re acquired after each transaction.
 	// Use false for lower latency
 	// Use true for better balancing
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index cc6c5a21..de8fd0ba 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -22,6 +22,7 @@ import (
 
 type Pool struct {
 	options Options
+	pooler  Pooler
 
 	closed chan struct{}
 
@@ -37,10 +38,20 @@ type Pool struct {
 }
 
 func NewPool(options Options) *Pool {
+	if options.NewPooler == nil {
+		panic("expected new pooler func")
+	}
+	pooler := options.NewPooler()
+	if pooler == nil {
+		panic("expected pooler")
+	}
+
 	p := &Pool{
+		options: options,
+		pooler:  pooler,
+
 		closed:  make(chan struct{}),
 		pending: make(chan struct{}, 1),
-		options: options,
 	}
 
 	s := NewScaler(p)
@@ -177,7 +188,7 @@ func (T *Pool) scaleUpL1(name string, r *recipe.Recipe) error {
 		return err
 	}
 
-	T.options.Pooler.AddServer(server.GetID())
+	T.pooler.AddServer(server.GetID())
 	return nil
 }
 
@@ -205,7 +216,7 @@ func (T *Pool) removeServer(server *Server) {
 
 func (T *Pool) removeServerL1(server *Server) {
 	delete(T.servers, server.GetID())
-	T.options.Pooler.DeleteServer(server.GetID())
+	T.pooler.DeleteServer(server.GetID())
 	_ = server.GetConn().Close()
 	if T.serversByRecipe != nil {
 		T.serversByRecipe[server.GetRecipe()] = slices.Remove(T.serversByRecipe[server.GetRecipe()], server)
@@ -216,14 +227,14 @@ func (T *Pool) acquireServer(client *Client) *Server {
 	client.SetState(metrics.ConnStateAwaitingServer, uuid.Nil)
 
 	for {
-		serverID := T.options.Pooler.Acquire(client.GetID(), SyncModeNonBlocking)
+		serverID := T.pooler.Acquire(client.GetID(), SyncModeNonBlocking)
 		if serverID == uuid.Nil {
 			T.pendingCount.Add(1)
 			select {
 			case T.pending <- struct{}{}:
 			default:
 			}
-			serverID = T.options.Pooler.Acquire(client.GetID(), SyncModeBlocking)
+			serverID = T.pooler.Acquire(client.GetID(), SyncModeBlocking)
 			T.pendingCount.Add(-1)
 		}
 
@@ -231,7 +242,7 @@ func (T *Pool) acquireServer(client *Client) *Server {
 		server, ok := T.servers[serverID]
 		T.mu.RUnlock()
 		if !ok {
-			T.options.Pooler.DeleteServer(serverID)
+			T.pooler.DeleteServer(serverID)
 			continue
 		}
 		return server
@@ -251,7 +262,7 @@ func (T *Pool) releaseServer(server *Server) {
 
 	server.SetState(metrics.ConnStateIdle, uuid.Nil)
 
-	T.options.Pooler.Release(server.GetID())
+	T.pooler.Release(server.GetID())
 }
 
 func (T *Pool) Serve(
@@ -374,7 +385,7 @@ func (T *Pool) addClient(client *Client) {
 		T.clientsByKey = make(map[[8]byte]*Client)
 	}
 	T.clientsByKey[client.GetBackendKey()] = client
-	T.options.Pooler.AddClient(client.GetID())
+	T.pooler.AddClient(client.GetID())
 }
 
 func (T *Pool) removeClient(client *Client) {
@@ -385,7 +396,7 @@ func (T *Pool) removeClient(client *Client) {
 }
 
 func (T *Pool) removeClientL1(client *Client) {
-	T.options.Pooler.DeleteClient(client.GetID())
+	T.pooler.DeleteClient(client.GetID())
 	_ = client.conn.Close()
 	delete(T.clients, client.GetID())
 	delete(T.clientsByKey, client.GetBackendKey())
diff --git a/lib/gat/pool/pools/session/apply.go b/lib/gat/pool/pools/session/apply.go
index f4ddee0e..2ca42754 100644
--- a/lib/gat/pool/pools/session/apply.go
+++ b/lib/gat/pool/pools/session/apply.go
@@ -5,7 +5,7 @@ import (
 )
 
 func Apply(options pool.Options) pool.Options {
-	options.Pooler = new(Pooler)
+	options.NewPooler = NewPooler
 	options.ParameterStatusSync = pool.ParameterStatusSyncInitial
 	options.ExtendedQuerySync = false
 	return options
diff --git a/lib/gat/pool/pools/session/pooler.go b/lib/gat/pool/pools/session/pooler.go
index f322a828..6024249a 100644
--- a/lib/gat/pool/pools/session/pooler.go
+++ b/lib/gat/pool/pools/session/pooler.go
@@ -16,6 +16,10 @@ type Pooler struct {
 	mu      sync.Mutex
 }
 
+func NewPooler() pool.Pooler {
+	return new(Pooler)
+}
+
 func (*Pooler) AddClient(_ uuid.UUID) {}
 
 func (*Pooler) DeleteClient(_ uuid.UUID) {
diff --git a/lib/gat/pool/pools/transaction/apply.go b/lib/gat/pool/pools/transaction/apply.go
index 637585b0..bfe41820 100644
--- a/lib/gat/pool/pools/transaction/apply.go
+++ b/lib/gat/pool/pools/transaction/apply.go
@@ -3,7 +3,7 @@ package transaction
 import "pggat/lib/gat/pool"
 
 func Apply(options pool.Options) pool.Options {
-	options.Pooler = new(Pooler)
+	options.NewPooler = NewPooler
 	options.ParameterStatusSync = pool.ParameterStatusSyncDynamic
 	options.ExtendedQuerySync = true
 	options.ReleaseAfterTransaction = true
diff --git a/lib/gat/pool/pools/transaction/pooler.go b/lib/gat/pool/pools/transaction/pooler.go
index 5774a26d..95011ba2 100644
--- a/lib/gat/pool/pools/transaction/pooler.go
+++ b/lib/gat/pool/pools/transaction/pooler.go
@@ -12,6 +12,10 @@ type Pooler struct {
 	s schedulers.Scheduler
 }
 
+func NewPooler() pool.Pooler {
+	return new(Pooler)
+}
+
 func (T *Pooler) AddClient(client uuid.UUID) {
 	T.s.AddUser(client)
 }
diff --git a/lib/rob/schedulers/v2/scheduler_test.go b/lib/rob/schedulers/v2/scheduler_test.go
index d7546402..1efed1cd 100644
--- a/lib/rob/schedulers/v2/scheduler_test.go
+++ b/lib/rob/schedulers/v2/scheduler_test.go
@@ -34,11 +34,14 @@ func (T *ShareTable) Get(user int) int {
 }
 
 func testSink(sched *Scheduler) uuid.UUID {
-	return sched.NewWorker()
+	id := uuid.New()
+	sched.AddWorker(id)
+	return id
 }
 
 func testSource(sched *Scheduler, tab *ShareTable, id int, dur time.Duration) {
-	source := sched.NewUser()
+	source := uuid.New()
+	sched.AddUser(source)
 	for {
 		sink := sched.Acquire(source, rob.SyncModeTryNonBlocking)
 		start := time.Now()
@@ -50,7 +53,8 @@ func testSource(sched *Scheduler, tab *ShareTable, id int, dur time.Duration) {
 }
 
 func testMultiSource(sched *Scheduler, tab *ShareTable, id int, dur time.Duration, num int) {
-	source := sched.NewUser()
+	source := uuid.New()
+	sched.AddUser(source)
 	for i := 0; i < num; i++ {
 		go func() {
 			for {
@@ -68,7 +72,8 @@ func testMultiSource(sched *Scheduler, tab *ShareTable, id int, dur time.Duratio
 func testStarver(sched *Scheduler, tab *ShareTable, id int, dur time.Duration) {
 	for {
 		func() {
-			source := sched.NewUser()
+			source := uuid.New()
+			sched.AddUser(source)
 			defer sched.DeleteUser(source)
 
 			sink := sched.Acquire(source, rob.SyncModeTryNonBlocking)
-- 
GitLab