diff --git a/lib/bouncer/bouncers/v1/bouncer.go b/lib/bouncer/bouncers/v1/bouncer.go index a17626185906838b4164364d35c0d0e53c45844a..43ae7ba5d8fe763305b9de179619892f97dad861 100644 --- a/lib/bouncer/bouncers/v1/bouncer.go +++ b/lib/bouncer/bouncers/v1/bouncer.go @@ -61,27 +61,30 @@ func copyIn0(ctx *bctx.Context) berr.Error { } func copyInRecoverServer(ctx *bctx.Context, err berr.Error) { + if !ctx.InCopyIn() { + return + } // send copyFail to server, will stop server copy out := ctx.ServerWrite() out.Type(packets.CopyFail) out.String(fmt.Sprintf("client error: %s", err.String())) _ = ctx.ServerSend(out) - - ctx.EndCopyIn() } func copyInRecoverClient(ctx *bctx.Context, err berr.Error) { + if !ctx.InCopyIn() { + return + } // send error to client, will stop client copy out := ctx.ClientWrite() packets.WriteErrorResponse(out, err.PError()) _ = ctx.ClientSend(out) - - ctx.EndCopyIn() } func copyInRecover(ctx *bctx.Context, err berr.Error) { - copyInRecoverServer(ctx, err) copyInRecoverClient(ctx, err) + copyInRecoverServer(ctx, err) + ctx.EndCopyIn() } func copyIn(ctx *bctx.Context) { @@ -109,9 +112,7 @@ func copyOut0(ctx *bctx.Context) berr.Error { ctx.EndCopyOut() return ctx.ClientProxy(in) default: - log.Printf("unexpected packet %c\n", in.Type()) panic("unexpected packet from server") - return berr.ServerProtocolError } } @@ -133,24 +134,31 @@ func copyOutRecoverServer0(ctx *bctx.Context) { } func copyOutRecoverServer(ctx *bctx.Context, _ berr.Error) { + if !ctx.InCopyOut() { + return + } // read until server is done with its copy for ctx.InCopyOut() { copyOutRecoverServer0(ctx) } + ctx.BeginCopyOut() } func copyOutRecoverClient(ctx *bctx.Context, err berr.Error) { + if !ctx.InCopyIn() { + return + } // send error to client, will stop client copy out := ctx.ClientWrite() packets.WriteErrorResponse(out, err.PError()) _ = ctx.ClientSend(out) - - ctx.EndCopyOut() } func copyOutRecover(ctx *bctx.Context, err berr.Error) { - copyOutRecoverServer(ctx, err) + log.Println("recover from copyOut") copyOutRecoverClient(ctx, err) + copyOutRecoverServer(ctx, err) + ctx.EndCopyOut() } func copyOut(ctx *bctx.Context) { @@ -200,9 +208,7 @@ func query0(ctx *bctx.Context) berr.Error { } return readyForQuery(ctx, in) default: - log.Printf("unexpected packet %c\n", in.Type()) panic("unexpected packet from server") - return berr.ServerProtocolError } } @@ -222,9 +228,11 @@ func queryRecoverServer0(ctx *bctx.Context, err berr.Error) { case packets.CopyInResponse: ctx.BeginCopyIn() copyInRecoverServer(ctx, err) + ctx.EndCopyIn() case packets.CopyOutResponse: ctx.BeginCopyOut() copyOutRecoverServer(ctx, err) + ctx.EndCopyOut() case packets.ReadyForQuery: ctx.EndQuery() readyForQuery(ctx, in) @@ -235,6 +243,10 @@ func queryRecoverServer0(ctx *bctx.Context, err berr.Error) { // serverTransactionFail ensures the server is in a failed txn block func serverTransactionFail(ctx *bctx.Context, err berr.Error) { + if !ctx.InTransaction() { + return + } + log.Println("fail transaction") // we need to change this to a failed transaction block, write a simple query that will fail out := ctx.ServerWrite() out.Type(packets.Query) @@ -247,15 +259,20 @@ func serverTransactionFail(ctx *bctx.Context, err berr.Error) { } func queryRecoverServer(ctx *bctx.Context, err berr.Error) { + if !ctx.InQuery() { + return + } for ctx.InQuery() { queryRecoverServer0(ctx, err) } - if ctx.InTransaction() { - serverTransactionFail(ctx, err) - } + serverTransactionFail(ctx, err) + ctx.BeginQuery() } func queryRecoverClient(ctx *bctx.Context, err berr.Error) { + if !ctx.InQuery() { + return + } // send error to client followed by ready for query out := ctx.ClientWrite() packets.WriteErrorResponse(out, err.PError()) @@ -267,13 +284,13 @@ func queryRecoverClient(ctx *bctx.Context, err berr.Error) { packets.WriteReadyForQuery(out, 'I') } _ = ctx.ClientSend(out) - - ctx.EndQuery() } func queryRecover(ctx *bctx.Context, err berr.Error) { - queryRecoverServer(ctx, err) + log.Println("recover from query") queryRecoverClient(ctx, err) + queryRecoverServer(ctx, err) + ctx.EndQuery() } func query(ctx *bctx.Context) { @@ -307,9 +324,7 @@ func functionCall0(ctx *bctx.Context) berr.Error { } return readyForQuery(ctx, in) default: - log.Printf("unexpected packet %c\n", in.Type()) panic("unexpected packet from server") - return berr.ServerProtocolError } } @@ -331,15 +346,20 @@ func functionCallRecoverServer0(ctx *bctx.Context) { } func functionCallRecoverServer(ctx *bctx.Context, err berr.Error) { + if !ctx.InFunctionCall() { + return + } for ctx.InFunctionCall() { functionCallRecoverServer0(ctx) } - if ctx.InTransaction() { - serverTransactionFail(ctx, err) - } + serverTransactionFail(ctx, err) + ctx.BeginFunctionCall() } func functionCallRecoverClient(ctx *bctx.Context, err berr.Error) { + if !ctx.InFunctionCall() { + return + } // send error to client followed by ready for query, will stop client function call out := ctx.ClientWrite() packets.WriteErrorResponse(out, err.PError()) @@ -351,13 +371,13 @@ func functionCallRecoverClient(ctx *bctx.Context, err berr.Error) { packets.WriteReadyForQuery(out, 'I') } _ = ctx.ClientSend(out) - - ctx.EndFunctionCall() } func functionCallRecover(ctx *bctx.Context, err berr.Error) { - functionCallRecoverServer(ctx, err) + log.Println("recover from functionCall") functionCallRecoverClient(ctx, err) + functionCallRecoverServer(ctx, err) + ctx.EndFunctionCall() } func functionCall(ctx *bctx.Context) { @@ -399,9 +419,7 @@ func sync0(ctx *bctx.Context) berr.Error { } return readyForQuery(ctx, in) default: - log.Printf("unexpected packet %c\n", in.Type()) panic("unexpected packet from server") - return berr.ServerProtocolError } } @@ -434,15 +452,20 @@ func syncRecoverServer0(ctx *bctx.Context, _ berr.Error) { } func syncRecoverServer(ctx *bctx.Context, err berr.Error) { + if !ctx.InSync() { + return + } for ctx.InSync() { syncRecoverServer0(ctx, err) } - if ctx.InTransaction() { - serverTransactionFail(ctx, err) - } + serverTransactionFail(ctx, err) + ctx.BeginSync() } func syncRecoverClient(ctx *bctx.Context, err berr.Error) { + if !ctx.InSync() { + return + } // send error to client followed by ready for query out := ctx.ClientWrite() packets.WriteErrorResponse(out, err.PError()) @@ -454,13 +477,13 @@ func syncRecoverClient(ctx *bctx.Context, err berr.Error) { packets.WriteReadyForQuery(out, 'I') } _ = ctx.ClientSend(out) - - ctx.EndSync() } func syncRecover(ctx *bctx.Context, err berr.Error) { - syncRecoverServer(ctx, err) + log.Println("recover from sync") syncRecoverClient(ctx, err) + syncRecoverServer(ctx, err) + ctx.EndSync() } func sync(ctx *bctx.Context) { @@ -527,8 +550,12 @@ func transactionRecoverServer(ctx *bctx.Context, err berr.Error) { } ctx.BeginSync() syncRecoverServer(ctx, err) + ctx.EndSync() + ctx.BeginEQP() } if ctx.InTransaction() { + // we need to fail this transaction + serverTransactionFail(ctx, err) // send END to break out of transaction and wait for ready for query out := ctx.ServerWrite() out.Type(packets.Query) @@ -539,24 +566,29 @@ func transactionRecoverServer(ctx *bctx.Context, err berr.Error) { } ctx.BeginQuery() queryRecoverServer(ctx, err) + ctx.EndQuery() + ctx.BeginTransaction() } } func transactionRecoverClient(ctx *bctx.Context, err berr.Error) { + if !ctx.InTransaction() && !ctx.InEQP() { + return + } out := ctx.ClientWrite() packets.WriteErrorResponse(out, err.PError()) _ = ctx.ClientSend(out) out = ctx.ClientWrite() packets.WriteReadyForQuery(out, 'I') _ = ctx.ClientSend(out) - - ctx.EndEQP() - ctx.EndTransaction() } func transactionRecover(ctx *bctx.Context, err berr.Error) { - transactionRecoverServer(ctx, err) + log.Println("recover from transaction") transactionRecoverClient(ctx, err) + transactionRecoverServer(ctx, err) + ctx.EndEQP() + ctx.EndTransaction() } func transaction(ctx *bctx.Context) {