From 9a8381a9bb64a732f87c7829cdf7a016b713a087 Mon Sep 17 00:00:00 2001 From: Tom Guinther <tguinther@gfxlabs.io> Date: Tue, 13 Aug 2024 17:10:38 -0400 Subject: [PATCH] more ctx params --- lib/fed/middlewares/eqp/sync.go | 16 ++++++++-------- lib/gat/app.go | 5 +++-- lib/gat/handler.go | 2 +- lib/gat/handlers/pool/module.go | 11 ++++++----- lib/gat/server.go | 32 +++++++++++++++++--------------- 5 files changed, 35 insertions(+), 31 deletions(-) diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index fbbe8f57..abf345aa 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -20,7 +20,7 @@ func preparedStatementsEqual(a, b *packets.Parse) bool { return true } -func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { +func SyncMiddleware(ctx context.Context, c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") @@ -36,7 +36,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { Which: 'P', Name: name, } - if err := server.WritePacket(ctx,&p); err != nil { + if err := server.WritePacket(ctx, &p); err != nil { return err } @@ -60,7 +60,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { Which: 'S', Name: name, } - if err := server.WritePacket(ctx,&p); err != nil { + if err := server.WritePacket(ctx, &p); err != nil { return err } @@ -75,7 +75,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { } } - if err := server.WritePacket(ctx,preparedStatement); err != nil { + if err := server.WritePacket(ctx, preparedStatement); err != nil { return err } @@ -84,7 +84,7 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { // bind all portals for _, portal := range c.state.portals { - if err := server.WritePacket(ctx,portal); err != nil { + if err := server.WritePacket(ctx, portal); err != nil { return err } @@ -93,18 +93,18 @@ func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { if needsBackendSync { var err error - err, _ = backends.Sync(server, nil) + err, _ = backends.Sync(ctx, server, nil) return err } return nil } -func Sync(ctx context.Context,client, server *fed.Conn) error { +func Sync(ctx context.Context, client, server *fed.Conn) error { c, ok := fed.LookupMiddleware[*Client](client) if !ok { panic("middleware not found") } - return SyncMiddleware(ctx,c, server) + return SyncMiddleware(ctx, c, server) } diff --git a/lib/gat/app.go b/lib/gat/app.go index b9126960..2c6abbad 100644 --- a/lib/gat/app.go +++ b/lib/gat/app.go @@ -1,6 +1,7 @@ package gat import ( + "context" "time" "github.com/caddyserver/caddy/v2" @@ -80,7 +81,7 @@ func (T *App) Start() error { } for _, server := range T.servers { - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return err } } @@ -92,7 +93,7 @@ func (T *App) Stop() error { close(T.closed) for _, server := range T.servers { - if err := server.Stop(); err != nil { + if err := server.Stop(context.Background()); err != nil { return err } } diff --git a/lib/gat/handler.go b/lib/gat/handler.go index f6519e35..682cf657 100644 --- a/lib/gat/handler.go +++ b/lib/gat/handler.go @@ -10,7 +10,7 @@ import ( type Handler interface { // Handle will attempt to handle the Conn. Return io.EOF for normal disconnection or nil to continue to the next // handle. The error will be relayed to the client so there is no need to send it yourself. - Handle(conn *fed.Conn) error + Handle(ctx context.Context, conn *fed.Conn) error } type CancellableHandler interface { diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index ee60787c..795201d5 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -1,6 +1,7 @@ package pool import ( + "context" "encoding/json" "github.com/caddyserver/caddy/v2" @@ -49,20 +50,20 @@ func (T *Module) Provision(ctx caddy.Context) error { return nil } -func (T *Module) Handle(conn *fed.Conn) error { - if err := frontends.Authenticate(conn, nil); err != nil { +func (T *Module) Handle(ctx context.Context, conn *fed.Conn) error { + if err := frontends.Authenticate(ctx, conn, nil); err != nil { return err } - return T.pool.Serve(conn) + return T.pool.Serve(ctx, conn) } func (T *Module) ReadMetrics(metrics *metrics.Handler) { T.pool.ReadMetrics(&metrics.Pool) } -func (T *Module) Cancel(key fed.BackendKey) { - T.pool.Cancel(key) +func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { + T.pool.Cancel(ctx, key) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/server.go b/lib/gat/server.go index 2fabc77f..ac87b910 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -1,6 +1,7 @@ package gat import ( + "context" "crypto/tls" "errors" "fmt" @@ -66,7 +67,7 @@ func (T *Server) Provision(ctx caddy.Context) error { return nil } -func (T *Server) Start() error { +func (T *Server) Start(ctx context.Context) error { for _, listener := range T.listen { if err := listener.Start(); err != nil { return err @@ -74,7 +75,7 @@ func (T *Server) Start() error { go func(listener *Listener) { for { - if !T.acceptFrom(listener) { + if !T.acceptFrom(ctx, listener) { break } } @@ -84,7 +85,7 @@ func (T *Server) Start() error { return nil } -func (T *Server) Stop() error { +func (T *Server) Stop(ctx context.Context) error { for _, listen := range T.listen { if err := listen.Stop(); err != nil { return err @@ -94,9 +95,9 @@ func (T *Server) Stop() error { return nil } -func (T *Server) Cancel(key fed.BackendKey) { +func (T *Server) Cancel(ctx context.Context, key fed.BackendKey) { for _, cancellableHandler := range T.cancellableHandlers { - cancellableHandler.Cancel(key) + cancellableHandler.Cancel(ctx, key) } } @@ -106,7 +107,7 @@ func (T *Server) ReadMetrics(m *metrics.Server) { } } -func (T *Server) Serve(conn *fed.Conn) { +func (T *Server) Serve(ctx context.Context, conn *fed.Conn) { for _, route := range T.routes { if route.match != nil && !route.match.Matches(conn) { continue @@ -115,7 +116,7 @@ func (T *Server) Serve(conn *fed.Conn) { if route.handle == nil { continue } - err := route.handle.Handle(conn) + err := route.handle.Handle(ctx,conn) if err != nil { if errors.Is(err, io.EOF) { // normal closure @@ -123,7 +124,7 @@ func (T *Server) Serve(conn *fed.Conn) { } errResp := perror.ToPacket(perror.Wrap(err)) - _ = conn.WritePacket(errResp) + _ = conn.WritePacket(ctx, errResp) return } } @@ -136,13 +137,13 @@ func (T *Server) Serve(conn *fed.Conn) { fmt.Sprintf(`Database "%s" not found`, conn.Database), ), ) - _ = conn.WritePacket(errResp) + _ = conn.WritePacket(ctx, errResp) T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database)) } -func (T *Server) accept(listener *Listener, conn *fed.Conn) { +func (T *Server) accept(ctx context.Context, listener *Listener, conn *fed.Conn) { defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() var tlsConfig *tls.Config @@ -162,7 +163,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { } if isCanceling { - T.Cancel(cancelKey) + T.Cancel(ctx, cancelKey) return } @@ -171,6 +172,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { if listener.MaxConnections != 0 && int(count) > listener.MaxConnections { _ = conn.WritePacket( + ctx, perror.ToPacket(perror.New( perror.FATAL, perror.TooManyConnections, @@ -180,12 +182,12 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { return } - T.Serve(conn) + T.Serve(ctx, conn) } -func (T *Server) acceptFrom(listener *Listener) bool { +func (T *Server) acceptFrom(ctx context.Context, listener *Listener) bool { err := listener.listener.Accept(func(c *fed.Conn) { - T.accept(listener, c) + T.accept(ctx, listener, c) }) if err != nil { if errors.Is(err, net.ErrClosed) { -- GitLab