diff --git a/contrib/codecs/websocket/codec.go b/contrib/codecs/websocket/codec.go index 325c4452120d908ae5cc510f28f8909ae8666572..68c0517632d821a7c3920fd97be547f16c8b71d7 100644 --- a/contrib/codecs/websocket/codec.go +++ b/contrib/codecs/websocket/codec.go @@ -5,13 +5,12 @@ import ( "encoding/json" "io" "net/http" + _ "net/http/pprof" "sync" "time" "gfx.cafe/open/websocket" - _ "net/http/pprof" - "gfx.cafe/open/jrpc/pkg/jjson" "gfx.cafe/open/jrpc/pkg/jsonrpc" "gfx.cafe/open/jrpc/pkg/serverutil" diff --git a/contrib/extension/subscription/client_test.go b/contrib/extension/subscription/client_test.go index 9e7a58e8a10d41a4594ad401a27c748af99048d4..4bf7f100afd5cac6b5b77c8d2a9e5d23d2ff20f0 100644 --- a/contrib/extension/subscription/client_test.go +++ b/contrib/extension/subscription/client_test.go @@ -5,6 +5,7 @@ import ( "log" "net" "net/http" + _ "net/http/pprof" "testing" "time" @@ -15,6 +16,70 @@ import ( "gfx.cafe/open/jrpc/pkg/server" ) +func TestSubscription(t *testing.T) { + go func() { + t.Error(http.ListenAndServe(":6060", nil)) + }() + + const count = 100 + + engine := NewEngine() + r := jmux.NewRouter() + r.Use(engine.Middleware()) + r.HandleFunc("test/subscribe", func(w jsonrpc.ResponseWriter, r *jsonrpc.Request) { + notifier, ok := NotifierFromContext(r.Context()) + if !ok { + _ = w.Send(nil, ErrNotificationsUnsupported) + return + } + + for i := 0; i < count; i++ { + if err := notifier.Notify(i); err != nil { + panic(err) + } + } + }) + + srv := server.NewServer(r) + handler := codecs.WebsocketHandler(srv, []string{"*"}) + httpSrv := http.Server{ + Addr: ":8855", + Handler: handler, + } + listener, err := net.Listen("tcp", ":8855") + if err != nil { + t.Error(err) + return + } + go func() { + if err := httpSrv.Serve(listener); err != nil { + t.Error(err) + return + } + }() + + cl, err := UpgradeConn(jrpc.Dial("ws://localhost:8855")) + if err != nil { + t.Error(err) + return + } + + ch := make(chan int, count) + sub, err := cl.Subscribe(context.Background(), "test", ch, nil) + defer func() { + if err = sub.Unsubscribe(); err != nil { + t.Error(err) + } + }() + + for i := 0; i < count; i++ { + v := <-ch + if v != i { + t.Errorf("expected %d but got %d", i, v) + } + } +} + func TestWrapClient(t *testing.T) { engine := NewEngine() r := jmux.NewRouter() diff --git a/pkg/jsonrpc/message.go b/pkg/jsonrpc/message.go index f6b96a398503739dd5a7fe006aaee8307e2f7a83..d23de9cd133b2b07c3b7e8f1104f20a67d1049eb 100644 --- a/pkg/jsonrpc/message.go +++ b/pkg/jsonrpc/message.go @@ -104,6 +104,15 @@ func (m *MessageWriter) Result() (io.Writer, error) { return &ResultWriter{w: m.w}, nil } +// Params returns a writer that writes to a params field +func (m *MessageWriter) Params() (io.Writer, error) { + _, err := m.w.Write([]byte(`,"params":`)) + if err != nil { + return nil, err + } + return &ResultWriter{w: m.w}, nil +} + type BatchWriter struct { w io.Writer mu *semaphore.Weighted diff --git a/pkg/server/server.go b/pkg/server/server.go index 0fbbf71c9dbadb88defc9af6f38f150266bfeed8..04aedf1fde12c62cf60e3b17c50c9c2d396eec24 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -325,7 +325,7 @@ func (c *callResponder) notify(env *notifyEnv, s *jsonrpc.MessageWriter) (err er return err } // if there is no error, we try to marshal the result - wr, err := s.Result() + wr, err := s.Params() if err != nil { return err }