diff --git a/contrib/extension/subscription/client.go b/contrib/extension/subscription/client.go index f329e005d42ff50d6f629d7991a203e5172d7150..ff3934df7093ee0bd9eddf3ff7a2afd4199e1ab1 100644 --- a/contrib/extension/subscription/client.go +++ b/contrib/extension/subscription/client.go @@ -90,6 +90,17 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an readErr: make(chan error, 1), } + go func() { + defer func() { + _ = sub.Unsubscribe() + }() + select { + case <-c.Closed(): + case <-ctx.Done(): + sub.err(ctx.Err()) + } + }() + c.mu.Lock() c.subs[sub.id] = sub c.mu.Unlock() diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go index cc7a8eb778e005f8d513e1d23e1e3ff027bcdb15..b63d6e2fbf5ca5a9380927adbfb69cd622dedbc8 100644 --- a/contrib/extension/subscription/client_test.go +++ b/contrib/extension/subscription/client_test.go @@ -2,6 +2,7 @@ package subscription import ( "context" + "encoding/json" "net/http/httptest" _ "net/http/pprof" "strings" @@ -15,49 +16,78 @@ import ( "gfx.cafe/open/jrpc/pkg/server" ) -func TestSubscription(t *testing.T) { - const count = 100 - +func newRouter(t *testing.T) jmux.Router { engine := NewEngine() r := jmux.NewRouter() r.Use(engine.Middleware()) + r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { + err := w.Send(r.Params, nil) + if err != nil { + t.Error(err) + } + }) + // extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying + // to unsubscribe r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { notifier, ok := NotifierFromContext(r.Context()) if !ok { - _ = w.Send(nil, ErrNotificationsUnsupported) + err := w.Send(nil, ErrNotificationsUnsupported) + if err != nil { + t.Error(err) + } return } - + var count int + _ = json.Unmarshal(r.Params, &count) go func() { time.Sleep(10 * time.Millisecond) - for i := 0; i < count; i++ { - if err := notifier.Notify(i); err != nil { - panic(err) + for idx := 0; count == 0 || idx < count; idx++ { + select { + case <-r.Context().Done(): + return + case <-notifier.Err(): + return + default: + } + err := notifier.Notify(idx) + if err != nil { + t.Error(err) } } }() }) + return r +} + +func newServer(t *testing.T) (Conn, func()) { + r := newRouter(t) srv := server.NewServer(r) - defer srv.Shutdown(context.Background()) handler := codecs.WebsocketHandler(srv, []string{"*"}) httpSrv := httptest.NewServer(handler) - defer httpSrv.Close() wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") cl, err := UpgradeConn(jrpc.Dial(wsURL)) if err != nil { t.Error(err) - return + return nil, nil } - defer func() { - if err = cl.Close(); err != nil { - t.Error(err) - } - }() + + return cl, func() { + _ = cl.Close() + httpSrv.Close() + srv.Shutdown(context.Background()) + } +} + +func TestSubscription(t *testing.T) { + const count = 100 + + cl, done := newServer(t) + defer done() ch := make(chan int, count) - sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + sub, err := cl.Subscribe(context.Background(), "test", ch, count) defer func() { if err = sub.Unsubscribe(); err != nil { t.Error(err) @@ -73,46 +103,11 @@ func TestSubscription(t *testing.T) { } func TestUnsubscribeNoRead(t *testing.T) { - engine := NewEngine() - r := jmux.NewRouter() - r.Use(engine.Middleware()) - r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { - notifier, ok := NotifierFromContext(r.Context()) - if !ok { - _ = w.Send(nil, ErrNotificationsUnsupported) - return - } - - go func() { - time.Sleep(10 * time.Millisecond) - for i := 0; i < 10; i++ { - if err := notifier.Notify(i); err != nil { - panic(err) - } - } - }() - }) - - srv := server.NewServer(r) - defer srv.Shutdown(context.Background()) - handler := codecs.WebsocketHandler(srv, []string{"*"}) - httpSrv := httptest.NewServer(handler) - defer httpSrv.Close() - - wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") - cl, err := UpgradeConn(jrpc.Dial(wsURL)) - if err != nil { - t.Error(err) - return - } - defer func() { - if err = cl.Close(); err != nil { - t.Error(err) - } - }() + cl, done := newServer(t) + defer done() ch := make(chan int) - sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + sub, err := cl.Subscribe(context.Background(), "test", ch, 10) time.Sleep(time.Second) if err = sub.Unsubscribe(); err != nil { t.Error(err) @@ -121,66 +116,12 @@ func TestUnsubscribeNoRead(t *testing.T) { } func TestWrapClient(t *testing.T) { - engine := NewEngine() - r := jmux.NewRouter() - r.Use(engine.Middleware()) - r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { - err := w.Send(r.Params, nil) - if err != nil { - t.Error(err) - } - }) - // extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying - // to unsubscribe - r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { - notifier, ok := NotifierFromContext(r.Context()) - if !ok { - err := w.Send(nil, ErrNotificationsUnsupported) - if err != nil { - t.Error(err) - } - return - } - go func() { - time.Sleep(10 * time.Millisecond) - idx := 0 - for { - select { - case <-r.Context().Done(): - return - case <-notifier.Err(): - return - default: - } - err := notifier.Notify(idx) - if err != nil { - t.Error(err) - } - idx += 1 - } - }() - }) - srv := server.NewServer(r) - defer srv.Shutdown(context.Background()) - handler := codecs.WebsocketHandler(srv, []string{"*"}) - httpSrv := httptest.NewServer(handler) - defer httpSrv.Close() - - wsURL := "ws:" + strings.TrimPrefix(httpSrv.URL, "http:") - cl, err := UpgradeConn(jrpc.Dial(wsURL)) - if err != nil { - t.Error(err) - return - } - defer func() { - if err = cl.Close(); err != nil { - t.Error(err) - } - }() + cl, done := newServer(t) + defer done() for i := 0; i < 10; i++ { var res string - if err = cl.Do(context.Background(), &res, "echo", "test"); err != nil { + if err := cl.Do(context.Background(), &res, "echo", "test"); err != nil { t.Error(err) return } @@ -190,8 +131,7 @@ func TestWrapClient(t *testing.T) { } ch := make(chan int, 101) - var sub ClientSubscription - sub, err = cl.Subscribe(context.Background(), "test", ch, nil) + sub, err := cl.Subscribe(context.Background(), "test", ch, nil) if err != nil { t.Error(err) return @@ -220,3 +160,35 @@ func TestWrapClient(t *testing.T) { }() } } + +func TestCloseClient(t *testing.T) { + cl, done := newServer(t) + defer done() + + ch := make(chan int) + sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + if err != nil { + t.Error(err) + return + } + + go func() { + if err := cl.Close(); err != nil { + t.Error(err) + } + }() + + for { + select { + case err, ok := <-sub.Err(): + if ok { + t.Errorf("sub errored: %v", err) + } + return + case _, ok := <-ch: + if !ok { + return + } + } + } +}