diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index db2951e63cf7cb93b0bde8be5957c94d2fa106a7..3b82b28eba24ad1d7e88f449f3367589ce6c8a6c 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -1,14 +1,18 @@ package main import ( + "crypto/tls" "net/http" _ "net/http/pprof" - "os" "tuxpa.in/a/zlog/log" - "pggat2/lib/gat/configs/pgbouncer" - "pggat2/lib/gat/configs/zalando" + "pggat2/lib/auth/credentials" + "pggat2/lib/bouncer" + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/gat" + "pggat2/lib/gat/pools/session" ) func main() { @@ -18,9 +22,56 @@ func main() { log.Printf("Starting pggat...") - if len(os.Args) == 2 { - log.Printf("running in pgbouncer compatibility mode") - conf, err := pgbouncer.Load(os.Args[1]) + g := new(gat.Gat) + g.TestPool = session.NewPool(gat.PoolOptions{ + Credentials: credentials.Cleartext{ + Username: "postgres", + Password: "password", + }, + }) + g.TestPool.AddRecipe("test", gat.Recipe{ + Dialer: gat.NetDialer{ + Network: "tcp", + Address: "localhost:5432", + + AcceptOptions: backends.AcceptOptions{ + SSLMode: bouncer.SSLModeAllow, + SSLConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Credentials: credentials.Cleartext{ + Username: "postgres", + Password: "password", + }, + Database: "postgres", + }, + }, + MinConnections: 1, + MaxConnections: 1, + }) + err := gat.ListenAndServe("tcp", ":6432", frontends.AcceptOptions{}, g) + if err != nil { + panic(err) + } + + /* + if len(os.Args) == 2 { + log.Printf("running in pgbouncer compatibility mode") + conf, err := pgbouncer.Load(os.Args[1]) + if err != nil { + panic(err) + } + + err = conf.ListenAndServe() + if err != nil { + panic(err) + } + return + } + + log.Printf("running in zalando compatibility mode") + + conf, err := zalando.Load() if err != nil { panic(err) } @@ -29,18 +80,5 @@ func main() { if err != nil { panic(err) } - return - } - - log.Printf("running in zalando compatibility mode") - - conf, err := zalando.Load() - if err != nil { - panic(err) - } - - err = conf.ListenAndServe() - if err != nil { - panic(err) - } + */ } diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 9bf5be32cea8c6aa032a734a29d43e2e5c193438..741b66b889482a3a7c9dc4f4fed4363a0e48ec39 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -5,7 +5,6 @@ import ( "errors" "pggat2/lib/auth" - "pggat2/lib/bouncer" "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" @@ -188,16 +187,16 @@ func startup0(server zap.Conn, creds auth.Credentials) (done bool, err error) { } } -func startup1(conn *bouncer.Conn) (done bool, err error) { +func startup1(conn zap.Conn, params *AcceptParams) (done bool, err error) { var packet zap.Packet - packet, err = conn.RW.ReadPacket(true) + packet, err = conn.ReadPacket(true) if err != nil { return } switch packet.Type() { case packets.TypeBackendKeyData: - packet.ReadBytes(conn.BackendKey[:]) + packet.ReadBytes(params.BackendKey[:]) return false, nil case packets.TypeParameterStatus: var ps packets.ParameterStatus @@ -206,10 +205,10 @@ func startup1(conn *bouncer.Conn) (done bool, err error) { return } ikey := strutil.MakeCIString(ps.Key) - if conn.InitialParameters == nil { - conn.InitialParameters = make(map[strutil.CIString]string) + if params.InitialParameters == nil { + params.InitialParameters = make(map[strutil.CIString]string) } - conn.InitialParameters[ikey] = ps.Value + params.InitialParameters[ikey] = ps.Value return false, nil case packets.TypeReadyForQuery: return true, nil @@ -256,27 +255,23 @@ func enableSSL(server zap.Conn, config *tls.Config) (bool, error) { return true, nil } -func Accept(server zap.Conn, options AcceptOptions) (bouncer.Conn, error) { +func Accept(server zap.Conn, options AcceptOptions) (AcceptParams, error) { username := options.Credentials.GetUsername() if options.Database == "" { options.Database = username } - conn := bouncer.Conn{ - RW: server, - User: username, - Database: options.Database, - } + var params AcceptParams if options.SSLMode.ShouldAttempt() { var err error - conn.SSLEnabled, err = enableSSL(server, options.SSLConfig) + params.SSLEnabled, err = enableSSL(server, options.SSLConfig) if err != nil { - return bouncer.Conn{}, err + return AcceptParams{}, err } - if !conn.SSLEnabled && options.SSLMode.IsRequired() { - return bouncer.Conn{}, errors.New("server rejected SSL encryption") + if !params.SSLEnabled && options.SSLMode.IsRequired() { + return AcceptParams{}, errors.New("server rejected SSL encryption") } } @@ -301,14 +296,14 @@ func Accept(server zap.Conn, options AcceptOptions) (bouncer.Conn, error) { err := server.WritePacket(packet) if err != nil { - return bouncer.Conn{}, err + return AcceptParams{}, err } for { var done bool done, err = startup0(server, options.Credentials) if err != nil { - return bouncer.Conn{}, err + return AcceptParams{}, err } if done { break @@ -317,9 +312,9 @@ func Accept(server zap.Conn, options AcceptOptions) (bouncer.Conn, error) { for { var done bool - done, err = startup1(&conn) + done, err = startup1(server, ¶ms) if err != nil { - return bouncer.Conn{}, err + return AcceptParams{}, err } if done { break @@ -327,5 +322,5 @@ func Accept(server zap.Conn, options AcceptOptions) (bouncer.Conn, error) { } // startup complete, connection is ready for queries - return conn, nil + return params, nil } diff --git a/lib/bouncer/backends/v0/params.go b/lib/bouncer/backends/v0/params.go new file mode 100644 index 0000000000000000000000000000000000000000..9c71ca2d64ecbe2a5e1e7e73cc1a5edc59d9ba47 --- /dev/null +++ b/lib/bouncer/backends/v0/params.go @@ -0,0 +1,9 @@ +package backends + +import "pggat2/lib/util/strutil" + +type AcceptParams struct { + SSLEnabled bool + InitialParameters map[strutil.CIString]string + BackendKey [8]byte +} diff --git a/lib/bouncer/conn.go b/lib/bouncer/conn.go deleted file mode 100644 index 3768b503c2a0e38cad8b03d1ee6caa7123d5a819..0000000000000000000000000000000000000000 --- a/lib/bouncer/conn.go +++ /dev/null @@ -1,16 +0,0 @@ -package bouncer - -import ( - "pggat2/lib/util/strutil" - "pggat2/lib/zap" -) - -type Conn struct { - RW zap.Conn - - SSLEnabled bool - User string - Database string - InitialParameters map[strutil.CIString]string - BackendKey [8]byte -} diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index f1f02eb96d0736c660d8b8be71307cbe100faed6..fa08f9b27cf783c7d4765e5e1b032d1a090f06c6 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -1,13 +1,9 @@ package frontends import ( - "crypto/rand" - "errors" "fmt" "strings" - "pggat2/lib/auth" - "pggat2/lib/bouncer" "pggat2/lib/perror" "pggat2/lib/util/slices" "pggat2/lib/util/strutil" @@ -16,10 +12,11 @@ import ( ) func startup0( - client *bouncer.Conn, + conn zap.Conn, + params *AcceptParams, options AcceptOptions, ) (done bool, err perror.Error) { - packet, err2 := client.RW.ReadPacket(false) + packet, err2 := conn.ReadPacket(false) if err2 != nil { err = perror.Wrap(err2) return @@ -35,35 +32,40 @@ func startup0( switch minorVersion { case 5678: // Cancel - p.ReadBytes(client.BackendKey[:]) - - options.Pooler.Cancel(client.BackendKey) + p.ReadBytes(params.CancelKey[:]) + + if params.CancelKey == [8]byte{} { + // very rare that this would ever happen + // and it's ok if we don't honor cancel requests + err = perror.New( + perror.FATAL, + perror.ProtocolViolation, + "cancel key cannot be null", + ) + return + } - err = perror.New( - perror.FATAL, - perror.ProtocolViolation, - "Expected client to disconnect", - ) + done = true return case 5679: // ssl is not enabled if options.SSLConfig == nil { - err = perror.Wrap(client.RW.WriteByte('N')) + err = perror.Wrap(conn.WriteByte('N')) return } // do ssl - if err = perror.Wrap(client.RW.WriteByte('S')); err != nil { + if err = perror.Wrap(conn.WriteByte('S')); err != nil { return } - if err = perror.Wrap(client.RW.EnableSSLServer(options.SSLConfig)); err != nil { + if err = perror.Wrap(conn.EnableSSLServer(options.SSLConfig)); err != nil { return } - client.SSLEnabled = true + params.SSLEnabled = true return case 5680: // GSSAPI is not supported yet - err = perror.Wrap(client.RW.WriteByte('N')) + err = perror.Wrap(conn.WriteByte('N')) return default: err = perror.New( @@ -98,9 +100,9 @@ func startup0( switch key { case "user": - client.User = value + params.User = value case "database": - client.Database = value + params.Database = value case "options": fields := strings.Fields(value) for i := 0; i < len(fields); i++ { @@ -130,10 +132,10 @@ func startup0( return } - if client.InitialParameters == nil { - client.InitialParameters = make(map[strutil.CIString]string) + if params.InitialParameters == nil { + params.InitialParameters = make(map[strutil.CIString]string) } - client.InitialParameters[ikey] = value + params.InitialParameters[ikey] = value default: err = perror.New( perror.FATAL, @@ -166,10 +168,10 @@ func startup0( return } - if client.InitialParameters == nil { - client.InitialParameters = make(map[strutil.CIString]string) + if params.InitialParameters == nil { + params.InitialParameters = make(map[strutil.CIString]string) } - client.InitialParameters[ikey] = value + params.InitialParameters[ikey] = value } } } @@ -181,13 +183,13 @@ func startup0( UnrecognizedOptions: unsupportedOptions, } - err = perror.Wrap(client.RW.WritePacket(uopts.IntoPacket())) + err = perror.Wrap(conn.WritePacket(uopts.IntoPacket())) if err != nil { return } } - if client.User == "" { + if params.User == "" { err = perror.New( perror.FATAL, perror.InvalidAuthorizationSpecification, @@ -195,125 +197,21 @@ func startup0( ) return } - if client.Database == "" { - client.Database = client.User + if params.Database == "" { + params.Database = params.User } done = true return } -func authenticationSASLInitial(client zap.Conn, creds auth.SASL) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { - // check which authentication method the client wants - packet, err2 := client.ReadPacket(true) - if err2 != nil { - err = perror.Wrap(err2) - return - } - var initialResponse packets.SASLInitialResponse - if !initialResponse.ReadFromPacket(packet) { - err = packets.ErrBadFormat - return - } - - tool, err2 = creds.VerifySASL(initialResponse.Mechanism) - if err2 != nil { - err = perror.Wrap(err2) - return - } - - resp, err2 = tool.Write(initialResponse.InitialResponse) - if err2 != nil { - if errors.Is(err2, auth.ErrSASLComplete) { - done = true - return - } - err = perror.Wrap(err2) - return - } - return -} - -func authenticationSASLContinue(client zap.Conn, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { - packet, err2 := client.ReadPacket(true) - if err2 != nil { - err = perror.Wrap(err2) - return - } - var authResp packets.AuthenticationResponse - if !authResp.ReadFromPacket(packet) { - err = packets.ErrBadFormat - return - } - - resp, err2 = tool.Write(authResp) - if err2 != nil { - if errors.Is(err2, auth.ErrSASLComplete) { - done = true - return - } - err = perror.Wrap(err2) - return - } - return -} - -func authenticationSASL(client zap.Conn, creds auth.SASL) perror.Error { - saslInitial := packets.AuthenticationSASL{ - Mechanisms: creds.SupportedSASLMechanisms(), - } - err := perror.Wrap(client.WritePacket(saslInitial.IntoPacket())) - if err != nil { - return err - } - - tool, resp, done, err := authenticationSASLInitial(client, creds) - if err != nil { - return err - } - - for { - if done { - final := packets.AuthenticationSASLFinal(resp) - err = perror.Wrap(client.WritePacket(final.IntoPacket())) - if err != nil { - return err - } - break - } else { - cont := packets.AuthenticationSASLContinue(resp) - err = perror.Wrap(client.WritePacket(cont.IntoPacket())) - if err != nil { - return err - } - } - - resp, done, err = authenticationSASLContinue(client, tool) - if err != nil { - return err - } - } - - return nil -} - -func updateParameter(client zap.Conn, name, value string) perror.Error { - ps := packets.ParameterStatus{ - Key: name, - Value: value, - } - return perror.Wrap(client.WritePacket(ps.IntoPacket())) -} - func accept( client zap.Conn, options AcceptOptions, -) (conn bouncer.Conn, err perror.Error) { - conn.RW = client - +) (params AcceptParams, err perror.Error) { for { var done bool - done, err = startup0(&conn, options) + done, err = startup0(client, ¶ms, options) if err != nil { return } @@ -322,70 +220,16 @@ func accept( } } - if options.SSLRequired && !conn.SSLEnabled { - err = perror.New( - perror.FATAL, - perror.InvalidPassword, - "SSL is required", - ) + if params.CancelKey != [8]byte{} { return } - creds := options.Pooler.GetUserCredentials(conn.User, conn.Database) - if creds == nil { + if options.SSLRequired && !params.SSLEnabled { err = perror.New( perror.FATAL, perror.InvalidPassword, - "User or database not found", - ) - return - } - if credsSASL, ok := creds.(auth.SASL); ok { - err = authenticationSASL(client, credsSASL) - } else { - err = perror.New( - perror.FATAL, - perror.InternalError, - "Auth method not supported", + "SSL is required", ) - } - if err != nil { - return - } - - // send auth Ok - authOk := packets.AuthenticationOk{} - if err = perror.Wrap(client.WritePacket(authOk.IntoPacket())); err != nil { - return - } - - // send backend key data - _, err2 := rand.Read(conn.BackendKey[:]) - if err2 != nil { - err = perror.Wrap(err2) - return - } - - keyData := packets.BackendKeyData{ - CancellationKey: conn.BackendKey, - } - if err = perror.Wrap(client.WritePacket(keyData.IntoPacket())); err != nil { - return - } - - if err = updateParameter(client, "client_encoding", "UTF8"); err != nil { - return - } - if err = updateParameter(client, "server_encoding", "UTF8"); err != nil { - return - } - if err = updateParameter(client, "server_version", "14.5"); err != nil { - return - } - - // send ready for query - rfq := packets.ReadyForQuery('I') - if err = perror.Wrap(client.WritePacket(rfq.IntoPacket())); err != nil { return } @@ -399,11 +243,11 @@ func fail(client zap.Conn, err perror.Error) { _ = client.WritePacket(resp.IntoPacket()) } -func Accept(client zap.Conn, options AcceptOptions) (bouncer.Conn, perror.Error) { - conn, err := accept(client, options) +func Accept(client zap.Conn, options AcceptOptions) (AcceptParams, perror.Error) { + params, err := accept(client, options) if err != nil { fail(client, err) - return bouncer.Conn{}, err + return AcceptParams{}, err } - return conn, nil + return params, nil } diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go new file mode 100644 index 0000000000000000000000000000000000000000..580a2eefecff94f96dbf8951780c31df96d0f583 --- /dev/null +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -0,0 +1,183 @@ +package frontends + +import ( + "crypto/rand" + "errors" + + "pggat2/lib/auth" + "pggat2/lib/perror" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +func authenticationSASLInitial(client zap.Conn, creds auth.SASL) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { + // check which authentication method the client wants + packet, err2 := client.ReadPacket(true) + if err2 != nil { + err = perror.Wrap(err2) + return + } + var initialResponse packets.SASLInitialResponse + if !initialResponse.ReadFromPacket(packet) { + err = packets.ErrBadFormat + return + } + + tool, err2 = creds.VerifySASL(initialResponse.Mechanism) + if err2 != nil { + err = perror.Wrap(err2) + return + } + + resp, err2 = tool.Write(initialResponse.InitialResponse) + if err2 != nil { + if errors.Is(err2, auth.ErrSASLComplete) { + done = true + return + } + err = perror.Wrap(err2) + return + } + return +} + +func authenticationSASLContinue(client zap.Conn, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { + packet, err2 := client.ReadPacket(true) + if err2 != nil { + err = perror.Wrap(err2) + return + } + var authResp packets.AuthenticationResponse + if !authResp.ReadFromPacket(packet) { + err = packets.ErrBadFormat + return + } + + resp, err2 = tool.Write(authResp) + if err2 != nil { + if errors.Is(err2, auth.ErrSASLComplete) { + done = true + return + } + err = perror.Wrap(err2) + return + } + return +} + +func authenticationSASL(client zap.Conn, creds auth.SASL) perror.Error { + saslInitial := packets.AuthenticationSASL{ + Mechanisms: creds.SupportedSASLMechanisms(), + } + err := perror.Wrap(client.WritePacket(saslInitial.IntoPacket())) + if err != nil { + return err + } + + tool, resp, done, err := authenticationSASLInitial(client, creds) + if err != nil { + return err + } + + for { + if done { + final := packets.AuthenticationSASLFinal(resp) + err = perror.Wrap(client.WritePacket(final.IntoPacket())) + if err != nil { + return err + } + break + } else { + cont := packets.AuthenticationSASLContinue(resp) + err = perror.Wrap(client.WritePacket(cont.IntoPacket())) + if err != nil { + return err + } + } + + resp, done, err = authenticationSASLContinue(client, tool) + if err != nil { + return err + } + } + + return nil +} + +func updateParameter(client zap.Conn, name, value string) perror.Error { + ps := packets.ParameterStatus{ + Key: name, + Value: value, + } + return perror.Wrap(client.WritePacket(ps.IntoPacket())) +} + +func authenticate(client zap.Conn, options AuthenticateOptions) (params AuthenticateParams, err perror.Error) { + if options.Credentials == nil { + err = perror.New( + perror.FATAL, + perror.InvalidPassword, + "User or database not found", + ) + return + } + if credsSASL, ok := options.Credentials.(auth.SASL); ok { + err = authenticationSASL(client, credsSASL) + } else { + err = perror.New( + perror.FATAL, + perror.InternalError, + "Auth method not supported", + ) + } + if err != nil { + return + } + + // send auth Ok + authOk := packets.AuthenticationOk{} + if err = perror.Wrap(client.WritePacket(authOk.IntoPacket())); err != nil { + return + } + + // send backend key data + _, err2 := rand.Read(params.BackendKey[:]) + if err2 != nil { + err = perror.Wrap(err2) + return + } + + keyData := packets.BackendKeyData{ + CancellationKey: params.BackendKey, + } + if err = perror.Wrap(client.WritePacket(keyData.IntoPacket())); err != nil { + return + } + + if err = updateParameter(client, "client_encoding", "UTF8"); err != nil { + return + } + if err = updateParameter(client, "server_encoding", "UTF8"); err != nil { + return + } + if err = updateParameter(client, "server_version", "14.5"); err != nil { + return + } + + // send ready for query + rfq := packets.ReadyForQuery('I') + if err = perror.Wrap(client.WritePacket(rfq.IntoPacket())); err != nil { + return + } + + return +} + +func Authenticate(client zap.Conn, options AuthenticateOptions) (AuthenticateParams, perror.Error) { + params, err := authenticate(client, options) + if err != nil { + fail(client, err) + return AuthenticateParams{}, err + } + return params, nil +} diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go index 26a77ace46cbcc1ac466b76ce34dbafb3b592afc..ef7dad7b9d388eea9eeec5ae4588c4b6674db928 100644 --- a/lib/bouncer/frontends/v0/options.go +++ b/lib/bouncer/frontends/v0/options.go @@ -3,13 +3,16 @@ package frontends import ( "crypto/tls" - "pggat2/lib/bouncer" + "pggat2/lib/auth" "pggat2/lib/util/strutil" ) type AcceptOptions struct { SSLRequired bool SSLConfig *tls.Config - Pooler bouncer.Pooler AllowedStartupOptions []strutil.CIString } + +type AuthenticateOptions struct { + Credentials auth.Credentials +} diff --git a/lib/bouncer/frontends/v0/params.go b/lib/bouncer/frontends/v0/params.go new file mode 100644 index 0000000000000000000000000000000000000000..0909c5f9311ac79c7eb5d61a51c082c630742918 --- /dev/null +++ b/lib/bouncer/frontends/v0/params.go @@ -0,0 +1,18 @@ +package frontends + +import "pggat2/lib/util/strutil" + +type AcceptParams struct { + CancelKey [8]byte + + // or + + SSLEnabled bool + User string + Database string + InitialParameters map[strutil.CIString]string +} + +type AuthenticateParams struct { + BackendKey [8]byte +} diff --git a/lib/bouncer/pooler.go b/lib/bouncer/pooler.go deleted file mode 100644 index 18c3c5c492328716aa75a0d0d18ebd32ee0de65b..0000000000000000000000000000000000000000 --- a/lib/bouncer/pooler.go +++ /dev/null @@ -1,10 +0,0 @@ -package bouncer - -import ( - "pggat2/lib/auth" -) - -type Pooler interface { - GetUserCredentials(user, database string) auth.Credentials - Cancel(cancellationKey [8]byte) -} diff --git a/lib/gat/acceptor.go b/lib/gat/acceptor.go new file mode 100644 index 0000000000000000000000000000000000000000..7dfa49b3eb20e9fe7ad6001264adfac32445a584 --- /dev/null +++ b/lib/gat/acceptor.go @@ -0,0 +1,58 @@ +package gat + +import ( + "net" + + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/zap" +) + +type Acceptor struct { + Listener net.Listener + Options frontends.AcceptOptions +} + +func (T Acceptor) Accept() (zap.Conn, frontends.AcceptParams, error) { + netConn, err := T.Listener.Accept() + if err != nil { + return nil, frontends.AcceptParams{}, err + } + conn := zap.WrapNetConn(netConn) + params, err := frontends.Accept(conn, T.Options) + if err != nil { + _ = conn.Close() + return nil, frontends.AcceptParams{}, err + } + return conn, params, nil +} + +func Listen(network, address string, options frontends.AcceptOptions) (Acceptor, error) { + listener, err := net.Listen(network, address) + if err != nil { + return Acceptor{}, err + } + return Acceptor{ + Listener: listener, + Options: options, + }, nil +} + +func Serve(acceptor Acceptor, gat *Gat) error { + for { + conn, params, err := acceptor.Accept() + if err != nil { + continue + } + go func() { + _ = gat.Serve(conn, params) + }() + } +} + +func ListenAndServe(network, address string, options frontends.AcceptOptions, gat *Gat) error { + listener, err := Listen(network, address, options) + if err != nil { + return err + } + return Serve(listener, gat) +} diff --git a/lib/gat/configs/pgbouncer/config.go b/lib/gat/configs/pgbouncer/config.go deleted file mode 100644 index 27cbdce65297e6dbe5c205ee5758a237306b70e7..0000000000000000000000000000000000000000 --- a/lib/gat/configs/pgbouncer/config.go +++ /dev/null @@ -1,429 +0,0 @@ -package pgbouncer - -import ( - "errors" - "net" - "os" - "strconv" - "strings" - "time" - - "tuxpa.in/a/zlog/log" - - "pggat2/lib/auth/credentials" - "pggat2/lib/gat" - "pggat2/lib/gat/pools/session" - "pggat2/lib/gat/pools/transaction" - "pggat2/lib/util/encoding/ini" - "pggat2/lib/util/encoding/userlist" - "pggat2/lib/util/flip" - "pggat2/lib/util/strutil" -) - -type PoolMode string - -const ( - PoolModeSession PoolMode = "session" - PoolModeTransaction PoolMode = "transaction" - PoolModeStatement PoolMode = "statement" -) - -type AuthType string - -const ( - AuthTypeCert AuthType = "cert" - AuthTypeMd5 AuthType = "md5" - AuthTypeScramSha256 AuthType = "scram-sha-256" - AuthTypePlain AuthType = "plain" - AuthTypeTrust AuthType = "trust" - AuthTypeAny AuthType = "any" - AuthTypeHba AuthType = "hba" - AuthTypePam AuthType = "pam" -) - -type SSLMode string - -const ( - SSLModeDisable SSLMode = "disable" - SSLModeAllow SSLMode = "allow" - SSLModePrefer SSLMode = "prefer" - SSLModeRequire SSLMode = "require" - SSLModeVerifyCa SSLMode = "verify-ca" - SSLModeVerifyFull SSLMode = "verify-full" -) - -type TLSProtocol string - -const ( - TLSProtocolV1_0 TLSProtocol = "tlsv1.0" - TLSProtocolV1_1 TLSProtocol = "tlsv1.1" - TLSProtocolV1_2 TLSProtocol = "tlsv1.2" - TLSProtocolV1_3 TLSProtocol = "tlsv1.3" - TLSProtocolAll TLSProtocol = "all" - TLSProtocolSecure TLSProtocol = "secure" - TLSProtocolLegacy TLSProtocol = "legacy" -) - -type TLSCipher string - -type TLSECDHCurve string - -type TLSDHEParams string - -type PgBouncer struct { - LogFile string `ini:"logfile"` - PidFile string `ini:"pidfile"` - ListenAddr string `ini:"listen_addr"` - ListenPort int `ini:"listen_port"` - UnixSocketDir string `ini:"unix_socket_dir"` - UnixSocketMode string `ini:"unix_socket_mode"` - UnixSocketGroup string `ini:"unix_socket_group"` - User string `ini:"user"` - PoolMode PoolMode `ini:"pool_mode"` - MaxClientConn int `ini:"max_client_conn"` - DefaultPoolSize int `ini:"default_pool_size"` - MinPoolSize int `ini:"min_pool_size"` - ReservePoolSize int `ini:"reserve_pool_size"` - ReservePoolTimeout float64 `ini:"reserve_pool_timeout"` - MaxDBConnections int `ini:"max_db_connections"` - MaxUserConnections int `ini:"max_user_connections"` - ServerRoundRobin int `ini:"server_round_robin"` - TrackExtraParameters []strutil.CIString `ini:"track_extra_parameters"` - IgnoreStartupParameters []strutil.CIString `ini:"ignore_startup_parameters"` - PeerID int `ini:"peer_id"` - DisablePQExec int `ini:"disable_pqexec"` - ApplicationNameAddHost int `ini:"application_name_add_host"` - ConfFile string `ini:"conffile"` - ServiceName string `ini:"service_name"` - StatsPeriod int `ini:"stats_period"` - AuthType string `ini:"auth_type"` - AuthHbaFile string `ini:"auth_hba_file"` - AuthFile string `ini:"auth_file"` - AuthUser string `ini:"auth_user"` - AuthQuery string `ini:"auth_query"` - AuthDbname string `ini:"auth_dbname"` - Syslog string `ini:"syslog"` - SyslogIdent string `ini:"syslog_ident"` - SyslogFacility string `ini:"syslog_facility"` - LogConnections int `ini:"log_connections"` - LogDisconnections int `ini:"log_disconnections"` - LogPoolerErrors int `ini:"log_pooler_errors"` - LogStats int `ini:"log_stats"` - Verbose int `ini:"verbose"` - AdminUsers []string `ini:"auth_users"` - StatsUsers []string `ini:"stats_users"` - ServerResetQuery string `ini:"server_reset_query"` - ServerResetQueryAlways int `ini:"server_reset_query_always"` - ServerCheckDelay float64 `ini:"server_check_delay"` - ServerCheckQuery string `ini:"server_check_query"` - ServerFastClose int `ini:"server_fast_close"` - ServerLifetime float64 `ini:"server_lifetime"` - ServerIdleTimeout float64 `ini:"server_idle_timeout"` - ServerConnectTimeout float64 `ini:"server_connect_timeout"` - ServerLoginRetry float64 `ini:"server_login_retry"` - ClientLoginTimeout float64 `ini:"client_login_timeout"` - AutodbIdleTimeout float64 `ini:"autodb_idle_timeout"` - DnsMaxTtl float64 `ini:"dns_max_ttl"` - DnsNxdomainTtl float64 `ini:"dns_nxdomain_ttl"` - DnsZoneCheckPeriod float64 `ini:"dns_zone_check_period"` - ResolvConf string `ini:"resolv.conf"` - ClientTLSSSLMode SSLMode `ini:"client_tls_sslmode"` - ClientTLSKeyFile string `ini:"client_tls_key_file"` - ClientTLSCertFile string `ini:"client_tls_cert_file"` - ClientTLSCaFile string `ini:"client_tls_ca_file"` - ClientTLSProtocols []TLSProtocol `ini:"client_tls_protocols"` - ClientTLSCiphers []TLSCipher `ini:"client_tls_ciphers"` - ClientTLSECDHCurve TLSECDHCurve `ini:"client_tls_ecdhcurve"` - ClientTLSDHEParams TLSDHEParams `ini:"client_tls_dheparams"` - ServerTLSSSLMode SSLMode `ini:"server_tls_sslmode"` - ServerTLSCaFile string `ini:"server_tls_ca_file"` - ServerTLSKeyFile string `ini:"server_tls_key_file"` - ServerTLSCertFile string `ini:"server_tls_cert_file"` - ServerTLSProtocols []TLSProtocol `ini:"server_tls_protocols"` - ServerTLSCiphers []TLSCipher `ini:"server_tls_ciphers"` - QueryTimeout float64 `ini:"query_timeout"` - QueryWaitTimeout float64 `ini:"query_wait_timeout"` - CancelWaitTimeout float64 `ini:"cancel_wait_timeout"` - ClientIdleTimeout float64 `ini:"client_idle_timeout"` - IdleTransactionTimeout float64 `ini:"idle_transaction_timeout"` - SuspendTimeout float64 `ini:"suspend_timeout"` - PktBuf int `ini:"pkt_buf"` - MaxPacketSize int `ini:"max_packet_size"` - ListenBacklog int `ini:"listen_backlog"` - SbufLoopcnt int `ini:"sbuf_loopcnt"` - SoReuseport int `ini:"so_reuseport"` - TcpDeferAccept int `ini:"tcp_defer_accept"` - TcpSocketBuffer int `ini:"tcp_socket_buffer"` - TcpKeepalive int `ini:"tcp_keepalive"` - TcpKeepidle int `ini:"tcp_keepidle"` - TcpKeepintvl int `ini:"tcp_keepintvl"` - TcpUserTimeout int `ini:"tcp_user_timeout"` -} - -type Database struct { - DBName string `ini:"dbname"` - Host string `ini:"host"` - Port int `ini:"port"` - User string `ini:"user"` - Password string `ini:"password"` - AuthUser string `ini:"auth_user"` - PoolSize int `ini:"pool_size"` - MinPoolSize int `ini:"min_pool_size"` - ReservePool int `ini:"reserve_pool"` - ConnectQuery string `ini:"connect_query"` - PoolMode PoolMode `ini:"pool_mode"` - MaxDBConnections int `ini:"max_db_connections"` - StartupParameters map[strutil.CIString]string `ini:"*"` -} - -type User struct { - PoolMode PoolMode `ini:"pool_mode"` - MaxUserConnections int `ini:"max_user_connections"` -} - -type Peer struct { - Host string `ini:"host"` - Port int `ini:"port"` - PoolSize int `ini:"pool_size"` -} - -type Config struct { - PgBouncer PgBouncer `ini:"pgbouncer"` - Databases map[string]Database `ini:"databases"` - Users map[string]User `ini:"users"` - Peers map[string]Peer `ini:"peers"` -} - -var Default = Config{ - PgBouncer: PgBouncer{ - ListenPort: 6432, - UnixSocketDir: "/tmp", - UnixSocketMode: "0777", - PoolMode: PoolModeSession, - MaxClientConn: 100, - DefaultPoolSize: 20, - ReservePoolTimeout: 5.0, - TrackExtraParameters: []strutil.CIString{ - strutil.MakeCIString("IntervalStyle"), - }, - ServiceName: "pgbouncer", - StatsPeriod: 60, - AuthQuery: "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", - SyslogIdent: "pgbouncer", - SyslogFacility: "daemon", - LogConnections: 1, - LogDisconnections: 1, - LogPoolerErrors: 1, - LogStats: 1, - ServerResetQuery: "DISCARD ALL", - ServerCheckDelay: 30.0, - ServerCheckQuery: "select 1", - ServerLifetime: 3600.0, - ServerIdleTimeout: 600.0, - ServerConnectTimeout: 15.0, - ServerLoginRetry: 15.0, - ClientLoginTimeout: 60.0, - AutodbIdleTimeout: 3600.0, - DnsMaxTtl: 15.0, - DnsNxdomainTtl: 15.0, - ClientTLSSSLMode: SSLModeDisable, - ClientTLSProtocols: []TLSProtocol{ - TLSProtocolSecure, - }, - ClientTLSCiphers: []TLSCipher{ - "fast", - }, - ClientTLSECDHCurve: "auto", - ServerTLSSSLMode: SSLModePrefer, - ServerTLSProtocols: []TLSProtocol{ - TLSProtocolSecure, - }, - ServerTLSCiphers: []TLSCipher{ - "fast", - }, - QueryWaitTimeout: 120.0, - CancelWaitTimeout: 10.0, - SuspendTimeout: 10.0, - PktBuf: 4096, - MaxPacketSize: 2147483647, - ListenBacklog: 128, - SbufLoopcnt: 5, - TcpDeferAccept: 1, - TcpKeepalive: 1, - }, -} - -func Load(config string) (Config, error) { - conf, err := ini.ReadFile(config) - if err != nil { - return Config{}, err - } - - var c = Default - err = ini.Unmarshal(conf, &c) - return c, err -} - -func (T *Config) ListenAndServe() error { - trackedParameters := append([]strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), - }, T.PgBouncer.TrackExtraParameters...) - - ignoreStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...) - - pooler := gat.NewPooler(gat.PoolerConfig{ - AllowedStartupParameters: ignoreStartupParameters, - }) - - var authFile map[string]string - if T.PgBouncer.AuthFile != "" { - file, err := os.ReadFile(T.PgBouncer.AuthFile) - if err != nil { - return err - } - - authFile, err = userlist.Unmarshal(file) - if err != nil { - return err - } - } - - for name, user := range T.Users { - creds := credentials.Cleartext{ - Username: name, - Password: authFile[name], // TODO(garet) md5 and sasl - } - u := gat.NewUser(creds) - pooler.AddUser(u) - - for dbname, db := range T.Databases { - // filter out dbs specific to users - if db.User != "" && db.User != name { - continue - } - - // override dbname - if db.DBName != "" { - dbname = db.DBName - } - - // override poolmode - var poolMode PoolMode - if db.PoolMode != "" { - poolMode = db.PoolMode - } else if user.PoolMode != "" { - poolMode = user.PoolMode - } else { - poolMode = T.PgBouncer.PoolMode - } - - rawPoolConfig := gat.BaseRawPoolConfig{ - TrackedParameters: trackedParameters, - } - - var raw gat.RawPool - switch poolMode { - case PoolModeSession: - raw = session.NewPool(session.Config{ - RoundRobin: T.PgBouncer.ServerRoundRobin != 0, - BaseRawPoolConfig: rawPoolConfig, - }) - case PoolModeTransaction: - raw = transaction.NewPool(transaction.Config{ - BaseRawPoolConfig: rawPoolConfig, - }) - default: - return errors.New("unsupported pool mode") - } - - p := gat.NewPool(raw, gat.PoolConfig{ - IdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)), - }) - u.AddPool(dbname, p) - - if db.Host == "" { - // connect over unix socket - // TODO(garet) - } else { - var address string - if db.Port == 0 { - address = net.JoinHostPort(db.Host, "5432") - } else { - address = net.JoinHostPort(db.Host, strconv.Itoa(db.Port)) - } - - creds := creds - if db.Password != "" { - // lookup password - creds.Password = db.Password - } - - // connect over tcp - recipe := gat.TCPRecipe{ - Database: dbname, - Address: address, - Credentials: creds, - MinConnections: db.MinPoolSize, - MaxConnections: db.MaxDBConnections, - StartupParameters: db.StartupParameters, - } - if recipe.MinConnections == 0 { - recipe.MinConnections = T.PgBouncer.MinPoolSize - } - if recipe.MaxConnections == 0 { - recipe.MaxConnections = T.PgBouncer.MaxDBConnections - } - - p.AddRecipe("pgbouncer", recipe) - } - } - } - - var bank flip.Bank - - if T.PgBouncer.ListenAddr != "" { - bank.Queue(func() error { - listenAddr := T.PgBouncer.ListenAddr - if listenAddr == "*" { - listenAddr = "" - } - - listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.PgBouncer.ListenPort)) - - listener, err := net.Listen("tcp", listen) - if err != nil { - return err - } - - log.Printf("listening on %s", listen) - - return pooler.ListenAndServe(listener) - }) - } - - // listen on unix socket - bank.Queue(func() error { - dir := T.PgBouncer.UnixSocketDir - port := T.PgBouncer.ListenPort - - if !strings.HasSuffix(dir, "/") { - dir = dir + "/" - } - dir = dir + ".s.PGSQL." + strconv.Itoa(port) - - listener, err := net.Listen("unix", dir) - if err != nil { - return err - } - - log.Printf("listening on unix:%s", dir) - - return pooler.ListenAndServe(listener) - }) - - return bank.Wait() -} diff --git a/lib/gat/configs/zalando/config.go b/lib/gat/configs/zalando/config.go deleted file mode 100644 index 52a0d858696dce38826a006311405b14420b25d4..0000000000000000000000000000000000000000 --- a/lib/gat/configs/zalando/config.go +++ /dev/null @@ -1,90 +0,0 @@ -package zalando - -import ( - "errors" - "fmt" - "net" - "strconv" - - "tuxpa.in/a/zlog/log" - - "gfx.cafe/util/go/gun" - - "pggat2/lib/auth/credentials" - "pggat2/lib/gat" - "pggat2/lib/gat/pools/session" - "pggat2/lib/gat/pools/transaction" - "pggat2/lib/util/flip" -) - -type Config struct { - PGHost string `env:"PGHOST"` - PGPort int `env:"PGPORT"` - PGUser string `env:"PGUSER"` - PGSchema string `env:"PGSCHEMA"` - PGPassword string `env:"PGPASSWORD"` - PoolerPort int `env:"CONNECTION_POOLER_PORT"` - PoolerMode string `env:"CONNECTION_POOLER_MODE"` - PoolerDefaultSize int `env:"CONNECTION_POOLER_DEFAULT_SIZE"` - PoolerMinSize int `env:"CONNECTION_POOLER_MIN_SIZE"` - PoolerReserveSize int `env:"CONNECTION_POOLER_RESERVE_SIZE"` - PoolerMaxClientConn int `env:"CONNECTION_POOLER_MAX_CLIENT_CONN"` - PoolerMaxDBConn int `env:"CONNECTION_POOLER_MAX_DB_CONN"` -} - -func Load() (Config, error) { - var conf Config - gun.Load(&conf) - if conf.PoolerMode == "" { - return Config{}, errors.New("expected pooler mode") - } - - return conf, nil -} - -func (T *Config) ListenAndServe() error { - pooler := gat.NewPooler(gat.PoolerConfig{}) - - creds := credentials.Cleartext{ - Username: T.PGUser, - Password: T.PGPassword, - } - - user := gat.NewUser(creds) - pooler.AddUser(user) - - var rawPool gat.RawPool - if T.PoolerMode == "transaction" { - rawPool = transaction.NewPool(transaction.Config{}) - } else { - rawPool = session.NewPool(session.Config{}) - } - - pool := gat.NewPool(rawPool, gat.PoolConfig{}) - user.AddPool("test", pool) - - pool.AddRecipe("zalando", gat.TCPRecipe{ - Address: net.JoinHostPort(T.PGHost, strconv.Itoa(T.PGPort)), - Credentials: creds, - MinConnections: T.PoolerMinSize, - MaxConnections: T.PoolerMaxDBConn, - Database: "test", - }) - - var bank flip.Bank - - bank.Queue(func() error { - listen := fmt.Sprintf(":%d", T.PoolerPort) - - listener, err := net.Listen("tcp", listen) - if err != nil { - return err - } - - log.Printf("listening on %s", listen) - - return pooler.ListenAndServe(listener) - }) - - return bank.Wait() -} diff --git a/lib/gat/dialer.go b/lib/gat/dialer.go new file mode 100644 index 0000000000000000000000000000000000000000..6df34123ba57e2bbc9d38ef125e4f1585a8fb48d --- /dev/null +++ b/lib/gat/dialer.go @@ -0,0 +1,33 @@ +package gat + +import ( + "net" + + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/zap" +) + +type Dialer interface { + Dial() (zap.Conn, backends.AcceptParams, error) +} + +type NetDialer struct { + Network string + Address string + + AcceptOptions backends.AcceptOptions +} + +func (T NetDialer) Dial() (zap.Conn, backends.AcceptParams, error) { + c, err := net.Dial(T.Network, T.Address) + if err != nil { + return nil, backends.AcceptParams{}, err + } + conn := zap.WrapNetConn(c) + params, err := backends.Accept(conn, T.AcceptOptions) + if err != nil { + return nil, backends.AcceptParams{}, err + } + + return conn, params, nil +} diff --git a/lib/gat/gat.go b/lib/gat/gat.go new file mode 100644 index 0000000000000000000000000000000000000000..4d254dd93eafee4d3996d6f5a42d013354cecdf1 --- /dev/null +++ b/lib/gat/gat.go @@ -0,0 +1,50 @@ +package gat + +import ( + "pggat2/lib/auth" + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/zap" +) + +type Gat struct { + TestPool *Pool +} + +func (T *Gat) Serve(client zap.Conn, acceptParams frontends.AcceptParams) error { + defer func() { + _ = client.Close() + }() + + if acceptParams.CancelKey != [8]byte{} { + // TODO(garet) execute cancel + return nil + } + + pool, err := T.GetPool(acceptParams.User, acceptParams.Database) + if err != nil { + return err + } + + var credentials auth.Credentials + if pool != nil { + credentials = pool.GetCredentials() + } + + authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{ + Credentials: credentials, + }) + if err != nil { + return err + } + + if pool == nil { + return nil + } + + return pool.Serve(client, acceptParams, authParams) +} + +func (T *Gat) GetPool(user, database string) (*Pool, error) { + return T.TestPool, nil + return nil, nil // TODO(garet) +} diff --git a/lib/gat/pool.go b/lib/gat/pool.go index 48da8ad453929ee3a70ec84a5aa83117460e433d..a1a85058232a08b868b05182db5c292a2518de68 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -1,228 +1,241 @@ package gat import ( - "sync" - "time" - - "tuxpa.in/a/zlog/log" - "github.com/google/uuid" + "tuxpa.in/a/zlog/log" - "pggat2/lib/bouncer" + "pggat2/lib/auth" "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/util/maps" + "pggat2/lib/bouncer/bouncers/v2" + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/middleware/interceptor" + "pggat2/lib/middleware/middlewares/unterminate" "pggat2/lib/util/maths" - "pggat2/lib/util/slices" - "pggat2/lib/util/strutil" + "pggat2/lib/zap" ) -type Context struct { - OnWait chan<- struct{} +type poolRecipe struct { + recipe Recipe + servers map[uuid.UUID]struct{} } -type RawPool interface { - Serve(ctx *Context, client bouncer.Conn) - - AddServer(server bouncer.Conn) uuid.UUID - GetServer(id uuid.UUID) bouncer.Conn - RemoveServer(id uuid.UUID) bouncer.Conn - - // LookupCorresponding finds the corresponding server and key for a particular client - LookupCorresponding(key [8]byte) (uuid.UUID, [8]byte, bool) +type Pool struct { + options PoolOptions - ScaleDown(amount int) (remaining int) - IdleSince() time.Time -} + recipes map[string]*poolRecipe -type BaseRawPoolConfig struct { - TrackedParameters []strutil.CIString + servers map[uuid.UUID]zap.Conn + clients map[uuid.UUID]zap.Conn } -type PoolRecipe struct { - removed bool - servers []uuid.UUID - mu sync.Mutex - - r Recipe +type PoolOptions struct { + Credentials auth.Credentials + Pooler Pooler + ServerResetQuery string } -type Pool struct { - config PoolConfig - - recipes maps.RWLocked[string, *PoolRecipe] - - ctx Context - raw RawPool +func NewPool(options PoolOptions) *Pool { + return &Pool{ + options: options, + } } -type PoolConfig struct { - // IdleTimeout determines how long idle servers are kept in the pool - IdleTimeout time.Duration +func (T *Pool) GetCredentials() auth.Credentials { + return T.options.Credentials } -func NewPool(raw RawPool, config PoolConfig) *Pool { - onWait := make(chan struct{}) - pool := &Pool{ - config: config, - ctx: Context{ - OnWait: onWait, - }, - raw: raw, +func (T *Pool) scale(name string, amount int) { + recipe := T.recipes[name] + if recipe == nil { + return } - go func() { - for range onWait { - pool.ScaleUp(1) - } - }() + target := maths.Clamp(len(recipe.servers)+amount, recipe.recipe.MinConnections, recipe.recipe.MaxConnections) + diff := target - len(recipe.servers) - if config.IdleTimeout != 0 { - go func() { - for { - var wait time.Duration - - now := time.Now() - idle := pool.IdleSince() - for now.Sub(idle) > config.IdleTimeout { - if idle == (time.Time{}) { - break - } - pool.ScaleDown(1) - idle = pool.IdleSince() - } - - if idle == (time.Time{}) { - wait = config.IdleTimeout - } else { - wait = now.Sub(idle.Add(config.IdleTimeout)) - } - - time.Sleep(wait) - } - }() - } + for diff > 0 { + diff-- - return pool -} + // add server + server, params, err := recipe.recipe.Dialer.Dial() + if err != nil { + log.Printf("failed to connect to server: %v", err) + continue + } -func (T *Pool) _tryAddServers(recipe *PoolRecipe, amount int) (remaining int) { - remaining = amount + _ = params // TODO(garet) - if recipe.removed { - return + serverID := T.addServer(server) + if recipe.servers == nil { + recipe.servers = make(map[uuid.UUID]struct{}) + } + recipe.servers[serverID] = struct{}{} } - j := 0 - for i := 0; i < len(recipe.servers); i++ { - if T.raw.GetServer(recipe.servers[i]).RW != nil { - recipe.servers[j] = recipe.servers[i] - j++ + for diff < 0 { + diff++ + + // remove server + for s := range recipe.servers { + T.removeServer(s) + break } } - recipe.servers = recipe.servers[:j] +} - var max = amount - maxConnections := recipe.r.GetMaxConnections() - if maxConnections != 0 { - max = maths.Min(maxConnections-j, max) +func (T *Pool) AddRecipe(name string, recipe Recipe) { + if T.recipes == nil { + T.recipes = make(map[string]*poolRecipe) } - for i := 0; i < max; i++ { - conn, err := recipe.r.Connect() - if err != nil { - log.Printf("error connecting to server: %v", err) - continue - } - id := T.raw.AddServer(conn) - recipe.servers = append(recipe.servers, id) - remaining-- + T.recipes[name] = &poolRecipe{ + recipe: recipe, } - return + T.scale(name, 0) } -func (T *Pool) tryAddServers(recipe *PoolRecipe, amount int) (remaining int) { - recipe.mu.Lock() - defer recipe.mu.Unlock() - - return T._tryAddServers(recipe, amount) -} - -func (T *Pool) addRecipe(recipe *PoolRecipe) { - recipe.mu.Lock() - defer recipe.mu.Unlock() - - recipe.removed = false - min := recipe.r.GetMinConnections() - len(recipe.servers) - T._tryAddServers(recipe, min) +func (T *Pool) RemoveRecipe(name string) { + if recipe, ok := T.recipes[name]; ok { + recipe.recipe.MaxConnections = 0 + T.scale(name, 0) + delete(T.recipes, name) + } } -func (T *Pool) removeRecipe(recipe *PoolRecipe) { - recipe.mu.Lock() - defer recipe.mu.Unlock() +func (T *Pool) addClient( + client zap.Conn, +) uuid.UUID { + clientID := uuid.New() + T.options.Pooler.AddClient(clientID) - recipe.removed = true - for _, id := range recipe.servers { - if conn := T.raw.RemoveServer(id); conn.RW != nil { - _ = conn.RW.Close() - } + if T.clients == nil { + T.clients = make(map[uuid.UUID]zap.Conn) } - - recipe.servers = recipe.servers[:0] + T.clients[clientID] = client + return clientID } -func (T *Pool) ScaleUp(amount int) (remaining int) { - remaining = amount - T.recipes.Range(func(_ string, r *PoolRecipe) bool { - remaining = T.tryAddServers(r, remaining) - return remaining != 0 - }) - return remaining +func (T *Pool) removeClient( + clientID uuid.UUID, +) { + T.options.Pooler.RemoveClient(clientID) + if client, ok := T.clients[clientID]; ok { + _ = client.Close() + delete(T.clients, clientID) + } } -func (T *Pool) ScaleDown(amount int) (remaining int) { - return T.raw.ScaleDown(amount) -} +func (T *Pool) addServer( + server zap.Conn, +) uuid.UUID { + serverID := uuid.New() + T.options.Pooler.AddServer(serverID) -func (T *Pool) IdleSince() time.Time { - return T.raw.IdleSince() + if T.servers == nil { + T.servers = make(map[uuid.UUID]zap.Conn) + } + T.servers[serverID] = server + return serverID } -func (T *Pool) AddRecipe(name string, recipe Recipe) { - r := &PoolRecipe{ - r: recipe, - } - T.addRecipe(r) - if old, ok := T.recipes.Swap(name, r); ok { - T.removeRecipe(old) +func (T *Pool) acquireServer( + clientID uuid.UUID, +) (serverID uuid.UUID, server zap.Conn) { + serverID = T.options.Pooler.AcquireConcurrent(clientID) + if serverID == uuid.Nil { + // TODO(garet) scale up + serverID = T.options.Pooler.AcquireAsync(clientID) } + + server = T.servers[serverID] + return } -func (T *Pool) RemoveRecipe(name string) { - if r, ok := T.recipes.LoadAndDelete(name); ok { - T.removeRecipe(r) +func (T *Pool) removeServer( + serverID uuid.UUID, +) { + T.options.Pooler.RemoveServer(serverID) + if server, ok := T.servers[serverID]; ok { + _ = server.Close() + delete(T.servers, serverID) } } -func (T *Pool) Serve(conn bouncer.Conn) { - T.raw.Serve(&T.ctx, conn) +func (T *Pool) tryReleaseServer( + serverID uuid.UUID, +) bool { + if !T.options.Pooler.CanRelease(serverID) { + return false + } + T.releaseServer(serverID) + return true } -func (T *Pool) Cancel(key [8]byte) { - server, cancelKey, ok := T.raw.LookupCorresponding(key) - if !ok { - return +func (T *Pool) releaseServer( + serverID uuid.UUID, +) { + if T.options.ServerResetQuery != "" { + server := T.servers[serverID] + err := backends.QueryString(new(backends.Context), server, T.options.ServerResetQuery) + if err != nil { + T.removeServer(serverID) + return + } } - T.recipes.Range(func(_ string, recipe *PoolRecipe) bool { - if slices.Contains(recipe.servers, server) { - rw, err := recipe.r.Dial() - if err != nil { - return false + T.options.Pooler.Release(serverID) +} + +func (T *Pool) Serve( + client zap.Conn, + acceptParams frontends.AcceptParams, + authParams frontends.AuthenticateParams, +) error { + client = interceptor.NewInterceptor( + client, + unterminate.Unterminate, + // TODO(garet) add middlewares based on Pool.options + ) + + defer func() { + _ = client.Close() + }() + + clientID := T.addClient(client) + + var serverID uuid.UUID + var server zap.Conn + + defer func() { + if serverID != uuid.Nil { + T.releaseServer(serverID) + } + }() + + for { + packet, err := client.ReadPacket(true) + if err != nil { + return err + } + + if serverID == uuid.Nil { + serverID, server = T.acquireServer(clientID) + } + clientErr, serverErr := bouncers.Bounce(client, server, packet) + if serverErr != nil { + T.removeServer(serverID) + serverID = uuid.Nil + server = nil + return serverErr + } else { + if T.tryReleaseServer(serverID) { + serverID = uuid.Nil + server = nil } - // error doesn't matter - _ = backends.Cancel(rw, cancelKey) - return false } - return true - }) + + if clientErr != nil { + return clientErr + } + } } diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index 8f5072dd12ac7044e80f2a36211d9550b9740058..747096089404a4e7780f03574a502ed714ac6045 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -1,124 +1,27 @@ package gat -import ( - "net" +import "github.com/google/uuid" - "pggat2/lib/auth" - "pggat2/lib/bouncer" - "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/middleware/interceptor" - "pggat2/lib/middleware/middlewares/unterminate" - "pggat2/lib/util/maps" - "pggat2/lib/util/slices" - "pggat2/lib/util/strutil" - "pggat2/lib/zap" -) +type Pooler interface { + AddClient(client uuid.UUID) + RemoveClient(client uuid.UUID) -type Pooler struct { - config PoolerConfig + AddServer(server uuid.UUID) + RemoveServer(server uuid.UUID) - // key -> pool for cancellation - keys maps.RWLocked[[8]byte, *Pool] + // AcquireConcurrent tries to acquire a peer for the client without stalling. + // Returns uuid.Nil if no peer can be acquired + AcquireConcurrent(client uuid.UUID) uuid.UUID - users maps.RWLocked[string, *User] -} - -type PoolerConfig struct { - AllowedStartupParameters []strutil.CIString - SSLMode bouncer.SSLMode -} - -func NewPooler(config PoolerConfig) *Pooler { - return &Pooler{ - config: config, - } -} - -func (T *Pooler) AddUser(user *User) { - T.users.Store(user.GetCredentials().GetUsername(), user) -} - -func (T *Pooler) RemoveUser(name string) { - T.users.Delete(name) -} - -func (T *Pooler) GetUser(name string) *User { - user, _ := T.users.Load(name) - return user -} - -func (T *Pooler) GetUserCredentials(user, database string) auth.Credentials { - u := T.GetUser(user) - if u == nil { - return nil - } - d := u.GetPool(database) - if d == nil { - return nil - } - return u.GetCredentials() -} + // AcquireAsync will stall until a peer is available. + AcquireAsync(client uuid.UUID) uuid.UUID -func (T *Pooler) Cancel(key [8]byte) { - pool, ok := T.keys.Load(key) - if !ok { - return - } + // CanRelease will check if a server can be released after a transaction. + // Some poolers (such as session poolers) do not release servers after each transaction. + // Returns true if Release could be called. + CanRelease(server uuid.UUID) bool - pool.Cancel(key) + // Release will force release the server. + // This should be called when the paired client has disconnected, or after CanRelease returns true. + Release(server uuid.UUID) } - -func (T *Pooler) IsStartupParameterAllowed(parameter strutil.CIString) bool { - return slices.Contains(T.config.AllowedStartupParameters, parameter) -} - -func (T *Pooler) Serve(client zap.Conn) { - defer func() { - _ = client.Close() - }() - - client = interceptor.NewInterceptor( - client, - unterminate.Unterminate, - ) - - conn, err := frontends.Accept( - client, - frontends.AcceptOptions{ - SSLRequired: T.config.SSLMode.IsRequired(), - // TODO(garet) SSL Config - Pooler: T, - AllowedStartupOptions: T.config.AllowedStartupParameters, - }, - ) - if err != nil { - return - } - - user := T.GetUser(conn.User) - if user == nil { - return - } - - pool := user.GetPool(conn.Database) - if pool == nil { - return - } - - T.keys.Store(conn.BackendKey, pool) - defer T.keys.Delete(conn.BackendKey) - - pool.Serve(conn) -} - -func (T *Pooler) ListenAndServe(listener net.Listener) error { - for { - conn, err := listener.Accept() - if err != nil { - return err - } - go T.Serve(zap.WrapNetConn(conn)) - } -} - -var _ bouncer.Pooler = (*Pooler)(nil) diff --git a/lib/gat/pools/session/config.go b/lib/gat/pools/session/config.go deleted file mode 100644 index ae442655dc9a8b4df0bfce2c147938be80bab8ad..0000000000000000000000000000000000000000 --- a/lib/gat/pools/session/config.go +++ /dev/null @@ -1,11 +0,0 @@ -package session - -import "pggat2/lib/gat" - -type Config struct { - gat.BaseRawPoolConfig - - // RoundRobin determines which order connections will be chosen. If false, connections are handled lifo, - // otherwise they are chosen fifo - RoundRobin bool -} diff --git a/lib/gat/pools/session/metrics.go b/lib/gat/pools/session/metrics.go deleted file mode 100644 index 3058166f2ab41e6205efc82f014f073f797a1e8b..0000000000000000000000000000000000000000 --- a/lib/gat/pools/session/metrics.go +++ /dev/null @@ -1,30 +0,0 @@ -package session - -import ( - "fmt" - "time" - - "github.com/google/uuid" -) - -type WorkerMetrics struct { - LastActive time.Time -} - -type Metrics struct { - Workers map[uuid.UUID]WorkerMetrics -} - -func (T *Metrics) InUse() int { - var used int - for _, worker := range T.Workers { - if worker.LastActive == (time.Time{}) { - used++ - } - } - return used -} - -func (T *Metrics) String() string { - return fmt.Sprintf("%d in use / %d total", T.InUse(), len(T.Workers)) -} diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index c180f7f24ceac5730e3114cf16ad4074a28d9128..1e33ce8c078c32a8dd447935ba4dd1daacd4494a 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -1,275 +1,8 @@ package session -import ( - "sync" - "time" +import "pggat2/lib/gat" - "github.com/google/uuid" - - "pggat2/lib/bouncer" - "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/bouncer/bouncers/v2" - "pggat2/lib/gat" - "pggat2/lib/util/chans" - "pggat2/lib/util/maps" - "pggat2/lib/util/ring" - "pggat2/lib/util/slices" - "pggat2/lib/util/strutil" - packets "pggat2/lib/zap/packets/v3.0" -) - -type queueItem struct { - added time.Time - id uuid.UUID -} - -type Pool struct { - config Config - - // use slice lifo for better perf - queue ring.Ring[queueItem] - conns map[uuid.UUID]bouncer.Conn - ready sync.Cond - qmu sync.Mutex -} - -// NewPool creates a new session pool. -func NewPool(config Config) *Pool { - p := &Pool{ - config: config, - } - p.ready.L = &p.qmu - return p -} - -func (T *Pool) acquire(ctx *gat.Context) (uuid.UUID, bouncer.Conn) { - T.qmu.Lock() - defer T.qmu.Unlock() - for T.queue.Length() == 0 { - chans.TrySend(ctx.OnWait, struct{}{}) - T.ready.Wait() - } - - var entry queueItem - if T.config.RoundRobin { - entry, _ = T.queue.PopFront() - } else { - entry, _ = T.queue.PopBack() - } - return entry.id, T.conns[entry.id] -} - -func (T *Pool) _release(id uuid.UUID) { - T.queue.PushBack(queueItem{ - added: time.Now(), - id: id, - }) - - T.ready.Signal() -} - -func (T *Pool) close(id uuid.UUID, conn bouncer.Conn) { - _ = conn.RW.Close() - T.qmu.Lock() - defer T.qmu.Unlock() - - delete(T.conns, id) -} - -func (T *Pool) release(id uuid.UUID, conn bouncer.Conn) { - // reset session state - err := backends.QueryString(&backends.Context{}, conn.RW, "DISCARD ALL") - if err != nil { - T.close(id, conn) - return - } - - T.qmu.Lock() - defer T.qmu.Unlock() - T._release(id) -} - -func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { - defer func() { - _ = client.RW.Close() - }() - - serverOK := true - serverID, server := T.acquire(ctx) - defer func() { - if serverOK { - T.release(serverID, server) - } else { - T.close(serverID, server) - } - }() - - if func() bool { - add := func(key strutil.CIString) error { - if value, ok := server.InitialParameters[key]; ok { - ps := packets.ParameterStatus{ - Key: key.String(), - Value: value, - } - - if err := client.RW.WritePacket(ps.IntoPacket()); err != nil { - return err - } - } - return nil - } - - for key, value := range client.InitialParameters { - // skip already set params - if server.InitialParameters[key] == value { - if err := add(key); err != nil { - return true - } - continue - } - - // only set tracking params - if !slices.Contains(T.config.TrackedParameters, key) { - if err := add(key); err != nil { - return true - } - continue - } - - ps := packets.ParameterStatus{ - Key: key.String(), - Value: value, - } - if err := client.RW.WritePacket(ps.IntoPacket()); err != nil { - return true - } - - if err := backends.SetParameter(&backends.Context{}, server.RW, key, value); err != nil { - serverOK = false - return true - } - } - - for key := range server.InitialParameters { - if _, ok := client.InitialParameters[key]; ok { - continue - } - - if err := add(key); err != nil { - return true - } - } - - return false - }() { - return - } - - for { - packet, err := client.RW.ReadPacket(true) - if err != nil { - break - } - clientErr, serverErr := bouncers.Bounce(client.RW, server.RW, packet) - if clientErr != nil || serverErr != nil { - serverOK = serverErr == nil - break - } - } -} - -func (T *Pool) LookupCorresponding(key [8]byte) (uuid.UUID, [8]byte, bool) { - // TODO(garet) - return uuid.Nil, [8]byte{}, false -} - -func (T *Pool) AddServer(server bouncer.Conn) uuid.UUID { - T.qmu.Lock() - defer T.qmu.Unlock() - - id := uuid.New() - if T.conns == nil { - T.conns = make(map[uuid.UUID]bouncer.Conn) - } - T.conns[id] = server - T._release(id) - return id -} - -func (T *Pool) GetServer(id uuid.UUID) bouncer.Conn { - T.qmu.Lock() - defer T.qmu.Unlock() - - return T.conns[id] -} - -func (T *Pool) RemoveServer(id uuid.UUID) bouncer.Conn { - T.qmu.Lock() - defer T.qmu.Unlock() - - conn, ok := T.conns[id] - if !ok { - return bouncer.Conn{} - } - delete(T.conns, id) - return conn -} - -func (T *Pool) ScaleDown(amount int) (remaining int) { - remaining = amount - - T.qmu.Lock() - defer T.qmu.Unlock() - - for i := 0; i < amount; i++ { - v, ok := T.queue.PopFront() - if !ok { - break - } - - conn, ok := T.conns[v.id] - if !ok { - continue - } - delete(T.conns, v.id) - - _ = conn.RW.Close() - remaining-- - } - - return -} - -func (T *Pool) IdleSince() time.Time { - T.qmu.Lock() - defer T.qmu.Unlock() - - v, _ := T.queue.Get(0) - return v.added +func NewPool(options gat.PoolOptions) *gat.Pool { + options.Pooler = new(Pooler) + return gat.NewPool(options) } - -func (T *Pool) ReadMetrics(metrics *Metrics) { - maps.Clear(metrics.Workers) - - if metrics.Workers == nil { - metrics.Workers = make(map[uuid.UUID]WorkerMetrics) - } - - T.qmu.Lock() - defer T.qmu.Unlock() - - for i := 0; i < T.queue.Length(); i++ { - item, _ := T.queue.Get(i) - metrics.Workers[item.id] = WorkerMetrics{ - LastActive: item.added, - } - } - - for id := range T.conns { - if _, ok := metrics.Workers[id]; !ok { - metrics.Workers[id] = WorkerMetrics{} - } - } -} - -var _ gat.RawPool = (*Pool)(nil) diff --git a/lib/gat/pools/session/pooler.go b/lib/gat/pools/session/pooler.go new file mode 100644 index 0000000000000000000000000000000000000000..d3c9e08870150805b4a0373fd559375b9bfb722f --- /dev/null +++ b/lib/gat/pools/session/pooler.go @@ -0,0 +1,96 @@ +package session + +import ( + "sync" + + "github.com/google/uuid" + + "pggat2/lib/gat" + "pggat2/lib/util/slices" +) + +type Pooler struct { + queue []uuid.UUID + servers map[uuid.UUID]struct{} + ready *sync.Cond + mu sync.Mutex +} + +func (*Pooler) AddClient(_ uuid.UUID) { + // nothing to do +} + +func (*Pooler) RemoveClient(_ uuid.UUID) { + // nothing to do +} + +func (T *Pooler) AddServer(server uuid.UUID) { + T.mu.Lock() + defer T.mu.Unlock() + + T.queue = append(T.queue, server) + + if T.servers == nil { + T.servers = make(map[uuid.UUID]struct{}) + } + T.servers[server] = struct{}{} + + if T.ready != nil { + T.ready.Signal() + } +} + +func (T *Pooler) RemoveServer(server uuid.UUID) { + T.mu.Lock() + defer T.mu.Unlock() + + // remove server from queue + T.queue = slices.Remove(T.queue, server) + + delete(T.servers, server) +} + +func (T *Pooler) AcquireConcurrent(_ uuid.UUID) uuid.UUID { + T.mu.Lock() + defer T.mu.Unlock() + + if len(T.queue) == 0 { + return uuid.Nil + } + + server := T.queue[len(T.queue)-1] + T.queue = T.queue[:len(T.queue)-1] + return server +} + +func (T *Pooler) AcquireAsync(_ uuid.UUID) uuid.UUID { + T.mu.Lock() + defer T.mu.Unlock() + + for len(T.queue) == 0 { + if T.ready == nil { + T.ready = sync.NewCond(&T.mu) + } + T.ready.Wait() + } + + server := T.queue[len(T.queue)-1] + T.queue = T.queue[:len(T.queue)-1] + return server +} + +func (*Pooler) CanRelease(_ uuid.UUID) bool { + // servers are released when the client is removed + return false +} + +func (T *Pooler) Release(server uuid.UUID) { + // check if server was removed + if _, ok := T.servers[server]; !ok { + return + } + + T.queue = append(T.queue, server) +} + +var _ gat.Pooler = (*Pooler)(nil) diff --git a/lib/gat/pools/transaction/config.go b/lib/gat/pools/transaction/config.go deleted file mode 100644 index 23b6cb1895c6924bbeed7452421ba70cbb47b8c2..0000000000000000000000000000000000000000 --- a/lib/gat/pools/transaction/config.go +++ /dev/null @@ -1,7 +0,0 @@ -package transaction - -import "pggat2/lib/gat" - -type Config struct { - gat.BaseRawPoolConfig -} diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go deleted file mode 100644 index b7e1a54b9d23dd27bdefd070de57161cb0d092fe..0000000000000000000000000000000000000000 --- a/lib/gat/pools/transaction/conn.go +++ /dev/null @@ -1,46 +0,0 @@ -package transaction - -import ( - "pggat2/lib/bouncer" - "pggat2/lib/bouncer/bouncers/v2" - "pggat2/lib/middleware/middlewares/eqp" - "pggat2/lib/middleware/middlewares/ps" - "pggat2/lib/rob" -) - -type Conn struct { - b bouncer.Conn - eqp *eqp.Server - ps *ps.Server -} - -func (T *Conn) Do(ctx *rob.Context, work any) { - job := work.(Work) - - var clientErr, serverErr error - - defer func() { - if clientErr != nil || serverErr != nil { - _ = job.rw.Close() - if serverErr != nil { - _ = T.b.RW.Close() - ctx.Remove() - } - } - }() - - // sync parameters - clientErr, serverErr = ps.Sync(job.trackedParameters, job.rw, job.ps, T.b.RW, T.ps) - if clientErr != nil || serverErr != nil { - return - } - - T.eqp.SetClient(job.eqp) - clientErr, serverErr = bouncers.Bounce(job.rw, T.b.RW, job.initialPacket) - if clientErr != nil || serverErr != nil { - return - } - return -} - -var _ rob.Worker = (*Conn)(nil) diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go deleted file mode 100644 index 2dee5b1688308f4bd45436b841914489f9f6f392..0000000000000000000000000000000000000000 --- a/lib/gat/pools/transaction/pool.go +++ /dev/null @@ -1,129 +0,0 @@ -package transaction - -import ( - "time" - - "github.com/google/uuid" - - "pggat2/lib/bouncer" - "pggat2/lib/gat" - "pggat2/lib/middleware/interceptor" - "pggat2/lib/middleware/middlewares/eqp" - "pggat2/lib/middleware/middlewares/ps" - "pggat2/lib/rob" - "pggat2/lib/rob/schedulers/v1" -) - -type Pool struct { - config Config - s schedulers.Scheduler -} - -func NewPool(config Config) *Pool { - pool := &Pool{ - config: config, - s: schedulers.MakeScheduler(), - } - - return pool -} - -func (T *Pool) AddServer(server bouncer.Conn) uuid.UUID { - eqps := eqp.NewServer() - pss := ps.NewServer(server.InitialParameters) - server.RW = interceptor.NewInterceptor( - server.RW, - eqps, - pss, - ) - sink := &Conn{ - b: server, - eqp: eqps, - ps: pss, - } - return T.s.AddWorker(0, sink) -} - -func (T *Pool) GetServer(id uuid.UUID) bouncer.Conn { - conn := T.s.GetWorker(id) - if conn == nil { - return bouncer.Conn{} - } - return conn.(*Conn).b -} - -func (T *Pool) RemoveServer(id uuid.UUID) bouncer.Conn { - conn := T.s.RemoveWorker(id) - if conn == nil { - return bouncer.Conn{} - } - return conn.(*Conn).b -} - -func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { - source := T.s.NewSource() - eqpc := eqp.NewClient() - defer eqpc.Done() - psc := ps.NewClient(client.InitialParameters) - c := interceptor.NewInterceptor( - client.RW, - eqpc, - psc, - ) - robCtx := rob.Context{ - OnWait: ctx.OnWait, - } - - for { - packet, err := c.ReadPacket(true) - if err != nil { - break - } - - source.Do(&robCtx, Work{ - rw: c, - initialPacket: packet, - eqp: eqpc, - ps: psc, - trackedParameters: T.config.TrackedParameters, - }) - } - _ = c.Close() -} - -func (T *Pool) LookupCorresponding(key [8]byte) (uuid.UUID, [8]byte, bool) { - // TODO(garet) - return uuid.Nil, [8]byte{}, false -} - -func (T *Pool) ScaleDown(amount int) (remaining int) { - remaining = amount - - for i := 0; i < amount; i++ { - id, idle := T.s.GetIdleWorker() - if id == uuid.Nil || idle == (time.Time{}) { - break - } - worker := T.s.RemoveWorker(id) - if worker == nil { - i-- - continue - } - conn := worker.(*Conn) - _ = conn.b.RW.Close() - remaining-- - } - - return -} - -func (T *Pool) IdleSince() time.Time { - _, idle := T.s.GetIdleWorker() - return idle -} - -func (T *Pool) ReadSchedulerMetrics(metrics *rob.Metrics) { - T.s.ReadMetrics(metrics) -} - -var _ gat.RawPool = (*Pool)(nil) diff --git a/lib/gat/pools/transaction/work.go b/lib/gat/pools/transaction/work.go deleted file mode 100644 index 053958730d8b5c3d4ec3bd8123da0cd3e92b9259..0000000000000000000000000000000000000000 --- a/lib/gat/pools/transaction/work.go +++ /dev/null @@ -1,16 +0,0 @@ -package transaction - -import ( - "pggat2/lib/middleware/middlewares/eqp" - "pggat2/lib/middleware/middlewares/ps" - "pggat2/lib/util/strutil" - "pggat2/lib/zap" -) - -type Work struct { - rw zap.Conn - initialPacket zap.Packet - eqp *eqp.Client - ps *ps.Client - trackedParameters []strutil.CIString -} diff --git a/lib/gat/recipe.go b/lib/gat/recipe.go index 5c9918c5fcaa58957b570354d6124dfb839d0b01..487f4433f171afe948ce58071727eef903c268b2 100644 --- a/lib/gat/recipe.go +++ b/lib/gat/recipe.go @@ -1,75 +1,7 @@ package gat -import ( - "crypto/tls" - "net" - - "pggat2/lib/auth" - "pggat2/lib/bouncer" - "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/util/strutil" - "pggat2/lib/zap" -) - -type Recipe interface { - Dial() (zap.Conn, error) - Connect() (bouncer.Conn, error) - - GetMinConnections() int - // GetMaxConnections returns the maximum amount of connections for this db - // Return 0 for unlimited connections - GetMaxConnections() int -} - -type TCPRecipe struct { - Database string - Address string - Credentials auth.Credentials - +type Recipe struct { + Dialer Dialer MinConnections int MaxConnections int - - SSLMode bouncer.SSLMode - - StartupParameters map[strutil.CIString]string -} - -func (T TCPRecipe) Dial() (zap.Conn, error) { - conn, err := net.Dial("tcp", T.Address) - if err != nil { - return nil, err - } - rw := zap.WrapNetConn(conn) - return rw, nil -} - -func (T TCPRecipe) Connect() (bouncer.Conn, error) { - rw, err := T.Dial() - if err != nil { - return bouncer.Conn{}, err - } - - server, err := backends.Accept(rw, backends.AcceptOptions{ - SSLMode: T.SSLMode, - SSLConfig: &tls.Config{ - // TODO(garet) SSL certificates if they need to be verified - InsecureSkipVerify: !T.SSLMode.VerifyCertificates(), - }, - Credentials: T.Credentials, - Database: T.Database, - StartupParameters: T.StartupParameters, - }) - if err != nil { - return bouncer.Conn{}, err - } - - return server, nil -} - -func (T TCPRecipe) GetMinConnections() int { - return T.MinConnections -} - -func (T TCPRecipe) GetMaxConnections() int { - return T.MaxConnections } diff --git a/lib/gat/user.go b/lib/gat/user.go deleted file mode 100644 index f34bac32ef57c5ee4f2371dad218170ef969d8fd..0000000000000000000000000000000000000000 --- a/lib/gat/user.go +++ /dev/null @@ -1,35 +0,0 @@ -package gat - -import ( - "pggat2/lib/auth" - "pggat2/lib/util/maps" -) - -type User struct { - credentials auth.Credentials - - pools maps.RWLocked[string, *Pool] -} - -func NewUser(credentials auth.Credentials) *User { - return &User{ - credentials: credentials, - } -} - -func (T *User) GetCredentials() auth.Credentials { - return T.credentials -} - -func (T *User) AddPool(name string, pool *Pool) { - T.pools.Store(name, pool) -} - -func (T *User) RemovePool(name string) { - T.pools.Delete(name) -} - -func (T *User) GetPool(name string) *Pool { - pool, _ := T.pools.Load(name) - return pool -} diff --git a/lib/perror/error.go b/lib/perror/error.go index fb30f94c15bd69616bff2b98acb2ca0fde7229e9..cd3542c7b9032d6df65651a685975cdc56c60ebe 100644 --- a/lib/perror/error.go +++ b/lib/perror/error.go @@ -6,4 +6,5 @@ type Error interface { Message() string Extra() []ExtraField String() string + Error() string } diff --git a/lib/perror/new.go b/lib/perror/new.go index c2abeeed9969711f7e1f7a14f5b7eeddb1516cfe..6b3f4e064130308b9d81c85ce87b58c85c36613f 100644 --- a/lib/perror/new.go +++ b/lib/perror/new.go @@ -36,4 +36,8 @@ func (T err) String() string { return string(T.severity) + ": " + T.message } +func (T err) Error() string { + return T.String() +} + var _ Error = err{} diff --git a/lib/util/slices/remove.go b/lib/util/slices/remove.go new file mode 100644 index 0000000000000000000000000000000000000000..6b62ffaf67601321d1052e0cbb17c15c0af053b0 --- /dev/null +++ b/lib/util/slices/remove.go @@ -0,0 +1,12 @@ +package slices + +func Remove[T comparable](slice []T, item T) []T { + for i, s := range slice { + if s == item { + copy(slice[i:], slice[i+1:]) + return slice[:len(slice)-1] + } + } + + return slice +}