From b1b5b2b9672f2c564cdd224992daedc675740a21 Mon Sep 17 00:00:00 2001
From: Garet Halliday <ghalliday@gfxlabs.io>
Date: Fri, 23 Sep 2022 11:44:20 -0500
Subject: [PATCH] send first user options to server closes #3

---
 lib/gat/admin/admin.go                  |  2 +-
 lib/gat/gatling/client/client.go        | 27 +++++++++++++++++--------
 lib/gat/gatling/server/server.go        | 17 ++++++++++------
 lib/gat/gatling/server/server_test.go   |  2 +-
 lib/gat/interfaces.go                   |  6 ++++--
 lib/gat/pool/session/pool.go            | 10 ++++-----
 lib/gat/pool/transaction/pool.go        |  4 ++--
 lib/gat/pool/transaction/shard/shard.go |  9 +++++++--
 lib/gat/pool/transaction/worker.go      | 10 ++++-----
 lib/parse/parse.go                      |  2 +-
 10 files changed, 56 insertions(+), 33 deletions(-)

diff --git a/lib/gat/admin/admin.go b/lib/gat/admin/admin.go
index dd91b0c1..32823d98 100644
--- a/lib/gat/admin/admin.go
+++ b/lib/gat/admin/admin.go
@@ -567,7 +567,7 @@ func (c *Pool) GetUser() *config.User {
 	return getAdminUser(c.database.gat)
 }
 
-func (c *Pool) GetServerInfo() []*protocol.ParameterStatus {
+func (c *Pool) GetServerInfo(_ gat.Client) []*protocol.ParameterStatus {
 	return getServerInfo(c.database.gat)
 }
 
diff --git a/lib/gat/gatling/client/client.go b/lib/gat/gatling/client/client.go
index 79ab36d1..88763f86 100644
--- a/lib/gat/gatling/client/client.go
+++ b/lib/gat/gatling/client/client.go
@@ -68,6 +68,8 @@ type Client struct {
 
 	recv chan protocol.Packet
 
+	options []protocol.FieldsStartupMessageParameters
+
 	state gat.ClientState
 
 	pid       int32
@@ -97,6 +99,10 @@ type Client struct {
 	mu sync.Mutex
 }
 
+func (c *Client) GetOptions() []protocol.FieldsStartupMessageParameters {
+	return c.options
+}
+
 func (c *Client) GetState() gat.ClientState {
 	c.mu.Lock()
 	defer c.mu.Unlock()
@@ -257,14 +263,20 @@ func (c *Client) Accept(ctx context.Context) error {
 			}
 		}
 	}
-	params := make(map[string]string)
+	c.options = make([]protocol.FieldsStartupMessageParameters, 0, len(startup.Fields.Parameters))
 	for _, v := range startup.Fields.Parameters {
-		params[v.Name] = v.Value
+		switch v.Name {
+		case "":
+		case "database":
+			c.poolName = v.Value
+		case "user":
+			c.username = v.Value
+		default:
+			c.options = append(c.options, v)
+		}
 	}
 
-	var ok bool
-	c.poolName, ok = params["database"]
-	if !ok {
+	if c.poolName == "" {
 		return &pg_error.Error{
 			Severity: pg_error.Fatal,
 			Code:     pg_error.InvalidAuthorizationSpecification,
@@ -272,8 +284,7 @@ func (c *Client) Accept(ctx context.Context) error {
 		}
 	}
 
-	c.username, ok = params["user"]
-	if !ok {
+	if c.username == "" {
 		return &pg_error.Error{
 			Severity: pg_error.Fatal,
 			Code:     pg_error.InvalidAuthorizationSpecification,
@@ -372,7 +383,7 @@ func (c *Client) Accept(ctx context.Context) error {
 	}
 
 	//
-	info := c.server.GetServerInfo()
+	info := c.server.GetServerInfo(c)
 	for _, inf := range info {
 		err = c.Send(inf)
 		if err != nil {
diff --git a/lib/gat/gatling/server/server.go b/lib/gat/gatling/server/server.go
index fc0813d3..f3c5676f 100644
--- a/lib/gat/gatling/server/server.go
+++ b/lib/gat/gatling/server/server.go
@@ -33,6 +33,8 @@ type Server struct {
 	client gat.Client
 	state  gat.ConnectionState
 
+	options []protocol.FieldsStartupMessageParameters
+
 	serverInfo []*protocol.ParameterStatus
 
 	processId int32
@@ -57,13 +59,15 @@ type Server struct {
 	mu sync.Mutex
 }
 
-func Dial(ctx context.Context, user *config.User, shard *config.Shard, server *config.Server) (gat.Connection, error) {
+func Dial(ctx context.Context, options []protocol.FieldsStartupMessageParameters, user *config.User, shard *config.Shard, server *config.Server) (gat.Connection, error) {
 	s := &Server{
 		addr: server.Host,
 		port: server.Port,
 
 		state: gat.ConnectionNew,
 
+		options: options,
+
 		boundPreparedStatments: make(map[string]*protocol.Parse),
 		boundPortals:           make(map[string]*protocol.Bind),
 
@@ -230,17 +234,18 @@ func (s *Server) startup(ctx context.Context) error {
 	s.log.Debug().Msg("sending startup")
 	start := new(protocol.StartupMessage)
 	start.Fields.ProtocolVersionNumber = 196608
-	start.Fields.Parameters = []protocol.FieldsStartupMessageParameters{
-		{
+	start.Fields.Parameters = append(
+		s.options,
+		protocol.FieldsStartupMessageParameters{
 			Name:  "user",
 			Value: s.dbuser,
 		},
-		{
+		protocol.FieldsStartupMessageParameters{
 			Name:  "database",
 			Value: s.db,
 		},
-		{},
-	}
+		protocol.FieldsStartupMessageParameters{},
+	)
 	err := s.writePacket(start)
 	if err != nil {
 		return err
diff --git a/lib/gat/gatling/server/server_test.go b/lib/gat/gatling/server/server_test.go
index 3828290a..ab0463f4 100644
--- a/lib/gat/gatling/server/server_test.go
+++ b/lib/gat/gatling/server/server_test.go
@@ -23,7 +23,7 @@ var test_user = config.User{
 }
 
 func TestServerDial(t *testing.T) {
-	srv, err := Dial(context.TODO(), &test_user, &test_shard, &test_server)
+	srv, err := Dial(context.TODO(), nil, &test_user, &test_shard, &test_server)
 	if err != nil {
 		t.Error(err)
 	}
diff --git a/lib/gat/interfaces.go b/lib/gat/interfaces.go
index 6ded9dce..a8af8c35 100644
--- a/lib/gat/interfaces.go
+++ b/lib/gat/interfaces.go
@@ -23,6 +23,8 @@ const (
 type Client interface {
 	GetId() ClientID
 
+	GetOptions() []protocol.FieldsStartupMessageParameters
+
 	GetPreparedStatement(name string) *protocol.Parse
 	GetPortal(name string) *protocol.Bind
 	GetCurrentConn() Connection
@@ -78,7 +80,7 @@ type QueryRouter interface {
 
 type Pool interface {
 	GetUser() *config.User
-	GetServerInfo() []*protocol.ParameterStatus
+	GetServerInfo(client Client) []*protocol.ParameterStatus
 
 	GetDatabase() Database
 
@@ -106,7 +108,7 @@ const (
 	ConnectionNew                    = "new"
 )
 
-type Dialer = func(context.Context, *config.User, *config.Shard, *config.Server) (Connection, error)
+type Dialer = func(context.Context, []protocol.FieldsStartupMessageParameters, *config.User, *config.Shard, *config.Server) (Connection, error)
 
 type Connection interface {
 	GetServerInfo() []*protocol.ParameterStatus
diff --git a/lib/gat/pool/session/pool.go b/lib/gat/pool/session/pool.go
index 49354e30..24986a8a 100644
--- a/lib/gat/pool/session/pool.go
+++ b/lib/gat/pool/session/pool.go
@@ -35,13 +35,13 @@ func New(database gat.Database, dialer gat.Dialer, conf *config.Pool, user *conf
 	return p
 }
 
-func (p *Pool) getConnection() (gat.Connection, error) {
+func (p *Pool) getConnection(client gat.Client) (gat.Connection, error) {
 	select {
 	case c := <-p.servers:
 		return c, nil
 	default:
 		shard := p.c.Load().Shards[0]
-		return p.dialer(context.TODO(), p.user, shard, shard.Servers[0])
+		return p.dialer(context.TODO(), client.GetOptions(), p.user, shard, shard.Servers[0])
 	}
 }
 
@@ -53,7 +53,7 @@ func (p *Pool) getOrAssign(client gat.Client) (gat.Connection, error) {
 	cid := client.GetId()
 	c, ok := p.assigned.Load(cid)
 	if !ok {
-		get, err := p.getConnection()
+		get, err := p.getConnection(client)
 		if err != nil {
 			return nil, err
 		}
@@ -84,8 +84,8 @@ func (p *Pool) GetUser() *config.User {
 	return p.user
 }
 
-func (p *Pool) GetServerInfo() []*protocol.ParameterStatus {
-	c, err := p.getConnection()
+func (p *Pool) GetServerInfo(client gat.Client) []*protocol.ParameterStatus {
+	c, err := p.getConnection(client)
 	if err != nil {
 		return nil
 	}
diff --git a/lib/gat/pool/transaction/pool.go b/lib/gat/pool/transaction/pool.go
index ca14ae30..f3fe5776 100644
--- a/lib/gat/pool/transaction/pool.go
+++ b/lib/gat/pool/transaction/pool.go
@@ -72,8 +72,8 @@ func (c *Pool) GetUser() *config.User {
 	return c.user
 }
 
-func (c *Pool) GetServerInfo() []*protocol.ParameterStatus {
-	return c.getWorker().GetServerInfo()
+func (c *Pool) GetServerInfo(client gat.Client) []*protocol.ParameterStatus {
+	return c.getWorker().GetServerInfo(client)
 }
 
 func (c *Pool) Describe(ctx context.Context, client gat.Client, d *protocol.Describe) error {
diff --git a/lib/gat/pool/transaction/shard/shard.go b/lib/gat/pool/transaction/shard/shard.go
index 05f440a2..1acdae07 100644
--- a/lib/gat/pool/transaction/shard/shard.go
+++ b/lib/gat/pool/transaction/shard/shard.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"gfx.cafe/gfx/pggat/lib/config"
 	"gfx.cafe/gfx/pggat/lib/gat"
+	"gfx.cafe/gfx/pggat/lib/gat/protocol"
 	"math/rand"
 	"reflect"
 )
@@ -15,14 +16,18 @@ type Shard struct {
 	user *config.User
 	conf *config.Shard
 
+	options []protocol.FieldsStartupMessageParameters
+
 	dialer gat.Dialer
 }
 
-func FromConfig(dialer gat.Dialer, user *config.User, conf *config.Shard) *Shard {
+func FromConfig(dialer gat.Dialer, options []protocol.FieldsStartupMessageParameters, user *config.User, conf *config.Shard) *Shard {
 	out := &Shard{
 		user: user,
 		conf: conf,
 
+		options: options,
+
 		dialer: dialer,
 	}
 	out.init()
@@ -33,7 +38,7 @@ func (s *Shard) init() {
 	s.primary = nil
 	s.replicas = nil
 	for _, serv := range s.conf.Servers {
-		srv, err := s.dialer(context.TODO(), s.user, s.conf, serv)
+		srv, err := s.dialer(context.TODO(), s.options, s.user, s.conf, serv)
 		if err != nil {
 			continue
 		}
diff --git a/lib/gat/pool/transaction/worker.go b/lib/gat/pool/transaction/worker.go
index 92e5ae84..082e5a88 100644
--- a/lib/gat/pool/transaction/worker.go
+++ b/lib/gat/pool/transaction/worker.go
@@ -31,7 +31,7 @@ func (w *worker) ret() {
 }
 
 // attempt to connect to a new shard with this worker
-func (w *worker) fetchShard(n int) bool {
+func (w *worker) fetchShard(client gat.Client, n int) bool {
 	conf := w.w.c.Load()
 	if n < 0 || n >= len(conf.Shards) {
 		return false
@@ -41,7 +41,7 @@ func (w *worker) fetchShard(n int) bool {
 		w.shards = append(w.shards, nil)
 	}
 
-	w.shards[n] = shard.FromConfig(w.w.dialer, w.w.user, conf.Shards[n])
+	w.shards[n] = shard.FromConfig(w.w.dialer, client.GetOptions(), w.w.user, conf.Shards[n])
 	return true
 }
 
@@ -76,17 +76,17 @@ func (w *worker) chooseShard(client gat.Client) *shard.Shard {
 	}
 
 	// we need to fetch a shard
-	if w.fetchShard(preferred) {
+	if w.fetchShard(client, preferred) {
 		return w.shards[preferred]
 	}
 
 	return nil
 }
 
-func (w *worker) GetServerInfo() []*protocol.ParameterStatus {
+func (w *worker) GetServerInfo(client gat.Client) []*protocol.ParameterStatus {
 	defer w.ret()
 
-	s := w.chooseShard(nil)
+	s := w.chooseShard(client)
 	if s == nil {
 		return nil
 	}
diff --git a/lib/parse/parse.go b/lib/parse/parse.go
index 2af41fe8..ab766763 100644
--- a/lib/parse/parse.go
+++ b/lib/parse/parse.go
@@ -270,7 +270,7 @@ func (r *reader) nextCommand() (cmd Command, err error) {
 
 // Parse parses an sql query in a single pass (with no look aheads or look behinds).
 // Because all we really care about is the commands, this can be very fast
-// based on https://www.postgresql.org/docs/current/sql-syntax-lexical.html
+// based on https://www.postgresql.org/docs/14/sql-syntax-lexical.html
 func Parse(sql string) (cmds []Command, err error) {
 	r := reader{
 		v: sql,
-- 
GitLab