From 0bea1c40c1d1a0ac3a3740d1293d4a98eb5621c7 Mon Sep 17 00:00:00 2001 From: a <a@a.a> Date: Fri, 21 Oct 2022 23:21:46 -0500 Subject: [PATCH] start on subscriptions --- benchmark_test.go | 3 +- client.go | 35 ++++ example/subscription/main.go | 70 ++++++++ handler.go | 170 +++++++++++++++--- http.go | 15 +- json.go | 13 ++ protocol.go | 49 +++--- service.go | 7 +- subscription.go | 329 +++++++++++++++++++++++++++++++++++ testservice_test.go | 64 ++++++- wsjson/wsjson.go | 7 +- 11 files changed, 711 insertions(+), 51 deletions(-) create mode 100644 example/subscription/main.go create mode 100644 subscription.go diff --git a/benchmark_test.go b/benchmark_test.go index 24402b3..49fa35c 100644 --- a/benchmark_test.go +++ b/benchmark_test.go @@ -67,12 +67,13 @@ func BenchmarkClientWebsocketEcho(b *testing.B) { "on": map[string]any{"two": "three"}, } + payload := []any{1, 2, 3, 4, 56, 6, wantBack, wantBack, wantBack} b.StartTimer() for n := 0; n < b.N; n++ { eg := &errgroup.Group{} for i := 0; i < 1000; i++ { eg.Go(func() error { - return client.Call(nil, "test_echoAny", []any{1, 2, 3, 4, 56, 6, wantBack, wantBack, wantBack}) + return client.Call(nil, "test_echoAny", payload) }) } eg.Wait() diff --git a/client.go b/client.go index e406bed..8a9a145 100644 --- a/client.go +++ b/client.go @@ -121,6 +121,8 @@ type requestOp struct { ids []json.RawMessage err error resp chan *jsonrpcMessage // receives up to len(ids) responses + + sub *ClientSubscription } func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { @@ -403,6 +405,39 @@ func (c *Client) DoNotify(ctx context.Context, method string, args any) error { return c.send(ctx, op, msg) } +func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { + // Check type of channel first. + chanVal := reflect.ValueOf(channel) + if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { + panic("first argument to Subscribe must be a writable channel") + } + if chanVal.IsNil() { + panic("channel given to Subscribe must not be nil") + } + if c.isHTTP { + return nil, ErrNotificationsUnsupported + } + msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...) + if err != nil { + return nil, err + } + op := &requestOp{ + ids: []json.RawMessage{msg.ID.RawMessage()}, + resp: make(chan *jsonrpcMessage), + sub: newClientSubscription(c, namespace, chanVal), + } + + // Send the subscription request. + // The arrival and validity of the response is signaled on sub.quit. + if err := c.send(ctx, op, msg); err != nil { + return nil, err + } + if _, err := op.wait(ctx, c); err != nil { + return nil, err + } + return op.sub, nil +} + func (c *Client) newMessage(method string, paramsIn ...any) (*jsonrpcMessage, error) { msg := &jsonrpcMessage{ID: c.nextID(), Method: method} if paramsIn != nil { // prevent sending "params":null diff --git a/example/subscription/main.go b/example/subscription/main.go new file mode 100644 index 0000000..7f3c5da --- /dev/null +++ b/example/subscription/main.go @@ -0,0 +1,70 @@ +package main + +import ( + "context" + "log" + "net/http" + "time" + + "gfx.cafe/open/jrpc" + "gfx.cafe/open/jrpc/middleware" +) + +func main() { + + r := jrpc.NewRouter() + r.Use(middleware.Logger) + srv := jrpc.NewServer(r) + + r.HandleFunc("echo", func(w jrpc.ResponseWriter, r *jrpc.Request) { + w.Send(r.Params(), nil) + }) + + r.HandleFunc("testservice_subscribe", func(w jrpc.ResponseWriter, r *jrpc.Request) { + sub, err := jrpc.UpgradeToSubscription(w, r) + w.Send(sub, err) + if err != nil { + return + } + go func() { + idx := 0 + for { + log.Println("sending:", idx) + err := w.Notify(idx) + if err != nil { + return + } + time.Sleep(1 * time.Second) + idx = idx + 1 + } + }() + }) + + go func() { + err := client() + if err != nil { + panic(err) + } + }() + log.Println("running on 8855") + log.Println(http.ListenAndServe(":8855", srv.ServeHTTPWithWss(nil))) +} + +func client() error { + cl, err := jrpc.Dial("ws://localhost:8855") + if err != nil { + return err + } + out := make(chan int, 1) + jcs, err := cl.Subscribe(context.TODO(), "testservice", out, "swag") + if err != nil { + return err + } + go func() { + log.Println(<-jcs.Err()) + }() + defer jcs.Unsubscribe() + for { + log.Println("receiving", <-out) + } +} diff --git a/handler.go b/handler.go index 2a69ba8..b1afde1 100644 --- a/handler.go +++ b/handler.go @@ -17,7 +17,12 @@ package jrpc import ( + "bytes" "context" + "encoding/json" + "errors" + "reflect" + "strings" "sync" "git.tuxpa.in/a/zlog" @@ -52,11 +57,17 @@ type handler struct { conn jsonWriter // where responses will be sent log *zlog.Logger + subLock sync.RWMutex + clientSubs map[string]*ClientSubscription // active client subscriptions + serverSubs map[SubID]*Subscription + unsubscribeCb *callback + peer PeerInfo } type callProc struct { - ctx context.Context + ctx context.Context + notifiers []*Notifier } func newHandler(connCtx context.Context, conn jsonWriter, reg Router) *handler { @@ -66,6 +77,8 @@ func newHandler(connCtx context.Context, conn jsonWriter, reg Router) *handler { reg: reg, conn: conn, respWait: make(map[string]*requestOp), + clientSubs: map[string]*ClientSubscription{}, + serverSubs: map[SubID]*Subscription{}, rootCtx: rootCtx, cancelRoot: cancelRoot, log: zlog.Ctx(connCtx), @@ -74,6 +87,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, reg Router) *handler { cl := h.log.With().Str("conn", conn.remoteAddr()).Logger() h.log = &cl } + h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe)) return h } @@ -104,9 +118,13 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { answers = append(answers, answer) } } + h.addSubscriptions(cp.notifiers) if len(answers) > 0 { h.conn.writeJSON(cp.ctx, answers) } + for _, n := range cp.notifiers { + n.activate() + } }) } @@ -117,9 +135,13 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) { } h.startCallProc(func(cp *callProc) { answer := h.handleCallMsg(cp, msg) + h.addSubscriptions(cp.notifiers) if answer != nil { h.conn.writeJSON(cp.ctx, answer) } + for _, n := range cp.notifiers { + n.activate() + } }) } @@ -129,6 +151,7 @@ func (h *handler) close(err error, inflightReq *requestOp) { h.cancelAllRequests(err, inflightReq) h.callWG.Wait() h.cancelRoot() + h.cancelServerSubscriptions(err) } // addRequestOp registers a request operation. @@ -180,7 +203,11 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { start := NewTimer() switch { case msg.isNotification(): - return true + if strings.HasSuffix(msg.Method, notificationMethodSuffix) { + h.handleSubscriptionResult(msg) + return true + } + return false case msg.isResponse(): h.handleResponse(msg) h.log.Trace().Str("reqid", string(msg.ID.RawMessage())).Dur("duration", start.Since(start)).Msg("Handled RPC response") @@ -190,6 +217,17 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { } } +func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { + var result subscriptionResult + if err := json.Unmarshal(msg.Params, &result); err != nil { + h.log.Trace().Msg("Dropping invalid subscription message") + return + } + if h.clientSubs[result.ID] != nil { + h.clientSubs[result.ID].deliver(result.Result) + } +} + // handleResponse processes method call responses. func (h *handler) handleResponse(msg *jsonrpcMessage) { op := h.respWait[string(msg.ID.RawMessage())] @@ -198,7 +236,22 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) { return } delete(h.respWait, string(msg.ID.RawMessage())) - op.resp <- msg + if op.sub == nil { + op.resp <- msg + return + } + // For subscription responses, start the subscription if the server + // indicates success. EthSubscribe gets unblocked in either case through + // the op.resp channel. + defer close(op.resp) + if msg.Error != nil { + op.err = msg.Error + return + } + if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { + go op.sub.start() + h.clientSubs[op.sub.subid] = op.sub + } } // handleCallMsg executes a call message and returns the answer. @@ -218,34 +271,105 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess } } +// parseSubscriptionName extracts the subscription name from an encoded argument array. +func parseSubscriptionName(rawArgs json.RawMessage) (string, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + if tok, _ := dec.Token(); tok != json.Delim('[') { + return "", errors.New("non-array args") + } + v, _ := dec.Token() + method, ok := v.(string) + if !ok { + return "", errors.New("expected subscription name as first argument") + } + return method, nil +} + +// handleSubscribe processes *_subscribe method calls. +func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { + switch h.peer.Transport { + case "http", "https": + return msg.errorResponse(ErrNotificationsUnsupported) + } + + // Subscription method name is first argument. + name, err := parseSubscriptionName(msg.Params) + if err != nil { + return msg.errorResponse(&invalidParamsError{err.Error()}) + } + namespace := msg.namespace() + has := h.reg.Match(NewRouteContext(), msg.Method) + if !has { + return msg.errorResponse(&subscriptionNotFoundError{namespace, name}) + } + // Install notifier in context so the subscription handler can find it. + n := &Notifier{h: h, namespace: namespace, idgen: randomIDGenerator()} + cp.notifiers = append(cp.notifiers, n) + req := &Request{ctx: cp.ctx, msg: *msg, peer: h.peer} + // now actually run the handler + req = req.WithContext( + context.WithValue(req.ctx, notifierKey{}, n), + ) + + mw := NewReaderResponseWriterMsg(req) + h.reg.ServeRPC(mw, req) + + return mw.msg +} + +func (h *handler) unsubscribe(ctx context.Context, id SubID) (bool, error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + s := h.serverSubs[id] + if s == nil { + return false, ErrSubscriptionNotFound + } + close(s.err) + delete(h.serverSubs, id) + return true, nil +} + +// cancelServerSubscriptions removes all subscriptions and closes their error channels. +func (h *handler) cancelServerSubscriptions(err error) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for id, s := range h.serverSubs { + s.err <- err + close(s.err) + delete(h.serverSubs, id) + } +} + +func (h *handler) addSubscriptions(nn []*Notifier) { + h.subLock.Lock() + defer h.subLock.Unlock() + + for _, n := range nn { + if sub := n.takeSubscription(); sub != nil { + h.serverSubs[sub.ID] = sub + } + } +} + func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { callb := h.reg.Match(NewRouteContext(), msg.Method) + req := &Request{ctx: cp.ctx, msg: *msg, peer: h.peer} + mw := NewReaderResponseWriterMsg(req) + if msg.isSubscribe() { + return h.handleSubscribe(cp, msg) + } + if msg.isUnsubscribe() { + h.unsubscribeCb.ServeRPC(mw, req) + return mw.msg + } // no method found if !callb { return msg.errorResponse(&methodNotFoundError{method: msg.Method}) } - req := &Request{ctx: cp.ctx, msg: *msg, peer: h.peer} - mw := NewReaderResponseWriterMsg(req) // now actually run the handler h.reg.ServeRPC(mw, req) - //TODO: notifications - //if mw.notifications != nil { - // go func() { - // for { - // val, more := <-mw.notifications - // if !more { - // break - // } - // err := h.conn.writeJSON(cp.ctx, val) - // if err != nil { - // if mw.notifications != nil { - // close(mw.notifications) - // } - // log.Println("error in notification", err) - // } - // } - // }() - //} return mw.msg } diff --git a/http.go b/http.go index 996410c..c2cd825 100644 --- a/http.go +++ b/http.go @@ -43,7 +43,7 @@ const ( var acceptedContentTypes = []string{ // https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13 contentType, "application/json-rpc", "application/jsonrequest", - // these are added because they make sense + // these are added because they make sense, fight me! "application/jsonrpc2", "application/json-rpc2", "application/jrpc", } @@ -317,6 +317,19 @@ func isWebsocket(r *http.Request) bool { strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade") } +func (s *Server) ServeHTTPWithWss(cb func(r *http.Request)) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isWebsocket(r) { + if cb != nil { + cb(r) + } + s.WebsocketHandler([]string{"*"}).ServeHTTP(w, r) + return + } + s.ServeHTTP(w, r) + }) +} + // ServeHTTP serves JSON-RPC requests over HTTP. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Permit dumb empty requests for remote health-checks (AWS) diff --git a/json.go b/json.go index 79c28bc..1dde4e0 100644 --- a/json.go +++ b/json.go @@ -25,6 +25,7 @@ import ( "io" "reflect" "strconv" + "strings" "sync" "time" @@ -77,6 +78,18 @@ func (msg *jsonrpcMessage) isResponse() bool { func (msg *jsonrpcMessage) hasValidID() bool { return msg.ID != nil && !msg.ID.null } +func (msg *jsonrpcMessage) isSubscribe() bool { + return strings.HasSuffix(msg.Method, subscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) isUnsubscribe() bool { + return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix) +} + +func (msg *jsonrpcMessage) namespace() string { + elem := strings.SplitN(msg.Method, serviceMethodSeparator, 2) + return elem[0] +} func (msg *jsonrpcMessage) String() string { b, _ := jzon.Marshal(msg) diff --git a/protocol.go b/protocol.go index 7c8feed..6ba5e03 100644 --- a/protocol.go +++ b/protocol.go @@ -3,6 +3,7 @@ package jrpc import ( "context" "encoding/json" + "errors" "net/http" ) @@ -15,8 +16,9 @@ type Handler interface { type ResponseWriter interface { Send(v any, err error) error Option(k string, v any) - Notify(v any) error Header() http.Header + + Notify(v any) error } func (fn HandlerFunc) ServeRPC(w ResponseWriter, r *Request) { @@ -104,9 +106,11 @@ func (r *Request) Msg() jsonrpcMessage { } type ResponseWriterMsg struct { - r *Request - msg *jsonrpcMessage - notifications chan *jsonrpcMessage + r *Request + n *Notifier + s *Subscription + + msg *jsonrpcMessage options options } @@ -115,15 +119,19 @@ type options struct { sorted bool } +func UpgradeToSubscription(w ResponseWriter, r *Request) (*Subscription, error) { + not, ok := NotifierFromContext(r.ctx) + if !ok || not == nil { + return nil, errors.New("subscription not supported") + } + return not.CreateSubscription(), nil +} + func NewReaderResponseWriterMsg(r *Request) *ResponseWriterMsg { rw := &ResponseWriterMsg{ r: r, } - switch r.Peer().Transport { - case "http": - default: - rw.notifications = make(chan *jsonrpcMessage, 128) - } + rw.n, _ = NotifierFromContext(r.ctx) return rw } @@ -147,25 +155,24 @@ func (w *ResponseWriterMsg) Send(args any, e error) (err error) { w.msg = cm.errorResponse(e) return nil } - w.msg = cm.response(args) - if w.notifications != nil { - close(w.notifications) + switch c := args.(type) { + case *Subscription: + w.s = c + default: } + w.msg = cm.response(args) w.msg.sortKeys = w.options.sorted return nil } func (w *ResponseWriterMsg) Notify(args any) (err error) { - if w.notifications == nil { - return nil + if w.s == nil || w.n == nil { + return ErrSubscriptionNotFound } - cm := w.r.Msg() - nf := cm.response(args) - nf.ID = nil - nf.sortKeys = w.options.sorted - select { - case w.notifications <- nf: - default: + bts, _ := jzon.Marshal(args) + err = w.n.send(w.s, bts) + if err != nil { + return err } return nil } diff --git a/service.go b/service.go index d062149..59efd0a 100644 --- a/service.go +++ b/service.go @@ -26,6 +26,7 @@ import ( var ( contextType = reflect.TypeOf((*context.Context)(nil)).Elem() errorType = reflect.TypeOf((*error)(nil)).Elem() + stringType = reflect.TypeOf("") ) // A helper function that mimics the behavior of the handlers in the go-ethereum rpc package @@ -117,9 +118,13 @@ func (e *callback) ServeRPC(w ResponseWriter, r *Request) { w.Send(results[0].Interface(), nil) } +func NewCallback(receiver, fn reflect.Value) Handler { + return newCallback(receiver, fn) +} + // newCallback turns fn (a function) into a callback object. It returns nil if the function // is unsuitable as an RPC callback. -func newCallback(receiver, fn reflect.Value) Handler { +func newCallback(receiver, fn reflect.Value) *callback { fntype := fn.Type() c := &callback{fn: fn, rcvr: receiver, errPos: -1} // Determine parameter types. They must all be exported or builtin types. diff --git a/subscription.go b/subscription.go new file mode 100644 index 0000000..fc652b2 --- /dev/null +++ b/subscription.go @@ -0,0 +1,329 @@ +package jrpc + +import ( + "container/list" + "context" + crand "crypto/rand" + "encoding/binary" + "encoding/hex" + "encoding/json" + "errors" + "math/rand" + "reflect" + "strings" + "sync" + "time" +) + +const ( + subscribeMethodSuffix = "_subscribe" + notificationMethodSuffix = "_subscription" + unsubscribeMethodSuffix = "_unsubscribe" + serviceMethodSeparator = "_" + + maxClientSubscriptionBuffer = 12800 +) + +var ( + // ErrNotificationsUnsupported is returned when the connection doesn't support notifications + ErrNotificationsUnsupported = errors.New("notifications not supported") + // ErrNotificationNotFound is returned when the notification for the given id is not found + ErrSubscriptionNotFound = errors.New("subscription not found") +) + +var globalGen = randomIDGenerator() + +type SubID string + +// NewID returns a new, random ID. +func NewID() SubID { + return globalGen() +} + +// randomIDGenerator returns a function generates a random IDs. +func randomIDGenerator() func() SubID { + var buf = make([]byte, 8) + var seed int64 + if _, err := crand.Read(buf); err == nil { + seed = int64(binary.BigEndian.Uint64(buf)) + } else { + seed = int64(time.Now().Nanosecond()) + } + + var ( + mu sync.Mutex + rng = rand.New(rand.NewSource(seed)) // nolint: gosec + ) + return func() SubID { + mu.Lock() + defer mu.Unlock() + id := make([]byte, 16) + rng.Read(id) + return encodeSubID(id) + } +} + +func encodeSubID(b []byte) SubID { + id := hex.EncodeToString(b) + id = strings.TrimLeft(id, "0") + if id == "" { + id = "0" // ID's are RPC quantities, no leading zero's and 0 is 0x0. + } + return SubID("0x" + id) +} + +type subscriptionResult struct { + ID string `json:"subscription"` + Result json.RawMessage `json:"result,omitempty"` +} + +type notifierKey struct{} + +// NotifierFromContext returns the Notifier value stored in ctx, if any. +func NotifierFromContext(ctx context.Context) (*Notifier, bool) { + n, ok := ctx.Value(notifierKey{}).(*Notifier) + return n, ok +} + +// Notifier is tied to a RPC connection that supports subscriptions. +// Server callbacks use the notifier to send notifications. +type Notifier struct { + h *handler + namespace string + + mu sync.Mutex + sub *Subscription + buffer []json.RawMessage + callReturned bool + activated bool + + idgen func() SubID +} + +// CreateSubscription returns a new subscription that is coupled to the +// RPC connection. By default subscriptions are inactive and notifications +// are dropped until the subscription is marked as active. This is done +// by the RPC server after the subscription ID is send to the client. +func (n *Notifier) CreateSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + if n.sub != nil { + panic("can't create multiple subscriptions with Notifier") + } else if n.callReturned { + panic("can't create subscription after subscribe call has returned") + } + n.sub = &Subscription{ + ID: n.idgen(), + namespace: n.namespace, + err: make(chan error, 1), + } + return n.sub +} + +// Notify sends a notification to the client with the given data as payload. +// If an error occurs the RPC connection is closed and the error is returned. +func (n *Notifier) Notify(id SubID, data interface{}) error { + enc, err := jzon.Marshal(data) + if err != nil { + return err + } + n.mu.Lock() + defer n.mu.Unlock() + if n.sub == nil { + panic("can't Notify before subscription is created") + } else if n.sub.ID != id { + panic("Notify with wrong ID") + } + if n.activated { + return n.send(n.sub, enc) + } + n.buffer = append(n.buffer, enc) + return nil +} + +// Closed returns a channel that is closed when the RPC connection is closed. +// Deprecated: use subscription error channel +func (n *Notifier) Closed() <-chan interface{} { + return n.h.conn.closed() +} + +// takeSubscription returns the subscription (if one has been created). No subscription can +// be created after this call. +func (n *Notifier) takeSubscription() *Subscription { + n.mu.Lock() + defer n.mu.Unlock() + n.callReturned = true + return n.sub +} + +// activate is called after the subscription ID was sent to client. Notifications are +// buffered before activation. This prevents notifications being sent to the client before +// the subscription ID is sent to the client. +func (n *Notifier) activate() error { + n.mu.Lock() + defer n.mu.Unlock() + + for _, data := range n.buffer { + if err := n.send(n.sub, data); err != nil { + return err + } + } + n.activated = true + return nil +} + +func (n *Notifier) send(sub *Subscription, data json.RawMessage) error { + params, _ := jzon.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data}) + ctx := context.Background() + return n.h.conn.writeJSON(ctx, &jsonrpcMessage{ + Method: n.namespace + notificationMethodSuffix, + Params: params, + }) +} + +// A Subscription is created by a notifier and tied to that notifier. The client can use +// this subscription to wait for an unsubscribe request for the client, see Err(). +type Subscription struct { + ID SubID + namespace string + err chan error // closed on unsubscribe +} + +// Err returns a channel that is closed when the client send an unsubscribe request. +func (s *Subscription) Err() <-chan error { + return s.err +} + +// MarshalJSON marshals a subscription as its ID. +func (s *Subscription) MarshalJSON() ([]byte, error) { + return jzon.Marshal(s.ID) +} + +// ClientSubscription is a subscription established through the Client's Subscribe or +// EthSubscribe methods. +type ClientSubscription struct { + client *Client + etype reflect.Type + channel reflect.Value + namespace string + subid string + in chan json.RawMessage + + quitOnce sync.Once // ensures quit is closed once + quit chan struct{} // quit is closed when the subscription exits + errOnce sync.Once // ensures err is closed once + err chan error +} + +func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription { + sub := &ClientSubscription{ + client: c, + namespace: namespace, + etype: channel.Type().Elem(), + channel: channel, + quit: make(chan struct{}), + err: make(chan error, 1), + in: make(chan json.RawMessage), + } + return sub +} + +// Err returns the subscription error channel. The intended use of Err is to schedule +// resubscription when the client connection is closed unexpectedly. +// +// The error channel receives a value when the subscription has ended due +// to an error. The received error is nil if Close has been called +// on the underlying client and no other error has occurred. +// +// The error channel is closed when Unsubscribe is called on the subscription. +func (sub *ClientSubscription) Err() <-chan error { + return sub.err +} + +// Unsubscribe unsubscribes the notification and closes the error channel. +// It can safely be called more than once. +func (sub *ClientSubscription) Unsubscribe() { + sub.quitWithError(true, nil) + sub.errOnce.Do(func() { close(sub.err) }) +} + +func (sub *ClientSubscription) quitWithError(unsubscribeServer bool, err error) { + sub.quitOnce.Do(func() { + // The dispatch loop won't be able to execute the unsubscribe call + // if it is blocked on deliver. Close sub.quit first because it + // unblocks deliver. + close(sub.quit) + if unsubscribeServer { + sub.requestUnsubscribe() + } + if err != nil { + if err == ErrClientQuit { + err = nil // Adhere to subscription semantics. + } + sub.err <- err + } + }) +} + +func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) { + select { + case sub.in <- result: + return true + case <-sub.quit: + return false + } +} + +func (sub *ClientSubscription) start() { + sub.quitWithError(sub.forward()) +} + +func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) { + cases := []reflect.SelectCase{ + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)}, + {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)}, + {Dir: reflect.SelectSend, Chan: sub.channel}, + } + buffer := list.New() + defer buffer.Init() + for { + var chosen int + var recv reflect.Value + if buffer.Len() == 0 { + // Idle, omit send case. + chosen, recv, _ = reflect.Select(cases[:2]) + } else { + // Non-empty buffer, send the first queued item. + cases[2].Send = reflect.ValueOf(buffer.Front().Value) + chosen, recv, _ = reflect.Select(cases) + } + + switch chosen { + case 0: // <-sub.quit + return false, nil + case 1: // <-sub.in + val, err := sub.unmarshal(recv.Interface().(json.RawMessage)) + if err != nil { + return true, err + } + if buffer.Len() == maxClientSubscriptionBuffer { + return true, ErrSubscriptionQueueOverflow + } + buffer.PushBack(val) + case 2: // sub.channel<- + cases[2].Send = reflect.Value{} // Don't hold onto the value. + buffer.Remove(buffer.Front()) + } + } +} + +func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) { + val := reflect.New(sub.etype) + err := jzon.Unmarshal(result, val.Interface()) + return val.Elem().Interface(), err +} + +func (sub *ClientSubscription) requestUnsubscribe() error { + var result interface{} + return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid) +} diff --git a/testservice_test.go b/testservice_test.go index e7c93f1..ac24fca 100644 --- a/testservice_test.go +++ b/testservice_test.go @@ -19,12 +19,29 @@ package jrpc import ( "context" "errors" + "log" "strings" "time" ) func newTestServer() *Server { server := NewServer() + server.Router().HandleFunc("testservice_subscribe", func(w ResponseWriter, r *Request) { + log.Println(r.Params()) + sub, err := UpgradeToSubscription(w, r) + w.Send(sub, err) + if err != nil { + return + } + idx := 0 + for { + err := w.Notify(idx) + if err != nil { + return + } + idx = idx + 1 + } + }) if err := server.Router().RegisterStruct("test", new(testService)); err != nil { panic(err) } @@ -124,7 +141,9 @@ func (s *testService) CallMeBackLater(ctx context.Context, method string, args [ } type notificationTestService struct { - unsubscribed chan string + unsubscribed chan string + gotHangSubscriptionReq chan struct{} + unblockHangSubscription chan struct{} } func (s *notificationTestService) Echo(i int) int { @@ -137,6 +156,49 @@ func (s *notificationTestService) Unsubscribe(subid string) { } } +func (s *notificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + + // By explicitly creating an subscription we make sure that the subscription id is send + // back to the client before the first subscription.Notify is called. Otherwise the + // events might be send before the response for the *_subscribe method. + subscription := notifier.CreateSubscription() + go func() { + for i := 0; i < n; i++ { + if err := notifier.Notify(subscription.ID, val+i); err != nil { + return + } + } + select { + case <-notifier.Closed(): + case <-subscription.Err(): + } + if s.unsubscribed != nil { + s.unsubscribed <- string(subscription.ID) + } + }() + return subscription, nil +} + +// HangSubscription blocks on s.unblockHangSubscription before sending anything. +func (s *notificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) { + notifier, supported := NotifierFromContext(ctx) + if !supported { + return nil, ErrNotificationsUnsupported + } + s.gotHangSubscriptionReq <- struct{}{} + <-s.unblockHangSubscription + subscription := notifier.CreateSubscription() + + go func() { + notifier.Notify(subscription.ID, val) + }() + return subscription, nil +} + // largeRespService generates arbitrary-size JSON responses. type largeRespService struct { length int diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go index 59a9100..b58a220 100644 --- a/wsjson/wsjson.go +++ b/wsjson/wsjson.go @@ -69,14 +69,15 @@ func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { return write(ctx, c, v) } +var jpool = jsoniter.NewStream(jzon, nil, 0).Pool() + func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { w, err := c.Writer(ctx, websocket.MessageText) if err != nil { return err } - // json.Marshal cannot reuse buffers between calls as it has to return - // a copy of the byte slice but Encoder does as it directly writes to w. - st := jsoniter.NewStream(jzon, w, 1024) + st := jpool.BorrowStream(w) + defer jpool.ReturnStream(st) st.WriteVal(v) err = st.Flush() if err != nil { -- GitLab