From 34bfd3904fc9936ed49915d5e0b631f19c97939c Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Thu, 14 Sep 2023 19:56:39 -0500
Subject: [PATCH] small optimizations

---
 lib/bouncer/backends/v0/query.go |  8 ++++++-
 lib/gat/pool/conn.go             |  8 +++++++
 lib/gat/pool/flow.go             |  6 ++---
 lib/gat/pool/pool.go             |  4 ++--
 lib/util/strutil/escape.go       | 38 ++++++++++++++++----------------
 5 files changed, 39 insertions(+), 25 deletions(-)

diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go
index 44e57a79..b2ac4510 100644
--- a/lib/bouncer/backends/v0/query.go
+++ b/lib/bouncer/backends/v0/query.go
@@ -1,6 +1,8 @@
 package backends
 
 import (
+	"fmt"
+
 	"pggat/lib/fed"
 	packets "pggat/lib/fed/packets/v3.0"
 	"pggat/lib/util/strutil"
@@ -106,7 +108,11 @@ func QueryString(ctx *Context, server fed.ReadWriter, query string) error {
 }
 
 func SetParameter(ctx *Context, server fed.ReadWriter, name strutil.CIString, value string) error {
-	return QueryString(ctx, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`))
+	return QueryString(
+		ctx,
+		server,
+		fmt.Sprintf(`SET "%s" = '%s'`, strutil.Escape(name.String(), '"'), strutil.Escape(value, '\'')),
+	)
 }
 
 func FunctionCall(ctx *Context, server fed.ReadWriter, packet fed.Packet) error {
diff --git a/lib/gat/pool/conn.go b/lib/gat/pool/conn.go
index bfd5284d..1a9e6b0f 100644
--- a/lib/gat/pool/conn.go
+++ b/lib/gat/pool/conn.go
@@ -16,6 +16,8 @@ type Conn struct {
 	id uuid.UUID
 
 	conn fed.Conn
+	// please someone fix runtime.convI2I
+	rw fed.ReadWriter
 
 	initialParameters map[strutil.CIString]string
 	backendKey        [8]byte
@@ -44,6 +46,7 @@ func MakeConn(
 	return Conn{
 		id:                id,
 		conn:              conn,
+		rw:                conn,
 		initialParameters: initialParameters,
 		backendKey:        backendKey,
 
@@ -59,6 +62,11 @@ func (T *Conn) GetConn() fed.Conn {
 	return T.conn
 }
 
+// GetReadWriter is the exact same as GetConn but bypasses the runtime.convI2I
+func (T *Conn) GetReadWriter() fed.ReadWriter {
+	return T.rw
+}
+
 func (T *Conn) GetInitialParameters() map[strutil.CIString]string {
 	return T.initialParameters
 }
diff --git a/lib/gat/pool/flow.go b/lib/gat/pool/flow.go
index d051e5a9..677d0ddd 100644
--- a/lib/gat/pool/flow.go
+++ b/lib/gat/pool/flow.go
@@ -22,7 +22,7 @@ func Pair(options Options, client *Client, server *Server) (clientErr, serverErr
 
 	switch options.ParameterStatusSync {
 	case ParameterStatusSyncDynamic:
-		clientErr, serverErr = ps.Sync(options.TrackedParameters, client.GetConn(), client.GetPS(), server.GetConn(), server.GetPS())
+		clientErr, serverErr = ps.Sync(options.TrackedParameters, client.GetReadWriter(), client.GetPS(), server.GetReadWriter(), server.GetPS())
 	case ParameterStatusSyncInitial:
 		clientErr, serverErr = SyncInitialParameters(options, client, server)
 	}
@@ -32,7 +32,7 @@ func Pair(options Options, client *Client, server *Server) (clientErr, serverErr
 	}
 
 	if options.ExtendedQuerySync {
-		serverErr = eqp.Sync(client.GetEQP(), server.GetConn(), server.GetEQP())
+		serverErr = eqp.Sync(client.GetEQP(), server.GetReadWriter(), server.GetEQP())
 	}
 
 	return
@@ -65,7 +65,7 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli
 			continue
 		}
 
-		serverErr = backends.SetParameter(new(backends.Context), server.GetConn(), key, value)
+		serverErr = backends.SetParameter(new(backends.Context), server.GetReadWriter(), key, value)
 		if serverErr != nil {
 			return
 		}
diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go
index 2d0e6c5f..77bcb2e0 100644
--- a/lib/gat/pool/pool.go
+++ b/lib/gat/pool/pool.go
@@ -283,7 +283,7 @@ func (T *Pool) releaseServer(server *Server) {
 	server.SetState(metrics.ConnStateRunningResetQuery, uuid.Nil)
 
 	if T.options.ServerResetQuery != "" {
-		err := backends.QueryString(new(backends.Context), server.GetConn(), T.options.ServerResetQuery)
+		err := backends.QueryString(new(backends.Context), server.GetReadWriter(), T.options.ServerResetQuery)
 		if err != nil {
 			T.removeServer(server)
 			return
@@ -381,7 +381,7 @@ func (T *Pool) serve(client *Client, initialize bool) error {
 			err, serverErr = Pair(T.options, client, server)
 		}
 		if err == nil && serverErr == nil {
-			err, serverErr = bouncers.Bounce(client.GetConn(), server.GetConn(), packet)
+			err, serverErr = bouncers.Bounce(client.GetReadWriter(), server.GetReadWriter(), packet)
 		}
 		if serverErr != nil {
 			T.removeServer(server)
diff --git a/lib/util/strutil/escape.go b/lib/util/strutil/escape.go
index 0a32cf6f..50eabdf7 100644
--- a/lib/util/strutil/escape.go
+++ b/lib/util/strutil/escape.go
@@ -5,28 +5,28 @@ import (
 	"unicode/utf8"
 )
 
-func Escape(str, sequence string) string {
-	var b strings.Builder
-	b.WriteString(sequence)
-	for len(str) > 0 {
-		if strings.HasPrefix(str, sequence) {
-			b.WriteByte('\\')
-			b.WriteString(sequence)
-			str = str[len(sequence):]
-			continue
-		}
-		if strings.HasPrefix(str, "\\") {
-			b.WriteString("\\\\")
-			str = str[1:]
-			continue
+func Escape(str string, char rune) string {
+	size := 0
+	escape := false
+	// check if it has any bad characters
+	for _, r := range str {
+		size += utf8.RuneLen(r)
+		if r == char || r == '\\' {
+			size += 1
+			escape = true
 		}
-		r, size := utf8.DecodeRuneInString(str)
-		if r == utf8.RuneError {
-			return ""
+	}
+	if !escape {
+		return str
+	}
+
+	var b strings.Builder
+	b.Grow(size)
+	for _, r := range str {
+		if char == r || r == '\\' {
+			b.WriteRune('\\')
 		}
 		b.WriteRune(r)
-		str = str[size:]
 	}
-	b.WriteString(sequence)
 	return b.String()
 }
-- 
GitLab