diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 142d2613005ff29b9f8219ee9df69a7c937a6afe..44a36239fe32e10b1ca3d49dd4e8d4a49f472737 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -6,6 +6,7 @@ import ( "net/http" _ "net/http/pprof" + "pggat2/lib/middleware/middlewares/eqp" "pggat2/lib/rob/schedulers/v0" "pggat2/lib/zap/onebuffer" @@ -20,13 +21,20 @@ import ( "pggat2/lib/zap/zio" ) +type work struct { + rw zap.ReadWriter + eqpc *eqp.Client +} + type server struct { - rw zap.ReadWriter + rw zap.ReadWriter + eqps *eqp.Server } -func (T server) Do(_ rob.Constraints, work any) { - client := work.(zap.ReadWriter) - bouncers.Bounce(client, T.rw) +func (T server) Do(_ rob.Constraints, w any) { + job := w.(work) + T.eqps.SetClient(job.eqpc) + bouncers.Bounce(job.rw, T.rw) } var _ rob.Worker = server{} @@ -37,9 +45,14 @@ func testServer(r rob.Scheduler) { panic(err) } rw := zio.MakeReadWriter(conn) - backends.Accept(&rw) + eqps := eqp.MakeServer() + mw := interceptor.MakeInterceptor(&rw, []middleware.Middleware{ + &eqps, + }) + backends.Accept(&mw) r.AddSink(0, server{ - rw: &rw, + rw: &mw, + eqps: &eqps, }) } @@ -70,8 +83,10 @@ func main() { source := r.NewSource() client := zio.MakeReadWriter(conn) ob := onebuffer.MakeOnebuffer(&client) + eqpc := eqp.MakeClient() mw := interceptor.MakeInterceptor(&ob, []middleware.Middleware{ unterminate.Unterminate, + &eqpc, }) frontends.Accept(&mw) for { @@ -79,7 +94,10 @@ func main() { if err != nil { break } - source.Do(0, &mw) + source.Do(0, work{ + rw: &mw, + eqpc: &eqpc, + }) } }() } diff --git a/lib/bouncer/bouncers/v1/bctx/context.go b/lib/bouncer/bouncers/v1/bctx/context.go index c96bc521eccec193d4f641c7cda83721cb81ebc4..04a0ddbcef0350f8866f9787a5eec6df8ba48345 100644 --- a/lib/bouncer/bouncers/v1/bctx/context.go +++ b/lib/bouncer/bouncers/v1/bctx/context.go @@ -35,6 +35,9 @@ func MakeContext(client, server zap.ReadWriter, clientIdleTimeout time.Duration) } func (T *Context) Done() { + if T.clientIdleTimeout == 0 { + return + } // if it fails, it's not my problem - Garet, May 12, 2023 _ = T.client.SetReadDeadline(time.Time{}) } diff --git a/lib/bouncer/bouncers/v1/bouncer.go b/lib/bouncer/bouncers/v1/bouncer.go index 7d848fead1b8917efec791fc6c008d59e3deb573..50ec65fa7f5ea4016932c31407cafb60117ae794 100644 --- a/lib/bouncer/bouncers/v1/bouncer.go +++ b/lib/bouncer/bouncers/v1/bouncer.go @@ -2,7 +2,6 @@ package bouncers import ( "log" - "time" "pggat2/lib/bouncer/bouncers/v1/bctx" "pggat2/lib/bouncer/bouncers/v1/berr" @@ -267,7 +266,7 @@ func transaction(ctx *bctx.Context) berr.Error { } func Bounce(client, server zap.ReadWriter) { - ctx := bctx.MakeContext(client, server, 1*time.Second) // TODO(garet) make this configurable + ctx := bctx.MakeContext(client, server, 0) // TODO(garet) make this configurable defer ctx.Done() err := transaction(&ctx) if err != nil { diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 6ce8d78a753aab084a9505df88f7689e984edab0..f1a5306f9f660faf5ca6f795f45d52900366e106 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -272,7 +272,7 @@ func Accept(client zap.ReadWriter) { } } - status := authenticationSASL(client, "test", "password") + status := authenticationSASL(client, "test", "pw") if status != Ok { return } diff --git a/lib/middleware/middlewares/eqp/client.go b/lib/middleware/middlewares/eqp/client.go index 927fb6fee482b11cd249323e11dbb41d219f8a11..b1a059573312ce77e5f01b23ddb80b608fd37446 100644 --- a/lib/middleware/middlewares/eqp/client.go +++ b/lib/middleware/middlewares/eqp/client.go @@ -24,9 +24,16 @@ func (T *Client) Send(_ middleware.Context, out zap.Out) error { in := zap.OutToIn(out) switch in.Type() { case packets.ReadyForQuery: - // clobber unnamed - delete(T.preparedStatements, "") - delete(T.portals, "") + state, ok := packets.ReadReadyForQuery(in) + if !ok { + return errors.New("bad packet format") + } + if state == 'I' { + // clobber all portals + for name := range T.portals { + delete(T.portals, name) + } + } case packets.ParseComplete, packets.BindComplete, packets.CloseComplete: // should've been caught by eqp.Server panic("unreachable") @@ -77,6 +84,9 @@ func (T *Client) Read(ctx middleware.Context, in zap.In) error { return errors.New("portal already exists") } } + if _, ok = T.preparedStatements[source]; !ok { + return errors.New("prepared statement does not exist") + } T.portals[destination] = Portal{ Source: source, ParameterFormatCodes: parameterFormatCodes, diff --git a/lib/middleware/middlewares/eqp/server.go b/lib/middleware/middlewares/eqp/server.go index 8f5aec86d9e146d180ce3c2e8fd3ac56d4394a22..1c5b92fcfde237fa51d917ff9bce87b202b787f4 100644 --- a/lib/middleware/middlewares/eqp/server.go +++ b/lib/middleware/middlewares/eqp/server.go @@ -133,6 +133,15 @@ func (T *Server) syncPreparedStatement(ctx middleware.Context, target string) er expected := T.peer.preparedStatements[target] actual, ok := T.preparedStatements[target] if !ok || !expected.Equals(actual) { + // clear all portals that use this prepared statement + for name, portal := range T.portals { + if portal.Source == target { + err := T.closePortal(ctx, name) + if err != nil { + return err + } + } + } return T.bindPreparedStatement(ctx, target, expected) } return nil @@ -140,6 +149,10 @@ func (T *Server) syncPreparedStatement(ctx middleware.Context, target string) er func (T *Server) syncPortal(ctx middleware.Context, target string) error { expected := T.peer.portals[target] + err := T.syncPreparedStatement(ctx, expected.Source) + if err != nil { + return err + } actual, ok := T.portals[target] if !ok || !expected.Equals(actual) { return T.bindPortal(ctx, target, expected) @@ -209,9 +222,16 @@ func (T *Server) Read(ctx middleware.Context, in zap.In) error { T.pendingCloses.PopFront() case packets.ReadyForQuery: - // clobber unnamed - delete(T.preparedStatements, "") - delete(T.portals, "") + state, ok := packets.ReadReadyForQuery(in) + if !ok { + return errors.New("bad packet format") + } + if state == 'I' { + // clobber all portals + for name := range T.portals { + delete(T.portals, name) + } + } // all pending failed for pending, ok := T.pendingPreparedStatements.PopBack(); ok; pending, ok = T.pendingPreparedStatements.PopBack() { delete(T.preparedStatements, pending)