From 81f0c4b5432380d788b93df48cae91023e2a0482 Mon Sep 17 00:00:00 2001
From: a <a@tuxpa.in>
Date: Sun, 4 Dec 2022 01:34:18 -0600
Subject: [PATCH] swag in progress

---
 errors.go          |   4 +-
 handler.go         | 100 ++++++++++++++++++++++-----------------------
 router_request.go  |  32 +++++++++++++++
 router_response.go |  66 +++++++++++++++++++++++++-----
 4 files changed, 141 insertions(+), 61 deletions(-)

diff --git a/errors.go b/errors.go
index 4b0ed13..158ea1c 100644
--- a/errors.go
+++ b/errors.go
@@ -41,7 +41,7 @@ type Error interface {
 
 // A DataError contains some data in addition to the error message.
 type DataError interface {
-	Error() string          // returns the message
+	Error() string  // returns the message
 	ErrorData() any // returns the error data
 }
 
@@ -58,6 +58,8 @@ var (
 
 const defaultErrorCode = -32000
 
+const applicationErrorCode = -32080
+
 type methodNotFoundError struct{ method string }
 
 func (e *methodNotFoundError) ErrorCode() int { return -32601 }
diff --git a/handler.go b/handler.go
index 22ac056..0d97b52 100644
--- a/handler.go
+++ b/handler.go
@@ -115,7 +115,8 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
 	h.startCallProc(func(cp *callProc) {
 		answers := make([]*jsonrpcMessage, 0, len(msgs))
 		for _, msg := range calls {
-			if answer := h.handleCallMsg(cp, msg); answer != nil {
+			r := NewMsgRequest(cp.ctx, h.peer, *msg)
+			if answer := h.handleCallMsg(cp, r); answer != nil {
 				answers = append(answers, answer)
 			}
 		}
@@ -135,7 +136,8 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) {
 		return
 	}
 	h.startCallProc(func(cp *callProc) {
-		answer := h.handleCallMsg(cp, msg)
+		r := NewMsgRequest(cp.ctx, h.peer, *msg)
+		answer := h.handleCallMsg(cp, r)
 		h.addSubscriptions(cp.notifiers)
 		if answer != nil {
 			h.conn.WriteJSON(cp.ctx, answer)
@@ -238,6 +240,7 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) {
 	}
 	delete(h.respWait, string(msg.ID.RawMessage()))
 	if op.sub == nil {
+		// not a sub, so just send the msg back
 		op.resp <- msg
 		return
 	}
@@ -257,65 +260,83 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) {
 
 // handleCallMsg executes a call message and returns the answer.
 // TODO: export prometheus metrics maybe? also fix logging
-func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
+func (h *handler) handleCallMsg(ctx *callProc, r *Request) *jsonrpcMessage {
 	switch {
-	case msg.isNotification():
-		go h.handleCall(ctx, msg)
+	case r.isNotification():
+		go h.handleCall(ctx, r)
 		return nil
-	case msg.isCall():
-		resp := h.handleCall(ctx, msg)
+	case r.isCall():
+		resp := h.handleCall(ctx, r)
 		return resp
-	case msg.hasValidID():
-		return msg.errorResponse(&invalidRequestError{"invalid request"})
+	case r.hasValidID():
+		return r.makeError(&invalidRequestError{"invalid request"})
 	default:
 		return errorMessage(&invalidRequestError{"invalid request"})
 	}
 }
 
-// parseSubscriptionName extracts the subscription name from an encoded argument array.
-func parseSubscriptionName(rawArgs json.RawMessage) (string, error) {
-	dec := json.NewDecoder(bytes.NewReader(rawArgs))
-	if tok, _ := dec.Token(); tok != json.Delim('[') {
-		return "", errors.New("non-array args")
+func (h *handler) handleCall(cp *callProc, r *Request) *jsonrpcMessage {
+	callb := h.reg.Match(NewRouteContext(), r.Method)
+	mw := NewReaderResponseWriterMsg(r)
+	if r.isSubscribe() {
+		return h.handleSubscribe(cp, r)
 	}
-	v, _ := dec.Token()
-	method, ok := v.(string)
-	if !ok {
-		return "", errors.New("expected subscription name as first argument")
+	if r.isUnsubscribe() {
+		h.unsubscribeCb.ServeRPC(mw, r)
+		return mw.Msg()
 	}
-	return method, nil
+	// no method found
+	msg := r.Msg()
+	if !callb {
+		return msg.errorResponse(&methodNotFoundError{method: r.Method})
+	}
+	// now actually run the handler
+	h.reg.ServeRPC(mw, r)
+	return mw.Msg()
 }
 
 // handleSubscribe processes *_subscribe method calls.
-func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
+func (h *handler) handleSubscribe(cp *callProc, r *Request) *jsonrpcMessage {
 	switch h.peer.Transport {
 	case "http", "https":
-		return msg.errorResponse(ErrNotificationsUnsupported)
+		return r.makeError(ErrNotificationsUnsupported)
 	}
 
 	// Subscription method name is first argument.
-	name, err := parseSubscriptionName(msg.Params)
+	name, err := parseSubscriptionName(r.Params)
 	if err != nil {
-		return msg.errorResponse(&invalidParamsError{err.Error()})
+		return r.makeError(&invalidParamsError{err.Error()})
 	}
-	namespace := msg.namespace()
-	has := h.reg.Match(NewRouteContext(), msg.Method)
+	namespace := r.namespace()
+	has := h.reg.Match(NewRouteContext(), r.Method)
 	if !has {
-		return msg.errorResponse(&subscriptionNotFoundError{namespace, name})
+		return r.makeError(&subscriptionNotFoundError{namespace, name})
 	}
 	// Install notifier in context so the subscription handler can find it.
 	n := &Notifier{h: h, namespace: namespace, idgen: randomIDGenerator()}
 	cp.notifiers = append(cp.notifiers, n)
-	req := NewMsgRequest(cp.ctx, h.peer, *msg)
+	req := r.WithContext(cp.ctx)
 	// now actually run the handler
 	req = req.WithContext(
 		context.WithValue(req.ctx, notifierKey{}, n),
 	)
-
 	mw := NewReaderResponseWriterMsg(req)
 	h.reg.ServeRPC(mw, req)
+	return mw.Msg()
+}
 
-	return mw.msg
+// parseSubscriptionName extracts the subscription name from an encoded argument array.
+func parseSubscriptionName(rawArgs json.RawMessage) (string, error) {
+	dec := json.NewDecoder(bytes.NewReader(rawArgs))
+	if tok, _ := dec.Token(); tok != json.Delim('[') {
+		return "", errors.New("non-array args")
+	}
+	v, _ := dec.Token()
+	method, ok := v.(string)
+	if !ok {
+		return "", errors.New("expected subscription name as first argument")
+	}
+	return method, nil
 }
 
 func (h *handler) unsubscribe(ctx context.Context, id SubID) (bool, error) {
@@ -353,24 +374,3 @@ func (h *handler) addSubscriptions(nn []*Notifier) {
 		}
 	}
 }
-
-func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
-	callb := h.reg.Match(NewRouteContext(), msg.Method)
-	req := NewMsgRequest(cp.ctx, h.peer, *msg)
-	mw := NewReaderResponseWriterMsg(req)
-	if msg.isSubscribe() {
-		return h.handleSubscribe(cp, msg)
-	}
-	if msg.isUnsubscribe() {
-		h.unsubscribeCb.ServeRPC(mw, req)
-		return mw.msg
-	}
-	// no method found
-	if !callb {
-		return msg.errorResponse(&methodNotFoundError{method: msg.Method})
-	}
-	// now actually run the handler
-	h.reg.ServeRPC(mw, req)
-
-	return mw.msg
-}
diff --git a/router_request.go b/router_request.go
index 4c2208e..2baa773 100644
--- a/router_request.go
+++ b/router_request.go
@@ -2,6 +2,7 @@ package jrpc
 
 import (
 	"context"
+	"strings"
 
 	json "github.com/goccy/go-json"
 	jsoniter "github.com/json-iterator/go"
@@ -44,6 +45,37 @@ func (r *Request) ParamSlice() []any {
 	return params
 }
 
+func (r *Request) makeError(err error) *jsonrpcMessage {
+	m := r.Msg()
+	return m.errorResponse(err)
+}
+
+// DEPRECATED
+// TODO: use our router to do this? jrpc.Namespace(string) (string, string) maybe?
+func (r *Request) namespace() string {
+	elem := strings.SplitN(r.Method, serviceMethodSeparator, 2)
+	return elem[0]
+}
+
+func (r *Request) isSubscribe() bool {
+	return strings.HasSuffix(r.Method, subscribeMethodSuffix)
+}
+func (r *Request) isUnsubscribe() bool {
+	return strings.HasSuffix(r.Method, unsubscribeMethodSuffix)
+}
+func (r *Request) isNotification() bool {
+	return r.ID == nil && len(r.Method) > 0
+}
+func (r *Request) isCall() bool {
+	return r.hasValidID() && len(r.Method) > 0
+}
+func (r *Request) isResponse() bool {
+	return false
+}
+func (r *Request) hasValidID() bool {
+	return r.ID != nil && !r.ID.null
+}
+
 func (r *Request) ParamArray(a ...any) error {
 	var params []json.RawMessage
 	json.Unmarshal(r.Params, &params)
diff --git a/router_response.go b/router_response.go
index 863a920..7eb0e13 100644
--- a/router_response.go
+++ b/router_response.go
@@ -6,12 +6,30 @@ import (
 	"net/http"
 )
 
-type ResponseWriterMsg struct {
-	r *Request
-	n *Notifier
-	s *Subscription
+type Response struct {
+	ID      *ID             `json:"id,omitempty"`
+	Version version         `json:"jsonrpc,omitempty"`
+	Result  json.RawMessage `json:"result,omitempty"`
+	Error   *jsonError      `json:"error,omitempty"`
+}
+
+func (r *Response) Msg() *jsonrpcMessage {
+	out := &jsonrpcMessage{
+		ID: r.ID,
+	}
+	if r.Error != nil {
+		out.Error = r.Error
+		return out
+	}
+	out.Result = r.Result
+	return out
+}
 
-	msg *jsonrpcMessage
+type ResponseWriterMsg struct {
+	r    *Request
+	resp *Response
+	n    *Notifier
+	s    *Subscription
 
 	//TODO: add options
 	// currently there are no useful options so i havent added any
@@ -33,6 +51,9 @@ func UpgradeToSubscription(w ResponseWriter, r *Request) (*Subscription, error)
 func NewReaderResponseWriterMsg(r *Request) *ResponseWriterMsg {
 	rw := &ResponseWriterMsg{
 		r: r,
+		resp: &Response{
+			ID: r.ID,
+		},
 	}
 	rw.n, _ = NotifierFromContext(r.ctx)
 	return rw
@@ -50,9 +71,23 @@ func (w *ResponseWriterMsg) Option(k string, v any) {
 }
 
 func (w *ResponseWriterMsg) Send(args any, e error) (err error) {
-	cm := w.r.Msg()
 	if e != nil {
-		w.msg = cm.errorResponse(e)
+		if c, ok := e.(*jsonError); ok {
+			w.resp.Error = c
+		} else {
+			w.resp.Error = &jsonError{
+				Code:    applicationErrorCode,
+				Message: e.Error(),
+			}
+		}
+		ec, ok := e.(Error)
+		if ok {
+			w.resp.Error.Code = ec.ErrorCode()
+		}
+		de, ok := e.(DataError)
+		if ok {
+			w.resp.Error.Data = de.ErrorData()
+		}
 		return nil
 	}
 	switch c := args.(type) {
@@ -60,7 +95,14 @@ func (w *ResponseWriterMsg) Send(args any, e error) (err error) {
 		w.s = c
 	default:
 	}
-	w.msg = cm.response(args)
+	w.resp.Result, err = jzon.Marshal(args)
+	if err != nil {
+		w.resp.Error = &jsonError{
+			Code:    -32603,
+			Message: err.Error(),
+		}
+		return nil
+	}
 	return nil
 }
 
@@ -76,6 +118,10 @@ func (w *ResponseWriterMsg) Notify(args any) (err error) {
 	return nil
 }
 
-func (w *ResponseWriterMsg) Result() *jsonrpcMessage {
-	return w.msg
+func (w *ResponseWriterMsg) Response() *Response {
+	return w.resp
+}
+
+func (w *ResponseWriterMsg) Msg() *jsonrpcMessage {
+	return w.resp.Msg()
 }
-- 
GitLab