diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index c58bc7fbe16034a23b36873bbf76fc2996dc2a87..ebda22774f5b8d0390eff7dcda85546773adf225 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -6,12 +6,11 @@ import ( _ "net/http/pprof" "pggat2/lib/bouncer/backends/v0" - "pggat2/lib/bouncer/bouncers/v0" + "pggat2/lib/bouncer/bouncers/v1" "pggat2/lib/bouncer/frontends/v0" "pggat2/lib/middleware/middlewares/onebuffer" "pggat2/lib/mw2" "pggat2/lib/mw2/interceptor" - "pggat2/lib/mw2/middlewares/eqp" "pggat2/lib/mw2/middlewares/unterminate" "pggat2/lib/rob" "pggat2/lib/rob/schedulers/v2" @@ -20,7 +19,6 @@ import ( ) type job struct { - eqpc *eqp.Client client zap.ReadWriter done chan<- struct{} } @@ -31,19 +29,11 @@ func testServer(r rob.Scheduler) { panic(err) } server := zio.MakeReadWriter(conn) - eqps := eqp.MakeServer() - mw := interceptor.MakeInterceptor( - &server, - []mw2.Middleware{ - &eqps, - }, - ) - backends.Accept(&mw) + backends.Accept(&server) sink := r.NewSink(0) for { j := sink.Read().(job) - eqps.SetClient(j.eqpc) - bouncers.Bounce(j.client, &mw) + bouncers.Bounce(j.client, &server) select { case j.done <- struct{}{}: default: @@ -72,10 +62,8 @@ func main() { source := r.NewSource() client := zio.MakeReadWriter(conn) ob := onebuffer.MakeOnebuffer(&client) - eqpc := eqp.MakeClient() mw := interceptor.MakeInterceptor(&ob, []mw2.Middleware{ unterminate.Unterminate, - &eqpc, }) frontends.Accept(&mw) done := make(chan struct{}) @@ -86,7 +74,6 @@ func main() { break } source.Schedule(job{ - eqpc: &eqpc, client: &mw, done: done, }, 0) diff --git a/lib/bouncer/bouncers/v0/bouncer.go b/lib/bouncer/bouncers/v0/bouncer.go index 5af359f57c676e631c2f37dd9427efe757fb65a3..e3de9c24a62fb94a5c05eb010848b8e5ab10d89d 100644 --- a/lib/bouncer/bouncers/v0/bouncer.go +++ b/lib/bouncer/bouncers/v0/bouncer.go @@ -165,11 +165,19 @@ func query0(client, server zap.ReadWriter) (done bool, status Status) { } return false, Ok case packets.ReadyForQuery: + state, ok := packets.ReadReadyForQuery(in) + if !ok { + serverFail(server, errors.New("bad packet")) + return false, Fail + } err = client.Send(zap.InToOut(in)) if err != nil { clientFail(client, perror.Wrap(err)) return false, Fail } + if state != 'I' { + return true, transaction(client, server) + } return true, Ok default: serverFail(server, errors.New("protocol error")) @@ -214,11 +222,19 @@ func functionCall0(client, server zap.ReadWriter) (done bool, status Status) { } return false, Ok case packets.ReadyForQuery: + state, ok := packets.ReadReadyForQuery(in) + if !ok { + serverFail(server, errors.New("bad packet")) + return false, Fail + } err = client.Send(zap.InToOut(in)) if err != nil { clientFail(client, perror.Wrap(err)) return false, Fail } + if state != 'I' { + return true, transaction(client, server) + } return true, Ok default: serverFail(server, errors.New("protocol error")) @@ -275,11 +291,19 @@ func sync0(client, server zap.ReadWriter) (done bool, status Status) { } return false, Ok case packets.ReadyForQuery: + state, ok := packets.ReadReadyForQuery(in) + if !ok { + serverFail(server, errors.New("bad packet")) + return false, Fail + } err = client.Send(zap.InToOut(in)) if err != nil { clientFail(client, perror.Wrap(err)) return false, Fail } + if state != 'I' { + return true, transaction(client, server) + } return true, Ok default: log.Printf("operation %c", in.Type()) @@ -339,7 +363,7 @@ func eqp(client, server zap.ReadWriter, in zap.In) (status Status) { } } -func Bounce(client, server zap.ReadWriter) { +func transaction(client, server zap.ReadWriter) (status Status) { in, err := client.Read() if err != nil { clientFail(client, perror.Wrap(err)) @@ -348,11 +372,11 @@ func Bounce(client, server zap.ReadWriter) { switch in.Type() { case packets.Query: - query(client, server, in) + return query(client, server, in) case packets.FunctionCall: - functionCall(client, server, in) + return functionCall(client, server, in) case packets.Sync, packets.Parse, packets.Bind, packets.Describe, packets.Execute: - eqp(client, server, in) + return eqp(client, server, in) default: log.Printf("operation %c", in.Type()) clientFail(client, perror.New( @@ -360,6 +384,10 @@ func Bounce(client, server zap.ReadWriter) { perror.FeatureNotSupported, "unsupported operation", )) - return + return Fail } } + +func Bounce(client, server zap.ReadWriter) { + transaction(client, server) +} diff --git a/lib/bouncer/bouncers/v1/bouncer.go b/lib/bouncer/bouncers/v1/bouncer.go new file mode 100644 index 0000000000000000000000000000000000000000..6d1ed212a10bf0a22ee563bfc71a38da094d6b71 --- /dev/null +++ b/lib/bouncer/bouncers/v1/bouncer.go @@ -0,0 +1,172 @@ +package bouncers + +import ( + "errors" + "log" + + "pggat2/lib/perror" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +type queryContext struct { + *transactionContext + done bool +} + +func (T *queryContext) queryDone() { + T.done = true +} + +func query0(ctx *queryContext) Error { + in, err := ctx.readServer() + if err != nil { + return err + } + + switch in.Type() { + case packets.CommandComplete, + packets.RowDescription, + packets.DataRow, + packets.EmptyQueryResponse, + packets.ErrorResponse, + packets.NoticeResponse, + packets.ParameterStatus: + return ctx.sendClient(zap.InToOut(in)) + case packets.CopyInResponse: + // return copyIn(ctx, in) + return nil + case packets.CopyOutResponse: + // return copyOut(ctx, in) + return nil + case packets.ReadyForQuery: + state, ok := packets.ReadReadyForQuery(in) + if !ok { + return makeClientError(packets.ErrBadFormat) + } + err = ctx.sendClient(zap.InToOut(in)) + if err != nil { + return err + } + ctx.queryDone() + if state == 'I' { + ctx.transactionDone() + } + return nil + default: + return makeServerError(errors.New("protocol error")) + } +} + +func query(c *transactionContext, in zap.In) Error { + // send in (initial query) to server + err := c.sendServer(zap.InToOut(in)) + if err != nil { + return err + } + + ctx := queryContext{ + transactionContext: c, + } + for !ctx.done { + err = query0(&ctx) + if err != nil { + return err + } + } + return nil +} + +type transactionContext struct { + *context + done bool +} + +func (T *transactionContext) transactionDone() { + T.done = true +} + +func transaction0(ctx *transactionContext) Error { + in, err := ctx.readClient() + if err != nil { + return err + } + + switch in.Type() { + case packets.Query: + return query(ctx, in) + case packets.FunctionCall: + // return functionCall(ctx, in) + return nil + case packets.Sync, packets.Parse, packets.Bind, packets.Describe, packets.Execute: + // return eqp(ctx, in) + return nil + default: + return makeClientError(perror.New( + perror.ERROR, + perror.FeatureNotSupported, + "unsupported operation", + )) + } +} + +func transaction(c *context) Error { + ctx := transactionContext{ + context: c, + } + for !ctx.done { + err := transaction0(&ctx) + if err != nil { + return err + } + } + return nil +} + +type context struct { + client, server zap.ReadWriter +} + +func (T *context) readClient() (zap.In, Error) { + in, err := T.client.Read() + if err != nil { + return zap.In{}, wrapClientError(err) + } + return in, nil +} + +func (T *context) readServer() (zap.In, Error) { + in, err := T.server.Read() + if err != nil { + return zap.In{}, makeServerError(err) + } + return in, nil +} + +func (T *context) sendClient(out zap.Out) Error { + err := T.client.Send(out) + if err != nil { + return wrapClientError(err) + } + return nil +} + +func (T *context) sendServer(out zap.Out) Error { + err := T.server.Send(out) + if err != nil { + return makeServerError(err) + } + return nil +} + +func Bounce(client, server zap.ReadWriter) { + ctx := context{ + client: client, + server: server, + } + err := transaction(&ctx) + if err != nil { + // TODO(garet) handle error + log.Println(err) + } +} diff --git a/lib/bouncer/bouncers/v1/errors.go b/lib/bouncer/bouncers/v1/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..35e89cc52eba39e8b6ce2d7d76b9c469b13d99fd --- /dev/null +++ b/lib/bouncer/bouncers/v1/errors.go @@ -0,0 +1,35 @@ +package bouncers + +import "pggat2/lib/perror" + +type Error interface { + bounceError() +} + +type ClientError struct { + Error perror.Error +} + +func makeClientError(err perror.Error) ClientError { + return ClientError{ + Error: err, + } +} + +func wrapClientError(err error) ClientError { + return makeClientError(perror.Wrap(err)) +} + +func (ClientError) bounceError() {} + +type ServerError struct { + Error error +} + +func makeServerError(err error) ServerError { + return ServerError{ + Error: err, + } +} + +func (ServerError) bounceError() {}