diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index 73afaeb8a28926e86cbf583e1ad7d62229ea91f7..7abcdc5f6833856e6e7a10cada7b8a7c84e7a784 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -1,6 +1,7 @@ package backends import ( + "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" ) @@ -127,6 +128,10 @@ func QueryString(ctx *Context, server zap.ReadWriter, query string) error { return Query(ctx, server, packet) } +func SetParameter(ctx *Context, server zap.ReadWriter, name strutil.CIString, value string) error { + return QueryString(ctx, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`)) +} + func FunctionCall(ctx *Context, server zap.ReadWriter, packet *zap.Packet) error { if err := server.Write(packet); err != nil { return err diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index 732247caf8320ac004dfb0bc4172b3a24fda18c1..ba1f32b6c2652a939b7f9e02d651b0f6388b3b0c 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -133,7 +133,7 @@ func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, ps map[strutil.CIS packets.WriteParameterStatus(pkt, key.String(), value) pkts.Append(pkt) - if err := backends.QueryString(&backends.Context{}, conn.rw, `SET `+strutil.Escape(key.String(), `"`)+` = `+strutil.Escape(value, `'`)); err != nil { + if err := backends.SetParameter(&backends.Context{}, conn.rw, key, value); err != nil { connOk = false return true } diff --git a/lib/middleware/middlewares/ps/sync.go b/lib/middleware/middlewares/ps/sync.go index aa752f54e317fc6117c05b623af0d472e9ed9e50..e4547aec18b7ba39b318e21dd95fa9d52269dcd0 100644 --- a/lib/middleware/middlewares/ps/sync.go +++ b/lib/middleware/middlewares/ps/sync.go @@ -23,7 +23,7 @@ func sync(tracking []strutil.CIString, clientPackets *zap.Packets, c *Client, se } if slices.Contains(tracking, name) { - if err := backends.QueryString(&backends.Context{}, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`)); err != nil { + if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil { panic(err) // TODO(garet) } if s.parameters == nil { diff --git a/pgbouncer.ini b/pgbouncer.ini index 6dc489006c53ecee5f04c3e439677d79148f520f..118c7214d648c51f71cf2af1af86ee01e04cb7ee 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -1,5 +1,5 @@ [pgbouncer] -pool_mode = transaction +pool_mode = session auth_file = userlist.txt listen_addr = *