diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 79d5d091d9965163841014c0b440381eeadd2833..5caf09b7e03d194370204653f815997f67b7d28c 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -242,7 +242,7 @@ func updateParameter(pkts *zap.Packets, name, value string) { pkts.Append(packet) } -func accept(client zap.ReadWriter, getPassword func(user string) (string, bool), initialParameterStatus map[string]string) (user string, database string, err perror.Error) { +func accept(client zap.ReadWriter, getPassword func(user, database string) (string, bool), initialParameterStatus map[string]string) (user string, database string, err perror.Error) { for { var done bool user, database, done, err = startup0(client) @@ -254,7 +254,7 @@ func accept(client zap.ReadWriter, getPassword func(user string) (string, bool), } } - password, ok := getPassword(user) + password, ok := getPassword(user, database) if !ok { err = perror.New( perror.FATAL, @@ -313,7 +313,7 @@ func fail(client zap.ReadWriter, err perror.Error) { _ = client.Write(packet) } -func Accept(client zap.ReadWriter, getPassword func(user string) (string, bool), initialParameterStatus map[string]string) (user, database string, err perror.Error) { +func Accept(client zap.ReadWriter, getPassword func(user, database string) (string, bool), initialParameterStatus map[string]string) (user, database string, err perror.Error) { user, database, err = accept(client, getPassword, initialParameterStatus) if err != nil { fail(client, err) diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go index 561f70deccf19c0b476bde2cb1a60d35d496e1ad..0c292d1877eb2e2c0f6b135f442b61449210fdbe 100644 --- a/lib/gat/pooler.go +++ b/lib/gat/pooler.go @@ -54,24 +54,31 @@ func (T *Pooler) Serve(client zap.ReadWriter) { unterminate.Unterminate, ) - username, database, err := frontends.Accept(client, func(username string) (string, bool) { + username, database, err := frontends.Accept(client, func(username, database string) (string, bool) { user := T.GetUser(username) if user == nil { return "", false } + pool := user.GetPool(database) + if pool == nil { + return "", false + } return user.GetPassword(), true }, DefaultParameterStatus) if err != nil { + _ = client.Close() return } user := T.GetUser(username) if user == nil { + _ = client.Close() return } pool := user.GetPool(database) if pool == nil { + _ = client.Close() return } diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go index 811229994aee93e0ff1e7e7ee3150d7230525592..1301891c52338aa5d74f76e2fae772201c25f736 100644 --- a/lib/gat/pools/session/pool.go +++ b/lib/gat/pools/session/pool.go @@ -1,6 +1,7 @@ package session import ( + "log" "net" "sync" @@ -54,6 +55,7 @@ func (T *Pool) Serve(client zap.ReadWriter) { for { clientErr, serverErr := bouncers.Bounce(client, server) if clientErr != nil || serverErr != nil { + _ = client.Close() if serverErr == nil { T.release(server) } @@ -66,13 +68,17 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) { for i := 0; i < recipe.MinConnections; i++ { conn, err := net.Dial("tcp", recipe.Address) if err != nil { + _ = conn.Close() // TODO(garet) do something here + log.Printf("Failed to connect to %s: %v", recipe.Address, err) continue } rw := zap.WrapIOReadWriter(conn) err2 := backends.Accept(rw, recipe.User, recipe.Password, recipe.Database) if err2 != nil { + _ = conn.Close() // TODO(garet) do something here + log.Printf("Failed to connect to %s: %v", recipe.Address, err2) continue } T.release(rw) diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go index 43c4bd83c0955502ae26febeb68bb02eb24098c9..836d7b299b83183c8f10c5fac45ba2bd2a29bec3 100644 --- a/lib/gat/pools/transaction/conn.go +++ b/lib/gat/pools/transaction/conn.go @@ -21,6 +21,7 @@ func (T Conn) Do(_ rob.Constraints, work any) { _, backendError := bouncers.Bounce(job.rw, T.rw) if backendError != nil { // TODO(garet) remove from pool + panic(backendError) } return } diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go index eec49aeac1c1354fc3847bb1996798ee570264b8..a2ad659b431cf167cdcd90c0e3e8be4eef9491b6 100644 --- a/lib/gat/pools/transaction/pool.go +++ b/lib/gat/pools/transaction/pool.go @@ -1,6 +1,7 @@ package transaction import ( + "log" "net" "pggat2/lib/bouncer/backends/v0" @@ -29,7 +30,9 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) { for i := 0; i < recipe.MinConnections; i++ { conn, err := net.Dial("tcp", recipe.Address) if err != nil { + _ = conn.Close() // TODO(garet) do something here + log.Printf("Failed to connect to %s: %v", recipe.Address, err) continue } rw := zap.WrapIOReadWriter(conn) @@ -42,7 +45,9 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) { ) err2 := backends.Accept(mw, recipe.User, recipe.Password, recipe.Database) if err2 != nil { + _ = conn.Close() // TODO(garet) do something here + log.Printf("Failed to connect to %s: %v", recipe.Address, err2) continue } T.s.AddSink(0, Conn{ @@ -72,6 +77,7 @@ func (T *Pool) Serve(client zap.ReadWriter) { defer buffer.Done() for { if err := buffer.Buffer(); err != nil { + _ = client.Close() break } source.Do(0, Work{ diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go index 8e0d280c1246154b08eb23d4a3c792e517ef7ad6..d19fdb1cecb4fa57120cec4a2cb83e8905b180fd 100644 --- a/lib/middleware/interceptor/interceptor.go +++ b/lib/middleware/interceptor/interceptor.go @@ -145,4 +145,8 @@ func (T *Interceptor) WriteV(packets *zap.Packets) error { return T.rw.WriteV(packets) } +func (T *Interceptor) Close() error { + return T.rw.Close() +} + var _ zap.ReadWriter = (*Interceptor)(nil) diff --git a/lib/zap/reader.go b/lib/zap/reader.go index 2ba2d2b52f8fb6bcc52ba39a57369f8e6143c355..36e002c36a99a6714445fab5a3930716887ae1e5 100644 --- a/lib/zap/reader.go +++ b/lib/zap/reader.go @@ -6,6 +6,8 @@ type Reader interface { ReadByte() (byte, error) Read(*Packet) error ReadUntyped(*UntypedPacket) error + + Close() error } func WrapIOReader(readCloser io.ReadCloser) Reader { @@ -39,4 +41,8 @@ func (T ioReader) ReadUntyped(packet *UntypedPacket) error { return err } +func (T ioReader) Close() error { + return T.closer.Close() +} + var _ Reader = ioReader{} diff --git a/lib/zap/writer.go b/lib/zap/writer.go index 24e571bc1b98fb31d3b339faec4428a6b464d8df..667954fcf0cc7ab8c89e5bd4808d0c557b6122eb 100644 --- a/lib/zap/writer.go +++ b/lib/zap/writer.go @@ -9,6 +9,8 @@ type Writer interface { Write(*Packet) error WriteUntyped(*UntypedPacket) error WriteV(*Packets) error + + Close() error } func WrapIOWriter(writeCloser io.WriteCloser) Writer { @@ -43,4 +45,8 @@ func (T ioWriter) WriteV(packets *Packets) error { return err } +func (T ioWriter) Close() error { + return T.closer.Close() +} + var _ Writer = ioWriter{}