From 5f1ef933475c1f052c62a12d04caead5ca4049a4 Mon Sep 17 00:00:00 2001
From: a <a@a.a>
Date: Mon, 29 Aug 2022 21:34:21 -0500
Subject: [PATCH] commented out tests since we now are okay with bad content
 types

---
 client_test.go             | 16 ++++----
 http_test.go               | 10 +++--
 mux.go                     | 18 +++++++++
 callback.go => protocol.go | 53 ------------------------
 router.go                  |  9 ++++-
 server.go                  | 10 +----
 service.go                 | 83 ++++++++++++++++++++++++++++++--------
 testservice_test.go        |  4 +-
 websocket_test.go          |  2 +-
 9 files changed, 110 insertions(+), 95 deletions(-)
 rename callback.go => protocol.go (53%)

diff --git a/client_test.go b/client_test.go
index 848ed34..72b77a3 100644
--- a/client_test.go
+++ b/client_test.go
@@ -26,7 +26,6 @@ import (
 	"os"
 	"reflect"
 	"runtime"
-	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -280,14 +279,15 @@ func TestClientSetHeader(t *testing.T) {
 		t.Fatal("client did not set custom header")
 	}
 
+	//NOTE: this test is removed because we accept invalid content types
 	// Check that Content-Type can be replaced.
-	client.SetHeader("content-type", "application/x-garbage")
-	_, err = client.SupportedModules()
-	if err == nil {
-		t.Fatal("no error for invalid content-type header")
-	} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
-		t.Fatalf("error is not related to content-type: %q", err)
-	}
+	//client.SetHeader("content-type", "application/x-garbage")
+	//_, err = client.SupportedModules()
+	//if err == nil {
+	//	t.Fatal("no error for invalid content-type header")
+	//} else if !strings.Contains(err.Error(), "Unsupported Media Type") {
+	//	t.Fatalf("error is not related to content-type: %q", err)
+	//}
 }
 
 func TestClientHTTP(t *testing.T) {
diff --git a/http_test.go b/http_test.go
index 85850cd..c1242e5 100644
--- a/http_test.go
+++ b/http_test.go
@@ -67,9 +67,11 @@ func TestHTTPErrorResponseWithMaxContentLength(t *testing.T) {
 		http.MethodPost, contentType, string(body), http.StatusRequestEntityTooLarge)
 }
 
-func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) {
-	confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType)
-}
+//NOTE: this test is not needed since we no longer check this
+//
+//func TestHTTPErrorResponseWithEmptyContentType(t *testing.T) {
+//	confirmRequestValidationCode(t, http.MethodPost, "", "", http.StatusUnsupportedMediaType)
+//}
 
 func TestHTTPErrorResponseWithValidRequest(t *testing.T) {
 	confirmRequestValidationCode(t, http.MethodPost, contentType, "", 0)
@@ -105,7 +107,7 @@ func TestHTTPRespBodyUnlimited(t *testing.T) {
 
 	s := NewServer()
 	defer s.Stop()
-	s.RegisterName("test", largeRespService{respLength})
+	s.Router().RegisterStruct("test", largeRespService{respLength})
 	ts := httptest.NewServer(s)
 	defer ts.Close()
 
diff --git a/mux.go b/mux.go
index c41452c..9186b95 100644
--- a/mux.go
+++ b/mux.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"errors"
 	"fmt"
+	"reflect"
 	"sync"
 )
 
@@ -56,6 +57,23 @@ func NewMux() *Mux {
 	return mux
 }
 
+func (m *Mux) RegisterStruct(name string, rcvr any) error {
+	rcvrVal := reflect.ValueOf(rcvr)
+	if name == "" {
+		return fmt.Errorf("no service name for type %s", rcvrVal.Type().String())
+	}
+	callbacks := suitableCallbacks(rcvrVal)
+	if len(callbacks) == 0 {
+		return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr)
+	}
+	m.Route(name, func(r Router) {
+		for nm, cb := range callbacks {
+			r.Handle(nm, cb)
+		}
+	})
+	return nil
+}
+
 // ServeRPC is the single method of the Handler interface that makes
 // Mux interoperable with the standard library. It uses a sync.Pool to get and
 // reuse routing contexts for each request.
diff --git a/callback.go b/protocol.go
similarity index 53%
rename from callback.go
rename to protocol.go
index eaff16a..4af0b6c 100644
--- a/callback.go
+++ b/protocol.go
@@ -4,10 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"io"
-	"reflect"
-	"runtime"
 
-	"git.tuxpa.in/a/zlog/log"
 	jsoniter "github.com/json-iterator/go"
 )
 
@@ -118,53 +115,3 @@ func (w *ResponseWriterMsg) Send(args any, e error) (err error) {
 	w.msg = cm.response(args)
 	return nil
 }
-
-// callback is a method callback which was registered in the server
-type callback struct {
-	fn       reflect.Value  // the function
-	rcvr     reflect.Value  // receiver object of method, set if fn is method
-	argTypes []reflect.Type // input argument types
-	hasCtx   bool           // method's first argument is a context (not included in argTypes)
-	errPos   int            // err return idx, of -1 when method cannot return error
-}
-
-// callback handler implements handler for the original receiver style that geth used
-func (e *callback) ServeRPC(w ResponseWriter, r *Request) {
-	argTypes := append([]reflect.Type{}, e.argTypes...)
-	args, err := parsePositionalArguments(r.msg.Params, argTypes)
-	if err != nil {
-		w.Send(nil, &invalidParamsError{err.Error()})
-		return
-	}
-	// Create the argument slice.
-	fullargs := make([]reflect.Value, 0, 2+len(args))
-	if e.rcvr.IsValid() {
-		fullargs = append(fullargs, e.rcvr)
-	}
-	if e.hasCtx {
-		fullargs = append(fullargs, reflect.ValueOf(r.ctx))
-	}
-	fullargs = append(fullargs, args...)
-	// Catch panic while running the callback.
-	defer func() {
-		if err := recover(); err != nil {
-			const size = 64 << 10
-			buf := make([]byte, size)
-			buf = buf[:runtime.Stack(buf, false)]
-			log.Error().Str("method", r.msg.Method).Interface("err", err).Hex("buf", buf).Msg("crashed")
-			//		errRes := errors.New("method handler crashed: " + fmt.Sprint(err))
-			w.Send(nil, nil)
-			return
-		}
-	}()
-	// Run the callback.
-	results := e.fn.Call(fullargs)
-	if e.errPos >= 0 && !results[e.errPos].IsNil() {
-		// Method has returned non-nil error value.
-		err := results[e.errPos].Interface().(error)
-		w.Send(nil, err)
-		return
-	}
-	w.Send(results[0].Interface(), nil)
-	return
-}
diff --git a/router.go b/router.go
index 7807101..f3b302d 100644
--- a/router.go
+++ b/router.go
@@ -5,11 +5,18 @@ func NewRouter() *Mux {
 	return NewMux()
 }
 
+type StructReflector interface {
+	// mimics the behavior of the handlers in the go-ethereum rpc package
+	// if you don't know how to use this, just use the chi-like interface instead.
+	RegisterStruct(pattern string, rcvr any) error
+}
+
 // Router consisting of the core routing methods used by chi's Mux,
-// using only the standard net/
+// adapted to fit json-rpc.
 type Router interface {
 	Handler
 	Routes
+	StructReflector
 
 	// Use appends one or more middlewares onto the Router stack.
 	Use(middlewares ...func(Handler) Handler)
diff --git a/server.go b/server.go
index 5608478..50d3560 100644
--- a/server.go
+++ b/server.go
@@ -37,7 +37,7 @@ func NewServer(r ...Router) *Server {
 	// Register the default service providing meta information about the RPC service such
 	// as the services and methods it offers.
 	rpcService := &RPCService{server}
-	server.RegisterName(MetadataApi, rpcService)
+	server.services.RegisterStruct(MetadataApi, rpcService)
 	return server
 }
 
@@ -45,14 +45,6 @@ func (s *Server) Router() Router {
 	return s.services
 }
 
-// RegisterName creates a service for the given receiver type under the given name. When no
-// methods on the given receiver match the criteria to be either a RPC method or a
-// subscription an error is returned. Otherwise a new service is created and added to the
-// service collection this server provides to clients.
-func (s *Server) RegisterName(name string, receiver any) error {
-	return RegisterStruct(s.services, name, receiver)
-}
-
 // ServeCodec reads incoming requests from codec, calls the appropriate callback and writes
 // the response back using the given codec. It will block until the codec is closed or the
 // server is stopped. In either case the codec is closed.
diff --git a/service.go b/service.go
index 42c2824..a5defae 100644
--- a/service.go
+++ b/service.go
@@ -19,7 +19,6 @@ package jrpc
 import (
 	"context"
 	"errors"
-	"fmt"
 	"reflect"
 	"runtime"
 	"unicode"
@@ -35,22 +34,22 @@ var (
 
 // A helper function that mimics the behavior of the handlers in the go-ethereum rpc package
 // if you don't know how to use this, just use the chi-like interface instead.
-func RegisterStruct(r Router, name string, rcvr any) error {
-	rcvrVal := reflect.ValueOf(rcvr)
-	if name == "" {
-		return fmt.Errorf("no service name for type %s", rcvrVal.Type().String())
-	}
-	callbacks := suitableCallbacks(rcvrVal)
-	if len(callbacks) == 0 {
-		return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr)
-	}
-	r.Route(name, func(r Router) {
-		for nm, cb := range callbacks {
-			r.Handle(nm, cb)
-		}
-	})
-	return nil
-}
+//func RegisterStruct(r Router, name string, rcvr any) error {
+//	rcvrVal := reflect.ValueOf(rcvr)
+//	if name == "" {
+//		return fmt.Errorf("no service name for type %s", rcvrVal.Type().String())
+//	}
+//	callbacks := suitableCallbacks(rcvrVal)
+//	if len(callbacks) == 0 {
+//		return fmt.Errorf("service %T doesn't have any suitable methods/subscriptions to expose", rcvr)
+//	}
+//	r.Route(name, func(r Router) {
+//		for nm, cb := range callbacks {
+//			r.Handle(nm, cb)
+//		}
+//	})
+//	return nil
+//}
 
 // suitableCallbacks iterates over the methods of the given type. It determines if a method
 // satisfies the criteria for a RPC callback or a subscription callback and adds it to the
@@ -73,6 +72,56 @@ func suitableCallbacks(receiver reflect.Value) map[string]Handler {
 	return callbacks
 }
 
+// callback is a method callback which was registered in the server
+type callback struct {
+	fn       reflect.Value  // the function
+	rcvr     reflect.Value  // receiver object of method, set if fn is method
+	argTypes []reflect.Type // input argument types
+	hasCtx   bool           // method's first argument is a context (not included in argTypes)
+	errPos   int            // err return idx, of -1 when method cannot return error
+}
+
+// callback handler implements handler for the original receiver style that geth used
+func (e *callback) ServeRPC(w ResponseWriter, r *Request) {
+	argTypes := append([]reflect.Type{}, e.argTypes...)
+	args, err := parsePositionalArguments(r.msg.Params, argTypes)
+	if err != nil {
+		w.Send(nil, &invalidParamsError{err.Error()})
+		return
+	}
+	// Create the argument slice.
+	fullargs := make([]reflect.Value, 0, 2+len(args))
+	if e.rcvr.IsValid() {
+		fullargs = append(fullargs, e.rcvr)
+	}
+	if e.hasCtx {
+		fullargs = append(fullargs, reflect.ValueOf(r.ctx))
+	}
+	fullargs = append(fullargs, args...)
+	// Catch panic while running the callback.
+	defer func() {
+		if err := recover(); err != nil {
+			const size = 64 << 10
+			buf := make([]byte, size)
+			buf = buf[:runtime.Stack(buf, false)]
+			log.Error().Str("method", r.msg.Method).Interface("err", err).Hex("buf", buf).Msg("crashed")
+			//		errRes := errors.New("method handler crashed: " + fmt.Sprint(err))
+			w.Send(nil, nil)
+			return
+		}
+	}()
+	// Run the callback.
+	results := e.fn.Call(fullargs)
+	if e.errPos >= 0 && !results[e.errPos].IsNil() {
+		// Method has returned non-nil error value.
+		err := results[e.errPos].Interface().(error)
+		w.Send(nil, err)
+		return
+	}
+	w.Send(results[0].Interface(), nil)
+	return
+}
+
 // newCallback turns fn (a function) into a callback object. It returns nil if the function
 // is unsuitable as an RPC callback.
 func newCallback(receiver, fn reflect.Value) Handler {
diff --git a/testservice_test.go b/testservice_test.go
index 393fdd1..5678e58 100644
--- a/testservice_test.go
+++ b/testservice_test.go
@@ -25,10 +25,10 @@ import (
 
 func newTestServer() *Server {
 	server := NewServer()
-	if err := server.RegisterName("test", new(testService)); err != nil {
+	if err := server.Router().RegisterStruct("test", new(testService)); err != nil {
 		panic(err)
 	}
-	if err := server.RegisterName("nftest", new(notificationTestService)); err != nil {
+	if err := server.Router().RegisterStruct("nftest", new(notificationTestService)); err != nil {
 		panic(err)
 	}
 	return server
diff --git a/websocket_test.go b/websocket_test.go
index ac42551..ed25502 100644
--- a/websocket_test.go
+++ b/websocket_test.go
@@ -156,7 +156,7 @@ func TestClientWebsocketLargeMessage(t *testing.T) {
 	defer httpsrv.Close()
 
 	respLength := wsMessageSizeLimit - 50
-	srv.RegisterName("test", largeRespService{respLength})
+	srv.Router().RegisterStruct("test", largeRespService{respLength})
 
 	c, err := DialWebsocket(context.Background(), wsURL, "")
 	if err != nil {
-- 
GitLab