From b1b5b2b9672f2c564cdd224992daedc675740a21 Mon Sep 17 00:00:00 2001 From: Garet Halliday <ghalliday@gfxlabs.io> Date: Fri, 23 Sep 2022 11:44:20 -0500 Subject: [PATCH] send first user options to server closes #3 --- lib/gat/admin/admin.go | 2 +- lib/gat/gatling/client/client.go | 27 +++++++++++++++++-------- lib/gat/gatling/server/server.go | 17 ++++++++++------ lib/gat/gatling/server/server_test.go | 2 +- lib/gat/interfaces.go | 6 ++++-- lib/gat/pool/session/pool.go | 10 ++++----- lib/gat/pool/transaction/pool.go | 4 ++-- lib/gat/pool/transaction/shard/shard.go | 9 +++++++-- lib/gat/pool/transaction/worker.go | 10 ++++----- lib/parse/parse.go | 2 +- 10 files changed, 56 insertions(+), 33 deletions(-) diff --git a/lib/gat/admin/admin.go b/lib/gat/admin/admin.go index dd91b0c1..32823d98 100644 --- a/lib/gat/admin/admin.go +++ b/lib/gat/admin/admin.go @@ -567,7 +567,7 @@ func (c *Pool) GetUser() *config.User { return getAdminUser(c.database.gat) } -func (c *Pool) GetServerInfo() []*protocol.ParameterStatus { +func (c *Pool) GetServerInfo(_ gat.Client) []*protocol.ParameterStatus { return getServerInfo(c.database.gat) } diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go index 79ab36d1..88763f86 100644 --- a/lib/gat/gatling/client/client.go +++ b/lib/gat/gatling/client/client.go @@ -68,6 +68,8 @@ type Client struct { recv chan protocol.Packet + options []protocol.FieldsStartupMessageParameters + state gat.ClientState pid int32 @@ -97,6 +99,10 @@ type Client struct { mu sync.Mutex } +func (c *Client) GetOptions() []protocol.FieldsStartupMessageParameters { + return c.options +} + func (c *Client) GetState() gat.ClientState { c.mu.Lock() defer c.mu.Unlock() @@ -257,14 +263,20 @@ func (c *Client) Accept(ctx context.Context) error { } } } - params := make(map[string]string) + c.options = make([]protocol.FieldsStartupMessageParameters, 0, len(startup.Fields.Parameters)) for _, v := range startup.Fields.Parameters { - params[v.Name] = v.Value + switch v.Name { + case "": + case "database": + c.poolName = v.Value + case "user": + c.username = v.Value + default: + c.options = append(c.options, v) + } } - var ok bool - c.poolName, ok = params["database"] - if !ok { + if c.poolName == "" { return &pg_error.Error{ Severity: pg_error.Fatal, Code: pg_error.InvalidAuthorizationSpecification, @@ -272,8 +284,7 @@ func (c *Client) Accept(ctx context.Context) error { } } - c.username, ok = params["user"] - if !ok { + if c.username == "" { return &pg_error.Error{ Severity: pg_error.Fatal, Code: pg_error.InvalidAuthorizationSpecification, @@ -372,7 +383,7 @@ func (c *Client) Accept(ctx context.Context) error { } // - info := c.server.GetServerInfo() + info := c.server.GetServerInfo(c) for _, inf := range info { err = c.Send(inf) if err != nil { diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go index fc0813d3..f3c5676f 100644 --- a/lib/gat/gatling/server/server.go +++ b/lib/gat/gatling/server/server.go @@ -33,6 +33,8 @@ type Server struct { client gat.Client state gat.ConnectionState + options []protocol.FieldsStartupMessageParameters + serverInfo []*protocol.ParameterStatus processId int32 @@ -57,13 +59,15 @@ type Server struct { mu sync.Mutex } -func Dial(ctx context.Context, user *config.User, shard *config.Shard, server *config.Server) (gat.Connection, error) { +func Dial(ctx context.Context, options []protocol.FieldsStartupMessageParameters, user *config.User, shard *config.Shard, server *config.Server) (gat.Connection, error) { s := &Server{ addr: server.Host, port: server.Port, state: gat.ConnectionNew, + options: options, + boundPreparedStatments: make(map[string]*protocol.Parse), boundPortals: make(map[string]*protocol.Bind), @@ -230,17 +234,18 @@ func (s *Server) startup(ctx context.Context) error { s.log.Debug().Msg("sending startup") start := new(protocol.StartupMessage) start.Fields.ProtocolVersionNumber = 196608 - start.Fields.Parameters = []protocol.FieldsStartupMessageParameters{ - { + start.Fields.Parameters = append( + s.options, + protocol.FieldsStartupMessageParameters{ Name: "user", Value: s.dbuser, }, - { + protocol.FieldsStartupMessageParameters{ Name: "database", Value: s.db, }, - {}, - } + protocol.FieldsStartupMessageParameters{}, + ) err := s.writePacket(start) if err != nil { return err diff --git a/lib/gat/gatling/server/server_test.go b/lib/gat/gatling/server/server_test.go index 3828290a..ab0463f4 100644 --- a/lib/gat/gatling/server/server_test.go +++ b/lib/gat/gatling/server/server_test.go @@ -23,7 +23,7 @@ var test_user = config.User{ } func TestServerDial(t *testing.T) { - srv, err := Dial(context.TODO(), &test_user, &test_shard, &test_server) + srv, err := Dial(context.TODO(), nil, &test_user, &test_shard, &test_server) if err != nil { t.Error(err) } diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go index 6ded9dce..a8af8c35 100644 --- a/lib/gat/interfaces.go +++ b/lib/gat/interfaces.go @@ -23,6 +23,8 @@ const ( type Client interface { GetId() ClientID + GetOptions() []protocol.FieldsStartupMessageParameters + GetPreparedStatement(name string) *protocol.Parse GetPortal(name string) *protocol.Bind GetCurrentConn() Connection @@ -78,7 +80,7 @@ type QueryRouter interface { type Pool interface { GetUser() *config.User - GetServerInfo() []*protocol.ParameterStatus + GetServerInfo(client Client) []*protocol.ParameterStatus GetDatabase() Database @@ -106,7 +108,7 @@ const ( ConnectionNew = "new" ) -type Dialer = func(context.Context, *config.User, *config.Shard, *config.Server) (Connection, error) +type Dialer = func(context.Context, []protocol.FieldsStartupMessageParameters, *config.User, *config.Shard, *config.Server) (Connection, error) type Connection interface { GetServerInfo() []*protocol.ParameterStatus diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go index 49354e30..24986a8a 100644 --- a/lib/gat/pool/session/pool.go +++ b/lib/gat/pool/session/pool.go @@ -35,13 +35,13 @@ func New(database gat.Database, dialer gat.Dialer, conf *config.Pool, user *conf return p } -func (p *Pool) getConnection() (gat.Connection, error) { +func (p *Pool) getConnection(client gat.Client) (gat.Connection, error) { select { case c := <-p.servers: return c, nil default: shard := p.c.Load().Shards[0] - return p.dialer(context.TODO(), p.user, shard, shard.Servers[0]) + return p.dialer(context.TODO(), client.GetOptions(), p.user, shard, shard.Servers[0]) } } @@ -53,7 +53,7 @@ func (p *Pool) getOrAssign(client gat.Client) (gat.Connection, error) { cid := client.GetId() c, ok := p.assigned.Load(cid) if !ok { - get, err := p.getConnection() + get, err := p.getConnection(client) if err != nil { return nil, err } @@ -84,8 +84,8 @@ func (p *Pool) GetUser() *config.User { return p.user } -func (p *Pool) GetServerInfo() []*protocol.ParameterStatus { - c, err := p.getConnection() +func (p *Pool) GetServerInfo(client gat.Client) []*protocol.ParameterStatus { + c, err := p.getConnection(client) if err != nil { return nil } diff --git a/lib/gat/pool/transaction/pool.go b/lib/gat/pool/transaction/pool.go index ca14ae30..f3fe5776 100644 --- a/lib/gat/pool/transaction/pool.go +++ b/lib/gat/pool/transaction/pool.go @@ -72,8 +72,8 @@ func (c *Pool) GetUser() *config.User { return c.user } -func (c *Pool) GetServerInfo() []*protocol.ParameterStatus { - return c.getWorker().GetServerInfo() +func (c *Pool) GetServerInfo(client gat.Client) []*protocol.ParameterStatus { + return c.getWorker().GetServerInfo(client) } func (c *Pool) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error { diff --git a/lib/gat/pool/transaction/shard/shard.go b/lib/gat/pool/transaction/shard/shard.go index 05f440a2..1acdae07 100644 --- a/lib/gat/pool/transaction/shard/shard.go +++ b/lib/gat/pool/transaction/shard/shard.go @@ -4,6 +4,7 @@ import ( "context" "gfx.cafe/gfx/pggat/lib/config" "gfx.cafe/gfx/pggat/lib/gat" + "gfx.cafe/gfx/pggat/lib/gat/protocol" "math/rand" "reflect" ) @@ -15,14 +16,18 @@ type Shard struct { user *config.User conf *config.Shard + options []protocol.FieldsStartupMessageParameters + dialer gat.Dialer } -func FromConfig(dialer gat.Dialer, user *config.User, conf *config.Shard) *Shard { +func FromConfig(dialer gat.Dialer, options []protocol.FieldsStartupMessageParameters, user *config.User, conf *config.Shard) *Shard { out := &Shard{ user: user, conf: conf, + options: options, + dialer: dialer, } out.init() @@ -33,7 +38,7 @@ func (s *Shard) init() { s.primary = nil s.replicas = nil for _, serv := range s.conf.Servers { - srv, err := s.dialer(context.TODO(), s.user, s.conf, serv) + srv, err := s.dialer(context.TODO(), s.options, s.user, s.conf, serv) if err != nil { continue } diff --git a/lib/gat/pool/transaction/worker.go b/lib/gat/pool/transaction/worker.go index 92e5ae84..082e5a88 100644 --- a/lib/gat/pool/transaction/worker.go +++ b/lib/gat/pool/transaction/worker.go @@ -31,7 +31,7 @@ func (w *worker) ret() { } // attempt to connect to a new shard with this worker -func (w *worker) fetchShard(n int) bool { +func (w *worker) fetchShard(client gat.Client, n int) bool { conf := w.w.c.Load() if n < 0 || n >= len(conf.Shards) { return false @@ -41,7 +41,7 @@ func (w *worker) fetchShard(n int) bool { w.shards = append(w.shards, nil) } - w.shards[n] = shard.FromConfig(w.w.dialer, w.w.user, conf.Shards[n]) + w.shards[n] = shard.FromConfig(w.w.dialer, client.GetOptions(), w.w.user, conf.Shards[n]) return true } @@ -76,17 +76,17 @@ func (w *worker) chooseShard(client gat.Client) *shard.Shard { } // we need to fetch a shard - if w.fetchShard(preferred) { + if w.fetchShard(client, preferred) { return w.shards[preferred] } return nil } -func (w *worker) GetServerInfo() []*protocol.ParameterStatus { +func (w *worker) GetServerInfo(client gat.Client) []*protocol.ParameterStatus { defer w.ret() - s := w.chooseShard(nil) + s := w.chooseShard(client) if s == nil { return nil } diff --git a/lib/parse/parse.go b/lib/parse/parse.go index 2af41fe8..ab766763 100644 --- a/lib/parse/parse.go +++ b/lib/parse/parse.go @@ -270,7 +270,7 @@ func (r *reader) nextCommand() (cmd Command, err error) { // Parse parses an sql query in a single pass (with no look aheads or look behinds). // Because all we really care about is the commands, this can be very fast -// based on https://www.postgresql.org/docs/current/sql-syntax-lexical.html +// based on https://www.postgresql.org/docs/14/sql-syntax-lexical.html func Parse(sql string) (cmds []Command, err error) { r := reader{ v: sql, -- GitLab