From 088e1789a5b765ec61949ce55312ff220ef96d4f Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Mon, 21 Aug 2023 17:01:39 -0500
Subject: [PATCH] a

---
 lib/bouncer/backends/v0/accept.go  |  2 +-
 lib/bouncer/conn.go                |  2 +-
 lib/bouncer/frontends/v0/accept.go |  8 +++++---
 lib/bouncer/pooler.go              |  1 +
 lib/gat/pool.go                    |  5 +++++
 lib/gat/pooler.go                  | 15 +++++++++++++++
 6 files changed, 28 insertions(+), 5 deletions(-)

diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go
index 984fd4de..078493c9 100644
--- a/lib/bouncer/backends/v0/accept.go
+++ b/lib/bouncer/backends/v0/accept.go
@@ -208,7 +208,7 @@ func startup1(conn *bouncer.Conn) (done bool, err error) {
 	switch packet.ReadType() {
 	case packets.BackendKeyData:
 		read := packet.Read()
-		ok := read.ReadBytes(conn.CancellationKey[:])
+		ok := read.ReadBytes(conn.BackendKey[:])
 		if !ok {
 			err = ErrBadFormat
 			return
diff --git a/lib/bouncer/conn.go b/lib/bouncer/conn.go
index 5cf06892..b469df5d 100644
--- a/lib/bouncer/conn.go
+++ b/lib/bouncer/conn.go
@@ -11,5 +11,5 @@ type Conn struct {
 	User              string
 	Database          string
 	InitialParameters map[strutil.CIString]string
-	CancellationKey   [8]byte
+	BackendKey        [8]byte
 }
diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go
index 654539b4..1f58f594 100644
--- a/lib/bouncer/frontends/v0/accept.go
+++ b/lib/bouncer/frontends/v0/accept.go
@@ -43,11 +43,13 @@ func startup0(
 		switch minorVersion {
 		case 5678:
 			// Cancel
-			if !read.ReadBytes(client.CancellationKey[:]) {
+			if !read.ReadBytes(client.BackendKey[:]) {
 				err = packets.ErrBadFormat
 				return
 			}
 
+			options.Pooler.Cancel(client.BackendKey)
+
 			err = perror.New(
 				perror.FATAL,
 				perror.ProtocolViolation,
@@ -357,14 +359,14 @@ func accept(
 	pkts.Append(packet)
 
 	// send backend key data
-	_, err2 := rand.Read(conn.CancellationKey[:])
+	_, err2 := rand.Read(conn.BackendKey[:])
 	if err2 != nil {
 		err = perror.Wrap(err2)
 		return
 	}
 
 	packet = zap.NewPacket()
-	packets.WriteBackendKeyData(packet, conn.CancellationKey)
+	packets.WriteBackendKeyData(packet, conn.BackendKey)
 	pkts.Append(packet)
 
 	if conn.InitialParameters == nil {
diff --git a/lib/bouncer/pooler.go b/lib/bouncer/pooler.go
index 73603cb8..18c3c5c4 100644
--- a/lib/bouncer/pooler.go
+++ b/lib/bouncer/pooler.go
@@ -6,4 +6,5 @@ import (
 
 type Pooler interface {
 	GetUserCredentials(user, database string) auth.Credentials
+	Cancel(cancellationKey [8]byte)
 }
diff --git a/lib/gat/pool.go b/lib/gat/pool.go
index 348a8146..1999342f 100644
--- a/lib/gat/pool.go
+++ b/lib/gat/pool.go
@@ -202,3 +202,8 @@ func (T *Pool) RemoveRecipe(name string) {
 func (T *Pool) Serve(conn bouncer.Conn) {
 	T.raw.Serve(&T.ctx, conn)
 }
+
+func (T *Pool) Cancel(key [8]byte) {
+	log.Println("cancel in pool", T, key)
+	// TODO(garet)
+}
diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go
index fa78ac74..6e16d21a 100644
--- a/lib/gat/pooler.go
+++ b/lib/gat/pooler.go
@@ -17,6 +17,9 @@ import (
 type Pooler struct {
 	config PoolerConfig
 
+	// key -> pool for cancellation
+	keys maps.RWLocked[[8]byte, *Pool]
+
 	users maps.RWLocked[string, *User]
 }
 
@@ -55,6 +58,15 @@ func (T *Pooler) GetUserCredentials(user, database string) auth.Credentials {
 	return u.GetCredentials()
 }
 
+func (T *Pooler) Cancel(key [8]byte) {
+	pool, ok := T.keys.Load(key)
+	if !ok {
+		return
+	}
+
+	pool.Cancel(key)
+}
+
 func (T *Pooler) IsStartupParameterAllowed(parameter strutil.CIString) bool {
 	return slices.Contains(T.config.AllowedStartupParameters, parameter)
 }
@@ -90,6 +102,9 @@ func (T *Pooler) Serve(client zap.ReadWriter) {
 		return
 	}
 
+	T.keys.Store(conn.BackendKey, pool)
+	defer T.keys.Delete(conn.BackendKey)
+
 	pool.Serve(conn)
 }
 
-- 
GitLab