diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index 7be9b631fc1b20ba8d7fd01b8e84f7b7b629ad7d..d86b46c19faec8c64321c1dc02b44849f6e9ab66 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -1,7 +1,7 @@ package backends import ( - "fmt" + "strings" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" @@ -117,10 +117,20 @@ func QueryString(server, peer *fed.Conn, query string) (err, peerError error) { } func SetParameter(server, peer *fed.Conn, name strutil.CIString, value string) (err, peerError error) { + var q strings.Builder + escapedName := strutil.Escape(name.String(), '"') + escapedValue := strutil.Escape(value, '\'') + q.Grow(len(`SET "" = ''`) + len(escapedName) + len(escapedValue)) + q.WriteString(`SET "`) + q.WriteString(escapedName) + q.WriteString(`" = '`) + q.WriteString(escapedValue) + q.WriteString(`'`) + return QueryString( server, peer, - fmt.Sprintf(`SET "%s" = '%s'`, strutil.Escape(name.String(), '"'), strutil.Escape(value, '\'')), + q.String(), ) } diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index 8a0e0cf34c588751bc1214442b3a3a9d79729ba4..05fc56e53b2860c246faf1e5549f87a8ce9715dc 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -1,11 +1,25 @@ package eqp import ( + "slices" + "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) +func preparedStatementsEqual(a, b *packets.Parse) bool { + if a.Query != b.Query { + return false + } + + if !slices.Equal(a.ParameterDataTypes, b.ParameterDataTypes) { + return false + } + + return true +} + func Sync(c *Client, server *fed.Conn, s *Server) error { var needsBackendSync bool @@ -29,9 +43,9 @@ func Sync(c *Client, server *fed.Conn, s *Server) error { // close all prepared statements that don't match client for name, preparedStatement := range s.state.preparedStatements { if clientPreparedStatement, ok := c.state.preparedStatements[name]; ok { - // TODO(garet) do not overwrite prepared statements that match - _ = preparedStatement - _ = clientPreparedStatement + if preparedStatementsEqual(preparedStatement, clientPreparedStatement) { + continue + } if name == "" { // will be overwritten @@ -53,9 +67,9 @@ func Sync(c *Client, server *fed.Conn, s *Server) error { // parse all prepared statements that aren't on server for name, preparedStatement := range c.state.preparedStatements { if serverPreparedStatement, ok := s.state.preparedStatements[name]; ok { - // TODO(garet) do not overwrite prepared statements that match - _ = preparedStatement - _ = serverPreparedStatement + if preparedStatementsEqual(preparedStatement, serverPreparedStatement) { + continue + } } if err := server.WritePacket(preparedStatement); err != nil {