From 94494aad9903b274bf81bcf406ebf7b3e0c310ee Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Sun, 16 Jun 2024 17:49:51 -0500
Subject: [PATCH] noot

---
 lib/fed/codecs/netconncodec/codec.go | 97 ++++++++++++++++++++++++++++
 lib/fed/conn.go                      | 28 ++------
 lib/fed/interface.go                 | 17 +++++
 3 files changed, 120 insertions(+), 22 deletions(-)
 create mode 100644 lib/fed/codecs/netconncodec/codec.go
 create mode 100644 lib/fed/interface.go

diff --git a/lib/fed/codecs/netconncodec/codec.go b/lib/fed/codecs/netconncodec/codec.go
new file mode 100644
index 00000000..93d343c4
--- /dev/null
+++ b/lib/fed/codecs/netconncodec/codec.go
@@ -0,0 +1,97 @@
+package netconncodec
+
+import (
+	"crypto/tls"
+	"errors"
+	"net"
+	"sync"
+
+	"gfx.cafe/gfx/pggat/lib/fed"
+	"gfx.cafe/gfx/pggat/lib/util/decorator"
+)
+
+type Codec struct {
+	noCopy decorator.NoCopy
+
+	conn net.Conn
+	ssl  bool
+
+	encoder fed.Encoder
+	decoder fed.Decoder
+
+	mu sync.RWMutex
+}
+
+func NewCodec(rw net.Conn) fed.PacketCodec {
+	c := &Codec{
+		conn: rw,
+	}
+	c.encoder.Reset(rw)
+	c.decoder.Reset(rw)
+	return c
+}
+
+func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) {
+	if err := c.decoder.Next(typed); err != nil {
+		return nil, err
+	}
+	return fed.PendingPacket{
+		Decoder: &c.decoder,
+	}, nil
+}
+
+func (c *Codec) WritePacket(packet fed.Packet) error {
+	err := c.encoder.Next(packet.Type(), packet.Length())
+	if err != nil {
+		return err
+	}
+
+	return packet.WriteTo(&c.encoder)
+}
+
+func (c *Codec) Flush() error {
+	return c.encoder.Flush()
+}
+
+func (c *Codec) Close() error {
+	if err := c.encoder.Flush(); err != nil {
+		return err
+	}
+	return c.conn.Close()
+}
+
+func (c *Codec) LocalAddr() net.Addr {
+	return c.conn.LocalAddr()
+}
+
+func (c *Codec) SSL() bool {
+	return c.ssl
+}
+
+func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error {
+	c.mu.Lock()
+	defer c.mu.Unlock()
+	if c.ssl {
+		return errors.New("SSL is already enabled")
+	}
+	c.ssl = true
+
+	// Flush buffers
+	if err := c.Flush(); err != nil {
+		return err
+	}
+	if c.decoder.Buffered() > 0 {
+		return errors.New("expected empty read buffer")
+	}
+
+	var sslConn *tls.Conn
+	if isClient {
+		sslConn = tls.Client(c.conn, config)
+	} else {
+		sslConn = tls.Server(c.conn, config)
+	}
+	c.encoder.Reset(sslConn)
+	c.decoder.Reset(sslConn)
+	c.conn = sslConn
+	return sslConn.Handshake()
+}
diff --git a/lib/fed/conn.go b/lib/fed/conn.go
index 9ea71495..e1b804c1 100644
--- a/lib/fed/conn.go
+++ b/lib/fed/conn.go
@@ -3,7 +3,6 @@ package fed
 import (
 	"crypto/tls"
 	"errors"
-	"net"
 
 	"gfx.cafe/gfx/pggat/lib/util/decorator"
 	"gfx.cafe/gfx/pggat/lib/util/strutil"
@@ -12,10 +11,7 @@ import (
 type Conn struct {
 	noCopy decorator.NoCopy
 
-	encoder Encoder
-	decoder Decoder
-
-	NetConn net.Conn
+	codec PacketCodec
 
 	Middleware []Middleware
 
@@ -30,26 +26,19 @@ type Conn struct {
 	Ready         bool
 }
 
-func NewConn(rw net.Conn) *Conn {
+func NewConn(codec PacketCodec) *Conn {
 	c := &Conn{
-		NetConn: rw,
+		codec: codec,
 	}
-	c.encoder.Reset(rw)
-	c.decoder.Reset(rw)
 	return c
 }
 
 func (T *Conn) Flush() error {
-	return T.encoder.Flush()
+	return T.codec.Flush()
 }
 
 func (T *Conn) readPacket(typed bool) (Packet, error) {
-	if err := T.decoder.Next(typed); err != nil {
-		return nil, err
-	}
-	return PendingPacket{
-		Decoder: &T.decoder,
-	}, nil
+	return T.codec.ReadPacket(typed)
 }
 
 func (T *Conn) ReadPacket(typed bool) (Packet, error) {
@@ -107,12 +96,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) {
 }
 
 func (T *Conn) writePacket(packet Packet) error {
-	err := T.encoder.Next(packet.Type(), packet.Length())
-	if err != nil {
-		return err
-	}
-
-	return packet.WriteTo(&T.encoder)
+	return T.codec.WritePacket(packet)
 }
 
 func (T *Conn) WritePacket(packet Packet) error {
diff --git a/lib/fed/interface.go b/lib/fed/interface.go
new file mode 100644
index 00000000..00c4f660
--- /dev/null
+++ b/lib/fed/interface.go
@@ -0,0 +1,17 @@
+package fed
+
+import (
+	"crypto/tls"
+	"net"
+)
+
+type PacketCodec interface {
+	ReadPacket(typed bool) (Packet, error)
+	WritePacket(packet Packet) error
+	LocalAddr() net.Addr
+	Flush() error
+	Close() error
+
+	SSL() bool
+	EnableSSL(config *tls.Config, isClient bool) error
+}
-- 
GitLab