diff --git a/contrib/extension/subscription/client.go b/contrib/extension/subscription/client.go index 3290126e3f4a13a848670761d1d0b49f4c32c052..d3bf8458183ed5a8459b385759beeb672cce55b3 100644 --- a/contrib/extension/subscription/client.go +++ b/contrib/extension/subscription/client.go @@ -184,15 +184,15 @@ func (c *clientSub) Err() <-chan error { } func (c *clientSub) Unsubscribe() error { + if c.done.CompareAndSwap(false, true) { + close(c.subdone) + } // TODO: dont use context background here... var result string err := c.conn.Do(context.Background(), &result, c.namespace+serviceMethodSeparator+unsubscribeMethodSuffix, nil) if err != nil { return err } - if c.done.CompareAndSwap(false, true) { - close(c.subdone) - } return nil } diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go index 0f338c579772c703e47057ac55f3c82b5664eadc..e94ca28c96d408d4a2fefa81b6a02ab8fbcc1d6c 100644 --- a/contrib/extension/subscription/client_test.go +++ b/contrib/extension/subscription/client_test.go @@ -3,6 +3,7 @@ package subscription import ( "context" "log" + "net/http" "net/http/httptest" _ "net/http/pprof" "strings" @@ -63,6 +64,48 @@ func TestSubscription(t *testing.T) { } } +func TestUnsubscribeNoRead(t *testing.T) { + go func() { + panic(http.ListenAndServe(":6060", nil)) + }() + + engine := NewEngine() + r := jmux.NewRouter() + r.Use(engine.Middleware()) + r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { + notifier, ok := NotifierFromContext(r.Context()) + if !ok { + _ = w.Send(nil, ErrNotificationsUnsupported) + return + } + + for i := 0; i < 10; i++ { + if err := notifier.Notify(i); err != nil { + panic(err) + } + } + }) + + srv := server.NewServer(r) + handler := codecs.WebsocketHandler(srv, []string{"*"}) + httpSrv := httptest.NewServer(handler) + + wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") + cl, err := UpgradeConn(jrpc.Dial(wsURL)) + if err != nil { + t.Error(err) + return + } + + ch := make(chan int) + sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + time.Sleep(time.Second) + if err = sub.Unsubscribe(); err != nil { + t.Error(err) + return + } +} + func TestWrapClient(t *testing.T) { engine := NewEngine() r := jmux.NewRouter()