From 4199dc279d0379ed7b6d91ee1ef869f11e35a423 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 19 Jul 2023 18:19:55 -0500
Subject: [PATCH] close and log

---
 lib/bouncer/frontends/v0/accept.go        | 6 +++---
 lib/gat/pooler.go                         | 9 ++++++++-
 lib/gat/pools/session/pool.go             | 6 ++++++
 lib/gat/pools/transaction/conn.go         | 1 +
 lib/gat/pools/transaction/pool.go         | 6 ++++++
 lib/middleware/interceptor/interceptor.go | 4 ++++
 lib/zap/reader.go                         | 6 ++++++
 lib/zap/writer.go                         | 6 ++++++
 8 files changed, 40 insertions(+), 4 deletions(-)

diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go
index 79d5d091..5caf09b7 100644
--- a/lib/bouncer/frontends/v0/accept.go
+++ b/lib/bouncer/frontends/v0/accept.go
@@ -242,7 +242,7 @@ func updateParameter(pkts *zap.Packets, name, value string) {
 	pkts.Append(packet)
 }
 
-func accept(client zap.ReadWriter, getPassword func(user string) (string, bool), initialParameterStatus map[string]string) (user string, database string, err perror.Error) {
+func accept(client zap.ReadWriter, getPassword func(user, database string) (string, bool), initialParameterStatus map[string]string) (user string, database string, err perror.Error) {
 	for {
 		var done bool
 		user, database, done, err = startup0(client)
@@ -254,7 +254,7 @@ func accept(client zap.ReadWriter, getPassword func(user string) (string, bool),
 		}
 	}
 
-	password, ok := getPassword(user)
+	password, ok := getPassword(user, database)
 	if !ok {
 		err = perror.New(
 			perror.FATAL,
@@ -313,7 +313,7 @@ func fail(client zap.ReadWriter, err perror.Error) {
 	_ = client.Write(packet)
 }
 
-func Accept(client zap.ReadWriter, getPassword func(user string) (string, bool), initialParameterStatus map[string]string) (user, database string, err perror.Error) {
+func Accept(client zap.ReadWriter, getPassword func(user, database string) (string, bool), initialParameterStatus map[string]string) (user, database string, err perror.Error) {
 	user, database, err = accept(client, getPassword, initialParameterStatus)
 	if err != nil {
 		fail(client, err)
diff --git a/lib/gat/pooler.go b/lib/gat/pooler.go
index 561f70de..0c292d18 100644
--- a/lib/gat/pooler.go
+++ b/lib/gat/pooler.go
@@ -54,24 +54,31 @@ func (T *Pooler) Serve(client zap.ReadWriter) {
 		unterminate.Unterminate,
 	)
 
-	username, database, err := frontends.Accept(client, func(username string) (string, bool) {
+	username, database, err := frontends.Accept(client, func(username, database string) (string, bool) {
 		user := T.GetUser(username)
 		if user == nil {
 			return "", false
 		}
+		pool := user.GetPool(database)
+		if pool == nil {
+			return "", false
+		}
 		return user.GetPassword(), true
 	}, DefaultParameterStatus)
 	if err != nil {
+		_ = client.Close()
 		return
 	}
 
 	user := T.GetUser(username)
 	if user == nil {
+		_ = client.Close()
 		return
 	}
 
 	pool := user.GetPool(database)
 	if pool == nil {
+		_ = client.Close()
 		return
 	}
 
diff --git a/lib/gat/pools/session/pool.go b/lib/gat/pools/session/pool.go
index 81122999..1301891c 100644
--- a/lib/gat/pools/session/pool.go
+++ b/lib/gat/pools/session/pool.go
@@ -1,6 +1,7 @@
 package session
 
 import (
+	"log"
 	"net"
 	"sync"
 
@@ -54,6 +55,7 @@ func (T *Pool) Serve(client zap.ReadWriter) {
 	for {
 		clientErr, serverErr := bouncers.Bounce(client, server)
 		if clientErr != nil || serverErr != nil {
+			_ = client.Close()
 			if serverErr == nil {
 				T.release(server)
 			}
@@ -66,13 +68,17 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) {
 	for i := 0; i < recipe.MinConnections; i++ {
 		conn, err := net.Dial("tcp", recipe.Address)
 		if err != nil {
+			_ = conn.Close()
 			// TODO(garet) do something here
+			log.Printf("Failed to connect to %s: %v", recipe.Address, err)
 			continue
 		}
 		rw := zap.WrapIOReadWriter(conn)
 		err2 := backends.Accept(rw, recipe.User, recipe.Password, recipe.Database)
 		if err2 != nil {
+			_ = conn.Close()
 			// TODO(garet) do something here
+			log.Printf("Failed to connect to %s: %v", recipe.Address, err2)
 			continue
 		}
 		T.release(rw)
diff --git a/lib/gat/pools/transaction/conn.go b/lib/gat/pools/transaction/conn.go
index 43c4bd83..836d7b29 100644
--- a/lib/gat/pools/transaction/conn.go
+++ b/lib/gat/pools/transaction/conn.go
@@ -21,6 +21,7 @@ func (T Conn) Do(_ rob.Constraints, work any) {
 	_, backendError := bouncers.Bounce(job.rw, T.rw)
 	if backendError != nil {
 		// TODO(garet) remove from pool
+		panic(backendError)
 	}
 	return
 }
diff --git a/lib/gat/pools/transaction/pool.go b/lib/gat/pools/transaction/pool.go
index eec49aea..a2ad659b 100644
--- a/lib/gat/pools/transaction/pool.go
+++ b/lib/gat/pools/transaction/pool.go
@@ -1,6 +1,7 @@
 package transaction
 
 import (
+	"log"
 	"net"
 
 	"pggat2/lib/bouncer/backends/v0"
@@ -29,7 +30,9 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) {
 	for i := 0; i < recipe.MinConnections; i++ {
 		conn, err := net.Dial("tcp", recipe.Address)
 		if err != nil {
+			_ = conn.Close()
 			// TODO(garet) do something here
+			log.Printf("Failed to connect to %s: %v", recipe.Address, err)
 			continue
 		}
 		rw := zap.WrapIOReadWriter(conn)
@@ -42,7 +45,9 @@ func (T *Pool) AddRecipe(name string, recipe gat.Recipe) {
 		)
 		err2 := backends.Accept(mw, recipe.User, recipe.Password, recipe.Database)
 		if err2 != nil {
+			_ = conn.Close()
 			// TODO(garet) do something here
+			log.Printf("Failed to connect to %s: %v", recipe.Address, err2)
 			continue
 		}
 		T.s.AddSink(0, Conn{
@@ -72,6 +77,7 @@ func (T *Pool) Serve(client zap.ReadWriter) {
 	defer buffer.Done()
 	for {
 		if err := buffer.Buffer(); err != nil {
+			_ = client.Close()
 			break
 		}
 		source.Do(0, Work{
diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go
index 8e0d280c..d19fdb1c 100644
--- a/lib/middleware/interceptor/interceptor.go
+++ b/lib/middleware/interceptor/interceptor.go
@@ -145,4 +145,8 @@ func (T *Interceptor) WriteV(packets *zap.Packets) error {
 	return T.rw.WriteV(packets)
 }
 
+func (T *Interceptor) Close() error {
+	return T.rw.Close()
+}
+
 var _ zap.ReadWriter = (*Interceptor)(nil)
diff --git a/lib/zap/reader.go b/lib/zap/reader.go
index 2ba2d2b5..36e002c3 100644
--- a/lib/zap/reader.go
+++ b/lib/zap/reader.go
@@ -6,6 +6,8 @@ type Reader interface {
 	ReadByte() (byte, error)
 	Read(*Packet) error
 	ReadUntyped(*UntypedPacket) error
+
+	Close() error
 }
 
 func WrapIOReader(readCloser io.ReadCloser) Reader {
@@ -39,4 +41,8 @@ func (T ioReader) ReadUntyped(packet *UntypedPacket) error {
 	return err
 }
 
+func (T ioReader) Close() error {
+	return T.closer.Close()
+}
+
 var _ Reader = ioReader{}
diff --git a/lib/zap/writer.go b/lib/zap/writer.go
index 24e571bc..667954fc 100644
--- a/lib/zap/writer.go
+++ b/lib/zap/writer.go
@@ -9,6 +9,8 @@ type Writer interface {
 	Write(*Packet) error
 	WriteUntyped(*UntypedPacket) error
 	WriteV(*Packets) error
+
+	Close() error
 }
 
 func WrapIOWriter(writeCloser io.WriteCloser) Writer {
@@ -43,4 +45,8 @@ func (T ioWriter) WriteV(packets *Packets) error {
 	return err
 }
 
+func (T ioWriter) Close() error {
+	return T.closer.Close()
+}
+
 var _ Writer = ioWriter{}
-- 
GitLab