diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 6a36d596002387f037a496c34bd1bbdf7822f7ff..78fe305a39f7fe69a701f01d99f97685ef68e4a6 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -8,6 +8,7 @@ import ( "pggat2/lib/bouncer/backends/v0" "pggat2/lib/bouncer/bouncers/v0" "pggat2/lib/bouncer/frontends/v0" + "pggat2/lib/middleware/middlewares/eqp" "pggat2/lib/middleware/middlewares/unread" "pggat2/lib/middleware/middlewares/unterminate" "pggat2/lib/pnet" @@ -27,10 +28,11 @@ func testServer(r rob.Scheduler) { } server := pnet.MakeIOReadWriter(conn) backends.Accept(&server) + consumer := eqp.MakeConsumer(&server) sink := r.NewSink(0) for { j := sink.Read().(job) - bouncers.Bounce(j.client, &server) + bouncers.Bounce(j.client, &consumer) select { case j.done <- struct{}{}: default: @@ -60,10 +62,11 @@ func main() { client := pnet.MakeIOReadWriter(conn) ut := unterminate.MakeUnterminate(&client) frontends.Accept(ut) + creator := eqp.MakeCreator(ut) done := make(chan struct{}) defer close(done) for { - u, err := unread.NewUnread(ut) + u, err := unread.NewUnread(&creator) if err != nil { break } diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index 7687d1da1de99e9da20a379cdce3394b46462bf8..a1d5d0a6f043170023c07b17f20e25309d439828 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -251,6 +251,8 @@ func Accept(server pnet.ReadWriter) { // TODO(garet) don't hardcode username and password out.String("user") out.String("postgres") + out.String("database") + out.String("uniswap") out.String("") err := server.Send(out.Finish()) diff --git a/lib/bouncer/bouncers/v0/bouncer.go b/lib/bouncer/bouncers/v0/bouncer.go index 24b2a430b7defc44b8f0ff291053f76ea943d0f2..250e90173863496ad7a13ccaf1b7e5ce3a3de875 100644 --- a/lib/bouncer/bouncers/v0/bouncer.go +++ b/lib/bouncer/bouncers/v0/bouncer.go @@ -2,10 +2,13 @@ package bouncers import ( "errors" + "log" + "runtime/debug" "pggat2/lib/perror" "pggat2/lib/pnet" "pggat2/lib/pnet/packet" + packets "pggat2/lib/pnet/packet/packets/v3.0" ) type Status int @@ -16,7 +19,13 @@ const ( ) func clientFail(client pnet.ReadWriter, err perror.Error) { - panic(err) + // DEBUG(garet) + log.Println("client fail", err) + debug.PrintStack() + + out := client.Write() + packets.WriteErrorResponse(out, err) + _ = client.Send(out.Finish()) } func serverFail(server pnet.ReadWriter, err error) { @@ -239,6 +248,98 @@ func functionCall(client, server pnet.ReadWriter, in packet.In) (status Status) return Ok } +func sync0(client, server pnet.ReadWriter) (done bool, status Status) { + in, err := server.Read() + if err != nil { + serverFail(server, err) + return false, Fail + } + + switch in.Type() { + case packet.ParseComplete, + packet.BindComplete, + packet.ErrorResponse, + packet.RowDescription, + packet.NoData, + packet.ParameterDescription, + + packet.CommandComplete, + packet.DataRow, + packet.EmptyQueryResponse, + packet.NoticeResponse, + packet.ParameterStatus, + packet.PortalSuspended: + err = pnet.ProxyPacket(client, in) + if err != nil { + clientFail(client, perror.Wrap(err)) + return false, Fail + } + return false, Ok + case packet.ReadyForQuery: + err = pnet.ProxyPacket(client, in) + if err != nil { + clientFail(client, perror.Wrap(err)) + return false, Fail + } + return true, Ok + default: + log.Printf("operation %c", in.Type()) + serverFail(server, errors.New("protocol error")) + return false, Fail + } +} + +func sync(client, server pnet.ReadWriter, in packet.In) (status Status) { + // send initial (sync) to server + err := pnet.ProxyPacket(server, in) + if err != nil { + serverFail(server, err) + return Fail + } + + // relay everything until ready for query + for { + var done bool + done, status = sync0(client, server) + if status != Ok { + return + } + if done { + break + } + } + return Ok +} + +func eqp(client, server pnet.ReadWriter, in packet.In) (status Status) { + for { + switch in.Type() { + case packet.Sync: + return sync(client, server, in) + case packet.Parse, packet.Bind, packet.Describe, packet.Execute: + err := pnet.ProxyPacket(server, in) + if err != nil { + serverFail(server, err) + return Fail + } + default: + log.Printf("operation %c", in.Type()) + clientFail(client, perror.New( + perror.ERROR, + perror.FeatureNotSupported, + "unsupported operation", + )) + return Fail + } + var err error + in, err = client.Read() + if err != nil { + clientFail(client, perror.Wrap(err)) + return Fail + } + } +} + func Bounce(client, server pnet.ReadWriter) { in, err := client.Read() if err != nil { @@ -251,7 +352,10 @@ func Bounce(client, server pnet.ReadWriter) { query(client, server, in) case packet.FunctionCall: functionCall(client, server, in) + case packet.Sync, packet.Parse, packet.Bind, packet.Describe, packet.Execute: + eqp(client, server, in) default: + log.Printf("operation %c", in.Type()) clientFail(client, perror.New( perror.ERROR, perror.FeatureNotSupported, diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index a1e051e5f9fb8e5a020f7bdb2cad3982c7e50323..a160b66a7bd58bac9e5c9e06d85072a8b00b1ee6 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -2,6 +2,8 @@ package frontends import ( "crypto/rand" + "log" + "runtime/debug" "strings" "pggat2/lib/auth/sasl" @@ -18,10 +20,13 @@ const ( ) func fail(client pnet.ReadWriter, err perror.Error) { + // DEBUG(garet) + log.Println("client fail", err) + debug.PrintStack() + out := client.Write() packets.WriteErrorResponse(out, err) _ = client.Send(out.Finish()) - panic(err) } func startup0(client pnet.ReadWriter) (done bool, status Status) { diff --git a/lib/middleware/middlewares/eqp/consumer.go b/lib/middleware/middlewares/eqp/consumer.go index b8db773e8cdd65a407793f84e1b4b91a344fb386..bd364d3f6ff27a300e8e8f3828a2c48b3af22a3a 100644 --- a/lib/middleware/middlewares/eqp/consumer.go +++ b/lib/middleware/middlewares/eqp/consumer.go @@ -4,12 +4,18 @@ import ( "pggat2/lib/pnet" "pggat2/lib/pnet/packet" packets "pggat2/lib/pnet/packet/packets/v3.0" + "pggat2/lib/util/decorator" + "pggat2/lib/util/ring" ) type Consumer struct { - preparedStatements map[string]PreparedStatement - portals map[string]Portal - inner pnet.ReadWriter + noCopy decorator.NoCopy + + preparedStatements map[string]PreparedStatement + portals map[string]Portal + pendingPreparedStatements ring.Ring[string] + pendingPortals ring.Ring[string] + inner pnet.ReadWriter } func MakeConsumer(inner pnet.ReadWriter) Consumer { @@ -20,23 +26,46 @@ func MakeConsumer(inner pnet.ReadWriter) Consumer { } } -func (T Consumer) Read() (packet.In, error) { - return T.inner.Read() +func NewConsumer(inner pnet.ReadWriter) *Consumer { + c := MakeConsumer(inner) + return &c +} + +func (T *Consumer) Read() (packet.In, error) { + in, err := T.inner.Read() + if err != nil { + return packet.In{}, err + } + switch in.Type() { + case packet.ParseComplete: + T.pendingPreparedStatements.PopFront() + case packet.BindComplete: + T.pendingPortals.PopFront() + case packet.ReadyForQuery: + // remove all pending, they were not added. + for pending, ok := T.pendingPreparedStatements.PopFront(); ok; pending, ok = T.pendingPreparedStatements.PopFront() { + delete(T.preparedStatements, pending) + } + for pending, ok := T.pendingPortals.PopFront(); ok; pending, ok = T.pendingPortals.PopFront() { + delete(T.portals, pending) + } + } + return in, nil } -func (T Consumer) ReadUntyped() (packet.In, error) { +func (T *Consumer) ReadUntyped() (packet.In, error) { return T.inner.ReadUntyped() } -func (T Consumer) Write() packet.Out { +func (T *Consumer) Write() packet.Out { return T.inner.Write() } -func (T Consumer) WriteByte(b byte) error { +func (T *Consumer) WriteByte(b byte) error { return T.inner.WriteByte(b) } -func (T Consumer) Send(typ packet.Type, bytes []byte) error { +func (T *Consumer) Send(typ packet.Type, bytes []byte) error { buf := packet.MakeInBuf(typ, bytes) in := packet.MakeIn(&buf) switch typ { @@ -58,6 +87,7 @@ func (T Consumer) Send(typ packet.Type, bytes []byte) error { Query: query, ParameterDataTypes: parameterDataTypes, } + T.pendingPreparedStatements.PushBack(destination) case packet.Bind: destination, source, parameterFormatCodes, parameterValues, resultFormatCodes, ok := packets.ReadBind(in) if !ok { @@ -74,6 +104,7 @@ func (T Consumer) Send(typ packet.Type, bytes []byte) error { ParameterValues: parameterValues, ResultFormatCodes: resultFormatCodes, } + T.pendingPortals.PushBack(destination) case packet.Close: which, target, ok := packets.ReadClose(in) if !ok { @@ -91,4 +122,4 @@ func (T Consumer) Send(typ packet.Type, bytes []byte) error { return T.inner.Send(typ, bytes) } -var _ pnet.ReadWriter = Consumer{} +var _ pnet.ReadWriter = (*Consumer)(nil) diff --git a/lib/middleware/middlewares/eqp/creator.go b/lib/middleware/middlewares/eqp/creator.go index b10b72df3fbdffc7efff785c800e24eea47e8684..6c1baa692685e1a9420e518fefcd03cc0866a9e1 100644 --- a/lib/middleware/middlewares/eqp/creator.go +++ b/lib/middleware/middlewares/eqp/creator.go @@ -4,12 +4,18 @@ import ( "pggat2/lib/pnet" "pggat2/lib/pnet/packet" packets "pggat2/lib/pnet/packet/packets/v3.0" + "pggat2/lib/util/decorator" + "pggat2/lib/util/ring" ) type Creator struct { - preparedStatements map[string]PreparedStatement - portals map[string]Portal - inner pnet.ReadWriter + noCopy decorator.NoCopy + + preparedStatements map[string]PreparedStatement + portals map[string]Portal + pendingPreparedStatements ring.Ring[string] + pendingPortals ring.Ring[string] + inner pnet.ReadWriter } func MakeCreator(inner pnet.ReadWriter) Creator { @@ -20,7 +26,12 @@ func MakeCreator(inner pnet.ReadWriter) Creator { } } -func (T Creator) Read() (packet.In, error) { +func NewCreator(inner pnet.ReadWriter) *Creator { + c := MakeCreator(inner) + return &c +} + +func (T *Creator) Read() (packet.In, error) { for { in, err := T.inner.Read() if err != nil { @@ -37,10 +48,16 @@ func (T Creator) Read() (packet.In, error) { if !ok { return packet.In{}, ErrBadPacketFormat } + if destination != "" { + if _, ok = T.preparedStatements[destination]; ok { + return packet.In{}, ErrPreparedStatementExists + } + } T.preparedStatements[destination] = PreparedStatement{ Query: query, ParameterDataTypes: parameterDataTypes, } + T.pendingPreparedStatements.PushBack(destination) // send parse complete out := T.inner.Write() @@ -54,12 +71,18 @@ func (T Creator) Read() (packet.In, error) { if !ok { return packet.In{}, ErrBadPacketFormat } + if destination != "" { + if _, ok = T.portals[destination]; ok { + return packet.In{}, ErrPortalExists + } + } T.portals[destination] = Portal{ Source: source, ParameterFormatCodes: parameterFormatCodes, ParameterValues: parameterValues, ResultFormatCodes: resultFormatCodes, } + T.pendingPortals.PushBack(destination) // send bind complete out := T.inner.Write() @@ -87,20 +110,34 @@ func (T Creator) Read() (packet.In, error) { } } -func (T Creator) ReadUntyped() (packet.In, error) { +func (T *Creator) ReadUntyped() (packet.In, error) { return T.inner.ReadUntyped() } -func (T Creator) Write() packet.Out { +func (T *Creator) Write() packet.Out { return T.inner.Write() } -func (T Creator) WriteByte(b byte) error { +func (T *Creator) WriteByte(b byte) error { return T.inner.WriteByte(b) } -func (T Creator) Send(typ packet.Type, payload []byte) error { +func (T *Creator) Send(typ packet.Type, payload []byte) error { + switch typ { + case packet.ParseComplete: + T.pendingPreparedStatements.PopFront() + case packet.BindComplete: + T.pendingPortals.PopFront() + case packet.ReadyForQuery: + // remove all pending, they were not added. + for pending, ok := T.pendingPreparedStatements.PopFront(); ok; pending, ok = T.pendingPreparedStatements.PopFront() { + delete(T.preparedStatements, pending) + } + for pending, ok := T.pendingPortals.PopFront(); ok; pending, ok = T.pendingPortals.PopFront() { + delete(T.portals, pending) + } + } return T.inner.Send(typ, payload) } -var _ pnet.ReadWriter = Creator{} +var _ pnet.ReadWriter = (*Creator)(nil) diff --git a/lib/middleware/middlewares/eqp/stealer.go b/lib/middleware/middlewares/eqp/stealer.go index ee25753042a9b58a79662e68b71b2500c58b1106..3d59d7829825bedc2a2d0b23b9f6035ce46a78ea 100644 --- a/lib/middleware/middlewares/eqp/stealer.go +++ b/lib/middleware/middlewares/eqp/stealer.go @@ -10,14 +10,14 @@ import ( // Stealer wraps a Consumer and duplicates the underlying Consumer's portals and prepared statements on use. type Stealer struct { - creator Creator - consumer Consumer + creator *Creator + consumer *Consumer // need a second buf because we cannot use the underlying Consumer's buf (or it would overwrite the outgoing packet) buf packet.OutBuf } -func NewStealer(consumer Consumer, creator Creator) *Stealer { +func NewStealer(consumer *Consumer, creator *Creator) *Stealer { return &Stealer{ creator: creator, consumer: consumer, @@ -25,7 +25,21 @@ func NewStealer(consumer Consumer, creator Creator) *Stealer { } func (T *Stealer) Read() (packet.In, error) { - return T.consumer.Read() + for { + in, err := T.consumer.Read() + if err != nil { + return packet.In{}, err + } + switch in.Type() { + case packet.ParseComplete: + // previous parse was successful + case packet.BindComplete: + // previous bind was successful + default: + // forward + return in, nil + } + } } func (T *Stealer) ReadUntyped() (packet.In, error) { @@ -54,22 +68,64 @@ func (T *Stealer) bindPortal(target string, portal Portal) error { return T.consumer.Send(out.Finish()) } +func (T *Stealer) closePreparedStatement(target string) error { + if _, ok := T.consumer.preparedStatements[target]; !ok { + // doesn't exist + return nil + } + T.buf.Reset() + out := packet.MakeOut(&T.buf) + packets.WriteClose(out, 'S', target) + return T.consumer.Send(out.Finish()) +} + +func (T *Stealer) closePortal(target string) error { + if _, ok := T.consumer.portals[target]; !ok { + // doesn't exist + return nil + } + T.buf.Reset() + out := packet.MakeOut(&T.buf) + packets.WriteClose(out, 'P', target) + return T.consumer.Send(out.Finish()) +} + func (T *Stealer) syncPreparedStatement(target string) error { - creatorStatement := T.creator.preparedStatements[target] - consumerStatement := T.consumer.preparedStatements[target] + creatorStatement, ok := T.creator.preparedStatements[target] + if !ok { + return T.closePreparedStatement(target) + } + consumerStatement, prev := T.consumer.preparedStatements[target] if creatorStatement.Equals(consumerStatement) { return nil } + // clean up prev + if prev { + err := T.closePreparedStatement(target) + if err != nil { + return err + } + } // send prepared statement return T.bindPreparedStatement(target, creatorStatement) } func (T *Stealer) syncPortal(target string) error { - creatorPortal := T.creator.portals[target] - consumerPortal := T.consumer.portals[target] + creatorPortal, ok := T.creator.portals[target] + if !ok { + return T.closePortal(target) + } + consumerPortal, prev := T.consumer.portals[target] if creatorPortal.Equals(consumerPortal) { return nil } + // clean up prev + if prev { + err := T.closePortal(target) + if err != nil { + return err + } + } // send portal return T.bindPortal(target, creatorPortal) }