diff --git a/README.md b/README.md index b0bf159b40f30ce720aa0b47952d73bbb5722a0d..513d2e752d0056ce3b6bcb8f3ded339a9adbe010 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,6 @@ Send each session to a new node. This mode supports all postgres features, but w ## Unsupported features One day these will maybe be supported -- Cancelling in flight queries - Reserve pool (for serving long-stalled clients) - Auth methods other than plaintext, MD5, and SASL-SCRAM-SHA256 - GSSAPI diff --git a/lib/gat/acceptor.go b/lib/gat/acceptor.go index 53c718e08263cf0e531e385e40f1bc927fcb2b6f..bccea393089ccb0c932659cff481ed742ba4be49 100644 --- a/lib/gat/acceptor.go +++ b/lib/gat/acceptor.go @@ -48,6 +48,7 @@ func serve(client zap.Conn, acceptParams frontends.AcceptParams, pools Pools) er if p == nil { return nil } + return p.Cancel(acceptParams.CancelKey) } diff --git a/lib/gat/pool/dialer.go b/lib/gat/pool/dialer.go index c8779dc7465724608e9d0a85375f044c4fb5d6f2..2d8ecdb69474e05dc18f611f2b8ac7e23bc4ae6c 100644 --- a/lib/gat/pool/dialer.go +++ b/lib/gat/pool/dialer.go @@ -9,6 +9,7 @@ import ( type Dialer interface { Dial() (zap.Conn, backends.AcceptParams, error) + Cancel(cancelKey [8]byte) error } type NetDialer struct { @@ -31,3 +32,15 @@ func (T NetDialer) Dial() (zap.Conn, backends.AcceptParams, error) { return conn, params, nil } + +func (T NetDialer) Cancel(cancelKey [8]byte) error { + c, err := net.Dial(T.Network, T.Address) + if err != nil { + return err + } + conn := zap.WrapNetConn(c) + defer func() { + _ = conn.Close() + }() + return backends.Cancel(conn, cancelKey) +} diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 6a673039b09173a7f6e8f80ec247d4567b353340..1eafe35255e07689b341f66f5617820ab22b85fc 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -44,12 +44,17 @@ type poolRecipe struct { count atomic.Int64 } +type poolClient struct { + conn zap.Conn + key [8]byte +} + type Pool struct { options Options recipes map[string]*poolRecipe servers map[uuid.UUID]*poolServer - clients map[uuid.UUID]zap.Conn + clients map[uuid.UUID]poolClient mu sync.Mutex } @@ -287,7 +292,7 @@ func (T *Pool) Serve( middlewares..., ) - clientID := T.addClient(client) + clientID := T.addClient(client, auth.BackendKey) var serverID uuid.UUID var server *poolServer @@ -341,16 +346,19 @@ func (T *Pool) Serve( } } -func (T *Pool) addClient(client zap.Conn) uuid.UUID { +func (T *Pool) addClient(client zap.Conn, key [8]byte) uuid.UUID { T.mu.Lock() defer T.mu.Unlock() clientID := uuid.New() if T.clients == nil { - T.clients = make(map[uuid.UUID]zap.Conn) + T.clients = make(map[uuid.UUID]poolClient) + } + T.clients[clientID] = poolClient{ + conn: client, + key: key, } - T.clients[clientID] = client T.options.Pooler.AddClient(clientID) return clientID } @@ -420,6 +428,50 @@ func (T *Pool) removeServer(serverID uuid.UUID) { } func (T *Pool) Cancel(key [8]byte) error { - // TODO(garet) implement cancel - return nil + dialer, backendKey := func() (Dialer, [8]byte) { + T.mu.Lock() + defer T.mu.Unlock() + + var clientID uuid.UUID + for id, client := range T.clients { + if client.key == key { + clientID = id + break + } + } + + if clientID == uuid.Nil { + return nil, [8]byte{} + } + + // get peer + var recipe string + var serverKey [8]byte + var ok bool + for _, server := range T.servers { + if server.peer == clientID { + recipe = server.recipe + serverKey = server.accept.BackendKey + ok = true + break + } + } + + if !ok { + return nil, [8]byte{} + } + + r, ok := T.recipes[recipe] + if !ok { + return nil, [8]byte{} + } + + return r.recipe.Dialer, serverKey + }() + + if dialer == nil { + return nil + } + + return dialer.Cancel(backendKey) } diff --git a/pgbouncer.ini b/pgbouncer.ini index f9ea6434c88c7e0deaa6cefba971d83db8da5032..6c175d6f9eb355246d436605ac5020519668c5f8 100644 --- a/pgbouncer.ini +++ b/pgbouncer.ini @@ -3,7 +3,6 @@ pool_mode = transaction auth_file = userlist.txt listen_addr = * track_extra_parameters = IntervalStyle, session_authorization, default_transaction_read_only, search_path -server_idle_timeout = 10 [users] postgres =