diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 933543a61585b372bdaf44aac8e1d79414e1bc3e..855f6048d067fe28b93d21df6eb44650d062d5f3 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -1091,16 +1091,15 @@ func TestBatchInsert(t *testing.T) { totalItems := int(rand.Int31n(21)) go func() { + defer batch.Done() for i := 0; i < totalItems; i++ { batch.Values(fmt.Sprintf("artist-%d", i)) } - batch.Done() }() - for q := range batch.Next() { - _, err = q.Exec() - assert.NoError(t, err) - } + err = batch.Wait() + assert.NoError(t, err) + assert.NoError(t, batch.Error()) c, err := sess.Collection("artist").Find().Count() assert.NoError(t, err) @@ -1130,17 +1129,16 @@ func TestBatchInsertReturningKeys(t *testing.T) { batch := sess.InsertInto("artist").Columns("name").Returning("id").NewBatch(batchSize) go func() { + defer batch.Done() for i := 0; i < totalItems; i++ { batch.Values(fmt.Sprintf("artist-%d", i)) } - batch.Done() }() - for q := range batch.Next() { - var keyMap []struct{ID int `db:"id"`} - err := q.Iterator().All(&keyMap) - assert.NoError(t, err) - + var keyMap []struct { + ID int `db:"id"` + } + for batch.NextResult(&keyMap) { // Each insertion must produce new keys. assert.True(t, len(keyMap) > 0) assert.True(t, len(keyMap) <= batchSize) @@ -1156,6 +1154,7 @@ func TestBatchInsertReturningKeys(t *testing.T) { assert.NoError(t, err) assert.Equal(t, uint64(len(keyMap)), c) } + assert.NoError(t, batch.Error()) // Count all new elements c, err := sess.Collection("artist").Find().Count() diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go index f35bb048a7e2d4a4eea863f2915fef2e6d9470f7..79a3f8a4bb9009ea94adb1b430b66f04f310aceb 100644 --- a/lib/sqlbuilder/batch.go +++ b/lib/sqlbuilder/batch.go @@ -1,15 +1,10 @@ package sqlbuilder -import ( - "sync" -) - type BatchInserter struct { inserter *inserter size int - values [][]interface{} - next chan Inserter - mu sync.Mutex + values chan []interface{} + err error } func newBatchInserter(inserter *inserter, size int) *BatchInserter { @@ -19,51 +14,45 @@ func newBatchInserter(inserter *inserter, size int) *BatchInserter { b := &BatchInserter{ inserter: inserter, size: size, - next: make(chan Inserter), + values: make(chan []interface{}, size), } - b.reset() return b } -func (b *BatchInserter) reset() { - b.values = make([][]interface{}, 0, b.size) +// Values pushes column values to be inserted as part of the batch. +func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { + b.values <- values + return b } -func (b *BatchInserter) flush() { - if len(b.values) > 0 { - clone := b.inserter.clone() - for i := range b.values { - clone.Values(b.values[i]...) +func (b *BatchInserter) NextResult(dst interface{}) bool { + clone := b.inserter.clone() + i := 0 + for values := range b.values { + i++ + clone.Values(values...) + if i == b.size { + break } - b.next <- clone - b.reset() } -} - -// Values pushes a value to be inserted as part of the batch. -func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { - b.mu.Lock() - defer b.mu.Unlock() - - b.values = append(b.values, values) - if len(b.values) >= b.size { - b.flush() + if i == 0 { + return false } - return b + b.err = clone.Iterator().All(dst) + return (b.err == nil) } -// Next returns a channel that receives new q elements. -func (b *BatchInserter) Next() chan Inserter { - b.mu.Lock() - defer b.mu.Unlock() - - return b.next +func (b *BatchInserter) Done() { + close(b.values) } -func (b *BatchInserter) Done() { - b.mu.Lock() - defer b.mu.Unlock() +func (b *BatchInserter) Wait() error { + var nop []struct{} + for b.NextResult(&nop) { + } + return b.err +} - b.flush() - close(b.next) +func (b *BatchInserter) Error() error { + return b.err }