diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 646a4511aa7e8c71804c09cf4892f2c2437eda8c..9d75d2dd4cd0cd304c2294a3e253ba3cb92f1a90 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -3,6 +3,7 @@ package pool import ( "errors" "sync" + "sync/atomic" "time" "github.com/google/uuid" @@ -24,6 +25,9 @@ type Pool struct { closed chan struct{} + pendingCount atomic.Int64 + pending chan struct{} + recipes map[string]*recipe.Recipe clients map[uuid.UUID]*Client clientsByKey map[[8]byte]*Client @@ -35,12 +39,12 @@ type Pool struct { func NewPool(options Options) *Pool { p := &Pool{ closed: make(chan struct{}), + pending: make(chan struct{}, 1), options: options, } - if options.ServerIdleTimeout != 0 { - go p.idleLoop() - } + s := NewScaler(p) + go s.Run() return p } @@ -64,33 +68,6 @@ func (T *Pool) idlest() (server *Server, at time.Time) { return } -func (T *Pool) idleLoop() { - for { - select { - case <-T.closed: - return - default: - } - - var wait time.Duration - - now := time.Now() - var idlest *Server - var idle time.Time - for idlest, idle = T.idlest(); idlest != nil && now.Sub(idle) > T.options.ServerIdleTimeout; idlest, idle = T.idlest() { - T.removeServer(idlest) - } - - if idlest == nil { - wait = T.options.ServerIdleTimeout - } else { - wait = idle.Add(T.options.ServerIdleTimeout).Sub(now) - } - - time.Sleep(wait) - } -} - func (T *Pool) GetCredentials() auth.Credentials { return T.options.Credentials } @@ -143,56 +120,20 @@ func (T *Pool) removeRecipe(name string) { } } -func (T *Pool) scaleUp() { - backoff := T.options.ServerReconnectInitialTime - - for { - select { - case <-T.closed: - return - default: - } - - name, r := func() (string, *recipe.Recipe) { - T.mu.RLock() - defer T.mu.RUnlock() - - for name, r := range T.recipes { - if r.Allocate() { - return name, r - } - } - - if len(T.servers) > 0 { - // don't retry this, there are other servers available - backoff = 0 - } - return "", nil - }() - - if r != nil { - err := T.scaleUpL1(name, r) - if err == nil { - return - } - - log.Printf("failed to dial server: %v", err) - } +func (T *Pool) scaleUpL0() (string, *recipe.Recipe) { + T.mu.RLock() + defer T.mu.RUnlock() - if backoff == 0 { - // no backoff - return + for name, r := range T.recipes { + if r.Allocate() { + return name, r } + } - log.Printf("failed to dial server. trying again in %v", backoff) - - time.Sleep(backoff) - - backoff *= 2 - if T.options.ServerReconnectMaxTime != 0 && backoff > T.options.ServerReconnectMaxTime { - backoff = T.options.ServerReconnectMaxTime - } + if len(T.servers) > 0 { + return "", nil } + return "", nil } func (T *Pool) scaleUpL1(name string, r *recipe.Recipe) error { @@ -240,6 +181,21 @@ func (T *Pool) scaleUpL1(name string, r *recipe.Recipe) error { return nil } +func (T *Pool) scaleUp() bool { + name, r := T.scaleUpL0() + if r == nil { + return false + } + + err := T.scaleUpL1(name, r) + if err != nil { + log.Printf("failed to dial server: %v", err) + return false + } + + return true +} + func (T *Pool) removeServer(server *Server) { T.mu.Lock() defer T.mu.Unlock() @@ -262,15 +218,20 @@ func (T *Pool) acquireServer(client *Client) *Server { for { serverID := T.options.Pooler.Acquire(client.GetID(), SyncModeNonBlocking) if serverID == uuid.Nil { - // TODO(garet) can this be run on same thread and only create a goroutine if scaling is possible? - go T.scaleUp() + T.pendingCount.Add(1) + select { + case T.pending <- struct{}{}: + default: + } serverID = T.options.Pooler.Acquire(client.GetID(), SyncModeBlocking) + T.pendingCount.Add(-1) } T.mu.RLock() server, ok := T.servers[serverID] T.mu.RUnlock() if !ok { + log.Println("here") T.options.Pooler.DeleteServer(serverID) continue } diff --git a/lib/gat/pool/scaler.go b/lib/gat/pool/scaler.go new file mode 100644 index 0000000000000000000000000000000000000000..224c37fbcedfed63b74be006e0fd864a444f58a5 --- /dev/null +++ b/lib/gat/pool/scaler.go @@ -0,0 +1,114 @@ +package pool + +import ( + "time" + "tuxpa.in/a/zlog/log" +) + +type Scaler struct { + pool *Pool + + backingOff bool + backoff time.Duration + + // timers + idle *time.Timer + pending *time.Timer +} + +func NewScaler(pool *Pool) *Scaler { + s := &Scaler{ + pool: pool, + backoff: pool.options.ServerIdleTimeout, + } + + if pool.options.ServerIdleTimeout != 0 { + s.idle = time.NewTimer(pool.options.ServerIdleTimeout) + } + + return s +} + +func (T *Scaler) idleTimeout(now time.Time) { + // idle loop for scaling down + var wait time.Duration + + var idlest *Server + var idleStart time.Time + for idlest, idleStart = T.pool.idlest(); idlest != nil && now.Sub(idleStart) > T.pool.options.ServerIdleTimeout; idlest, idleStart = T.pool.idlest() { + T.pool.removeServer(idlest) + } + + if idlest == nil { + wait = T.pool.options.ServerIdleTimeout + } else { + wait = idleStart.Add(T.pool.options.ServerIdleTimeout).Sub(now) + } + + T.idle.Reset(wait) +} + +func (T *Scaler) pendingTimeout() { + if T.backingOff { + T.backoff *= 2 + if T.pool.options.ServerReconnectMaxTime != 0 && T.backoff > T.pool.options.ServerReconnectMaxTime { + T.backoff = T.pool.options.ServerReconnectMaxTime + } + } + + for T.pool.pendingCount.Load() > 0 { + // pending loop for scaling up + if T.pool.scaleUp() { + // scale up successful, see if we need to scale up more + T.backoff = T.pool.options.ServerReconnectInitialTime + T.backingOff = false + continue + } + + if T.backoff == 0 { + // no backoff + T.backoff = T.pool.options.ServerReconnectInitialTime + T.backingOff = false + continue + } + + T.backingOff = true + if T.pending == nil { + T.pending = time.NewTimer(T.backoff) + } else { + T.pending.Reset(T.backoff) + } + + log.Printf("failed to dial server. trying again in %v", T.backoff) + + return + } +} + +func (T *Scaler) Run() { + for { + var idle <-chan time.Time + if T.idle != nil { + idle = T.idle.C + } + + var pending1 <-chan struct{} + var pending2 <-chan time.Time + if T.backingOff { + pending2 = T.pending.C + } else { + pending1 = T.pool.pending + } + + select { + case t := <-idle: + T.idleTimeout(t) + case <-pending1: + T.pendingTimeout() + case <-pending2: + T.pendingTimeout() + case <-T.pool.closed: + return + } + } +} diff --git a/test/tester_test.go b/test/tester_test.go index aaf15beeae5433843c14ea94db7cf37aefbc9ebb..3555fa5c9066c487205dea478054a536a5bbf3b5 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -6,9 +6,6 @@ import ( "fmt" "net" _ "net/http/pprof" - "strconv" - "testing" - "pggat/lib/auth" "pggat/lib/auth/credentials" "pggat/lib/bouncer/backends/v0" @@ -21,6 +18,8 @@ import ( "pggat/lib/gat/pool/recipe" "pggat/test" "pggat/test/tests" + "strconv" + "testing" ) func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, error) {