From 81f0c4b5432380d788b93df48cae91023e2a0482 Mon Sep 17 00:00:00 2001 From: a <a@tuxpa.in> Date: Sun, 4 Dec 2022 01:34:18 -0600 Subject: [PATCH] swag in progress --- errors.go | 4 +- handler.go | 100 ++++++++++++++++++++++----------------------- router_request.go | 32 +++++++++++++++ router_response.go | 66 +++++++++++++++++++++++++----- 4 files changed, 141 insertions(+), 61 deletions(-) diff --git a/errors.go b/errors.go index 4b0ed13..158ea1c 100644 --- a/errors.go +++ b/errors.go @@ -41,7 +41,7 @@ type Error interface { // A DataError contains some data in addition to the error message. type DataError interface { - Error() string // returns the message + Error() string // returns the message ErrorData() any // returns the error data } @@ -58,6 +58,8 @@ var ( const defaultErrorCode = -32000 +const applicationErrorCode = -32080 + type methodNotFoundError struct{ method string } func (e *methodNotFoundError) ErrorCode() int { return -32601 } diff --git a/handler.go b/handler.go index 22ac056..0d97b52 100644 --- a/handler.go +++ b/handler.go @@ -115,7 +115,8 @@ func (h *handler) handleBatch(msgs []*jsonrpcMessage) { h.startCallProc(func(cp *callProc) { answers := make([]*jsonrpcMessage, 0, len(msgs)) for _, msg := range calls { - if answer := h.handleCallMsg(cp, msg); answer != nil { + r := NewMsgRequest(cp.ctx, h.peer, *msg) + if answer := h.handleCallMsg(cp, r); answer != nil { answers = append(answers, answer) } } @@ -135,7 +136,8 @@ func (h *handler) handleMsg(msg *jsonrpcMessage) { return } h.startCallProc(func(cp *callProc) { - answer := h.handleCallMsg(cp, msg) + r := NewMsgRequest(cp.ctx, h.peer, *msg) + answer := h.handleCallMsg(cp, r) h.addSubscriptions(cp.notifiers) if answer != nil { h.conn.WriteJSON(cp.ctx, answer) @@ -238,6 +240,7 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) { } delete(h.respWait, string(msg.ID.RawMessage())) if op.sub == nil { + // not a sub, so just send the msg back op.resp <- msg return } @@ -257,65 +260,83 @@ func (h *handler) handleResponse(msg *jsonrpcMessage) { // handleCallMsg executes a call message and returns the answer. // TODO: export prometheus metrics maybe? also fix logging -func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { +func (h *handler) handleCallMsg(ctx *callProc, r *Request) *jsonrpcMessage { switch { - case msg.isNotification(): - go h.handleCall(ctx, msg) + case r.isNotification(): + go h.handleCall(ctx, r) return nil - case msg.isCall(): - resp := h.handleCall(ctx, msg) + case r.isCall(): + resp := h.handleCall(ctx, r) return resp - case msg.hasValidID(): - return msg.errorResponse(&invalidRequestError{"invalid request"}) + case r.hasValidID(): + return r.makeError(&invalidRequestError{"invalid request"}) default: return errorMessage(&invalidRequestError{"invalid request"}) } } -// parseSubscriptionName extracts the subscription name from an encoded argument array. -func parseSubscriptionName(rawArgs json.RawMessage) (string, error) { - dec := json.NewDecoder(bytes.NewReader(rawArgs)) - if tok, _ := dec.Token(); tok != json.Delim('[') { - return "", errors.New("non-array args") +func (h *handler) handleCall(cp *callProc, r *Request) *jsonrpcMessage { + callb := h.reg.Match(NewRouteContext(), r.Method) + mw := NewReaderResponseWriterMsg(r) + if r.isSubscribe() { + return h.handleSubscribe(cp, r) } - v, _ := dec.Token() - method, ok := v.(string) - if !ok { - return "", errors.New("expected subscription name as first argument") + if r.isUnsubscribe() { + h.unsubscribeCb.ServeRPC(mw, r) + return mw.Msg() } - return method, nil + // no method found + msg := r.Msg() + if !callb { + return msg.errorResponse(&methodNotFoundError{method: r.Method}) + } + // now actually run the handler + h.reg.ServeRPC(mw, r) + return mw.Msg() } // handleSubscribe processes *_subscribe method calls. -func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { +func (h *handler) handleSubscribe(cp *callProc, r *Request) *jsonrpcMessage { switch h.peer.Transport { case "http", "https": - return msg.errorResponse(ErrNotificationsUnsupported) + return r.makeError(ErrNotificationsUnsupported) } // Subscription method name is first argument. - name, err := parseSubscriptionName(msg.Params) + name, err := parseSubscriptionName(r.Params) if err != nil { - return msg.errorResponse(&invalidParamsError{err.Error()}) + return r.makeError(&invalidParamsError{err.Error()}) } - namespace := msg.namespace() - has := h.reg.Match(NewRouteContext(), msg.Method) + namespace := r.namespace() + has := h.reg.Match(NewRouteContext(), r.Method) if !has { - return msg.errorResponse(&subscriptionNotFoundError{namespace, name}) + return r.makeError(&subscriptionNotFoundError{namespace, name}) } // Install notifier in context so the subscription handler can find it. n := &Notifier{h: h, namespace: namespace, idgen: randomIDGenerator()} cp.notifiers = append(cp.notifiers, n) - req := NewMsgRequest(cp.ctx, h.peer, *msg) + req := r.WithContext(cp.ctx) // now actually run the handler req = req.WithContext( context.WithValue(req.ctx, notifierKey{}, n), ) - mw := NewReaderResponseWriterMsg(req) h.reg.ServeRPC(mw, req) + return mw.Msg() +} - return mw.msg +// parseSubscriptionName extracts the subscription name from an encoded argument array. +func parseSubscriptionName(rawArgs json.RawMessage) (string, error) { + dec := json.NewDecoder(bytes.NewReader(rawArgs)) + if tok, _ := dec.Token(); tok != json.Delim('[') { + return "", errors.New("non-array args") + } + v, _ := dec.Token() + method, ok := v.(string) + if !ok { + return "", errors.New("expected subscription name as first argument") + } + return method, nil } func (h *handler) unsubscribe(ctx context.Context, id SubID) (bool, error) { @@ -353,24 +374,3 @@ func (h *handler) addSubscriptions(nn []*Notifier) { } } } - -func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { - callb := h.reg.Match(NewRouteContext(), msg.Method) - req := NewMsgRequest(cp.ctx, h.peer, *msg) - mw := NewReaderResponseWriterMsg(req) - if msg.isSubscribe() { - return h.handleSubscribe(cp, msg) - } - if msg.isUnsubscribe() { - h.unsubscribeCb.ServeRPC(mw, req) - return mw.msg - } - // no method found - if !callb { - return msg.errorResponse(&methodNotFoundError{method: msg.Method}) - } - // now actually run the handler - h.reg.ServeRPC(mw, req) - - return mw.msg -} diff --git a/router_request.go b/router_request.go index 4c2208e..2baa773 100644 --- a/router_request.go +++ b/router_request.go @@ -2,6 +2,7 @@ package jrpc import ( "context" + "strings" json "github.com/goccy/go-json" jsoniter "github.com/json-iterator/go" @@ -44,6 +45,37 @@ func (r *Request) ParamSlice() []any { return params } +func (r *Request) makeError(err error) *jsonrpcMessage { + m := r.Msg() + return m.errorResponse(err) +} + +// DEPRECATED +// TODO: use our router to do this? jrpc.Namespace(string) (string, string) maybe? +func (r *Request) namespace() string { + elem := strings.SplitN(r.Method, serviceMethodSeparator, 2) + return elem[0] +} + +func (r *Request) isSubscribe() bool { + return strings.HasSuffix(r.Method, subscribeMethodSuffix) +} +func (r *Request) isUnsubscribe() bool { + return strings.HasSuffix(r.Method, unsubscribeMethodSuffix) +} +func (r *Request) isNotification() bool { + return r.ID == nil && len(r.Method) > 0 +} +func (r *Request) isCall() bool { + return r.hasValidID() && len(r.Method) > 0 +} +func (r *Request) isResponse() bool { + return false +} +func (r *Request) hasValidID() bool { + return r.ID != nil && !r.ID.null +} + func (r *Request) ParamArray(a ...any) error { var params []json.RawMessage json.Unmarshal(r.Params, ¶ms) diff --git a/router_response.go b/router_response.go index 863a920..7eb0e13 100644 --- a/router_response.go +++ b/router_response.go @@ -6,12 +6,30 @@ import ( "net/http" ) -type ResponseWriterMsg struct { - r *Request - n *Notifier - s *Subscription +type Response struct { + ID *ID `json:"id,omitempty"` + Version version `json:"jsonrpc,omitempty"` + Result json.RawMessage `json:"result,omitempty"` + Error *jsonError `json:"error,omitempty"` +} + +func (r *Response) Msg() *jsonrpcMessage { + out := &jsonrpcMessage{ + ID: r.ID, + } + if r.Error != nil { + out.Error = r.Error + return out + } + out.Result = r.Result + return out +} - msg *jsonrpcMessage +type ResponseWriterMsg struct { + r *Request + resp *Response + n *Notifier + s *Subscription //TODO: add options // currently there are no useful options so i havent added any @@ -33,6 +51,9 @@ func UpgradeToSubscription(w ResponseWriter, r *Request) (*Subscription, error) func NewReaderResponseWriterMsg(r *Request) *ResponseWriterMsg { rw := &ResponseWriterMsg{ r: r, + resp: &Response{ + ID: r.ID, + }, } rw.n, _ = NotifierFromContext(r.ctx) return rw @@ -50,9 +71,23 @@ func (w *ResponseWriterMsg) Option(k string, v any) { } func (w *ResponseWriterMsg) Send(args any, e error) (err error) { - cm := w.r.Msg() if e != nil { - w.msg = cm.errorResponse(e) + if c, ok := e.(*jsonError); ok { + w.resp.Error = c + } else { + w.resp.Error = &jsonError{ + Code: applicationErrorCode, + Message: e.Error(), + } + } + ec, ok := e.(Error) + if ok { + w.resp.Error.Code = ec.ErrorCode() + } + de, ok := e.(DataError) + if ok { + w.resp.Error.Data = de.ErrorData() + } return nil } switch c := args.(type) { @@ -60,7 +95,14 @@ func (w *ResponseWriterMsg) Send(args any, e error) (err error) { w.s = c default: } - w.msg = cm.response(args) + w.resp.Result, err = jzon.Marshal(args) + if err != nil { + w.resp.Error = &jsonError{ + Code: -32603, + Message: err.Error(), + } + return nil + } return nil } @@ -76,6 +118,10 @@ func (w *ResponseWriterMsg) Notify(args any) (err error) { return nil } -func (w *ResponseWriterMsg) Result() *jsonrpcMessage { - return w.msg +func (w *ResponseWriterMsg) Response() *Response { + return w.resp +} + +func (w *ResponseWriterMsg) Msg() *jsonrpcMessage { + return w.resp.Msg() } -- GitLab