diff --git a/lib/gat/pool/flow.go b/lib/gat/pool/flow.go index 9a7eb2199562384cc77f7df36f84330ebaa35d5d..b5fb57e0a534095cfafdcf748c230b2175455874 100644 --- a/lib/gat/pool/flow.go +++ b/lib/gat/pool/flow.go @@ -4,6 +4,7 @@ import ( "pggat/lib/bouncer/backends/v0" packets "pggat/lib/fed/packets/v3.0" "pggat/lib/gat/metrics" + "pggat/lib/middleware/middlewares/eqp" "pggat/lib/middleware/middlewares/ps" "pggat/lib/util/slices" ) @@ -19,8 +20,12 @@ func Pair(options Options, client *Client, server *Server) (clientErr, serverErr clientErr, serverErr = SyncInitialParameters(options, client, server) } + if clientErr != nil || serverErr != nil { + return + } + if options.ExtendedQuerySync { - server.GetEQP().SetClient(client.GetEQP()) + serverErr = eqp.Sync(client.GetEQP(), server.GetConn(), server.GetEQP()) } return diff --git a/lib/middleware/middlewares/eqp/client.go b/lib/middleware/middlewares/eqp/client.go index b439d154908d33417696eb49a179e2adffcaff98..ec1814504515606cbec7911865d2b7b6d805c55a 100644 --- a/lib/middleware/middlewares/eqp/client.go +++ b/lib/middleware/middlewares/eqp/client.go @@ -1,140 +1,25 @@ package eqp import ( - "errors" - "pggat/lib/fed" - packets "pggat/lib/fed/packets/v3.0" "pggat/lib/middleware" ) type Client struct { - preparedStatements map[string]PreparedStatement - portals map[string]Portal + state State } func NewClient() *Client { - return &Client{ - preparedStatements: make(map[string]PreparedStatement), - portals: make(map[string]Portal), - } -} - -func (T *Client) deletePreparedStatement(name string) { - delete(T.preparedStatements, name) -} - -func (T *Client) deletePortal(name string) { - delete(T.portals, name) + return new(Client) } -func (T *Client) Done() { - for name := range T.preparedStatements { - T.deletePreparedStatement(name) - } - for name := range T.portals { - T.deletePortal(name) - } -} - -func (T *Client) Write(_ middleware.Context, packet fed.Packet) error { - switch packet.Type() { - case packets.TypeReadyForQuery: - var readyForQuery packets.ReadyForQuery - if !readyForQuery.ReadFromPacket(packet) { - return errors.New("bad packet format a") - } - if readyForQuery == 'I' { - // clobber all named portals - for name := range T.portals { - T.deletePortal(name) - } - } - case packets.TypeParseComplete, packets.TypeBindComplete, packets.TypeCloseComplete: - // should've been caught by eqp.Server - panic("unreachable") - } +func (T *Client) Read(_ middleware.Context, packet fed.Packet) error { + T.state.C2S(packet) return nil } -func (T *Client) Read(ctx middleware.Context, packet fed.Packet) error { - switch packet.Type() { - case packets.TypeQuery: - // clobber unnamed portal and unnamed prepared statement - T.deletePreparedStatement("") - T.deletePortal("") - case packets.TypeParse: - ctx.Cancel() - - destination, preparedStatement, ok := ReadParse(packet) - if !ok { - return errors.New("bad packet format b") - } - - T.preparedStatements[destination] = preparedStatement - - // send parse complete - packet = fed.NewPacket(packets.TypeParseComplete) - err := ctx.Write(packet) - if err != nil { - return err - } - case packets.TypeBind: - ctx.Cancel() - - destination, portal, ok := ReadBind(packet) - if !ok { - return errors.New("bad packet format c") - } - - T.portals[destination] = portal - - // send bind complete - packet = fed.NewPacket(packets.TypeParseComplete) - err := ctx.Write(packet) - if err != nil { - return err - } - case packets.TypeClose: - ctx.Cancel() - - var p packets.Close - if !p.ReadFromPacket(packet) { - return errors.New("bad packet format d") - } - switch p.Which { - case 'S': - T.deletePreparedStatement(p.Target) - case 'P': - T.deletePortal(p.Target) - default: - return errors.New("bad packet format e") - } - - // send close complete - packet = fed.NewPacket(packets.TypeCloseComplete) - err := ctx.Write(packet) - if err != nil { - return err - } - case packets.TypeDescribe: - // ensure target exists - var describe packets.Describe - if !describe.ReadFromPacket(packet) { - return errors.New("bad packet format f") - } - switch describe.Which { - case 'S', 'P': - // ok - default: - return errors.New("unknown describe target") - } - case packets.TypeExecute: - var execute packets.Execute - if !execute.ReadFromPacket(packet) { - return errors.New("bad packet format g") - } - } +func (T *Client) Write(_ middleware.Context, packet fed.Packet) error { + T.state.S2C(packet) return nil } diff --git a/lib/middleware/middlewares/eqp/close.go b/lib/middleware/middlewares/eqp/close.go deleted file mode 100644 index 679a38e7d3b39cde4b3fbed66959450034e8aec8..0000000000000000000000000000000000000000 --- a/lib/middleware/middlewares/eqp/close.go +++ /dev/null @@ -1,8 +0,0 @@ -package eqp - -type Close struct { - Which byte - Target string - Source string - Hash uint64 -} diff --git a/lib/middleware/middlewares/eqp/portal.go b/lib/middleware/middlewares/eqp/portal.go deleted file mode 100644 index a3c8a382bf2ccb310ee717c6f67ae3be2ec2606c..0000000000000000000000000000000000000000 --- a/lib/middleware/middlewares/eqp/portal.go +++ /dev/null @@ -1,27 +0,0 @@ -package eqp - -import ( - "hash/maphash" - - "pggat/lib/fed" - packets "pggat/lib/fed/packets/v3.0" -) - -type Portal struct { - source string - packet fed.Packet - hash uint64 -} - -func ReadBind(in fed.Packet) (destination string, portal Portal, ok bool) { - if in.Type() != packets.TypeBind { - return - } - p := in.ReadString(&destination) - p = p.ReadString(&portal.source) - - portal.packet = in - portal.hash = maphash.Bytes(seed, portal.packet.Payload()) - ok = true - return -} diff --git a/lib/middleware/middlewares/eqp/preparedStatement.go b/lib/middleware/middlewares/eqp/preparedStatement.go deleted file mode 100644 index c0bad80234d635913dd0509e06c54871ffec6d54..0000000000000000000000000000000000000000 --- a/lib/middleware/middlewares/eqp/preparedStatement.go +++ /dev/null @@ -1,26 +0,0 @@ -package eqp - -import ( - "hash/maphash" - - "pggat/lib/fed" - packets "pggat/lib/fed/packets/v3.0" -) - -type PreparedStatement struct { - packet fed.Packet - hash uint64 -} - -func ReadParse(packet fed.Packet) (destination string, preparedStatement PreparedStatement, ok bool) { - if packet.Type() != packets.TypeParse { - return - } - - packet.ReadString(&destination) - - preparedStatement.packet = packet - preparedStatement.hash = maphash.Bytes(seed, preparedStatement.packet.Payload()) - ok = true - return -} diff --git a/lib/middleware/middlewares/eqp/seed.go b/lib/middleware/middlewares/eqp/seed.go deleted file mode 100644 index c6025e65eff7dc3e269a7981f3e26f94115bdbad..0000000000000000000000000000000000000000 --- a/lib/middleware/middlewares/eqp/seed.go +++ /dev/null @@ -1,6 +0,0 @@ -package eqp - -import "hash/maphash" - -// seed for use in maphash -var seed = maphash.MakeSeed() diff --git a/lib/middleware/middlewares/eqp/server.go b/lib/middleware/middlewares/eqp/server.go index 10f00b560e8ca662991ef814931147df7a3af149..bc1244a2aa6e29040b17692f0990dbba639e3bb1 100644 --- a/lib/middleware/middlewares/eqp/server.go +++ b/lib/middleware/middlewares/eqp/server.go @@ -1,325 +1,26 @@ package eqp import ( - "errors" - "pggat/lib/fed" - packets "pggat/lib/fed/packets/v3.0" "pggat/lib/middleware" - "pggat/lib/util/ring" ) -type HashedPortal struct { - source string - hash uint64 -} - type Server struct { - preparedStatements map[string]uint64 - portals map[string]HashedPortal - - pendingPreparedStatements ring.Ring[string] - pendingPortals ring.Ring[string] - pendingCloses ring.Ring[Close] - - peer *Client + state State } func NewServer() *Server { - return &Server{ - preparedStatements: make(map[string]uint64), - portals: make(map[string]HashedPortal), - } -} - -func (T *Server) SetClient(client *Client) { - T.peer = client -} - -func (T *Server) deletePreparedStatement(target string) { - delete(T.preparedStatements, target) -} - -func (T *Server) deletePortal(target string) { - delete(T.portals, target) -} - -func (T *Server) closePreparedStatement(ctx middleware.Context, target string) error { - // no need to close unnamed prepared statement - if target == "" { - return nil - } - - hash, ok := T.preparedStatements[target] - if !ok { - // already closed - return nil - } - - // send close packet - c := packets.Close{ - Which: 'S', - Target: target, - } - err := ctx.Write(c.IntoPacket()) - if err != nil { - return err - } - - // add it to pending - delete(T.preparedStatements, target) - T.pendingCloses.PushBack(Close{ - Which: 'S', - Target: target, - Hash: hash, - }) - return nil + return new(Server) } -func (T *Server) closePortal(ctx middleware.Context, target string) error { - /* - DON'T DO THIS!! Even though the unnamed portal doesn't need to be closed if the portal is ok, binding over an - unrunnable portal will keep the portal in an unrunnable state. - - if target == "" { - return nil - } - */ - - hash, ok := T.portals[target] - if !ok { - // already closed - return nil - } - - // send close packet - c := packets.Close{ - Which: 'P', - Target: target, - } - err := ctx.Write(c.IntoPacket()) - if err != nil { - return err - } - - // add it to pending - delete(T.portals, target) - T.pendingCloses.PushBack(Close{ - Which: 'P', - Target: target, - Source: hash.source, - Hash: hash.hash, - }) +func (T *Server) Read(_ middleware.Context, packet fed.Packet) error { + T.state.S2C(packet) return nil } -func (T *Server) bindPreparedStatement( - ctx middleware.Context, - target string, - preparedStatement PreparedStatement, -) error { - err := T.closePreparedStatement(ctx, target) - if err != nil { - return err - } - - err = ctx.Write(preparedStatement.packet) - if err != nil { - return err - } - - T.deletePreparedStatement(target) - T.preparedStatements[target] = preparedStatement.hash - T.pendingPreparedStatements.PushBack(target) +func (T *Server) Write(_ middleware.Context, packet fed.Packet) error { + T.state.C2S(packet) return nil } -func (T *Server) bindPortal( - ctx middleware.Context, - target string, - portal Portal, -) error { - // check if we already have it bound - if old, ok := T.portals[target]; ok { - if old.hash == portal.hash { - return nil - } - } - - err := T.closePortal(ctx, target) - if err != nil { - return err - } - - err = ctx.Write(portal.packet) - if err != nil { - return err - } - - T.deletePortal(target) - T.portals[target] = HashedPortal{ - source: portal.source, - hash: portal.hash, - } - T.pendingPortals.PushBack(target) - return nil -} - -func (T *Server) syncPreparedStatement(ctx middleware.Context, target string) error { - expected, some := T.peer.preparedStatements[target] - if !some { - return T.closePreparedStatement(ctx, target) - } - - // check if we already have it bound - if old, ok := T.preparedStatements[target]; ok { - if old == expected.hash { - return nil - } - } - - // clear all portals that use this prepared statement - for name, portal := range T.portals { - if portal.source == target { - err := T.closePortal(ctx, name) - if err != nil { - return err - } - } - } - - return T.bindPreparedStatement(ctx, target, expected) -} - -func (T *Server) syncPortal(ctx middleware.Context, target string) error { - expected, some := T.peer.portals[target] - if !some { - return T.closePortal(ctx, target) - } - - err := T.syncPreparedStatement(ctx, expected.source) - if err != nil { - return err - } - - // check if we already have it bound - if old, ok := T.portals[target]; ok { - if old.hash == expected.hash { - return nil - } - } - - return T.bindPortal(ctx, target, expected) -} - -func (T *Server) Write(ctx middleware.Context, packet fed.Packet) error { - switch packet.Type() { - case packets.TypeQuery: - // clobber unnamed portal and unnamed prepared statement - T.deletePreparedStatement("") - T.deletePortal("") - case packets.TypeParse, packets.TypeBind, packets.TypeClose: - // should've been caught by eqp.Client - panic("unreachable") - case packets.TypeDescribe: - // ensure target exists - var describe packets.Describe - if !describe.ReadFromPacket(packet) { - // should've been caught by eqp.Client - panic("unreachable") - } - switch describe.Which { - case 'S': - // sync prepared statement - err := T.syncPreparedStatement(ctx, describe.Target) - if err != nil { - return err - } - case 'P': - // sync portal - err := T.syncPortal(ctx, describe.Target) - if err != nil { - return err - } - default: - panic("unknown describe target") - } - case packets.TypeExecute: - var execute packets.Execute - if !execute.ReadFromPacket(packet) { - // should've been caught by eqp.Client - panic("unreachable") - } - // sync portal - err := T.syncPortal(ctx, execute.Target) - if err != nil { - return err - } - } - - return nil -} - -func (T *Server) Read(ctx middleware.Context, packet fed.Packet) error { - switch packet.Type() { - case packets.TypeParseComplete: - ctx.Cancel() - - T.pendingPreparedStatements.PopFront() - case packets.TypeBindComplete: - ctx.Cancel() - - T.pendingPortals.PopFront() - case packets.TypeCloseComplete: - ctx.Cancel() - - T.pendingCloses.PopFront() - case packets.TypeReadyForQuery: - var state packets.ReadyForQuery - if !state.ReadFromPacket(packet) { - return errors.New("bad packet format h") - } - if state == 'I' { - // clobber all portals - for name := range T.portals { - T.deletePortal(name) - } - } - // all pending failed - for pending, ok := T.pendingPreparedStatements.PopBack(); ok; pending, ok = T.pendingPreparedStatements.PopBack() { - T.deletePreparedStatement(pending) - } - for pending, ok := T.pendingPortals.PopBack(); ok; pending, ok = T.pendingPortals.PopBack() { - T.deletePortal(pending) - } - for pending, ok := T.pendingCloses.PopBack(); ok; pending, ok = T.pendingCloses.PopBack() { - switch pending.Which { - case 'S': // prepared statement - T.deletePreparedStatement(pending.Target) - T.preparedStatements[pending.Target] = pending.Hash - case 'P': // portal - T.deletePortal(pending.Target) - T.portals[pending.Target] = HashedPortal{ - hash: pending.Hash, - source: pending.Source, - } - default: - panic("unreachable") - } - } - } - return nil -} - -func (T *Server) Done() { - for name := range T.preparedStatements { - T.deletePreparedStatement(name) - } - for name := range T.portals { - T.deletePortal(name) - } - for _, ok := T.pendingCloses.PopBack(); ok; _, ok = T.pendingCloses.PopBack() { - } -} - var _ middleware.Middleware = (*Server)(nil) diff --git a/lib/middleware/middlewares/eqp2/sync.go b/lib/middleware/middlewares/eqp/state.go similarity index 65% rename from lib/middleware/middlewares/eqp2/sync.go rename to lib/middleware/middlewares/eqp/state.go index 8e0a4a08bc11f7a28163542386b7ee9ccfc6be53..0a6fc13f71f643ca21292ac1f494157e08be4cc3 100644 --- a/lib/middleware/middlewares/eqp2/sync.go +++ b/lib/middleware/middlewares/eqp/state.go @@ -1,20 +1,32 @@ -package eqp2 +package eqp import ( + "hash/maphash" + "pggat/lib/fed" packets "pggat/lib/fed/packets/v3.0" "pggat/lib/util/ring" ) +var seed = maphash.MakeSeed() + type PreparedStatement struct { Packet fed.Packet Target string + Hash uint64 } func MakePreparedStatement(packet fed.Packet) PreparedStatement { if packet.Type() != packets.TypeParse { panic("unreachable") } + + var res PreparedStatement + packet.ReadString(&res.Target) + res.Packet = packet + res.Hash = maphash.Bytes(seed, packet.Payload()) + + return res } type Portal struct { @@ -26,6 +38,12 @@ func MakePortal(packet fed.Packet) Portal { if packet.Type() != packets.TypeBind { panic("unreachable") } + + var res Portal + packet.ReadString(&res.Target) + res.Packet = packet + + return res } type CloseVariant int @@ -40,7 +58,7 @@ type Close struct { Target string } -type Sync struct { +type State struct { preparedStatements map[string]PreparedStatement portals map[string]Portal @@ -49,8 +67,51 @@ type Sync struct { pendingCloses ring.Ring[Close] } +// C2S is client to server packets +func (T *State) C2S(packet fed.Packet) { + switch packet.Type() { + case packets.TypeClose: + T.Close(packet) + case packets.TypeParse: + T.Parse(packet) + case packets.TypeBind: + T.Bind(packet) + case packets.TypeQuery: + T.Query() + } +} + +// S2C is server to client packets +func (T *State) S2C(packet fed.Packet) { + switch packet.Type() { + case packets.TypeCloseComplete: + T.CloseComplete() + case packets.TypeParseComplete: + T.ParseComplete() + case packets.TypeBindComplete: + T.BindComplete() + case packets.TypeReadyForQuery: + T.ReadyForQuery(packet) + } +} + // Close is a pending close. Execute on Close C->S -func (T *Sync) Close(variant CloseVariant, target string) { +func (T *State) Close(packet fed.Packet) { + var which byte + p := packet.ReadUint8(&which) + var target string + p.ReadString(&target) + + var variant CloseVariant + switch which { + case 'S': + variant = CloseVariantPreparedStatement + case 'P': + variant = CloseVariantPortal + default: + return + } + T.pendingCloses.PushBack(Close{ Variant: variant, Target: target, @@ -58,7 +119,7 @@ func (T *Sync) Close(variant CloseVariant, target string) { } // CloseComplete notifies that a close was successful. Execute on CloseComplete S->C -func (T *Sync) CloseComplete() { +func (T *State) CloseComplete() { c, ok := T.pendingCloses.PopFront() if !ok { return @@ -75,13 +136,13 @@ func (T *Sync) CloseComplete() { } // Parse is a pending prepared statement. Execute on Parse C->S -func (T *Sync) Parse(packet fed.Packet) { +func (T *State) Parse(packet fed.Packet) { preparedStatement := MakePreparedStatement(packet) T.pendingPreparedStatements.PushBack(preparedStatement) } // ParseComplete notifies that a parse was successful. Execute on ParseComplete S->C -func (T *Sync) ParseComplete() { +func (T *State) ParseComplete() { preparedStatement, ok := T.pendingPreparedStatements.PopFront() if !ok { return @@ -94,13 +155,13 @@ func (T *Sync) ParseComplete() { } // Bind is a pending portal. Execute on Bind C->S -func (T *Sync) Bind(packet fed.Packet) { +func (T *State) Bind(packet fed.Packet) { portal := MakePortal(packet) T.pendingPortals.PushBack(portal) } // BindComplete notifies that a bind was successful. Execute on BindComplete S->C -func (T *Sync) BindComplete() { +func (T *State) BindComplete() { portal, ok := T.pendingPortals.PopFront() if !ok { return @@ -113,13 +174,16 @@ func (T *Sync) BindComplete() { } // Query clobbers the unnamed portal and unnamed prepared statement. Execute on Query C->S -func (T *Sync) Query() { +func (T *State) Query() { delete(T.portals, "") delete(T.preparedStatements, "") } // ReadyForQuery clobbers portals if state == 'I' and pending. Execute on ReadyForQuery S->C -func (T *Sync) ReadyForQuery(state byte) { +func (T *State) ReadyForQuery(packet fed.Packet) { + var state byte + packet.ReadUint8(&state) + if state == 'I' { // clobber all portals for name := range T.portals { diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go new file mode 100644 index 0000000000000000000000000000000000000000..1e8c4b85c28667e374d6a2736f621fb56f964c6a --- /dev/null +++ b/lib/middleware/middlewares/eqp/sync.go @@ -0,0 +1,61 @@ +package eqp + +import ( + "pggat/lib/bouncer/backends/v0" + "pggat/lib/fed" + packets "pggat/lib/fed/packets/v3.0" +) + +func Sync(c *Client, server fed.ReadWriter, s *Server) error { + // close all portals on server + // we close all because there won't be any for the normal case anyway, and it's hard to tell + // if a portal is accurate because the underlying prepared statement could have changed. + for name := range s.state.portals { + p := packets.Close{ + Which: 'P', + Target: name, + } + if err := server.WritePacket(p.IntoPacket()); err != nil { + return err + } + } + + // close all prepared statements that don't match client + for name, preparedStatement := range s.state.preparedStatements { + clientPreparedStatement, ok := c.state.preparedStatements[name] + if ok && (name == "" || preparedStatement.Hash == clientPreparedStatement.Hash) { + // match or unnamed prepared statement that will be bound over + continue + } + + p := packets.Close{ + Which: 'S', + Target: name, + } + if err := server.WritePacket(p.IntoPacket()); err != nil { + return err + } + } + + // parse all prepared statements that aren't on server + for name, preparedStatement := range c.state.preparedStatements { + serverPreparedStatement, ok := s.state.preparedStatements[name] + if ok && preparedStatement.Hash == serverPreparedStatement.Hash { + // matched, don't need to set + continue + } + + if err := server.WritePacket(preparedStatement.Packet); err != nil { + return err + } + } + + // bind all portals + for _, portal := range c.state.portals { + if err := server.WritePacket(portal.Packet); err != nil { + return err + } + } + + return backends.Sync(new(backends.Context), server) +} diff --git a/lib/middleware/middlewares/eqp2/client.go b/lib/middleware/middlewares/eqp2/client.go deleted file mode 100644 index b15458cd9dcaf92ecd119bddff05cb42a2b7fdaf..0000000000000000000000000000000000000000 --- a/lib/middleware/middlewares/eqp2/client.go +++ /dev/null @@ -1,21 +0,0 @@ -package eqp2 - -import ( - "pggat/lib/fed" - "pggat/lib/middleware" -) - -type Client struct { -} - -func (T *Client) Read(ctx middleware.Context, packet fed.Packet) error { - // TODO implement me - panic("implement me") -} - -func (T *Client) Write(ctx middleware.Context, packet fed.Packet) error { - // TODO implement me - panic("implement me") -} - -var _ middleware.Middleware = (*Client)(nil)