From 88e26e820c81b5818d567d393480bcce2e9a48fb Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Wed, 3 May 2023 16:33:48 -0500 Subject: [PATCH] clean up errors and disconnection --- lib/frontend/frontends/v0/client.go | 155 ++++++++++++++------------ lib/frontend/frontends/v0/frontend.go | 15 +-- 2 files changed, 92 insertions(+), 78 deletions(-) diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go index 54ea4272..3b0a259e 100644 --- a/lib/frontend/frontends/v0/client.go +++ b/lib/frontend/frontends/v0/client.go @@ -24,6 +24,17 @@ var ErrProtocolError = perror.New( "Expected a different packet", ) +func WrapError(err error) perror.Error { + if err == nil { + return nil + } + return perror.New( + perror.FATAL, + perror.InternalError, + err.Error(), + ) +} + type Client struct { conn net.Conn @@ -37,7 +48,7 @@ type Client struct { cancellationKey [8]byte } -func NewClient(conn net.Conn) (*Client, error) { +func NewClient(conn net.Conn) *Client { client := &Client{ conn: conn, Reader: pnet.MakeReader(conn), @@ -45,25 +56,26 @@ func NewClient(conn net.Conn) (*Client, error) { } err := client.accept() if err != nil { - return nil, err + client.Close(err) + return nil } - return client, nil + return client } -func (T *Client) startup0() (bool, error) { +func (T *Client) startup0() (bool, perror.Error) { startup, err := T.ReadUntyped() if err != nil { - return false, err + return false, WrapError(err) } reader := packet.MakeReader(startup) majorVersion, ok := reader.Uint16() if !ok { - return false, T.Error(ErrBadPacketFormat) + return false, ErrBadPacketFormat } minorVersion, ok := reader.Uint16() if !ok { - return false, T.Error(ErrBadPacketFormat) + return false, ErrBadPacketFormat } if majorVersion == 1234 { @@ -71,36 +83,34 @@ func (T *Client) startup0() (bool, error) { switch minorVersion { case 5678: // Cancel - err = T.Error(perror.New( + return false, perror.New( perror.FATAL, perror.FeatureNotSupported, "Cancel is not supported yet", - )) - return false, err + ) case 5679: // SSL is not supported yet err = T.WriteByte('N') - return false, err + return false, WrapError(err) case 5680: // GSSAPI is not supported yet err = T.WriteByte('N') - return false, err + return false, WrapError(err) default: - err = T.Error(perror.New( + return false, perror.New( perror.FATAL, perror.ProtocolViolation, "Unknown request code", - )) - return false, err + ) } } if majorVersion != 3 { - err = T.Error(perror.New( + return false, perror.New( perror.FATAL, perror.ProtocolViolation, "Unsupported protocol version", - )) + ) } var unsupportedOptions []string @@ -108,7 +118,7 @@ func (T *Client) startup0() (bool, error) { for { key, ok := reader.String() if !ok { - return false, T.Error(ErrBadPacketFormat) + return false, ErrBadPacketFormat } if key == "" { break @@ -116,7 +126,7 @@ func (T *Client) startup0() (bool, error) { value, ok := reader.String() if !ok { - return false, T.Error(ErrBadPacketFormat) + return false, ErrBadPacketFormat } switch key { @@ -125,17 +135,17 @@ func (T *Client) startup0() (bool, error) { case "database": T.database = value case "options": - return false, T.Error(perror.New( + return false, perror.New( perror.FATAL, perror.FeatureNotSupported, "Startup options are not supported yet", - )) + ) case "replication": - return false, T.Error(perror.New( + return false, perror.New( perror.FATAL, perror.FeatureNotSupported, "Replication mode is not supported yet", - )) + ) default: unsupportedOptions = append(unsupportedOptions, key) } @@ -153,16 +163,16 @@ func (T *Client) startup0() (bool, error) { err = T.Write(builder.Raw()) if err != nil { - return false, err + return false, WrapError(err) } } if T.user == "" { - return false, T.Error(perror.New( + return false, perror.New( perror.FATAL, perror.InvalidAuthorizationSpecification, "User is required", - )) + ) } if T.database == "" { T.database = T.user @@ -171,7 +181,7 @@ func (T *Client) startup0() (bool, error) { return true, nil } -func (T *Client) authenticationSASL(username, password string) error { +func (T *Client) authenticationSASL(username, password string) perror.Error { var builder packet.Builder builder.Type(packet.Authentication) builder.Int32(10) @@ -182,37 +192,37 @@ func (T *Client) authenticationSASL(username, password string) error { err := T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } // check which authentication method the client wants pkt, err := T.Read() if err != nil { - return err + return WrapError(err) } if pkt.Type != packet.AuthenticationResponse { - return T.Error(ErrBadPacketFormat) + return ErrBadPacketFormat } reader := packet.MakeReader(pkt) mechanism, ok := reader.String() if !ok { - return T.Error(ErrBadPacketFormat) + return ErrBadPacketFormat } tool, err := sasl.NewServer(mechanism, username, password) if err != nil { - return err + return WrapError(err) } _, ok = reader.Int32() if !ok { - return T.Error(ErrBadPacketFormat) + return ErrBadPacketFormat } resp, done, err := tool.InitialResponse(reader.Remaining()) for { if err != nil { - return err + return WrapError(err) } if done { builder = packet.Builder{} @@ -221,7 +231,7 @@ func (T *Client) authenticationSASL(username, password string) error { builder.Bytes(resp) err = T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } break } else { @@ -231,16 +241,16 @@ func (T *Client) authenticationSASL(username, password string) error { builder.Bytes(resp) err = T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } } pkt, err = T.Read() if err != nil { - return err + return WrapError(err) } if pkt.Type != packet.AuthenticationResponse { - return T.Error(ErrProtocolError) + return ErrProtocolError } resp, done, err = tool.Continue(pkt.Payload) @@ -249,11 +259,11 @@ func (T *Client) authenticationSASL(username, password string) error { return nil } -func (T *Client) authenticationMD5(username, password string) error { +func (T *Client) authenticationMD5(username, password string) perror.Error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { - return err + return WrapError(err) } // password time @@ -265,41 +275,41 @@ func (T *Client) authenticationMD5(username, password string) error { err = T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } // read password pkt, err := T.Read() if err != nil { - return err + return WrapError(err) } reader := packet.MakeReader(pkt) if reader.Type() != packet.AuthenticationResponse { - return T.Error(perror.New( + return perror.New( perror.FATAL, perror.ProtocolViolation, "Expected password", - )) + ) } pw, ok := reader.String() if !ok { - return T.Error(ErrBadPacketFormat) + return ErrBadPacketFormat } if !md5.Check(username, password, salt, pw) { - return T.Error(perror.New( + return perror.New( perror.FATAL, perror.InvalidPassword, "Invalid password", - )) + ) } return nil } -func (T *Client) accept() error { +func (T *Client) accept() perror.Error { for { done, err := T.startup0() if err != nil { @@ -311,9 +321,9 @@ func (T *Client) accept() error { } // TODO(garet) don't hardcode username and password - err := T.authenticationSASL("test", "password") - if err != nil { - return err + perr := T.authenticationSASL("test", "password") + if perr != nil { + return perr } // send auth ok @@ -321,15 +331,15 @@ func (T *Client) accept() error { builder.Type(packet.Authentication) builder.Uint32(0) - err = T.Write(builder.Raw()) + err := T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } // send backend key data _, err = rand.Read(T.cancellationKey[:]) if err != nil { - return err + return WrapError(err) } builder = packet.Builder{} builder.Type(packet.BackendKeyData) @@ -337,7 +347,7 @@ func (T *Client) accept() error { err = T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } // send ready for query @@ -347,33 +357,36 @@ func (T *Client) accept() error { err = T.Write(builder.Raw()) if err != nil { - return err + return WrapError(err) } return nil } -func (T *Client) Error(err perror.Error) error { - var builder packet.Builder - builder.Type(packet.ErrorResponse) +func (T *Client) Close(err perror.Error) { + if err != nil { + var builder packet.Builder + builder.Type(packet.ErrorResponse) - builder.Uint8('S') - builder.String(string(err.Severity())) + builder.Uint8('S') + builder.String(string(err.Severity())) - builder.Uint8('C') - builder.String(string(err.Code())) + builder.Uint8('C') + builder.String(string(err.Code())) - builder.Uint8('M') - builder.String(err.Message()) + builder.Uint8('M') + builder.String(err.Message()) - for _, field := range err.Extra() { - builder.Uint8(uint8(field.Type)) - builder.String(field.Value) - } + for _, field := range err.Extra() { + builder.Uint8(uint8(field.Type)) + builder.String(field.Value) + } - builder.Uint8(0) + builder.Uint8(0) - return T.Write(builder.Raw()) + _ = T.Write(builder.Raw()) + } + _ = T.conn.Close() } var _ frontend.Client = (*Client)(nil) diff --git a/lib/frontend/frontends/v0/frontend.go b/lib/frontend/frontends/v0/frontend.go index ac982a73..6373a4df 100644 --- a/lib/frontend/frontends/v0/frontend.go +++ b/lib/frontend/frontends/v0/frontend.go @@ -1,7 +1,6 @@ package frontends import ( - "log" "net" "pggat2/lib/frontend" @@ -22,6 +21,13 @@ func NewFrontend() (*Frontend, error) { }, nil } +func (T *Frontend) accept(conn net.Conn) { + client := NewClient(conn) + if client != nil { + T.clients = append(T.clients, client) + } +} + func (T *Frontend) Run() error { for { conn, err := T.listener.Accept() @@ -29,12 +35,7 @@ func (T *Frontend) Run() error { return err } - client, err := NewClient(conn) - if err != nil { - log.Println("rejected client:", err) - } else { - T.clients = append(T.clients, client) - } + go T.accept(conn) } } -- GitLab