diff --git a/lib/rob/schedulers/v2/pool/pool.go b/lib/rob/schedulers/v2/pool/pool.go
index 62f3124fd1368b3a9544a8a35df7586843161473..3c73b10968296c598b881508597296eee39e75c7 100644
--- a/lib/rob/schedulers/v2/pool/pool.go
+++ b/lib/rob/schedulers/v2/pool/pool.go
@@ -106,11 +106,16 @@ func (T *Pool) stealFor(id uuid.UUID) {
 			continue
 		}
 		works, ok := s.Steal(constraints)
+		if !ok {
+			continue
+		}
+		if len(works) > 0 {
+			source := works[0].Source
+			T.affinity[source] = id
+		}
 		for _, work := range works {
 			snk.Queue(work)
 		}
-		if ok {
-			break
-		}
+		break
 	}
 }
diff --git a/lib/rob/schedulers/v2/scheduler_test.go b/lib/rob/schedulers/v2/scheduler_test.go
index cb1445b849f75d0c839280b5ba336dfc5c3b7f9a..48ba7689b5b988ff71278ef4030327fcae8318f4 100644
--- a/lib/rob/schedulers/v2/scheduler_test.go
+++ b/lib/rob/schedulers/v2/scheduler_test.go
@@ -5,12 +5,15 @@ import (
 	"sync"
 	"testing"
 	"time"
+
+	"pggat2/lib/rob"
 )
 
 type Work struct {
-	Sender   int
-	Duration time.Duration
-	Done     chan<- struct{}
+	Sender      int
+	Duration    time.Duration
+	Done        chan<- struct{}
+	Constraints rob.Constraints
 }
 
 type ShareTable struct {
@@ -36,12 +39,15 @@ func (T *ShareTable) Get(user int) int {
 	return v
 }
 
-func testSink(sched *Scheduler, table *ShareTable) {
-	sink := sched.NewSink(0)
+func testSink(sched *Scheduler, table *ShareTable, constraints rob.Constraints) {
+	sink := sched.NewSink(constraints)
 	for {
 		w := sink.Read()
 		switch v := w.(type) {
 		case Work:
+			if !constraints.Satisfies(v.Constraints) {
+				panic("Scheduler did not obey constraints")
+			}
 			// dummy load
 			start := time.Now()
 			for time.Since(start) < v.Duration {
@@ -52,16 +58,17 @@ func testSink(sched *Scheduler, table *ShareTable) {
 	}
 }
 
-func testSource(sched *Scheduler, id int, dur time.Duration) {
+func testSource(sched *Scheduler, id int, dur time.Duration, constraints rob.Constraints) {
 	source := sched.NewSource()
 	for {
 		done := make(chan struct{})
 		w := Work{
-			Sender:   id,
-			Duration: dur,
-			Done:     done,
+			Sender:      id,
+			Duration:    dur,
+			Done:        done,
+			Constraints: constraints,
 		}
-		source.Schedule(w, 0)
+		source.Schedule(w, constraints)
 		<-done
 	}
 }
@@ -109,12 +116,12 @@ func allStacks() []byte {
 func TestScheduler(t *testing.T) {
 	var table ShareTable
 	sched := NewScheduler()
-	go testSink(sched, &table)
+	go testSink(sched, &table, 0)
 
-	go testSource(sched, 0, 10*time.Millisecond)
-	go testSource(sched, 1, 10*time.Millisecond)
-	go testSource(sched, 2, 50*time.Millisecond)
-	go testSource(sched, 3, 100*time.Millisecond)
+	go testSource(sched, 0, 10*time.Millisecond, 0)
+	go testSource(sched, 1, 10*time.Millisecond, 0)
+	go testSource(sched, 2, 50*time.Millisecond, 0)
+	go testSource(sched, 3, 100*time.Millisecond, 0)
 
 	time.Sleep(20 * time.Second)
 	t0 := table.Get(0)
@@ -149,15 +156,15 @@ func TestScheduler(t *testing.T) {
 func TestScheduler_Late(t *testing.T) {
 	var table ShareTable
 	sched := NewScheduler()
-	go testSink(sched, &table)
+	go testSink(sched, &table, 0)
 
-	go testSource(sched, 0, 10*time.Millisecond)
-	go testSource(sched, 1, 10*time.Millisecond)
+	go testSource(sched, 0, 10*time.Millisecond, 0)
+	go testSource(sched, 1, 10*time.Millisecond, 0)
 
 	time.Sleep(10 * time.Second)
 
-	go testSource(sched, 2, 10*time.Millisecond)
-	go testSource(sched, 3, 10*time.Millisecond)
+	go testSource(sched, 2, 10*time.Millisecond, 0)
+	go testSource(sched, 3, 10*time.Millisecond, 0)
 
 	time.Sleep(10 * time.Second)
 	t0 := table.Get(0)
@@ -193,13 +200,13 @@ func TestScheduler_Late(t *testing.T) {
 func TestScheduler_StealBalanced(t *testing.T) {
 	var table ShareTable
 	sched := NewScheduler()
-	go testSink(sched, &table)
-	go testSink(sched, &table)
+	go testSink(sched, &table, 0)
+	go testSink(sched, &table, 0)
 
-	go testSource(sched, 0, 10*time.Millisecond)
-	go testSource(sched, 1, 10*time.Millisecond)
-	go testSource(sched, 2, 10*time.Millisecond)
-	go testSource(sched, 3, 10*time.Millisecond)
+	go testSource(sched, 0, 10*time.Millisecond, 0)
+	go testSource(sched, 1, 10*time.Millisecond, 0)
+	go testSource(sched, 2, 10*time.Millisecond, 0)
+	go testSource(sched, 3, 10*time.Millisecond, 0)
 
 	time.Sleep(20 * time.Second)
 	t0 := table.Get(0)
@@ -230,12 +237,12 @@ func TestScheduler_StealBalanced(t *testing.T) {
 func TestScheduler_StealUnbalanced(t *testing.T) {
 	var table ShareTable
 	sched := NewScheduler()
-	go testSink(sched, &table)
-	go testSink(sched, &table)
+	go testSink(sched, &table, 0)
+	go testSink(sched, &table, 0)
 
-	go testSource(sched, 0, 10*time.Millisecond)
-	go testSource(sched, 1, 10*time.Millisecond)
-	go testSource(sched, 2, 10*time.Millisecond)
+	go testSource(sched, 0, 10*time.Millisecond, 0)
+	go testSource(sched, 1, 10*time.Millisecond, 0)
+	go testSource(sched, 2, 10*time.Millisecond, 0)
 
 	time.Sleep(20 * time.Second)
 	t0 := table.Get(0)
@@ -260,3 +267,45 @@ func TestScheduler_StealUnbalanced(t *testing.T) {
 		t.Errorf("%s", allStacks())
 	}
 }
+
+func TestScheduler_Constraints(t *testing.T) {
+	const (
+		ConstraintA rob.Constraints = 1 << iota
+		ConstraintB
+	)
+
+	var table ShareTable
+	sched := NewScheduler()
+
+	go testSink(sched, &table, rob.Constraints.All(ConstraintA, ConstraintB))
+	go testSink(sched, &table, ConstraintA)
+	go testSink(sched, &table, ConstraintB)
+
+	go testSource(sched, 0, 10*time.Millisecond, rob.Constraints.All(ConstraintA, ConstraintB))
+	go testSource(sched, 1, 10*time.Millisecond, rob.Constraints.All(ConstraintA, ConstraintB))
+	go testSource(sched, 2, 10*time.Millisecond, ConstraintA)
+	go testSource(sched, 3, 10*time.Millisecond, ConstraintA)
+	go testSource(sched, 4, 10*time.Millisecond, ConstraintB)
+	go testSource(sched, 5, 10*time.Millisecond, ConstraintB)
+
+	time.Sleep(20 * time.Second)
+	t0 := table.Get(0)
+	t1 := table.Get(1)
+	t2 := table.Get(2)
+	t3 := table.Get(3)
+	t4 := table.Get(4)
+	t5 := table.Get(5)
+
+	/*
+		Expectations:
+		- all users should get similar # of executions (shares of 0 and 1 may be less because they have less sinks they can use: 2 vs 4)
+		- all constraints should be honored
+	*/
+
+	t.Log("share of 0:", t0)
+	t.Log("share of 1:", t1)
+	t.Log("share of 2:", t2)
+	t.Log("share of 3:", t3)
+	t.Log("share of 4:", t4)
+	t.Log("share of 5:", t5)
+}