From 41dba9e452b8c924817fe7babbcf9a81f385d934 Mon Sep 17 00:00:00 2001 From: Tom Guinther <tguinther@gfxlabs.io> Date: Tue, 13 Aug 2024 19:44:47 -0400 Subject: [PATCH] shockingly we have something that compiles --- lib/gat/handlers/discovery/module.go | 146 +++++++++--------- lib/gat/handlers/pgbouncer/module.go | 35 +++-- lib/gat/handlers/pool/module.go | 10 +- lib/gat/handlers/pool/pool.go | 10 +- lib/gat/handlers/pool/pools/basic/factory.go | 5 +- lib/gat/handlers/pool/pools/basic/pool.go | 26 ++-- lib/gat/handlers/pool/pools/hybrid/factory.go | 5 +- lib/gat/handlers/pool/pools/hybrid/pool.go | 38 ++--- lib/gat/handlers/pool/spool/kitchen/chef.go | 39 ++--- lib/gat/handlers/pool/spool/pool.go | 42 ++--- lib/gat/handlers/rewrite_password/module.go | 2 + 11 files changed, 183 insertions(+), 175 deletions(-) diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go index d58735ea..acc1c582 100644 --- a/lib/gat/handlers/discovery/module.go +++ b/lib/gat/handlers/discovery/module.go @@ -98,10 +98,10 @@ func (T *Module) Provision(ctx caddy.Context) error { } T.closed = make(chan struct{}) - if err := T.reconcile(); err != nil { + if err := T.reconcile(ctx); err != nil { return err } - go T.discoverLoop() + go T.discoverLoop(ctx) return nil } @@ -123,9 +123,9 @@ func (T *Module) Cleanup() error { return nil } -func (T *Module) added(cluster Cluster) { +func (T *Module) added(ctx context.Context, cluster Cluster) { if prev, ok := T.clusters[cluster.ID]; ok { - T.updated(prev, cluster) + T.updated(ctx, prev, cluster) return } if T.clusters == nil { @@ -134,33 +134,33 @@ func (T *Module) added(cluster Cluster) { T.clusters[cluster.ID] = cluster for _, user := range cluster.Users { - T.addUser(cluster.Primary, cluster.Replicas, cluster.Databases, user) + T.addUser(ctx, cluster.Primary, cluster.Replicas, cluster.Databases, user) } } -func (T *Module) updated(prev, next Cluster) { +func (T *Module) updated(ctx context.Context, prev, next Cluster) { T.clusters[next.ID] = next // primary endpoints if prev.Primary != next.Primary { - T.replacePrimary(prev.Users, prev.Databases, next.Primary) + T.replacePrimary(ctx, prev.Users, prev.Databases, next.Primary) } // replica endpoints if len(prev.Replicas) != 0 && len(next.Replicas) == 0 { - T.removeReplicas(prev.Replicas, prev.Users, prev.Databases) + T.removeReplicas(ctx, prev.Replicas, prev.Users, prev.Databases) } else if len(prev.Replicas) == 0 && len(next.Replicas) != 0 { - T.addReplicas(next.Replicas, prev.Users, prev.Databases) + T.addReplicas(ctx, next.Replicas, prev.Users, prev.Databases) } else { // change # of replicas for id, nextReplica := range next.Replicas { prevReplica, ok := prev.Replicas[id] if !ok { - T.addReplica(prev.Users, prev.Databases, id, nextReplica) + T.addReplica(ctx, prev.Users, prev.Databases, id, nextReplica) } else if prevReplica != nextReplica { // don't need to remove, add will replace the recipe atomically - T.addReplica(prev.Users, prev.Databases, id, nextReplica) + T.addReplica(ctx, prev.Users, prev.Databases, id, nextReplica) } } for id := range prev.Replicas { @@ -169,7 +169,7 @@ func (T *Module) updated(prev, next Cluster) { continue // already handled } - T.removeReplica(prev.Users, prev.Databases, id) + T.removeReplica(ctx, prev.Users, prev.Databases, id) } } @@ -187,10 +187,10 @@ func (T *Module) updated(prev, next Cluster) { } if !ok { - T.addUser(next.Primary, next.Replicas, prev.Databases, nextUser) + T.addUser(ctx, next.Primary, next.Replicas, prev.Databases, nextUser) } else if nextUser.Password != prevUser.Password { - T.removeUser(next.Replicas, prev.Databases, nextUser.Username) - T.addUser(next.Primary, next.Replicas, prev.Databases, nextUser) + T.removeUser(ctx, next.Replicas, prev.Databases, nextUser.Username) + T.addUser(ctx, next.Primary, next.Replicas, prev.Databases, nextUser) } } outer: @@ -201,23 +201,23 @@ outer: } } - T.removeUser(next.Replicas, prev.Databases, prevUser.Username) + T.removeUser(ctx, next.Replicas, prev.Databases, prevUser.Username) } for _, nextDatabase := range next.Databases { if !slices.Contains(prev.Databases, nextDatabase) { - T.addDatabase(next.Primary, next.Replicas, next.Users, nextDatabase) + T.addDatabase(ctx, next.Primary, next.Replicas, next.Users, nextDatabase) } } for _, prevDatabase := range prev.Databases { if !slices.Contains(next.Databases, prevDatabase) { - T.removeDatabase(next.Replicas, next.Users, prevDatabase) + T.removeDatabase(ctx, next.Replicas, next.Users, prevDatabase) } } } -func (T *Module) addPrimaryNode(user User, database string, primary Node) { - p := T.getOrAddPool(user, database) +func (T *Module) addPrimaryNode(ctx context.Context, user User, database string, primary Node) { + p := T.getOrAddPool(ctx, user, database) d := pool.Recipe{ Dialer: pool.Dialer{ @@ -233,15 +233,15 @@ func (T *Module) addPrimaryNode(user User, database string, primary Node) { MinConnections: T.ServerMinConnections, MaxConnections: T.ServerMaxConnections, } - p.pool.AddRecipe("primary", &d) + p.pool.AddRecipe(ctx, "primary", &d) } func (T *Module) removePrimaryNode(username, database string) { T.removePool(username, database) } -func (T *Module) addReplicaNodes(user User, database string, replicas map[string]Node) { - p := T.getOrAddPool(user, database) +func (T *Module) addReplicaNodes(ctx context.Context, user User, database string, replicas map[string]Node) { + p := T.getOrAddPool(ctx, user, database) if rp, ok := p.pool.(pool.ReplicaPool); ok { for id, replica := range replicas { @@ -259,12 +259,12 @@ func (T *Module) addReplicaNodes(user User, database string, replicas map[string MinConnections: T.ServerMinConnections, MaxConnections: T.ServerMaxConnections, } - rp.AddReplicaRecipe(id, &d) + rp.AddReplicaRecipe(ctx, id, &d) } return } - rp := T.getOrAddReplicaPool(user, database) + rp := T.getOrAddReplicaPool(ctx, user, database) for id, replica := range replicas { d := pool.Recipe{ Dialer: pool.Dialer{ @@ -280,11 +280,11 @@ func (T *Module) addReplicaNodes(user User, database string, replicas map[string MinConnections: T.ServerMinConnections, MaxConnections: T.ServerMaxConnections, } - rp.pool.AddRecipe(id, &d) + rp.pool.AddRecipe(ctx, id, &d) } } -func (T *Module) removeReplicaNodes(username string, database string, replicas map[string]Node) { +func (T *Module) removeReplicaNodes(ctx context.Context, username string, database string, replicas map[string]Node) { p, ok := T.getPool(username, database) if !ok { return @@ -293,7 +293,7 @@ func (T *Module) removeReplicaNodes(username string, database string, replicas m // remove endpoints from replica pool if rp, ok := p.pool.(pool.ReplicaPool); ok { for key := range replicas { - rp.RemoveReplicaRecipe(key) + rp.RemoveReplicaRecipe(ctx, key) } return } @@ -302,8 +302,8 @@ func (T *Module) removeReplicaNodes(username string, database string, replicas m T.removeReplicaPool(username, database) } -func (T *Module) addReplicaNode(user User, database string, id string, replica Node) { - p := T.getOrAddPool(user, database) +func (T *Module) addReplicaNode(ctx context.Context, user User, database string, id string, replica Node) { + p := T.getOrAddPool(ctx, user, database) d := pool.Recipe{ Dialer: pool.Dialer{ @@ -321,15 +321,15 @@ func (T *Module) addReplicaNode(user User, database string, id string, replica N } if rp, ok := p.pool.(pool.ReplicaPool); ok { - rp.AddReplicaRecipe(id, &d) + rp.AddReplicaRecipe(ctx, id, &d) return } - rp := T.getOrAddReplicaPool(user, database) - rp.pool.AddRecipe(id, &d) + rp := T.getOrAddReplicaPool(ctx, user, database) + rp.pool.AddRecipe(ctx, id, &d) } -func (T *Module) removeReplicaNode(username string, database string, id string) { +func (T *Module) removeReplicaNode(ctx context.Context, username string, database string, id string) { p, ok := T.getPool(username, database) if !ok { return @@ -337,7 +337,7 @@ func (T *Module) removeReplicaNode(username string, database string, id string) // remove endpoints from replica pool if rp, ok := p.pool.(pool.ReplicaPool); ok { - rp.RemoveReplicaRecipe(id) + rp.RemoveReplicaRecipe(ctx, id) return } @@ -346,87 +346,87 @@ func (T *Module) removeReplicaNode(username string, database string, id string) if !ok { return } - rp.pool.RemoveRecipe(id) + rp.pool.RemoveRecipe(ctx, id) } // replacePrimary replaces the primary endpoint. -func (T *Module) replacePrimary(users []User, databases []string, primary Node) { +func (T *Module) replacePrimary(ctx context.Context, users []User, databases []string, primary Node) { for _, user := range users { for _, database := range databases { - T.addPrimaryNode(user, database, primary) + T.addPrimaryNode(ctx, user, database, primary) } } } // addReplicas adds multiple replicas. Other replicas must not exist. -func (T *Module) addReplicas(replicas map[string]Node, users []User, databases []string) { +func (T *Module) addReplicas(ctx context.Context, replicas map[string]Node, users []User, databases []string) { for _, user := range users { for _, database := range databases { - T.addReplicaNodes(user, database, replicas) + T.addReplicaNodes(ctx, user, database, replicas) } } } // removeReplicas removes all replicas. -func (T *Module) removeReplicas(replicas map[string]Node, users []User, databases []string) { +func (T *Module) removeReplicas(ctx context.Context, replicas map[string]Node, users []User, databases []string) { for _, user := range users { for _, database := range databases { - T.removeReplicaNodes(user.Username, database, replicas) + T.removeReplicaNodes(ctx, user.Username, database, replicas) } } } // addReplica adds a single replica. -func (T *Module) addReplica(users []User, databases []string, id string, replica Node) { +func (T *Module) addReplica(ctx context.Context, users []User, databases []string, id string, replica Node) { for _, user := range users { for _, database := range databases { - T.addReplicaNode(user, database, id, replica) + T.addReplicaNode(ctx, user, database, id, replica) } } } // removeReplica removes a single replica. -func (T *Module) removeReplica(users []User, databases []string, id string) { +func (T *Module) removeReplica(ctx context.Context, users []User, databases []string, id string) { for _, user := range users { for _, database := range databases { - T.removeReplicaNode(user.Username, database, id) + T.removeReplicaNode(ctx, user.Username, database, id) } } } // addUser adds a new user. -func (T *Module) addUser(primary Node, replicas map[string]Node, databases []string, user User) { +func (T *Module) addUser(ctx context.Context, primary Node, replicas map[string]Node, databases []string, user User) { for _, database := range databases { - T.addPrimaryNode(user, database, primary) - T.addReplicaNodes(user, database, replicas) + T.addPrimaryNode(ctx, user, database, primary) + T.addReplicaNodes(ctx, user, database, replicas) } } // removeUser removes a user. -func (T *Module) removeUser(replicas map[string]Node, databases []string, username string) { +func (T *Module) removeUser(ctx context.Context, replicas map[string]Node, databases []string, username string) { for _, database := range databases { - T.removeReplicaNodes(username, database, replicas) + T.removeReplicaNodes(ctx, username, database, replicas) T.removePrimaryNode(username, database) } } // addDatabase adds a new database. -func (T *Module) addDatabase(primary Node, replicas map[string]Node, users []User, database string) { +func (T *Module) addDatabase(ctx context.Context, primary Node, replicas map[string]Node, users []User, database string) { for _, user := range users { - T.addPrimaryNode(user, database, primary) - T.addReplicaNodes(user, database, replicas) + T.addPrimaryNode(ctx, user, database, primary) + T.addReplicaNodes(ctx, user, database, replicas) } } // removeDatabase removes a single database. -func (T *Module) removeDatabase(replicas map[string]Node, users []User, database string) { +func (T *Module) removeDatabase(ctx context.Context, replicas map[string]Node, users []User, database string) { for _, user := range users { - T.removeReplicaNodes(user.Username, database, replicas) + T.removeReplicaNodes(ctx, user.Username, database, replicas) T.removePrimaryNode(user.Username, database) } } -func (T *Module) removed(id string) { +func (T *Module) removed(ctx context.Context, id string) { cluster, ok := T.clusters[id] if !ok { return @@ -434,11 +434,11 @@ func (T *Module) removed(id string) { delete(T.clusters, id) for _, database := range cluster.Databases { - T.removeDatabase(cluster.Replicas, cluster.Users, database) + T.removeDatabase(ctx, cluster.Replicas, cluster.Users, database) } } -func (T *Module) reconcile() error { +func (T *Module) reconcile(ctx context.Context) error { clusters, err := T.discoverer.Clusters() if err != nil { return err @@ -447,9 +447,9 @@ func (T *Module) reconcile() error { for _, cluster := range clusters { prev, ok := T.clusters[cluster.ID] if !ok { - T.added(cluster) + T.added(ctx, cluster) } else { - T.updated(prev, cluster) + T.updated(ctx, prev, cluster) } } @@ -461,13 +461,13 @@ outer: continue outer } } - T.removed(id) + T.removed(ctx, id) } return nil } -func (T *Module) discoverLoop() { +func (T *Module) discoverLoop(ctx context.Context) { var reconcile <-chan time.Time if T.ReconcilePeriod != 0 { r := time.NewTicker(time.Duration(T.ReconcilePeriod)) @@ -478,11 +478,11 @@ func (T *Module) discoverLoop() { for { select { case cluster := <-T.discoverer.Added(): - T.added(cluster) + T.added(ctx,cluster) case id := <-T.discoverer.Removed(): - T.removed(id) + T.removed(ctx, id) case <-reconcile: - err := T.reconcile() + err := T.reconcile(ctx) if err != nil { T.log.Warn("failed to reconcile", zap.Error(err)) } @@ -513,7 +513,7 @@ func (T *Module) getCreds(user User) auth.Credentials { return creds } -func (T *Module) getOrAddPool(user User, database string) poolAndCredentials { +func (T *Module) getOrAddPool(ctx context.Context, user User, database string) poolAndCredentials { T.poolsMu.Lock() defer T.poolsMu.Unlock() if old, ok := T.pools.Load(user.Username, database); ok { @@ -522,7 +522,7 @@ func (T *Module) getOrAddPool(user User, database string) poolAndCredentials { creds := T.getCreds(user) p := poolAndCredentials{ - pool: T.poolFactory.NewPool(), + pool: T.poolFactory.NewPool(ctx), creds: creds, } T.pools.Store(user.Username, database, p) @@ -530,8 +530,8 @@ func (T *Module) getOrAddPool(user User, database string) poolAndCredentials { return p } -func (T *Module) getOrAddReplicaPool(user User, database string) poolAndCredentials { - return T.getOrAddPool(T.toReplicaUser(user), database) +func (T *Module) getOrAddReplicaPool(ctx context.Context, user User, database string) poolAndCredentials { + return T.getOrAddPool(ctx, T.toReplicaUser(user), database) } func (T *Module) getPool(user, database string) (poolAndCredentials, bool) { @@ -569,17 +569,17 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { }) } -func (T *Module) Handle(ctx context.Context, conn *fed.Conn) error { +func (T *Module) Handle(conn *fed.Conn) error { p, ok := T.getPool(conn.User, conn.Database) if !ok { return nil } - if err := frontends.Authenticate(ctx, conn, p.creds); err != nil { + if err := frontends.Authenticate(context.Background(), conn, p.creds); err != nil { return err } - return p.pool.Serve(ctx, conn) + return p.pool.Serve(context.Background(), conn) } func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { diff --git a/lib/gat/handlers/pgbouncer/module.go b/lib/gat/handlers/pgbouncer/module.go index 057eea53..bb43a149 100644 --- a/lib/gat/handlers/pgbouncer/module.go +++ b/lib/gat/handlers/pgbouncer/module.go @@ -1,6 +1,7 @@ package pgbouncer import ( + "context" "crypto/tls" "errors" "fmt" @@ -80,7 +81,7 @@ func (T *Module) Cleanup() error { defer T.mu.Unlock() T.pools.Range(func(user string, database string, p poolAndCredentials) bool { - p.pool.Close() + p.pool.Close(context.Background()) T.pools.Delete(user, database) return true }) @@ -88,7 +89,7 @@ func (T *Module) Cleanup() error { return nil } -func (T *Module) getPassword(user, database string) (string, bool) { +func (T *Module) getPassword(ctx context.Context, user, database string) (string, bool) { // try to get password password, ok := T.Config.PgBouncer.AuthFile[user] if !ok { @@ -105,7 +106,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { } } - authPool, ok := T.lookup(authUser, database) + authPool, ok := T.lookup(ctx, authUser, database) if !ok { return "", false } @@ -116,14 +117,14 @@ func (T *Module) getPassword(user, database string) (string, bool) { inward, outward, _, _ := gsql.NewPair() b.Queue(func() error { - if err := gsql.ExtendedQuery(inward, &result, T.Config.PgBouncer.AuthQuery, user); err != nil { + if err := gsql.ExtendedQuery(ctx, inward, &result, T.Config.PgBouncer.AuthQuery, user); err != nil { return err } - return inward.Close() + return inward.Close(ctx) }) b.Queue(func() error { - err := authPool.pool.Serve(outward) + err := authPool.pool.Serve(ctx, outward) if err != nil && !errors.Is(err, io.EOF) { return err } @@ -148,7 +149,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { return password, true } -func (T *Module) tryCreate(user, database string) (poolAndCredentials, bool) { +func (T *Module) tryCreate(ctx context.Context, user, database string) (poolAndCredentials, bool) { db, ok := T.Config.Databases[database] if !ok { // try wildcard @@ -159,7 +160,7 @@ func (T *Module) tryCreate(user, database string) (poolAndCredentials, bool) { } // try to get password - password, ok := T.getPassword(user, database) + password, ok := T.getPassword(ctx, user, database) if !ok { return poolAndCredentials{}, false } @@ -275,19 +276,19 @@ func (T *Module) tryCreate(user, database string) (poolAndCredentials, bool) { r.MaxConnections = T.Config.PgBouncer.MaxDBConnections } - p.pool.AddRecipe("pgbouncer", &r) + p.pool.AddRecipe(ctx, "pgbouncer", &r) return p, true } -func (T *Module) lookup(user, database string) (poolAndCredentials, bool) { +func (T *Module) lookup(ctx context.Context, user, database string) (poolAndCredentials, bool) { p, ok := T.pools.Load(user, database) if ok { return p, true } // try to create pool - return T.tryCreate(user, database) + return T.tryCreate(ctx, user, database) } func (T *Module) Handle(conn *fed.Conn) error { @@ -327,16 +328,18 @@ func (T *Module) Handle(conn *fed.Conn) error { ) } - p, ok := T.lookup(conn.User, conn.Database) + ctx := context.Background() + + p, ok := T.lookup(ctx, conn.User, conn.Database) if !ok { return nil } - if err := frontends.Authenticate(conn, p.creds); err != nil { + if err := frontends.Authenticate(ctx, conn, p.creds); err != nil { return err } - return p.pool.Serve(conn) + return p.pool.Serve(ctx, conn) } func (T *Module) ReadMetrics(metrics *metrics.Handler) { @@ -348,11 +351,11 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { }) } -func (T *Module) Cancel(key fed.BackendKey) { +func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { T.mu.RLock() defer T.mu.RUnlock() T.pools.Range(func(_ string, _ string, p poolAndCredentials) bool { - p.pool.Cancel(key) + p.pool.Cancel(ctx, key) return true }) } diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index 795201d5..5d2b6b94 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -40,22 +40,22 @@ func (T *Module) Provision(ctx caddy.Context) error { if err != nil { return err } - T.pool = raw.(PoolFactory).NewPool() + T.pool = raw.(PoolFactory).NewPool(ctx) if err = T.Recipe.Provision(ctx); err != nil { return err } - T.pool.AddRecipe("recipe", &T.Recipe) + T.pool.AddRecipe(ctx, "recipe", &T.Recipe) return nil } -func (T *Module) Handle(ctx context.Context, conn *fed.Conn) error { - if err := frontends.Authenticate(ctx, conn, nil); err != nil { +func (T *Module) Handle(conn *fed.Conn) error { + if err := frontends.Authenticate(context.Background(), conn, nil); err != nil { return err } - return T.pool.Serve(ctx, conn) + return T.pool.Serve(context.Background(), conn) } func (T *Module) ReadMetrics(metrics *metrics.Handler) { diff --git a/lib/gat/handlers/pool/pool.go b/lib/gat/handlers/pool/pool.go index 5cf3143f..e81efb17 100644 --- a/lib/gat/handlers/pool/pool.go +++ b/lib/gat/handlers/pool/pool.go @@ -9,9 +9,9 @@ import ( type Pool interface { // AddRecipe will add the recipe to the pool for use. The pool should delete any existing recipes with the same name // and scale the recipe to min. - AddRecipe(name string, recipe *Recipe) + AddRecipe(ctx context.Context, name string, recipe *Recipe) // RemoveRecipe will remove a recipe and disconnect all servers created by that recipe. - RemoveRecipe(name string) + RemoveRecipe(ctx context.Context, name string) Serve(ctx context.Context, conn *fed.Conn) error @@ -23,10 +23,10 @@ type Pool interface { type ReplicaPool interface { Pool - AddReplicaRecipe(name string, recipe *Recipe) - RemoveReplicaRecipe(name string) + AddReplicaRecipe(ctx context.Context, name string, recipe *Recipe) + RemoveReplicaRecipe(ctx context.Context, name string) } type PoolFactory interface { - NewPool() Pool + NewPool(ctx context.Context) Pool } diff --git a/lib/gat/handlers/pool/pools/basic/factory.go b/lib/gat/handlers/pool/pools/basic/factory.go index 44d58c95..eb90578d 100644 --- a/lib/gat/handlers/pool/pools/basic/factory.go +++ b/lib/gat/handlers/pool/pools/basic/factory.go @@ -1,6 +1,7 @@ package basic import ( + "context" "fmt" "github.com/caddyserver/caddy/v2" @@ -50,8 +51,8 @@ func (T *Factory) Provision(ctx caddy.Context) error { return nil } -func (T *Factory) NewPool() pool.Pool { - return NewPool(T.Config) +func (T *Factory) NewPool(ctx context.Context) pool.Pool { + return NewPool(ctx, T.Config) } var _ pool.PoolFactory = (*Factory)(nil) diff --git a/lib/gat/handlers/pool/pools/basic/pool.go b/lib/gat/handlers/pool/pools/basic/pool.go index 91c92ced..cc0f059d 100644 --- a/lib/gat/handlers/pool/pools/basic/pool.go +++ b/lib/gat/handlers/pool/pools/basic/pool.go @@ -28,21 +28,21 @@ type Pool struct { mu sync.RWMutex } -func NewPool(config Config) *Pool { +func NewPool(ctx context.Context, config Config) *Pool { p := &Pool{ config: config, servers: spool.MakePool(config.Spool()), } - go p.servers.ScaleLoop() + go p.servers.ScaleLoop(ctx) return p } -func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { - T.servers.AddRecipe(name, recipe) +func (T *Pool) AddRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + T.servers.AddRecipe(ctx, name, recipe) } -func (T *Pool) RemoveRecipe(name string) { - T.servers.RemoveRecipe(name) +func (T *Pool) RemoveRecipe(ctx context.Context, name string) { + T.servers.RemoveRecipe(ctx, name) } func (T *Pool) SyncInitialParameters(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { @@ -184,9 +184,9 @@ func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { defer func() { if server != nil { if serverErr != nil { - T.servers.RemoveServer(server) + T.servers.RemoveServer(ctx, server) } else { - T.servers.Release(server) + T.servers.Release(ctx, server) } server = nil } @@ -209,7 +209,7 @@ func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { } p := packets.ReadyForQuery('I') - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return err } @@ -220,7 +220,7 @@ func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { for { if server != nil && T.config.ReleaseAfterTransaction { client.SetState(metrics.ConnStateIdle, nil) - T.servers.Release(server) + T.servers.Release(ctx, server) server = nil } @@ -275,7 +275,7 @@ func (T *Pool) Cancel(ctx context.Context, key fed.BackendKey) { return } - T.servers.Cancel(peer) + T.servers.Cancel(ctx, peer) } func (T *Pool) ReadMetrics(m *metrics.Pool) { @@ -294,8 +294,8 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close() { - T.servers.Close() +func (T *Pool) Close(ctx context.Context) { + T.servers.Close(ctx) } var _ pool.Pool = (*Pool)(nil) diff --git a/lib/gat/handlers/pool/pools/hybrid/factory.go b/lib/gat/handlers/pool/pools/hybrid/factory.go index d189be6c..a15c6bbc 100644 --- a/lib/gat/handlers/pool/pools/hybrid/factory.go +++ b/lib/gat/handlers/pool/pools/hybrid/factory.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "fmt" "github.com/caddyserver/caddy/v2" @@ -44,8 +45,8 @@ func (T *Factory) Provision(ctx caddy.Context) error { return nil } -func (T *Factory) NewPool() pool.Pool { - return NewPool(T.Config) +func (T *Factory) NewPool(ctx context.Context) pool.Pool { + return NewPool(ctx, T.Config) } var _ pool.PoolFactory = (*Factory)(nil) diff --git a/lib/gat/handlers/pool/pools/hybrid/pool.go b/lib/gat/handlers/pool/pools/hybrid/pool.go index edfd45d7..f215be9d 100644 --- a/lib/gat/handlers/pool/pools/hybrid/pool.go +++ b/lib/gat/handlers/pool/pools/hybrid/pool.go @@ -29,7 +29,7 @@ type Pool struct { mu sync.RWMutex } -func NewPool(config Config) *Pool { +func NewPool(ctx context.Context, config Config) *Pool { c := config.Spool() p := &Pool{ @@ -38,25 +38,25 @@ func NewPool(config Config) *Pool { primary: spool.MakePool(c), replica: spool.MakePool(c), } - go p.primary.ScaleLoop() - go p.replica.ScaleLoop() + go p.primary.ScaleLoop(ctx) + go p.replica.ScaleLoop(ctx) return p } -func (T *Pool) AddReplicaRecipe(name string, recipe *pool.Recipe) { - T.replica.AddRecipe(name, recipe) +func (T *Pool) AddReplicaRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + T.replica.AddRecipe(ctx, name, recipe) } -func (T *Pool) RemoveReplicaRecipe(name string) { - T.replica.RemoveRecipe(name) +func (T *Pool) RemoveReplicaRecipe(ctx context.Context, name string) { + T.replica.RemoveRecipe(ctx, name) } -func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { - T.primary.AddRecipe(name, recipe) +func (T *Pool) AddRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + T.primary.AddRecipe(ctx, name, recipe) } -func (T *Pool) RemoveRecipe(name string) { - T.primary.RemoveRecipe(name) +func (T *Pool) RemoveRecipe(ctx context.Context, name string) { + T.primary.RemoveRecipe(ctx, name) } func (T *Pool) Pair(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { @@ -145,7 +145,7 @@ func (T *Pool) serveRW(ctx context.Context, conn *fed.Conn) error { defer func() { if primary != nil { if serverErr != nil { - T.primary.RemoveServer(primary) + T.primary.RemoveServer(ctx, primary) } else { T.primary.Release(ctx, primary) } @@ -153,7 +153,7 @@ func (T *Pool) serveRW(ctx context.Context, conn *fed.Conn) error { } if replica != nil { if serverErr != nil { - T.replica.RemoveServer(replica) + T.replica.RemoveServer(ctx, replica) } else { T.replica.Release(ctx, replica) } @@ -339,7 +339,7 @@ func (T *Pool) serveOnly(ctx context.Context, conn *fed.Conn, write bool) error defer func() { if server != nil { if serverErr != nil { - sp.RemoveServer(server) + sp.RemoveServer(ctx, server) } else { sp.Release(ctx, server) } @@ -439,9 +439,9 @@ func (T *Pool) Cancel(ctx context.Context, key fed.BackendKey) { } if replica { - T.replica.Cancel(peer) + T.replica.Cancel(ctx, peer) } else { - T.primary.Cancel(peer) + T.primary.Cancel(ctx, peer) } } @@ -462,9 +462,9 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close(_ context.Context) { - T.primary.Close() - T.replica.Close() +func (T *Pool) Close(ctx context.Context) { + T.primary.Close(ctx) + T.replica.Close(ctx) } var _ pool.Pool = (*Pool)(nil) diff --git a/lib/gat/handlers/pool/spool/kitchen/chef.go b/lib/gat/handlers/pool/spool/kitchen/chef.go index e4fa39b8..3eb2f873 100644 --- a/lib/gat/handlers/pool/spool/kitchen/chef.go +++ b/lib/gat/handlers/pool/spool/kitchen/chef.go @@ -1,6 +1,7 @@ package kitchen import ( + "context" "fmt" "math" "sort" @@ -36,7 +37,7 @@ func NewChef(config Config) *Chef { } // Learn will add a recipe to the kitchen. Returns initial removed and added conns -func (T *Chef) Learn(name string, recipe *pool.Recipe) (removed []*fed.Conn, added []*fed.Conn) { +func (T *Chef) Learn(ctx context.Context, name string, recipe *pool.Recipe) (removed []*fed.Conn, added []*fed.Conn) { n := recipe.AllocateInitial() added = make([]*fed.Conn, 0, n) for i := 0; i < n; i++ { @@ -56,7 +57,7 @@ func (T *Chef) Learn(name string, recipe *pool.Recipe) (removed []*fed.Conn, add T.mu.Lock() defer T.mu.Unlock() - removed = T.forget(name) + removed = T.forget(ctx, name) r := NewRecipe(recipe, added) @@ -77,7 +78,7 @@ func (T *Chef) Learn(name string, recipe *pool.Recipe) (removed []*fed.Conn, add return } -func (T *Chef) forget(name string) []*fed.Conn { +func (T *Chef) forget(ctx context.Context, name string) []*fed.Conn { r, ok := T.byName[name] if !ok { return nil @@ -88,7 +89,7 @@ func (T *Chef) forget(name string) []*fed.Conn { for conn := range r.conns { conns = append(conns, conn) - _ = conn.Close() + _ = conn.Close(ctx) r.recipe.Free() delete(T.byConn, conn) @@ -102,11 +103,11 @@ func (T *Chef) forget(name string) []*fed.Conn { // Forget will remove a recipe from the kitchen. All conn made with the recipe will be closed. Returns conns made with // recipe. -func (T *Chef) Forget(name string) []*fed.Conn { +func (T *Chef) Forget(ctx context.Context, name string) []*fed.Conn { T.mu.Lock() defer T.mu.Unlock() - return T.forget(name) + return T.forget(ctx, name) } func (T *Chef) Empty() bool { @@ -123,7 +124,7 @@ func (T *Chef) cook(r *Recipe) (*fed.Conn, error) { return r.recipe.Dial() } -func (T *Chef) score(r *Recipe) error { +func (T *Chef) score(ctx context.Context, r *Recipe) error { now := time.Now() r.ratings = slices.Resize(r.ratings, len(T.config.Critics)) @@ -164,7 +165,7 @@ func (T *Chef) score(r *Recipe) error { return err } defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() for i, critic := range critics { @@ -174,7 +175,7 @@ func (T *Chef) score(r *Recipe) error { var score int var validity time.Duration - score, validity, err = critic.Taste(conn) + score, validity, err = critic.Taste(ctx, conn) if err != nil { return err } @@ -211,12 +212,12 @@ func (T *Chef) score(r *Recipe) error { } // Cook will cook the best recipe -func (T *Chef) Cook() (*fed.Conn, error) { +func (T *Chef) Cook(ctx context.Context) (*fed.Conn, error) { T.mu.Lock() defer T.mu.Unlock() for _, r := range T.byName { - if err := T.score(r); err != nil { + if err := T.score(ctx, r); err != nil { r.score = math.MaxInt T.config.Logger.Error("failed to score recipe", zap.Error(err)) continue @@ -277,7 +278,7 @@ func (T *Chef) Cook() (*fed.Conn, error) { } // Burn forcefully closes conn and escorts it out of the kitchen. -func (T *Chef) Burn(conn *fed.Conn) { +func (T *Chef) Burn(ctx context.Context, conn *fed.Conn) { T.mu.Lock() defer T.mu.Unlock() @@ -286,14 +287,14 @@ func (T *Chef) Burn(conn *fed.Conn) { return } r.recipe.Free() - _ = conn.Close() + _ = conn.Close(ctx) delete(T.byConn, conn) delete(r.conns, conn) } // Ignite tries to Burn conn. If successful, conn is closed and returns true -func (T *Chef) Ignite(conn *fed.Conn) bool { +func (T *Chef) Ignite(ctx context.Context, conn *fed.Conn) bool { T.mu.Lock() defer T.mu.Unlock() @@ -304,14 +305,14 @@ func (T *Chef) Ignite(conn *fed.Conn) bool { if !r.recipe.TryFree() { return false } - _ = conn.Close() + _ = conn.Close(ctx) delete(T.byConn, conn) delete(r.conns, conn) return true } -func (T *Chef) Cancel(conn *fed.Conn) { +func (T *Chef) Cancel(ctx context.Context, conn *fed.Conn) { T.mu.Lock() defer T.mu.Unlock() @@ -320,10 +321,10 @@ func (T *Chef) Cancel(conn *fed.Conn) { return } - r.recipe.Cancel(conn.BackendKey) + r.recipe.Cancel(ctx, conn.BackendKey) } -func (T *Chef) Close() { +func (T *Chef) Close(ctx context.Context) { T.mu.Lock() defer T.mu.Unlock() @@ -331,7 +332,7 @@ func (T *Chef) Close() { T.order = T.order[:0] for conn, r := range T.byConn { r.recipe.Free() - _ = conn.Close() + _ = conn.Close(ctx) delete(T.byConn, conn) delete(r.conns, conn) diff --git a/lib/gat/handlers/pool/spool/pool.go b/lib/gat/handlers/pool/spool/pool.go index 285403dd..ef733f3d 100644 --- a/lib/gat/handlers/pool/spool/pool.go +++ b/lib/gat/handlers/pool/spool/pool.go @@ -47,9 +47,9 @@ func MakePool(config Config) Pool { } } -func NewPool(config Config) *Pool { +func NewPool(ctx context.Context, config Config) *Pool { p := MakePool(config) - go p.ScaleLoop() + go p.ScaleLoop(ctx) return &p } @@ -93,8 +93,8 @@ func (T *Pool) removeServer(conn *fed.Conn) { T.pooler.DeleteServer(server.ID) } -func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { - removed, added := T.chef.Learn(name, recipe) +func (T *Pool) AddRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + removed, added := T.chef.Learn(ctx, name, recipe) if len(removed) == 0 && len(added) == 0 { return } @@ -111,8 +111,8 @@ func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { } } -func (T *Pool) RemoveRecipe(name string) { - servers := T.chef.Forget(name) +func (T *Pool) RemoveRecipe(ctx context.Context, name string) { + servers := T.chef.Forget(ctx, name) if len(servers) == 0 { return } @@ -129,8 +129,8 @@ func (T *Pool) Empty() bool { return T.chef.Empty() } -func (T *Pool) ScaleUp() error { - server, err := T.chef.Cook() +func (T *Pool) ScaleUp(ctx context.Context) error { + server, err := T.chef.Cook(ctx) if err != nil { return err } @@ -143,7 +143,7 @@ func (T *Pool) ScaleUp() error { return nil } -func (T *Pool) ScaleDown(now time.Time) time.Duration { +func (T *Pool) ScaleDown(ctx context.Context, now time.Time) time.Duration { T.mu.Lock() defer T.mu.Unlock() @@ -159,7 +159,7 @@ func (T *Pool) ScaleDown(now time.Time) time.Duration { idle := now.Sub(since) if idle > T.config.IdleTimeout { // try to free - if T.chef.Ignite(s.Conn) { + if T.chef.Ignite(ctx, s.Conn) { delete(T.serversByID, s.ID) delete(T.serversByConn, s.Conn) T.pooler.DeleteServer(s.ID) @@ -175,7 +175,7 @@ func (T *Pool) ScaleDown(now time.Time) time.Duration { return m } -func (T *Pool) ScaleLoop() { +func (T *Pool) ScaleLoop(ctx context.Context) { idle := new(time.Timer) if T.config.IdleTimeout != 0 { idle = time.NewTimer(T.config.IdleTimeout) @@ -199,7 +199,7 @@ func (T *Pool) ScaleLoop() { ok := true for T.pooler.Waiters() > 0 { - if err := T.ScaleUp(); err != nil { + if err := T.ScaleUp(ctx); err != nil { ok = false break } @@ -223,7 +223,7 @@ func (T *Pool) ScaleLoop() { ok := true for T.pooler.Waiters() > 0 { - if err := T.ScaleUp(); err != nil { + if err := T.ScaleUp(ctx); err != nil { ok = false break } @@ -240,7 +240,7 @@ func (T *Pool) ScaleLoop() { } case now := <-idle.C: // scale down - idle.Reset(T.ScaleDown(now)) + idle.Reset(T.ScaleDown(ctx, now)) } } } @@ -282,7 +282,7 @@ func (T *Pool) Release(ctx context.Context, server *Server) { if err, _ := backends.QueryString(ctx, server.Conn, nil, T.config.ResetQuery); err != nil { T.config.Logger.Error("failed to run reset query", zap.Error(err)) - T.RemoveServer(server) + T.RemoveServer(ctx,server) return } } @@ -292,8 +292,8 @@ func (T *Pool) Release(ctx context.Context, server *Server) { server.SetState(metrics.ConnStateIdle, uuid.Nil) } -func (T *Pool) RemoveServer(server *Server) { - T.chef.Burn(server.Conn) +func (T *Pool) RemoveServer(ctx context.Context, server *Server) { + T.chef.Burn(ctx, server.Conn) T.pooler.DeleteServer(server.ID) T.mu.Lock() @@ -304,8 +304,8 @@ func (T *Pool) RemoveServer(server *Server) { T.pooler.DeleteServer(server.ID) } -func (T *Pool) Cancel(server *Server) { - T.chef.Cancel(server.Conn) +func (T *Pool) Cancel(ctx context.Context, server *Server) { + T.chef.Cancel(ctx, server.Conn) } func (T *Pool) ReadMetrics(m *metrics.Pool) { @@ -322,10 +322,10 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close() { +func (T *Pool) Close(ctx context.Context) { close(T.closed) - T.chef.Close() + T.chef.Close(ctx) T.pooler.Close() T.mu.Lock() diff --git a/lib/gat/handlers/rewrite_password/module.go b/lib/gat/handlers/rewrite_password/module.go index 0cec7729..7ee9f78b 100644 --- a/lib/gat/handlers/rewrite_password/module.go +++ b/lib/gat/handlers/rewrite_password/module.go @@ -1,6 +1,7 @@ package rewrite_password import ( + "context" "github.com/caddyserver/caddy/v2" "gfx.cafe/gfx/pggat/lib/auth/credentials" @@ -28,6 +29,7 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { func (T *Module) Handle(conn *fed.Conn) error { return frontends.Authenticate( + context.Background(), conn, credentials.FromString(conn.User, T.Password), ) -- GitLab