From b37fe718b407656f6a84816f2b575a0961b60a36 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 16 Aug 2023 20:07:57 -0500
Subject: [PATCH] afgfdg

---
 lib/bouncer/backends/v0/query.go        |  4 ++++
 lib/middleware/middlewares/ps/server.go |  2 ++
 lib/middleware/middlewares/ps/sync.go   | 27 ++++++++++++++++---------
 lib/util/encoding/ini/unmarshal.go      |  2 +-
 lib/util/strutil/cistring.go            |  5 +++--
 lib/util/strutil/escape.go              |  2 ++
 pgbouncer.ini                           |  1 +
 7 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go
index 7abcdc5f..fe589a5a 100644
--- a/lib/bouncer/backends/v0/query.go
+++ b/lib/bouncer/backends/v0/query.go
@@ -132,6 +132,10 @@ func SetParameter(ctx *Context, server zap.ReadWriter, name strutil.CIString, va
 	return QueryString(ctx, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`))
 }
 
+func ResetParameter(ctx *Context, server zap.ReadWriter, name strutil.CIString) error {
+	return QueryString(ctx, server, `RESET `+strutil.Escape(name.String(), `"`))
+}
+
 func FunctionCall(ctx *Context, server zap.ReadWriter, packet *zap.Packet) error {
 	if err := server.Write(packet); err != nil {
 		return err
diff --git a/lib/middleware/middlewares/ps/server.go b/lib/middleware/middlewares/ps/server.go
index 565c2915..37cd60c8 100644
--- a/lib/middleware/middlewares/ps/server.go
+++ b/lib/middleware/middlewares/ps/server.go
@@ -2,6 +2,7 @@ package ps
 
 import (
 	"errors"
+	"log"
 
 	"pggat2/lib/middleware"
 	"pggat2/lib/util/strutil"
@@ -29,6 +30,7 @@ func (T *Server) Read(_ middleware.Context, in *zap.Packet) error {
 			return errors.New("bad packet format")
 		}
 		ikey := strutil.MakeCIString(key)
+		log.Printf("backend updated %s = %s", ikey.String(), value)
 		if T.parameters == nil {
 			T.parameters = make(map[strutil.CIString]string)
 		}
diff --git a/lib/middleware/middlewares/ps/sync.go b/lib/middleware/middlewares/ps/sync.go
index c6686af0..e7ae97aa 100644
--- a/lib/middleware/middlewares/ps/sync.go
+++ b/lib/middleware/middlewares/ps/sync.go
@@ -1,6 +1,8 @@
 package ps
 
 import (
+	"log"
+
 	"pggat2/lib/bouncer/backends/v0"
 	"pggat2/lib/util/slices"
 	"pggat2/lib/util/strutil"
@@ -21,20 +23,27 @@ func sync(tracking []strutil.CIString, clientPackets *zap.Packets, c *Client, se
 		return
 	}
 
-	if hasValue && slices.Contains(tracking, name) {
-		if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil {
-			panic(err) // TODO(garet)
-		}
-		if s.parameters == nil {
-			s.parameters = make(map[strutil.CIString]string)
+	if slices.Contains(tracking, name) {
+		if hasValue {
+			log.Printf("backend set %s = %s", name.String(), value)
+			if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil {
+				panic(err) // TODO(garet)
+			}
+			if s.parameters == nil {
+				s.parameters = make(map[strutil.CIString]string)
+			}
+			s.parameters[name] = value
+		} else {
+			log.Printf("backend reset %s", name.String())
+			if err := backends.ResetParameter(&backends.Context{}, server, name); err != nil {
+				panic(err) // TODO(garet)
+			}
+			delete(s.parameters, name)
 		}
 	} else if hasExpected {
 		pkt := zap.NewPacket()
 		packets.WriteParameterStatus(pkt, name.String(), expected)
 		clientPackets.Append(pkt)
-		if c.parameters == nil {
-			c.parameters = make(map[strutil.CIString]string)
-		}
 	}
 }
 
diff --git a/lib/util/encoding/ini/unmarshal.go b/lib/util/encoding/ini/unmarshal.go
index 9e5a875a..b41956d6 100644
--- a/lib/util/encoding/ini/unmarshal.go
+++ b/lib/util/encoding/ini/unmarshal.go
@@ -134,7 +134,7 @@ outer:
 		return nil
 	case reflect.Slice:
 		items := bytes.Split(value, []byte{','})
-		slice := reflect.MakeSlice(rt.Elem(), len(items), len(items))
+		slice := reflect.MakeSlice(rt, len(items), len(items))
 		for i, item := range items {
 			if err := set(slice.Index(i), bytes.TrimSpace(item)); err != nil {
 				return err
diff --git a/lib/util/strutil/cistring.go b/lib/util/strutil/cistring.go
index 05dc7b83..096bb366 100644
--- a/lib/util/strutil/cistring.go
+++ b/lib/util/strutil/cistring.go
@@ -1,6 +1,7 @@
 package strutil
 
 import (
+	"bytes"
 	"encoding/json"
 	"strings"
 
@@ -37,8 +38,8 @@ func (T *CIString) UnmarshalJSON(bytes []byte) error {
 var _ json.Marshaler = (*CIString)(nil)
 var _ json.Unmarshaler = (*CIString)(nil)
 
-func (T *CIString) UnmarshalINI(bytes []byte) error {
-	T.value = strings.ToLower(string(bytes))
+func (T *CIString) UnmarshalINI(b []byte) error {
+	T.value = string(bytes.ToLower(b))
 	return nil
 }
 
diff --git a/lib/util/strutil/escape.go b/lib/util/strutil/escape.go
index 42928856..0a32cf6f 100644
--- a/lib/util/strutil/escape.go
+++ b/lib/util/strutil/escape.go
@@ -7,6 +7,7 @@ import (
 
 func Escape(str, sequence string) string {
 	var b strings.Builder
+	b.WriteString(sequence)
 	for len(str) > 0 {
 		if strings.HasPrefix(str, sequence) {
 			b.WriteByte('\\')
@@ -26,5 +27,6 @@ func Escape(str, sequence string) string {
 		b.WriteRune(r)
 		str = str[size:]
 	}
+	b.WriteString(sequence)
 	return b.String()
 }
diff --git a/pgbouncer.ini b/pgbouncer.ini
index 6dc48900..f6f8ec66 100644
--- a/pgbouncer.ini
+++ b/pgbouncer.ini
@@ -2,6 +2,7 @@
 pool_mode = transaction
 auth_file = userlist.txt
 listen_addr = *
+track_extra_parameters = IntervalStyle, session_authorization, default_transaction_read_only
 
 [users]
 postgres =
-- 
GitLab