diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index 38791993da287e88a3a4a4fdb0a84d4e16840a50..aeabbad83b16e944c49df68194090cec17d62150 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -136,14 +136,6 @@ func authenticationMD5(client fed.Conn, creds auth.MD5) perror.Error { return nil } -func updateParameter(client fed.Conn, name, value string) perror.Error { - ps := packets.ParameterStatus{ - Key: name, - Value: value, - } - return perror.Wrap(client.WritePacket(ps.IntoPacket())) -} - func authenticate(client fed.Conn, options AuthenticateOptions) (params AuthenticateParams, err perror.Error) { if options.Credentials == nil { err = perror.New( @@ -188,22 +180,6 @@ func authenticate(client fed.Conn, options AuthenticateOptions) (params Authenti return } - if err = updateParameter(client, "client_encoding", "UTF8"); err != nil { - return - } - if err = updateParameter(client, "server_encoding", "UTF8"); err != nil { - return - } - if err = updateParameter(client, "server_version", "14.5"); err != nil { - return - } - - // send ready for query - rfq := packets.ReadyForQuery('I') - if err = perror.Wrap(client.WritePacket(rfq.IntoPacket())); err != nil { - return - } - return } diff --git a/lib/gat/modes/pgbouncer/pools.go b/lib/gat/modes/pgbouncer/pools.go index 0080800e41e29d5d9fb05fa1bd548211fd121d86..eff75604e61a286d47a0df3fd4ba83afec4dbd1b 100644 --- a/lib/gat/modes/pgbouncer/pools.go +++ b/lib/gat/modes/pgbouncer/pools.go @@ -109,7 +109,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { log.Println("auth query failed:", err) return nil } - err = authPool.Serve(client, nil, [8]byte{}) + err = authPool.ServeBot(client) if err != nil && !errors.Is(err, net.ErrClosed) { log.Println("auth query failed:", err) return nil diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 11a23a69e25c2cedc1ed853d33dfd585daafec73..118e224883ee6c92963b6eda757abdcb2d611ae9 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -11,6 +11,7 @@ import ( "pggat/lib/bouncer/backends/v0" "pggat/lib/bouncer/bouncers/v2" "pggat/lib/fed" + packets "pggat/lib/fed/packets/v3.0" "pggat/lib/gat/metrics" "pggat/lib/gat/pool/recipe" "pggat/lib/util/slices" @@ -256,16 +257,61 @@ func (T *Pool) Serve( backendKey, ) - return T.serve(client) + return T.serve(client, false) } -func (T *Pool) serve(client *Client) error { +// ServeBot is for clients that don't need initial parameters, cancelling queries, and are ready now. Use Serve for +// real clients +func (T *Pool) ServeBot( + conn fed.Conn, +) error { + defer func() { + _ = conn.Close() + }() + + client := NewClient( + T.options, + conn, + nil, + [8]byte{}, + ) + + return T.serve(client, true) +} + +func (T *Pool) serve(client *Client, initialize bool) error { T.addClient(client) defer T.removeClient(client) var server *Server + if !initialize { + server = T.acquireServer(client) + + err, serverErr := Pair(T.options, client, server) + if serverErr != nil { + T.removeServer(server) + return serverErr + } + if err != nil { + T.releaseServer(server) + return err + } + + p := packets.ReadyForQuery('I') + err = client.GetConn().WritePacket(p.IntoPacket()) + if err != nil { + T.releaseServer(server) + return err + } + } for { + if server != nil && T.options.ReleaseAfterTransaction { + client.SetState(metrics.ConnStateIdle, uuid.Nil) + go T.releaseServer(server) // TODO(garet) does this need to be a goroutine + server = nil + } + packet, err := client.GetConn().ReadPacket(true) if err != nil { if server != nil { @@ -288,17 +334,11 @@ func (T *Pool) serve(client *Client) error { return serverErr } else { TransactionComplete(client, server) - if T.options.ReleaseAfterTransaction { - client.SetState(metrics.ConnStateIdle, uuid.Nil) - go T.releaseServer(server) // TODO(garet) does this need to be a goroutine - server = nil - } + } if err != nil { - if server != nil { - T.releaseServer(server) - } + T.releaseServer(server) return err } } diff --git a/lib/middleware/middlewares/ps/sync.go b/lib/middleware/middlewares/ps/sync.go index c55e36c1fb72e55bdf9e0bd10ee80f27e2407c70..3dc6ca50318a76506d5138b3d673e6fff4bdc773 100644 --- a/lib/middleware/middlewares/ps/sync.go +++ b/lib/middleware/middlewares/ps/sync.go @@ -25,22 +25,23 @@ func sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server return nil } - if slices.Contains(tracking, name) { - if hasValue { - if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil { - return err - } - if s.parameters == nil { - s.parameters = make(map[strutil.CIString]string) - } - s.parameters[name] = value - } else { - if err := backends.ResetParameter(&backends.Context{}, server, name); err != nil { - return err - } - delete(s.parameters, name) + var doSet bool + + if hasValue && slices.Contains(tracking, name) { + if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil { + return err + } + if s.parameters == nil { + s.parameters = make(map[strutil.CIString]string) } + s.parameters[name] = value + + doSet = true } else if hasExpected { + doSet = true + } + + if doSet { ps := packets.ParameterStatus{ Key: name.String(), Value: expected, diff --git a/test/runner.go b/test/runner.go index 59a2063fe53e0eed5d38918cddf7bce76065bc60..07689d9aed949f4f4645cc97d4ed14f986e0a090 100644 --- a/test/runner.go +++ b/test/runner.go @@ -138,7 +138,7 @@ func (T *Runner) runMode(options pool.Options) ([]Capturer, error) { return nil, err } - if err := p.Serve(&client, nil, [8]byte{}); err != nil && !errors.Is(err, io.EOF) { + if err := p.ServeBot(&client); err != nil && !errors.Is(err, io.EOF) { return nil, err }