diff --git a/README.md b/README.md index ac41362c67b8db0a015aab36f8f130a108a4e3a9..b0bf159b40f30ce720aa0b47952d73bbb5722a0d 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,6 @@ Send each session to a new node. This mode supports all postgres features, but w One day these will maybe be supported - Cancelling in flight queries - Reserve pool (for serving long-stalled clients) -- SSL/TLS - Auth methods other than plaintext, MD5, and SASL-SCRAM-SHA256 - GSSAPI - Timeouts (other than transaction idle timeout) diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 078493c91603b80cefbc00985e955d50d86b8b6d..de4f4d1dd2187cb1bbba11c97fa656393935f467 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -245,6 +245,33 @@ func startup1(conn *bouncer.Conn) (done bool, err error) { } } +func enableSSL(server zap.ReadWriter) (bool, error) { + packet := zap.NewUntypedPacket() + defer packet.Done() + packet.WriteUint16(1234) + packet.WriteUint16(5679) + if err := server.WriteUntyped(packet); err != nil { + return false, err + } + + // read byte to see if ssl is allowed + yn, err := server.ReadByte() + if err != nil { + return false, err + } + + if yn != 'S' { + // not supported + return false, nil + } + + if err = server.EnableSSL(true); err != nil { + return false, err + } + + return true, nil +} + func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error) { username := options.Credentials.GetUsername() @@ -252,11 +279,21 @@ func Accept(server zap.ReadWriter, options AcceptOptions) (bouncer.Conn, error) options.Database = username } + if options.SSLMode.ShouldAttempt() { + ok, err := enableSSL(server) + if err != nil { + return bouncer.Conn{}, err + } + if !ok && options.SSLMode.IsRequired() { + return bouncer.Conn{}, errors.New("server rejected SSL encryption") + } + } + // we can re-use the memory for this pkt most of the way down because we don't pass this anywhere packet := zap.NewUntypedPacket() defer packet.Done() - packet.WriteInt16(3) - packet.WriteInt16(0) + packet.WriteUint16(3) + packet.WriteUint16(0) packet.WriteString("user") packet.WriteString(username) packet.WriteString("database") diff --git a/lib/bouncer/backends/v0/options.go b/lib/bouncer/backends/v0/options.go index 0361824523e5a0342ebf1a5d9fc80d6a0fdcbddc..88f073b49b06e46956e83608d22d2af9b5f9c5ff 100644 --- a/lib/bouncer/backends/v0/options.go +++ b/lib/bouncer/backends/v0/options.go @@ -2,10 +2,12 @@ package backends import ( "pggat2/lib/auth" + "pggat2/lib/bouncer" "pggat2/lib/util/strutil" ) type AcceptOptions struct { + SSLMode bouncer.SSLMode Credentials auth.Credentials Database string StartupParameters map[strutil.CIString]string diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 1f58f594e90d18d25fb5b053e98ba28041f490c0..3756bc1aa381deb966802f1608f1ef53d52d417f 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -58,7 +58,12 @@ func startup0( return case 5679: // SSL is not supported yet - err = perror.Wrap(client.RW.WriteByte('N')) + if err = perror.Wrap(client.RW.WriteByte('S')); err != nil { + return + } + if err = perror.Wrap(client.RW.EnableSSL(false)); err != nil { + return + } return case 5680: // GSSAPI is not supported yet diff --git a/lib/bouncer/sslmode.go b/lib/bouncer/sslmode.go new file mode 100644 index 0000000000000000000000000000000000000000..d8c8bdf185b71b50bf878be6f1d0122d3ff3a2b2 --- /dev/null +++ b/lib/bouncer/sslmode.go @@ -0,0 +1,30 @@ +package bouncer + +type SSLMode string + +const ( + SSLModeDisable SSLMode = "disable" + SSLModeAllow SSLMode = "allow" + SSLModePrefer SSLMode = "prefer" + SSLModeRequire SSLMode = "require" + SSLModeVerifyCa SSLMode = "verify-ca" + SSLModeVerifyFull SSLMode = "verify-full" +) + +func (T SSLMode) ShouldAttempt() bool { + switch T { + case SSLModeDisable: + return false + default: + return true + } +} + +func (T SSLMode) IsRequired() bool { + switch T { + case SSLModeDisable, SSLModeAllow, SSLModeRequire: + return false + default: + return true + } +} diff --git a/lib/zap/conn.go b/lib/zap/conn.go index 4804e5ddf82f8a5549299e26c26cb7b182ec4329..41e08b1d34e262dcd82e96c8e7f81efae8222f8b 100644 --- a/lib/zap/conn.go +++ b/lib/zap/conn.go @@ -1,6 +1,7 @@ package zap import ( + "crypto/tls" "io" "net" ) @@ -16,6 +17,17 @@ func WrapNetConn(conn net.Conn) *Conn { } } +func (T *Conn) EnableSSL(client bool) error { + var sslConn *tls.Conn + if client { + sslConn = tls.Client(T.conn, nil) + } else { + sslConn = tls.Server(T.conn, nil) + } + T.conn = sslConn + return sslConn.Handshake() +} + func (T *Conn) ReadByte() (byte, error) { _, err := io.ReadFull(T.conn, T.buf[:]) if err != nil { diff --git a/lib/zap/reader.go b/lib/zap/reader.go deleted file mode 100644 index a8f5bee817892b5b55f1d4cf729e4c5fc3a153e5..0000000000000000000000000000000000000000 --- a/lib/zap/reader.go +++ /dev/null @@ -1,9 +0,0 @@ -package zap - -type Reader interface { - ReadByte() (byte, error) - Read(*Packet) error - ReadUntyped(*UntypedPacket) error - - Close() error -} diff --git a/lib/zap/readwriter.go b/lib/zap/readwriter.go index 1f686ed460f1727ed0e2cc9dc34775f7397f4056..7e602d847fd1ada468905d443b643208e5027d93 100644 --- a/lib/zap/readwriter.go +++ b/lib/zap/readwriter.go @@ -4,6 +4,14 @@ import "io" type ReadWriter interface { io.ByteReader - Reader - Writer + io.ByteWriter + io.Closer + + EnableSSL(client bool) error + + Read(*Packet) error + ReadUntyped(*UntypedPacket) error + Write(*Packet) error + WriteUntyped(*UntypedPacket) error + WriteV(*Packets) error } diff --git a/lib/zap/writer.go b/lib/zap/writer.go deleted file mode 100644 index 5cf89989f7701d8e741e823e40a5822bf447e5e2..0000000000000000000000000000000000000000 --- a/lib/zap/writer.go +++ /dev/null @@ -1,10 +0,0 @@ -package zap - -type Writer interface { - WriteByte(byte) error - Write(*Packet) error - WriteUntyped(*UntypedPacket) error - WriteV(*Packets) error - - Close() error -}