diff --git a/lib/bouncer/bouncers/v2/bctx/context.go b/lib/bouncer/bouncers/v2/bctx/context.go index 94bfecad10222145baf2a167ab9904c50130a8f2..86368b39a894c964773ddd5fb5b938bf30bbe317 100644 --- a/lib/bouncer/bouncers/v2/bctx/context.go +++ b/lib/bouncer/bouncers/v2/bctx/context.go @@ -57,3 +57,19 @@ func (T *Context) ServerWrite(packet *zap.Packet) berr.Error { } return nil } + +func (T *Context) ClientWriteV(packets *zap.Packets) berr.Error { + err := T.client.WriteV(packets) + if err != nil { + return berr.MakeClient(err) + } + return nil +} + +func (T *Context) ServerWriteV(packets *zap.Packets) berr.Error { + err := T.server.WriteV(packets) + if err != nil { + return berr.MakeServer(err) + } + return nil +} diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index 6524238ea31368e26f5ed8f1f52a9e15957b83e0..3256a22bf0247b0338110b7fdc6ccf4b686bc18a 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -31,60 +31,63 @@ func serverRead(ctx *bctx.Context, packet *zap.Packet) berr.Error { } func copyIn(ctx *bctx.Context) berr.Error { - packet := zap.NewPacket() - defer packet.Done() + pkts := zap.NewPackets() + defer pkts.Done() for { + packet := zap.NewPacket() err := ctx.ClientRead(packet) if err != nil { + packet.Done() return err } switch packet.ReadType() { case packets.CopyData: - if err = ctx.ServerWrite(packet); err != nil { - return err - } + pkts.Append(packet) case packets.CopyDone, packets.CopyFail: - if err = ctx.ServerWrite(packet); err != nil { - return err - } + pkts.Append(packet) ctx.CopyIn = false - return nil + return ctx.ServerWriteV(pkts) default: + packet.Done() return berr.ClientUnexpectedPacket } } } func copyOut(ctx *bctx.Context) berr.Error { - packet := zap.NewPacket() - defer packet.Done() + pkts := zap.NewPackets() + defer pkts.Done() for { + packet := zap.NewPacket() err := serverRead(ctx, packet) if err != nil { + packet.Done() return err } switch packet.ReadType() { case packets.CopyData: - if err = ctx.ClientWrite(packet); err != nil { - return err - } + pkts.Append(packet) case packets.CopyDone, packets.ErrorResponse: + pkts.Append(packet) ctx.CopyOut = false - return ctx.ClientWrite(packet) + return ctx.ClientWriteV(pkts) default: + packet.Done() return berr.ServerUnexpectedPacket } } } func query(ctx *bctx.Context) berr.Error { - packet := zap.NewPacket() - defer packet.Done() + pkts := zap.NewPackets() + defer pkts.Done() for { + packet := zap.NewPacket() err := serverRead(ctx, packet) if err != nil { + packet.Done() return err } @@ -96,44 +99,50 @@ func query(ctx *bctx.Context) berr.Error { packets.DataRow, packets.EmptyQueryResponse, packets.ErrorResponse: - if err = ctx.ClientWrite(packet); err != nil { - return err - } + pkts.Append(packet) case packets.CopyInResponse: + pkts.Append(packet) ctx.CopyIn = true - if err = ctx.ClientWrite(packet); err != nil { + if err = ctx.ClientWriteV(pkts); err != nil { return err } + pkts.Clear() if err = copyIn(ctx); err != nil { return err } case packets.CopyOutResponse: + pkts.Append(packet) ctx.CopyOut = true - if err = ctx.ClientWrite(packet); err != nil { + if err = ctx.ClientWriteV(pkts); err != nil { return err } + pkts.Clear() if err = copyOut(ctx); err != nil { return err } case packets.ReadyForQuery: + pkts.Append(packet) ctx.Query = false var ok bool if ctx.TxState, ok = packets.ReadReadyForQuery(&read); !ok { return berr.ServerBadPacket } - return ctx.ClientWrite(packet) + return ctx.ClientWriteV(pkts) default: + packet.Done() return berr.ServerUnexpectedPacket } } } func functionCall(ctx *bctx.Context) berr.Error { - packet := zap.NewPacket() - defer packet.Done() + pkts := zap.NewPackets() + defer pkts.Done() for { + packet := zap.NewPacket() err := serverRead(ctx, packet) if err != nil { + packet.Done() return err } @@ -141,26 +150,30 @@ func functionCall(ctx *bctx.Context) berr.Error { switch read.ReadType() { case packets.ErrorResponse, packets.FunctionCallResponse: - if err = ctx.ClientWrite(packet); err != nil { - return err - } + pkts.Append(packet) case packets.ReadyForQuery: + pkts.Append(packet) ctx.FunctionCall = false var ok bool if ctx.TxState, ok = packets.ReadReadyForQuery(&read); !ok { return berr.ServerBadPacket } - return ctx.ClientWrite(packet) + return ctx.ClientWriteV(pkts) + default: + packet.Done() + return berr.ServerUnexpectedPacket } } } func sync(ctx *bctx.Context) berr.Error { - packet := zap.NewPacket() - defer packet.Done() + pkts := zap.NewPackets() + defer pkts.Done() for { + packet := zap.NewPacket() err := serverRead(ctx, packet) if err != nil { + packet.Done() return err } @@ -178,61 +191,67 @@ func sync(ctx *bctx.Context) berr.Error { packets.DataRow, packets.EmptyQueryResponse, packets.PortalSuspended: - err = ctx.ClientWrite(packet) - if err != nil { - return err - } + pkts.Append(packet) case packets.CopyInResponse: ctx.CopyIn = true - if err = ctx.ClientWrite(packet); err != nil { + pkts.Append(packet) + if err = ctx.ClientWriteV(pkts); err != nil { return err } + pkts.Clear() if err = copyIn(ctx); err != nil { return err } case packets.CopyOutResponse: ctx.CopyOut = true - if err = ctx.ClientWrite(packet); err != nil { + pkts.Append(packet) + if err = ctx.ClientWriteV(pkts); err != nil { return err } + pkts.Clear() if err = copyOut(ctx); err != nil { return err } case packets.ReadyForQuery: + pkts.Append(packet) ctx.Sync = false ctx.EQP = false var ok bool if ctx.TxState, ok = packets.ReadReadyForQuery(&read); !ok { return berr.ServerBadPacket } - return ctx.ClientWrite(packet) + return ctx.ClientWriteV(pkts) default: + packet.Done() return berr.ServerUnexpectedPacket } } } func eqp(ctx *bctx.Context) berr.Error { - packet := zap.NewPacket() - defer packet.Done() + pkts := zap.NewPackets() + defer pkts.Done() for { + packet := zap.NewPacket() err := ctx.ClientRead(packet) if err != nil { + packet.Done() return err } switch packet.ReadType() { case packets.Sync: - if err = ctx.ServerWrite(packet); err != nil { + pkts.Append(packet) + ctx.Sync = true + if err = ctx.ServerWriteV(pkts); err != nil { return err } - ctx.Sync = true + pkts.Clear() return sync(ctx) case packets.Parse, packets.Bind, packets.Close, packets.Describe, packets.Execute, packets.Flush: - if err = ctx.ServerWrite(packet); err != nil { - return err - } + pkts.Append(packet) default: + packet.Done() return berr.ClientUnexpectedPacket } } diff --git a/lib/zap/packet.go b/lib/zap/packet.go index 448d2104a4048b3406326cd3c5ce664a86221041..a71bad013b6f631105d5be10dfb68d79aea53bd6 100644 --- a/lib/zap/packet.go +++ b/lib/zap/packet.go @@ -109,6 +109,12 @@ func (T *Packets) Remove(i int) { T.order = T.order[:len(T.order)-1] } +func (T *Packets) Clear() { + T.order = T.order[:0] + T.packets = T.packets[:0] + T.untypedPackets = T.untypedPackets[:0] +} + func (T *Packets) Done() { // TODO(garet) }