diff --git a/contrib/client/pooling.go b/contrib/client/pooling.go index 4c460be1fe097386c9c747c8be52a2ed4b838487..dea11946e46f39053dd26e16105297505e17db08 100644 --- a/contrib/client/pooling.go +++ b/contrib/client/pooling.go @@ -7,27 +7,59 @@ import ( "sync/atomic" "gfx.cafe/open/jrpc" + "gfx.cafe/open/jrpc/contrib/extension/subscription" "gfx.cafe/open/jrpc/pkg/codec" ) var _ codec.Conn = (*Pooling)(nil) +var _ subscription.Conn = (*Pooling)(nil) type Pooling struct { dialer func(ctx context.Context) (jrpc.Conn, error) conns chan codec.Conn - base codec.Conn + base subscription.Conn closed atomic.Bool middleware []codec.Middleware mu sync.Mutex } -func NewPooling(dialer func(ctx context.Context) (jrpc.Conn, error), max int) *Pooling { +func (p *Pooling) getBase(ctx context.Context) (subscription.Conn, error) { + p.mu.Lock() + defer p.mu.Unlock() + if p.base == nil { + conn, err := subscription.UpgradeConn(p.dialer(ctx)) + if err != nil { + return nil, err + } + p.base = conn + } + return p.base, nil +} + +func (p *Pooling) Notify(ctx context.Context, method string, params any) error { + base, err := p.getBase(ctx) + if err != nil { + return err + } + return base.Notify(ctx, method, params) +} + +func (p *Pooling) Subscribe(ctx context.Context, namespace string, channel any, args any) (subscription.ClientSubscription, error) { + base, err := p.getBase(ctx) + if err != nil { + return nil, err + } + return base.Subscribe(ctx, namespace, channel, args) +} + +func NewPooling(ctx context.Context, dialer func(ctx context.Context) (jrpc.Conn, error), max int) (*Pooling, error) { r := &Pooling{ dialer: dialer, conns: make(chan codec.Conn, max), } - return r + + return r, nil } func (r *Pooling) Do(ctx context.Context, result any, method string, params any) error {