diff --git a/Gatfile b/Gatfile index 03f55c04864a8191085f63cd1b55c6a9ac9ce810..5ef6aac1024a923ee6406466006e0a2a9981aa00 100644 --- a/Gatfile +++ b/Gatfile @@ -1,3 +1,23 @@ :5433 { - error "server is not configured" + ssl self_signed + + pool /base { + pool basic session + + address localhost:5432 + + username postgres + password postgres + database base + } + + pool /pgbench { + pool basic session + + address localhost:5432 + + username postgres + password postgres + database test + } } diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 6a6578aa095d566aa152478f4672616e6423755e..c0dfa247f56aaa78043b2d70fa10f41b7bf11aa6 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -1,6 +1,7 @@ package backends import ( + "context" "crypto/tls" "errors" "io" @@ -13,9 +14,14 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) (done bool, err error) { +type acceptParams struct { + Conn *fed.Conn + Options acceptOptions +} + +func authenticationSASLChallenge(ctx context.Context, params *acceptParams, encoder auth.SASLEncoder) (done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -41,7 +47,7 @@ func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) ( } resp := packets.SASLResponse(response) - err = ctx.Conn.WritePacket(&resp) + err = params.Conn.WritePacket(ctx, &resp) return case *packets.AuthenticationPayloadSASLFinal: // finish @@ -60,7 +66,7 @@ func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) ( } } -func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASLClient) error { +func authenticationSASL(ctx context.Context, params *acceptParams, mechanisms []string, creds auth.SASLClient) error { mechanism, encoder, err := creds.EncodeSASL(mechanisms) if err != nil { return err @@ -74,7 +80,7 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL Mechanism: mechanism, InitialClientResponse: initialResponse, } - err = ctx.Conn.WritePacket(&saslInitialResponse) + err = params.Conn.WritePacket(ctx, &saslInitialResponse) if err != nil { return err } @@ -82,7 +88,7 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL // challenge loop for { var done bool - done, err = authenticationSASLChallenge(ctx, encoder) + done, err = authenticationSASLChallenge(ctx, params, encoder) if err != nil { return err } @@ -94,46 +100,46 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL return nil } -func authenticationMD5(ctx *acceptContext, salt [4]byte, creds auth.MD5Client) error { +func authenticationMD5(ctx context.Context, params *acceptParams, salt [4]byte, creds auth.MD5Client) error { pw := packets.PasswordMessage(creds.EncodeMD5(salt)) - err := ctx.Conn.WritePacket(&pw) + err := params.Conn.WritePacket(ctx, &pw) if err != nil { return err } return nil } -func authenticationCleartext(ctx *acceptContext, creds auth.CleartextClient) error { +func authenticationCleartext(ctx context.Context, params *acceptParams, creds auth.CleartextClient) error { pw := packets.PasswordMessage(creds.EncodeCleartext()) - err := ctx.Conn.WritePacket(&pw) + err := params.Conn.WritePacket(ctx, &pw) if err != nil { return err } return nil } -func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, err error) { +func authentication(ctx context.Context, params *acceptParams, p *packets.Authentication) (done bool, err error) { // they have more authentication methods than there are pokemon switch mode := p.Mode.(type) { case *packets.AuthenticationPayloadOk: // we're good to go, that was easy - ctx.Conn.Authenticated = true + params.Conn.Authenticated = true return true, nil case *packets.AuthenticationPayloadKerberosV5: err = errors.New("kerberos v5 is not supported") return case *packets.AuthenticationPayloadCleartextPassword: - c, ok := ctx.Options.Credentials.(auth.CleartextClient) + c, ok := params.Options.Credentials.(auth.CleartextClient) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationCleartext(ctx, c) + return false, authenticationCleartext(ctx, params, c) case *packets.AuthenticationPayloadMD5Password: - c, ok := ctx.Options.Credentials.(auth.MD5Client) + c, ok := params.Options.Credentials.(auth.MD5Client) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationMD5(ctx, *mode, c) + return false, authenticationMD5(ctx, params, *mode, c) case *packets.AuthenticationPayloadGSS: err = errors.New("gss is not supported") return @@ -141,7 +147,7 @@ func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, e err = errors.New("sspi is not supported") return case *packets.AuthenticationPayloadSASL: - c, ok := ctx.Options.Credentials.(auth.SASLClient) + c, ok := params.Options.Credentials.(auth.SASLClient) if !ok { return false, auth.ErrMethodNotSupported } @@ -151,16 +157,16 @@ func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, e mechanisms = append(mechanisms, m.Method) } - return false, authenticationSASL(ctx, mechanisms, c) + return false, authenticationSASL(ctx, params, mechanisms, c) default: err = errors.New("unknown authentication method") return } } -func startup0(ctx *acceptContext) (done bool, err error) { +func startup0(ctx context.Context, params *acceptParams) (done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -180,7 +186,7 @@ func startup0(ctx *acceptContext) (done bool, err error) { if err != nil { return } - return authentication(ctx, &p) + return authentication(ctx, params, &p) case packets.TypeNegotiateProtocolVersion: // we only support protocol 3.0 for now err = errors.New("server wanted to negotiate protocol version") @@ -191,9 +197,9 @@ func startup0(ctx *acceptContext) (done bool, err error) { } } -func startup1(ctx *acceptContext) (done bool, err error) { +func startup1(ctx context.Context, params *acceptParams) (done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -205,8 +211,8 @@ func startup1(ctx *acceptContext) (done bool, err error) { if err != nil { return } - ctx.Conn.BackendKey.SecretKey = p.SecretKey - ctx.Conn.BackendKey.ProcessID = p.ProcessID + params.Conn.BackendKey.SecretKey = p.SecretKey + params.Conn.BackendKey.ProcessID = p.ProcessID return false, nil case packets.TypeParameterStatus: @@ -216,10 +222,10 @@ func startup1(ctx *acceptContext) (done bool, err error) { return } ikey := strutil.MakeCIString(p.Key) - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + if params.Conn.InitialParameters == nil { + params.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = p.Value + params.Conn.InitialParameters[ikey] = p.Value return false, nil case packets.TypeReadyForQuery: return true, nil @@ -240,19 +246,19 @@ func startup1(ctx *acceptContext) (done bool, err error) { } } -func enableSSL(ctx *acceptContext) (bool, error) { +func enableSSL(ctx context.Context, params *acceptParams) (bool, error) { p := packets.Startup{ Mode: &packets.StartupPayloadControl{ Mode: &packets.StartupPayloadControlPayloadSSL{}, }, } - if err := ctx.Conn.WritePacket(&p); err != nil { + if err := params.Conn.WritePacket(ctx, &p); err != nil { return false, err } // read byte to see if ssl is allowed - yn, err := ctx.Conn.ReadByte() + yn, err := params.Conn.ReadByte(ctx) if err != nil { return false, err } @@ -262,26 +268,26 @@ func enableSSL(ctx *acceptContext) (bool, error) { return false, nil } - if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, true); err != nil { + if err = params.Conn.EnableSSL(ctx, params.Options.SSLConfig, true); err != nil { return false, err } return true, nil } -func accept(ctx *acceptContext) error { - username := ctx.Options.Username +func accept(ctx context.Context, params *acceptParams) error { + username := params.Options.Username - if ctx.Options.Database == "" { - ctx.Options.Database = username + if params.Options.Database == "" { + params.Options.Database = username } - if ctx.Options.SSLMode.ShouldAttempt() { - sslEnabled, err := enableSSL(ctx) + if params.Options.SSLMode.ShouldAttempt() { + sslEnabled, err := enableSSL(ctx, params) if err != nil { return err } - if !sslEnabled && ctx.Options.SSLMode.IsRequired() { + if !sslEnabled && params.Options.SSLMode.IsRequired() { return errors.New("server rejected SSL encryption") } } @@ -295,12 +301,12 @@ func accept(ctx *acceptContext) error { }, { Key: "database", - Value: ctx.Options.Database, + Value: params.Options.Database, }, }, } - for key, value := range ctx.Options.StartupParameters { + for key, value := range params.Options.StartupParameters { m.Parameters = append(m.Parameters, packets.StartupPayloadVersion3PayloadParameter{ Key: key.String(), Value: value, @@ -311,14 +317,14 @@ func accept(ctx *acceptContext) error { Mode: &m, } - err := ctx.Conn.WritePacket(&p) + err := params.Conn.WritePacket(ctx, &p) if err != nil { return err } for { var done bool - done, err = startup0(ctx) + done, err = startup0(ctx, params) if err != nil { return err } @@ -329,7 +335,7 @@ func accept(ctx *acceptContext) error { for { var done bool - done, err = startup1(ctx) + done, err = startup1(ctx, params) if err != nil { return err } @@ -343,6 +349,7 @@ func accept(ctx *acceptContext) error { } func Accept( + ctx context.Context, conn *fed.Conn, sslMode bouncer.SSLMode, sslConfig *tls.Config, @@ -351,7 +358,7 @@ func Accept( database string, startupParameters map[strutil.CIString]string, ) error { - ctx := acceptContext{ + params := acceptParams{ Conn: conn, Options: acceptOptions{ SSLMode: sslMode, @@ -362,5 +369,5 @@ func Accept( StartupParameters: startupParameters, }, } - return accept(&ctx) + return accept(ctx, ¶ms) } diff --git a/lib/bouncer/backends/v0/cancel.go b/lib/bouncer/backends/v0/cancel.go index 48ee2b66461df144c76de97208c8b00bd642d2bf..ab665da05bd1db2fc6d8799890ddd7cadc279587 100644 --- a/lib/bouncer/backends/v0/cancel.go +++ b/lib/bouncer/backends/v0/cancel.go @@ -1,11 +1,12 @@ package backends import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func Cancel(server *fed.Conn, key fed.BackendKey) error { +func Cancel(ctx context.Context, server *fed.Conn, key fed.BackendKey) error { p := packets.Startup{ Mode: &packets.StartupPayloadControl{ Mode: &packets.StartupPayloadControlPayloadCancel{ @@ -14,5 +15,5 @@ func Cancel(server *fed.Conn, key fed.BackendKey) error { }, }, } - return server.WritePacket(&p) + return server.WritePacket(ctx, &p) } diff --git a/lib/bouncer/backends/v0/context.go b/lib/bouncer/backends/v0/context.go index 8288ab1812d3b20bb069ef4d78f2d27961ddd7a9..ab1d35eca29df53506ad778c930304997350cc34 100644 --- a/lib/bouncer/backends/v0/context.go +++ b/lib/bouncer/backends/v0/context.go @@ -1,44 +1,40 @@ package backends import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) -type acceptContext struct { - Conn *fed.Conn - Options acceptOptions -} - -type context struct { +type serverToPeerBinding struct { Server *fed.Conn - Packet fed.Packet Peer *fed.Conn + Packet fed.Packet PeerError error TxState byte } -func (T *context) ErrUnexpectedPacket() error { +func (T *serverToPeerBinding) ErrUnexpectedPacket() error { return ErrUnexpectedPacket(T.Packet.Type()) } -func (T *context) ServerRead() error { +func (T *serverToPeerBinding) ServerRead(ctx context.Context) error { var err error - T.Packet, err = T.Server.ReadPacket(true) + T.Packet, err = T.Server.ReadPacket(ctx, true) return err } -func (T *context) ServerWrite() error { - return T.Server.WritePacket(T.Packet) +func (T *serverToPeerBinding) ServerWrite(ctx context.Context) error { + return T.Server.WritePacket(ctx, T.Packet) } -func (T *context) PeerOK() bool { +func (T *serverToPeerBinding) PeerOK() bool { if T == nil { return false } return T.Peer != nil && T.PeerError == nil } -func (T *context) PeerFail(err error) { +func (T *serverToPeerBinding) PeerFail(err error) { if T == nil { return } @@ -46,7 +42,7 @@ func (T *context) PeerFail(err error) { T.PeerError = err } -func (T *context) PeerRead() bool { +func (T *serverToPeerBinding) PeerRead(ctx context.Context) bool { if T == nil { return false } @@ -54,7 +50,7 @@ func (T *context) PeerRead() bool { return false } var err error - T.Packet, err = T.Peer.ReadPacket(true) + T.Packet, err = T.Peer.ReadPacket(ctx, true) if err != nil { T.PeerFail(err) return false @@ -62,14 +58,14 @@ func (T *context) PeerRead() bool { return true } -func (T *context) PeerWrite() { +func (T *serverToPeerBinding) PeerWrite(ctx context.Context) { if T == nil { return } if !T.PeerOK() { return } - err := T.Peer.WritePacket(T.Packet) + err := T.Peer.WritePacket(ctx, T.Packet) if err != nil { T.PeerFail(err) } diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index c1e7b52df8ffe14ab7c55f2f0be08ffb166f9fa0..d7c0e1fa53a3e8ccd3e8370cfaad572f776d0835 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -1,6 +1,7 @@ package backends import ( + "context" "strings" "gfx.cafe/gfx/pggat/lib/fed" @@ -8,65 +9,65 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func copyIn(ctx *context) error { - ctx.PeerWrite() +func copyIn(ctx context.Context, binding *serverToPeerBinding) error { + binding.PeerWrite(ctx) for { - if !ctx.PeerRead() { + if !binding.PeerRead(ctx) { copyFail := packets.CopyFail("peer failed") - ctx.Packet = ©Fail - return ctx.ServerWrite() + binding.Packet = ©Fail + return binding.ServerWrite(ctx) } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeCopyData: - if err := ctx.ServerWrite(); err != nil { + if err := binding.ServerWrite(ctx); err != nil { return err } case packets.TypeCopyDone, packets.TypeCopyFail: - return ctx.ServerWrite() + return binding.ServerWrite(ctx) default: - ctx.PeerFail(ctx.ErrUnexpectedPacket()) + binding.PeerFail(binding.ErrUnexpectedPacket()) } } } -func copyOut(ctx *context) error { - ctx.PeerWrite() +func copyOut(ctx context.Context, binding *serverToPeerBinding) error { + binding.PeerWrite(ctx) for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeCopyData, packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeCopyDone, packets.TypeMarkiplierResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) return nil default: - return ctx.ErrUnexpectedPacket() + return binding.ErrUnexpectedPacket() } } } -func query(ctx *context) error { - if err := ctx.ServerWrite(); err != nil { +func query(ctx context.Context, binding *serverToPeerBinding) error { + if err := binding.ServerWrite(ctx); err != nil { return err } for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeCommandComplete, packets.TypeRowDescription, packets.TypeDataRow, @@ -75,48 +76,48 @@ func query(ctx *context) error { packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeCopyInResponse: - if err = copyIn(ctx); err != nil { + if err = copyIn(ctx, binding); err != nil { return err } case packets.TypeCopyOutResponse: - if err = copyOut(ctx); err != nil { + if err = copyOut(ctx, binding); err != nil { return err } case packets.TypeReadyForQuery: var p packets.ReadyForQuery - err = fed.ToConcrete(&p, ctx.Packet) + err = fed.ToConcrete(&p, binding.Packet) if err != nil { return err } - ctx.Packet = &p - ctx.TxState = byte(p) - ctx.PeerWrite() + binding.Packet = &p + binding.TxState = byte(p) + binding.PeerWrite(ctx) return nil default: - return ctx.ErrUnexpectedPacket() + return binding.ErrUnexpectedPacket() } } } -func queryString(ctx *context, q string) error { +func queryString(ctx context.Context, binding *serverToPeerBinding, q string) error { qq := packets.Query(q) - ctx.Packet = &qq - return query(ctx) + binding.Packet = &qq + return query(ctx, binding) } -func QueryString(server, peer *fed.Conn, query string) (err, peerError error) { - ctx := context{ +func QueryString(ctx context.Context, server, peer *fed.Conn, query string) (err, peerError error) { + binding := serverToPeerBinding{ Server: server, Peer: peer, } - err = queryString(&ctx, query) - peerError = ctx.PeerError + err = queryString(ctx, &binding, query) + peerError = binding.PeerError return } -func SetParameter(server, peer *fed.Conn, name strutil.CIString, value string) (err, peerError error) { +func SetParameter(ctx context.Context, server, peer *fed.Conn, name strutil.CIString, value string) (err, peerError error) { var q strings.Builder escapedName := strutil.Escape(name.String(), '"') escapedValue := strutil.Escape(value, '\'') @@ -128,58 +129,59 @@ func SetParameter(server, peer *fed.Conn, name strutil.CIString, value string) ( q.WriteString(`'`) return QueryString( + ctx, server, peer, q.String(), ) } -func functionCall(ctx *context) error { - if err := ctx.ServerWrite(); err != nil { +func functionCall(ctx context.Context, binding *serverToPeerBinding) error { + if err := binding.ServerWrite(ctx); err != nil { return err } for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeMarkiplierResponse, packets.TypeFunctionCallResponse, packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeReadyForQuery: var p packets.ReadyForQuery - err = fed.ToConcrete(&p, ctx.Packet) + err = fed.ToConcrete(&p, binding.Packet) if err != nil { return err } - ctx.Packet = &p - ctx.TxState = byte(p) - ctx.PeerWrite() + binding.Packet = &p + binding.TxState = byte(p) + binding.PeerWrite(ctx) return nil default: - return ctx.ErrUnexpectedPacket() + return binding.ErrUnexpectedPacket() } } } -func sync(ctx *context) (bool, error) { - if err := ctx.ServerWrite(); err != nil { +func sync(ctx context.Context, binding *serverToPeerBinding) (bool, error) { + if err := binding.ServerWrite(ctx); err != nil { return false, err } for { - err := ctx.ServerRead() + err := binding.ServerRead(ctx) if err != nil { return false, err } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeParseComplete, packets.TypeBindComplete, packets.TypeCloseComplete, @@ -196,54 +198,54 @@ func sync(ctx *context) (bool, error) { packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite() + binding.PeerWrite(ctx) case packets.TypeCopyInResponse: - if err = copyIn(ctx); err != nil { + if err = copyIn(ctx, binding); err != nil { return false, err } // why return false, nil case packets.TypeCopyOutResponse: - if err = copyOut(ctx); err != nil { + if err = copyOut(ctx, binding); err != nil { return false, err } case packets.TypeReadyForQuery: var p packets.ReadyForQuery - err = fed.ToConcrete(&p, ctx.Packet) + err = fed.ToConcrete(&p, binding.Packet) if err != nil { return false, err } - ctx.Packet = &p - ctx.TxState = byte(p) - ctx.PeerWrite() + binding.Packet = &p + binding.TxState = byte(p) + binding.PeerWrite(ctx) return true, nil default: - return false, ctx.ErrUnexpectedPacket() + return false, binding.ErrUnexpectedPacket() } } } -func Sync(server, peer *fed.Conn) (err, peerErr error) { - ctx := context{ +func Sync(ctx context.Context, server, peer *fed.Conn) (err, peerErr error) { + binding := serverToPeerBinding{ Server: server, Peer: peer, Packet: &packets.Sync{}, } - _, err = sync(&ctx) - peerErr = ctx.PeerError + _, err = sync(ctx, &binding) + peerErr = binding.PeerError return } -func eqp(ctx *context) error { - if err := ctx.ServerWrite(); err != nil { +func eqp(ctx context.Context, binding *serverToPeerBinding) error { + if err := binding.ServerWrite(ctx); err != nil { return err } for { - if !ctx.PeerRead() { + if !binding.PeerRead(ctx) { for { - ctx.Packet = &packets.Sync{} - ok, err := sync(ctx) + binding.Packet = &packets.Sync{} + ok, err := sync(ctx, binding) if err != nil { return err } @@ -253,9 +255,9 @@ func eqp(ctx *context) error { } } - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeSync: - ok, err := sync(ctx) + ok, err := sync(ctx, binding) if err != nil { return err } @@ -263,51 +265,51 @@ func eqp(ctx *context) error { return nil } case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: - if err := ctx.ServerWrite(); err != nil { + if err := binding.ServerWrite(ctx); err != nil { return err } default: - ctx.PeerFail(ctx.ErrUnexpectedPacket()) + binding.PeerFail(binding.ErrUnexpectedPacket()) } } } -func transaction(ctx *context) error { +func transaction(ctx context.Context, binding *serverToPeerBinding) error { for { - switch ctx.Packet.Type() { + switch binding.Packet.Type() { case packets.TypeQuery: - if err := query(ctx); err != nil { + if err := query(ctx, binding); err != nil { return err } case packets.TypeFunctionCall: - if err := functionCall(ctx); err != nil { + if err := functionCall(ctx, binding); err != nil { return err } case packets.TypeSync: // phony sync call, we can just reply with a fake ReadyForQuery(TxState) - rfq := packets.ReadyForQuery(ctx.TxState) - ctx.Packet = &rfq - ctx.PeerWrite() + rfq := packets.ReadyForQuery(binding.TxState) + binding.Packet = &rfq + binding.PeerWrite(ctx) case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: - if err := eqp(ctx); err != nil { + if err := eqp(ctx, binding); err != nil { return err } default: - ctx.PeerFail(ctx.ErrUnexpectedPacket()) + binding.PeerFail(binding.ErrUnexpectedPacket()) } - if ctx.TxState == 'I' { + if binding.TxState == 'I' { return nil } - if !ctx.PeerRead() { + if !binding.PeerRead(ctx) { // abort tx - err := queryString(ctx, "ABORT;") + err := queryString(ctx, binding, "ABORT;") if err != nil { return err } - if ctx.TxState != 'I' { + if binding.TxState != 'I' { return ErrExpectedIdle } return nil @@ -315,13 +317,13 @@ func transaction(ctx *context) error { } } -func Transaction(server, peer *fed.Conn, initialPacket fed.Packet) (err, peerError error) { - ctx := context{ +func Transaction(ctx context.Context, server, peer *fed.Conn, initialPacket fed.Packet) (err, peerError error) { + pgState := serverToPeerBinding{ Server: server, Peer: peer, Packet: initialPacket, } - err = transaction(&ctx) - peerError = ctx.PeerError + err = transaction(ctx, &pgState) + peerError = pgState.PeerError return } diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index ec8b4c10d615653cd13d968ae6fe49d68cdc3b51..b825cb1822b0e4c9952a8d14e0d0fcb01e06fcf9 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -1,11 +1,12 @@ package bouncers import ( + "context" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" ) -func Bounce(client, server *fed.Conn, initialPacket fed.Packet) (clientError error, serverError error) { - serverError, clientError = backends.Transaction(server, client, initialPacket) +func Bounce(ctx context.Context, client, server *fed.Conn, initialPacket fed.Packet) (clientError error, serverError error) { + serverError, clientError = backends.Transaction(ctx, server, client, initialPacket) return } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 00036e1abc5a8672eac16afe582f83d69b29b431..ab55c407ef9c2bfb16fea9f99c1f61f0ebb50d06 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -1,6 +1,7 @@ package frontends import ( + "context" "crypto/tls" "strings" @@ -10,12 +11,23 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) +type acceptParams struct { + Conn *fed.Conn + Options acceptOptions +} + +type acceptResult struct { + CancelKey fed.BackendKey + IsCanceling bool +} + func startup0( - ctx *acceptContext, + ctx context.Context, params *acceptParams, + result *acceptResult, ) (cancelling bool, done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(false) + packet, err = params.Conn.ReadPacket(ctx, false) if err != nil { return } @@ -31,29 +43,29 @@ func startup0( switch control := mode.Mode.(type) { case *packets.StartupPayloadControlPayloadCancel: // Cancel - params.CancelKey.ProcessID = control.ProcessID - params.CancelKey.SecretKey = control.SecretKey + result.CancelKey.ProcessID = control.ProcessID + result.CancelKey.SecretKey = control.SecretKey cancelling = true done = true return case *packets.StartupPayloadControlPayloadSSL: // ssl is not enabled - if ctx.Options.SSLConfig == nil { - err = ctx.Conn.WriteByte('N') + if params.Options.SSLConfig == nil { + err = params.Conn.WriteByte(ctx, 'N') return } // do ssl - if err = ctx.Conn.WriteByte('S'); err != nil { + if err = params.Conn.WriteByte(ctx, 'S'); err != nil { return } - if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, false); err != nil { + if err = params.Conn.EnableSSL(ctx, params.Options.SSLConfig, false); err != nil { return } return case *packets.StartupPayloadControlPayloadGSSAPI: // GSSAPI is not supported yet - err = ctx.Conn.WriteByte('N') + err = params.Conn.WriteByte(ctx, 'N') return default: err = perror.New( @@ -69,9 +81,9 @@ func startup0( for _, parameter := range mode.Parameters { switch parameter.Key { case "user": - ctx.Conn.User = parameter.Value + params.Conn.User = parameter.Value case "database": - ctx.Conn.Database = parameter.Value + params.Conn.Database = parameter.Value case "options": fields := strings.Fields(parameter.Value) for i := 0; i < len(fields); i++ { @@ -91,10 +103,10 @@ func startup0( ikey := strutil.MakeCIString(key) - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + if params.Conn.InitialParameters == nil { + params.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = value + params.Conn.InitialParameters[ikey] = value default: err = perror.New( perror.FATAL, @@ -118,10 +130,10 @@ func startup0( } else { ikey := strutil.MakeCIString(parameter.Key) - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + if params.Conn.InitialParameters == nil { + params.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = parameter.Value + params.Conn.InitialParameters[ikey] = parameter.Value } } } @@ -132,13 +144,13 @@ func startup0( MinorProtocolVersion: 0, UnrecognizedProtocolOptions: unsupportedOptions, } - err = ctx.Conn.WritePacket(&uopts) + err = params.Conn.WritePacket(ctx, &uopts) if err != nil { return } } - if ctx.Conn.User == "" { + if params.Conn.User == "" { err = perror.New( perror.FATAL, perror.InvalidAuthorizationSpecification, @@ -146,8 +158,8 @@ func startup0( ) return } - if ctx.Conn.Database == "" { - ctx.Conn.Database = ctx.Conn.User + if params.Conn.Database == "" { + params.Conn.Database = params.Conn.User } done = true @@ -163,11 +175,12 @@ func startup0( } func accept0( - ctx *acceptContext, -) (params acceptParams, err error) { + ctx context.Context, + params *acceptParams, +) (result acceptResult, err error) { for { var done bool - params.IsCanceling, done, err = startup0(ctx, ¶ms) + result.IsCanceling, done, err = startup0(ctx, params, &result) if err != nil { return } @@ -179,18 +192,18 @@ func accept0( return } -func fail(client *fed.Conn, err error) { +func fail(ctx context.Context, client *fed.Conn, err error) { resp := perror.ToPacket(perror.Wrap(err)) - _ = client.WritePacket(resp) + _ = client.WritePacket(ctx, resp) } -func accept(ctx *acceptContext) (acceptParams, error) { - params, err := accept0(ctx) +func accept(ctx context.Context, params *acceptParams) (acceptResult, error) { + result, err := accept0(ctx, params) if err != nil { - fail(ctx.Conn, err) - return acceptParams{}, err + fail(ctx, params.Conn, err) + return acceptResult{}, err } - return params, nil + return result, nil } func Accept(conn *fed.Conn, tlsConfig *tls.Config) ( @@ -198,15 +211,16 @@ func Accept(conn *fed.Conn, tlsConfig *tls.Config) ( isCanceling bool, err error, ) { - ctx := acceptContext{ + params := acceptParams{ Conn: conn, Options: acceptOptions{ SSLConfig: tlsConfig, }, } - var params acceptParams - params, err = accept(&ctx) - cancelKey = params.CancelKey - isCanceling = params.IsCanceling + var result acceptResult + if result, err = accept(context.Background(), ¶ms); err == nil { + cancelKey = result.CancelKey + isCanceling = result.IsCanceling + } return } diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index 498dec39e44b7712f55a5f1807230373fcd7b9ae..709dedda15bc0b3ece9eafe35f357e589862f4be 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -1,6 +1,7 @@ package frontends import ( + "context" "crypto/rand" "encoding/binary" "errors" @@ -13,10 +14,15 @@ import ( "gfx.cafe/gfx/pggat/lib/perror" ) -func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err error) { +type authParams struct { + Conn *fed.Conn + Options authOptions +} + +func authenticationSASLInitial(ctx context.Context, params *authParams, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err error) { // check which authentication method the client wants var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -43,9 +49,9 @@ func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) return } -func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err error) { +func authenticationSASLContinue(ctx context.Context, params *authParams, tool auth.SASLVerifier) (resp []byte, done bool, err error) { var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return } @@ -67,7 +73,7 @@ func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier return } -func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { +func authenticationSASL(ctx context.Context, params *authParams, creds auth.SASLServer) error { var mode packets.AuthenticationPayloadSASL mechanisms := creds.SupportedSASLMechanisms() for _, mechanism := range mechanisms { @@ -79,12 +85,12 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { saslInitial := packets.Authentication{ Mode: &mode, } - err := ctx.Conn.WritePacket(&saslInitial) + err := params.Conn.WritePacket(ctx, &saslInitial) if err != nil { return err } - tool, resp, done, err := authenticationSASLInitial(ctx, creds) + tool, resp, done, err := authenticationSASLInitial(ctx, params, creds) if err != nil { return err } @@ -95,7 +101,7 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { final := packets.Authentication{ Mode: &m, } - err = ctx.Conn.WritePacket(&final) + err = params.Conn.WritePacket(ctx, &final) if err != nil { return err } @@ -105,13 +111,13 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { cont := packets.Authentication{ Mode: &m, } - err = ctx.Conn.WritePacket(&cont) + err = params.Conn.WritePacket(ctx, &cont) if err != nil { return err } } - resp, done, err = authenticationSASLContinue(ctx, tool) + resp, done, err = authenticationSASLContinue(ctx, params, tool) if err != nil { return err } @@ -120,7 +126,7 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { return nil } -func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { +func authenticationMD5(ctx context.Context, params *authParams, creds auth.MD5Server) error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { @@ -130,13 +136,14 @@ func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { md5Initial := packets.Authentication{ Mode: &mode, } - err = ctx.Conn.WritePacket(&md5Initial) + + err = params.Conn.WritePacket(ctx, &md5Initial) if err != nil { return err } var packet fed.Packet - packet, err = ctx.Conn.ReadPacket(true) + packet, err = params.Conn.ReadPacket(ctx, true) if err != nil { return err } @@ -154,12 +161,12 @@ func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { return nil } -func authenticate(ctx *authenticateContext) (err error) { - if ctx.Options.Credentials != nil { - if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { - err = authenticationSASL(ctx, credsSASL) - } else if credsMD5, ok := ctx.Options.Credentials.(auth.MD5Server); ok { - err = authenticationMD5(ctx, credsMD5) +func authenticate(ctx context.Context, params *authParams) (err error) { + if params.Options.Credentials != nil { + if credsSASL, ok := params.Options.Credentials.(auth.SASLServer); ok { + err = authenticationSASL(ctx, params, credsSASL) + } else if credsMD5, ok := params.Options.Credentials.(auth.MD5Server); ok { + err = authenticationMD5(ctx, params, credsMD5) } else { err = perror.New( perror.FATAL, @@ -176,10 +183,10 @@ func authenticate(ctx *authenticateContext) (err error) { authOk := packets.Authentication{ Mode: &packets.AuthenticationPayloadOk{}, } - if err = ctx.Conn.WritePacket(&authOk); err != nil { + if err = params.Conn.WritePacket(ctx, &authOk); err != nil { return } - ctx.Conn.Authenticated = true + params.Conn.Authenticated = true // send backend key data var processID [4]byte @@ -190,35 +197,35 @@ func authenticate(ctx *authenticateContext) (err error) { if _, err = rand.Reader.Read(backendKey[:]); err != nil { return } - ctx.Conn.BackendKey = fed.BackendKey{ + params.Conn.BackendKey = fed.BackendKey{ ProcessID: int32(binary.BigEndian.Uint32(processID[:])), SecretKey: int32(binary.BigEndian.Uint32(backendKey[:])), } keyData := packets.BackendKeyData{ - ProcessID: ctx.Conn.BackendKey.ProcessID, - SecretKey: ctx.Conn.BackendKey.SecretKey, + ProcessID: params.Conn.BackendKey.ProcessID, + SecretKey: params.Conn.BackendKey.SecretKey, } - if err = ctx.Conn.WritePacket(&keyData); err != nil { + if err = params.Conn.WritePacket(ctx, &keyData); err != nil { return } return } -func Authenticate(conn *fed.Conn, creds auth.Credentials) (err error) { +func Authenticate(ctx context.Context, conn *fed.Conn, creds auth.Credentials) (err error) { if conn.Authenticated { // already authenticated return } - ctx := authenticateContext{ + params := authParams{ Conn: conn, - Options: authenticateOptions{ + Options: authOptions{ Credentials: creds, }, } - err = authenticate(&ctx) + err = authenticate(ctx, ¶ms) if err != nil { // sleep after incorrect password time.Sleep(250 * time.Millisecond) diff --git a/lib/bouncer/frontends/v0/context.go b/lib/bouncer/frontends/v0/context.go deleted file mode 100644 index 993073ffab07b6d371176eff956b7627fa814a75..0000000000000000000000000000000000000000 --- a/lib/bouncer/frontends/v0/context.go +++ /dev/null @@ -1,13 +0,0 @@ -package frontends - -import "gfx.cafe/gfx/pggat/lib/fed" - -type acceptContext struct { - Conn *fed.Conn - Options acceptOptions -} - -type authenticateContext struct { - Conn *fed.Conn - Options authenticateOptions -} diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go index 304f7b82256efcd4edf0dc1d7125c25491f93320..af243b93b6b65f30841fc18d838a630d679cd745 100644 --- a/lib/bouncer/frontends/v0/options.go +++ b/lib/bouncer/frontends/v0/options.go @@ -10,6 +10,6 @@ type acceptOptions struct { SSLConfig *tls.Config } -type authenticateOptions struct { +type authOptions struct { Credentials auth.Credentials } diff --git a/lib/bouncer/frontends/v0/params.go b/lib/bouncer/frontends/v0/params.go deleted file mode 100644 index 4f28f8347cb111f03ade898b28fde91364420d3b..0000000000000000000000000000000000000000 --- a/lib/bouncer/frontends/v0/params.go +++ /dev/null @@ -1,8 +0,0 @@ -package frontends - -import "gfx.cafe/gfx/pggat/lib/fed" - -type acceptParams struct { - CancelKey fed.BackendKey - IsCanceling bool -} diff --git a/lib/fed/codecs/netconncodec/codec.go b/lib/fed/codecs/netconncodec/codec.go index 414943cae19a6254b120d4326c7ca786743a4f9e..3929f1ad0cd6b76e06104a3bf699d8c340883da8 100644 --- a/lib/fed/codecs/netconncodec/codec.go +++ b/lib/fed/codecs/netconncodec/codec.go @@ -1,6 +1,7 @@ package netconncodec import ( + "context" "crypto/tls" "errors" "fmt" @@ -32,7 +33,7 @@ func NewCodec(rw net.Conn) fed.PacketCodec { return c } -func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) { +func (c *Codec) ReadPacket(ctx context.Context, typed bool) (fed.Packet, error) { if err := c.decoder.Next(typed); err != nil { return nil, err } @@ -41,7 +42,7 @@ func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) { }, nil } -func (c *Codec) WritePacket(packet fed.Packet) error { +func (c *Codec) WritePacket(ctx context.Context, packet fed.Packet) error { err := c.encoder.Next(packet.Type(), packet.Length()) if err != nil { return err @@ -49,23 +50,23 @@ func (c *Codec) WritePacket(packet fed.Packet) error { return packet.WriteTo(&c.encoder) } -func (c *Codec) WriteByte(b byte) error { +func (c *Codec) WriteByte(ctx context.Context, b byte) error { return c.encoder.WriteByte(b) } -func (c *Codec) ReadByte() (byte, error) { - if err := c.Flush(); err != nil { +func (c *Codec) ReadByte(ctx context.Context) (byte, error) { + if err := c.Flush(ctx); err != nil { return 0, err } return c.decoder.ReadByte() } -func (c *Codec) Flush() error { +func (c *Codec) Flush(ctx context.Context) error { return c.encoder.Flush() } -func (c *Codec) Close() error { +func (c *Codec) Close(ctx context.Context) error { if err := c.encoder.Flush(); err != nil { return err } @@ -80,7 +81,7 @@ func (c *Codec) SSL() bool { return c.ssl } -func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error { +func (c *Codec) EnableSSL(ctx context.Context, config *tls.Config, isClient bool) error { c.mu.Lock() defer c.mu.Unlock() if c.ssl { @@ -89,7 +90,7 @@ func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error { c.ssl = true // Flush buffers - if err := c.Flush(); err != nil { + if err := c.Flush(ctx); err != nil { return err } if c.decoder.Buffered() > 0 { diff --git a/lib/fed/conn.go b/lib/fed/conn.go index a32336e7bf480c26947e76bc6681c46334bfe8b7..d9e0ad9e8967ce04fb9aa2f21103bfe0ceca7f9d 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -1,6 +1,7 @@ package fed import ( + "context" "crypto/tls" "io" "net" @@ -39,16 +40,16 @@ func NewConn(codec PacketCodec) *Conn { return c } -func (T *Conn) Flush() error { - return T.codec.Flush() +func (T *Conn) Flush(ctx context.Context) error { + return T.codec.Flush(ctx) } -func (T *Conn) readPacket(typed bool) (Packet, error) { - return T.codec.ReadPacket(typed) +func (T *Conn) readPacket(ctx context.Context, typed bool) (Packet, error) { + return T.codec.ReadPacket(ctx, typed) } -func (T *Conn) ReadPacket(typed bool) (Packet, error) { - if err := T.Flush(); err != nil { +func (T *Conn) ReadPacket(ctx context.Context, typed bool) (Packet, error) { + if err := T.Flush(ctx); err != nil { return nil, err } @@ -57,7 +58,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { for i := 0; i < len(T.Middleware); i++ { middleware := T.Middleware[i] for { - packet, err := middleware.PreRead(typed) + packet, err := middleware.PreRead(ctx, typed) if err != nil { return nil, err } @@ -67,7 +68,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } for j := i; j < len(T.Middleware); j++ { - packet, err = T.Middleware[j].ReadPacket(packet) + packet, err = T.Middleware[j].ReadPacket(ctx, packet) if err != nil { return nil, err } @@ -82,12 +83,12 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } } - packet, err := T.readPacket(typed) + packet, err := T.readPacket(ctx, typed) if err != nil { return nil, err } for _, middleware := range T.Middleware { - packet, err = middleware.ReadPacket(packet) + packet, err = middleware.ReadPacket(ctx, packet) if err != nil { return nil, err } @@ -101,16 +102,16 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } } -func (T *Conn) writePacket(packet Packet) error { - return T.codec.WritePacket(packet) +func (T *Conn) writePacket(ctx context.Context, packet Packet) error { + return T.codec.WritePacket(ctx, packet) } -func (T *Conn) WritePacket(packet Packet) error { +func (T *Conn) WritePacket(ctx context.Context, packet Packet) error { for i := len(T.Middleware) - 1; i >= 0; i-- { middleware := T.Middleware[i] var err error - packet, err = middleware.WritePacket(packet) + packet, err = middleware.WritePacket(ctx, packet) if err != nil { return err } @@ -119,7 +120,7 @@ func (T *Conn) WritePacket(packet Packet) error { } } if packet != nil { - if err := T.writePacket(packet); err != nil { + if err := T.writePacket(ctx, packet); err != nil { return err } } @@ -130,7 +131,7 @@ func (T *Conn) WritePacket(packet Packet) error { for { var err error - packet, err = middleware.PostWrite() + packet, err = middleware.PostWrite(ctx) if err != nil { return err } @@ -140,7 +141,7 @@ func (T *Conn) WritePacket(packet Packet) error { } for j := i; j >= 0; j-- { - packet, err = T.Middleware[j].WritePacket(packet) + packet, err = T.Middleware[j].WritePacket(ctx, packet) if err != nil { return err } @@ -150,7 +151,7 @@ func (T *Conn) WritePacket(packet Packet) error { } if packet != nil { - if err = T.writePacket(packet); err != nil { + if err = T.writePacket(ctx, packet); err != nil { return err } } @@ -160,8 +161,8 @@ func (T *Conn) WritePacket(packet Packet) error { return nil } -func (T *Conn) WriteByte(b byte) error { - return T.codec.WriteByte(b) +func (T *Conn) WriteByte(ctx context.Context, b byte) error { + return T.codec.WriteByte(ctx, b) } func (T *Conn) LocalAddr() net.Addr { @@ -169,14 +170,14 @@ func (T *Conn) LocalAddr() net.Addr { } -func (T *Conn) ReadByte() (byte, error) { - return T.codec.ReadByte() +func (T *Conn) ReadByte(ctx context.Context) (byte, error) { + return T.codec.ReadByte(ctx) } -func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error { - return T.codec.EnableSSL(config, isClient) +func (T *Conn) EnableSSL(ctx context.Context, config *tls.Config, isClient bool) error { + return T.codec.EnableSSL(ctx, config, isClient) } -func (T *Conn) Close() error { - return T.codec.Close() +func (T *Conn) Close(ctx context.Context) error { + return T.codec.Close(ctx) } diff --git a/lib/fed/interface.go b/lib/fed/interface.go deleted file mode 100644 index 636d62a13fd321278543fcd52d62e95e580b0886..0000000000000000000000000000000000000000 --- a/lib/fed/interface.go +++ /dev/null @@ -1,20 +0,0 @@ -package fed - -import ( - "crypto/tls" - "net" -) - -type PacketCodec interface { - ReadPacket(typed bool) (Packet, error) - WritePacket(packet Packet) error - WriteByte(b byte) error - ReadByte() (byte, error) - - LocalAddr() net.Addr - Flush() error - Close() error - - SSL() bool - EnableSSL(config *tls.Config, isClient bool) error -} diff --git a/lib/fed/middleware.go b/lib/fed/middleware.go index 92e31bd889d7fd87adda72761f6256dc665dc0c8..605d00ed7a6af81070ba8ea1c9b8dce683921a63 100644 --- a/lib/fed/middleware.go +++ b/lib/fed/middleware.go @@ -1,12 +1,14 @@ package fed +import "context" + // Middleware intercepts packets and possibly changes them. Return a 0 length packet to cancel. type Middleware interface { - PreRead(typed bool) (Packet, error) - ReadPacket(packet Packet) (Packet, error) + PreRead(ctx context.Context,typed bool) (Packet, error) + ReadPacket(ctx context.Context,packet Packet) (Packet, error) - WritePacket(packet Packet) (Packet, error) - PostWrite() (Packet, error) + WritePacket(ctx context.Context,packet Packet) (Packet, error) + PostWrite(ctx context.Context,) (Packet, error) } func LookupMiddleware[T Middleware](conn *Conn) (T, bool) { diff --git a/lib/fed/middlewares/eqp/client.go b/lib/fed/middlewares/eqp/client.go index 652d707369c1082da6f6fa4500acccbfd24e89e3..2af9b781b4ea59180b481461274433091e4dce1e 100644 --- a/lib/fed/middlewares/eqp/client.go +++ b/lib/fed/middlewares/eqp/client.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) @@ -12,23 +13,23 @@ func NewClient() *Client { return new(Client) } -func (T *Client) PreRead(_ bool) (fed.Packet, error) { +func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.C2S(packet) } -func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.S2C(packet) } -func (T *Client) PostWrite() (fed.Packet, error) { +func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } -func (T *Client) Set(other *Client) { +func (T *Client) Set(ctx context.Context, other *Client) { T.state.Set(&other.state) } diff --git a/lib/fed/middlewares/eqp/server.go b/lib/fed/middlewares/eqp/server.go index f847a9fa664d12f76162806a91a6c8911aa9e5e8..c24e472a7767e7291cf0a76c713dce26dbe307f9 100644 --- a/lib/fed/middlewares/eqp/server.go +++ b/lib/fed/middlewares/eqp/server.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" ) @@ -12,19 +13,19 @@ func NewServer() *Server { return new(Server) } -func (T *Server) PreRead(_ bool) (fed.Packet, error) { +func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.S2C(packet) } -func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return T.state.C2S(packet) } -func (T *Server) PostWrite() (fed.Packet, error) { +func (T *Server) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index 2cc1be7dfc5ce6fa34d3a3be78c5ebe4a8c0e71b..abf345aae310a0444a965ecd402ea55e207fdf1f 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -1,6 +1,7 @@ package eqp import ( + "context" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" @@ -19,7 +20,7 @@ func preparedStatementsEqual(a, b *packets.Parse) bool { return true } -func SyncMiddleware(c *Client, server *fed.Conn) error { +func SyncMiddleware(ctx context.Context, c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") @@ -35,7 +36,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { Which: 'P', Name: name, } - if err := server.WritePacket(&p); err != nil { + if err := server.WritePacket(ctx, &p); err != nil { return err } @@ -59,7 +60,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { Which: 'S', Name: name, } - if err := server.WritePacket(&p); err != nil { + if err := server.WritePacket(ctx, &p); err != nil { return err } @@ -74,7 +75,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { } } - if err := server.WritePacket(preparedStatement); err != nil { + if err := server.WritePacket(ctx, preparedStatement); err != nil { return err } @@ -83,7 +84,7 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { // bind all portals for _, portal := range c.state.portals { - if err := server.WritePacket(portal); err != nil { + if err := server.WritePacket(ctx, portal); err != nil { return err } @@ -92,18 +93,18 @@ func SyncMiddleware(c *Client, server *fed.Conn) error { if needsBackendSync { var err error - err, _ = backends.Sync(server, nil) + err, _ = backends.Sync(ctx, server, nil) return err } return nil } -func Sync(client, server *fed.Conn) error { +func Sync(ctx context.Context, client, server *fed.Conn) error { c, ok := fed.LookupMiddleware[*Client](client) if !ok { panic("middleware not found") } - return SyncMiddleware(c, server) + return SyncMiddleware(ctx, c, server) } diff --git a/lib/fed/middlewares/ps/client.go b/lib/fed/middlewares/ps/client.go index 94b5014ff1d39702527e388d09352c6a9fe53854..f0ac38b5048e7efe4271a6f786b5c358ef4769d3 100644 --- a/lib/fed/middlewares/ps/client.go +++ b/lib/fed/middlewares/ps/client.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/util/maps" @@ -18,15 +19,15 @@ func NewClient(parameters map[strutil.CIString]string) *Client { } } -func (T *Client) PreRead(_ bool) (fed.Packet, error) { +func (T *Client) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Client) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Client) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { switch packet.Type() { case packets.TypeParameterStatus: var p packets.ParameterStatus @@ -49,11 +50,11 @@ func (T *Client) WritePacket(packet fed.Packet) (fed.Packet, error) { } } -func (T *Client) PostWrite() (fed.Packet, error) { +func (T *Client) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } -func (T *Client) Set(other *Client) { +func (T *Client) Set(ctx context.Context, other *Client) { T.synced = other.synced maps.Clear(T.parameters) diff --git a/lib/fed/middlewares/ps/server.go b/lib/fed/middlewares/ps/server.go index 137b34df9fa2331758a91a3bab5bd7eab8ce93db..7ce9c15da2b488a26c6dddb81baf43fb6121c5c1 100644 --- a/lib/fed/middlewares/ps/server.go +++ b/lib/fed/middlewares/ps/server.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/util/strutil" @@ -16,11 +17,11 @@ func NewServer(parameters map[strutil.CIString]string) *Server { } } -func (T *Server) PreRead(_ bool) (fed.Packet, error) { +func (T *Server) PreRead(ctx context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) ReadPacket(ctx context.Context,packet fed.Packet) (fed.Packet, error) { switch packet.Type() { case packets.TypeParameterStatus: var p packets.ParameterStatus @@ -39,11 +40,11 @@ func (T *Server) ReadPacket(packet fed.Packet) (fed.Packet, error) { } } -func (T *Server) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Server) WritePacket(ctx context.Context,packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Server) PostWrite() (fed.Packet, error) { +func (T *Server) PostWrite(ctx context.Context,) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/middlewares/ps/sync.go b/lib/fed/middlewares/ps/sync.go index 152e86b0480990e44585ec006bc07ff56715ac7e..0655694a0d05e4e4f89358e02ab175bd66d3f337 100644 --- a/lib/fed/middlewares/ps/sync.go +++ b/lib/fed/middlewares/ps/sync.go @@ -1,6 +1,7 @@ package ps import ( + "context" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" @@ -8,7 +9,7 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) (clientErr, serverErr error) { +func sync(ctx context.Context, tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) (clientErr, serverErr error) { value, hasValue := c.parameters[name] expected, hasExpected := s.parameters[name] @@ -18,7 +19,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. Key: name.String(), Value: expected, } - clientErr = client.WritePacket(&ps) + clientErr = client.WritePacket(ctx, &ps) } return } @@ -26,7 +27,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. var doSet bool if hasValue && slices.Contains(tracking, name) { - if serverErr, _ = backends.SetParameter(server, nil, name, value); serverErr != nil { + if serverErr, _ = backends.SetParameter(ctx, server, nil, name, value); serverErr != nil { return } if s.parameters == nil { @@ -47,7 +48,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. Key: name.String(), Value: expected, } - if clientErr = client.WritePacket(&ps); clientErr != nil { + if clientErr = client.WritePacket(ctx, &ps); clientErr != nil { return } } @@ -55,14 +56,14 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. return } -func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) error { +func SyncMiddleware(ctx context.Context, tracking []strutil.CIString, c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") } for name := range c.parameters { - if _, err := sync(tracking, nil, c, server, s, name); err != nil { + if _, err := sync(ctx, tracking, nil, c, server, s, name); err != nil { return err } } @@ -71,7 +72,7 @@ func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) er if _, ok = c.parameters[name]; ok { continue } - if _, err := sync(tracking, nil, c, server, s, name); err != nil { + if _, err := sync(ctx, tracking, nil, c, server, s, name); err != nil { return err } } @@ -79,7 +80,7 @@ func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) er return nil } -func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, serverErr error) { +func Sync(ctx context.Context, tracking []strutil.CIString, client, server *fed.Conn) (clientErr, serverErr error) { c, ok := fed.LookupMiddleware[*Client](client) if !ok { panic("middleware not found") @@ -90,7 +91,7 @@ func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, ser } for name := range c.parameters { - if clientErr, serverErr = sync(tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { + if clientErr, serverErr = sync(ctx, tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { return } } @@ -99,7 +100,7 @@ func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, ser if _, ok = c.parameters[name]; ok { continue } - if clientErr, serverErr = sync(tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { + if clientErr, serverErr = sync(ctx, tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { return } } diff --git a/lib/fed/middlewares/unterminate/unterminate.go b/lib/fed/middlewares/unterminate/unterminate.go index 073dbe1084707bc202ef70324d1cbef4a37c7c78..6218b110ab02d6e90e276f60ef17f76c28175f61 100644 --- a/lib/fed/middlewares/unterminate/unterminate.go +++ b/lib/fed/middlewares/unterminate/unterminate.go @@ -1,6 +1,7 @@ package unterminate import ( + "context" "io" "gfx.cafe/gfx/pggat/lib/fed" @@ -13,22 +14,22 @@ var Unterminate = unterm{} type unterm struct{} -func (unterm) PreRead(_ bool) (fed.Packet, error) { +func (unterm) PreRead(_ context.Context, _ bool) (fed.Packet, error) { return nil, nil } -func (unterm) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (unterm) ReadPacket(_ context.Context, packet fed.Packet) (fed.Packet, error) { if packet.Type() == packets.TypeTerminate { return packet, io.EOF } return packet, nil } -func (unterm) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (unterm) WritePacket(_ context.Context, packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (unterm) PostWrite() (fed.Packet, error) { +func (unterm) PostWrite(_ context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/fed/packetCodec.go b/lib/fed/packetCodec.go new file mode 100644 index 0000000000000000000000000000000000000000..03c0f8d7a19e30a49602db8ec32a600808ec831e --- /dev/null +++ b/lib/fed/packetCodec.go @@ -0,0 +1,21 @@ +package fed + +import ( + "context" + "crypto/tls" + "net" +) + +type PacketCodec interface { + ReadPacket(ctx context.Context,typed bool) (Packet, error) + WritePacket(ctx context.Context,packet Packet) error + WriteByte(ctx context.Context,b byte) error + ReadByte(ctx context.Context,) (byte, error) + + LocalAddr() net.Addr + Flush(ctx context.Context,) error + Close(ctx context.Context,) error + + SSL() bool + EnableSSL(ctx context.Context,config *tls.Config, isClient bool) error +} diff --git a/lib/gat/app.go b/lib/gat/app.go index b91269605ab2847eb3b63839c20e39a454e92479..2c6abbad73aca55d9bbdba22c24a6555cae22092 100644 --- a/lib/gat/app.go +++ b/lib/gat/app.go @@ -1,6 +1,7 @@ package gat import ( + "context" "time" "github.com/caddyserver/caddy/v2" @@ -80,7 +81,7 @@ func (T *App) Start() error { } for _, server := range T.servers { - if err := server.Start(); err != nil { + if err := server.Start(context.Background()); err != nil { return err } } @@ -92,7 +93,7 @@ func (T *App) Stop() error { close(T.closed) for _, server := range T.servers { - if err := server.Stop(); err != nil { + if err := server.Stop(context.Background()); err != nil { return err } } diff --git a/lib/gat/handler.go b/lib/gat/handler.go index 41d1ed1f3d66ba3a6df1120b4b4ee0602a345e95..c29b9a56886e402bb30e4b131bef2ebe5e0fb10f 100644 --- a/lib/gat/handler.go +++ b/lib/gat/handler.go @@ -1,6 +1,7 @@ package gat import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/metrics" ) @@ -30,7 +31,7 @@ func (R RouterFunc) Route(conn *fed.Conn) error { type CancellableHandler interface { Handler - Cancel(key fed.BackendKey) + Cancel(ctx context.Context, key fed.BackendKey) } type MetricsHandler interface { diff --git a/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go b/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go index fb3a1304065b51c940c47db9d62cff31eefb1304..eaf22ac7464bcf70fb42ef616033b427d3b1f090 100644 --- a/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go +++ b/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go @@ -1,6 +1,7 @@ package google_cloud_sql import ( + "context" "crypto/tls" "fmt" "net" @@ -101,7 +102,7 @@ func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, repli var admin *fed.Conn defer func() { if admin != nil { - _ = admin.Close() + _ = admin.Close(context.Background()) } }() @@ -137,24 +138,26 @@ func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, repli inward, outward, _, _ := gsql.NewPair() + ctx := context.Background() + var b flip.Bank b.Queue(func() error { - return gsql.ExtendedQuery(inward, &result, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", user.Name) + return gsql.ExtendedQuery(ctx, inward, &result, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", user.Name) }) b.Queue(func() error { - initialPacket, err := outward.ReadPacket(true) + initialPacket, err := outward.ReadPacket(ctx, true) if err != nil { return err } - err, err2 := bouncers.Bounce(outward, admin, initialPacket) + err, err2 := bouncers.Bounce(ctx, outward, admin, initialPacket) if err != nil { return err } if err2 != nil { return err2 } - return outward.Close() + return outward.Close(ctx) }) if err = b.Wait(); err != nil { diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go index 193f0c603a7a21e2a48592e580f301571fcad0fb..86edfabd7928b8b689aea1a7946d3df07a42f08f 100644 --- a/lib/gat/handlers/discovery/module.go +++ b/lib/gat/handlers/discovery/module.go @@ -1,6 +1,7 @@ package discovery import ( + "context" "crypto/tls" "fmt" "sync" @@ -97,10 +98,10 @@ func (T *Module) Provision(ctx caddy.Context) error { } T.closed = make(chan struct{}) - if err := T.reconcile(); err != nil { + if err := T.reconcile(ctx); err != nil { return err } - go T.discoverLoop() + go T.discoverLoop(ctx) return nil } @@ -115,16 +116,16 @@ func (T *Module) Cleanup() error { T.poolsMu.Lock() defer T.poolsMu.Unlock() T.pools.Range(func(user string, database string, p poolAndCredentials) bool { - p.pool.Close() + p.pool.Close(context.Background()) T.pools.Delete(user, database) return true }) return nil } -func (T *Module) added(cluster Cluster) { +func (T *Module) added(ctx context.Context, cluster Cluster) { if prev, ok := T.clusters[cluster.ID]; ok { - T.updated(prev, cluster) + T.updated(ctx, prev, cluster) return } if T.clusters == nil { @@ -133,33 +134,33 @@ func (T *Module) added(cluster Cluster) { T.clusters[cluster.ID] = cluster for _, user := range cluster.Users { - T.addUser(cluster.Primary, cluster.Replicas, cluster.Databases, user) + T.addUser(ctx, cluster.Primary, cluster.Replicas, cluster.Databases, user) } } -func (T *Module) updated(prev, next Cluster) { +func (T *Module) updated(ctx context.Context, prev, next Cluster) { T.clusters[next.ID] = next // primary endpoints if prev.Primary != next.Primary { - T.replacePrimary(prev.Users, prev.Databases, next.Primary) + T.replacePrimary(ctx, prev.Users, prev.Databases, next.Primary) } // replica endpoints if len(prev.Replicas) != 0 && len(next.Replicas) == 0 { - T.removeReplicas(prev.Replicas, prev.Users, prev.Databases) + T.removeReplicas(ctx, prev.Replicas, prev.Users, prev.Databases) } else if len(prev.Replicas) == 0 && len(next.Replicas) != 0 { - T.addReplicas(next.Replicas, prev.Users, prev.Databases) + T.addReplicas(ctx, next.Replicas, prev.Users, prev.Databases) } else { // change # of replicas for id, nextReplica := range next.Replicas { prevReplica, ok := prev.Replicas[id] if !ok { - T.addReplica(prev.Users, prev.Databases, id, nextReplica) + T.addReplica(ctx, prev.Users, prev.Databases, id, nextReplica) } else if prevReplica != nextReplica { // don't need to remove, add will replace the recipe atomically - T.addReplica(prev.Users, prev.Databases, id, nextReplica) + T.addReplica(ctx, prev.Users, prev.Databases, id, nextReplica) } } for id := range prev.Replicas { @@ -168,7 +169,7 @@ func (T *Module) updated(prev, next Cluster) { continue // already handled } - T.removeReplica(prev.Users, prev.Databases, id) + T.removeReplica(ctx, prev.Users, prev.Databases, id) } } @@ -186,10 +187,10 @@ func (T *Module) updated(prev, next Cluster) { } if !ok { - T.addUser(next.Primary, next.Replicas, prev.Databases, nextUser) + T.addUser(ctx, next.Primary, next.Replicas, prev.Databases, nextUser) } else if nextUser.Password != prevUser.Password { - T.removeUser(next.Replicas, prev.Databases, nextUser.Username) - T.addUser(next.Primary, next.Replicas, prev.Databases, nextUser) + T.removeUser(ctx, next.Replicas, prev.Databases, nextUser.Username) + T.addUser(ctx, next.Primary, next.Replicas, prev.Databases, nextUser) } } outer: @@ -200,23 +201,23 @@ outer: } } - T.removeUser(next.Replicas, prev.Databases, prevUser.Username) + T.removeUser(ctx, next.Replicas, prev.Databases, prevUser.Username) } for _, nextDatabase := range next.Databases { if !slices.Contains(prev.Databases, nextDatabase) { - T.addDatabase(next.Primary, next.Replicas, next.Users, nextDatabase) + T.addDatabase(ctx, next.Primary, next.Replicas, next.Users, nextDatabase) } } for _, prevDatabase := range prev.Databases { if !slices.Contains(next.Databases, prevDatabase) { - T.removeDatabase(next.Replicas, next.Users, prevDatabase) + T.removeDatabase(ctx, next.Replicas, next.Users, prevDatabase) } } } -func (T *Module) addPrimaryNode(user User, database string, primary Node) { - p := T.getOrAddPool(user, database) +func (T *Module) addPrimaryNode(ctx context.Context, user User, database string, primary Node) { + p := T.getOrAddPool(ctx, user, database) d := pool.Recipe{ Dialer: pool.Dialer{ @@ -232,15 +233,15 @@ func (T *Module) addPrimaryNode(user User, database string, primary Node) { MinConnections: T.ServerMinConnections, MaxConnections: T.ServerMaxConnections, } - p.pool.AddRecipe("primary", &d) + p.pool.AddRecipe(ctx, "primary", &d) } func (T *Module) removePrimaryNode(username, database string) { T.removePool(username, database) } -func (T *Module) addReplicaNodes(user User, database string, replicas map[string]Node) { - p := T.getOrAddPool(user, database) +func (T *Module) addReplicaNodes(ctx context.Context, user User, database string, replicas map[string]Node) { + p := T.getOrAddPool(ctx, user, database) if rp, ok := p.pool.(pool.ReplicaPool); ok { for id, replica := range replicas { @@ -258,12 +259,12 @@ func (T *Module) addReplicaNodes(user User, database string, replicas map[string MinConnections: T.ServerMinConnections, MaxConnections: T.ServerMaxConnections, } - rp.AddReplicaRecipe(id, &d) + rp.AddReplicaRecipe(ctx, id, &d) } return } - rp := T.getOrAddReplicaPool(user, database) + rp := T.getOrAddReplicaPool(ctx, user, database) for id, replica := range replicas { d := pool.Recipe{ Dialer: pool.Dialer{ @@ -279,11 +280,11 @@ func (T *Module) addReplicaNodes(user User, database string, replicas map[string MinConnections: T.ServerMinConnections, MaxConnections: T.ServerMaxConnections, } - rp.pool.AddRecipe(id, &d) + rp.pool.AddRecipe(ctx, id, &d) } } -func (T *Module) removeReplicaNodes(username string, database string, replicas map[string]Node) { +func (T *Module) removeReplicaNodes(ctx context.Context, username string, database string, replicas map[string]Node) { p, ok := T.getPool(username, database) if !ok { return @@ -292,7 +293,7 @@ func (T *Module) removeReplicaNodes(username string, database string, replicas m // remove endpoints from replica pool if rp, ok := p.pool.(pool.ReplicaPool); ok { for key := range replicas { - rp.RemoveReplicaRecipe(key) + rp.RemoveReplicaRecipe(ctx, key) } return } @@ -301,8 +302,8 @@ func (T *Module) removeReplicaNodes(username string, database string, replicas m T.removeReplicaPool(username, database) } -func (T *Module) addReplicaNode(user User, database string, id string, replica Node) { - p := T.getOrAddPool(user, database) +func (T *Module) addReplicaNode(ctx context.Context, user User, database string, id string, replica Node) { + p := T.getOrAddPool(ctx, user, database) d := pool.Recipe{ Dialer: pool.Dialer{ @@ -320,15 +321,15 @@ func (T *Module) addReplicaNode(user User, database string, id string, replica N } if rp, ok := p.pool.(pool.ReplicaPool); ok { - rp.AddReplicaRecipe(id, &d) + rp.AddReplicaRecipe(ctx, id, &d) return } - rp := T.getOrAddReplicaPool(user, database) - rp.pool.AddRecipe(id, &d) + rp := T.getOrAddReplicaPool(ctx, user, database) + rp.pool.AddRecipe(ctx, id, &d) } -func (T *Module) removeReplicaNode(username string, database string, id string) { +func (T *Module) removeReplicaNode(ctx context.Context, username string, database string, id string) { p, ok := T.getPool(username, database) if !ok { return @@ -336,7 +337,7 @@ func (T *Module) removeReplicaNode(username string, database string, id string) // remove endpoints from replica pool if rp, ok := p.pool.(pool.ReplicaPool); ok { - rp.RemoveReplicaRecipe(id) + rp.RemoveReplicaRecipe(ctx, id) return } @@ -345,87 +346,87 @@ func (T *Module) removeReplicaNode(username string, database string, id string) if !ok { return } - rp.pool.RemoveRecipe(id) + rp.pool.RemoveRecipe(ctx, id) } // replacePrimary replaces the primary endpoint. -func (T *Module) replacePrimary(users []User, databases []string, primary Node) { +func (T *Module) replacePrimary(ctx context.Context, users []User, databases []string, primary Node) { for _, user := range users { for _, database := range databases { - T.addPrimaryNode(user, database, primary) + T.addPrimaryNode(ctx, user, database, primary) } } } // addReplicas adds multiple replicas. Other replicas must not exist. -func (T *Module) addReplicas(replicas map[string]Node, users []User, databases []string) { +func (T *Module) addReplicas(ctx context.Context, replicas map[string]Node, users []User, databases []string) { for _, user := range users { for _, database := range databases { - T.addReplicaNodes(user, database, replicas) + T.addReplicaNodes(ctx, user, database, replicas) } } } // removeReplicas removes all replicas. -func (T *Module) removeReplicas(replicas map[string]Node, users []User, databases []string) { +func (T *Module) removeReplicas(ctx context.Context, replicas map[string]Node, users []User, databases []string) { for _, user := range users { for _, database := range databases { - T.removeReplicaNodes(user.Username, database, replicas) + T.removeReplicaNodes(ctx, user.Username, database, replicas) } } } // addReplica adds a single replica. -func (T *Module) addReplica(users []User, databases []string, id string, replica Node) { +func (T *Module) addReplica(ctx context.Context, users []User, databases []string, id string, replica Node) { for _, user := range users { for _, database := range databases { - T.addReplicaNode(user, database, id, replica) + T.addReplicaNode(ctx, user, database, id, replica) } } } // removeReplica removes a single replica. -func (T *Module) removeReplica(users []User, databases []string, id string) { +func (T *Module) removeReplica(ctx context.Context, users []User, databases []string, id string) { for _, user := range users { for _, database := range databases { - T.removeReplicaNode(user.Username, database, id) + T.removeReplicaNode(ctx, user.Username, database, id) } } } // addUser adds a new user. -func (T *Module) addUser(primary Node, replicas map[string]Node, databases []string, user User) { +func (T *Module) addUser(ctx context.Context, primary Node, replicas map[string]Node, databases []string, user User) { for _, database := range databases { - T.addPrimaryNode(user, database, primary) - T.addReplicaNodes(user, database, replicas) + T.addPrimaryNode(ctx, user, database, primary) + T.addReplicaNodes(ctx, user, database, replicas) } } // removeUser removes a user. -func (T *Module) removeUser(replicas map[string]Node, databases []string, username string) { +func (T *Module) removeUser(ctx context.Context, replicas map[string]Node, databases []string, username string) { for _, database := range databases { - T.removeReplicaNodes(username, database, replicas) + T.removeReplicaNodes(ctx, username, database, replicas) T.removePrimaryNode(username, database) } } // addDatabase adds a new database. -func (T *Module) addDatabase(primary Node, replicas map[string]Node, users []User, database string) { +func (T *Module) addDatabase(ctx context.Context, primary Node, replicas map[string]Node, users []User, database string) { for _, user := range users { - T.addPrimaryNode(user, database, primary) - T.addReplicaNodes(user, database, replicas) + T.addPrimaryNode(ctx, user, database, primary) + T.addReplicaNodes(ctx, user, database, replicas) } } // removeDatabase removes a single database. -func (T *Module) removeDatabase(replicas map[string]Node, users []User, database string) { +func (T *Module) removeDatabase(ctx context.Context, replicas map[string]Node, users []User, database string) { for _, user := range users { - T.removeReplicaNodes(user.Username, database, replicas) + T.removeReplicaNodes(ctx, user.Username, database, replicas) T.removePrimaryNode(user.Username, database) } } -func (T *Module) removed(id string) { +func (T *Module) removed(ctx context.Context, id string) { cluster, ok := T.clusters[id] if !ok { return @@ -433,11 +434,11 @@ func (T *Module) removed(id string) { delete(T.clusters, id) for _, database := range cluster.Databases { - T.removeDatabase(cluster.Replicas, cluster.Users, database) + T.removeDatabase(ctx, cluster.Replicas, cluster.Users, database) } } -func (T *Module) reconcile() error { +func (T *Module) reconcile(ctx context.Context) error { clusters, err := T.discoverer.Clusters() if err != nil { return err @@ -446,9 +447,9 @@ func (T *Module) reconcile() error { for _, cluster := range clusters { prev, ok := T.clusters[cluster.ID] if !ok { - T.added(cluster) + T.added(ctx, cluster) } else { - T.updated(prev, cluster) + T.updated(ctx, prev, cluster) } } @@ -460,13 +461,13 @@ outer: continue outer } } - T.removed(id) + T.removed(ctx, id) } return nil } -func (T *Module) discoverLoop() { +func (T *Module) discoverLoop(ctx context.Context) { var reconcile <-chan time.Time if T.ReconcilePeriod != 0 { r := time.NewTicker(time.Duration(T.ReconcilePeriod)) @@ -477,11 +478,11 @@ func (T *Module) discoverLoop() { for { select { case cluster := <-T.discoverer.Added(): - T.added(cluster) + T.added(ctx, cluster) case id := <-T.discoverer.Removed(): - T.removed(id) + T.removed(ctx, id) case <-reconcile: - err := T.reconcile() + err := T.reconcile(ctx) if err != nil { T.log.Warn("failed to reconcile", zap.Error(err)) } @@ -512,7 +513,7 @@ func (T *Module) getCreds(user User) auth.Credentials { return creds } -func (T *Module) getOrAddPool(user User, database string) poolAndCredentials { +func (T *Module) getOrAddPool(ctx context.Context, user User, database string) poolAndCredentials { T.poolsMu.Lock() defer T.poolsMu.Unlock() if old, ok := T.pools.Load(user.Username, database); ok { @@ -521,7 +522,7 @@ func (T *Module) getOrAddPool(user User, database string) poolAndCredentials { creds := T.getCreds(user) p := poolAndCredentials{ - pool: T.poolFactory.NewPool(), + pool: T.poolFactory.NewPool(ctx), creds: creds, } T.pools.Store(user.Username, database, p) @@ -529,8 +530,8 @@ func (T *Module) getOrAddPool(user User, database string) poolAndCredentials { return p } -func (T *Module) getOrAddReplicaPool(user User, database string) poolAndCredentials { - return T.getOrAddPool(T.toReplicaUser(user), database) +func (T *Module) getOrAddReplicaPool(ctx context.Context, user User, database string) poolAndCredentials { + return T.getOrAddPool(ctx, T.toReplicaUser(user), database) } func (T *Module) getPool(user, database string) (poolAndCredentials, bool) { @@ -550,7 +551,7 @@ func (T *Module) removePool(user, database string) { if !ok { return } - p.pool.Close() + p.pool.Close(context.Background()) T.log.Info("removed pool", zap.String("user", user), zap.String("database", database)) T.pools.Delete(user, database) } @@ -570,22 +571,24 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { func (T *Module) Handle(next gat.Router) gat.Router { return gat.RouterFunc(func(conn *fed.Conn) error { + ctx := context.Background() + p, ok := T.getPool(conn.User, conn.Database) if !ok { return next.Route(conn) } - if err := frontends.Authenticate(conn, p.creds); err != nil { + if err := frontends.Authenticate(ctx, conn, p.creds); err != nil { return err } - return p.pool.Serve(conn) + return p.pool.Serve(ctx, conn) }) } -func (T *Module) Cancel(key fed.BackendKey) { +func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { T.poolsMu.RLock() defer T.poolsMu.RUnlock() T.pools.Range(func(_ string, _ string, p poolAndCredentials) bool { - p.pool.Cancel(key) + p.pool.Cancel(ctx, key) return true }) } diff --git a/lib/gat/handlers/pgbouncer/module.go b/lib/gat/handlers/pgbouncer/module.go index 49b54d9a130e11730020f48f115263210fdcdcbc..42b64cdf3b15f7bb24b4b84a6c0054badbdf7e43 100644 --- a/lib/gat/handlers/pgbouncer/module.go +++ b/lib/gat/handlers/pgbouncer/module.go @@ -1,6 +1,7 @@ package pgbouncer import ( + "context" "crypto/tls" "errors" "fmt" @@ -80,7 +81,7 @@ func (T *Module) Cleanup() error { defer T.mu.Unlock() T.pools.Range(func(user string, database string, p poolAndCredentials) bool { - p.pool.Close() + p.pool.Close(context.Background()) T.pools.Delete(user, database) return true }) @@ -88,7 +89,7 @@ func (T *Module) Cleanup() error { return nil } -func (T *Module) getPassword(user, database string) (string, bool) { +func (T *Module) getPassword(ctx context.Context, user, database string) (string, bool) { // try to get password password, ok := T.Config.PgBouncer.AuthFile[user] if !ok { @@ -105,7 +106,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { } } - authPool, ok := T.lookup(authUser, database) + authPool, ok := T.lookup(ctx, authUser, database) if !ok { return "", false } @@ -116,14 +117,14 @@ func (T *Module) getPassword(user, database string) (string, bool) { inward, outward, _, _ := gsql.NewPair() b.Queue(func() error { - if err := gsql.ExtendedQuery(inward, &result, T.Config.PgBouncer.AuthQuery, user); err != nil { + if err := gsql.ExtendedQuery(ctx, inward, &result, T.Config.PgBouncer.AuthQuery, user); err != nil { return err } - return inward.Close() + return inward.Close(ctx) }) b.Queue(func() error { - err := authPool.pool.Serve(outward) + err := authPool.pool.Serve(ctx, outward) if err != nil && !errors.Is(err, io.EOF) { return err } @@ -148,7 +149,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { return password, true } -func (T *Module) tryCreate(user, database string) (poolAndCredentials, bool) { +func (T *Module) tryCreate(ctx context.Context, user, database string) (poolAndCredentials, bool) { db, ok := T.Config.Databases[database] if !ok { // try wildcard @@ -159,7 +160,7 @@ func (T *Module) tryCreate(user, database string) (poolAndCredentials, bool) { } // try to get password - password, ok := T.getPassword(user, database) + password, ok := T.getPassword(ctx, user, database) if !ok { return poolAndCredentials{}, false } @@ -275,23 +276,25 @@ func (T *Module) tryCreate(user, database string) (poolAndCredentials, bool) { r.MaxConnections = T.Config.PgBouncer.MaxDBConnections } - p.pool.AddRecipe("pgbouncer", &r) + p.pool.AddRecipe(ctx, "pgbouncer", &r) return p, true } -func (T *Module) lookup(user, database string) (poolAndCredentials, bool) { +func (T *Module) lookup(ctx context.Context, user, database string) (poolAndCredentials, bool) { p, ok := T.pools.Load(user, database) if ok { return p, true } // try to create pool - return T.tryCreate(user, database) + return T.tryCreate(ctx, user, database) } func (T *Module) Handle(next gat.Router) gat.Router { return gat.RouterFunc(func(conn *fed.Conn) error { + ctx := context.Background() + // check ssl if T.Config.PgBouncer.ClientTLSSSLMode.IsRequired() { if !conn.SSL { @@ -327,16 +330,16 @@ func (T *Module) Handle(next gat.Router) gat.Router { ) } - p, ok := T.lookup(conn.User, conn.Database) + p, ok := T.lookup(ctx, conn.User, conn.Database) if !ok { return next.Route(conn) } - if err := frontends.Authenticate(conn, p.creds); err != nil { + if err := frontends.Authenticate(ctx, conn, p.creds); err != nil { return err } - return p.pool.Serve(conn) + return p.pool.Serve(ctx, conn) }) } @@ -349,11 +352,11 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { }) } -func (T *Module) Cancel(key fed.BackendKey) { +func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { T.mu.RLock() defer T.mu.RUnlock() T.pools.Range(func(_ string, _ string, p poolAndCredentials) bool { - p.pool.Cancel(key) + p.pool.Cancel(ctx, key) return true }) } diff --git a/lib/gat/handlers/pool/critics/latency/critic.go b/lib/gat/handlers/pool/critics/latency/critic.go index abeda6c5a34af794db3d7fc8ee6b9db5cc77fce8..671906270998daee087cf8949eaf3d355897f424 100644 --- a/lib/gat/handlers/pool/critics/latency/critic.go +++ b/lib/gat/handlers/pool/critics/latency/critic.go @@ -1,6 +1,7 @@ package latency import ( + "context" "time" "github.com/caddyserver/caddy/v2" @@ -28,9 +29,9 @@ func (T *Critic) CaddyModule() caddy.ModuleInfo { } } -func (T *Critic) Taste(conn *fed.Conn) (int, time.Duration, error) { +func (T *Critic) Taste(ctx context.Context, conn *fed.Conn) (int, time.Duration, error) { start := time.Now() - err, _ := backends.QueryString(conn, nil, "select 0") + err, _ := backends.QueryString(ctx, conn, nil, "select 0") if err != nil { return 0, time.Duration(T.Validity), err } diff --git a/lib/gat/handlers/pool/dialer.go b/lib/gat/handlers/pool/dialer.go index 1da5d9fea194640190cd085290e8c104ce083623..107154e217fc61bee78ed6b5635a23cd4c801c43 100644 --- a/lib/gat/handlers/pool/dialer.go +++ b/lib/gat/handlers/pool/dialer.go @@ -1,6 +1,7 @@ package pool import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -70,6 +71,7 @@ func (T *Dialer) Dial() (*fed.Conn, error) { conn.User = T.Username conn.Database = T.Database err = backends.Accept( + context.Background(), conn, T.SSLMode, T.SSLConfig, @@ -85,21 +87,21 @@ func (T *Dialer) Dial() (*fed.Conn, error) { return conn, nil } -func (T *Dialer) Cancel(key fed.BackendKey) { +func (T *Dialer) Cancel(ctx context.Context, key fed.BackendKey) { c, err := T.dial() if err != nil { return } conn := fed.NewConn(netconncodec.NewCodec(c)) defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() - if err = backends.Cancel(conn, key); err != nil { + if err = backends.Cancel(ctx, conn, key); err != nil { return } // wait for server to close the connection, this means that the server received it ok - _, _ = conn.ReadPacket(true) + _, _ = conn.ReadPacket(ctx, true) } var _ caddy.Provisioner = (*gat.Listener)(nil) diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index 6f44c1a6e704eb76a7121f01aa64e6b1e5ed75f6..01f8d24cdad6d3a3007f810836c0f12008e4fb88 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -1,6 +1,7 @@ package pool import ( + "context" "encoding/json" "github.com/caddyserver/caddy/v2" @@ -39,23 +40,24 @@ func (T *Module) Provision(ctx caddy.Context) error { if err != nil { return err } - T.pool = raw.(PoolFactory).NewPool() + T.pool = raw.(PoolFactory).NewPool(ctx) if err = T.Recipe.Provision(ctx); err != nil { return err } - T.pool.AddRecipe("recipe", &T.Recipe) + T.pool.AddRecipe(ctx, "recipe", &T.Recipe) return nil } func (T *Module) Handle(next gat.Router) gat.Router { return gat.RouterFunc(func(c *fed.Conn) error { - if err := frontends.Authenticate(c, nil); err != nil { + ctx := context.Background() + if err := frontends.Authenticate(ctx, c, nil); err != nil { return err } - return T.pool.Serve(c) + return T.pool.Serve(ctx, c) }) } @@ -63,8 +65,8 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { T.pool.ReadMetrics(&metrics.Pool) } -func (T *Module) Cancel(key fed.BackendKey) { - T.pool.Cancel(key) +func (T *Module) Cancel(ctx context.Context, key fed.BackendKey) { + T.pool.Cancel(ctx, key) } var _ gat.Handler = (*Module)(nil) diff --git a/lib/gat/handlers/pool/penalty.go b/lib/gat/handlers/pool/penalty.go index 7b7f92438788c9fb602eeb24c8ef75a0e290d8da..9d7b66adb51612bf5758b3f1fb6e25e0137d618e 100644 --- a/lib/gat/handlers/pool/penalty.go +++ b/lib/gat/handlers/pool/penalty.go @@ -1,6 +1,7 @@ package pool import ( + "context" "time" "gfx.cafe/gfx/pggat/lib/fed" @@ -8,5 +9,5 @@ import ( type Critic interface { // Taste calculates how much conn should be penalized. Lower is better - Taste(conn *fed.Conn) (score int, validity time.Duration, err error) + Taste(ctx context.Context, conn *fed.Conn) (score int, validity time.Duration, err error) } diff --git a/lib/gat/handlers/pool/pool.go b/lib/gat/handlers/pool/pool.go index e6792eba5aeeea9a939fe07c67d054e403926034..e81efb17edc849a1a5fbc1b98bf10f22f98a78b3 100644 --- a/lib/gat/handlers/pool/pool.go +++ b/lib/gat/handlers/pool/pool.go @@ -1,6 +1,7 @@ package pool import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/metrics" ) @@ -8,24 +9,24 @@ import ( type Pool interface { // AddRecipe will add the recipe to the pool for use. The pool should delete any existing recipes with the same name // and scale the recipe to min. - AddRecipe(name string, recipe *Recipe) + AddRecipe(ctx context.Context, name string, recipe *Recipe) // RemoveRecipe will remove a recipe and disconnect all servers created by that recipe. - RemoveRecipe(name string) + RemoveRecipe(ctx context.Context, name string) - Serve(conn *fed.Conn) error + Serve(ctx context.Context, conn *fed.Conn) error - Cancel(key fed.BackendKey) + Cancel(ctx context.Context, key fed.BackendKey) ReadMetrics(m *metrics.Pool) - Close() + Close(ctx context.Context) } type ReplicaPool interface { Pool - AddReplicaRecipe(name string, recipe *Recipe) - RemoveReplicaRecipe(name string) + AddReplicaRecipe(ctx context.Context, name string, recipe *Recipe) + RemoveReplicaRecipe(ctx context.Context, name string) } type PoolFactory interface { - NewPool() Pool + NewPool(ctx context.Context) Pool } diff --git a/lib/gat/handlers/pool/pools/basic/factory.go b/lib/gat/handlers/pool/pools/basic/factory.go index 44d58c95954b33d12d7a4c53e419eb06e884aa1a..eb90578d0ebe982e809f6ae8dea0fc18ece66195 100644 --- a/lib/gat/handlers/pool/pools/basic/factory.go +++ b/lib/gat/handlers/pool/pools/basic/factory.go @@ -1,6 +1,7 @@ package basic import ( + "context" "fmt" "github.com/caddyserver/caddy/v2" @@ -50,8 +51,8 @@ func (T *Factory) Provision(ctx caddy.Context) error { return nil } -func (T *Factory) NewPool() pool.Pool { - return NewPool(T.Config) +func (T *Factory) NewPool(ctx context.Context) pool.Pool { + return NewPool(ctx, T.Config) } var _ pool.PoolFactory = (*Factory)(nil) diff --git a/lib/gat/handlers/pool/pools/basic/pool.go b/lib/gat/handlers/pool/pools/basic/pool.go index b69d7ce30caaf45666dd6c4771bd1c464de11687..6a72da312474425380ed11ca00739939f6c31cd5 100644 --- a/lib/gat/handlers/pool/pools/basic/pool.go +++ b/lib/gat/handlers/pool/pools/basic/pool.go @@ -1,6 +1,7 @@ package basic import ( + "context" "fmt" "sync" "time" @@ -29,24 +30,24 @@ type Pool struct { mu sync.RWMutex } -func NewPool(config Config) *Pool { +func NewPool(ctx context.Context, config Config) *Pool { p := &Pool{ config: config, servers: spool.MakePool(config.Spool()), } - go p.servers.ScaleLoop() + go p.servers.ScaleLoop(ctx) return p } -func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { - T.servers.AddRecipe(name, recipe) +func (T *Pool) AddRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + T.servers.AddRecipe(ctx, name, recipe) } -func (T *Pool) RemoveRecipe(name string) { - T.servers.RemoveRecipe(name) +func (T *Pool) RemoveRecipe(ctx context.Context, name string) { + T.servers.RemoveRecipe(ctx, name) } -func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, serverErr error) { +func (T *Pool) SyncInitialParameters(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { clientParams := client.Conn.InitialParameters serverParams := server.Conn.InitialParameters @@ -57,7 +58,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, Key: key.String(), Value: serverParams[key], } - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return } @@ -74,7 +75,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, Key: key.String(), Value: value, } - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return } @@ -83,7 +84,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, continue } - serverErr, _ = backends.SetParameter(server.Conn, nil, key, value) + serverErr, _ = backends.SetParameter(ctx, server.Conn, nil, key, value) if serverErr != nil { return } @@ -101,7 +102,7 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, Key: key.String(), Value: value, } - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return } @@ -110,16 +111,16 @@ func (T *Pool) SyncInitialParameters(client *Client, server *spool.Server) (err, return } -func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) { +func (T *Pool) Pair(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { if T.config.ParameterStatusSync != ParameterStatusSyncNone || T.config.ExtendedQuerySync { client.SetState(metrics.ConnStatePairing, server) server.SetState(metrics.ConnStatePairing, client.ID) switch T.config.ParameterStatusSync { case ParameterStatusSyncDynamic: - err, serverErr = ps.Sync(T.config.TrackedParameters, client.Conn, server.Conn) + err, serverErr = ps.Sync(ctx, T.config.TrackedParameters, client.Conn, server.Conn) case ParameterStatusSyncInitial: - err, serverErr = T.SyncInitialParameters(client, server) + err, serverErr = T.SyncInitialParameters(ctx, client, server) } if err != nil || serverErr != nil { @@ -127,7 +128,7 @@ func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) } if T.config.ExtendedQuerySync { - serverErr = eqp.Sync(client.Conn, server.Conn) + serverErr = eqp.Sync(ctx, client.Conn, server.Conn) } if serverErr != nil { @@ -157,7 +158,7 @@ func (T *Pool) removeClient(client *Client) { delete(T.clients, client.Conn.BackendKey) } -func (T *Pool) Serve(conn *fed.Conn) error { +func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { if T.config.ParameterStatusSync == ParameterStatusSyncDynamic { conn.Middleware = append( conn.Middleware, @@ -185,9 +186,9 @@ func (T *Pool) Serve(conn *fed.Conn) error { defer func() { if server != nil { if serverErr != nil { - T.servers.RemoveServer(server) + T.servers.RemoveServer(ctx, server) } else { - T.servers.Release(server) + T.servers.Release(ctx, server) } server = nil } @@ -201,7 +202,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) if serverErr != nil { return serverErr } @@ -210,7 +211,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { } p := packets.ReadyForQuery('I') - err = client.Conn.WritePacket(&p) + err = client.Conn.WritePacket(ctx, &p) if err != nil { return err } @@ -236,12 +237,12 @@ func (T *Pool) Serve(conn *fed.Conn) error { for { if server != nil && T.config.ReleaseAfterTransaction { client.SetState(metrics.ConnStateIdle, nil) - T.servers.Release(server) + T.servers.Release(ctx, server) server = nil } var packet fed.Packet - packet, err = client.Conn.ReadPacket(true) + packet, err = client.Conn.ReadPacket(ctx, true) if err != nil { return err } @@ -255,7 +256,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) dur := time.Since(start) if err == nil && serverErr == nil { prom.OperationSimple.Acquire(opLabels).Observe(float64(dur) / float64(time.Millisecond)) @@ -264,7 +265,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { if err == nil && serverErr == nil { { start := time.Now() - err, serverErr = bouncers.Bounce(client.Conn, server.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, client.Conn, server.Conn, packet) if serverErr == nil { dur := time.Since(start) prom.OperationSimple.Execution(opLabels).Observe(float64(dur) / float64(time.Millisecond)) @@ -285,7 +286,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { } } -func (T *Pool) Cancel(key fed.BackendKey) { +func (T *Pool) Cancel(ctx context.Context, key fed.BackendKey) { peer := func() *spool.Server { T.mu.RLock() defer T.mu.RUnlock() @@ -303,7 +304,7 @@ func (T *Pool) Cancel(key fed.BackendKey) { return } - T.servers.Cancel(peer) + T.servers.Cancel(ctx, peer) } func (T *Pool) ReadMetrics(m *metrics.Pool) { @@ -322,8 +323,8 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close() { - T.servers.Close() +func (T *Pool) Close(ctx context.Context) { + T.servers.Close(ctx) } var _ pool.Pool = (*Pool)(nil) diff --git a/lib/gat/handlers/pool/pools/hybrid/factory.go b/lib/gat/handlers/pool/pools/hybrid/factory.go index d189be6cc6a1e410ab0928e03e0359e927082f25..a15c6bbce381d12908d67bc9f1e26f62060ca56a 100644 --- a/lib/gat/handlers/pool/pools/hybrid/factory.go +++ b/lib/gat/handlers/pool/pools/hybrid/factory.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "fmt" "github.com/caddyserver/caddy/v2" @@ -44,8 +45,8 @@ func (T *Factory) Provision(ctx caddy.Context) error { return nil } -func (T *Factory) NewPool() pool.Pool { - return NewPool(T.Config) +func (T *Factory) NewPool(ctx context.Context) pool.Pool { + return NewPool(ctx, T.Config) } var _ pool.PoolFactory = (*Factory)(nil) diff --git a/lib/gat/handlers/pool/pools/hybrid/middleware.go b/lib/gat/handlers/pool/pools/hybrid/middleware.go index 6912f9d2304389ea7dd648f9febd933c1074a6b9..18e05bac63ac1f51145e9399b4ddcec2f73872ab 100644 --- a/lib/gat/handlers/pool/pools/hybrid/middleware.go +++ b/lib/gat/handlers/pool/pools/hybrid/middleware.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/perror" @@ -20,7 +21,7 @@ func NewMiddleware() *Middleware { return m } -func (T *Middleware) PreRead(typed bool) (fed.Packet, error) { +func (T *Middleware) PreRead(ctx context.Context, typed bool) (fed.Packet, error) { if !T.primary { return nil, nil } @@ -37,7 +38,7 @@ func (T *Middleware) PreRead(typed bool) (fed.Packet, error) { }, nil } -func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) { +func (T *Middleware) ReadPacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if T.primary { return packet, nil } @@ -60,7 +61,7 @@ func (T *Middleware) ReadPacket(packet fed.Packet) (fed.Packet, error) { return p, nil } -func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) { +func (T *Middleware) WritePacket(ctx context.Context, packet fed.Packet) (fed.Packet, error) { if T.primary && (T.buf.Buffered() > 0 || T.bufDec.Buffered() > 0) { return nil, nil } @@ -84,7 +85,7 @@ func (T *Middleware) WritePacket(packet fed.Packet) (fed.Packet, error) { return packet, nil } -func (T *Middleware) PostWrite() (fed.Packet, error) { +func (T *Middleware) PostWrite(ctx context.Context) (fed.Packet, error) { return nil, nil } diff --git a/lib/gat/handlers/pool/pools/hybrid/pool.go b/lib/gat/handlers/pool/pools/hybrid/pool.go index fb87e467c904abd89d21259db112626613f1df0f..a2990df2f3591c356351c2dc27128f9175b9a75d 100644 --- a/lib/gat/handlers/pool/pools/hybrid/pool.go +++ b/lib/gat/handlers/pool/pools/hybrid/pool.go @@ -1,6 +1,7 @@ package hybrid import ( + "context" "fmt" "sync" "time" @@ -30,7 +31,7 @@ type Pool struct { mu sync.RWMutex } -func NewPool(config Config) *Pool { +func NewPool(ctx context.Context, config Config) *Pool { c := config.Spool() p := &Pool{ @@ -39,38 +40,38 @@ func NewPool(config Config) *Pool { primary: spool.MakePool(c), replica: spool.MakePool(c), } - go p.primary.ScaleLoop() - go p.replica.ScaleLoop() + go p.primary.ScaleLoop(ctx) + go p.replica.ScaleLoop(ctx) return p } -func (T *Pool) AddReplicaRecipe(name string, recipe *pool.Recipe) { - T.replica.AddRecipe(name, recipe) +func (T *Pool) AddReplicaRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + T.replica.AddRecipe(ctx, name, recipe) } -func (T *Pool) RemoveReplicaRecipe(name string) { - T.replica.RemoveRecipe(name) +func (T *Pool) RemoveReplicaRecipe(ctx context.Context, name string) { + T.replica.RemoveRecipe(ctx, name) } -func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { - T.primary.AddRecipe(name, recipe) +func (T *Pool) AddRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + T.primary.AddRecipe(ctx, name, recipe) } -func (T *Pool) RemoveRecipe(name string) { - T.primary.RemoveRecipe(name) +func (T *Pool) RemoveRecipe(ctx context.Context, name string) { + T.primary.RemoveRecipe(ctx, name) } -func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) { +func (T *Pool) Pair(ctx context.Context, client *Client, server *spool.Server) (err, serverErr error) { client.SetState(metrics.ConnStatePairing, server, true) server.SetState(metrics.ConnStatePairing, client.ID) - err, serverErr = ps.Sync(T.config.TrackedParameters, client.Conn, server.Conn) + err, serverErr = ps.Sync(ctx, T.config.TrackedParameters, client.Conn, server.Conn) if err != nil || serverErr != nil { return } - serverErr = eqp.Sync(client.Conn, server.Conn) + serverErr = eqp.Sync(ctx, client.Conn, server.Conn) if serverErr != nil { return @@ -81,14 +82,14 @@ func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) return } -func (T *Pool) PairPrimary(client *Client, psc *ps.Client, eqpc *eqp.Client, server *spool.Server) error { +func (T *Pool) PairPrimary(ctx context.Context, client *Client, psc *ps.Client, eqpc *eqp.Client, server *spool.Server) error { server.SetState(metrics.ConnStatePairing, client.ID) - if err := ps.SyncMiddleware(T.config.TrackedParameters, psc, server.Conn); err != nil { + if err := ps.SyncMiddleware(ctx, T.config.TrackedParameters, psc, server.Conn); err != nil { return err } - if err := eqp.SyncMiddleware(eqpc, server.Conn); err != nil { + if err := eqp.SyncMiddleware(ctx, eqpc, server.Conn); err != nil { return err } @@ -114,7 +115,7 @@ func (T *Pool) removeClient(client *Client) { delete(T.clients, client.Conn.BackendKey) } -func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { +func (T *Pool) serveRW(ctx context.Context, l prom.PoolHybridLabels, conn *fed.Conn) error { m := NewMiddleware() eqpa := eqp.NewClient() @@ -146,17 +147,17 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { defer func() { if primary != nil { if serverErr != nil { - T.primary.RemoveServer(primary) + T.primary.RemoveServer(ctx, primary) } else { - T.primary.Release(primary) + T.primary.Release(ctx, primary) } primary = nil } if replica != nil { if serverErr != nil { - T.replica.RemoveServer(replica) + T.replica.RemoveServer(ctx, replica) } else { - T.replica.Release(replica) + T.replica.Release(ctx, replica) } replica = nil } @@ -171,7 +172,7 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, replica) + err, serverErr = T.Pair(ctx, client, replica) if serverErr != nil { return serverErr } @@ -186,7 +187,7 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, primary) + err, serverErr = T.Pair(ctx, client, primary) if serverErr != nil { return serverErr } @@ -196,7 +197,7 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { } p := packets.ReadyForQuery('I') - if err = conn.WritePacket(&p); err != nil { + if err = conn.WritePacket(ctx, &p); err != nil { return err } @@ -205,17 +206,17 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { for { if primary != nil { - T.primary.Release(primary) + T.primary.Release(ctx, primary) primary = nil } if replica != nil { - T.replica.Release(replica) + T.replica.Release(ctx, replica) replica = nil } client.SetState(metrics.ConnStateIdle, nil, false) var packet fed.Packet - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -230,16 +231,16 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, replica) + err, serverErr = T.Pair(ctx, client, replica) dur := time.Since(start) - psi.Set(psa) - eqpi.Set(eqpa) + psi.Set(ctx, psa) + eqpi.Set(ctx, eqpa) if err == nil && serverErr == nil { prom.OperationHybrid.Acquire(l.ToOperation("replica")).Observe(float64(dur) / float64(time.Millisecond)) start := time.Now() - err, serverErr = bouncers.Bounce(conn, replica.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, replica.Conn, packet) if serverErr == nil { dur := time.Since(start) prom.OperationHybrid.Execution(l.ToOperation("replica")).Observe(float64(dur) / float64(time.Millisecond)) @@ -253,13 +254,12 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { // fallback to primary if err == (ErrReadOnly{}) { - prom.OperationHybrid.Miss(l.ToOperation("replica")).Inc() m.Primary() - T.replica.Release(replica) + T.replica.Release(ctx, replica) replica = nil - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -273,13 +273,13 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - serverErr = T.PairPrimary(client, psi, eqpi, primary) + serverErr = T.PairPrimary(ctx, client, psi, eqpi, primary) dur := time.Since(start) if serverErr == nil { prom.OperationHybrid.Acquire(l.ToOperation("primary")).Observe(float64(dur) / float64(time.Millisecond)) start := time.Now() - err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, primary.Conn, packet) dur := time.Since(start) prom.OperationHybrid.Execution(l.ToOperation("primary")).Observe(float64(dur) / float64(time.Millisecond)) } @@ -295,7 +295,7 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { // straight to primary m.Primary() - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -309,14 +309,14 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, primary) + err, serverErr = T.Pair(ctx, client, primary) dur := time.Since(start) if err == nil && serverErr == nil { prom.OperationHybrid.Acquire(l.ToOperation("primary")).Observe(float64(dur) / float64(time.Millisecond)) start := time.Now() - err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, primary.Conn, packet) if serverErr == nil { dur := time.Since(start) prom.OperationHybrid.Execution(l.ToOperation("primary")).Observe(float64(dur) / float64(time.Millisecond)) @@ -337,7 +337,7 @@ func (T *Pool) serveRW(l prom.PoolHybridLabels, conn *fed.Conn) error { } } -func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) error { +func (T *Pool) serveOnly(ctx context.Context, l prom.PoolHybridLabels, conn *fed.Conn, write bool) error { var sp *spool.Pool if write { sp = &T.primary @@ -366,9 +366,9 @@ func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) er defer func() { if server != nil { if serverErr != nil { - sp.RemoveServer(server) + sp.RemoveServer(ctx, server) } else { - sp.Release(server) + sp.Release(ctx, server) } server = nil } @@ -382,7 +382,7 @@ func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) er return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) if serverErr != nil { return serverErr } @@ -391,7 +391,7 @@ func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) er } p := packets.ReadyForQuery('I') - if err = conn.WritePacket(&p); err != nil { + if err = conn.WritePacket(ctx, &p); err != nil { return err } @@ -407,13 +407,13 @@ func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) er for { if server != nil { - sp.Release(server) + sp.Release(ctx, server) server = nil } client.SetState(metrics.ConnStateIdle, nil, true) var packet fed.Packet - packet, err = conn.ReadPacket(true) + packet, err = conn.ReadPacket(ctx, true) if err != nil { return err } @@ -425,12 +425,12 @@ func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) er if server == nil { return pool.ErrFailedToAcquirePeer } - err, serverErr = T.Pair(client, server) + err, serverErr = T.Pair(ctx, client, server) dur := time.Since(start) if err == nil && serverErr == nil { prom.OperationHybrid.Acquire(opL).Observe(float64(dur) / float64(time.Millisecond)) start := time.Now() - err, serverErr = bouncers.Bounce(conn, server.Conn, packet) + err, serverErr = bouncers.Bounce(ctx, conn, server.Conn, packet) if serverErr == nil { dur := time.Since(start) prom.OperationHybrid.Execution(opL).Observe(float64(dur) / float64(time.Millisecond)) @@ -449,7 +449,7 @@ func (T *Pool) serveOnly(l prom.PoolHybridLabels, conn *fed.Conn, write bool) er } } -func (T *Pool) Serve(conn *fed.Conn) error { +func (T *Pool) Serve(ctx context.Context, conn *fed.Conn) error { labels := prom.PoolHybridLabels{ Database: conn.Database, User: conn.User, @@ -467,17 +467,17 @@ func (T *Pool) Serve(conn *fed.Conn) error { defer prom.PoolHybrid.Current(labels).Dec() switch labels.Mode { case "ro": - return T.serveOnly(labels, conn, false) + return T.serveOnly(ctx, labels, conn, false) case "wo": - return T.serveOnly(labels, conn, true) + return T.serveOnly(ctx, labels, conn, true) case "rw": - return T.serveRW(labels, conn) + return T.serveRW(ctx, labels, conn) default: panic("impossible") } } -func (T *Pool) Cancel(key fed.BackendKey) { +func (T *Pool) Cancel(ctx context.Context, key fed.BackendKey) { peer, replica := func() (*spool.Server, bool) { T.mu.RLock() defer T.mu.RUnlock() @@ -496,9 +496,9 @@ func (T *Pool) Cancel(key fed.BackendKey) { } if replica { - T.replica.Cancel(peer) + T.replica.Cancel(ctx, peer) } else { - T.primary.Cancel(peer) + T.primary.Cancel(ctx, peer) } } @@ -519,9 +519,9 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close() { - T.primary.Close() - T.replica.Close() +func (T *Pool) Close(ctx context.Context) { + T.primary.Close(ctx) + T.replica.Close(ctx) } var _ pool.Pool = (*Pool)(nil) diff --git a/lib/gat/handlers/pool/spool/kitchen/chef.go b/lib/gat/handlers/pool/spool/kitchen/chef.go index e4fa39b81f91d6e027f58609b7157dc09eaa7230..3eb2f873fdc750c89cf96d8cfad4e176fd26d007 100644 --- a/lib/gat/handlers/pool/spool/kitchen/chef.go +++ b/lib/gat/handlers/pool/spool/kitchen/chef.go @@ -1,6 +1,7 @@ package kitchen import ( + "context" "fmt" "math" "sort" @@ -36,7 +37,7 @@ func NewChef(config Config) *Chef { } // Learn will add a recipe to the kitchen. Returns initial removed and added conns -func (T *Chef) Learn(name string, recipe *pool.Recipe) (removed []*fed.Conn, added []*fed.Conn) { +func (T *Chef) Learn(ctx context.Context, name string, recipe *pool.Recipe) (removed []*fed.Conn, added []*fed.Conn) { n := recipe.AllocateInitial() added = make([]*fed.Conn, 0, n) for i := 0; i < n; i++ { @@ -56,7 +57,7 @@ func (T *Chef) Learn(name string, recipe *pool.Recipe) (removed []*fed.Conn, add T.mu.Lock() defer T.mu.Unlock() - removed = T.forget(name) + removed = T.forget(ctx, name) r := NewRecipe(recipe, added) @@ -77,7 +78,7 @@ func (T *Chef) Learn(name string, recipe *pool.Recipe) (removed []*fed.Conn, add return } -func (T *Chef) forget(name string) []*fed.Conn { +func (T *Chef) forget(ctx context.Context, name string) []*fed.Conn { r, ok := T.byName[name] if !ok { return nil @@ -88,7 +89,7 @@ func (T *Chef) forget(name string) []*fed.Conn { for conn := range r.conns { conns = append(conns, conn) - _ = conn.Close() + _ = conn.Close(ctx) r.recipe.Free() delete(T.byConn, conn) @@ -102,11 +103,11 @@ func (T *Chef) forget(name string) []*fed.Conn { // Forget will remove a recipe from the kitchen. All conn made with the recipe will be closed. Returns conns made with // recipe. -func (T *Chef) Forget(name string) []*fed.Conn { +func (T *Chef) Forget(ctx context.Context, name string) []*fed.Conn { T.mu.Lock() defer T.mu.Unlock() - return T.forget(name) + return T.forget(ctx, name) } func (T *Chef) Empty() bool { @@ -123,7 +124,7 @@ func (T *Chef) cook(r *Recipe) (*fed.Conn, error) { return r.recipe.Dial() } -func (T *Chef) score(r *Recipe) error { +func (T *Chef) score(ctx context.Context, r *Recipe) error { now := time.Now() r.ratings = slices.Resize(r.ratings, len(T.config.Critics)) @@ -164,7 +165,7 @@ func (T *Chef) score(r *Recipe) error { return err } defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() for i, critic := range critics { @@ -174,7 +175,7 @@ func (T *Chef) score(r *Recipe) error { var score int var validity time.Duration - score, validity, err = critic.Taste(conn) + score, validity, err = critic.Taste(ctx, conn) if err != nil { return err } @@ -211,12 +212,12 @@ func (T *Chef) score(r *Recipe) error { } // Cook will cook the best recipe -func (T *Chef) Cook() (*fed.Conn, error) { +func (T *Chef) Cook(ctx context.Context) (*fed.Conn, error) { T.mu.Lock() defer T.mu.Unlock() for _, r := range T.byName { - if err := T.score(r); err != nil { + if err := T.score(ctx, r); err != nil { r.score = math.MaxInt T.config.Logger.Error("failed to score recipe", zap.Error(err)) continue @@ -277,7 +278,7 @@ func (T *Chef) Cook() (*fed.Conn, error) { } // Burn forcefully closes conn and escorts it out of the kitchen. -func (T *Chef) Burn(conn *fed.Conn) { +func (T *Chef) Burn(ctx context.Context, conn *fed.Conn) { T.mu.Lock() defer T.mu.Unlock() @@ -286,14 +287,14 @@ func (T *Chef) Burn(conn *fed.Conn) { return } r.recipe.Free() - _ = conn.Close() + _ = conn.Close(ctx) delete(T.byConn, conn) delete(r.conns, conn) } // Ignite tries to Burn conn. If successful, conn is closed and returns true -func (T *Chef) Ignite(conn *fed.Conn) bool { +func (T *Chef) Ignite(ctx context.Context, conn *fed.Conn) bool { T.mu.Lock() defer T.mu.Unlock() @@ -304,14 +305,14 @@ func (T *Chef) Ignite(conn *fed.Conn) bool { if !r.recipe.TryFree() { return false } - _ = conn.Close() + _ = conn.Close(ctx) delete(T.byConn, conn) delete(r.conns, conn) return true } -func (T *Chef) Cancel(conn *fed.Conn) { +func (T *Chef) Cancel(ctx context.Context, conn *fed.Conn) { T.mu.Lock() defer T.mu.Unlock() @@ -320,10 +321,10 @@ func (T *Chef) Cancel(conn *fed.Conn) { return } - r.recipe.Cancel(conn.BackendKey) + r.recipe.Cancel(ctx, conn.BackendKey) } -func (T *Chef) Close() { +func (T *Chef) Close(ctx context.Context) { T.mu.Lock() defer T.mu.Unlock() @@ -331,7 +332,7 @@ func (T *Chef) Close() { T.order = T.order[:0] for conn, r := range T.byConn { r.recipe.Free() - _ = conn.Close() + _ = conn.Close(ctx) delete(T.byConn, conn) delete(r.conns, conn) diff --git a/lib/gat/handlers/pool/spool/pool.go b/lib/gat/handlers/pool/spool/pool.go index 7589ae16ea073585b4a480993c909d517d9d82f6..ef733f3d50ffffb151a0a659c3fa39325217d305 100644 --- a/lib/gat/handlers/pool/spool/pool.go +++ b/lib/gat/handlers/pool/spool/pool.go @@ -1,6 +1,7 @@ package spool import ( + "context" "sync" "time" @@ -46,9 +47,9 @@ func MakePool(config Config) Pool { } } -func NewPool(config Config) *Pool { +func NewPool(ctx context.Context, config Config) *Pool { p := MakePool(config) - go p.ScaleLoop() + go p.ScaleLoop(ctx) return &p } @@ -92,8 +93,8 @@ func (T *Pool) removeServer(conn *fed.Conn) { T.pooler.DeleteServer(server.ID) } -func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { - removed, added := T.chef.Learn(name, recipe) +func (T *Pool) AddRecipe(ctx context.Context, name string, recipe *pool.Recipe) { + removed, added := T.chef.Learn(ctx, name, recipe) if len(removed) == 0 && len(added) == 0 { return } @@ -110,8 +111,8 @@ func (T *Pool) AddRecipe(name string, recipe *pool.Recipe) { } } -func (T *Pool) RemoveRecipe(name string) { - servers := T.chef.Forget(name) +func (T *Pool) RemoveRecipe(ctx context.Context, name string) { + servers := T.chef.Forget(ctx, name) if len(servers) == 0 { return } @@ -128,8 +129,8 @@ func (T *Pool) Empty() bool { return T.chef.Empty() } -func (T *Pool) ScaleUp() error { - server, err := T.chef.Cook() +func (T *Pool) ScaleUp(ctx context.Context) error { + server, err := T.chef.Cook(ctx) if err != nil { return err } @@ -142,7 +143,7 @@ func (T *Pool) ScaleUp() error { return nil } -func (T *Pool) ScaleDown(now time.Time) time.Duration { +func (T *Pool) ScaleDown(ctx context.Context, now time.Time) time.Duration { T.mu.Lock() defer T.mu.Unlock() @@ -158,7 +159,7 @@ func (T *Pool) ScaleDown(now time.Time) time.Duration { idle := now.Sub(since) if idle > T.config.IdleTimeout { // try to free - if T.chef.Ignite(s.Conn) { + if T.chef.Ignite(ctx, s.Conn) { delete(T.serversByID, s.ID) delete(T.serversByConn, s.Conn) T.pooler.DeleteServer(s.ID) @@ -174,7 +175,7 @@ func (T *Pool) ScaleDown(now time.Time) time.Duration { return m } -func (T *Pool) ScaleLoop() { +func (T *Pool) ScaleLoop(ctx context.Context) { idle := new(time.Timer) if T.config.IdleTimeout != 0 { idle = time.NewTimer(T.config.IdleTimeout) @@ -198,7 +199,7 @@ func (T *Pool) ScaleLoop() { ok := true for T.pooler.Waiters() > 0 { - if err := T.ScaleUp(); err != nil { + if err := T.ScaleUp(ctx); err != nil { ok = false break } @@ -222,7 +223,7 @@ func (T *Pool) ScaleLoop() { ok := true for T.pooler.Waiters() > 0 { - if err := T.ScaleUp(); err != nil { + if err := T.ScaleUp(ctx); err != nil { ok = false break } @@ -239,7 +240,7 @@ func (T *Pool) ScaleLoop() { } case now := <-idle.C: // scale down - idle.Reset(T.ScaleDown(now)) + idle.Reset(T.ScaleDown(ctx, now)) } } } @@ -275,13 +276,13 @@ func (T *Pool) Acquire(client uuid.UUID) *Server { } } -func (T *Pool) Release(server *Server) { +func (T *Pool) Release(ctx context.Context, server *Server) { if T.config.ResetQuery != "" { server.SetState(metrics.ConnStateRunningResetQuery, uuid.Nil) - if err, _ := backends.QueryString(server.Conn, nil, T.config.ResetQuery); err != nil { + if err, _ := backends.QueryString(ctx, server.Conn, nil, T.config.ResetQuery); err != nil { T.config.Logger.Error("failed to run reset query", zap.Error(err)) - T.RemoveServer(server) + T.RemoveServer(ctx,server) return } } @@ -291,8 +292,8 @@ func (T *Pool) Release(server *Server) { server.SetState(metrics.ConnStateIdle, uuid.Nil) } -func (T *Pool) RemoveServer(server *Server) { - T.chef.Burn(server.Conn) +func (T *Pool) RemoveServer(ctx context.Context, server *Server) { + T.chef.Burn(ctx, server.Conn) T.pooler.DeleteServer(server.ID) T.mu.Lock() @@ -303,8 +304,8 @@ func (T *Pool) RemoveServer(server *Server) { T.pooler.DeleteServer(server.ID) } -func (T *Pool) Cancel(server *Server) { - T.chef.Cancel(server.Conn) +func (T *Pool) Cancel(ctx context.Context, server *Server) { + T.chef.Cancel(ctx, server.Conn) } func (T *Pool) ReadMetrics(m *metrics.Pool) { @@ -321,10 +322,10 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { } } -func (T *Pool) Close() { +func (T *Pool) Close(ctx context.Context) { close(T.closed) - T.chef.Close() + T.chef.Close(ctx) T.pooler.Close() T.mu.Lock() diff --git a/lib/gat/handlers/rewrite_password/module.go b/lib/gat/handlers/rewrite_password/module.go index 7c2345733cd5dad4cc5ccf9b7fb49058c7a6daf7..4f8b3f8e8c92a39a5060df9d0b3208af9f44813e 100644 --- a/lib/gat/handlers/rewrite_password/module.go +++ b/lib/gat/handlers/rewrite_password/module.go @@ -1,6 +1,7 @@ package rewrite_password import ( + "context" "github.com/caddyserver/caddy/v2" "gfx.cafe/gfx/pggat/lib/auth/credentials" @@ -29,6 +30,7 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { func (T *Module) Handle(next gat.Router) gat.Router { return gat.RouterFunc(func(conn *fed.Conn) error { if err := frontends.Authenticate( + context.Background(), conn, credentials.FromString(conn.User, T.Password), ); err != nil { diff --git a/lib/gat/server.go b/lib/gat/server.go index 81fb5dce0c4d916f6b2fd896dbaa60f1d39b3a52..553edadc9f9adc8747af650d92f77b1d40503e25 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -1,6 +1,7 @@ package gat import ( + "context" "crypto/tls" "errors" "fmt" @@ -67,7 +68,7 @@ func (T *Server) Provision(ctx caddy.Context) error { return nil } -func (T *Server) Start() error { +func (T *Server) Start(ctx context.Context) error { for _, listener := range T.listen { if err := listener.Start(); err != nil { return err @@ -75,7 +76,7 @@ func (T *Server) Start() error { go func(listener *Listener) { for { - if !T.acceptFrom(listener) { + if !T.acceptFrom(ctx, listener) { break } } @@ -85,7 +86,7 @@ func (T *Server) Start() error { return nil } -func (T *Server) Stop() error { +func (T *Server) Stop(ctx context.Context) error { for _, listen := range T.listen { if err := listen.Stop(); err != nil { return err @@ -95,9 +96,9 @@ func (T *Server) Stop() error { return nil } -func (T *Server) Cancel(key fed.BackendKey) { +func (T *Server) Cancel(ctx context.Context, key fed.BackendKey) { for _, cancellableHandler := range T.cancellableHandlers { - cancellableHandler.Cancel(key) + cancellableHandler.Cancel(ctx, key) } } @@ -108,6 +109,8 @@ func (T *Server) ReadMetrics(m *metrics.Server) { } func (T *Server) Serve(conn *fed.Conn) { + ctx := context.Background() + composed := Router(RouterFunc(func(conn *fed.Conn) error { // database not found errResp := perror.ToPacket( @@ -117,7 +120,7 @@ func (T *Server) Serve(conn *fed.Conn) { fmt.Sprintf(`Database "%s" not found`, conn.Database), ), ) - _ = conn.WritePacket(errResp) + _ = conn.WritePacket(ctx, errResp) T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database)) return nil })) @@ -139,14 +142,14 @@ func (T *Server) Serve(conn *fed.Conn) { } errResp := perror.ToPacket(perror.Wrap(err)) - _ = conn.WritePacket(errResp) + _ = conn.WritePacket(ctx, errResp) return } } -func (T *Server) accept(listener *Listener, conn *fed.Conn) { +func (T *Server) accept(ctx context.Context, listener *Listener, conn *fed.Conn) { defer func() { - _ = conn.Close() + _ = conn.Close(ctx) }() labels := prom.ListenerLabels{ListenAddr: listener.networkAddress.String()} @@ -167,7 +170,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { } if isCanceling { - T.Cancel(cancelKey) + T.Cancel(ctx, cancelKey) return } @@ -181,6 +184,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { if listener.MaxConnections != 0 && int(count) > listener.MaxConnections { _ = conn.WritePacket( + ctx, perror.ToPacket(perror.New( perror.FATAL, perror.TooManyConnections, @@ -193,9 +197,9 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { T.Serve(conn) } -func (T *Server) acceptFrom(listener *Listener) bool { +func (T *Server) acceptFrom(ctx context.Context, listener *Listener) bool { err := listener.listener.Accept(func(c *fed.Conn) { - T.accept(listener, c) + T.accept(ctx, listener, c) }) if err != nil { if errors.Is(err, net.ErrClosed) { diff --git a/lib/gsql/eq.go b/lib/gsql/eq.go index 3cfc8d49ecec61e2a664687bdb6124f475610a69..6e207b63af1162a6331a4f4af9426bb0fad309d6 100644 --- a/lib/gsql/eq.go +++ b/lib/gsql/eq.go @@ -1,6 +1,7 @@ package gsql import ( + "context" "reflect" "strconv" @@ -8,16 +9,16 @@ import ( packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func ExtendedQuery(client *fed.Conn, result any, query string, args ...any) error { +func ExtendedQuery(ctx context.Context, client *fed.Conn, result any, query string, args ...any) error { if len(args) == 0 { - return Query(client, []any{result}, query) + return Query(ctx, client, []any{result}, query) } // parse parse := packets.Parse{ Query: query, } - if err := client.WritePacket(&parse); err != nil { + if err := client.WritePacket(ctx, &parse); err != nil { return err } @@ -59,7 +60,7 @@ outer: bind := packets.Bind{ Parameters: params, } - if err := client.WritePacket(&bind); err != nil { + if err := client.WritePacket(ctx, &bind); err != nil { return err } @@ -67,29 +68,29 @@ outer: describe := packets.Describe{ Which: 'P', } - if err := client.WritePacket(&describe); err != nil { + if err := client.WritePacket(ctx, &describe); err != nil { return err } // execute execute := packets.Execute{} - if err := client.WritePacket(&execute); err != nil { + if err := client.WritePacket(ctx, &execute); err != nil { return err } // sync sync := packets.Sync{} - if err := client.WritePacket(&sync); err != nil { + if err := client.WritePacket(ctx, &sync); err != nil { return err } // result - if err := readQueryResults(client, result); err != nil { + if err := readQueryResults(ctx, client, result); err != nil { return err } // make sure we receive ready for query - packet, err := client.ReadPacket(true) + packet, err := client.ReadPacket(ctx, true) if err != nil { return err } diff --git a/lib/gsql/query.go b/lib/gsql/query.go index 73e236a386bbd023e719cb68197bcc2e0703b699..f5c8654964fc6bf30c4569868d5dc6a6e16aea21 100644 --- a/lib/gsql/query.go +++ b/lib/gsql/query.go @@ -1,22 +1,23 @@ package gsql import ( + "context" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func Query(client *fed.Conn, results []any, query string) error { +func Query(ctx context.Context, client *fed.Conn, results []any, query string) error { var q = packets.Query(query) - if err := client.WritePacket(&q); err != nil { + if err := client.WritePacket(ctx, &q); err != nil { return err } - if err := readQueryResults(client, results...); err != nil { + if err := readQueryResults(ctx, client, results...); err != nil { return err } // make sure we receive ready for query - packet, err := client.ReadPacket(true) + packet, err := client.ReadPacket(ctx, true) if err != nil { return err } @@ -28,9 +29,9 @@ func Query(client *fed.Conn, results []any, query string) error { return nil } -func readQueryResults(client *fed.Conn, results ...any) error { +func readQueryResults(ctx context.Context, client *fed.Conn, results ...any) error { for _, result := range results { - if err := readRows(client, result); err != nil { + if err := readRows(ctx, client, result); err != nil { return err } } diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index 476748de548ed0ea65fccc8c245fed625cb828ba..1fe9fa90112ec18f71be077cff4fcd53343751bc 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -1,6 +1,7 @@ package gsql_test import ( + "context" "crypto/tls" "log" "net" @@ -32,8 +33,10 @@ func TestQuery(t *testing.T) { t.Error(err) return } + ctx := context.Background() server := fed.NewConn(netconncodec.NewCodec(s)) err = backends.Accept( + ctx, server, "disable", &tls.Config{}, @@ -56,25 +59,25 @@ func TestQuery(t *testing.T) { var b flip.Bank b.Queue(func() error { - if err := gsql.ExtendedQuery(inward, &res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "postgres"); err != nil { + if err := gsql.ExtendedQuery(context.Background(), inward, &res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "postgres"); err != nil { return err } return nil }) b.Queue(func() error { - initial, err := outward.ReadPacket(true) + initial, err := outward.ReadPacket(ctx, true) if err != nil { return err } - clientErr, serverErr := bouncers.Bounce(outward, server, initial) + clientErr, serverErr := bouncers.Bounce(ctx, outward, server, initial) if clientErr != nil { return clientErr } if serverErr != nil { return serverErr } - if err := outward.Close(); err != nil { + if err := outward.Close(ctx); err != nil { return err } return nil diff --git a/lib/gsql/row.go b/lib/gsql/row.go index be8f0acaed768534a99d553911c8cf9bb68cc2e6..f14ba3197591b08af86d2a064934e2954f1e8de3 100644 --- a/lib/gsql/row.go +++ b/lib/gsql/row.go @@ -1,6 +1,7 @@ package gsql import ( + "context" "reflect" "strconv" @@ -9,13 +10,13 @@ import ( "gfx.cafe/gfx/pggat/lib/perror" ) -func readRows(client *fed.Conn, result any) error { +func readRows(ctx context.Context,client *fed.Conn, result any) error { res := reflect.ValueOf(result) row := 0 var rd packets.RowDescription for { - packet, err := client.ReadPacket(true) + packet, err := client.ReadPacket(ctx,true) if err != nil { return err } diff --git a/test/runner.go b/test/runner.go index 8b1c7e9683652584a7b5de40789c62fc8e565e53..53ba36a2bf51d74751585b95094952de667d4ab7 100644 --- a/test/runner.go +++ b/test/runner.go @@ -2,6 +2,7 @@ package test import ( "bytes" + "context" "errors" "fmt" "io" @@ -30,16 +31,16 @@ func MakeRunner(config Config, test Test) Runner { func (T *Runner) prepare(client *fed.Conn, until int) error { for i := 0; i < until; i++ { x := T.test.Packets[i] - if err := client.WritePacket(x); err != nil { + if err := client.WritePacket(context.Background(), x); err != nil { return err } } - if err := client.WritePacket(&packets.Terminate{}); err != nil { + if err := client.WritePacket(context.Background(), &packets.Terminate{}); err != nil { return err } - return client.Flush() + return client.Flush(context.Background()) } func (T *Runner) runModeL1(dialer pool.Dialer, client *fed.Conn) error { @@ -48,14 +49,14 @@ func (T *Runner) runModeL1(dialer pool.Dialer, client *fed.Conn) error { return err } defer func() { - _ = server.Close() + _ = server.Close(context.Background()) }() client.Middleware = append(client.Middleware, unterminate.Unterminate) for { var p fed.Packet - p, err = client.ReadPacket(true) + p, err = client.ReadPacket(context.Background(), true) if err != nil { if errors.Is(err, io.EOF) { break @@ -63,7 +64,7 @@ func (T *Runner) runModeL1(dialer pool.Dialer, client *fed.Conn) error { return err } - clientErr, serverErr := bouncers.Bounce(client, server, p) + clientErr, serverErr := bouncers.Bounce(context.Background(), client, server, p) if clientErr != nil { return clientErr } @@ -85,7 +86,7 @@ func (T *Runner) runModeOnce(dialer pool.Dialer) ([]byte, error) { return nil, err } - if err := inward.Close(); err != nil { + if err := inward.Close(context.Background()); err != nil { return nil, err }