From 880f93ea6c1dda9a439295d8d12459483947f8f0 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Mon, 31 Jul 2023 11:59:51 -0600
Subject: [PATCH] fix races

---
 lib/gat/pool.go | 57 +++++++++++++++++++++++++++++++++++--------------
 1 file changed, 41 insertions(+), 16 deletions(-)

diff --git a/lib/gat/pool.go b/lib/gat/pool.go
index 71ec57df..0c89faf7 100644
--- a/lib/gat/pool.go
+++ b/lib/gat/pool.go
@@ -26,12 +26,13 @@ type RawPool interface {
 type recipeWithConns struct {
 	recipe Recipe
 
-	conns []uuid.UUID
-	mu    sync.Mutex
+	removed bool
+	conns   []uuid.UUID
+	mu      sync.Mutex
 }
 
 func (T *recipeWithConns) scaleUp(pool *Pool, currentScale int) bool {
-	if currentScale >= T.recipe.GetMaxConnections() {
+	if currentScale >= T.recipe.GetMaxConnections() || T.removed {
 		return false
 	}
 
@@ -57,7 +58,7 @@ func (T *recipeWithConns) scaleUp(pool *Pool, currentScale int) bool {
 }
 
 func (T *recipeWithConns) scaleDown(pool *Pool, currentScale int) bool {
-	if currentScale <= T.recipe.GetMinConnections() {
+	if currentScale <= T.recipe.GetMinConnections() || T.removed {
 		return false
 	}
 
@@ -76,6 +77,10 @@ func (T *recipeWithConns) scaleDown(pool *Pool, currentScale int) bool {
 }
 
 func (T *recipeWithConns) scale(pool *Pool, currentScale int, amount int) int {
+	if T.removed {
+		return amount
+	}
+
 	if amount > 0 {
 		for amount > 0 {
 			if T.scaleUp(pool, currentScale) {
@@ -99,6 +104,10 @@ func (T *recipeWithConns) scale(pool *Pool, currentScale int, amount int) int {
 }
 
 func (T *recipeWithConns) currentScale(pool *Pool) int {
+	if T.removed {
+		return 0
+	}
+
 	i := 0
 	for j := 0; j < len(T.conns); j++ {
 		if pool.raw.GetServer(T.conns[j]) != nil {
@@ -122,14 +131,17 @@ func (T *recipeWithConns) Scale(pool *Pool, amount int) int {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
+	if T.removed {
+		return amount
+	}
 	currentScale := T.currentScale(pool)
 	return T.scale(pool, currentScale, amount)
 }
 
-func (T *recipeWithConns) SetScale(pool *Pool, scale int) {
-	T.mu.Lock()
-	defer T.mu.Unlock()
-
+func (T *recipeWithConns) setScale(pool *Pool, scale int) {
+	if T.removed {
+		return
+	}
 	target := maths.Clamp(scale, T.recipe.GetMinConnections(), T.recipe.GetMaxConnections())
 	currentScale := T.currentScale(pool)
 	target -= currentScale
@@ -137,14 +149,27 @@ func (T *recipeWithConns) SetScale(pool *Pool, scale int) {
 	T.scale(pool, currentScale, target)
 }
 
+func (T *recipeWithConns) SetScale(pool *Pool, scale int) {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	T.setScale(pool, scale)
+}
+
 func (T *recipeWithConns) Added(pool *Pool) {
-	T.SetScale(pool, 0)
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	T.removed = false
+	T.setScale(pool, 0)
 }
 
 func (T *recipeWithConns) Removed(pool *Pool) {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
+	T.removed = true
+
 	for _, conn := range T.conns {
 		pool.raw.RemoveServer(conn)
 	}
@@ -186,23 +211,23 @@ func (T *Pool) Serve(client zap.ReadWriter) {
 
 func (T *Pool) CurrentScale() int {
 	T.mu.Lock()
-	recipes := make([]string, 0, len(T.recipes))
-	for recipe := range T.recipes {
+	recipes := make([]*recipeWithConns, 0, len(T.recipes))
+	for _, recipe := range T.recipes {
 		recipes = append(recipes, recipe)
 	}
 	T.mu.Unlock()
 
 	scale := 0
 	for _, recipe := range recipes {
-		scale += T.recipes[recipe].CurrentScale(T)
+		scale += recipe.CurrentScale(T)
 	}
 	return scale
 }
 
 func (T *Pool) Scale(amount int) {
 	T.mu.Lock()
-	recipes := make([]string, 0, len(T.recipes))
-	for recipe := range T.recipes {
+	recipes := make([]*recipeWithConns, 0, len(T.recipes))
+	for _, recipe := range T.recipes {
 		recipes = append(recipes, recipe)
 	}
 	T.mu.Unlock()
@@ -213,13 +238,13 @@ outer:
 		for i := 0; i < len(recipes); i++ {
 			recipe := recipes[i]
 			if amount > 0 {
-				if T.recipes[recipe].Scale(T, 1) == 0 {
+				if recipe.Scale(T, 1) == 0 {
 					amount--
 					recipes[j] = recipes[i]
 					j++
 				}
 			} else if amount < 0 {
-				if T.recipes[recipe].Scale(T, -1) == 0 {
+				if recipe.Scale(T, -1) == 0 {
 					amount++
 					recipes[j] = recipes[i]
 					j++
-- 
GitLab