diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 3c5bdba1e14f6a9eda5f0e9ec21e7cbd31f70fab..679df76496c472d298fd20d88a527505176ccb10 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -24,7 +24,7 @@ variables: test: image: postgres:alpine variables: - POSTGRES_PASSWORD: password + POSTGRES_PASSWORD: postgres stage: test extends: .go-cache script: @@ -56,7 +56,7 @@ coverage: stage: test image: postgres:alpine variables: - POSTGRES_PASSWORD: password + POSTGRES_PASSWORD: postgres coverage: '/\(statements\)(?:\s+)?(\d+(?:\.\d+)?%)/' extends: .go-cache script: diff --git a/lib/fed/codecs/netconncodec/codec.go b/lib/fed/codecs/netconncodec/codec.go new file mode 100644 index 0000000000000000000000000000000000000000..414943cae19a6254b120d4326c7ca786743a4f9e --- /dev/null +++ b/lib/fed/codecs/netconncodec/codec.go @@ -0,0 +1,113 @@ +package netconncodec + +import ( + "crypto/tls" + "errors" + "fmt" + "net" + "sync" + + "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/util/decorator" +) + +type Codec struct { + noCopy decorator.NoCopy + + conn net.Conn + ssl bool + + encoder fed.Encoder + decoder fed.Decoder + + mu sync.RWMutex +} + +func NewCodec(rw net.Conn) fed.PacketCodec { + c := &Codec{ + conn: rw, + } + c.encoder.Reset(rw) + c.decoder.Reset(rw) + return c +} + +func (c *Codec) ReadPacket(typed bool) (fed.Packet, error) { + if err := c.decoder.Next(typed); err != nil { + return nil, err + } + return fed.PendingPacket{ + Decoder: &c.decoder, + }, nil +} + +func (c *Codec) WritePacket(packet fed.Packet) error { + err := c.encoder.Next(packet.Type(), packet.Length()) + if err != nil { + return err + } + + return packet.WriteTo(&c.encoder) +} +func (c *Codec) WriteByte(b byte) error { + return c.encoder.WriteByte(b) +} + +func (c *Codec) ReadByte() (byte, error) { + if err := c.Flush(); err != nil { + return 0, err + } + + return c.decoder.ReadByte() +} + +func (c *Codec) Flush() error { + return c.encoder.Flush() +} + +func (c *Codec) Close() error { + if err := c.encoder.Flush(); err != nil { + return err + } + return c.conn.Close() +} + +func (c *Codec) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +func (c *Codec) SSL() bool { + return c.ssl +} + +func (c *Codec) EnableSSL(config *tls.Config, isClient bool) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.ssl { + return errors.New("SSL is already enabled") + } + c.ssl = true + + // Flush buffers + if err := c.Flush(); err != nil { + return err + } + if c.decoder.Buffered() > 0 { + return errors.New("expected empty read buffer") + } + + var sslConn *tls.Conn + if isClient { + sslConn = tls.Client(c.conn, config) + } else { + sslConn = tls.Server(c.conn, config) + } + c.encoder.Reset(sslConn) + c.decoder.Reset(sslConn) + c.conn = sslConn + err := sslConn.Handshake() + if err != nil { + return fmt.Errorf("ssl handshake fail client(%v): %w", isClient, err) + } + return nil +} diff --git a/lib/fed/conn.go b/lib/fed/conn.go index 9ea714957192e16f9db6db088fcfc077a04d77ab..b6bc0016b2e728e3224ba1fb58ed2fdacec28e31 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -2,7 +2,6 @@ package fed import ( "crypto/tls" - "errors" "net" "gfx.cafe/gfx/pggat/lib/util/decorator" @@ -12,10 +11,7 @@ import ( type Conn struct { noCopy decorator.NoCopy - encoder Encoder - decoder Decoder - - NetConn net.Conn + codec PacketCodec Middleware []Middleware @@ -30,26 +26,19 @@ type Conn struct { Ready bool } -func NewConn(rw net.Conn) *Conn { +func NewConn(codec PacketCodec) *Conn { c := &Conn{ - NetConn: rw, + codec: codec, } - c.encoder.Reset(rw) - c.decoder.Reset(rw) return c } func (T *Conn) Flush() error { - return T.encoder.Flush() + return T.codec.Flush() } func (T *Conn) readPacket(typed bool) (Packet, error) { - if err := T.decoder.Next(typed); err != nil { - return nil, err - } - return PendingPacket{ - Decoder: &T.decoder, - }, nil + return T.codec.ReadPacket(typed) } func (T *Conn) ReadPacket(typed bool) (Packet, error) { @@ -107,12 +96,7 @@ func (T *Conn) ReadPacket(typed bool) (Packet, error) { } func (T *Conn) writePacket(packet Packet) error { - err := T.encoder.Next(packet.Type(), packet.Length()) - if err != nil { - return err - } - - return packet.WriteTo(&T.encoder) + return T.codec.WritePacket(packet) } func (T *Conn) WritePacket(packet Packet) error { @@ -171,47 +155,22 @@ func (T *Conn) WritePacket(packet Packet) error { } func (T *Conn) WriteByte(b byte) error { - return T.encoder.WriteByte(b) + return T.codec.WriteByte(b) } -func (T *Conn) ReadByte() (byte, error) { - if err := T.Flush(); err != nil { - return 0, err - } +func (T *Conn) LocalAddr() net.Addr { + return T.codec.LocalAddr() - return T.decoder.ReadByte() } -func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error { - if T.SSL { - return errors.New("SSL is already enabled") - } - T.SSL = true - - // Flush buffers - if err := T.Flush(); err != nil { - return err - } - if T.decoder.Buffered() > 0 { - return errors.New("expected empty read buffer") - } +func (T *Conn) ReadByte() (byte, error) { + return T.codec.ReadByte() +} - var sslConn *tls.Conn - if isClient { - sslConn = tls.Client(T.NetConn, config) - } else { - sslConn = tls.Server(T.NetConn, config) - } - T.encoder.Reset(sslConn) - T.decoder.Reset(sslConn) - T.NetConn = sslConn - return sslConn.Handshake() +func (T *Conn) EnableSSL(config *tls.Config, isClient bool) error { + return T.codec.EnableSSL(config, isClient) } func (T *Conn) Close() error { - if err := T.encoder.Flush(); err != nil { - return err - } - - return T.NetConn.Close() + return T.codec.Close() } diff --git a/lib/fed/interface.go b/lib/fed/interface.go new file mode 100644 index 0000000000000000000000000000000000000000..636d62a13fd321278543fcd52d62e95e580b0886 --- /dev/null +++ b/lib/fed/interface.go @@ -0,0 +1,20 @@ +package fed + +import ( + "crypto/tls" + "net" +) + +type PacketCodec interface { + ReadPacket(typed bool) (Packet, error) + WritePacket(packet Packet) error + WriteByte(b byte) error + ReadByte() (byte, error) + + LocalAddr() net.Addr + Flush() error + Close() error + + SSL() bool + EnableSSL(config *tls.Config, isClient bool) error +} diff --git a/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go b/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go index 68be3af337529c04288283777b99e9033a6da822..fb3a1304065b51c940c47db9d62cff31eefb1304 100644 --- a/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go +++ b/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go @@ -135,7 +135,7 @@ func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, repli var result authQueryResult - inward, outward := gsql.NewPair() + inward, outward, _, _ := gsql.NewPair() var b flip.Bank b.Queue(func() error { diff --git a/lib/gat/handlers/pgbouncer/module.go b/lib/gat/handlers/pgbouncer/module.go index 6973ad29e587d4226aa5ad23c67bb82c75d104aa..057eea53be1d9edfe350dc694bc360adcfdbd72c 100644 --- a/lib/gat/handlers/pgbouncer/module.go +++ b/lib/gat/handlers/pgbouncer/module.go @@ -114,7 +114,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { var b flip.Bank - inward, outward := gsql.NewPair() + inward, outward, _, _ := gsql.NewPair() b.Queue(func() error { if err := gsql.ExtendedQuery(inward, &result, T.Config.PgBouncer.AuthQuery, user); err != nil { return err diff --git a/lib/gat/handlers/pool/dialer.go b/lib/gat/handlers/pool/dialer.go index ba902c94dbfa1ed6197c599659d750ea41d40e65..1da5d9fea194640190cd085290e8c104ce083623 100644 --- a/lib/gat/handlers/pool/dialer.go +++ b/lib/gat/handlers/pool/dialer.go @@ -14,6 +14,7 @@ import ( "gfx.cafe/gfx/pggat/lib/bouncer" "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/fed/codecs/netconncodec" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -65,7 +66,7 @@ func (T *Dialer) Dial() (*fed.Conn, error) { if err != nil { return nil, err } - conn := fed.NewConn(c) + conn := fed.NewConn(netconncodec.NewCodec(c)) conn.User = T.Username conn.Database = T.Database err = backends.Accept( @@ -89,7 +90,7 @@ func (T *Dialer) Cancel(key fed.BackendKey) { if err != nil { return } - conn := fed.NewConn(c) + conn := fed.NewConn(netconncodec.NewCodec(c)) defer func() { _ = conn.Close() }() diff --git a/lib/gat/listen.go b/lib/gat/listen.go index 9a4e22b83b87c0574658f69baead2bb288856b1b..fb47fd97442a9143d5b824f64b6fe525dbbf495d 100644 --- a/lib/gat/listen.go +++ b/lib/gat/listen.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/fed/codecs/netconncodec" ) type ListenerConfig struct { @@ -40,7 +41,7 @@ func (T *Listener) accept() (*fed.Conn, error) { if err != nil { return nil, err } - return fed.NewConn(raw), nil + return fed.NewConn(netconncodec.NewCodec(raw)), nil } func (T *Listener) Provision(ctx caddy.Context) error { diff --git a/lib/gat/matchers/localaddress.go b/lib/gat/matchers/localaddress.go index 920d2852733269c18e18f044df48528c5276e5b3..b89e6699ddc522908d261711ef1f53909e667a8e 100644 --- a/lib/gat/matchers/localaddress.go +++ b/lib/gat/matchers/localaddress.go @@ -48,7 +48,7 @@ func (T *LocalAddress) Provision(ctx caddy.Context) error { } func (T *LocalAddress) Matches(conn *fed.Conn) bool { - switch addr := conn.NetConn.LocalAddr().(type) { + switch addr := conn.LocalAddr().(type) { case *net.TCPAddr: expected, ok := T.addr.(*net.TCPAddr) if !ok { diff --git a/lib/gsql/pair.go b/lib/gsql/pair.go index e6ceb79c563cc62fec49f976cb31b6ae70679526..2fcaff95d81d73f27868478ff58ab7c557e7aacb 100644 --- a/lib/gsql/pair.go +++ b/lib/gsql/pair.go @@ -1,16 +1,29 @@ package gsql import ( + "net" + "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/fed/codecs/netconncodec" "gfx.cafe/gfx/pggat/lib/util/mio" ) -func NewPair() (*fed.Conn, *fed.Conn) { +func NewPair() (*fed.Conn, *fed.Conn, net.Conn, net.Conn) { conn := new(mio.Conn) - inward := fed.NewConn(mio.InwardConn{Conn: conn}) + in := mio.InwardConn{Conn: conn} + out := mio.OutwardConn{Conn: conn} + inward := fed.NewConn( + netconncodec.NewCodec( + in, + ), + ) inward.Ready = true - outward := fed.NewConn(mio.OutwardConn{Conn: conn}) + outward := fed.NewConn( + netconncodec.NewCodec( + out, + ), + ) outward.Ready = true - return inward, outward + return inward, outward, in, out } diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index 3839029d8ccc12b12a7d6e35c30d25220debe3a4..476748de548ed0ea65fccc8c245fed625cb828ba 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -1,6 +1,7 @@ package gsql_test import ( + "crypto/tls" "log" "net" "net/http" @@ -11,6 +12,7 @@ import ( "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" "gfx.cafe/gfx/pggat/lib/bouncer/bouncers/v2" "gfx.cafe/gfx/pggat/lib/fed" + "gfx.cafe/gfx/pggat/lib/fed/codecs/netconncodec" "gfx.cafe/gfx/pggat/lib/gsql" "gfx.cafe/gfx/pggat/lib/util/flip" ) @@ -30,15 +32,15 @@ func TestQuery(t *testing.T) { t.Error(err) return } - server := fed.NewConn(s) + server := fed.NewConn(netconncodec.NewCodec(s)) err = backends.Accept( server, - "", - nil, + "disable", + &tls.Config{}, "postgres", credentials.Cleartext{ Username: "postgres", - Password: "password", + Password: "postgres", }, "postgres", nil, @@ -48,7 +50,7 @@ func TestQuery(t *testing.T) { return } - inward, outward := gsql.NewPair() + inward, outward, _, _ := gsql.NewPair() var res Result diff --git a/test/runner.go b/test/runner.go index 5bbeb187d26be796aab2d648c4e6c6a25a6a373b..8b1c7e9683652584a7b5de40789c62fc8e565e53 100644 --- a/test/runner.go +++ b/test/runner.go @@ -76,7 +76,7 @@ func (T *Runner) runModeL1(dialer pool.Dialer, client *fed.Conn) error { } func (T *Runner) runModeOnce(dialer pool.Dialer) ([]byte, error) { - inward, outward := gsql.NewPair() + inward, outward, in, _ := gsql.NewPair() if err := T.prepare(inward, len(T.test.Packets)); err != nil { return nil, err } @@ -89,12 +89,12 @@ func (T *Runner) runModeOnce(dialer pool.Dialer) ([]byte, error) { return nil, err } - return io.ReadAll(inward.NetConn) + return io.ReadAll(in) } func (T *Runner) runModeFail(dialer pool.Dialer) error { for i := 1; i <= len(T.test.Packets); i++ { - inward, outward := gsql.NewPair() + inward, outward, _, _ := gsql.NewPair() if err := T.prepare(inward, i); err != nil { return err }