From 880f93ea6c1dda9a439295d8d12459483947f8f0 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Mon, 31 Jul 2023 11:59:51 -0600 Subject: [PATCH] fix races --- lib/gat/pool.go | 57 +++++++++++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 16 deletions(-) diff --git a/lib/gat/pool.go b/lib/gat/pool.go index 71ec57df..0c89faf7 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -26,12 +26,13 @@ type RawPool interface { type recipeWithConns struct { recipe Recipe - conns []uuid.UUID - mu sync.Mutex + removed bool + conns []uuid.UUID + mu sync.Mutex } func (T *recipeWithConns) scaleUp(pool *Pool, currentScale int) bool { - if currentScale >= T.recipe.GetMaxConnections() { + if currentScale >= T.recipe.GetMaxConnections() || T.removed { return false } @@ -57,7 +58,7 @@ func (T *recipeWithConns) scaleUp(pool *Pool, currentScale int) bool { } func (T *recipeWithConns) scaleDown(pool *Pool, currentScale int) bool { - if currentScale <= T.recipe.GetMinConnections() { + if currentScale <= T.recipe.GetMinConnections() || T.removed { return false } @@ -76,6 +77,10 @@ func (T *recipeWithConns) scaleDown(pool *Pool, currentScale int) bool { } func (T *recipeWithConns) scale(pool *Pool, currentScale int, amount int) int { + if T.removed { + return amount + } + if amount > 0 { for amount > 0 { if T.scaleUp(pool, currentScale) { @@ -99,6 +104,10 @@ func (T *recipeWithConns) scale(pool *Pool, currentScale int, amount int) int { } func (T *recipeWithConns) currentScale(pool *Pool) int { + if T.removed { + return 0 + } + i := 0 for j := 0; j < len(T.conns); j++ { if pool.raw.GetServer(T.conns[j]) != nil { @@ -122,14 +131,17 @@ func (T *recipeWithConns) Scale(pool *Pool, amount int) int { T.mu.Lock() defer T.mu.Unlock() + if T.removed { + return amount + } currentScale := T.currentScale(pool) return T.scale(pool, currentScale, amount) } -func (T *recipeWithConns) SetScale(pool *Pool, scale int) { - T.mu.Lock() - defer T.mu.Unlock() - +func (T *recipeWithConns) setScale(pool *Pool, scale int) { + if T.removed { + return + } target := maths.Clamp(scale, T.recipe.GetMinConnections(), T.recipe.GetMaxConnections()) currentScale := T.currentScale(pool) target -= currentScale @@ -137,14 +149,27 @@ func (T *recipeWithConns) SetScale(pool *Pool, scale int) { T.scale(pool, currentScale, target) } +func (T *recipeWithConns) SetScale(pool *Pool, scale int) { + T.mu.Lock() + defer T.mu.Unlock() + + T.setScale(pool, scale) +} + func (T *recipeWithConns) Added(pool *Pool) { - T.SetScale(pool, 0) + T.mu.Lock() + defer T.mu.Unlock() + + T.removed = false + T.setScale(pool, 0) } func (T *recipeWithConns) Removed(pool *Pool) { T.mu.Lock() defer T.mu.Unlock() + T.removed = true + for _, conn := range T.conns { pool.raw.RemoveServer(conn) } @@ -186,23 +211,23 @@ func (T *Pool) Serve(client zap.ReadWriter) { func (T *Pool) CurrentScale() int { T.mu.Lock() - recipes := make([]string, 0, len(T.recipes)) - for recipe := range T.recipes { + recipes := make([]*recipeWithConns, 0, len(T.recipes)) + for _, recipe := range T.recipes { recipes = append(recipes, recipe) } T.mu.Unlock() scale := 0 for _, recipe := range recipes { - scale += T.recipes[recipe].CurrentScale(T) + scale += recipe.CurrentScale(T) } return scale } func (T *Pool) Scale(amount int) { T.mu.Lock() - recipes := make([]string, 0, len(T.recipes)) - for recipe := range T.recipes { + recipes := make([]*recipeWithConns, 0, len(T.recipes)) + for _, recipe := range T.recipes { recipes = append(recipes, recipe) } T.mu.Unlock() @@ -213,13 +238,13 @@ outer: for i := 0; i < len(recipes); i++ { recipe := recipes[i] if amount > 0 { - if T.recipes[recipe].Scale(T, 1) == 0 { + if recipe.Scale(T, 1) == 0 { amount-- recipes[j] = recipes[i] j++ } } else if amount < 0 { - if T.recipes[recipe].Scale(T, -1) == 0 { + if recipe.Scale(T, -1) == 0 { amount++ recipes[j] = recipes[i] j++ -- GitLab