From b37fe718b407656f6a84816f2b575a0961b60a36 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Wed, 16 Aug 2023 20:07:57 -0500 Subject: [PATCH] afgfdg --- lib/bouncer/backends/v0/query.go | 4 ++++ lib/middleware/middlewares/ps/server.go | 2 ++ lib/middleware/middlewares/ps/sync.go | 27 ++++++++++++++++--------- lib/util/encoding/ini/unmarshal.go | 2 +- lib/util/strutil/cistring.go | 5 +++-- lib/util/strutil/escape.go | 2 ++ pgbouncer.ini | 1 + 7 files changed, 31 insertions(+), 12 deletions(-) diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index 7abcdc5f..fe589a5a 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -132,6 +132,10 @@ func SetParameter(ctx *Context, server zap.ReadWriter, name strutil.CIString, va return QueryString(ctx, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`)) } +func ResetParameter(ctx *Context, server zap.ReadWriter, name strutil.CIString) error { + return QueryString(ctx, server, `RESET `+strutil.Escape(name.String(), `"`)) +} + func FunctionCall(ctx *Context, server zap.ReadWriter, packet *zap.Packet) error { if err := server.Write(packet); err != nil { return err diff --git a/lib/middleware/middlewares/ps/server.go b/lib/middleware/middlewares/ps/server.go index 565c2915..37cd60c8 100644 --- a/lib/middleware/middlewares/ps/server.go +++ b/lib/middleware/middlewares/ps/server.go @@ -2,6 +2,7 @@ package ps import ( "errors" + "log" "pggat2/lib/middleware" "pggat2/lib/util/strutil" @@ -29,6 +30,7 @@ func (T *Server) Read(_ middleware.Context, in *zap.Packet) error { return errors.New("bad packet format") } ikey := strutil.MakeCIString(key) + log.Printf("backend updated %s = %s", ikey.String(), value) if T.parameters == nil { T.parameters = make(map[strutil.CIString]string) } diff --git a/lib/middleware/middlewares/ps/sync.go b/lib/middleware/middlewares/ps/sync.go index c6686af0..e7ae97aa 100644 --- a/lib/middleware/middlewares/ps/sync.go +++ b/lib/middleware/middlewares/ps/sync.go @@ -1,6 +1,8 @@ package ps import ( + "log" + "pggat2/lib/bouncer/backends/v0" "pggat2/lib/util/slices" "pggat2/lib/util/strutil" @@ -21,20 +23,27 @@ func sync(tracking []strutil.CIString, clientPackets *zap.Packets, c *Client, se return } - if hasValue && slices.Contains(tracking, name) { - if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil { - panic(err) // TODO(garet) - } - if s.parameters == nil { - s.parameters = make(map[strutil.CIString]string) + if slices.Contains(tracking, name) { + if hasValue { + log.Printf("backend set %s = %s", name.String(), value) + if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil { + panic(err) // TODO(garet) + } + if s.parameters == nil { + s.parameters = make(map[strutil.CIString]string) + } + s.parameters[name] = value + } else { + log.Printf("backend reset %s", name.String()) + if err := backends.ResetParameter(&backends.Context{}, server, name); err != nil { + panic(err) // TODO(garet) + } + delete(s.parameters, name) } } else if hasExpected { pkt := zap.NewPacket() packets.WriteParameterStatus(pkt, name.String(), expected) clientPackets.Append(pkt) - if c.parameters == nil { - c.parameters = make(map[strutil.CIString]string) - } } } diff --git a/lib/util/encoding/ini/unmarshal.go b/lib/util/encoding/ini/unmarshal.go index 9e5a875a..b41956d6 100644 --- a/lib/util/encoding/ini/unmarshal.go +++ b/lib/util/encoding/ini/unmarshal.go @@ -134,7 +134,7 @@ outer: return nil case reflect.Slice: items := bytes.Split(value, []byte{','}) - slice := reflect.MakeSlice(rt.Elem(), len(items), len(items)) + slice := reflect.MakeSlice(rt, len(items), len(items)) for i, item := range items { if err := set(slice.Index(i), bytes.TrimSpace(item)); err != nil { return err diff --git a/lib/util/strutil/cistring.go b/lib/util/strutil/cistring.go index 05dc7b83..096bb366 100644 --- a/lib/util/strutil/cistring.go +++ b/lib/util/strutil/cistring.go @@ -1,6 +1,7 @@ package strutil import ( + "bytes" "encoding/json" "strings" @@ -37,8 +38,8 @@ func (T *CIString) UnmarshalJSON(bytes []byte) error { var _ json.Marshaler = (*CIString)(nil) var _ json.Unmarshaler = (*CIString)(nil) -func (T *CIString) UnmarshalINI(bytes []byte) error { - T.value = strings.ToLower(string(bytes)) +func (T *CIString) UnmarshalINI(b []byte) error { + T.value = string(bytes.ToLower(b)) return nil } diff --git a/lib/util/strutil/escape.go b/lib/util/strutil/escape.go index 42928856..0a32cf6f 100644 --- a/lib/util/strutil/escape.go +++ b/lib/util/strutil/escape.go @@ -7,6 +7,7 @@ import ( func Escape(str, sequence string) string { var b strings.Builder + b.WriteString(sequence) for len(str) > 0 { if strings.HasPrefix(str, sequence) { b.WriteByte('\\') @@ -26,5 +27,6 @@ func Escape(str, sequence string) string { b.WriteRune(r) str = str[size:] } + b.WriteString(sequence) return b.String() } diff --git a/pgbouncer.ini b/pgbouncer.ini index 6dc48900..f6f8ec66 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -2,6 +2,7 @@ pool_mode = transaction auth_file = userlist.txt listen_addr = * +track_extra_parameters = IntervalStyle, session_authorization, default_transaction_read_only [users] postgres = -- GitLab