good morning!!!!

Skip to content
Snippets Groups Projects
Commit 2a97ab9c authored by Garet Halliday's avatar Garet Halliday
Browse files

almost working

parent b42ed0c9
Branches
Tags
No related merge requests found
Showing
with 371 additions and 107 deletions
package frontends package frontends
import ( import (
"fmt" "crypto/tls"
"io" "io"
"strings" "strings"
"gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/fed"
packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
"gfx.cafe/gfx/pggat/lib/perror" "gfx.cafe/gfx/pggat/lib/perror"
"gfx.cafe/gfx/pggat/lib/util/slices"
"gfx.cafe/gfx/pggat/lib/util/strutil" "gfx.cafe/gfx/pggat/lib/util/strutil"
) )
func startup0( func startup0(
ctx *AcceptContext, ctx *acceptContext,
params *AcceptParams, params *acceptParams,
) (done bool, err perror.Error) { ) (cancelling bool, done bool, err perror.Error) {
var err2 error var err2 error
ctx.Packet, err2 = ctx.Conn.ReadPacket(false, ctx.Packet) ctx.Packet, err2 = ctx.Conn.ReadPacket(false, ctx.Packet)
if err2 != nil { if err2 != nil {
...@@ -34,18 +33,7 @@ func startup0( ...@@ -34,18 +33,7 @@ func startup0(
case 5678: case 5678:
// Cancel // Cancel
p.ReadBytes(params.CancelKey[:]) p.ReadBytes(params.CancelKey[:])
cancelling = true
if params.CancelKey == [8]byte{} {
// very rare that this would ever happen
// and it's ok if we don't honor cancel requests
err = perror.New(
perror.FATAL,
perror.ProtocolViolation,
"cancel key cannot be null",
)
return
}
done = true done = true
return return
case 5679: case 5679:
...@@ -150,15 +138,6 @@ func startup0( ...@@ -150,15 +138,6 @@ func startup0(
ikey := strutil.MakeCIString(key) ikey := strutil.MakeCIString(key)
if !slices.Contains(ctx.Options.AllowedStartupOptions, ikey) {
err = perror.New(
perror.FATAL,
perror.FeatureNotSupported,
fmt.Sprintf(`Startup parameter "%s" is not allowed`, key),
)
return
}
if params.InitialParameters == nil { if params.InitialParameters == nil {
params.InitialParameters = make(map[strutil.CIString]string) params.InitialParameters = make(map[strutil.CIString]string)
} }
...@@ -186,15 +165,6 @@ func startup0( ...@@ -186,15 +165,6 @@ func startup0(
} else { } else {
ikey := strutil.MakeCIString(key) ikey := strutil.MakeCIString(key)
if !slices.Contains(ctx.Options.AllowedStartupOptions, ikey) {
err = perror.New(
perror.FATAL,
perror.FeatureNotSupported,
fmt.Sprintf(`Startup parameter "%s" is not allowed`, key),
)
return
}
if params.InitialParameters == nil { if params.InitialParameters == nil {
params.InitialParameters = make(map[strutil.CIString]string) params.InitialParameters = make(map[strutil.CIString]string)
} }
...@@ -232,12 +202,12 @@ func startup0( ...@@ -232,12 +202,12 @@ func startup0(
return return
} }
func accept( func accept0(
ctx *AcceptContext, ctx *acceptContext,
) (params AcceptParams, err perror.Error) { ) (params acceptParams, err perror.Error) {
for { for {
var done bool var done bool
done, err = startup0(ctx, &params) params.IsCanceling, done, err = startup0(ctx, &params)
if err != nil { if err != nil {
return return
} }
...@@ -246,23 +216,10 @@ func accept( ...@@ -246,23 +216,10 @@ func accept(
} }
} }
if params.CancelKey != [8]byte{} {
return return
} }
if ctx.Options.SSLRequired && !params.SSLEnabled { func fail(packet fed.Packet, client fed.ReadWriter, err perror.Error) {
err = perror.New(
perror.FATAL,
perror.InvalidPassword,
"SSL is required",
)
return
}
return
}
func fail(packet fed.Packet, client fed.Conn, err perror.Error) {
resp := packets.ErrorResponse{ resp := packets.ErrorResponse{
Error: err, Error: err,
} }
...@@ -270,11 +227,37 @@ func fail(packet fed.Packet, client fed.Conn, err perror.Error) { ...@@ -270,11 +227,37 @@ func fail(packet fed.Packet, client fed.Conn, err perror.Error) {
_ = client.WritePacket(packet) _ = client.WritePacket(packet)
} }
func Accept(ctx *AcceptContext) (AcceptParams, perror.Error) { func accept(ctx *acceptContext) (acceptParams, perror.Error) {
params, err := accept(ctx) params, err := accept0(ctx)
if err != nil { if err != nil {
fail(ctx.Packet, ctx.Conn, err) fail(ctx.Packet, ctx.Conn, err)
return AcceptParams{}, err return acceptParams{}, err
} }
return params, nil return params, nil
} }
func Accept(conn fed.ReadWriter, tlsConfig *tls.Config) (
cancelKey [8]byte,
isCanceling bool,
sslEnabled bool,
user string,
database string,
initialParameters map[strutil.CIString]string,
err perror.Error,
) {
ctx := acceptContext{
Conn: conn,
Options: acceptOptions{
SSLConfig: tlsConfig,
},
}
var params acceptParams
params, err = accept(&ctx)
cancelKey = params.CancelKey
isCanceling = params.IsCanceling
sslEnabled = params.SSLEnabled
user = params.User
database = params.Database
initialParameters = params.InitialParameters
return
}
...@@ -6,11 +6,12 @@ import ( ...@@ -6,11 +6,12 @@ import (
"io" "io"
"gfx.cafe/gfx/pggat/lib/auth" "gfx.cafe/gfx/pggat/lib/auth"
"gfx.cafe/gfx/pggat/lib/fed"
packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
"gfx.cafe/gfx/pggat/lib/perror" "gfx.cafe/gfx/pggat/lib/perror"
) )
func authenticationSASLInitial(ctx *AuthenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) {
// check which authentication method the client wants // check which authentication method the client wants
var err2 error var err2 error
ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet)
...@@ -42,7 +43,7 @@ func authenticationSASLInitial(ctx *AuthenticateContext, creds auth.SASLServer) ...@@ -42,7 +43,7 @@ func authenticationSASLInitial(ctx *AuthenticateContext, creds auth.SASLServer)
return return
} }
func authenticationSASLContinue(ctx *AuthenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) {
var err2 error var err2 error
ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet)
if err2 != nil { if err2 != nil {
...@@ -67,7 +68,7 @@ func authenticationSASLContinue(ctx *AuthenticateContext, tool auth.SASLVerifier ...@@ -67,7 +68,7 @@ func authenticationSASLContinue(ctx *AuthenticateContext, tool auth.SASLVerifier
return return
} }
func authenticationSASL(ctx *AuthenticateContext, creds auth.SASLServer) perror.Error { func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) perror.Error {
saslInitial := packets.AuthenticationSASL{ saslInitial := packets.AuthenticationSASL{
Mechanisms: creds.SupportedSASLMechanisms(), Mechanisms: creds.SupportedSASLMechanisms(),
} }
...@@ -109,7 +110,7 @@ func authenticationSASL(ctx *AuthenticateContext, creds auth.SASLServer) perror. ...@@ -109,7 +110,7 @@ func authenticationSASL(ctx *AuthenticateContext, creds auth.SASLServer) perror.
return nil return nil
} }
func authenticationMD5(ctx *AuthenticateContext, creds auth.MD5Server) perror.Error { func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) perror.Error {
var salt [4]byte var salt [4]byte
_, err := rand.Read(salt[:]) _, err := rand.Read(salt[:])
if err != nil { if err != nil {
...@@ -141,7 +142,7 @@ func authenticationMD5(ctx *AuthenticateContext, creds auth.MD5Server) perror.Er ...@@ -141,7 +142,7 @@ func authenticationMD5(ctx *AuthenticateContext, creds auth.MD5Server) perror.Er
return nil return nil
} }
func authenticate(ctx *AuthenticateContext) (params AuthenticateParams, err perror.Error) { func authenticate0(ctx *authenticateContext) (params authenticateParams, err perror.Error) {
if ctx.Options.Credentials != nil { if ctx.Options.Credentials != nil {
if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok {
err = authenticationSASL(ctx, credsSASL) err = authenticationSASL(ctx, credsSASL)
...@@ -184,11 +185,24 @@ func authenticate(ctx *AuthenticateContext) (params AuthenticateParams, err perr ...@@ -184,11 +185,24 @@ func authenticate(ctx *AuthenticateContext) (params AuthenticateParams, err perr
return return
} }
func Authenticate(ctx *AuthenticateContext) (AuthenticateParams, perror.Error) { func authenticate(ctx *authenticateContext) (authenticateParams, perror.Error) {
params, err := authenticate(ctx) params, err := authenticate0(ctx)
if err != nil { if err != nil {
fail(ctx.Packet, ctx.Conn, err) fail(ctx.Packet, ctx.Conn, err)
return AuthenticateParams{}, err return authenticateParams{}, err
} }
return params, nil return params, nil
} }
func Authenticate(conn fed.ReadWriter, creds auth.Credentials) (backendKey [8]byte, err perror.Error) {
ctx := authenticateContext{
Conn: conn,
Options: authenticateOptions{
Credentials: creds,
},
}
var params authenticateParams
params, err = authenticate(&ctx)
backendKey = params.BackendKey
return
}
...@@ -2,14 +2,14 @@ package frontends ...@@ -2,14 +2,14 @@ package frontends
import "gfx.cafe/gfx/pggat/lib/fed" import "gfx.cafe/gfx/pggat/lib/fed"
type AcceptContext struct { type acceptContext struct {
Packet fed.Packet Packet fed.Packet
Conn fed.Conn Conn fed.ReadWriter
Options AcceptOptions Options acceptOptions
} }
type AuthenticateContext struct { type authenticateContext struct {
Packet fed.Packet Packet fed.Packet
Conn fed.Conn Conn fed.ReadWriter
Options AuthenticateOptions Options authenticateOptions
} }
...@@ -4,15 +4,12 @@ import ( ...@@ -4,15 +4,12 @@ import (
"crypto/tls" "crypto/tls"
"gfx.cafe/gfx/pggat/lib/auth" "gfx.cafe/gfx/pggat/lib/auth"
"gfx.cafe/gfx/pggat/lib/util/strutil"
) )
type AcceptOptions struct { type acceptOptions struct {
SSLRequired bool
SSLConfig *tls.Config SSLConfig *tls.Config
AllowedStartupOptions []strutil.CIString
} }
type AuthenticateOptions struct { type authenticateOptions struct {
Credentials auth.Credentials Credentials auth.Credentials
} }
...@@ -2,8 +2,9 @@ package frontends ...@@ -2,8 +2,9 @@ package frontends
import "gfx.cafe/gfx/pggat/lib/util/strutil" import "gfx.cafe/gfx/pggat/lib/util/strutil"
type AcceptParams struct { type acceptParams struct {
CancelKey [8]byte CancelKey [8]byte
IsCanceling bool
// or // or
...@@ -13,6 +14,6 @@ type AcceptParams struct { ...@@ -13,6 +14,6 @@ type AcceptParams struct {
InitialParameters map[strutil.CIString]string InitialParameters map[strutil.CIString]string
} }
type AuthenticateParams struct { type authenticateParams struct {
BackendKey [8]byte BackendKey [8]byte
} }
...@@ -9,24 +9,38 @@ import ( ...@@ -9,24 +9,38 @@ import (
"net" "net"
"gfx.cafe/gfx/pggat/lib/util/slices" "gfx.cafe/gfx/pggat/lib/util/slices"
"gfx.cafe/gfx/pggat/lib/util/strutil"
) )
type Conn interface { type Conn interface {
ReadWriter ReadWriter
LocalAddr() net.Addr
RemoteAddr() net.Addr
SSLEnabled() bool
User() string
Database() string
InitialParameters() map[strutil.CIString]string
Close() error Close() error
} }
type netConn struct { type NetConn struct {
conn net.Conn conn net.Conn
writer bufio.Writer writer bufio.Writer
reader bufio.Reader reader bufio.Reader
sslEnabled bool
user string
database string
initialParameters map[strutil.CIString]string
headerBuf [5]byte headerBuf [5]byte
} }
func WrapNetConn(conn net.Conn) Conn { func WrapNetConn(conn net.Conn) *NetConn {
c := &netConn{ c := &NetConn{
conn: conn, conn: conn,
} }
c.writer.Reset(conn) c.writer.Reset(conn)
...@@ -34,7 +48,50 @@ func WrapNetConn(conn net.Conn) Conn { ...@@ -34,7 +48,50 @@ func WrapNetConn(conn net.Conn) Conn {
return c return c
} }
func (T *netConn) EnableSSLClient(config *tls.Config) error { func (T *NetConn) LocalAddr() net.Addr {
return T.conn.LocalAddr()
}
func (T *NetConn) RemoteAddr() net.Addr {
return T.conn.RemoteAddr()
}
func (T *NetConn) SSLEnabled() bool {
return T.sslEnabled
}
func (T *NetConn) User() string {
return T.user
}
func (T *NetConn) SetUser(user string) {
T.user = user
}
func (T *NetConn) Database() string {
return T.database
}
func (T *NetConn) SetDatabase(database string) {
T.database = database
}
func (T *NetConn) InitialParameters() map[strutil.CIString]string {
return T.initialParameters
}
func (T *NetConn) SetInitialParameters(initialParameters map[strutil.CIString]string) {
T.initialParameters = initialParameters
}
var errSSLAlreadyEnabled = errors.New("ssl is already enabled")
func (T *NetConn) EnableSSLClient(config *tls.Config) error {
if T.sslEnabled {
return errSSLAlreadyEnabled
}
T.sslEnabled = true
if err := T.writer.Flush(); err != nil { if err := T.writer.Flush(); err != nil {
return err return err
} }
...@@ -48,7 +105,12 @@ func (T *netConn) EnableSSLClient(config *tls.Config) error { ...@@ -48,7 +105,12 @@ func (T *netConn) EnableSSLClient(config *tls.Config) error {
return sslConn.Handshake() return sslConn.Handshake()
} }
func (T *netConn) EnableSSLServer(config *tls.Config) error { func (T *NetConn) EnableSSLServer(config *tls.Config) error {
if T.sslEnabled {
return errSSLAlreadyEnabled
}
T.sslEnabled = true
if err := T.writer.Flush(); err != nil { if err := T.writer.Flush(); err != nil {
return err return err
} }
...@@ -62,14 +124,14 @@ func (T *netConn) EnableSSLServer(config *tls.Config) error { ...@@ -62,14 +124,14 @@ func (T *netConn) EnableSSLServer(config *tls.Config) error {
return sslConn.Handshake() return sslConn.Handshake()
} }
func (T *netConn) ReadByte() (byte, error) { func (T *NetConn) ReadByte() (byte, error) {
if err := T.writer.Flush(); err != nil { if err := T.writer.Flush(); err != nil {
return 0, err return 0, err
} }
return T.reader.ReadByte() return T.reader.ReadByte()
} }
func (T *netConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err error) { func (T *NetConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err error) {
packet = buffer packet = buffer
if err = T.writer.Flush(); err != nil { if err = T.writer.Flush(); err != nil {
...@@ -100,24 +162,24 @@ func (T *netConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err erro ...@@ -100,24 +162,24 @@ func (T *netConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err erro
return return
} }
func (T *netConn) WriteByte(b byte) error { func (T *NetConn) WriteByte(b byte) error {
return T.writer.WriteByte(b) return T.writer.WriteByte(b)
} }
func (T *netConn) WritePacket(packet Packet) error { func (T *NetConn) WritePacket(packet Packet) error {
_, err := T.writer.Write(packet.Bytes()) _, err := T.writer.Write(packet.Bytes())
return err return err
} }
func (T *netConn) Close() error { func (T *NetConn) Close() error {
if err := T.writer.Flush(); err != nil { if err := T.writer.Flush(); err != nil {
return err return err
} }
return T.conn.Close() return T.conn.Close()
} }
var _ Conn = (*netConn)(nil) var _ Conn = (*NetConn)(nil)
var _ SSLServer = (*netConn)(nil) var _ SSLServer = (*NetConn)(nil)
var _ SSLClient = (*netConn)(nil) var _ SSLClient = (*NetConn)(nil)
var _ io.ByteReader = (*netConn)(nil) var _ io.ByteReader = (*NetConn)(nil)
var _ io.ByteWriter = (*netConn)(nil) var _ io.ByteWriter = (*NetConn)(nil)
package gat package gat
import ( import (
"github.com/caddyserver/caddy/v2" "crypto/tls"
"errors"
"fmt"
"net"
"github.com/caddyserver/caddy/v2"
"tuxpa.in/a/zlog/log"
"gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0"
"gfx.cafe/gfx/pggat/lib/fed"
packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
"gfx.cafe/gfx/pggat/lib/middleware/interceptor"
"gfx.cafe/gfx/pggat/lib/middleware/middlewares/unterminate"
"gfx.cafe/gfx/pggat/lib/perror"
"gfx.cafe/gfx/pggat/lib/util/dur" "gfx.cafe/gfx/pggat/lib/util/dur"
"gfx.cafe/gfx/pggat/lib/util/maps" "gfx.cafe/gfx/pggat/lib/util/maps"
"gfx.cafe/gfx/pggat/lib/util/slices"
) )
type Config struct { type Config struct {
...@@ -61,12 +74,124 @@ func (T *App) Provision(ctx caddy.Context) error { ...@@ -61,12 +74,124 @@ func (T *App) Provision(ctx caddy.Context) error {
return nil return nil
} }
func (T *App) cancel(key [8]byte) {
p, _ := T.keys.Load(key)
if p == nil {
return
}
_ = p.Cancel(key)
}
func (T *App) serve(server *Server, conn fed.Conn) {
initialParameters := conn.InitialParameters()
for key := range initialParameters {
if !slices.Contains(server.AllowedStartupParameters, key) {
errResp := packets.ErrorResponse{
Error: perror.New(
perror.FATAL,
perror.FeatureNotSupported,
fmt.Sprintf(`Startup parameter "%s" is not allowed`, key),
),
}
_ = conn.WritePacket(errResp.IntoPacket(nil))
return
}
}
p := server.lookup(conn)
if p == nil {
log.Printf("pool not found for client: user=%s database=%s", conn.User(), conn.Database())
return
}
backendKey, err := frontends.Authenticate(conn, p.Credentials())
if err != nil {
log.Printf("error authenticating client: %v", err)
return
}
T.keys.Store(backendKey, p)
defer T.keys.Delete(backendKey)
if err2 := p.Serve(conn, backendKey); err2 != nil {
log.Printf("error serving client: %v", err2)
return
}
}
func (T *App) accept(listener *Listener, conn *fed.NetConn) {
defer func() {
_ = conn.Close()
}()
var tlsConfig *tls.Config
if listener.ssl != nil {
tlsConfig = listener.ssl.ServerTLSConfig()
}
cancelKey, isCanceling, _, user, database, initialParameters, err := frontends.Accept(conn, tlsConfig)
if err != nil {
log.Printf("error accepting client: %v", err)
return
}
if isCanceling {
T.cancel(cancelKey)
return
}
conn.SetUser(user)
conn.SetDatabase(database)
conn.SetInitialParameters(initialParameters)
for _, server := range T.servers {
if server.match == nil || server.match.Matches(conn) {
T.serve(server, interceptor.NewInterceptor(conn, unterminate.Unterminate))
return
}
}
log.Printf("server not found for client: user=%s database=%s", conn.User(), conn.Database())
errResp := packets.ErrorResponse{
Error: perror.New(
perror.FATAL,
perror.InternalError,
"No server is available to handle your request",
),
}
_ = conn.WritePacket(errResp.IntoPacket(nil))
}
func (T *App) acceptFrom(listener *Listener) bool {
conn, err := listener.accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return false
}
log.Printf("error accepting client: %v", err)
return true
}
go T.accept(listener, conn)
return true
}
func (T *App) Start() error { func (T *App) Start() error {
// start listeners // start listeners
for _, listener := range T.listen { for _, listener := range T.listen {
if err := listener.Start(); err != nil { if err := listener.Start(); err != nil {
return err return err
} }
go func(listener *Listener) {
for {
if !T.acceptFrom(listener) {
break
}
}
}(listener)
} }
return nil return nil
......
...@@ -7,6 +7,8 @@ import ( ...@@ -7,6 +7,8 @@ import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"tuxpa.in/a/zlog/log" "tuxpa.in/a/zlog/log"
"gfx.cafe/gfx/pggat/lib/fed"
) )
type ListenerConfig struct { type ListenerConfig struct {
...@@ -23,6 +25,14 @@ type Listener struct { ...@@ -23,6 +25,14 @@ type Listener struct {
listener net.Listener listener net.Listener
} }
func (T *Listener) accept() (*fed.NetConn, error) {
raw, err := T.listener.Accept()
if err != nil {
return nil, err
}
return fed.WrapNetConn(raw), nil
}
func (T *Listener) Provision(ctx caddy.Context) error { func (T *Listener) Provision(ctx caddy.Context) error {
if T.SSL != nil { if T.SSL != nil {
val, err := ctx.LoadModule(T, "SSL") val, err := ctx.LoadModule(T, "SSL")
......
package gat package gat
import "gfx.cafe/gfx/pggat/lib/fed"
type Matcher interface { type Matcher interface {
Matches(conn fed.Conn) bool
} }
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -44,6 +45,15 @@ func (T *And) Provision(ctx caddy.Context) error { ...@@ -44,6 +45,15 @@ func (T *And) Provision(ctx caddy.Context) error {
return nil return nil
} }
func (T *And) Matches(conn fed.Conn) bool {
for _, matcher := range T.and {
if !matcher.Matches(conn) {
return false
}
}
return true
}
var _ gat.Matcher = (*And)(nil) var _ gat.Matcher = (*And)(nil)
var _ caddy.Module = (*And)(nil) var _ caddy.Module = (*And)(nil)
var _ caddy.Provisioner = (*And)(nil) var _ caddy.Provisioner = (*And)(nil)
...@@ -3,6 +3,7 @@ package matchers ...@@ -3,6 +3,7 @@ package matchers
import ( import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -23,5 +24,9 @@ func (T *Database) CaddyModule() caddy.ModuleInfo { ...@@ -23,5 +24,9 @@ func (T *Database) CaddyModule() caddy.ModuleInfo {
} }
} }
func (T *Database) Matches(conn fed.Conn) bool {
return conn.Database() == T.Database
}
var _ gat.Matcher = (*Database)(nil) var _ gat.Matcher = (*Database)(nil)
var _ caddy.Module = (*Database)(nil) var _ caddy.Module = (*Database)(nil)
...@@ -3,6 +3,7 @@ package matchers ...@@ -3,6 +3,7 @@ package matchers
import ( import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -23,5 +24,10 @@ func (T *LocalAddress) CaddyModule() caddy.ModuleInfo { ...@@ -23,5 +24,10 @@ func (T *LocalAddress) CaddyModule() caddy.ModuleInfo {
} }
} }
func (T *LocalAddress) Matches(conn fed.Conn) bool {
// TODO(garet)
return true
}
var _ gat.Matcher = (*LocalAddress)(nil) var _ gat.Matcher = (*LocalAddress)(nil)
var _ caddy.Module = (*LocalAddress)(nil) var _ caddy.Module = (*LocalAddress)(nil)
...@@ -3,6 +3,7 @@ package matchers ...@@ -3,6 +3,7 @@ package matchers
import ( import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -23,5 +24,9 @@ func (T *Network) CaddyModule() caddy.ModuleInfo { ...@@ -23,5 +24,9 @@ func (T *Network) CaddyModule() caddy.ModuleInfo {
} }
} }
func (T *Network) Matches(conn fed.Conn) bool {
return conn.LocalAddr().Network() == T.Network
}
var _ gat.Matcher = (*Network)(nil) var _ gat.Matcher = (*Network)(nil)
var _ caddy.Module = (*Network)(nil) var _ caddy.Module = (*Network)(nil)
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -44,6 +45,15 @@ func (T *Or) Provision(ctx caddy.Context) error { ...@@ -44,6 +45,15 @@ func (T *Or) Provision(ctx caddy.Context) error {
return nil return nil
} }
func (T *Or) Matches(conn fed.Conn) bool {
for _, matcher := range T.or {
if matcher.Matches(conn) {
return true
}
}
return false
}
var _ gat.Matcher = (*Or)(nil) var _ gat.Matcher = (*Or)(nil)
var _ caddy.Module = (*Or)(nil) var _ caddy.Module = (*Or)(nil)
var _ caddy.Provisioner = (*Or)(nil) var _ caddy.Provisioner = (*Or)(nil)
...@@ -3,6 +3,7 @@ package matchers ...@@ -3,6 +3,7 @@ package matchers
import ( import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -23,5 +24,9 @@ func (T *SSL) CaddyModule() caddy.ModuleInfo { ...@@ -23,5 +24,9 @@ func (T *SSL) CaddyModule() caddy.ModuleInfo {
} }
} }
func (T *SSL) Matches(conn fed.Conn) bool {
return conn.SSLEnabled() == T.SSL
}
var _ gat.Matcher = (*SSL)(nil) var _ gat.Matcher = (*SSL)(nil)
var _ caddy.Module = (*SSL)(nil) var _ caddy.Module = (*SSL)(nil)
...@@ -3,7 +3,9 @@ package matchers ...@@ -3,7 +3,9 @@ package matchers
import ( import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
"gfx.cafe/gfx/pggat/lib/util/strutil"
) )
func init() { func init() {
...@@ -12,6 +14,8 @@ func init() { ...@@ -12,6 +14,8 @@ func init() {
type StartupParameters struct { type StartupParameters struct {
Parameters map[string]string `json:"startup_parameters"` Parameters map[string]string `json:"startup_parameters"`
parameters map[strutil.CIString]string
} }
func (T *StartupParameters) CaddyModule() caddy.ModuleInfo { func (T *StartupParameters) CaddyModule() caddy.ModuleInfo {
...@@ -23,5 +27,25 @@ func (T *StartupParameters) CaddyModule() caddy.ModuleInfo { ...@@ -23,5 +27,25 @@ func (T *StartupParameters) CaddyModule() caddy.ModuleInfo {
} }
} }
func (T *StartupParameters) Provision(ctx caddy.Context) error {
T.parameters = make(map[strutil.CIString]string, len(T.Parameters))
for key, value := range T.Parameters {
T.parameters[strutil.MakeCIString(key)] = value
}
return nil
}
func (T *StartupParameters) Matches(conn fed.Conn) bool {
initialParameters := conn.InitialParameters()
for key, value := range T.parameters {
if initialParameters[key] != value {
return false
}
}
return true
}
var _ gat.Matcher = (*StartupParameters)(nil) var _ gat.Matcher = (*StartupParameters)(nil)
var _ caddy.Module = (*StartupParameters)(nil) var _ caddy.Module = (*StartupParameters)(nil)
var _ caddy.Provisioner = (*StartupParameters)(nil)
...@@ -3,6 +3,7 @@ package matchers ...@@ -3,6 +3,7 @@ package matchers
import ( import (
"github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2"
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat"
) )
...@@ -23,5 +24,9 @@ func (T *User) CaddyModule() caddy.ModuleInfo { ...@@ -23,5 +24,9 @@ func (T *User) CaddyModule() caddy.ModuleInfo {
} }
} }
func (T *User) Matches(conn fed.Conn) bool {
return conn.User() == T.User
}
var _ gat.Matcher = (*User)(nil) var _ gat.Matcher = (*User)(nil)
var _ caddy.Module = (*User)(nil) var _ caddy.Module = (*User)(nil)
...@@ -7,7 +7,6 @@ import ( ...@@ -7,7 +7,6 @@ import (
"gfx.cafe/gfx/pggat/lib/middleware/middlewares/eqp" "gfx.cafe/gfx/pggat/lib/middleware/middlewares/eqp"
"gfx.cafe/gfx/pggat/lib/middleware/middlewares/ps" "gfx.cafe/gfx/pggat/lib/middleware/middlewares/ps"
"gfx.cafe/gfx/pggat/lib/middleware/middlewares/unterminate" "gfx.cafe/gfx/pggat/lib/middleware/middlewares/unterminate"
"gfx.cafe/gfx/pggat/lib/util/strutil"
) )
type pooledClient struct { type pooledClient struct {
...@@ -20,13 +19,14 @@ type pooledClient struct { ...@@ -20,13 +19,14 @@ type pooledClient struct {
func newClient( func newClient(
options Options, options Options,
conn fed.Conn, conn fed.Conn,
initialParameters map[strutil.CIString]string,
backendKey [8]byte, backendKey [8]byte,
) *pooledClient { ) *pooledClient {
middlewares := []middleware.Middleware{ middlewares := []middleware.Middleware{
unterminate.Unterminate, unterminate.Unterminate,
} }
initialParameters := conn.InitialParameters()
var psClient *ps.Client var psClient *ps.Client
if options.ParameterStatusSync == ParameterStatusSyncDynamic { if options.ParameterStatusSync == ParameterStatusSyncDynamic {
// add ps middleware // add ps middleware
......
...@@ -16,7 +16,6 @@ import ( ...@@ -16,7 +16,6 @@ import (
packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0"
"gfx.cafe/gfx/pggat/lib/gat/metrics" "gfx.cafe/gfx/pggat/lib/gat/metrics"
"gfx.cafe/gfx/pggat/lib/util/slices" "gfx.cafe/gfx/pggat/lib/util/slices"
"gfx.cafe/gfx/pggat/lib/util/strutil"
) )
type Pool struct { type Pool struct {
...@@ -79,7 +78,7 @@ func (T *Pool) idlest() (server *pooledServer, at time.Time) { ...@@ -79,7 +78,7 @@ func (T *Pool) idlest() (server *pooledServer, at time.Time) {
return return
} }
func (T *Pool) GetCredentials() auth.Credentials { func (T *Pool) Credentials() auth.Credentials {
return T.options.Credentials return T.options.Credentials
} }
...@@ -284,7 +283,6 @@ func (T *Pool) releaseServer(server *pooledServer) { ...@@ -284,7 +283,6 @@ func (T *Pool) releaseServer(server *pooledServer) {
func (T *Pool) Serve( func (T *Pool) Serve(
conn fed.Conn, conn fed.Conn,
initialParameters map[strutil.CIString]string,
backendKey [8]byte, backendKey [8]byte,
) error { ) error {
defer func() { defer func() {
...@@ -294,7 +292,6 @@ func (T *Pool) Serve( ...@@ -294,7 +292,6 @@ func (T *Pool) Serve(
client := newClient( client := newClient(
T.options, T.options,
conn, conn,
initialParameters,
backendKey, backendKey,
) )
...@@ -313,7 +310,6 @@ func (T *Pool) ServeBot( ...@@ -313,7 +310,6 @@ func (T *Pool) ServeBot(
client := newClient( client := newClient(
T.options, T.options,
conn, conn,
nil,
[8]byte{}, [8]byte{},
) )
......
package gat package gat
import "gfx.cafe/gfx/pggat/lib/gat/metrics" import (
"gfx.cafe/gfx/pggat/lib/fed"
"gfx.cafe/gfx/pggat/lib/gat/metrics"
)
// Provider provides pool to the server // Provider provides pool to the server
type Provider interface { type Provider interface {
Lookup(user, database string) *Pool Lookup(conn fed.Conn) *Pool
ReadMetrics(metrics *metrics.Pools) ReadMetrics(metrics *metrics.Pools)
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment