diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 6fe9254b8cc8184cc7889f48cf61c2a2757884e8..086de2e09e98a658d2e0efcad830fb2c83320a6e 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -21,7 +21,7 @@ func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) ( } if packet.Type() != packets.TypeAuthentication { - err = ErrUnexpectedPacket + err = ErrUnexpectedPacket(packet.Type()) return } @@ -55,7 +55,7 @@ func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) ( return true, nil default: - err = ErrUnexpectedPacket + err = ErrUnexpectedAuthenticationResponse return } } @@ -186,7 +186,7 @@ func startup0(ctx *acceptContext) (done bool, err error) { err = errors.New("server wanted to negotiate protocol version") return default: - err = ErrUnexpectedPacket + err = ErrUnexpectedPacket(packet.Type()) return } } @@ -235,7 +235,7 @@ func startup1(ctx *acceptContext) (done bool, err error) { // TODO(garet) do something with notice return false, nil default: - err = ErrUnexpectedPacket + err = ErrUnexpectedPacket(packet.Type()) return } } diff --git a/lib/bouncer/backends/v0/context.go b/lib/bouncer/backends/v0/context.go index 3554fc2bbbb9be1721fd4a233611dc9de63f0807..8288ab1812d3b20bb069ef4d78f2d27961ddd7a9 100644 --- a/lib/bouncer/backends/v0/context.go +++ b/lib/bouncer/backends/v0/context.go @@ -17,6 +17,10 @@ type context struct { TxState byte } +func (T *context) ErrUnexpectedPacket() error { + return ErrUnexpectedPacket(T.Packet.Type()) +} + func (T *context) ServerRead() error { var err error T.Packet, err = T.Server.ReadPacket(true) diff --git a/lib/bouncer/backends/v0/errors.go b/lib/bouncer/backends/v0/errors.go index b7a7ad7fff182f2192aacdfe9664cb22f9c64873..234793864b605e5a07eab9f17b9927460f08aa51 100644 --- a/lib/bouncer/backends/v0/errors.go +++ b/lib/bouncer/backends/v0/errors.go @@ -1,8 +1,17 @@ package backends -import "errors" +import ( + "errors" + "fmt" + + "gfx.cafe/gfx/pggat/lib/fed" +) + +func ErrUnexpectedPacket(typ fed.Type) error { + return fmt.Errorf("unexpected packet: %c", typ) +} var ( - ErrBadFormat = errors.New("bad packet format") - ErrUnexpectedPacket = errors.New("unexpected packet") + ErrExpectedIdle = errors.New("expected server to return ReadyForQuery(IDLE)") + ErrUnexpectedAuthenticationResponse = errors.New("unexpected authentication response") ) diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index d86b46c19faec8c64321c1dc02b44849f6e9ab66..ebd466d24c5c45f8d4871951a20de18e6716ef45 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -26,7 +26,7 @@ func copyIn(ctx *context) error { case packets.TypeCopyDone, packets.TypeCopyFail: return ctx.ServerWrite() default: - ctx.PeerFail(ErrUnexpectedPacket) + ctx.PeerFail(ctx.ErrUnexpectedPacket()) } } } @@ -50,7 +50,7 @@ func copyOut(ctx *context) error { ctx.PeerWrite() return nil default: - return ErrUnexpectedPacket + return ctx.ErrUnexpectedPacket() } } } @@ -95,7 +95,7 @@ func query(ctx *context) error { ctx.PeerWrite() return nil default: - return ErrUnexpectedPacket + return ctx.ErrUnexpectedPacket() } } } @@ -163,7 +163,7 @@ func functionCall(ctx *context) error { ctx.PeerWrite() return nil default: - return ErrUnexpectedPacket + return ctx.ErrUnexpectedPacket() } } } @@ -218,7 +218,7 @@ func sync(ctx *context) (bool, error) { ctx.PeerWrite() return true, nil default: - return false, ErrUnexpectedPacket + return false, ctx.ErrUnexpectedPacket() } } } @@ -267,7 +267,7 @@ func eqp(ctx *context) error { return err } default: - ctx.PeerFail(ErrUnexpectedPacket) + ctx.PeerFail(ctx.ErrUnexpectedPacket()) } } } @@ -293,7 +293,7 @@ func transaction(ctx *context) error { return err } default: - ctx.PeerFail(ErrUnexpectedPacket) + ctx.PeerFail(ctx.ErrUnexpectedPacket()) } if ctx.TxState == 'I' { @@ -308,7 +308,7 @@ func transaction(ctx *context) error { } if ctx.TxState != 'I' { - return ErrUnexpectedPacket + return ErrExpectedIdle } return nil } diff --git a/lib/gat/handlers/pool/pools/basic/pool.go b/lib/gat/handlers/pool/pools/basic/pool.go index a698727376eafa587885aa5979f22de429445d6f..d3023407259a83876a0a78c6f0a7c11fc2d4fd8c 100644 --- a/lib/gat/handlers/pool/pools/basic/pool.go +++ b/lib/gat/handlers/pool/pools/basic/pool.go @@ -1,6 +1,7 @@ package basic import ( + "fmt" "sync" "github.com/google/uuid" @@ -243,7 +244,7 @@ func (T *Pool) Serve(conn *fed.Conn) error { } if serverErr != nil { - return serverErr + return fmt.Errorf("server error: %w", serverErr) } else { client.TransactionComplete() server.TransactionComplete() diff --git a/lib/gat/handlers/pool/pools/hybrid/pool.go b/lib/gat/handlers/pool/pools/hybrid/pool.go index 2e5055fb58f55255c3ff3121bfe4cb519ac9c079..3e28d8643d2f986e8a488c13a15f57f4a7994f9e 100644 --- a/lib/gat/handlers/pool/pools/hybrid/pool.go +++ b/lib/gat/handlers/pool/pools/hybrid/pool.go @@ -1,6 +1,7 @@ package hybrid import ( + "fmt" "sync" "github.com/google/uuid" @@ -235,7 +236,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { err, serverErr = bouncers.Bounce(conn, replica.Conn, packet) } if serverErr != nil { - return serverErr + return fmt.Errorf("server error: %w", serverErr) } else { replica.TransactionComplete() } @@ -266,7 +267,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) } if serverErr != nil { - return serverErr + return fmt.Errorf("server error: %w", serverErr) } else { primary.TransactionComplete() } @@ -294,7 +295,7 @@ func (T *Pool) serveRW(conn *fed.Conn) error { err, serverErr = bouncers.Bounce(conn, primary.Conn, packet) } if serverErr != nil { - return serverErr + return fmt.Errorf("server error: %w", serverErr) } else { primary.TransactionComplete() } @@ -413,7 +414,7 @@ func (T *Pool) serveOnly(conn *fed.Conn, write bool) error { err, serverErr = bouncers.Bounce(conn, server.Conn, packet) } if serverErr != nil { - return serverErr + return fmt.Errorf("server error: %w", serverErr) } else { server.TransactionComplete() client.TransactionComplete() diff --git a/lib/gat/handlers/pool/spool/pool.go b/lib/gat/handlers/pool/spool/pool.go index feaf0401bd0db4839a02abadb798e551895003ba..700585aaa43053c143456c3a526aaa3a9d320072 100644 --- a/lib/gat/handlers/pool/spool/pool.go +++ b/lib/gat/handlers/pool/spool/pool.go @@ -14,6 +14,7 @@ import ( "gfx.cafe/gfx/pggat/lib/gat/handlers/pool" "gfx.cafe/gfx/pggat/lib/gat/handlers/pool/spool/kitchen" "gfx.cafe/gfx/pggat/lib/gat/metrics" + "gfx.cafe/gfx/pggat/lib/util/maps" ) type Pool struct { @@ -316,5 +317,11 @@ func (T *Pool) ReadMetrics(m *metrics.Pool) { func (T *Pool) Close() { close(T.closed) + T.oven.Close() T.pooler.Close() + + T.mu.Lock() + defer T.mu.Unlock() + maps.Clear(T.serversByID) + maps.Clear(T.serversByConn) }