good morning!!!!

Skip to content
Snippets Groups Projects

fix pool hanging on close

3 files
+ 113
67
Compare changes
  • Side-by-side
  • Inline

Files

+ 110
66
@@ -4,7 +4,6 @@ import (
"context"
"net"
"sync"
"sync/atomic"
"gfx.cafe/open/jrpc"
"gfx.cafe/open/jrpc/contrib/extension/subscription"
@@ -16,17 +15,37 @@ var _ subscription.Conn = (*Pooling)(nil)
type Pooling struct {
dialer func(ctx context.Context) (jrpc.Conn, error)
conns chan codec.Conn
base subscription.Conn
closed atomic.Bool
middleware []codec.Middleware
mu sync.Mutex
closed chan struct{}
base subscription.Conn
baseClosed bool
baseMu sync.Mutex
conns chan codec.Conn
connsCount int
connsClosed bool
connsMu sync.Mutex
}
func NewPooling(ctx context.Context, dialer func(ctx context.Context) (jrpc.Conn, error), max int) (*Pooling, error) {
r := &Pooling{
dialer: dialer,
closed: make(chan struct{}),
conns: make(chan codec.Conn, max),
}
return r, nil
}
func (p *Pooling) getBase(ctx context.Context) (subscription.Conn, error) {
p.mu.Lock()
defer p.mu.Unlock()
p.baseMu.Lock()
defer p.baseMu.Unlock()
if p.baseClosed {
return nil, net.ErrClosed
}
if p.base == nil {
conn, err := subscription.UpgradeConn(p.dialer(ctx))
if err != nil {
@@ -53,90 +72,115 @@ func (p *Pooling) Subscribe(ctx context.Context, namespace string, channel any,
return base.Subscribe(ctx, namespace, channel, args)
}
func NewPooling(ctx context.Context, dialer func(ctx context.Context) (jrpc.Conn, error), max int) (*Pooling, error) {
r := &Pooling{
dialer: dialer,
conns: make(chan codec.Conn, max),
func (p *Pooling) Do(ctx context.Context, result any, method string, params any) error {
conn, err := p.getClient(ctx)
if err != nil {
return err
}
defer p.putClient(conn)
return conn.Do(ctx, result, method, params)
}
return r, nil
func (p *Pooling) BatchCall(ctx context.Context, b ...*codec.BatchElem) error {
conn, err := p.getClient(ctx)
if err != nil {
return err
}
defer p.putClient(conn)
return conn.BatchCall(ctx, b...)
}
func (p *Pooling) Mount(m codec.Middleware) {
p.middleware = append(p.middleware, m)
}
func (r *Pooling) Do(ctx context.Context, result any, method string, params any) error {
if r.closed.Load() {
return net.ErrClosed
func (p *Pooling) Close() error {
select {
case <-p.closed:
return nil
default:
close(p.closed)
}
errChan := make(chan error)
go func() {
conn, err := r.getClient(ctx)
if err != nil {
errChan <- err
func() {
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.connsClosed {
return
}
defer r.putClient(conn)
errChan <- conn.Do(ctx, result, method, params)
p.connsClosed = true
close(p.conns)
for conn := range p.conns {
p.connsCount--
_ = conn.Close()
}
}()
return <-errChan
}
func (r *Pooling) BatchCall(ctx context.Context, b ...*codec.BatchElem) error {
if r.closed.Load() {
return net.ErrClosed
}
errChan := make(chan error)
go func() {
conn, err := r.getClient(ctx)
if err != nil {
errChan <- err
func() {
p.baseMu.Lock()
defer p.baseMu.Unlock()
if p.baseClosed {
return
}
defer r.putClient(conn)
errChan <- conn.BatchCall(ctx, b...)
p.baseClosed = true
if p.base != nil {
_ = p.base.Close()
}
}()
return <-errChan
return nil
}
func (p *Pooling) Mount(m codec.Middleware) {
p.middleware = append(p.middleware, m)
func (p *Pooling) Closed() <-chan struct{} {
return p.closed
}
func (p *Pooling) Close() error {
if p.closed.CompareAndSwap(false, true) {
for k := range p.conns {
k.Close()
func (p *Pooling) getClient(ctx context.Context) (jrpc.Conn, error) {
create := func() bool {
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.connsCount < cap(p.conns) {
p.connsCount++
return true
}
}
return nil
}
func (p *Pooling) Closed() <-chan struct{} {
return make(<-chan struct{})
}
return false
}()
func (r *Pooling) getClient(ctx context.Context) (jrpc.Conn, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-r.conns:
default:
// wait for conn
if !create {
select {
case conn, ok := <-p.conns:
// if p.conns is closed, the pool is closed
if !ok {
return nil, net.ErrClosed
}
return conn, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
conn, err := r.dialer(ctx)
// dial new conn
conn, err := p.dialer(ctx)
if err != nil {
p.connsMu.Lock()
defer p.connsMu.Unlock()
p.connsCount--
return nil, err
}
return conn, nil
}
func (r *Pooling) putClient(conn jrpc.Conn) {
if r.closed.Load() {
func (p *Pooling) putClient(conn jrpc.Conn) {
p.connsMu.Lock()
defer p.connsMu.Unlock()
if p.connsClosed {
p.connsCount--
_ = conn.Close()
return
}
select {
case <-conn.Closed():
default:
}
select {
case r.conns <- conn:
default:
conn.Close()
}
// there should always be space
p.conns <- conn
}
Loading