diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index a1036acc1eb45e2583a80a41fa964622641fc703..c3281b934f7bbbdbf92eb87ba5bf3160a83ec008 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -199,7 +199,7 @@ func startup0(ctx *acceptContext) (done bool, err error) { } } -func startup1(ctx *acceptContext, params *acceptParams) (done bool, err error) { +func startup1(ctx *acceptContext) (done bool, err error) { ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) if err != nil { return @@ -207,7 +207,7 @@ func startup1(ctx *acceptContext, params *acceptParams) (done bool, err error) { switch ctx.Packet.Type() { case packets.TypeBackendKeyData: - ctx.Packet.ReadBytes(params.BackendKey[:]) + ctx.Packet.ReadBytes(ctx.Conn.BackendKey[:]) return false, nil case packets.TypeParameterStatus: var ps packets.ParameterStatus @@ -216,10 +216,10 @@ func startup1(ctx *acceptContext, params *acceptParams) (done bool, err error) { return } ikey := strutil.MakeCIString(ps.Key) - if params.InitialParameters == nil { - params.InitialParameters = make(map[strutil.CIString]string) + if ctx.Conn.InitialParameters == nil { + ctx.Conn.InitialParameters = make(map[strutil.CIString]string) } - params.InitialParameters[ikey] = ps.Value + ctx.Conn.InitialParameters[ikey] = ps.Value return false, nil case packets.TypeReadyForQuery: return true, nil @@ -248,7 +248,7 @@ func enableSSL(ctx *acceptContext) (bool, error) { return false, err } - byteReader, ok := ctx.Conn.(io.ByteReader) + byteReader, ok := ctx.Conn.ReadWriteCloser.(io.ByteReader) if !ok { return false, errors.New("server must be io.ByteReader to enable ssl") } @@ -264,7 +264,7 @@ func enableSSL(ctx *acceptContext) (bool, error) { return false, nil } - sslClient, ok := ctx.Conn.(fed.SSLClient) + sslClient, ok := ctx.Conn.ReadWriteCloser.(fed.SSLClient) if !ok { return false, errors.New("server must be fed.SSLClient to enable ssl") } @@ -276,23 +276,20 @@ func enableSSL(ctx *acceptContext) (bool, error) { return true, nil } -func accept(ctx *acceptContext) (acceptParams, error) { +func accept(ctx *acceptContext) error { username := ctx.Options.Username if ctx.Options.Database == "" { ctx.Options.Database = username } - var params acceptParams - if ctx.Options.SSLMode.ShouldAttempt() { - var err error - params.SSLEnabled, err = enableSSL(ctx) + sslEnabled, err := enableSSL(ctx) if err != nil { - return acceptParams{}, err + return err } - if !params.SSLEnabled && ctx.Options.SSLMode.IsRequired() { - return acceptParams{}, errors.New("server rejected SSL encryption") + if !sslEnabled && ctx.Options.SSLMode.IsRequired() { + return errors.New("server rejected SSL encryption") } } @@ -317,14 +314,14 @@ func accept(ctx *acceptContext) (acceptParams, error) { err := ctx.Conn.WritePacket(ctx.Packet) if err != nil { - return acceptParams{}, err + return err } for { var done bool done, err = startup0(ctx) if err != nil { - return acceptParams{}, err + return err } if done { break @@ -333,9 +330,9 @@ func accept(ctx *acceptContext) (acceptParams, error) { for { var done bool - done, err = startup1(ctx, ¶ms) + done, err = startup1(ctx) if err != nil { - return acceptParams{}, err + return err } if done { break @@ -343,23 +340,18 @@ func accept(ctx *acceptContext) (acceptParams, error) { } // startup complete, connection is ready for queries - return params, nil + return nil } func Accept( - conn fed.ReadWriter, + conn *fed.Conn, sslMode bouncer.SSLMode, sslConfig *tls.Config, username string, credentials auth.Credentials, database string, startupParameters map[strutil.CIString]string, -) ( - sslEnabled bool, - initialParameters map[strutil.CIString]string, - backendKey [8]byte, - err error, -) { +) error { ctx := acceptContext{ Conn: conn, Options: acceptOptions{ @@ -371,10 +363,5 @@ func Accept( StartupParameters: startupParameters, }, } - var params acceptParams - params, err = accept(&ctx) - sslEnabled = params.SSLEnabled - initialParameters = params.InitialParameters - backendKey = params.BackendKey - return + return accept(&ctx) } diff --git a/lib/bouncer/backends/v0/cancel.go b/lib/bouncer/backends/v0/cancel.go index 1ff2e957ab15fc15b14429e55f7656a40e1cf4cf..23769b8dbed161c791b0e64ee450820a24d8bb9d 100644 --- a/lib/bouncer/backends/v0/cancel.go +++ b/lib/bouncer/backends/v0/cancel.go @@ -2,7 +2,7 @@ package backends import "gfx.cafe/gfx/pggat/lib/fed" -func Cancel(server fed.ReadWriter, key [8]byte) error { +func Cancel(server *fed.Conn, key [8]byte) error { packet := fed.NewPacket(0, 12) packet = packet.AppendUint16(1234) packet = packet.AppendUint16(5678) diff --git a/lib/bouncer/backends/v0/context.go b/lib/bouncer/backends/v0/context.go index d74ffdeabd8a6219c38883ce11f50ec6cff6becb..01945b65e47630a4bdfa299abf446d1bbdc7f05c 100644 --- a/lib/bouncer/backends/v0/context.go +++ b/lib/bouncer/backends/v0/context.go @@ -6,14 +6,14 @@ import ( type acceptContext struct { Packet fed.Packet - Conn fed.ReadWriter + Conn *fed.Conn Options acceptOptions } type context struct { - Server fed.ReadWriter + Server *fed.Conn Packet fed.Packet - Peer fed.ReadWriter + Peer *fed.Conn PeerError error TxState byte } diff --git a/lib/bouncer/backends/v0/params.go b/lib/bouncer/backends/v0/params.go deleted file mode 100644 index d9ac2a63a4802c2443f73bac16ee377cd864103a..0000000000000000000000000000000000000000 --- a/lib/bouncer/backends/v0/params.go +++ /dev/null @@ -1,9 +0,0 @@ -package backends - -import "gfx.cafe/gfx/pggat/lib/util/strutil" - -type acceptParams struct { - SSLEnabled bool - InitialParameters map[strutil.CIString]string - BackendKey [8]byte -} diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index 118e8511aff8cf50888b860e9797b956b9168633..89a8bd3f82d758a7f924f14d7944e0285e4376ef 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -106,7 +106,7 @@ func queryString(ctx *context, q string) error { return query(ctx) } -func QueryString(server, peer fed.ReadWriter, buffer fed.Packet, query string) (err, peerError error, packet fed.Packet) { +func QueryString(server, peer *fed.Conn, buffer fed.Packet, query string) (err, peerError error, packet fed.Packet) { ctx := context{ Server: server, Peer: peer, @@ -118,7 +118,7 @@ func QueryString(server, peer fed.ReadWriter, buffer fed.Packet, query string) ( return } -func SetParameter(server, peer fed.ReadWriter, buffer fed.Packet, name strutil.CIString, value string) (err, peerError error, packet fed.Packet) { +func SetParameter(server, peer *fed.Conn, buffer fed.Packet, name strutil.CIString, value string) (err, peerError error, packet fed.Packet) { return QueryString( server, peer, @@ -213,7 +213,7 @@ func sync(ctx *context) (bool, error) { } } -func Sync(server, peer fed.ReadWriter, buffer fed.Packet) (err, peerErr error, packet fed.Packet) { +func Sync(server, peer *fed.Conn, buffer fed.Packet) (err, peerErr error, packet fed.Packet) { ctx := context{ Server: server, Peer: peer, @@ -305,7 +305,7 @@ func transaction(ctx *context) error { } } -func Transaction(server, peer fed.ReadWriter, initialPacket fed.Packet) (err, peerError error, packet fed.Packet) { +func Transaction(server, peer *fed.Conn, initialPacket fed.Packet) (err, peerError error, packet fed.Packet) { ctx := context{ Server: server, Peer: peer, diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index f39c52c9a4d0914b370ac7145ca8ec6596735534..573b8cc28a3ac92de6dcf1c5a03cfa5ace827877 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -7,7 +7,7 @@ import ( "gfx.cafe/gfx/pggat/lib/perror" ) -func clientFail(packet fed.Packet, client fed.ReadWriter, err perror.Error) fed.Packet { +func clientFail(packet fed.Packet, client *fed.Conn, err perror.Error) fed.Packet { // send fatal error to client resp := packets.ErrorResponse{ Error: err, @@ -17,7 +17,7 @@ func clientFail(packet fed.Packet, client fed.ReadWriter, err perror.Error) fed. return packet } -func Bounce(client, server fed.ReadWriter, initialPacket fed.Packet) (packet fed.Packet, clientError error, serverError error) { +func Bounce(client, server *fed.Conn, initialPacket fed.Packet) (packet fed.Packet, clientError error, serverError error) { serverError, clientError, packet = backends.Transaction(server, client, initialPacket) if clientError != nil { diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index ab061d866110cbf0f046ab6ce8b9c89113a60920..5fadb9e3e51ba004a9e06fa86f103944bb7b09fd 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -37,7 +37,7 @@ func startup0( done = true return case 5679: - byteWriter, ok := ctx.Conn.(io.ByteWriter) + byteWriter, ok := ctx.Conn.ReadWriteCloser.(io.ByteWriter) if !ok { err = perror.New( perror.FATAL, @@ -53,7 +53,7 @@ func startup0( return } - sslServer, ok := ctx.Conn.(fed.SSLServer) + sslServer, ok := ctx.Conn.ReadWriteCloser.(fed.SSLServer) if !ok { err = perror.Wrap(byteWriter.WriteByte('N')) return @@ -66,10 +66,9 @@ func startup0( if err = perror.Wrap(sslServer.EnableSSLServer(ctx.Options.SSLConfig)); err != nil { return } - params.SSLEnabled = true return case 5680: - byteWriter, ok := ctx.Conn.(io.ByteWriter) + byteWriter, ok := ctx.Conn.ReadWriteCloser.(io.ByteWriter) if !ok { err = perror.New( perror.FATAL, @@ -115,9 +114,9 @@ func startup0( switch key { case "user": - params.User = value + ctx.Conn.User = value case "database": - params.Database = value + ctx.Conn.Database = value case "options": fields := strings.Fields(value) for i := 0; i < len(fields); i++ { @@ -138,10 +137,10 @@ func startup0( ikey := strutil.MakeCIString(key) - if params.InitialParameters == nil { - params.InitialParameters = make(map[strutil.CIString]string) + if ctx.Conn.InitialParameters == nil { + ctx.Conn.InitialParameters = make(map[strutil.CIString]string) } - params.InitialParameters[ikey] = value + ctx.Conn.InitialParameters[ikey] = value default: err = perror.New( perror.FATAL, @@ -165,10 +164,10 @@ func startup0( } else { ikey := strutil.MakeCIString(key) - if params.InitialParameters == nil { - params.InitialParameters = make(map[strutil.CIString]string) + if ctx.Conn.InitialParameters == nil { + ctx.Conn.InitialParameters = make(map[strutil.CIString]string) } - params.InitialParameters[ikey] = value + ctx.Conn.InitialParameters[ikey] = value } } } @@ -186,7 +185,7 @@ func startup0( } } - if params.User == "" { + if ctx.Conn.User == "" { err = perror.New( perror.FATAL, perror.InvalidAuthorizationSpecification, @@ -194,8 +193,8 @@ func startup0( ) return } - if params.Database == "" { - params.Database = params.User + if ctx.Conn.Database == "" { + ctx.Conn.Database = ctx.Conn.User } done = true @@ -236,13 +235,9 @@ func accept(ctx *acceptContext) (acceptParams, perror.Error) { return params, nil } -func Accept(conn fed.ReadWriter, tlsConfig *tls.Config) ( +func Accept(conn *fed.Conn, tlsConfig *tls.Config) ( cancelKey [8]byte, isCanceling bool, - sslEnabled bool, - user string, - database string, - initialParameters map[strutil.CIString]string, err perror.Error, ) { ctx := acceptContext{ @@ -255,9 +250,5 @@ func Accept(conn fed.ReadWriter, tlsConfig *tls.Config) ( params, err = accept(&ctx) cancelKey = params.CancelKey isCanceling = params.IsCanceling - sslEnabled = params.SSLEnabled - user = params.User - database = params.Database - initialParameters = params.InitialParameters return } diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index 680002a1c354b4b51db15e82d0ba8c7da049f251..b5feb148d85e75b7ccf2975aee8ba74b36591289 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -142,7 +142,7 @@ func authenticationMD5(ctx *authenticateContext, creds auth.MD5Server) perror.Er return nil } -func authenticate0(ctx *authenticateContext) (params authenticateParams, err perror.Error) { +func authenticate0(ctx *authenticateContext) (err perror.Error) { if ctx.Options.Credentials != nil { if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { err = authenticationSASL(ctx, credsSASL) @@ -168,14 +168,14 @@ func authenticate0(ctx *authenticateContext) (params authenticateParams, err per } // send backend key data - _, err2 := rand.Read(params.BackendKey[:]) + _, err2 := rand.Read(ctx.Conn.BackendKey[:]) if err2 != nil { err = perror.Wrap(err2) return } keyData := packets.BackendKeyData{ - CancellationKey: params.BackendKey, + CancellationKey: ctx.Conn.BackendKey, } ctx.Packet = keyData.IntoPacket(ctx.Packet) if err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)); err != nil { @@ -185,24 +185,22 @@ func authenticate0(ctx *authenticateContext) (params authenticateParams, err per return } -func authenticate(ctx *authenticateContext) (authenticateParams, perror.Error) { - params, err := authenticate0(ctx) +func authenticate(ctx *authenticateContext) perror.Error { + err := authenticate0(ctx) if err != nil { fail(ctx.Packet, ctx.Conn, err) - return authenticateParams{}, err + return err } - return params, nil + return nil } -func Authenticate(conn fed.ReadWriter, creds auth.Credentials) (backendKey [8]byte, err perror.Error) { +func Authenticate(conn *fed.Conn, creds auth.Credentials) (err perror.Error) { ctx := authenticateContext{ Conn: conn, Options: authenticateOptions{ Credentials: creds, }, } - var params authenticateParams - params, err = authenticate(&ctx) - backendKey = params.BackendKey + err = authenticate(&ctx) return } diff --git a/lib/bouncer/frontends/v0/context.go b/lib/bouncer/frontends/v0/context.go index aca67272ab2f3888527a1c7b9f13abe67c23d283..e0668cef9a24a5c7287da9562012e47f27cda543 100644 --- a/lib/bouncer/frontends/v0/context.go +++ b/lib/bouncer/frontends/v0/context.go @@ -4,12 +4,12 @@ import "gfx.cafe/gfx/pggat/lib/fed" type acceptContext struct { Packet fed.Packet - Conn fed.ReadWriter + Conn *fed.Conn Options acceptOptions } type authenticateContext struct { Packet fed.Packet - Conn fed.ReadWriter + Conn *fed.Conn Options authenticateOptions } diff --git a/lib/bouncer/frontends/v0/params.go b/lib/bouncer/frontends/v0/params.go index 2aba47cefdc5483d47cbd486e8c34fd9c0b949c7..0d960dbf3edbad11e492c22e899dd71e08824172 100644 --- a/lib/bouncer/frontends/v0/params.go +++ b/lib/bouncer/frontends/v0/params.go @@ -1,19 +1,6 @@ package frontends -import "gfx.cafe/gfx/pggat/lib/util/strutil" - type acceptParams struct { CancelKey [8]byte IsCanceling bool - - // or - - SSLEnabled bool - User string - Database string - InitialParameters map[strutil.CIString]string -} - -type authenticateParams struct { - BackendKey [8]byte } diff --git a/lib/fed/middlewares/eqp/sync.go b/lib/fed/middlewares/eqp/sync.go index 41f1e21eea165aa73e8275c33c1ec5a9aa06b8a9..150d8e0d9ef2824a7cbbdc4f9029dc92d29abf7d 100644 --- a/lib/fed/middlewares/eqp/sync.go +++ b/lib/fed/middlewares/eqp/sync.go @@ -6,7 +6,7 @@ import ( packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" ) -func Sync(c *Client, server fed.ReadWriter, s *Server) error { +func Sync(c *Client, server *fed.Conn, s *Server) error { var needsBackendSync bool // close all portals on server diff --git a/lib/fed/middlewares/ps/sync.go b/lib/fed/middlewares/ps/sync.go index e5a1cd4c8abd1048bd7c18ae07464c0493f2e6e7..be296b9c3421570d64d5d46460fdf623419c5325 100644 --- a/lib/fed/middlewares/ps/sync.go +++ b/lib/fed/middlewares/ps/sync.go @@ -8,7 +8,7 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server fed.ReadWriter, s *Server, name strutil.CIString) error { +func sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server, name strutil.CIString) error { value, hasValue := c.parameters[name] expected, hasExpected := s.parameters[name] @@ -59,7 +59,7 @@ func sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server return nil } -func Sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server fed.ReadWriter, s *Server) (clientErr, serverErr error) { +func Sync(tracking []strutil.CIString, client *fed.Conn, c *Client, server *fed.Conn, s *Server) (clientErr, serverErr error) { for name := range c.parameters { if serverErr = sync(tracking, client, c, server, s, name); serverErr != nil { return diff --git a/lib/gat/app.go b/lib/gat/app.go index b8be69410f8340a22b6f9ee43f69a6233406d55a..a57b484b9788b00dea9f9c007bbc3acc9560684e 100644 --- a/lib/gat/app.go +++ b/lib/gat/app.go @@ -111,8 +111,7 @@ func (T *App) serve(server *Server, conn *fed.Conn) { return } - var err error - conn.BackendKey, err = frontends.Authenticate(conn.ReadWriteCloser, p.Credentials()) + err := frontends.Authenticate(conn, p.Credentials()) if err != nil { T.log.Warn("error authenticating client", zap.Error(err)) return @@ -141,7 +140,7 @@ func (T *App) accept(listener *Listener, conn *fed.Conn) { var cancelKey [8]byte var isCanceling bool var err error - cancelKey, isCanceling, _, conn.User, conn.Database, conn.InitialParameters, err = frontends.Accept(conn.ReadWriteCloser, tlsConfig) + cancelKey, isCanceling, err = frontends.Accept(conn, tlsConfig) if err != nil { T.log.Warn("error accepting client", zap.Error(err)) return diff --git a/lib/gat/pool/recipe/dialer.go b/lib/gat/pool/recipe/dialer.go index e9392c118f25220fafe6691f409af7aaf6ddf004..974e0bb3e0da38400f61b6119b413ea05508f3e3 100644 --- a/lib/gat/pool/recipe/dialer.go +++ b/lib/gat/pool/recipe/dialer.go @@ -35,8 +35,8 @@ func (T Dialer) Dial() (*fed.Conn, error) { ) conn.User = T.Username conn.Database = T.Database - _, conn.InitialParameters, conn.BackendKey, err = backends.Accept( - conn.ReadWriteCloser, + err = backends.Accept( + conn, T.SSLMode, T.SSLConfig, T.Username, @@ -55,7 +55,9 @@ func (T Dialer) Cancel(key [8]byte) error { if err != nil { return err } - conn := fed.NewNetConn(c) + conn := fed.NewConn( + fed.NewNetConn(c), + ) defer func() { _ = conn.Close() }() diff --git a/lib/gat/providers/discovery/discoverers/google_cloud_sql/discoverer.go b/lib/gat/providers/discovery/discoverers/google_cloud_sql/discoverer.go index 41ba08e46bf421edc240e9fc9c4c106edb3e88f7..d4fa4020c66aa4f546da7983b7a8d028611ee375 100644 --- a/lib/gat/providers/discovery/discoverers/google_cloud_sql/discoverer.go +++ b/lib/gat/providers/discovery/discoverers/google_cloud_sql/discoverer.go @@ -147,7 +147,7 @@ func (T *Discoverer) instanceToCluster(primary *sqladmin.DatabaseInstance, repli if err != nil { return discovery.Cluster{}, err } - _, err, err2 := bouncers.Bounce(client, admin, initialPacket) + _, err, err2 := bouncers.Bounce(fed.NewConn(client), admin, initialPacket) if err != nil { return discovery.Cluster{}, err }