diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 7d8fd084085e803d31fd94eb64ddb9ff23e4c0f7..81eb84d5b3c15523388006947f5f740525d74482 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -6,7 +6,6 @@ import ( _ "net/http/pprof" "os" - "pggat2/lib/gat" "pggat2/lib/gat/configs/pgbouncer" ) @@ -26,9 +25,7 @@ func main() { panic(err) } - pooler := gat.NewPooler() - - err = conf.ListenAndServe(pooler) + err = conf.ListenAndServe() if err != nil { panic(err) } diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 59a6d34a4424b90c4650dc57f55f8c71063a8026..729d65501545f61d4d2733c8fcc6133f483af116 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -4,6 +4,7 @@ import ( "errors" "pggat2/lib/auth" + "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" ) @@ -195,7 +196,7 @@ func startup0(server zap.ReadWriter, creds auth.Credentials) (done bool, err err } } -func startup1(server zap.ReadWriter, parameterStatus map[string]string) (done bool, err error) { +func startup1(server zap.ReadWriter, parameterStatus map[strutil.CIString]string) (done bool, err error) { packet := zap.NewPacket() defer packet.Done() err = server.Read(packet) @@ -220,7 +221,8 @@ func startup1(server zap.ReadWriter, parameterStatus map[string]string) (done bo err = ErrBadFormat return } - parameterStatus[key] = value + ikey := strutil.MakeCIString(key) + parameterStatus[ikey] = value return false, nil case packets.ReadyForQuery: return true, nil @@ -241,7 +243,7 @@ func startup1(server zap.ReadWriter, parameterStatus map[string]string) (done bo } } -func Accept(server zap.ReadWriter, creds auth.Credentials, database string, startupParameters map[string]string) error { +func Accept(server zap.ReadWriter, creds auth.Credentials, database string, startupParameters map[strutil.CIString]string) error { if database == "" { database = creds.GetUsername() } @@ -255,7 +257,7 @@ func Accept(server zap.ReadWriter, creds auth.Credentials, database string, star packet.WriteString("database") packet.WriteString(database) for key, value := range startupParameters { - packet.WriteString(key) + packet.WriteString(key.String()) packet.WriteString(value) } packet.WriteString("") diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 7dcafbac503dd9a3df0d99afaa766ae94574fde9..1784a5a3900c7d2350b03ca3ffc682e96875bfba 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -3,15 +3,22 @@ package frontends import ( "crypto/rand" "errors" + "fmt" "strings" "pggat2/lib/auth" "pggat2/lib/perror" + "pggat2/lib/util/slices" + "pggat2/lib/util/strutil" "pggat2/lib/zap" "pggat2/lib/zap/packets/v3.0" ) -func startup0(client zap.ReadWriter, startupParameters map[string]string) (user, database string, done bool, err perror.Error) { +func startup0( + client zap.ReadWriter, + allowedStartupParameters []strutil.CIString, + startupParameters map[strutil.CIString]string, +) (user, database string, done bool, err perror.Error) { packet := zap.NewUntypedPacket() defer packet.Done() err = perror.Wrap(client.ReadUntyped(packet)) @@ -109,7 +116,17 @@ func startup0(client zap.ReadWriter, startupParameters map[string]string) (user, return } - startupParameters[key] = value + ikey := strutil.MakeCIString(key) + + if !slices.Contains(allowedStartupParameters, ikey) { + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + fmt.Sprintf(`Startup parameter "%s" is not allowed`, key), + ) + return + } + startupParameters[ikey] = value default: err = perror.New( perror.FATAL, @@ -131,7 +148,18 @@ func startup0(client zap.ReadWriter, startupParameters map[string]string) (user, // we don't support protocol extensions at the moment unsupportedOptions = append(unsupportedOptions, key) } else { - startupParameters[key] = value + ikey := strutil.MakeCIString(key) + + if !slices.Contains(allowedStartupParameters, ikey) { + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + fmt.Sprintf(`Startup parameter "%s" is not allowed`, key), + ) + return + } + + startupParameters[ikey] = value } } } @@ -269,12 +297,16 @@ func updateParameter(pkts *zap.Packets, name, value string) { pkts.Append(packet) } -func accept(client zap.ReadWriter, getCredentials func(user, database string) (auth.Credentials, bool)) (user string, database string, startupParameters map[string]string, err perror.Error) { - startupParameters = make(map[string]string) +func accept( + client zap.ReadWriter, + getCredentials func(user, database string) (auth.Credentials, bool), + allowedStartupParameters []strutil.CIString, +) (user string, database string, startupParameters map[strutil.CIString]string, err perror.Error) { + startupParameters = make(map[strutil.CIString]string) for { var done bool - user, database, done, err = startup0(client, startupParameters) + user, database, done, err = startup0(client, allowedStartupParameters, startupParameters) if err != nil { return } @@ -350,8 +382,12 @@ func fail(client zap.ReadWriter, err perror.Error) { _ = client.Write(packet) } -func Accept(client zap.ReadWriter, getCredentials func(user, database string) (auth.Credentials, bool)) (user, database string, startupParameters map[string]string, err perror.Error) { - user, database, startupParameters, err = accept(client, getCredentials) +func Accept( + client zap.ReadWriter, + getCredentials func(user, database string) (auth.Credentials, bool), + allowedStartupParameters []strutil.CIString, +) (user, database string, startupParameters map[strutil.CIString]string, err perror.Error) { + user, database, startupParameters, err = accept(client, getCredentials, allowedStartupParameters) if err != nil { fail(client, err) } diff --git a/lib/gat/configs/pgbouncer/config.go b/lib/gat/configs/pgbouncer/config.go index 3c77e7c57231cb4dd31880a0f127de7ea9bf80dc..1bfc0ff2a23ea81699018295c231f0701b7da702 100644 --- a/lib/gat/configs/pgbouncer/config.go +++ b/lib/gat/configs/pgbouncer/config.go @@ -15,8 +15,9 @@ import ( "pggat2/lib/gat" "pggat2/lib/gat/pools/session" "pggat2/lib/gat/pools/transaction" - ini2 "pggat2/lib/util/encoding/ini" + "pggat2/lib/util/encoding/ini" "pggat2/lib/util/encoding/userlist" + "pggat2/lib/util/strutil" ) type PoolMode string @@ -160,19 +161,19 @@ type PgBouncer struct { } type Database struct { - DBName string `ini:"dbname"` - Host string `ini:"host"` - Port int `ini:"port"` - User string `ini:"user"` - Password string `ini:"password"` - AuthUser string `ini:"auth_user"` - PoolSize int `ini:"pool_size"` - MinPoolSize int `ini:"min_pool_size"` - ReservePool int `ini:"reserve_pool"` - ConnectQuery string `ini:"connect_query"` - PoolMode PoolMode `ini:"pool_mode"` - MaxDBConnections int `ini:"max_db_connections"` - StartupParameters map[string]string `ini:"*"` + DBName string `ini:"dbname"` + Host string `ini:"host"` + Port int `ini:"port"` + User string `ini:"user"` + Password string `ini:"password"` + AuthUser string `ini:"auth_user"` + PoolSize int `ini:"pool_size"` + MinPoolSize int `ini:"min_pool_size"` + ReservePool int `ini:"reserve_pool"` + ConnectQuery string `ini:"connect_query"` + PoolMode PoolMode `ini:"pool_mode"` + MaxDBConnections int `ini:"max_db_connections"` + StartupParameters map[strutil.CIString]string `ini:"*"` } type User struct { @@ -253,17 +254,27 @@ var Default = Config{ } func Load(config string) (Config, error) { - conf, err := ini2.ReadFile(config) + conf, err := ini.ReadFile(config) if err != nil { return Config{}, err } var c = Default - err = ini2.Unmarshal(conf, &c) + err = ini.Unmarshal(conf, &c) return c, err } -func (T *Config) ListenAndServe(pooler *gat.Pooler) error { +func (T *Config) ListenAndServe() error { + pooler := gat.NewPooler(gat.PoolerConfig{ + AllowedStartupParameters: []strutil.CIString{ + strutil.MakeCIString("intervalstyle"), + strutil.MakeCIString("application_name"), + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + }, + }) + var authFile map[string]string if T.PgBouncer.AuthFile != "" { file, err := os.ReadFile(T.PgBouncer.AuthFile) @@ -286,19 +297,22 @@ func (T *Config) ListenAndServe(pooler *gat.Pooler) error { pooler.AddUser(name, u) for dbname, db := range T.Databases { + // filter out dbs specific to users if db.User != "" && db.User != name { continue } + // override dbname if db.DBName != "" { dbname = db.DBName } + // override poolmode var poolMode PoolMode - if user.PoolMode != "" { - poolMode = user.PoolMode - } else if db.PoolMode != "" { + if db.PoolMode != "" { poolMode = db.PoolMode + } else if user.PoolMode != "" { + poolMode = user.PoolMode } else { poolMode = T.PgBouncer.PoolMode } @@ -306,14 +320,18 @@ func (T *Config) ListenAndServe(pooler *gat.Pooler) error { var raw gat.RawPool switch poolMode { case PoolModeSession: - raw = session.NewPool(T.PgBouncer.ServerRoundRobin != 0) + raw = session.NewPool(session.Config{ + RoundRobin: T.PgBouncer.ServerRoundRobin != 0, + }) case PoolModeTransaction: raw = transaction.NewPool() default: return errors.New("unsupported pool mode") } - p := gat.NewPool(raw, time.Duration(T.PgBouncer.ServerIdleTimeout*float64(time.Second))) + p := gat.NewPool(raw, gat.PoolConfig{ + IdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)), + }) u.AddPool(dbname, p) if db.Host == "" { diff --git a/lib/gat/pool.go b/lib/gat/pool.go index faed75e1e7e728ab30e2454aa402b8fd01791da1..f198eb39cb965bfecf5fc99e68405e6653cf5f26 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -9,6 +9,7 @@ import ( "pggat2/lib/util/maps" "pggat2/lib/util/maths" + "pggat2/lib/util/strutil" "pggat2/lib/zap" ) @@ -17,9 +18,9 @@ type Context struct { } type RawPool interface { - Serve(ctx *Context, client zap.ReadWriter, startupParameters map[string]string) + Serve(ctx *Context, client zap.ReadWriter, startupParameters map[strutil.CIString]string) - AddServer(server zap.ReadWriter, startupParameters map[string]string) uuid.UUID + AddServer(server zap.ReadWriter, startupParameters map[strutil.CIString]string) uuid.UUID GetServer(id uuid.UUID) zap.ReadWriter RemoveServer(id uuid.UUID) zap.ReadWriter @@ -36,15 +37,23 @@ type PoolRecipe struct { } type Pool struct { + config PoolConfig + recipes maps.RWLocked[string, *PoolRecipe] ctx Context raw RawPool } -func NewPool(raw RawPool, idleTimeout time.Duration) *Pool { +type PoolConfig struct { + // IdleTimeout determines how long idle servers are kept in the pool + IdleTimeout time.Duration +} + +func NewPool(raw RawPool, config PoolConfig) *Pool { onWait := make(chan struct{}) pool := &Pool{ + config: config, ctx: Context{ OnWait: onWait, }, @@ -57,14 +66,14 @@ func NewPool(raw RawPool, idleTimeout time.Duration) *Pool { } }() - if idleTimeout != 0 { + if config.IdleTimeout != 0 { go func() { for { var wait time.Duration now := time.Now() idle := pool.IdleSince() - for now.Sub(idle) > idleTimeout { + for now.Sub(idle) > config.IdleTimeout { if idle == (time.Time{}) { break } @@ -73,9 +82,9 @@ func NewPool(raw RawPool, idleTimeout time.Duration) *Pool { } if idle == (time.Time{}) { - wait = idleTimeout + wait = config.IdleTimeout } else { - wait = now.Sub(idle.Add(idleTimeout)) + wait = now.Sub(idle.Add(config.IdleTimeout)) } time.Sleep(wait) @@ -185,6 +194,6 @@ func (T *Pool) RemoveRecipe(name string) { } } -func (T *Pool) Serve(client zap.ReadWriter, startupParameters map[string]string) { +func (T *Pool) Serve(client zap.ReadWriter, startupParameters map[strutil.CIString]string) { T.raw.Serve(&T.ctx, client, startupParameters) } diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index cf04f0bf9f6875a2da1e00f9997d95716d8f4a49..78f229d0780904efd7d24111dbab88b8c58e3415 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -8,15 +8,24 @@ import ( "pggat2/lib/middleware/interceptor" "pggat2/lib/middleware/middlewares/unterminate" "pggat2/lib/util/maps" + "pggat2/lib/util/strutil" "pggat2/lib/zap" ) type Pooler struct { + config PoolerConfig + users maps.RWLocked[string, *User] } -func NewPooler() *Pooler { - return &Pooler{} +type PoolerConfig struct { + AllowedStartupParameters []strutil.CIString +} + +func NewPooler(config PoolerConfig) *Pooler { + return &Pooler{ + config: config, + } } func (T *Pooler) AddUser(name string, user *User) { @@ -38,17 +47,21 @@ func (T *Pooler) Serve(client zap.ReadWriter) { unterminate.Unterminate, ) - username, database, startupParameters, err := frontends.Accept(client, func(username, database string) (auth.Credentials, bool) { - user := T.GetUser(username) - if user == nil { - return nil, false - } - pool := user.GetPool(database) - if pool == nil { - return nil, false - } - return user.GetCredentials(), true - }) + username, database, startupParameters, err := frontends.Accept( + client, + func(username, database string) (auth.Credentials, bool) { + user := T.GetUser(username) + if user == nil { + return nil, false + } + pool := user.GetPool(database) + if pool == nil { + return nil, false + } + return user.GetCredentials(), true + }, + T.config.AllowedStartupParameters, + ) if err != nil { _ = client.Close() return diff --git a/lib/gat/pools/session/config.go b/lib/gat/pools/session/config.go new file mode 100644 index 0000000000000000000000000000000000000000..0ffe00ffe94e83cafe8f6321c5ed22a1243f94cd --- /dev/null +++ b/lib/gat/pools/session/config.go @@ -0,0 +1,7 @@ +package session + +type Config struct { + // RoundRobin determines which order connections will be chosen. If false, connections are handled lifo, + // otherwise they are chosen fifo + RoundRobin bool +} diff --git a/lib/gat/pools/session/conn.go b/lib/gat/pools/session/conn.go index fafa27ca446e1543d7efbfaa5399ef9c11a50926..8c5e0b38914c343dbfe52a63b791271fb1777ac3 100644 --- a/lib/gat/pools/session/conn.go +++ b/lib/gat/pools/session/conn.go @@ -3,11 +3,12 @@ package session import ( "github.com/google/uuid" + "pggat2/lib/util/strutil" "pggat2/lib/zap" ) type Conn struct { id uuid.UUID rw zap.ReadWriter - initialParameters map[string]string + initialParameters map[strutil.CIString]string } diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index 404a8ff73a7a15e198d9b58c5e7bb64b3e45e3a1..0b80f5a0ea6bdb748c47babb6c18c9ccd3f838f4 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -12,6 +12,7 @@ import ( "pggat2/lib/util/chans" "pggat2/lib/util/maps" "pggat2/lib/util/ring" + "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" ) @@ -22,7 +23,7 @@ type queueItem struct { } type Pool struct { - roundRobin bool + config Config // use slice lifo for better perf queue ring.Ring[queueItem] @@ -32,11 +33,9 @@ type Pool struct { } // NewPool creates a new session pool. -// roundRobin determines which order connections will be chosen. If roundRobin = false, connections are handled lifo, -// otherwise they are chosen fifo -func NewPool(roundRobin bool) *Pool { +func NewPool(config Config) *Pool { p := &Pool{ - roundRobin: roundRobin, + config: config, } p.ready.L = &p.qmu return p @@ -51,7 +50,7 @@ func (T *Pool) acquire(ctx *gat.Context) Conn { } var entry queueItem - if T.roundRobin { + if T.config.RoundRobin { entry, _ = T.queue.PopFront() } else { entry, _ = T.queue.PopBack() @@ -89,7 +88,7 @@ func (T *Pool) release(conn Conn) { T._release(conn.id) } -func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[string]string) { +func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, ps map[strutil.CIString]string) { defer func() { _ = client.Close() }() @@ -104,12 +103,26 @@ func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[string]strin } }() + for key, value := range ps { + if conn.initialParameters[key] == value { + continue + } + if err := backends.QueryString(&backends.Context{}, conn.rw, `SET `+strutil.Escape(key.String(), `"`)+` = `+strutil.Escape(value, `'`)); err != nil { + connOk = false + return + } + } + if func() bool { pkts := zap.NewPackets() defer pkts.Done() for key, value := range conn.initialParameters { packet := zap.NewPacket() - packets.WriteParameterStatus(packet, key, value) + if val, ok := ps[key]; ok { + packets.WriteParameterStatus(packet, key.String(), val) + } else { + packets.WriteParameterStatus(packet, key.String(), value) + } pkts.Append(packet) } @@ -137,7 +150,7 @@ func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[string]strin } } -func (T *Pool) AddServer(server zap.ReadWriter, parameters map[string]string) uuid.UUID { +func (T *Pool) AddServer(server zap.ReadWriter, parameters map[strutil.CIString]string) uuid.UUID { T.qmu.Lock() defer T.qmu.Unlock() diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go index cd490ab49d75e97f8aab999a04d9d69694920b85..96f4b4a7090e3bc1f9a5215e6d65e590042149be 100644 --- a/lib/gat/pools/transaction/pool.go +++ b/lib/gat/pools/transaction/pool.go @@ -11,6 +11,7 @@ import ( "pggat2/lib/middleware/middlewares/ps" "pggat2/lib/rob" "pggat2/lib/rob/schedulers/v1" + "pggat2/lib/util/strutil" "pggat2/lib/zap" ) @@ -26,7 +27,7 @@ func NewPool() *Pool { return pool } -func (T *Pool) AddServer(server zap.ReadWriter, parameters map[string]string) uuid.UUID { +func (T *Pool) AddServer(server zap.ReadWriter, parameters map[strutil.CIString]string) uuid.UUID { eqps := eqp.NewServer() pss := ps.NewServer(parameters) mw := interceptor.NewInterceptor( @@ -58,7 +59,7 @@ func (T *Pool) RemoveServer(id uuid.UUID) zap.ReadWriter { return conn.(*Conn).rw } -func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[string]string) { +func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[strutil.CIString]string) { source := T.s.NewSource() eqpc := eqp.NewClient() defer eqpc.Done() diff --git a/lib/gat/recipe.go b/lib/gat/recipe.go index 65438046a478e7cd3f06b5f407a4c61830cad925..cc4918e93b6944e6f28f131d8bdcd4694d199aa8 100644 --- a/lib/gat/recipe.go +++ b/lib/gat/recipe.go @@ -6,11 +6,12 @@ import ( "pggat2/lib/auth" "pggat2/lib/bouncer/backends/v0" "pggat2/lib/util/maps" + "pggat2/lib/util/strutil" "pggat2/lib/zap" ) type Recipe interface { - Connect() (zap.ReadWriter, map[string]string, error) + Connect() (zap.ReadWriter, map[strutil.CIString]string, error) GetMinConnections() int // GetMaxConnections returns the maximum amount of connections for this db @@ -26,10 +27,10 @@ type TCPRecipe struct { MinConnections int MaxConnections int - StartupParameters map[string]string + StartupParameters map[strutil.CIString]string } -func (T TCPRecipe) Connect() (zap.ReadWriter, map[string]string, error) { +func (T TCPRecipe) Connect() (zap.ReadWriter, map[strutil.CIString]string, error) { conn, err := net.Dial("tcp", T.Address) if err != nil { return nil, nil, err @@ -38,7 +39,7 @@ func (T TCPRecipe) Connect() (zap.ReadWriter, map[string]string, error) { parameterStatus := maps.Clone(T.StartupParameters) if parameterStatus == nil { - parameterStatus = make(map[string]string) + parameterStatus = make(map[strutil.CIString]string) } err = backends.Accept(rw, T.Credentials, T.Database, parameterStatus) diff --git a/lib/middleware/middlewares/ps/client.go b/lib/middleware/middlewares/ps/client.go index 563849c3704c096fe9e8a9db645f83611e5f5f22..1c07be90b1c1be0440cbc3ab93ad1cbed7b906fe 100644 --- a/lib/middleware/middlewares/ps/client.go +++ b/lib/middleware/middlewares/ps/client.go @@ -4,19 +4,20 @@ import ( "errors" "pggat2/lib/middleware" + "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" ) type Client struct { - parameters map[string]string + parameters map[strutil.CIString]string middleware.Nil } func NewClient() *Client { return &Client{ - parameters: make(map[string]string), + parameters: make(map[strutil.CIString]string), } } @@ -27,12 +28,13 @@ func (T *Client) Send(ctx middleware.Context, packet *zap.Packet) error { if !ok { return errors.New("bad packet format") } - if T.parameters[key] == value { + ikey := strutil.MakeCIString(key) + if T.parameters[ikey] == value { // already set ctx.Cancel() break } - T.parameters[key] = value + T.parameters[ikey] = value } return nil } diff --git a/lib/middleware/middlewares/ps/server.go b/lib/middleware/middlewares/ps/server.go index 971025ba9142e16b130c3b667b3ac0bd5d2413b0..c141230b0fffa2d7722a3ae3175f5f05000f732c 100644 --- a/lib/middleware/middlewares/ps/server.go +++ b/lib/middleware/middlewares/ps/server.go @@ -4,25 +4,26 @@ import ( "errors" "pggat2/lib/middleware" + "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" ) type Server struct { - parameters map[string]string + parameters map[strutil.CIString]string middleware.Nil } -func NewServer(parameters map[string]string) *Server { +func NewServer(parameters map[strutil.CIString]string) *Server { return &Server{ parameters: parameters, } } -func (T *Server) syncParameter(pkts *zap.Packets, ps *Client, name, expected string) { +func (T *Server) syncParameter(pkts *zap.Packets, ps *Client, name strutil.CIString, expected string) { packet := zap.NewPacket() - packets.WriteParameterStatus(packet, name, expected) + packets.WriteParameterStatus(packet, name.String(), expected) pkts.Append(packet) ps.parameters[name] = expected @@ -59,7 +60,8 @@ func (T *Server) Read(_ middleware.Context, in *zap.Packet) error { if !ok { return errors.New("bad packet format") } - T.parameters[key] = value + ikey := strutil.MakeCIString(key) + T.parameters[ikey] = value } return nil } diff --git a/lib/util/encoding/ini/unmarshal.go b/lib/util/encoding/ini/unmarshal.go index 7f227f60da7b138b478b357ac7635e3d77747e20..9e5a875a4de469ad74b49c3968899aa8524dbf8a 100644 --- a/lib/util/encoding/ini/unmarshal.go +++ b/lib/util/encoding/ini/unmarshal.go @@ -5,10 +5,17 @@ import ( "errors" "reflect" "strconv" - "strings" ) -func get(rv reflect.Value, key string, fn func(rv reflect.Value) error) error { +type Unmarshaller interface { + UnmarshalINI(bytes []byte) error +} + +var ( + unmarshaller = reflect.TypeOf((*Unmarshaller)(nil)).Elem() +) + +func get(rv reflect.Value, key []byte, fn func(rv reflect.Value) error) error { outer: for { switch rv.Kind() { @@ -26,6 +33,7 @@ outer: switch rv.Kind() { case reflect.Struct: + keystr := string(key) rt := rv.Type() numFields := rt.NumField() for i := 0; i < numFields; i++ { @@ -40,7 +48,7 @@ outer: if name == "*" { return get(rv.Field(i), key, fn) } - if name == key { + if name == keystr { return fn(rv.Field(i)) } } @@ -48,14 +56,13 @@ outer: case reflect.Map: rt := rv.Type() rtKey := rt.Key() - if rtKey.Kind() != reflect.String { - return nil - } if rv.IsNil() { rv.Set(reflect.MakeMap(rt)) } k := reflect.New(rtKey).Elem() - k.SetString(key) + if err := set(k, key); err != nil { + return err + } v := reflect.New(rt.Elem()).Elem() if err := fn(v); err != nil { return err @@ -67,7 +74,7 @@ outer: } } -func set(rv reflect.Value, value string) error { +func set(rv reflect.Value, value []byte) error { outer: for { switch rv.Kind() { @@ -89,11 +96,21 @@ outer: } } + rt := rv.Type() + if rt.Implements(unmarshaller) { + rvu := rv.Interface().(Unmarshaller) + return rvu.UnmarshalINI(value) + } + if rv.CanAddr() && reflect.PointerTo(rt).Implements(unmarshaller) { + rvu := rv.Addr().Interface().(Unmarshaller) + return rvu.UnmarshalINI(value) + } + switch rv.Kind() { case reflect.Struct, reflect.Map: - fields := strings.Fields(value) + fields := bytes.Fields(value) for _, field := range fields { - k, v, ok := strings.Cut(field, "=") + k, v, ok := bytes.Cut(field, []byte{'='}) if !ok { return errors.New("expected key=value") } @@ -105,45 +122,45 @@ outer: } return nil case reflect.Array: - items := strings.Split(value, ",") + items := bytes.Split(value, []byte{','}) if len(items) != rv.Len() { return errors.New("wrong length for array") } for i, item := range items { - if err := set(rv.Index(i), strings.TrimSpace(item)); err != nil { + if err := set(rv.Index(i), bytes.TrimSpace(item)); err != nil { return err } } return nil case reflect.Slice: - items := strings.Split(value, ",") - slice := reflect.MakeSlice(rv.Type().Elem(), len(items), len(items)) + items := bytes.Split(value, []byte{','}) + slice := reflect.MakeSlice(rt.Elem(), len(items), len(items)) for i, item := range items { - if err := set(slice.Index(i), strings.TrimSpace(item)); err != nil { + if err := set(slice.Index(i), bytes.TrimSpace(item)); err != nil { return err } } rv.Set(slice) return nil case reflect.String: - rv.SetString(value) + rv.SetString(string(value)) return nil case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v, err := strconv.ParseInt(value, 10, 64) + v, err := strconv.ParseInt(string(value), 10, 64) if err != nil { return err } rv.SetInt(v) return nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v, err := strconv.ParseUint(value, 10, 64) + v, err := strconv.ParseUint(string(value), 10, 64) if err != nil { return err } rv.SetUint(v) return nil case reflect.Float32, reflect.Float64: - v, err := strconv.ParseFloat(value, 64) + v, err := strconv.ParseFloat(string(value), 64) if err != nil { return err } @@ -154,8 +171,8 @@ outer: } } -func setpath(rv reflect.Value, section, key, value string) error { - if section == "" { +func setpath(rv reflect.Value, section, key, value []byte) error { + if len(section) == 0 { return get(rv, key, func(entry reflect.Value) error { return set(entry, value) }) @@ -174,7 +191,7 @@ func Unmarshal(data []byte, v any) error { } rv = rv.Elem() - var section string + var section []byte var line []byte for { @@ -199,7 +216,7 @@ func Unmarshal(data []byte, v any) error { // section if bytes.HasPrefix(line, []byte{'['}) && bytes.HasSuffix(line, []byte{']'}) { - section = string(line[1 : len(line)-1]) + section = line[1 : len(line)-1] continue } @@ -211,7 +228,7 @@ func Unmarshal(data []byte, v any) error { key = bytes.TrimSpace(key) value = bytes.TrimSpace(value) - if err := setpath(rv, section, string(key), string(value)); err != nil { + if err := setpath(rv, section, key, value); err != nil { return err } } diff --git a/lib/util/slices/contains.go b/lib/util/slices/contains.go new file mode 100644 index 0000000000000000000000000000000000000000..f4a2996c9691f45df2d79f6ef0a8a58ee8a6cdc7 --- /dev/null +++ b/lib/util/slices/contains.go @@ -0,0 +1,11 @@ +package slices + +func Contains[T comparable](haystack []T, needle T) bool { + for _, hay := range haystack { + if hay == needle { + return true + } + } + + return false +} diff --git a/lib/util/strutil/cistring.go b/lib/util/strutil/cistring.go new file mode 100644 index 0000000000000000000000000000000000000000..da1a5a68ac90ebb7e7b4e17eb8e38249c1c517b3 --- /dev/null +++ b/lib/util/strutil/cistring.go @@ -0,0 +1,41 @@ +package strutil + +import ( + "encoding/json" + "strings" + + "pggat2/lib/util/encoding/ini" +) + +// CIString is a case-insensitive string. +type CIString struct { + value string +} + +func MakeCIString(value string) CIString { + return CIString{ + strings.ToLower(value), + } +} + +func (T *CIString) String() string { + return T.value +} + +func (T *CIString) MarshalJSON() ([]byte, error) { + return json.Marshal(T.value) +} + +func (T *CIString) UnmarshalJSON(bytes []byte) error { + return json.Unmarshal(bytes, &T.value) +} + +var _ json.Marshaler = (*CIString)(nil) +var _ json.Unmarshaler = (*CIString)(nil) + +func (T *CIString) UnmarshalINI(bytes []byte) error { + T.value = string(bytes) + return nil +} + +var _ ini.Unmarshaller = (*CIString)(nil) diff --git a/lib/util/strings/escape.go b/lib/util/strutil/escape.go similarity index 96% rename from lib/util/strings/escape.go rename to lib/util/strutil/escape.go index c59a8f1397ab17a73b0811fa5b3c2648fd92ee4d..42928856f48e5d1e92c1edf4fd62a0c22767d36a 100644 --- a/lib/util/strings/escape.go +++ b/lib/util/strutil/escape.go @@ -1,4 +1,4 @@ -package strings +package strutil import ( "strings" diff --git a/pgbouncer.ini b/pgbouncer.ini index c0d71af9d41a708354103529c16aac6fec87c716..118c7214d648c51f71cf2af1af86ee01e04cb7ee 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -7,5 +7,5 @@ listen_addr = * postgres = [databases] -regression = host=localhost datestyle=Postgres,MDY intervalstyle=postgres_verbose timezone=PST8PDT +regression = host=localhost datestyle=Postgres,MDY timezone=PST8PDT postgres = host=localhost