diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index fbbe8f5767025ce435e9e706e652d4680a88ea31..abf345aae310a0444a965ecd402ea55e207fdf1f 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -20,7 +20,7 @@ func preparedStatementsEqual(a, b *packets.Parse) bool { return true } -func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { +func SyncMiddleware(ctx context.Context, c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") @@ -36,7 +36,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { Which: 'P', Name: name, } - if err := server.WritePacket(ctx,&p); err != nil { + if err := server.WritePacket(ctx, &p); err != nil { return err } @@ -60,7 +60,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { Which: 'S', Name: name, } - if err := server.WritePacket(ctx,&p); err != nil { + if err := server.WritePacket(ctx, &p); err != nil { return err } @@ -75,7 +75,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { } } - if err := server.WritePacket(ctx,preparedStatement); err != nil { + if err := server.WritePacket(ctx, preparedStatement); err != nil { return err } @@ -84,7 +84,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { // bind all portals for _, portal := range c.state.portals { - if err := server.WritePacket(ctx,portal); err != nil { + if err := server.WritePacket(ctx, portal); err != nil { return err } @@ -93,18 +93,18 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { if needsBackendSync { var err error - err, _ = backends.Sync(server, nil) + err, _ = backends.Sync(ctx, server, nil) return err } return nil } -func Sync(ctx context.Context,client, server *fed.Conn) error { +func Sync(ctx context.Context, client, server *fed.Conn) error { c, ok := fed.LookupMiddleware[*Client](client) if !ok { panic("middleware not found") } - return SyncMiddleware(ctx,c, server) + return SyncMiddleware(ctx, c, server) } diff --git a/lib/gat/app.go b/lib/gat/app.go index b91269605ab2847eb3b63839c20e39a454e92479..2c6abbad73aca55d9bbdba22c24a6555cae22092 100644 --- a/lib/gat/app.go +++ b/lib/gat/app.go @@ -1,6 +1,7 @@ package gat import ( + "context" "time" "github.com/caddyserver/caddy/v2" @@ -80,7 +81,7 @@ func (T *App) Start() error { } for _, server := range T.servers { - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return err } } @@ -92,7 +93,7 @@ func (T *App) Stop() error { close(T.closed) for _, server := range T.servers { - if err := server.Stop(); err != nil { + if err := server.Stop(context.Background()); err != nil { return err } } diff --git a/lib/gat/handler.go b/lib/gat/handler.go index f6519e35385ff8e18f138e0e8d4678e891301aac..682cf657750c093f7674e8af32fd9ddb0c1a1979 100644 --- a/lib/gat/handler.go +++ b/lib/gat/handler.go @@ -10,7 +10,7 @@ import ( type Handler interface { // Handle will attempt to handle the Conn. Return io.EOF for normal disconnection or nil to continue to the next // handle. The error will be relayed to the client so there is no need to send it yourself. - Handle(conn *fed.Conn) error + Handle(ctx context.Context, conn *fed.Conn) error } type CancellableHandler interface { diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index ee60787c0a0dd6414c7ecbff49e0ad2de5415aa3..795201d55717f976dc1b1570b391be89b30918a0 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -1,6 +1,7 @@ package pool import ( + "context" "encoding/json" "github.com/caddyserver/caddy/v2" @@ -49,20 +50,20 @@ func (T *Module) Provision(ctx caddy.Context) error { return nil } -func (T *Module) Handle(conn *fed.Conn) error { - if err := frontends.Authenticate(conn, nil); err != nil { +func (T *Module) Handle(ctx context.Context, conn *fed.Conn) error { + if err := frontends.Authenticate(ctx, conn, nil); err != nil { return err } - return T.pool.Serve(conn) + return T.pool.Serve(ctx, conn) } func (T *Module) ReadMetrics(metrics *metrics.Handler) { T.pool.ReadMetrics(&metrics.Pool) } -func (T *Module) Cancel(key fed.BackendKey) { - T.pool.Cancel(key) +func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { + T.pool.Cancel(ctx, key) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/server.go b/lib/gat/server.go index 2fabc77f73069ef0891c5cfeb75001ec12493126..ac87b910f3e1ad4b0e45a367b113e7f0965ad11c 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -1,6 +1,7 @@ package gat import ( + "context" "crypto/tls" "errors" "fmt" @@ -66,7 +67,7 @@ func (T *Server) Provision(ctx caddy.Context) error { return nil } -func (T *Server) Start() error { +func (T *Server) Start(ctx context.Context) error { for _, listener := range T.listen { if err := listener.Start(); err != nil { return err @@ -74,7 +75,7 @@ func (T *Server) Start() error { go func(listener *Listener) { for { - if !T.acceptFrom(listener) { + if !T.acceptFrom(ctx, listener) { break } } @@ -84,7 +85,7 @@ func (T *Server) Start() error { return nil } -func (T *Server) Stop() error { +func (T *Server) Stop(ctx context.Context) error { for _, listen := range T.listen { if err := listen.Stop(); err != nil { return err @@ -94,9 +95,9 @@ func (T *Server) Stop() error { return nil } -func (T *Server) Cancel(key fed.BackendKey) { +func (T *Server) Cancel(ctx context.Context, key fed.BackendKey) { for _, cancellableHandler := range T.cancellableHandlers { - cancellableHandler.Cancel(key) + cancellableHandler.Cancel(ctx, key) } } @@ -106,7 +107,7 @@ func (T *Server) ReadMetrics(m *metrics.Server) { } } -func (T *Server) Serve(conn *fed.Conn) { +func (T *Server) Serve(ctx context.Context, conn *fed.Conn) { for _, route := range T.routes { if route.match != nil && !route.match.Matches(conn) { continue @@ -115,7 +116,7 @@ func (T *Server) Serve(conn *fed.Conn) { if route.handle == nil { continue } - err := route.handle.Handle(conn) + err := route.handle.Handle(ctx,conn) if err != nil { if errors.Is(err, io.EOF) { // normal closure @@ -123,7 +124,7 @@ func (T *Server) Serve(conn *fed.Conn) { } errResp := perror.ToPacket(perror.Wrap(err)) - _ = conn.WritePacket(errResp) + _ = conn.WritePacket(ctx, errResp) return } } @@ -136,13 +137,13 @@ func (T *Server) Serve(conn *fed.Conn) { fmt.Sprintf(`Database "%s" not found`, conn.Database), ), ) - _ = conn.WritePacket(errResp) + _ = conn.WritePacket(ctx, errResp) T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database)) } -func (T *Server) accept(listener *Listener, conn *fed.Conn) { +func (T *Server) accept(ctx context.Context, listener *Listener, conn *fed.Conn) { defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() var tlsConfig *tls.Config @@ -162,7 +163,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { } if isCanceling { - T.Cancel(cancelKey) + T.Cancel(ctx, cancelKey) return } @@ -171,6 +172,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { if listener.MaxConnections != 0 && int(count) > listener.MaxConnections { _ = conn.WritePacket( + ctx, perror.ToPacket(perror.New( perror.FATAL, perror.TooManyConnections, @@ -180,12 +182,12 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { return } - T.Serve(conn) + T.Serve(ctx, conn) } -func (T *Server) acceptFrom(listener *Listener) bool { +func (T *Server) acceptFrom(ctx context.Context, listener *Listener) bool { err := listener.listener.Accept(func(c *fed.Conn) { - T.accept(listener, c) + T.accept(ctx, listener, c) }) if err != nil { if errors.Is(err, net.ErrClosed) {