diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index 89df7fbb87436cdac4c4bce893ce2b43f27329ac..44e57a79cbd5e0d9727d3ffbc1ed942190db6116 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -142,15 +142,15 @@ func FunctionCall(ctx *Context, server fed.ReadWriter, packet fed.Packet) error } } -func Sync(ctx *Context, server fed.ReadWriter) error { +func Sync(ctx *Context, server fed.ReadWriter) (bool, error) { if err := server.WritePacket(fed.NewPacket(packets.TypeSync)); err != nil { - return err + return false, err } for { packet, err := server.ReadPacket(true) if err != nil { - return err + return false, err } switch packet.Type() { @@ -173,22 +173,24 @@ func Sync(ctx *Context, server fed.ReadWriter) error { ctx.PeerWrite(packet) case packets.TypeCopyInResponse: if err = CopyIn(ctx, server, packet); err != nil { - return err + return false, err } + // why + return false, nil case packets.TypeCopyOutResponse: if err = CopyOut(ctx, server, packet); err != nil { - return err + return false, err } case packets.TypeReadyForQuery: var txState packets.ReadyForQuery if !txState.ReadFromPacket(packet) { - return ErrBadFormat + return false, ErrBadFormat } ctx.TxState = byte(txState) ctx.PeerWrite(packet) - return nil + return true, nil default: - return ErrUnexpectedPacket + return false, ErrUnexpectedPacket } } } @@ -201,12 +203,26 @@ func EQP(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { for { packet = ctx.PeerRead() if packet == nil { - return Sync(ctx, server) + for { + ok, err := Sync(ctx, server) + if err != nil { + return err + } + if ok { + return nil + } + } } switch packet.Type() { case packets.TypeSync: - return Sync(ctx, server) + ok, err := Sync(ctx, server) + if err != nil { + return err + } + if ok { + return nil + } case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: if err := server.WritePacket(packet); err != nil { return err diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go index 848ec954ca9ee9446c52392636143ded302193c7..57b818396db2faa79e0b92cfbaf180d5c9b3e9a6 100644 --- a/lib/middleware/middlewares/eqp/sync.go +++ b/lib/middleware/middlewares/eqp/sync.go @@ -64,5 +64,6 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { } } - return backends.Sync(new(backends.Context), server) + _, err := backends.Sync(new(backends.Context), server) + return err } diff --git a/test/tester_test.go b/test/tester_test.go index a10017679573ca400fa17652a4d1b5e396e2730d..0365a5bcd433ca2a7fa80bf351276bc476599aa5 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "fmt" "net" + _ "net/http/pprof" "strconv" "testing" diff --git a/test/tests/copy_in.go b/test/tests/copy_in.go index 18f85261b5c46b411ec285233427e84d027d2f18..1bbe5bec23f902dc0c36421e231c4dd1a146b696 100644 --- a/test/tests/copy_in.go +++ b/test/tests/copy_in.go @@ -34,6 +34,7 @@ var CopyIn1 = test.Test{ inst.CopyData{49, 50, 51, 9, 104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 10}, inst.CopyData{45, 51, 50, 52, 9, 103, 97, 114, 101, 116, 32, 119, 97, 115, 32, 104, 101, 114, 101, 10}, inst.CopyDone{}, + inst.Sync{}, inst.SimpleQuery("DROP TABLE test"), }, }