From 202092d8350d93e67cee9b1f07ad9f64e51eef92 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Wed, 23 Aug 2023 15:03:05 -0500 Subject: [PATCH] more ssl --- lib/bouncer/backends/v0/accept.go | 22 ++++++++++++---------- lib/bouncer/backends/v0/options.go | 3 +++ lib/bouncer/conn.go | 1 + lib/bouncer/frontends/v0/accept.go | 19 ++++++++++++++++++- lib/bouncer/frontends/v0/options.go | 4 ++++ lib/bouncer/sslmode.go | 11 ++++++++++- lib/gat/pooler.go | 3 +++ lib/gat/recipe.go | 8 ++++++++ lib/middleware/interceptor/interceptor.go | 10 ++++++++-- lib/zap/conn.go | 15 ++++++++------- lib/zap/readwriter.go | 8 ++++++-- 11 files changed, 81 insertions(+), 23 deletions(-) diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index de4f4d1d..b1e084c0 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -1,6 +1,7 @@ package backends import ( + "crypto/tls" "errors" "pggat2/lib/auth" @@ -245,7 +246,7 @@ func startup1(conn *bouncer.Conn) (done bool, err error) { } } -func enableSSL(server zap.ReadWriter) (bool, error) { +func enableSSL(server zap.ReadWriter, config *tls.Config) (bool, error) { packet := zap.NewUntypedPacket() defer packet.Done() packet.WriteUint16(1234) @@ -265,7 +266,7 @@ func enableSSL(server zap.ReadWriter) (bool, error) { return false, nil } - if err = server.EnableSSL(true); err != nil { + if err = server.EnableSSLClient(config); err != nil { return false, err } @@ -279,12 +280,19 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error) options.Database = username } + conn := bouncer.Conn{ + RW: server, + User: username, + Database: options.Database, + } + if options.SSLMode.ShouldAttempt() { - ok, err := enableSSL(server) + var err error + conn.SSLEnabled, err = enableSSL(server, options.SSLConfig) if err != nil { return bouncer.Conn{}, err } - if !ok && options.SSLMode.IsRequired() { + if !conn.SSLEnabled && options.SSLMode.IsRequired() { return bouncer.Conn{}, errors.New("server rejected SSL encryption") } } @@ -320,12 +328,6 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error) } } - conn := bouncer.Conn{ - RW: server, - User: username, - Database: options.Database, - } - for { var done bool done, err = startup1(&conn) diff --git a/lib/bouncer/backends/v0/options.go b/lib/bouncer/backends/v0/options.go index 88f073b4..7d2b7619 100644 --- a/lib/bouncer/backends/v0/options.go +++ b/lib/bouncer/backends/v0/options.go @@ -1,6 +1,8 @@ package backends import ( + "crypto/tls" + "pggat2/lib/auth" "pggat2/lib/bouncer" "pggat2/lib/util/strutil" @@ -8,6 +10,7 @@ import ( type AcceptOptions struct { SSLMode bouncer.SSLMode + SSLConfig *tls.Config Credentials auth.Credentials Database string StartupParameters map[strutil.CIString]string diff --git a/lib/bouncer/conn.go b/lib/bouncer/conn.go index b469df5d..6ccff857 100644 --- a/lib/bouncer/conn.go +++ b/lib/bouncer/conn.go @@ -8,6 +8,7 @@ import ( type Conn struct { RW zap.ReadWriter + SSLEnabled bool User string Database string InitialParameters map[strutil.CIString]string diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index be247203..9e364366 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -57,12 +57,20 @@ func startup0( ) return case 5679: + // ssl is not enabled + if options.SSLConfig == nil { + err = perror.Wrap(client.RW.WriteByte('N')) + return + } + + // do ssl if err = perror.Wrap(client.RW.WriteByte('S')); err != nil { return } - if err = perror.Wrap(client.RW.EnableSSL(false)); err != nil { + if err = perror.Wrap(client.RW.EnableSSLServer(options.SSLConfig)); err != nil { return } + client.SSLEnabled = true return case 5680: // GSSAPI is not supported yet @@ -332,6 +340,15 @@ func accept( } } + if options.SSLRequired && !conn.SSLEnabled { + err = perror.New( + perror.FATAL, + perror.InvalidPassword, + "SSL is required", + ) + return + } + creds := options.Pooler.GetUserCredentials(conn.User, conn.Database) if creds == nil { err = perror.New( diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go index f3241fd5..26a77ace 100644 --- a/lib/bouncer/frontends/v0/options.go +++ b/lib/bouncer/frontends/v0/options.go @@ -1,11 +1,15 @@ package frontends import ( + "crypto/tls" + "pggat2/lib/bouncer" "pggat2/lib/util/strutil" ) type AcceptOptions struct { + SSLRequired bool + SSLConfig *tls.Config Pooler bouncer.Pooler AllowedStartupOptions []strutil.CIString } diff --git a/lib/bouncer/sslmode.go b/lib/bouncer/sslmode.go index d8c8bdf1..abcd2be9 100644 --- a/lib/bouncer/sslmode.go +++ b/lib/bouncer/sslmode.go @@ -22,9 +22,18 @@ func (T SSLMode) ShouldAttempt() bool { func (T SSLMode) IsRequired() bool { switch T { - case SSLModeDisable, SSLModeAllow, SSLModeRequire: + case SSLModeDisable, SSLModeAllow, SSLModeRequire, "": return false default: return true } } + +func (T SSLMode) VerifyCertificates() bool { + switch T { + case SSLModeVerifyCa, SSLModeVerifyFull: + return true + default: + return false + } +} diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index 1788bbc9..ee3a1480 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -25,6 +25,7 @@ type Pooler struct { type PoolerConfig struct { AllowedStartupParameters []strutil.CIString + SSLMode bouncer.SSLMode } func NewPooler(config PoolerConfig) *Pooler { @@ -84,6 +85,8 @@ func (T *Pooler) Serve(client zap.ReadWriter) { conn, err := frontends.Accept( client, frontends.AcceptOptions{ + SSLRequired: T.config.SSLMode.IsRequired(), + // TODO(garet) SSL Config Pooler: T, AllowedStartupOptions: T.config.AllowedStartupParameters, }, diff --git a/lib/gat/recipe.go b/lib/gat/recipe.go index 650bd603..58f41b88 100644 --- a/lib/gat/recipe.go +++ b/lib/gat/recipe.go @@ -1,6 +1,7 @@ package gat import ( + "crypto/tls" "net" "pggat2/lib/auth" @@ -28,6 +29,8 @@ type TCPRecipe struct { MinConnections int MaxConnections int + SSLMode bouncer.SSLMode + StartupParameters map[strutil.CIString]string } @@ -47,6 +50,11 @@ func (T TCPRecipe) Connect() (bouncer.Conn, error) { } server, err := backends.Accept(rw, backends.AcceptOptions{ + SSLMode: T.SSLMode, + SSLConfig: &tls.Config{ + // TODO(garet) SSL certificates if they need to be verified + InsecureSkipVerify: !T.SSLMode.VerifyCertificates(), + }, Credentials: T.Credentials, Database: T.Database, StartupParameters: T.StartupParameters, diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go index b6bef931..bca6ad34 100644 --- a/lib/middleware/interceptor/interceptor.go +++ b/lib/middleware/interceptor/interceptor.go @@ -1,6 +1,8 @@ package interceptor import ( + "crypto/tls" + "pggat2/lib/middleware" "pggat2/lib/zap" ) @@ -23,8 +25,12 @@ func NewInterceptor(rw zap.ReadWriter, middlewares ...middleware.Middleware) *In } } -func (T *Interceptor) EnableSSL(client bool) error { - return T.rw.EnableSSL(client) +func (T *Interceptor) EnableSSLClient(config *tls.Config) error { + return T.rw.EnableSSLClient(config) +} + +func (T *Interceptor) EnableSSLServer(config *tls.Config) error { + return T.rw.EnableSSLServer(config) } func (T *Interceptor) ReadByte() (byte, error) { diff --git a/lib/zap/conn.go b/lib/zap/conn.go index 41e08b1d..515900ff 100644 --- a/lib/zap/conn.go +++ b/lib/zap/conn.go @@ -17,13 +17,14 @@ func WrapNetConn(conn net.Conn) *Conn { } } -func (T *Conn) EnableSSL(client bool) error { - var sslConn *tls.Conn - if client { - sslConn = tls.Client(T.conn, nil) - } else { - sslConn = tls.Server(T.conn, nil) - } +func (T *Conn) EnableSSLClient(config *tls.Config) error { + sslConn := tls.Client(T.conn, config) + T.conn = sslConn + return sslConn.Handshake() +} + +func (T *Conn) EnableSSLServer(config *tls.Config) error { + sslConn := tls.Server(T.conn, config) T.conn = sslConn return sslConn.Handshake() } diff --git a/lib/zap/readwriter.go b/lib/zap/readwriter.go index 7e602d84..74350fa4 100644 --- a/lib/zap/readwriter.go +++ b/lib/zap/readwriter.go @@ -1,13 +1,17 @@ package zap -import "io" +import ( + "crypto/tls" + "io" +) type ReadWriter interface { io.ByteReader io.ByteWriter io.Closer - EnableSSL(client bool) error + EnableSSLClient(config *tls.Config) error + EnableSSLServer(config *tls.Config) error Read(*Packet) error ReadUntyped(*UntypedPacket) error -- GitLab