From 34bfd3904fc9936ed49915d5e0b631f19c97939c Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Thu, 14 Sep 2023 19:56:39 -0500 Subject: [PATCH] small optimizations --- lib/bouncer/backends/v0/query.go | 8 ++++++- lib/gat/pool/conn.go | 8 +++++++ lib/gat/pool/flow.go | 6 ++--- lib/gat/pool/pool.go | 4 ++-- lib/util/strutil/escape.go | 38 ++++++++++++++++---------------- 5 files changed, 39 insertions(+), 25 deletions(-) diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index 44e57a79..b2ac4510 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -1,6 +1,8 @@ package backends import ( + "fmt" + "pggat/lib/fed" packets "pggat/lib/fed/packets/v3.0" "pggat/lib/util/strutil" @@ -106,7 +108,11 @@ func QueryString(ctx *Context, server fed.ReadWriter, query string) error { } func SetParameter(ctx *Context, server fed.ReadWriter, name strutil.CIString, value string) error { - return QueryString(ctx, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`)) + return QueryString( + ctx, + server, + fmt.Sprintf(`SET "%s" = '%s'`, strutil.Escape(name.String(), '"'), strutil.Escape(value, '\'')), + ) } func FunctionCall(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { diff --git a/lib/gat/pool/conn.go b/lib/gat/pool/conn.go index bfd5284d..1a9e6b0f 100644 --- a/lib/gat/pool/conn.go +++ b/lib/gat/pool/conn.go @@ -16,6 +16,8 @@ type Conn struct { id uuid.UUID conn fed.Conn + // please someone fix runtime.convI2I + rw fed.ReadWriter initialParameters map[strutil.CIString]string backendKey [8]byte @@ -44,6 +46,7 @@ func MakeConn( return Conn{ id: id, conn: conn, + rw: conn, initialParameters: initialParameters, backendKey: backendKey, @@ -59,6 +62,11 @@ func (T *Conn) GetConn() fed.Conn { return T.conn } +// GetReadWriter is the exact same as GetConn but bypasses the runtime.convI2I +func (T *Conn) GetReadWriter() fed.ReadWriter { + return T.rw +} + func (T *Conn) GetInitialParameters() map[strutil.CIString]string { return T.initialParameters } diff --git a/lib/gat/pool/flow.go b/lib/gat/pool/flow.go index d051e5a9..677d0ddd 100644 --- a/lib/gat/pool/flow.go +++ b/lib/gat/pool/flow.go @@ -22,7 +22,7 @@ func Pair(options Options, client *Client, server *Server) (clientErr, serverErr switch options.ParameterStatusSync { case ParameterStatusSyncDynamic: - clientErr, serverErr = ps.Sync(options.TrackedParameters, client.GetConn(), client.GetPS(), server.GetConn(), server.GetPS()) + clientErr, serverErr = ps.Sync(options.TrackedParameters, client.GetReadWriter(), client.GetPS(), server.GetReadWriter(), server.GetPS()) case ParameterStatusSyncInitial: clientErr, serverErr = SyncInitialParameters(options, client, server) } @@ -32,7 +32,7 @@ func Pair(options Options, client *Client, server *Server) (clientErr, serverErr } if options.ExtendedQuerySync { - serverErr = eqp.Sync(client.GetEQP(), server.GetConn(), server.GetEQP()) + serverErr = eqp.Sync(client.GetEQP(), server.GetReadWriter(), server.GetEQP()) } return @@ -65,7 +65,7 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli continue } - serverErr = backends.SetParameter(new(backends.Context), server.GetConn(), key, value) + serverErr = backends.SetParameter(new(backends.Context), server.GetReadWriter(), key, value) if serverErr != nil { return } diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 2d0e6c5f..77bcb2e0 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -283,7 +283,7 @@ func (T *Pool) releaseServer(server *Server) { server.SetState(metrics.ConnStateRunningResetQuery, uuid.Nil) if T.options.ServerResetQuery != "" { - err := backends.QueryString(new(backends.Context), server.GetConn(), T.options.ServerResetQuery) + err := backends.QueryString(new(backends.Context), server.GetReadWriter(), T.options.ServerResetQuery) if err != nil { T.removeServer(server) return @@ -381,7 +381,7 @@ func (T *Pool) serve(client *Client, initialize bool) error { err, serverErr = Pair(T.options, client, server) } if err == nil && serverErr == nil { - err, serverErr = bouncers.Bounce(client.GetConn(), server.GetConn(), packet) + err, serverErr = bouncers.Bounce(client.GetReadWriter(), server.GetReadWriter(), packet) } if serverErr != nil { T.removeServer(server) diff --git a/lib/util/strutil/escape.go b/lib/util/strutil/escape.go index 0a32cf6f..50eabdf7 100644 --- a/lib/util/strutil/escape.go +++ b/lib/util/strutil/escape.go @@ -5,28 +5,28 @@ import ( "unicode/utf8" ) -func Escape(str, sequence string) string { - var b strings.Builder - b.WriteString(sequence) - for len(str) > 0 { - if strings.HasPrefix(str, sequence) { - b.WriteByte('\\') - b.WriteString(sequence) - str = str[len(sequence):] - continue - } - if strings.HasPrefix(str, "\\") { - b.WriteString("\\\\") - str = str[1:] - continue +func Escape(str string, char rune) string { + size := 0 + escape := false + // check if it has any bad characters + for _, r := range str { + size += utf8.RuneLen(r) + if r == char || r == '\\' { + size += 1 + escape = true } - r, size := utf8.DecodeRuneInString(str) - if r == utf8.RuneError { - return "" + } + if !escape { + return str + } + + var b strings.Builder + b.Grow(size) + for _, r := range str { + if char == r || r == '\\' { + b.WriteRune('\\') } b.WriteRune(r) - str = str[size:] } - b.WriteString(sequence) return b.String() } -- GitLab