diff --git a/lib/mw2/middlewares/eqp/client.go b/lib/mw2/middlewares/eqp/client.go index 9f0f84af19c2dea63a7158fd7d452a34118b99ac..0cf0d18bc2acb235c4637e07644c8960659258fc 100644 --- a/lib/mw2/middlewares/eqp/client.go +++ b/lib/mw2/middlewares/eqp/client.go @@ -13,6 +13,13 @@ type Client struct { portals map[string]Portal } +func MakeClient() Client { + return Client{ + preparedStatements: make(map[string]PreparedStatement), + portals: make(map[string]Portal), + } +} + func (T *Client) Send(_ mw2.Context, out zap.Out) error { in := zap.OutToIn(out) switch in.Type() { @@ -105,8 +112,32 @@ func (T *Client) Read(ctx mw2.Context, in zap.In) error { if err != nil { return err } - - // TODO(garet) we should read Describe and Execute to check if target exists + case packets.Describe: + // ensure target exists + which, target, ok := packets.ReadDescribe(in) + if !ok { + return errors.New("bad packet format") + } + switch which { + case 'S': + if _, ok := T.preparedStatements[target]; !ok { + return errors.New("prepared statement doesn't exist") + } + case 'P': + if _, ok := T.portals[target]; !ok { + return errors.New("portal doesn't exist") + } + default: + return errors.New("unknown describe target") + } + case packets.Execute: + target, _, ok := packets.ReadExecute(in) + if !ok { + return errors.New("bad packet format") + } + if _, ok := T.portals[target]; !ok { + return errors.New("portal doesn't exist") + } } return nil } diff --git a/lib/mw2/middlewares/eqp/portal.go b/lib/mw2/middlewares/eqp/portal.go index c55ff00f78c6f4de44814428cc0c49ae61b7759b..7ec58e4d2253cf47ef640977d3d40788420e57dc 100644 --- a/lib/mw2/middlewares/eqp/portal.go +++ b/lib/mw2/middlewares/eqp/portal.go @@ -1,8 +1,31 @@ package eqp +import "pggat2/lib/util/slices" + type Portal struct { Source string ParameterFormatCodes []int16 ParameterValues [][]byte ResultFormatCodes []int16 } + +func (T Portal) Equals(rhs Portal) bool { + if T.Source != rhs.Source { + return false + } + if !slices.Equal(T.ParameterFormatCodes, rhs.ParameterFormatCodes) { + return false + } + if len(T.ParameterValues) != len(rhs.ParameterValues) { + return false + } + for i := range T.ParameterValues { + if !slices.Equal(T.ParameterValues[i], rhs.ParameterValues[i]) { + return false + } + } + if !slices.Equal(T.ResultFormatCodes, rhs.ResultFormatCodes) { + return false + } + return true +} diff --git a/lib/mw2/middlewares/eqp/preparedStatement.go b/lib/mw2/middlewares/eqp/preparedStatement.go index 5f62b2433b6efb56130bd60aeb638ba7bfee73b3..a710b61b8139ba7b6a7c54a6bc51760bcc3cd5f3 100644 --- a/lib/mw2/middlewares/eqp/preparedStatement.go +++ b/lib/mw2/middlewares/eqp/preparedStatement.go @@ -1,6 +1,18 @@ package eqp +import "pggat2/lib/util/slices" + type PreparedStatement struct { Query string ParameterDataTypes []int32 } + +func (T PreparedStatement) Equals(rhs PreparedStatement) bool { + if T.Query != rhs.Query { + return false + } + if !slices.Equal(T.ParameterDataTypes, rhs.ParameterDataTypes) { + return false + } + return true +} diff --git a/lib/mw2/middlewares/eqp/server.go b/lib/mw2/middlewares/eqp/server.go index 8703e67d5d870c2b4b6ab6c242daa54884a497a9..73f25d6a3f1ed20332c51f8235a27cfee4568df3 100644 --- a/lib/mw2/middlewares/eqp/server.go +++ b/lib/mw2/middlewares/eqp/server.go @@ -9,21 +9,79 @@ import ( packets "pggat2/lib/zap/packets/v3.0" ) +type pendingClose interface { + pendingClose() +} + +type pendingClosePreparedStatement struct { + target string + preparedStatement PreparedStatement +} + +func (pendingClosePreparedStatement) pendingClose() {} + +type pendingClosePortal struct { + target string + portal Portal +} + +func (pendingClosePortal) pendingClose() {} + type Server struct { preparedStatements map[string]PreparedStatement portals map[string]Portal pendingPreparedStatements ring.Ring[string] pendingPortals ring.Ring[string] + pendingCloses ring.Ring[pendingClose] + + buf zap.Buf peer *Client } +func MakeServer() Server { + return Server{ + preparedStatements: make(map[string]PreparedStatement), + portals: make(map[string]Portal), + } +} + +func (T *Server) SetClient(client *Client) { + T.peer = client +} + func (T *Server) closePreparedStatement(ctx mw2.Context, target string) error { + out := T.buf.Write() + packets.WriteClose(out, 'S', target) + err := ctx.Send(out) + if err != nil { + return err + } + preparedStatement := T.preparedStatements[target] + delete(T.preparedStatements, target) + T.pendingCloses.PushBack(pendingClosePreparedStatement{ + target: target, + preparedStatement: preparedStatement, + }) + return nil } func (T *Server) closePortal(ctx mw2.Context, target string) error { + out := T.buf.Write() + packets.WriteClose(out, 'P', target) + err := ctx.Send(out) + if err != nil { + return err + } + portal := T.portals[target] + delete(T.portals, target) + T.pendingCloses.PushBack(pendingClosePortal{ + target: target, + portal: portal, + }) + return nil } func (T *Server) bindPreparedStatement(ctx mw2.Context, target string, preparedStatement PreparedStatement) error { @@ -33,6 +91,17 @@ func (T *Server) bindPreparedStatement(ctx mw2.Context, target string, preparedS return err } } + + out := T.buf.Write() + packets.WriteParse(out, target, preparedStatement.Query, preparedStatement.ParameterDataTypes) + err := ctx.Send(out) + if err != nil { + return err + } + + T.preparedStatements[target] = preparedStatement + T.pendingPreparedStatements.PushBack(target) + return nil } func (T *Server) bindPortal(ctx mw2.Context, target string, portal Portal) error { @@ -42,14 +111,36 @@ func (T *Server) bindPortal(ctx mw2.Context, target string, portal Portal) error return err } } + + out := T.buf.Write() + packets.WriteBind(out, target, portal.Source, portal.ParameterFormatCodes, portal.ParameterValues, portal.ResultFormatCodes) + err := ctx.Send(out) + if err != nil { + return err + } + + T.portals[target] = portal + T.pendingPortals.PushBack(target) + return nil } func (T *Server) syncPreparedStatement(ctx mw2.Context, target string) error { - + // we can assume client has the prepared statement because it should be checked by eqp.Client + expected := T.peer.preparedStatements[target] + actual, ok := T.preparedStatements[target] + if !ok || !expected.Equals(actual) { + return T.bindPreparedStatement(ctx, target, expected) + } + return nil } func (T *Server) syncPortal(ctx mw2.Context, target string) error { - + expected := T.peer.portals[target] + actual, ok := T.portals[target] + if !ok || !expected.Equals(actual) { + return T.bindPortal(ctx, target, expected) + } + return nil } func (T *Server) Send(ctx mw2.Context, out zap.Out) error { @@ -112,15 +203,25 @@ func (T *Server) Read(ctx mw2.Context, in zap.In) error { case packets.CloseComplete: ctx.Cancel() - // TODO(garet) Correctness: we could check this to make sure state is synced, but waiting for close is a pain + T.pendingCloses.PopFront() case packets.ReadyForQuery: // all pending failed - for pending, ok := T.pendingPreparedStatements.PopFront(); ok; pending, ok = T.pendingPreparedStatements.PopFront() { + for pending, ok := T.pendingPreparedStatements.PopBack(); ok; pending, ok = T.pendingPreparedStatements.PopBack() { delete(T.preparedStatements, pending) } - for pending, ok := T.pendingPortals.PopFront(); ok; pending, ok = T.pendingPortals.PopFront() { + for pending, ok := T.pendingPortals.PopBack(); ok; pending, ok = T.pendingPortals.PopBack() { delete(T.portals, pending) } + for pending, ok := T.pendingCloses.PopBack(); ok; pending, ok = T.pendingCloses.PopBack() { + switch p := pending.(type) { + case pendingClosePortal: + T.portals[p.target] = p.portal + case pendingClosePreparedStatement: + T.preparedStatements[p.target] = p.preparedStatement + default: + panic("what") + } + } } return nil }