From d5b7cc348a1dcdfe38683daabfa69ff1bf158102 Mon Sep 17 00:00:00 2001 From: Tom Guinther <tguinther@gfxlabs.io> Date: Tue, 13 Aug 2024 16:57:13 -0400 Subject: [PATCH] adding context as 1st param to most significant funcs re-factoring things that have context as or in the name move some type defs closer to code that uses them (deleting type file) --- lib/bouncer/backends/v0/accept.go | 99 +++++----- lib/bouncer/backends/v0/cancel.go | 5 +- lib/bouncer/backends/v0/context.go | 32 ++-- lib/bouncer/backends/v0/query.go | 178 +++++++++--------- lib/bouncer/bouncers/v2/bouncer.go | 5 +- lib/bouncer/frontends/v0/accept.go | 86 +++++---- lib/bouncer/frontends/v0/authenticate.go | 65 ++++--- lib/bouncer/frontends/v0/context.go | 13 -- lib/bouncer/frontends/v0/options.go | 2 +- lib/bouncer/frontends/v0/params.go | 8 - lib/fed/codecs/netconncodec/codec.go | 19 +- lib/fed/conn.go | 53 +++--- lib/fed/interface.go | 20 -- lib/fed/middleware.go | 10 +- lib/fed/middlewares/eqp/client.go | 11 +- lib/fed/middlewares/eqp/server.go | 9 +- lib/fed/middlewares/eqp/sync.go | 15 +- lib/fed/middlewares/ps/client.go | 11 +- lib/fed/middlewares/ps/server.go | 9 +- lib/fed/middlewares/ps/sync.go | 21 ++- .../middlewares/unterminate/unterminate.go | 9 +- lib/fed/packetCodec.go | 21 +++ lib/gat/handler.go | 3 +- .../handlers/pool/critics/latency/critic.go | 5 +- lib/gat/handlers/pool/dialer.go | 10 +- lib/gat/handlers/pool/penalty.go | 3 +- lib/gat/handlers/pool/pool.go | 7 +- lib/gat/handlers/pool/pools/basic/pool.go | 31 +-- .../handlers/pool/pools/hybrid/middleware.go | 9 +- lib/gat/handlers/pool/pools/hybrid/pool.go | 81 ++++---- lib/gat/handlers/pool/spool/pool.go | 5 +- lib/gat/server.go | 2 +- lib/gsql/eq.go | 19 +- lib/gsql/query.go | 13 +- lib/gsql/query_test.go | 9 +- lib/gsql/row.go | 5 +- 36 files changed, 469 insertions(+), 434 deletions(-) delete mode 100644 lib/bouncer/frontends/v0/context.go delete mode 100644 lib/bouncer/frontends/v0/params.go delete mode 100644 lib/fed/interface.go create mode 100644 lib/fed/packetCodec.go diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 086de2e0..fbea82af 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -1,6 +1,7 @@ package backends import ( + "context" "crypto/tls" "errors" "io" @@ -13,9 +14,14 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) (done bool, err error) { +type acceptParams struct { + Conn *fed.Conn + Options acceptOptions +} + +func authenticationSASLChallenge(ctx context.Context, params *acceptParams, encoder auth.SASLEncoder) (done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -41,7 +47,7 @@ func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) ( } resp := packets.SASLResponse(response) - err = ctx.Conn.WritePacket(&resp) + err = params.Conn.WritePacket(ctx, &resp) return case *packets.AuthenticationPayloadSASLFinal: // finish @@ -60,7 +66,7 @@ func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) ( } } -func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASLClient) error { +func authenticationSASL(ctx context.Context, params *acceptParams, mechanisms []string, creds auth.SASLClient) error { mechanism, encoder, err := creds.EncodeSASL(mechanisms) if err != nil { return err @@ -74,7 +80,7 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL Mechanism: mechanism, InitialClientResponse: initialResponse, } - err = ctx.Conn.WritePacket(&saslInitialResponse) + err = params.Conn.WritePacket(ctx, &saslInitialResponse) if err != nil { return err } @@ -82,7 +88,7 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL // challenge loop for { var done bool - done, err = authenticationSASLChallenge(ctx, encoder) + done, err = authenticationSASLChallenge(ctx, params, encoder) if err != nil { return err } @@ -94,46 +100,46 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL return nil } -func authenticationMD5(ctx *acceptContext, salt [4]byte, creds auth.MD5Client) error { +func authenticationMD5(ctx context.Context, params *acceptParams, salt [4]byte, creds auth.MD5Client) error { pw := packets.PasswordMessage(creds.EncodeMD5(salt)) - err := ctx.Conn.WritePacket(&pw) + err := params.Conn.WritePacket(ctx, &pw) if err != nil { return err } return nil } -func authenticationCleartext(ctx *acceptContext, creds auth.CleartextClient) error { +func authenticationCleartext(ctx context.Context, params *acceptParams, creds auth.CleartextClient) error { pw := packets.PasswordMessage(creds.EncodeCleartext()) - err := ctx.Conn.WritePacket(&pw) + err := params.Conn.WritePacket(ctx, &pw) if err != nil { return err } return nil } -func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, err error) { +func authentication(ctx context.Context, params *acceptParams, p *packets.Authentication) (done bool, err error) { // they have more authentication methods than there are pokemon switch mode := p.Mode.(type) { case *packets.AuthenticationPayloadOk: // we're good to go, that was easy - ctx.Conn.Authenticated = true + params.Conn.Authenticated = true return true, nil case *packets.AuthenticationPayloadKerberosV5: err = errors.New("kerberos v5 is not supported") return case *packets.AuthenticationPayloadCleartextPassword: - c, ok := ctx.Options.Credentials.(auth.CleartextClient) + c, ok := params.Options.Credentials.(auth.CleartextClient) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationCleartext(ctx, c) + return false, authenticationCleartext(ctx, params, c) case *packets.AuthenticationPayloadMD5Password: - c, ok := ctx.Options.Credentials.(auth.MD5Client) + c, ok := params.Options.Credentials.(auth.MD5Client) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationMD5(ctx, *mode, c) + return false, authenticationMD5(ctx, params, *mode, c) case *packets.AuthenticationPayloadGSS: err = errors.New("gss is not supported") return @@ -141,7 +147,7 @@ func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, e err = errors.New("sspi is not supported") return case *packets.AuthenticationPayloadSASL: - c, ok := ctx.Options.Credentials.(auth.SASLClient) + c, ok := params.Options.Credentials.(auth.SASLClient) if !ok { return false, auth.ErrMethodNotSupported } @@ -151,16 +157,16 @@ func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, e mechanisms = append(mechanisms, m.Method) } - return false, authenticationSASL(ctx, mechanisms, c) + return false, authenticationSASL(ctx, params, mechanisms, c) default: err = errors.New("unknown authentication method") return } } -func startup0(ctx *acceptContext) (done bool, err error) { +func startup0(ctx context.Context, params *acceptParams) (done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -180,7 +186,7 @@ func startup0(ctx *acceptContext) (done bool, err error) { if err != nil { return } - return authentication(ctx, &p) + return authentication(ctx, params, &p) case packets.TypeNegotiateProtocolVersion: // we only support protocol 3.0 for now err = errors.New("server wanted to negotiate protocol version") @@ -191,9 +197,9 @@ func startup0(ctx *acceptContext) (done bool, err error) { } } -func startup1(ctx *acceptContext) (done bool, err error) { +func startup1(ctx context.Context, params *acceptParams) (done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -205,8 +211,8 @@ func startup1(ctx *acceptContext) (done bool, err error) { if err != nil { return } - ctx.Conn.BackendKey.SecretKey = p.SecretKey - ctx.Conn.BackendKey.ProcessID = p.ProcessID + params.Conn.BackendKey.SecretKey = p.SecretKey + params.Conn.BackendKey.ProcessID = p.ProcessID return false, nil case packets.TypeParameterStatus: @@ -216,10 +222,10 @@ func startup1(ctx *acceptContext) (done bool, err error) { return } ikey := strutil.MakeCIString(p.Key) - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + if params.Conn.InitialParameters == nil { + params.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = p.Value + params.Conn.InitialParameters[ikey] = p.Value return false, nil case packets.TypeReadyForQuery: return true, nil @@ -240,19 +246,19 @@ func startup1(ctx *acceptContext) (done bool, err error) { } } -func enableSSL(ctx *acceptContext) (bool, error) { +func enableSSL(ctx context.Context, params *acceptParams) (bool, error) { p := packets.Startup{ Mode: &packets.StartupPayloadControl{ Mode: &packets.StartupPayloadControlPayloadSSL{}, }, } - if err := ctx.Conn.WritePacket(&p); err != nil { + if err := params.Conn.WritePacket(ctx, &p); err != nil { return false, err } // read byte to see if ssl is allowed - yn, err := ctx.Conn.ReadByte() + yn, err := params.Conn.ReadByte(ctx) if err != nil { return false, err } @@ -262,26 +268,26 @@ func enableSSL(ctx *acceptContext) (bool, error) { return false, nil } - if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, true); err != nil { + if err = params.Conn.EnableSSL(ctx, params.Options.SSLConfig, true); err != nil { return false, err } return true, nil } -func accept(ctx *acceptContext) error { - username := ctx.Options.Username +func accept(ctx context.Context, params *acceptParams) error { + username := params.Options.Username - if ctx.Options.Database == "" { - ctx.Options.Database = username + if params.Options.Database == "" { + params.Options.Database = username } - if ctx.Options.SSLMode.ShouldAttempt() { - sslEnabled, err := enableSSL(ctx) + if params.Options.SSLMode.ShouldAttempt() { + sslEnabled, err := enableSSL(ctx, params) if err != nil { return err } - if !sslEnabled && ctx.Options.SSLMode.IsRequired() { + if !sslEnabled && params.Options.SSLMode.IsRequired() { return errors.New("server rejected SSL encryption") } } @@ -295,12 +301,12 @@ func accept(ctx *acceptContext) error { }, { Key: "database", - Value: ctx.Options.Database, + Value: params.Options.Database, }, }, } - for key, value := range ctx.Options.StartupParameters { + for key, value := range params.Options.StartupParameters { m.Parameters = append(m.Parameters, packets.StartupPayloadVersion3PayloadParameter{ Key: key.String(), Value: value, @@ -311,14 +317,14 @@ func accept(ctx *acceptContext) error { Mode: &m, } - err := ctx.Conn.WritePacket(&p) + err := params.Conn.WritePacket(ctx, &p) if err != nil { return err } for { var done bool - done, err = startup0(ctx) + done, err = startup0(ctx, params) if err != nil { return err } @@ -329,7 +335,7 @@ func accept(ctx *acceptContext) error { for { var done bool - done, err = startup1(ctx) + done, err = startup1(ctx, params) if err != nil { return err } @@ -343,6 +349,7 @@ func accept(ctx *acceptContext) error { } func Accept( + ctx context.Context, conn *fed.Conn, sslMode bouncer.SSLMode, sslConfig *tls.Config, @@ -351,7 +358,7 @@ func Accept( database string, startupParameters map[strutil.CIString]string, ) error { - ctx := acceptContext{ + params := acceptParams{ Conn: conn, Options: acceptOptions{ SSLMode: sslMode, @@ -362,5 +369,5 @@ func Accept( StartupParameters: startupParameters, }, } - return accept(&ctx) + return accept(ctx, ¶ms) } diff --git a/lib/bouncer/backends/v0/cancel.go b/lib/bouncer/backends/v0/cancel.go index 48ee2b66..ab665da0 100644 --- a/lib/bouncer/backends/v0/cancel.go +++ b/lib/bouncer/backends/v0/cancel.go @@ -1,11 +1,12 @@ package backends import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func Cancel(server *fed.Conn, key fed.BackendKey) error { +func Cancel(ctx context.Context, server *fed.Conn, key fed.BackendKey) error { p := packets.Startup{ Mode: &packets.StartupPayloadControl{ Mode: &packets.StartupPayloadControlPayloadCancel{ @@ -14,5 +15,5 @@ func Cancel(server *fed.Conn, key fed.BackendKey) error { }, }, } - return server.WritePacket(&p) + return server.WritePacket(ctx, &p) } diff --git a/lib/bouncer/backends/v0/context.go b/lib/bouncer/backends/v0/context.go index 8288ab18..ab1d35ec 100644 --- a/lib/bouncer/backends/v0/context.go +++ b/lib/bouncer/backends/v0/context.go @@ -1,44 +1,40 @@ package backends import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) -type acceptContext struct { - Conn *fed.Conn - Options acceptOptions -} - -type context struct { +type serverToPeerBinding struct { Server *fed.Conn - Packet fed.Packet Peer *fed.Conn + Packet fed.Packet PeerError error TxState byte } -func (T *context) ErrUnexpectedPacket() error { +func (T *serverToPeerBinding) ErrUnexpectedPacket() error { return ErrUnexpectedPacket(T.Packet.Type()) } -func (T *context) ServerRead() error { +func (T *serverToPeerBinding) ServerRead(ctx context.Context) error { var err error - T.Packet, err = T.Server.ReadPacket(true) + T.Packet, err = T.Server.ReadPacket(ctx, true) return err } -func (T *context) ServerWrite() error { - return T.Server.WritePacket(T.Packet) +func (T *serverToPeerBinding) ServerWrite(ctx context.Context) error { + return T.Server.WritePacket(ctx, T.Packet) } -func (T *context) PeerOK() bool { +func (T *serverToPeerBinding) PeerOK() bool { if T == nil { return false } return T.Peer != nil && T.PeerError == nil } -func (T *context) PeerFail(err error) { +func (T *serverToPeerBinding) PeerFail(err error) { if T == nil { return } @@ -46,7 +42,7 @@ func (T *context) PeerFail(err error) { T.PeerError = err } -func (T *context) PeerRead() bool { +func (T *serverToPeerBinding) PeerRead(ctx context.Context) bool { if T == nil { return false } @@ -54,7 +50,7 @@ func (T *context) PeerRead() bool { return false } var err error - T.Packet, err = T.Peer.ReadPacket(true) + T.Packet, err = T.Peer.ReadPacket(ctx, true) if err != nil { T.PeerFail(err) return false @@ -62,14 +58,14 @@ func (T *context) PeerRead() bool { return true } -func (T *context) PeerWrite() { +func (T *serverToPeerBinding) PeerWrite(ctx context.Context) { if T == nil { return } if !T.PeerOK() { return } - err := T.Peer.WritePacket(T.Packet) + err := T.Peer.WritePacket(ctx, T.Packet) if err != nil { T.PeerFail(err) } diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index ebd466d2..7d187bd4 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -1,6 +1,7 @@ package backends import ( + "context" "strings" "gfx.cafe/gfx/pggat/lib/fed" @@ -8,65 +9,65 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func copyIn(ctx *context) error { - ctx.PeerWrite() +func copyIn(ctx context.Context, binding *serverToPeerBinding) error { + binding.PeerWrite(ctx) for { - if !ctx.PeerRead() { + if !binding.PeerRead(ctx) { copyFail := packets.CopyFail("peer failed") - ctx.Packet = ©Fail - return ctx.ServerWrite() + binding.Packet = ©Fail + return binding.ServerWrite(ctx) } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeCopyData: - if err := ctx.ServerWrite(); err != nil { + if err := binding.ServerWrite(ctx); err != nil { return err } case packets.TypeCopyDone, packets.TypeCopyFail: - return ctx.ServerWrite() + return binding.ServerWrite(ctx) default: - ctx.PeerFail(ctx.ErrUnexpectedPacket()) + binding.PeerFail(binding.ErrUnexpectedPacket()) } } } -func copyOut(ctx *context) error { - ctx.PeerWrite() +func copyOut(ctx context.Context, binding *serverToPeerBinding) error { + binding.PeerWrite(ctx) for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeCopyData, packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeCopyDone, packets.TypeErrorResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) return nil default: - return ctx.ErrUnexpectedPacket() + return binding.ErrUnexpectedPacket() } } } -func query(ctx *context) error { - if err := ctx.ServerWrite(); err != nil { +func query(ctx context.Context, binding *serverToPeerBinding) error { + if err := binding.ServerWrite(ctx); err != nil { return err } for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeCommandComplete, packets.TypeRowDescription, packets.TypeDataRow, @@ -75,48 +76,48 @@ func query(ctx *context) error { packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeCopyInResponse: - if err = copyIn(ctx); err != nil { + if err = copyIn(ctx, binding); err != nil { return err } case packets.TypeCopyOutResponse: - if err = copyOut(ctx); err != nil { + if err = copyOut(ctx, binding); err != nil { return err } case packets.TypeReadyForQuery: var p packets.ReadyForQuery - err = fed.ToConcrete(&p, ctx.Packet) + err = fed.ToConcrete(&p, binding.Packet) if err != nil { return err } - ctx.Packet = &p - ctx.TxState = byte(p) - ctx.PeerWrite() + binding.Packet = &p + binding.TxState = byte(p) + binding.PeerWrite(ctx) return nil default: - return ctx.ErrUnexpectedPacket() + return binding.ErrUnexpectedPacket() } } } -func queryString(ctx *context, q string) error { +func queryString(ctx context.Context, binding *serverToPeerBinding, q string) error { qq := packets.Query(q) - ctx.Packet = &qq - return query(ctx) + binding.Packet = &qq + return query(ctx, binding) } -func QueryString(server, peer *fed.Conn, query string) (err, peerError error) { - ctx := context{ +func QueryString(ctx context.Context, server, peer *fed.Conn, query string) (err, peerError error) { + binding := serverToPeerBinding{ Server: server, Peer: peer, } - err = queryString(&ctx, query) - peerError = ctx.PeerError + err = queryString(ctx, &binding, query) + peerError = binding.PeerError return } -func SetParameter(server, peer *fed.Conn, name strutil.CIString, value string) (err, peerError error) { +func SetParameter(ctx context.Context, server, peer *fed.Conn, name strutil.CIString, value string) (err, peerError error) { var q strings.Builder escapedName := strutil.Escape(name.String(), '"') escapedValue := strutil.Escape(value, '\'') @@ -128,58 +129,59 @@ func SetParameter(server, peer *fed.Conn, name strutil.CIString, value string) ( q.WriteString(`'`) return QueryString( + ctx, server, peer, q.String(), ) } -func functionCall(ctx *context) error { - if err := ctx.ServerWrite(); err != nil { +func functionCall(ctx context.Context, binding *serverToPeerBinding) error { + if err := binding.ServerWrite(ctx); err != nil { return err } for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeErrorResponse, packets.TypeFunctionCallResponse, packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeReadyForQuery: var p packets.ReadyForQuery - err = fed.ToConcrete(&p, ctx.Packet) + err = fed.ToConcrete(&p, binding.Packet) if err != nil { return err } - ctx.Packet = &p - ctx.TxState = byte(p) - ctx.PeerWrite() + binding.Packet = &p + binding.TxState = byte(p) + binding.PeerWrite(ctx) return nil default: - return ctx.ErrUnexpectedPacket() + return binding.ErrUnexpectedPacket() } } } -func sync(ctx *context) (bool, error) { - if err := ctx.ServerWrite(); err != nil { +func sync(ctx context.Context, binding *serverToPeerBinding) (bool, error) { + if err := binding.ServerWrite(ctx); err != nil { return false, err } for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return false, err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeParseComplete, packets.TypeBindComplete, packets.TypeCloseComplete, @@ -196,54 +198,54 @@ func sync(ctx *context) (bool, error) { packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeCopyInResponse: - if err = copyIn(ctx); err != nil { + if err = copyIn(ctx, binding); err != nil { return false, err } // why return false, nil case packets.TypeCopyOutResponse: - if err = copyOut(ctx); err != nil { + if err = copyOut(ctx, binding); err != nil { return false, err } case packets.TypeReadyForQuery: var p packets.ReadyForQuery - err = fed.ToConcrete(&p, ctx.Packet) + err = fed.ToConcrete(&p, binding.Packet) if err != nil { return false, err } - ctx.Packet = &p - ctx.TxState = byte(p) - ctx.PeerWrite() + binding.Packet = &p + binding.TxState = byte(p) + binding.PeerWrite(ctx) return true, nil default: - return false, ctx.ErrUnexpectedPacket() + return false, binding.ErrUnexpectedPacket() } } } -func Sync(server, peer *fed.Conn) (err, peerErr error) { - ctx := context{ +func Sync(ctx context.Context, server, peer *fed.Conn) (err, peerErr error) { + binding := serverToPeerBinding{ Server: server, Peer: peer, Packet: &packets.Sync{}, } - _, err = sync(&ctx) - peerErr = ctx.PeerError + _, err = sync(ctx, &binding) + peerErr = binding.PeerError return } -func eqp(ctx *context) error { - if err := ctx.ServerWrite(); err != nil { +func eqp(ctx context.Context, binding *serverToPeerBinding) error { + if err := binding.ServerWrite(ctx); err != nil { return err } for { - if !ctx.PeerRead() { + if !binding.PeerRead(ctx) { for { - ctx.Packet = &packets.Sync{} - ok, err := sync(ctx) + binding.Packet = &packets.Sync{} + ok, err := sync(ctx, binding) if err != nil { return err } @@ -253,9 +255,9 @@ func eqp(ctx *context) error { } } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeSync: - ok, err := sync(ctx) + ok, err := sync(ctx, binding) if err != nil { return err } @@ -263,51 +265,51 @@ func eqp(ctx *context) error { return nil } case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: - if err := ctx.ServerWrite(); err != nil { + if err := binding.ServerWrite(ctx); err != nil { return err } default: - ctx.PeerFail(ctx.ErrUnexpectedPacket()) + binding.PeerFail(binding.ErrUnexpectedPacket()) } } } -func transaction(ctx *context) error { +func transaction(ctx context.Context, binding *serverToPeerBinding) error { for { - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeQuery: - if err := query(ctx); err != nil { + if err := query(ctx, binding); err != nil { return err } case packets.TypeFunctionCall: - if err := functionCall(ctx); err != nil { + if err := functionCall(ctx, binding); err != nil { return err } case packets.TypeSync: // phony sync call, we can just reply with a fake ReadyForQuery(TxState) - rfq := packets.ReadyForQuery(ctx.TxState) - ctx.Packet = &rfq - ctx.PeerWrite() + rfq := packets.ReadyForQuery(binding.TxState) + binding.Packet = &rfq + binding.PeerWrite(ctx) case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: - if err := eqp(ctx); err != nil { + if err := eqp(ctx, binding); err != nil { return err } default: - ctx.PeerFail(ctx.ErrUnexpectedPacket()) + binding.PeerFail(binding.ErrUnexpectedPacket()) } - if ctx.TxState == 'I' { + if binding.TxState == 'I' { return nil } - if !ctx.PeerRead() { + if !binding.PeerRead(ctx) { // abort tx - err := queryString(ctx, "ABORT;") + err := queryString(ctx, binding, "ABORT;") if err != nil { return err } - if ctx.TxState != 'I' { + if binding.TxState != 'I' { return ErrExpectedIdle } return nil @@ -315,13 +317,13 @@ func transaction(ctx *context) error { } } -func Transaction(server, peer *fed.Conn, initialPacket fed.Packet) (err, peerError error) { - ctx := context{ +func Transaction(ctx context.Context, server, peer *fed.Conn, initialPacket fed.Packet) (err, peerError error) { + pgState := serverToPeerBinding{ Server: server, Peer: peer, Packet: initialPacket, } - err = transaction(&ctx) - peerError = ctx.PeerError + err = transaction(ctx, &pgState) + peerError = pgState.PeerError return } diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index ec8b4c10..b825cb18 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -1,11 +1,12 @@ package bouncers import ( + "context" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" ) -func Bounce(client, server *fed.Conn, initialPacket fed.Packet) (clientError error, serverError error) { - serverError, clientError = backends.Transaction(server, client, initialPacket) +func Bounce(ctx context.Context, client, server *fed.Conn, initialPacket fed.Packet) (clientError error, serverError error) { + serverError, clientError = backends.Transaction(ctx, server, client, initialPacket) return } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 00036e1a..ab55c407 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -1,6 +1,7 @@ package frontends import ( + "context" "crypto/tls" "strings" @@ -10,12 +11,23 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) +type acceptParams struct { + Conn *fed.Conn + Options acceptOptions +} + +type acceptResult struct { + CancelKey fed.BackendKey + IsCanceling bool +} + func startup0( - ctx *acceptContext, + ctx context.Context, params *acceptParams, + result *acceptResult, ) (cancelling bool, done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(false) + packet, err = params.Conn.ReadPacket(ctx, false) if err != nil { return } @@ -31,29 +43,29 @@ func startup0( switch control := mode.Mode.(type) { case *packets.StartupPayloadControlPayloadCancel: // Cancel - params.CancelKey.ProcessID = control.ProcessID - params.CancelKey.SecretKey = control.SecretKey + result.CancelKey.ProcessID = control.ProcessID + result.CancelKey.SecretKey = control.SecretKey cancelling = true done = true return case *packets.StartupPayloadControlPayloadSSL: // ssl is not enabled - if ctx.Options.SSLConfig == nil { - err = ctx.Conn.WriteByte('N') + if params.Options.SSLConfig == nil { + err = params.Conn.WriteByte(ctx, 'N') return } // do ssl - if err = ctx.Conn.WriteByte('S'); err != nil { + if err = params.Conn.WriteByte(ctx, 'S'); err != nil { return } - if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, false); err != nil { + if err = params.Conn.EnableSSL(ctx, params.Options.SSLConfig, false); err != nil { return } return case *packets.StartupPayloadControlPayloadGSSAPI: // GSSAPI is not supported yet - err = ctx.Conn.WriteByte('N') + err = params.Conn.WriteByte(ctx, 'N') return default: err = perror.New( @@ -69,9 +81,9 @@ func startup0( for _, parameter := range mode.Parameters { switch parameter.Key { case "user": - ctx.Conn.User = parameter.Value + params.Conn.User = parameter.Value case "database": - ctx.Conn.Database = parameter.Value + params.Conn.Database = parameter.Value case "options": fields := strings.Fields(parameter.Value) for i := 0; i < len(fields); i++ { @@ -91,10 +103,10 @@ func startup0( ikey := strutil.MakeCIString(key) - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + if params.Conn.InitialParameters == nil { + params.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = value + params.Conn.InitialParameters[ikey] = value default: err = perror.New( perror.FATAL, @@ -118,10 +130,10 @@ func startup0( } else { ikey := strutil.MakeCIString(parameter.Key) - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + if params.Conn.InitialParameters == nil { + params.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = parameter.Value + params.Conn.InitialParameters[ikey] = parameter.Value } } } @@ -132,13 +144,13 @@ func startup0( MinorProtocolVersion: 0, UnrecognizedProtocolOptions: unsupportedOptions, } - err = ctx.Conn.WritePacket(&uopts) + err = params.Conn.WritePacket(ctx, &uopts) if err != nil { return } } - if ctx.Conn.User == "" { + if params.Conn.User == "" { err = perror.New( perror.FATAL, perror.InvalidAuthorizationSpecification, @@ -146,8 +158,8 @@ func startup0( ) return } - if ctx.Conn.Database == "" { - ctx.Conn.Database = ctx.Conn.User + if params.Conn.Database == "" { + params.Conn.Database = params.Conn.User } done = true @@ -163,11 +175,12 @@ func startup0( } func accept0( - ctx *acceptContext, -) (params acceptParams, err error) { + ctx context.Context, + params *acceptParams, +) (result acceptResult, err error) { for { var done bool - params.IsCanceling, done, err = startup0(ctx, ¶ms) + result.IsCanceling, done, err = startup0(ctx, params, &result) if err != nil { return } @@ -179,18 +192,18 @@ func accept0( return } -func fail(client *fed.Conn, err error) { +func fail(ctx context.Context, client *fed.Conn, err error) { resp := perror.ToPacket(perror.Wrap(err)) - _ = client.WritePacket(resp) + _ = client.WritePacket(ctx, resp) } -func accept(ctx *acceptContext) (acceptParams, error) { - params, err := accept0(ctx) +func accept(ctx context.Context, params *acceptParams) (acceptResult, error) { + result, err := accept0(ctx, params) if err != nil { - fail(ctx.Conn, err) - return acceptParams{}, err + fail(ctx, params.Conn, err) + return acceptResult{}, err } - return params, nil + return result, nil } func Accept(conn *fed.Conn, tlsConfig *tls.Config) ( @@ -198,15 +211,16 @@ func Accept(conn *fed.Conn, tlsConfig *tls.Config) ( isCanceling bool, err error, ) { - ctx := acceptContext{ + params := acceptParams{ Conn: conn, Options: acceptOptions{ SSLConfig: tlsConfig, }, } - var params acceptParams - params, err = accept(&ctx) - cancelKey = params.CancelKey - isCanceling = params.IsCanceling + var result acceptResult + if result, err = accept(context.Background(), ¶ms); err == nil { + cancelKey = result.CancelKey + isCanceling = result.IsCanceling + } return } diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index 498dec39..709dedda 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -1,6 +1,7 @@ package frontends import ( + "context" "crypto/rand" "encoding/binary" "errors" @@ -13,10 +14,15 @@ import ( "gfx.cafe/gfx/pggat/lib/perror" ) -func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err error) { +type authParams struct { + Conn *fed.Conn + Options authOptions +} + +func authenticationSASLInitial(ctx context.Context, params *authParams, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err error) { // check which authentication method the client wants var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -43,9 +49,9 @@ func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) return } -func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err error) { +func authenticationSASLContinue(ctx context.Context, params *authParams, tool auth.SASLVerifier) (resp []byte, done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -67,7 +73,7 @@ func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier return } -func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { +func authenticationSASL(ctx context.Context, params *authParams, creds auth.SASLServer) error { var mode packets.AuthenticationPayloadSASL mechanisms := creds.SupportedSASLMechanisms() for _, mechanism := range mechanisms { @@ -79,12 +85,12 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { saslInitial := packets.Authentication{ Mode: &mode, } - err := ctx.Conn.WritePacket(&saslInitial) + err := params.Conn.WritePacket(ctx, &saslInitial) if err != nil { return err } - tool, resp, done, err := authenticationSASLInitial(ctx, creds) + tool, resp, done, err := authenticationSASLInitial(ctx, params, creds) if err != nil { return err } @@ -95,7 +101,7 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { final := packets.Authentication{ Mode: &m, } - err = ctx.Conn.WritePacket(&final) + err = params.Conn.WritePacket(ctx, &final) if err != nil { return err } @@ -105,13 +111,13 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { cont := packets.Authentication{ Mode: &m, } - err = ctx.Conn.WritePacket(&cont) + err = params.Conn.WritePacket(ctx, &cont) if err != nil { return err } } - resp, done, err = authenticationSASLContinue(ctx, tool) + resp, done, err = authenticationSASLContinue(ctx, params, tool) if err != nil { return err } @@ -120,7 +126,7 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { return nil } -func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { +func authenticationMD5(ctx context.Context, params *authParams, creds auth.MD5Server) error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { @@ -130,13 +136,14 @@ func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { md5Initial := packets.Authentication{ Mode: &mode, } - err = ctx.Conn.WritePacket(&md5Initial) + + err = params.Conn.WritePacket(ctx, &md5Initial) if err != nil { return err } var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return err } @@ -154,12 +161,12 @@ func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { return nil } -func authenticate(ctx *authenticateContext) (err error) { - if ctx.Options.Credentials != nil { - if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { - err = authenticationSASL(ctx, credsSASL) - } else if credsMD5, ok := ctx.Options.Credentials.(auth.MD5Server); ok { - err = authenticationMD5(ctx, credsMD5) +func authenticate(ctx context.Context, params *authParams) (err error) { + if params.Options.Credentials != nil { + if credsSASL, ok := params.Options.Credentials.(auth.SASLServer); ok { + err = authenticationSASL(ctx, params, credsSASL) + } else if credsMD5, ok := params.Options.Credentials.(auth.MD5Server); ok { + err = authenticationMD5(ctx, params, credsMD5) } else { err = perror.New( perror.FATAL, @@ -176,10 +183,10 @@ func authenticate(ctx *authenticateContext) (err error) { authOk := packets.Authentication{ Mode: &packets.AuthenticationPayloadOk{}, } - if err = ctx.Conn.WritePacket(&authOk); err != nil { + if err = params.Conn.WritePacket(ctx, &authOk); err != nil { return } - ctx.Conn.Authenticated = true + params.Conn.Authenticated = true // send backend key data var processID [4]byte @@ -190,35 +197,35 @@ func authenticate(ctx *authenticateContext) (err error) { if _, err = rand.Reader.Read(backendKey[:]); err != nil { return } - ctx.Conn.BackendKey = fed.BackendKey{ + params.Conn.BackendKey = fed.BackendKey{ ProcessID: int32(binary.BigEndian.Uint32(processID[:])), SecretKey: int32(binary.BigEndian.Uint32(backendKey[:])), } keyData := packets.BackendKeyData{ - ProcessID: ctx.Conn.BackendKey.ProcessID, - SecretKey: ctx.Conn.BackendKey.SecretKey, + ProcessID: params.Conn.BackendKey.ProcessID, + SecretKey: params.Conn.BackendKey.SecretKey, } - if err = ctx.Conn.WritePacket(&keyData); err != nil { + if err = params.Conn.WritePacket(ctx, &keyData); err != nil { return } return } -func Authenticate(conn *fed.Conn, creds auth.Credentials) (err error) { +func Authenticate(ctx context.Context, conn *fed.Conn, creds auth.Credentials) (err error) { if conn.Authenticated { // already authenticated return } - ctx := authenticateContext{ + params := authParams{ Conn: conn, - Options: authenticateOptions{ + Options: authOptions{ Credentials: creds, }, } - err = authenticate(&ctx) + err = authenticate(ctx, ¶ms) if err != nil { // sleep after incorrect password time.Sleep(250 * time.Millisecond) diff --git a/lib/bouncer/frontends/v0/context.go b/lib/bouncer/frontends/v0/context.go deleted file mode 100644 index 993073ff..00000000 --- a/lib/bouncer/frontends/v0/context.go +++ /dev/null @@ -1,13 +0,0 @@ -package frontends - -import "gfx.cafe/gfx/pggat/lib/fed" - -type acceptContext struct { - Conn *fed.Conn - Options acceptOptions -} - -type authenticateContext struct { - Conn *fed.Conn - Options authenticateOptions -} diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go index 304f7b82..af243b93 100644 --- a/lib/bouncer/frontends/v0/options.go +++ b/lib/bouncer/frontends/v0/options.go @@ -10,6 +10,6 @@ type acceptOptions struct { SSLConfig *tls.Config } -type authenticateOptions struct { +type authOptions struct { Credentials auth.Credentials } diff --git a/lib/bouncer/frontends/v0/params.go b/lib/bouncer/frontends/v0/params.go deleted file mode 100644 index 4f28f834..00000000 --- a/lib/bouncer/frontends/v0/params.go +++ /dev/null @@ -1,8 +0,0 @@ -package frontends - -import "gfx.cafe/gfx/pggat/lib/fed" - -type acceptParams struct { - CancelKey fed.BackendKey - IsCanceling bool -} diff --git a/lib/fed/codecs/netconncodec/codec.go b/lib/fed/codecs/netconncodec/codec.go index 414943ca..3929f1ad 100644 --- a/lib/fed/codecs/netconncodec/codec.go +++ b/lib/fed/codecs/netconncodec/codec.go @@ -1,6 +1,7 @@ package netconncodec import ( + "context" "crypto/tls" "errors" "fmt" @@ -32,7 +33,7 @@ func NewCodec(rw net.Conn) fed.PacketCodec { return c } -func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) { +func (c *Codec) ReadPacket(ctx context.Context, typed bool) (fed.Packet, error) { if err := c.decoder.Next(typed); err != nil { return nil, err } @@ -41,7 +42,7 @@ func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) { }, nil } -func (c *Codec) WritePacket(packet fed.Packet) error { +func (c *Codec) WritePacket(ctx context.Context, packet fed.Packet) error { err := c.encoder.Next(packet.Type(), packet.Length()) if err != nil { return err @@ -49,23 +50,23 @@ func (c *Codec) WritePacket(packet fed.Packet) error { return packet.WriteTo(&c.encoder) } -func (c *Codec) WriteByte(b byte) error { +func (c *Codec) WriteByte(ctx context.Context, b byte) error { return c.encoder.WriteByte(b) } -func (c *Codec) ReadByte() (byte, error) { - if err := c.Flush(); err != nil { +func (c *Codec) ReadByte(ctx context.Context) (byte, error) { + if err := c.Flush(ctx); err != nil { return 0, err } return c.decoder.ReadByte() } -func (c *Codec) Flush() error { +func (c *Codec) Flush(ctx context.Context) error { return c.encoder.Flush() } -func (c *Codec) Close() error { +func (c *Codec) Close(ctx context.Context) error { if err := c.encoder.Flush(); err != nil { return err } @@ -80,7 +81,7 @@ func (c *Codec) SSL() bool { return c.ssl } -func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error { +func (c *Codec) EnableSSL(ctx context.Context, config *tls.Config, isClient bool) error { c.mu.Lock() defer c.mu.Unlock() if c.ssl { @@ -89,7 +90,7 @@ func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error { c.ssl = true // Flush buffers - if err := c.Flush(); err != nil { + if err := c.Flush(ctx); err != nil { return err } if c.decoder.Buffered() > 0 { diff --git a/lib/fed/conn.go b/lib/fed/conn.go index a32336e7..d9e0ad9e 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -1,6 +1,7 @@ package fed import ( + "context" "crypto/tls" "io" "net" @@ -39,16 +40,16 @@ func NewConn(codec PacketCodec) *Conn { return c } -func (T *Conn) Flush() error { - return T.codec.Flush() +func (T *Conn) Flush(ctx context.Context) error { + return T.codec.Flush(ctx) } -func (T *Conn) readPacket(typed bool) (Packet, error) { - return T.codec.ReadPacket(typed) +func (T *Conn) readPacket(ctx context.Context, typed bool) (Packet, error) { + return T.codec.ReadPacket(ctx, typed) } -func (T *Conn) ReadPacket(typed bool) (Packet, error) { - if err := T.Flush(); err != nil { +func (T *Conn) ReadPacket(ctx context.Context, typed bool) (Packet, error) { + if err := T.Flush(ctx); err != nil { return nil, err } @@ -57,7 +58,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { for i := 0; i < len(T.Middleware); i++ { middleware := T.Middleware[i] for { - packet, err := middleware.PreRead(typed) + packet, err := middleware.PreRead(ctx, typed) if err != nil { return nil, err } @@ -67,7 +68,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } for j := i; j < len(T.Middleware); j++ { - packet, err = T.Middleware[j].ReadPacket(packet) + packet, err = T.Middleware[j].ReadPacket(ctx, packet) if err != nil { return nil, err } @@ -82,12 +83,12 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } } - packet, err := T.readPacket(typed) + packet, err := T.readPacket(ctx, typed) if err != nil { return nil, err } for _, middleware := range T.Middleware { - packet, err = middleware.ReadPacket(packet) + packet, err = middleware.ReadPacket(ctx, packet) if err != nil { return nil, err } @@ -101,16 +102,16 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } } -func (T *Conn) writePacket(packet Packet) error { - return T.codec.WritePacket(packet) +func (T *Conn) writePacket(ctx context.Context, packet Packet) error { + return T.codec.WritePacket(ctx, packet) } -func (T *Conn) WritePacket(packet Packet) error { +func (T *Conn) WritePacket(ctx context.Context, packet Packet) error { for i := len(T.Middleware) - 1; i >= 0; i-- { middleware := T.Middleware[i] var err error - packet, err = middleware.WritePacket(packet) + packet, err = middleware.WritePacket(ctx, packet) if err != nil { return err } @@ -119,7 +120,7 @@ func (T *Conn) WritePacket(packet Packet) error { } } if packet != nil { - if err := T.writePacket(packet); err != nil { + if err := T.writePacket(ctx, packet); err != nil { return err } } @@ -130,7 +131,7 @@ func (T *Conn) WritePacket(packet Packet) error { for { var err error - packet, err = middleware.PostWrite() + packet, err = middleware.PostWrite(ctx) if err != nil { return err } @@ -140,7 +141,7 @@ func (T *Conn) WritePacket(packet Packet) error { } for j := i; j >= 0; j-- { - packet, err = T.Middleware[j].WritePacket(packet) + packet, err = T.Middleware[j].WritePacket(ctx, packet) if err != nil { return err } @@ -150,7 +151,7 @@ func (T *Conn) WritePacket(packet Packet) error { } if packet != nil { - if err = T.writePacket(packet); err != nil { + if err = T.writePacket(ctx, packet); err != nil { return err } } @@ -160,8 +161,8 @@ func (T *Conn) WritePacket(packet Packet) error { return nil } -func (T *Conn) WriteByte(b byte) error { - return T.codec.WriteByte(b) +func (T *Conn) WriteByte(ctx context.Context, b byte) error { + return T.codec.WriteByte(ctx, b) } func (T *Conn) LocalAddr() net.Addr { @@ -169,14 +170,14 @@ func (T *Conn) LocalAddr() net.Addr { } -func (T *Conn) ReadByte() (byte, error) { - return T.codec.ReadByte() +func (T *Conn) ReadByte(ctx context.Context) (byte, error) { + return T.codec.ReadByte(ctx) } -func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error { - return T.codec.EnableSSL(config, isClient) +func (T *Conn) EnableSSL(ctx context.Context, config *tls.Config, isClient bool) error { + return T.codec.EnableSSL(ctx, config, isClient) } -func (T *Conn) Close() error { - return T.codec.Close() +func (T *Conn) Close(ctx context.Context) error { + return T.codec.Close(ctx) } diff --git a/lib/fed/interface.go b/lib/fed/interface.go deleted file mode 100644 index 636d62a1..00000000 --- a/lib/fed/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package fed - -import ( - "crypto/tls" - "net" -) - -type PacketCodec interface { - ReadPacket(typed bool) (Packet, error) - WritePacket(packet Packet) error - WriteByte(b byte) error - ReadByte() (byte, error) - - LocalAddr() net.Addr - Flush() error - Close() error - - SSL() bool - EnableSSL(config *tls.Config, isClient bool) error -} diff --git a/lib/fed/middleware.go b/lib/fed/middleware.go index 92e31bd8..605d00ed 100644 --- a/lib/fed/middleware.go +++ b/lib/fed/middleware.go @@ -1,12 +1,14 @@ package fed +import "context" + // Middleware intercepts packets and possibly changes them. Return a 0 length packet to cancel. type Middleware interface { - PreRead(typed bool) (Packet, error) - ReadPacket(packet Packet) (Packet, error) + PreRead(ctx context.Context,typed bool) (Packet, error) + ReadPacket(ctx context.Context,packet Packet) (Packet, error) - WritePacket(packet Packet) (Packet, error) - PostWrite() (Packet, error) + WritePacket(ctx context.Context,packet Packet) (Packet, error) + PostWrite(ctx context.Context,) (Packet, error) } func LookupMiddleware[T Middleware](conn *Conn) (T, bool) { diff --git a/lib/fed/middlewares/eqp/client.go b/lib/fed/middlewares/eqp/client.go index 652d7073..2af9b781 100644 --- a/lib/fed/middlewares/eqp/client.go +++ b/lib/fed/middlewares/eqp/client.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) @@ -12,23 +13,23 @@ func NewClient() *Client { return new(Client) } -func (T *Client) PreRead(_ bool) (fed.Packet, error) { +func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.C2S(packet) } -func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.S2C(packet) } -func (T *Client) PostWrite() (fed.Packet, error) { +func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } -func (T *Client) Set(other *Client) { +func (T *Client) Set(ctx context.Context, other *Client) { T.state.Set(&other.state) } diff --git a/lib/fed/middlewares/eqp/server.go b/lib/fed/middlewares/eqp/server.go index f847a9fa..c24e472a 100644 --- a/lib/fed/middlewares/eqp/server.go +++ b/lib/fed/middlewares/eqp/server.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) @@ -12,19 +13,19 @@ func NewServer() *Server { return new(Server) } -func (T *Server) PreRead(_ bool) (fed.Packet, error) { +func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.S2C(packet) } -func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.C2S(packet) } -func (T *Server) PostWrite() (fed.Packet, error) { +func (T *Server) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index 2cc1be7d..fbbe8f57 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" @@ -19,7 +20,7 @@ func preparedStatementsEqual(a, b *packets.Parse) bool { return true } -func SyncMiddleware(c *Client, server *fed.Conn) error { +func SyncMiddleware(ctx context.Context,c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") @@ -35,7 +36,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { Which: 'P', Name: name, } - if err := server.WritePacket(&p); err != nil { + if err := server.WritePacket(ctx,&p); err != nil { return err } @@ -59,7 +60,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { Which: 'S', Name: name, } - if err := server.WritePacket(&p); err != nil { + if err := server.WritePacket(ctx,&p); err != nil { return err } @@ -74,7 +75,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { } } - if err := server.WritePacket(preparedStatement); err != nil { + if err := server.WritePacket(ctx,preparedStatement); err != nil { return err } @@ -83,7 +84,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { // bind all portals for _, portal := range c.state.portals { - if err := server.WritePacket(portal); err != nil { + if err := server.WritePacket(ctx,portal); err != nil { return err } @@ -99,11 +100,11 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { return nil } -func Sync(client, server *fed.Conn) error { +func Sync(ctx context.Context,client, server *fed.Conn) error { c, ok := fed.LookupMiddleware[*Client](client) if !ok { panic("middleware not found") } - return SyncMiddleware(c, server) + return SyncMiddleware(ctx,c, server) } diff --git a/lib/fed/middlewares/ps/client.go b/lib/fed/middlewares/ps/client.go index 94b5014f..f0ac38b5 100644 --- a/lib/fed/middlewares/ps/client.go +++ b/lib/fed/middlewares/ps/client.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/util/maps" @@ -18,15 +19,15 @@ func NewClient(parameters map[strutil.CIString]string) *Client { } } -func (T *Client) PreRead(_ bool) (fed.Packet, error) { +func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { switch packet.Type() { case packets.TypeParameterStatus: var p packets.ParameterStatus @@ -49,11 +50,11 @@ func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { } } -func (T *Client) PostWrite() (fed.Packet, error) { +func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } -func (T *Client) Set(other *Client) { +func (T *Client) Set(ctx context.Context, other *Client) { T.synced = other.synced maps.Clear(T.parameters) diff --git a/lib/fed/middlewares/ps/server.go b/lib/fed/middlewares/ps/server.go index 137b34df..7ce9c15d 100644 --- a/lib/fed/middlewares/ps/server.go +++ b/lib/fed/middlewares/ps/server.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/util/strutil" @@ -16,11 +17,11 @@ func NewServer(parameters map[strutil.CIString]string) *Server { } } -func (T *Server) PreRead(_ bool) (fed.Packet, error) { +func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) ReadPacket(ctx context.Context,packet fed.Packet) (fed.Packet, error) { switch packet.Type() { case packets.TypeParameterStatus: var p packets.ParameterStatus @@ -39,11 +40,11 @@ func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { } } -func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) WritePacket(ctx context.Context,packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Server) PostWrite() (fed.Packet, error) { +func (T *Server) PostWrite(ctx context.Context,) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/ps/sync.go b/lib/fed/middlewares/ps/sync.go index 152e86b0..0655694a 100644 --- a/lib/fed/middlewares/ps/sync.go +++ b/lib/fed/middlewares/ps/sync.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" @@ -8,7 +9,7 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) (clientErr, serverErr error) { +func sync(ctx context.Context, tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) (clientErr, serverErr error) { value, hasValue := c.parameters[name] expected, hasExpected := s.parameters[name] @@ -18,7 +19,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. Key: name.String(), Value: expected, } - clientErr = client.WritePacket(&ps) + clientErr = client.WritePacket(ctx, &ps) } return } @@ -26,7 +27,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. var doSet bool if hasValue && slices.Contains(tracking, name) { - if serverErr, _ = backends.SetParameter(server, nil, name, value); serverErr != nil { + if serverErr, _ = backends.SetParameter(ctx, server, nil, name, value); serverErr != nil { return } if s.parameters == nil { @@ -47,7 +48,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. Key: name.String(), Value: expected, } - if clientErr = client.WritePacket(&ps); clientErr != nil { + if clientErr = client.WritePacket(ctx, &ps); clientErr != nil { return } } @@ -55,14 +56,14 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. return } -func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) error { +func SyncMiddleware(ctx context.Context, tracking []strutil.CIString, c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") } for name := range c.parameters { - if _, err := sync(tracking, nil, c, server, s, name); err != nil { + if _, err := sync(ctx, tracking, nil, c, server, s, name); err != nil { return err } } @@ -71,7 +72,7 @@ func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) er if _, ok = c.parameters[name]; ok { continue } - if _, err := sync(tracking, nil, c, server, s, name); err != nil { + if _, err := sync(ctx, tracking, nil, c, server, s, name); err != nil { return err } } @@ -79,7 +80,7 @@ func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) er return nil } -func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, serverErr error) { +func Sync(ctx context.Context, tracking []strutil.CIString, client, server *fed.Conn) (clientErr, serverErr error) { c, ok := fed.LookupMiddleware[*Client](client) if !ok { panic("middleware not found") @@ -90,7 +91,7 @@ func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, ser } for name := range c.parameters { - if clientErr, serverErr = sync(tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { + if clientErr, serverErr = sync(ctx, tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { return } } @@ -99,7 +100,7 @@ func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, ser if _, ok = c.parameters[name]; ok { continue } - if clientErr, serverErr = sync(tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { + if clientErr, serverErr = sync(ctx, tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { return } } diff --git a/lib/fed/middlewares/unterminate/unterminate.go b/lib/fed/middlewares/unterminate/unterminate.go index 073dbe10..6218b110 100644 --- a/lib/fed/middlewares/unterminate/unterminate.go +++ b/lib/fed/middlewares/unterminate/unterminate.go @@ -1,6 +1,7 @@ package unterminate import ( + "context" "io" "gfx.cafe/gfx/pggat/lib/fed" @@ -13,22 +14,22 @@ var Unterminate = unterm{} type unterm struct{} -func (unterm) PreRead(_ bool) (fed.Packet, error) { +func (unterm) PreRead(_ context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (unterm) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (unterm) ReadPacket(_ context.Context, packet fed.Packet) (fed.Packet, error) { if packet.Type() == packets.TypeTerminate { return packet, io.EOF } return packet, nil } -func (unterm) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (unterm) WritePacket(_ context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (unterm) PostWrite() (fed.Packet, error) { +func (unterm) PostWrite(_ context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/packetCodec.go b/lib/fed/packetCodec.go new file mode 100644 index 00000000..03c0f8d7 --- /dev/null +++ b/lib/fed/packetCodec.go @@ -0,0 +1,21 @@ +package fed + +import ( + "context" + "crypto/tls" + "net" +) + +type PacketCodec interface { + ReadPacket(ctx context.Context,typed bool) (Packet, error) + WritePacket(ctx context.Context,packet Packet) error + WriteByte(ctx context.Context,b byte) error + ReadByte(ctx context.Context,) (byte, error) + + LocalAddr() net.Addr + Flush(ctx context.Context,) error + Close(ctx context.Context,) error + + SSL() bool + EnableSSL(ctx context.Context,config *tls.Config, isClient bool) error +} diff --git a/lib/gat/handler.go b/lib/gat/handler.go index b0d2d339..f6519e35 100644 --- a/lib/gat/handler.go +++ b/lib/gat/handler.go @@ -1,6 +1,7 @@ package gat import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/metrics" ) @@ -15,7 +16,7 @@ type Handler interface { type CancellableHandler interface { Handler - Cancel(key fed.BackendKey) + Cancel(ctx context.Context, key fed.BackendKey) } type MetricsHandler interface { diff --git a/lib/gat/handlers/pool/critics/latency/critic.go b/lib/gat/handlers/pool/critics/latency/critic.go index abeda6c5..67190627 100644 --- a/lib/gat/handlers/pool/critics/latency/critic.go +++ b/lib/gat/handlers/pool/critics/latency/critic.go @@ -1,6 +1,7 @@ package latency import ( + "context" "time" "github.com/caddyserver/caddy/v2" @@ -28,9 +29,9 @@ func (T *Critic) CaddyModule() caddy.ModuleInfo { } } -func (T *Critic) Taste(conn *fed.Conn) (int, time.Duration, error) { +func (T *Critic) Taste(ctx context.Context, conn *fed.Conn) (int, time.Duration, error) { start := time.Now() - err, _ := backends.QueryString(conn, nil, "select 0") + err, _ := backends.QueryString(ctx, conn, nil, "select 0") if err != nil { return 0, time.Duration(T.Validity), err } diff --git a/lib/gat/handlers/pool/dialer.go b/lib/gat/handlers/pool/dialer.go index 1da5d9fe..107154e2 100644 --- a/lib/gat/handlers/pool/dialer.go +++ b/lib/gat/handlers/pool/dialer.go @@ -1,6 +1,7 @@ package pool import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -70,6 +71,7 @@ func (T *Dialer) Dial() (*fed.Conn, error) { conn.User = T.Username conn.Database = T.Database err = backends.Accept( + context.Background(), conn, T.SSLMode, T.SSLConfig, @@ -85,21 +87,21 @@ func (T *Dialer) Dial() (*fed.Conn, error) { return conn, nil } -func (T *Dialer) Cancel(key fed.BackendKey) { +func (T *Dialer) Cancel(ctx context.Context, key fed.BackendKey) { c, err := T.dial() if err != nil { return } conn := fed.NewConn(netconncodec.NewCodec(c)) defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() - if err = backends.Cancel(conn, key); err != nil { + if err = backends.Cancel(ctx, conn, key); err != nil { return } // wait for server to close the connection, this means that the server received it ok - _, _ = conn.ReadPacket(true) + _, _ = conn.ReadPacket(ctx, true) } var _ caddy.Provisioner = (*gat.Listener)(nil) diff --git a/lib/gat/handlers/pool/penalty.go b/lib/gat/handlers/pool/penalty.go index 7b7f9243..9d7b66ad 100644 --- a/lib/gat/handlers/pool/penalty.go +++ b/lib/gat/handlers/pool/penalty.go @@ -1,6 +1,7 @@ package pool import ( + "context" "time" "gfx.cafe/gfx/pggat/lib/fed" @@ -8,5 +9,5 @@ import ( type Critic interface { // Taste calculates how much conn should be penalized. Lower is better - Taste(conn *fed.Conn) (score int, validity time.Duration, err error) + Taste(ctx context.Context, conn *fed.Conn) (score int, validity time.Duration, err error) } diff --git a/lib/gat/handlers/pool/pool.go b/lib/gat/handlers/pool/pool.go index e6792eba..5cf3143f 100644 --- a/lib/gat/handlers/pool/pool.go +++ b/lib/gat/handlers/pool/pool.go @@ -1,6 +1,7 @@ package pool import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/metrics" ) @@ -12,11 +13,11 @@ type Pool interface { // RemoveRecipe will remove a recipe and disconnect all servers created by that recipe. RemoveRecipe(name string) - Serve(conn *fed.Conn) error + Serve(ctx context.Context, conn *fed.Conn) error - Cancel(key fed.BackendKey) + Cancel(ctx context.Context, key fed.BackendKey) ReadMetrics(m *metrics.Pool) - Close() + Close(ctx context.Context) } type ReplicaPool interface { diff --git a/lib/gat/handlers/pool/pools/basic/pool.go b/lib/gat/handlers/pool/pools/basic/pool.go index bfbb6c10..91c92ced 100644 --- a/lib/gat/handlers/pool/pools/basic/pool.go +++ b/lib/gat/handlers/pool/pools/basic/pool.go @@ -1,6 +1,7 @@ package basic import ( + "context" "fmt" "sync" @@ -44,7 +45,7 @@ func (T *Pool) RemoveRecipe(name string) { T.servers.RemoveRecipe(name) } -func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, serverErr error) { +func (T *Pool) SyncInitialParameters(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { clientParams := client.Conn.InitialParameters serverParams := server.Conn.InitialParameters @@ -55,7 +56,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, Key: key.String(), Value: serverParams[key], } - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return } @@ -72,7 +73,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, Key: key.String(), Value: value, } - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return } @@ -81,7 +82,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, continue } - serverErr, _ = backends.SetParameter(server.Conn, nil, key, value) + serverErr, _ = backends.SetParameter(ctx, server.Conn, nil, key, value) if serverErr != nil { return } @@ -99,7 +100,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, Key: key.String(), Value: value, } - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return } @@ -108,16 +109,16 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, return } -func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) { +func (T *Pool) Pair(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { if T.config.ParameterStatusSync != ParameterStatusSyncNone || T.config.ExtendedQuerySync { client.SetState(metrics.ConnStatePairing, server) server.SetState(metrics.ConnStatePairing, client.ID) switch T.config.ParameterStatusSync { case ParameterStatusSyncDynamic: - err, serverErr = ps.Sync(T.config.TrackedParameters, client.Conn, server.Conn) + err, serverErr = ps.Sync(ctx, T.config.TrackedParameters, client.Conn, server.Conn) case ParameterStatusSyncInitial: - err, serverErr = T.SyncInitialParameters(client, server) + err, serverErr = T.SyncInitialParameters(ctx, client, server) } if err != nil || serverErr != nil { @@ -125,7 +126,7 @@ func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) } if T.config.ExtendedQuerySync { - serverErr = eqp.Sync(client.Conn, server.Conn) + serverErr = eqp.Sync(ctx, client.Conn, server.Conn) } if serverErr != nil { @@ -155,7 +156,7 @@ func (T *Pool) removeClient(client *Client) { delete(T.clients, client.Conn.BackendKey) } -func (T *Pool) Serve(conn *fed.Conn) error { +func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { if T.config.ParameterStatusSync == ParameterStatusSyncDynamic { conn.Middleware = append( conn.Middleware, @@ -199,7 +200,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) if serverErr != nil { return serverErr } @@ -224,7 +225,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { } var packet fed.Packet - packet, err = client.Conn.ReadPacket(true) + packet, err = client.Conn.ReadPacket(ctx, true) if err != nil { return err } @@ -237,10 +238,10 @@ func (T *Pool) Serve(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) } if err == nil && serverErr == nil { - err, serverErr = bouncers.Bounce(client.Conn, server.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, client.Conn, server.Conn, packet) } if serverErr != nil { @@ -256,7 +257,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { } } -func (T *Pool) Cancel(key fed.BackendKey) { +func (T *Pool) Cancel(ctx context.Context, key fed.BackendKey) { peer := func() *spool.Server { T.mu.RLock() defer T.mu.RUnlock() diff --git a/lib/gat/handlers/pool/pools/hybrid/middleware.go b/lib/gat/handlers/pool/pools/hybrid/middleware.go index 230ab054..ac6ff340 100644 --- a/lib/gat/handlers/pool/pools/hybrid/middleware.go +++ b/lib/gat/handlers/pool/pools/hybrid/middleware.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/perror" @@ -20,7 +21,7 @@ func NewMiddleware() *Middleware { return m } -func (T *Middleware) PreRead(typed bool) (fed.Packet, error) { +func (T *Middleware) PreRead(ctx context.Context, typed bool) (fed.Packet, error) { if !T.primary { return nil, nil } @@ -37,7 +38,7 @@ func (T *Middleware) PreRead(typed bool) (fed.Packet, error) { }, nil } -func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Middleware) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if T.primary { return packet, nil } @@ -60,7 +61,7 @@ func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) { return p, nil } -func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Middleware) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if T.primary && (T.buf.Buffered() > 0 || T.bufDec.Buffered() > 0) { return nil, nil } @@ -84,7 +85,7 @@ func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Middleware) PostWrite() (fed.Packet, error) { +func (T *Middleware) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/gat/handlers/pool/pools/hybrid/pool.go b/lib/gat/handlers/pool/pools/hybrid/pool.go index 6bf3aaf1..edfd45d7 100644 --- a/lib/gat/handlers/pool/pools/hybrid/pool.go +++ b/lib/gat/handlers/pool/pools/hybrid/pool.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "fmt" "sync" @@ -58,17 +59,17 @@ func (T *Pool) RemoveRecipe(name string) { T.primary.RemoveRecipe(name) } -func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) { +func (T *Pool) Pair(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { client.SetState(metrics.ConnStatePairing, server, true) server.SetState(metrics.ConnStatePairing, client.ID) - err, serverErr = ps.Sync(T.config.TrackedParameters, client.Conn, server.Conn) + err, serverErr = ps.Sync(ctx, T.config.TrackedParameters, client.Conn, server.Conn) if err != nil || serverErr != nil { return } - serverErr = eqp.Sync(client.Conn, server.Conn) + serverErr = eqp.Sync(ctx, client.Conn, server.Conn) if serverErr != nil { return @@ -79,14 +80,14 @@ func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) return } -func (T *Pool) PairPrimary(client *Client, psc *ps.Client, eqpc *eqp.Client, server *spool.Server) error { +func (T *Pool) PairPrimary(ctx context.Context, client *Client, psc *ps.Client, eqpc *eqp.Client, server *spool.Server) error { server.SetState(metrics.ConnStatePairing, client.ID) - if err := ps.SyncMiddleware(T.config.TrackedParameters, psc, server.Conn); err != nil { + if err := ps.SyncMiddleware(ctx, T.config.TrackedParameters, psc, server.Conn); err != nil { return err } - if err := eqp.SyncMiddleware(eqpc, server.Conn); err != nil { + if err := eqp.SyncMiddleware(ctx, eqpc, server.Conn); err != nil { return err } @@ -112,7 +113,7 @@ func (T *Pool) removeClient(client *Client) { delete(T.clients, client.Conn.BackendKey) } -func (T *Pool) serveRW(conn *fed.Conn) error { +func (T *Pool) serveRW(ctx context.Context, conn *fed.Conn) error { m := NewMiddleware() eqpa := eqp.NewClient() @@ -146,7 +147,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { if serverErr != nil { T.primary.RemoveServer(primary) } else { - T.primary.Release(primary) + T.primary.Release(ctx, primary) } primary = nil } @@ -154,7 +155,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { if serverErr != nil { T.replica.RemoveServer(replica) } else { - T.replica.Release(replica) + T.replica.Release(ctx, replica) } replica = nil } @@ -169,7 +170,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, replica) + err, serverErr = T.Pair(ctx, client, replica) if serverErr != nil { return serverErr } @@ -184,7 +185,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, primary) + err, serverErr = T.Pair(ctx, client, primary) if serverErr != nil { return serverErr } @@ -194,7 +195,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { } p := packets.ReadyForQuery('I') - if err = conn.WritePacket(&p); err != nil { + if err = conn.WritePacket(ctx, &p); err != nil { return err } @@ -203,17 +204,17 @@ func (T *Pool) serveRW(conn *fed.Conn) error { for { if primary != nil { - T.primary.Release(primary) + T.primary.Release(ctx, primary) primary = nil } if replica != nil { - T.replica.Release(replica) + T.replica.Release(ctx, replica) replica = nil } client.SetState(metrics.ConnStateIdle, nil, false) var packet fed.Packet - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -227,13 +228,13 @@ func (T *Pool) serveRW(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, replica) + err, serverErr = T.Pair(ctx, client, replica) - psi.Set(psa) - eqpi.Set(eqpa) + psi.Set(ctx, psa) + eqpi.Set(ctx, eqpa) if err == nil && serverErr == nil { - err, serverErr = bouncers.Bounce(conn, replica.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, replica.Conn, packet) } if serverErr != nil { return fmt.Errorf("server error: %w", serverErr) @@ -245,10 +246,10 @@ func (T *Pool) serveRW(conn *fed.Conn) error { if err == (ErrReadOnly{}) { m.Primary() - T.replica.Release(replica) + T.replica.Release(ctx, replica) replica = nil - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -261,10 +262,10 @@ func (T *Pool) serveRW(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - serverErr = T.PairPrimary(client, psi, eqpi, primary) + serverErr = T.PairPrimary(ctx, client, psi, eqpi, primary) if serverErr == nil { - err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, primary.Conn, packet) } if serverErr != nil { return fmt.Errorf("server error: %w", serverErr) @@ -276,7 +277,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { // straight to primary m.Primary() - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -289,10 +290,10 @@ func (T *Pool) serveRW(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, primary) + err, serverErr = T.Pair(ctx, client, primary) if err == nil && serverErr == nil { - err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, primary.Conn, packet) } if serverErr != nil { return fmt.Errorf("server error: %w", serverErr) @@ -309,7 +310,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { } } -func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { +func (T *Pool) serveOnly(ctx context.Context, conn *fed.Conn, write bool) error { var sp *spool.Pool if write { sp = &T.primary @@ -340,7 +341,7 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { if serverErr != nil { sp.RemoveServer(server) } else { - sp.Release(server) + sp.Release(ctx, server) } server = nil } @@ -354,7 +355,7 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) if serverErr != nil { return serverErr } @@ -363,7 +364,7 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { } p := packets.ReadyForQuery('I') - if err = conn.WritePacket(&p); err != nil { + if err = conn.WritePacket(ctx, &p); err != nil { return err } @@ -372,13 +373,13 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { for { if server != nil { - sp.Release(server) + sp.Release(ctx, server) server = nil } client.SetState(metrics.ConnStateIdle, nil, true) var packet fed.Packet - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -390,10 +391,10 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) if err == nil && serverErr == nil { - err, serverErr = bouncers.Bounce(conn, server.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, server.Conn, packet) } if serverErr != nil { return fmt.Errorf("server error: %w", serverErr) @@ -408,18 +409,18 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { } } -func (T *Pool) Serve(conn *fed.Conn) error { +func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { switch conn.InitialParameters[strutil.MakeCIString("hybrid.mode")] { case "ro": - return T.serveOnly(conn, false) + return T.serveOnly(ctx, conn, false) case "wo": - return T.serveOnly(conn, true) + return T.serveOnly(ctx, conn, true) default: - return T.serveRW(conn) + return T.serveRW(ctx, conn) } } -func (T *Pool) Cancel(key fed.BackendKey) { +func (T *Pool) Cancel(ctx context.Context, key fed.BackendKey) { peer, replica := func() (*spool.Server, bool) { T.mu.RLock() defer T.mu.RUnlock() @@ -461,7 +462,7 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close() { +func (T *Pool) Close(_ context.Context) { T.primary.Close() T.replica.Close() } diff --git a/lib/gat/handlers/pool/spool/pool.go b/lib/gat/handlers/pool/spool/pool.go index 7589ae16..285403dd 100644 --- a/lib/gat/handlers/pool/spool/pool.go +++ b/lib/gat/handlers/pool/spool/pool.go @@ -1,6 +1,7 @@ package spool import ( + "context" "sync" "time" @@ -275,11 +276,11 @@ func (T *Pool) Acquire(client uuid.UUID) *Server { } } -func (T *Pool) Release(server *Server) { +func (T *Pool) Release(ctx context.Context, server *Server) { if T.config.ResetQuery != "" { server.SetState(metrics.ConnStateRunningResetQuery, uuid.Nil) - if err, _ := backends.QueryString(server.Conn, nil, T.config.ResetQuery); err != nil { + if err, _ := backends.QueryString(ctx, server.Conn, nil, T.config.ResetQuery); err != nil { T.config.Logger.Error("failed to run reset query", zap.Error(err)) T.RemoveServer(server) return diff --git a/lib/gat/server.go b/lib/gat/server.go index ad68e5f3..2fabc77f 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -132,7 +132,7 @@ func (T *Server) Serve(conn *fed.Conn) { errResp := perror.ToPacket( perror.New( perror.FATAL, - perror.InvalidPassword, + perror.InvalidCatalogName, fmt.Sprintf(`Database "%s" not found`, conn.Database), ), ) diff --git a/lib/gsql/eq.go b/lib/gsql/eq.go index 3cfc8d49..6e207b63 100644 --- a/lib/gsql/eq.go +++ b/lib/gsql/eq.go @@ -1,6 +1,7 @@ package gsql import ( + "context" "reflect" "strconv" @@ -8,16 +9,16 @@ import ( packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func ExtendedQuery(client *fed.Conn, result any, query string, args ...any) error { +func ExtendedQuery(ctx context.Context, client *fed.Conn, result any, query string, args ...any) error { if len(args) == 0 { - return Query(client, []any{result}, query) + return Query(ctx, client, []any{result}, query) } // parse parse := packets.Parse{ Query: query, } - if err := client.WritePacket(&parse); err != nil { + if err := client.WritePacket(ctx, &parse); err != nil { return err } @@ -59,7 +60,7 @@ outer: bind := packets.Bind{ Parameters: params, } - if err := client.WritePacket(&bind); err != nil { + if err := client.WritePacket(ctx, &bind); err != nil { return err } @@ -67,29 +68,29 @@ outer: describe := packets.Describe{ Which: 'P', } - if err := client.WritePacket(&describe); err != nil { + if err := client.WritePacket(ctx, &describe); err != nil { return err } // execute execute := packets.Execute{} - if err := client.WritePacket(&execute); err != nil { + if err := client.WritePacket(ctx, &execute); err != nil { return err } // sync sync := packets.Sync{} - if err := client.WritePacket(&sync); err != nil { + if err := client.WritePacket(ctx, &sync); err != nil { return err } // result - if err := readQueryResults(client, result); err != nil { + if err := readQueryResults(ctx, client, result); err != nil { return err } // make sure we receive ready for query - packet, err := client.ReadPacket(true) + packet, err := client.ReadPacket(ctx, true) if err != nil { return err } diff --git a/lib/gsql/query.go b/lib/gsql/query.go index 73e236a3..f5c86549 100644 --- a/lib/gsql/query.go +++ b/lib/gsql/query.go @@ -1,22 +1,23 @@ package gsql import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func Query(client *fed.Conn, results []any, query string) error { +func Query(ctx context.Context, client *fed.Conn, results []any, query string) error { var q = packets.Query(query) - if err := client.WritePacket(&q); err != nil { + if err := client.WritePacket(ctx, &q); err != nil { return err } - if err := readQueryResults(client, results...); err != nil { + if err := readQueryResults(ctx, client, results...); err != nil { return err } // make sure we receive ready for query - packet, err := client.ReadPacket(true) + packet, err := client.ReadPacket(ctx, true) if err != nil { return err } @@ -28,9 +29,9 @@ func Query(client *fed.Conn, results []any, query string) error { return nil } -func readQueryResults(client *fed.Conn, results ...any) error { +func readQueryResults(ctx context.Context, client *fed.Conn, results ...any) error { for _, result := range results { - if err := readRows(client, result); err != nil { + if err := readRows(ctx, client, result); err != nil { return err } } diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index 476748de..602061aa 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -1,6 +1,7 @@ package gsql_test import ( + "context" "crypto/tls" "log" "net" @@ -32,8 +33,10 @@ func TestQuery(t *testing.T) { t.Error(err) return } + ctx := context.Background() server := fed.NewConn(netconncodec.NewCodec(s)) err = backends.Accept( + ctx, server, "disable", &tls.Config{}, @@ -63,18 +66,18 @@ func TestQuery(t *testing.T) { }) b.Queue(func() error { - initial, err := outward.ReadPacket(true) + initial, err := outward.ReadPacket(ctx, true) if err != nil { return err } - clientErr, serverErr := bouncers.Bounce(outward, server, initial) + clientErr, serverErr := bouncers.Bounce(ctx, outward, server, initial) if clientErr != nil { return clientErr } if serverErr != nil { return serverErr } - if err := outward.Close(); err != nil { + if err := outward.Close(ctx); err != nil { return err } return nil diff --git a/lib/gsql/row.go b/lib/gsql/row.go index 2e907694..ef1e5110 100644 --- a/lib/gsql/row.go +++ b/lib/gsql/row.go @@ -1,6 +1,7 @@ package gsql import ( + "context" "reflect" "strconv" @@ -9,13 +10,13 @@ import ( "gfx.cafe/gfx/pggat/lib/perror" ) -func readRows(client *fed.Conn, result any) error { +func readRows(ctx context.Context,client *fed.Conn, result any) error { res := reflect.ValueOf(result) row := 0 var rd packets.RowDescription for { - packet, err := client.ReadPacket(true) + packet, err := client.ReadPacket(ctx,true) if err != nil { return err } -- GitLab