From 2f492c813c3eba5fb9073ad4d8a0aab23634c4fc Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Mon, 4 Dec 2023 17:39:17 -0600
Subject: [PATCH] unsubscribe without reading

---
 contrib/extension/subscription/client.go      |  6 +--
 contrib/extension/subscription/client_test.go | 43 +++++++++++++++++++
 2 files changed, 46 insertions(+), 3 deletions(-)

diff --git a/contrib/extension/subscription/client.go b/contrib/extension/subscription/client.go
index 3290126..d3bf845 100644
--- a/contrib/extension/subscription/client.go
+++ b/contrib/extension/subscription/client.go
@@ -184,15 +184,15 @@ func (c *clientSub) Err() <-chan error {
 }
 
 func (c *clientSub) Unsubscribe() error {
+	if c.done.CompareAndSwap(false, true) {
+		close(c.subdone)
+	}
 	// TODO: dont use context background here...
 	var result string
 	err := c.conn.Do(context.Background(), &result, c.namespace+serviceMethodSeparator+unsubscribeMethodSuffix, nil)
 	if err != nil {
 		return err
 	}
-	if c.done.CompareAndSwap(false, true) {
-		close(c.subdone)
-	}
 	return nil
 }
 
diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go
index 0f338c5..e94ca28 100644
--- a/contrib/extension/subscription/client_test.go
+++ b/contrib/extension/subscription/client_test.go
@@ -3,6 +3,7 @@ package subscription
 import (
 	"context"
 	"log"
+	"net/http"
 	"net/http/httptest"
 	_ "net/http/pprof"
 	"strings"
@@ -63,6 +64,48 @@ func TestSubscription(t *testing.T) {
 	}
 }
 
+func TestUnsubscribeNoRead(t *testing.T) {
+	go func() {
+		panic(http.ListenAndServe(":6060", nil))
+	}()
+
+	engine := NewEngine()
+	r := jmux.NewRouter()
+	r.Use(engine.Middleware())
+	r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) {
+		notifier, ok := NotifierFromContext(r.Context())
+		if !ok {
+			_ = w.Send(nil, ErrNotificationsUnsupported)
+			return
+		}
+
+		for i := 0; i < 10; i++ {
+			if err := notifier.Notify(i); err != nil {
+				panic(err)
+			}
+		}
+	})
+
+	srv := server.NewServer(r)
+	handler := codecs.WebsocketHandler(srv, []string{"*"})
+	httpSrv := httptest.NewServer(handler)
+
+	wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:")
+	cl, err := UpgradeConn(jrpc.Dial(wsURL))
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	ch := make(chan int)
+	sub, err := cl.Subscribe(context.Background(), "test", ch, nil)
+	time.Sleep(time.Second)
+	if err = sub.Unsubscribe(); err != nil {
+		t.Error(err)
+		return
+	}
+}
+
 func TestWrapClient(t *testing.T) {
 	engine := NewEngine()
 	r := jmux.NewRouter()
-- 
GitLab