diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 9120f4ee9bafb2056ea308b6653a68f330aa7453..315cd4f17e3512ef0beb581765e7f26ac6365838 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -254,6 +254,16 @@ func TestSelect(t *testing.T) { b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), ) + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Where("a.id = ?", 3).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3 AND a.id = $4) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).And("a.id = ?", 3).Limit(1).String(), + ) + assert.Equal( `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, b.SelectFrom("artist a"). @@ -425,7 +435,7 @@ func TestSelect(t *testing.T) { } { - sel := b.SelectFrom("foo").Where("bar", 2).Where(db.Cond{"baz": 1}) + sel := b.SelectFrom("foo").Where("bar", 2).And(db.Cond{"baz": 1}) assert.Equal( `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, sel.String(), @@ -436,6 +446,18 @@ func TestSelect(t *testing.T) { ) } + { + sel := b.SelectFrom("foo").Where("bar", 2).Where(db.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("baz" = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + { sel := b.SelectFrom("foo").Where(db.Raw("bar->'baz' = ?", true)) assert.Equal( diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index d0d6de3dd7f4b681daf148a1b60b7eb1e22d3d5d..478179d20afcc098f9fc040186f4b3ae941d2883 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -175,9 +175,13 @@ type Selector interface { // the Marshaler interface, then with fmt.Stringer and finally, if the // argument does not satisfy any of those interfaces Where() will use // fmt.Sprintf("%v", arg) to transform the type into a string. + // + // Subsequent calls to Where() will overwrite previously set conditions, if + // you want these new conditions to be appended use And() instead. Where(conds ...interface{}) Selector - // And appends more arguments to the WHERE clause. + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. And(conds ...interface{}) Selector // GroupBy represents a GROUP BY statement. diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 4a79c5017aa364e65ceb1ef50c3881ddd321220c..326e613e8f67602cfef64140359b94f88764844b 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "strings" + "sync" "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter/exql" @@ -19,91 +20,135 @@ const ( type selector struct { *stringer - mode selectMode - builder *sqlBuilder + + mode selectMode + builder *sqlBuilder + table *exql.Columns - as string + tableArgs []interface{} + + as string + where *exql.Where - groupBy *exql.GroupBy - orderBy exql.OrderBy - limit exql.Limit - offset exql.Offset - columns *exql.Columns + whereArgs []interface{} + + groupBy *exql.GroupBy + groupByArgs []interface{} + + orderBy exql.OrderBy + orderByArgs []interface{} + + limit exql.Limit + offset exql.Offset + + columns *exql.Columns + columnsArgs []interface{} + joins []*exql.Join - arguments []interface{} - err error + joinsArgs []interface{} + + mu sync.Mutex + + err error } func (qs *selector) From(tables ...interface{}) Selector { f, args, err := columnFragments(qs.builder.t, tables) if err != nil { - qs.err = err + qs.setErr(err) return qs } c := exql.JoinColumns(f...) + + qs.mu.Lock() qs.table = c + qs.tableArgs = args + qs.mu.Unlock() - qs.arguments = append(qs.arguments, args...) return qs } func (qs *selector) Columns(columns ...interface{}) Selector { f, args, err := columnFragments(qs.builder.t, columns) if err != nil { - qs.err = err + qs.setErr(err) return qs } c := exql.JoinColumns(f...) + + qs.mu.Lock() if qs.columns != nil { qs.columns.Append(c) } else { qs.columns = c } + qs.columnsArgs = append(qs.columnsArgs, args...) + qs.mu.Unlock() - qs.arguments = append(qs.arguments, args...) return qs } func (qs *selector) Distinct() Selector { + qs.mu.Lock() qs.mode = selectModeDistinct + qs.mu.Unlock() return qs } func (qs *selector) Where(terms ...interface{}) Selector { - if qs.where != nil { - return qs.And(terms...) - } - where, arguments := qs.builder.t.ToWhereWithArguments(terms) - qs.where = &where - qs.arguments = append(qs.arguments, arguments...) - return qs + qs.mu.Lock() + qs.where, qs.whereArgs = &exql.Where{}, []interface{}{} + qs.mu.Unlock() + return qs.And(terms...) } func (qs *selector) And(terms ...interface{}) Selector { - where, arguments := qs.builder.t.ToWhereWithArguments(terms) + where, whereArgs := qs.builder.t.ToWhereWithArguments(terms) + + qs.mu.Lock() if qs.where == nil { - qs.where = &exql.Where{} + qs.where, qs.whereArgs = &exql.Where{}, []interface{}{} } qs.where.Append(&where) - qs.arguments = append(qs.arguments, arguments...) + qs.whereArgs = append(qs.whereArgs, whereArgs...) + qs.mu.Unlock() + return qs } func (qs *selector) Arguments() []interface{} { - return qs.arguments + qs.mu.Lock() + defer qs.mu.Unlock() + + total := len(qs.tableArgs) + len(qs.columnsArgs) + len(qs.whereArgs) + len(qs.joinsArgs) + len(qs.groupByArgs) + len(qs.orderByArgs) + if total == 0 { + return nil + } + args := make([]interface{}, 0, total) + args = append(args, qs.tableArgs...) + args = append(args, qs.columnsArgs...) + args = append(args, qs.joinsArgs...) + args = append(args, qs.whereArgs...) + args = append(args, qs.groupByArgs...) + args = append(args, qs.orderByArgs...) + return args } func (qs *selector) GroupBy(columns ...interface{}) Selector { fragments, args, err := columnFragments(qs.builder.t, columns) if err != nil { - qs.err = err + qs.setErr(err) return qs } + + qs.mu.Lock() if fragments != nil { qs.groupBy = exql.GroupByColumns(fragments...) } - qs.arguments = append(qs.arguments, args...) + qs.groupByArgs = args + qs.mu.Unlock() + return qs } @@ -119,7 +164,9 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { sort = &exql.SortColumn{ Column: exql.RawValue(col), } - qs.arguments = append(qs.arguments, args...) + qs.mu.Lock() + qs.orderByArgs = args + qs.mu.Unlock() case db.Function: fnName, fnArgs := value.Name(), value.Arguments() if len(fnArgs) == 0 { @@ -131,7 +178,9 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { sort = &exql.SortColumn{ Column: exql.RawValue(expanded), } - qs.arguments = append(qs.arguments, fnArgs...) + qs.mu.Lock() + qs.orderByArgs = fnArgs + qs.mu.Unlock() case string: if strings.HasPrefix(value, "-") { sort = &exql.SortColumn{ @@ -152,57 +201,66 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { } } default: - qs.err = fmt.Errorf("Can't sort by type %T", value) + qs.setErr(fmt.Errorf("Can't sort by type %T", value)) return qs } sortColumns.Columns = append(sortColumns.Columns, sort) } + qs.mu.Lock() qs.orderBy.SortColumns = &sortColumns + qs.mu.Unlock() return qs } func (qs *selector) Using(columns ...interface{}) Selector { - if len(qs.joins) == 0 { - qs.err = errors.New(`Cannot use Using() without a preceding Join() expression.`) + qs.mu.Lock() + joins := len(qs.joins) + qs.mu.Unlock() + + if joins == 0 { + qs.setErr(errors.New(`Cannot use Using() without a preceding Join() expression.`)) return qs } - lastJoin := qs.joins[len(qs.joins)-1] - + lastJoin := qs.joins[joins-1] if lastJoin.On != nil { - qs.err = errors.New(`Cannot use Using() and On() with the same Join() expression.`) + qs.setErr(errors.New(`Cannot use Using() and On() with the same Join() expression.`)) return qs } fragments, args, err := columnFragments(qs.builder.t, columns) if err != nil { - qs.err = err + qs.setErr(err) return qs } - qs.arguments = append(qs.arguments, args...) + qs.mu.Lock() + qs.joinsArgs = append(qs.joinsArgs, args...) lastJoin.Using = exql.UsingColumns(fragments...) + qs.mu.Unlock() + return qs } func (qs *selector) pushJoin(t string, tables []interface{}) Selector { - if qs.joins == nil { - qs.joins = []*exql.Join{} - } - tableNames := make([]string, len(tables)) for i := range tables { tableNames[i] = fmt.Sprintf("%s", tables[i]) } + qs.mu.Lock() + if qs.joins == nil { + qs.joins = []*exql.Join{} + } qs.joins = append(qs.joins, &exql.Join{ Type: t, Table: exql.TableWithName(strings.Join(tableNames, ", ")), }, ) + qs.mu.Unlock() return qs } @@ -228,33 +286,44 @@ func (qs *selector) Join(tables ...interface{}) Selector { } func (qs *selector) On(terms ...interface{}) Selector { - if len(qs.joins) == 0 { - qs.err = errors.New(`Cannot use On() without a preceding Join() expression.`) + qs.mu.Lock() + joins := len(qs.joins) + qs.mu.Unlock() + + if joins == 0 { + qs.setErr(errors.New(`Cannot use On() without a preceding Join() expression.`)) return qs } - lastJoin := qs.joins[len(qs.joins)-1] - + lastJoin := qs.joins[joins-1] if lastJoin.On != nil { - qs.err = errors.New(`Cannot use Using() and On() with the same Join() expression.`) + qs.setErr(errors.New(`Cannot use Using() and On() with the same Join() expression.`)) return qs } w, a := qs.builder.t.ToWhereWithArguments(terms) o := exql.On(w) + lastJoin.On = &o - qs.arguments = append(qs.arguments, a...) + qs.mu.Lock() + qs.joinsArgs = append(qs.joinsArgs, a...) + qs.mu.Unlock() + return qs } func (qs *selector) Limit(n int) Selector { + qs.mu.Lock() qs.limit = exql.Limit(n) + qs.mu.Unlock() return qs } func (qs *selector) Offset(n int) Selector { + qs.mu.Lock() qs.offset = exql.Offset(n) + qs.mu.Unlock() return qs } @@ -273,12 +342,12 @@ func (qs *selector) statement() *exql.Statement { } func (qs *selector) Query() (*sql.Rows, error) { - return qs.builder.sess.StatementQuery(qs.statement(), qs.arguments...) + return qs.builder.sess.StatementQuery(qs.statement(), qs.Arguments()...) } func (qs *selector) As(alias string) Selector { if qs.table == nil { - qs.err = errors.New("Cannot use As() without a preceding From() expression") + qs.setErr(errors.New("Cannot use As() without a preceding From() expression")) return qs } last := len(qs.table.Columns) - 1 @@ -289,11 +358,11 @@ func (qs *selector) As(alias string) Selector { } func (qs *selector) QueryRow() (*sql.Row, error) { - return qs.builder.sess.StatementQueryRow(qs.statement(), qs.arguments...) + return qs.builder.sess.StatementQueryRow(qs.statement(), qs.Arguments()...) } func (qs *selector) Iterator() Iterator { - rows, err := qs.builder.sess.StatementQuery(qs.statement(), qs.arguments...) + rows, err := qs.builder.sess.StatementQuery(qs.statement(), qs.Arguments()...) return &iterator{rows, err} } @@ -304,3 +373,9 @@ func (qs *selector) All(destSlice interface{}) error { func (qs *selector) One(dest interface{}) error { return qs.Iterator().One(dest) } + +func (qs *selector) setErr(err error) { + qs.mu.Lock() + qs.err = err + qs.mu.Unlock() +}