From 0bea1c40c1d1a0ac3a3740d1293d4a98eb5621c7 Mon Sep 17 00:00:00 2001
From: a <a@a.a>
Date: Fri, 21 Oct 2022 23:21:46 -0500
Subject: [PATCH] start on subscriptions

---
 benchmark_test.go            |   3 +-
 client.go                    |  35 ++++
 example/subscription/main.go |  70 ++++++++
 handler.go                   | 170 +++++++++++++++---
 http.go                      |  15 +-
 json.go                      |  13 ++
 protocol.go                  |  49 +++---
 service.go                   |   7 +-
 subscription.go              | 329 +++++++++++++++++++++++++++++++++++
 testservice_test.go          |  64 ++++++-
 wsjson/wsjson.go             |   7 +-
 11 files changed, 711 insertions(+), 51 deletions(-)
 create mode 100644 example/subscription/main.go
 create mode 100644 subscription.go

diff --git a/benchmark_test.go b/benchmark_test.go
index 24402b3..49fa35c 100644
--- a/benchmark_test.go
+++ b/benchmark_test.go
@@ -67,12 +67,13 @@ func BenchmarkClientWebsocketEcho(b *testing.B) {
 		"on":  map[string]any{"two": "three"},
 	}
 
+	payload := []any{1, 2, 3, 4, 56, 6, wantBack, wantBack, wantBack}
 	b.StartTimer()
 	for n := 0; n < b.N; n++ {
 		eg := &errgroup.Group{}
 		for i := 0; i < 1000; i++ {
 			eg.Go(func() error {
-				return client.Call(nil, "test_echoAny", []any{1, 2, 3, 4, 56, 6, wantBack, wantBack, wantBack})
+				return client.Call(nil, "test_echoAny", payload)
 			})
 		}
 		eg.Wait()
diff --git a/client.go b/client.go
index e406bed..8a9a145 100644
--- a/client.go
+++ b/client.go
@@ -121,6 +121,8 @@ type requestOp struct {
 	ids  []json.RawMessage
 	err  error
 	resp chan *jsonrpcMessage // receives up to len(ids) responses
+
+	sub *ClientSubscription
 }
 
 func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
@@ -403,6 +405,39 @@ func (c *Client) DoNotify(ctx context.Context, method string, args any) error {
 	return c.send(ctx, op, msg)
 }
 
+func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) {
+	// Check type of channel first.
+	chanVal := reflect.ValueOf(channel)
+	if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 {
+		panic("first argument to Subscribe must be a writable channel")
+	}
+	if chanVal.IsNil() {
+		panic("channel given to Subscribe must not be nil")
+	}
+	if c.isHTTP {
+		return nil, ErrNotificationsUnsupported
+	}
+	msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...)
+	if err != nil {
+		return nil, err
+	}
+	op := &requestOp{
+		ids:  []json.RawMessage{msg.ID.RawMessage()},
+		resp: make(chan *jsonrpcMessage),
+		sub:  newClientSubscription(c, namespace, chanVal),
+	}
+
+	// Send the subscription request.
+	// The arrival and validity of the response is signaled on sub.quit.
+	if err := c.send(ctx, op, msg); err != nil {
+		return nil, err
+	}
+	if _, err := op.wait(ctx, c); err != nil {
+		return nil, err
+	}
+	return op.sub, nil
+}
+
 func (c *Client) newMessage(method string, paramsIn ...any) (*jsonrpcMessage, error) {
 	msg := &jsonrpcMessage{ID: c.nextID(), Method: method}
 	if paramsIn != nil { // prevent sending "params":null
diff --git a/example/subscription/main.go b/example/subscription/main.go
new file mode 100644
index 0000000..7f3c5da
--- /dev/null
+++ b/example/subscription/main.go
@@ -0,0 +1,70 @@
+package main
+
+import (
+	"context"
+	"log"
+	"net/http"
+	"time"
+
+	"gfx.cafe/open/jrpc"
+	"gfx.cafe/open/jrpc/middleware"
+)
+
+func main() {
+
+	r := jrpc.NewRouter()
+	r.Use(middleware.Logger)
+	srv := jrpc.NewServer(r)
+
+	r.HandleFunc("echo", func(w jrpc.ResponseWriter, r *jrpc.Request) {
+		w.Send(r.Params(), nil)
+	})
+
+	r.HandleFunc("testservice_subscribe", func(w jrpc.ResponseWriter, r *jrpc.Request) {
+		sub, err := jrpc.UpgradeToSubscription(w, r)
+		w.Send(sub, err)
+		if err != nil {
+			return
+		}
+		go func() {
+			idx := 0
+			for {
+				log.Println("sending:", idx)
+				err := w.Notify(idx)
+				if err != nil {
+					return
+				}
+				time.Sleep(1 * time.Second)
+				idx = idx + 1
+			}
+		}()
+	})
+
+	go func() {
+		err := client()
+		if err != nil {
+			panic(err)
+		}
+	}()
+	log.Println("running on 8855")
+	log.Println(http.ListenAndServe(":8855", srv.ServeHTTPWithWss(nil)))
+}
+
+func client() error {
+	cl, err := jrpc.Dial("ws://localhost:8855")
+	if err != nil {
+		return err
+	}
+	out := make(chan int, 1)
+	jcs, err := cl.Subscribe(context.TODO(), "testservice", out, "swag")
+	if err != nil {
+		return err
+	}
+	go func() {
+		log.Println(<-jcs.Err())
+	}()
+	defer jcs.Unsubscribe()
+	for {
+		log.Println("receiving", <-out)
+	}
+}
diff --git a/handler.go b/handler.go
index 2a69ba8..b1afde1 100644
--- a/handler.go
+++ b/handler.go
@@ -17,7 +17,12 @@
 package jrpc
 
 import (
+	"bytes"
 	"context"
+	"encoding/json"
+	"errors"
+	"reflect"
+	"strings"
 	"sync"
 
 	"git.tuxpa.in/a/zlog"
@@ -52,11 +57,17 @@ type handler struct {
 	conn       jsonWriter            // where responses will be sent
 	log        *zlog.Logger
 
+	subLock       sync.RWMutex
+	clientSubs    map[string]*ClientSubscription // active client subscriptions
+	serverSubs    map[SubID]*Subscription
+	unsubscribeCb *callback
+
 	peer PeerInfo
 }
 
 type callProc struct {
-	ctx context.Context
+	ctx       context.Context
+	notifiers []*Notifier
 }
 
 func newHandler(connCtx context.Context, conn jsonWriter, reg Router) *handler {
@@ -66,6 +77,8 @@ func newHandler(connCtx context.Context, conn jsonWriter, reg Router) *handler {
 		reg:        reg,
 		conn:       conn,
 		respWait:   make(map[string]*requestOp),
+		clientSubs: map[string]*ClientSubscription{},
+		serverSubs: map[SubID]*Subscription{},
 		rootCtx:    rootCtx,
 		cancelRoot: cancelRoot,
 		log:        zlog.Ctx(connCtx),
@@ -74,6 +87,7 @@ func newHandler(connCtx context.Context, conn jsonWriter, reg Router) *handler {
 		cl := h.log.With().Str("conn", conn.remoteAddr()).Logger()
 		h.log = &cl
 	}
+	h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe))
 	return h
 }
 
@@ -104,9 +118,13 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
 				answers = append(answers, answer)
 			}
 		}
+		h.addSubscriptions(cp.notifiers)
 		if len(answers) > 0 {
 			h.conn.writeJSON(cp.ctx, answers)
 		}
+		for _, n := range cp.notifiers {
+			n.activate()
+		}
 	})
 }
 
@@ -117,9 +135,13 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
 	}
 	h.startCallProc(func(cp *callProc) {
 		answer := h.handleCallMsg(cp, msg)
+		h.addSubscriptions(cp.notifiers)
 		if answer != nil {
 			h.conn.writeJSON(cp.ctx, answer)
 		}
+		for _, n := range cp.notifiers {
+			n.activate()
+		}
 	})
 }
 
@@ -129,6 +151,7 @@ func (h *handler) close(err error, inflightReq *requestOp) {
 	h.cancelAllRequests(err, inflightReq)
 	h.callWG.Wait()
 	h.cancelRoot()
+	h.cancelServerSubscriptions(err)
 }
 
 // addRequestOp registers a request operation.
@@ -180,7 +203,11 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
 	start := NewTimer()
 	switch {
 	case msg.isNotification():
-		return true
+		if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
+			h.handleSubscriptionResult(msg)
+			return true
+		}
+		return false
 	case msg.isResponse():
 		h.handleResponse(msg)
 		h.log.Trace().Str("reqid", string(msg.ID.RawMessage())).Dur("duration", start.Since(start)).Msg("Handled RPC response")
@@ -190,6 +217,17 @@ func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
 	}
 }
 
+func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
+	var result subscriptionResult
+	if err := json.Unmarshal(msg.Params, &result); err != nil {
+		h.log.Trace().Msg("Dropping invalid subscription message")
+		return
+	}
+	if h.clientSubs[result.ID] != nil {
+		h.clientSubs[result.ID].deliver(result.Result)
+	}
+}
+
 // handleResponse processes method call responses.
 func (h *handler) handleResponse(msg *jsonrpcMessage) {
 	op := h.respWait[string(msg.ID.RawMessage())]
@@ -198,7 +236,22 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) {
 		return
 	}
 	delete(h.respWait, string(msg.ID.RawMessage()))
-	op.resp <- msg
+	if op.sub == nil {
+		op.resp <- msg
+		return
+	}
+	// For subscription responses, start the subscription if the server
+	// indicates success. EthSubscribe gets unblocked in either case through
+	// the op.resp channel.
+	defer close(op.resp)
+	if msg.Error != nil {
+		op.err = msg.Error
+		return
+	}
+	if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
+		go op.sub.start()
+		h.clientSubs[op.sub.subid] = op.sub
+	}
 }
 
 // handleCallMsg executes a call message and returns the answer.
@@ -218,34 +271,105 @@ func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMess
 	}
 }
 
+// parseSubscriptionName extracts the subscription name from an encoded argument array.
+func parseSubscriptionName(rawArgs json.RawMessage) (string, error) {
+	dec := json.NewDecoder(bytes.NewReader(rawArgs))
+	if tok, _ := dec.Token(); tok != json.Delim('[') {
+		return "", errors.New("non-array args")
+	}
+	v, _ := dec.Token()
+	method, ok := v.(string)
+	if !ok {
+		return "", errors.New("expected subscription name as first argument")
+	}
+	return method, nil
+}
+
+// handleSubscribe processes *_subscribe method calls.
+func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
+	switch h.peer.Transport {
+	case "http", "https":
+		return msg.errorResponse(ErrNotificationsUnsupported)
+	}
+
+	// Subscription method name is first argument.
+	name, err := parseSubscriptionName(msg.Params)
+	if err != nil {
+		return msg.errorResponse(&invalidParamsError{err.Error()})
+	}
+	namespace := msg.namespace()
+	has := h.reg.Match(NewRouteContext(), msg.Method)
+	if !has {
+		return msg.errorResponse(&subscriptionNotFoundError{namespace, name})
+	}
+	// Install notifier in context so the subscription handler can find it.
+	n := &Notifier{h: h, namespace: namespace, idgen: randomIDGenerator()}
+	cp.notifiers = append(cp.notifiers, n)
+	req := &Request{ctx: cp.ctx, msg: *msg, peer: h.peer}
+	// now actually run the handler
+	req = req.WithContext(
+		context.WithValue(req.ctx, notifierKey{}, n),
+	)
+
+	mw := NewReaderResponseWriterMsg(req)
+	h.reg.ServeRPC(mw, req)
+
+	return mw.msg
+}
+
+func (h *handler) unsubscribe(ctx context.Context, id SubID) (bool, error) {
+	h.subLock.Lock()
+	defer h.subLock.Unlock()
+
+	s := h.serverSubs[id]
+	if s == nil {
+		return false, ErrSubscriptionNotFound
+	}
+	close(s.err)
+	delete(h.serverSubs, id)
+	return true, nil
+}
+
+// cancelServerSubscriptions removes all subscriptions and closes their error channels.
+func (h *handler) cancelServerSubscriptions(err error) {
+	h.subLock.Lock()
+	defer h.subLock.Unlock()
+
+	for id, s := range h.serverSubs {
+		s.err <- err
+		close(s.err)
+		delete(h.serverSubs, id)
+	}
+}
+
+func (h *handler) addSubscriptions(nn []*Notifier) {
+	h.subLock.Lock()
+	defer h.subLock.Unlock()
+
+	for _, n := range nn {
+		if sub := n.takeSubscription(); sub != nil {
+			h.serverSubs[sub.ID] = sub
+		}
+	}
+}
+
 func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
 	callb := h.reg.Match(NewRouteContext(), msg.Method)
+	req := &Request{ctx: cp.ctx, msg: *msg, peer: h.peer}
+	mw := NewReaderResponseWriterMsg(req)
+	if msg.isSubscribe() {
+		return h.handleSubscribe(cp, msg)
+	}
+	if msg.isUnsubscribe() {
+		h.unsubscribeCb.ServeRPC(mw, req)
+		return mw.msg
+	}
 	// no method found
 	if !callb {
 		return msg.errorResponse(&methodNotFoundError{method: msg.Method})
 	}
-	req := &Request{ctx: cp.ctx, msg: *msg, peer: h.peer}
-	mw := NewReaderResponseWriterMsg(req)
 	// now actually run the handler
 	h.reg.ServeRPC(mw, req)
 
-	//TODO: notifications
-	//if mw.notifications != nil {
-	//	go func() {
-	//		for {
-	//			val, more := <-mw.notifications
-	//			if !more {
-	//				break
-	//			}
-	//			err := h.conn.writeJSON(cp.ctx, val)
-	//			if err != nil {
-	//				if mw.notifications != nil {
-	//					close(mw.notifications)
-	//				}
-	//				log.Println("error in notification", err)
-	//			}
-	//		}
-	//	}()
-	//}
 	return mw.msg
 }
diff --git a/http.go b/http.go
index 996410c..c2cd825 100644
--- a/http.go
+++ b/http.go
@@ -43,7 +43,7 @@ const (
 var acceptedContentTypes = []string{
 	// https://www.jsonrpc.org/historical/json-rpc-over-http.html#id13
 	contentType, "application/json-rpc", "application/jsonrequest",
-	// these are added because they make sense
+	// these are added because they make sense, fight me!
 	"application/jsonrpc2", "application/json-rpc2", "application/jrpc",
 }
 
@@ -317,6 +317,19 @@ func isWebsocket(r *http.Request) bool {
 		strings.Contains(strings.ToLower(r.Header.Get("Connection")), "upgrade")
 }
 
+func (s *Server) ServeHTTPWithWss(cb func(r *http.Request)) http.Handler {
+	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+		if isWebsocket(r) {
+			if cb != nil {
+				cb(r)
+			}
+			s.WebsocketHandler([]string{"*"}).ServeHTTP(w, r)
+			return
+		}
+		s.ServeHTTP(w, r)
+	})
+}
+
 // ServeHTTP serves JSON-RPC requests over HTTP.
 func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
 	// Permit dumb empty requests for remote health-checks (AWS)
diff --git a/json.go b/json.go
index 79c28bc..1dde4e0 100644
--- a/json.go
+++ b/json.go
@@ -25,6 +25,7 @@ import (
 	"io"
 	"reflect"
 	"strconv"
+	"strings"
 	"sync"
 	"time"
 
@@ -77,6 +78,18 @@ func (msg *jsonrpcMessage) isResponse() bool {
 func (msg *jsonrpcMessage) hasValidID() bool {
 	return msg.ID != nil && !msg.ID.null
 }
+func (msg *jsonrpcMessage) isSubscribe() bool {
+	return strings.HasSuffix(msg.Method, subscribeMethodSuffix)
+}
+
+func (msg *jsonrpcMessage) isUnsubscribe() bool {
+	return strings.HasSuffix(msg.Method, unsubscribeMethodSuffix)
+}
+
+func (msg *jsonrpcMessage) namespace() string {
+	elem := strings.SplitN(msg.Method, serviceMethodSeparator, 2)
+	return elem[0]
+}
 
 func (msg *jsonrpcMessage) String() string {
 	b, _ := jzon.Marshal(msg)
diff --git a/protocol.go b/protocol.go
index 7c8feed..6ba5e03 100644
--- a/protocol.go
+++ b/protocol.go
@@ -3,6 +3,7 @@ package jrpc
 import (
 	"context"
 	"encoding/json"
+	"errors"
 	"net/http"
 )
 
@@ -15,8 +16,9 @@ type Handler interface {
 type ResponseWriter interface {
 	Send(v any, err error) error
 	Option(k string, v any)
-	Notify(v any) error
 	Header() http.Header
+
+	Notify(v any) error
 }
 
 func (fn HandlerFunc) ServeRPC(w ResponseWriter, r *Request) {
@@ -104,9 +106,11 @@ func (r *Request) Msg() jsonrpcMessage {
 }
 
 type ResponseWriterMsg struct {
-	r             *Request
-	msg           *jsonrpcMessage
-	notifications chan *jsonrpcMessage
+	r *Request
+	n *Notifier
+	s *Subscription
+
+	msg *jsonrpcMessage
 
 	options options
 }
@@ -115,15 +119,19 @@ type options struct {
 	sorted bool
 }
 
+func UpgradeToSubscription(w ResponseWriter, r *Request) (*Subscription, error) {
+	not, ok := NotifierFromContext(r.ctx)
+	if !ok || not == nil {
+		return nil, errors.New("subscription not supported")
+	}
+	return not.CreateSubscription(), nil
+}
+
 func NewReaderResponseWriterMsg(r *Request) *ResponseWriterMsg {
 	rw := &ResponseWriterMsg{
 		r: r,
 	}
-	switch r.Peer().Transport {
-	case "http":
-	default:
-		rw.notifications = make(chan *jsonrpcMessage, 128)
-	}
+	rw.n, _ = NotifierFromContext(r.ctx)
 	return rw
 }
 
@@ -147,25 +155,24 @@ func (w *ResponseWriterMsg) Send(args any, e error) (err error) {
 		w.msg = cm.errorResponse(e)
 		return nil
 	}
-	w.msg = cm.response(args)
-	if w.notifications != nil {
-		close(w.notifications)
+	switch c := args.(type) {
+	case *Subscription:
+		w.s = c
+	default:
 	}
+	w.msg = cm.response(args)
 	w.msg.sortKeys = w.options.sorted
 	return nil
 }
 
 func (w *ResponseWriterMsg) Notify(args any) (err error) {
-	if w.notifications == nil {
-		return nil
+	if w.s == nil || w.n == nil {
+		return ErrSubscriptionNotFound
 	}
-	cm := w.r.Msg()
-	nf := cm.response(args)
-	nf.ID = nil
-	nf.sortKeys = w.options.sorted
-	select {
-	case w.notifications <- nf:
-	default:
+	bts, _ := jzon.Marshal(args)
+	err = w.n.send(w.s, bts)
+	if err != nil {
+		return err
 	}
 	return nil
 }
diff --git a/service.go b/service.go
index d062149..59efd0a 100644
--- a/service.go
+++ b/service.go
@@ -26,6 +26,7 @@ import (
 var (
 	contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
 	errorType   = reflect.TypeOf((*error)(nil)).Elem()
+	stringType  = reflect.TypeOf("")
 )
 
 // A helper function that mimics the behavior of the handlers in the go-ethereum rpc package
@@ -117,9 +118,13 @@ func (e *callback) ServeRPC(w ResponseWriter, r *Request) {
 	w.Send(results[0].Interface(), nil)
 }
 
+func NewCallback(receiver, fn reflect.Value) Handler {
+	return newCallback(receiver, fn)
+}
+
 // newCallback turns fn (a function) into a callback object. It returns nil if the function
 // is unsuitable as an RPC callback.
-func newCallback(receiver, fn reflect.Value) Handler {
+func newCallback(receiver, fn reflect.Value) *callback {
 	fntype := fn.Type()
 	c := &callback{fn: fn, rcvr: receiver, errPos: -1}
 	// Determine parameter types. They must all be exported or builtin types.
diff --git a/subscription.go b/subscription.go
new file mode 100644
index 0000000..fc652b2
--- /dev/null
+++ b/subscription.go
@@ -0,0 +1,329 @@
+package jrpc
+
+import (
+	"container/list"
+	"context"
+	crand "crypto/rand"
+	"encoding/binary"
+	"encoding/hex"
+	"encoding/json"
+	"errors"
+	"math/rand"
+	"reflect"
+	"strings"
+	"sync"
+	"time"
+)
+
+const (
+	subscribeMethodSuffix    = "_subscribe"
+	notificationMethodSuffix = "_subscription"
+	unsubscribeMethodSuffix  = "_unsubscribe"
+	serviceMethodSeparator   = "_"
+
+	maxClientSubscriptionBuffer = 12800
+)
+
+var (
+	// ErrNotificationsUnsupported is returned when the connection doesn't support notifications
+	ErrNotificationsUnsupported = errors.New("notifications not supported")
+	// ErrNotificationNotFound is returned when the notification for the given id is not found
+	ErrSubscriptionNotFound = errors.New("subscription not found")
+)
+
+var globalGen = randomIDGenerator()
+
+type SubID string
+
+// NewID returns a new, random ID.
+func NewID() SubID {
+	return globalGen()
+}
+
+// randomIDGenerator returns a function generates a random IDs.
+func randomIDGenerator() func() SubID {
+	var buf = make([]byte, 8)
+	var seed int64
+	if _, err := crand.Read(buf); err == nil {
+		seed = int64(binary.BigEndian.Uint64(buf))
+	} else {
+		seed = int64(time.Now().Nanosecond())
+	}
+
+	var (
+		mu  sync.Mutex
+		rng = rand.New(rand.NewSource(seed)) // nolint: gosec
+	)
+	return func() SubID {
+		mu.Lock()
+		defer mu.Unlock()
+		id := make([]byte, 16)
+		rng.Read(id)
+		return encodeSubID(id)
+	}
+}
+
+func encodeSubID(b []byte) SubID {
+	id := hex.EncodeToString(b)
+	id = strings.TrimLeft(id, "0")
+	if id == "" {
+		id = "0" // ID's are RPC quantities, no leading zero's and 0 is 0x0.
+	}
+	return SubID("0x" + id)
+}
+
+type subscriptionResult struct {
+	ID     string          `json:"subscription"`
+	Result json.RawMessage `json:"result,omitempty"`
+}
+
+type notifierKey struct{}
+
+// NotifierFromContext returns the Notifier value stored in ctx, if any.
+func NotifierFromContext(ctx context.Context) (*Notifier, bool) {
+	n, ok := ctx.Value(notifierKey{}).(*Notifier)
+	return n, ok
+}
+
+// Notifier is tied to a RPC connection that supports subscriptions.
+// Server callbacks use the notifier to send notifications.
+type Notifier struct {
+	h         *handler
+	namespace string
+
+	mu           sync.Mutex
+	sub          *Subscription
+	buffer       []json.RawMessage
+	callReturned bool
+	activated    bool
+
+	idgen func() SubID
+}
+
+// CreateSubscription returns a new subscription that is coupled to the
+// RPC connection. By default subscriptions are inactive and notifications
+// are dropped until the subscription is marked as active. This is done
+// by the RPC server after the subscription ID is send to the client.
+func (n *Notifier) CreateSubscription() *Subscription {
+	n.mu.Lock()
+	defer n.mu.Unlock()
+	if n.sub != nil {
+		panic("can't create multiple subscriptions with Notifier")
+	} else if n.callReturned {
+		panic("can't create subscription after subscribe call has returned")
+	}
+	n.sub = &Subscription{
+		ID:        n.idgen(),
+		namespace: n.namespace,
+		err:       make(chan error, 1),
+	}
+	return n.sub
+}
+
+// Notify sends a notification to the client with the given data as payload.
+// If an error occurs the RPC connection is closed and the error is returned.
+func (n *Notifier) Notify(id SubID, data interface{}) error {
+	enc, err := jzon.Marshal(data)
+	if err != nil {
+		return err
+	}
+	n.mu.Lock()
+	defer n.mu.Unlock()
+	if n.sub == nil {
+		panic("can't Notify before subscription is created")
+	} else if n.sub.ID != id {
+		panic("Notify with wrong ID")
+	}
+	if n.activated {
+		return n.send(n.sub, enc)
+	}
+	n.buffer = append(n.buffer, enc)
+	return nil
+}
+
+// Closed returns a channel that is closed when the RPC connection is closed.
+// Deprecated: use subscription error channel
+func (n *Notifier) Closed() <-chan interface{} {
+	return n.h.conn.closed()
+}
+
+// takeSubscription returns the subscription (if one has been created). No subscription can
+// be created after this call.
+func (n *Notifier) takeSubscription() *Subscription {
+	n.mu.Lock()
+	defer n.mu.Unlock()
+	n.callReturned = true
+	return n.sub
+}
+
+// activate is called after the subscription ID was sent to client. Notifications are
+// buffered before activation. This prevents notifications being sent to the client before
+// the subscription ID is sent to the client.
+func (n *Notifier) activate() error {
+	n.mu.Lock()
+	defer n.mu.Unlock()
+
+	for _, data := range n.buffer {
+		if err := n.send(n.sub, data); err != nil {
+			return err
+		}
+	}
+	n.activated = true
+	return nil
+}
+
+func (n *Notifier) send(sub *Subscription, data json.RawMessage) error {
+	params, _ := jzon.Marshal(&subscriptionResult{ID: string(sub.ID), Result: data})
+	ctx := context.Background()
+	return n.h.conn.writeJSON(ctx, &jsonrpcMessage{
+		Method: n.namespace + notificationMethodSuffix,
+		Params: params,
+	})
+}
+
+// A Subscription is created by a notifier and tied to that notifier. The client can use
+// this subscription to wait for an unsubscribe request for the client, see Err().
+type Subscription struct {
+	ID        SubID
+	namespace string
+	err       chan error // closed on unsubscribe
+}
+
+// Err returns a channel that is closed when the client send an unsubscribe request.
+func (s *Subscription) Err() <-chan error {
+	return s.err
+}
+
+// MarshalJSON marshals a subscription as its ID.
+func (s *Subscription) MarshalJSON() ([]byte, error) {
+	return jzon.Marshal(s.ID)
+}
+
+// ClientSubscription is a subscription established through the Client's Subscribe or
+// EthSubscribe methods.
+type ClientSubscription struct {
+	client    *Client
+	etype     reflect.Type
+	channel   reflect.Value
+	namespace string
+	subid     string
+	in        chan json.RawMessage
+
+	quitOnce sync.Once     // ensures quit is closed once
+	quit     chan struct{} // quit is closed when the subscription exits
+	errOnce  sync.Once     // ensures err is closed once
+	err      chan error
+}
+
+func newClientSubscription(c *Client, namespace string, channel reflect.Value) *ClientSubscription {
+	sub := &ClientSubscription{
+		client:    c,
+		namespace: namespace,
+		etype:     channel.Type().Elem(),
+		channel:   channel,
+		quit:      make(chan struct{}),
+		err:       make(chan error, 1),
+		in:        make(chan json.RawMessage),
+	}
+	return sub
+}
+
+// Err returns the subscription error channel. The intended use of Err is to schedule
+// resubscription when the client connection is closed unexpectedly.
+//
+// The error channel receives a value when the subscription has ended due
+// to an error. The received error is nil if Close has been called
+// on the underlying client and no other error has occurred.
+//
+// The error channel is closed when Unsubscribe is called on the subscription.
+func (sub *ClientSubscription) Err() <-chan error {
+	return sub.err
+}
+
+// Unsubscribe unsubscribes the notification and closes the error channel.
+// It can safely be called more than once.
+func (sub *ClientSubscription) Unsubscribe() {
+	sub.quitWithError(true, nil)
+	sub.errOnce.Do(func() { close(sub.err) })
+}
+
+func (sub *ClientSubscription) quitWithError(unsubscribeServer bool, err error) {
+	sub.quitOnce.Do(func() {
+		// The dispatch loop won't be able to execute the unsubscribe call
+		// if it is blocked on deliver. Close sub.quit first because it
+		// unblocks deliver.
+		close(sub.quit)
+		if unsubscribeServer {
+			sub.requestUnsubscribe()
+		}
+		if err != nil {
+			if err == ErrClientQuit {
+				err = nil // Adhere to subscription semantics.
+			}
+			sub.err <- err
+		}
+	})
+}
+
+func (sub *ClientSubscription) deliver(result json.RawMessage) (ok bool) {
+	select {
+	case sub.in <- result:
+		return true
+	case <-sub.quit:
+		return false
+	}
+}
+
+func (sub *ClientSubscription) start() {
+	sub.quitWithError(sub.forward())
+}
+
+func (sub *ClientSubscription) forward() (unsubscribeServer bool, err error) {
+	cases := []reflect.SelectCase{
+		{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.quit)},
+		{Dir: reflect.SelectRecv, Chan: reflect.ValueOf(sub.in)},
+		{Dir: reflect.SelectSend, Chan: sub.channel},
+	}
+	buffer := list.New()
+	defer buffer.Init()
+	for {
+		var chosen int
+		var recv reflect.Value
+		if buffer.Len() == 0 {
+			// Idle, omit send case.
+			chosen, recv, _ = reflect.Select(cases[:2])
+		} else {
+			// Non-empty buffer, send the first queued item.
+			cases[2].Send = reflect.ValueOf(buffer.Front().Value)
+			chosen, recv, _ = reflect.Select(cases)
+		}
+
+		switch chosen {
+		case 0: // <-sub.quit
+			return false, nil
+		case 1: // <-sub.in
+			val, err := sub.unmarshal(recv.Interface().(json.RawMessage))
+			if err != nil {
+				return true, err
+			}
+			if buffer.Len() == maxClientSubscriptionBuffer {
+				return true, ErrSubscriptionQueueOverflow
+			}
+			buffer.PushBack(val)
+		case 2: // sub.channel<-
+			cases[2].Send = reflect.Value{} // Don't hold onto the value.
+			buffer.Remove(buffer.Front())
+		}
+	}
+}
+
+func (sub *ClientSubscription) unmarshal(result json.RawMessage) (interface{}, error) {
+	val := reflect.New(sub.etype)
+	err := jzon.Unmarshal(result, val.Interface())
+	return val.Elem().Interface(), err
+}
+
+func (sub *ClientSubscription) requestUnsubscribe() error {
+	var result interface{}
+	return sub.client.Call(&result, sub.namespace+unsubscribeMethodSuffix, sub.subid)
+}
diff --git a/testservice_test.go b/testservice_test.go
index e7c93f1..ac24fca 100644
--- a/testservice_test.go
+++ b/testservice_test.go
@@ -19,12 +19,29 @@ package jrpc
 import (
 	"context"
 	"errors"
+	"log"
 	"strings"
 	"time"
 )
 
 func newTestServer() *Server {
 	server := NewServer()
+	server.Router().HandleFunc("testservice_subscribe", func(w ResponseWriter, r *Request) {
+		log.Println(r.Params())
+		sub, err := UpgradeToSubscription(w, r)
+		w.Send(sub, err)
+		if err != nil {
+			return
+		}
+		idx := 0
+		for {
+			err := w.Notify(idx)
+			if err != nil {
+				return
+			}
+			idx = idx + 1
+		}
+	})
 	if err := server.Router().RegisterStruct("test", new(testService)); err != nil {
 		panic(err)
 	}
@@ -124,7 +141,9 @@ func (s *testService) CallMeBackLater(ctx context.Context, method string, args [
 }
 
 type notificationTestService struct {
-	unsubscribed chan string
+	unsubscribed            chan string
+	gotHangSubscriptionReq  chan struct{}
+	unblockHangSubscription chan struct{}
 }
 
 func (s *notificationTestService) Echo(i int) int {
@@ -137,6 +156,49 @@ func (s *notificationTestService) Unsubscribe(subid string) {
 	}
 }
 
+func (s *notificationTestService) SomeSubscription(ctx context.Context, n, val int) (*Subscription, error) {
+	notifier, supported := NotifierFromContext(ctx)
+	if !supported {
+		return nil, ErrNotificationsUnsupported
+	}
+
+	// By explicitly creating an subscription we make sure that the subscription id is send
+	// back to the client before the first subscription.Notify is called. Otherwise the
+	// events might be send before the response for the *_subscribe method.
+	subscription := notifier.CreateSubscription()
+	go func() {
+		for i := 0; i < n; i++ {
+			if err := notifier.Notify(subscription.ID, val+i); err != nil {
+				return
+			}
+		}
+		select {
+		case <-notifier.Closed():
+		case <-subscription.Err():
+		}
+		if s.unsubscribed != nil {
+			s.unsubscribed <- string(subscription.ID)
+		}
+	}()
+	return subscription, nil
+}
+
+// HangSubscription blocks on s.unblockHangSubscription before sending anything.
+func (s *notificationTestService) HangSubscription(ctx context.Context, val int) (*Subscription, error) {
+	notifier, supported := NotifierFromContext(ctx)
+	if !supported {
+		return nil, ErrNotificationsUnsupported
+	}
+	s.gotHangSubscriptionReq <- struct{}{}
+	<-s.unblockHangSubscription
+	subscription := notifier.CreateSubscription()
+
+	go func() {
+		notifier.Notify(subscription.ID, val)
+	}()
+	return subscription, nil
+}
+
 // largeRespService generates arbitrary-size JSON responses.
 type largeRespService struct {
 	length int
diff --git a/wsjson/wsjson.go b/wsjson/wsjson.go
index 59a9100..b58a220 100644
--- a/wsjson/wsjson.go
+++ b/wsjson/wsjson.go
@@ -69,14 +69,15 @@ func Write(ctx context.Context, c *websocket.Conn, v interface{}) error {
 	return write(ctx, c, v)
 }
 
+var jpool = jsoniter.NewStream(jzon, nil, 0).Pool()
+
 func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) {
 	w, err := c.Writer(ctx, websocket.MessageText)
 	if err != nil {
 		return err
 	}
-	// json.Marshal cannot reuse buffers between calls as it has to return
-	// a copy of the byte slice but Encoder does as it directly writes to w.
-	st := jsoniter.NewStream(jzon, w, 1024)
+	st := jpool.BorrowStream(w)
+	defer jpool.ReturnStream(st)
 	st.WriteVal(v)
 	err = st.Flush()
 	if err != nil {
-- 
GitLab