From bf8a04e3b2b0575890a3a026b3b2f1e9b32e81e8 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 11 Oct 2023 18:01:09 -0500
Subject: [PATCH] pretty much working

---
 lib/bouncer/backends/v0/accept.go  |  2 +-
 lib/bouncer/frontends/v0/accept.go |  2 +-
 lib/fed/conn.go                    | 68 ++++++++++++++++++++----------
 lib/fed/decoder.go                 | 21 +++++++++
 lib/fed/encoder.go                 | 15 +++++++
 5 files changed, 84 insertions(+), 24 deletions(-)

diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go
index b74bb1ad..6fe9254b 100644
--- a/lib/bouncer/backends/v0/accept.go
+++ b/lib/bouncer/backends/v0/accept.go
@@ -262,7 +262,7 @@ func enableSSL(ctx *acceptContext) (bool, error) {
 		return false, nil
 	}
 
-	if err = ctx.Conn.EnableSSLClient(ctx.Options.SSLConfig); err != nil {
+	if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, true); err != nil {
 		return false, err
 	}
 
diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go
index 50e51356..00036e1a 100644
--- a/lib/bouncer/frontends/v0/accept.go
+++ b/lib/bouncer/frontends/v0/accept.go
@@ -47,7 +47,7 @@ func startup0(
 			if err = ctx.Conn.WriteByte('S'); err != nil {
 				return
 			}
-			if err = ctx.Conn.EnableSSLServer(ctx.Options.SSLConfig); err != nil {
+			if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, false); err != nil {
 				return
 			}
 			return
diff --git a/lib/fed/conn.go b/lib/fed/conn.go
index 20fbe57d..905a32d2 100644
--- a/lib/fed/conn.go
+++ b/lib/fed/conn.go
@@ -2,7 +2,9 @@ package fed
 
 import (
 	"crypto/tls"
+	"errors"
 	"io"
+	"net"
 
 	"gfx.cafe/gfx/pggat/lib/util/decorator"
 	"gfx.cafe/gfx/pggat/lib/util/strutil"
@@ -11,9 +13,10 @@ import (
 type Conn struct {
 	noCopy decorator.NoCopy
 
+	encoder Encoder
+	decoder Decoder
+
 	ReadWriter io.ReadWriteCloser
-	Encoder    Encoder
-	Decoder    Decoder
 
 	Middleware []Middleware
 	SSL        bool
@@ -29,27 +32,27 @@ func NewConn(rw io.ReadWriteCloser) *Conn {
 	c := &Conn{
 		ReadWriter: rw,
 	}
-	c.Encoder.Writer.Reset(rw)
-	c.Decoder.Reader.Reset(rw)
+	c.encoder.Writer.Reset(rw)
+	c.decoder.Reader.Reset(rw)
 	return c
 }
 
-func (T *Conn) Flush() error {
-	return T.Encoder.Flush()
+func (T *Conn) flush() error {
+	return T.encoder.Flush()
 }
 
 func (T *Conn) ReadPacket(typed bool) (Packet, error) {
-	if err := T.Flush(); err != nil {
+	if err := T.flush(); err != nil {
 		return nil, err
 	}
 
 	for {
-		if err := T.Decoder.Next(typed); err != nil {
+		if err := T.decoder.Next(typed); err != nil {
 			return nil, err
 		}
 		var packet Packet
 		packet = PendingPacket{
-			Decoder: &T.Decoder,
+			Decoder: &T.decoder,
 		}
 		for _, middleware := range T.Middleware {
 			var err error
@@ -82,38 +85,59 @@ func (T *Conn) WritePacket(packet Packet) error {
 		return nil
 	}
 
-	err := T.Encoder.Next(packet.Type(), packet.Length())
+	err := T.encoder.Next(packet.Type(), packet.Length())
 	if err != nil {
 		return err
 	}
 
-	return packet.WriteTo(&T.Encoder)
+	return packet.WriteTo(&T.encoder)
 }
 
 func (T *Conn) WriteByte(b byte) error {
-	return T.Encoder.Uint8(b)
+	return T.encoder.WriteByte(b)
 }
 
 func (T *Conn) ReadByte() (byte, error) {
-	if err := T.Flush(); err != nil {
+	if err := T.flush(); err != nil {
 		return 0, err
 	}
 
-	return T.Decoder.Uint8()
+	return T.decoder.ReadByte()
 }
 
-func (T *Conn) EnableSSLClient(config *tls.Config) error {
-	// TODO(garet)
-	panic("TODO")
-}
+func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error {
+	if T.SSL {
+		return errors.New("SSL is already enabled")
+	}
+	T.SSL = true
+
+	// flush buffers
+	if err := T.flush(); err != nil {
+		return err
+	}
+	if T.decoder.Reader.Buffered() > 0 {
+		return errors.New("expected empty read buffer")
+	}
 
-func (T *Conn) EnableSSLServer(config *tls.Config) error {
-	// TODO(garet)
-	panic("TODO")
+	conn, ok := T.ReadWriter.(net.Conn)
+	if !ok {
+		return errors.New("ssl not supported for this read writer")
+	}
+
+	var sslConn *tls.Conn
+	if isClient {
+		sslConn = tls.Client(conn, config)
+	} else {
+		sslConn = tls.Server(conn, config)
+	}
+	T.encoder.Writer.Reset(sslConn)
+	T.decoder.Reader.Reset(sslConn)
+	T.ReadWriter = sslConn
+	return sslConn.Handshake()
 }
 
 func (T *Conn) Close() error {
-	if err := T.Encoder.Flush(); err != nil {
+	if err := T.encoder.Flush(); err != nil {
 		return err
 	}
 
diff --git a/lib/fed/decoder.go b/lib/fed/decoder.go
index ae53d304..6aa011b0 100644
--- a/lib/fed/decoder.go
+++ b/lib/fed/decoder.go
@@ -28,7 +28,28 @@ func NewDecoder(r io.Reader) *Decoder {
 	return d
 }
 
+func (T *Decoder) ReadByte() (byte, error) {
+	if T.pos != T.len {
+		_, err := T.Reader.Discard(T.len - T.pos)
+		if err != nil {
+			return 0, err
+		}
+	}
+
+	T.typ = 0
+	T.len = 0
+	T.pos = 0
+	return T.Reader.ReadByte()
+}
+
 func (T *Decoder) Next(typed bool) error {
+	if T.pos != T.len {
+		_, err := T.Reader.Discard(T.len - T.pos)
+		if err != nil {
+			return err
+		}
+	}
+
 	var err error
 	if typed {
 		_, err = io.ReadFull(&T.Reader, T.buf[:5])
diff --git a/lib/fed/encoder.go b/lib/fed/encoder.go
index 247982be..44d257b9 100644
--- a/lib/fed/encoder.go
+++ b/lib/fed/encoder.go
@@ -31,7 +31,22 @@ func (T *Encoder) Flush() error {
 	return T.Writer.Flush()
 }
 
+func (T *Encoder) WriteByte(b byte) error {
+	if T.pos != T.len {
+		panic("wrong number of bytes written")
+	}
+
+	T.typ = 0
+	T.len = 0
+	T.pos = 0
+	return T.Writer.WriteByte(b)
+}
+
 func (T *Encoder) Next(typ Type, length int) error {
+	if T.pos != T.len {
+		panic("wrong number of bytes written")
+	}
+
 	if typ != 0 {
 		if err := T.Writer.WriteByte(byte(typ)); err != nil {
 			return err
-- 
GitLab