From 00bc133d059889d6163fd5389fe32a5506ae72bf Mon Sep 17 00:00:00 2001 From: "Vojtech Vitek (V-Teq)" <vojtech.vitek@pressly.com> Date: Wed, 24 Aug 2016 13:04:45 -0400 Subject: [PATCH] Simplify BatchInserter & Batch API --- internal/sqladapter/testing/adapter.go.tpl | 23 ++++---- lib/sqlbuilder/batch.go | 65 ++++++++-------------- 2 files changed, 34 insertions(+), 54 deletions(-) diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 933543a6..32717996 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -1092,15 +1092,14 @@ func TestBatchInsert(t *testing.T) { go func() { for i := 0; i < totalItems; i++ { - batch.Values(fmt.Sprintf("artist-%d", i)) + batch.Values <- fmt.Sprintf("artist-%d", i) } - batch.Done() + close(batch.Values) }() - for q := range batch.Next() { - _, err = q.Exec() - assert.NoError(t, err) - } + err = batch.Exec() + assert.NoError(t, err) + assert.NoError(t, batch.Error()) c, err := sess.Collection("artist").Find().Count() assert.NoError(t, err) @@ -1131,14 +1130,15 @@ func TestBatchInsertReturningKeys(t *testing.T) { go func() { for i := 0; i < totalItems; i++ { - batch.Values(fmt.Sprintf("artist-%d", i)) + batch.Values <- fmt.Sprintf("artist-%d", i) } - batch.Done() + close(batch.Values) }() - for q := range batch.Next() { - var keyMap []struct{ID int `db:"id"`} - err := q.Iterator().All(&keyMap) + var keyMap []struct { + ID int `db:"id"` + } + for batch.Next(&keyMap) { assert.NoError(t, err) // Each insertion must produce new keys. @@ -1156,6 +1156,7 @@ func TestBatchInsertReturningKeys(t *testing.T) { assert.NoError(t, err) assert.Equal(t, uint64(len(keyMap)), c) } + assert.NoError(t, batch.Error()) // Count all new elements c, err := sess.Collection("artist").Find().Count() diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go index f35bb048..1148c8a7 100644 --- a/lib/sqlbuilder/batch.go +++ b/lib/sqlbuilder/batch.go @@ -1,15 +1,10 @@ package sqlbuilder -import ( - "sync" -) - type BatchInserter struct { inserter *inserter size int - values [][]interface{} - next chan Inserter - mu sync.Mutex + Values chan interface{} + err error } func newBatchInserter(inserter *inserter, size int) *BatchInserter { @@ -19,51 +14,35 @@ func newBatchInserter(inserter *inserter, size int) *BatchInserter { b := &BatchInserter{ inserter: inserter, size: size, - next: make(chan Inserter), + Values: make(chan interface{}, size), } - b.reset() return b } -func (b *BatchInserter) reset() { - b.values = make([][]interface{}, 0, b.size) -} - -func (b *BatchInserter) flush() { - if len(b.values) > 0 { - clone := b.inserter.clone() - for i := range b.values { - clone.Values(b.values[i]...) +func (b *BatchInserter) Next(dst interface{}) bool { + clone := b.inserter.clone() + i := 0 + for value := range b.Values { + i++ + clone.Values(value) + if b.size == i { + break } - b.next <- clone - b.reset() } -} - -// Values pushes a value to be inserted as part of the batch. -func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { - b.mu.Lock() - defer b.mu.Unlock() - - b.values = append(b.values, values) - if len(b.values) >= b.size { - b.flush() + if i == 0 { + return false } - return b + b.err = clone.Iterator().All(dst) + return (b.err == nil) } -// Next returns a channel that receives new q elements. -func (b *BatchInserter) Next() chan Inserter { - b.mu.Lock() - defer b.mu.Unlock() - - return b.next +func (b *BatchInserter) Exec() error { + var nop []struct{} + for b.Next(&nop) { + } + return b.err } -func (b *BatchInserter) Done() { - b.mu.Lock() - defer b.mu.Unlock() - - b.flush() - close(b.next) +func (b *BatchInserter) Error() error { + return b.err } -- GitLab