From e4f94d9e5e4ba4be8b8944bb1cafdfec63934940 Mon Sep 17 00:00:00 2001 From: a <a@tuxpa.in> Date: Mon, 17 Jun 2024 17:52:05 -0500 Subject: [PATCH] noot --- lib/gat/handler.go | 17 +++- .../allowed_startup_parameters/module.go | 23 +++--- lib/gat/handlers/discovery/module.go | 22 +++--- lib/gat/handlers/error/module.go | 14 ++-- lib/gat/handlers/pgbouncer/module.go | 79 ++++++++++--------- lib/gat/handlers/pool/module.go | 14 ++-- lib/gat/handlers/require_ssl/module.go | 32 ++++---- lib/gat/handlers/rewrite_database/module.go | 23 +++--- lib/gat/handlers/rewrite_parameter/module.go | 16 ++-- lib/gat/handlers/rewrite_password/module.go | 12 +-- lib/gat/handlers/rewrite_user/module.go | 23 +++--- lib/gat/server.go | 58 +++++++------- lib/instrumentation/prom/metrics.go | 8 +- 13 files changed, 184 insertions(+), 157 deletions(-) diff --git a/lib/gat/handler.go b/lib/gat/handler.go index b0d2d339..41d1ed1f 100644 --- a/lib/gat/handler.go +++ b/lib/gat/handler.go @@ -9,7 +9,22 @@ import ( type Handler interface { // Handle will attempt to handle the Conn. Return io.EOF for normal disconnection or nil to continue to the next // handle. The error will be relayed to the client so there is no need to send it yourself. - Handle(conn *fed.Conn) error + Handle(Router) Router +} + +type HandlerFunc func(Router) Router + +func (H HandlerFunc) Handle(next Router) Router { + return H(next) +} + +type Router interface { + Route(conn *fed.Conn) error +} +type RouterFunc func(conn *fed.Conn) error + +func (R RouterFunc) Route(conn *fed.Conn) error { + return R(conn) } type CancellableHandler interface { diff --git a/lib/gat/handlers/allowed_startup_parameters/module.go b/lib/gat/handlers/allowed_startup_parameters/module.go index d797381b..03fc3f8e 100644 --- a/lib/gat/handlers/allowed_startup_parameters/module.go +++ b/lib/gat/handlers/allowed_startup_parameters/module.go @@ -29,18 +29,19 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { } } -func (T *Module) Handle(conn *fed.Conn) error { - for parameter := range conn.InitialParameters { - if !slices.Contains(T.Parameters, parameter) { - return perror.New( - perror.FATAL, - perror.FeatureNotSupported, - fmt.Sprintf(`Startup parameter "%s" is not supported`, parameter.String()), - ) +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + for parameter := range conn.InitialParameters { + if !slices.Contains(T.Parameters, parameter) { + return perror.New( + perror.FATAL, + perror.FeatureNotSupported, + fmt.Sprintf(`Startup parameter "%s" is not supported`, parameter.String()), + ) + } } - } - - return nil + return next.Route(conn) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go index edbe3d39..193f0c60 100644 --- a/lib/gat/handlers/discovery/module.go +++ b/lib/gat/handlers/discovery/module.go @@ -568,17 +568,17 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { }) } -func (T *Module) Handle(conn *fed.Conn) error { - p, ok := T.getPool(conn.User, conn.Database) - if !ok { - return nil - } - - if err := frontends.Authenticate(conn, p.creds); err != nil { - return err - } - - return p.pool.Serve(conn) +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + p, ok := T.getPool(conn.User, conn.Database) + if !ok { + return next.Route(conn) + } + if err := frontends.Authenticate(conn, p.creds); err != nil { + return err + } + return p.pool.Serve(conn) + }) } func (T *Module) Cancel(key fed.BackendKey) { diff --git a/lib/gat/handlers/error/module.go b/lib/gat/handlers/error/module.go index b8eedbf7..42c7b088 100644 --- a/lib/gat/handlers/error/module.go +++ b/lib/gat/handlers/error/module.go @@ -24,12 +24,14 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { } } -func (T *Module) Handle(_ *fed.Conn) error { - return perror.New( - perror.FATAL, - perror.InternalError, - T.Message, - ) +func (T *Module) Handle(gat.Router) gat.Router { + return gat.RouterFunc(func(c *fed.Conn) error { + return perror.New( + perror.FATAL, + perror.InternalError, + T.Message, + ) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/pgbouncer/module.go b/lib/gat/handlers/pgbouncer/module.go index 057eea53..49b54d9a 100644 --- a/lib/gat/handlers/pgbouncer/module.go +++ b/lib/gat/handlers/pgbouncer/module.go @@ -290,53 +290,54 @@ func (T *Module) lookup(user, database string) (poolAndCredentials, bool) { return T.tryCreate(user, database) } -func (T *Module) Handle(conn *fed.Conn) error { - // check ssl - if T.Config.PgBouncer.ClientTLSSSLMode.IsRequired() { - if !conn.SSL { +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + // check ssl + if T.Config.PgBouncer.ClientTLSSSLMode.IsRequired() { + if !conn.SSL { + return perror.New( + perror.FATAL, + perror.InvalidPassword, + "SSL is required", + ) + } + } + // check startup parameters + for key := range conn.InitialParameters { + if slices.Contains([]strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, key) { + continue + } + if slices.Contains(T.Config.PgBouncer.TrackExtraParameters, key) { + continue + } + if slices.Contains(T.Config.PgBouncer.IgnoreStartupParameters, key) { + continue + } + return perror.New( perror.FATAL, - perror.InvalidPassword, - "SSL is required", + perror.FeatureNotSupported, + fmt.Sprintf(`Startup parameter "%s" is not supported`, key.String()), ) } - } - // check startup parameters - for key := range conn.InitialParameters { - if slices.Contains([]strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), - }, key) { - continue - } - if slices.Contains(T.Config.PgBouncer.TrackExtraParameters, key) { - continue - } - if slices.Contains(T.Config.PgBouncer.IgnoreStartupParameters, key) { - continue + p, ok := T.lookup(conn.User, conn.Database) + if !ok { + return next.Route(conn) } - return perror.New( - perror.FATAL, - perror.FeatureNotSupported, - fmt.Sprintf(`Startup parameter "%s" is not supported`, key.String()), - ) - } - - p, ok := T.lookup(conn.User, conn.Database) - if !ok { - return nil - } - - if err := frontends.Authenticate(conn, p.creds); err != nil { - return err - } + if err := frontends.Authenticate(conn, p.creds); err != nil { + return err + } - return p.pool.Serve(conn) + return p.pool.Serve(conn) + }) } func (T *Module) ReadMetrics(metrics *metrics.Handler) { diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index ee60787c..6f44c1a6 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -49,12 +49,14 @@ func (T *Module) Provision(ctx caddy.Context) error { return nil } -func (T *Module) Handle(conn *fed.Conn) error { - if err := frontends.Authenticate(conn, nil); err != nil { - return err - } - - return T.pool.Serve(conn) +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(c *fed.Conn) error { + if err := frontends.Authenticate(c, nil); err != nil { + return err + } + + return T.pool.Serve(c) + }) } func (T *Module) ReadMetrics(metrics *metrics.Handler) { diff --git a/lib/gat/handlers/require_ssl/module.go b/lib/gat/handlers/require_ssl/module.go index 476a8285..bd73c154 100644 --- a/lib/gat/handlers/require_ssl/module.go +++ b/lib/gat/handlers/require_ssl/module.go @@ -25,26 +25,28 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { } } -func (T *Module) Handle(conn *fed.Conn) error { - if T.SSL { - if !conn.SSL { +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + if T.SSL { + if !conn.SSL { + return perror.New( + perror.FATAL, + perror.InvalidPassword, + "SSL is required", + ) + } + return next.Route(conn) + } + + if conn.SSL { return perror.New( perror.FATAL, perror.InvalidPassword, - "SSL is required", + "SSL is not allowed", ) } - return nil - } - - if conn.SSL { - return perror.New( - perror.FATAL, - perror.InvalidPassword, - "SSL is not allowed", - ) - } - return nil + return next.Route(conn) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/rewrite_database/module.go b/lib/gat/handlers/rewrite_database/module.go index e10d5e52..f9ff0840 100644 --- a/lib/gat/handlers/rewrite_database/module.go +++ b/lib/gat/handlers/rewrite_database/module.go @@ -37,17 +37,18 @@ func (T *Module) Validate() error { } } -func (T *Module) Handle(conn *fed.Conn) error { - switch T.Mode { - case "strip_prefix": - conn.Database = strings.TrimPrefix(conn.Database, T.Database) - case "strip_suffix": - conn.Database = strings.TrimSuffix(conn.Database, T.Database) - default: - conn.Database = T.Database - } - - return nil +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + switch T.Mode { + case "strip_prefix": + conn.Database = strings.TrimPrefix(conn.Database, T.Database) + case "strip_suffix": + conn.Database = strings.TrimSuffix(conn.Database, T.Database) + default: + conn.Database = T.Database + } + return next.Route(conn) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/rewrite_parameter/module.go b/lib/gat/handlers/rewrite_parameter/module.go index 89b056fb..57975fbb 100644 --- a/lib/gat/handlers/rewrite_parameter/module.go +++ b/lib/gat/handlers/rewrite_parameter/module.go @@ -26,13 +26,15 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { } } -func (T *Module) Handle(conn *fed.Conn) error { - if conn.InitialParameters == nil { - conn.InitialParameters = make(map[strutil.CIString]string) - } - conn.InitialParameters[T.Key] = T.Value - - return nil +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + if conn.InitialParameters == nil { + conn.InitialParameters = make(map[strutil.CIString]string) + } + conn.InitialParameters[T.Key] = T.Value + + return next.Route(conn) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/rewrite_password/module.go b/lib/gat/handlers/rewrite_password/module.go index 0cec7729..4eace0da 100644 --- a/lib/gat/handlers/rewrite_password/module.go +++ b/lib/gat/handlers/rewrite_password/module.go @@ -26,11 +26,13 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { } } -func (T *Module) Handle(conn *fed.Conn) error { - return frontends.Authenticate( - conn, - credentials.FromString(conn.User, T.Password), - ) +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + return frontends.Authenticate( + conn, + credentials.FromString(conn.User, T.Password), + ) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/rewrite_user/module.go b/lib/gat/handlers/rewrite_user/module.go index 22d16a09..18c4da8a 100644 --- a/lib/gat/handlers/rewrite_user/module.go +++ b/lib/gat/handlers/rewrite_user/module.go @@ -37,17 +37,18 @@ func (T *Module) Validate() error { } } -func (T *Module) Handle(conn *fed.Conn) error { - switch T.Mode { - case "strip_prefix": - conn.User = strings.TrimPrefix(conn.User, T.User) - case "strip_suffix": - conn.User = strings.TrimSuffix(conn.User, T.User) - default: - conn.User = T.User - } - - return nil +func (T *Module) Handle(next gat.Router) gat.Router { + return gat.RouterFunc(func(conn *fed.Conn) error { + switch T.Mode { + case "strip_prefix": + conn.User = strings.TrimPrefix(conn.User, T.User) + case "strip_suffix": + conn.User = strings.TrimSuffix(conn.User, T.User) + default: + conn.User = T.User + } + return next.Route(conn) + }) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/server.go b/lib/gat/server.go index fc0db871..0dd5f243 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -108,37 +108,37 @@ func (T *Server) ReadMetrics(m *metrics.Server) { } func (T *Server) Serve(conn *fed.Conn) { - for _, route := range T.routes { + composed := Router(RouterFunc(func(conn *fed.Conn) error { + // database not found + errResp := perror.ToPacket( + perror.New( + perror.FATAL, + perror.InvalidPassword, + fmt.Sprintf(`Database "%s" not found`, conn.Database), + ), + ) + _ = conn.WritePacket(errResp) + T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database)) + return nil + })) + for j := 0; j < len(T.routes); j++ { + route := T.routes[j] if route.match != nil && !route.match.Matches(conn) { continue } - - if route.handle == nil { - continue - } - err := route.handle.Handle(conn) - if err != nil { - if errors.Is(err, io.EOF) { - // normal closure - return - } - - errResp := perror.ToPacket(perror.Wrap(err)) - _ = conn.WritePacket(errResp) + composed = route.handle.Handle(composed) + } + err := composed.Route(conn) + if err != nil { + if errors.Is(err, io.EOF) { + // normal closure return } - } - // database not found - errResp := perror.ToPacket( - perror.New( - perror.FATAL, - perror.InvalidPassword, - fmt.Sprintf(`Database "%s" not found`, conn.Database), - ), - ) - _ = conn.WritePacket(errResp) - T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database)) + errResp := perror.ToPacket(perror.Wrap(err)) + _ = conn.WritePacket(errResp) + return + } } func (T *Server) accept(listener *Listener, conn *fed.Conn) { @@ -169,11 +169,11 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { } count := listener.open.Add(1) - prom.Listener.ClientConnections(labels).Inc() - prom.Listener.IncomingConnections(labels).Inc() + prom.Listener.Client(labels).Inc() + prom.Listener.Incoming(labels).Inc() defer func() { listener.open.Add(-1) - prom.Listener.ClientConnections(labels).Dec() + prom.Listener.Client(labels).Dec() }() if listener.MaxConnections != 0 && int(count) > listener.MaxConnections { @@ -186,7 +186,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { ) return } - prom.Listener.AcceptedConnections(labels).Inc() + prom.Listener.Accepted(labels).Inc() T.Serve(conn) } diff --git a/lib/instrumentation/prom/metrics.go b/lib/instrumentation/prom/metrics.go index 30a55be7..02a7e96a 100644 --- a/lib/instrumentation/prom/metrics.go +++ b/lib/instrumentation/prom/metrics.go @@ -10,23 +10,21 @@ type ListenerLabels struct { } var Listener struct { - Incoming func(ListenerLabels) prometheus.Counter `name:"incoming"` - Accepted func(ListenerLabels) prometheus.Counter `name:"accepted"` - Client func(ListenerLabels) prometheus.Gauge `name:"client"` + Incoming func(ListenerLabels) prometheus.Counter `name:"incoming" help:"incoming connections"` + Accepted func(ListenerLabels) prometheus.Counter `name:"accepted" help:"accepted connetions"` + Client func(ListenerLabels) prometheus.Gauge `name:"client" help:"current clients"` } type ServingLabels struct { } var Serving struct { - Route func() `name:""` } type InstanceLabels struct { } var Instance struct { - Route func() `name:""` } func init() { -- GitLab