From 5cd9b4a8c4b5e64f4314047b9ad65dd7d9e81eaf Mon Sep 17 00:00:00 2001
From: "Vojtech Vitek (V-Teq)" <vojtech.vitek@pressly.com>
Date: Wed, 24 Aug 2016 15:15:27 -0400
Subject: [PATCH] Batch API improvements vol.2

---
 internal/sqladapter/testing/adapter.go.tpl | 12 +++++------
 lib/sqlbuilder/batch.go                    | 24 +++++++++++++++-------
 2 files changed, 23 insertions(+), 13 deletions(-)

diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl
index d41012ea..855f6048 100644
--- a/internal/sqladapter/testing/adapter.go.tpl
+++ b/internal/sqladapter/testing/adapter.go.tpl
@@ -1091,13 +1091,13 @@ func TestBatchInsert(t *testing.T) {
 		totalItems := int(rand.Int31n(21))
 
 		go func() {
+			defer batch.Done()
 			for i := 0; i < totalItems; i++ {
-				batch.Values <- fmt.Sprintf("artist-%d", i)
+				batch.Values(fmt.Sprintf("artist-%d", i))
 			}
-			close(batch.Values)
 		}()
 
-		err = batch.Exec()
+		err = batch.Wait()
 		assert.NoError(t, err)
 		assert.NoError(t, batch.Error())
 
@@ -1129,16 +1129,16 @@ func TestBatchInsertReturningKeys(t *testing.T) {
 	batch := sess.InsertInto("artist").Columns("name").Returning("id").NewBatch(batchSize)
 
 	go func() {
+		defer batch.Done()
 		for i := 0; i < totalItems; i++ {
-			batch.Values <- fmt.Sprintf("artist-%d", i)
+			batch.Values(fmt.Sprintf("artist-%d", i))
 		}
-		close(batch.Values)
 	}()
 
 	var keyMap []struct {
 		ID int `db:"id"`
 	}
-	for batch.Next(&keyMap) {
+	for batch.NextResult(&keyMap) {
 		// Each insertion must produce new keys.
 		assert.True(t, len(keyMap) > 0)
 		assert.True(t, len(keyMap) <= batchSize)
diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go
index c61b3510..79a3f8a4 100644
--- a/lib/sqlbuilder/batch.go
+++ b/lib/sqlbuilder/batch.go
@@ -3,7 +3,7 @@ package sqlbuilder
 type BatchInserter struct {
 	inserter *inserter
 	size     int
-	Values   chan interface{}
+	values   chan []interface{}
 	err      error
 }
 
@@ -14,17 +14,23 @@ func newBatchInserter(inserter *inserter, size int) *BatchInserter {
 	b := &BatchInserter{
 		inserter: inserter,
 		size:     size,
-		Values:   make(chan interface{}, size),
+		values:   make(chan []interface{}, size),
 	}
 	return b
 }
 
-func (b *BatchInserter) Next(dst interface{}) bool {
+// Values pushes column values to be inserted as part of the batch.
+func (b *BatchInserter) Values(values ...interface{}) *BatchInserter {
+	b.values <- values
+	return b
+}
+
+func (b *BatchInserter) NextResult(dst interface{}) bool {
 	clone := b.inserter.clone()
 	i := 0
-	for value := range b.Values {
+	for values := range b.values {
 		i++
-		clone.Values(value)
+		clone.Values(values...)
 		if i == b.size {
 			break
 		}
@@ -36,9 +42,13 @@ func (b *BatchInserter) Next(dst interface{}) bool {
 	return (b.err == nil)
 }
 
-func (b *BatchInserter) Exec() error {
+func (b *BatchInserter) Done() {
+	close(b.values)
+}
+
+func (b *BatchInserter) Wait() error {
 	var nop []struct{}
-	for b.Next(&nop) {
+	for b.NextResult(&nop) {
 	}
 	return b.err
 }
-- 
GitLab