diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index f3b3444f4e4f81f872b52ecb8800479087006c91..c58bc7fbe16034a23b36873bbf76fc2996dc2a87 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -11,6 +11,7 @@ import ( "pggat2/lib/middleware/middlewares/onebuffer" "pggat2/lib/mw2" "pggat2/lib/mw2/interceptor" + "pggat2/lib/mw2/middlewares/eqp" "pggat2/lib/mw2/middlewares/unterminate" "pggat2/lib/rob" "pggat2/lib/rob/schedulers/v2" @@ -19,6 +20,7 @@ import ( ) type job struct { + eqpc *eqp.Client client zap.ReadWriter done chan<- struct{} } @@ -29,11 +31,19 @@ func testServer(r rob.Scheduler) { panic(err) } server := zio.MakeReadWriter(conn) - backends.Accept(&server) + eqps := eqp.MakeServer() + mw := interceptor.MakeInterceptor( + &server, + []mw2.Middleware{ + &eqps, + }, + ) + backends.Accept(&mw) sink := r.NewSink(0) for { j := sink.Read().(job) - bouncers.Bounce(j.client, &server) + eqps.SetClient(j.eqpc) + bouncers.Bounce(j.client, &mw) select { case j.done <- struct{}{}: default: @@ -62,8 +72,10 @@ func main() { source := r.NewSource() client := zio.MakeReadWriter(conn) ob := onebuffer.MakeOnebuffer(&client) + eqpc := eqp.MakeClient() mw := interceptor.MakeInterceptor(&ob, []mw2.Middleware{ unterminate.Unterminate, + &eqpc, }) frontends.Accept(&mw) done := make(chan struct{}) @@ -74,6 +86,7 @@ func main() { break } source.Schedule(job{ + eqpc: &eqpc, client: &mw, done: done, }, 0) diff --git a/lib/mw2/middlewares/eqp/client.go b/lib/mw2/middlewares/eqp/client.go index 0cf0d18bc2acb235c4637e07644c8960659258fc..70abce4651393ebe56b869ba9bc71b1ec74cbb60 100644 --- a/lib/mw2/middlewares/eqp/client.go +++ b/lib/mw2/middlewares/eqp/client.go @@ -23,6 +23,10 @@ func MakeClient() Client { func (T *Client) Send(_ mw2.Context, out zap.Out) error { in := zap.OutToIn(out) switch in.Type() { + case packets.ReadyForQuery: + // clobber unnamed + delete(T.preparedStatements, "") + delete(T.portals, "") case packets.ParseComplete, packets.BindComplete, packets.CloseComplete: // should've been caught by eqp.Server panic("unreachable") diff --git a/lib/mw2/middlewares/eqp/server.go b/lib/mw2/middlewares/eqp/server.go index 73f25d6a3f1ed20332c51f8235a27cfee4568df3..108f105f51e97c8aa40fbdd6dcd93ff0ac771eac 100644 --- a/lib/mw2/middlewares/eqp/server.go +++ b/lib/mw2/middlewares/eqp/server.go @@ -85,10 +85,12 @@ func (T *Server) closePortal(ctx mw2.Context, target string) error { } func (T *Server) bindPreparedStatement(ctx mw2.Context, target string, preparedStatement PreparedStatement) error { - if _, ok := T.preparedStatements[target]; ok { - err := T.closePreparedStatement(ctx, target) - if err != nil { - return err + if target != "" { + if _, ok := T.preparedStatements[target]; ok { + err := T.closePreparedStatement(ctx, target) + if err != nil { + return err + } } } @@ -105,10 +107,12 @@ func (T *Server) bindPreparedStatement(ctx mw2.Context, target string, preparedS } func (T *Server) bindPortal(ctx mw2.Context, target string, portal Portal) error { - if _, ok := T.portals[target]; ok { - err := T.closePortal(ctx, target) - if err != nil { - return err + if target != "" { + if _, ok := T.portals[target]; ok { + err := T.closePortal(ctx, target) + if err != nil { + return err + } } } @@ -205,6 +209,9 @@ func (T *Server) Read(ctx mw2.Context, in zap.In) error { T.pendingCloses.PopFront() case packets.ReadyForQuery: + // clobber unnamed + delete(T.preparedStatements, "") + delete(T.portals, "") // all pending failed for pending, ok := T.pendingPreparedStatements.PopBack(); ok; pending, ok = T.pendingPreparedStatements.PopBack() { delete(T.preparedStatements, pending) diff --git a/lib/zap/packets/v3.0/bind.go b/lib/zap/packets/v3.0/bind.go index caa0d871620b1d1789c171e1ebc5eb311194d340..07c526a6ba97e1a84cc1e3a5e196db36a6f13a3c 100644 --- a/lib/zap/packets/v3.0/bind.go +++ b/lib/zap/packets/v3.0/bind.go @@ -17,8 +17,8 @@ func ReadBind(in zap.In) (destination string, source string, parameterFormatCode if !ok { return } - var parameterFormatCodesLength int16 - parameterFormatCodesLength, ok = in.Int16() + var parameterFormatCodesLength uint16 + parameterFormatCodesLength, ok = in.Uint16() if !ok { return } @@ -31,8 +31,8 @@ func ReadBind(in zap.In) (destination string, source string, parameterFormatCode } parameterFormatCodes = append(parameterFormatCodes, parameterFormatCode) } - var parameterValuesLength int16 - parameterValuesLength, ok = in.Int16() + var parameterValuesLength uint16 + parameterValuesLength, ok = in.Uint16() if !ok { return } @@ -50,8 +50,8 @@ func ReadBind(in zap.In) (destination string, source string, parameterFormatCode } parameterValues = append(parameterValues, parameterValue) } - var resultFormatCodesLength int16 - resultFormatCodesLength, ok = in.Int16() + var resultFormatCodesLength uint16 + resultFormatCodesLength, ok = in.Uint16() if !ok { return } @@ -72,11 +72,11 @@ func WriteBind(out zap.Out, destination, source string, parameterFormatCodes []i out.Type(Bind) out.String(destination) out.String(source) - out.Int16(int16(len(parameterFormatCodes))) + out.Uint16(uint16(len(parameterFormatCodes))) for _, v := range parameterFormatCodes { out.Int16(v) } - out.Int16(int16(len(parameterValues))) + out.Uint16(uint16(len(parameterValues))) for _, v := range parameterValues { if v == nil { out.Int32(-1) @@ -85,7 +85,7 @@ func WriteBind(out zap.Out, destination, source string, parameterFormatCodes []i out.Int32(int32(len(v))) out.Bytes(v) } - out.Int16(int16(len(resultFormatCodes))) + out.Uint16(uint16(len(resultFormatCodes))) for _, v := range resultFormatCodes { out.Int16(v) }