From 88e26e820c81b5818d567d393480bcce2e9a48fb Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 3 May 2023 16:33:48 -0500
Subject: [PATCH] clean up errors and disconnection

---
 lib/frontend/frontends/v0/client.go   | 155 ++++++++++++++------------
 lib/frontend/frontends/v0/frontend.go |  15 +--
 2 files changed, 92 insertions(+), 78 deletions(-)

diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go
index 54ea4272..3b0a259e 100644
--- a/lib/frontend/frontends/v0/client.go
+++ b/lib/frontend/frontends/v0/client.go
@@ -24,6 +24,17 @@ var ErrProtocolError = perror.New(
 	"Expected a different packet",
 )
 
+func WrapError(err error) perror.Error {
+	if err == nil {
+		return nil
+	}
+	return perror.New(
+		perror.FATAL,
+		perror.InternalError,
+		err.Error(),
+	)
+}
+
 type Client struct {
 	conn net.Conn
 
@@ -37,7 +48,7 @@ type Client struct {
 	cancellationKey [8]byte
 }
 
-func NewClient(conn net.Conn) (*Client, error) {
+func NewClient(conn net.Conn) *Client {
 	client := &Client{
 		conn:   conn,
 		Reader: pnet.MakeReader(conn),
@@ -45,25 +56,26 @@ func NewClient(conn net.Conn) (*Client, error) {
 	}
 	err := client.accept()
 	if err != nil {
-		return nil, err
+		client.Close(err)
+		return nil
 	}
-	return client, nil
+	return client
 }
 
-func (T *Client) startup0() (bool, error) {
+func (T *Client) startup0() (bool, perror.Error) {
 	startup, err := T.ReadUntyped()
 	if err != nil {
-		return false, err
+		return false, WrapError(err)
 	}
 	reader := packet.MakeReader(startup)
 
 	majorVersion, ok := reader.Uint16()
 	if !ok {
-		return false, T.Error(ErrBadPacketFormat)
+		return false, ErrBadPacketFormat
 	}
 	minorVersion, ok := reader.Uint16()
 	if !ok {
-		return false, T.Error(ErrBadPacketFormat)
+		return false, ErrBadPacketFormat
 	}
 
 	if majorVersion == 1234 {
@@ -71,36 +83,34 @@ func (T *Client) startup0() (bool, error) {
 		switch minorVersion {
 		case 5678:
 			// Cancel
-			err = T.Error(perror.New(
+			return false, perror.New(
 				perror.FATAL,
 				perror.FeatureNotSupported,
 				"Cancel is not supported yet",
-			))
-			return false, err
+			)
 		case 5679:
 			// SSL is not supported yet
 			err = T.WriteByte('N')
-			return false, err
+			return false, WrapError(err)
 		case 5680:
 			// GSSAPI is not supported yet
 			err = T.WriteByte('N')
-			return false, err
+			return false, WrapError(err)
 		default:
-			err = T.Error(perror.New(
+			return false, perror.New(
 				perror.FATAL,
 				perror.ProtocolViolation,
 				"Unknown request code",
-			))
-			return false, err
+			)
 		}
 	}
 
 	if majorVersion != 3 {
-		err = T.Error(perror.New(
+		return false, perror.New(
 			perror.FATAL,
 			perror.ProtocolViolation,
 			"Unsupported protocol version",
-		))
+		)
 	}
 
 	var unsupportedOptions []string
@@ -108,7 +118,7 @@ func (T *Client) startup0() (bool, error) {
 	for {
 		key, ok := reader.String()
 		if !ok {
-			return false, T.Error(ErrBadPacketFormat)
+			return false, ErrBadPacketFormat
 		}
 		if key == "" {
 			break
@@ -116,7 +126,7 @@ func (T *Client) startup0() (bool, error) {
 
 		value, ok := reader.String()
 		if !ok {
-			return false, T.Error(ErrBadPacketFormat)
+			return false, ErrBadPacketFormat
 		}
 
 		switch key {
@@ -125,17 +135,17 @@ func (T *Client) startup0() (bool, error) {
 		case "database":
 			T.database = value
 		case "options":
-			return false, T.Error(perror.New(
+			return false, perror.New(
 				perror.FATAL,
 				perror.FeatureNotSupported,
 				"Startup options are not supported yet",
-			))
+			)
 		case "replication":
-			return false, T.Error(perror.New(
+			return false, perror.New(
 				perror.FATAL,
 				perror.FeatureNotSupported,
 				"Replication mode is not supported yet",
-			))
+			)
 		default:
 			unsupportedOptions = append(unsupportedOptions, key)
 		}
@@ -153,16 +163,16 @@ func (T *Client) startup0() (bool, error) {
 
 		err = T.Write(builder.Raw())
 		if err != nil {
-			return false, err
+			return false, WrapError(err)
 		}
 	}
 
 	if T.user == "" {
-		return false, T.Error(perror.New(
+		return false, perror.New(
 			perror.FATAL,
 			perror.InvalidAuthorizationSpecification,
 			"User is required",
-		))
+		)
 	}
 	if T.database == "" {
 		T.database = T.user
@@ -171,7 +181,7 @@ func (T *Client) startup0() (bool, error) {
 	return true, nil
 }
 
-func (T *Client) authenticationSASL(username, password string) error {
+func (T *Client) authenticationSASL(username, password string) perror.Error {
 	var builder packet.Builder
 	builder.Type(packet.Authentication)
 	builder.Int32(10)
@@ -182,37 +192,37 @@ func (T *Client) authenticationSASL(username, password string) error {
 
 	err := T.Write(builder.Raw())
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	// check which authentication method the client wants
 	pkt, err := T.Read()
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 	if pkt.Type != packet.AuthenticationResponse {
-		return T.Error(ErrBadPacketFormat)
+		return ErrBadPacketFormat
 	}
 
 	reader := packet.MakeReader(pkt)
 	mechanism, ok := reader.String()
 	if !ok {
-		return T.Error(ErrBadPacketFormat)
+		return ErrBadPacketFormat
 	}
 	tool, err := sasl.NewServer(mechanism, username, password)
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 	_, ok = reader.Int32()
 	if !ok {
-		return T.Error(ErrBadPacketFormat)
+		return ErrBadPacketFormat
 	}
 
 	resp, done, err := tool.InitialResponse(reader.Remaining())
 
 	for {
 		if err != nil {
-			return err
+			return WrapError(err)
 		}
 		if done {
 			builder = packet.Builder{}
@@ -221,7 +231,7 @@ func (T *Client) authenticationSASL(username, password string) error {
 			builder.Bytes(resp)
 			err = T.Write(builder.Raw())
 			if err != nil {
-				return err
+				return WrapError(err)
 			}
 			break
 		} else {
@@ -231,16 +241,16 @@ func (T *Client) authenticationSASL(username, password string) error {
 			builder.Bytes(resp)
 			err = T.Write(builder.Raw())
 			if err != nil {
-				return err
+				return WrapError(err)
 			}
 		}
 
 		pkt, err = T.Read()
 		if err != nil {
-			return err
+			return WrapError(err)
 		}
 		if pkt.Type != packet.AuthenticationResponse {
-			return T.Error(ErrProtocolError)
+			return ErrProtocolError
 		}
 
 		resp, done, err = tool.Continue(pkt.Payload)
@@ -249,11 +259,11 @@ func (T *Client) authenticationSASL(username, password string) error {
 	return nil
 }
 
-func (T *Client) authenticationMD5(username, password string) error {
+func (T *Client) authenticationMD5(username, password string) perror.Error {
 	var salt [4]byte
 	_, err := rand.Read(salt[:])
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	// password time
@@ -265,41 +275,41 @@ func (T *Client) authenticationMD5(username, password string) error {
 
 	err = T.Write(builder.Raw())
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	// read password
 	pkt, err := T.Read()
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	reader := packet.MakeReader(pkt)
 	if reader.Type() != packet.AuthenticationResponse {
-		return T.Error(perror.New(
+		return perror.New(
 			perror.FATAL,
 			perror.ProtocolViolation,
 			"Expected password",
-		))
+		)
 	}
 
 	pw, ok := reader.String()
 	if !ok {
-		return T.Error(ErrBadPacketFormat)
+		return ErrBadPacketFormat
 	}
 
 	if !md5.Check(username, password, salt, pw) {
-		return T.Error(perror.New(
+		return perror.New(
 			perror.FATAL,
 			perror.InvalidPassword,
 			"Invalid password",
-		))
+		)
 	}
 
 	return nil
 }
 
-func (T *Client) accept() error {
+func (T *Client) accept() perror.Error {
 	for {
 		done, err := T.startup0()
 		if err != nil {
@@ -311,9 +321,9 @@ func (T *Client) accept() error {
 	}
 
 	// TODO(garet) don't hardcode username and password
-	err := T.authenticationSASL("test", "password")
-	if err != nil {
-		return err
+	perr := T.authenticationSASL("test", "password")
+	if perr != nil {
+		return perr
 	}
 
 	// send auth ok
@@ -321,15 +331,15 @@ func (T *Client) accept() error {
 	builder.Type(packet.Authentication)
 	builder.Uint32(0)
 
-	err = T.Write(builder.Raw())
+	err := T.Write(builder.Raw())
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	// send backend key data
 	_, err = rand.Read(T.cancellationKey[:])
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 	builder = packet.Builder{}
 	builder.Type(packet.BackendKeyData)
@@ -337,7 +347,7 @@ func (T *Client) accept() error {
 
 	err = T.Write(builder.Raw())
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	// send ready for query
@@ -347,33 +357,36 @@ func (T *Client) accept() error {
 
 	err = T.Write(builder.Raw())
 	if err != nil {
-		return err
+		return WrapError(err)
 	}
 
 	return nil
 }
 
-func (T *Client) Error(err perror.Error) error {
-	var builder packet.Builder
-	builder.Type(packet.ErrorResponse)
+func (T *Client) Close(err perror.Error) {
+	if err != nil {
+		var builder packet.Builder
+		builder.Type(packet.ErrorResponse)
 
-	builder.Uint8('S')
-	builder.String(string(err.Severity()))
+		builder.Uint8('S')
+		builder.String(string(err.Severity()))
 
-	builder.Uint8('C')
-	builder.String(string(err.Code()))
+		builder.Uint8('C')
+		builder.String(string(err.Code()))
 
-	builder.Uint8('M')
-	builder.String(err.Message())
+		builder.Uint8('M')
+		builder.String(err.Message())
 
-	for _, field := range err.Extra() {
-		builder.Uint8(uint8(field.Type))
-		builder.String(field.Value)
-	}
+		for _, field := range err.Extra() {
+			builder.Uint8(uint8(field.Type))
+			builder.String(field.Value)
+		}
 
-	builder.Uint8(0)
+		builder.Uint8(0)
 
-	return T.Write(builder.Raw())
+		_ = T.Write(builder.Raw())
+	}
+	_ = T.conn.Close()
 }
 
 var _ frontend.Client = (*Client)(nil)
diff --git a/lib/frontend/frontends/v0/frontend.go b/lib/frontend/frontends/v0/frontend.go
index ac982a73..6373a4df 100644
--- a/lib/frontend/frontends/v0/frontend.go
+++ b/lib/frontend/frontends/v0/frontend.go
@@ -1,7 +1,6 @@
 package frontends
 
 import (
-	"log"
 	"net"
 
 	"pggat2/lib/frontend"
@@ -22,6 +21,13 @@ func NewFrontend() (*Frontend, error) {
 	}, nil
 }
 
+func (T *Frontend) accept(conn net.Conn) {
+	client := NewClient(conn)
+	if client != nil {
+		T.clients = append(T.clients, client)
+	}
+}
+
 func (T *Frontend) Run() error {
 	for {
 		conn, err := T.listener.Accept()
@@ -29,12 +35,7 @@ func (T *Frontend) Run() error {
 			return err
 		}
 
-		client, err := NewClient(conn)
-		if err != nil {
-			log.Println("rejected client:", err)
-		} else {
-			T.clients = append(T.clients, client)
-		}
+		go T.accept(conn)
 	}
 }
 
-- 
GitLab