From 94494aad9903b274bf81bcf406ebf7b3e0c310ee Mon Sep 17 00:00:00 2001 From: a <a@tuxpa.in> Date: Sun, 16 Jun 2024 17:49:51 -0500 Subject: [PATCH] noot --- lib/fed/codecs/netconncodec/codec.go | 97 ++++++++++++++++++++++++++++ lib/fed/conn.go | 28 ++------ lib/fed/interface.go | 17 +++++ 3 files changed, 120 insertions(+), 22 deletions(-) create mode 100644 lib/fed/codecs/netconncodec/codec.go create mode 100644 lib/fed/interface.go diff --git a/lib/fed/codecs/netconncodec/codec.go b/lib/fed/codecs/netconncodec/codec.go new file mode 100644 index 00000000..93d343c4 --- /dev/null +++ b/lib/fed/codecs/netconncodec/codec.go @@ -0,0 +1,97 @@ +package netconncodec + +import ( + "crypto/tls" + "errors" + "net" + "sync" + + "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/util/decorator" +) + +type Codec struct { + noCopy decorator.NoCopy + + conn net.Conn + ssl bool + + encoder fed.Encoder + decoder fed.Decoder + + mu sync.RWMutex +} + +func NewCodec(rw net.Conn) fed.PacketCodec { + c := &Codec{ + conn: rw, + } + c.encoder.Reset(rw) + c.decoder.Reset(rw) + return c +} + +func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) { + if err := c.decoder.Next(typed); err != nil { + return nil, err + } + return fed.PendingPacket{ + Decoder: &c.decoder, + }, nil +} + +func (c *Codec) WritePacket(packet fed.Packet) error { + err := c.encoder.Next(packet.Type(), packet.Length()) + if err != nil { + return err + } + + return packet.WriteTo(&c.encoder) +} + +func (c *Codec) Flush() error { + return c.encoder.Flush() +} + +func (c *Codec) Close() error { + if err := c.encoder.Flush(); err != nil { + return err + } + return c.conn.Close() +} + +func (c *Codec) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *Codec) SSL() bool { + return c.ssl +} + +func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.ssl { + return errors.New("SSL is already enabled") + } + c.ssl = true + + // Flush buffers + if err := c.Flush(); err != nil { + return err + } + if c.decoder.Buffered() > 0 { + return errors.New("expected empty read buffer") + } + + var sslConn *tls.Conn + if isClient { + sslConn = tls.Client(c.conn, config) + } else { + sslConn = tls.Server(c.conn, config) + } + c.encoder.Reset(sslConn) + c.decoder.Reset(sslConn) + c.conn = sslConn + return sslConn.Handshake() +} diff --git a/lib/fed/conn.go b/lib/fed/conn.go index 9ea71495..e1b804c1 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -3,7 +3,6 @@ package fed import ( "crypto/tls" "errors" - "net" "gfx.cafe/gfx/pggat/lib/util/decorator" "gfx.cafe/gfx/pggat/lib/util/strutil" @@ -12,10 +11,7 @@ import ( type Conn struct { noCopy decorator.NoCopy - encoder Encoder - decoder Decoder - - NetConn net.Conn + codec PacketCodec Middleware []Middleware @@ -30,26 +26,19 @@ type Conn struct { Ready bool } -func NewConn(rw net.Conn) *Conn { +func NewConn(codec PacketCodec) *Conn { c := &Conn{ - NetConn: rw, + codec: codec, } - c.encoder.Reset(rw) - c.decoder.Reset(rw) return c } func (T *Conn) Flush() error { - return T.encoder.Flush() + return T.codec.Flush() } func (T *Conn) readPacket(typed bool) (Packet, error) { - if err := T.decoder.Next(typed); err != nil { - return nil, err - } - return PendingPacket{ - Decoder: &T.decoder, - }, nil + return T.codec.ReadPacket(typed) } func (T *Conn) ReadPacket(typed bool) (Packet, error) { @@ -107,12 +96,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } func (T *Conn) writePacket(packet Packet) error { - err := T.encoder.Next(packet.Type(), packet.Length()) - if err != nil { - return err - } - - return packet.WriteTo(&T.encoder) + return T.codec.WritePacket(packet) } func (T *Conn) WritePacket(packet Packet) error { diff --git a/lib/fed/interface.go b/lib/fed/interface.go new file mode 100644 index 00000000..00c4f660 --- /dev/null +++ b/lib/fed/interface.go @@ -0,0 +1,17 @@ +package fed + +import ( + "crypto/tls" + "net" +) + +type PacketCodec interface { + ReadPacket(typed bool) (Packet, error) + WritePacket(packet Packet) error + LocalAddr() net.Addr + Flush() error + Close() error + + SSL() bool + EnableSSL(config *tls.Config, isClient bool) error +} -- GitLab