diff --git a/lib/middleware/middlewares/eqp/state.go b/lib/middleware/middlewares/eqp/state.go index 0a6fc13f71f643ca21292ac1f494157e08be4cc3..cebb92fec52deedae588fb2e3daf73a9ebc73b9f 100644 --- a/lib/middleware/middlewares/eqp/state.go +++ b/lib/middleware/middlewares/eqp/state.go @@ -192,10 +192,10 @@ func (T *State) ReadyForQuery(packet fed.Packet) { } // all pending has failed - for _, ok := T.pendingPreparedStatements.PopBack(); ok; _, ok = T.pendingPortals.PopBack() { + for _, ok := T.pendingPreparedStatements.PopBack(); ok; _, ok = T.pendingPreparedStatements.PopBack() { } for _, ok := T.pendingPortals.PopBack(); ok; _, ok = T.pendingPortals.PopBack() { } - for _, ok := T.pendingCloses.PopBack(); ok; _, ok = T.pendingPortals.PopBack() { + for _, ok := T.pendingCloses.PopBack(); ok; _, ok = T.pendingCloses.PopBack() { } } diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go index 1e8c4b85c28667e374d6a2736f621fb56f964c6a..848ec954ca9ee9446c52392636143ded302193c7 100644 --- a/lib/middleware/middlewares/eqp/sync.go +++ b/lib/middleware/middlewares/eqp/sync.go @@ -22,10 +22,16 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { // close all prepared statements that don't match client for name, preparedStatement := range s.state.preparedStatements { - clientPreparedStatement, ok := c.state.preparedStatements[name] - if ok && (name == "" || preparedStatement.Hash == clientPreparedStatement.Hash) { - // match or unnamed prepared statement that will be bound over - continue + if clientPreparedStatement, ok := c.state.preparedStatements[name]; ok { + if preparedStatement.Hash == clientPreparedStatement.Hash { + // the same + continue + } + + if name == "" { + // will be overwritten + continue + } } p := packets.Close{ @@ -39,10 +45,11 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { // parse all prepared statements that aren't on server for name, preparedStatement := range c.state.preparedStatements { - serverPreparedStatement, ok := s.state.preparedStatements[name] - if ok && preparedStatement.Hash == serverPreparedStatement.Hash { - // matched, don't need to set - continue + if serverPreparedStatement, ok := s.state.preparedStatements[name]; ok { + if preparedStatement.Hash == serverPreparedStatement.Hash { + // the same + continue + } } if err := server.WritePacket(preparedStatement.Packet); err != nil { diff --git a/lib/util/ring/ring.go b/lib/util/ring/ring.go index a55ce1d24583fbcc5f69a2caa4da9acd3907a652..2ba4003950e512e05703113315d4d4e859371cdb 100644 --- a/lib/util/ring/ring.go +++ b/lib/util/ring/ring.go @@ -1,7 +1,8 @@ package ring type Ring[T any] struct { - buffer []T + buf []T + // real head is head-1, like this so nil ring is valid head int tail int length int @@ -9,128 +10,138 @@ type Ring[T any] struct { func MakeRing[T any](length, capacity int) Ring[T] { if length > capacity { - // programmer error, panic - panic("length must be < capacity") - } - if capacity < 0 { - panic("capacity must be >= 0") - } - tail := length + 1 - if tail >= capacity { - tail -= capacity + panic("length must be less than capacity") } return Ring[T]{ - buffer: make([]T, capacity), - head: 0, + buf: make([]T, capacity), + tail: length, length: length, - tail: tail, } } func NewRing[T any](length, capacity int) *Ring[T] { - ring := MakeRing[T](length, capacity) - return &ring + r := MakeRing[T](length, capacity) + return &r } func (r *Ring[T]) grow() { - if cap(r.buffer) == 0 { - // special case, uninitialized - r.buffer = make([]T, 2) - r.head = 0 - r.tail = 1 - return + size := len(r.buf) * 2 + if size == 0 { + size = 2 } - // make new buffer with twice as much space - buf := make([]T, cap(r.buffer)*2) - - // copy from [head, end of buffer] into new buffer - copy(buf, r.buffer[r.head:]) - // copy from [beginning of buffer, tail) into new buffer - copy(buf[len(r.buffer)-r.head:], r.buffer[:r.tail]) - - r.tail = len(r.buffer) - r.head + r.tail + buf := make([]T, size) + copy(buf, r.buf[r.head:]) + copy(buf[len(r.buf[r.head:]):], r.buf[:r.head]) r.head = 0 - r.buffer = buf + r.tail = r.length + r.buf = buf } -func (r *Ring[T]) PushBack(value T) { - if r == nil { - panic("PushBack() on nil Ring") +func (r *Ring[T]) incHead() { + // resize + if r.length == 0 { + panic("smashing detected") } - if r.length == cap(r.buffer) { - r.grow() - } - r.buffer[r.tail] = value - r.tail++ - if r.tail >= len(r.buffer) { - r.tail -= len(r.buffer) + r.length-- + + r.head++ + if r.head == len(r.buf) { + r.head = 0 } - r.length++ } -func (r *Ring[T]) PushFront(value T) { - if r == nil { - panic("PushFront() on nil Ring") - } - if r.length == cap(r.buffer) { +func (r *Ring[T]) decHead() { + // resize + if r.length == len(r.buf) { r.grow() } - r.buffer[r.head] = value + r.length++ + r.head-- - if r.head < 0 { - r.head += len(r.buffer) + if r.head == -1 { + r.head = len(r.buf) - 1 + } +} + +func (r *Ring[T]) incTail() { + // resize + if r.length == len(r.buf) { + r.grow() } r.length++ + + r.tail++ + if r.tail == len(r.buf) { + r.tail = 0 + } } -func (r *Ring[T]) PopBack() (T, bool) { - if r == nil || r.length == 0 { - return *new(T), false +func (r *Ring[T]) decTail() { + // resize + if r.length == 0 { + panic("smashing detected") + } + r.length-- + + r.tail-- + if r.tail == -1 { + r.tail = len(r.buf) - 1 } +} + +func (r *Ring[T]) tailSub1() int { tail := r.tail - 1 - if tail < 0 { - tail += len(r.buffer) + if tail == -1 { + tail = len(r.buf) - 1 } - r.tail = tail - r.length-- - return r.buffer[tail], true + return tail } func (r *Ring[T]) PopFront() (T, bool) { - if r == nil || r.length == 0 { + if r.length == 0 { return *new(T), false } - head := r.head + 1 - if head >= len(r.buffer) { - head -= len(r.buffer) - } - r.head = head - r.length-- - return r.buffer[head], true + + front := r.buf[r.head] + r.incHead() + return front, true } -func (r *Ring[T]) Get(i int) (T, bool) { - if r == nil || i >= r.length || i < 0 { +func (r *Ring[T]) PopBack() (T, bool) { + if r.length == 0 { return *new(T), false } - ptr := r.head + 1 + i - if ptr >= len(r.buffer) { - ptr -= len(r.buffer) - } - return r.buffer[ptr], true + + r.decTail() + return r.buf[r.tail], true +} + +func (r *Ring[T]) PushFront(value T) { + r.decHead() + r.buf[r.head] = value +} + +func (r *Ring[T]) PushBack(value T) { + r.incTail() + r.buf[r.tailSub1()] = value } func (r *Ring[T]) Length() int { - if r == nil { - return 0 - } return r.length } func (r *Ring[T]) Capacity() int { - if r == nil { - return 0 + return len(r.buf) +} + +func (r *Ring[T]) Get(n int) T { + if n >= r.length { + panic("index out of range") + } + ptr := r.head + n + if ptr >= len(r.buf) { + ptr -= len(r.buf) } - return cap(r.buffer) + return r.buf[ptr] }