diff --git a/contrib/extension/subscription/client.go b/contrib/extension/subscription/client.go
index f329e005d42ff50d6f629d7991a203e5172d7150..ff3934df7093ee0bd9eddf3ff7a2afd4199e1ab1 100644
--- a/contrib/extension/subscription/client.go
+++ b/contrib/extension/subscription/client.go
@@ -90,6 +90,17 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an
 		readErr:   make(chan error, 1),
 	}
 
+	go func() {
+		defer func() {
+			_ = sub.Unsubscribe()
+		}()
+		select {
+		case <-c.Closed():
+		case <-ctx.Done():
+			sub.err(ctx.Err())
+		}
+	}()
+
 	c.mu.Lock()
 	c.subs[sub.id] = sub
 	c.mu.Unlock()
diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go
index cc7a8eb778e005f8d513e1d23e1e3ff027bcdb15..b63d6e2fbf5ca5a9380927adbfb69cd622dedbc8 100644
--- a/contrib/extension/subscription/client_test.go
+++ b/contrib/extension/subscription/client_test.go
@@ -2,6 +2,7 @@ package subscription
 
 import (
 	"context"
+	"encoding/json"
 	"net/http/httptest"
 	_ "net/http/pprof"
 	"strings"
@@ -15,49 +16,78 @@ import (
 	"gfx.cafe/open/jrpc/pkg/server"
 )
 
-func TestSubscription(t *testing.T) {
-	const count = 100
-
+func newRouter(t *testing.T) jmux.Router {
 	engine := NewEngine()
 	r := jmux.NewRouter()
 	r.Use(engine.Middleware())
+	r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
+		err := w.Send(r.Params, nil)
+		if err != nil {
+			t.Error(err)
+		}
+	})
+	// extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying
+	// to unsubscribe
 	r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
 		notifier, ok := NotifierFromContext(r.Context())
 		if !ok {
-			_ = w.Send(nil, ErrNotificationsUnsupported)
+			err := w.Send(nil, ErrNotificationsUnsupported)
+			if err != nil {
+				t.Error(err)
+			}
 			return
 		}
-
+		var count int
+		_ = json.Unmarshal(r.Params, &count)
 		go func() {
 			time.Sleep(10 * time.Millisecond)
-			for i := 0; i < count; i++ {
-				if err := notifier.Notify(i); err != nil {
-					panic(err)
+			for idx := 0; count == 0 || idx < count; idx++ {
+				select {
+				case <-r.Context().Done():
+					return
+				case <-notifier.Err():
+					return
+				default:
+				}
+				err := notifier.Notify(idx)
+				if err != nil {
+					t.Error(err)
 				}
 			}
 		}()
 	})
 
+	return r
+}
+
+func newServer(t *testing.T) (Conn, func()) {
+	r := newRouter(t)
 	srv := server.NewServer(r)
-	defer srv.Shutdown(context.Background())
 	handler := codecs.WebsocketHandler(srv, []string{"*"})
 	httpSrv := httptest.NewServer(handler)
-	defer httpSrv.Close()
 
 	wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
 	cl, err := UpgradeConn(jrpc.Dial(wsURL))
 	if err != nil {
 		t.Error(err)
-		return
+		return nil, nil
 	}
-	defer func() {
-		if err = cl.Close(); err != nil {
-			t.Error(err)
-		}
-	}()
+
+	return cl, func() {
+		_ = cl.Close()
+		httpSrv.Close()
+		srv.Shutdown(context.Background())
+	}
+}
+
+func TestSubscription(t *testing.T) {
+	const count = 100
+
+	cl, done := newServer(t)
+	defer done()
 
 	ch := make(chan int, count)
-	sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
+	sub, err := cl.Subscribe(context.Background(), "test", ch, count)
 	defer func() {
 		if err = sub.Unsubscribe(); err != nil {
 			t.Error(err)
@@ -73,46 +103,11 @@ func TestSubscription(t *testing.T) {
 }
 
 func TestUnsubscribeNoRead(t *testing.T) {
-	engine := NewEngine()
-	r := jmux.NewRouter()
-	r.Use(engine.Middleware())
-	r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
-		notifier, ok := NotifierFromContext(r.Context())
-		if !ok {
-			_ = w.Send(nil, ErrNotificationsUnsupported)
-			return
-		}
-
-		go func() {
-			time.Sleep(10 * time.Millisecond)
-			for i := 0; i < 10; i++ {
-				if err := notifier.Notify(i); err != nil {
-					panic(err)
-				}
-			}
-		}()
-	})
-
-	srv := server.NewServer(r)
-	defer srv.Shutdown(context.Background())
-	handler := codecs.WebsocketHandler(srv, []string{"*"})
-	httpSrv := httptest.NewServer(handler)
-	defer httpSrv.Close()
-
-	wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
-	cl, err := UpgradeConn(jrpc.Dial(wsURL))
-	if err != nil {
-		t.Error(err)
-		return
-	}
-	defer func() {
-		if err = cl.Close(); err != nil {
-			t.Error(err)
-		}
-	}()
+	cl, done := newServer(t)
+	defer done()
 
 	ch := make(chan int)
-	sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
+	sub, err := cl.Subscribe(context.Background(), "test", ch, 10)
 	time.Sleep(time.Second)
 	if err = sub.Unsubscribe(); err != nil {
 		t.Error(err)
@@ -121,66 +116,12 @@ func TestUnsubscribeNoRead(t *testing.T) {
 }
 
 func TestWrapClient(t *testing.T) {
-	engine := NewEngine()
-	r := jmux.NewRouter()
-	r.Use(engine.Middleware())
-	r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
-		err := w.Send(r.Params, nil)
-		if err != nil {
-			t.Error(err)
-		}
-	})
-	// extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying
-	// to unsubscribe
-	r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
-		notifier, ok := NotifierFromContext(r.Context())
-		if !ok {
-			err := w.Send(nil, ErrNotificationsUnsupported)
-			if err != nil {
-				t.Error(err)
-			}
-			return
-		}
-		go func() {
-			time.Sleep(10 * time.Millisecond)
-			idx := 0
-			for {
-				select {
-				case <-r.Context().Done():
-					return
-				case <-notifier.Err():
-					return
-				default:
-				}
-				err := notifier.Notify(idx)
-				if err != nil {
-					t.Error(err)
-				}
-				idx += 1
-			}
-		}()
-	})
-	srv := server.NewServer(r)
-	defer srv.Shutdown(context.Background())
-	handler := codecs.WebsocketHandler(srv, []string{"*"})
-	httpSrv := httptest.NewServer(handler)
-	defer httpSrv.Close()
-
-	wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
-	cl, err := UpgradeConn(jrpc.Dial(wsURL))
-	if err != nil {
-		t.Error(err)
-		return
-	}
-	defer func() {
-		if err = cl.Close(); err != nil {
-			t.Error(err)
-		}
-	}()
+	cl, done := newServer(t)
+	defer done()
 
 	for i := 0; i < 10; i++ {
 		var res string
-		if err = cl.Do(context.Background(), &res, "echo", "test"); err != nil {
+		if err := cl.Do(context.Background(), &res, "echo", "test"); err != nil {
 			t.Error(err)
 			return
 		}
@@ -190,8 +131,7 @@ func TestWrapClient(t *testing.T) {
 		}
 
 		ch := make(chan int, 101)
-		var sub ClientSubscription
-		sub, err = cl.Subscribe(context.Background(), "test", ch, nil)
+		sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
 		if err != nil {
 			t.Error(err)
 			return
@@ -220,3 +160,35 @@ func TestWrapClient(t *testing.T) {
 		}()
 	}
 }
+
+func TestCloseClient(t *testing.T) {
+	cl, done := newServer(t)
+	defer done()
+
+	ch := make(chan int)
+	sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	go func() {
+		if err := cl.Close(); err != nil {
+			t.Error(err)
+		}
+	}()
+
+	for {
+		select {
+		case err, ok := <-sub.Err():
+			if ok {
+				t.Errorf("sub errored: %v", err)
+			}
+			return
+		case _, ok := <-ch:
+			if !ok {
+				return
+			}
+		}
+	}
+}