From 9a8381a9bb64a732f87c7829cdf7a016b713a087 Mon Sep 17 00:00:00 2001
From: Tom Guinther <tguinther@gfxlabs.io>
Date: Tue, 13 Aug 2024 17:10:38 -0400
Subject: [PATCH] more ctx params

---
 lib/fed/middlewares/eqp/sync.go | 16 ++++++++--------
 lib/gat/app.go                  |  5 +++--
 lib/gat/handler.go              |  2 +-
 lib/gat/handlers/pool/module.go | 11 ++++++-----
 lib/gat/server.go               | 32 +++++++++++++++++---------------
 5 files changed, 35 insertions(+), 31 deletions(-)

diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go
index fbbe8f57..abf345aa 100644
--- a/lib/fed/middlewares/eqp/sync.go
+++ b/lib/fed/middlewares/eqp/sync.go
@@ -20,7 +20,7 @@ func preparedStatementsEqual(a, b *packets.Parse) bool {
 	return true
 }
 
-func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error {
+func SyncMiddleware(ctx context.Context, c *Client, server *fed.Conn) error {
 	s, ok := fed.LookupMiddleware[*Server](server)
 	if !ok {
 		panic("middleware not found")
@@ -36,7 +36,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error {
 			Which: 'P',
 			Name:  name,
 		}
-		if err := server.WritePacket(ctx,&p); err != nil {
+		if err := server.WritePacket(ctx, &p); err != nil {
 			return err
 		}
 
@@ -60,7 +60,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error {
 			Which: 'S',
 			Name:  name,
 		}
-		if err := server.WritePacket(ctx,&p); err != nil {
+		if err := server.WritePacket(ctx, &p); err != nil {
 			return err
 		}
 
@@ -75,7 +75,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error {
 			}
 		}
 
-		if err := server.WritePacket(ctx,preparedStatement); err != nil {
+		if err := server.WritePacket(ctx, preparedStatement); err != nil {
 			return err
 		}
 
@@ -84,7 +84,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error {
 
 	// bind all portals
 	for _, portal := range c.state.portals {
-		if err := server.WritePacket(ctx,portal); err != nil {
+		if err := server.WritePacket(ctx, portal); err != nil {
 			return err
 		}
 
@@ -93,18 +93,18 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error {
 
 	if needsBackendSync {
 		var err error
-		err, _ = backends.Sync(server, nil)
+		err, _ = backends.Sync(ctx, server, nil)
 		return err
 	}
 
 	return nil
 }
 
-func Sync(ctx context.Context,client, server *fed.Conn) error {
+func Sync(ctx context.Context, client, server *fed.Conn) error {
 	c, ok := fed.LookupMiddleware[*Client](client)
 	if !ok {
 		panic("middleware not found")
 	}
 
-	return SyncMiddleware(ctx,c, server)
+	return SyncMiddleware(ctx, c, server)
 }
diff --git a/lib/gat/app.go b/lib/gat/app.go
index b9126960..2c6abbad 100644
--- a/lib/gat/app.go
+++ b/lib/gat/app.go
@@ -1,6 +1,7 @@
 package gat
 
 import (
+	"context"
 	"time"
 
 	"github.com/caddyserver/caddy/v2"
@@ -80,7 +81,7 @@ func (T *App) Start() error {
 	}
 
 	for _, server := range T.servers {
-		if err := server.Start(); err != nil {
+		if err := server.Start(context.Background()); err != nil {
 			return err
 		}
 	}
@@ -92,7 +93,7 @@ func (T *App) Stop() error {
 	close(T.closed)
 
 	for _, server := range T.servers {
-		if err := server.Stop(); err != nil {
+		if err := server.Stop(context.Background()); err != nil {
 			return err
 		}
 	}
diff --git a/lib/gat/handler.go b/lib/gat/handler.go
index f6519e35..682cf657 100644
--- a/lib/gat/handler.go
+++ b/lib/gat/handler.go
@@ -10,7 +10,7 @@ import (
 type Handler interface {
 	// Handle will attempt to handle the Conn. Return io.EOF for normal disconnection or nil to continue to the next
 	// handle. The error will be relayed to the client so there is no need to send it yourself.
-	Handle(conn *fed.Conn) error
+	Handle(ctx context.Context, conn *fed.Conn) error
 }
 
 type CancellableHandler interface {
diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go
index ee60787c..795201d5 100644
--- a/lib/gat/handlers/pool/module.go
+++ b/lib/gat/handlers/pool/module.go
@@ -1,6 +1,7 @@
 package pool
 
 import (
+	"context"
 	"encoding/json"
 
 	"github.com/caddyserver/caddy/v2"
@@ -49,20 +50,20 @@ func (T *Module) Provision(ctx caddy.Context) error {
 	return nil
 }
 
-func (T *Module) Handle(conn *fed.Conn) error {
-	if err := frontends.Authenticate(conn, nil); err != nil {
+func (T *Module) Handle(ctx context.Context, conn *fed.Conn) error {
+	if err := frontends.Authenticate(ctx, conn, nil); err != nil {
 		return err
 	}
 
-	return T.pool.Serve(conn)
+	return T.pool.Serve(ctx, conn)
 }
 
 func (T *Module) ReadMetrics(metrics *metrics.Handler) {
 	T.pool.ReadMetrics(&metrics.Pool)
 }
 
-func (T *Module) Cancel(key fed.BackendKey) {
-	T.pool.Cancel(key)
+func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) {
+	T.pool.Cancel(ctx, key)
 }
 
 var _ gat.Handler = (*Module)(nil)
diff --git a/lib/gat/server.go b/lib/gat/server.go
index 2fabc77f..ac87b910 100644
--- a/lib/gat/server.go
+++ b/lib/gat/server.go
@@ -1,6 +1,7 @@
 package gat
 
 import (
+	"context"
 	"crypto/tls"
 	"errors"
 	"fmt"
@@ -66,7 +67,7 @@ func (T *Server) Provision(ctx caddy.Context) error {
 	return nil
 }
 
-func (T *Server) Start() error {
+func (T *Server) Start(ctx context.Context) error {
 	for _, listener := range T.listen {
 		if err := listener.Start(); err != nil {
 			return err
@@ -74,7 +75,7 @@ func (T *Server) Start() error {
 
 		go func(listener *Listener) {
 			for {
-				if !T.acceptFrom(listener) {
+				if !T.acceptFrom(ctx, listener) {
 					break
 				}
 			}
@@ -84,7 +85,7 @@ func (T *Server) Start() error {
 	return nil
 }
 
-func (T *Server) Stop() error {
+func (T *Server) Stop(ctx context.Context) error {
 	for _, listen := range T.listen {
 		if err := listen.Stop(); err != nil {
 			return err
@@ -94,9 +95,9 @@ func (T *Server) Stop() error {
 	return nil
 }
 
-func (T *Server) Cancel(key fed.BackendKey) {
+func (T *Server) Cancel(ctx context.Context, key fed.BackendKey) {
 	for _, cancellableHandler := range T.cancellableHandlers {
-		cancellableHandler.Cancel(key)
+		cancellableHandler.Cancel(ctx, key)
 	}
 }
 
@@ -106,7 +107,7 @@ func (T *Server) ReadMetrics(m *metrics.Server) {
 	}
 }
 
-func (T *Server) Serve(conn *fed.Conn) {
+func (T *Server) Serve(ctx context.Context, conn *fed.Conn) {
 	for _, route := range T.routes {
 		if route.match != nil && !route.match.Matches(conn) {
 			continue
@@ -115,7 +116,7 @@ func (T *Server) Serve(conn *fed.Conn) {
 		if route.handle == nil {
 			continue
 		}
-		err := route.handle.Handle(conn)
+		err := route.handle.Handle(ctx,conn)
 		if err != nil {
 			if errors.Is(err, io.EOF) {
 				// normal closure
@@ -123,7 +124,7 @@ func (T *Server) Serve(conn *fed.Conn) {
 			}
 
 			errResp := perror.ToPacket(perror.Wrap(err))
-			_ = conn.WritePacket(errResp)
+			_ = conn.WritePacket(ctx, errResp)
 			return
 		}
 	}
@@ -136,13 +137,13 @@ func (T *Server) Serve(conn *fed.Conn) {
 			fmt.Sprintf(`Database "%s" not found`, conn.Database),
 		),
 	)
-	_ = conn.WritePacket(errResp)
+	_ = conn.WritePacket(ctx, errResp)
 	T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database))
 }
 
-func (T *Server) accept(listener *Listener, conn *fed.Conn) {
+func (T *Server) accept(ctx context.Context, listener *Listener, conn *fed.Conn) {
 	defer func() {
-		_ = conn.Close()
+		_ = conn.Close(ctx)
 	}()
 
 	var tlsConfig *tls.Config
@@ -162,7 +163,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) {
 	}
 
 	if isCanceling {
-		T.Cancel(cancelKey)
+		T.Cancel(ctx, cancelKey)
 		return
 	}
 
@@ -171,6 +172,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) {
 
 	if listener.MaxConnections != 0 && int(count) > listener.MaxConnections {
 		_ = conn.WritePacket(
+			ctx,
 			perror.ToPacket(perror.New(
 				perror.FATAL,
 				perror.TooManyConnections,
@@ -180,12 +182,12 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) {
 		return
 	}
 
-	T.Serve(conn)
+	T.Serve(ctx, conn)
 }
 
-func (T *Server) acceptFrom(listener *Listener) bool {
+func (T *Server) acceptFrom(ctx context.Context, listener *Listener) bool {
 	err := listener.listener.Accept(func(c *fed.Conn) {
-		T.accept(listener, c)
+		T.accept(ctx, listener, c)
 	})
 	if err != nil {
 		if errors.Is(err, net.ErrClosed) {
-- 
GitLab