diff --git a/contrib/extension/subscription/client.go b/contrib/extension/subscription/client.go index 1ed1c27dab8ede40c76e679961790f7e7d0fde5d..3290126e3f4a13a848670761d1d0b49f4c32c052 100644 --- a/contrib/extension/subscription/client.go +++ b/contrib/extension/subscription/client.go @@ -50,7 +50,11 @@ func (c *WrapClient) Middleware(h jsonrpc.Handler) jsonrpc.Handler { clientSub, ok := c.subs[params.ID] c.mu.Unlock() if ok { - clientSub.onmsg <- params.Result + // this could deadlock if we waited on onmsg and the sub was done + select { + case clientSub.onmsg <- params.Result: + case <-clientSub.subdone: + } } }) } @@ -84,24 +88,29 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an namespace: namespace, id: result, channel: chanVal, - // BUG: a worse is better solution... it means that when this fills, you might receive subscriptions in an undefined error - onmsg: make(chan json.RawMessage, 32), - subdone: make(chan struct{}), - readErr: make(chan error), + onmsg: make(chan json.RawMessage), + subdone: make(chan struct{}), + readErr: make(chan error), } // will get the type of the event etype := chanVal.Type().Elem() go func() { + defer func() { + // close if possible + if sub.done.CompareAndSwap(false, true) { + close(sub.subdone) + } + // we're done reading + close(sub.readErr) + }() for { select { case <-sub.subdone: - // sub is done, so close readErr - close(sub.readErr) + return case params, ok := <-sub.onmsg: if !ok { - close(sub.readErr) return } val := reflect.New(etype) @@ -111,9 +120,23 @@ func (c *WrapClient) Subscribe(ctx context.Context, namespace string, channel an return } // and now send the elem - sub.channel.Send(val.Elem()) + // this could deadlock if the client stopped waiting on the chan and unsubscribed + reflect.Select([]reflect.SelectCase{ + { + Dir: reflect.SelectSend, + Chan: sub.channel, + Send: val.Elem(), + }, + { + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(ctx.Done()), + }, + { + Dir: reflect.SelectRecv, + Chan: reflect.ValueOf(sub.subdone), + }, + }) case <-ctx.Done(): - close(sub.readErr) return } } diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go new file mode 100644 index 0000000000000000000000000000000000000000..9e7a58e8a10d41a4594ad401a27c748af99048d4 --- /dev/null +++ b/contrib/extension/subscription/client_test.go @@ -0,0 +1,109 @@ +package subscription + +import ( + "context" + "log" + "net" + "net/http" + "testing" + "time" + + "gfx.cafe/open/jrpc" + "gfx.cafe/open/jrpc/contrib/codecs" + "gfx.cafe/open/jrpc/contrib/jmux" + "gfx.cafe/open/jrpc/pkg/jsonrpc" + "gfx.cafe/open/jrpc/pkg/server" +) + +func TestWrapClient(t *testing.T) { + engine := NewEngine() + r := jmux.NewRouter() + r.Use(engine.Middleware()) + r.HandleFunc("echo", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { + _ = w.Send(r.Params, nil) + }) + // extremely fast subscription to fill buffers to get a higher chance that we receive another message while trying + // to unsubscribe + r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { + notifier, ok := NotifierFromContext(r.Context()) + if !ok { + _ = w.Send(nil, ErrNotificationsUnsupported) + return + } + go func() { + idx := 0 + for { + select { + case <-r.Context().Done(): + return + case <-notifier.Err(): + return + default: + } + _ = notifier.Notify(idx) + idx += 1 + } + }() + }) + srv := server.NewServer(r) + handler := codecs.WebsocketHandler(srv, []string{"*"}) + httpSrv := http.Server{ + Addr: ":8855", + Handler: handler, + } + listener, err := net.Listen("tcp", ":8855") + if err != nil { + t.Error(err) + return + } + go func() { + if err := httpSrv.Serve(listener); err != nil { + t.Error(err) + return + } + }() + + cl, err := UpgradeConn(jrpc.Dial("ws://localhost:8855")) + if err != nil { + t.Error(err) + return + } + + for i := 0; i < 10; i++ { + var res string + if err = cl.Do(context.Background(), &res, "echo", "test"); err != nil { + t.Error(err) + return + } + if res != "test" { + t.Errorf(`expected "test" but got %#v`, res) + return + } + + ch := make(chan int, 1) + sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + if err != nil { + t.Error(err) + return + } + + go func() { + time.Sleep(2 * time.Second) + _ = sub.Unsubscribe() + }() + + func() { + for { + select { + case err, ok := <-sub.Err(): + if ok { + t.Errorf("sub errored: %v", err) + } + return + case v := <-ch: + log.Printf("%v", v) + } + } + }() + } +}