From bf8a04e3b2b0575890a3a026b3b2f1e9b32e81e8 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Wed, 11 Oct 2023 18:01:09 -0500 Subject: [PATCH] pretty much working --- lib/bouncer/backends/v0/accept.go | 2 +- lib/bouncer/frontends/v0/accept.go | 2 +- lib/fed/conn.go | 68 ++++++++++++++++++++---------- lib/fed/decoder.go | 21 +++++++++ lib/fed/encoder.go | 15 +++++++ 5 files changed, 84 insertions(+), 24 deletions(-) diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index b74bb1ad..6fe9254b 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -262,7 +262,7 @@ func enableSSL(ctx *acceptContext) (bool, error) { return false, nil } - if err = ctx.Conn.EnableSSLClient(ctx.Options.SSLConfig); err != nil { + if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, true); err != nil { return false, err } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 50e51356..00036e1a 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -47,7 +47,7 @@ func startup0( if err = ctx.Conn.WriteByte('S'); err != nil { return } - if err = ctx.Conn.EnableSSLServer(ctx.Options.SSLConfig); err != nil { + if err = ctx.Conn.EnableSSL(ctx.Options.SSLConfig, false); err != nil { return } return diff --git a/lib/fed/conn.go b/lib/fed/conn.go index 20fbe57d..905a32d2 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -2,7 +2,9 @@ package fed import ( "crypto/tls" + "errors" "io" + "net" "gfx.cafe/gfx/pggat/lib/util/decorator" "gfx.cafe/gfx/pggat/lib/util/strutil" @@ -11,9 +13,10 @@ import ( type Conn struct { noCopy decorator.NoCopy + encoder Encoder + decoder Decoder + ReadWriter io.ReadWriteCloser - Encoder Encoder - Decoder Decoder Middleware []Middleware SSL bool @@ -29,27 +32,27 @@ func NewConn(rw io.ReadWriteCloser) *Conn { c := &Conn{ ReadWriter: rw, } - c.Encoder.Writer.Reset(rw) - c.Decoder.Reader.Reset(rw) + c.encoder.Writer.Reset(rw) + c.decoder.Reader.Reset(rw) return c } -func (T *Conn) Flush() error { - return T.Encoder.Flush() +func (T *Conn) flush() error { + return T.encoder.Flush() } func (T *Conn) ReadPacket(typed bool) (Packet, error) { - if err := T.Flush(); err != nil { + if err := T.flush(); err != nil { return nil, err } for { - if err := T.Decoder.Next(typed); err != nil { + if err := T.decoder.Next(typed); err != nil { return nil, err } var packet Packet packet = PendingPacket{ - Decoder: &T.Decoder, + Decoder: &T.decoder, } for _, middleware := range T.Middleware { var err error @@ -82,38 +85,59 @@ func (T *Conn) WritePacket(packet Packet) error { return nil } - err := T.Encoder.Next(packet.Type(), packet.Length()) + err := T.encoder.Next(packet.Type(), packet.Length()) if err != nil { return err } - return packet.WriteTo(&T.Encoder) + return packet.WriteTo(&T.encoder) } func (T *Conn) WriteByte(b byte) error { - return T.Encoder.Uint8(b) + return T.encoder.WriteByte(b) } func (T *Conn) ReadByte() (byte, error) { - if err := T.Flush(); err != nil { + if err := T.flush(); err != nil { return 0, err } - return T.Decoder.Uint8() + return T.decoder.ReadByte() } -func (T *Conn) EnableSSLClient(config *tls.Config) error { - // TODO(garet) - panic("TODO") -} +func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error { + if T.SSL { + return errors.New("SSL is already enabled") + } + T.SSL = true + + // flush buffers + if err := T.flush(); err != nil { + return err + } + if T.decoder.Reader.Buffered() > 0 { + return errors.New("expected empty read buffer") + } -func (T *Conn) EnableSSLServer(config *tls.Config) error { - // TODO(garet) - panic("TODO") + conn, ok := T.ReadWriter.(net.Conn) + if !ok { + return errors.New("ssl not supported for this read writer") + } + + var sslConn *tls.Conn + if isClient { + sslConn = tls.Client(conn, config) + } else { + sslConn = tls.Server(conn, config) + } + T.encoder.Writer.Reset(sslConn) + T.decoder.Reader.Reset(sslConn) + T.ReadWriter = sslConn + return sslConn.Handshake() } func (T *Conn) Close() error { - if err := T.Encoder.Flush(); err != nil { + if err := T.encoder.Flush(); err != nil { return err } diff --git a/lib/fed/decoder.go b/lib/fed/decoder.go index ae53d304..6aa011b0 100644 --- a/lib/fed/decoder.go +++ b/lib/fed/decoder.go @@ -28,7 +28,28 @@ func NewDecoder(r io.Reader) *Decoder { return d } +func (T *Decoder) ReadByte() (byte, error) { + if T.pos != T.len { + _, err := T.Reader.Discard(T.len - T.pos) + if err != nil { + return 0, err + } + } + + T.typ = 0 + T.len = 0 + T.pos = 0 + return T.Reader.ReadByte() +} + func (T *Decoder) Next(typed bool) error { + if T.pos != T.len { + _, err := T.Reader.Discard(T.len - T.pos) + if err != nil { + return err + } + } + var err error if typed { _, err = io.ReadFull(&T.Reader, T.buf[:5]) diff --git a/lib/fed/encoder.go b/lib/fed/encoder.go index 247982be..44d257b9 100644 --- a/lib/fed/encoder.go +++ b/lib/fed/encoder.go @@ -31,7 +31,22 @@ func (T *Encoder) Flush() error { return T.Writer.Flush() } +func (T *Encoder) WriteByte(b byte) error { + if T.pos != T.len { + panic("wrong number of bytes written") + } + + T.typ = 0 + T.len = 0 + T.pos = 0 + return T.Writer.WriteByte(b) +} + func (T *Encoder) Next(typ Type, length int) error { + if T.pos != T.len { + panic("wrong number of bytes written") + } + if typ != 0 { if err := T.Writer.WriteByte(byte(typ)); err != nil { return err -- GitLab