diff --git a/lib/gat/modes/pgbouncer/config.go b/lib/gat/modes/pgbouncer/config.go index d33887251411cbcc84e2179cc6328552a775f870..3d74dd0820c672534f5ab79e73e315c6da5ccf52 100644 --- a/lib/gat/modes/pgbouncer/config.go +++ b/lib/gat/modes/pgbouncer/config.go @@ -1,26 +1,16 @@ package pgbouncer import ( - "errors" "net" - "os" "strconv" "strings" - "time" "tuxpa.in/a/zlog/log" "pggat2/lib/bouncer" - "pggat2/lib/bouncer/backends/v0" "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/gat/pool" - "pggat2/lib/gat/pool/pools/session" - "pggat2/lib/gat/pool/pools/transaction" - - "pggat2/lib/auth/credentials" "pggat2/lib/gat" "pggat2/lib/util/encoding/ini" - "pggat2/lib/util/encoding/userlist" "pggat2/lib/util/flip" "pggat2/lib/util/strutil" ) @@ -167,6 +157,7 @@ type Database struct { ConnectQuery string `ini:"connect_query"` PoolMode PoolMode `ini:"pool_mode"` MaxDBConnections int `ini:"max_db_connections"` + AuthDBName string `ini:"auth_dbname"` StartupParameters map[strutil.CIString]string `ini:"*"` } @@ -273,112 +264,9 @@ func (T *Config) ListenAndServe() error { AllowedStartupOptions: allowedStartupParameters, } - pools := new(gat.PoolsMap) - - var authFile map[string]string - if T.PgBouncer.AuthFile != "" { - file, err := os.ReadFile(T.PgBouncer.AuthFile) - if err != nil { - return err - } - - authFile, err = userlist.Unmarshal(file) - if err != nil { - return err - } - } - - for name, user := range T.Users { - creds := credentials.Cleartext{ - Username: name, - Password: authFile[name], // TODO(garet) md5 and sasl - } - - for dbname, db := range T.Databases { - // filter out dbs specific to users - if db.User != "" && db.User != name { - continue - } - - // override dbname - if db.DBName != "" { - dbname = db.DBName - } - - // override poolmode - var poolMode PoolMode - if db.PoolMode != "" { - poolMode = db.PoolMode - } else if user.PoolMode != "" { - poolMode = user.PoolMode - } else { - poolMode = T.PgBouncer.PoolMode - } - - poolOptions := pool.Options{ - Credentials: creds, - TrackedParameters: trackedParameters, - ServerResetQuery: T.PgBouncer.ServerResetQuery, - ServerIdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)), - } - - var p *pool.Pool - switch poolMode { - case PoolModeSession: - p = session.NewPool(poolOptions) - case PoolModeTransaction: - if T.PgBouncer.ServerResetQueryAlways == 0 { - poolOptions.ServerResetQuery = "" - } - p = transaction.NewPool(poolOptions) - default: - return errors.New("unsupported pool mode") - } - - pools.Add(name, dbname, p) - - if db.Host == "" { - // connect over unix socket - // TODO(garet) - } else { - var address string - if db.Port == 0 { - address = net.JoinHostPort(db.Host, "5432") - } else { - address = net.JoinHostPort(db.Host, strconv.Itoa(db.Port)) - } - - creds := creds - if db.Password != "" { - // lookup password - creds.Password = db.Password - } - - // connect over tcp - dialer := pool.NetDialer{ - Network: "tcp", - Address: address, - AcceptOptions: backends.AcceptOptions{ - Credentials: creds, - Database: dbname, - StartupParameters: db.StartupParameters, - }, - } - recipe := pool.Recipe{ - Dialer: dialer, - MinConnections: db.MinPoolSize, - MaxConnections: db.MaxDBConnections, - } - if recipe.MinConnections == 0 { - recipe.MinConnections = T.PgBouncer.MinPoolSize - } - if recipe.MaxConnections == 0 { - recipe.MaxConnections = T.PgBouncer.MaxDBConnections - } - - p.AddRecipe("pgbouncer", recipe) - } - } + pools, err := NewPools(T) + if err != nil { + return err } var bank flip.Bank diff --git a/lib/gat/modes/pgbouncer/pools.go b/lib/gat/modes/pgbouncer/pools.go new file mode 100644 index 0000000000000000000000000000000000000000..faa0ebdf3ee399df5f8eb86b43ab84a28b4ab72d --- /dev/null +++ b/lib/gat/modes/pgbouncer/pools.go @@ -0,0 +1,243 @@ +package pgbouncer + +import ( + "net" + "os" + "strconv" + "time" + + "tuxpa.in/a/zlog/log" + + "pggat2/lib/auth/credentials" + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/gat" + "pggat2/lib/gat/pool" + "pggat2/lib/gat/pool/pools/session" + "pggat2/lib/gat/pool/pools/transaction" + "pggat2/lib/psql" + "pggat2/lib/util/encoding/userlist" + "pggat2/lib/util/strutil" + "pggat2/lib/zap" +) + +type authQueryResult struct { + Username string `ini:"usename"` + Password *string `ini:"passwd"` +} + +type poolKey struct { + User string + Database string +} + +type Pools struct { + Config *Config + + AuthFile map[string]string + + pools map[poolKey]*pool.Pool + keys map[[8]byte]*pool.Pool +} + +func NewPools(config *Config) (*Pools, error) { + pools := &Pools{ + Config: config, + } + + if config.PgBouncer.AuthFile != "" { + file, err := os.ReadFile(config.PgBouncer.AuthFile) + if err != nil { + return nil, err + } + + pools.AuthFile, err = userlist.Unmarshal(file) + if err != nil { + return nil, err + } + } + + return pools, nil +} + +func (T *Pools) Lookup(user, database string) *pool.Pool { + key := poolKey{ + User: user, + Database: database, + } + p := T.pools[key] + if p != nil { + return p + } + + // create pool + db, ok := T.Config.Databases[database] + if !ok { + // try wildcard + db, ok = T.Config.Databases["*"] + if !ok { + return nil + } + } + + password, ok := T.AuthFile[user] + if !ok { + // try auth query + authUser := db.AuthUser + if authUser == "" { + authUser = T.Config.PgBouncer.AuthUser + } + + if authUser == "" { + // user not present in auth file + return nil + } + + if T.Config.PgBouncer.AuthQuery == "" { + // no auth query + return nil + } + + // auth user should be in auth file + if authUser == user { + return nil + } + + authPool := T.Lookup(authUser, database) + if authPool == nil { + return nil + } + + var result authQueryResult + err := authPool.Do(func(server zap.Conn) error { + return psql.Query(server, &result, T.Config.PgBouncer.AuthQuery, user) + }) + if err != nil { + log.Println("auth query failed:", err) + return nil + } + + if result.Password != nil { + password = *result.Password + } + } + + creds := credentials.Cleartext{ + Username: user, + Password: password, // TODO(garet) md5 and sasl + } + + backendDatabase := database + if db.DBName != "" { + backendDatabase = db.DBName + } + + configUser := T.Config.Users[user] + + poolMode := db.PoolMode + if poolMode == "" { + poolMode = configUser.PoolMode + } + if poolMode == "" { + poolMode = T.Config.PgBouncer.PoolMode + } + + trackedParameters := append([]strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, T.Config.PgBouncer.TrackExtraParameters...) + + poolOptions := pool.Options{ + Credentials: creds, + TrackedParameters: trackedParameters, + ServerResetQuery: T.Config.PgBouncer.ServerResetQuery, + ServerIdleTimeout: time.Duration(T.Config.PgBouncer.ServerIdleTimeout * float64(time.Second)), + } + + switch poolMode { + case PoolModeSession: + p = session.NewPool(poolOptions) + case PoolModeTransaction: + if T.Config.PgBouncer.ServerResetQueryAlways == 0 { + poolOptions.ServerResetQuery = "" + } + p = transaction.NewPool(poolOptions) + default: + return nil + } + + if T.pools == nil { + T.pools = make(map[poolKey]*pool.Pool) + } + T.pools[poolKey{ + User: user, + Database: database, + }] = p + + if db.Host == "" { + // connect over unix socket + // TODO(garet) + } else { + var address string + if db.Port == 0 { + address = net.JoinHostPort(db.Host, "5432") + } else { + address = net.JoinHostPort(db.Host, strconv.Itoa(db.Port)) + } + + creds := creds + if db.Password != "" { + // lookup password + creds.Password = db.Password + } + + // connect over tcp + dialer := pool.NetDialer{ + Network: "tcp", + Address: address, + AcceptOptions: backends.AcceptOptions{ + Credentials: creds, + Database: backendDatabase, + StartupParameters: db.StartupParameters, + }, + } + recipe := pool.Recipe{ + Dialer: dialer, + MinConnections: db.MinPoolSize, + MaxConnections: db.MaxDBConnections, + } + if recipe.MinConnections == 0 { + recipe.MinConnections = T.Config.PgBouncer.MinPoolSize + } + if recipe.MaxConnections == 0 { + recipe.MaxConnections = T.Config.PgBouncer.MaxDBConnections + } + + p.AddRecipe("pgbouncer", recipe) + } + + return p +} + +func (T *Pools) RegisterKey(key [8]byte, user, database string) { + p := T.Lookup(user, database) + if p == nil { + return + } + if T.keys == nil { + T.keys = make(map[[8]byte]*pool.Pool) + } + T.keys[key] = p +} + +func (T *Pools) UnregisterKey(key [8]byte) { + delete(T.keys, key) +} + +func (T *Pools) LookupKey(key [8]byte) *pool.Pool { + return T.keys[key] +} + +var _ gat.Pools = (*Pools)(nil) diff --git a/lib/gat/modes/zalando/config.go b/lib/gat/modes/zalando/config.go index bca45eb93b84ff9a534c9c7e5353f04c3fb64036..55670ff07041c83509415936531bc3dc7a14e17c 100644 --- a/lib/gat/modes/zalando/config.go +++ b/lib/gat/modes/zalando/config.go @@ -3,22 +3,12 @@ package zalando import ( "errors" "fmt" - "net" - "strconv" - - "tuxpa.in/a/zlog/log" "gfx.cafe/util/go/gun" - "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/gat/pool/pools/session" - "pggat2/lib/gat/pool/pools/transaction" - - "pggat2/lib/auth/credentials" - "pggat2/lib/gat" - "pggat2/lib/gat/pool" - "pggat2/lib/util/flip" + "pggat2/lib/bouncer" + "pggat2/lib/gat/modes/pgbouncer" + "pggat2/lib/util/strutil" ) type Config struct { @@ -47,44 +37,48 @@ func Load() (Config, error) { } func (T *Config) ListenAndServe() error { - pools := new(gat.PoolsMap) - - creds := credentials.Cleartext{ - Username: T.PGUser, - Password: T.PGPassword, + pgb := pgbouncer.Default + if pgb.Databases == nil { + pgb.Databases = make(map[string]pgbouncer.Database) } - - var p *pool.Pool - if T.PoolerMode == "transaction" { - p = transaction.NewPool(pool.Options{}) - } else { - p = session.NewPool(pool.Options{}) + pgb.Databases["*"] = pgbouncer.Database{ + Host: T.PGHost, + Port: T.PGPort, + AuthUser: T.PGUser, + } + pgb.PgBouncer.PoolMode = pgbouncer.PoolMode(T.PoolerMode) + pgb.PgBouncer.ListenPort = T.PoolerPort + pgb.PgBouncer.ListenAddr = "*" + pgb.PgBouncer.AuthType = "md5" + pgb.PgBouncer.AuthFile = "/etc/pgbouncer/auth_file.txt" + pgb.PgBouncer.AdminUsers = []string{T.PGUser} + pgb.PgBouncer.AuthQuery = fmt.Sprintf("SELECT * FROM %s.user_lookup($1)", T.PGSchema) + pgb.PgBouncer.LogFile = "/var/olg/pgbouncer/pgbouncer.log" + pgb.PgBouncer.PidFile = "/var/run/pgbouncer/pgbouncer.pid" + + pgb.PgBouncer.ServerTLSSSLMode = bouncer.SSLModeRequire + pgb.PgBouncer.ServerTLSCaFile = "/etc/ssl/certs/pgbouncer.crt" + pgb.PgBouncer.ServerTLSProtocols = []pgbouncer.TLSProtocol{ + pgbouncer.TLSProtocolSecure, + } + pgb.PgBouncer.ClientTLSSSLMode = bouncer.SSLModeRequire + pgb.PgBouncer.ClientTLSKeyFile = "/etc/ssl/certs/pgbouncer.key" + pgb.PgBouncer.ClientTLSCertFile = "/etc/ssl/certs/pgbouncer.crt" + + pgb.PgBouncer.LogConnections = 0 + pgb.PgBouncer.LogDisconnections = 0 + + pgb.PgBouncer.DefaultPoolSize = T.PoolerDefaultSize + pgb.PgBouncer.ReservePoolSize = T.PoolerReserveSize + pgb.PgBouncer.MaxClientConn = T.PoolerMaxClientConn + pgb.PgBouncer.MaxDBConnections = T.PoolerMaxDBConn + pgb.PgBouncer.IdleTransactionTimeout = 600 + pgb.PgBouncer.ServerLoginRetry = 5 + + pgb.PgBouncer.IgnoreStartupParameters = []strutil.CIString{ + strutil.MakeCIString("extra_float_digits"), + strutil.MakeCIString("options"), } - pools.Add(T.PGUser, "test", p) - - p.AddRecipe("zalando", pool.Recipe{ - Dialer: pool.NetDialer{ - Network: "tcp", - Address: net.JoinHostPort(T.PGHost, strconv.Itoa(T.PGPort)), - AcceptOptions: backends.AcceptOptions{ - Credentials: creds, - Database: "test", - }, - }, - MinConnections: T.PoolerMinSize, - MaxConnections: T.PoolerMaxDBConn, - }) - - var bank flip.Bank - - bank.Queue(func() error { - listen := fmt.Sprintf(":%d", T.PoolerPort) - - log.Printf("listening on %s", listen) - - return gat.ListenAndServe("tcp", listen, frontends.AcceptOptions{}, pools) - }) - - return bank.Wait() + return pgb.ListenAndServe() } diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index f0c2ec3d9df8343b9de26d6a154c6db3af618d3d..8348f67e836ec51bc138cb4930c0622b0c40b5ed 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -267,6 +267,16 @@ func (T *Pool) syncInitialParameters( return } +func (T *Pool) Do(fn func(zap.Conn) error) error { + id := T.addClient(nil, [8]byte{}) + defer T.removeClient(id) + + serverID, server := T.acquireServer(id) + defer T.releaseServer(serverID) + + return fn(server.conn) +} + func (T *Pool) Serve( client zap.Conn, accept frontends.AcceptParams, @@ -300,6 +310,7 @@ func (T *Pool) Serve( ) clientID := T.addClient(client, auth.BackendKey) + defer T.removeClient(clientID) var serverID uuid.UUID var server *poolServer @@ -370,6 +381,14 @@ func (T *Pool) addClient(client zap.Conn, key [8]byte) uuid.UUID { return clientID } +func (T *Pool) removeClient(clientID uuid.UUID) { + T.mu.Lock() + defer T.mu.Unlock() + + delete(T.clients, clientID) + T.options.Pooler.RemoveClient(clientID) +} + func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server *poolServer) { serverID = T.options.Pooler.AcquireConcurrent(clientID) if serverID == uuid.Nil {