good morning!!!!

Skip to content
Snippets Groups Projects
router_tree.go 16.8 KiB
Newer Older
  • Learn to ignore specific revisions
  • a's avatar
    a committed
    package jmux
    
    a's avatar
    rpc
    a committed
    
    // Radix tree implementation below is a based on chi
    // https://github.com/go-chi/chi/blob/master/tree.go
    
    import (
    	"fmt"
    	"regexp"
    	"sort"
    	"strings"
    
    a's avatar
    a committed
    
    
    a's avatar
    a committed
    	"gfx.cafe/open/jrpc/pkg/jsonrpc"
    
    a's avatar
    rpc
    a committed
    )
    
    type nodeTyp uint8
    
    const (
    	ntStatic   nodeTyp = iota // /home
    	ntRegexp                  // /{id:[0-9]+}
    	ntParam                   // /{user}
    	ntCatchAll                // /api/v1/*
    )
    
    type node struct {
    	// subroutes on the leaf node
    	subroutes Routes
    
    	// regexp matcher for regexp nodes
    	rex *regexp.Regexp
    
    	// HTTP handler endpoints on the leaf node
    	endpoint *endpoint
    
    	// prefix is the common prefix we ignore
    	prefix string
    
    	// child nodes should be stored in-order for iteration,
    	// in groups of the node type.
    	children [ntCatchAll + 1]nodes
    
    	// first byte of the child prefix
    	tail byte
    
    	// node type: static, regexp, param, catchAll
    	typ nodeTyp
    
    	// first byte of the prefix
    	label byte
    }
    
    type endpoint struct {
    	// endpoint handler
    
    a's avatar
    a committed
    	handler jsonrpc.Handler
    
    a's avatar
    rpc
    a committed
    
    	// pattern is the routing pattern for handler nodes
    	pattern string
    
    	// parameter keys recorded on handler nodes
    	paramKeys []string
    }
    
    
    a's avatar
    a committed
    func (n *node) InsertRoute(pattern string, handler jsonrpc.Handler) *node {
    
    a's avatar
    rpc
    a committed
    	var parent *node
    	search := pattern
    	for {
    		// Handle key exhaustion
    		if len(search) == 0 {
    			// Insert or update the node's leaf handler
    			n.setEndpoint(handler, pattern)
    			return n
    		}
    
    		// We're going to be searching for a wild node next,
    		// in this case, we need to get the tail
    		var label = search[0]
    		var segTail byte
    		var segEndIdx int
    		var segTyp nodeTyp
    		var segRexpat string
    		if label == '{' || label == '*' {
    			segTyp, _, segRexpat, segTail, _, segEndIdx = patNextSegment(search)
    		}
    
    		var prefix string
    		if segTyp == ntRegexp {
    			prefix = segRexpat
    		}
    
    		// Look for the edge to attach to
    		parent = n
    		n = n.getEdge(segTyp, label, segTail, prefix)
    
    		// No edge, create one
    		if n == nil {
    			child := &node{label: label, tail: segTail, prefix: search}
    			hn := parent.addChild(child, search)
    			hn.setEndpoint(handler, pattern)
    			return hn
    		}
    
    		// Found an edge to match the pattern
    
    		if n.typ > ntStatic {
    			// We found a param node, trim the param from the search path and continue.
    			// This param/wild pattern segment would already be on the tree from a previous
    			// call to addChild when creating a new node.
    			search = search[segEndIdx:]
    			continue
    		}
    
    		// Static nodes fall below here.
    		// Determine longest prefix of the search key on match.
    		commonPrefix := longestPrefix(search, n.prefix)
    		if commonPrefix == len(n.prefix) {
    			// the common prefix is as long as the current node's prefix we're attempting to insert.
    			// keep the search going.
    			search = search[commonPrefix:]
    			continue
    		}
    
    		// Split the node
    		child := &node{
    			typ:    ntStatic,
    			prefix: search[:commonPrefix],
    		}
    		parent.replaceChild(search[0], segTail, child)
    
    		// Restore the existing node
    		n.label = n.prefix[commonPrefix]
    		n.prefix = n.prefix[commonPrefix:]
    		child.addChild(n, n.prefix)
    
    		// If the new key is a subset, set the method/handler on this node and finish.
    		search = search[commonPrefix:]
    		if len(search) == 0 {
    			child.setEndpoint(handler, pattern)
    			return child
    		}
    
    		// Create a new edge for the node
    		subchild := &node{
    			typ:    ntStatic,
    			label:  search[0],
    			prefix: search,
    		}
    		hn := child.addChild(subchild, search)
    		hn.setEndpoint(handler, pattern)
    		return hn
    	}
    }
    
    // addChild appends the new `child` node to the tree using the `pattern` as the trie key.
    // For a URL router like chi's, we split the static, param, regexp and wildcard segments
    // into different nodes. In addition, addChild will recursively call itself until every
    // pattern segment is added to the url pattern tree as individual nodes, depending on type.
    func (n *node) addChild(child *node, prefix string) *node {
    
    	search := prefix
    
    	// handler leaf node added to the tree is the child.
    	// this may be overridden later down the flow
    	hn := child
    
    	// Parse next segment
    	segTyp, _, segRexpat, segTail, segStartIdx, segEndIdx := patNextSegment(search)
    
    	// Add child depending on next up segment
    	switch segTyp {
    
    	case ntStatic:
    		// Search prefix is all static (that is, has no params in path)
    		// noop
    
    	case ntRegexp:
    		rex, err := regexp.Compile(segRexpat)
    		if err != nil {
    			panic(fmt.Sprintf("rpc: invalid regexp pattern '%s' in route param", segRexpat))
    		}
    		child.prefix = segRexpat
    		child.rex = rex
    
    		fallthrough
    	default:
    		// Search prefix contains a param, regexp or wildcard
    
    		if segStartIdx == 0 {
    			// Route starts with a param
    			child.typ = segTyp
    
    			if segTyp == ntCatchAll {
    				segStartIdx = -1
    			} else {
    				segStartIdx = segEndIdx
    			}
    			if segStartIdx < 0 {
    				segStartIdx = len(search)
    			}
    			child.tail = segTail // for params, we set the tail
    
    			if segStartIdx != len(search) {
    				// add static edge for the remaining part, split the end.
    				// its not possible to have adjacent param nodes, so its certainly
    				// going to be a static node next.
    
    				search = search[segStartIdx:] // advance search position
    
    				nn := &node{
    					typ:    ntStatic,
    					label:  search[0],
    					prefix: search,
    				}
    				hn = child.addChild(nn, search)
    			}
    
    		} else if segStartIdx > 0 {
    			// Route has some param
    
    			// starts with a static segment
    			child.typ = ntStatic
    			child.prefix = search[:segStartIdx]
    			child.rex = nil
    
    			// add the param edge node
    			search = search[segStartIdx:]
    
    			nn := &node{
    				typ:   segTyp,
    				label: search[0],
    				tail:  segTail,
    			}
    			hn = child.addChild(nn, search)
    
    		}
    	}
    
    	n.children[child.typ] = append(n.children[child.typ], child)
    	n.children[child.typ].Sort()
    	return hn
    }
    
    func (n *node) replaceChild(label, tail byte, child *node) {
    	for i := 0; i < len(n.children[child.typ]); i++ {
    		if n.children[child.typ][i].label == label && n.children[child.typ][i].tail == tail {
    			n.children[child.typ][i] = child
    			n.children[child.typ][i].label = label
    			n.children[child.typ][i].tail = tail
    			return
    		}
    	}
    	panic("rpc: replacing missing child")
    }
    
    func (n *node) getEdge(ntyp nodeTyp, label, tail byte, prefix string) *node {
    	nds := n.children[ntyp]
    	for i := 0; i < len(nds); i++ {
    		if nds[i].label == label && nds[i].tail == tail {
    			if ntyp == ntRegexp && nds[i].prefix != prefix {
    				continue
    			}
    			return nds[i]
    		}
    	}
    	return nil
    }
    
    
    a's avatar
    a committed
    func (n *node) setEndpoint(handler jsonrpc.Handler, pattern string) {
    
    a's avatar
    rpc
    a committed
    	paramKeys := patParamKeys(pattern)
    	n.endpoint = &endpoint{
    		handler:   handler,
    		pattern:   pattern,
    		paramKeys: paramKeys,
    	}
    }
    
    
    a's avatar
    a committed
    func (n *node) FindRoute(rctx *Context, path string) (*node, *endpoint, jsonrpc.Handler) {
    
    a's avatar
    rpc
    a committed
    	// Reset the context routing pattern and params
    	rctx.routePattern = ""
    	rctx.routeParams.Keys = rctx.routeParams.Keys[:0]
    	rctx.routeParams.Values = rctx.routeParams.Values[:0]
    
    
    a's avatar
    a committed
    	if !strings.HasPrefix(path, "/") {
    		path = "/" + path
    	}
    
    
    a's avatar
    rpc
    a committed
    	// Find the routing handlers for the path
    	rn := n.findRoute(rctx, path)
    	if rn == nil {
    		return nil, nil, nil
    	}
    
    	// Record the routing params in the request lifecycle
    	rctx.MethodParams.Keys = append(rctx.MethodParams.Keys, rctx.routeParams.Keys...)
    	rctx.MethodParams.Values = append(rctx.MethodParams.Values, rctx.routeParams.Values...)
    
    	// Record the routing pattern in the request lifecycle
    	if rn.endpoint.pattern != "" {
    		rctx.routePattern = rn.endpoint.pattern
    		rctx.RoutePatterns = append(rctx.RoutePatterns, rctx.routePattern)
    	}
    
    	return rn, rn.endpoint, rn.endpoint.handler
    }
    
    // Recursive edge traversal by checking all nodeTyp groups along the way.
    // It's like searching through a multi-dimensional radix trie.
    func (n *node) findRoute(rctx *Context, path string) *node {
    	nn := n
    	search := path
    
    	for t, nds := range nn.children {
    		ntyp := nodeTyp(t)
    		if len(nds) == 0 {
    			continue
    		}
    
    		var xn *node
    		xsearch := search
    
    		var label byte
    		if search != "" {
    			label = search[0]
    		}
    
    		switch ntyp {
    		case ntStatic:
    			xn = nds.findEdge(label)
    			if xn == nil || !strings.HasPrefix(xsearch, xn.prefix) {
    				continue
    			}
    			xsearch = xsearch[len(xn.prefix):]
    
    		case ntParam, ntRegexp:
    			// short-circuit and return no matching route for empty param values
    			if xsearch == "" {
    				continue
    			}
    
    			// serially loop through each node grouped by the tail delimiter
    			for idx := 0; idx < len(nds); idx++ {
    				xn = nds[idx]
    
    				// label for param nodes is the delimiter byte
    				p := strings.IndexByte(xsearch, xn.tail)
    
    				if p < 0 {
    
    a's avatar
    a committed
    					if xn.tail == sepRune {
    
    a's avatar
    rpc
    a committed
    						p = len(xsearch)
    					} else {
    						continue
    					}
    				} else if ntyp == ntRegexp && p == 0 {
    					continue
    				}
    
    				if ntyp == ntRegexp && xn.rex != nil {
    					if !xn.rex.MatchString(xsearch[:p]) {
    						continue
    					}
    
    a's avatar
    a committed
    				} else if strings.IndexByte(xsearch[:p], sepRune) != -1 {
    
    a's avatar
    rpc
    a committed
    					// avoid a match across path segments
    					continue
    				}
    
    				prevlen := len(rctx.routeParams.Values)
    				rctx.routeParams.Values = append(rctx.routeParams.Values, xsearch[:p])
    				xsearch = xsearch[p:]
    
    				if len(xsearch) == 0 {
    					if xn.isLeaf() {
    						h := xn.endpoint
    						if h != nil && h.handler != nil {
    							rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
    							return xn
    						}
    
    						// flag that the routing context found a route, but not a corresponding
    						// supported method
    						rctx.methodNotAllowed = true
    					}
    				}
    
    				// recursively find the next node on this branch
    				fin := xn.findRoute(rctx, xsearch)
    				if fin != nil {
    					return fin
    				}
    
    				// not found on this branch, reset vars
    				rctx.routeParams.Values = rctx.routeParams.Values[:prevlen]
    				xsearch = search
    			}
    
    			rctx.routeParams.Values = append(rctx.routeParams.Values, "")
    
    		default:
    			// catch-all nodes
    			rctx.routeParams.Values = append(rctx.routeParams.Values, search)
    			xn = nds[0]
    			xsearch = ""
    		}
    
    		if xn == nil {
    			continue
    		}
    
    		// did we find it yet?
    		if len(xsearch) == 0 {
    			if xn.isLeaf() {
    				h := xn.endpoint
    				if h != nil && h.handler != nil {
    					rctx.routeParams.Keys = append(rctx.routeParams.Keys, h.paramKeys...)
    					return xn
    				}
    				// flag that the routing context found a route, but not a corresponding
    				// supported method
    				rctx.methodNotAllowed = true
    			}
    		}
    
    		// recursively find the next node..
    		fin := xn.findRoute(rctx, xsearch)
    		if fin != nil {
    			return fin
    		}
    
    		// Did not find final handler, let's remove the param here if it was set
    		if xn.typ > ntStatic {
    			if len(rctx.routeParams.Values) > 0 {
    				rctx.routeParams.Values = rctx.routeParams.Values[:len(rctx.routeParams.Values)-1]
    			}
    		}
    
    	}
    
    	return nil
    }
    
    func (n *node) findEdge(ntyp nodeTyp, label byte) *node {
    	nds := n.children[ntyp]
    	num := len(nds)
    	idx := 0
    
    	switch ntyp {
    	case ntStatic, ntParam, ntRegexp:
    		i, j := 0, num-1
    		for i <= j {
    			idx = i + (j-i)/2
    
    a's avatar
    a committed
    			switch {
    			case label > nds[idx].label:
    
    a's avatar
    rpc
    a committed
    				i = idx + 1
    
    a's avatar
    a committed
    			case label < nds[idx].label:
    
    a's avatar
    rpc
    a committed
    				j = idx - 1
    
    a's avatar
    a committed
    			default:
    
    a's avatar
    rpc
    a committed
    				i = num // breaks cond
    			}
    		}
    		if nds[idx].label != label {
    			return nil
    		}
    		return nds[idx]
    
    	default: // catch all
    		return nds[idx]
    	}
    }
    
    func (n *node) isLeaf() bool {
    	return n.endpoint != nil
    }
    
    func (n *node) findPattern(pattern string) bool {
    	nn := n
    	for _, nds := range nn.children {
    		if len(nds) == 0 {
    			continue
    		}
    
    		n = nn.findEdge(nds[0].typ, pattern[0])
    		if n == nil {
    			continue
    		}
    
    		var idx int
    		var xpattern string
    
    		switch n.typ {
    		case ntStatic:
    			idx = longestPrefix(pattern, n.prefix)
    			if idx < len(n.prefix) {
    				continue
    			}
    
    		case ntParam, ntRegexp:
    			idx = strings.IndexByte(pattern, '}') + 1
    
    		case ntCatchAll:
    			idx = longestPrefix(pattern, "*")
    
    		default:
    			panic("rpc: unknown node type")
    		}
    
    		xpattern = pattern[idx:]
    		if len(xpattern) == 0 {
    			return true
    		}
    
    		return n.findPattern(xpattern)
    	}
    	return false
    }
    
    func (n *node) routes() []Route {
    	rts := []Route{}
    
    	n.walk(func(eps *endpoint, subroutes Routes) bool {
    
    		rts = append(rts, Route{subroutes, eps.handler, eps.pattern})
    
    		return false
    	})
    
    	return rts
    }
    
    func (n *node) walk(fn func(eps *endpoint, subroutes Routes) bool) bool {
    	// Visit the leaf values if any
    	if (n.endpoint != nil || n.subroutes != nil) && fn(n.endpoint, n.subroutes) {
    		return true
    	}
    
    	// Recurse on the children
    	for _, ns := range n.children {
    		for _, cn := range ns {
    			if cn.walk(fn) {
    				return true
    			}
    		}
    	}
    	return false
    }
    
    // patNextSegment returns the next segment details from a pattern:
    // node type, param key, regexp string, param tail byte, param starting index, param ending index
    func patNextSegment(pattern string) (nodeTyp, string, string, byte, int, int) {
    	ps := strings.Index(pattern, "{")
    	ws := strings.Index(pattern, "*")
    
    	if ps < 0 && ws < 0 {
    		return ntStatic, "", "", 0, 0, len(pattern) // we return the entire thing
    	}
    
    	// Sanity check
    	if ps >= 0 && ws >= 0 && ws < ps {
    		panic("rpc: wildcard '*' must be the last pattern in a route, otherwise use a '{param}'")
    	}
    
    
    a's avatar
    a committed
    	var tail byte = sepRune // Default endpoint tail to _ byte
    
    a's avatar
    rpc
    a committed
    
    	if ps >= 0 {
    		// Param/Regexp pattern is next
    		nt := ntParam
    
    		// Read to closing } taking into account opens and closes in curl count (cc)
    		cc := 0
    		pe := ps
    		for i, c := range pattern[ps:] {
    			if c == '{' {
    				cc++
    			} else if c == '}' {
    				cc--
    				if cc == 0 {
    					pe = ps + i
    					break
    				}
    			}
    		}
    		if pe == ps {
    			panic("rpc: route param closing delimiter '}' is missing")
    		}
    
    		key := pattern[ps+1 : pe]
    		pe++ // set end to next position
    
    		if pe < len(pattern) {
    			tail = pattern[pe]
    		}
    
    		var rexpat string
    		if idx := strings.Index(key, ":"); idx >= 0 {
    			nt = ntRegexp
    			rexpat = key[idx+1:]
    			key = key[:idx]
    		}
    
    		if len(rexpat) > 0 {
    			if rexpat[0] != '^' {
    				rexpat = "^" + rexpat
    			}
    			if rexpat[len(rexpat)-1] != '$' {
    				rexpat += "$"
    			}
    		}
    
    		return nt, key, rexpat, tail, ps, pe
    	}
    
    	// Wildcard pattern as finale
    	if ws < len(pattern)-1 {
    		panic("rpc: wildcard '*' must be the last value in a route. trim trailing text or use a '{param}' instead")
    	}
    	return ntCatchAll, "*", "", 0, ws, len(pattern)
    }
    
    func patParamKeys(pattern string) []string {
    	pat := pattern
    	paramKeys := []string{}
    	for {
    		ptyp, paramKey, _, _, _, e := patNextSegment(pat)
    		if ptyp == ntStatic {
    			return paramKeys
    		}
    		for i := 0; i < len(paramKeys); i++ {
    			if paramKeys[i] == paramKey {
    				panic(fmt.Sprintf("rpc: routing pattern '%s' contains duplicate param key, '%s'", pattern, paramKey))
    			}
    		}
    		paramKeys = append(paramKeys, paramKey)
    		pat = pat[e:]
    	}
    }
    
    // longestPrefix finds the length of the shared prefix
    // of two strings
    func longestPrefix(k1, k2 string) int {
    	max := len(k1)
    	if l := len(k2); l < max {
    		max = l
    	}
    	var i int
    	for i = 0; i < max; i++ {
    		if k1[i] != k2[i] {
    			break
    		}
    	}
    	return i
    }
    
    type nodes []*node
    
    // Sort the list of nodes by label
    func (ns nodes) Sort()              { sort.Sort(ns); ns.tailSort() }
    func (ns nodes) Len() int           { return len(ns) }
    func (ns nodes) Swap(i, j int)      { ns[i], ns[j] = ns[j], ns[i] }
    func (ns nodes) Less(i, j int) bool { return ns[i].label < ns[j].label }
    
    
    a's avatar
    a committed
    // tailSort pushes nodes with sepRune as the tail to the end of the list for param nodes.
    
    a's avatar
    rpc
    a committed
    // The list order determines the traversal order.
    func (ns nodes) tailSort() {
    	for i := len(ns) - 1; i >= 0; i-- {
    
    a's avatar
    a committed
    		if ns[i].typ > ntStatic && ns[i].tail == sepRune {
    
    a's avatar
    rpc
    a committed
    			ns.Swap(i, len(ns)-1)
    			return
    		}
    	}
    }
    
    func (ns nodes) findEdge(label byte) *node {
    	num := len(ns)
    	idx := 0
    	i, j := 0, num-1
    	for i <= j {
    		idx = i + (j-i)/2
    
    a's avatar
    a committed
    		switch {
    		case label > ns[idx].label:
    
    a's avatar
    rpc
    a committed
    			i = idx + 1
    
    a's avatar
    a committed
    		case label < ns[idx].label:
    
    a's avatar
    rpc
    a committed
    			j = idx - 1
    
    a's avatar
    a committed
    		default:
    
    a's avatar
    rpc
    a committed
    			i = num // breaks cond
    		}
    	}
    	if ns[idx].label != label {
    		return nil
    	}
    	return ns[idx]
    }
    
    // Route describes the details of a routing handler.
    type Route struct {
    	SubRoutes Routes
    
    a's avatar
    a committed
    	Handler   jsonrpc.Handler
    
    a's avatar
    rpc
    a committed
    	Pattern   string
    }
    
    // WalkFunc is the type of the function called for each method and route visited by Walk.
    
    a's avatar
    a committed
    type WalkFunc func(route string, handler jsonrpc.Handler, middlewares ...func(jsonrpc.Handler) jsonrpc.Handler) error
    
    a's avatar
    rpc
    a committed
    
    // Walk walks any router tree that implements Routes interface.
    func Walk(r Routes, walkFn WalkFunc) error {
    	return walk(r, walkFn, "")
    }
    
    
    a's avatar
    a committed
    func walk(r Routes, walkFn WalkFunc, parentRoute string, parentMw ...func(jsonrpc.Handler) jsonrpc.Handler) error {
    
    a's avatar
    rpc
    a committed
    	for _, route := range r.Routes() {
    
    a's avatar
    a committed
    		mws := make([]func(jsonrpc.Handler) jsonrpc.Handler, len(parentMw))
    
    a's avatar
    rpc
    a committed
    		copy(mws, parentMw)
    		mws = append(mws, r.Middlewares()...)
    
    		if route.SubRoutes != nil {
    			if err := walk(route.SubRoutes, walkFn, parentRoute+route.Pattern, mws...); err != nil {
    				return err
    			}
    			continue
    		}
    		handler := route.Handler
    
    
    a's avatar
    a committed
    		fullRoute := parentRoute + sepString + route.Pattern
    
    a's avatar
    a committed
    		fullRoute = strings.ReplaceAll(fullRoute, sepString+"*"+sepString, sepString)
    
    a's avatar
    rpc
    a committed
    
    		if chain, ok := handler.(*ChainHandler); ok {
    			if err := walkFn(fullRoute, chain.Endpoint, append(mws, chain.Middlewares...)...); err != nil {
    				return err
    			}
    		} else {
    			if err := walkFn(fullRoute, handler, mws...); err != nil {
    				return err
    			}
    		}
    	}
    
    	return nil
    }