From 542fd0eb1177b7f1ef97efe13d11d9bae7182999 Mon Sep 17 00:00:00 2001 From: Garet Halliday <me@garet.holiday> Date: Tue, 16 May 2023 20:21:08 -0500 Subject: [PATCH] man --- lib/middleware/middlewares/eqp2/pooler.go | 237 ++++++++++++++++++ lib/middleware/middlewares/eqp2/portal.go | 34 +++ .../middlewares/eqp2/preparedStatement.go | 21 ++ lib/util/pools/locked.go | 11 +- lib/util/pools/pool.go | 16 ++ lib/util/slices/cloneinto.go | 9 + lib/zap/packets/v3.0/bind.go | 5 +- 7 files changed, 324 insertions(+), 9 deletions(-) create mode 100644 lib/middleware/middlewares/eqp2/pooler.go create mode 100644 lib/middleware/middlewares/eqp2/portal.go create mode 100644 lib/middleware/middlewares/eqp2/preparedStatement.go create mode 100644 lib/util/pools/pool.go create mode 100644 lib/util/slices/cloneinto.go diff --git a/lib/middleware/middlewares/eqp2/pooler.go b/lib/middleware/middlewares/eqp2/pooler.go new file mode 100644 index 00000000..944ab96a --- /dev/null +++ b/lib/middleware/middlewares/eqp2/pooler.go @@ -0,0 +1,237 @@ +package eqp2 + +import ( + "pggat2/lib/util/pools" + "pggat2/lib/util/slices" + "pggat2/lib/zap" + packets "pggat2/lib/zap/packets/v3.0" +) + +type Pooler struct { + uint8Slice pools.Pool[[]byte] + uint8SliceSlice pools.Pool[[][]byte] + int16Slice pools.Pool[[]int16] + int32Slice pools.Pool[[]int32] + portal pools.Pool[*Portal] + preparedStatement pools.Pool[*PreparedStatement] +} + +func (T *Pooler) PutUint8Slice(v []byte) { + if v == nil { + return + } + T.uint8Slice.Put(v[:0]) +} + +func (T *Pooler) PutUint8SliceSlice(v [][]byte) { + if v == nil { + return + } + for _, b := range v { + T.PutUint8Slice(b) + } + T.uint8SliceSlice.Put(v[:0]) +} + +func (T *Pooler) PutInt16Slice(v []int16) { + if v == nil { + return + } + T.int16Slice.Put(v[:0]) +} + +func (T *Pooler) PutInt32Slice(v []int32) { + if v == nil { + return + } + T.int32Slice.Put(v[:0]) +} + +func (T *Pooler) PutPortal(portal *Portal) { + if portal == nil { + return + } + T.PutInt16Slice(portal.ParameterFormatCodes) + T.PutUint8SliceSlice(portal.ParameterValues) + T.PutInt16Slice(portal.ResultFormatCodes) + *portal = Portal{} + T.portal.Put(portal) +} + +func (T *Pooler) PutPreparedStatement(preparedStatement *PreparedStatement) { + if preparedStatement == nil { + return + } + T.PutUint8Slice(preparedStatement.Query) + T.PutInt32Slice(preparedStatement.ParameterDataTypes) + *preparedStatement = PreparedStatement{} + T.preparedStatement.Put(preparedStatement) +} + +func (T *Pooler) GetUint8Slice() []byte { + v, _ := T.uint8Slice.Get() + return v +} + +func (T *Pooler) GetUint8SliceSlice() [][]byte { + v, _ := T.uint8SliceSlice.Get() + return v +} + +func (T *Pooler) GetInt16Slice() []int16 { + v, _ := T.int16Slice.Get() + return v +} + +func (T *Pooler) GetInt32Slice() []int32 { + v, _ := T.int32Slice.Get() + return v +} + +func (T *Pooler) GetPortal() *Portal { + v, ok := T.portal.Get() + if !ok { + v = &Portal{} + } + return v +} + +func (T *Pooler) GetPreparedStatement() *PreparedStatement { + v, ok := T.preparedStatement.Get() + if !ok { + v = &PreparedStatement{} + } + return v +} + +func (T *Pooler) ClonePortal(portal *Portal) *Portal { + clone := T.GetPortal() + clone.Source = portal.Source + clone.ParameterFormatCodes = slices.CloneInto(T.GetInt16Slice(), portal.ParameterFormatCodes) + clone.ParameterValues = slices.Resize(T.GetUint8SliceSlice(), len(portal.ParameterValues)) + for i, v := range portal.ParameterValues { + clone.ParameterValues[i] = slices.CloneInto(T.GetUint8Slice(), v) + } + clone.ResultFormatCodes = slices.CloneInto(T.GetInt16Slice(), portal.ResultFormatCodes) + return clone +} + +func (T *Pooler) ClonePreparedStatement(preparedStatement *PreparedStatement) *PreparedStatement { + clone := T.GetPreparedStatement() + clone.Query = slices.CloneInto(T.GetUint8Slice(), preparedStatement.Query) + clone.ParameterDataTypes = slices.CloneInto(T.GetInt32Slice(), preparedStatement.ParameterDataTypes) + return clone +} + +func (T *Pooler) ReadBind(in zap.In) (destination string, portal *Portal, ok bool) { + in.Reset() + if in.Type() != packets.Bind { + return + } + destination, ok = in.String() + if !ok { + return + } + portal = T.GetPortal() + portal.Source, ok = in.String() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + var parameterFormatCodesLength uint16 + parameterFormatCodesLength, ok = in.Uint16() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + portal.ParameterFormatCodes = slices.Resize(T.GetInt16Slice(), int(parameterFormatCodesLength)) + for i := 0; i < int(parameterFormatCodesLength); i++ { + portal.ParameterFormatCodes[i], ok = in.Int16() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + } + var parameterValuesLength uint16 + parameterValuesLength, ok = in.Uint16() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + portal.ParameterValues = slices.Resize(T.GetUint8SliceSlice(), int(parameterValuesLength)) + for i := 0; i < int(parameterValuesLength); i++ { + var parameterValueLength int32 + parameterValueLength, ok = in.Int32() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + if parameterValueLength >= 0 { + portal.ParameterValues[i] = slices.Resize(T.GetUint8Slice(), int(parameterValueLength)) + ok = in.Bytes(portal.ParameterValues[i]) + if !ok { + T.PutPortal(portal) + portal = nil + return + } + } + } + var resultFormatCodesLength uint16 + resultFormatCodesLength, ok = in.Uint16() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + portal.ResultFormatCodes = slices.Resize(T.GetInt16Slice(), int(resultFormatCodesLength)) + for i := 0; i < int(resultFormatCodesLength); i++ { + portal.ResultFormatCodes[i], ok = in.Int16() + if !ok { + T.PutPortal(portal) + portal = nil + return + } + } + return +} + +func (T *Pooler) ReadParse(in zap.In) (destination string, preparedStatement *PreparedStatement, ok bool) { + in.Reset() + if in.Type() != packets.Parse { + return "", nil, false + } + + destination, ok = in.String() + if !ok { + return + } + preparedStatement = T.GetPreparedStatement() + preparedStatement.Query, ok = in.StringBytes(T.GetUint8Slice()) + if !ok { + T.PutPreparedStatement(preparedStatement) + preparedStatement = nil + return + } + var parameterDataTypesCount int16 + parameterDataTypesCount, ok = in.Int16() + if !ok { + T.PutPreparedStatement(preparedStatement) + preparedStatement = nil + return + } + preparedStatement.ParameterDataTypes = slices.Resize(T.GetInt32Slice(), int(parameterDataTypesCount)) + for i := 0; i < int(parameterDataTypesCount); i++ { + preparedStatement.ParameterDataTypes[i], ok = in.Int32() + if !ok { + T.PutPreparedStatement(preparedStatement) + preparedStatement = nil + return + } + } + return +} diff --git a/lib/middleware/middlewares/eqp2/portal.go b/lib/middleware/middlewares/eqp2/portal.go new file mode 100644 index 00000000..62e4d62d --- /dev/null +++ b/lib/middleware/middlewares/eqp2/portal.go @@ -0,0 +1,34 @@ +package eqp2 + +import "pggat2/lib/util/slices" + +type Portal struct { + Source string + ParameterFormatCodes []int16 + ParameterValues [][]byte + ResultFormatCodes []int16 +} + +func (T *Portal) Equals(rhs *Portal) bool { + if T == rhs { + return true + } + if T.Source != rhs.Source { + return false + } + if !slices.Equal(T.ParameterFormatCodes, rhs.ParameterFormatCodes) { + return false + } + if len(T.ParameterValues) != len(rhs.ParameterValues) { + return false + } + for i := range T.ParameterValues { + if !slices.Equal(T.ParameterValues[i], rhs.ParameterValues[i]) { + return false + } + } + if !slices.Equal(T.ResultFormatCodes, rhs.ResultFormatCodes) { + return false + } + return true +} diff --git a/lib/middleware/middlewares/eqp2/preparedStatement.go b/lib/middleware/middlewares/eqp2/preparedStatement.go new file mode 100644 index 00000000..88858099 --- /dev/null +++ b/lib/middleware/middlewares/eqp2/preparedStatement.go @@ -0,0 +1,21 @@ +package eqp2 + +import "pggat2/lib/util/slices" + +type PreparedStatement struct { + Query []byte + ParameterDataTypes []int32 +} + +func (T *PreparedStatement) Equals(rhs *PreparedStatement) bool { + if T == rhs { + return true + } + if !slices.Equal(T.Query, rhs.Query) { + return false + } + if !slices.Equal(T.ParameterDataTypes, rhs.ParameterDataTypes) { + return false + } + return true +} diff --git a/lib/util/pools/locked.go b/lib/util/pools/locked.go index 122aa742..54ace323 100644 --- a/lib/util/pools/locked.go +++ b/lib/util/pools/locked.go @@ -3,23 +3,18 @@ package pools import "sync" type Locked[T any] struct { - inner []T + inner Pool[T] mu sync.Mutex } func (L *Locked[T]) Get() (T, bool) { L.mu.Lock() defer L.mu.Unlock() - if len(L.inner) == 0 { - return *new(T), false - } - v := L.inner[len(L.inner)-1] - L.inner = L.inner[:len(L.inner)-1] - return v, true + return L.inner.Get() } func (L *Locked[T]) Put(v T) { L.mu.Lock() defer L.mu.Unlock() - L.inner = append(L.inner, v) + L.inner.Put(v) } diff --git a/lib/util/pools/pool.go b/lib/util/pools/pool.go new file mode 100644 index 00000000..390ef4e6 --- /dev/null +++ b/lib/util/pools/pool.go @@ -0,0 +1,16 @@ +package pools + +type Pool[T any] []T + +func (P *Pool[T]) Get() (T, bool) { + if len(*P) == 0 { + return *new(T), false + } + v := (*P)[len(*P)-1] + *P = (*P)[:len(*P)-1] + return v, true +} + +func (P *Pool[T]) Put(v T) { + *P = append(*P, v) +} diff --git a/lib/util/slices/cloneinto.go b/lib/util/slices/cloneinto.go new file mode 100644 index 00000000..b544cc38 --- /dev/null +++ b/lib/util/slices/cloneinto.go @@ -0,0 +1,9 @@ +package slices + +func CloneInto[T any](dst, src []T) []T { + dst = Resize(dst, len(src)) + for i, v := range src { + dst[i] = v + } + return dst +} diff --git a/lib/zap/packets/v3.0/bind.go b/lib/zap/packets/v3.0/bind.go index 07c526a6..b6effa44 100644 --- a/lib/zap/packets/v3.0/bind.go +++ b/lib/zap/packets/v3.0/bind.go @@ -46,7 +46,10 @@ func ReadBind(in zap.In) (destination string, source string, parameterFormatCode var parameterValue []byte if parameterValueLength >= 0 { parameterValue = make([]byte, int(parameterValueLength)) - in.Bytes(parameterValue) + ok = in.Bytes(parameterValue) + if !ok { + return + } } parameterValues = append(parameterValues, parameterValue) } -- GitLab