diff --git a/lib/bouncer/bouncers/v2/bctx/context.go b/lib/bouncer/bouncers/v2/bctx/context.go index 02cc0d282a227b7d5f2c5f158408cb6aca23f3ea..63d3e2439ebe69b39f312a809a6d88cceb0a299e 100644 --- a/lib/bouncer/bouncers/v2/bctx/context.go +++ b/lib/bouncer/bouncers/v2/bctx/context.go @@ -7,7 +7,15 @@ import ( type Context struct { client, server zap.ReadWriter - TxState byte + + // state (for flow and recovery) + CopyOut bool + CopyIn bool + Query bool + FunctionCall bool + Sync bool + EQP bool + TxState byte } func MakeContext(client, server zap.ReadWriter) Context { diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index eb315f5452066c42d86a73b6d50086de41254bc3..1ab7f3e5df4b342c7fb24b4fcb4f2a944db5a6d0 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -1,8 +1,6 @@ package bouncers import ( - "log" - "pggat2/lib/bouncer/bouncers/v2/bctx" "pggat2/lib/bouncer/bouncers/v2/berr" "pggat2/lib/bouncer/bouncers/v2/rserver" @@ -43,7 +41,11 @@ func copyIn(ctx *bctx.Context) berr.Error { return err } case packets.CopyDone, packets.CopyFail: - return ctx.ServerProxy(in) + if err = ctx.ServerProxy(in); err != nil { + return err + } + ctx.CopyIn = false + return nil default: return berr.ClientUnexpectedPacket } @@ -63,9 +65,9 @@ func copyOut(ctx *bctx.Context) berr.Error { return err } case packets.CopyDone, packets.ErrorResponse: + ctx.CopyOut = false return ctx.ClientProxy(in) default: - log.Println("a") return berr.ServerUnexpectedPacket } } @@ -88,6 +90,7 @@ func query(ctx *bctx.Context) berr.Error { return err } case packets.CopyInResponse: + ctx.CopyIn = true if err = ctx.ClientProxy(in); err != nil { return err } @@ -95,6 +98,7 @@ func query(ctx *bctx.Context) berr.Error { return err } case packets.CopyOutResponse: + ctx.CopyOut = true if err = ctx.ClientProxy(in); err != nil { return err } @@ -102,13 +106,13 @@ func query(ctx *bctx.Context) berr.Error { return err } case packets.ReadyForQuery: + ctx.Query = false var ok bool if ctx.TxState, ok = packets.ReadReadyForQuery(in); !ok { return berr.ServerBadPacket } return ctx.ClientProxy(in) default: - log.Println("b") return berr.ServerUnexpectedPacket } } @@ -127,6 +131,7 @@ func functionCall(ctx *bctx.Context) berr.Error { return err } case packets.ReadyForQuery: + ctx.FunctionCall = false var ok bool if ctx.TxState, ok = packets.ReadReadyForQuery(in); !ok { return berr.ServerBadPacket @@ -159,14 +164,31 @@ func sync(ctx *bctx.Context) berr.Error { if err != nil { return err } + case packets.CopyInResponse: + ctx.CopyIn = true + if err = ctx.ClientProxy(in); err != nil { + return err + } + if err = copyIn(ctx); err != nil { + return err + } + case packets.CopyOutResponse: + ctx.CopyOut = true + if err = ctx.ClientProxy(in); err != nil { + return err + } + if err = copyOut(ctx); err != nil { + return err + } case packets.ReadyForQuery: + ctx.Sync = false + ctx.EQP = false var ok bool if ctx.TxState, ok = packets.ReadReadyForQuery(in); !ok { return berr.ServerBadPacket } return ctx.ClientProxy(in) default: - log.Println("c", in.Type()) return berr.ServerUnexpectedPacket } } @@ -184,6 +206,7 @@ func eqp(ctx *bctx.Context) berr.Error { if err = ctx.ServerProxy(in); err != nil { return err } + ctx.Sync = true return sync(ctx) case packets.Parse, packets.Bind, packets.Close, packets.Describe, packets.Execute, packets.Flush: if err = ctx.ServerProxy(in); err != nil { @@ -207,6 +230,7 @@ func transaction(ctx *bctx.Context) berr.Error { if err = ctx.ServerProxy(in); err != nil { return err } + ctx.Query = true if err = query(ctx); err != nil { return err } @@ -214,6 +238,7 @@ func transaction(ctx *bctx.Context) berr.Error { if err = ctx.ServerProxy(in); err != nil { return err } + ctx.FunctionCall = true if err = functionCall(ctx); err != nil { return err } @@ -228,6 +253,7 @@ func transaction(ctx *bctx.Context) berr.Error { if err = ctx.ServerProxy(in); err != nil { return err } + ctx.EQP = true if err = eqp(ctx); err != nil { return err } diff --git a/lib/bouncer/bouncers/v2/rserver/recoverer.go b/lib/bouncer/bouncers/v2/rserver/recoverer.go index 8daa59a74f6486b731474240fc8c400bf3402fbc..dc5ab09fac432c7029b2818aa7718a27d0184f1b 100644 --- a/lib/bouncer/bouncers/v2/rserver/recoverer.go +++ b/lib/bouncer/bouncers/v2/rserver/recoverer.go @@ -2,29 +2,225 @@ package rserver import ( "pggat2/lib/bouncer/bouncers/v2/bctx" + "pggat2/lib/bouncer/bouncers/v2/berr" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" ) +func serverRead(ctx *bctx.Context) (zap.In, error) { + for { + in, err := ctx.ServerRead() + if err != nil { + return zap.In{}, err + } + switch in.Type() { + case packets.NoticeResponse, + packets.ParameterStatus, + packets.NotificationResponse: + continue + default: + return in, nil + } + } +} + +func copyIn(ctx *bctx.Context) error { + // send copy fail + out := ctx.ServerWrite() + out.Type(packets.CopyFail) + out.String("client failed") + if err := ctx.ServerSend(out); err != nil { + return err + } + ctx.CopyIn = false + return nil +} + +func copyOut(ctx *bctx.Context) error { + for { + in, err := serverRead(ctx) + if err != nil { + return err + } + + switch in.Type() { + case packets.CopyData: + continue + case packets.CopyDone, packets.ErrorResponse: + ctx.CopyOut = false + return nil + default: + return berr.ServerUnexpectedPacket + } + } +} + +func query(ctx *bctx.Context) error { + for { + in, err := serverRead(ctx) + if err != nil { + return err + } + + switch in.Type() { + case packets.CommandComplete, + packets.RowDescription, + packets.DataRow, + packets.EmptyQueryResponse, + packets.ErrorResponse: + continue + case packets.CopyInResponse: + ctx.CopyIn = true + if err = copyIn(ctx); err != nil { + return err + } + case packets.CopyOutResponse: + ctx.CopyOut = true + if err = copyOut(ctx); err != nil { + return err + } + case packets.ReadyForQuery: + ctx.Query = false + var ok bool + if ctx.TxState, ok = packets.ReadReadyForQuery(in); !ok { + return berr.ServerBadPacket + } + return nil + default: + return berr.ServerUnexpectedPacket + } + } +} + +func functionCall(ctx *bctx.Context) error { + for { + in, err := serverRead(ctx) + if err != nil { + return err + } + + switch in.Type() { + case packets.ErrorResponse, packets.FunctionCallResponse: + continue + case packets.ReadyForQuery: + ctx.FunctionCall = false + var ok bool + if ctx.TxState, ok = packets.ReadReadyForQuery(in); !ok { + return berr.ServerBadPacket + } + return nil + default: + return berr.ServerUnexpectedPacket + } + } +} + +func sync(ctx *bctx.Context) error { + for { + in, err := serverRead(ctx) + if err != nil { + return err + } + + switch in.Type() { + case packets.ParseComplete, + packets.BindComplete, + packets.ErrorResponse, + packets.RowDescription, + packets.NoData, + packets.ParameterDescription, + + packets.CommandComplete, + packets.DataRow, + packets.EmptyQueryResponse, + packets.PortalSuspended: + continue + case packets.CopyInResponse: + ctx.CopyIn = true + if err = copyIn(ctx); err != nil { + return err + } + case packets.CopyOutResponse: + ctx.CopyOut = true + if err = copyOut(ctx); err != nil { + return err + } + case packets.ReadyForQuery: + ctx.Sync = false + ctx.EQP = false + var ok bool + if ctx.TxState, ok = packets.ReadReadyForQuery(in); !ok { + return berr.ServerBadPacket + } + return nil + default: + return berr.ServerUnexpectedPacket + } + } +} + +func eqp(ctx *bctx.Context) error { + // send sync + out := ctx.ServerWrite() + out.Type(packets.Sync) + if err := ctx.ServerSend(out); err != nil { + return err + } + ctx.Sync = true + + // handle sync + return sync(ctx) +} + +func transaction(ctx *bctx.Context) error { + // write Query('ABORT;') + out := ctx.ServerWrite() + out.Type(packets.Query) + out.String("ABORT;") + if err := ctx.ServerSend(out); err != nil { + return err + } + ctx.Query = true + + // handle query + return query(ctx) +} + func Recover(ctx *bctx.Context) error { - if inCopyOut { - // TODO(garet) wait for CopyDone or ErrorResponse + if ctx.CopyOut { + if err := copyOut(ctx); err != nil { + return err + } } - if inCopyIn { - // TODO(garet) send CopyFail + if ctx.CopyIn { + if err := copyIn(ctx); err != nil { + return err + } } - if inQuery { - // TODO(garet) wait for ready for query, waiting for copyOut if it happens, failing copyIn if it happens + if ctx.Query { + if err := query(ctx); err != nil { + return err + } } - if inFunctionCall { - // TODO(garet) wait for ready for query + if ctx.FunctionCall { + if err := functionCall(ctx); err != nil { + return err + } } - if inSync { - // TODO(garet) wait for ready for query + if ctx.Sync { + if err := sync(ctx); err != nil { + return err + } } - if inEQP { - // TODO(garet) send sync and wait for ready for query + if ctx.EQP { + if err := eqp(ctx); err != nil { + return err + } } if ctx.TxState != 'I' { - // TODO(garet) send Query('ABORT;') and wait for ReadyForQuery + if err := transaction(ctx); err != nil { + return err + } } return nil }