diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index de4f4d1dd2187cb1bbba11c97fa656393935f467..b1e084c08b0c8049dbcb9def8da098a06467abc0 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -1,6 +1,7 @@ package backends import ( + "crypto/tls" "errors" "pggat2/lib/auth" @@ -245,7 +246,7 @@ func startup1(conn *bouncer.Conn) (done bool, err error) { } } -func enableSSL(server zap.ReadWriter) (bool, error) { +func enableSSL(server zap.ReadWriter, config *tls.Config) (bool, error) { packet := zap.NewUntypedPacket() defer packet.Done() packet.WriteUint16(1234) @@ -265,7 +266,7 @@ func enableSSL(server zap.ReadWriter) (bool, error) { return false, nil } - if err = server.EnableSSL(true); err != nil { + if err = server.EnableSSLClient(config); err != nil { return false, err } @@ -279,12 +280,19 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error) options.Database = username } + conn := bouncer.Conn{ + RW: server, + User: username, + Database: options.Database, + } + if options.SSLMode.ShouldAttempt() { - ok, err := enableSSL(server) + var err error + conn.SSLEnabled, err = enableSSL(server, options.SSLConfig) if err != nil { return bouncer.Conn{}, err } - if !ok && options.SSLMode.IsRequired() { + if !conn.SSLEnabled && options.SSLMode.IsRequired() { return bouncer.Conn{}, errors.New("server rejected SSL encryption") } } @@ -320,12 +328,6 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error) } } - conn := bouncer.Conn{ - RW: server, - User: username, - Database: options.Database, - } - for { var done bool done, err = startup1(&conn) diff --git a/lib/bouncer/backends/v0/options.go b/lib/bouncer/backends/v0/options.go index 88f073b49b06e46956e83608d22d2af9b5f9c5ff..7d2b7619d4d51fcad535b3f275d71fe86fbc6108 100644 --- a/lib/bouncer/backends/v0/options.go +++ b/lib/bouncer/backends/v0/options.go @@ -1,6 +1,8 @@ package backends import ( + "crypto/tls" + "pggat2/lib/auth" "pggat2/lib/bouncer" "pggat2/lib/util/strutil" @@ -8,6 +10,7 @@ import ( type AcceptOptions struct { SSLMode bouncer.SSLMode + SSLConfig *tls.Config Credentials auth.Credentials Database string StartupParameters map[strutil.CIString]string diff --git a/lib/bouncer/conn.go b/lib/bouncer/conn.go index b469df5dae922ba69d1ec9b330d02a20fb2a1ec4..6ccff8573ac6f05b2be1ebbc231ad9e8a97b3c3e 100644 --- a/lib/bouncer/conn.go +++ b/lib/bouncer/conn.go @@ -8,6 +8,7 @@ import ( type Conn struct { RW zap.ReadWriter + SSLEnabled bool User string Database string InitialParameters map[strutil.CIString]string diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index be24720330af8b0ebd18ede68660aefaac7bf2f6..9e364366c23a2bb4f82ebfaa88f1222d0f51115b 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -57,12 +57,20 @@ func startup0( ) return case 5679: + // ssl is not enabled + if options.SSLConfig == nil { + err = perror.Wrap(client.RW.WriteByte('N')) + return + } + + // do ssl if err = perror.Wrap(client.RW.WriteByte('S')); err != nil { return } - if err = perror.Wrap(client.RW.EnableSSL(false)); err != nil { + if err = perror.Wrap(client.RW.EnableSSLServer(options.SSLConfig)); err != nil { return } + client.SSLEnabled = true return case 5680: // GSSAPI is not supported yet @@ -332,6 +340,15 @@ func accept( } } + if options.SSLRequired && !conn.SSLEnabled { + err = perror.New( + perror.FATAL, + perror.InvalidPassword, + "SSL is required", + ) + return + } + creds := options.Pooler.GetUserCredentials(conn.User, conn.Database) if creds == nil { err = perror.New( diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go index f3241fd5b70a1672adf2405f58c91ca410f93bf3..26a77ace46cbcc1ac466b76ce34dbafb3b592afc 100644 --- a/lib/bouncer/frontends/v0/options.go +++ b/lib/bouncer/frontends/v0/options.go @@ -1,11 +1,15 @@ package frontends import ( + "crypto/tls" + "pggat2/lib/bouncer" "pggat2/lib/util/strutil" ) type AcceptOptions struct { + SSLRequired bool + SSLConfig *tls.Config Pooler bouncer.Pooler AllowedStartupOptions []strutil.CIString } diff --git a/lib/bouncer/sslmode.go b/lib/bouncer/sslmode.go index d8c8bdf185b71b50bf878be6f1d0122d3ff3a2b2..abcd2be924b698a35f4581bd1d55399c2ff4b9c0 100644 --- a/lib/bouncer/sslmode.go +++ b/lib/bouncer/sslmode.go @@ -22,9 +22,18 @@ func (T SSLMode) ShouldAttempt() bool { func (T SSLMode) IsRequired() bool { switch T { - case SSLModeDisable, SSLModeAllow, SSLModeRequire: + case SSLModeDisable, SSLModeAllow, SSLModeRequire, "": return false default: return true } } + +func (T SSLMode) VerifyCertificates() bool { + switch T { + case SSLModeVerifyCa, SSLModeVerifyFull: + return true + default: + return false + } +} diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index 1788bbc9d72ae7f06b7e32c0dd7f830452601bdc..ee3a14806dbf0c11e1260f0d5a01b19a94925ef6 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -25,6 +25,7 @@ type Pooler struct { type PoolerConfig struct { AllowedStartupParameters []strutil.CIString + SSLMode bouncer.SSLMode } func NewPooler(config PoolerConfig) *Pooler { @@ -84,6 +85,8 @@ func (T *Pooler) Serve(client zap.ReadWriter) { conn, err := frontends.Accept( client, frontends.AcceptOptions{ + SSLRequired: T.config.SSLMode.IsRequired(), + // TODO(garet) SSL Config Pooler: T, AllowedStartupOptions: T.config.AllowedStartupParameters, }, diff --git a/lib/gat/recipe.go b/lib/gat/recipe.go index 650bd6034ad4f6a519333ceaa4d8f4c2c86eff88..58f41b88eb02efaf95179b0291665b7e485349a0 100644 --- a/lib/gat/recipe.go +++ b/lib/gat/recipe.go @@ -1,6 +1,7 @@ package gat import ( + "crypto/tls" "net" "pggat2/lib/auth" @@ -28,6 +29,8 @@ type TCPRecipe struct { MinConnections int MaxConnections int + SSLMode bouncer.SSLMode + StartupParameters map[strutil.CIString]string } @@ -47,6 +50,11 @@ func (T TCPRecipe) Connect() (bouncer.Conn, error) { } server, err := backends.Accept(rw, backends.AcceptOptions{ + SSLMode: T.SSLMode, + SSLConfig: &tls.Config{ + // TODO(garet) SSL certificates if they need to be verified + InsecureSkipVerify: !T.SSLMode.VerifyCertificates(), + }, Credentials: T.Credentials, Database: T.Database, StartupParameters: T.StartupParameters, diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go index b6bef93153baac0c9f88087cccd0b59ac6a24a19..bca6ad34592b2cf077ae12293d769147c66f71f7 100644 --- a/lib/middleware/interceptor/interceptor.go +++ b/lib/middleware/interceptor/interceptor.go @@ -1,6 +1,8 @@ package interceptor import ( + "crypto/tls" + "pggat2/lib/middleware" "pggat2/lib/zap" ) @@ -23,8 +25,12 @@ func NewInterceptor(rw zap.ReadWriter, middlewares ...middleware.Middleware) *In } } -func (T *Interceptor) EnableSSL(client bool) error { - return T.rw.EnableSSL(client) +func (T *Interceptor) EnableSSLClient(config *tls.Config) error { + return T.rw.EnableSSLClient(config) +} + +func (T *Interceptor) EnableSSLServer(config *tls.Config) error { + return T.rw.EnableSSLServer(config) } func (T *Interceptor) ReadByte() (byte, error) { diff --git a/lib/zap/conn.go b/lib/zap/conn.go index 41e08b1d34e262dcd82e96c8e7f81efae8222f8b..515900ff4ad923fa47929bc1a9516920be79ae0c 100644 --- a/lib/zap/conn.go +++ b/lib/zap/conn.go @@ -17,13 +17,14 @@ func WrapNetConn(conn net.Conn) *Conn { } } -func (T *Conn) EnableSSL(client bool) error { - var sslConn *tls.Conn - if client { - sslConn = tls.Client(T.conn, nil) - } else { - sslConn = tls.Server(T.conn, nil) - } +func (T *Conn) EnableSSLClient(config *tls.Config) error { + sslConn := tls.Client(T.conn, config) + T.conn = sslConn + return sslConn.Handshake() +} + +func (T *Conn) EnableSSLServer(config *tls.Config) error { + sslConn := tls.Server(T.conn, config) T.conn = sslConn return sslConn.Handshake() } diff --git a/lib/zap/readwriter.go b/lib/zap/readwriter.go index 7e602d847fd1ada468905d443b643208e5027d93..74350fa4c59fcddf671282b22a638a8810ef1673 100644 --- a/lib/zap/readwriter.go +++ b/lib/zap/readwriter.go @@ -1,13 +1,17 @@ package zap -import "io" +import ( + "crypto/tls" + "io" +) type ReadWriter interface { io.ByteReader io.ByteWriter io.Closer - EnableSSL(client bool) error + EnableSSLClient(config *tls.Config) error + EnableSSLServer(config *tls.Config) error Read(*Packet) error ReadUntyped(*UntypedPacket) error