diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 96ed05a6a5857d5249f55b8a945b9f25455e32e4..66004ea173a7b4d3992322799dd7d30f0310a4de 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -1078,6 +1078,90 @@ func TestDataTypes(t *testing.T) { assert.Equal(t, testValues, item) } +func TestBatchInsert(t *testing.T) { + sess := mustOpen() + defer sess.Close() + + for batchSize := 0; batchSize < 17; batchSize++ { + err := sess.Collection("artist").Truncate() + assert.NoError(t, err) + + batch := sess.InsertInto("artist").Columns("name").Batch(batchSize) + + totalItems := int(rand.Int31n(21)) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + }() + + err = batch.Wait() + assert.NoError(t, err) + assert.NoError(t, batch.Err()) + + c, err := sess.Collection("artist").Find().Count() + assert.NoError(t, err) + assert.Equal(t, uint64(totalItems), c) + + for i := 0; i < totalItems; i++ { + c, err := sess.Collection("artist").Find(db.Cond{"name": fmt.Sprintf("artist-%d", i)}).Count() + assert.NoError(t, err) + assert.Equal(t, uint64(1), c) + } + } +} + +func TestBatchInsertReturningKeys(t *testing.T) { + if Adapter != "postgresql" { + t.Skip("Currently not supported.") + } + + sess := mustOpen() + defer sess.Close() + + err := sess.Collection("artist").Truncate() + assert.NoError(t, err) + + batchSize, totalItems := 7, 12 + + batch := sess.InsertInto("artist").Columns("name").Returning("id").Batch(batchSize) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + }() + + var keyMap []struct { + ID int `db:"id"` + } + for batch.NextResult(&keyMap) { + // Each insertion must produce new keys. + assert.True(t, len(keyMap) > 0) + assert.True(t, len(keyMap) <= batchSize) + + // Find the elements we've just inserted + keys := make([]int, 0, len(keyMap)) + for i := range keyMap { + keys = append(keys, keyMap[i].ID) + } + + // Make sure count matches. + c, err := sess.Collection("artist").Find(db.Cond{"id": keys}).Count() + assert.NoError(t, err) + assert.Equal(t, uint64(len(keyMap)), c) + } + assert.NoError(t, batch.Err()) + + // Count all new elements + c, err := sess.Collection("artist").Find().Count() + assert.NoError(t, err) + assert.Equal(t, uint64(totalItems), c) +} + func TestBuilder(t *testing.T) { sess := mustOpen() defer sess.Close() diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go new file mode 100644 index 0000000000000000000000000000000000000000..0177b8e686384797f07b1e82125b323a0d2d04df --- /dev/null +++ b/lib/sqlbuilder/batch.go @@ -0,0 +1,81 @@ +package sqlbuilder + +// BatchInserter provides a helper that can be used to do massive insertions in +// batches. +type BatchInserter struct { + inserter *inserter + size int + values chan []interface{} + err error +} + +func newBatchInserter(inserter *inserter, size int) *BatchInserter { + if size < 1 { + size = 1 + } + b := &BatchInserter{ + inserter: inserter, + size: size, + values: make(chan []interface{}, size), + } + return b +} + +// Values pushes column values to be inserted as part of the batch. +func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { + b.values <- values + return b +} + +func (b *BatchInserter) nextQuery() *inserter { + clone := b.inserter.clone() + i := 0 + for values := range b.values { + i++ + clone.Values(values...) + if i == b.size { + break + } + } + if i == 0 { + return nil + } + return clone +} + +// NextResult is useful when using PostgreSQL and Returning(), it dumps the +// next slice of results to dst, which can mean having the IDs of all inserted +// elements in the batch. +func (b *BatchInserter) NextResult(dst interface{}) bool { + clone := b.nextQuery() + if clone == nil { + return false + } + b.err = clone.Iterator().All(dst) + return (b.err == nil) +} + +// Done means that no more elements are going to be added. +func (b *BatchInserter) Done() { + close(b.values) +} + +// Wait blocks until the whole batch is executed. +func (b *BatchInserter) Wait() error { + for { + q := b.nextQuery() + if q == nil { + break + } + if _, err := q.Exec(); err != nil { + b.err = err + break + } + } + return b.Err() +} + +// Err returns any error while executing the batch. +func (b *BatchInserter) Err() error { + return b.err +} diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index 0124ff7f12e4cb0032f904d5747f92a3de9b6c20..d86991a7c357aca6c52128db7c82637e4e3c3c5b 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -17,6 +17,16 @@ type inserter struct { extra string } +func (qi *inserter) clone() *inserter { + clone := &inserter{} + *clone = *qi + return clone +} + +func (qi *inserter) Batch(n int) *BatchInserter { + return newBatchInserter(qi.clone(), n) +} + func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) error { l := len(columns) f := make([]exql.Fragment, l) diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index b0c4382a109a809e4b56167871ba78401eda155c..602c1405a71f33501db9baccb4a74aa92dca9b12 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -323,6 +323,11 @@ type Inserter interface { // Inserter. This is only possible when using Returning(). Iterator() Iterator + // Batch provies a BatchInserter that can be used to insert many elements at + // once by issuing several calls to Values(). It accepts a size parameter + // which defines the batch size. If size is < 1, the batch size is set to 1. + Batch(size int) *BatchInserter + // Execer provides the Exec method. Execer diff --git a/lib/sqlbuilder/placeholder_test.go b/lib/sqlbuilder/placeholder_test.go new file mode 100644 index 0000000000000000000000000000000000000000..82f472cd26477c0dfee15e28483cf37b0ff4ca04 --- /dev/null +++ b/lib/sqlbuilder/placeholder_test.go @@ -0,0 +1,83 @@ +package sqlbuilder + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "upper.io/db.v2" +) + +func TestPlaceholderSimple(t *testing.T) { + { + ret, _ := expandPlaceholders("?", 1) + assert.Equal(t, "?", ret) + } + { + ret, _ := expandPlaceholders("?") + assert.Equal(t, "?", ret) + } +} + +func TestPlaceholderMany(t *testing.T) { + { + ret, _ := expandPlaceholders("?, ?, ?", 1, 2, 3) + assert.Equal(t, "?, ?, ?", ret) + } +} + +func TestPlaceholderArray(t *testing.T) { + { + ret, _ := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + assert.Equal(t, "?, ?, (?, ?, ?)", ret) + } + + { + ret, _ := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + assert.Equal(t, "(?, ?, ?), ?, ?", ret) + } + + { + ret, _ := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + assert.Equal(t, "?, (?, ?, ?), ?", ret) + } + + { + ret, _ := expandPlaceholders("???", 1, []interface{}{2, 3, 4}, 5) + assert.Equal(t, "?(?, ?, ?)?", ret) + } + + { + ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}) + assert.Equal(t, "(?, ?, ?)?", ret) + } +} + +func TestPlaceholderArguments(t *testing.T) { + { + _, args := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := expandPlaceholders("?, ?", []interface{}{1, 2, 3}, []interface{}{4, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } +} + +func TestPlaceholderReplace(t *testing.T) { + { + ret, args := expandPlaceholders("?, ?, ?", 1, db.Raw("foo"), 3) + assert.Equal(t, "?, foo, ?", ret) + assert.Equal(t, []interface{}{1, 3}, args) + } +}