From d8983723bae46af8259e6bdbc99c83f89a1b6dcd Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Mon, 2 Oct 2023 18:05:00 -0500
Subject: [PATCH] close pooler to prevent clients from hanging forever if pool
 closes while awaiting server

---
 lib/gat/pool/errors.go                |  5 +++
 lib/gat/pool/pool.go                  | 10 ++++++
 lib/gat/pool/pooler.go                |  2 ++
 lib/gat/poolers/session/pooler.go     | 23 +++++++++++++
 lib/gat/poolers/transaction/pooler.go |  4 +++
 lib/rob/scheduler.go                  |  2 ++
 lib/rob/schedulers/v2/scheduler.go    | 48 ++++++++++++++++++++++++---
 7 files changed, 90 insertions(+), 4 deletions(-)
 create mode 100644 lib/gat/pool/errors.go

diff --git a/lib/gat/pool/errors.go b/lib/gat/pool/errors.go
new file mode 100644
index 00000000..b0a18cec
--- /dev/null
+++ b/lib/gat/pool/errors.go
@@ -0,0 +1,5 @@
+package pool
+
+import "errors"
+
+var ErrClosed = errors.New("pool closed")
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index 01221ae1..8399e39b 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -245,6 +245,9 @@ func (T *Pool) acquireServer(client *pooledClient) *pooledServer {
 			}
 			serverID = T.pooler.Acquire(client.GetID(), SyncModeBlocking)
 			T.pendingCount.Add(-1)
+			if serverID == uuid.Nil {
+				return nil
+			}
 		}
 
 		T.mu.RLock()
@@ -331,6 +334,9 @@ func (T *Pool) serve(client *pooledClient, initialized bool) error {
 
 	if !initialized {
 		server = T.acquireServer(client)
+		if server == nil {
+			return ErrClosed
+		}
 
 		err, serverErr = pair(T.config, client, server)
 		if serverErr != nil {
@@ -362,6 +368,9 @@ func (T *Pool) serve(client *pooledClient, initialized bool) error {
 
 		if server == nil {
 			server = T.acquireServer(client)
+			if server == nil {
+				return ErrClosed
+			}
 
 			err, serverErr = pair(T.config, client, server)
 		}
@@ -471,6 +480,7 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) {
 
 func (T *Pool) Close() {
 	close(T.closed)
+	T.pooler.Close()
 
 	T.mu.Lock()
 	defer T.mu.Unlock()
diff --git a/lib/gat/pool/pooler.go b/lib/gat/pool/pooler.go
index 2133ce11..06bba5a4 100644
--- a/lib/gat/pool/pooler.go
+++ b/lib/gat/pool/pooler.go
@@ -18,4 +18,6 @@ type Pooler interface {
 
 	Acquire(client uuid.UUID, sync SyncMode) (server uuid.UUID)
 	Release(server uuid.UUID)
+
+	Close()
 }
diff --git a/lib/gat/poolers/session/pooler.go b/lib/gat/poolers/session/pooler.go
index 9827727e..8f0dd004 100644
--- a/lib/gat/poolers/session/pooler.go
+++ b/lib/gat/poolers/session/pooler.go
@@ -13,6 +13,7 @@ type Pooler struct {
 	queue   []uuid.UUID
 	servers map[uuid.UUID]struct{}
 	ready   *sync.Cond
+	closed  bool
 	mu      sync.Mutex
 }
 
@@ -56,6 +57,10 @@ func (T *Pooler) TryAcquire() uuid.UUID {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
+	if T.closed {
+		return uuid.Nil
+	}
+
 	if len(T.queue) == 0 {
 		return uuid.Nil
 	}
@@ -69,6 +74,10 @@ func (T *Pooler) AcquireBlocking() uuid.UUID {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
+	if T.closed {
+		return uuid.Nil
+	}
+
 	for len(T.queue) == 0 {
 		if T.ready == nil {
 			T.ready = sync.NewCond(&T.mu)
@@ -76,6 +85,10 @@ func (T *Pooler) AcquireBlocking() uuid.UUID {
 		T.ready.Wait()
 	}
 
+	if T.closed {
+		return uuid.Nil
+	}
+
 	server := T.queue[len(T.queue)-1]
 	T.queue = T.queue[:len(T.queue)-1]
 	return server
@@ -104,4 +117,14 @@ func (T *Pooler) Release(server uuid.UUID) {
 	T.queue = append(T.queue, server)
 }
 
+func (T *Pooler) Close() {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	T.closed = true
+	if T.ready != nil {
+		T.ready.Broadcast()
+	}
+}
+
 var _ pool.Pooler = (*Pooler)(nil)
diff --git a/lib/gat/poolers/transaction/pooler.go b/lib/gat/poolers/transaction/pooler.go
index 939718a2..3eb1b9e1 100644
--- a/lib/gat/poolers/transaction/pooler.go
+++ b/lib/gat/poolers/transaction/pooler.go
@@ -47,4 +47,8 @@ func (T *Pooler) Release(server uuid.UUID) {
 	T.s.Release(server)
 }
 
+func (T *Pooler) Close() {
+	T.s.Close()
+}
+
 var _ pool.Pooler = (*Pooler)(nil)
diff --git a/lib/rob/scheduler.go b/lib/rob/scheduler.go
index 1cef414e..97fcadd7 100644
--- a/lib/rob/scheduler.go
+++ b/lib/rob/scheduler.go
@@ -28,4 +28,6 @@ type Scheduler interface {
 	// Release will release a worker.
 	// This should be called after acquire unless the worker is removed with RemoveWorker
 	Release(worker uuid.UUID)
+
+	Close()
 }
diff --git a/lib/rob/schedulers/v2/scheduler.go b/lib/rob/schedulers/v2/scheduler.go
index 2a5d72e6..01443c79 100644
--- a/lib/rob/schedulers/v2/scheduler.go
+++ b/lib/rob/schedulers/v2/scheduler.go
@@ -1,13 +1,15 @@
 package schedulers
 
 import (
+	"sync"
+
+	"github.com/google/uuid"
+
 	"gfx.cafe/gfx/pggat/lib/rob"
 	"gfx.cafe/gfx/pggat/lib/rob/schedulers/v2/job"
 	"gfx.cafe/gfx/pggat/lib/rob/schedulers/v2/sink"
 	"gfx.cafe/gfx/pggat/lib/util/maps"
 	"gfx.cafe/gfx/pggat/lib/util/pools"
-	"github.com/google/uuid"
-	"sync"
 )
 
 type Scheduler struct {
@@ -20,6 +22,7 @@ type Scheduler struct {
 	backlog []job.Stalled
 	bmu     sync.Mutex
 	sinks   map[uuid.UUID]*sink.Sink
+	closed  bool
 	mu      sync.RWMutex
 }
 
@@ -102,6 +105,10 @@ func (T *Scheduler) TryAcquire(j job.Concurrent) uuid.UUID {
 	T.mu.RLock()
 	defer T.mu.RUnlock()
 
+	if T.closed {
+		return uuid.Nil
+	}
+
 	return T.tryAcquire(j)
 }
 
@@ -130,6 +137,13 @@ func (T *Scheduler) Enqueue(j ...job.Stalled) {
 	T.mu.RLock()
 	defer T.mu.RUnlock()
 
+	if T.closed {
+		for _, jj := range j {
+			close(jj.Ready)
+			return
+		}
+	}
+
 	for _, jj := range j {
 		T.enqueue(jj)
 	}
@@ -146,7 +160,6 @@ func (T *Scheduler) Acquire(user uuid.UUID, mode rob.SyncMode) uuid.UUID {
 		if !ok {
 			ready = make(chan uuid.UUID, 1)
 		}
-		defer T.ready.Put(ready)
 
 		j := job.Stalled{
 			Concurrent: job.Concurrent{
@@ -156,7 +169,11 @@ func (T *Scheduler) Acquire(user uuid.UUID, mode rob.SyncMode) uuid.UUID {
 		}
 		T.Enqueue(j)
 
-		return <-ready
+		s, ok := <-ready
+		if ok {
+			T.ready.Put(ready)
+		}
+		return s
 	case rob.SyncModeTryNonBlocking:
 		if id := T.Acquire(user, rob.SyncModeNonBlocking); id != uuid.Nil {
 			return id
@@ -201,4 +218,27 @@ func (T *Scheduler) stealFor(worker uuid.UUID) {
 	}
 }
 
+func (T *Scheduler) Close() {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	T.closed = true
+
+	for worker, s := range T.sinks {
+		delete(T.sinks, worker)
+
+		// now we need to reschedule all the work that was scheduled to s (stalled only).
+		jobs := s.StealAll()
+
+		for _, j := range jobs {
+			close(j.Ready)
+		}
+	}
+
+	for _, j := range T.backlog {
+		close(j.Ready)
+	}
+	T.backlog = T.backlog[:0]
+}
+
 var _ rob.Scheduler = (*Scheduler)(nil)
-- 
GitLab