diff --git a/lib/middleware/middleware.go b/lib/middleware/middleware.go deleted file mode 100644 index 80213fb41c3125100969f49c5b92ec16c2ddf2cd..0000000000000000000000000000000000000000 --- a/lib/middleware/middleware.go +++ /dev/null @@ -1,8 +0,0 @@ -package middleware - -import "pggat2/lib/pnet/packet" - -type Middleware interface { - Write(in packet.In) (forward bool, err error) - Read(in packet.In) (forward bool, err error) -} diff --git a/lib/mw2/context.go b/lib/mw2/context.go new file mode 100644 index 0000000000000000000000000000000000000000..06acdbd89b2d05f333f14f9ca6002f9b87af129a --- /dev/null +++ b/lib/mw2/context.go @@ -0,0 +1,11 @@ +package mw2 + +import "pggat2/lib/zap" + +type Context interface { + // Cancel the current packet + Cancel() + + // Send to underlying writer + Send(out zap.Out) error +} diff --git a/lib/mw2/middleware.go b/lib/mw2/middleware.go new file mode 100644 index 0000000000000000000000000000000000000000..f1afde8e20e910c737a7ce75e232003eef9ef397 --- /dev/null +++ b/lib/mw2/middleware.go @@ -0,0 +1,8 @@ +package mw2 + +import "pggat2/lib/zap" + +type Middleware interface { + Send(ctx Context, out zap.Out) error + Read(ctx Context, in zap.In) error +} diff --git a/lib/mw2/middlewares/eqp/client.go b/lib/mw2/middlewares/eqp/client.go new file mode 100644 index 0000000000000000000000000000000000000000..9f0f84af19c2dea63a7158fd7d452a34118b99ac --- /dev/null +++ b/lib/mw2/middlewares/eqp/client.go @@ -0,0 +1,114 @@ +package eqp + +import ( + "errors" + + "pggat2/lib/mw2" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +type Client struct { + preparedStatements map[string]PreparedStatement + portals map[string]Portal +} + +func (T *Client) Send(_ mw2.Context, out zap.Out) error { + in := zap.OutToIn(out) + switch in.Type() { + case packets.ParseComplete, packets.BindComplete, packets.CloseComplete: + // should've been caught by eqp.Server + panic("unreachable") + } + return nil +} + +func (T *Client) Read(ctx mw2.Context, in zap.In) error { + switch in.Type() { + case packets.Query: + // clobber unnamed portal and unnamed prepared statement + delete(T.preparedStatements, "") + delete(T.portals, "") + case packets.Parse: + ctx.Cancel() + + destination, query, parameterDataTypes, ok := packets.ReadParse(in) + if !ok { + return errors.New("bad packet format") + } + if destination != "" { + if _, ok = T.preparedStatements[destination]; ok { + return errors.New("prepared statement already exists") + } + } + T.preparedStatements[destination] = PreparedStatement{ + Query: query, + ParameterDataTypes: parameterDataTypes, + } + + // send parse complete + out := zap.InToOut(in) + out.Reset() + out.Type(packets.ParseComplete) + err := ctx.Send(out) + if err != nil { + return err + } + case packets.Bind: + ctx.Cancel() + + destination, source, parameterFormatCodes, parameterValues, resultFormatCodes, ok := packets.ReadBind(in) + if !ok { + return errors.New("bad packet format") + } + if destination != "" { + if _, ok = T.portals[destination]; ok { + return errors.New("portal already exists") + } + } + T.portals[destination] = Portal{ + Source: source, + ParameterFormatCodes: parameterFormatCodes, + ParameterValues: parameterValues, + ResultFormatCodes: resultFormatCodes, + } + + // send bind complete + out := zap.InToOut(in) + out.Reset() + out.Type(packets.BindComplete) + err := ctx.Send(out) + if err != nil { + return err + } + case packets.Close: + ctx.Cancel() + + which, target, ok := packets.ReadClose(in) + if !ok { + return errors.New("bad packet format") + } + switch which { + case 'S': + delete(T.preparedStatements, target) + case 'P': + delete(T.portals, target) + default: + return errors.New("bad packet format") + } + + // send close complete + out := zap.InToOut(in) + out.Reset() + out.Type(packets.CloseComplete) + err := ctx.Send(out) + if err != nil { + return err + } + + // TODO(garet) we should read Describe and Execute to check if target exists + } + return nil +} + +var _ mw2.Middleware = (*Client)(nil) diff --git a/lib/mw2/middlewares/eqp/portal.go b/lib/mw2/middlewares/eqp/portal.go new file mode 100644 index 0000000000000000000000000000000000000000..c55ff00f78c6f4de44814428cc0c49ae61b7759b --- /dev/null +++ b/lib/mw2/middlewares/eqp/portal.go @@ -0,0 +1,8 @@ +package eqp + +type Portal struct { + Source string + ParameterFormatCodes []int16 + ParameterValues [][]byte + ResultFormatCodes []int16 +} diff --git a/lib/mw2/middlewares/eqp/preparedStatement.go b/lib/mw2/middlewares/eqp/preparedStatement.go new file mode 100644 index 0000000000000000000000000000000000000000..5f62b2433b6efb56130bd60aeb638ba7bfee73b3 --- /dev/null +++ b/lib/mw2/middlewares/eqp/preparedStatement.go @@ -0,0 +1,6 @@ +package eqp + +type PreparedStatement struct { + Query string + ParameterDataTypes []int32 +} diff --git a/lib/mw2/middlewares/eqp/server.go b/lib/mw2/middlewares/eqp/server.go new file mode 100644 index 0000000000000000000000000000000000000000..8703e67d5d870c2b4b6ab6c242daa54884a497a9 --- /dev/null +++ b/lib/mw2/middlewares/eqp/server.go @@ -0,0 +1,128 @@ +package eqp + +import ( + "errors" + + "pggat2/lib/mw2" + "pggat2/lib/util/ring" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +type Server struct { + preparedStatements map[string]PreparedStatement + portals map[string]Portal + pendingPreparedStatements ring.Ring[string] + pendingPortals ring.Ring[string] + + peer *Client +} + +func (T *Server) closePreparedStatement(ctx mw2.Context, target string) error { + +} + +func (T *Server) closePortal(ctx mw2.Context, target string) error { + +} + +func (T *Server) bindPreparedStatement(ctx mw2.Context, target string, preparedStatement PreparedStatement) error { + if _, ok := T.preparedStatements[target]; ok { + err := T.closePreparedStatement(ctx, target) + if err != nil { + return err + } + } +} + +func (T *Server) bindPortal(ctx mw2.Context, target string, portal Portal) error { + if _, ok := T.portals[target]; ok { + err := T.closePortal(ctx, target) + if err != nil { + return err + } + } +} + +func (T *Server) syncPreparedStatement(ctx mw2.Context, target string) error { + +} + +func (T *Server) syncPortal(ctx mw2.Context, target string) error { + +} + +func (T *Server) Send(ctx mw2.Context, out zap.Out) error { + in := zap.OutToIn(out) + switch in.Type() { + case packets.Query: + // clobber unnamed portal and unnamed prepared statement + delete(T.preparedStatements, "") + delete(T.portals, "") + case packets.Parse, packets.Bind, packets.Close: + // should've been caught by eqp.Client + panic("unreachable") + case packets.Describe: + // ensure target exists + which, target, ok := packets.ReadDescribe(in) + if !ok { + return errors.New("bad packet format") + } + switch which { + case 'S': + // sync prepared statement + err := T.syncPreparedStatement(ctx, target) + if err != nil { + return err + } + case 'P': + // sync portal + err := T.syncPortal(ctx, target) + if err != nil { + return err + } + default: + return errors.New("unknown describe target") + } + case packets.Execute: + target, _, ok := packets.ReadExecute(in) + if !ok { + return errors.New("bad packet format") + } + // sync portal + err := T.syncPortal(ctx, target) + if err != nil { + return err + } + } + + return nil +} + +func (T *Server) Read(ctx mw2.Context, in zap.In) error { + switch in.Type() { + case packets.ParseComplete: + ctx.Cancel() + + T.pendingPreparedStatements.PopFront() + case packets.BindComplete: + ctx.Cancel() + + T.pendingPortals.PopFront() + case packets.CloseComplete: + ctx.Cancel() + + // TODO(garet) Correctness: we could check this to make sure state is synced, but waiting for close is a pain + case packets.ReadyForQuery: + // all pending failed + for pending, ok := T.pendingPreparedStatements.PopFront(); ok; pending, ok = T.pendingPreparedStatements.PopFront() { + delete(T.preparedStatements, pending) + } + for pending, ok := T.pendingPortals.PopFront(); ok; pending, ok = T.pendingPortals.PopFront() { + delete(T.portals, pending) + } + } + return nil +} + +var _ mw2.Middleware = (*Server)(nil) diff --git a/lib/mw2/nil.go b/lib/mw2/nil.go new file mode 100644 index 0000000000000000000000000000000000000000..1d3d5fdeba835cc546c798fe319be328097bcbc1 --- /dev/null +++ b/lib/mw2/nil.go @@ -0,0 +1,15 @@ +package mw2 + +import "pggat2/lib/zap" + +type Nil struct{} + +func (Nil) Send(_ Context, _ zap.Out) error { + return nil +} + +func (Nil) Read(_ Context, _ zap.In) error { + return nil +} + +var _ Middleware = Nil{}