From 35b20ee89f9c6a25d3ea0bccbf2fce623e647392 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Wed, 17 May 2023 15:29:19 -0500 Subject: [PATCH] Pool eqp allocations, leave portals and preparedStatements in packet format closes #3 --- cmd/cgat/main.go | 1 + lib/middleware/middlewares/eqp/client.go | 76 +++--- lib/middleware/middlewares/eqp/close.go | 32 +++ lib/middleware/middlewares/eqp/portal.go | 59 +++-- .../middlewares/eqp/preparedStatement.go | 44 +++- lib/middleware/middlewares/eqp/server.go | 216 +++++++++++------- lib/zap/buf.go | 37 ++- lib/zap/zio/reader.go | 4 + lib/zap/zio/readwriter.go | 4 + lib/zap/zio/writer.go | 4 + 10 files changed, 328 insertions(+), 149 deletions(-) create mode 100644 lib/middleware/middlewares/eqp/close.go diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 8785296f..a89c83b7 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -83,6 +83,7 @@ func main() { go func() { source := r.NewSource() client := zio.MakeReadWriter(conn) + defer client.Done() ob := onebuffer.MakeOnebuffer(&client) eqpc := eqp.MakeClient() mw := interceptor.MakeInterceptor(&ob, []middleware.Middleware{ diff --git a/lib/middleware/middlewares/eqp/client.go b/lib/middleware/middlewares/eqp/client.go index b1a05957..8ae17e53 100644 --- a/lib/middleware/middlewares/eqp/client.go +++ b/lib/middleware/middlewares/eqp/client.go @@ -20,6 +20,33 @@ func MakeClient() Client { } } +func (T *Client) deletePreparedStatement(name string) { + preparedStatement, ok := T.preparedStatements[name] + if !ok { + return + } + preparedStatement.Done() + delete(T.preparedStatements, name) +} + +func (T *Client) deletePortal(name string) { + portal, ok := T.portals[name] + if !ok { + return + } + portal.Done() + delete(T.portals, name) +} + +func (T *Client) Done() { + for name := range T.preparedStatements { + T.deletePreparedStatement(name) + } + for name := range T.portals { + T.deletePortal(name) + } +} + func (T *Client) Send(_ middleware.Context, out zap.Out) error { in := zap.OutToIn(out) switch in.Type() { @@ -29,9 +56,9 @@ func (T *Client) Send(_ middleware.Context, out zap.Out) error { return errors.New("bad packet format") } if state == 'I' { - // clobber all portals + // clobber all named portals for name := range T.portals { - delete(T.portals, name) + T.deletePortal(name) } } case packets.ParseComplete, packets.BindComplete, packets.CloseComplete: @@ -45,24 +72,17 @@ func (T *Client) Read(ctx middleware.Context, in zap.In) error { switch in.Type() { case packets.Query: // clobber unnamed portal and unnamed prepared statement - delete(T.preparedStatements, "") - delete(T.portals, "") + T.deletePreparedStatement("") + T.deletePortal("") case packets.Parse: ctx.Cancel() - destination, query, parameterDataTypes, ok := packets.ReadParse(in) + destination, preparedStatement, ok := ReadParse(in) if !ok { return errors.New("bad packet format") } - if destination != "" { - if _, ok = T.preparedStatements[destination]; ok { - return errors.New("prepared statement already exists") - } - } - T.preparedStatements[destination] = PreparedStatement{ - Query: query, - ParameterDataTypes: parameterDataTypes, - } + + T.preparedStatements[destination] = preparedStatement // send parse complete out := zap.InToOut(in) @@ -75,24 +95,12 @@ func (T *Client) Read(ctx middleware.Context, in zap.In) error { case packets.Bind: ctx.Cancel() - destination, source, parameterFormatCodes, parameterValues, resultFormatCodes, ok := packets.ReadBind(in) + destination, portal, ok := ReadBind(in) if !ok { return errors.New("bad packet format") } - if destination != "" { - if _, ok = T.portals[destination]; ok { - return errors.New("portal already exists") - } - } - if _, ok = T.preparedStatements[source]; !ok { - return errors.New("prepared statement does not exist") - } - T.portals[destination] = Portal{ - Source: source, - ParameterFormatCodes: parameterFormatCodes, - ParameterValues: parameterValues, - ResultFormatCodes: resultFormatCodes, - } + + T.portals[destination] = portal // send bind complete out := zap.InToOut(in) @@ -111,9 +119,9 @@ func (T *Client) Read(ctx middleware.Context, in zap.In) error { } switch which { case 'S': - delete(T.preparedStatements, target) + T.deletePreparedStatement(target) case 'P': - delete(T.portals, target) + T.deletePortal(target) default: return errors.New("bad packet format") } @@ -134,11 +142,11 @@ func (T *Client) Read(ctx middleware.Context, in zap.In) error { } switch which { case 'S': - if _, ok := T.preparedStatements[target]; !ok { + if _, ok = T.preparedStatements[target]; !ok { return errors.New("prepared statement doesn't exist") } case 'P': - if _, ok := T.portals[target]; !ok { + if _, ok = T.portals[target]; !ok { return errors.New("portal doesn't exist") } default: @@ -149,7 +157,7 @@ func (T *Client) Read(ctx middleware.Context, in zap.In) error { if !ok { return errors.New("bad packet format") } - if _, ok := T.portals[target]; !ok { + if _, ok = T.portals[target]; !ok { return errors.New("portal doesn't exist") } } diff --git a/lib/middleware/middlewares/eqp/close.go b/lib/middleware/middlewares/eqp/close.go new file mode 100644 index 00000000..d87883de --- /dev/null +++ b/lib/middleware/middlewares/eqp/close.go @@ -0,0 +1,32 @@ +package eqp + +type Close interface { + Done() + close() +} + +type ClosePortal struct { + target string + portal Portal +} + +func (T ClosePortal) Done() { + T.portal.Done() +} + +func (ClosePortal) close() {} + +var _ Close = ClosePortal{} + +type ClosePreparedStatement struct { + target string + preparedStatement PreparedStatement +} + +func (T ClosePreparedStatement) Done() { + T.preparedStatement.Done() +} + +func (ClosePreparedStatement) close() {} + +var _ Close = ClosePreparedStatement{} diff --git a/lib/middleware/middlewares/eqp/portal.go b/lib/middleware/middlewares/eqp/portal.go index 7ec58e4d..14cda124 100644 --- a/lib/middleware/middlewares/eqp/portal.go +++ b/lib/middleware/middlewares/eqp/portal.go @@ -1,31 +1,50 @@ package eqp -import "pggat2/lib/util/slices" +import ( + "pggat2/lib/global" + "pggat2/lib/util/slices" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) type Portal struct { - Source string - ParameterFormatCodes []int16 - ParameterValues [][]byte - ResultFormatCodes []int16 + source string + raw []byte } -func (T Portal) Equals(rhs Portal) bool { - if T.Source != rhs.Source { - return false +func ReadBind(in zap.In) (destination string, portal Portal, ok bool) { + in.Reset() + if in.Type() != packets.Bind { + return } - if !slices.Equal(T.ParameterFormatCodes, rhs.ParameterFormatCodes) { - return false + destination, ok = in.String() + if !ok { + return } - if len(T.ParameterValues) != len(rhs.ParameterValues) { - return false + portal.source, ok = in.String() + if !ok { + return } - for i := range T.ParameterValues { - if !slices.Equal(T.ParameterValues[i], rhs.ParameterValues[i]) { - return false - } - } - if !slices.Equal(T.ResultFormatCodes, rhs.ResultFormatCodes) { - return false + full := zap.InToOut(in).Full() + portal.raw = global.GetBytes(int32(len(full))) + copy(portal.raw, full) + return +} + +func (T *Portal) Done() { + global.PutBytes(T.raw) + T.raw = nil +} + +func (T *Portal) Equal(rhs *Portal) bool { + return slices.Equal(T.raw, rhs.raw) +} + +func (T *Portal) Clone() Portal { + raw := global.GetBytes(int32(len(T.raw))) + copy(raw, T.raw) + return Portal{ + source: T.source, + raw: raw, } - return true } diff --git a/lib/middleware/middlewares/eqp/preparedStatement.go b/lib/middleware/middlewares/eqp/preparedStatement.go index a710b61b..a156cb43 100644 --- a/lib/middleware/middlewares/eqp/preparedStatement.go +++ b/lib/middleware/middlewares/eqp/preparedStatement.go @@ -1,18 +1,44 @@ package eqp -import "pggat2/lib/util/slices" +import ( + "pggat2/lib/global" + "pggat2/lib/util/slices" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) type PreparedStatement struct { - Query string - ParameterDataTypes []int32 + raw []byte } -func (T PreparedStatement) Equals(rhs PreparedStatement) bool { - if T.Query != rhs.Query { - return false +func ReadParse(in zap.In) (destination string, preparedStatement PreparedStatement, ok bool) { + in.Reset() + if in.Type() != packets.Parse { + return } - if !slices.Equal(T.ParameterDataTypes, rhs.ParameterDataTypes) { - return false + destination, ok = in.String() + if !ok { + return + } + full := zap.InToOut(in).Full() + preparedStatement.raw = global.GetBytes(int32(len(full))) + copy(preparedStatement.raw, full) + return +} + +func (T *PreparedStatement) Done() { + global.PutBytes(T.raw) + T.raw = nil +} + +func (T *PreparedStatement) Equal(rhs *PreparedStatement) bool { + return slices.Equal(T.raw, rhs.raw) +} + +func (T *PreparedStatement) Clone() PreparedStatement { + raw := global.GetBytes(int32(len(T.raw))) + copy(raw, T.raw) + return PreparedStatement{ + raw: raw, } - return true } diff --git a/lib/middleware/middlewares/eqp/server.go b/lib/middleware/middlewares/eqp/server.go index 1c5b92fc..339f5a8d 100644 --- a/lib/middleware/middlewares/eqp/server.go +++ b/lib/middleware/middlewares/eqp/server.go @@ -9,30 +9,13 @@ import ( packets "pggat2/lib/zap/packets/v3.0" ) -type pendingClose interface { - pendingClose() -} - -type pendingClosePreparedStatement struct { - target string - preparedStatement PreparedStatement -} - -func (pendingClosePreparedStatement) pendingClose() {} - -type pendingClosePortal struct { - target string - portal Portal -} - -func (pendingClosePortal) pendingClose() {} - type Server struct { - preparedStatements map[string]PreparedStatement - portals map[string]Portal + preparedStatements map[string]PreparedStatement + portals map[string]Portal + pendingPreparedStatements ring.Ring[string] pendingPortals ring.Ring[string] - pendingCloses ring.Ring[pendingClose] + pendingCloses ring.Ring[Close] buf zap.Buf @@ -50,7 +33,37 @@ func (T *Server) SetClient(client *Client) { T.peer = client } +func (T *Server) deletePreparedStatement(target string) { + v, ok := T.preparedStatements[target] + if !ok { + return + } + v.Done() + delete(T.preparedStatements, target) +} + +func (T *Server) deletePortal(target string) { + v, ok := T.portals[target] + if !ok { + return + } + v.Done() + delete(T.portals, target) +} + func (T *Server) closePreparedStatement(ctx middleware.Context, target string) error { + // no need to close unnamed prepared statement + if target == "" { + return nil + } + + preparedStatement, ok := T.preparedStatements[target] + if !ok { + // already closed + return nil + } + + // send close packet out := T.buf.Write() packets.WriteClose(out, 'S', target) err := ctx.Send(out) @@ -58,9 +71,9 @@ func (T *Server) closePreparedStatement(ctx middleware.Context, target string) e return err } - preparedStatement := T.preparedStatements[target] + // add it to pending delete(T.preparedStatements, target) - T.pendingCloses.PushBack(pendingClosePreparedStatement{ + T.pendingCloses.PushBack(ClosePreparedStatement{ target: target, preparedStatement: preparedStatement, }) @@ -68,6 +81,18 @@ func (T *Server) closePreparedStatement(ctx middleware.Context, target string) e } func (T *Server) closePortal(ctx middleware.Context, target string) error { + // no need to close unnamed portal + if target == "" { + return nil + } + + portal, ok := T.portals[target] + if !ok { + // already closed + return nil + } + + // send close packet out := T.buf.Write() packets.WriteClose(out, 'P', target) err := ctx.Send(out) @@ -75,89 +100,105 @@ func (T *Server) closePortal(ctx middleware.Context, target string) error { return err } - portal := T.portals[target] + // add it to pending delete(T.portals, target) - T.pendingCloses.PushBack(pendingClosePortal{ + T.pendingCloses.PushBack(ClosePortal{ target: target, portal: portal, }) return nil } -func (T *Server) bindPreparedStatement(ctx middleware.Context, target string, preparedStatement PreparedStatement) error { - if target != "" { - if _, ok := T.preparedStatements[target]; ok { - err := T.closePreparedStatement(ctx, target) - if err != nil { - return err - } - } +func (T *Server) bindPreparedStatement( + ctx middleware.Context, + target string, + preparedStatement PreparedStatement, +) error { + err := T.closePreparedStatement(ctx, target) + if err != nil { + return err } - out := T.buf.Write() - packets.WriteParse(out, target, preparedStatement.Query, preparedStatement.ParameterDataTypes) - err := ctx.Send(out) + buf := zap.MakeBuf(preparedStatement.raw) + err = ctx.Send(buf.Out()) if err != nil { return err } - T.preparedStatements[target] = preparedStatement + T.deletePreparedStatement(target) + T.preparedStatements[target] = preparedStatement.Clone() T.pendingPreparedStatements.PushBack(target) return nil } -func (T *Server) bindPortal(ctx middleware.Context, target string, portal Portal) error { - if target != "" { - if _, ok := T.portals[target]; ok { - err := T.closePortal(ctx, target) - if err != nil { - return err - } +func (T *Server) bindPortal( + ctx middleware.Context, + target string, + portal Portal, +) error { + // check if we already have it bound + if old, ok := T.portals[target]; ok { + if old.Equal(&portal) { + return nil } } - out := T.buf.Write() - packets.WriteBind(out, target, portal.Source, portal.ParameterFormatCodes, portal.ParameterValues, portal.ResultFormatCodes) - err := ctx.Send(out) + err := T.closePortal(ctx, target) + if err != nil { + return err + } + + buf := zap.MakeBuf(portal.raw) + err = ctx.Send(buf.Out()) if err != nil { return err } - T.portals[target] = portal + T.deletePortal(target) + T.portals[target] = portal.Clone() T.pendingPortals.PushBack(target) return nil } func (T *Server) syncPreparedStatement(ctx middleware.Context, target string) error { - // we can assume client has the prepared statement because it should be checked by eqp.Client expected := T.peer.preparedStatements[target] - actual, ok := T.preparedStatements[target] - if !ok || !expected.Equals(actual) { - // clear all portals that use this prepared statement - for name, portal := range T.portals { - if portal.Source == target { - err := T.closePortal(ctx, name) - if err != nil { - return err - } + + // check if we already have it bound + if old, ok := T.preparedStatements[target]; ok { + if old.Equal(&expected) { + return nil + } + } + + // clear all portals that use this prepared statement + for name, portal := range T.portals { + if portal.source == target { + err := T.closePortal(ctx, name) + if err != nil { + return err } } - return T.bindPreparedStatement(ctx, target, expected) } - return nil + + return T.bindPreparedStatement(ctx, target, expected) } func (T *Server) syncPortal(ctx middleware.Context, target string) error { expected := T.peer.portals[target] - err := T.syncPreparedStatement(ctx, expected.Source) + + err := T.syncPreparedStatement(ctx, expected.source) if err != nil { return err } - actual, ok := T.portals[target] - if !ok || !expected.Equals(actual) { - return T.bindPortal(ctx, target, expected) + + // check if we already have it bound + if old, ok := T.portals[target]; ok { + if old.Equal(&expected) { + return nil + } } - return nil + + return T.bindPortal(ctx, target, expected) } func (T *Server) Send(ctx middleware.Context, out zap.Out) error { @@ -165,8 +206,8 @@ func (T *Server) Send(ctx middleware.Context, out zap.Out) error { switch in.Type() { case packets.Query: // clobber unnamed portal and unnamed prepared statement - delete(T.preparedStatements, "") - delete(T.portals, "") + T.deletePreparedStatement("") + T.deletePortal("") case packets.Parse, packets.Bind, packets.Close: // should've been caught by eqp.Client panic("unreachable") @@ -174,7 +215,8 @@ func (T *Server) Send(ctx middleware.Context, out zap.Out) error { // ensure target exists which, target, ok := packets.ReadDescribe(in) if !ok { - return errors.New("bad packet format") + // should've been caught by eqp.Client + panic("unreachable") } switch which { case 'S': @@ -190,12 +232,13 @@ func (T *Server) Send(ctx middleware.Context, out zap.Out) error { return err } default: - return errors.New("unknown describe target") + panic("unknown describe target") } case packets.Execute: target, _, ok := packets.ReadExecute(in) if !ok { - return errors.New("bad packet format") + // should've been caught by eqp.Client + panic("unreachable") } // sync portal err := T.syncPortal(ctx, target) @@ -220,7 +263,9 @@ func (T *Server) Read(ctx middleware.Context, in zap.In) error { case packets.CloseComplete: ctx.Cancel() - T.pendingCloses.PopFront() + if c, ok := T.pendingCloses.PopFront(); ok { + c.Done() + } case packets.ReadyForQuery: state, ok := packets.ReadReadyForQuery(in) if !ok { @@ -229,28 +274,43 @@ func (T *Server) Read(ctx middleware.Context, in zap.In) error { if state == 'I' { // clobber all portals for name := range T.portals { - delete(T.portals, name) + T.deletePortal(name) } } // all pending failed for pending, ok := T.pendingPreparedStatements.PopBack(); ok; pending, ok = T.pendingPreparedStatements.PopBack() { - delete(T.preparedStatements, pending) + T.deletePreparedStatement(pending) } for pending, ok := T.pendingPortals.PopBack(); ok; pending, ok = T.pendingPortals.PopBack() { - delete(T.portals, pending) + T.deletePortal(pending) } for pending, ok := T.pendingCloses.PopBack(); ok; pending, ok = T.pendingCloses.PopBack() { switch p := pending.(type) { - case pendingClosePortal: - T.portals[p.target] = p.portal - case pendingClosePreparedStatement: + case ClosePreparedStatement: + T.deletePreparedStatement(p.target) T.preparedStatements[p.target] = p.preparedStatement + case ClosePortal: + T.deletePortal(p.target) + T.portals[p.target] = p.portal default: - panic("what") + panic("unreachable") } } } return nil } +func (T *Server) Done() { + T.buf.Done() + for name := range T.preparedStatements { + T.deletePreparedStatement(name) + } + for name := range T.portals { + T.deletePortal(name) + } + for pending, ok := T.pendingCloses.PopBack(); ok; pending, ok = T.pendingCloses.PopBack() { + pending.Done() + } +} + var _ middleware.Middleware = (*Server)(nil) diff --git a/lib/zap/buf.go b/lib/zap/buf.go index 582a717b..4d02f4af 100644 --- a/lib/zap/buf.go +++ b/lib/zap/buf.go @@ -18,6 +18,12 @@ type Buf struct { rev int } +func MakeBuf(buf []byte) Buf { + return Buf{ + buf: buf, + } +} + func (T *Buf) assertRev(rev int) { // this check can be turned off when in production mode (for dev, this is helpful though) if T.rev != rev { @@ -46,6 +52,20 @@ func (T *Buf) ensureBufExtra(extra int) { } } +func (T *Buf) In() In { + return In{ + buf: T, + rev: T.rev, + } +} + +func (T *Buf) Out() Out { + return Out{ + buf: T, + rev: T.rev, + } +} + func (T *Buf) ReadByte(reader io.Reader) (byte, error) { T.rev++ T.pos = 0 @@ -84,10 +104,7 @@ func (T *Buf) Read(reader io.Reader, typed bool) (In, error) { return In{}, err } - return In{ - buf: T, - rev: T.rev, - }, nil + return T.In(), nil } func (T *Buf) WriteByte(writer io.Writer, b byte) error { @@ -107,10 +124,14 @@ func (T *Buf) Write() Out { T.setBufLen(5) T.buf[0] = 0 - return Out{ - buf: T, - rev: T.rev, - } + return T.Out() +} + +func (T *Buf) Done() { + T.rev++ + T.pos = 0 + global.PutBytes(T.buf) + T.buf = nil } func (T *Buf) full() []byte { diff --git a/lib/zap/zio/reader.go b/lib/zap/zio/reader.go index 3f0cbcfc..c405d25e 100644 --- a/lib/zap/zio/reader.go +++ b/lib/zap/zio/reader.go @@ -37,4 +37,8 @@ func (T *Reader) ReadUntyped() (zap.In, error) { return T.buf.Read(T.r, false) } +func (T *Reader) Done() { + T.buf.Done() +} + var _ zap.Reader = (*Reader)(nil) diff --git a/lib/zap/zio/readwriter.go b/lib/zap/zio/readwriter.go index 0f09732e..d4fedd45 100644 --- a/lib/zap/zio/readwriter.go +++ b/lib/zap/zio/readwriter.go @@ -61,4 +61,8 @@ func (T *ReadWriter) Send(out zap.Out) error { return err } +func (T *ReadWriter) Done() { + T.buf.Done() +} + var _ zap.ReadWriter = (*ReadWriter)(nil) diff --git a/lib/zap/zio/writer.go b/lib/zap/zio/writer.go index 1c283b70..8d5b6c10 100644 --- a/lib/zap/zio/writer.go +++ b/lib/zap/zio/writer.go @@ -38,4 +38,8 @@ func (T *Writer) Send(out zap.Out) error { return err } +func (T *Writer) Done() { + T.buf.Done() +} + var _ zap.Writer = (*Writer)(nil) -- GitLab