diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 3b82b28eba24ad1d7e88f449f3367589ce6c8a6c..42981009bfaac1d419169bcf754fff943984fb0b 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -12,7 +12,8 @@ import ( "pggat2/lib/bouncer/backends/v0" "pggat2/lib/bouncer/frontends/v0" "pggat2/lib/gat" - "pggat2/lib/gat/pools/session" + "pggat2/lib/gat/pool" + "pggat2/lib/gat/pool/pools/session" ) func main() { @@ -23,14 +24,14 @@ func main() { log.Printf("Starting pggat...") g := new(gat.Gat) - g.TestPool = session.NewPool(gat.PoolOptions{ + g.TestPool = session.NewPool(pool.Options{ Credentials: credentials.Cleartext{ Username: "postgres", Password: "password", }, }) - g.TestPool.AddRecipe("test", gat.Recipe{ - Dialer: gat.NetDialer{ + g.TestPool.AddRecipe("test", pool.Recipe{ + Dialer: pool.NetDialer{ Network: "tcp", Address: "localhost:5432", diff --git a/lib/gat/gat.go b/lib/gat/gat.go index 4d254dd93eafee4d3996d6f5a42d013354cecdf1..8cadbb01ac1fb4459db5e0f89c40f49d738fbccd 100644 --- a/lib/gat/gat.go +++ b/lib/gat/gat.go @@ -3,11 +3,12 @@ package gat import ( "pggat2/lib/auth" "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/gat/pool" "pggat2/lib/zap" ) type Gat struct { - TestPool *Pool + TestPool *pool.Pool } func (T *Gat) Serve(client zap.Conn, acceptParams frontends.AcceptParams) error { @@ -20,14 +21,14 @@ func (T *Gat) Serve(client zap.Conn, acceptParams frontends.AcceptParams) error return nil } - pool, err := T.GetPool(acceptParams.User, acceptParams.Database) + p, err := T.GetPool(acceptParams.User, acceptParams.Database) if err != nil { return err } var credentials auth.Credentials - if pool != nil { - credentials = pool.GetCredentials() + if p != nil { + credentials = p.GetCredentials() } authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{ @@ -37,14 +38,14 @@ func (T *Gat) Serve(client zap.Conn, acceptParams frontends.AcceptParams) error return err } - if pool == nil { + if p == nil { return nil } - return pool.Serve(client, acceptParams, authParams) + return p.Serve(client, acceptParams, authParams) } -func (T *Gat) GetPool(user, database string) (*Pool, error) { +func (T *Gat) GetPool(user, database string) (*pool.Pool, error) { return T.TestPool, nil return nil, nil // TODO(garet) } diff --git a/lib/gat/modes/pgbouncer/config.go b/lib/gat/modes/pgbouncer/config.go new file mode 100644 index 0000000000000000000000000000000000000000..aad103ffd65e61ded8b91c224b4a85fc7107dada --- /dev/null +++ b/lib/gat/modes/pgbouncer/config.go @@ -0,0 +1,425 @@ +package pgbouncer + +import ( + "errors" + "net" + "os" + "strconv" + "strings" + "time" + + "tuxpa.in/a/zlog/log" + + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/gat/pool" + "pggat2/lib/gat/pool/pools/session" + + "pggat2/lib/auth/credentials" + "pggat2/lib/gat" + "pggat2/lib/util/encoding/ini" + "pggat2/lib/util/encoding/userlist" + "pggat2/lib/util/flip" + "pggat2/lib/util/strutil" +) + +type PoolMode string + +const ( + PoolModeSession PoolMode = "session" + PoolModeTransaction PoolMode = "transaction" + PoolModeStatement PoolMode = "statement" +) + +type AuthType string + +const ( + AuthTypeCert AuthType = "cert" + AuthTypeMd5 AuthType = "md5" + AuthTypeScramSha256 AuthType = "scram-sha-256" + AuthTypePlain AuthType = "plain" + AuthTypeTrust AuthType = "trust" + AuthTypeAny AuthType = "any" + AuthTypeHba AuthType = "hba" + AuthTypePam AuthType = "pam" +) + +type SSLMode string + +const ( + SSLModeDisable SSLMode = "disable" + SSLModeAllow SSLMode = "allow" + SSLModePrefer SSLMode = "prefer" + SSLModeRequire SSLMode = "require" + SSLModeVerifyCa SSLMode = "verify-ca" + SSLModeVerifyFull SSLMode = "verify-full" +) + +type TLSProtocol string + +const ( + TLSProtocolV1_0 TLSProtocol = "tlsv1.0" + TLSProtocolV1_1 TLSProtocol = "tlsv1.1" + TLSProtocolV1_2 TLSProtocol = "tlsv1.2" + TLSProtocolV1_3 TLSProtocol = "tlsv1.3" + TLSProtocolAll TLSProtocol = "all" + TLSProtocolSecure TLSProtocol = "secure" + TLSProtocolLegacy TLSProtocol = "legacy" +) + +type TLSCipher string + +type TLSECDHCurve string + +type TLSDHEParams string + +type PgBouncer struct { + LogFile string `ini:"logfile"` + PidFile string `ini:"pidfile"` + ListenAddr string `ini:"listen_addr"` + ListenPort int `ini:"listen_port"` + UnixSocketDir string `ini:"unix_socket_dir"` + UnixSocketMode string `ini:"unix_socket_mode"` + UnixSocketGroup string `ini:"unix_socket_group"` + User string `ini:"user"` + PoolMode PoolMode `ini:"pool_mode"` + MaxClientConn int `ini:"max_client_conn"` + DefaultPoolSize int `ini:"default_pool_size"` + MinPoolSize int `ini:"min_pool_size"` + ReservePoolSize int `ini:"reserve_pool_size"` + ReservePoolTimeout float64 `ini:"reserve_pool_timeout"` + MaxDBConnections int `ini:"max_db_connections"` + MaxUserConnections int `ini:"max_user_connections"` + ServerRoundRobin int `ini:"server_round_robin"` + TrackExtraParameters []strutil.CIString `ini:"track_extra_parameters"` + IgnoreStartupParameters []strutil.CIString `ini:"ignore_startup_parameters"` + PeerID int `ini:"peer_id"` + DisablePQExec int `ini:"disable_pqexec"` + ApplicationNameAddHost int `ini:"application_name_add_host"` + ConfFile string `ini:"conffile"` + ServiceName string `ini:"service_name"` + StatsPeriod int `ini:"stats_period"` + AuthType string `ini:"auth_type"` + AuthHbaFile string `ini:"auth_hba_file"` + AuthFile string `ini:"auth_file"` + AuthUser string `ini:"auth_user"` + AuthQuery string `ini:"auth_query"` + AuthDbname string `ini:"auth_dbname"` + Syslog string `ini:"syslog"` + SyslogIdent string `ini:"syslog_ident"` + SyslogFacility string `ini:"syslog_facility"` + LogConnections int `ini:"log_connections"` + LogDisconnections int `ini:"log_disconnections"` + LogPoolerErrors int `ini:"log_pooler_errors"` + LogStats int `ini:"log_stats"` + Verbose int `ini:"verbose"` + AdminUsers []string `ini:"auth_users"` + StatsUsers []string `ini:"stats_users"` + ServerResetQuery string `ini:"server_reset_query"` + ServerResetQueryAlways int `ini:"server_reset_query_always"` + ServerCheckDelay float64 `ini:"server_check_delay"` + ServerCheckQuery string `ini:"server_check_query"` + ServerFastClose int `ini:"server_fast_close"` + ServerLifetime float64 `ini:"server_lifetime"` + ServerIdleTimeout float64 `ini:"server_idle_timeout"` + ServerConnectTimeout float64 `ini:"server_connect_timeout"` + ServerLoginRetry float64 `ini:"server_login_retry"` + ClientLoginTimeout float64 `ini:"client_login_timeout"` + AutodbIdleTimeout float64 `ini:"autodb_idle_timeout"` + DnsMaxTtl float64 `ini:"dns_max_ttl"` + DnsNxdomainTtl float64 `ini:"dns_nxdomain_ttl"` + DnsZoneCheckPeriod float64 `ini:"dns_zone_check_period"` + ResolvConf string `ini:"resolv.conf"` + ClientTLSSSLMode SSLMode `ini:"client_tls_sslmode"` + ClientTLSKeyFile string `ini:"client_tls_key_file"` + ClientTLSCertFile string `ini:"client_tls_cert_file"` + ClientTLSCaFile string `ini:"client_tls_ca_file"` + ClientTLSProtocols []TLSProtocol `ini:"client_tls_protocols"` + ClientTLSCiphers []TLSCipher `ini:"client_tls_ciphers"` + ClientTLSECDHCurve TLSECDHCurve `ini:"client_tls_ecdhcurve"` + ClientTLSDHEParams TLSDHEParams `ini:"client_tls_dheparams"` + ServerTLSSSLMode SSLMode `ini:"server_tls_sslmode"` + ServerTLSCaFile string `ini:"server_tls_ca_file"` + ServerTLSKeyFile string `ini:"server_tls_key_file"` + ServerTLSCertFile string `ini:"server_tls_cert_file"` + ServerTLSProtocols []TLSProtocol `ini:"server_tls_protocols"` + ServerTLSCiphers []TLSCipher `ini:"server_tls_ciphers"` + QueryTimeout float64 `ini:"query_timeout"` + QueryWaitTimeout float64 `ini:"query_wait_timeout"` + CancelWaitTimeout float64 `ini:"cancel_wait_timeout"` + ClientIdleTimeout float64 `ini:"client_idle_timeout"` + IdleTransactionTimeout float64 `ini:"idle_transaction_timeout"` + SuspendTimeout float64 `ini:"suspend_timeout"` + PktBuf int `ini:"pkt_buf"` + MaxPacketSize int `ini:"max_packet_size"` + ListenBacklog int `ini:"listen_backlog"` + SbufLoopcnt int `ini:"sbuf_loopcnt"` + SoReuseport int `ini:"so_reuseport"` + TcpDeferAccept int `ini:"tcp_defer_accept"` + TcpSocketBuffer int `ini:"tcp_socket_buffer"` + TcpKeepalive int `ini:"tcp_keepalive"` + TcpKeepidle int `ini:"tcp_keepidle"` + TcpKeepintvl int `ini:"tcp_keepintvl"` + TcpUserTimeout int `ini:"tcp_user_timeout"` +} + +type Database struct { + DBName string `ini:"dbname"` + Host string `ini:"host"` + Port int `ini:"port"` + User string `ini:"user"` + Password string `ini:"password"` + AuthUser string `ini:"auth_user"` + PoolSize int `ini:"pool_size"` + MinPoolSize int `ini:"min_pool_size"` + ReservePool int `ini:"reserve_pool"` + ConnectQuery string `ini:"connect_query"` + PoolMode PoolMode `ini:"pool_mode"` + MaxDBConnections int `ini:"max_db_connections"` + StartupParameters map[strutil.CIString]string `ini:"*"` +} + +type User struct { + PoolMode PoolMode `ini:"pool_mode"` + MaxUserConnections int `ini:"max_user_connections"` +} + +type Peer struct { + Host string `ini:"host"` + Port int `ini:"port"` + PoolSize int `ini:"pool_size"` +} + +type Config struct { + PgBouncer PgBouncer `ini:"pgbouncer"` + Databases map[string]Database `ini:"databases"` + Users map[string]User `ini:"users"` + Peers map[string]Peer `ini:"peers"` +} + +var Default = Config{ + PgBouncer: PgBouncer{ + ListenPort: 6432, + UnixSocketDir: "/tmp", + UnixSocketMode: "0777", + PoolMode: PoolModeSession, + MaxClientConn: 100, + DefaultPoolSize: 20, + ReservePoolTimeout: 5.0, + TrackExtraParameters: []strutil.CIString{ + strutil.MakeCIString("IntervalStyle"), + }, + ServiceName: "pgbouncer", + StatsPeriod: 60, + AuthQuery: "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", + SyslogIdent: "pgbouncer", + SyslogFacility: "daemon", + LogConnections: 1, + LogDisconnections: 1, + LogPoolerErrors: 1, + LogStats: 1, + ServerResetQuery: "DISCARD ALL", + ServerCheckDelay: 30.0, + ServerCheckQuery: "select 1", + ServerLifetime: 3600.0, + ServerIdleTimeout: 600.0, + ServerConnectTimeout: 15.0, + ServerLoginRetry: 15.0, + ClientLoginTimeout: 60.0, + AutodbIdleTimeout: 3600.0, + DnsMaxTtl: 15.0, + DnsNxdomainTtl: 15.0, + ClientTLSSSLMode: SSLModeDisable, + ClientTLSProtocols: []TLSProtocol{ + TLSProtocolSecure, + }, + ClientTLSCiphers: []TLSCipher{ + "fast", + }, + ClientTLSECDHCurve: "auto", + ServerTLSSSLMode: SSLModePrefer, + ServerTLSProtocols: []TLSProtocol{ + TLSProtocolSecure, + }, + ServerTLSCiphers: []TLSCipher{ + "fast", + }, + QueryWaitTimeout: 120.0, + CancelWaitTimeout: 10.0, + SuspendTimeout: 10.0, + PktBuf: 4096, + MaxPacketSize: 2147483647, + ListenBacklog: 128, + SbufLoopcnt: 5, + TcpDeferAccept: 1, + TcpKeepalive: 1, + }, +} + +func Load(config string) (Config, error) { + conf, err := ini.ReadFile(config) + if err != nil { + return Config{}, err + } + + var c = Default + err = ini.Unmarshal(conf, &c) + return c, err +} + +func (T *Config) ListenAndServe() error { + trackedParameters := append([]strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, T.PgBouncer.TrackExtraParameters...) + + allowedStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...) + + acceptOptions := frontends.AcceptOptions{ + AllowedStartupOptions: allowedStartupParameters, + } + + g := new(gat.Gat) + + var authFile map[string]string + if T.PgBouncer.AuthFile != "" { + file, err := os.ReadFile(T.PgBouncer.AuthFile) + if err != nil { + return err + } + + authFile, err = userlist.Unmarshal(file) + if err != nil { + return err + } + } + + for name, user := range T.Users { + creds := credentials.Cleartext{ + Username: name, + Password: authFile[name], // TODO(garet) md5 and sasl + } + /* TODO(garet) + u := gat.NewUser(creds) + g.AddUser(u) + */ + + for dbname, db := range T.Databases { + // filter out dbs specific to users + if db.User != "" && db.User != name { + continue + } + + // override dbname + if db.DBName != "" { + dbname = db.DBName + } + + // override poolmode + var poolMode PoolMode + if db.PoolMode != "" { + poolMode = db.PoolMode + } else if user.PoolMode != "" { + poolMode = user.PoolMode + } else { + poolMode = T.PgBouncer.PoolMode + } + + poolOptions := pool.Options{ + TrackedParameters: trackedParameters, + ServerIdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)), + } + + var p *pool.Pool + switch poolMode { + case PoolModeSession: + p = session.NewPool(poolOptions) + case PoolModeTransaction: + // TODO(garet) + default: + return errors.New("unsupported pool mode") + } + + // TODO(garet) add to gat + + if db.Host == "" { + // connect over unix socket + // TODO(garet) + } else { + var address string + if db.Port == 0 { + address = net.JoinHostPort(db.Host, "5432") + } else { + address = net.JoinHostPort(db.Host, strconv.Itoa(db.Port)) + } + + creds := creds + if db.Password != "" { + // lookup password + creds.Password = db.Password + } + + // connect over tcp + dialer := pool.NetDialer{ + Network: "tcp", + Address: address, + AcceptOptions: backends.AcceptOptions{ + Credentials: creds, + Database: dbname, + StartupParameters: db.StartupParameters, + }, + } + recipe := pool.Recipe{ + Dialer: dialer, + MinConnections: db.MinPoolSize, + MaxConnections: db.MaxDBConnections, + } + if recipe.MinConnections == 0 { + recipe.MinConnections = T.PgBouncer.MinPoolSize + } + if recipe.MaxConnections == 0 { + recipe.MaxConnections = T.PgBouncer.MaxDBConnections + } + + p.AddRecipe("pgbouncer", recipe) + } + } + } + + var bank flip.Bank + + if T.PgBouncer.ListenAddr != "" { + bank.Queue(func() error { + listenAddr := T.PgBouncer.ListenAddr + if listenAddr == "*" { + listenAddr = "" + } + + listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.PgBouncer.ListenPort)) + + log.Printf("listening on %s", listen) + + return gat.ListenAndServe("tcp", listen, acceptOptions, g) + }) + } + + // listen on unix socket + bank.Queue(func() error { + dir := T.PgBouncer.UnixSocketDir + port := T.PgBouncer.ListenPort + + if !strings.HasSuffix(dir, "/") { + dir = dir + "/" + } + dir = dir + ".s.PGSQL." + strconv.Itoa(port) + + log.Printf("listening on unix:%s", dir) + + return gat.ListenAndServe("unix", dir, acceptOptions, g) + }) + + return bank.Wait() +} diff --git a/lib/gat/modes/zalando/config.go b/lib/gat/modes/zalando/config.go new file mode 100644 index 0000000000000000000000000000000000000000..806d43ec61414b489c36369aee364e2600cbecda --- /dev/null +++ b/lib/gat/modes/zalando/config.go @@ -0,0 +1,94 @@ +package zalando + +import ( + "errors" + "fmt" + "net" + "strconv" + + "tuxpa.in/a/zlog/log" + + "gfx.cafe/util/go/gun" + + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/gat/pool/pools/session" + + "pggat2/lib/auth/credentials" + "pggat2/lib/gat" + "pggat2/lib/gat/pool" + "pggat2/lib/util/flip" +) + +type Config struct { + PGHost string `env:"PGHOST"` + PGPort int `env:"PGPORT"` + PGUser string `env:"PGUSER"` + PGSchema string `env:"PGSCHEMA"` + PGPassword string `env:"PGPASSWORD"` + PoolerPort int `env:"CONNECTION_POOLER_PORT"` + PoolerMode string `env:"CONNECTION_POOLER_MODE"` + PoolerDefaultSize int `env:"CONNECTION_POOLER_DEFAULT_SIZE"` + PoolerMinSize int `env:"CONNECTION_POOLER_MIN_SIZE"` + PoolerReserveSize int `env:"CONNECTION_POOLER_RESERVE_SIZE"` + PoolerMaxClientConn int `env:"CONNECTION_POOLER_MAX_CLIENT_CONN"` + PoolerMaxDBConn int `env:"CONNECTION_POOLER_MAX_DB_CONN"` +} + +func Load() (Config, error) { + var conf Config + gun.Load(&conf) + if conf.PoolerMode == "" { + return Config{}, errors.New("expected pooler mode") + } + + return conf, nil +} + +func (T *Config) ListenAndServe() error { + g := new(gat.Gat) + + creds := credentials.Cleartext{ + Username: T.PGUser, + Password: T.PGPassword, + } + + /* TODO(garet) + user := gat.NewUser(creds) + g.AddUser(user) + */ + + var p *pool.Pool + if T.PoolerMode == "transaction" { + // p = transaction.NewPool(pool.Options{}) + } else { + p = session.NewPool(pool.Options{}) + } + + // TODO(garet) add to gat + + p.AddRecipe("zalando", pool.Recipe{ + Dialer: pool.NetDialer{ + Network: "tcp", + Address: net.JoinHostPort(T.PGHost, strconv.Itoa(T.PGPort)), + AcceptOptions: backends.AcceptOptions{ + Credentials: creds, + Database: "test", + }, + }, + MinConnections: T.PoolerMinSize, + MaxConnections: T.PoolerMaxDBConn, + }) + + var bank flip.Bank + + bank.Queue(func() error { + listen := fmt.Sprintf(":%d", T.PoolerPort) + + log.Printf("listening on %s", listen) + + return gat.ListenAndServe("tcp", listen, frontends.AcceptOptions{}, g) + }) + + return bank.Wait() +} diff --git a/lib/gat/pool.go b/lib/gat/pool.go deleted file mode 100644 index a1a85058232a08b868b05182db5c292a2518de68..0000000000000000000000000000000000000000 --- a/lib/gat/pool.go +++ /dev/null @@ -1,241 +0,0 @@ -package gat - -import ( - "github.com/google/uuid" - "tuxpa.in/a/zlog/log" - - "pggat2/lib/auth" - "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/bouncer/bouncers/v2" - "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/middleware/interceptor" - "pggat2/lib/middleware/middlewares/unterminate" - "pggat2/lib/util/maths" - "pggat2/lib/zap" -) - -type poolRecipe struct { - recipe Recipe - servers map[uuid.UUID]struct{} -} - -type Pool struct { - options PoolOptions - - recipes map[string]*poolRecipe - - servers map[uuid.UUID]zap.Conn - clients map[uuid.UUID]zap.Conn -} - -type PoolOptions struct { - Credentials auth.Credentials - Pooler Pooler - ServerResetQuery string -} - -func NewPool(options PoolOptions) *Pool { - return &Pool{ - options: options, - } -} - -func (T *Pool) GetCredentials() auth.Credentials { - return T.options.Credentials -} - -func (T *Pool) scale(name string, amount int) { - recipe := T.recipes[name] - if recipe == nil { - return - } - - target := maths.Clamp(len(recipe.servers)+amount, recipe.recipe.MinConnections, recipe.recipe.MaxConnections) - diff := target - len(recipe.servers) - - for diff > 0 { - diff-- - - // add server - server, params, err := recipe.recipe.Dialer.Dial() - if err != nil { - log.Printf("failed to connect to server: %v", err) - continue - } - - _ = params // TODO(garet) - - serverID := T.addServer(server) - if recipe.servers == nil { - recipe.servers = make(map[uuid.UUID]struct{}) - } - recipe.servers[serverID] = struct{}{} - } - - for diff < 0 { - diff++ - - // remove server - for s := range recipe.servers { - T.removeServer(s) - break - } - } -} - -func (T *Pool) AddRecipe(name string, recipe Recipe) { - if T.recipes == nil { - T.recipes = make(map[string]*poolRecipe) - } - - T.recipes[name] = &poolRecipe{ - recipe: recipe, - } - - T.scale(name, 0) -} - -func (T *Pool) RemoveRecipe(name string) { - if recipe, ok := T.recipes[name]; ok { - recipe.recipe.MaxConnections = 0 - T.scale(name, 0) - delete(T.recipes, name) - } -} - -func (T *Pool) addClient( - client zap.Conn, -) uuid.UUID { - clientID := uuid.New() - T.options.Pooler.AddClient(clientID) - - if T.clients == nil { - T.clients = make(map[uuid.UUID]zap.Conn) - } - T.clients[clientID] = client - return clientID -} - -func (T *Pool) removeClient( - clientID uuid.UUID, -) { - T.options.Pooler.RemoveClient(clientID) - if client, ok := T.clients[clientID]; ok { - _ = client.Close() - delete(T.clients, clientID) - } -} - -func (T *Pool) addServer( - server zap.Conn, -) uuid.UUID { - serverID := uuid.New() - T.options.Pooler.AddServer(serverID) - - if T.servers == nil { - T.servers = make(map[uuid.UUID]zap.Conn) - } - T.servers[serverID] = server - return serverID -} - -func (T *Pool) acquireServer( - clientID uuid.UUID, -) (serverID uuid.UUID, server zap.Conn) { - serverID = T.options.Pooler.AcquireConcurrent(clientID) - if serverID == uuid.Nil { - // TODO(garet) scale up - serverID = T.options.Pooler.AcquireAsync(clientID) - } - - server = T.servers[serverID] - return -} - -func (T *Pool) removeServer( - serverID uuid.UUID, -) { - T.options.Pooler.RemoveServer(serverID) - if server, ok := T.servers[serverID]; ok { - _ = server.Close() - delete(T.servers, serverID) - } -} - -func (T *Pool) tryReleaseServer( - serverID uuid.UUID, -) bool { - if !T.options.Pooler.CanRelease(serverID) { - return false - } - T.releaseServer(serverID) - return true -} - -func (T *Pool) releaseServer( - serverID uuid.UUID, -) { - if T.options.ServerResetQuery != "" { - server := T.servers[serverID] - err := backends.QueryString(new(backends.Context), server, T.options.ServerResetQuery) - if err != nil { - T.removeServer(serverID) - return - } - } - T.options.Pooler.Release(serverID) -} - -func (T *Pool) Serve( - client zap.Conn, - acceptParams frontends.AcceptParams, - authParams frontends.AuthenticateParams, -) error { - client = interceptor.NewInterceptor( - client, - unterminate.Unterminate, - // TODO(garet) add middlewares based on Pool.options - ) - - defer func() { - _ = client.Close() - }() - - clientID := T.addClient(client) - - var serverID uuid.UUID - var server zap.Conn - - defer func() { - if serverID != uuid.Nil { - T.releaseServer(serverID) - } - }() - - for { - packet, err := client.ReadPacket(true) - if err != nil { - return err - } - - if serverID == uuid.Nil { - serverID, server = T.acquireServer(clientID) - } - clientErr, serverErr := bouncers.Bounce(client, server, packet) - if serverErr != nil { - T.removeServer(serverID) - serverID = uuid.Nil - server = nil - return serverErr - } else { - if T.tryReleaseServer(serverID) { - serverID = uuid.Nil - server = nil - } - } - - if clientErr != nil { - return clientErr - } - } -} diff --git a/lib/gat/dialer.go b/lib/gat/pool/dialer.go similarity index 97% rename from lib/gat/dialer.go rename to lib/gat/pool/dialer.go index 6df34123ba57e2bbc9d38ef125e4f1585a8fb48d..c8779dc7465724608e9d0a85375f044c4fb5d6f2 100644 --- a/lib/gat/dialer.go +++ b/lib/gat/pool/dialer.go @@ -1,4 +1,4 @@ -package gat +package pool import ( "net" diff --git a/lib/gat/pool/options.go b/lib/gat/pool/options.go new file mode 100644 index 0000000000000000000000000000000000000000..732cafd16fd2d402ff7d0dc530dd28a489f09f61 --- /dev/null +++ b/lib/gat/pool/options.go @@ -0,0 +1,39 @@ +package pool + +import ( + "time" + + "pggat2/lib/auth" + "pggat2/lib/util/strutil" +) + +type ParameterStatusSync int + +const ( + // ParameterStatusSyncNone does not attempt to sync parameter status. + ParameterStatusSyncNone ParameterStatusSync = iota + // ParameterStatusSyncInitial assumes both client and server have their initial status before syncing. + // Use in session pooling for lower latency + ParameterStatusSyncInitial + // ParameterStatusSyncDynamic will track parameter status and ensure they are synced correctly. + // Use in transaction pooling + ParameterStatusSyncDynamic +) + +type Options struct { + Credentials auth.Credentials + Pooler Pooler + ServerResetQuery string + // ServerIdleTimeout defines how long a server may be idle before it is disconnected + ServerIdleTimeout time.Duration + + // ParameterStatusSync is the parameter syncing mode + ParameterStatusSync ParameterStatusSync + // TrackedParameters are parameters which should be synced by updating the server, not the client. + TrackedParameters []strutil.CIString + + // ExtendedQuerySync controls whether prepared statements and portals should be tracked and synced before use. + // Use false for lower latency + // Use true for transaction pooling + ExtendedQuerySync bool +} diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..0654b39f990e35c4121ac973d727161ba0bc5c2c --- /dev/null +++ b/lib/gat/pool/pool.go @@ -0,0 +1,354 @@ +package pool + +import ( + "sync" + + "github.com/google/uuid" + "tuxpa.in/a/zlog/log" + + "pggat2/lib/auth" + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/bouncer/bouncers/v2" + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/middleware" + "pggat2/lib/middleware/interceptor" + "pggat2/lib/middleware/middlewares/eqp" + "pggat2/lib/middleware/middlewares/ps" + "pggat2/lib/middleware/middlewares/unterminate" + "pggat2/lib/util/slices" + "pggat2/lib/util/strutil" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +type poolServer struct { + conn zap.Conn + accept backends.AcceptParams + recipe string + + // middlewares + psServer *ps.Server + eqpServer *eqp.Server +} + +type poolRecipe struct { + recipe Recipe + count int +} + +type Pool struct { + options Options + + maxServers int + recipes map[string]*poolRecipe + servers map[uuid.UUID]poolServer + clients map[uuid.UUID]zap.Conn + mu sync.Mutex +} + +func NewPool(options Options) *Pool { + return &Pool{ + options: options, + } +} + +func (T *Pool) GetCredentials() auth.Credentials { + return T.options.Credentials +} + +func (T *Pool) _scaleUpRecipe(name string) { + r := T.recipes[name] + + server, params, err := r.recipe.Dialer.Dial() + if err != nil { + log.Printf("failed to dial server: %v", err) + } + + serverID := uuid.New() + if T.servers == nil { + T.servers = make(map[uuid.UUID]poolServer) + } + + var middlewares []middleware.Middleware + + var psServer *ps.Server + if T.options.ParameterStatusSync == ParameterStatusSyncDynamic { + // add ps middleware + psServer = ps.NewServer(params.InitialParameters) + middlewares = append(middlewares, psServer) + } + + var eqpServer *eqp.Server + if T.options.ExtendedQuerySync { + // add eqp middleware + eqpServer = eqp.NewServer() + middlewares = append(middlewares, eqpServer) + } + + T.servers[serverID] = poolServer{ + conn: server, + accept: params, + recipe: name, + + psServer: psServer, + eqpServer: eqpServer, + } + T.options.Pooler.AddServer(serverID) +} + +func (T *Pool) AddRecipe(name string, recipe Recipe) { + T.mu.Lock() + defer T.mu.Unlock() + + if T.recipes == nil { + T.recipes = make(map[string]*poolRecipe) + } + T.maxServers += recipe.MaxConnections + T.recipes[name] = &poolRecipe{ + recipe: recipe, + count: 0, + } + + for i := 0; i < recipe.MinConnections; i++ { + T._scaleUpRecipe(name) + } +} + +func (T *Pool) RemoveRecipe(name string) { + T.mu.Lock() + defer T.mu.Unlock() + + if r, ok := T.recipes[name]; ok { + T.maxServers -= r.count + } + delete(T.recipes, name) + + // close all servers with this recipe + for id, server := range T.servers { + if server.recipe == name { + _ = server.conn.Close() + T.options.Pooler.RemoveServer(id) + delete(T.servers, id) + } + } +} + +func (T *Pool) scaleUp() { + T.mu.Lock() + defer T.mu.Unlock() + + for name, r := range T.recipes { + if r.count < r.recipe.MaxConnections { + T._scaleUpRecipe(name) + return + } + } +} + +func (T *Pool) syncInitialParameters( + client zap.Conn, + clientParams map[strutil.CIString]string, + server zap.Conn, + serverParams map[strutil.CIString]string, +) (clientErr, serverErr error) { + for key, value := range clientParams { + setServer := slices.Contains(T.options.TrackedParameters, key) + + // skip already set params + if serverParams[key] == value { + setServer = false + } else if !setServer { + value = serverParams[key] + } + + p := packets.ParameterStatus{ + Key: key.String(), + Value: serverParams[key], + } + clientErr = client.WritePacket(p.IntoPacket()) + if clientErr != nil { + return + } + + if !setServer { + continue + } + + serverErr = backends.SetParameter(new(backends.Context), server, key, value) + if serverErr != nil { + return + } + } + + for key, value := range serverParams { + if _, ok := clientParams[key]; ok { + continue + } + + if slices.Contains(T.options.TrackedParameters, key) { + serverErr = backends.ResetParameter(new(backends.Context), server, key) + if serverErr != nil { + return + } + } else { + // send to client + p := packets.ParameterStatus{ + Key: key.String(), + Value: value, + } + clientErr = client.WritePacket(p.IntoPacket()) + if clientErr != nil { + return + } + } + } + + return +} + +func (T *Pool) Serve( + client zap.Conn, + accept frontends.AcceptParams, + auth frontends.AuthenticateParams, +) error { + defer func() { + _ = client.Close() + }() + + middlewares := []middleware.Middleware{ + unterminate.Unterminate, + } + + var psClient *ps.Client + if T.options.ParameterStatusSync == ParameterStatusSyncDynamic { + // add ps middleware + psClient = ps.NewClient(accept.InitialParameters) + middlewares = append(middlewares, psClient) + } + + var eqpClient *eqp.Client + if T.options.ExtendedQuerySync { + // add eqp middleware + eqpClient = eqp.NewClient() + middlewares = append(middlewares, eqpClient) + } + + client = interceptor.NewInterceptor( + client, + middlewares..., + ) + + clientID := T.addClient(client) + + var serverID uuid.UUID + var server poolServer + + defer func() { + if serverID != uuid.Nil { + T.releaseServer(serverID) + } + }() + + for { + packet, err := client.ReadPacket(true) + if err != nil { + return err + } + + var clientErr, serverErr error + if serverID == uuid.Nil { + serverID, server = T.acquireServer(clientID) + + switch T.options.ParameterStatusSync { + case ParameterStatusSyncDynamic: + clientErr, serverErr = ps.Sync(T.options.TrackedParameters, client, psClient, server.conn, server.psServer) + case ParameterStatusSyncInitial: + clientErr, serverErr = T.syncInitialParameters(client, accept.InitialParameters, server.conn, server.accept.InitialParameters) + } + + if T.options.ExtendedQuerySync { + server.eqpServer.SetClient(eqpClient) + } + } + if clientErr != nil && serverErr != nil { + clientErr, serverErr = bouncers.Bounce(client, server.conn, packet) + } + if serverErr != nil { + T.removeServer(serverID) + serverID = uuid.Nil + server = poolServer{} + return serverErr + } else { + if T.options.Pooler.ReleaseAfterTransaction() { + T.releaseServer(serverID) + serverID = uuid.Nil + server = poolServer{} + } + } + + if clientErr != nil { + return clientErr + } + } +} + +func (T *Pool) addClient(client zap.Conn) uuid.UUID { + T.mu.Lock() + defer T.mu.Unlock() + + clientID := uuid.New() + + if T.clients == nil { + T.clients = make(map[uuid.UUID]zap.Conn) + } + T.clients[clientID] = client + T.options.Pooler.AddClient(clientID) + return clientID +} + +func (T *Pool) acquireServer(clientID uuid.UUID) (serverID uuid.UUID, server poolServer) { + serverID = T.options.Pooler.AcquireConcurrent(clientID) + if serverID == uuid.Nil { + go T.scaleUp() + serverID = T.options.Pooler.AcquireAsync(clientID) + } + + T.mu.Lock() + defer T.mu.Unlock() + server = T.servers[serverID] + return +} + +func (T *Pool) releaseServer(serverID uuid.UUID) { + T.mu.Lock() + defer T.mu.Unlock() + + if T.options.ServerResetQuery != "" { + server := T.servers[serverID].conn + err := backends.QueryString(new(backends.Context), server, T.options.ServerResetQuery) + if err != nil { + T._removeServer(serverID) + return + } + } + T.options.Pooler.Release(serverID) +} + +func (T *Pool) _removeServer(serverID uuid.UUID) { + if server, ok := T.servers[serverID]; ok { + _ = server.conn.Close() + delete(T.servers, serverID) + T.options.Pooler.RemoveServer(serverID) + r := T.recipes[server.recipe] + if r != nil { + r.count-- + } + } +} + +func (T *Pool) removeServer(serverID uuid.UUID) { + T.mu.Lock() + defer T.mu.Unlock() + + T._removeServer(serverID) +} diff --git a/lib/gat/pooler.go b/lib/gat/pool/pooler.go similarity index 70% rename from lib/gat/pooler.go rename to lib/gat/pool/pooler.go index 747096089404a4e7780f03574a502ed714ac6045..27e39a7e9f86b19be63cac445f0ae51a3fcb7b0f 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pool/pooler.go @@ -1,4 +1,4 @@ -package gat +package pool import "github.com/google/uuid" @@ -16,10 +16,8 @@ type Pooler interface { // AcquireAsync will stall until a peer is available. AcquireAsync(client uuid.UUID) uuid.UUID - // CanRelease will check if a server can be released after a transaction. - // Some poolers (such as session poolers) do not release servers after each transaction. - // Returns true if Release could be called. - CanRelease(server uuid.UUID) bool + // ReleaseAfterTransaction queries whether servers should be immediately released after a transaction is completed. + ReleaseAfterTransaction() bool // Release will force release the server. // This should be called when the paired client has disconnected, or after CanRelease returns true. diff --git a/lib/gat/pool/pools/session/pool.go b/lib/gat/pool/pools/session/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..196a92b69e3e4cd45a4405d0d26174de25871f50 --- /dev/null +++ b/lib/gat/pool/pools/session/pool.go @@ -0,0 +1,12 @@ +package session + +import ( + "pggat2/lib/gat/pool" +) + +func NewPool(options pool.Options) *pool.Pool { + options.Pooler = new(Pooler) + options.ParameterStatusSync = pool.ParameterStatusSyncInitial + options.ExtendedQuerySync = false + return pool.NewPool(options) +} diff --git a/lib/gat/pools/session/pooler.go b/lib/gat/pool/pools/session/pooler.go similarity index 93% rename from lib/gat/pools/session/pooler.go rename to lib/gat/pool/pools/session/pooler.go index d3c9e08870150805b4a0373fd559375b9bfb722f..7d8bb9daed7d424e4582edcb25baa2406140d8e1 100644 --- a/lib/gat/pools/session/pooler.go +++ b/lib/gat/pool/pools/session/pooler.go @@ -5,7 +5,7 @@ import ( "github.com/google/uuid" - "pggat2/lib/gat" + "pggat2/lib/gat/pool" "pggat2/lib/util/slices" ) @@ -79,7 +79,7 @@ func (T *Pooler) AcquireAsync(_ uuid.UUID) uuid.UUID { return server } -func (*Pooler) CanRelease(_ uuid.UUID) bool { +func (*Pooler) ReleaseAfterTransaction() bool { // servers are released when the client is removed return false } @@ -93,4 +93,4 @@ func (T *Pooler) Release(server uuid.UUID) { T.queue = append(T.queue, server) } -var _ gat.Pooler = (*Pooler)(nil) +var _ pool.Pooler = (*Pooler)(nil) diff --git a/lib/gat/recipe.go b/lib/gat/pool/recipe.go similarity index 87% rename from lib/gat/recipe.go rename to lib/gat/pool/recipe.go index 487f4433f171afe948ce58071727eef903c268b2..66260a8f9c6dac25e65f1576b3e90f475fcebc50 100644 --- a/lib/gat/recipe.go +++ b/lib/gat/pool/recipe.go @@ -1,4 +1,4 @@ -package gat +package pool type Recipe struct { Dialer Dialer diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go deleted file mode 100644 index 1e33ce8c078c32a8dd447935ba4dd1daacd4494a..0000000000000000000000000000000000000000 --- a/lib/gat/pools/session/pool.go +++ /dev/null @@ -1,8 +0,0 @@ -package session - -import "pggat2/lib/gat" - -func NewPool(options gat.PoolOptions) *gat.Pool { - options.Pooler = new(Pooler) - return gat.NewPool(options) -} diff --git a/lib/middleware/middlewares/unterminate/unterminate.go b/lib/middleware/middlewares/unterminate/unterminate.go index 7885d54298c51f231d626cfceaf8a220f37dfe34..87951ba534879e0ec9b3142503828933ab7f8f3b 100644 --- a/lib/middleware/middlewares/unterminate/unterminate.go +++ b/lib/middleware/middlewares/unterminate/unterminate.go @@ -8,6 +8,8 @@ import ( packets "pggat2/lib/zap/packets/v3.0" ) +// Unterminate catches the Terminate packet and returns io.EOF instead. +// Useful if you don't want to forward to the server and close the connection. var Unterminate = unterm{} type unterm struct {