From 202092d8350d93e67cee9b1f07ad9f64e51eef92 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 23 Aug 2023 15:03:05 -0500
Subject: [PATCH] more ssl

---
 lib/bouncer/backends/v0/accept.go         | 22 ++++++++++++----------
 lib/bouncer/backends/v0/options.go        |  3 +++
 lib/bouncer/conn.go                       |  1 +
 lib/bouncer/frontends/v0/accept.go        | 19 ++++++++++++++++++-
 lib/bouncer/frontends/v0/options.go       |  4 ++++
 lib/bouncer/sslmode.go                    | 11 ++++++++++-
 lib/gat/pooler.go                         |  3 +++
 lib/gat/recipe.go                         |  8 ++++++++
 lib/middleware/interceptor/interceptor.go | 10 ++++++++--
 lib/zap/conn.go                           | 15 ++++++++-------
 lib/zap/readwriter.go                     |  8 ++++++--
 11 files changed, 81 insertions(+), 23 deletions(-)

diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go
index de4f4d1d..b1e084c0 100644
--- a/lib/bouncer/backends/v0/accept.go
+++ b/lib/bouncer/backends/v0/accept.go
@@ -1,6 +1,7 @@
 package backends
 
 import (
+	"crypto/tls"
 	"errors"
 
 	"pggat2/lib/auth"
@@ -245,7 +246,7 @@ func startup1(conn *bouncer.Conn) (done bool, err error) {
 	}
 }
 
-func enableSSL(server zap.ReadWriter) (bool, error) {
+func enableSSL(server zap.ReadWriter, config *tls.Config) (bool, error) {
 	packet := zap.NewUntypedPacket()
 	defer packet.Done()
 	packet.WriteUint16(1234)
@@ -265,7 +266,7 @@ func enableSSL(server zap.ReadWriter) (bool, error) {
 		return false, nil
 	}
 
-	if err = server.EnableSSL(true); err != nil {
+	if err = server.EnableSSLClient(config); err != nil {
 		return false, err
 	}
 
@@ -279,12 +280,19 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error)
 		options.Database = username
 	}
 
+	conn := bouncer.Conn{
+		RW:       server,
+		User:     username,
+		Database: options.Database,
+	}
+
 	if options.SSLMode.ShouldAttempt() {
-		ok, err := enableSSL(server)
+		var err error
+		conn.SSLEnabled, err = enableSSL(server, options.SSLConfig)
 		if err != nil {
 			return bouncer.Conn{}, err
 		}
-		if !ok && options.SSLMode.IsRequired() {
+		if !conn.SSLEnabled && options.SSLMode.IsRequired() {
 			return bouncer.Conn{}, errors.New("server rejected SSL encryption")
 		}
 	}
@@ -320,12 +328,6 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error)
 		}
 	}
 
-	conn := bouncer.Conn{
-		RW:       server,
-		User:     username,
-		Database: options.Database,
-	}
-
 	for {
 		var done bool
 		done, err = startup1(&conn)
diff --git a/lib/bouncer/backends/v0/options.go b/lib/bouncer/backends/v0/options.go
index 88f073b4..7d2b7619 100644
--- a/lib/bouncer/backends/v0/options.go
+++ b/lib/bouncer/backends/v0/options.go
@@ -1,6 +1,8 @@
 package backends
 
 import (
+	"crypto/tls"
+
 	"pggat2/lib/auth"
 	"pggat2/lib/bouncer"
 	"pggat2/lib/util/strutil"
@@ -8,6 +10,7 @@ import (
 
 type AcceptOptions struct {
 	SSLMode           bouncer.SSLMode
+	SSLConfig         *tls.Config
 	Credentials       auth.Credentials
 	Database          string
 	StartupParameters map[strutil.CIString]string
diff --git a/lib/bouncer/conn.go b/lib/bouncer/conn.go
index b469df5d..6ccff857 100644
--- a/lib/bouncer/conn.go
+++ b/lib/bouncer/conn.go
@@ -8,6 +8,7 @@ import (
 type Conn struct {
 	RW zap.ReadWriter
 
+	SSLEnabled        bool
 	User              string
 	Database          string
 	InitialParameters map[strutil.CIString]string
diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go
index be247203..9e364366 100644
--- a/lib/bouncer/frontends/v0/accept.go
+++ b/lib/bouncer/frontends/v0/accept.go
@@ -57,12 +57,20 @@ func startup0(
 			)
 			return
 		case 5679:
+			// ssl is not enabled
+			if options.SSLConfig == nil {
+				err = perror.Wrap(client.RW.WriteByte('N'))
+				return
+			}
+
+			// do ssl
 			if err = perror.Wrap(client.RW.WriteByte('S')); err != nil {
 				return
 			}
-			if err = perror.Wrap(client.RW.EnableSSL(false)); err != nil {
+			if err = perror.Wrap(client.RW.EnableSSLServer(options.SSLConfig)); err != nil {
 				return
 			}
+			client.SSLEnabled = true
 			return
 		case 5680:
 			// GSSAPI is not supported yet
@@ -332,6 +340,15 @@ func accept(
 		}
 	}
 
+	if options.SSLRequired && !conn.SSLEnabled {
+		err = perror.New(
+			perror.FATAL,
+			perror.InvalidPassword,
+			"SSL is required",
+		)
+		return
+	}
+
 	creds := options.Pooler.GetUserCredentials(conn.User, conn.Database)
 	if creds == nil {
 		err = perror.New(
diff --git a/lib/bouncer/frontends/v0/options.go b/lib/bouncer/frontends/v0/options.go
index f3241fd5..26a77ace 100644
--- a/lib/bouncer/frontends/v0/options.go
+++ b/lib/bouncer/frontends/v0/options.go
@@ -1,11 +1,15 @@
 package frontends
 
 import (
+	"crypto/tls"
+
 	"pggat2/lib/bouncer"
 	"pggat2/lib/util/strutil"
 )
 
 type AcceptOptions struct {
+	SSLRequired           bool
+	SSLConfig             *tls.Config
 	Pooler                bouncer.Pooler
 	AllowedStartupOptions []strutil.CIString
 }
diff --git a/lib/bouncer/sslmode.go b/lib/bouncer/sslmode.go
index d8c8bdf1..abcd2be9 100644
--- a/lib/bouncer/sslmode.go
+++ b/lib/bouncer/sslmode.go
@@ -22,9 +22,18 @@ func (T SSLMode) ShouldAttempt() bool {
 
 func (T SSLMode) IsRequired() bool {
 	switch T {
-	case SSLModeDisable, SSLModeAllow, SSLModeRequire:
+	case SSLModeDisable, SSLModeAllow, SSLModeRequire, "":
 		return false
 	default:
 		return true
 	}
 }
+
+func (T SSLMode) VerifyCertificates() bool {
+	switch T {
+	case SSLModeVerifyCa, SSLModeVerifyFull:
+		return true
+	default:
+		return false
+	}
+}
diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go
index 1788bbc9..ee3a1480 100644
--- a/lib/gat/pooler.go
+++ b/lib/gat/pooler.go
@@ -25,6 +25,7 @@ type Pooler struct {
 
 type PoolerConfig struct {
 	AllowedStartupParameters []strutil.CIString
+	SSLMode                  bouncer.SSLMode
 }
 
 func NewPooler(config PoolerConfig) *Pooler {
@@ -84,6 +85,8 @@ func (T *Pooler) Serve(client zap.ReadWriter) {
 	conn, err := frontends.Accept(
 		client,
 		frontends.AcceptOptions{
+			SSLRequired: T.config.SSLMode.IsRequired(),
+			// TODO(garet) SSL Config
 			Pooler:                T,
 			AllowedStartupOptions: T.config.AllowedStartupParameters,
 		},
diff --git a/lib/gat/recipe.go b/lib/gat/recipe.go
index 650bd603..58f41b88 100644
--- a/lib/gat/recipe.go
+++ b/lib/gat/recipe.go
@@ -1,6 +1,7 @@
 package gat
 
 import (
+	"crypto/tls"
 	"net"
 
 	"pggat2/lib/auth"
@@ -28,6 +29,8 @@ type TCPRecipe struct {
 	MinConnections int
 	MaxConnections int
 
+	SSLMode bouncer.SSLMode
+
 	StartupParameters map[strutil.CIString]string
 }
 
@@ -47,6 +50,11 @@ func (T TCPRecipe) Connect() (bouncer.Conn, error) {
 	}
 
 	server, err := backends.Accept(rw, backends.AcceptOptions{
+		SSLMode: T.SSLMode,
+		SSLConfig: &tls.Config{
+			// TODO(garet) SSL certificates if they need to be verified
+			InsecureSkipVerify: !T.SSLMode.VerifyCertificates(),
+		},
 		Credentials:       T.Credentials,
 		Database:          T.Database,
 		StartupParameters: T.StartupParameters,
diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go
index b6bef931..bca6ad34 100644
--- a/lib/middleware/interceptor/interceptor.go
+++ b/lib/middleware/interceptor/interceptor.go
@@ -1,6 +1,8 @@
 package interceptor
 
 import (
+	"crypto/tls"
+
 	"pggat2/lib/middleware"
 	"pggat2/lib/zap"
 )
@@ -23,8 +25,12 @@ func NewInterceptor(rw zap.ReadWriter, middlewares ...middleware.Middleware) *In
 	}
 }
 
-func (T *Interceptor) EnableSSL(client bool) error {
-	return T.rw.EnableSSL(client)
+func (T *Interceptor) EnableSSLClient(config *tls.Config) error {
+	return T.rw.EnableSSLClient(config)
+}
+
+func (T *Interceptor) EnableSSLServer(config *tls.Config) error {
+	return T.rw.EnableSSLServer(config)
 }
 
 func (T *Interceptor) ReadByte() (byte, error) {
diff --git a/lib/zap/conn.go b/lib/zap/conn.go
index 41e08b1d..515900ff 100644
--- a/lib/zap/conn.go
+++ b/lib/zap/conn.go
@@ -17,13 +17,14 @@ func WrapNetConn(conn net.Conn) *Conn {
 	}
 }
 
-func (T *Conn) EnableSSL(client bool) error {
-	var sslConn *tls.Conn
-	if client {
-		sslConn = tls.Client(T.conn, nil)
-	} else {
-		sslConn = tls.Server(T.conn, nil)
-	}
+func (T *Conn) EnableSSLClient(config *tls.Config) error {
+	sslConn := tls.Client(T.conn, config)
+	T.conn = sslConn
+	return sslConn.Handshake()
+}
+
+func (T *Conn) EnableSSLServer(config *tls.Config) error {
+	sslConn := tls.Server(T.conn, config)
 	T.conn = sslConn
 	return sslConn.Handshake()
 }
diff --git a/lib/zap/readwriter.go b/lib/zap/readwriter.go
index 7e602d84..74350fa4 100644
--- a/lib/zap/readwriter.go
+++ b/lib/zap/readwriter.go
@@ -1,13 +1,17 @@
 package zap
 
-import "io"
+import (
+	"crypto/tls"
+	"io"
+)
 
 type ReadWriter interface {
 	io.ByteReader
 	io.ByteWriter
 	io.Closer
 
-	EnableSSL(client bool) error
+	EnableSSLClient(config *tls.Config) error
+	EnableSSLServer(config *tls.Config) error
 
 	Read(*Packet) error
 	ReadUntyped(*UntypedPacket) error
-- 
GitLab