From e4f94d9e5e4ba4be8b8944bb1cafdfec63934940 Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Mon, 17 Jun 2024 17:52:05 -0500
Subject: [PATCH] noot

---
 lib/gat/handler.go                            | 17 +++-
 .../allowed_startup_parameters/module.go      | 23 +++---
 lib/gat/handlers/discovery/module.go          | 22 +++---
 lib/gat/handlers/error/module.go              | 14 ++--
 lib/gat/handlers/pgbouncer/module.go          | 79 ++++++++++---------
 lib/gat/handlers/pool/module.go               | 14 ++--
 lib/gat/handlers/require_ssl/module.go        | 32 ++++----
 lib/gat/handlers/rewrite_database/module.go   | 23 +++---
 lib/gat/handlers/rewrite_parameter/module.go  | 16 ++--
 lib/gat/handlers/rewrite_password/module.go   | 12 +--
 lib/gat/handlers/rewrite_user/module.go       | 23 +++---
 lib/gat/server.go                             | 58 +++++++-------
 lib/instrumentation/prom/metrics.go           |  8 +-
 13 files changed, 184 insertions(+), 157 deletions(-)

diff --git a/lib/gat/handler.go b/lib/gat/handler.go
index b0d2d339..41d1ed1f 100644
--- a/lib/gat/handler.go
+++ b/lib/gat/handler.go
@@ -9,7 +9,22 @@ import (
 type Handler interface {
 	// Handle will attempt to handle the Conn. Return io.EOF for normal disconnection or nil to continue to the next
 	// handle. The error will be relayed to the client so there is no need to send it yourself.
-	Handle(conn *fed.Conn) error
+	Handle(Router) Router
+}
+
+type HandlerFunc func(Router) Router
+
+func (H HandlerFunc) Handle(next Router) Router {
+	return H(next)
+}
+
+type Router interface {
+	Route(conn *fed.Conn) error
+}
+type RouterFunc func(conn *fed.Conn) error
+
+func (R RouterFunc) Route(conn *fed.Conn) error {
+	return R(conn)
 }
 
 type CancellableHandler interface {
diff --git a/lib/gat/handlers/allowed_startup_parameters/module.go b/lib/gat/handlers/allowed_startup_parameters/module.go
index d797381b..03fc3f8e 100644
--- a/lib/gat/handlers/allowed_startup_parameters/module.go
+++ b/lib/gat/handlers/allowed_startup_parameters/module.go
@@ -29,18 +29,19 @@ func (T *Module) CaddyModule() caddy.ModuleInfo {
 	}
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	for parameter := range conn.InitialParameters {
-		if !slices.Contains(T.Parameters, parameter) {
-			return perror.New(
-				perror.FATAL,
-				perror.FeatureNotSupported,
-				fmt.Sprintf(`Startup parameter "%s" is not supported`, parameter.String()),
-			)
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		for parameter := range conn.InitialParameters {
+			if !slices.Contains(T.Parameters, parameter) {
+				return perror.New(
+					perror.FATAL,
+					perror.FeatureNotSupported,
+					fmt.Sprintf(`Startup parameter "%s" is not supported`, parameter.String()),
+				)
+			}
 		}
-	}
-
-	return nil
+		return next.Route(conn)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go
index edbe3d39..193f0c60 100644
--- a/lib/gat/handlers/discovery/module.go
+++ b/lib/gat/handlers/discovery/module.go
@@ -568,17 +568,17 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) {
 	})
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	p, ok := T.getPool(conn.User, conn.Database)
-	if !ok {
-		return nil
-	}
-
-	if err := frontends.Authenticate(conn, p.creds); err != nil {
-		return err
-	}
-
-	return p.pool.Serve(conn)
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		p, ok := T.getPool(conn.User, conn.Database)
+		if !ok {
+			return next.Route(conn)
+		}
+		if err := frontends.Authenticate(conn, p.creds); err != nil {
+			return err
+		}
+		return p.pool.Serve(conn)
+	})
 }
 
 func (T *Module) Cancel(key fed.BackendKey) {
diff --git a/lib/gat/handlers/error/module.go b/lib/gat/handlers/error/module.go
index b8eedbf7..42c7b088 100644
--- a/lib/gat/handlers/error/module.go
+++ b/lib/gat/handlers/error/module.go
@@ -24,12 +24,14 @@ func (T *Module) CaddyModule() caddy.ModuleInfo {
 	}
 }
 
-func (T *Module) Handle(_ *fed.Conn) error {
-	return perror.New(
-		perror.FATAL,
-		perror.InternalError,
-		T.Message,
-	)
+func (T *Module) Handle(gat.Router) gat.Router {
+	return gat.RouterFunc(func(c *fed.Conn) error {
+		return perror.New(
+			perror.FATAL,
+			perror.InternalError,
+			T.Message,
+		)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/handlers/pgbouncer/module.go b/lib/gat/handlers/pgbouncer/module.go
index 057eea53..49b54d9a 100644
--- a/lib/gat/handlers/pgbouncer/module.go
+++ b/lib/gat/handlers/pgbouncer/module.go
@@ -290,53 +290,54 @@ func (T *Module) lookup(user, database string) (poolAndCredentials, bool) {
 	return T.tryCreate(user, database)
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	// check ssl
-	if T.Config.PgBouncer.ClientTLSSSLMode.IsRequired() {
-		if !conn.SSL {
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		// check ssl
+		if T.Config.PgBouncer.ClientTLSSSLMode.IsRequired() {
+			if !conn.SSL {
+				return perror.New(
+					perror.FATAL,
+					perror.InvalidPassword,
+					"SSL is required",
+				)
+			}
+		}
+		// check startup parameters
+		for key := range conn.InitialParameters {
+			if slices.Contains([]strutil.CIString{
+				strutil.MakeCIString("client_encoding"),
+				strutil.MakeCIString("datestyle"),
+				strutil.MakeCIString("timezone"),
+				strutil.MakeCIString("standard_conforming_strings"),
+				strutil.MakeCIString("application_name"),
+			}, key) {
+				continue
+			}
+			if slices.Contains(T.Config.PgBouncer.TrackExtraParameters, key) {
+				continue
+			}
+			if slices.Contains(T.Config.PgBouncer.IgnoreStartupParameters, key) {
+				continue
+			}
+
 			return perror.New(
 				perror.FATAL,
-				perror.InvalidPassword,
-				"SSL is required",
+				perror.FeatureNotSupported,
+				fmt.Sprintf(`Startup parameter "%s" is not supported`, key.String()),
 			)
 		}
-	}
 
-	// check startup parameters
-	for key := range conn.InitialParameters {
-		if slices.Contains([]strutil.CIString{
-			strutil.MakeCIString("client_encoding"),
-			strutil.MakeCIString("datestyle"),
-			strutil.MakeCIString("timezone"),
-			strutil.MakeCIString("standard_conforming_strings"),
-			strutil.MakeCIString("application_name"),
-		}, key) {
-			continue
-		}
-		if slices.Contains(T.Config.PgBouncer.TrackExtraParameters, key) {
-			continue
-		}
-		if slices.Contains(T.Config.PgBouncer.IgnoreStartupParameters, key) {
-			continue
+		p, ok := T.lookup(conn.User, conn.Database)
+		if !ok {
+			return next.Route(conn)
 		}
 
-		return perror.New(
-			perror.FATAL,
-			perror.FeatureNotSupported,
-			fmt.Sprintf(`Startup parameter "%s" is not supported`, key.String()),
-		)
-	}
-
-	p, ok := T.lookup(conn.User, conn.Database)
-	if !ok {
-		return nil
-	}
-
-	if err := frontends.Authenticate(conn, p.creds); err != nil {
-		return err
-	}
+		if err := frontends.Authenticate(conn, p.creds); err != nil {
+			return err
+		}
 
-	return p.pool.Serve(conn)
+		return p.pool.Serve(conn)
+	})
 }
 
 func (T *Module) ReadMetrics(metrics *metrics.Handler) {
diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go
index ee60787c..6f44c1a6 100644
--- a/lib/gat/handlers/pool/module.go
+++ b/lib/gat/handlers/pool/module.go
@@ -49,12 +49,14 @@ func (T *Module) Provision(ctx caddy.Context) error {
 	return nil
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	if err := frontends.Authenticate(conn, nil); err != nil {
-		return err
-	}
-
-	return T.pool.Serve(conn)
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(c *fed.Conn) error {
+		if err := frontends.Authenticate(c, nil); err != nil {
+			return err
+		}
+
+		return T.pool.Serve(c)
+	})
 }
 
 func (T *Module) ReadMetrics(metrics *metrics.Handler) {
diff --git a/lib/gat/handlers/require_ssl/module.go b/lib/gat/handlers/require_ssl/module.go
index 476a8285..bd73c154 100644
--- a/lib/gat/handlers/require_ssl/module.go
+++ b/lib/gat/handlers/require_ssl/module.go
@@ -25,26 +25,28 @@ func (T *Module) CaddyModule() caddy.ModuleInfo {
 	}
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	if T.SSL {
-		if !conn.SSL {
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		if T.SSL {
+			if !conn.SSL {
+				return perror.New(
+					perror.FATAL,
+					perror.InvalidPassword,
+					"SSL is required",
+				)
+			}
+			return next.Route(conn)
+		}
+
+		if conn.SSL {
 			return perror.New(
 				perror.FATAL,
 				perror.InvalidPassword,
-				"SSL is required",
+				"SSL is not allowed",
 			)
 		}
-		return nil
-	}
-
-	if conn.SSL {
-		return perror.New(
-			perror.FATAL,
-			perror.InvalidPassword,
-			"SSL is not allowed",
-		)
-	}
-	return nil
+		return next.Route(conn)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/handlers/rewrite_database/module.go b/lib/gat/handlers/rewrite_database/module.go
index e10d5e52..f9ff0840 100644
--- a/lib/gat/handlers/rewrite_database/module.go
+++ b/lib/gat/handlers/rewrite_database/module.go
@@ -37,17 +37,18 @@ func (T *Module) Validate() error {
 	}
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	switch T.Mode {
-	case "strip_prefix":
-		conn.Database = strings.TrimPrefix(conn.Database, T.Database)
-	case "strip_suffix":
-		conn.Database = strings.TrimSuffix(conn.Database, T.Database)
-	default:
-		conn.Database = T.Database
-	}
-
-	return nil
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		switch T.Mode {
+		case "strip_prefix":
+			conn.Database = strings.TrimPrefix(conn.Database, T.Database)
+		case "strip_suffix":
+			conn.Database = strings.TrimSuffix(conn.Database, T.Database)
+		default:
+			conn.Database = T.Database
+		}
+		return next.Route(conn)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/handlers/rewrite_parameter/module.go b/lib/gat/handlers/rewrite_parameter/module.go
index 89b056fb..57975fbb 100644
--- a/lib/gat/handlers/rewrite_parameter/module.go
+++ b/lib/gat/handlers/rewrite_parameter/module.go
@@ -26,13 +26,15 @@ func (T *Module) CaddyModule() caddy.ModuleInfo {
 	}
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	if conn.InitialParameters == nil {
-		conn.InitialParameters = make(map[strutil.CIString]string)
-	}
-	conn.InitialParameters[T.Key] = T.Value
-
-	return nil
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		if conn.InitialParameters == nil {
+			conn.InitialParameters = make(map[strutil.CIString]string)
+		}
+		conn.InitialParameters[T.Key] = T.Value
+
+		return next.Route(conn)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/handlers/rewrite_password/module.go b/lib/gat/handlers/rewrite_password/module.go
index 0cec7729..4eace0da 100644
--- a/lib/gat/handlers/rewrite_password/module.go
+++ b/lib/gat/handlers/rewrite_password/module.go
@@ -26,11 +26,13 @@ func (T *Module) CaddyModule() caddy.ModuleInfo {
 	}
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	return frontends.Authenticate(
-		conn,
-		credentials.FromString(conn.User, T.Password),
-	)
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		return frontends.Authenticate(
+			conn,
+			credentials.FromString(conn.User, T.Password),
+		)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/handlers/rewrite_user/module.go b/lib/gat/handlers/rewrite_user/module.go
index 22d16a09..18c4da8a 100644
--- a/lib/gat/handlers/rewrite_user/module.go
+++ b/lib/gat/handlers/rewrite_user/module.go
@@ -37,17 +37,18 @@ func (T *Module) Validate() error {
 	}
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	switch T.Mode {
-	case "strip_prefix":
-		conn.User = strings.TrimPrefix(conn.User, T.User)
-	case "strip_suffix":
-		conn.User = strings.TrimSuffix(conn.User, T.User)
-	default:
-		conn.User = T.User
-	}
-
-	return nil
+func (T *Module) Handle(next gat.Router) gat.Router {
+	return gat.RouterFunc(func(conn *fed.Conn) error {
+		switch T.Mode {
+		case "strip_prefix":
+			conn.User = strings.TrimPrefix(conn.User, T.User)
+		case "strip_suffix":
+			conn.User = strings.TrimSuffix(conn.User, T.User)
+		default:
+			conn.User = T.User
+		}
+		return next.Route(conn)
+	})
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/server.go b/lib/gat/server.go
index fc0db871..0dd5f243 100644
--- a/lib/gat/server.go
+++ b/lib/gat/server.go
@@ -108,37 +108,37 @@ func (T *Server) ReadMetrics(m *metrics.Server) {
 }
 
 func (T *Server) Serve(conn *fed.Conn) {
-	for _, route := range T.routes {
+	composed := Router(RouterFunc(func(conn *fed.Conn) error {
+		// database not found
+		errResp := perror.ToPacket(
+			perror.New(
+				perror.FATAL,
+				perror.InvalidPassword,
+				fmt.Sprintf(`Database "%s" not found`, conn.Database),
+			),
+		)
+		_ = conn.WritePacket(errResp)
+		T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database))
+		return nil
+	}))
+	for j := 0; j < len(T.routes); j++ {
+		route := T.routes[j]
 		if route.match != nil && !route.match.Matches(conn) {
 			continue
 		}
-
-		if route.handle == nil {
-			continue
-		}
-		err := route.handle.Handle(conn)
-		if err != nil {
-			if errors.Is(err, io.EOF) {
-				// normal closure
-				return
-			}
-
-			errResp := perror.ToPacket(perror.Wrap(err))
-			_ = conn.WritePacket(errResp)
+		composed = route.handle.Handle(composed)
+	}
+	err := composed.Route(conn)
+	if err != nil {
+		if errors.Is(err, io.EOF) {
+			// normal closure
 			return
 		}
-	}
 
-	// database not found
-	errResp := perror.ToPacket(
-		perror.New(
-			perror.FATAL,
-			perror.InvalidPassword,
-			fmt.Sprintf(`Database "%s" not found`, conn.Database),
-		),
-	)
-	_ = conn.WritePacket(errResp)
-	T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database))
+		errResp := perror.ToPacket(perror.Wrap(err))
+		_ = conn.WritePacket(errResp)
+		return
+	}
 }
 
 func (T *Server) accept(listener *Listener, conn *fed.Conn) {
@@ -169,11 +169,11 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) {
 	}
 
 	count := listener.open.Add(1)
-	prom.Listener.ClientConnections(labels).Inc()
-	prom.Listener.IncomingConnections(labels).Inc()
+	prom.Listener.Client(labels).Inc()
+	prom.Listener.Incoming(labels).Inc()
 	defer func() {
 		listener.open.Add(-1)
-		prom.Listener.ClientConnections(labels).Dec()
+		prom.Listener.Client(labels).Dec()
 	}()
 
 	if listener.MaxConnections != 0 && int(count) > listener.MaxConnections {
@@ -186,7 +186,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) {
 		)
 		return
 	}
-	prom.Listener.AcceptedConnections(labels).Inc()
+	prom.Listener.Accepted(labels).Inc()
 	T.Serve(conn)
 }
 
diff --git a/lib/instrumentation/prom/metrics.go b/lib/instrumentation/prom/metrics.go
index 30a55be7..02a7e96a 100644
--- a/lib/instrumentation/prom/metrics.go
+++ b/lib/instrumentation/prom/metrics.go
@@ -10,23 +10,21 @@ type ListenerLabels struct {
 }
 
 var Listener struct {
-	Incoming func(ListenerLabels) prometheus.Counter `name:"incoming"`
-	Accepted func(ListenerLabels) prometheus.Counter `name:"accepted"`
-	Client   func(ListenerLabels) prometheus.Gauge   `name:"client"`
+	Incoming func(ListenerLabels) prometheus.Counter `name:"incoming" help:"incoming connections"`
+	Accepted func(ListenerLabels) prometheus.Counter `name:"accepted" help:"accepted connetions"`
+	Client   func(ListenerLabels) prometheus.Gauge   `name:"client" help:"current clients"`
 }
 
 type ServingLabels struct {
 }
 
 var Serving struct {
-	Route func() `name:""`
 }
 
 type InstanceLabels struct {
 }
 
 var Instance struct {
-	Route func() `name:""`
 }
 
 func init() {
-- 
GitLab