diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 748b03338d05fff788947b298f8d662757436aea..522f8aae9c508b653896e6c074f4f30c657ec462 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -1082,22 +1082,85 @@ func TestBatchInsert(t *testing.T) { sess := mustOpen() defer sess.Close() + for batchSize := 0; batchSize < 17; batchSize++ { + err := sess.Collection("artist").Truncate() + assert.NoError(t, err) + + batch := sess.InsertInto("artist").Columns("name").NewBatch(batchSize) + + totalItems := int(rand.Int31n(21)) + + go func() { + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + batch.Done() + }() + + for q := range batch.Next() { + _, err = q.Exec() + assert.NoError(t, err) + } + + c, err := sess.Collection("artist").Find().Count() + assert.NoError(t, err) + assert.Equal(t, uint64(totalItems), c) + + for i := 0; i < totalItems; i++ { + c, err := sess.Collection("artist").Find(db.Cond{"name": fmt.Sprintf("artist-%d", i)}).Count() + assert.NoError(t, err) + assert.Equal(t, uint64(1), c) + } + } +} + +func TestBatchInsertReturningKeys(t *testing.T) { + if Adapter == "ql" { + t.Skip("Currently not supported.") + } + + sess := mustOpen() + defer sess.Close() + err := sess.Collection("artist").Truncate() assert.NoError(t, err) - batch := sess.InsertInto("artist").Columns("name").NewBatch(5) + batchSize, totalItems := 7, 12 + + batch := sess.InsertInto("artist").Columns("name").Returning("id").NewBatch(batchSize) go func() { - for i := 0; i < 9; i++ { + for i := 0; i < totalItems; i++ { batch.Values(fmt.Sprintf("artist-%d", i)) } batch.Done() }() for q := range batch.Next() { - _, err = q.Exec() + var keyMap []struct{ID int `db:"id"`} + err := q.Iterator().All(&keyMap) assert.NoError(t, err) + + // Each insertion must produce new keys. + assert.True(t, len(keyMap) > 0) + assert.True(t, len(keyMap) <= batchSize) + + // Find the elements we've just inserted + keys := make([]int, len(keyMap)) + for i := range keyMap { + keys = append(keys, keyMap[i].ID) + } + + // Make sure count matches. + c, err := sess.Collection("artist").Find(db.Cond{"id": keys}).Count() + assert.NoError(t, err) + assert.Equal(t, uint64(len(keyMap)), c) } + + // Count all new elements + c, err := sess.Collection("artist").Find().Count() + assert.NoError(t, err) + assert.Equal(t, uint64(totalItems), c) } func TestBuilder(t *testing.T) { diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go new file mode 100644 index 0000000000000000000000000000000000000000..f35bb048a7e2d4a4eea863f2915fef2e6d9470f7 --- /dev/null +++ b/lib/sqlbuilder/batch.go @@ -0,0 +1,69 @@ +package sqlbuilder + +import ( + "sync" +) + +type BatchInserter struct { + inserter *inserter + size int + values [][]interface{} + next chan Inserter + mu sync.Mutex +} + +func newBatchInserter(inserter *inserter, size int) *BatchInserter { + if size < 1 { + size = 1 + } + b := &BatchInserter{ + inserter: inserter, + size: size, + next: make(chan Inserter), + } + b.reset() + return b +} + +func (b *BatchInserter) reset() { + b.values = make([][]interface{}, 0, b.size) +} + +func (b *BatchInserter) flush() { + if len(b.values) > 0 { + clone := b.inserter.clone() + for i := range b.values { + clone.Values(b.values[i]...) + } + b.next <- clone + b.reset() + } +} + +// Values pushes a value to be inserted as part of the batch. +func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { + b.mu.Lock() + defer b.mu.Unlock() + + b.values = append(b.values, values) + if len(b.values) >= b.size { + b.flush() + } + return b +} + +// Next returns a channel that receives new q elements. +func (b *BatchInserter) Next() chan Inserter { + b.mu.Lock() + defer b.mu.Unlock() + + return b.next +} + +func (b *BatchInserter) Done() { + b.mu.Lock() + defer b.mu.Unlock() + + b.flush() + close(b.next) +} diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index f7259de4c4cd7289e218e3048a4bddbce65756e7..a1247143c988adf44382cc28a1d205d5acb75565 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -24,7 +24,7 @@ func (qi *inserter) clone() *inserter { } func (qi *inserter) NewBatch(n int) *BatchInserter { - return &BatchInserter{inserter: qi.clone(), size: n} + return newBatchInserter(qi.clone(), n) } func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) error { diff --git a/lib/sqlbuilder/placeholder_test.go b/lib/sqlbuilder/placeholder_test.go new file mode 100644 index 0000000000000000000000000000000000000000..82f472cd26477c0dfee15e28483cf37b0ff4ca04 --- /dev/null +++ b/lib/sqlbuilder/placeholder_test.go @@ -0,0 +1,83 @@ +package sqlbuilder + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "upper.io/db.v2" +) + +func TestPlaceholderSimple(t *testing.T) { + { + ret, _ := expandPlaceholders("?", 1) + assert.Equal(t, "?", ret) + } + { + ret, _ := expandPlaceholders("?") + assert.Equal(t, "?", ret) + } +} + +func TestPlaceholderMany(t *testing.T) { + { + ret, _ := expandPlaceholders("?, ?, ?", 1, 2, 3) + assert.Equal(t, "?, ?, ?", ret) + } +} + +func TestPlaceholderArray(t *testing.T) { + { + ret, _ := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + assert.Equal(t, "?, ?, (?, ?, ?)", ret) + } + + { + ret, _ := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + assert.Equal(t, "(?, ?, ?), ?, ?", ret) + } + + { + ret, _ := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + assert.Equal(t, "?, (?, ?, ?), ?", ret) + } + + { + ret, _ := expandPlaceholders("???", 1, []interface{}{2, 3, 4}, 5) + assert.Equal(t, "?(?, ?, ?)?", ret) + } + + { + ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}) + assert.Equal(t, "(?, ?, ?)?", ret) + } +} + +func TestPlaceholderArguments(t *testing.T) { + { + _, args := expandPlaceholders("?, ?, ?", 1, 2, []interface{}{3, 4, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := expandPlaceholders("?, ?, ?", 1, []interface{}{2, 3, 4}, 5) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := expandPlaceholders("?, ?, ?", []interface{}{1, 2, 3}, 4, 5) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := expandPlaceholders("?, ?", []interface{}{1, 2, 3}, []interface{}{4, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } +} + +func TestPlaceholderReplace(t *testing.T) { + { + ret, args := expandPlaceholders("?, ?, ?", 1, db.Raw("foo"), 3) + assert.Equal(t, "?, foo, ?", ret) + assert.Equal(t, []interface{}{1, 3}, args) + } +}