diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 16ca9b111ff1635a4cbe7c740f0ac54e6e506328..1ac0cc230b9e7b3b8f47554e28b81164fea1afb8 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -2,85 +2,13 @@ package main import ( "log" - "net" "net/http" _ "net/http/pprof" - "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/bouncer/bouncers/v2" - "pggat2/lib/bouncer/frontends/v0" - "pggat2/lib/middleware/interceptor" - "pggat2/lib/middleware/middlewares/eqp" - "pggat2/lib/middleware/middlewares/ps" - "pggat2/lib/middleware/middlewares/unterminate" - "pggat2/lib/rob" - "pggat2/lib/rob/schedulers/v1" - "pggat2/lib/zap" + "pggat2/lib/gat" + "pggat2/lib/gat/pools/transaction" ) -type work struct { - rw zap.ReadWriter - eqpc *eqp.Client - psc *ps.Client -} - -type server struct { - rw zap.ReadWriter - eqps *eqp.Server - pss *ps.Server -} - -func (T server) Do(_ rob.Constraints, w any) { - job := w.(work) - job.psc.SetServer(T.pss) - T.eqps.SetClient(job.eqpc) - bouncers.Bounce(job.rw, T.rw) - return -} - -var _ rob.Worker = server{} - -func testServer(r rob.Scheduler) { - conn, err := net.Dial("tcp", "localhost:5432") - if err != nil { - panic(err) - } - rw := zap.CombinedReadWriter{ - Reader: zap.IOReader{Reader: conn}, - Writer: zap.IOWriter{Writer: conn}, - } - eqps := eqp.NewServer() - pss := ps.NewServer() - mw := interceptor.NewInterceptor( - rw, - eqps, - pss, - ) - backends.Accept(mw, "postgres", "password", "uniswap") - r.AddSink(0, server{ - rw: mw, - eqps: eqps, - pss: pss, - }) -} - -var DefaultParameterStatus = map[string]string{ - // TODO(garet) we should just get these from the first server connection - "DateStyle": "ISO, MDY", - "IntervalStyle": "postgres", - "TimeZone": "America/Chicago", - "application_name": "", - "client_encoding": "UTF8", - "default_transaction_read_only": "off", - "in_hot_standby": "off", - "integer_datetimes": "on", - "is_superuser": "on", - "server_encoding": "UTF8", - "server_version": "14.5", - "session_authorization": "postgres", - "standard_conforming_strings": "on", -} - func main() { go func() { panic(http.ListenAndServe(":8080", nil)) @@ -88,50 +16,13 @@ func main() { log.Println("Starting pggat...") - r := schedulers.MakeScheduler() - for i := 0; i < 5; i++ { - go testServer(&r) - } + pooler := gat.NewPooler() + pooler.Mount("uniswap", transaction.NewPool()) + + log.Println("Listening on :6432") - listener, err := net.Listen("tcp", "0.0.0.0:6432") // TODO(garet) make this configurable + err := pooler.ListenAndServe(":6432") if err != nil { panic(err) } - - log.Println("Listening on 0.0.0.0:6432") - - for { - conn, err := listener.Accept() - if err != nil { - panic(err) - } - go func() { - source := r.NewSource() - client := zap.CombinedReadWriter{ - Reader: zap.IOReader{Reader: conn}, - Writer: zap.IOWriter{Writer: conn}, - } - eqpc := eqp.NewClient() - defer eqpc.Done() - psc := ps.NewClient() - mw := interceptor.NewInterceptor( - client, - unterminate.Unterminate, - eqpc, - psc, - ) - frontends.Accept(mw, DefaultParameterStatus) - for { - _, err := conn.Read([]byte{}) - if err != nil { - break - } - source.Do(0, work{ - rw: mw, - eqpc: eqpc, - psc: psc, - }) - } - }() - } } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 00bf9c2eb9809f0d9df958919aa9d4105b86e01d..66b5e2bbaa3e33d32bb900cddcc3f10070c5f28f 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -24,25 +24,25 @@ func fail(client zap.ReadWriter, err perror.Error) { _ = client.Write(packet) } -func startup0(client zap.ReadWriter) (done bool, status Status) { +func startup0(client zap.ReadWriter) (user, database string, done bool, status Status) { packet := zap.NewUntypedPacket() defer packet.Done() err := client.ReadUntyped(packet) if err != nil { fail(client, perror.Wrap(err)) - return false, Fail + return } read := packet.Read() majorVersion, ok := read.ReadUint16() if !ok { fail(client, packets.ErrBadFormat) - return false, Fail + return } minorVersion, ok := read.ReadUint16() if !ok { fail(client, packets.ErrBadFormat) - return false, Fail + return } if majorVersion == 1234 { @@ -55,30 +55,32 @@ func startup0(client zap.ReadWriter) (done bool, status Status) { perror.FeatureNotSupported, "Cancel is not supported yet", )) - return false, Fail + return case 5679: // SSL is not supported yet err = client.WriteByte('N') if err != nil { fail(client, perror.Wrap(err)) - return false, Fail + return } - return false, Ok + status = Ok + return case 5680: // GSSAPI is not supported yet err = client.WriteByte('N') if err != nil { fail(client, perror.Wrap(err)) - return false, Fail + return } - return false, Ok + status = Ok + return default: fail(client, perror.New( perror.FATAL, perror.ProtocolViolation, "Unknown request code", )) - return false, Fail + return } } @@ -92,14 +94,11 @@ func startup0(client zap.ReadWriter) (done bool, status Status) { var unsupportedOptions []string - var user string - var database string - for { key, ok := read.ReadString() if !ok { fail(client, packets.ErrBadFormat) - return false, Fail + return } if key == "" { break @@ -108,7 +107,7 @@ func startup0(client zap.ReadWriter) (done bool, status Status) { value, ok := read.ReadString() if !ok { fail(client, packets.ErrBadFormat) - return false, Fail + return } switch key { @@ -122,14 +121,14 @@ func startup0(client zap.ReadWriter) (done bool, status Status) { perror.FeatureNotSupported, "Startup options are not supported yet", )) - return false, Fail + return case "replication": fail(client, perror.New( perror.FATAL, perror.FeatureNotSupported, "Replication mode is not supported yet", )) - return false, Fail + return default: if strings.HasPrefix(key, "_pq_.") { // we don't support protocol extensions at the moment @@ -149,7 +148,7 @@ func startup0(client zap.ReadWriter) (done bool, status Status) { err = client.Write(packet) if err != nil { fail(client, perror.Wrap(err)) - return false, Fail + return } } @@ -159,13 +158,15 @@ func startup0(client zap.ReadWriter) (done bool, status Status) { perror.InvalidAuthorizationSpecification, "User is required", )) - return false, Fail + return } if database == "" { database = user } - return true, Ok + status = Ok + done = true + return } func authenticationSASLInitial(client zap.ReadWriter, username, password string) (server sasl.Server, resp []byte, done bool, status Status) { @@ -268,9 +269,11 @@ func updateParameter(pkts *zap.Packets, name, value string) Status { return Ok } -func Accept(client zap.ReadWriter, initialParameterStatus map[string]string) { +func Accept(client zap.ReadWriter, getPassword func(user string, database string) string, initialParameterStatus map[string]string) (user string, database string, ok bool) { for { - done, status := startup0(client) + var done bool + var status Status + user, database, done, status = startup0(client) if status != Ok { return } @@ -279,7 +282,7 @@ func Accept(client zap.ReadWriter, initialParameterStatus map[string]string) { } } - status := authenticationSASL(client, "test", "pw") + status := authenticationSASL(client, user, getPassword(user, database)) if status != Ok { return } @@ -321,4 +324,7 @@ func Accept(client zap.ReadWriter, initialParameterStatus map[string]string) { fail(client, perror.Wrap(err)) return } + + ok = true + return } diff --git a/lib/gat/pool.go b/lib/gat/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..d7a355dba10e0cd4c806d7ff888e28c71092d69c --- /dev/null +++ b/lib/gat/pool.go @@ -0,0 +1,7 @@ +package gat + +import "pggat2/lib/zap" + +type Pool interface { + Serve(client zap.ReadWriter) +} diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go new file mode 100644 index 0000000000000000000000000000000000000000..d007373c2537f321c5dde2be99f0525f46ba6efb --- /dev/null +++ b/lib/gat/pooler.go @@ -0,0 +1,96 @@ +package gat + +import ( + "net" + "sync" + + "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/middleware/interceptor" + "pggat2/lib/middleware/middlewares/unterminate" + "pggat2/lib/zap" +) + +var DefaultParameterStatus = map[string]string{ + // TODO(garet) we should just get these from the first server connection + "DateStyle": "ISO, MDY", + "IntervalStyle": "postgres", + "TimeZone": "America/Chicago", + "application_name": "", + "client_encoding": "UTF8", + "default_transaction_read_only": "off", + "in_hot_standby": "off", + "integer_datetimes": "on", + "is_superuser": "on", + "server_encoding": "UTF8", + "server_version": "14.5", + "session_authorization": "postgres", + "standard_conforming_strings": "on", +} + +type Pooler struct { + pools map[string]Pool + mu sync.RWMutex +} + +func NewPooler() *Pooler { + return &Pooler{ + pools: make(map[string]Pool), + } +} + +func (T *Pooler) Mount(name string, pool Pool) { + T.mu.Lock() + defer T.mu.Unlock() + T.pools[name] = pool +} + +func (T *Pooler) Unmount(name string) { + T.mu.Lock() + defer T.mu.Unlock() + delete(T.pools, name) +} + +func (T *Pooler) getPool(name string) Pool { + T.mu.RLock() + defer T.mu.RUnlock() + return T.pools[name] +} + +func (T *Pooler) Serve(client zap.ReadWriter) { + client = interceptor.NewInterceptor( + client, + unterminate.Unterminate, + ) + + _, database, ok := frontends.Accept(client, func(user string, database string) string { + return "pw" + }, DefaultParameterStatus) + if !ok { + return + } + + pool := T.getPool(database) + if pool == nil { + return + } + + pool.Serve(client) +} + +func (T *Pooler) ListenAndServe(address string) error { + listener, err := net.Listen("tcp", address) + if err != nil { + return err + } + + for { + conn, err := listener.Accept() + if err != nil { + return err + } + go T.Serve(zap.CombinedReadWriter{ + Reader: zap.IOReader{Reader: conn}, + Writer: zap.IOWriter{Writer: conn}, + }) + } +} diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..b435ac1807874f63db2573ffa8dd89442b658b1b --- /dev/null +++ b/lib/gat/pools/session/pool.go @@ -0,0 +1,16 @@ +package session + +import ( + "pggat2/lib/gat" + "pggat2/lib/zap" +) + +type Pool struct { +} + +func (T *Pool) Serve(client zap.ReadWriter) { + // TODO implement me + panic("implement me") +} + +var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go new file mode 100644 index 0000000000000000000000000000000000000000..33b8ca1627e4382d2bcc032dd71052fe9f4bde4a --- /dev/null +++ b/lib/gat/pools/transaction/conn.go @@ -0,0 +1,25 @@ +package transaction + +import ( + "pggat2/lib/bouncer/bouncers/v2" + "pggat2/lib/middleware/middlewares/eqp" + "pggat2/lib/middleware/middlewares/ps" + "pggat2/lib/rob" + "pggat2/lib/zap" +) + +type Conn struct { + rw zap.ReadWriter + eqp *eqp.Server + ps *ps.Server +} + +func (T Conn) Do(_ rob.Constraints, work any) { + job := work.(Work) + job.ps.SetServer(T.ps) + T.eqp.SetClient(job.eqp) + bouncers.Bounce(job.rw, T.rw) + return +} + +var _ rob.Worker = Conn{} diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go new file mode 100644 index 0000000000000000000000000000000000000000..236b33da66e268dedbc3c276e50e07b4ffe12438 --- /dev/null +++ b/lib/gat/pools/transaction/pool.go @@ -0,0 +1,74 @@ +package transaction + +import ( + "net" + + "pggat2/lib/bouncer/backends/v0" + "pggat2/lib/gat" + "pggat2/lib/middleware/interceptor" + "pggat2/lib/middleware/middlewares/eqp" + "pggat2/lib/middleware/middlewares/ps" + "pggat2/lib/rob/schedulers/v1" + "pggat2/lib/zap" +) + +type Pool struct { + s schedulers.Scheduler +} + +func NewPool() *Pool { + pool := &Pool{ + s: schedulers.MakeScheduler(), + } + + for i := 0; i < 5; i++ { + conn, err := net.Dial("tcp", "localhost:5432") + if err != nil { + panic(err) + } + rw := zap.CombinedReadWriter{ + Reader: zap.IOReader{Reader: conn}, + Writer: zap.IOWriter{Writer: conn}, + } + eqps := eqp.NewServer() + pss := ps.NewServer() + mw := interceptor.NewInterceptor( + rw, + eqps, + pss, + ) + backends.Accept(mw, "postgres", "password", "uniswap") + pool.s.AddSink(0, Conn{ + rw: mw, + eqp: eqps, + ps: pss, + }) + } + + return pool +} + +func (T *Pool) Serve(client zap.ReadWriter) { + source := T.s.NewSource() + eqpc := eqp.NewClient() + defer eqpc.Done() + psc := ps.NewClient() + client = interceptor.NewInterceptor( + client, + eqpc, + psc, + ) + for { + err := client.Poll() + if err != nil { + break + } + source.Do(0, Work{ + rw: client, + eqp: eqpc, + ps: psc, + }) + } +} + +var _ gat.Pool = (*Pool)(nil) diff --git a/lib/gat/pools/transaction/work.go b/lib/gat/pools/transaction/work.go new file mode 100644 index 0000000000000000000000000000000000000000..f6366986d19ab0718f49bf7ed6cd9fecf1b64993 --- /dev/null +++ b/lib/gat/pools/transaction/work.go @@ -0,0 +1,13 @@ +package transaction + +import ( + "pggat2/lib/middleware/middlewares/eqp" + "pggat2/lib/middleware/middlewares/ps" + "pggat2/lib/zap" +) + +type Work struct { + rw zap.ReadWriter + eqp *eqp.Client + ps *ps.Client +} diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go index f58d2253d2313cc255cdb2555f92887cd66d0442..49f2c6e2872a466ac77675b297636db9f6f75ce1 100644 --- a/lib/middleware/interceptor/interceptor.go +++ b/lib/middleware/interceptor/interceptor.go @@ -19,6 +19,10 @@ func NewInterceptor(rw zap.ReadWriter, middlewares ...middleware.Middleware) *In } } +func (T *Interceptor) Poll() error { + return T.rw.Poll() +} + func (T *Interceptor) ReadByte() (byte, error) { return T.rw.ReadByte() } diff --git a/lib/zap/reader.go b/lib/zap/reader.go index 61bc8de86c31ebec23263472c27205ed09bdc0fd..76effb413659f52a50ce638c2f424b535012ee40 100644 --- a/lib/zap/reader.go +++ b/lib/zap/reader.go @@ -3,6 +3,7 @@ package zap import "io" type Reader interface { + Poll() error ReadByte() (byte, error) Read(*Packet) error ReadUntyped(*UntypedPacket) error @@ -12,6 +13,11 @@ type IOReader struct { Reader io.Reader } +func (T IOReader) Poll() error { + _, err := T.Reader.Read([]byte{}) + return err +} + func (T IOReader) ReadByte() (byte, error) { var res = []byte{0} _, err := io.ReadFull(T.Reader, res)