From 00bc133d059889d6163fd5389fe32a5506ae72bf Mon Sep 17 00:00:00 2001
From: "Vojtech Vitek (V-Teq)" <vojtech.vitek@pressly.com>
Date: Wed, 24 Aug 2016 13:04:45 -0400
Subject: [PATCH] Simplify BatchInserter & Batch API

---
 internal/sqladapter/testing/adapter.go.tpl | 23 ++++----
 lib/sqlbuilder/batch.go                    | 65 ++++++++--------------
 2 files changed, 34 insertions(+), 54 deletions(-)

diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index 933543a6..32717996 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -1092,15 +1092,14 @@ func TestBatchInsert(t *testing.T) {
 
 		go func() {
 			for i := 0; i < totalItems; i++ {
-				batch.Values(fmt.Sprintf("artist-%d", i))
+				batch.Values <- fmt.Sprintf("artist-%d", i)
 			}
-			batch.Done()
+			close(batch.Values)
 		}()
 
-		for q := range batch.Next() {
-			_, err = q.Exec()
-			assert.NoError(t, err)
-		}
+		err = batch.Exec()
+		assert.NoError(t, err)
+		assert.NoError(t, batch.Error())
 
 		c, err := sess.Collection("artist").Find().Count()
 		assert.NoError(t, err)
@@ -1131,14 +1130,15 @@ func TestBatchInsertReturningKeys(t *testing.T) {
 
 	go func() {
 		for i := 0; i < totalItems; i++ {
-			batch.Values(fmt.Sprintf("artist-%d", i))
+			batch.Values <- fmt.Sprintf("artist-%d", i)
 		}
-		batch.Done()
+		close(batch.Values)
 	}()
 
-	for q := range batch.Next() {
-		var keyMap []struct{ID int `db:"id"`}
-		err := q.Iterator().All(&keyMap)
+	var keyMap []struct {
+		ID int `db:"id"`
+	}
+	for batch.Next(&keyMap) {
 		assert.NoError(t, err)
 
 		// Each insertion must produce new keys.
@@ -1156,6 +1156,7 @@ func TestBatchInsertReturningKeys(t *testing.T) {
 		assert.NoError(t, err)
 		assert.Equal(t, uint64(len(keyMap)), c)
 	}
+	assert.NoError(t, batch.Error())
 
 	// Count all new elements
 	c, err := sess.Collection("artist").Find().Count()
diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go
index f35bb048..1148c8a7 100644
--- a/lib/sqlbuilder/batch.go
+++ b/lib/sqlbuilder/batch.go
@@ -1,15 +1,10 @@
 package sqlbuilder
 
-import (
-	"sync"
-)
-
 type BatchInserter struct {
 	inserter *inserter
 	size     int
-	values   [][]interface{}
-	next     chan Inserter
-	mu       sync.Mutex
+	Values   chan interface{}
+	err      error
 }
 
 func newBatchInserter(inserter *inserter, size int) *BatchInserter {
@@ -19,51 +14,35 @@ func newBatchInserter(inserter *inserter, size int) *BatchInserter {
 	b := &BatchInserter{
 		inserter: inserter,
 		size:     size,
-		next:     make(chan Inserter),
+		Values:   make(chan interface{}, size),
 	}
-	b.reset()
 	return b
 }
 
-func (b *BatchInserter) reset() {
-	b.values = make([][]interface{}, 0, b.size)
-}
-
-func (b *BatchInserter) flush() {
-	if len(b.values) > 0 {
-		clone := b.inserter.clone()
-		for i := range b.values {
-			clone.Values(b.values[i]...)
+func (b *BatchInserter) Next(dst interface{}) bool {
+	clone := b.inserter.clone()
+	i := 0
+	for value := range b.Values {
+		i++
+		clone.Values(value)
+		if b.size == i {
+			break
 		}
-		b.next <- clone
-		b.reset()
 	}
-}
-
-// Values pushes a value to be inserted as part of the batch.
-func (b *BatchInserter) Values(values ...interface{}) *BatchInserter {
-	b.mu.Lock()
-	defer b.mu.Unlock()
-
-	b.values = append(b.values, values)
-	if len(b.values) >= b.size {
-		b.flush()
+	if i == 0 {
+		return false
 	}
-	return b
+	b.err = clone.Iterator().All(dst)
+	return (b.err == nil)
 }
 
-// Next returns a channel that receives new q elements.
-func (b *BatchInserter) Next() chan Inserter {
-	b.mu.Lock()
-	defer b.mu.Unlock()
-
-	return b.next
+func (b *BatchInserter) Exec() error {
+	var nop []struct{}
+	for b.Next(&nop) {
+	}
+	return b.err
 }
 
-func (b *BatchInserter) Done() {
-	b.mu.Lock()
-	defer b.mu.Unlock()
-
-	b.flush()
-	close(b.next)
+func (b *BatchInserter) Error() error {
+	return b.err
 }
-- 
GitLab