diff --git a/contrib/codecs/rdwr/client.go b/contrib/codecs/rdwr/client.go index d36b04ca05c5a218fec0a5a5f229a1d005a38ba2..4d035b7c8a361309296ec3b18e483364c09688e3 100644 --- a/contrib/codecs/rdwr/client.go +++ b/contrib/codecs/rdwr/client.go @@ -66,7 +66,9 @@ func (c *Client) Mount(h jsonrpc.Middleware) { func (c *Client) listen() error { var msg json.RawMessage - defer c.cn() + defer func() { + _ = c.Close() + }() dec := json.NewDecoder(bufio.NewReader(c.rd)) for { err := dec.Decode(&msg) @@ -157,7 +159,7 @@ func (c *Client) SetHeader(key string, value string) { func (c *Client) Close() error { c.cn() - return nil + return c.p.Close() } func (c *Client) writeContext(ctx context.Context, xs []byte) error { diff --git a/pkg/clientutil/idreply.go b/pkg/clientutil/idreply.go index 46ebcc1504aa37e8c669a421027644087a02a633..2033c474ee4cfe72cdd19a515deba416600b5eb0 100644 --- a/pkg/clientutil/idreply.go +++ b/pkg/clientutil/idreply.go @@ -3,6 +3,7 @@ package clientutil import ( "context" "io" + "net" "sync" "sync/atomic" @@ -12,6 +13,8 @@ import ( type IdReply struct { id atomic.Int64 + closed chan struct{} + chs map[string]chan msgOrError mu sync.Mutex } @@ -23,7 +26,8 @@ type msgOrError struct { func NewIdReply() *IdReply { return &IdReply{ - chs: make(map[string]chan msgOrError, 1), + closed: make(chan struct{}), + chs: make(map[string]chan msgOrError, 1), } } @@ -94,5 +98,16 @@ func (i *IdReply) Ask(ctx context.Context, id []byte) (io.ReadCloser, error) { case <-ctx.Done(): i.remove(id) return nil, ctx.Err() + case <-i.closed: + return nil, net.ErrClosed } } + +func (i *IdReply) Closed() <-chan struct{} { + return i.closed +} + +func (i *IdReply) Close() error { + close(i.closed) + return nil +} diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go index b175b6fad3cd050371762f4bb25a6ee31506463e..c92e6a0448a8e6619c5383037ef192b2c9a97c21 100644 --- a/pkg/jrpctest/suites.go +++ b/pkg/jrpctest/suites.go @@ -3,7 +3,9 @@ package jrpctest import ( "context" "embed" + "errors" "math/rand" + "net" "reflect" "sync" "testing" @@ -188,6 +190,16 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) { wg.Wait() }) + makeTest("close", func(t *testing.T, server *server.Server, client jsonrpc.Conn) { + go func() { + _ = client.Close() + }() + err := jsonrpc.CallInto(context.Background(), client, nil, "test_block") + if !errors.Is(err, net.ErrClosed) { + t.Errorf("expected close error but got %v", err) + } + }) + makeTest("", func(t *testing.T, server *server.Server, client jsonrpc.Conn) { }) }