From 5347bf6b971ce4cf493ce911c86de2b653435c0e Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Fri, 8 Dec 2023 13:41:27 -0600
Subject: [PATCH] don't hang in idreply if conn closes

---
 contrib/codecs/rdwr/client.go |  6 ++++--
 pkg/clientutil/idreply.go     | 17 ++++++++++++++++-
 pkg/jrpctest/suites.go        | 12 ++++++++++++
 3 files changed, 32 insertions(+), 3 deletions(-)

diff --git a/contrib/codecs/rdwr/client.go b/contrib/codecs/rdwr/client.go
index d36b04c..4d035b7 100644
--- a/contrib/codecs/rdwr/client.go
+++ b/contrib/codecs/rdwr/client.go
@@ -66,7 +66,9 @@ func (c *Client) Mount(h jsonrpc.Middleware) {
 
 func (c *Client) listen() error {
 	var msg json.RawMessage
-	defer c.cn()
+	defer func() {
+		_ = c.Close()
+	}()
 	dec := json.NewDecoder(bufio.NewReader(c.rd))
 	for {
 		err := dec.Decode(&msg)
@@ -157,7 +159,7 @@ func (c *Client) SetHeader(key string, value string) {
 
 func (c *Client) Close() error {
 	c.cn()
-	return nil
+	return c.p.Close()
 }
 
 func (c *Client) writeContext(ctx context.Context, xs []byte) error {
diff --git a/pkg/clientutil/idreply.go b/pkg/clientutil/idreply.go
index 46ebcc1..2033c47 100644
--- a/pkg/clientutil/idreply.go
+++ b/pkg/clientutil/idreply.go
@@ -3,6 +3,7 @@ package clientutil
 import (
 	"context"
 	"io"
+	"net"
 	"sync"
 	"sync/atomic"
 
@@ -12,6 +13,8 @@ import (
 type IdReply struct {
 	id atomic.Int64
 
+	closed chan struct{}
+
 	chs map[string]chan msgOrError
 	mu  sync.Mutex
 }
@@ -23,7 +26,8 @@ type msgOrError struct {
 
 func NewIdReply() *IdReply {
 	return &IdReply{
-		chs: make(map[string]chan msgOrError, 1),
+		closed: make(chan struct{}),
+		chs:    make(map[string]chan msgOrError, 1),
 	}
 }
 
@@ -94,5 +98,16 @@ func (i *IdReply) Ask(ctx context.Context, id []byte) (io.ReadCloser, error) {
 	case <-ctx.Done():
 		i.remove(id)
 		return nil, ctx.Err()
+	case <-i.closed:
+		return nil, net.ErrClosed
 	}
 }
+
+func (i *IdReply) Closed() <-chan struct{} {
+	return i.closed
+}
+
+func (i *IdReply) Close() error {
+	close(i.closed)
+	return nil
+}
diff --git a/pkg/jrpctest/suites.go b/pkg/jrpctest/suites.go
index b175b6f..c92e6a0 100644
--- a/pkg/jrpctest/suites.go
+++ b/pkg/jrpctest/suites.go
@@ -3,7 +3,9 @@ package jrpctest
 import (
 	"context"
 	"embed"
+	"errors"
 	"math/rand"
+	"net"
 	"reflect"
 	"sync"
 	"testing"
@@ -188,6 +190,16 @@ func RunBasicTestSuite(t *testing.T, args BasicTestSuiteArgs) {
 		wg.Wait()
 	})
 
+	makeTest("close", func(t *testing.T, server *server.Server, client jsonrpc.Conn) {
+		go func() {
+			_ = client.Close()
+		}()
+		err := jsonrpc.CallInto(context.Background(), client, nil, "test_block")
+		if !errors.Is(err, net.ErrClosed) {
+			t.Errorf("expected close error but got %v", err)
+		}
+	})
+
 	makeTest("", func(t *testing.T, server *server.Server, client jsonrpc.Conn) {
 	})
 }
-- 
GitLab