diff --git a/cmd/pggat/pgbouncer.go b/cmd/pggat/pgbouncer.go index 4fe2f957a4c5b74a0040857b730b2c0720245d11..52dcf980bd7379b09a877e5ff1371d542debc612 100644 --- a/cmd/pggat/pgbouncer.go +++ b/cmd/pggat/pgbouncer.go @@ -6,7 +6,8 @@ import ( "github.com/caddyserver/caddy/v2" "github.com/caddyserver/caddy/v2/caddyconfig" - caddycmd "github.com/caddyserver/caddy/v2/cmd" + + caddycmd "gfx.cafe/gfx/pggat/cmd" "gfx.cafe/gfx/pggat/lib/gat" "gfx.cafe/gfx/pggat/lib/gat/handlers/pgbouncer" diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 4f69c3d08faa014ae1b117eb10c89dc5ce403bd5..bdadd4efa5cb923eaf98c3d244f7c7d09ef28ff6 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -9,39 +9,43 @@ import ( "gfx.cafe/gfx/pggat/lib/bouncer" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" + "gfx.cafe/gfx/pggat/lib/perror" "gfx.cafe/gfx/pggat/lib/util/strutil" ) func authenticationSASLChallenge(ctx *acceptContext, encoder auth.SASLEncoder) (done bool, err error) { - ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(true) if err != nil { return } - if ctx.Packet.Type() != packets.TypeAuthentication { + if packet.Type() != packets.TypeAuthentication { err = ErrUnexpectedPacket return } - var method int32 - p := ctx.Packet.ReadInt32(&method) + var p *packets.Authentication + p, err = fed.ToConcrete[*packets.Authentication](packet) + if err != nil { + return + } - switch method { - case 11: + switch mode := p.Mode.(type) { + case *packets.AuthenticationPayloadSASLContinue: // challenge var response []byte - response, err = encoder.Write(p) + response, err = encoder.Write(*mode) if err != nil { return } - resp := packets.AuthenticationResponse(response) - ctx.Packet = resp.IntoPacket(ctx.Packet) - err = ctx.Conn.WritePacket(ctx.Packet) + resp := packets.SASLResponse(response) + err = ctx.Conn.WritePacket(&resp) return - case 12: + case *packets.AuthenticationPayloadSASLFinal: // finish - _, err = encoder.Write(p) + _, err = encoder.Write(*mode) if err != io.EOF { if err == nil { err = errors.New("expected EOF") @@ -67,11 +71,10 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL } saslInitialResponse := packets.SASLInitialResponse{ - Mechanism: mechanism, - InitialResponse: initialResponse, + Mechanism: mechanism, + InitialClientResponse: initialResponse, } - ctx.Packet = saslInitialResponse.IntoPacket(ctx.Packet) - err = ctx.Conn.WritePacket(ctx.Packet) + err = ctx.Conn.WritePacket(&saslInitialResponse) if err != nil { return err } @@ -92,11 +95,8 @@ func authenticationSASL(ctx *acceptContext, mechanisms []string, creds auth.SASL } func authenticationMD5(ctx *acceptContext, salt [4]byte, creds auth.MD5Client) error { - pw := packets.PasswordMessage{ - Password: creds.EncodeMD5(salt), - } - ctx.Packet = pw.IntoPacket(ctx.Packet) - err := ctx.Conn.WritePacket(ctx.Packet) + pw := packets.PasswordMessage(creds.EncodeMD5(salt)) + err := ctx.Conn.WritePacket(&pw) if err != nil { return err } @@ -104,69 +104,54 @@ func authenticationMD5(ctx *acceptContext, salt [4]byte, creds auth.MD5Client) e } func authenticationCleartext(ctx *acceptContext, creds auth.CleartextClient) error { - pw := packets.PasswordMessage{ - Password: creds.EncodeCleartext(), - } - ctx.Packet = pw.IntoPacket(ctx.Packet) - err := ctx.Conn.WritePacket(ctx.Packet) + pw := packets.PasswordMessage(creds.EncodeCleartext()) + err := ctx.Conn.WritePacket(&pw) if err != nil { return err } return nil } -func authentication(ctx *acceptContext) (done bool, err error) { - var method int32 - ctx.Packet.ReadInt32(&method) +func authentication(ctx *acceptContext, p *packets.Authentication) (done bool, err error) { // they have more authentication methods than there are pokemon - switch method { - case 0: + switch mode := p.Mode.(type) { + case *packets.AuthenticationPayloadOk: // we're good to go, that was easy ctx.Conn.Authenticated = true return true, nil - case 2: + case *packets.AuthenticationPayloadKerberosV5: err = errors.New("kerberos v5 is not supported") return - case 3: + case *packets.AuthenticationPayloadCleartextPassword: c, ok := ctx.Options.Credentials.(auth.CleartextClient) if !ok { return false, auth.ErrMethodNotSupported } return false, authenticationCleartext(ctx, c) - case 5: - var md5 packets.AuthenticationMD5 - if !md5.ReadFromPacket(ctx.Packet) { - err = ErrBadFormat - return - } - + case *packets.AuthenticationPayloadMD5Password: c, ok := ctx.Options.Credentials.(auth.MD5Client) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationMD5(ctx, md5.Salt, c) - case 6: - err = errors.New("scm credential is not supported") - return - case 7: + return false, authenticationMD5(ctx, *mode, c) + case *packets.AuthenticationPayloadGSS: err = errors.New("gss is not supported") return - case 9: + case *packets.AuthenticationPayloadSSPI: err = errors.New("sspi is not supported") return - case 10: - // read list of mechanisms - var sasl packets.AuthenticationSASL - if !sasl.ReadFrom(ctx.Packet) { - err = ErrBadFormat - return - } - + case *packets.AuthenticationPayloadSASL: c, ok := ctx.Options.Credentials.(auth.SASLClient) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationSASL(ctx, sasl.Mechanisms, c) + + var mechanisms = make([]string, 0, len(*mode)) + for _, m := range *mode { + mechanisms = append(mechanisms, m.Method) + } + + return false, authenticationSASL(ctx, mechanisms, c) default: err = errors.New("unknown authentication method") return @@ -174,22 +159,28 @@ func authentication(ctx *acceptContext) (done bool, err error) { } func startup0(ctx *acceptContext) (done bool, err error) { - ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(true) if err != nil { return } - switch ctx.Packet.Type() { + switch packet.Type() { case packets.TypeErrorResponse: - var err2 packets.ErrorResponse - if !err2.ReadFromPacket(ctx.Packet) { - err = ErrBadFormat - } else { - err = errors.New(err2.Error.String()) + var p *packets.ErrorResponse + p, err = fed.ToConcrete[*packets.ErrorResponse](packet) + if err != nil { + return } + err = perror.FromPacket(p) return case packets.TypeAuthentication: - return authentication(ctx) + var p *packets.Authentication + p, err = fed.ToConcrete[*packets.Authentication](packet) + if err != nil { + return + } + return authentication(ctx, p) case packets.TypeNegotiateProtocolVersion: // we only support protocol 3.0 for now err = errors.New("server wanted to negotiate protocol version") @@ -201,36 +192,44 @@ func startup0(ctx *acceptContext) (done bool, err error) { } func startup1(ctx *acceptContext) (done bool, err error) { - ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(true) if err != nil { return } - switch ctx.Packet.Type() { + switch packet.Type() { case packets.TypeBackendKeyData: - ctx.Packet.ReadBytes(ctx.Conn.BackendKey[:]) + var p *packets.BackendKeyData + p, err = fed.ToConcrete[*packets.BackendKeyData](packet) + if err != nil { + return + } + ctx.Conn.BackendKey.SecretKey = p.SecretKey + ctx.Conn.BackendKey.ProcessID = p.ProcessID + return false, nil case packets.TypeParameterStatus: - var ps packets.ParameterStatus - if !ps.ReadFromPacket(ctx.Packet) { - err = ErrBadFormat + var p *packets.ParameterStatus + p, err = fed.ToConcrete[*packets.ParameterStatus](packet) + if err != nil { return } - ikey := strutil.MakeCIString(ps.Key) + ikey := strutil.MakeCIString(p.Key) if ctx.Conn.InitialParameters == nil { ctx.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = ps.Value + ctx.Conn.InitialParameters[ikey] = p.Value return false, nil case packets.TypeReadyForQuery: return true, nil case packets.TypeErrorResponse: - var err2 packets.ErrorResponse - if !err2.ReadFromPacket(ctx.Packet) { - err = ErrBadFormat - } else { - err = errors.New(err2.Error.String()) + var p *packets.ErrorResponse + p, err = fed.ToConcrete[*packets.ErrorResponse](packet) + if err != nil { + return } + err = perror.FromPacket(p) return case packets.TypeNoticeResponse: // TODO(garet) do something with notice @@ -242,20 +241,18 @@ func startup1(ctx *acceptContext) (done bool, err error) { } func enableSSL(ctx *acceptContext) (bool, error) { - ctx.Packet = ctx.Packet.Reset(0, 4) - ctx.Packet = ctx.Packet.AppendUint16(1234) - ctx.Packet = ctx.Packet.AppendUint16(5679) - if err := ctx.Conn.WritePacket(ctx.Packet); err != nil { - return false, err + p := packets.Startup{ + Mode: &packets.StartupPayloadControl{ + Mode: &packets.StartupPayloadControlPayloadSSL{}, + }, } - byteReader, ok := ctx.Conn.ReadWriteCloser.(io.ByteReader) - if !ok { - return false, errors.New("server must be io.ByteReader to enable ssl") + if err := ctx.Conn.WritePacket(&p); err != nil { + return false, err } // read byte to see if ssl is allowed - yn, err := byteReader.ReadByte() + yn, err := ctx.Conn.Decoder.Uint8() if err != nil { return false, err } @@ -265,12 +262,7 @@ func enableSSL(ctx *acceptContext) (bool, error) { return false, nil } - sslClient, ok := ctx.Conn.ReadWriteCloser.(fed.SSLClient) - if !ok { - return false, errors.New("server must be fed.SSLClient to enable ssl") - } - - if err = sslClient.EnableSSLClient(ctx.Options.SSLConfig); err != nil { + if err = ctx.Conn.EnableSSLClient(ctx.Options.SSLConfig); err != nil { return false, err } @@ -294,26 +286,32 @@ func accept(ctx *acceptContext) error { } } - size := 4 + len("user") + 1 + len(username) + 1 + len("database") + 1 + len(ctx.Options.Database) + 1 - for key, value := range ctx.Options.StartupParameters { - size += len(key.String()) + len(value) + 2 + m := packets.StartupPayloadVersion3{ + MinorVersion: 0, + Parameters: []packets.StartupPayloadVersion3PayloadParameter{ + { + Key: "user", + Value: username, + }, + { + Key: "database", + Value: ctx.Options.Database, + }, + }, } - size += 1 - - ctx.Packet = ctx.Packet.Reset(0, size) - ctx.Packet = ctx.Packet.AppendUint16(3) - ctx.Packet = ctx.Packet.AppendUint16(0) - ctx.Packet = ctx.Packet.AppendString("user") - ctx.Packet = ctx.Packet.AppendString(username) - ctx.Packet = ctx.Packet.AppendString("database") - ctx.Packet = ctx.Packet.AppendString(ctx.Options.Database) + for key, value := range ctx.Options.StartupParameters { - ctx.Packet = ctx.Packet.AppendString(key.String()) - ctx.Packet = ctx.Packet.AppendString(value) + m.Parameters = append(m.Parameters, packets.StartupPayloadVersion3PayloadParameter{ + Key: key.String(), + Value: value, + }) + } + + p := packets.Startup{ + Mode: &m, } - ctx.Packet = ctx.Packet.AppendUint8(0) - err := ctx.Conn.WritePacket(ctx.Packet) + err := ctx.Conn.WritePacket(&p) if err != nil { return err } diff --git a/lib/bouncer/backends/v0/cancel.go b/lib/bouncer/backends/v0/cancel.go index 23769b8dbed161c791b0e64ee450820a24d8bb9d..48ee2b66461df144c76de97208c8b00bd642d2bf 100644 --- a/lib/bouncer/backends/v0/cancel.go +++ b/lib/bouncer/backends/v0/cancel.go @@ -1,11 +1,18 @@ package backends -import "gfx.cafe/gfx/pggat/lib/fed" +import ( + "gfx.cafe/gfx/pggat/lib/fed" + packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" +) -func Cancel(server *fed.Conn, key [8]byte) error { - packet := fed.NewPacket(0, 12) - packet = packet.AppendUint16(1234) - packet = packet.AppendUint16(5678) - packet = packet.AppendBytes(key[:]) - return server.WritePacket(packet) +func Cancel(server *fed.Conn, key fed.BackendKey) error { + p := packets.Startup{ + Mode: &packets.StartupPayloadControl{ + Mode: &packets.StartupPayloadControlPayloadCancel{ + ProcessID: key.ProcessID, + SecretKey: key.SecretKey, + }, + }, + } + return server.WritePacket(&p) } diff --git a/lib/bouncer/backends/v0/context.go b/lib/bouncer/backends/v0/context.go index 4f4d73b30608359668d502c2a8f83fa832cb6c2a..3554fc2bbbb9be1721fd4a233611dc9de63f0807 100644 --- a/lib/bouncer/backends/v0/context.go +++ b/lib/bouncer/backends/v0/context.go @@ -5,7 +5,6 @@ import ( ) type acceptContext struct { - Packet fed.Packet Conn *fed.Conn Options acceptOptions } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index 5fadb9e3e51ba004a9e06fa86f103944bb7b09fd..d69f0cc04a6b9cadb6d7f267f78ce2b34aa11aff 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -2,7 +2,6 @@ package frontends import ( "crypto/tls" - "io" "strings" "gfx.cafe/gfx/pggat/lib/fed" @@ -14,72 +13,47 @@ import ( func startup0( ctx *acceptContext, params *acceptParams, -) (cancelling bool, done bool, err perror.Error) { - var err2 error - ctx.Packet, err2 = ctx.Conn.ReadPacket(false, ctx.Packet) - if err2 != nil { - err = perror.Wrap(err2) +) (cancelling bool, done bool, err error) { + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(false) + if err != nil { return } - var majorVersion uint16 - var minorVersion uint16 - p := ctx.Packet.ReadUint16(&majorVersion) - p = p.ReadUint16(&minorVersion) + var p *packets.Startup + p, err = fed.ToConcrete[*packets.Startup](packet) + if err != nil { + return + } - if majorVersion == 1234 { - // Cancel or SSL - switch minorVersion { - case 5678: + switch mode := p.Mode.(type) { + case *packets.StartupPayloadControl: + switch control := mode.Mode.(type) { + case *packets.StartupPayloadControlPayloadCancel: // Cancel - p.ReadBytes(params.CancelKey[:]) + params.CancelKey.ProcessID = control.ProcessID + params.CancelKey.SecretKey = control.SecretKey cancelling = true done = true return - case 5679: - byteWriter, ok := ctx.Conn.ReadWriteCloser.(io.ByteWriter) - if !ok { - err = perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "SSL is not supported", - ) - return - } - + case *packets.StartupPayloadControlPayloadSSL: // ssl is not enabled if ctx.Options.SSLConfig == nil { - err = perror.Wrap(byteWriter.WriteByte('N')) - return - } - - sslServer, ok := ctx.Conn.ReadWriteCloser.(fed.SSLServer) - if !ok { - err = perror.Wrap(byteWriter.WriteByte('N')) + err = ctx.Conn.Encoder.Uint8('N') return } // do ssl - if err = perror.Wrap(byteWriter.WriteByte('S')); err != nil { + if err = ctx.Conn.Encoder.Uint8('S'); err != nil { return } - if err = perror.Wrap(sslServer.EnableSSLServer(ctx.Options.SSLConfig)); err != nil { + if err = ctx.Conn.EnableSSLServer(ctx.Options.SSLConfig); err != nil { return } return - case 5680: - byteWriter, ok := ctx.Conn.ReadWriteCloser.(io.ByteWriter) - if !ok { - err = perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "GSSAPI is not supported", - ) - return - } - + case *packets.StartupPayloadControlPayloadGSSAPI: // GSSAPI is not supported yet - err = perror.Wrap(byteWriter.WriteByte('N')) + err = ctx.Conn.Encoder.Uint8('N') return default: err = perror.New( @@ -89,121 +63,108 @@ func startup0( ) return } - } - - if majorVersion != 3 { - err = perror.New( - perror.FATAL, - perror.ProtocolViolation, - "Unsupported protocol version", - ) - return - } - - var unsupportedOptions []string - - for { - var key string - p = p.ReadString(&key) - if key == "" { - break - } - - var value string - p = p.ReadString(&value) - - switch key { - case "user": - ctx.Conn.User = value - case "database": - ctx.Conn.Database = value - case "options": - fields := strings.Fields(value) - for i := 0; i < len(fields); i++ { - switch fields[i] { - case "-c": - i++ - set := fields[i] - var ok bool - key, value, ok = strings.Cut(set, "=") - if !ok { + case *packets.StartupPayloadVersion3: + var unsupportedOptions []string + + for _, parameter := range mode.Parameters { + switch parameter.Key { + case "user": + ctx.Conn.User = parameter.Value + case "database": + ctx.Conn.Database = parameter.Value + case "options": + fields := strings.Fields(parameter.Value) + for i := 0; i < len(fields); i++ { + switch fields[i] { + case "-c": + i++ + set := fields[i] + key, value, ok := strings.Cut(set, "=") + if !ok { + err = perror.New( + perror.FATAL, + perror.ProtocolViolation, + "Expected key=value", + ) + return + } + + ikey := strutil.MakeCIString(key) + + if ctx.Conn.InitialParameters == nil { + ctx.Conn.InitialParameters = make(map[strutil.CIString]string) + } + ctx.Conn.InitialParameters[ikey] = value + default: err = perror.New( perror.FATAL, - perror.ProtocolViolation, - "Expected key=value", + perror.FeatureNotSupported, + "Flag not supported, sorry", ) return } - - ikey := strutil.MakeCIString(key) + } + case "replication": + err = perror.New( + perror.FATAL, + perror.FeatureNotSupported, + "Replication mode is not supported yet", + ) + return + default: + if strings.HasPrefix(parameter.Key, "_pq_.") { + // we don't support protocol extensions at the moment + unsupportedOptions = append(unsupportedOptions, parameter.Key) + } else { + ikey := strutil.MakeCIString(parameter.Key) if ctx.Conn.InitialParameters == nil { ctx.Conn.InitialParameters = make(map[strutil.CIString]string) } - ctx.Conn.InitialParameters[ikey] = value - default: - err = perror.New( - perror.FATAL, - perror.FeatureNotSupported, - "Flag not supported, sorry", - ) - return + ctx.Conn.InitialParameters[ikey] = parameter.Value } } - case "replication": + } + + if mode.MinorVersion != 0 || len(unsupportedOptions) > 0 { + // negotiate protocol + uopts := packets.NegotiateProtocolVersion{ + MinorProtocolVersion: 0, + UnrecognizedProtocolOptions: unsupportedOptions, + } + err = ctx.Conn.WritePacket(&uopts) + if err != nil { + return + } + } + + if ctx.Conn.User == "" { err = perror.New( perror.FATAL, - perror.FeatureNotSupported, - "Replication mode is not supported yet", + perror.InvalidAuthorizationSpecification, + "User is required", ) return - default: - if strings.HasPrefix(key, "_pq_.") { - // we don't support protocol extensions at the moment - unsupportedOptions = append(unsupportedOptions, key) - } else { - ikey := strutil.MakeCIString(key) - - if ctx.Conn.InitialParameters == nil { - ctx.Conn.InitialParameters = make(map[strutil.CIString]string) - } - ctx.Conn.InitialParameters[ikey] = value - } - } - } - - if minorVersion != 0 || len(unsupportedOptions) > 0 { - // negotiate protocol - uopts := packets.NegotiateProtocolVersion{ - MinorProtocolVersion: 0, - UnrecognizedOptions: unsupportedOptions, } - ctx.Packet = uopts.IntoPacket(ctx.Packet) - err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) - if err != nil { - return + if ctx.Conn.Database == "" { + ctx.Conn.Database = ctx.Conn.User } - } - if ctx.Conn.User == "" { + done = true + return + default: err = perror.New( perror.FATAL, - perror.InvalidAuthorizationSpecification, - "User is required", + perror.ProtocolViolation, + "Unsupported protocol version", ) return } - if ctx.Conn.Database == "" { - ctx.Conn.Database = ctx.Conn.User - } - - done = true - return } func accept0( ctx *acceptContext, -) (params acceptParams, err perror.Error) { +) (params acceptParams, err error) { for { var done bool params.IsCanceling, done, err = startup0(ctx, ¶ms) @@ -218,27 +179,24 @@ func accept0( return } -func fail(packet fed.Packet, client fed.ReadWriter, err perror.Error) { - resp := packets.ErrorResponse{ - Error: err, - } - packet = resp.IntoPacket(packet) - _ = client.WritePacket(packet) +func fail(client *fed.Conn, err error) { + resp := perror.ToPacket(perror.Wrap(err)) + _ = client.WritePacket(resp) } -func accept(ctx *acceptContext) (acceptParams, perror.Error) { +func accept(ctx *acceptContext) (acceptParams, error) { params, err := accept0(ctx) if err != nil { - fail(ctx.Packet, ctx.Conn, err) + fail(ctx.Conn, err) return acceptParams{}, err } return params, nil } func Accept(conn *fed.Conn, tlsConfig *tls.Config) ( - cancelKey [8]byte, + cancelKey fed.BackendKey, isCanceling bool, - err perror.Error, + err error, ) { ctx := acceptContext{ Conn: conn, diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index ff2e50a1c3c86ca6ff0c1ceef169f148a8c5fe5b..a7e1120e353f25955587d66c3ec3ea7a1d554612 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -2,6 +2,7 @@ package frontends import ( "crypto/rand" + "encoding/binary" "errors" "io" @@ -11,69 +12,71 @@ import ( "gfx.cafe/gfx/pggat/lib/perror" ) -func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { +func authenticationSASLInitial(ctx *authenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err error) { // check which authentication method the client wants - var err2 error - ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) - if err2 != nil { - err = perror.Wrap(err2) + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(true) + if err != nil { return } - var initialResponse packets.SASLInitialResponse - if !initialResponse.ReadFromPacket(ctx.Packet) { - err = packets.ErrBadFormat + var p *packets.SASLInitialResponse + p, err = fed.ToConcrete[*packets.SASLInitialResponse](packet) + if err != nil { return } - tool, err2 = creds.VerifySASL(initialResponse.Mechanism) - if err2 != nil { - err = perror.Wrap(err2) + tool, err = creds.VerifySASL(p.Mechanism) + if err != nil { return } - resp, err2 = tool.Write(initialResponse.InitialResponse) - if err2 != nil { - if errors.Is(err2, io.EOF) { + resp, err = tool.Write(p.InitialClientResponse) + if err != nil { + if errors.Is(err, io.EOF) { done = true return } - err = perror.Wrap(err2) return } return } -func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { - var err2 error - ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) - if err2 != nil { - err = perror.Wrap(err2) +func authenticationSASLContinue(ctx *authenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err error) { + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(true) + if err != nil { return } - var authResp packets.AuthenticationResponse - if !authResp.ReadFrom(ctx.Packet) { - err = packets.ErrBadFormat + var p *packets.SASLResponse + p, err = fed.ToConcrete[*packets.SASLResponse](packet) + if err != nil { return } - resp, err2 = tool.Write(authResp) - if err2 != nil { - if errors.Is(err2, io.EOF) { + resp, err = tool.Write(*p) + if err != nil { + if errors.Is(err, io.EOF) { done = true return } - err = perror.Wrap(err2) return } return } -func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) perror.Error { - saslInitial := packets.AuthenticationSASL{ - Mechanisms: creds.SupportedSASLMechanisms(), +func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) error { + var mode packets.AuthenticationPayloadSASL + mechanisms := creds.SupportedSASLMechanisms() + for _, mechanism := range mechanisms { + mode = append(mode, packets.AuthenticationPayloadSASLMethod{ + Method: mechanism, + }) + } + + saslInitial := packets.Authentication{ + Mode: &mode, } - ctx.Packet = saslInitial.IntoPacket(ctx.Packet) - err := perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) + err := ctx.Conn.WritePacket(&saslInitial) if err != nil { return err } @@ -85,17 +88,21 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) perror. for { if done { - final := packets.AuthenticationSASLFinal(resp) - ctx.Packet = final.IntoPacket(ctx.Packet) - err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) + m := packets.AuthenticationPayloadSASLFinal(resp) + final := packets.Authentication{ + Mode: &m, + } + err = ctx.Conn.WritePacket(&final) if err != nil { return err } break } else { - cont := packets.AuthenticationSASLContinue(resp) - ctx.Packet = cont.IntoPacket(ctx.Packet) - err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) + m := packets.AuthenticationPayloadSASLContinue(resp) + cont := packets.Authentication{ + Mode: &m, + } + err = ctx.Conn.WritePacket(&cont) if err != nil { return err } @@ -110,39 +117,41 @@ func authenticationSASL(ctx *authenticateContext, creds auth.SASLServer) perror. return nil } -func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) perror.Error { +func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { - return perror.Wrap(err) + return err } - md5Initial := packets.AuthenticationMD5{ - Salt: salt, + mode := packets.AuthenticationPayloadMD5Password(salt) + md5Initial := packets.Authentication{ + Mode: &mode, } - ctx.Packet = md5Initial.IntoPacket(ctx.Packet) - err = ctx.Conn.WritePacket(ctx.Packet) + err = ctx.Conn.WritePacket(&md5Initial) if err != nil { - return perror.Wrap(err) + return err } - ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) + var packet fed.Packet + packet, err = ctx.Conn.ReadPacket(true) if err != nil { - return perror.Wrap(err) + return err } - var pw packets.PasswordMessage - if !pw.ReadFromPacket(ctx.Packet) { - return packets.ErrUnexpectedPacket + var pw *packets.PasswordMessage + pw, err = fed.ToConcrete[*packets.PasswordMessage](packet) + if err != nil { + return err } - if err = creds.VerifyMD5(salt, pw.Password); err != nil { - return perror.Wrap(err) + if err = creds.VerifyMD5(salt, string(*pw)); err != nil { + return err } return nil } -func authenticate(ctx *authenticateContext) (err perror.Error) { +func authenticate(ctx *authenticateContext) (err error) { if ctx.Options.Credentials != nil { if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { err = authenticationSASL(ctx, credsSASL) @@ -161,32 +170,40 @@ func authenticate(ctx *authenticateContext) (err perror.Error) { } // send auth Ok - authOk := packets.AuthenticationOk{} - ctx.Packet = authOk.IntoPacket(ctx.Packet) - if err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)); err != nil { + authOk := packets.Authentication{ + Mode: &packets.AuthenticationPayloadOk{}, + } + if err = ctx.Conn.WritePacket(&authOk); err != nil { return } ctx.Conn.Authenticated = true // send backend key data - _, err2 := rand.Read(ctx.Conn.BackendKey[:]) - if err2 != nil { - err = perror.Wrap(err2) + var processID [4]byte + if _, err = rand.Reader.Read(processID[:]); err != nil { return } + var backendKey [4]byte + if _, err = rand.Reader.Read(backendKey[:]); err != nil { + return + } + ctx.Conn.BackendKey = fed.BackendKey{ + ProcessID: int32(binary.BigEndian.Uint32(processID[:])), + SecretKey: int32(binary.BigEndian.Uint32(backendKey[:])), + } keyData := packets.BackendKeyData{ - CancellationKey: ctx.Conn.BackendKey, + ProcessID: ctx.Conn.BackendKey.ProcessID, + SecretKey: ctx.Conn.BackendKey.SecretKey, } - ctx.Packet = keyData.IntoPacket(ctx.Packet) - if err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)); err != nil { + if err = ctx.Conn.WritePacket(&keyData); err != nil { return } return } -func Authenticate(conn *fed.Conn, creds auth.Credentials) (err perror.Error) { +func Authenticate(conn *fed.Conn, creds auth.Credentials) (err error) { if conn.Authenticated { // already authenticated return diff --git a/lib/bouncer/frontends/v0/context.go b/lib/bouncer/frontends/v0/context.go index e0668cef9a24a5c7287da9562012e47f27cda543..993073ffab07b6d371176eff956b7627fa814a75 100644 --- a/lib/bouncer/frontends/v0/context.go +++ b/lib/bouncer/frontends/v0/context.go @@ -3,13 +3,11 @@ package frontends import "gfx.cafe/gfx/pggat/lib/fed" type acceptContext struct { - Packet fed.Packet Conn *fed.Conn Options acceptOptions } type authenticateContext struct { - Packet fed.Packet Conn *fed.Conn Options authenticateOptions } diff --git a/lib/bouncer/frontends/v0/params.go b/lib/bouncer/frontends/v0/params.go index 0d960dbf3edbad11e492c22e899dd71e08824172..4f28f8347cb111f03ade898b28fde91364420d3b 100644 --- a/lib/bouncer/frontends/v0/params.go +++ b/lib/bouncer/frontends/v0/params.go @@ -1,6 +1,8 @@ package frontends +import "gfx.cafe/gfx/pggat/lib/fed" + type acceptParams struct { - CancelKey [8]byte + CancelKey fed.BackendKey IsCanceling bool } diff --git a/lib/fed/backendkey.go b/lib/fed/backendkey.go new file mode 100644 index 0000000000000000000000000000000000000000..e611366bbcf9db31d64ddf330c0a0d455f9ec5ca --- /dev/null +++ b/lib/fed/backendkey.go @@ -0,0 +1,6 @@ +package fed + +type BackendKey struct { + ProcessID int32 + SecretKey int32 +} diff --git a/lib/fed/conn.go b/lib/fed/conn.go index e8534bfed14c3aa1b3246064d4fb8a4b26e7ca39..2b57d9dfd77d36a7ca5320cf6d297fb6b2dc3465 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -1,6 +1,7 @@ package fed import ( + "crypto/tls" "io" "gfx.cafe/gfx/pggat/lib/util/decorator" @@ -15,12 +16,13 @@ type Conn struct { Decoder Decoder Middleware []Middleware + SSL bool User string Database string InitialParameters map[strutil.CIString]string Authenticated bool - BackendKey [8]byte + BackendKey BackendKey } func NewConn(rw io.ReadWriteCloser) *Conn { @@ -88,6 +90,16 @@ func (T *Conn) WritePacket(packet Packet) error { return packet.WriteTo(&T.Encoder) } +func (T *Conn) EnableSSLClient(config *tls.Config) error { + // TODO(garet) + panic("TODO") +} + +func (T *Conn) EnableSSLServer(config *tls.Config) error { + // TODO(garet) + panic("TODO") +} + func (T *Conn) Close() error { if err := T.Encoder.Flush(); err != nil { return err diff --git a/lib/fed/decoder.go b/lib/fed/decoder.go index b27331f38ed8e0f5a6f6eb4aba4101b436c33ee0..ae53d30455e2f1a6d1c46f8b4b667f8c596568c1 100644 --- a/lib/fed/decoder.go +++ b/lib/fed/decoder.go @@ -39,7 +39,7 @@ func (T *Decoder) Next(typed bool) error { if err != nil { return err } - T.typ = T.buf[0] + T.typ = Type(T.buf[0]) T.len = int(binary.BigEndian.Uint32(T.buf[1:5])) - 4 T.pos = 0 return nil diff --git a/lib/fed/encoder.go b/lib/fed/encoder.go index 69145d92dedaa606a8d98869cdf32fe8828352e0..247982be73283008eac0879ae7cde9a83958f2a4 100644 --- a/lib/fed/encoder.go +++ b/lib/fed/encoder.go @@ -33,7 +33,7 @@ func (T *Encoder) Flush() error { func (T *Encoder) Next(typ Type, length int) error { if typ != 0 { - if err := T.Writer.WriteByte(typ); err != nil { + if err := T.Writer.WriteByte(byte(typ)); err != nil { return err } } diff --git a/lib/fed/middlewares/ps/sync.go b/lib/fed/middlewares/ps/sync.go index be296b9c3421570d64d5d46460fdf623419c5325..8f9f19267e23c7918527a56c781338c06742a2be 100644 --- a/lib/fed/middlewares/ps/sync.go +++ b/lib/fed/middlewares/ps/sync.go @@ -12,16 +12,13 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. value, hasValue := c.parameters[name] expected, hasExpected := s.parameters[name] - var packet fed.Packet - if value == expected { if !c.synced { ps := packets.ParameterStatus{ Key: name.String(), Value: expected, } - packet = ps.IntoPacket(packet) - if err := client.WritePacket(packet); err != nil { + if err := client.WritePacket(&ps); err != nil { return err } } @@ -32,7 +29,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. if hasValue && slices.Contains(tracking, name) { var err error - if err, _, packet = backends.SetParameter(server, nil, packet, name, value); err != nil { + if err, _ = backends.SetParameter(server, nil, name, value); err != nil { return err } if s.parameters == nil { @@ -50,8 +47,7 @@ func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed. Key: name.String(), Value: expected, } - packet = ps.IntoPacket(packet) - if err := client.WritePacket(packet); err != nil { + if err := client.WritePacket(&ps); err != nil { return err } } diff --git a/lib/gat/handler.go b/lib/gat/handler.go index 6026c1e6778260a9f25d5a2ce66e5b14677c051f..b0d2d339c14e119b9e56ddd5bacf46bce0319f8b 100644 --- a/lib/gat/handler.go +++ b/lib/gat/handler.go @@ -15,7 +15,7 @@ type Handler interface { type CancellableHandler interface { Handler - Cancel(key [8]byte) + Cancel(key fed.BackendKey) } type MetricsHandler interface { diff --git a/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go b/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go index 6bb9351b023cd1c14533778c68fe9c1a575b239d..6ccdd35f86c431df43cea2a14417ab62f802d2b0 100644 --- a/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go +++ b/lib/gat/handlers/discovery/discoverers/google_cloud_sql/discoverer.go @@ -135,7 +135,7 @@ func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, repli var result authQueryResult client := new(gsql.Client) - err := gsql.ExtendedQuery(client, &result, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", user.Name) + err = gsql.ExtendedQuery(client, &result, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", user.Name) if err != nil { return discovery.Cluster{}, err } @@ -144,11 +144,13 @@ func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, repli return discovery.Cluster{}, err } - initialPacket, err := client.ReadPacket(true, nil) + clientConn := fed.NewConn(client) + + initialPacket, err := clientConn.ReadPacket(true) if err != nil { return discovery.Cluster{}, err } - _, err, err2 := bouncers.Bounce(fed.NewConn(client), admin, initialPacket) + err, err2 := bouncers.Bounce(clientConn, admin, initialPacket) if err != nil { return discovery.Cluster{}, err } diff --git a/lib/gat/handlers/discovery/module.go b/lib/gat/handlers/discovery/module.go index bd6614d2e3f9350789e84dd0cb2637f060d85bbd..3652a482255fbb723c260cf9b57de265e4123248 100644 --- a/lib/gat/handlers/discovery/module.go +++ b/lib/gat/handlers/discovery/module.go @@ -556,7 +556,7 @@ func (T *Module) Handle(conn *fed.Conn) error { return p.Serve(conn) } -func (T *Module) Cancel(key [8]byte) { +func (T *Module) Cancel(key fed.BackendKey) { T.mu.RLock() defer T.mu.RUnlock() T.pools.Range(func(_ string, _ string, p pool.WithCredentials) bool { diff --git a/lib/gat/handlers/pgbouncer/module.go b/lib/gat/handlers/pgbouncer/module.go index a5a8f025fcdc70e1c491edd94bd21d854f08dc38..fd44332c91c87632879649bad63c68f205cc5571 100644 --- a/lib/gat/handlers/pgbouncer/module.go +++ b/lib/gat/handlers/pgbouncer/module.go @@ -118,7 +118,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { T.log.Warn("auth query failed", zap.Error(err)) return "", false } - err = authPool.ServeBot(client) + err = authPool.ServeBot(fed.NewConn(client)) if err != nil && !errors.Is(err, io.EOF) { T.log.Warn("auth query failed", zap.Error(err)) return "", false @@ -285,13 +285,7 @@ func (T *Module) lookup(user, database string) (pool.WithCredentials, bool) { func (T *Module) Handle(conn *fed.Conn) error { // check ssl if T.Config.PgBouncer.ClientTLSSSLMode.IsRequired() { - var ssl bool - netConn, ok := conn.ReadWriteCloser.(*fed.NetConn) - if ok { - ssl = netConn.SSL() - } - - if !ssl { + if !conn.SSL { return perror.New( perror.FATAL, perror.InvalidPassword, @@ -346,7 +340,7 @@ func (T *Module) ReadMetrics(metrics *metrics.Handler) { }) } -func (T *Module) Cancel(key [8]byte) { +func (T *Module) Cancel(key fed.BackendKey) { T.mu.RLock() defer T.mu.RUnlock() T.pools.Range(func(_ string, _ string, p pool.WithCredentials) bool { diff --git a/lib/gat/handlers/pool/module.go b/lib/gat/handlers/pool/module.go index 61b0085a7abd477aac62cd69d0d83d37df226374..ef57d02e42edbf73219a9b66155faaffd6fa7249 100644 --- a/lib/gat/handlers/pool/module.go +++ b/lib/gat/handlers/pool/module.go @@ -97,7 +97,7 @@ func (T *Module) Handle(conn *fed.Conn) error { return T.pool.Serve(conn) } -func (T *Module) Cancel(key [8]byte) { +func (T *Module) Cancel(key fed.BackendKey) { T.pool.Cancel(key) } diff --git a/lib/gat/handlers/require_ssl/module.go b/lib/gat/handlers/require_ssl/module.go index 47f7cb8e63bb3f6ce31fd810505154aa1817a941..476a8285e12db50108b43d5fb435aa1aa8f97bca 100644 --- a/lib/gat/handlers/require_ssl/module.go +++ b/lib/gat/handlers/require_ssl/module.go @@ -26,15 +26,8 @@ func (T *Module) CaddyModule() caddy.ModuleInfo { } func (T *Module) Handle(conn *fed.Conn) error { - var ssl bool - - sslConn, ok := conn.ReadWriteCloser.(fed.SSL) - if ok { - ssl = sslConn.SSL() - } - if T.SSL { - if !ssl { + if !conn.SSL { return perror.New( perror.FATAL, perror.InvalidPassword, @@ -44,7 +37,7 @@ func (T *Module) Handle(conn *fed.Conn) error { return nil } - if ssl { + if conn.SSL { return perror.New( perror.FATAL, perror.InvalidPassword, diff --git a/lib/gat/listen.go b/lib/gat/listen.go index baa1090cef04d4ea7c474b0e98d3807af28d2c0f..7204a61222fa347492060a8c196ee5d1b02be5f8 100644 --- a/lib/gat/listen.go +++ b/lib/gat/listen.go @@ -35,9 +35,7 @@ func (T *Listener) accept() (*fed.Conn, error) { if err != nil { return nil, err } - return fed.NewConn( - fed.NewNetConn(raw), - ), nil + return fed.NewConn(raw), nil } func (T *Listener) Provision(ctx caddy.Context) error { diff --git a/lib/gat/matchers/localaddress.go b/lib/gat/matchers/localaddress.go index c9ac3093f475bf7f3b6ddd3a9c6f128ff68f3933..2219d9d01422c686f237dd654243ec876c0e9af6 100644 --- a/lib/gat/matchers/localaddress.go +++ b/lib/gat/matchers/localaddress.go @@ -48,7 +48,7 @@ func (T *LocalAddress) Provision(ctx caddy.Context) error { } func (T *LocalAddress) Matches(conn *fed.Conn) bool { - netConn, ok := conn.ReadWriteCloser.(*fed.NetConn) + netConn, ok := conn.ReadWriter.(net.Conn) if !ok { return false } diff --git a/lib/gat/matchers/ssl.go b/lib/gat/matchers/ssl.go index 57ae1a0acb907587d5c30070051af0a5a30d4017..255bde81dd008ff70c925f00fa9d3f76f9c722cb 100644 --- a/lib/gat/matchers/ssl.go +++ b/lib/gat/matchers/ssl.go @@ -25,11 +25,7 @@ func (T *SSL) CaddyModule() caddy.ModuleInfo { } func (T *SSL) Matches(conn *fed.Conn) bool { - sslConn, ok := conn.ReadWriteCloser.(fed.SSL) - if !ok { - return T.SSL == false - } - return sslConn.SSL() == T.SSL + return conn.SSL == T.SSL } var _ gat.Matcher = (*SSL)(nil) diff --git a/lib/gat/pool/conn.go b/lib/gat/pool/conn.go index 065a645895951e6693e756a2054e03f45f89b031..c1ea8981e07e6487e41bc72d0c49f126a1cb26d8 100644 --- a/lib/gat/pool/conn.go +++ b/lib/gat/pool/conn.go @@ -55,7 +55,7 @@ func (T *pooledConn) GetInitialParameters() map[strutil.CIString]string { return T.conn.InitialParameters } -func (T *pooledConn) GetBackendKey() [8]byte { +func (T *pooledConn) GetBackendKey() fed.BackendKey { return T.conn.BackendKey } diff --git a/lib/gat/pool/flow.go b/lib/gat/pool/flow.go index e12636210e444859639229cf308bb4e48c2b6cec..712255fcf206e81e5cfe29ee250290a0faa6d922 100644 --- a/lib/gat/pool/flow.go +++ b/lib/gat/pool/flow.go @@ -2,7 +2,6 @@ package pool import ( "gfx.cafe/gfx/pggat/lib/bouncer/backends/v0" - "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/fed/middlewares/eqp" "gfx.cafe/gfx/pggat/lib/fed/middlewares/ps" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" @@ -43,8 +42,6 @@ func syncInitialParameters(options Config, client *pooledClient, server *pooledS clientParams := client.GetInitialParameters() serverParams := server.GetInitialParameters() - var packet fed.Packet - for key, value := range clientParams { // skip already set params if serverParams[key] == value { @@ -52,8 +49,7 @@ func syncInitialParameters(options Config, client *pooledClient, server *pooledS Key: key.String(), Value: serverParams[key], } - packet = p.IntoPacket(packet) - clientErr = client.GetConn().WritePacket(packet) + clientErr = client.GetConn().WritePacket(&p) if clientErr != nil { return } @@ -70,8 +66,7 @@ func syncInitialParameters(options Config, client *pooledClient, server *pooledS Key: key.String(), Value: value, } - packet = p.IntoPacket(packet) - clientErr = client.GetConn().WritePacket(packet) + clientErr = client.GetConn().WritePacket(&p) if clientErr != nil { return } @@ -80,7 +75,7 @@ func syncInitialParameters(options Config, client *pooledClient, server *pooledS continue } - serverErr, _, packet = backends.SetParameter(server.GetConn(), nil, packet, key, value) + serverErr, _ = backends.SetParameter(server.GetConn(), nil, key, value) if serverErr != nil { return } @@ -98,8 +93,7 @@ func syncInitialParameters(options Config, client *pooledClient, server *pooledS Key: key.String(), Value: value, } - packet = p.IntoPacket(packet) - clientErr = client.GetConn().WritePacket(packet) + clientErr = client.GetConn().WritePacket(&p) if clientErr != nil { return } diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index ac7606342a44596777bdc40ae10fed1ceda7389e..258ced2197dc5d878d51b13bbd5770168ed4c767 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -29,7 +29,7 @@ type Pool struct { recipes map[string]*Recipe recipeScaleOrder slices.Sorted[string] clients map[uuid.UUID]*pooledClient - clientsByKey map[[8]byte]*pooledClient + clientsByKey map[fed.BackendKey]*pooledClient servers map[uuid.UUID]*pooledServer serversByRecipe map[string][]*pooledServer mu sync.RWMutex @@ -265,7 +265,7 @@ func (T *Pool) releaseServer(server *pooledServer) { if T.config.ServerResetQuery != "" { server.SetState(metrics.ConnStateRunningResetQuery, uuid.Nil) - err, _, _ := backends.QueryString(server.GetConn(), nil, nil, T.config.ServerResetQuery) + err, _ := backends.QueryString(server.GetConn(), nil, T.config.ServerResetQuery) if err != nil { T.removeServer(server) return @@ -295,7 +295,7 @@ func (T *Pool) Serve( // ServeBot is for clients that don't need initial parameters, cancelling queries, and are ready now. Use Serve for // real clients func (T *Pool) ServeBot( - conn fed.ReadWriteCloser, + conn *fed.Conn, ) error { defer func() { _ = conn.Close() @@ -303,9 +303,7 @@ func (T *Pool) ServeBot( client := newClient( T.config, - &fed.Conn{ - ReadWriteCloser: conn, - }, + conn, ) return T.serve(client, true) @@ -330,8 +328,6 @@ func (T *Pool) serve(client *pooledClient, initialized bool) error { } }() - var packet fed.Packet - if !initialized { server = T.acquireServer(client) if server == nil { @@ -347,8 +343,7 @@ func (T *Pool) serve(client *pooledClient, initialized bool) error { } p := packets.ReadyForQuery('I') - packet = p.IntoPacket(packet) - err = client.GetConn().WritePacket(packet) + err = client.GetConn().WritePacket(&p) if err != nil { return err } @@ -361,7 +356,8 @@ func (T *Pool) serve(client *pooledClient, initialized bool) error { server = nil } - packet, err = client.GetConn().ReadPacket(true, packet) + var packet fed.Packet + packet, err = client.GetConn().ReadPacket(true) if err != nil { return err } @@ -375,7 +371,7 @@ func (T *Pool) serve(client *pooledClient, initialized bool) error { err, serverErr = pair(T.config, client, server) } if err == nil && serverErr == nil { - packet, err, serverErr = bouncers.Bounce(client.GetConn(), server.GetConn(), packet) + err, serverErr = bouncers.Bounce(client.GetConn(), server.GetConn(), packet) } if serverErr != nil { @@ -399,7 +395,7 @@ func (T *Pool) addClient(client *pooledClient) { } T.clients[client.GetID()] = client if T.clientsByKey == nil { - T.clientsByKey = make(map[[8]byte]*pooledClient) + T.clientsByKey = make(map[fed.BackendKey]*pooledClient) } T.clientsByKey[client.GetBackendKey()] = client T.pooler.AddClient(client.GetID()) @@ -419,7 +415,7 @@ func (T *Pool) removeClientL1(client *pooledClient) { delete(T.clientsByKey, client.GetBackendKey()) } -func (T *Pool) Cancel(key [8]byte) { +func (T *Pool) Cancel(key fed.BackendKey) { T.mu.RLock() defer T.mu.RUnlock() diff --git a/lib/gat/pool/recipe/dialer.go b/lib/gat/pool/recipe/dialer.go index 8202c2b9a6b1633fb835e7c092d9e8a84757ee69..b11ac3a98efea8fea1c4ea03079f4104705ad94d 100644 --- a/lib/gat/pool/recipe/dialer.go +++ b/lib/gat/pool/recipe/dialer.go @@ -28,9 +28,7 @@ func (T Dialer) Dial() (*fed.Conn, error) { if err != nil { return nil, err } - conn := fed.NewConn( - fed.NewNetConn(c), - ) + conn := fed.NewConn(c) conn.User = T.Username conn.Database = T.Database err = backends.Accept( @@ -48,14 +46,12 @@ func (T Dialer) Dial() (*fed.Conn, error) { return conn, nil } -func (T Dialer) Cancel(key [8]byte) { +func (T Dialer) Cancel(key fed.BackendKey) { c, err := net.Dial(T.Network, T.Address) if err != nil { return } - conn := fed.NewConn( - fed.NewNetConn(c), - ) + conn := fed.NewConn(c) defer func() { _ = conn.Close() }() @@ -64,5 +60,5 @@ func (T Dialer) Cancel(key [8]byte) { } // wait for server to close the connection, this means that the server received it ok - _, _ = conn.ReadPacket(true, nil) + _, _ = conn.ReadPacket(true) } diff --git a/lib/gat/pool/recipe/recipe.go b/lib/gat/pool/recipe/recipe.go index aeb25e76eb3a4a81b6c4061df7a615274ffe8b7e..78a79391ad159a1e286ac6b0cd69b7aa7f86ea45 100644 --- a/lib/gat/pool/recipe/recipe.go +++ b/lib/gat/pool/recipe/recipe.go @@ -70,6 +70,6 @@ func (T *Recipe) Dial() (*fed.Conn, error) { return T.config.Dialer.Dial() } -func (T *Recipe) Cancel(key [8]byte) { +func (T *Recipe) Cancel(key fed.BackendKey) { T.config.Dialer.Cancel(key) } diff --git a/lib/gat/server.go b/lib/gat/server.go index d31a0f6d593725e6214bbc6a9c98772a21d16a9d..81a1cd58e605f87d638697f36e5e08aedf3d2f2c 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -12,7 +12,6 @@ import ( "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/fed" - packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/gat/metrics" "gfx.cafe/gfx/pggat/lib/perror" ) @@ -95,7 +94,7 @@ func (T *Server) Stop() error { return nil } -func (T *Server) Cancel(key [8]byte) { +func (T *Server) Cancel(key fed.BackendKey) { for _, cancellableHandler := range T.cancellableHandlers { cancellableHandler.Cancel(key) } @@ -123,23 +122,21 @@ func (T *Server) Serve(conn *fed.Conn) { return } - errResp := packets.ErrorResponse{ - Error: perror.Wrap(err), - } - _ = conn.WritePacket(errResp.IntoPacket(nil)) + errResp := perror.ToPacket(perror.Wrap(err)) + _ = conn.WritePacket(errResp) return } } // database not found - errResp := packets.ErrorResponse{ - Error: perror.New( + errResp := perror.ToPacket( + perror.New( perror.FATAL, perror.InvalidPassword, fmt.Sprintf(`Database "%s" not found`, conn.Database), ), - } - _ = conn.WritePacket(errResp.IntoPacket(nil)) + ) + _ = conn.WritePacket(errResp) T.log.Warn("database not found", zap.String("user", conn.User), zap.String("database", conn.Database)) } @@ -153,7 +150,7 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) { tlsConfig = listener.ssl.ServerTLSConfig() } - var cancelKey [8]byte + var cancelKey fed.BackendKey var isCanceling bool var err error cancelKey, isCanceling, err = frontends.Accept(conn, tlsConfig) diff --git a/lib/gsql/client.go b/lib/gsql/client.go index de0a0630725fab6bf7b678fe490196a813a992f3..ae0935d5ba52c812c032972ab2a608a3f9ce91fc 100644 --- a/lib/gsql/client.go +++ b/lib/gsql/client.go @@ -1,13 +1,11 @@ package gsql import ( - "io" "net" "sync" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/util/ring" - "gfx.cafe/gfx/pggat/lib/util/slices" ) type batch struct { @@ -58,70 +56,12 @@ func (T *Client) queueNext() bool { return false } -func (T *Client) ReadPacket(typed bool, buffer fed.Packet) (packet fed.Packet, err error) { - packet = buffer - - T.mu.Lock() - defer T.mu.Unlock() - - var p fed.Packet - for { - var ok bool - p, ok = T.read.PopFront() - if ok { - break - } - - // try to add next in queue - if T.queueNext() { - continue - } - - if T.closed { - err = io.EOF - return - } - - if T.readC == nil { - T.readC = sync.NewCond(&T.mu) - } - T.readC.Wait() - } - - if (p.Type() == 0 && typed) || (p.Type() != 0 && !typed) { - err = ErrTypedMismatch - return - } - - packet = slices.Resize(packet, len(p)) - copy(packet, p) - return +func (T *Client) Read(b []byte) (int, error) { + panic("TODO") } -func (T *Client) WritePacket(packet fed.Packet) error { - T.mu.Lock() - defer T.mu.Unlock() - - for T.write == nil { - if T.read.Length() == 0 && T.queueNext() { - continue - } - - if T.closed { - return io.EOF - } - - if T.writeC == nil { - T.writeC = sync.NewCond(&T.mu) - } - T.writeC.Wait() - } - - if err := T.write.WritePacket(packet); err != nil { - return err - } - - return nil +func (T *Client) Write(b []byte) (int, error) { + panic("TODO") } func (T *Client) Close() error { @@ -142,5 +82,3 @@ func (T *Client) Close() error { } return nil } - -var _ fed.ReadWriteCloser = (*Client)(nil) diff --git a/lib/gsql/eq.go b/lib/gsql/eq.go index 7908d0a819d20fe147ff594113385a2d5a1acd22..d323787e6cb21b31eaad45755109ea3634897df7 100644 --- a/lib/gsql/eq.go +++ b/lib/gsql/eq.go @@ -20,7 +20,7 @@ func ExtendedQuery(client *Client, result any, query string, args ...any) error parse := packets.Parse{ Query: query, } - pkts = append(pkts, parse.IntoPacket(nil)) + pkts = append(pkts, &parse) // bind params := make([][]byte, 0, len(args)) @@ -58,23 +58,23 @@ outer: params = append(params, value) } bind := packets.Bind{ - ParameterValues: params, + Parameters: params, } - pkts = append(pkts, bind.IntoPacket(nil)) + pkts = append(pkts, &bind) // describe describe := packets.Describe{ Which: 'P', } - pkts = append(pkts, describe.IntoPacket(nil)) + pkts = append(pkts, &describe) // execute execute := packets.Execute{} - pkts = append(pkts, execute.IntoPacket(nil)) + pkts = append(pkts, &execute) // sync - sync := fed.NewPacket(packets.TypeSync) - pkts = append(pkts, sync) + sync := packets.Sync{} + pkts = append(pkts, &sync) // result client.Do(NewQueryWriter(result), pkts...) diff --git a/lib/gsql/query.go b/lib/gsql/query.go index c4c8f7f4bade875322000a62aadea58dbe3d0b60..155f1dd82b4ba08402f86b6a1875100d84f6dd95 100644 --- a/lib/gsql/query.go +++ b/lib/gsql/query.go @@ -8,7 +8,7 @@ import ( func Query(client *Client, results []any, query string) { var q = packets.Query(query) - client.Do(NewQueryWriter(results...), q.IntoPacket(nil)) + client.Do(NewQueryWriter(results...), &q) } type QueryWriter struct { diff --git a/lib/gsql/result.go b/lib/gsql/result.go index 11c084500333e529c726cf7bdd05a63b97eef2e7..520d7b520ca74542aea1e49869e0338f56346c3e 100644 --- a/lib/gsql/result.go +++ b/lib/gsql/result.go @@ -3,5 +3,5 @@ package gsql import "gfx.cafe/gfx/pggat/lib/fed" type ResultWriter interface { - fed.Writer + WritePacket(fed.Packet) error } diff --git a/lib/gsql/row.go b/lib/gsql/row.go index 574e8e7738cc52e972109effc9411636fd60ddf6..d6e4b1afced56c3d5de7845dc5b9f9896ed93dc0 100644 --- a/lib/gsql/row.go +++ b/lib/gsql/row.go @@ -1,12 +1,12 @@ package gsql import ( - "errors" "reflect" "strconv" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" + "gfx.cafe/gfx/pggat/lib/perror" ) type RowWriter struct { @@ -28,10 +28,10 @@ func NewRowWriter(result any) *RowWriter { } func (T *RowWriter) set(i int, col []byte) error { - if i >= len(T.rd.Fields) { + if i >= len(T.rd) { return ErrExtraFields } - desc := T.rd.Fields[i] + desc := T.rd[i] result := T.result @@ -228,26 +228,28 @@ outer2: func (T *RowWriter) WritePacket(packet fed.Packet) error { switch packet.Type() { case packets.TypeRowDescription: - if !T.rd.ReadFromPacket(packet) { - return errors.New("invalid format") + rd, err := fed.ToConcrete[*packets.RowDescription](packet) + if err != nil { + return err } + T.rd = *rd case packets.TypeDataRow: - var dr packets.DataRow - if !dr.ReadFromPacket(packet) { - return errors.New("invalid format") + dr, err := fed.ToConcrete[*packets.DataRow](packet) + if err != nil { + return err } - for i, col := range dr.Columns { - if err := T.set(i, col); err != nil { + for i, col := range *dr { + if err = T.set(i, col); err != nil { return err } } T.row += 1 case packets.TypeErrorResponse: - var err packets.ErrorResponse - if !err.ReadFromPacket(packet) { - return errors.New("invalid format") + p, err := fed.ToConcrete[*packets.ErrorResponse](packet) + if err != nil { + return err } - return err.Error + return perror.FromPacket(p) case packets.TypeCommandComplete: T.done = true return nil diff --git a/lib/perror/packet.go b/lib/perror/packet.go index e025ce9c0cea2236192363d0e5cc1cd8d9c1e55b..606c1071ee62e6fd38263b74ec8e96e9dcc1cd6e 100644 --- a/lib/perror/packet.go +++ b/lib/perror/packet.go @@ -2,6 +2,36 @@ package perror import packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" +func FromPacket(packet *packets.ErrorResponse) Error { + var severity Severity + var code Code + var message string + var extra []ExtraField + + for _, field := range *packet { + switch field.Code { + case 'S': + severity = Severity(field.Value) + case 'C': + code = Code(field.Value) + case 'M': + message = field.Value + default: + extra = append(extra, ExtraField{ + Type: Extra(field.Code), + Value: field.Value, + }) + } + } + + return New( + severity, + code, + message, + extra..., + ) +} + func ToPacket(err Error) *packets.ErrorResponse { var resp packets.ErrorResponse resp = append(