From ff267ce2b7dfeaef2ebfaf13538b1573ed347530 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 11 Oct 2023 17:11:28 -0500
Subject: [PATCH] hmmm

---
 lib/bouncer/backends/v0/accept.go        |  2 +-
 lib/bouncer/frontends/v0/accept.go       |  6 +++---
 lib/bouncer/frontends/v0/authenticate.go |  2 ++
 lib/fed/conn.go                          | 12 ++++++++++++
 4 files changed, 18 insertions(+), 4 deletions(-)

diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go
index 4247f09f..b74bb1ad 100644
--- a/lib/bouncer/backends/v0/accept.go
+++ b/lib/bouncer/backends/v0/accept.go
@@ -252,7 +252,7 @@ func enableSSL(ctx *acceptContext) (bool, error) {
 	}
 
 	// read byte to see if ssl is allowed
-	yn, err := ctx.Conn.Decoder.Uint8()
+	yn, err := ctx.Conn.ReadByte()
 	if err != nil {
 		return false, err
 	}
diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go
index 3b1b1994..50e51356 100644
--- a/lib/bouncer/frontends/v0/accept.go
+++ b/lib/bouncer/frontends/v0/accept.go
@@ -39,12 +39,12 @@ func startup0(
 		case *packets.StartupPayloadControlPayloadSSL:
 			// ssl is not enabled
 			if ctx.Options.SSLConfig == nil {
-				err = ctx.Conn.Encoder.Uint8('N')
+				err = ctx.Conn.WriteByte('N')
 				return
 			}
 
 			// do ssl
-			if err = ctx.Conn.Encoder.Uint8('S'); err != nil {
+			if err = ctx.Conn.WriteByte('S'); err != nil {
 				return
 			}
 			if err = ctx.Conn.EnableSSLServer(ctx.Options.SSLConfig); err != nil {
@@ -53,7 +53,7 @@ func startup0(
 			return
 		case *packets.StartupPayloadControlPayloadGSSAPI:
 			// GSSAPI is not supported yet
-			err = ctx.Conn.Encoder.Uint8('N')
+			err = ctx.Conn.WriteByte('N')
 			return
 		default:
 			err = perror.New(
diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go
index 50639355..4c534ee2 100644
--- a/lib/bouncer/frontends/v0/authenticate.go
+++ b/lib/bouncer/frontends/v0/authenticate.go
@@ -34,6 +34,7 @@ func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer)
 	if err != nil {
 		if errors.Is(err, io.EOF) {
 			done = true
+			err = nil
 			return
 		}
 		return
@@ -57,6 +58,7 @@ func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier
 	if err != nil {
 		if errors.Is(err, io.EOF) {
 			done = true
+			err = nil
 			return
 		}
 		return
diff --git a/lib/fed/conn.go b/lib/fed/conn.go
index 2b57d9df..20fbe57d 100644
--- a/lib/fed/conn.go
+++ b/lib/fed/conn.go
@@ -90,6 +90,18 @@ func (T *Conn) WritePacket(packet Packet) error {
 	return packet.WriteTo(&T.Encoder)
 }
 
+func (T *Conn) WriteByte(b byte) error {
+	return T.Encoder.Uint8(b)
+}
+
+func (T *Conn) ReadByte() (byte, error) {
+	if err := T.Flush(); err != nil {
+		return 0, err
+	}
+
+	return T.Decoder.Uint8()
+}
+
 func (T *Conn) EnableSSLClient(config *tls.Config) error {
 	// TODO(garet)
 	panic("TODO")
-- 
GitLab