diff --git a/service.go b/service.go index bd1bf1156ffa8345d63fd90917881a8cc8dede47..0f5e210347f5771a1f3ea318a5f1348652572c32 100644 --- a/service.go +++ b/service.go @@ -33,9 +33,8 @@ var ( stringType = reflect.TypeOf("") ) -// TODO: this entire thing is a fucking mess -// TODO: we need to redo this router style, basically it's like URLs, but "_" instead of "/". Maybe look at the chi code and see how they did it -// TODO: also we need to add support for things like default handlers, prefix handlers, etc +// Supported for legacy reasons. +//TODO: we should redo our tests such that we no longer need this function. func registerStruct(r Router, name string, rcvr any) error { rcvrVal := reflect.ValueOf(rcvr) if name == "" { @@ -124,7 +123,6 @@ func (c *callback) makeArgTypes() { } // call invokes the callback. -// NOTE: this is done with some sorta awkward reflection. I wonder if there is a neater way to do this. func (c *callback) call(ctx context.Context, method string, args []reflect.Value) (res any, errRes error) { // Create the argument slice. fullargs := make([]reflect.Value, 0, 2+len(args)) @@ -160,14 +158,6 @@ func (c *callback) call(ctx context.Context, method string, args []reflect.Value return results[0].Interface(), nil } -// Is t context.Context or *context.Context? -func isContextType(t reflect.Type) bool { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - return t == contextType -} - // Does t satisfy the error interface? func isErrorType(t reflect.Type) bool { for t.Kind() == reflect.Ptr { diff --git a/websocket.go b/websocket.go index aba89745f9967d112f45d7322d874f2fc2eb17eb..052662b6b80410294c6ab18cf9e97ea5d44a0006 100644 --- a/websocket.go +++ b/websocket.go @@ -19,16 +19,12 @@ package jrpc import ( "context" "encoding/base64" - "fmt" "net/http" "net/url" - "os" - "strings" "sync" "time" "git.tuxpa.in/a/zlog/log" - mapset "github.com/deckarep/golang-set" "nhooyr.io/websocket" "nhooyr.io/websocket/wsjson" ) @@ -64,50 +60,6 @@ func (s *Server) WebsocketHandler(allowedOrigins []string) http.Handler { }) } -// wsHandshakeValidator returns a handler that verifies the origin during the -// websocket upgrade process. When a '*' is specified as an allowed origins all -// connections are accepted. -func wsHandshakeValidator(allowedOrigins []string) func(*http.Request) bool { - origins := mapset.NewSet() - allowAllOrigins := false - - for _, origin := range allowedOrigins { - if origin == "*" { - allowAllOrigins = true - } - if origin != "" { - origins.Add(origin) - } - } - // allow localhost if no allowedOrigins are specified. - if len(origins.ToSlice()) == 0 { - origins.Add("http://localhost*") - if hostname, err := os.Hostname(); err == nil { - origins.Add("http://" + hostname + "*") - } - } - log.Debug().Msg(fmt.Sprintf("Allowed origin(s) for WS RPC interface %v", origins.ToSlice())) - - f := func(req *http.Request) bool { - // Skip origin verification if no Origin header is present. The origin check - // is supposed to protect against browser based attacks. Browsers always set - // Origin. Non-browser software can put anything in origin and checking it doesn't - // provide additional security. - if _, ok := req.Header["Origin"]; !ok { - return true - } - // Verify origin against allow list. - origin := strings.ToLower(req.Header.Get("Origin")) - if allowAllOrigins || originIsAllowed(origins, origin) { - return true - } - log.Warn().Str("origin", origin).Msg("Rejected WebSocket connection") - return false - } - - return f -} - type wsHandshakeError struct { err error status string @@ -121,66 +73,6 @@ func (e wsHandshakeError) Error() string { return s } -func originIsAllowed(allowedOrigins mapset.Set, browserOrigin string) bool { - it := allowedOrigins.Iterator() - for origin := range it.C { - if ruleAllowsOrigin(origin.(string), browserOrigin) { - return true - } - } - return false -} - -func ruleAllowsOrigin(allowedOrigin string, browserOrigin string) bool { - var ( - allowedScheme, allowedHostname, allowedPort string - browserScheme, browserHostname, browserPort string - err error - ) - allowedScheme, allowedHostname, allowedPort, err = parseOriginURL(allowedOrigin) - if err != nil { - log.Warn().Str("spec", allowedOrigin).Err(err).Msg("Error parsing allowed origin specification") - return false - } - browserScheme, browserHostname, browserPort, err = parseOriginURL(browserOrigin) - if err != nil { - log.Warn().Str("Origin", browserOrigin).Err(err).Msg("Error parsing browser 'Origin' field") - return false - } - if allowedScheme != "" && allowedScheme != browserScheme { - return false - } - if allowedHostname != "" && allowedHostname != browserHostname { - return false - } - if allowedPort != "" && allowedPort != browserPort { - return false - } - return true -} - -func parseOriginURL(origin string) (string, string, string, error) { - parsedURL, err := url.Parse(strings.ToLower(origin)) - if err != nil { - return "", "", "", err - } - var scheme, hostname, port string - if strings.Contains(origin, "://") { - scheme = parsedURL.Scheme - hostname = parsedURL.Hostname() - port = parsedURL.Port() - } else { - scheme = "" - hostname = parsedURL.Scheme - port = parsedURL.Opaque - if hostname == "" { - hostname = origin - } - } - return scheme, hostname, port, nil -} - -// DialWebsocketWithDialer creates a new RPC client that communicates with a JSON-RPC server // that is listening on the given endpoint using the provided dialer. func DialWebsocketWithDialer(ctx context.Context, endpoint, origin string, opts *websocket.DialOptions) (*Client, error) { endpoint, header, err := wsClientHeaders(endpoint, origin)