From a10e92b19a1ee5989dd9da6c88b1127b1ab8fc8e Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Tue, 29 Aug 2023 16:33:58 -0500 Subject: [PATCH] session pool working fine --- cmd/cgat/main.go | 79 ++++++++----------------------- lib/gat/acceptor.go | 48 +++++++++++++++++-- lib/gat/gat.go | 51 -------------------- lib/gat/modes/pgbouncer/config.go | 38 ++++++--------- lib/gat/modes/zalando/config.go | 14 ++---- lib/gat/pool/pool.go | 60 ++++++++++++----------- lib/gat/pool/recipe.go | 2 + lib/gat/pools.go | 75 +++++++++++++++++++++++++++++ 8 files changed, 193 insertions(+), 174 deletions(-) delete mode 100644 lib/gat/gat.go create mode 100644 lib/gat/pools.go diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 42981009..e3235210 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -1,19 +1,14 @@ package main import ( - "crypto/tls" "net/http" _ "net/http/pprof" + "os" "tuxpa.in/a/zlog/log" - "pggat2/lib/auth/credentials" - "pggat2/lib/bouncer" - "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/gat" - "pggat2/lib/gat/pool" - "pggat2/lib/gat/pool/pools/session" + "pggat2/lib/gat/modes/pgbouncer" + "pggat2/lib/gat/modes/zalando" ) func main() { @@ -23,56 +18,9 @@ func main() { log.Printf("Starting pggat...") - g := new(gat.Gat) - g.TestPool = session.NewPool(pool.Options{ - Credentials: credentials.Cleartext{ - Username: "postgres", - Password: "password", - }, - }) - g.TestPool.AddRecipe("test", pool.Recipe{ - Dialer: pool.NetDialer{ - Network: "tcp", - Address: "localhost:5432", - - AcceptOptions: backends.AcceptOptions{ - SSLMode: bouncer.SSLModeAllow, - SSLConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - Credentials: credentials.Cleartext{ - Username: "postgres", - Password: "password", - }, - Database: "postgres", - }, - }, - MinConnections: 1, - MaxConnections: 1, - }) - err := gat.ListenAndServe("tcp", ":6432", frontends.AcceptOptions{}, g) - if err != nil { - panic(err) - } - - /* - if len(os.Args) == 2 { - log.Printf("running in pgbouncer compatibility mode") - conf, err := pgbouncer.Load(os.Args[1]) - if err != nil { - panic(err) - } - - err = conf.ListenAndServe() - if err != nil { - panic(err) - } - return - } - - log.Printf("running in zalando compatibility mode") - - conf, err := zalando.Load() + if len(os.Args) == 2 { + log.Printf("running in pgbouncer compatibility mode") + conf, err := pgbouncer.Load(os.Args[1]) if err != nil { panic(err) } @@ -81,5 +29,18 @@ func main() { if err != nil { panic(err) } - */ + return + } + + log.Printf("running in zalando compatibility mode") + + conf, err := zalando.Load() + if err != nil { + panic(err) + } + + err = conf.ListenAndServe() + if err != nil { + panic(err) + } } diff --git a/lib/gat/acceptor.go b/lib/gat/acceptor.go index 7dfa49b3..53c718e0 100644 --- a/lib/gat/acceptor.go +++ b/lib/gat/acceptor.go @@ -3,6 +3,7 @@ package gat import ( "net" + "pggat2/lib/auth" "pggat2/lib/bouncer/frontends/v0" "pggat2/lib/zap" ) @@ -37,22 +38,59 @@ func Listen(network, address string, options frontends.AcceptOptions) (Acceptor, }, nil } -func Serve(acceptor Acceptor, gat *Gat) error { +func serve(client zap.Conn, acceptParams frontends.AcceptParams, pools Pools) error { + defer func() { + _ = client.Close() + }() + + if acceptParams.CancelKey != [8]byte{} { + p := pools.LookupKey(acceptParams.CancelKey) + if p == nil { + return nil + } + return p.Cancel(acceptParams.CancelKey) + } + + p := pools.Lookup(acceptParams.User, acceptParams.Database) + + var credentials auth.Credentials + if p != nil { + credentials = p.GetCredentials() + } + + authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{ + Credentials: credentials, + }) + if err != nil { + return err + } + + if p == nil { + return nil + } + + pools.RegisterKey(authParams.BackendKey, acceptParams.User, acceptParams.Database) + defer pools.UnregisterKey(authParams.BackendKey) + + return p.Serve(client, acceptParams, authParams) +} + +func Serve(acceptor Acceptor, pools Pools) error { for { - conn, params, err := acceptor.Accept() + conn, acceptParams, err := acceptor.Accept() if err != nil { continue } go func() { - _ = gat.Serve(conn, params) + _ = serve(conn, acceptParams, pools) }() } } -func ListenAndServe(network, address string, options frontends.AcceptOptions, gat *Gat) error { +func ListenAndServe(network, address string, options frontends.AcceptOptions, pools Pools) error { listener, err := Listen(network, address, options) if err != nil { return err } - return Serve(listener, gat) + return Serve(listener, pools) } diff --git a/lib/gat/gat.go b/lib/gat/gat.go deleted file mode 100644 index 8cadbb01..00000000 --- a/lib/gat/gat.go +++ /dev/null @@ -1,51 +0,0 @@ -package gat - -import ( - "pggat2/lib/auth" - "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/gat/pool" - "pggat2/lib/zap" -) - -type Gat struct { - TestPool *pool.Pool -} - -func (T *Gat) Serve(client zap.Conn, acceptParams frontends.AcceptParams) error { - defer func() { - _ = client.Close() - }() - - if acceptParams.CancelKey != [8]byte{} { - // TODO(garet) execute cancel - return nil - } - - p, err := T.GetPool(acceptParams.User, acceptParams.Database) - if err != nil { - return err - } - - var credentials auth.Credentials - if p != nil { - credentials = p.GetCredentials() - } - - authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{ - Credentials: credentials, - }) - if err != nil { - return err - } - - if p == nil { - return nil - } - - return p.Serve(client, acceptParams, authParams) -} - -func (T *Gat) GetPool(user, database string) (*pool.Pool, error) { - return T.TestPool, nil - return nil, nil // TODO(garet) -} diff --git a/lib/gat/modes/pgbouncer/config.go b/lib/gat/modes/pgbouncer/config.go index aad103ff..dfd7eca9 100644 --- a/lib/gat/modes/pgbouncer/config.go +++ b/lib/gat/modes/pgbouncer/config.go @@ -10,6 +10,7 @@ import ( "tuxpa.in/a/zlog/log" + "pggat2/lib/bouncer" "pggat2/lib/bouncer/backends/v0" "pggat2/lib/bouncer/frontends/v0" "pggat2/lib/gat/pool" @@ -44,17 +45,6 @@ const ( AuthTypePam AuthType = "pam" ) -type SSLMode string - -const ( - SSLModeDisable SSLMode = "disable" - SSLModeAllow SSLMode = "allow" - SSLModePrefer SSLMode = "prefer" - SSLModeRequire SSLMode = "require" - SSLModeVerifyCa SSLMode = "verify-ca" - SSLModeVerifyFull SSLMode = "verify-full" -) - type TLSProtocol string const ( @@ -130,7 +120,7 @@ type PgBouncer struct { DnsNxdomainTtl float64 `ini:"dns_nxdomain_ttl"` DnsZoneCheckPeriod float64 `ini:"dns_zone_check_period"` ResolvConf string `ini:"resolv.conf"` - ClientTLSSSLMode SSLMode `ini:"client_tls_sslmode"` + ClientTLSSSLMode bouncer.SSLMode `ini:"client_tls_sslmode"` ClientTLSKeyFile string `ini:"client_tls_key_file"` ClientTLSCertFile string `ini:"client_tls_cert_file"` ClientTLSCaFile string `ini:"client_tls_ca_file"` @@ -138,7 +128,7 @@ type PgBouncer struct { ClientTLSCiphers []TLSCipher `ini:"client_tls_ciphers"` ClientTLSECDHCurve TLSECDHCurve `ini:"client_tls_ecdhcurve"` ClientTLSDHEParams TLSDHEParams `ini:"client_tls_dheparams"` - ServerTLSSSLMode SSLMode `ini:"server_tls_sslmode"` + ServerTLSSSLMode bouncer.SSLMode `ini:"server_tls_sslmode"` ServerTLSCaFile string `ini:"server_tls_ca_file"` ServerTLSKeyFile string `ini:"server_tls_key_file"` ServerTLSCertFile string `ini:"server_tls_cert_file"` @@ -229,7 +219,7 @@ var Default = Config{ AutodbIdleTimeout: 3600.0, DnsMaxTtl: 15.0, DnsNxdomainTtl: 15.0, - ClientTLSSSLMode: SSLModeDisable, + ClientTLSSSLMode: bouncer.SSLModeDisable, ClientTLSProtocols: []TLSProtocol{ TLSProtocolSecure, }, @@ -237,7 +227,7 @@ var Default = Config{ "fast", }, ClientTLSECDHCurve: "auto", - ServerTLSSSLMode: SSLModePrefer, + ServerTLSSSLMode: bouncer.SSLModePrefer, ServerTLSProtocols: []TLSProtocol{ TLSProtocolSecure, }, @@ -282,7 +272,7 @@ func (T *Config) ListenAndServe() error { AllowedStartupOptions: allowedStartupParameters, } - g := new(gat.Gat) + pools := new(gat.PoolsMap) var authFile map[string]string if T.PgBouncer.AuthFile != "" { @@ -302,10 +292,6 @@ func (T *Config) ListenAndServe() error { Username: name, Password: authFile[name], // TODO(garet) md5 and sasl } - /* TODO(garet) - u := gat.NewUser(creds) - g.AddUser(u) - */ for dbname, db := range T.Databases { // filter out dbs specific to users @@ -329,7 +315,9 @@ func (T *Config) ListenAndServe() error { } poolOptions := pool.Options{ + Credentials: creds, TrackedParameters: trackedParameters, + ServerResetQuery: T.PgBouncer.ServerResetQuery, ServerIdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)), } @@ -338,12 +326,16 @@ func (T *Config) ListenAndServe() error { case PoolModeSession: p = session.NewPool(poolOptions) case PoolModeTransaction: + if T.PgBouncer.ServerResetQueryAlways == 0 { + poolOptions.ServerResetQuery = "" + } + panic("transaction mode not implemented yet") // TODO(garet) default: return errors.New("unsupported pool mode") } - // TODO(garet) add to gat + pools.Add(name, dbname, p) if db.Host == "" { // connect over unix socket @@ -402,7 +394,7 @@ func (T *Config) ListenAndServe() error { log.Printf("listening on %s", listen) - return gat.ListenAndServe("tcp", listen, acceptOptions, g) + return gat.ListenAndServe("tcp", listen, acceptOptions, pools) }) } @@ -418,7 +410,7 @@ func (T *Config) ListenAndServe() error { log.Printf("listening on unix:%s", dir) - return gat.ListenAndServe("unix", dir, acceptOptions, g) + return gat.ListenAndServe("unix", dir, acceptOptions, pools) }) return bank.Wait() diff --git a/lib/gat/modes/zalando/config.go b/lib/gat/modes/zalando/config.go index 806d43ec..1a2a2739 100644 --- a/lib/gat/modes/zalando/config.go +++ b/lib/gat/modes/zalando/config.go @@ -46,26 +46,22 @@ func Load() (Config, error) { } func (T *Config) ListenAndServe() error { - g := new(gat.Gat) + pools := new(gat.PoolsMap) creds := credentials.Cleartext{ Username: T.PGUser, Password: T.PGPassword, } - /* TODO(garet) - user := gat.NewUser(creds) - g.AddUser(user) - */ - var p *pool.Pool if T.PoolerMode == "transaction" { - // p = transaction.NewPool(pool.Options{}) + panic("transaction mode not implemented yet") + // TODO(garet) p = transaction.NewPool(pool.Options{}) } else { p = session.NewPool(pool.Options{}) } - // TODO(garet) add to gat + pools.Add(T.PGUser, "test", p) p.AddRecipe("zalando", pool.Recipe{ Dialer: pool.NetDialer{ @@ -87,7 +83,7 @@ func (T *Config) ListenAndServe() error { log.Printf("listening on %s", listen) - return gat.ListenAndServe("tcp", listen, frontends.AcceptOptions{}, g) + return gat.ListenAndServe("tcp", listen, frontends.AcceptOptions{}, pools) }) return bank.Wait() diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 0654b39f..a7a2313b 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -39,17 +39,24 @@ type poolRecipe struct { type Pool struct { options Options - maxServers int - recipes map[string]*poolRecipe - servers map[uuid.UUID]poolServer - clients map[uuid.UUID]zap.Conn - mu sync.Mutex + recipes map[string]*poolRecipe + servers map[uuid.UUID]poolServer + clients map[uuid.UUID]zap.Conn + mu sync.Mutex } func NewPool(options Options) *Pool { - return &Pool{ + p := &Pool{ options: options, } + + if options.ServerIdleTimeout != 0 { + go func() { + // TODO(garet) check pool for idle servers + }() + } + + return p } func (T *Pool) GetCredentials() auth.Credentials { @@ -62,6 +69,7 @@ func (T *Pool) _scaleUpRecipe(name string) { server, params, err := r.recipe.Dialer.Dial() if err != nil { log.Printf("failed to dial server: %v", err) + return } serverID := uuid.New() @@ -103,7 +111,6 @@ func (T *Pool) AddRecipe(name string, recipe Recipe) { if T.recipes == nil { T.recipes = make(map[string]*poolRecipe) } - T.maxServers += recipe.MaxConnections T.recipes[name] = &poolRecipe{ recipe: recipe, count: 0, @@ -118,9 +125,6 @@ func (T *Pool) RemoveRecipe(name string) { T.mu.Lock() defer T.mu.Unlock() - if r, ok := T.recipes[name]; ok { - T.maxServers -= r.count - } delete(T.recipes, name) // close all servers with this recipe @@ -138,11 +142,13 @@ func (T *Pool) scaleUp() { defer T.mu.Unlock() for name, r := range T.recipes { - if r.count < r.recipe.MaxConnections { + if r.recipe.MaxConnections == 0 || r.count < r.recipe.MaxConnections { T._scaleUpRecipe(name) return } } + + log.Println("warning: tried to scale up pool but no space was available") } func (T *Pool) syncInitialParameters( @@ -185,21 +191,16 @@ func (T *Pool) syncInitialParameters( continue } - if slices.Contains(T.options.TrackedParameters, key) { - serverErr = backends.ResetParameter(new(backends.Context), server, key) - if serverErr != nil { - return - } - } else { - // send to client - p := packets.ParameterStatus{ - Key: key.String(), - Value: value, - } - clientErr = client.WritePacket(p.IntoPacket()) - if clientErr != nil { - return - } + // Don't need to run reset on server because it will reset it to the initial value + + // send to client + p := packets.ParameterStatus{ + Key: key.String(), + Value: value, + } + clientErr = client.WritePacket(p.IntoPacket()) + if clientErr != nil { + return } } @@ -270,7 +271,7 @@ func (T *Pool) Serve( server.eqpServer.SetClient(eqpClient) } } - if clientErr != nil && serverErr != nil { + if clientErr == nil && serverErr == nil { clientErr, serverErr = bouncers.Bounce(client, server.conn, packet) } if serverErr != nil { @@ -352,3 +353,8 @@ func (T *Pool) removeServer(serverID uuid.UUID) { T._removeServer(serverID) } + +func (T *Pool) Cancel(key [8]byte) error { + // TODO(garet) implement cancel + return nil +} diff --git a/lib/gat/pool/recipe.go b/lib/gat/pool/recipe.go index 66260a8f..e12c8a24 100644 --- a/lib/gat/pool/recipe.go +++ b/lib/gat/pool/recipe.go @@ -3,5 +3,7 @@ package pool type Recipe struct { Dialer Dialer MinConnections int + // MaxConnections is the max number of active server connections for this recipe. + // 0 = unlimited MaxConnections int } diff --git a/lib/gat/pools.go b/lib/gat/pools.go new file mode 100644 index 00000000..6cb9e834 --- /dev/null +++ b/lib/gat/pools.go @@ -0,0 +1,75 @@ +package gat + +import ( + "pggat2/lib/gat/pool" + "pggat2/lib/util/maps" +) + +type Pools interface { + Lookup(user, database string) *pool.Pool + + // Key based lookup functions (for cancellation) + + RegisterKey(key [8]byte, user, database string) + UnregisterKey(key [8]byte) + + LookupKey(key [8]byte) *pool.Pool +} + +type mapKey struct { + User string + Database string +} + +type PoolsMap struct { + pools maps.RWLocked[mapKey, *pool.Pool] + keys maps.RWLocked[[8]byte, mapKey] +} + +func (T *PoolsMap) Add(user, database string, pool *pool.Pool) { + T.pools.Store(mapKey{ + User: user, + Database: database, + }, pool) +} + +func (T *PoolsMap) Remove(user, database string) { + T.pools.Delete(mapKey{ + User: user, + Database: database, + }) +} + +func (T *PoolsMap) Lookup(user, database string) *pool.Pool { + p, _ := T.pools.Load(mapKey{ + User: user, + Database: database, + }) + return p +} + +// key based lookup funcs + +func (T *PoolsMap) RegisterKey(key [8]byte, user, database string) { + T.keys.Store(key, mapKey{ + User: user, + Database: database, + }) +} + +func (T *PoolsMap) UnregisterKey(key [8]byte) { + T.keys.Delete(key) +} + +func (T *PoolsMap) LookupKey(key [8]byte) *pool.Pool { + m, ok := T.keys.Load(key) + if !ok { + return nil + } + p, ok := T.pools.Load(m) + if !ok { + T.keys.Delete(key) + return nil + } + return p +} -- GitLab