diff --git a/contrib/extension/subscription/client.go b/contrib/extension/subscription/client.go index 3290126e3f4a13a848670761d1d0b49f4c32c052..d3bf8458183ed5a8459b385759beeb672cce55b3 100644 --- a/contrib/extension/subscription/client.go +++ b/contrib/extension/subscription/client.go @@ -184,15 +184,15 @@ func (c *clientSub) Err() <-chan error { } func (c *clientSub) Unsubscribe() error { + if c.done.CompareAndSwap(false, true) { + close(c.subdone) + } // TODO: dont use context background here... var result string err := c.conn.Do(context.Background(), &result, c.namespace+serviceMethodSeparator+unsubscribeMethodSuffix, nil) if err != nil { return err } - if c.done.CompareAndSwap(false, true) { - close(c.subdone) - } return nil } diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go index 0f338c579772c703e47057ac55f3c82b5664eadc..bf4a7fa2ba9e6490e491bb24a4518afbbf219238 100644 --- a/contrib/extension/subscription/client_test.go +++ b/contrib/extension/subscription/client_test.go @@ -2,7 +2,6 @@ package subscription import ( "context" - "log" "net/http/httptest" _ "net/http/pprof" "strings" @@ -63,6 +62,44 @@ func TestSubscription(t *testing.T) { } } +func TestUnsubscribeNoRead(t *testing.T) { + engine := NewEngine() + r := jmux.NewRouter() + r.Use(engine.Middleware()) + r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { + notifier, ok := NotifierFromContext(r.Context()) + if !ok { + _ = w.Send(nil, ErrNotificationsUnsupported) + return + } + + for i := 0; i < 10; i++ { + if err := notifier.Notify(i); err != nil { + panic(err) + } + } + }) + + srv := server.NewServer(r) + handler := codecs.WebsocketHandler(srv, []string{"*"}) + httpSrv := httptest.NewServer(handler) + + wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") + cl, err := UpgradeConn(jrpc.Dial(wsURL)) + if err != nil { + t.Error(err) + return + } + + ch := make(chan int) + sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + time.Sleep(time.Second) + if err = sub.Unsubscribe(); err != nil { + t.Error(err) + return + } +} + func TestWrapClient(t *testing.T) { engine := NewEngine() r := jmux.NewRouter() @@ -135,8 +172,7 @@ func TestWrapClient(t *testing.T) { t.Errorf("sub errored: %v", err) } return - case v := <-ch: - log.Printf("%v", v) + case <-ch: } } }()