diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index dcbb0c2dfc5732065244956cf71c68681af56aa3..32f545b5d03a04aaa9770c522938893c03aff71d 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -322,16 +322,15 @@ func accept( ) return } - credsSASL, ok := creds.(auth.SASL) - if !ok { + if credsSASL, ok := creds.(auth.SASL); ok { + err = authenticationSASL(client, credsSASL) + } else { err = perror.New( perror.FATAL, perror.InternalError, "Auth method not supported", ) } - - err = authenticationSASL(client, credsSASL) if err != nil { return } diff --git a/lib/gat/configs/pgbouncer/config.go b/lib/gat/configs/pgbouncer/config.go index 1bfc0ff2a23ea81699018295c231f0701b7da702..83f83c0bbaa3ad6faa33280622539a08ad4031b9 100644 --- a/lib/gat/configs/pgbouncer/config.go +++ b/lib/gat/configs/pgbouncer/config.go @@ -71,93 +71,93 @@ type TLSECDHCurve string type TLSDHEParams string type PgBouncer struct { - LogFile string `ini:"logfile"` - PidFile string `ini:"pidfile"` - ListenAddr string `ini:"listen_addr"` - ListenPort int `ini:"listen_port"` - UnixSocketDir string `ini:"unix_socket_dir"` - UnixSocketMode string `ini:"unix_socket_mode"` - UnixSocketGroup string `ini:"unix_socket_group"` - User string `ini:"user"` - PoolMode PoolMode `ini:"pool_mode"` - MaxClientConn int `ini:"max_client_conn"` - DefaultPoolSize int `ini:"default_pool_size"` - MinPoolSize int `ini:"min_pool_size"` - ReservePoolSize int `ini:"reserve_pool_size"` - ReservePoolTimeout float64 `ini:"reserve_pool_timeout"` - MaxDBConnections int `ini:"max_db_connections"` - MaxUserConnections int `ini:"max_user_connections"` - ServerRoundRobin int `ini:"server_round_robin"` - TrackExtraParameters []string `ini:"track_extra_parameters"` - IgnoreStartupParameters []string `ini:"ignore_startup_parameters"` - PeerID int `ini:"peer_id"` - DisablePQExec int `ini:"disable_pqexec"` - ApplicationNameAddHost int `ini:"application_name_add_host"` - ConfFile string `ini:"conffile"` - ServiceName string `ini:"service_name"` - StatsPeriod int `ini:"stats_period"` - AuthType string `ini:"auth_type"` - AuthHbaFile string `ini:"auth_hba_file"` - AuthFile string `ini:"auth_file"` - AuthUser string `ini:"auth_user"` - AuthQuery string `ini:"auth_query"` - AuthDbname string `ini:"auth_dbname"` - Syslog string `ini:"syslog"` - SyslogIdent string `ini:"syslog_ident"` - SyslogFacility string `ini:"syslog_facility"` - LogConnections int `ini:"log_connections"` - LogDisconnections int `ini:"log_disconnections"` - LogPoolerErrors int `ini:"log_pooler_errors"` - LogStats int `ini:"log_stats"` - Verbose int `ini:"verbose"` - AdminUsers []string `ini:"auth_users"` - StatsUsers []string `ini:"stats_users"` - ServerResetQuery string `ini:"server_reset_query"` - ServerResetQueryAlways int `ini:"server_reset_query_always"` - ServerCheckDelay float64 `ini:"server_check_delay"` - ServerCheckQuery string `ini:"server_check_query"` - ServerFastClose int `ini:"server_fast_close"` - ServerLifetime float64 `ini:"server_lifetime"` - ServerIdleTimeout float64 `ini:"server_idle_timeout"` - ServerConnectTimeout float64 `ini:"server_connect_timeout"` - ServerLoginRetry float64 `ini:"server_login_retry"` - ClientLoginTimeout float64 `ini:"client_login_timeout"` - AutodbIdleTimeout float64 `ini:"autodb_idle_timeout"` - DnsMaxTtl float64 `ini:"dns_max_ttl"` - DnsNxdomainTtl float64 `ini:"dns_nxdomain_ttl"` - DnsZoneCheckPeriod float64 `ini:"dns_zone_check_period"` - ResolvConf string `ini:"resolv.conf"` - ClientTLSSSLMode SSLMode `ini:"client_tls_sslmode"` - ClientTLSKeyFile string `ini:"client_tls_key_file"` - ClientTLSCertFile string `ini:"client_tls_cert_file"` - ClientTLSCaFile string `ini:"client_tls_ca_file"` - ClientTLSProtocols []TLSProtocol `ini:"client_tls_protocols"` - ClientTLSCiphers []TLSCipher `ini:"client_tls_ciphers"` - ClientTLSECDHCurve TLSECDHCurve `ini:"client_tls_ecdhcurve"` - ClientTLSDHEParams TLSDHEParams `ini:"client_tls_dheparams"` - ServerTLSSSLMode SSLMode `ini:"server_tls_sslmode"` - ServerTLSCaFile string `ini:"server_tls_ca_file"` - ServerTLSKeyFile string `ini:"server_tls_key_file"` - ServerTLSCertFile string `ini:"server_tls_cert_file"` - ServerTLSProtocols []TLSProtocol `ini:"server_tls_protocols"` - ServerTLSCiphers []TLSCipher `ini:"server_tls_ciphers"` - QueryTimeout float64 `ini:"query_timeout"` - QueryWaitTimeout float64 `ini:"query_wait_timeout"` - CancelWaitTimeout float64 `ini:"cancel_wait_timeout"` - ClientIdleTimeout float64 `ini:"client_idle_timeout"` - IdleTransactionTimeout float64 `ini:"idle_transaction_timeout"` - SuspendTimeout float64 `ini:"suspend_timeout"` - PktBuf int `ini:"pkt_buf"` - MaxPacketSize int `ini:"max_packet_size"` - ListenBacklog int `ini:"listen_backlog"` - SbufLoopcnt int `ini:"sbuf_loopcnt"` - SoReuseport int `ini:"so_reuseport"` - TcpDeferAccept int `ini:"tcp_defer_accept"` - TcpSocketBuffer int `ini:"tcp_socket_buffer"` - TcpKeepalive int `ini:"tcp_keepalive"` - TcpKeepidle int `ini:"tcp_keepidle"` - TcpKeepintvl int `ini:"tcp_keepintvl"` - TcpUserTimeout int `ini:"tcp_user_timeout"` + LogFile string `ini:"logfile"` + PidFile string `ini:"pidfile"` + ListenAddr string `ini:"listen_addr"` + ListenPort int `ini:"listen_port"` + UnixSocketDir string `ini:"unix_socket_dir"` + UnixSocketMode string `ini:"unix_socket_mode"` + UnixSocketGroup string `ini:"unix_socket_group"` + User string `ini:"user"` + PoolMode PoolMode `ini:"pool_mode"` + MaxClientConn int `ini:"max_client_conn"` + DefaultPoolSize int `ini:"default_pool_size"` + MinPoolSize int `ini:"min_pool_size"` + ReservePoolSize int `ini:"reserve_pool_size"` + ReservePoolTimeout float64 `ini:"reserve_pool_timeout"` + MaxDBConnections int `ini:"max_db_connections"` + MaxUserConnections int `ini:"max_user_connections"` + ServerRoundRobin int `ini:"server_round_robin"` + TrackExtraParameters []strutil.CIString `ini:"track_extra_parameters"` + IgnoreStartupParameters []strutil.CIString `ini:"ignore_startup_parameters"` + PeerID int `ini:"peer_id"` + DisablePQExec int `ini:"disable_pqexec"` + ApplicationNameAddHost int `ini:"application_name_add_host"` + ConfFile string `ini:"conffile"` + ServiceName string `ini:"service_name"` + StatsPeriod int `ini:"stats_period"` + AuthType string `ini:"auth_type"` + AuthHbaFile string `ini:"auth_hba_file"` + AuthFile string `ini:"auth_file"` + AuthUser string `ini:"auth_user"` + AuthQuery string `ini:"auth_query"` + AuthDbname string `ini:"auth_dbname"` + Syslog string `ini:"syslog"` + SyslogIdent string `ini:"syslog_ident"` + SyslogFacility string `ini:"syslog_facility"` + LogConnections int `ini:"log_connections"` + LogDisconnections int `ini:"log_disconnections"` + LogPoolerErrors int `ini:"log_pooler_errors"` + LogStats int `ini:"log_stats"` + Verbose int `ini:"verbose"` + AdminUsers []string `ini:"auth_users"` + StatsUsers []string `ini:"stats_users"` + ServerResetQuery string `ini:"server_reset_query"` + ServerResetQueryAlways int `ini:"server_reset_query_always"` + ServerCheckDelay float64 `ini:"server_check_delay"` + ServerCheckQuery string `ini:"server_check_query"` + ServerFastClose int `ini:"server_fast_close"` + ServerLifetime float64 `ini:"server_lifetime"` + ServerIdleTimeout float64 `ini:"server_idle_timeout"` + ServerConnectTimeout float64 `ini:"server_connect_timeout"` + ServerLoginRetry float64 `ini:"server_login_retry"` + ClientLoginTimeout float64 `ini:"client_login_timeout"` + AutodbIdleTimeout float64 `ini:"autodb_idle_timeout"` + DnsMaxTtl float64 `ini:"dns_max_ttl"` + DnsNxdomainTtl float64 `ini:"dns_nxdomain_ttl"` + DnsZoneCheckPeriod float64 `ini:"dns_zone_check_period"` + ResolvConf string `ini:"resolv.conf"` + ClientTLSSSLMode SSLMode `ini:"client_tls_sslmode"` + ClientTLSKeyFile string `ini:"client_tls_key_file"` + ClientTLSCertFile string `ini:"client_tls_cert_file"` + ClientTLSCaFile string `ini:"client_tls_ca_file"` + ClientTLSProtocols []TLSProtocol `ini:"client_tls_protocols"` + ClientTLSCiphers []TLSCipher `ini:"client_tls_ciphers"` + ClientTLSECDHCurve TLSECDHCurve `ini:"client_tls_ecdhcurve"` + ClientTLSDHEParams TLSDHEParams `ini:"client_tls_dheparams"` + ServerTLSSSLMode SSLMode `ini:"server_tls_sslmode"` + ServerTLSCaFile string `ini:"server_tls_ca_file"` + ServerTLSKeyFile string `ini:"server_tls_key_file"` + ServerTLSCertFile string `ini:"server_tls_cert_file"` + ServerTLSProtocols []TLSProtocol `ini:"server_tls_protocols"` + ServerTLSCiphers []TLSCipher `ini:"server_tls_ciphers"` + QueryTimeout float64 `ini:"query_timeout"` + QueryWaitTimeout float64 `ini:"query_wait_timeout"` + CancelWaitTimeout float64 `ini:"cancel_wait_timeout"` + ClientIdleTimeout float64 `ini:"client_idle_timeout"` + IdleTransactionTimeout float64 `ini:"idle_transaction_timeout"` + SuspendTimeout float64 `ini:"suspend_timeout"` + PktBuf int `ini:"pkt_buf"` + MaxPacketSize int `ini:"max_packet_size"` + ListenBacklog int `ini:"listen_backlog"` + SbufLoopcnt int `ini:"sbuf_loopcnt"` + SoReuseport int `ini:"so_reuseport"` + TcpDeferAccept int `ini:"tcp_defer_accept"` + TcpSocketBuffer int `ini:"tcp_socket_buffer"` + TcpKeepalive int `ini:"tcp_keepalive"` + TcpKeepidle int `ini:"tcp_keepidle"` + TcpKeepintvl int `ini:"tcp_keepintvl"` + TcpUserTimeout int `ini:"tcp_user_timeout"` } type Database struct { @@ -203,8 +203,8 @@ var Default = Config{ MaxClientConn: 100, DefaultPoolSize: 20, ReservePoolTimeout: 5.0, - TrackExtraParameters: []string{ - "IntervalStyle", + TrackExtraParameters: []strutil.CIString{ + strutil.MakeCIString("IntervalStyle"), }, ServiceName: "pgbouncer", StatsPeriod: 60, @@ -265,14 +265,18 @@ func Load(config string) (Config, error) { } func (T *Config) ListenAndServe() error { + trackedParameters := append([]strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, T.PgBouncer.TrackExtraParameters...) + + ignoreStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...) + pooler := gat.NewPooler(gat.PoolerConfig{ - AllowedStartupParameters: []strutil.CIString{ - strutil.MakeCIString("intervalstyle"), - strutil.MakeCIString("application_name"), - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - }, + AllowedStartupParameters: ignoreStartupParameters, }) var authFile map[string]string @@ -317,14 +321,21 @@ func (T *Config) ListenAndServe() error { poolMode = T.PgBouncer.PoolMode } + rawPoolConfig := gat.BaseRawPoolConfig{ + TrackedParameters: trackedParameters, + } + var raw gat.RawPool switch poolMode { case PoolModeSession: raw = session.NewPool(session.Config{ - RoundRobin: T.PgBouncer.ServerRoundRobin != 0, + RoundRobin: T.PgBouncer.ServerRoundRobin != 0, + BaseRawPoolConfig: rawPoolConfig, }) case PoolModeTransaction: - raw = transaction.NewPool() + raw = transaction.NewPool(transaction.Config{ + BaseRawPoolConfig: rawPoolConfig, + }) default: return errors.New("unsupported pool mode") } diff --git a/lib/gat/pool.go b/lib/gat/pool.go index f198eb39cb965bfecf5fc99e68405e6653cf5f26..c79df6b34d843d9b8d0b3a2331d0e9ae4702bf6d 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -28,6 +28,10 @@ type RawPool interface { IdleSince() time.Time } +type BaseRawPoolConfig struct { + TrackedParameters []strutil.CIString +} + type PoolRecipe struct { removed bool servers []uuid.UUID diff --git a/lib/gat/pools/session/config.go b/lib/gat/pools/session/config.go index 0ffe00ffe94e83cafe8f6321c5ed22a1243f94cd..ae442655dc9a8b4df0bfce2c147938be80bab8ad 100644 --- a/lib/gat/pools/session/config.go +++ b/lib/gat/pools/session/config.go @@ -1,6 +1,10 @@ package session +import "pggat2/lib/gat" + type Config struct { + gat.BaseRawPoolConfig + // RoundRobin determines which order connections will be chosen. If false, connections are handled lifo, // otherwise they are chosen fifo RoundRobin bool diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index 0b80f5a0ea6bdb748c47babb6c18c9ccd3f838f4..732247caf8320ac004dfb0bc4172b3a24fda18c1 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -12,6 +12,7 @@ import ( "pggat2/lib/util/chans" "pggat2/lib/util/maps" "pggat2/lib/util/ring" + "pggat2/lib/util/slices" "pggat2/lib/util/strutil" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" @@ -103,27 +104,47 @@ func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, ps map[strutil.CIS } }() - for key, value := range ps { - if conn.initialParameters[key] == value { - continue - } - if err := backends.QueryString(&backends.Context{}, conn.rw, `SET `+strutil.Escape(key.String(), `"`)+` = `+strutil.Escape(value, `'`)); err != nil { - connOk = false - return - } - } - if func() bool { pkts := zap.NewPackets() defer pkts.Done() - for key, value := range conn.initialParameters { - packet := zap.NewPacket() - if val, ok := ps[key]; ok { - packets.WriteParameterStatus(packet, key.String(), val) - } else { - packets.WriteParameterStatus(packet, key.String(), value) + + add := func(key strutil.CIString) { + if value, ok := conn.initialParameters[key]; ok { + pkt := zap.NewPacket() + packets.WriteParameterStatus(pkt, key.String(), value) + pkts.Append(pkt) + } + } + + for key, value := range ps { + // skip already set params + if conn.initialParameters[key] == value { + add(key) + continue } - pkts.Append(packet) + + // only set tracking params + if !slices.Contains(T.config.TrackedParameters, key) { + add(key) + continue + } + + pkt := zap.NewPacket() + packets.WriteParameterStatus(pkt, key.String(), value) + pkts.Append(pkt) + + if err := backends.QueryString(&backends.Context{}, conn.rw, `SET `+strutil.Escape(key.String(), `"`)+` = `+strutil.Escape(value, `'`)); err != nil { + connOk = false + return true + } + } + + for key := range conn.initialParameters { + if _, ok := ps[key]; ok { + continue + } + + add(key) } err := client.WriteV(pkts) diff --git a/lib/gat/pools/transaction/config.go b/lib/gat/pools/transaction/config.go new file mode 100644 index 0000000000000000000000000000000000000000..23b6cb1895c6924bbeed7452421ba70cbb47b8c2 --- /dev/null +++ b/lib/gat/pools/transaction/config.go @@ -0,0 +1,7 @@ +package transaction + +import "pggat2/lib/gat" + +type Config struct { + gat.BaseRawPoolConfig +} diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go index c1aa72fc1915277f752e64323512e9ec7ef5b473..2090c69ad5829cd800f44305eb940f510bad551a 100644 --- a/lib/gat/pools/transaction/conn.go +++ b/lib/gat/pools/transaction/conn.go @@ -18,11 +18,7 @@ func (T *Conn) Do(ctx *rob.Context, work any) { job := work.(Work) // sync parameters - err := T.ps.Sync(job.rw, job.ps) - if err != nil { - _ = job.rw.Close() - return - } + ps.Sync(job.trackedParameters, job.rw, job.ps, T.rw, T.ps) T.eqp.SetClient(job.eqp) clientErr, serverErr := bouncers.Bounce(job.rw, T.rw, job.initialPacket) diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go index 96f4b4a7090e3bc1f9a5215e6d65e590042149be..82d3cb34c525dc6f07d00c6b528be698ea1d0345 100644 --- a/lib/gat/pools/transaction/pool.go +++ b/lib/gat/pools/transaction/pool.go @@ -16,12 +16,14 @@ import ( ) type Pool struct { - s schedulers.Scheduler + config Config + s schedulers.Scheduler } -func NewPool() *Pool { +func NewPool(config Config) *Pool { pool := &Pool{ - s: schedulers.MakeScheduler(), + config: config, + s: schedulers.MakeScheduler(), } return pool @@ -59,11 +61,11 @@ func (T *Pool) RemoveServer(id uuid.UUID) zap.ReadWriter { return conn.(*Conn).rw } -func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[strutil.CIString]string) { +func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, parameters map[strutil.CIString]string) { source := T.s.NewSource() eqpc := eqp.NewClient() defer eqpc.Done() - psc := ps.NewClient() + psc := ps.NewClient(parameters) client = interceptor.NewInterceptor( client, eqpc, @@ -82,10 +84,11 @@ func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, _ map[strutil.CISt } source.Do(&robCtx, Work{ - rw: client, - initialPacket: packet, - eqp: eqpc, - ps: psc, + rw: client, + initialPacket: packet, + eqp: eqpc, + ps: psc, + trackedParameters: T.config.TrackedParameters, }) } _ = client.Close() diff --git a/lib/gat/pools/transaction/work.go b/lib/gat/pools/transaction/work.go index e3eb55184fb4c9974ddfdb407da2b5abb70b9837..5b755fb901991bd51d03c1b4be656a169c0aee2b 100644 --- a/lib/gat/pools/transaction/work.go +++ b/lib/gat/pools/transaction/work.go @@ -3,12 +3,14 @@ package transaction import ( "pggat2/lib/middleware/middlewares/eqp" "pggat2/lib/middleware/middlewares/ps" + "pggat2/lib/util/strutil" "pggat2/lib/zap" ) type Work struct { - rw zap.ReadWriter - initialPacket *zap.Packet - eqp *eqp.Client - ps *ps.Client + rw zap.ReadWriter + initialPacket *zap.Packet + eqp *eqp.Client + ps *ps.Client + trackedParameters []strutil.CIString } diff --git a/lib/middleware/middlewares/ps/client.go b/lib/middleware/middlewares/ps/client.go index 1c07be90b1c1be0440cbc3ab93ad1cbed7b906fe..f3429c9d37e29d387aafe18fa756762eac6c36da 100644 --- a/lib/middleware/middlewares/ps/client.go +++ b/lib/middleware/middlewares/ps/client.go @@ -10,14 +10,15 @@ import ( ) type Client struct { + synced bool parameters map[strutil.CIString]string middleware.Nil } -func NewClient() *Client { +func NewClient(parameters map[strutil.CIString]string) *Client { return &Client{ - parameters: make(map[strutil.CIString]string), + parameters: parameters, } } @@ -34,6 +35,9 @@ func (T *Client) Send(ctx middleware.Context, packet *zap.Packet) error { ctx.Cancel() break } + if T.parameters == nil { + T.parameters = make(map[strutil.CIString]string) + } T.parameters[ikey] = value } return nil diff --git a/lib/middleware/middlewares/ps/server.go b/lib/middleware/middlewares/ps/server.go index c141230b0fffa2d7722a3ae3175f5f05000f732c..565c2915db38fae51ff14e3e51dc3f17c40e8039 100644 --- a/lib/middleware/middlewares/ps/server.go +++ b/lib/middleware/middlewares/ps/server.go @@ -21,38 +21,6 @@ func NewServer(parameters map[strutil.CIString]string) *Server { } } -func (T *Server) syncParameter(pkts *zap.Packets, ps *Client, name strutil.CIString, expected string) { - packet := zap.NewPacket() - packets.WriteParameterStatus(packet, name.String(), expected) - pkts.Append(packet) - - ps.parameters[name] = expected -} - -func (T *Server) Sync(client zap.ReadWriter, ps *Client) error { - pkts := zap.NewPackets() - defer pkts.Done() - - for name, value := range ps.parameters { - expected := T.parameters[name] - if value == expected { - continue - } - - T.syncParameter(pkts, ps, name, expected) - } - - for name, expected := range T.parameters { - if T.parameters[name] == expected { - continue - } - - T.syncParameter(pkts, ps, name, expected) - } - - return client.WriteV(pkts) -} - func (T *Server) Read(_ middleware.Context, in *zap.Packet) error { switch in.ReadType() { case packets.ParameterStatus: @@ -61,6 +29,9 @@ func (T *Server) Read(_ middleware.Context, in *zap.Packet) error { return errors.New("bad packet format") } ikey := strutil.MakeCIString(key) + if T.parameters == nil { + T.parameters = make(map[strutil.CIString]string) + } T.parameters[ikey] = value } return nil diff --git a/lib/middleware/middlewares/ps/sync.go b/lib/middleware/middlewares/ps/sync.go new file mode 100644 index 0000000000000000000000000000000000000000..aa752f54e317fc6117c05b623af0d472e9ed9e50 --- /dev/null +++ b/lib/middleware/middlewares/ps/sync.go @@ -0,0 +1,61 @@ +package ps + +import ( + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/util/slices" + "pggat2/lib/util/strutil" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +func sync(tracking []strutil.CIString, clientPackets *zap.Packets, c *Client, server zap.ReadWriter, s *Server, name strutil.CIString) { + value := c.parameters[name] + expected := s.parameters[name] + + if value == expected { + // TODO(garet) this will send twice if both server and client have it + if !c.synced { + pkt := zap.NewPacket() + packets.WriteParameterStatus(pkt, name.String(), expected) + clientPackets.Append(pkt) + } + return + } + + if slices.Contains(tracking, name) { + if err := backends.QueryString(&backends.Context{}, server, `SET `+strutil.Escape(name.String(), `"`)+` = `+strutil.Escape(value, `'`)); err != nil { + panic(err) // TODO(garet) + } + if s.parameters == nil { + s.parameters = make(map[strutil.CIString]string) + } + s.parameters[name] = value + } else { + pkt := zap.NewPacket() + packets.WriteParameterStatus(pkt, name.String(), expected) + clientPackets.Append(pkt) + if c.parameters == nil { + c.parameters = make(map[strutil.CIString]string) + } + c.parameters[name] = value + } +} + +func Sync(tracking []strutil.CIString, client zap.ReadWriter, c *Client, server zap.ReadWriter, s *Server) { + pkts := zap.NewPackets() + defer pkts.Done() + + for name := range c.parameters { + sync(tracking, pkts, c, server, s, name) + } + + for name := range s.parameters { + sync(tracking, pkts, c, server, s, name) + } + + c.synced = true + + if err := client.WriteV(pkts); err != nil { + panic(err) // TODO(garet) + } +} diff --git a/lib/zap/packet.go b/lib/zap/packet.go index a71bad013b6f631105d5be10dfb68d79aea53bd6..51573d3f7c7042b8d589a3187976cc6c8eda42d5 100644 --- a/lib/zap/packet.go +++ b/lib/zap/packet.go @@ -25,6 +25,10 @@ func NewPackets() *Packets { } func (T *Packets) WriteTo(w io.Writer) (int64, error) { + if len(T.order) == 0 { + return 0, nil + } + buffers := make(net.Buffers, 0, len(T.order)) for _, order := range T.order { diff --git a/pgbouncer.ini b/pgbouncer.ini index 118c7214d648c51f71cf2af1af86ee01e04cb7ee..6dc489006c53ecee5f04c3e439677d79148f520f 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -1,5 +1,5 @@ [pgbouncer] -pool_mode = session +pool_mode = transaction auth_file = userlist.txt listen_addr = *