diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 883cd71122bfa0f0bbe1e5a7ae3c2eb5a937fd14..ab061d866110cbf0f046ab6ce8b9c89113a60920 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -1,21 +1,20 @@ package frontends import ( - "fmt" + "crypto/tls" "io" "strings" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/perror" - "gfx.cafe/gfx/pggat/lib/util/slices" "gfx.cafe/gfx/pggat/lib/util/strutil" ) func startup0( - ctx *AcceptContext, - params *AcceptParams, -) (done bool, err perror.Error) { + ctx *acceptContext, + params *acceptParams, +) (cancelling bool, done bool, err perror.Error) { var err2 error ctx.Packet, err2 = ctx.Conn.ReadPacket(false, ctx.Packet) if err2 != nil { @@ -34,18 +33,7 @@ func startup0( case 5678: // Cancel p.ReadBytes(params.CancelKey[:]) - - if params.CancelKey == [8]byte{} { - // very rare that this would ever happen - // and it's ok if we don't honor cancel requests - err = perror.New( - perror.FATAL, - perror.ProtocolViolation, - "cancel key cannot be null", - ) - return - } - + cancelling = true done = true return case 5679: @@ -150,15 +138,6 @@ func startup0( ikey := strutil.MakeCIString(key) - if !slices.Contains(ctx.Options.AllowedStartupOptions, ikey) { - err = perror.New( - perror.FATAL, - perror.FeatureNotSupported, - fmt.Sprintf(`Startup parameter "%s" is not allowed`, key), - ) - return - } - if params.InitialParameters == nil { params.InitialParameters = make(map[strutil.CIString]string) } @@ -186,15 +165,6 @@ func startup0( } else { ikey := strutil.MakeCIString(key) - if !slices.Contains(ctx.Options.AllowedStartupOptions, ikey) { - err = perror.New( - perror.FATAL, - perror.FeatureNotSupported, - fmt.Sprintf(`Startup parameter "%s" is not allowed`, key), - ) - return - } - if params.InitialParameters == nil { params.InitialParameters = make(map[strutil.CIString]string) } @@ -232,12 +202,12 @@ func startup0( return } -func accept( - ctx *AcceptContext, -) (params AcceptParams, err perror.Error) { +func accept0( + ctx *acceptContext, +) (params acceptParams, err perror.Error) { for { var done bool - done, err = startup0(ctx, ¶ms) + params.IsCanceling, done, err = startup0(ctx, ¶ms) if err != nil { return } @@ -246,23 +216,10 @@ func accept( } } - if params.CancelKey != [8]byte{} { - return - } - - if ctx.Options.SSLRequired && !params.SSLEnabled { - err = perror.New( - perror.FATAL, - perror.InvalidPassword, - "SSL is required", - ) - return - } - return } -func fail(packet fed.Packet, client fed.Conn, err perror.Error) { +func fail(packet fed.Packet, client fed.ReadWriter, err perror.Error) { resp := packets.ErrorResponse{ Error: err, } @@ -270,11 +227,37 @@ func fail(packet fed.Packet, client fed.Conn, err perror.Error) { _ = client.WritePacket(packet) } -func Accept(ctx *AcceptContext) (AcceptParams, perror.Error) { - params, err := accept(ctx) +func accept(ctx *acceptContext) (acceptParams, perror.Error) { + params, err := accept0(ctx) if err != nil { fail(ctx.Packet, ctx.Conn, err) - return AcceptParams{}, err + return acceptParams{}, err } return params, nil } + +func Accept(conn fed.ReadWriter, tlsConfig *tls.Config) ( + cancelKey [8]byte, + isCanceling bool, + sslEnabled bool, + user string, + database string, + initialParameters map[strutil.CIString]string, + err perror.Error, +) { + ctx := acceptContext{ + Conn: conn, + Options: acceptOptions{ + SSLConfig: tlsConfig, + }, + } + var params acceptParams + params, err = accept(&ctx) + cancelKey = params.CancelKey + isCanceling = params.IsCanceling + sslEnabled = params.SSLEnabled + user = params.User + database = params.Database + initialParameters = params.InitialParameters + return +} diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index b81603e3f1d59fccb2779a1b9125732c3986ec1c..680002a1c354b4b51db15e82d0ba8c7da049f251 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -6,11 +6,12 @@ import ( "io" "gfx.cafe/gfx/pggat/lib/auth" + "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/perror" ) -func authenticationSASLInitial(ctx *AuthenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { +func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { // check which authentication method the client wants var err2 error ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) @@ -42,7 +43,7 @@ func authenticationSASLInitial(ctx *AuthenticateContext, creds auth.SASLServer) return } -func authenticationSASLContinue(ctx *AuthenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { +func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { var err2 error ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) if err2 != nil { @@ -67,7 +68,7 @@ func authenticationSASLContinue(ctx *AuthenticateContext, tool auth.SASLVerifier return } -func authenticationSASL(ctx *AuthenticateContext, creds auth.SASLServer) perror.Error { +func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) perror.Error { saslInitial := packets.AuthenticationSASL{ Mechanisms: creds.SupportedSASLMechanisms(), } @@ -109,7 +110,7 @@ func authenticationSASL(ctx *AuthenticateContext, creds auth.SASLServer) perror. return nil } -func authenticationMD5(ctx *AuthenticateContext, creds auth.MD5Server) perror.Error { +func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) perror.Error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { @@ -141,7 +142,7 @@ func authenticationMD5(ctx *AuthenticateContext, creds auth.MD5Server) perror.Er return nil } -func authenticate(ctx *AuthenticateContext) (params AuthenticateParams, err perror.Error) { +func authenticate0(ctx *authenticateContext) (params authenticateParams, err perror.Error) { if ctx.Options.Credentials != nil { if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { err = authenticationSASL(ctx, credsSASL) @@ -184,11 +185,24 @@ func authenticate(ctx *AuthenticateContext) (params AuthenticateParams, err perr return } -func Authenticate(ctx *AuthenticateContext) (AuthenticateParams, perror.Error) { - params, err := authenticate(ctx) +func authenticate(ctx *authenticateContext) (authenticateParams, perror.Error) { + params, err := authenticate0(ctx) if err != nil { fail(ctx.Packet, ctx.Conn, err) - return AuthenticateParams{}, err + return authenticateParams{}, err } return params, nil } + +func Authenticate(conn fed.ReadWriter, creds auth.Credentials) (backendKey [8]byte, err perror.Error) { + ctx := authenticateContext{ + Conn: conn, + Options: authenticateOptions{ + Credentials: creds, + }, + } + var params authenticateParams + params, err = authenticate(&ctx) + backendKey = params.BackendKey + return +} diff --git a/lib/bouncer/frontends/v0/context.go b/lib/bouncer/frontends/v0/context.go index 859532528081af9edecf9d1ca3bb9bfbb2adb3da..aca67272ab2f3888527a1c7b9f13abe67c23d283 100644 --- a/lib/bouncer/frontends/v0/context.go +++ b/lib/bouncer/frontends/v0/context.go @@ -2,14 +2,14 @@ package frontends import "gfx.cafe/gfx/pggat/lib/fed" -type AcceptContext struct { +type acceptContext struct { Packet fed.Packet - Conn fed.Conn - Options AcceptOptions + Conn fed.ReadWriter + Options acceptOptions } -type AuthenticateContext struct { +type authenticateContext struct { Packet fed.Packet - Conn fed.Conn - Options AuthenticateOptions + Conn fed.ReadWriter + Options authenticateOptions } diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go index b06d9889b7cb9f7d13ebbfb5413e080fcb9d6eff..304f7b82256efcd4edf0dc1d7125c25491f93320 100644 --- a/lib/bouncer/frontends/v0/options.go +++ b/lib/bouncer/frontends/v0/options.go @@ -4,15 +4,12 @@ import ( "crypto/tls" "gfx.cafe/gfx/pggat/lib/auth" - "gfx.cafe/gfx/pggat/lib/util/strutil" ) -type AcceptOptions struct { - SSLRequired bool - SSLConfig *tls.Config - AllowedStartupOptions []strutil.CIString +type acceptOptions struct { + SSLConfig *tls.Config } -type AuthenticateOptions struct { +type authenticateOptions struct { Credentials auth.Credentials } diff --git a/lib/bouncer/frontends/v0/params.go b/lib/bouncer/frontends/v0/params.go index a182469535682589611815d51acd31be83e4acd0..2aba47cefdc5483d47cbd486e8c34fd9c0b949c7 100644 --- a/lib/bouncer/frontends/v0/params.go +++ b/lib/bouncer/frontends/v0/params.go @@ -2,8 +2,9 @@ package frontends import "gfx.cafe/gfx/pggat/lib/util/strutil" -type AcceptParams struct { - CancelKey [8]byte +type acceptParams struct { + CancelKey [8]byte + IsCanceling bool // or @@ -13,6 +14,6 @@ type AcceptParams struct { InitialParameters map[strutil.CIString]string } -type AuthenticateParams struct { +type authenticateParams struct { BackendKey [8]byte } diff --git a/lib/fed/conn.go b/lib/fed/conn.go index 394e451bdfa0b6e8c393900d4e3ee1f507aae460..532a6f10d5f0e5f37759923f36fe43c55c99b618 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -9,24 +9,38 @@ import ( "net" "gfx.cafe/gfx/pggat/lib/util/slices" + "gfx.cafe/gfx/pggat/lib/util/strutil" ) type Conn interface { ReadWriter + LocalAddr() net.Addr + RemoteAddr() net.Addr + + SSLEnabled() bool + User() string + Database() string + InitialParameters() map[strutil.CIString]string + Close() error } -type netConn struct { - conn net.Conn - writer bufio.Writer - reader bufio.Reader +type NetConn struct { + conn net.Conn + writer bufio.Writer + reader bufio.Reader + sslEnabled bool + + user string + database string + initialParameters map[strutil.CIString]string headerBuf [5]byte } -func WrapNetConn(conn net.Conn) Conn { - c := &netConn{ +func WrapNetConn(conn net.Conn) *NetConn { + c := &NetConn{ conn: conn, } c.writer.Reset(conn) @@ -34,7 +48,50 @@ func WrapNetConn(conn net.Conn) Conn { return c } -func (T *netConn) EnableSSLClient(config *tls.Config) error { +func (T *NetConn) LocalAddr() net.Addr { + return T.conn.LocalAddr() +} + +func (T *NetConn) RemoteAddr() net.Addr { + return T.conn.RemoteAddr() +} + +func (T *NetConn) SSLEnabled() bool { + return T.sslEnabled +} + +func (T *NetConn) User() string { + return T.user +} + +func (T *NetConn) SetUser(user string) { + T.user = user +} + +func (T *NetConn) Database() string { + return T.database +} + +func (T *NetConn) SetDatabase(database string) { + T.database = database +} + +func (T *NetConn) InitialParameters() map[strutil.CIString]string { + return T.initialParameters +} + +func (T *NetConn) SetInitialParameters(initialParameters map[strutil.CIString]string) { + T.initialParameters = initialParameters +} + +var errSSLAlreadyEnabled = errors.New("ssl is already enabled") + +func (T *NetConn) EnableSSLClient(config *tls.Config) error { + if T.sslEnabled { + return errSSLAlreadyEnabled + } + T.sslEnabled = true + if err := T.writer.Flush(); err != nil { return err } @@ -48,7 +105,12 @@ func (T *netConn) EnableSSLClient(config *tls.Config) error { return sslConn.Handshake() } -func (T *netConn) EnableSSLServer(config *tls.Config) error { +func (T *NetConn) EnableSSLServer(config *tls.Config) error { + if T.sslEnabled { + return errSSLAlreadyEnabled + } + T.sslEnabled = true + if err := T.writer.Flush(); err != nil { return err } @@ -62,14 +124,14 @@ func (T *netConn) EnableSSLServer(config *tls.Config) error { return sslConn.Handshake() } -func (T *netConn) ReadByte() (byte, error) { +func (T *NetConn) ReadByte() (byte, error) { if err := T.writer.Flush(); err != nil { return 0, err } return T.reader.ReadByte() } -func (T *netConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err error) { +func (T *NetConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err error) { packet = buffer if err = T.writer.Flush(); err != nil { @@ -100,24 +162,24 @@ func (T *netConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err erro return } -func (T *netConn) WriteByte(b byte) error { +func (T *NetConn) WriteByte(b byte) error { return T.writer.WriteByte(b) } -func (T *netConn) WritePacket(packet Packet) error { +func (T *NetConn) WritePacket(packet Packet) error { _, err := T.writer.Write(packet.Bytes()) return err } -func (T *netConn) Close() error { +func (T *NetConn) Close() error { if err := T.writer.Flush(); err != nil { return err } return T.conn.Close() } -var _ Conn = (*netConn)(nil) -var _ SSLServer = (*netConn)(nil) -var _ SSLClient = (*netConn)(nil) -var _ io.ByteReader = (*netConn)(nil) -var _ io.ByteWriter = (*netConn)(nil) +var _ Conn = (*NetConn)(nil) +var _ SSLServer = (*NetConn)(nil) +var _ SSLClient = (*NetConn)(nil) +var _ io.ByteReader = (*NetConn)(nil) +var _ io.ByteWriter = (*NetConn)(nil) diff --git a/lib/gat/app.go b/lib/gat/app.go index f69db111b2339c3869a957eb3c23af63cfcc316a..eafd6e4128a65b8ea505c4b285216d7c309c6c65 100644 --- a/lib/gat/app.go +++ b/lib/gat/app.go @@ -1,10 +1,23 @@ package gat import ( - "github.com/caddyserver/caddy/v2" + "crypto/tls" + "errors" + "fmt" + "net" + "github.com/caddyserver/caddy/v2" + "tuxpa.in/a/zlog/log" + + "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" + "gfx.cafe/gfx/pggat/lib/fed" + packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" + "gfx.cafe/gfx/pggat/lib/middleware/interceptor" + "gfx.cafe/gfx/pggat/lib/middleware/middlewares/unterminate" + "gfx.cafe/gfx/pggat/lib/perror" "gfx.cafe/gfx/pggat/lib/util/dur" "gfx.cafe/gfx/pggat/lib/util/maps" + "gfx.cafe/gfx/pggat/lib/util/slices" ) type Config struct { @@ -61,12 +74,124 @@ func (T *App) Provision(ctx caddy.Context) error { return nil } +func (T *App) cancel(key [8]byte) { + p, _ := T.keys.Load(key) + if p == nil { + return + } + + _ = p.Cancel(key) +} + +func (T *App) serve(server *Server, conn fed.Conn) { + initialParameters := conn.InitialParameters() + for key := range initialParameters { + if !slices.Contains(server.AllowedStartupParameters, key) { + errResp := packets.ErrorResponse{ + Error: perror.New( + perror.FATAL, + perror.FeatureNotSupported, + fmt.Sprintf(`Startup parameter "%s" is not allowed`, key), + ), + } + _ = conn.WritePacket(errResp.IntoPacket(nil)) + return + } + } + + p := server.lookup(conn) + if p == nil { + log.Printf("pool not found for client: user=%s database=%s", conn.User(), conn.Database()) + return + } + + backendKey, err := frontends.Authenticate(conn, p.Credentials()) + if err != nil { + log.Printf("error authenticating client: %v", err) + return + } + + T.keys.Store(backendKey, p) + defer T.keys.Delete(backendKey) + + if err2 := p.Serve(conn, backendKey); err2 != nil { + log.Printf("error serving client: %v", err2) + return + } +} + +func (T *App) accept(listener *Listener, conn *fed.NetConn) { + defer func() { + _ = conn.Close() + }() + + var tlsConfig *tls.Config + if listener.ssl != nil { + tlsConfig = listener.ssl.ServerTLSConfig() + } + + cancelKey, isCanceling, _, user, database, initialParameters, err := frontends.Accept(conn, tlsConfig) + if err != nil { + log.Printf("error accepting client: %v", err) + return + } + + if isCanceling { + T.cancel(cancelKey) + return + } + + conn.SetUser(user) + conn.SetDatabase(database) + conn.SetInitialParameters(initialParameters) + + for _, server := range T.servers { + if server.match == nil || server.match.Matches(conn) { + T.serve(server, interceptor.NewInterceptor(conn, unterminate.Unterminate)) + return + } + } + + log.Printf("server not found for client: user=%s database=%s", conn.User(), conn.Database()) + + errResp := packets.ErrorResponse{ + Error: perror.New( + perror.FATAL, + perror.InternalError, + "No server is available to handle your request", + ), + } + _ = conn.WritePacket(errResp.IntoPacket(nil)) +} + +func (T *App) acceptFrom(listener *Listener) bool { + conn, err := listener.accept() + if err != nil { + if errors.Is(err, net.ErrClosed) { + return false + } + log.Printf("error accepting client: %v", err) + return true + } + + go T.accept(listener, conn) + return true +} + func (T *App) Start() error { // start listeners for _, listener := range T.listen { if err := listener.Start(); err != nil { return err } + + go func(listener *Listener) { + for { + if !T.acceptFrom(listener) { + break + } + } + }(listener) } return nil diff --git a/lib/gat/listen.go b/lib/gat/listen.go index 024a7a7e6813c59d5d02838b9afee608f6c4258b..1a92f4ae563b58c15c847bcf4d75fa13f4e3182f 100644 --- a/lib/gat/listen.go +++ b/lib/gat/listen.go @@ -7,6 +7,8 @@ import ( "github.com/caddyserver/caddy/v2" "tuxpa.in/a/zlog/log" + + "gfx.cafe/gfx/pggat/lib/fed" ) type ListenerConfig struct { @@ -23,6 +25,14 @@ type Listener struct { listener net.Listener } +func (T *Listener) accept() (*fed.NetConn, error) { + raw, err := T.listener.Accept() + if err != nil { + return nil, err + } + return fed.WrapNetConn(raw), nil +} + func (T *Listener) Provision(ctx caddy.Context) error { if T.SSL != nil { val, err := ctx.LoadModule(T, "SSL") diff --git a/lib/gat/matcher.go b/lib/gat/matcher.go index 1f51a8a905f5616dc898e7c1fe332d99de723546..20c20618a96764dd4afd82bd961391e1a39c1ec5 100644 --- a/lib/gat/matcher.go +++ b/lib/gat/matcher.go @@ -1,4 +1,7 @@ package gat +import "gfx.cafe/gfx/pggat/lib/fed" + type Matcher interface { + Matches(conn fed.Conn) bool } diff --git a/lib/gat/matchers/and.go b/lib/gat/matchers/and.go index 35740607775ff38f1de12a9bfe41f1c8ae842685..7dd683b1187356729f7de5af58583979c07b0d2b 100644 --- a/lib/gat/matchers/and.go +++ b/lib/gat/matchers/and.go @@ -6,6 +6,7 @@ import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -44,6 +45,15 @@ func (T *And) Provision(ctx caddy.Context) error { return nil } +func (T *And) Matches(conn fed.Conn) bool { + for _, matcher := range T.and { + if !matcher.Matches(conn) { + return false + } + } + return true +} + var _ gat.Matcher = (*And)(nil) var _ caddy.Module = (*And)(nil) var _ caddy.Provisioner = (*And)(nil) diff --git a/lib/gat/matchers/database.go b/lib/gat/matchers/database.go index 280105f4cf3b8ef47e09e9962cac76dcba890d9f..0133a90e4c3631a39f81ddca4e4108b2e9eb43f3 100644 --- a/lib/gat/matchers/database.go +++ b/lib/gat/matchers/database.go @@ -3,6 +3,7 @@ package matchers import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -23,5 +24,9 @@ func (T *Database) CaddyModule() caddy.ModuleInfo { } } +func (T *Database) Matches(conn fed.Conn) bool { + return conn.Database() == T.Database +} + var _ gat.Matcher = (*Database)(nil) var _ caddy.Module = (*Database)(nil) diff --git a/lib/gat/matchers/localaddress.go b/lib/gat/matchers/localaddress.go index cf9d1d6d3b68a44d2d92f442cd75900ac70bac2e..a936c454c663b21861b4fbfdbb32317b65c620fc 100644 --- a/lib/gat/matchers/localaddress.go +++ b/lib/gat/matchers/localaddress.go @@ -3,6 +3,7 @@ package matchers import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -23,5 +24,10 @@ func (T *LocalAddress) CaddyModule() caddy.ModuleInfo { } } +func (T *LocalAddress) Matches(conn fed.Conn) bool { + // TODO(garet) + return true +} + var _ gat.Matcher = (*LocalAddress)(nil) var _ caddy.Module = (*LocalAddress)(nil) diff --git a/lib/gat/matchers/network.go b/lib/gat/matchers/network.go index b1178225c7fd337434096313563759944882820a..fa638e3dffc99d267b36671d67dde25aa7a1e414 100644 --- a/lib/gat/matchers/network.go +++ b/lib/gat/matchers/network.go @@ -3,6 +3,7 @@ package matchers import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -23,5 +24,9 @@ func (T *Network) CaddyModule() caddy.ModuleInfo { } } +func (T *Network) Matches(conn fed.Conn) bool { + return conn.LocalAddr().Network() == T.Network +} + var _ gat.Matcher = (*Network)(nil) var _ caddy.Module = (*Network)(nil) diff --git a/lib/gat/matchers/or.go b/lib/gat/matchers/or.go index 342c0dbad106226dc98ea307257b83f535eb09a8..a8d0632c7ca726bba0c1e76e5301e51927d22f55 100644 --- a/lib/gat/matchers/or.go +++ b/lib/gat/matchers/or.go @@ -6,6 +6,7 @@ import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -44,6 +45,15 @@ func (T *Or) Provision(ctx caddy.Context) error { return nil } +func (T *Or) Matches(conn fed.Conn) bool { + for _, matcher := range T.or { + if matcher.Matches(conn) { + return true + } + } + return false +} + var _ gat.Matcher = (*Or)(nil) var _ caddy.Module = (*Or)(nil) var _ caddy.Provisioner = (*Or)(nil) diff --git a/lib/gat/matchers/ssl.go b/lib/gat/matchers/ssl.go index 9ad844bf24cbebca40488057d9fcaa790917bf3b..eacd4090e6329eafbde15b655ab747a414a92ecf 100644 --- a/lib/gat/matchers/ssl.go +++ b/lib/gat/matchers/ssl.go @@ -3,6 +3,7 @@ package matchers import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -23,5 +24,9 @@ func (T *SSL) CaddyModule() caddy.ModuleInfo { } } +func (T *SSL) Matches(conn fed.Conn) bool { + return conn.SSLEnabled() == T.SSL +} + var _ gat.Matcher = (*SSL)(nil) var _ caddy.Module = (*SSL)(nil) diff --git a/lib/gat/matchers/startupparameters.go b/lib/gat/matchers/startupparameters.go index 7692d272554ccedcec5ca2c1050f7345ccbfcf41..addb92f48347ff802dcf56f311023c988cd6aeb7 100644 --- a/lib/gat/matchers/startupparameters.go +++ b/lib/gat/matchers/startupparameters.go @@ -3,7 +3,9 @@ package matchers import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/util/strutil" ) func init() { @@ -12,6 +14,8 @@ func init() { type StartupParameters struct { Parameters map[string]string `json:"startup_parameters"` + + parameters map[strutil.CIString]string } func (T *StartupParameters) CaddyModule() caddy.ModuleInfo { @@ -23,5 +27,25 @@ func (T *StartupParameters) CaddyModule() caddy.ModuleInfo { } } +func (T *StartupParameters) Provision(ctx caddy.Context) error { + T.parameters = make(map[strutil.CIString]string, len(T.Parameters)) + for key, value := range T.Parameters { + T.parameters[strutil.MakeCIString(key)] = value + } + + return nil +} + +func (T *StartupParameters) Matches(conn fed.Conn) bool { + initialParameters := conn.InitialParameters() + for key, value := range T.parameters { + if initialParameters[key] != value { + return false + } + } + return true +} + var _ gat.Matcher = (*StartupParameters)(nil) var _ caddy.Module = (*StartupParameters)(nil) +var _ caddy.Provisioner = (*StartupParameters)(nil) diff --git a/lib/gat/matchers/user.go b/lib/gat/matchers/user.go index 5985bc2a0041f1a5c3c23c6d920ac22e31b63138..a666a4bb12912a727b38dfc94ebdae23025b3a65 100644 --- a/lib/gat/matchers/user.go +++ b/lib/gat/matchers/user.go @@ -3,6 +3,7 @@ package matchers import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" ) @@ -23,5 +24,9 @@ func (T *User) CaddyModule() caddy.ModuleInfo { } } +func (T *User) Matches(conn fed.Conn) bool { + return conn.User() == T.User +} + var _ gat.Matcher = (*User)(nil) var _ caddy.Module = (*User)(nil) diff --git a/lib/gat/pool/client.go b/lib/gat/pool/client.go index 2d687e10229b09be35905ad882daf27ac27eee44..6d4888e114244785291fbe88863fec8e1856e3b9 100644 --- a/lib/gat/pool/client.go +++ b/lib/gat/pool/client.go @@ -7,7 +7,6 @@ import ( "gfx.cafe/gfx/pggat/lib/middleware/middlewares/eqp" "gfx.cafe/gfx/pggat/lib/middleware/middlewares/ps" "gfx.cafe/gfx/pggat/lib/middleware/middlewares/unterminate" - "gfx.cafe/gfx/pggat/lib/util/strutil" ) type pooledClient struct { @@ -20,13 +19,14 @@ type pooledClient struct { func newClient( options Options, conn fed.Conn, - initialParameters map[strutil.CIString]string, backendKey [8]byte, ) *pooledClient { middlewares := []middleware.Middleware{ unterminate.Unterminate, } + initialParameters := conn.InitialParameters() + var psClient *ps.Client if options.ParameterStatusSync == ParameterStatusSyncDynamic { // add ps middleware diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 3ad6b328f89e9307cf135e82a2d6dc48fad57795..3dd2e3fc1e9c7ce2738afecd10dbae0a1bf758bc 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -16,7 +16,6 @@ import ( packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/gat/metrics" "gfx.cafe/gfx/pggat/lib/util/slices" - "gfx.cafe/gfx/pggat/lib/util/strutil" ) type Pool struct { @@ -79,7 +78,7 @@ func (T *Pool) idlest() (server *pooledServer, at time.Time) { return } -func (T *Pool) GetCredentials() auth.Credentials { +func (T *Pool) Credentials() auth.Credentials { return T.options.Credentials } @@ -284,7 +283,6 @@ func (T *Pool) releaseServer(server *pooledServer) { func (T *Pool) Serve( conn fed.Conn, - initialParameters map[strutil.CIString]string, backendKey [8]byte, ) error { defer func() { @@ -294,7 +292,6 @@ func (T *Pool) Serve( client := newClient( T.options, conn, - initialParameters, backendKey, ) @@ -313,7 +310,6 @@ func (T *Pool) ServeBot( client := newClient( T.options, conn, - nil, [8]byte{}, ) diff --git a/lib/gat/provider.go b/lib/gat/provider.go index 719dfcb39fa78d1e7559b71bae5420098f92917c..127e16a76aa18deb3509bcc2e31eb6540fed7339 100644 --- a/lib/gat/provider.go +++ b/lib/gat/provider.go @@ -1,9 +1,12 @@ package gat -import "gfx.cafe/gfx/pggat/lib/gat/metrics" +import ( + "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/gat/metrics" +) // Provider provides pool to the server type Provider interface { - Lookup(user, database string) *Pool + Lookup(conn fed.Conn) *Pool ReadMetrics(metrics *metrics.Pools) } diff --git a/lib/gat/providers/discovery/module.go b/lib/gat/providers/discovery/module.go index d3f37bf55adda372b6e5bb007fe0e0bb1bdb2493..b3c351b1e52cd41312e14385468e2d517e90113c 100644 --- a/lib/gat/providers/discovery/module.go +++ b/lib/gat/providers/discovery/module.go @@ -10,6 +10,7 @@ import ( "gfx.cafe/gfx/pggat/lib/auth" "gfx.cafe/gfx/pggat/lib/auth/credentials" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/metrics" "gfx.cafe/gfx/pggat/lib/gat/pool" @@ -230,7 +231,7 @@ func (T *Module) replacePrimary(users []User, databases []string, endpoint Endpo AcceptOptions: acceptOptions, } - p := T.Lookup(user.Username, database) + p := T.lookup(user.Username, database) if p == nil { continue } @@ -284,7 +285,7 @@ func (T *Module) addReplica(users []User, databases []string, id string, endpoin for _, database := range databases { acceptOptions := T.backendAcceptOptions(user.Username, primaryCreds, database) - p := T.Lookup(replicaUsername, database) + p := T.lookup(replicaUsername, database) if p == nil { continue } @@ -305,7 +306,7 @@ func (T *Module) removeReplica(users []User, databases []string, id string) { for _, user := range users { username := T.replicaUsername(user.Username) for _, database := range databases { - p := T.Lookup(username, database) + p := T.lookup(username, database) if p == nil { continue } @@ -508,13 +509,17 @@ func (T *Module) ReadMetrics(metrics *metrics.Pools) { }) } -func (T *Module) Lookup(user, database string) *gat.Pool { +func (T *Module) lookup(user, database string) *gat.Pool { T.mu.RLock() defer T.mu.RUnlock() p, _ := T.pools.Load(user, database) return p } +func (T *Module) Lookup(conn fed.Conn) *gat.Pool { + return T.lookup(conn.User(), conn.Database()) +} + var _ gat.Provider = (*Module)(nil) var _ caddy.Module = (*Module)(nil) var _ caddy.Provisioner = (*Module)(nil) diff --git a/lib/gat/providers/pgbouncer/module.go b/lib/gat/providers/pgbouncer/module.go index 51cfc37ecfe20865b14f5e48c134d7aba1bbfd25..eae73261a948cacd05a269cbbecab7036734ef87 100644 --- a/lib/gat/providers/pgbouncer/module.go +++ b/lib/gat/providers/pgbouncer/module.go @@ -13,6 +13,7 @@ import ( "github.com/caddyserver/caddy/v2" "tuxpa.in/a/zlog/log" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/poolers/session" "gfx.cafe/gfx/pggat/lib/gat/poolers/transaction" "gfx.cafe/gfx/pggat/lib/util/dur" @@ -90,7 +91,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { } } - authPool := T.Lookup(authUser, database) + authPool := T.lookup(authUser, database) if authPool == nil { return "", false } @@ -266,7 +267,7 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { return p } -func (T *Module) Lookup(user, database string) *gat.Pool { +func (T *Module) lookup(user, database string) *gat.Pool { p, _ := T.pools.Load(user, database) if p != nil { return p @@ -276,6 +277,10 @@ func (T *Module) Lookup(user, database string) *gat.Pool { return T.tryCreate(user, database) } +func (T *Module) Lookup(conn fed.Conn) *gat.Pool { + return T.lookup(conn.User(), conn.Database()) +} + func (T *Module) ReadMetrics(metrics *metrics.Pools) { T.mu.RLock() defer T.mu.RUnlock() diff --git a/lib/gat/server.go b/lib/gat/server.go index 264fe6e53162c667bb0079736584fa9910e5fd23..23a01606feb046f8136c8e68ede24b6c7286747c 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -6,6 +6,8 @@ import ( "github.com/caddyserver/caddy/v2" + "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/gat/pool" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -45,4 +47,19 @@ func (T *Server) Provision(ctx caddy.Context) error { return nil } +func (T *Server) lookup(conn fed.Conn) *pool.Pool { + for _, route := range T.routes { + if route.match != nil && !route.match.Matches(conn) { + continue + } + + p := route.provide.Lookup(conn) + if p != nil { + return p + } + } + + return nil +} + var _ caddy.Provisioner = (*Server)(nil) diff --git a/lib/gsql/addr.go b/lib/gsql/addr.go new file mode 100644 index 0000000000000000000000000000000000000000..61bcf5d2439817290ec46ef99a1ed926d9003004 --- /dev/null +++ b/lib/gsql/addr.go @@ -0,0 +1,15 @@ +package gsql + +import "net" + +type Addr struct{} + +func (Addr) Network() string { + return "gsql" +} + +func (Addr) String() string { + return "local gsql client" +} + +var _ net.Addr = Addr{} diff --git a/lib/gsql/client.go b/lib/gsql/client.go index 85b5d2d9f81bf72dd2a81b01a7c3dd34a32a9cdc..07d232fe801b9e0e5af75fe20c3f204936496b44 100644 --- a/lib/gsql/client.go +++ b/lib/gsql/client.go @@ -8,6 +8,7 @@ import ( "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/util/ring" "gfx.cafe/gfx/pggat/lib/util/slices" + "gfx.cafe/gfx/pggat/lib/util/strutil" ) type batch struct { @@ -143,4 +144,28 @@ func (T *Client) Close() error { return nil } +func (T *Client) LocalAddr() net.Addr { + return Addr{} +} + +func (T *Client) RemoteAddr() net.Addr { + return Addr{} +} + +func (T *Client) SSLEnabled() bool { + return false +} + +func (T *Client) User() string { + return "" +} + +func (T *Client) Database() string { + return "" +} + +func (T *Client) InitialParameters() map[strutil.CIString]string { + return nil +} + var _ fed.Conn = (*Client)(nil) diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go index 3c13c928eca1e52bd615892d9c217e6006890c78..0a40d2588fff7bb4e120e72ca88859dd4850c391 100644 --- a/lib/middleware/interceptor/interceptor.go +++ b/lib/middleware/interceptor/interceptor.go @@ -1,25 +1,28 @@ package interceptor import ( + "net" + "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/middleware" + "gfx.cafe/gfx/pggat/lib/util/strutil" ) type Interceptor struct { middlewares []middleware.Middleware context Context - rw fed.Conn + conn fed.Conn } -func NewInterceptor(rw fed.Conn, middlewares ...middleware.Middleware) *Interceptor { - if v, ok := rw.(*Interceptor); ok { +func NewInterceptor(conn fed.Conn, middlewares ...middleware.Middleware) *Interceptor { + if v, ok := conn.(*Interceptor); ok { v.middlewares = append(v.middlewares, middlewares...) return v } return &Interceptor{ middlewares: middlewares, - context: makeContext(rw), - rw: rw, + context: makeContext(conn), + conn: conn, } } @@ -27,7 +30,7 @@ func (T *Interceptor) ReadPacket(typed bool, packet fed.Packet) (fed.Packet, err outer: for { var err error - packet, err = T.rw.ReadPacket(typed, packet) + packet, err = T.conn.ReadPacket(typed, packet) if err != nil { return packet, err } @@ -59,11 +62,35 @@ func (T *Interceptor) WritePacket(packet fed.Packet) error { } } - return T.rw.WritePacket(packet) + return T.conn.WritePacket(packet) +} + +func (T *Interceptor) LocalAddr() net.Addr { + return T.conn.LocalAddr() +} + +func (T *Interceptor) RemoteAddr() net.Addr { + return T.conn.RemoteAddr() +} + +func (T *Interceptor) SSLEnabled() bool { + return T.conn.SSLEnabled() +} + +func (T *Interceptor) User() string { + return T.conn.User() +} + +func (T *Interceptor) Database() string { + return T.conn.Database() +} + +func (T *Interceptor) InitialParameters() map[strutil.CIString]string { + return T.conn.InitialParameters() } func (T *Interceptor) Close() error { - return T.rw.Close() + return T.conn.Close() } var _ fed.Conn = (*Interceptor)(nil)