diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..2321a827301decc1553a6bda94efcd6da917882f --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +# syntax=docker/dockerfile:1 +FROM golang:1.21-alpine as GOBUILDER +RUN apk add build-base git +WORKDIR /src +COPY . . + +RUN go mod tidy +RUN go build -o pggat ./cmd/cgat + +FROM alpine:latest +WORKDIR /bin +COPY --from=GOBUILDER /src/pggat pgbouncer + +# use these so it works with zalando/postgres-operator +ENTRYPOINT ["/bin/pgbouncer", "/etc/pgbouncer/pgbouncer.ini"] \ No newline at end of file diff --git a/lib/bouncer/backends/v0/cancel.go b/lib/bouncer/backends/v0/cancel.go new file mode 100644 index 0000000000000000000000000000000000000000..e7e111abd5c12f95cf8ef9048ebcad6bb781ea4b --- /dev/null +++ b/lib/bouncer/backends/v0/cancel.go @@ -0,0 +1,12 @@ +package backends + +import "pggat2/lib/zap" + +func Cancel(server zap.ReadWriter, key [8]byte) error { + packet := zap.NewUntypedPacket() + defer packet.Done() + packet.WriteUint16(1234) + packet.WriteUint16(5678) + packet.WriteBytes(key[:]) + return server.WriteUntyped(packet) +} diff --git a/lib/gat/pool.go b/lib/gat/pool.go index 1999342ffae71f1817cb0ee2d08ba65398174693..df5b27b158bba673466fc6671d00fe559a5724ab 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -8,10 +8,11 @@ import ( "github.com/google/uuid" "pggat2/lib/bouncer" + "pggat2/lib/bouncer/backends/v0" "pggat2/lib/util/maps" "pggat2/lib/util/maths" + "pggat2/lib/util/slices" "pggat2/lib/util/strutil" - "pggat2/lib/zap" ) type Context struct { @@ -22,8 +23,11 @@ type RawPool interface { Serve(ctx *Context, client bouncer.Conn) AddServer(server bouncer.Conn) uuid.UUID - GetServer(id uuid.UUID) zap.ReadWriter - RemoveServer(id uuid.UUID) zap.ReadWriter + GetServer(id uuid.UUID) bouncer.Conn + RemoveServer(id uuid.UUID) bouncer.Conn + + // LookupCorresponding finds the corresponding server and key for a particular client + LookupCorresponding(key [8]byte) (uuid.UUID, [8]byte, bool) ScaleDown(amount int) (remaining int) IdleSince() time.Time @@ -109,7 +113,7 @@ func (T *Pool) _tryAddServers(recipe *PoolRecipe, amount int) (remaining int) { j := 0 for i := 0; i < len(recipe.servers); i++ { - if T.raw.GetServer(recipe.servers[i]) != nil { + if T.raw.GetServer(recipe.servers[i]).RW != nil { recipe.servers[j] = recipe.servers[i] j++ } @@ -158,8 +162,8 @@ func (T *Pool) removeRecipe(recipe *PoolRecipe) { recipe.removed = true for _, id := range recipe.servers { - if conn := T.raw.RemoveServer(id); conn != nil { - _ = conn.Close() + if conn := T.raw.RemoveServer(id); conn.RW != nil { + _ = conn.RW.Close() } } @@ -204,6 +208,20 @@ func (T *Pool) Serve(conn bouncer.Conn) { } func (T *Pool) Cancel(key [8]byte) { - log.Println("cancel in pool", T, key) - // TODO(garet) + server, cancelKey, ok := T.raw.LookupCorresponding(key) + if !ok { + return + } + T.recipes.Range(func(_ string, recipe *PoolRecipe) bool { + if slices.Contains(recipe.servers, server) { + rw, err := recipe.r.Dial() + if err != nil { + return false + } + // error doesn't matter + _ = backends.Cancel(rw, cancelKey) + return false + } + return true + }) } diff --git a/lib/gat/pools/session/conn.go b/lib/gat/pools/session/conn.go deleted file mode 100644 index 8c5e0b38914c343dbfe52a63b791271fb1777ac3..0000000000000000000000000000000000000000 --- a/lib/gat/pools/session/conn.go +++ /dev/null @@ -1,14 +0,0 @@ -package session - -import ( - "github.com/google/uuid" - - "pggat2/lib/util/strutil" - "pggat2/lib/zap" -) - -type Conn struct { - id uuid.UUID - rw zap.ReadWriter - initialParameters map[strutil.CIString]string -} diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index 93cc2404d4538aecffda0387249fb5fcbaf11f9c..6b8d98c72134914720a4eb0b93e7770d3353d82a 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -29,7 +29,7 @@ type Pool struct { // use slice lifo for better perf queue ring.Ring[queueItem] - conns map[uuid.UUID]Conn + conns map[uuid.UUID]bouncer.Conn ready sync.Cond qmu sync.Mutex } @@ -43,7 +43,7 @@ func NewPool(config Config) *Pool { return p } -func (T *Pool) acquire(ctx *gat.Context) Conn { +func (T *Pool) acquire(ctx *gat.Context) (uuid.UUID, bouncer.Conn) { T.qmu.Lock() defer T.qmu.Unlock() for T.queue.Length() == 0 { @@ -57,7 +57,7 @@ func (T *Pool) acquire(ctx *gat.Context) Conn { } else { entry, _ = T.queue.PopBack() } - return T.conns[entry.id] + return entry.id, T.conns[entry.id] } func (T *Pool) _release(id uuid.UUID) { @@ -69,25 +69,25 @@ func (T *Pool) _release(id uuid.UUID) { T.ready.Signal() } -func (T *Pool) close(conn Conn) { - _ = conn.rw.Close() +func (T *Pool) close(id uuid.UUID, conn bouncer.Conn) { + _ = conn.RW.Close() T.qmu.Lock() defer T.qmu.Unlock() - delete(T.conns, conn.id) + delete(T.conns, id) } -func (T *Pool) release(conn Conn) { +func (T *Pool) release(id uuid.UUID, conn bouncer.Conn) { // reset session state - err := backends.QueryString(&backends.Context{}, conn.rw, "DISCARD ALL") + err := backends.QueryString(&backends.Context{}, conn.RW, "DISCARD ALL") if err != nil { - T.close(conn) + T.close(id, conn) return } T.qmu.Lock() defer T.qmu.Unlock() - T._release(conn.id) + T._release(id) } func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { @@ -95,13 +95,13 @@ func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { _ = client.RW.Close() }() - connOk := true - conn := T.acquire(ctx) + serverOK := true + serverID, server := T.acquire(ctx) defer func() { - if connOk { - T.release(conn) + if serverOK { + T.release(serverID, server) } else { - T.close(conn) + T.close(serverID, server) } }() @@ -110,7 +110,7 @@ func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { defer pkts.Done() add := func(key strutil.CIString) { - if value, ok := conn.initialParameters[key]; ok { + if value, ok := server.InitialParameters[key]; ok { pkt := zap.NewPacket() packets.WriteParameterStatus(pkt, key.String(), value) pkts.Append(pkt) @@ -119,7 +119,7 @@ func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { for key, value := range client.InitialParameters { // skip already set params - if conn.initialParameters[key] == value { + if server.InitialParameters[key] == value { add(key) continue } @@ -134,13 +134,13 @@ func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { packets.WriteParameterStatus(pkt, key.String(), value) pkts.Append(pkt) - if err := backends.SetParameter(&backends.Context{}, conn.rw, key, value); err != nil { - connOk = false + if err := backends.SetParameter(&backends.Context{}, server.RW, key, value); err != nil { + serverOK = false return true } } - for key := range conn.initialParameters { + for key := range server.InitialParameters { if _, ok := client.InitialParameters[key]; ok { continue } @@ -164,48 +164,49 @@ func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { if err := client.RW.Read(packet); err != nil { break } - clientErr, serverErr := bouncers.Bounce(client.RW, conn.rw, packet) + clientErr, serverErr := bouncers.Bounce(client.RW, server.RW, packet) if clientErr != nil || serverErr != nil { - connOk = serverErr == nil + serverOK = serverErr == nil break } } } +func (T *Pool) LookupCorresponding(key [8]byte) (uuid.UUID, [8]byte, bool) { + // TODO(garet) + return uuid.Nil, [8]byte{}, false +} + func (T *Pool) AddServer(server bouncer.Conn) uuid.UUID { T.qmu.Lock() defer T.qmu.Unlock() id := uuid.New() if T.conns == nil { - T.conns = make(map[uuid.UUID]Conn) - } - T.conns[id] = Conn{ - id: id, - rw: server.RW, - initialParameters: server.InitialParameters, + T.conns = make(map[uuid.UUID]bouncer.Conn) } + T.conns[id] = server T._release(id) return id } -func (T *Pool) GetServer(id uuid.UUID) zap.ReadWriter { +func (T *Pool) GetServer(id uuid.UUID) bouncer.Conn { T.qmu.Lock() defer T.qmu.Unlock() - return T.conns[id].rw + return T.conns[id] } -func (T *Pool) RemoveServer(id uuid.UUID) zap.ReadWriter { +func (T *Pool) RemoveServer(id uuid.UUID) bouncer.Conn { T.qmu.Lock() defer T.qmu.Unlock() conn, ok := T.conns[id] if !ok { - return nil + return bouncer.Conn{} } delete(T.conns, id) - return conn.rw + return conn } func (T *Pool) ScaleDown(amount int) (remaining int) { @@ -226,7 +227,7 @@ func (T *Pool) ScaleDown(amount int) (remaining int) { } delete(T.conns, v.id) - _ = conn.rw.Close() + _ = conn.RW.Close() remaining-- } diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go index 33ab6ab04b89bd67a621b446dcf521a28284a850..b7e1a54b9d23dd27bdefd070de57161cb0d092fe 100644 --- a/lib/gat/pools/transaction/conn.go +++ b/lib/gat/pools/transaction/conn.go @@ -1,15 +1,15 @@ package transaction import ( + "pggat2/lib/bouncer" "pggat2/lib/bouncer/bouncers/v2" "pggat2/lib/middleware/middlewares/eqp" "pggat2/lib/middleware/middlewares/ps" "pggat2/lib/rob" - "pggat2/lib/zap" ) type Conn struct { - rw zap.ReadWriter + b bouncer.Conn eqp *eqp.Server ps *ps.Server } @@ -23,20 +23,20 @@ func (T *Conn) Do(ctx *rob.Context, work any) { if clientErr != nil || serverErr != nil { _ = job.rw.Close() if serverErr != nil { - _ = T.rw.Close() + _ = T.b.RW.Close() ctx.Remove() } } }() // sync parameters - clientErr, serverErr = ps.Sync(job.trackedParameters, job.rw, job.ps, T.rw, T.ps) + clientErr, serverErr = ps.Sync(job.trackedParameters, job.rw, job.ps, T.b.RW, T.ps) if clientErr != nil || serverErr != nil { return } T.eqp.SetClient(job.eqp) - clientErr, serverErr = bouncers.Bounce(job.rw, T.rw, job.initialPacket) + clientErr, serverErr = bouncers.Bounce(job.rw, T.b.RW, job.initialPacket) if clientErr != nil || serverErr != nil { return } diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go index 86d731760b7af7747114567d46febfb1de4a4f1b..b6a36e89ebb26ef7eff63d5827dfa7b3ff230204 100644 --- a/lib/gat/pools/transaction/pool.go +++ b/lib/gat/pools/transaction/pool.go @@ -32,33 +32,33 @@ func NewPool(config Config) *Pool { func (T *Pool) AddServer(server bouncer.Conn) uuid.UUID { eqps := eqp.NewServer() pss := ps.NewServer(server.InitialParameters) - mw := interceptor.NewInterceptor( + server.RW = interceptor.NewInterceptor( server.RW, eqps, pss, ) sink := &Conn{ - rw: mw, + b: server, eqp: eqps, ps: pss, } return T.s.AddWorker(0, sink) } -func (T *Pool) GetServer(id uuid.UUID) zap.ReadWriter { +func (T *Pool) GetServer(id uuid.UUID) bouncer.Conn { conn := T.s.GetWorker(id) if conn == nil { - return nil + return bouncer.Conn{} } - return conn.(*Conn).rw + return conn.(*Conn).b } -func (T *Pool) RemoveServer(id uuid.UUID) zap.ReadWriter { +func (T *Pool) RemoveServer(id uuid.UUID) bouncer.Conn { conn := T.s.RemoveWorker(id) if conn == nil { - return nil + return bouncer.Conn{} } - return conn.(*Conn).rw + return conn.(*Conn).b } func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { @@ -94,6 +94,11 @@ func (T *Pool) Serve(ctx *gat.Context, client bouncer.Conn) { _ = c.Close() } +func (T *Pool) LookupCorresponding(key [8]byte) (uuid.UUID, [8]byte, bool) { + // TODO(garet) + return uuid.Nil, [8]byte{}, false +} + func (T *Pool) ScaleDown(amount int) (remaining int) { remaining = amount @@ -108,7 +113,7 @@ func (T *Pool) ScaleDown(amount int) (remaining int) { continue } conn := worker.(*Conn) - _ = conn.rw.Close() + _ = conn.b.RW.Close() remaining-- } diff --git a/lib/gat/recipe.go b/lib/gat/recipe.go index bdf42095b66f53f5eb45de97d288ac7d8ea5f364..9d42198e56369e479285622de12ae623add85b06 100644 --- a/lib/gat/recipe.go +++ b/lib/gat/recipe.go @@ -11,6 +11,7 @@ import ( ) type Recipe interface { + Dial() (zap.ReadWriter, error) Connect() (bouncer.Conn, error) GetMinConnections() int @@ -30,12 +31,20 @@ type TCPRecipe struct { StartupParameters map[strutil.CIString]string } -func (T TCPRecipe) Connect() (bouncer.Conn, error) { +func (T TCPRecipe) Dial() (zap.ReadWriter, error) { conn, err := net.Dial("tcp", T.Address) if err != nil { - return bouncer.Conn{}, err + return nil, err } rw := zap.WrapIOReadWriter(conn) + return rw, nil +} + +func (T TCPRecipe) Connect() (bouncer.Conn, error) { + rw, err := T.Dial() + if err != nil { + return bouncer.Conn{}, err + } server, err := backends.Accept(rw, backends.AcceptOptions{ Credentials: T.Credentials, diff --git a/pgbouncer.ini b/pgbouncer.ini index 6c175d6f9eb355246d436605ac5020519668c5f8..5e2ff5aa099ad513dbf8c73daae4856b246d337c 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -1,5 +1,5 @@ [pgbouncer] -pool_mode = transaction +pool_mode = session auth_file = userlist.txt listen_addr = * track_extra_parameters = IntervalStyle, session_authorization, default_transaction_read_only, search_path