From dc26a3c448906546fc6311369e92b31e9197eaa1 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 2 May 2023 12:39:19 -0500
Subject: [PATCH] safety push before i mess everything up

---
 lib/rob/schedulers/v2/job/job.go       | 13 ++++
 lib/rob/schedulers/v2/pool/pool.go     | 56 ++++++------------
 lib/rob/schedulers/v2/sink/sink.go     | 82 ++++++++++++++++----------
 lib/rob/schedulers/v2/source/source.go |  7 ++-
 lib/util/rbtree/rbtree.go              | 33 ++++++++++-
 lib/util/rbtree/rbtree_test.go         | 29 +++++++++
 6 files changed, 148 insertions(+), 72 deletions(-)
 create mode 100644 lib/rob/schedulers/v2/job/job.go

diff --git a/lib/rob/schedulers/v2/job/job.go b/lib/rob/schedulers/v2/job/job.go
new file mode 100644
index 00000000..2d5dae90
--- /dev/null
+++ b/lib/rob/schedulers/v2/job/job.go
@@ -0,0 +1,13 @@
+package job
+
+import (
+	"github.com/google/uuid"
+
+	"pggat2/lib/rob"
+)
+
+type Job struct {
+	Source      uuid.UUID
+	Work        any
+	Constraints rob.Constraints
+}
diff --git a/lib/rob/schedulers/v2/pool/pool.go b/lib/rob/schedulers/v2/pool/pool.go
index 96d9bbf5..46d4fed9 100644
--- a/lib/rob/schedulers/v2/pool/pool.go
+++ b/lib/rob/schedulers/v2/pool/pool.go
@@ -7,24 +7,14 @@ import (
 	"github.com/google/uuid"
 
 	"pggat2/lib/rob"
+	"pggat2/lib/rob/schedulers/v2/job"
 	"pggat2/lib/rob/schedulers/v2/sink"
 )
 
-type sinkAndConstraints struct {
-	sink        *sink.Sink
-	constraints rob.Constraints
-}
-
-type job struct {
-	source      uuid.UUID
-	work        any
-	constraints rob.Constraints
-}
-
 type Pool struct {
 	affinity  map[uuid.UUID]int
-	sinks     []sinkAndConstraints
-	backorder []job
+	sinks     []*sink.Sink
+	backorder []job.Job
 	mu        sync.Mutex
 }
 
@@ -35,20 +25,17 @@ func MakePool() Pool {
 }
 
 func (T *Pool) NewSink(constraints rob.Constraints) *sink.Sink {
-	snk := sink.NewSink()
+	snk := sink.NewSink(constraints)
 
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
-	T.sinks = append(T.sinks, sinkAndConstraints{
-		sink:        snk,
-		constraints: constraints,
-	})
+	T.sinks = append(T.sinks, snk)
 
 	i := 0
 	for _, j := range T.backorder {
-		if constraints.Satisfies(j.constraints) {
-			snk.Queue(j.source, j.work, j.constraints)
+		if constraints.Satisfies(j.Constraints) {
+			snk.Queue(j)
 		} else {
 			T.backorder[i] = j
 			i++
@@ -59,51 +46,44 @@ func (T *Pool) NewSink(constraints rob.Constraints) *sink.Sink {
 	return snk
 }
 
-func (T *Pool) Schedule(source uuid.UUID, work any, constraints rob.Constraints) {
+func (T *Pool) Schedule(work job.Job) {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
 	if len(T.sinks) == 0 {
-		T.backorder = append(T.backorder, job{
-			source:      source,
-			work:        work,
-			constraints: constraints,
-		})
+		T.backorder = append(T.backorder, work)
 		return
 	}
 
-	affinity, ok := T.affinity[source]
+	affinity, ok := T.affinity[work.Source]
 	if !ok {
 		affinity = rand.Intn(len(T.sinks))
-		T.affinity[source] = affinity
+		T.affinity[work.Source] = affinity
 	}
 
 	snk := T.sinks[affinity]
-	if !snk.sink.Idle() || !snk.constraints.Satisfies(constraints) {
+	if !snk.Constraints().Satisfies(work.Constraints) || !snk.Idle() {
 		// choose a new affinity that satisfies constraints
 		ok = false
 		for id, s := range T.sinks {
-			if s.constraints.Satisfies(constraints) {
+			if s.Constraints().Satisfies(work.Constraints) {
+				current := id == affinity
 				snk = s
 				affinity = id
 				ok = true
-				if s.sink.Idle() {
+				if !current && s.Idle() {
 					// prefer idle core, if not idle try to see if we can find one that is
 					break
 				}
 			}
 		}
 		if !ok {
-			T.backorder = append(T.backorder, job{
-				source:      source,
-				work:        work,
-				constraints: constraints,
-			})
+			T.backorder = append(T.backorder, work)
 			return
 		}
-		T.affinity[source] = affinity
+		T.affinity[work.Source] = affinity
 	}
 
 	// yay, queued
-	snk.sink.Queue(source, work, constraints)
+	snk.Queue(work)
 }
diff --git a/lib/rob/schedulers/v2/sink/sink.go b/lib/rob/schedulers/v2/sink/sink.go
index 4f0a63d7..615f19f0 100644
--- a/lib/rob/schedulers/v2/sink/sink.go
+++ b/lib/rob/schedulers/v2/sink/sink.go
@@ -7,38 +7,40 @@ import (
 	"github.com/google/uuid"
 
 	"pggat2/lib/rob"
+	"pggat2/lib/rob/schedulers/v2/job"
 	"pggat2/lib/util/rbtree"
 	"pggat2/lib/util/ring"
 )
 
-type job struct {
-	source uuid.UUID
-	work   any
-	// need to keep track of constraints for work stealing
+type Sink struct {
 	constraints rob.Constraints
-}
 
-type Sink struct {
 	active uuid.UUID
 	start  time.Time
 
 	floor time.Duration
 
 	stride    map[uuid.UUID]time.Duration
-	pending   map[uuid.UUID]*ring.Ring[job]
-	scheduled rbtree.RBTree[time.Duration, job]
+	pending   map[uuid.UUID]*ring.Ring[job.Job]
+	scheduled rbtree.RBTree[time.Duration, job.Job]
 	signal    chan struct{}
 	mu        sync.Mutex
 }
 
-func NewSink() *Sink {
+func NewSink(constraints rob.Constraints) *Sink {
 	return &Sink{
-		stride:  make(map[uuid.UUID]time.Duration),
-		pending: make(map[uuid.UUID]*ring.Ring[job]),
-		signal:  make(chan struct{}),
+		constraints: constraints,
+		stride:      make(map[uuid.UUID]time.Duration),
+		pending:     make(map[uuid.UUID]*ring.Ring[job.Job]),
+		signal:      make(chan struct{}),
 	}
 }
 
+func (T *Sink) Constraints() rob.Constraints {
+	// no lock needed because these never change
+	return T.constraints
+}
+
 func (T *Sink) Idle() bool {
 	T.mu.Lock()
 	defer T.mu.Unlock()
@@ -46,27 +48,47 @@ func (T *Sink) Idle() bool {
 	return T.active == uuid.Nil
 }
 
-func (T *Sink) Queue(source uuid.UUID, work any, constraints rob.Constraints) {
+func (T *Sink) Queue(work job.Job) {
 	T.mu.Lock()
 	defer T.mu.Unlock()
 
-	j := job{
-		source:      source,
-		work:        work,
-		constraints: constraints,
-	}
-
 	// try to schedule right away
-	if ok := T.scheduleWork(j); ok {
+	if ok := T.scheduleWork(work); ok {
 		return
 	}
 
 	// add to pending queue
-	if _, ok := T.pending[source]; !ok {
-		T.pending[source] = new(ring.Ring[job])
+	if _, ok := T.pending[work.Source]; !ok {
+		T.pending[work.Source] = new(ring.Ring[job.Job])
+	}
+
+	T.pending[work.Source].PushBack(work)
+}
+
+func (T *Sink) Steal(constraints rob.Constraints) (job.Job, *ring.Ring[job.Job], bool) {
+	T.mu.Lock()
+	defer T.mu.Unlock()
+
+	iter := T.scheduled.Iter()
+	for stride, work, ok := iter(); ok; stride, work, ok = iter() {
+		if constraints.Satisfies(work.Constraints) {
+			// steal it
+			T.scheduled.Delete(stride)
+
+			// steal pending
+			pending, _ := T.pending[work.Source]
+			if pending.Length() == 0 {
+				pending = nil
+			} else {
+				delete(T.pending, work.Source)
+			}
+
+			return work, pending, true
+		}
 	}
 
-	T.pending[source].PushBack(j)
+	// no stealable work
+	return job.Job{}, nil, false
 }
 
 // schedule the next work for source
@@ -85,21 +107,21 @@ func (T *Sink) schedule(source uuid.UUID) {
 	pending.PopFront()
 }
 
-func (T *Sink) scheduleWork(work job) bool {
-	if T.active == work.source {
+func (T *Sink) scheduleWork(work job.Job) bool {
+	if T.active == work.Source {
 		return false
 	}
 
-	stride := T.stride[work.source]
+	stride := T.stride[work.Source]
 	if stride < T.floor {
 		stride = T.floor
-		T.stride[work.source] = stride
+		T.stride[work.Source] = stride
 	}
 
 	for {
 		// find unique stride to schedule on
 		if j, ok := T.scheduled.Get(stride); ok {
-			if j.source == work.source {
+			if j.Source == work.Source {
 				return false
 			}
 			stride += 1
@@ -149,9 +171,9 @@ func (T *Sink) Read() any {
 		T.scheduled.Delete(stride)
 		T.floor = stride
 
-		T.active = j.source
+		T.active = j.Source
 		T.start = time.Now()
-		return j.work
+		return j.Work
 	}
 }
 
diff --git a/lib/rob/schedulers/v2/source/source.go b/lib/rob/schedulers/v2/source/source.go
index d6103023..bf73a1a7 100644
--- a/lib/rob/schedulers/v2/source/source.go
+++ b/lib/rob/schedulers/v2/source/source.go
@@ -4,6 +4,7 @@ import (
 	"github.com/google/uuid"
 
 	"pggat2/lib/rob"
+	"pggat2/lib/rob/schedulers/v2/job"
 	"pggat2/lib/rob/schedulers/v2/pool"
 )
 
@@ -20,7 +21,11 @@ func NewSource(p *pool.Pool) *Source {
 }
 
 func (T *Source) Schedule(work any, constraints rob.Constraints) {
-	T.pool.Schedule(T.uuid, work, constraints)
+	T.pool.Schedule(job.Job{
+		Source:      T.uuid,
+		Work:        work,
+		Constraints: constraints,
+	})
 }
 
 var _ rob.Source = (*Source)(nil)
diff --git a/lib/util/rbtree/rbtree.go b/lib/util/rbtree/rbtree.go
index 10f69cf6..592252fe 100644
--- a/lib/util/rbtree/rbtree.go
+++ b/lib/util/rbtree/rbtree.go
@@ -184,11 +184,38 @@ func (T *RBTree[K, V]) min(n *node[K, V]) *node[K, V] {
 	return T.min(n.left)
 }
 
-type color int
+func (T *RBTree[K, V]) Iter() func() (K, V, bool) {
+	// TODO(garet) make this not allocate
+	nodes := T.all(T.root, nil)
+	i := 0
+
+	return func() (K, V, bool) {
+		if i >= len(nodes) {
+			return *new(K), *new(V), false
+		}
+
+		n := nodes[i]
+		i++
+		return n.key, n.value, true
+	}
+}
+
+func (T *RBTree[K, V]) all(n *node[K, V], slice []*node[K, V]) []*node[K, V] {
+	if n == nil {
+		return slice
+	}
+
+	slice = T.all(n.left, slice)
+	slice = append(slice, n)
+	slice = T.all(n.right, slice)
+	return slice
+}
+
+type color bool
 
 const (
-	black color = 0
-	red   color = 1
+	black color = false
+	red   color = true
 )
 
 func (T color) opposite() color {
diff --git a/lib/util/rbtree/rbtree_test.go b/lib/util/rbtree/rbtree_test.go
index 9abd4e9c..c1cf8188 100644
--- a/lib/util/rbtree/rbtree_test.go
+++ b/lib/util/rbtree/rbtree_test.go
@@ -35,6 +35,23 @@ func assertMin[K order, V comparable](t *testing.T, tree *RBTree[K, V], key K, v
 	}
 }
 
+func assertIterSome[K order, V comparable](t *testing.T, iter func() (K, V, bool), key K, value V) {
+	k, v, ok := iter()
+	if !ok {
+		t.Error("expected iterator to have values")
+	}
+	if k != key || v != value {
+		t.Error("expected key, value to be", key, value, "but got", k, v)
+	}
+}
+
+func assertIterNone[K order, V any](t *testing.T, iter func() (K, V, bool)) {
+	k, v, ok := iter()
+	if ok {
+		t.Error("expected no items but got", k, v)
+	}
+}
+
 func TestRBTree_Insert(t *testing.T) {
 	tree := new(RBTree[int, int])
 	tree.Set(1, 2)
@@ -68,6 +85,18 @@ func TestRBTree_Min(t *testing.T) {
 	assertMin(t, tree, 5, 6)
 }
 
+func TestRBTree_Iter(t *testing.T) {
+	tree := new(RBTree[int, int])
+	tree.Set(1, 2)
+	tree.Set(5, 6)
+	tree.Set(3, 4)
+	iter := tree.Iter()
+	assertIterSome(t, iter, 1, 2)
+	assertIterSome(t, iter, 3, 4)
+	assertIterSome(t, iter, 5, 6)
+	assertIterNone(t, iter)
+}
+
 func TestRBTree_Stress(t *testing.T) {
 	const n = 1000000
 
-- 
GitLab