diff --git a/go.mod b/go.mod index fd75d65b028cfc2389d48d956dc928ff29eb334a..a2619c25dd6e8900eeab224602d851b900cfe271 100644 --- a/go.mod +++ b/go.mod @@ -33,11 +33,13 @@ require ( github.com/kr/pretty v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect github.com/lib/pq v1.9.0 // indirect + github.com/looplab/fsm v0.3.0 // indirect github.com/mattn/go-colorable v0.1.12 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/sirupsen/logrus v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/theckman/go-fsm v0.0.2 // indirect github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/stringprep v1.0.3 // indirect golang.org/x/exp v0.0.0-20220827204233-334a2380cb91 // indirect diff --git a/go.sum b/go.sum index 1efff19f0129c2cf2a74831221cbbfd9d239875e..5c5c65ead0e3f3110fd772aebe66d45edce88795 100644 --- a/go.sum +++ b/go.sum @@ -162,6 +162,8 @@ github.com/labstack/echo/v4 v4.1.11/go.mod h1:i541M3Fj6f76NZtHSj7TXnyM8n2gaodfvf github.com/labstack/gommon v0.3.0/go.mod h1:MULnywXg0yavhxWKc+lOruYdAhDwPK9wf0OL7NoOu+k= github.com/lib/pq v1.9.0 h1:L8nSXQQzAYByakOFMTwpjRoHsMJklur4Gi59b6VivR8= github.com/lib/pq v1.9.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/looplab/fsm v0.3.0 h1:kIgNS3Yyud1tyxhG8kDqh853B7QqwnlWdgL3TD2s3Sw= +github.com/looplab/fsm v0.3.0/go.mod h1:PmD3fFvQEIsjMEfvZdrCDZ6y8VwKTwWNjlpEr6IKPO4= github.com/magiconair/properties v1.8.0/go.mod h1:PppfXfuXeibc/6YijjN8zIbojt8czPbwD3XqdrwzmxQ= github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= @@ -228,6 +230,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/theckman/go-fsm v0.0.2 h1:KdFn95Si2ATAGWzExxCONuwZiY3caSPITlMBTw4Y3VI= +github.com/theckman/go-fsm v0.0.2/go.mod h1:hN13NqBn5Mf9MbGIw1ToT3dMtRa36//yr75uRI2vha8= github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc= github.com/ugorji/go/codec v0.0.0-20181204163529-d75b2dcb6bc8/go.mod h1:VFNgLljTbGfSG7qAOspJ7OScBnGdDN/yBr0sguwnwf0= github.com/urfave/negroni v1.0.0/go.mod h1:Meg73S6kFm/4PpbYdq35yYWoCZ9mS/YSx+lKnmiohz4= diff --git a/lib/util/cmux/cmux.go b/lib/util/cmux/cmux.go new file mode 100644 index 0000000000000000000000000000000000000000..bddefe836b26d210d1f1c937faa81398b4909602 --- /dev/null +++ b/lib/util/cmux/cmux.go @@ -0,0 +1,104 @@ +package cmux + +import ( + "strings" + "sync" + + "github.com/looplab/fsm" +) + +type Mux[T any] interface { + Register([]string, func([]string) T) + + Call([]string) T +} + +type funcSet[T any] struct { + Ref []string + Call func([]string) T +} + +type FsmMux[T any] struct { + f *fsm.FSM + funcs map[string]funcSet[T] + + sync.RWMutex +} + +func (f *FsmMux[T]) Register(path []string, fn func([]string) T) { + execkey := strings.Join(path, "|") + f.funcs[execkey] = funcSet[T]{ + Ref: path, + Call: fn, + } + f.construct() +} + +func (f *FsmMux[T]) construct() { + evts := fsm.Events{} + cbs := fsm.Callbacks{} + for _, fset := range f.funcs { + path := fset.Ref + lp := len(path) + switch lp { + case 0: + case 1: + evts = append(evts, fsm.EventDesc{ + Name: path[0], + Src: []string{}, + Dst: "", + }) + default: + for i := 0; i < (len(path) - 1); i++ { + evts = append(evts, fsm.EventDesc{ + Name: path[i], + Src: []string{}, + Dst: path[i+1], + }) + } + } + } + f.f = fsm.NewFSM("", evts, cbs) +} + +func (f *FsmMux[T]) Call(k []string) T { + fn := f.funcs[""].Call + args := k + path := k + lp := len(path) + switch lp { + case 0: + case 1: + args = args[1:] + fn = f.funcs[k[0]].Call + default: + f.Lock() + f.f.SetState(path[0]) + for i := 1; i < len(path); i++ { + if f.f.Can(path[i]) { + f.f.Event(path[i]) + } else { + key := strings.Join(path[:i], "|") + if mb, ok := f.funcs[key]; ok { + fn = mb.Call + args = args[i:] + break + } + } + } + f.Unlock() + } + return fn(args) +} + +func NewFsmMux[T any]() Mux[T] { + o := &FsmMux[T]{ + funcs: map[string]funcSet[T]{ + "": { + Ref: []string{}, + Call: func([]string) T { return *new(T) }, + }, + }, + } + return o +} diff --git a/lib/util/cmux/cmux_test.go b/lib/util/cmux/cmux_test.go new file mode 100644 index 0000000000000000000000000000000000000000..80af8e1680b8cf790e2178b2d4aeb776e08133ae --- /dev/null +++ b/lib/util/cmux/cmux_test.go @@ -0,0 +1,22 @@ +package cmux + +import ( + "log" + "testing" +) + +func TestFsm(t *testing.T) { + m := NewFsmMux[error]() + + m.Register([]string{"set", "shard", "to"}, func(s []string) error { + log.Println(s) + return nil + }) + m.Register([]string{"set", "sharding", "key", "to"}, func(s []string) error { + log.Println(s) + return nil + }) + + m.Call([]string{"set", "shard", "to", "doggo", "wow", "this", "works"}) + +} diff --git a/lib/util/fsm/fsm.go b/lib/util/fsm/fsm.go new file mode 100644 index 0000000000000000000000000000000000000000..225a120e8bef75c3ae784563e59bce51975afeb9 --- /dev/null +++ b/lib/util/fsm/fsm.go @@ -0,0 +1,231 @@ +package fsm + +import ( + "fmt" + "sync" +) + +// TransitionRuleSet is a set of allowed transitions. This uses map of struct{} +// to implement a set. +type TransitionRuleSet map[string]struct{} + +// Copy copies the TransitionRuleSet in to a different TransitionRuleSet. +func (trs TransitionRuleSet) Copy() TransitionRuleSet { + srt := make(TransitionRuleSet) + + for rule, value := range trs { + srt[rule] = value + } + + return srt +} + +// CallbackHandler is an interface type defining the interface for receiving callbacks. +type CallbackHandler interface { + StateTransitionCallback(string) error +} + +// Machine is the state machine. +type Machine struct { + state string + mu sync.RWMutex + + transitions map[string]TransitionRuleSet + + callback CallbackHandler + syncCallback bool +} + +func (m *Machine) Clone() *Machine { + return &Machine{ + state: m.state, + transitions: m.transitions, + callback: m.callback, + syncCallback: m.syncCallback, + } +} + +// CurrentState returns the machine's current state. If the State returned is +// "", then the machine has not been given an initial state. +func (m *Machine) CurrentState() string { + m.mu.RLock() + defer m.mu.RUnlock() + + return m.state +} + +// StateTransitionRules returns the allowed states for +func (m *Machine) StateTransitionRules(state string) (TransitionRuleSet, error) { + m.mu.RLock() + defer m.mu.RUnlock() + + if m.transitions == nil { + return nil, newErrorStruct("the machine has not been fully initialized", ErrorMachineNotInitialized) + } + + // ensure the state has been registered + if _, ok := m.transitions[state]; !ok { + return nil, newErrorStruct(fmt.Sprintf("state %s has not been registered", state), ErrorStateUndefined) + } + + return m.transitions[state].Copy(), nil +} + +// AddStateTransitionRules is a function for adding valid state transitions to the machine. +// This allows you to define which states any given state can be transitioned to. +func (m *Machine) AddStateTransitionRules(sourceState string, destinationStates ...string) error { + m.mu.Lock() + defer m.mu.Unlock() + + // if the transitions map is nil, we need to allocate it + if m.transitions == nil { + m.transitions = make(map[string]TransitionRuleSet) + } + + // if the map for the source state does not exist, allocate it + if m.transitions[sourceState] == nil { + m.transitions[sourceState] = make(TransitionRuleSet) + } + + // get a reference to the map we care about + // avoids doing the map lookup for each iteration + mp := m.transitions[sourceState] + + for _, dest := range destinationStates { + mp[dest] = struct{}{} + } + + return nil +} + +// SetStateTransitionCallback for the state transition. This is meant to send +// callbacks back to the consumer for state changes. The callback only sends the +// new state. The synchonous parameter indicates whether the callback is done +// synchronously with the StateTransition() call. +func (m *Machine) SetStateTransitionCallback(callback CallbackHandler, synchronous bool) error { + m.mu.Lock() + defer m.mu.Unlock() + + m.callback = callback + m.syncCallback = synchronous + + return nil +} + +// StateTransition triggers a transition to the toState. This function is also +// used to set the initial state of machine. +// +// Before you can transition to any state, even for the initial, you must define +// it with AddStateTransition(). If you are setting the initial state, and that +// state is not define, this will return an ErrInvalidInitialState error. +// +// When transitioning from a state, this function will return an error either +// if the state transition is not allowed, or if the destination state has +// not been defined. In both cases, it's seen as a non-permitted state transition. +func (m *Machine) StateTransition(toState string) error { + m.mu.Lock() + defer m.mu.Unlock() + + // if this is nil we cannot assume any state + if m.transitions == nil { + return newErrorStruct("the machine has no states added", ErrorMachineNotInitialized) + } + + // if the state is nothing, this is probably the initial state + if m.state == "" { + // if the state is not defined, it's invalid + if _, ok := m.transitions[toState]; !ok { + return newErrorStruct("the initial state has not been defined within the machine", ErrorStateUndefined) + } + + // set the state + m.state = toState + return nil + } + + // if we are not permitted to transition to this state... + if _, ok := m.transitions[m.state][toState]; !ok { + return newErrorStruct(fmt.Sprintf("transition from state %s to %s is not permitted", m.state, toState), ErrorTransitionNotPermitted) + } + + // if the destination state was not defined... + if _, ok := m.transitions[toState]; !ok { + return newErrorStruct(fmt.Sprintf("state %s has not been registered", toState), ErrorStateUndefined) + } + + m.state = toState + + if m.callback != nil { + if m.syncCallback { + // do not return the error + // this may be reconsidered + m.callback.StateTransitionCallback(toState) + } else { + // spin off the callback + go func() { m.callback.StateTransitionCallback(toState) }() + } + } + + return nil +} + +type ErrorCode uint + +func (e ErrorCode) String() string { + switch e { + case ErrorMachineNotInitialized: + return "MachineNotInitialized" + case ErrorTransitionNotPermitted: + return "TransitionNotPermitted" + case ErrorStateUndefined: + return "StateUndefined" + default: + return "Unknown" + } +} + +const ( + // ErrorUnknown is the default value + ErrorUnknown ErrorCode = iota + + // ErrorMachineNotInitialized is an error returned when actions are taken on + // a machine before it has been initialized. A machine is initialized by + // adding at least one state and setting it as the initial state. + ErrorMachineNotInitialized + + // ErrorTransitionNotPermitted is the error returned when trying to + // transition to an invalid state. In other words, the machine is not + // permitted to transition from the current state to the one requested. + ErrorTransitionNotPermitted + + // ErrorStateUndefined is the error returned when the requested state is + // not defined within the machine. + ErrorStateUndefined +) + +// Error is the struct representing internal errors. +// This implements the error interface +type Error struct { + message string + code ErrorCode +} + +// newErrorStruct uses messge and code to create an *Error struct. The *Error +// struct implements the 'error' interface, so it should be able to be used +// wherever 'error' is expected. +func newErrorStruct(message string, code ErrorCode) *Error { + return &Error{ + message: message, + code: code, + } +} + +// Message returns the error message. +func (e *Error) Message() string { return e.message } + +// Code returns the error code. +func (e *Error) Code() ErrorCode { return e.code } + +func (e *Error) Error() string { + return fmt.Sprintf("%s (%d): %s", e.code, e.code, e.message) +}