From 5cd9b4a8c4b5e64f4314047b9ad65dd7d9e81eaf Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (V-Teq)" <vojtech.vitek@pressly.com> Date: Wed, 24 Aug 2016 15:15:27 -0400 Subject: [PATCH] Batch API improvements vol.2 --- internal/sqladapter/testing/adapter.go.tpl | 12 +++++------ lib/sqlbuilder/batch.go | 24 +++++++++++++++------- 2 files changed, 23 insertions(+), 13 deletions(-) diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index d41012ea..855f6048 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -1091,13 +1091,13 @@ func TestBatchInsert(t *testing.T) { totalItems := int(rand.Int31n(21)) go func() { + defer batch.Done() for i := 0; i < totalItems; i++ { - batch.Values <- fmt.Sprintf("artist-%d", i) + batch.Values(fmt.Sprintf("artist-%d", i)) } - close(batch.Values) }() - err = batch.Exec() + err = batch.Wait() assert.NoError(t, err) assert.NoError(t, batch.Error()) @@ -1129,16 +1129,16 @@ func TestBatchInsertReturningKeys(t *testing.T) { batch := sess.InsertInto("artist").Columns("name").Returning("id").NewBatch(batchSize) go func() { + defer batch.Done() for i := 0; i < totalItems; i++ { - batch.Values <- fmt.Sprintf("artist-%d", i) + batch.Values(fmt.Sprintf("artist-%d", i)) } - close(batch.Values) }() var keyMap []struct { ID int `db:"id"` } - for batch.Next(&keyMap) { + for batch.NextResult(&keyMap) { // Each insertion must produce new keys. assert.True(t, len(keyMap) > 0) assert.True(t, len(keyMap) <= batchSize) diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go index c61b3510..79a3f8a4 100644 --- a/lib/sqlbuilder/batch.go +++ b/lib/sqlbuilder/batch.go @@ -3,7 +3,7 @@ package sqlbuilder type BatchInserter struct { inserter *inserter size int - Values chan interface{} + values chan []interface{} err error } @@ -14,17 +14,23 @@ func newBatchInserter(inserter *inserter, size int) *BatchInserter { b := &BatchInserter{ inserter: inserter, size: size, - Values: make(chan interface{}, size), + values: make(chan []interface{}, size), } return b } -func (b *BatchInserter) Next(dst interface{}) bool { +// Values pushes column values to be inserted as part of the batch. +func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { + b.values <- values + return b +} + +func (b *BatchInserter) NextResult(dst interface{}) bool { clone := b.inserter.clone() i := 0 - for value := range b.Values { + for values := range b.values { i++ - clone.Values(value) + clone.Values(values...) if i == b.size { break } @@ -36,9 +42,13 @@ func (b *BatchInserter) Next(dst interface{}) bool { return (b.err == nil) } -func (b *BatchInserter) Exec() error { +func (b *BatchInserter) Done() { + close(b.values) +} + +func (b *BatchInserter) Wait() error { var nop []struct{} - for b.Next(&nop) { + for b.NextResult(&nop) { } return b.err } -- GitLab