From c95cf4050f60f362d774e63bdc06bea9f5eb3b8d Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Wed, 3 May 2023 14:29:49 -0500
Subject: [PATCH] backend ready

---
 go.mod                              |   8 +-
 go.sum                              |  31 ++++
 lib/auth/{ => md5}/md5.go           |   8 +-
 lib/auth/{ => md5}/md5_test.go      |   6 +-
 lib/auth/sasl/client.go             |  26 ++++
 lib/auth/sasl/scram/client.go       |  59 ++++++++
 lib/auth/sasl/scram/server.go       |   1 +
 lib/backend/backends/v0/server.go   | 211 ++++++++++++++++++++++++----
 lib/frontend/frontends/v0/client.go |   4 +-
 lib/pnet/packet/reader.go           |   4 +
 10 files changed, 321 insertions(+), 37 deletions(-)
 rename lib/auth/{ => md5}/md5.go (74%)
 rename lib/auth/{ => md5}/md5_test.go (88%)
 create mode 100644 lib/auth/sasl/client.go
 create mode 100644 lib/auth/sasl/scram/client.go
 create mode 100644 lib/auth/sasl/scram/server.go

diff --git a/go.mod b/go.mod
index 3f39040e..d07289d0 100644
--- a/go.mod
+++ b/go.mod
@@ -2,4 +2,10 @@ module pggat2
 
 go 1.20
 
-require github.com/google/uuid v1.3.0 // indirect
+require (
+	github.com/google/uuid v1.3.0 // indirect
+	github.com/xdg-go/pbkdf2 v1.0.0 // indirect
+	github.com/xdg-go/scram v1.1.2 // indirect
+	github.com/xdg-go/stringprep v1.0.4 // indirect
+	golang.org/x/text v0.3.8 // indirect
+)
diff --git a/go.sum b/go.sum
index 3dfe1c9f..4831f69c 100644
--- a/go.sum
+++ b/go.sum
@@ -1,2 +1,33 @@
 github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c=
+github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
+github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY=
+github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
+github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8=
+github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM=
+github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
+golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
+golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
+golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
+golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
+golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
+golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
+golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
+golang.org/x/text v0.3.8 h1:nAL+RVCQ9uMn3vJZbV+MRnydTJFPf8qqY42YiA6MrqY=
+golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
+golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
+golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
diff --git a/lib/auth/md5.go b/lib/auth/md5/md5.go
similarity index 74%
rename from lib/auth/md5.go
rename to lib/auth/md5/md5.go
index 5ec1f581..a721bd5c 100644
--- a/lib/auth/md5.go
+++ b/lib/auth/md5/md5.go
@@ -1,4 +1,4 @@
-package auth
+package md5
 
 import (
 	"crypto/md5"
@@ -6,7 +6,7 @@ import (
 	"strings"
 )
 
-func EncodeMD5(username, password string, salt [4]byte) string {
+func Encode(username, password string, salt [4]byte) string {
 	hash := md5.New()
 	hash.Write([]byte(password))
 	hash.Write([]byte(username))
@@ -31,6 +31,6 @@ func EncodeMD5(username, password string, salt [4]byte) string {
 	return out.String()
 }
 
-func CheckMD5(username, password string, salt [4]byte, encoded string) bool {
-	return EncodeMD5(username, password, salt) == encoded
+func Check(username, password string, salt [4]byte, encoded string) bool {
+	return Encode(username, password, salt) == encoded
 }
diff --git a/lib/auth/md5_test.go b/lib/auth/md5/md5_test.go
similarity index 88%
rename from lib/auth/md5_test.go
rename to lib/auth/md5/md5_test.go
index c67f3274..16212d4a 100644
--- a/lib/auth/md5_test.go
+++ b/lib/auth/md5/md5_test.go
@@ -1,4 +1,4 @@
-package auth
+package md5
 
 import "testing"
 
@@ -38,7 +38,7 @@ var Cases = []TestCase{
 
 func TestEncodeMD5(t *testing.T) {
 	for _, c := range Cases {
-		encoded := EncodeMD5(c.Username, c.Password, c.Salt)
+		encoded := Encode(c.Username, c.Password, c.Salt)
 		if encoded != c.Encoded {
 			t.Error("encoding failed! expected", c.Encoded, "but got", encoded)
 		}
@@ -47,7 +47,7 @@ func TestEncodeMD5(t *testing.T) {
 
 func TestCheckMD5(t *testing.T) {
 	for _, c := range Cases {
-		if !CheckMD5(c.Username, c.Password, c.Salt, c.Encoded) {
+		if !Check(c.Username, c.Password, c.Salt, c.Encoded) {
 			t.Error("check failed!")
 		}
 	}
diff --git a/lib/auth/sasl/client.go b/lib/auth/sasl/client.go
new file mode 100644
index 00000000..211fb7f6
--- /dev/null
+++ b/lib/auth/sasl/client.go
@@ -0,0 +1,26 @@
+package sasl
+
+import (
+	"errors"
+
+	"pggat2/lib/auth/sasl/scram"
+)
+
+var ErrMechanismsNotSupported = errors.New("SASL mechanisms not supported")
+
+type Client interface {
+	Name() string
+	InitialResponse() []byte
+	Continue([]byte) ([]byte, error)
+	Final([]byte) error
+}
+
+func NewClient(mechanisms []string, username, password string) (Client, error) {
+	for _, mechanism := range mechanisms {
+		switch mechanism {
+		case scram.SHA256:
+			return scram.NewClient(mechanism, username, password)
+		}
+	}
+	return nil, ErrMechanismsNotSupported
+}
diff --git a/lib/auth/sasl/scram/client.go b/lib/auth/sasl/scram/client.go
new file mode 100644
index 00000000..1f7e2a28
--- /dev/null
+++ b/lib/auth/sasl/scram/client.go
@@ -0,0 +1,59 @@
+package scram
+
+import (
+	"errors"
+
+	"github.com/xdg-go/scram"
+)
+
+var ErrUnsupportedMethod = errors.New("unsupported SCRAM method")
+
+const (
+	SHA256 = "SCRAM-SHA-256"
+)
+
+type Client struct {
+	name         string
+	conversation *scram.ClientConversation
+}
+
+func NewClient(method, username, password string) (*Client, error) {
+	var client *scram.Client
+
+	switch method {
+	case SHA256:
+		var err error
+		client, err = scram.SHA256.NewClient(username, password, "")
+		if err != nil {
+			return nil, err
+		}
+	default:
+		return nil, ErrUnsupportedMethod
+	}
+
+	return &Client{
+		name:         method,
+		conversation: client.NewConversation(),
+	}, nil
+}
+
+func (T *Client) Name() string {
+	return T.name
+}
+
+func (T *Client) InitialResponse() []byte {
+	return nil
+}
+
+func (T *Client) Continue(bytes []byte) ([]byte, error) {
+	msg, err := T.conversation.Step(string(bytes))
+	if err != nil {
+		return nil, err
+	}
+	return []byte(msg), nil
+}
+
+func (T *Client) Final(bytes []byte) error {
+	_, err := T.Continue(bytes)
+	return err
+}
diff --git a/lib/auth/sasl/scram/server.go b/lib/auth/sasl/scram/server.go
new file mode 100644
index 00000000..3fbc1d27
--- /dev/null
+++ b/lib/auth/sasl/scram/server.go
@@ -0,0 +1 @@
+package scram
diff --git a/lib/backend/backends/v0/server.go b/lib/backend/backends/v0/server.go
index 7db35173..f07d7e46 100644
--- a/lib/backend/backends/v0/server.go
+++ b/lib/backend/backends/v0/server.go
@@ -4,23 +4,31 @@ import (
 	"errors"
 	"net"
 
+	"pggat2/lib/auth/sasl"
 	"pggat2/lib/backend"
 	"pggat2/lib/pnet"
 	"pggat2/lib/pnet/packet"
 )
 
+var ErrBadPacketFormat = errors.New("bad packet format")
+var ErrProtocolError = errors.New("server sent unexpected packet")
+
 type Server struct {
 	conn net.Conn
 
 	pnet.Reader
 	pnet.Writer
+
+	cancellationKey [8]byte
+	parameters      map[string]string
 }
 
 func NewServer(conn net.Conn) (*Server, error) {
 	server := &Server{
-		conn:   conn,
-		Reader: pnet.MakeReader(conn),
-		Writer: pnet.MakeWriter(conn),
+		conn:       conn,
+		Reader:     pnet.MakeReader(conn),
+		Writer:     pnet.MakeWriter(conn),
+		parameters: make(map[string]string),
 	}
 	err := server.accept()
 	if err != nil {
@@ -29,58 +37,207 @@ func NewServer(conn net.Conn) (*Server, error) {
 	return server, nil
 }
 
-func (T *Server) accept() error {
-	var builder packet.Builder
-	builder.Int16(3)
-	builder.Int16(0)
-	builder.String("user")
-	builder.String("postgres")
-	builder.String("")
-
-	err := T.WriteUntyped(builder.Raw())
+func (T *Server) authenticationSASL(mechanisms []string) error {
+	mechanism, err := sasl.NewClient(mechanisms, "test", "password")
 	if err != nil {
 		return err
 	}
 
-	auth, err := T.Read()
+	builder := packet.Builder{}
+	builder.Type(packet.AuthenticationResponse)
+	builder.String(mechanism.Name())
+	initialResponse := mechanism.InitialResponse()
+	if initialResponse == nil {
+		builder.Int32(-1)
+	} else {
+		builder.Int32(int32(len(initialResponse)))
+		builder.Bytes(initialResponse)
+	}
+	err = T.Write(builder.Raw())
 	if err != nil {
 		return err
 	}
 
-	reader := packet.MakeReader(auth)
+	// challenge loop
+outer:
+	for {
+		challenge, err := T.Read()
+		if err != nil {
+			return err
+		}
+
+		reader := packet.MakeReader(challenge)
+		if reader.Type() != packet.Authentication {
+			return ErrProtocolError
+		}
+
+		method, ok := reader.Int32()
+		if !ok {
+			return ErrBadPacketFormat
+		}
+
+		switch method {
+		case 11:
+			// challenge
+			response, err := mechanism.Continue(reader.Remaining())
+			if err != nil {
+				return err
+			}
+
+			builder = packet.Builder{}
+			builder.Type(packet.AuthenticationResponse)
+			builder.Bytes(response)
+
+			err = T.Write(builder.Raw())
+			if err != nil {
+				return err
+			}
+		case 12:
+			// finish
+			err = mechanism.Final(reader.Remaining())
+			if err != nil {
+				return err
+			}
+
+			break outer
+		default:
+			return ErrProtocolError
+		}
+	}
+
+	return nil
+}
+
+func (T *Server) startup0() (bool, error) {
+	pkt, err := T.Read()
+	if err != nil {
+		return false, err
+	}
+
+	reader := packet.MakeReader(pkt)
 	switch reader.Type() {
+	case packet.ErrorResponse:
+		return false, errors.New("received error response")
 	case packet.Authentication:
 		method, ok := reader.Int32()
 		if !ok {
-			return errors.New("expected authentication method")
+			return false, ErrBadPacketFormat
 		}
 		// they have more authentication methods than there are pokemon
 		switch method {
 		case 0:
 			// we're good to go, that was easy
+			return true, nil
 		case 2:
-			return errors.New("kerberos v5 is not supported")
+			return false, errors.New("kerberos v5 is not supported")
 		case 3:
-			return errors.New("cleartext is not supported")
+			return false, errors.New("cleartext is not supported")
 		case 5:
-			return errors.New("md5 password is not supported")
+			return false, errors.New("md5 password is not supported")
 		case 6:
-			return errors.New("scm credential is not supported")
+			return false, errors.New("scm credential is not supported")
 		case 7:
-			return errors.New("gss is not supported")
+			return false, errors.New("gss is not supported")
 		case 9:
-			return errors.New("sspi is not supported")
+			return false, errors.New("sspi is not supported")
 		case 10:
-			return errors.New("sasl is not supported")
+			// read list of mechanisms
+			var mechanisms []string
+			for {
+				mechanism, ok := reader.String()
+				if !ok {
+					return false, ErrBadPacketFormat
+				}
+				if mechanism == "" {
+					break
+				}
+				mechanisms = append(mechanisms, mechanism)
+			}
+
+			return false, T.authenticationSASL(mechanisms)
 		default:
-			return errors.New("unknown authentication method")
+			// we only support protocol 3.0 for now
+			return false, errors.New("unknown authentication method")
 		}
-	case packet.ErrorResponse:
-		return errors.New("backend errored")
 	case packet.NegotiateProtocolVersion:
-		// we only support 3.0 as of now
-		return errors.New("unsupported protocol version")
+		return false, errors.New("server wanted to negotiate protocol version")
+	default:
+		return false, ErrProtocolError
 	}
+}
+
+func (T *Server) startup1() (bool, error) {
+	pkt, err := T.Read()
+	if err != nil {
+		return false, err
+	}
+
+	reader := packet.MakeReader(pkt)
+	switch reader.Type() {
+	case packet.BackendKeyData:
+		cancellationKey, ok := reader.Bytes(8)
+		if !ok {
+			return false, ErrBadPacketFormat
+		}
+		copy(T.cancellationKey[:], cancellationKey)
+		return false, nil
+	case packet.ParameterStatus:
+		parameter, ok := reader.String()
+		if !ok {
+			return false, ErrBadPacketFormat
+		}
+		value, ok := reader.String()
+		if !ok {
+			return false, ErrBadPacketFormat
+		}
+		T.parameters[parameter] = value
+		return false, nil
+	case packet.ReadyForQuery:
+		return true, nil
+	case packet.ErrorResponse:
+		return false, errors.New("received error response")
+	case packet.NoticeResponse:
+		// TODO(garet) do something with notice
+		return false, nil
+	default:
+		return false, ErrProtocolError
+	}
+}
+
+func (T *Server) accept() error {
+	var builder packet.Builder
+	builder.Int16(3)
+	builder.Int16(0)
+	builder.String("user")
+	builder.String("postgres")
+	builder.String("")
+
+	err := T.WriteUntyped(builder.Raw())
+	if err != nil {
+		return err
+	}
+
+	for {
+		done, err := T.startup0()
+		if err != nil {
+			return err
+		}
+		if done {
+			break
+		}
+	}
+
+	for {
+		done, err := T.startup1()
+		if err != nil {
+			return err
+		}
+		if done {
+			break
+		}
+	}
+
+	// startup complete, connection is ready for queries
 
 	return nil
 }
diff --git a/lib/frontend/frontends/v0/client.go b/lib/frontend/frontends/v0/client.go
index 1ab21a57..8213b61d 100644
--- a/lib/frontend/frontends/v0/client.go
+++ b/lib/frontend/frontends/v0/client.go
@@ -4,7 +4,7 @@ import (
 	"crypto/rand"
 	"net"
 
-	"pggat2/lib/auth"
+	"pggat2/lib/auth/md5"
 	"pggat2/lib/frontend"
 	"pggat2/lib/perror"
 	"pggat2/lib/pnet"
@@ -215,7 +215,7 @@ func (T *Client) accept() error {
 		return T.Error(ErrBadPacketFormat)
 	}
 
-	if !auth.CheckMD5("test", "password", salt, pw) {
+	if !md5.Check("test", "password", salt, pw) {
 		return T.Error(perror.New(
 			perror.FATAL,
 			perror.InvalidPassword,
diff --git a/lib/pnet/packet/reader.go b/lib/pnet/packet/reader.go
index fe85f394..bbb34a7d 100644
--- a/lib/pnet/packet/reader.go
+++ b/lib/pnet/packet/reader.go
@@ -107,3 +107,7 @@ func (T *Reader) Bytes(count int) ([]byte, bool) {
 	T.bytes = T.bytes[count:]
 	return v, true
 }
+
+func (T *Reader) Remaining() []byte {
+	return T.bytes
+}
-- 
GitLab