diff --git a/lib/config/config.go b/lib/config/config.go index 778ada2ed9c47a35e2bfafaa8d4b3e06c510e67f..71be384900afd34f5bf668c3e1bc79e4adcfa628 100644 --- a/lib/config/config.go +++ b/lib/config/config.go @@ -64,7 +64,7 @@ type Pool struct { } type User struct { - Name string `toml:"name" yaml:"name" json:"name"` + Name string `toml:"username" yaml:"name" json:"name"` Password string `toml:"password" yaml:"password" json:"password"` PoolSize int `toml:"pool_size" yaml:"pool_size" json:"pool_size"` StatementTimeout int `toml:"statement_timeout" yaml:"statement_timeout" json:"statement_timeout"` diff --git a/lib/gat/admin.go b/lib/gat/admin.go index bbaae3134b3ee93cfb7538038ecb0b7b0d240e99..b98aa4b380aab6159ad29299207b5b856fea61e5 100644 --- a/lib/gat/admin.go +++ b/lib/gat/admin.go @@ -1,17 +1,50 @@ package gat -import "bytes" +import ( + "gfx.cafe/gfx/pggat/lib/gat/protocol" +) const SERVER_VERSION = "0.0.1" -func AdminServerInfo() []byte { - buf := new(bytes.Buffer) - buf.Write(ServerParameterMessage("application_name", "")) - buf.Write(ServerParameterMessage("client_encoding", "UTF8")) - buf.Write(ServerParameterMessage("server_encoding", "UTF8")) - buf.Write(ServerParameterMessage("server_version", SERVER_VERSION)) - buf.Write(ServerParameterMessage("DateStyle", "ISO, MDY")) - return buf.Bytes() +func AdminServerInfo() []*protocol.ParameterStatus { + return []*protocol.ParameterStatus{ + { + Fields: protocol.FieldsParameterStatus{ + Parameter: "application_name", + Value: "", + }, + }, + { + Fields: protocol.FieldsParameterStatus{ + Parameter: "client_encoding", + Value: "UTF8", + }, + }, + { + Fields: protocol.FieldsParameterStatus{ + Parameter: "server_encoding", + Value: "UTF8", + }, + }, + { + Fields: protocol.FieldsParameterStatus{ + Parameter: "server_encoding", + Value: "UTF8", + }, + }, + { + Fields: protocol.FieldsParameterStatus{ + Parameter: "server_version", + Value: SERVER_VERSION, + }, + }, + { + Fields: protocol.FieldsParameterStatus{ + Parameter: "DataStyle", + Value: "ISO, MDY", + }, + }, + } } ///// Handle admin client. diff --git a/lib/gat/client.go b/lib/gat/client.go index 60ed907015f5d202310ec0a8ffeadd1f3a3f5e98..05cee2ce9c35684279e120ce49e2676fa7597f04 100644 --- a/lib/gat/client.go +++ b/lib/gat/client.go @@ -7,6 +7,7 @@ import ( "crypto/rand" "crypto/tls" "fmt" + "gfx.cafe/gfx/pggat/lib/util/maps" "io" "math/big" "net" @@ -69,7 +70,7 @@ type Client struct { stats any // TODO: Reporter admin bool - server_info []byte + server_info []*protocol.ParameterStatus last_addr_id int last_srv_id int @@ -237,15 +238,26 @@ func (c *Client) Accept(ctx context.Context) error { } } } else { - // TODO: actually get a server pool c.server_info = AdminServerInfo() - pool := ServerPool{ - user: config.User{ - Name: "postgres", - Password: "postgres", - }, + pool, ok := c.conf.Pools[c.pool_name] + if !ok { + return &PostgresError{ + Severity: Fatal, + Code: InvalidAuthorizationSpecification, + Message: "no such pool", + } + } + _, user, ok := maps.FirstWhere(pool.Users, func(_ string, user config.User) bool { + return user.Name == c.username + }) + if !ok { + return &PostgresError{ + Severity: Fatal, + Code: InvalidPassword, + Message: "invalid password", + } } - pw_hash := Md5HashPassword(c.username, pool.user.Password, salt[:]) + pw_hash := Md5HashPassword(c.username, user.Password, salt[:]) if !reflect.DeepEqual(pw_hash, passwordResponse) { return &PostgresError{ Severity: Fatal, @@ -263,9 +275,11 @@ func (c *Client) Accept(ctx context.Context) error { } // - _, err = c.wr.Write(c.server_info) - if err != nil { - return err + for _, inf := range c.server_info { + _, err = inf.Write(c.wr) + if err != nil { + return err + } } backendKeyData := new(protocol.BackendKeyData) backendKeyData.Fields.ProcessID = c.pid @@ -303,7 +317,7 @@ func (c *Client) tick(ctx context.Context) (bool, error) { if err != nil { return true, err } - log.Printf("%T %+v", rsp, rsp) + log.Printf("%#v", rsp, rsp) switch cast := rsp.(type) { case *protocol.Describe: case *protocol.Query: diff --git a/lib/gat/gatling.go b/lib/gat/gatling.go index 52bcccc854a31ce05ad99aadcedb952981645563..eaa015a36b811dcc94de7ca75a32b2678ad1f22f 100644 --- a/lib/gat/gatling.go +++ b/lib/gat/gatling.go @@ -39,7 +39,8 @@ func (g *Gatling) ListenAndServe(ctx context.Context) error { return err } for { - c, err := ln.Accept() + var c net.Conn + c, err = ln.Accept() if err != nil { return err } diff --git a/lib/gat/messages.go b/lib/gat/messages.go index 860d3d5106561f7f11e03aba342667878cca6693..88b56701446e9b9921dd7cbaf94d596ffc10270a 100644 --- a/lib/gat/messages.go +++ b/lib/gat/messages.go @@ -1,7 +1,6 @@ package gat import ( - "bytes" "crypto/md5" "crypto/rand" "encoding/hex" @@ -315,12 +314,3 @@ func Md5HashPassword(user string, password string, salt []byte) []byte { // // Ok(bytes) // } -func ServerParameterMessage(key, value string) []byte { - buf := new(bytes.Buffer) - pkt := new(protocol.ParameterStatus) - pkt.Fields.Parameter = key - pkt.Fields.Value = value - _, _ = pkt.Write(buf) - - return buf.Bytes() -} diff --git a/lib/gat/server.go b/lib/gat/server.go index f59a3f370411e84d6980407fa9719531e7d44074..9146bfc66e80846df6c5e03431147d0840b88101 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -6,7 +6,7 @@ import ( "encoding/binary" "fmt" "gfx.cafe/gfx/pggat/lib/gat/protocol" - "gfx.cafe/gfx/pggat/lib/util" + "gfx.cafe/gfx/pggat/lib/util/slices" "gfx.cafe/util/go/bufpool" "io" "net" @@ -111,7 +111,6 @@ func (s *Server) connect(ctx context.Context) error { if err != nil { return err } - log.Printf("pkt %T %+v", pkt, pkt) switch p := pkt.(type) { case *protocol.Authentication: switch p.Fields.Code { @@ -120,7 +119,7 @@ func (s *Server) connect(ctx context.Context) error { case 0: // AUTH SUCCESS case 10: // SASL s.log.Debug().Msg("starting sasl auth") - if util.Contains(p.Fields.SASLMechanism, scram.SHA256.Name()) { + if slices.Contains(p.Fields.SASLMechanism, scram.SHA256.Name()) { s.log.Debug().Str("method", "scram256").Msg("valid protocol") } else { return fmt.Errorf("unsupported scram version: %s", p.Fields.SASLMechanism) diff --git a/lib/util/maps/maps.go b/lib/util/maps/maps.go new file mode 100644 index 0000000000000000000000000000000000000000..95e2f7953cd9a7c640f97fffd3a467ef14f0ab63 --- /dev/null +++ b/lib/util/maps/maps.go @@ -0,0 +1,12 @@ +package maps + +func FirstWhere[K comparable, V any](haystack map[K]V, predicate func(K, V) bool) (K, V, bool) { + for k, v := range haystack { + if predicate(k, v) { + return k, v, true + } + } + var k K + var v V + return k, v, false +} diff --git a/lib/util/slice.go b/lib/util/slices/slices.go similarity index 90% rename from lib/util/slice.go rename to lib/util/slices/slices.go index 45ac74d55e77dadea1602d0869d69f01bbbd45d1..50c2902d81f023c9e0527b8e778703d91404703c 100644 --- a/lib/util/slice.go +++ b/lib/util/slices/slices.go @@ -1,4 +1,4 @@ -package util +package slices func Contains[T comparable](haystack []T, needle T) bool { for _, v := range haystack {