From c95cf4050f60f362d774e63bdc06bea9f5eb3b8d Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Wed, 3 May 2023 14:29:49 -0500 Subject: [PATCH] backend ready --- go.mod | 8 +- go.sum | 31 ++++ lib/auth/{ => md5}/md5.go | 8 +- lib/auth/{ => md5}/md5_test.go | 6 +- lib/auth/sasl/client.go | 26 ++++ lib/auth/sasl/scram/client.go | 59 ++++++++ lib/auth/sasl/scram/server.go | 1 + lib/backend/backends/v0/server.go | 211 ++++++++++++++++++++++++---- lib/frontend/frontends/v0/client.go | 4 +- lib/pnet/packet/reader.go | 4 + 10 files changed, 321 insertions(+), 37 deletions(-) rename lib/auth/{ => md5}/md5.go (74%) rename lib/auth/{ => md5}/md5_test.go (88%) create mode 100644 lib/auth/sasl/client.go create mode 100644 lib/auth/sasl/scram/client.go create mode 100644 lib/auth/sasl/scram/server.go diff --git a/go.mod b/go.mod index 3f39040e..d07289d0 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,10 @@ module pggat2 go 1.20 -require github.com/google/uuid v1.3.0 // indirect +require ( + github.com/google/uuid v1.3.0 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.2 // indirect + github.com/xdg-go/stringprep v1.0.4 // indirect + golang.org/x/text v0.3.8 // indirect +) diff --git a/go.sum b/go.sum index 3dfe1c9f..4831f69c 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,33 @@ github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= +github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= +github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/lib/auth/md5.go b/lib/auth/md5/md5.go similarity index 74% rename from lib/auth/md5.go rename to lib/auth/md5/md5.go index 5ec1f581..a721bd5c 100644 --- a/lib/auth/md5.go +++ b/lib/auth/md5/md5.go @@ -1,4 +1,4 @@ -package auth +package md5 import ( "crypto/md5" @@ -6,7 +6,7 @@ import ( "strings" ) -func EncodeMD5(username, password string, salt [4]byte) string { +func Encode(username, password string, salt [4]byte) string { hash := md5.New() hash.Write([]byte(password)) hash.Write([]byte(username)) @@ -31,6 +31,6 @@ func EncodeMD5(username, password string, salt [4]byte) string { return out.String() } -func CheckMD5(username, password string, salt [4]byte, encoded string) bool { - return EncodeMD5(username, password, salt) == encoded +func Check(username, password string, salt [4]byte, encoded string) bool { + return Encode(username, password, salt) == encoded } diff --git a/lib/auth/md5_test.go b/lib/auth/md5/md5_test.go similarity index 88% rename from lib/auth/md5_test.go rename to lib/auth/md5/md5_test.go index c67f3274..16212d4a 100644 --- a/lib/auth/md5_test.go +++ b/lib/auth/md5/md5_test.go @@ -1,4 +1,4 @@ -package auth +package md5 import "testing" @@ -38,7 +38,7 @@ var Cases = []TestCase{ func TestEncodeMD5(t *testing.T) { for _, c := range Cases { - encoded := EncodeMD5(c.Username, c.Password, c.Salt) + encoded := Encode(c.Username, c.Password, c.Salt) if encoded != c.Encoded { t.Error("encoding failed! expected", c.Encoded, "but got", encoded) } @@ -47,7 +47,7 @@ func TestEncodeMD5(t *testing.T) { func TestCheckMD5(t *testing.T) { for _, c := range Cases { - if !CheckMD5(c.Username, c.Password, c.Salt, c.Encoded) { + if !Check(c.Username, c.Password, c.Salt, c.Encoded) { t.Error("check failed!") } } diff --git a/lib/auth/sasl/client.go b/lib/auth/sasl/client.go new file mode 100644 index 00000000..211fb7f6 --- /dev/null +++ b/lib/auth/sasl/client.go @@ -0,0 +1,26 @@ +package sasl + +import ( + "errors" + + "pggat2/lib/auth/sasl/scram" +) + +var ErrMechanismsNotSupported = errors.New("SASL mechanisms not supported") + +type Client interface { + Name() string + InitialResponse() []byte + Continue([]byte) ([]byte, error) + Final([]byte) error +} + +func NewClient(mechanisms []string, username, password string) (Client, error) { + for _, mechanism := range mechanisms { + switch mechanism { + case scram.SHA256: + return scram.NewClient(mechanism, username, password) + } + } + return nil, ErrMechanismsNotSupported +} diff --git a/lib/auth/sasl/scram/client.go b/lib/auth/sasl/scram/client.go new file mode 100644 index 00000000..1f7e2a28 --- /dev/null +++ b/lib/auth/sasl/scram/client.go @@ -0,0 +1,59 @@ +package scram + +import ( + "errors" + + "github.com/xdg-go/scram" +) + +var ErrUnsupportedMethod = errors.New("unsupported SCRAM method") + +const ( + SHA256 = "SCRAM-SHA-256" +) + +type Client struct { + name string + conversation *scram.ClientConversation +} + +func NewClient(method, username, password string) (*Client, error) { + var client *scram.Client + + switch method { + case SHA256: + var err error + client, err = scram.SHA256.NewClient(username, password, "") + if err != nil { + return nil, err + } + default: + return nil, ErrUnsupportedMethod + } + + return &Client{ + name: method, + conversation: client.NewConversation(), + }, nil +} + +func (T *Client) Name() string { + return T.name +} + +func (T *Client) InitialResponse() []byte { + return nil +} + +func (T *Client) Continue(bytes []byte) ([]byte, error) { + msg, err := T.conversation.Step(string(bytes)) + if err != nil { + return nil, err + } + return []byte(msg), nil +} + +func (T *Client) Final(bytes []byte) error { + _, err := T.Continue(bytes) + return err +} diff --git a/lib/auth/sasl/scram/server.go b/lib/auth/sasl/scram/server.go new file mode 100644 index 00000000..3fbc1d27 --- /dev/null +++ b/lib/auth/sasl/scram/server.go @@ -0,0 +1 @@ +package scram diff --git a/lib/backend/backends/v0/server.go b/lib/backend/backends/v0/server.go index 7db35173..f07d7e46 100644 --- a/lib/backend/backends/v0/server.go +++ b/lib/backend/backends/v0/server.go @@ -4,23 +4,31 @@ import ( "errors" "net" + "pggat2/lib/auth/sasl" "pggat2/lib/backend" "pggat2/lib/pnet" "pggat2/lib/pnet/packet" ) +var ErrBadPacketFormat = errors.New("bad packet format") +var ErrProtocolError = errors.New("server sent unexpected packet") + type Server struct { conn net.Conn pnet.Reader pnet.Writer + + cancellationKey [8]byte + parameters map[string]string } func NewServer(conn net.Conn) (*Server, error) { server := &Server{ - conn: conn, - Reader: pnet.MakeReader(conn), - Writer: pnet.MakeWriter(conn), + conn: conn, + Reader: pnet.MakeReader(conn), + Writer: pnet.MakeWriter(conn), + parameters: make(map[string]string), } err := server.accept() if err != nil { @@ -29,58 +37,207 @@ func NewServer(conn net.Conn) (*Server, error) { return server, nil } -func (T *Server) accept() error { - var builder packet.Builder - builder.Int16(3) - builder.Int16(0) - builder.String("user") - builder.String("postgres") - builder.String("") - - err := T.WriteUntyped(builder.Raw()) +func (T *Server) authenticationSASL(mechanisms []string) error { + mechanism, err := sasl.NewClient(mechanisms, "test", "password") if err != nil { return err } - auth, err := T.Read() + builder := packet.Builder{} + builder.Type(packet.AuthenticationResponse) + builder.String(mechanism.Name()) + initialResponse := mechanism.InitialResponse() + if initialResponse == nil { + builder.Int32(-1) + } else { + builder.Int32(int32(len(initialResponse))) + builder.Bytes(initialResponse) + } + err = T.Write(builder.Raw()) if err != nil { return err } - reader := packet.MakeReader(auth) + // challenge loop +outer: + for { + challenge, err := T.Read() + if err != nil { + return err + } + + reader := packet.MakeReader(challenge) + if reader.Type() != packet.Authentication { + return ErrProtocolError + } + + method, ok := reader.Int32() + if !ok { + return ErrBadPacketFormat + } + + switch method { + case 11: + // challenge + response, err := mechanism.Continue(reader.Remaining()) + if err != nil { + return err + } + + builder = packet.Builder{} + builder.Type(packet.AuthenticationResponse) + builder.Bytes(response) + + err = T.Write(builder.Raw()) + if err != nil { + return err + } + case 12: + // finish + err = mechanism.Final(reader.Remaining()) + if err != nil { + return err + } + + break outer + default: + return ErrProtocolError + } + } + + return nil +} + +func (T *Server) startup0() (bool, error) { + pkt, err := T.Read() + if err != nil { + return false, err + } + + reader := packet.MakeReader(pkt) switch reader.Type() { + case packet.ErrorResponse: + return false, errors.New("received error response") case packet.Authentication: method, ok := reader.Int32() if !ok { - return errors.New("expected authentication method") + return false, ErrBadPacketFormat } // they have more authentication methods than there are pokemon switch method { case 0: // we're good to go, that was easy + return true, nil case 2: - return errors.New("kerberos v5 is not supported") + return false, errors.New("kerberos v5 is not supported") case 3: - return errors.New("cleartext is not supported") + return false, errors.New("cleartext is not supported") case 5: - return errors.New("md5 password is not supported") + return false, errors.New("md5 password is not supported") case 6: - return errors.New("scm credential is not supported") + return false, errors.New("scm credential is not supported") case 7: - return errors.New("gss is not supported") + return false, errors.New("gss is not supported") case 9: - return errors.New("sspi is not supported") + return false, errors.New("sspi is not supported") case 10: - return errors.New("sasl is not supported") + // read list of mechanisms + var mechanisms []string + for { + mechanism, ok := reader.String() + if !ok { + return false, ErrBadPacketFormat + } + if mechanism == "" { + break + } + mechanisms = append(mechanisms, mechanism) + } + + return false, T.authenticationSASL(mechanisms) default: - return errors.New("unknown authentication method") + // we only support protocol 3.0 for now + return false, errors.New("unknown authentication method") } - case packet.ErrorResponse: - return errors.New("backend errored") case packet.NegotiateProtocolVersion: - // we only support 3.0 as of now - return errors.New("unsupported protocol version") + return false, errors.New("server wanted to negotiate protocol version") + default: + return false, ErrProtocolError } +} + +func (T *Server) startup1() (bool, error) { + pkt, err := T.Read() + if err != nil { + return false, err + } + + reader := packet.MakeReader(pkt) + switch reader.Type() { + case packet.BackendKeyData: + cancellationKey, ok := reader.Bytes(8) + if !ok { + return false, ErrBadPacketFormat + } + copy(T.cancellationKey[:], cancellationKey) + return false, nil + case packet.ParameterStatus: + parameter, ok := reader.String() + if !ok { + return false, ErrBadPacketFormat + } + value, ok := reader.String() + if !ok { + return false, ErrBadPacketFormat + } + T.parameters[parameter] = value + return false, nil + case packet.ReadyForQuery: + return true, nil + case packet.ErrorResponse: + return false, errors.New("received error response") + case packet.NoticeResponse: + // TODO(garet) do something with notice + return false, nil + default: + return false, ErrProtocolError + } +} + +func (T *Server) accept() error { + var builder packet.Builder + builder.Int16(3) + builder.Int16(0) + builder.String("user") + builder.String("postgres") + builder.String("") + + err := T.WriteUntyped(builder.Raw()) + if err != nil { + return err + } + + for { + done, err := T.startup0() + if err != nil { + return err + } + if done { + break + } + } + + for { + done, err := T.startup1() + if err != nil { + return err + } + if done { + break + } + } + + // startup complete, connection is ready for queries return nil } diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go index 1ab21a57..8213b61d 100644 --- a/lib/frontend/frontends/v0/client.go +++ b/lib/frontend/frontends/v0/client.go @@ -4,7 +4,7 @@ import ( "crypto/rand" "net" - "pggat2/lib/auth" + "pggat2/lib/auth/md5" "pggat2/lib/frontend" "pggat2/lib/perror" "pggat2/lib/pnet" @@ -215,7 +215,7 @@ func (T *Client) accept() error { return T.Error(ErrBadPacketFormat) } - if !auth.CheckMD5("test", "password", salt, pw) { + if !md5.Check("test", "password", salt, pw) { return T.Error(perror.New( perror.FATAL, perror.InvalidPassword, diff --git a/lib/pnet/packet/reader.go b/lib/pnet/packet/reader.go index fe85f394..bbb34a7d 100644 --- a/lib/pnet/packet/reader.go +++ b/lib/pnet/packet/reader.go @@ -107,3 +107,7 @@ func (T *Reader) Bytes(count int) ([]byte, bool) { T.bytes = T.bytes[count:] return v, true } + +func (T *Reader) Remaining() []byte { + return T.bytes +} -- GitLab