diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index d22bb2882d183c8236d2ec57093f5e7d1b120128..21ea9abb7c4f723fde7b046bde0d1f84a8b1194e 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -7,8 +7,7 @@ import ( "time" "pggat2/lib/gat" - "pggat2/lib/gat/pools/transaction" - "pggat2/lib/rob" + "pggat2/lib/gat/pools/session" ) func main() { @@ -25,11 +24,24 @@ func main() { pooler.AddUser("postgres", postgres) // create pool - rawPool := transaction.NewPool() + { + rawPool := session.NewPool(false) + pool := gat.NewPool(rawPool, 15*time.Second) + postgres.AddPool("postgres", pool) + pool.AddRecipe("localhost", gat.TCPRecipe{ + Database: "postgres", + Address: "localhost:5432", + User: "postgres", + Password: "password", + MinConnections: 0, + MaxConnections: 5, + }) + } + rawPool := session.NewPool(false) pool := gat.NewPool(rawPool, 15*time.Second) - postgres.AddPool("uniswap", pool) + postgres.AddPool("regression", pool) pool.AddRecipe("localhost", gat.TCPRecipe{ - Database: "uniswap", + Database: "regression", Address: "localhost:5432", User: "postgres", Password: "password", @@ -37,26 +49,26 @@ func main() { MaxConnections: 5, }) - go func() { - var metrics rob.Metrics - - for { - time.Sleep(1 * time.Second) - rawPool.ReadSchedulerMetrics(&metrics) - log.Println(metrics.String()) - } - }() /* go func() { - var metrics session.Metrics + var metrics rob.Metrics for { time.Sleep(1 * time.Second) - rawPool.ReadMetrics(&metrics) + rawPool.ReadSchedulerMetrics(&metrics) log.Println(metrics.String()) } }() */ + go func() { + var metrics session.Metrics + + for { + time.Sleep(1 * time.Second) + rawPool.ReadMetrics(&metrics) + log.Println(metrics.String()) + } + }() log.Println("Listening on :6432") diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 2d2c425b4f5a1a1ba2ebdfcd596dfb9dd848d762..af02dc0475ac4528802992629e688af6bd3137bd 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -182,7 +182,7 @@ func startup0(server zap.ReadWriter, username, password string) (done bool, err } } -func startup1(server zap.ReadWriter) (done bool, err error) { +func startup1(server zap.ReadWriter, parameterStatus map[string]string) (done bool, err error) { packet := zap.NewPacket() defer packet.Done() err = server.Read(packet) @@ -202,6 +202,12 @@ func startup1(server zap.ReadWriter) (done bool, err error) { // TODO(garet) put cancellation key somewhere return false, nil case packets.ParameterStatus: + key, value, ok := packets.ReadParameterStatus(&read) + if !ok { + err = ErrBadFormat + return + } + parameterStatus[key] = value return false, nil case packets.ReadyForQuery: return true, nil @@ -222,7 +228,9 @@ func startup1(server zap.ReadWriter) (done bool, err error) { } } -func Accept(server zap.ReadWriter, username, password, database string) error { +func Accept(server zap.ReadWriter, username, password, database string) (map[string]string, error) { + parameterStatus := make(map[string]string) + if database == "" { database = username } @@ -239,14 +247,14 @@ func Accept(server zap.ReadWriter, username, password, database string) error { err := server.WriteUntyped(packet) if err != nil { - return err + return nil, err } for { var done bool done, err = startup0(server, username, password) if err != nil { - return err + return nil, err } if done { break @@ -255,9 +263,9 @@ func Accept(server zap.ReadWriter, username, password, database string) error { for { var done bool - done, err = startup1(server) + done, err = startup1(server, parameterStatus) if err != nil { - return err + return nil, err } if done { break @@ -265,5 +273,5 @@ func Accept(server zap.ReadWriter, username, password, database string) error { } // startup complete, connection is ready for queries - return nil + return parameterStatus, nil } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 08511fcb8630d4cf8f33351aeb88d728f50687b2..a97a39ed138f5057b1eb2792ec79643e6ff339fe 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -130,7 +130,7 @@ func startup0(client zap.ReadWriter, startupParameters map[string]string) (user, // we don't support protocol extensions at the moment unsupportedOptions = append(unsupportedOptions, key) } else { - // TODO(garet) do something with this parameter + startupParameters[key] = value } } } @@ -311,6 +311,10 @@ func accept(client zap.ReadWriter, getPassword func(user, database string) (stri packets.WriteBackendKeyData(packet, cancellationKey) pkts.Append(packet) + updateParameter(pkts, "client_encoding", "UTF8") + updateParameter(pkts, "server_encoding", "UTF8") + updateParameter(pkts, "server_version", "14.5") + // send ready for query packet = zap.NewPacket() packets.WriteReadyForQuery(packet, 'I') diff --git a/lib/gat/pool.go b/lib/gat/pool.go index e38086e6d85b830abc7ac3db954513a0bab29962..ff3b1217fdc40d8ae003557685512deab3259541 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -20,7 +20,7 @@ type Context struct { type RawPool interface { Serve(ctx *Context, client zap.ReadWriter, startupParameters map[string]string) - AddServer(server zap.ReadWriter) uuid.UUID + AddServer(server zap.ReadWriter, startupParameters map[string]string) uuid.UUID GetServer(id uuid.UUID) zap.ReadWriter RemoveServer(id uuid.UUID) zap.ReadWriter @@ -36,18 +36,18 @@ type PoolRecipe struct { r Recipe } -func (T *PoolRecipe) connect() (zap.ReadWriter, error) { +func (T *PoolRecipe) connect() (zap.ReadWriter, map[string]string, error) { rw, err := T.r.Connect() if err != nil { - return nil, err + return nil, nil, err } - err = backends.Accept(rw, T.r.GetUser(), T.r.GetPassword(), T.r.GetDatabase()) + parameterStatus, err := backends.Accept(rw, T.r.GetUser(), T.r.GetPassword(), T.r.GetDatabase()) if err != nil { - return nil, err + return nil, nil, err } - return rw, nil + return rw, parameterStatus, nil } type Pool struct { @@ -120,13 +120,13 @@ func (T *Pool) tryAddServers(recipe *PoolRecipe, amount int) (remaining int) { max := maths.Min(recipe.r.GetMaxConnections()-j, amount) for i := 0; i < max; i++ { - conn, err := recipe.connect() + conn, ps, err := recipe.connect() if err != nil { log.Printf("error connecting to server: %v", err) continue } - id := T.raw.AddServer(conn) + id := T.raw.AddServer(conn, ps) recipe.servers = append(recipe.servers, id) remaining-- } @@ -141,13 +141,13 @@ func (T *Pool) addRecipe(recipe *PoolRecipe) { recipe.removed = false min := recipe.r.GetMinConnections() - len(recipe.servers) for i := 0; i < min; i++ { - conn, err := recipe.connect() + conn, ps, err := recipe.connect() if err != nil { log.Printf("error connecting to server: %v", err) continue } - id := T.raw.AddServer(conn) + id := T.raw.AddServer(conn, ps) recipe.servers = append(recipe.servers, id) } } diff --git a/lib/gat/pools/session/conn.go b/lib/gat/pools/session/conn.go new file mode 100644 index 0000000000000000000000000000000000000000..fafa27ca446e1543d7efbfaa5399ef9c11a50926 --- /dev/null +++ b/lib/gat/pools/session/conn.go @@ -0,0 +1,13 @@ +package session + +import ( + "github.com/google/uuid" + + "pggat2/lib/zap" +) + +type Conn struct { + id uuid.UUID + rw zap.ReadWriter + initialParameters map[string]string +} diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index 3400f28bf091f4a367619eaf3572fb0df6eb1dad..2ace4859ee3a799ba361cd7f0b7ff9af7101a110 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -1,7 +1,6 @@ package session import ( - "log" "sync" "time" @@ -14,6 +13,7 @@ import ( "pggat2/lib/util/maps" "pggat2/lib/util/ring" "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" ) type queueItem struct { @@ -26,7 +26,7 @@ type Pool struct { // use slice lifo for better perf queue ring.Ring[queueItem] - conns map[uuid.UUID]zap.ReadWriter + conns map[uuid.UUID]Conn ready sync.Cond qmu sync.Mutex } @@ -42,7 +42,7 @@ func NewPool(roundRobin bool) *Pool { return p } -func (T *Pool) acquire(ctx *gat.Context) (uuid.UUID, zap.ReadWriter) { +func (T *Pool) acquire(ctx *gat.Context) Conn { T.qmu.Lock() defer T.qmu.Unlock() for T.queue.Length() == 0 { @@ -56,7 +56,7 @@ func (T *Pool) acquire(ctx *gat.Context) (uuid.UUID, zap.ReadWriter) { } else { entry, _ = T.queue.PopBack() } - return entry.id, T.conns[entry.id] + return T.conns[entry.id] } func (T *Pool) _release(id uuid.UUID) { @@ -68,35 +68,59 @@ func (T *Pool) _release(id uuid.UUID) { T.ready.Signal() } -func (T *Pool) release(id uuid.UUID, server zap.ReadWriter) { +func (T *Pool) release(conn Conn) { // reset session state - err := backends.Query(server, "DISCARD ALL") + err := backends.Query(conn.rw, "DISCARD ALL") if err != nil { - _ = server.Close() + _ = conn.rw.Close() return } T.qmu.Lock() defer T.qmu.Unlock() - T._release(id) + T._release(conn.id) } func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, startupParameters map[string]string) { - id, server := T.acquire(ctx) + conn := T.acquire(ctx) + + pkts := zap.NewPackets() + for key, value := range conn.initialParameters { + if _, ok := startupParameters[key]; ok { + continue + } + packet := zap.NewPacket() + packets.WriteParameterStatus(packet, key, value) + pkts.Append(packet) + } + err := client.WriteV(pkts) + if err != nil { + pkts.Done() + _ = client.Close() + T.release(conn) + return + } + pkts.Done() - // TODO(garet) set startup parameters - log.Println(startupParameters) + for key, value := range startupParameters { + err = backends.Query(conn.rw, "SET "+key+" = '"+value+"'") + if err != nil { + _ = client.Close() + _ = conn.rw.Close() + return + } + } for { - clientErr, serverErr := bouncers.Bounce(client, server) + clientErr, serverErr := bouncers.Bounce(client, conn.rw) if clientErr != nil || serverErr != nil { _ = client.Close() if serverErr == nil { - T.release(id, server) + T.release(conn) } else { - _ = server.Close() + _ = conn.rw.Close() T.qmu.Lock() - delete(T.conns, id) + delete(T.conns, conn.id) T.qmu.Unlock() } break @@ -104,15 +128,19 @@ func (T *Pool) Serve(ctx *gat.Context, client zap.ReadWriter, startupParameters } } -func (T *Pool) AddServer(server zap.ReadWriter) uuid.UUID { +func (T *Pool) AddServer(server zap.ReadWriter, parameters map[string]string) uuid.UUID { T.qmu.Lock() defer T.qmu.Unlock() id := uuid.New() if T.conns == nil { - T.conns = make(map[uuid.UUID]zap.ReadWriter) + T.conns = make(map[uuid.UUID]Conn) + } + T.conns[id] = Conn{ + id: id, + rw: server, + initialParameters: parameters, } - T.conns[id] = server T._release(id) return id } @@ -121,7 +149,7 @@ func (T *Pool) GetServer(id uuid.UUID) zap.ReadWriter { T.qmu.Lock() defer T.qmu.Unlock() - return T.conns[id] + return T.conns[id].rw } func (T *Pool) RemoveServer(id uuid.UUID) zap.ReadWriter { @@ -133,7 +161,7 @@ func (T *Pool) RemoveServer(id uuid.UUID) zap.ReadWriter { return nil } delete(T.conns, id) - return conn + return conn.rw } func (T *Pool) ScaleDown(amount int) (remaining int) { @@ -154,7 +182,7 @@ func (T *Pool) ScaleDown(amount int) (remaining int) { } delete(T.conns, v.id) - _ = conn.Close() + _ = conn.rw.Close() remaining-- } diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go index 8755024dacfc318e69c2ea38c84b4fd7e880ad35..65cebeedf9a17ac3d81463de6177143019acf001 100644 --- a/lib/gat/pools/transaction/conn.go +++ b/lib/gat/pools/transaction/conn.go @@ -16,7 +16,14 @@ type Conn struct { func (T *Conn) Do(ctx *rob.Context, work any) { job := work.(Work) - job.ps.SetServer(T.ps) + + // sync parameters + err := T.ps.Sync(job.rw, job.ps) + if err != nil { + _ = job.rw.Close() + return + } + T.eqp.SetClient(job.eqp) clientErr, serverErr := bouncers.Bounce(job.rw, T.rw) if clientErr != nil || serverErr != nil { diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go index 7da2b33379db1ecd007837b16362c2e17337e251..7ddaaceedd2e40a18657418d8494859bec9af673 100644 --- a/lib/gat/pools/transaction/pool.go +++ b/lib/gat/pools/transaction/pool.go @@ -27,9 +27,9 @@ func NewPool() *Pool { return pool } -func (T *Pool) AddServer(server zap.ReadWriter) uuid.UUID { +func (T *Pool) AddServer(server zap.ReadWriter, parameters map[string]string) uuid.UUID { eqps := eqp.NewServer() - pss := ps.NewServer() + pss := ps.NewServer(parameters) mw := interceptor.NewInterceptor( server, eqps, diff --git a/lib/middleware/middlewares/ps/client.go b/lib/middleware/middlewares/ps/client.go index dd4e671655381df956e0a8618c483e2752d1251e..a78e7ef631889b288d500bcdbfeb106d9c9a5e49 100644 --- a/lib/middleware/middlewares/ps/client.go +++ b/lib/middleware/middlewares/ps/client.go @@ -11,9 +11,6 @@ import ( type Client struct { parameters map[string]string - peer *Server - dirty bool - middleware.Nil } @@ -23,60 +20,6 @@ func NewClient() *Client { } } -func (T *Client) SetServer(peer *Server) { - T.dirty = true - T.peer = peer -} - -func (T *Client) updateParameter0(ctx middleware.Context, name, value string) error { - packet := zap.NewPacket() - defer packet.Done() - packets.WriteParameterStatus(packet, name, value) - err := ctx.Write(packet) - if err != nil { - return err - } - - T.parameters[name] = value - - return nil -} - -func (T *Client) updateParameter(ctx middleware.Context, name, value string) error { - if T.parameters[name] == value { - return nil - } - - return T.updateParameter0(ctx, name, value) -} - -func (T *Client) sync(ctx middleware.Context) error { - if T.peer == nil || !T.dirty { - return nil - } - T.dirty = false - - for name, value := range T.parameters { - expected := T.peer.parameters[name] - if value == expected { - continue - } - err := T.updateParameter0(ctx, name, expected) - if err != nil { - return err - } - } - - for name, expected := range T.peer.parameters { - err := T.updateParameter(ctx, name, expected) - if err != nil { - return err - } - } - - return nil -} - func (T *Client) Send(ctx middleware.Context, packet *zap.Packet) error { read := packet.Read() switch read.ReadType() { @@ -92,7 +35,7 @@ func (T *Client) Send(ctx middleware.Context, packet *zap.Packet) error { } T.parameters[key] = value } - return T.sync(ctx) + return nil } var _ middleware.Middleware = (*Client)(nil) diff --git a/lib/middleware/middlewares/ps/server.go b/lib/middleware/middlewares/ps/server.go index 802ea5b21212012d6a1c6384586ed5c7ece1b24b..fd5cd0d44f0bb6e1a9ae43a44986daeb2edb0ebb 100644 --- a/lib/middleware/middlewares/ps/server.go +++ b/lib/middleware/middlewares/ps/server.go @@ -14,12 +14,44 @@ type Server struct { middleware.Nil } -func NewServer() *Server { +func NewServer(parameters map[string]string) *Server { return &Server{ - parameters: make(map[string]string), + parameters: parameters, } } +func (T *Server) syncParameter(pkts *zap.Packets, ps *Client, name, expected string) { + packet := zap.NewPacket() + packets.WriteParameterStatus(packet, name, expected) + pkts.Append(packet) + + ps.parameters[name] = expected +} + +func (T *Server) Sync(client zap.ReadWriter, ps *Client) error { + pkts := zap.NewPackets() + defer pkts.Done() + + for name, value := range ps.parameters { + expected := T.parameters[name] + if value == expected { + continue + } + + T.syncParameter(pkts, ps, name, expected) + } + + for name, expected := range T.parameters { + if T.parameters[name] == expected { + continue + } + + T.syncParameter(pkts, ps, name, expected) + } + + return client.WriteV(pkts) +} + func (T *Server) Read(_ middleware.Context, in *zap.Packet) error { read := in.Read() switch read.ReadType() { diff --git a/lib/rob/schedulers/v1/pool/pool.go b/lib/rob/schedulers/v1/pool/pool.go index 997e8634443bb12966301529baa63dffffc6e3b1..9052832126514ce92fadac406ce5c7b8993d9b97 100644 --- a/lib/rob/schedulers/v1/pool/pool.go +++ b/lib/rob/schedulers/v1/pool/pool.go @@ -113,7 +113,6 @@ func (T *Pool) AddWorker(constraints rob.Constraints, worker rob.Worker) uuid.UU s := sink.NewSink(id, constraints, worker) T.mu.Lock() - defer T.mu.Unlock() // if mu is locked, we don't need to lock bmu, because we are the only accessor T.sinks[id] = s i := 0 @@ -124,6 +123,9 @@ func (T *Pool) AddWorker(constraints rob.Constraints, worker rob.Worker) uuid.UU } } T.backlog = T.backlog[:i] + T.mu.Unlock() + + T.stealFor(id) return id }