diff --git a/Makefile b/Makefile index 7181dd7260cd4fb9de24760eccb6a5b078be84ab..b72e16627f8d13f612b250da464c4e889c8b2c32 100644 --- a/Makefile +++ b/Makefile @@ -5,11 +5,9 @@ DB_HOST ?= 127.0.0.1 export DB_HOST test: - go test -v -benchtime=500ms -bench=. ./lib/... & \ - go test -v -benchtime=500ms -bench=. ./internal/... & \ - wait && \ + go test -v -benchtime=500ms -bench=. ./lib/... && \ + go test -v -benchtime=500ms -bench=. ./internal/... && \ for ADAPTER in postgresql mysql sqlite ql mongo; do \ - $(MAKE) -C $$ADAPTER test & \ + $(MAKE) -C $$ADAPTER test; \ done && \ - wait && \ go test -v diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 97d99bc26ebb5abc29d93b9c6e89c4f00489736d..bd6a350e30ba229768b8afb16cda2247a9e89667 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -326,7 +326,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql for i := 0; i < l; i++ { switch v := columns[i].(type) { case *selector: - expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments()) + expanded, rawArgs := Preprocess(v.statement().Compile(v.stringer.t), v.Arguments()) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case db.Function: @@ -336,11 +336,11 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + expanded, fnArgs := Preprocess(fnName, fnArgs) f[i] = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: - expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()) + expanded, rawArgs := Preprocess(v.Raw(), v.Arguments()) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case exql.Fragment: diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 22f5d0f4653bfe3088bb634da4c7d845a3ff6177..07c69f9c5f12c488173d2dd894dc8b8a3df1e403 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -746,6 +746,89 @@ func TestInsert(t *testing.T) { ) } + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1)`, + b.InsertInto("artist"). + Values(artistStruct{Name: "Chavela Vargas"}). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1)`, + b.InsertInto("artist"). + Values(artistStruct{ID: 1}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + { + q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}) + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1)`, + q.String(), + ) + assert.Equal( + []interface{}{"Chavela Vargas"}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}).Values(artistStruct{Name: "Alondra de la Parra"}) + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1), ($2)`, + q.String(), + ) + assert.Equal( + []interface{}{"Chavela Vargas", "Alondra de la Parra"}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{ID: 1}) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1)`, + q.String(), + ) + + assert.Equal( + []interface{}{1}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{ID: 1}).Values(artistStruct{ID: 2}) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1), ($2)`, + q.String(), + ) + + assert.Equal( + []interface{}{1, 2}, + q.Arguments(), + ) + } + + } + { intRef := func(i int) *int { if i == 0 { diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index 6c999aeb66b87b14e9b14decac834ca7b66bc7df..db8a592a25f29559b98f6cf101e81c7dbfeaeb14 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -78,11 +78,6 @@ func Preprocess(in string, args []interface{}) (string, []interface{}) { return expandQuery(in, args, preprocessFn) } -func expandPlaceholders(in string, args []interface{}) (string, []interface{}) { - // TODO: Remove after immutable query builder - return in, args -} - // ToWhereWithArguments converts the given parameters into a exql.Where // value. func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { @@ -93,7 +88,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. if len(t) > 0 { if s, ok := t[0].(string); ok { if strings.ContainsAny(s, "?") || len(t) == 1 { - s, args = expandPlaceholders(s, t[1:]) + s, args = Preprocess(s, t[1:]) where.Conditions = []exql.Fragment{exql.RawValue(s)} } else { var val interface{} @@ -122,7 +117,7 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. } return case db.RawValue: - r, v := expandPlaceholders(t.Raw(), t.Arguments()) + r, v := Preprocess(t.Raw(), t.Arguments()) where.Conditions = []exql.Fragment{exql.RawValue(r)} args = append(args, v...) return @@ -294,11 +289,11 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal // A function with one or more arguments. fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + expanded, fnArgs := Preprocess(fnName, fnArgs) columnValue.Value = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: - expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()) + expanded, rawArgs := Preprocess(value.Raw(), value.Arguments()) columnValue.Value = exql.RawValue(expanded) args = append(args, rawArgs...) default: diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index bef31ae34663812afa32a795a4a4f67784694a87..e5a3bb6aae6ac3a164f13a11ed3357f341a3b08b 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -2,15 +2,19 @@ package sqlbuilder import ( "database/sql" + "sync" "upper.io/db.v2/internal/sqladapter/exql" ) type inserter struct { *stringer - builder *sqlBuilder - table string - values []*exql.Values + builder *sqlBuilder + table string + + enqueuedValues [][]interface{} + mu sync.Mutex + returning []exql.Fragment columns []exql.Fragment arguments []interface{} @@ -28,6 +32,7 @@ func (qi *inserter) Batch(n int) *BatchInserter { } func (qi *inserter) Arguments() []interface{} { + _ = qi.statement() return qi.arguments } @@ -69,34 +74,77 @@ func (qi *inserter) Columns(columns ...string) Inserter { } func (qi *inserter) Values(values ...interface{}) Inserter { - if len(values) == 1 { - ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true}) - if err == nil { - columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) - - qi.arguments = append(qi.arguments, arguments...) - qi.values = append(qi.values, vals) - if len(qi.columns) == 0 { - for _, c := range columns.Columns { - qi.columns = append(qi.columns, c) + qi.mu.Lock() + defer qi.mu.Unlock() + + if qi.enqueuedValues == nil { + qi.enqueuedValues = [][]interface{}{} + } + qi.enqueuedValues = append(qi.enqueuedValues, values) + return qi +} + +func (qi *inserter) processValues() (values []*exql.Values, arguments []interface{}) { + // TODO: simplify with immutable queries + var insertNils bool + + for _, enqueuedValue := range qi.enqueuedValues { + if len(enqueuedValue) == 1 { + ff, vv, err := Map(enqueuedValue[0], nil) + if err == nil { + columns, vals, args, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) + + values, arguments = append(values, vals), append(arguments, args...) + + if len(qi.columns) == 0 { + for _, c := range columns.Columns { + qi.columns = append(qi.columns, c) + } + } else { + if len(qi.columns) != len(columns.Columns) { + insertNils = true + break + } } + continue } - return qi } - } - if len(qi.columns) == 0 || len(values) == len(qi.columns) { - qi.arguments = append(qi.arguments, values...) + if len(qi.columns) == 0 || len(enqueuedValue) == len(qi.columns) { + arguments = append(arguments, enqueuedValue...) - l := len(values) - placeholders := make([]exql.Fragment, l) - for i := 0; i < l; i++ { - placeholders[i] = exql.RawValue(`?`) + l := len(enqueuedValue) + placeholders := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + placeholders[i] = exql.RawValue(`?`) + } + values = append(values, exql.NewValueGroup(placeholders...)) } - qi.values = append(qi.values, exql.NewValueGroup(placeholders...)) } - return qi + if insertNils { + values, arguments = values[0:0], arguments[0:0] + + for _, enqueuedValue := range qi.enqueuedValues { + if len(enqueuedValue) == 1 { + ff, vv, err := Map(enqueuedValue[0], &MapOptions{IncludeZeroed: true, IncludeNil: true}) + if err == nil { + columns, vals, args, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) + values, arguments = append(values, vals), append(arguments, args...) + + if len(qi.columns) != len(columns.Columns) { + qi.columns = qi.columns[0:0] + for _, c := range columns.Columns { + qi.columns = append(qi.columns, c) + } + } + } + continue + } + } + } + + return } func (qi *inserter) statement() *exql.Statement { @@ -105,14 +153,18 @@ func (qi *inserter) statement() *exql.Statement { Table: exql.TableWithName(qi.table), } - if len(qi.values) > 0 { - stmt.Values = exql.JoinValueGroups(qi.values...) - } + values, arguments := qi.processValues() + + qi.arguments = arguments if len(qi.columns) > 0 { stmt.Columns = exql.JoinColumns(qi.columns...) } + if len(values) > 0 { + stmt.Values = exql.JoinValueGroups(values...) + } + if len(qi.returning) > 0 { stmt.Returning = exql.ReturningColumns(qi.returning...) } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 7df9f2727237162c61d0a00e94dbf02639ede7e7..8be3a88b5a7719b164f140dd36d7bae8aecb000c 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -156,7 +156,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { switch value := columns[i].(type) { case db.RawValue: - col, args := expandPlaceholders(value.Raw(), value.Arguments()) + col, args := Preprocess(value.Raw(), value.Arguments()) sort = &exql.SortColumn{ Column: exql.RawValue(col), } @@ -170,7 +170,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + expanded, fnArgs := Preprocess(fnName, fnArgs) sort = &exql.SortColumn{ Column: exql.RawValue(expanded), }