diff --git a/lib/fed/middlewares/eqp/client.go b/lib/fed/middlewares/eqp/client.go index a2dcc633d9f1629742159faf9bc23cf4701d09b3..652d707369c1082da6f6fa4500acccbfd24e89e3 100644 --- a/lib/fed/middlewares/eqp/client.go +++ b/lib/fed/middlewares/eqp/client.go @@ -28,4 +28,8 @@ func (T *Client) PostWrite() (fed.Packet, error) { return nil, nil } +func (T *Client) Set(other *Client) { + T.state.Set(&other.state) +} + var _ fed.Middleware = (*Client)(nil) diff --git a/lib/fed/middlewares/eqp/state.go b/lib/fed/middlewares/eqp/state.go index f029e01d0bb26ebc510df7012404b891accad86d..c80723fc9729cd1b3c0f11df984d52c9689875b8 100644 --- a/lib/fed/middlewares/eqp/state.go +++ b/lib/fed/middlewares/eqp/state.go @@ -204,3 +204,35 @@ func (T *State) ReadyForQuery(packet fed.Packet) (fed.Packet, error) { return &p, nil } + +func (T *State) Set(other *State) { + maps.Clear(T.preparedStatements) + maps.Clear(T.portals) + + T.pendingPreparedStatements.Clear() + T.pendingPortals.Clear() + T.pendingCloses.Clear() + + if T.preparedStatements == nil { + T.preparedStatements = make(map[string]*packets.Parse) + } + if T.portals == nil { + T.portals = make(map[string]*packets.Bind) + } + + for k, v := range other.preparedStatements { + T.preparedStatements[k] = v + } + for k, v := range other.portals { + T.portals[k] = v + } + for i := 0; i < other.pendingPreparedStatements.Length(); i++ { + T.pendingPreparedStatements.PushBack(other.pendingPreparedStatements.Get(i)) + } + for i := 0; i < other.pendingPortals.Length(); i++ { + T.pendingPortals.PushBack(other.pendingPortals.Get(i)) + } + for i := 0; i < other.pendingCloses.Length(); i++ { + T.pendingCloses.PushBack(other.pendingCloses.Get(i)) + } +} diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index 18ec8fcc1479b814c88fdc16e3dd161efd605065..eaafdb0e7811ab2ff6c39e28129817d3b38abba5 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -20,11 +20,7 @@ func preparedStatementsEqual(a, b *packets.Parse) bool { return true } -func Sync(client, server *fed.Conn) error { - c, ok := fed.LookupMiddleware[*Client](client) - if !ok { - panic("middleware not found") - } +func SyncMiddleware(c *Client, server *fed.Conn) error { s, ok := fed.LookupMiddleware[*Server](server) if !ok { panic("middleware not found") @@ -35,10 +31,6 @@ func Sync(client, server *fed.Conn) error { // close all portals on server // we close all because there won't be any for the normal case anyway, and it's hard to tell // if a portal is accurate because the underlying prepared statement could have changed. - if len(s.state.portals) > 0 { - needsBackendSync = true - } - for name := range s.state.portals { p := packets.Close{ Which: 'P', @@ -47,6 +39,8 @@ func Sync(client, server *fed.Conn) error { if err := server.WritePacket(&p); err != nil { return err } + + needsBackendSync = true } // close all prepared statements that don't match client @@ -89,14 +83,12 @@ func Sync(client, server *fed.Conn) error { } // bind all portals - if len(c.state.portals) > 0 { - needsBackendSync = true - } - for _, portal := range c.state.portals { if err := server.WritePacket(portal); err != nil { return err } + + needsBackendSync = true } if needsBackendSync { @@ -107,3 +99,12 @@ func Sync(client, server *fed.Conn) error { return nil } + +func Sync(client, server *fed.Conn) error { + c, ok := fed.LookupMiddleware[*Client](client) + if !ok { + panic("middleware not found") + } + + return SyncMiddleware(c, server) +} diff --git a/lib/fed/middlewares/ps/client.go b/lib/fed/middlewares/ps/client.go index 36d1d4ef7bf81fa55a30ac8407ab0205097ac18a..94b5014ff1d39702527e388d09352c6a9fe53854 100644 --- a/lib/fed/middlewares/ps/client.go +++ b/lib/fed/middlewares/ps/client.go @@ -3,6 +3,7 @@ package ps import ( "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" + "gfx.cafe/gfx/pggat/lib/util/maps" "gfx.cafe/gfx/pggat/lib/util/strutil" ) @@ -52,4 +53,16 @@ func (T *Client) PostWrite() (fed.Packet, error) { return nil, nil } +func (T *Client) Set(other *Client) { + T.synced = other.synced + + maps.Clear(T.parameters) + if T.parameters == nil { + T.parameters = make(map[strutil.CIString]string) + } + for k, v := range other.parameters { + T.parameters[k] = v + } +} + var _ fed.Middleware = (*Client)(nil) diff --git a/lib/fed/middlewares/ps/sync.go b/lib/fed/middlewares/ps/sync.go index dea000782b07a32a49aa1b96d669e87c20ec4e6d..0a101af41091346777cdb96a5fc7bbc0c3ab177d 100644 --- a/lib/fed/middlewares/ps/sync.go +++ b/lib/fed/middlewares/ps/sync.go @@ -8,29 +8,26 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) error { +func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) (clientErr, serverErr error) { value, hasValue := c.parameters[name] expected, hasExpected := s.parameters[name] if value == expected { - if !c.synced { + if client != nil && !c.synced { ps := packets.ParameterStatus{ Key: name.String(), Value: expected, } - if err := client.WritePacket(&ps); err != nil { - return err - } + clientErr = client.WritePacket(&ps) } - return nil + return } var doSet bool if hasValue && slices.Contains(tracking, name) { - var err error - if err, _ = backends.SetParameter(server, nil, name, value); err != nil { - return err + if serverErr, _ = backends.SetParameter(server, nil, name, value); serverErr != nil { + return } if s.parameters == nil { s.parameters = make(map[strutil.CIString]string) @@ -42,12 +39,36 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. doSet = true } - if doSet { + if client != nil && doSet { ps := packets.ParameterStatus{ Key: name.String(), Value: expected, } - if err := client.WritePacket(&ps); err != nil { + if clientErr = client.WritePacket(&ps); clientErr != nil { + return + } + } + + return +} + +func SyncMiddleware(tracking []strutil.CIString, c *Client, server *fed.Conn) error { + s, ok := fed.LookupMiddleware[*Server](server) + if !ok { + panic("middleware not found") + } + + for name := range c.parameters { + if _, err := sync(tracking, nil, c, server, s, name); err != nil { + return err + } + } + + for name := range s.parameters { + if _, ok = c.parameters[name]; ok { + continue + } + if _, err := sync(tracking, nil, c, server, s, name); err != nil { return err } } @@ -66,7 +87,7 @@ func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, ser } for name := range c.parameters { - if serverErr = sync(tracking, client, c, server, s, name); serverErr != nil { + if clientErr, serverErr = sync(tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { return } } @@ -75,7 +96,7 @@ func Sync(tracking []strutil.CIString, client, server *fed.Conn) (clientErr, ser if _, ok = c.parameters[name]; ok { continue } - if serverErr = sync(tracking, client, c, server, s, name); serverErr != nil { + if clientErr, serverErr = sync(tracking, client, c, server, s, name); clientErr != nil || serverErr != nil { return } } diff --git a/lib/gat/handlers/pool/pools/basic/pool.go b/lib/gat/handlers/pool/pools/basic/pool.go index 147357abfe7eddfa699dde2fb0e2a96ef2d3a2f7..a698727376eafa587885aa5979f22de429445d6f 100644 --- a/lib/gat/handlers/pool/pools/basic/pool.go +++ b/lib/gat/handlers/pool/pools/basic/pool.go @@ -191,6 +191,8 @@ func (T *Pool) Serve(conn *fed.Conn) error { }() if !client.Conn.Ready { + client.SetState(metrics.ConnStateAwaitingServer, nil) + server = T.servers.Acquire(client.ID) if server == nil { return pool.ErrClosed @@ -227,6 +229,8 @@ func (T *Pool) Serve(conn *fed.Conn) error { } if server == nil { + client.SetState(metrics.ConnStateAwaitingServer, nil) + server = T.servers.Acquire(client.ID) if server == nil { return pool.ErrClosed diff --git a/lib/gat/handlers/pool/pools/hybrid/pool.go b/lib/gat/handlers/pool/pools/hybrid/pool.go index 2f598856e677f1692acf4cd30b3ee5dcf8a49b61..306bade7d6d360f37c2c047d3cab29fc4974fa9c 100644 --- a/lib/gat/handlers/pool/pools/hybrid/pool.go +++ b/lib/gat/handlers/pool/pools/hybrid/pool.go @@ -57,7 +57,7 @@ func (T *Pool) RemoveRecipe(name string) { } func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) { - client.SetState(metrics.ConnStatePairing, server, false) + client.SetState(metrics.ConnStatePairing, server, true) server.SetState(metrics.ConnStatePairing, client.ID) err, serverErr = ps.Sync(T.config.TrackedParameters, client.Conn, server.Conn) @@ -72,11 +72,27 @@ func (T *Pool) Pair(client *Client, server *spool.Server) (err, serverErr error) return } - client.SetState(metrics.ConnStateActive, server, false) + client.SetState(metrics.ConnStateActive, server, true) server.SetState(metrics.ConnStateActive, client.ID) return } +func (T *Pool) PairPrimary(client *Client, psc *ps.Client, eqpc *eqp.Client, server *spool.Server) error { + server.SetState(metrics.ConnStatePairing, client.ID) + + if err := ps.SyncMiddleware(T.config.TrackedParameters, psc, server.Conn); err != nil { + return err + } + + if err := eqp.SyncMiddleware(eqpc, server.Conn); err != nil { + return err + } + + client.SetState(metrics.ConnStateActive, server, false) + server.SetState(metrics.ConnStateActive, client.ID) + return nil +} + func (T *Pool) addClient(client *Client) { T.mu.Lock() defer T.mu.Unlock() @@ -96,11 +112,17 @@ func (T *Pool) removeClient(client *Client) { func (T *Pool) Serve(conn *fed.Conn) error { m := NewMiddleware() + + eqpa := eqp.NewClient() + eqpi := eqp.NewClient() + psa := ps.NewClient(conn.InitialParameters) + psi := ps.NewClient(nil) + conn.Middleware = append( conn.Middleware, unterminate.Unterminate, - ps.NewClient(conn.InitialParameters), - eqp.NewClient(), + psa, + eqpa, m, ) @@ -137,6 +159,8 @@ func (T *Pool) Serve(conn *fed.Conn) error { }() if !conn.Ready { + client.SetState(metrics.ConnStateAwaitingServer, nil, false) + replica = T.replica.Acquire(client.ID) if replica == nil { return pool.ErrClosed @@ -175,6 +199,8 @@ func (T *Pool) Serve(conn *fed.Conn) error { return err } + client.SetState(metrics.ConnStateAwaitingServer, nil, false) + replica = T.replica.Acquire(client.ID) if replica == nil { return pool.ErrClosed @@ -182,6 +208,9 @@ func (T *Pool) Serve(conn *fed.Conn) error { err, serverErr = T.Pair(client, replica) + psi.Set(psa) + eqpi.Set(eqpa) + if err == nil && serverErr == nil { err, serverErr = bouncers.Bounce(conn, replica.Conn, packet) } @@ -204,15 +233,19 @@ func (T *Pool) Serve(conn *fed.Conn) error { return err } + client.SetState(metrics.ConnStateAwaitingServer, nil, false) + // acquire primary primary = T.primary.Acquire(client.ID) if primary == nil { return pool.ErrClosed } - // TODO(garet) get primary in the same state replica was when the tx started + serverErr = T.PairPrimary(client, psi, eqpi, primary) - err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) + if serverErr == nil { + err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) + } if serverErr != nil { return serverErr } else {