diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 984fd4de07d6fe9f7051b590e915591db23fdb6f..078493c91603b80cefbc00985e955d50d86b8b6d 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -208,7 +208,7 @@ func startup1(conn *bouncer.Conn) (done bool, err error) { switch packet.ReadType() { case packets.BackendKeyData: read := packet.Read() - ok := read.ReadBytes(conn.CancellationKey[:]) + ok := read.ReadBytes(conn.BackendKey[:]) if !ok { err = ErrBadFormat return diff --git a/lib/bouncer/conn.go b/lib/bouncer/conn.go index 5cf06892c7b7e50dcc24b0cef605fd7d26b4229d..b469df5dae922ba69d1ec9b330d02a20fb2a1ec4 100644 --- a/lib/bouncer/conn.go +++ b/lib/bouncer/conn.go @@ -11,5 +11,5 @@ type Conn struct { User string Database string InitialParameters map[strutil.CIString]string - CancellationKey [8]byte + BackendKey [8]byte } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 654539b408c953476fc097028fb120ef8da8bbab..1f58f594e90d18d25fb5b053e98ba28041f490c0 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -43,11 +43,13 @@ func startup0( switch minorVersion { case 5678: // Cancel - if !read.ReadBytes(client.CancellationKey[:]) { + if !read.ReadBytes(client.BackendKey[:]) { err = packets.ErrBadFormat return } + options.Pooler.Cancel(client.BackendKey) + err = perror.New( perror.FATAL, perror.ProtocolViolation, @@ -357,14 +359,14 @@ func accept( pkts.Append(packet) // send backend key data - _, err2 := rand.Read(conn.CancellationKey[:]) + _, err2 := rand.Read(conn.BackendKey[:]) if err2 != nil { err = perror.Wrap(err2) return } packet = zap.NewPacket() - packets.WriteBackendKeyData(packet, conn.CancellationKey) + packets.WriteBackendKeyData(packet, conn.BackendKey) pkts.Append(packet) if conn.InitialParameters == nil { diff --git a/lib/bouncer/pooler.go b/lib/bouncer/pooler.go index 73603cb8375995c0a8902c2176e39a4fa65b5e43..18c3c5c492328716aa75a0d0d18ebd32ee0de65b 100644 --- a/lib/bouncer/pooler.go +++ b/lib/bouncer/pooler.go @@ -6,4 +6,5 @@ import ( type Pooler interface { GetUserCredentials(user, database string) auth.Credentials + Cancel(cancellationKey [8]byte) } diff --git a/lib/gat/pool.go b/lib/gat/pool.go index 348a8146c79600f021c51543cf43a30627e86cdf..1999342ffae71f1817cb0ee2d08ba65398174693 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -202,3 +202,8 @@ func (T *Pool) RemoveRecipe(name string) { func (T *Pool) Serve(conn bouncer.Conn) { T.raw.Serve(&T.ctx, conn) } + +func (T *Pool) Cancel(key [8]byte) { + log.Println("cancel in pool", T, key) + // TODO(garet) +} diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index fa78ac74e46c88202d74cccefc1bbd50ab02e644..6e16d21a6783a720edaf7d2b9db992a0c7be11c6 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -17,6 +17,9 @@ import ( type Pooler struct { config PoolerConfig + // key -> pool for cancellation + keys maps.RWLocked[[8]byte, *Pool] + users maps.RWLocked[string, *User] } @@ -55,6 +58,15 @@ func (T *Pooler) GetUserCredentials(user, database string) auth.Credentials { return u.GetCredentials() } +func (T *Pooler) Cancel(key [8]byte) { + pool, ok := T.keys.Load(key) + if !ok { + return + } + + pool.Cancel(key) +} + func (T *Pooler) IsStartupParameterAllowed(parameter strutil.CIString) bool { return slices.Contains(T.config.AllowedStartupParameters, parameter) } @@ -90,6 +102,9 @@ func (T *Pooler) Serve(client zap.ReadWriter) { return } + T.keys.Store(conn.BackendKey, pool) + defer T.keys.Delete(conn.BackendKey) + pool.Serve(conn) }