diff --git a/lib/auth/credentials/credentials_test.go b/lib/auth/credentials/credentials_test.go new file mode 100644 index 0000000000000000000000000000000000000000..aab8a815dd54f1f4f236e2a384dafbd4b945201e --- /dev/null +++ b/lib/auth/credentials/credentials_test.go @@ -0,0 +1,28 @@ +package credentials + +import ( + "crypto/rand" + "testing" + + "pggat2/lib/auth" +) + +func TestMD5(t *testing.T) { + pw := FromString("bob", "jNKuKKlBDO48qbLiVw7IuoaamZ1SmHAUdQ9PKH7qRzsyJVF0BNPSFMbHTQwxe0HJ") + md5 := FromString("bob", "md5e20510fd38e1c0fd99db13da5c29bd95") + + pwMD5 := pw.(auth.MD5) + md5MD5 := md5.(auth.MD5) + + var salt [4]byte + _, err := rand.Read(salt[:]) + if err != nil { + t.Error(err) + return + } + + err = md5MD5.VerifyMD5(salt, pwMD5.EncodeMD5(salt)) + if err != nil { + t.Error(err) + } +} diff --git a/lib/auth/credentials/string.go b/lib/auth/credentials/string.go new file mode 100644 index 0000000000000000000000000000000000000000..78b57243173b8d21a8f182e1b5410c6932f69e93 --- /dev/null +++ b/lib/auth/credentials/string.go @@ -0,0 +1,32 @@ +package credentials + +import ( + "encoding/hex" + "strings" + + "pggat2/lib/auth" +) + +func FromString(user, password string) auth.Credentials { + if password == "" { + return nil + } else if strings.HasPrefix(password, "md5") { + hexHash := strings.TrimPrefix(password, "md5") + hash, err := hex.DecodeString(hexHash) + if err != nil { + return Cleartext{ + Username: user, + Password: password, + } + } + return MD5{ + Username: user, + Hash: hash, + } + } else { + return Cleartext{ + Username: user, + Password: password, // TODO(garet) sasl + } + } +} diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index f63f35e179bebe0bc2717ad9f58d5a7db611af20..ef55450a9dc7d020ab59c574605dde82484c586c 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -104,6 +104,38 @@ func authenticationSASL(client fed.Conn, creds auth.SASL) perror.Error { return nil } +func authenticationMD5(client fed.Conn, creds auth.MD5) perror.Error { + var salt [4]byte + _, err := rand.Read(salt[:]) + if err != nil { + return perror.Wrap(err) + } + md5Initial := packets.AuthenticationMD5{ + Salt: salt, + } + err = client.WritePacket(md5Initial.IntoPacket()) + if err != nil { + return perror.Wrap(err) + } + + var packet fed.Packet + packet, err = client.ReadPacket(true) + if err != nil { + return perror.Wrap(err) + } + + var pw packets.PasswordMessage + if !pw.ReadFromPacket(packet) { + return packets.ErrUnexpectedPacket + } + + if err = creds.VerifyMD5(salt, pw.Password); err != nil { + return perror.Wrap(err) + } + + return nil +} + func updateParameter(client fed.Conn, name, value string) perror.Error { ps := packets.ParameterStatus{ Key: name, @@ -123,6 +155,8 @@ func authenticate(client fed.Conn, options AuthenticateOptions) (params Authenti } if credsSASL, ok := options.Credentials.(auth.SASL); ok { err = authenticationSASL(client, credsSASL) + } else if credsMD5, ok := options.Credentials.(auth.MD5); ok { + err = authenticationMD5(client, credsMD5) } else { err = perror.New( perror.FATAL, diff --git a/lib/gat/acceptor.go b/lib/gat/acceptor.go index 55c05b1770dcdca5ed39aa99ddd61b85ad3859fb..c41574ba2c098144ca3b2d2529b4440800d4177b 100644 --- a/lib/gat/acceptor.go +++ b/lib/gat/acceptor.go @@ -3,7 +3,6 @@ package gat import ( "net" - "pggat2/lib/auth" "pggat2/lib/bouncer/frontends/v0" "pggat2/lib/fed" ) @@ -54,22 +53,17 @@ func serve(client fed.Conn, acceptParams frontends.AcceptParams, pools Pools) er p := pools.Lookup(acceptParams.User, acceptParams.Database) - var credentials auth.Credentials - if p != nil { - credentials = p.GetCredentials() + if p == nil { + return nil } authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{ - Credentials: credentials, + Credentials: p.GetCredentials(), }) if err != nil { return err } - if p == nil { - return nil - } - pools.RegisterKey(authParams.BackendKey, acceptParams.User, acceptParams.Database) defer pools.UnregisterKey(authParams.BackendKey) diff --git a/lib/gat/modes/pgbouncer/pools.go b/lib/gat/modes/pgbouncer/pools.go index 6f34e8b44d54565b5d0bddfad615cb87a322230a..8fe62276e886acf112849e89c35eee699271f8e2 100644 --- a/lib/gat/modes/pgbouncer/pools.go +++ b/lib/gat/modes/pgbouncer/pools.go @@ -2,6 +2,7 @@ package pgbouncer import ( "crypto/tls" + "errors" "net" "strconv" "time" @@ -21,8 +22,8 @@ import ( ) type authQueryResult struct { - Username string `ini:"usename"` - Password *string `ini:"passwd"` + Username string `sql:"0"` + Password *string `sql:"1"` } type poolKey struct { @@ -106,20 +107,23 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { return nil } err = authPool.Serve(client, frontends.AcceptParams{}, frontends.AuthenticateParams{}) - if err != nil { + if err != nil && !errors.Is(err, net.ErrClosed) { log.Println("auth query failed:", err) return nil } + if result.Username != user { + // user not found + return nil + } + if result.Password != nil { password = *result.Password + ok = true } } - creds := credentials.Cleartext{ - Username: user, - Password: password, // TODO(garet) md5 and sasl - } + creds := credentials.FromString(user, password) backendDatabase := db.DBName if backendDatabase == "" { @@ -182,7 +186,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { creds := creds if db.Password != "" { // lookup password - creds.Password = db.Password + creds = credentials.FromString(user, db.Password) } // connect over tcp diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index 21b9ad0a334039e550a16c2711b8c4cb232a6eab..ca5947e5d3062960c6eba28373590c48de5e3ee5 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -13,8 +13,8 @@ import ( ) type Result struct { - Username string `sql:"usename"` - Password *string `sql:"passwd"` + Username string `sql:"0"` + Password *string `sql:"1"` } func TestQuery(t *testing.T) { @@ -39,7 +39,7 @@ func TestQuery(t *testing.T) { var res Result client := new(Client) - err = client.ExtendedQuery(&res, "SELECT $1 as usename, $2 as passwd", "username", "test") + err = client.ExtendedQuery(&res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "bob") if err != nil { t.Error(err) return diff --git a/lib/gsql/row.go b/lib/gsql/row.go index 1ab445fdb3d8d42800e7537480bbc9198a90f70e..9ca877fb67f0129470ea552e8e4a5a96b764b441 100644 --- a/lib/gsql/row.go +++ b/lib/gsql/row.go @@ -108,6 +108,15 @@ outer2: result = result.Field(j) break outer2 } + + // handle `sql:"3"` + sqlNameIndex, err := strconv.Atoi(sqlName) + if err == nil { + if sqlNameIndex == i { + result = result.Field(j) + break outer2 + } + } } // ignore field diff --git a/pgbouncer.ini b/pgbouncer.ini index 8a6df6f093193bf7747a47a9a6957b413fe41c1d..ee698d0f7ecb85bd8a095e52562ef7924fbcffb5 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -3,9 +3,7 @@ pool_mode = transaction auth_file = userlist.txt listen_addr = * track_extra_parameters = IntervalStyle, session_authorization, default_transaction_read_only, search_path - -[users] -postgres = +auth_user = postgres [databases] * = host=localhost datestyle=Postgres,MDY timezone=PST8PDT