From df87b7cfcbb74f7def6e809842cec828a51966df Mon Sep 17 00:00:00 2001 From: Tom Guinther <tguinther@gfxlabs.io> Date: Fri, 26 Jul 2024 11:49:51 -0400 Subject: [PATCH] add context --- lib/fed/listeners/netconnlistener/listener.go | 3 ++- lib/fed/middleware.go | 10 ++++++---- lib/fed/middlewares/eqp/client.go | 9 +++++---- lib/fed/middlewares/eqp/server.go | 9 +++++---- lib/fed/middlewares/ps/client.go | 9 +++++---- lib/fed/middlewares/ps/server.go | 9 +++++---- lib/fed/middlewares/unterminate/unterminate.go | 9 +++++---- lib/gat/handlers/pool/dialer.go | 5 +++-- lib/gat/handlers/pool/pools/hybrid/middleware.go | 9 +++++---- lib/gsql/pair.go | 3 +++ lib/gsql/query_test.go | 3 ++- 11 files changed, 46 insertions(+), 32 deletions(-) diff --git a/lib/fed/listeners/netconnlistener/listener.go b/lib/fed/listeners/netconnlistener/listener.go index 0d6c0c55..6ce55f35 100644 --- a/lib/fed/listeners/netconnlistener/listener.go +++ b/lib/fed/listeners/netconnlistener/listener.go @@ -1,6 +1,7 @@ package netconnlistener import ( + "context" "net" "gfx.cafe/gfx/pggat/lib/fed" @@ -16,7 +17,7 @@ func (listener *Listener) Accept(fn func(*fed.Conn)) error { if err != nil { return err } - fedConn := fed.NewConn(netconncodec.NewCodec(raw)) + fedConn := fed.NewConn(context.Background(), netconncodec.NewCodec(raw)) go func() { fn(fedConn) }() diff --git a/lib/fed/middleware.go b/lib/fed/middleware.go index 92e31bd8..1d9668c2 100644 --- a/lib/fed/middleware.go +++ b/lib/fed/middleware.go @@ -1,12 +1,14 @@ package fed +import "context" + // Middleware intercepts packets and possibly changes them. Return a 0 length packet to cancel. type Middleware interface { - PreRead(typed bool) (Packet, error) - ReadPacket(packet Packet) (Packet, error) + PreRead(ctx context.Context, typed bool) (Packet, error) + ReadPacket(ctx context.Context, packet Packet) (Packet, error) - WritePacket(packet Packet) (Packet, error) - PostWrite() (Packet, error) + WritePacket(ctx context.Context, packet Packet) (Packet, error) + PostWrite(ctx context.Context) (Packet, error) } func LookupMiddleware[T Middleware](conn *Conn) (T, bool) { diff --git a/lib/fed/middlewares/eqp/client.go b/lib/fed/middlewares/eqp/client.go index 652d7073..0476f1b9 100644 --- a/lib/fed/middlewares/eqp/client.go +++ b/lib/fed/middlewares/eqp/client.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) @@ -12,19 +13,19 @@ func NewClient() *Client { return new(Client) } -func (T *Client) PreRead(_ bool) (fed.Packet, error) { +func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.C2S(packet) } -func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.S2C(packet) } -func (T *Client) PostWrite() (fed.Packet, error) { +func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/eqp/server.go b/lib/fed/middlewares/eqp/server.go index f847a9fa..676fe166 100644 --- a/lib/fed/middlewares/eqp/server.go +++ b/lib/fed/middlewares/eqp/server.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) @@ -12,19 +13,19 @@ func NewServer() *Server { return new(Server) } -func (T *Server) PreRead(_ bool) (fed.Packet, error) { +func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.S2C(packet) } -func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.C2S(packet) } -func (T *Server) PostWrite() (fed.Packet, error) { +func (T *Server) PostWrite(ctx context.Context, ) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/ps/client.go b/lib/fed/middlewares/ps/client.go index 94b5014f..d7f383b6 100644 --- a/lib/fed/middlewares/ps/client.go +++ b/lib/fed/middlewares/ps/client.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/util/maps" @@ -18,15 +19,15 @@ func NewClient(parameters map[strutil.CIString]string) *Client { } } -func (T *Client) PreRead(_ bool) (fed.Packet, error) { +func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { switch packet.Type() { case packets.TypeParameterStatus: var p packets.ParameterStatus @@ -49,7 +50,7 @@ func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { } } -func (T *Client) PostWrite() (fed.Packet, error) { +func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/ps/server.go b/lib/fed/middlewares/ps/server.go index 137b34df..2f922850 100644 --- a/lib/fed/middlewares/ps/server.go +++ b/lib/fed/middlewares/ps/server.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/util/strutil" @@ -16,11 +17,11 @@ func NewServer(parameters map[strutil.CIString]string) *Server { } } -func (T *Server) PreRead(_ bool) (fed.Packet, error) { +func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { switch packet.Type() { case packets.TypeParameterStatus: var p packets.ParameterStatus @@ -39,11 +40,11 @@ func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { } } -func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Server) PostWrite() (fed.Packet, error) { +func (T *Server) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/unterminate/unterminate.go b/lib/fed/middlewares/unterminate/unterminate.go index 073dbe10..633de635 100644 --- a/lib/fed/middlewares/unterminate/unterminate.go +++ b/lib/fed/middlewares/unterminate/unterminate.go @@ -1,6 +1,7 @@ package unterminate import ( + "context" "io" "gfx.cafe/gfx/pggat/lib/fed" @@ -13,22 +14,22 @@ var Unterminate = unterm{} type unterm struct{} -func (unterm) PreRead(_ bool) (fed.Packet, error) { +func (unterm) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (unterm) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (unterm) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if packet.Type() == packets.TypeTerminate { return packet, io.EOF } return packet, nil } -func (unterm) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (unterm) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (unterm) PostWrite() (fed.Packet, error) { +func (unterm) PostWrite(ctx context.Context, ) (fed.Packet, error) { return nil, nil } diff --git a/lib/gat/handlers/pool/dialer.go b/lib/gat/handlers/pool/dialer.go index 1da5d9fe..49ec224a 100644 --- a/lib/gat/handlers/pool/dialer.go +++ b/lib/gat/handlers/pool/dialer.go @@ -1,6 +1,7 @@ package pool import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -66,7 +67,7 @@ func (T *Dialer) Dial() (*fed.Conn, error) { if err != nil { return nil, err } - conn := fed.NewConn(netconncodec.NewCodec(c)) + conn := fed.NewConn(context.Background(), netconncodec.NewCodec(c)) conn.User = T.Username conn.Database = T.Database err = backends.Accept( @@ -90,7 +91,7 @@ func (T *Dialer) Cancel(key fed.BackendKey) { if err != nil { return } - conn := fed.NewConn(netconncodec.NewCodec(c)) + conn := fed.NewConn(context.Background(), netconncodec.NewCodec(c)) defer func() { _ = conn.Close() }() diff --git a/lib/gat/handlers/pool/pools/hybrid/middleware.go b/lib/gat/handlers/pool/pools/hybrid/middleware.go index 230ab054..ac6ff340 100644 --- a/lib/gat/handlers/pool/pools/hybrid/middleware.go +++ b/lib/gat/handlers/pool/pools/hybrid/middleware.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/perror" @@ -20,7 +21,7 @@ func NewMiddleware() *Middleware { return m } -func (T *Middleware) PreRead(typed bool) (fed.Packet, error) { +func (T *Middleware) PreRead(ctx context.Context, typed bool) (fed.Packet, error) { if !T.primary { return nil, nil } @@ -37,7 +38,7 @@ func (T *Middleware) PreRead(typed bool) (fed.Packet, error) { }, nil } -func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Middleware) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if T.primary { return packet, nil } @@ -60,7 +61,7 @@ func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) { return p, nil } -func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Middleware) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if T.primary && (T.buf.Buffered() > 0 || T.bufDec.Buffered() > 0) { return nil, nil } @@ -84,7 +85,7 @@ func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Middleware) PostWrite() (fed.Packet, error) { +func (T *Middleware) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/gsql/pair.go b/lib/gsql/pair.go index 2fcaff95..857f8fe1 100644 --- a/lib/gsql/pair.go +++ b/lib/gsql/pair.go @@ -1,6 +1,7 @@ package gsql import ( + "context" "net" "gfx.cafe/gfx/pggat/lib/fed" @@ -13,12 +14,14 @@ func NewPair() (*fed.Conn, *fed.Conn, net.Conn, net.Conn) { in := mio.InwardConn{Conn: conn} out := mio.OutwardConn{Conn: conn} inward := fed.NewConn( + context.Background(), netconncodec.NewCodec( in, ), ) inward.Ready = true outward := fed.NewConn( + context.Background(), netconncodec.NewCodec( out, ), diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index 476748de..0859a6ba 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -1,6 +1,7 @@ package gsql_test import ( + "context" "crypto/tls" "log" "net" @@ -32,7 +33,7 @@ func TestQuery(t *testing.T) { t.Error(err) return } - server := fed.NewConn(netconncodec.NewCodec(s)) + server := fed.NewConn(context.Background(), netconncodec.NewCodec(s)) err = backends.Accept( server, "disable", -- GitLab