From ea9f2c44a8b533732574b98ac63d5e4ae7fdd427 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Fri, 19 Aug 2016 16:54:26 -0500 Subject: [PATCH] Support more special cases on SQLbuilder. --- db.go | 19 +- internal/sqladapter/exql/column.go | 17 +- internal/sqladapter/exql/column_test.go | 2 +- internal/sqladapter/exql/column_value_test.go | 4 +- internal/sqladapter/exql/columns.go | 5 + internal/sqladapter/exql/statement.go | 3 +- internal/sqladapter/exql/where.go | 10 +- internal/sqladapter/testing/adapter.go.tpl | 4 +- lib/sqlbuilder/builder.go | 55 ++-- lib/sqlbuilder/builder_test.go | 246 +++++++++++++++++- lib/sqlbuilder/convert.go | 95 ++++--- lib/sqlbuilder/interfaces.go | 13 +- lib/sqlbuilder/select.go | 69 ++++- 13 files changed, 445 insertions(+), 97 deletions(-) diff --git a/db.go b/db.go index 0ea6467b..2ed02e17 100644 --- a/db.go +++ b/db.go @@ -117,12 +117,13 @@ type RawValue interface { fmt.Stringer Compound Raw() string + Arguments() []interface{} } // Function interface defines methods for representing database functions. type Function interface { - Arguments() []interface{} Name() string + Arguments() []interface{} } // Marshaler is the interface implemented by struct fields that can marshal @@ -194,6 +195,14 @@ func (c Cond) Empty() bool { type rawValue struct { v string + a *[]interface{} // This may look ugly but allows us to use db.Raw() as keys for db.Cond{}. +} + +func (r rawValue) Arguments() []interface{} { + if r.a != nil { + return *r.a + } + return nil } func (r rawValue) Raw() string { @@ -389,8 +398,12 @@ func Or(conds ...Compound) *Union { // // // SOUNDEX('Hello') // Raw("SOUNDEX('Hello')") -func Raw(s string) RawValue { - return rawValue{v: s} +func Raw(value string, args ...interface{}) RawValue { + r := rawValue{v: value, a: nil} + if len(args) > 0 { + r.a = &args + } + return r } // Database is an interface that defines methods that must be satisfied by diff --git a/internal/sqladapter/exql/column.go b/internal/sqladapter/exql/column.go index b73abf2f..70626078 100644 --- a/internal/sqladapter/exql/column.go +++ b/internal/sqladapter/exql/column.go @@ -12,8 +12,9 @@ type columnT struct { // Column represents a SQL column. type Column struct { - Name interface{} - hash hash + Name interface{} + Alias string + hash hash } // ColumnWithName creates and returns a Column with the given name. @@ -32,6 +33,8 @@ func (c *Column) Compile(layout *Template) (compiled string) { return z } + alias := c.Alias + switch value := c.Name.(type) { case string: input := trimString(value) @@ -51,22 +54,22 @@ func (c *Column) Compile(layout *Template) (compiled string) { nameChunks[i] = mustParse(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) } - name = strings.Join(nameChunks, layout.ColumnSeparator) - - var alias string + compiled = strings.Join(nameChunks, layout.ColumnSeparator) if len(chunks) > 1 { alias = trimString(chunks[1]) alias = mustParse(layout.IdentifierQuote, Raw{Value: alias}) } - - compiled = mustParse(layout.ColumnAliasLayout, columnT{name, alias}) case Raw: compiled = value.String() default: compiled = fmt.Sprintf("%v", c.Name) } + if alias != "" { + compiled = mustParse(layout.ColumnAliasLayout, columnT{compiled, alias}) + } + layout.Write(c, compiled) return diff --git a/internal/sqladapter/exql/column_test.go b/internal/sqladapter/exql/column_test.go index 6852f105..15ad63a5 100644 --- a/internal/sqladapter/exql/column_test.go +++ b/internal/sqladapter/exql/column_test.go @@ -10,7 +10,7 @@ func TestColumnHash(t *testing.T) { column := Column{Name: "role.name"} s = column.Hash() - e = "*exql.Column:5574933406985810060" + e = "*exql.Column:5663680925324531495" if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) diff --git a/internal/sqladapter/exql/column_value_test.go b/internal/sqladapter/exql/column_value_test.go index 942099ee..b71db7ed 100644 --- a/internal/sqladapter/exql/column_value_test.go +++ b/internal/sqladapter/exql/column_value_test.go @@ -10,7 +10,7 @@ func TestColumnValueHash(t *testing.T) { c := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} s = c.Hash() - e = `*exql.ColumnValue:7841113954072405845` + e = `*exql.ColumnValue:4950005282640920683` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -26,7 +26,7 @@ func TestColumnValuesHash(t *testing.T) { ) s = c.Hash() - e = `*exql.ColumnValues:12182225587466517135` + e = `*exql.ColumnValues:8728513848368010747` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) diff --git a/internal/sqladapter/exql/columns.go b/internal/sqladapter/exql/columns.go index 10e768f7..ebb0ed83 100644 --- a/internal/sqladapter/exql/columns.go +++ b/internal/sqladapter/exql/columns.go @@ -30,6 +30,11 @@ func UsingColumns(columns ...Fragment) *Using { return &Using{Columns: columns} } +func (c *Columns) Append(a *Columns) *Columns { + c.Columns = append(c.Columns, a.Columns...) + return c +} + // Compile transforms the Columns into an equivalent SQL representation. func (c *Columns) Compile(layout *Template) (compiled string) { diff --git a/internal/sqladapter/exql/statement.go b/internal/sqladapter/exql/statement.go index 1665b136..6453c860 100644 --- a/internal/sqladapter/exql/statement.go +++ b/internal/sqladapter/exql/statement.go @@ -2,6 +2,7 @@ package exql import ( "reflect" + "strings" "upper.io/db.v2/internal/cache" ) @@ -109,8 +110,8 @@ func (s *Statement) Compile(layout *Template) (compiled string) { panic("Unknown template type.") } + compiled = strings.TrimSpace(compiled) layout.Write(s, compiled) - return compiled } diff --git a/internal/sqladapter/exql/where.go b/internal/sqladapter/exql/where.go index a080f4af..ad99317c 100644 --- a/internal/sqladapter/exql/where.go +++ b/internal/sqladapter/exql/where.go @@ -40,6 +40,14 @@ func (w *Where) Hash() string { return w.hash.Hash(w) } +// Appends adds the conditions to the ones that already exist. +func (w *Where) Append(a *Where) *Where { + if a != nil { + w.Conditions = append(w.Conditions, a.Conditions...) + } + return w +} + // Hash returns a unique identifier. func (o *Or) Hash() string { w := Where(*o) @@ -49,7 +57,7 @@ func (o *Or) Hash() string { // Hash returns a unique identifier. func (a *And) Hash() string { w := Where(*a) - return `Or(` + w.Hash() + `)` + return `And(` + w.Hash() + `)` } // Compile transforms the Or into an equivalent SQL representation. diff --git a/internal/sqladapter/testing/adapter.go.tpl b/internal/sqladapter/testing/adapter.go.tpl index 88d431a6..96ed05a6 100644 --- a/internal/sqladapter/testing/adapter.go.tpl +++ b/internal/sqladapter/testing/adapter.go.tpl @@ -678,9 +678,9 @@ func TestFunction(t *testing.T) { assert.Equal(t, uint64(4), total) // Testing conditions - cond = db.Cond{"id NOT": db.Func("IN", 0, -1)} + cond = db.Cond{"id NOT IN": []interface{}{0, -1}} if Adapter == "ql" { - cond = db.Cond{"id() NOT": db.Func("IN", 0, -1)} + cond = db.Cond{"id() NOT IN": []interface{}{0, -1}} } res = artist.Find(cond) diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 1edfd366..51aee331 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -15,6 +15,10 @@ import ( "upper.io/db.v2/lib/reflectx" ) +type hasArguments interface { + Arguments() []interface{} +} + type hasStatement interface { statement() *exql.Statement } @@ -121,14 +125,12 @@ func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, } } -func (b *sqlBuilder) SelectFrom(table string) Selector { +func (b *sqlBuilder) SelectFrom(table ...interface{}) Selector { qs := &selector{ builder: b, - table: table, } - qs.stringer = &stringer{qs, b.t.Template} - return qs + return qs.From(table...) } func (b *sqlBuilder) Select(columns ...interface{}) Selector { @@ -252,27 +254,43 @@ func Map(item interface{}) ([]string, []interface{}, error) { return fv.fields, fv.values, nil } -func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql.Fragment, error) { +func extractArguments(fragments []interface{}) []interface{} { + args := []interface{}{} + l := len(fragments) + for i := 0; i < l; i++ { + switch v := fragments[i].(type) { + case hasArguments: // TODO: use this on other places where we want to extract arguments. + args = append(args, v.Arguments()...) + } + } + return args +} + +func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql.Fragment, []interface{}, error) { l := len(columns) f := make([]exql.Fragment, l) + args := []interface{}{} for i := 0; i < l; i++ { switch v := columns[i].(type) { + case *selector: + expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments()...) + f[i] = exql.RawValue(expanded) + args = append(args, rawArgs...) case db.Function: - var s string - a := template.ToInterfaceArguments(v.Arguments()) - if len(a) == 0 { - s = fmt.Sprintf(`%s()`, v.Name()) + fnName, fnArgs := v.Name(), v.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" } else { - ss := make([]string, 0, len(a)) - for j := range a { - ss = append(ss, fmt.Sprintf(`%v`, a[j])) - } - s = fmt.Sprintf(`%s(%s)`, v.Name(), strings.Join(ss, `, `)) + fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - f[i] = exql.RawValue(s) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + f[i] = exql.RawValue(expanded) + args = append(args, fnArgs...) case db.RawValue: - f[i] = exql.RawValue(v.String()) + expanded, rawArgs := expandPlaceholders(v.Raw(), v.Arguments()...) + f[i] = exql.RawValue(expanded) + args = append(args, rawArgs...) case exql.Fragment: f[i] = v case string: @@ -280,11 +298,10 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql case interface{}: f[i] = exql.ColumnWithName(fmt.Sprintf("%v", v)) default: - return nil, fmt.Errorf("Unexpected argument type %T for Select() argument.", v) + return nil, nil, fmt.Errorf("Unexpected argument type %T for Select() argument.", v) } } - - return f, nil + return f, args, nil } func (s *stringer) String() string { diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 6deac459..bea0ef66 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -22,23 +22,32 @@ func TestSelect(t *testing.T) { b.SelectFrom("artist").String(), ) - assert.Equal( - `SELECT DISTINCT(name) FROM "artist"`, - b.Select(db.Func("DISTINCT", "name")).From("artist").String(), - ) + { + sel := b.Select(db.Func("DISTINCT", "name")).From("artist") + assert.Equal( + `SELECT DISTINCT($1) FROM "artist"`, + sel.String(), + ) + assert.Equal( + []interface{}{"name"}, + sel.Arguments(), + ) + } assert.Equal( `SELECT * FROM "artist" WHERE (1 = $1)`, b.Select().From("artist").Where(db.Cond{1: 1}).String(), ) - // TODO: handle this case - /* - assert.Equal( - `SELECT * FROM "artist" WHERE ($1 = ANY("column"))`, // search_term = ANY(name) - b.Select().From("artist").Where(db.Cond{1: db.Func("ANY", db.Raw("column"))}).String(), // ?? - ) - */ + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = ANY($1))`, + b.Select().From("artist").Where(db.Cond{1: db.Func("ANY", "name")}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = ANY(column))`, + b.Select().From("artist").Where(db.Cond{1: db.Func("ANY", db.Raw("column"))}).String(), + ) assert.Equal( `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, @@ -181,6 +190,11 @@ func TestSelect(t *testing.T) { b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), ) + assert.Equal( + `SELECT * FROM "artist" WHERE (id IN ($1, $2, $3, $4) AND foo = $5 AND bar IN ($6, $7, $8))`, + b.SelectFrom("artist").Where("id IN ? AND foo = ? AND bar IN ?", []int{1, 9, 8, 7}, 28, []int{1, 2, 3}).String(), + ) + assert.Equal( `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, b.SelectFrom("artist").Where("name IS NOT NULL").String(), @@ -253,6 +267,216 @@ func TestSelect(t *testing.T) { `SELECT DATE()`, b.Select(db.Raw("DATE()")).String(), ) + + { + sel := b.Select(db.Raw("CONCAT(?, ?)", "foo", "bar")) + assert.Equal( + `SELECT CONCAT($1, $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{"foo", "bar"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{"bar": db.Raw("1")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{db.Raw("1"): 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{db.Raw("1"): db.Raw("1")}) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Raw("1 = 1")) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{"bar": 1}, db.Cond{"baz": db.Raw("CONCAT(?, ?)", "foo", "bar")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = CONCAT($2, $3))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "foo", "bar"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{"bar": 1}, db.Raw("? = ANY(col)", "name")) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND $2 = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "name"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{"bar": 1}, db.Cond{"name": db.Raw("ANY(col)")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "name" = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{"bar": 1}, db.Cond{db.Raw("CONCAT(?, ?)", "a", "b"): db.Raw("ANY(col)")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND CONCAT($2, $3) = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "a", "b"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).And(db.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").And(db.Cond{"bar": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).Where(db.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Raw("bar->'baz' = ?", true)) + assert.Equal( + `SELECT * FROM "foo" WHERE (bar->'baz' = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{true}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(db.Cond{}).And(db.Cond{}) + assert.Equal( + `SELECT * FROM "foo"`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar = 1").And(db.Or( + db.Raw("fieldA ILIKE ?", `%a%`), + db.Raw("fieldB ILIKE ?", `%b%`), + )) + assert.Equal( + `SELECT * FROM "foo" WHERE (bar = 1 AND (fieldA ILIKE $1 OR fieldB ILIKE $2))`, + sel.String(), + ) + assert.Equal( + []interface{}{`%a%`, `%b%`}, + sel.Arguments(), + ) + } + + { + s := `SUM(CASE WHEN foo in ? THEN 1 ELSE 0 END) AS _sum` + sel := b.Select("c1").Columns(db.Raw(s, []int{5, 4, 3, 2})).From("foo").Where("bar = ?", 1) + assert.Equal( + `SELECT "c1", SUM(CASE WHEN foo in ($1, $2, $3, $4) THEN 1 ELSE 0 END) AS _sum FROM "foo" WHERE (bar = $5)`, + sel.String(), + ) + assert.Equal( + []interface{}{5, 4, 3, 2, 1}, + sel.Arguments(), + ) + } + + { + s := `SUM(CASE WHEN foo in ? THEN 1 ELSE 0 END) AS _sum` + sel := b.Select("c1").Columns(db.Raw(s, []int{5, 4, 3, 2})).From("foo").Where("bar = ?", 1) + sel2 := b.SelectFrom(sel).As("subquery").Where(db.Cond{"foo": "bar"}).OrderBy("subquery.seq") + assert.Equal( + `SELECT * FROM (SELECT "c1", SUM(CASE WHEN foo in ($1, $2, $3, $4) THEN 1 ELSE 0 END) AS _sum FROM "foo" WHERE (bar = $5)) AS "subquery" WHERE ("foo" = $6) ORDER BY "subquery"."seq" ASC`, + sel2.String(), + ) + assert.Equal( + []interface{}{5, 4, 3, 2, 1, "bar"}, + sel2.Arguments(), + ) + } } func TestInsert(t *testing.T) { diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index 416a01a4..e1ecd446 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -24,6 +24,37 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils { return &templateWithUtils{template} } +func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) { + argn := 0 + for i := 0; i < len(in); i++ { + if in[i] == '?' { + if len(args) > argn { // we have arguments to match. + u := toInterfaceArguments(args[argn]) + k := `?` + + if len(u) > 1 { + // An array of arguments + k = `(?` + strings.Repeat(`, ?`, len(u)-1) + `)` + } else if len(u) == 1 { + if rawValue, ok := u[0].(db.RawValue); ok { + k = rawValue.Raw() + u = []interface{}{} + } + } + + lk := len(k) + if lk > 1 { + in = in[:i] + k + in[i+1:] + i += len(k) - 1 + } + args = append(args[:argn], append(u, args[argn+1:]...)...) + argn += len(u) + } + } + } + return in, args +} + // ToWhereWithArguments converts the given parameters into a exql.Where // value. func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { @@ -34,38 +65,17 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. if len(t) > 0 { if s, ok := t[0].(string); ok { if strings.ContainsAny(s, "?") || len(t) == 1 { - var j int - - vv := t[1:] - - for i := 0; i < len(s); i++ { - if s[i] == '?' { - if len(vv) > j { - u := tu.ToInterfaceArguments(vv[j]) - args = append(args, u...) - j = j + 1 - if len(u) > 1 { - k := "(?" + strings.Repeat(", ?", len(u)-1) + ")" - s = s[:i] + k + s[i+1:] - i = i - 1 + len(k) - } - } - } - } - + s, args = expandPlaceholders(s, t[1:]...) where.Conditions = []exql.Fragment{exql.RawValue(s)} } else { var val interface{} key := s - if len(t) > 2 { val = t[1:] } else { val = t[1] } - cv, v := tu.ToColumnValues(db.NewConstraint(key, val)) - args = append(args, v...) for i := range cv.ColumnValues { where.Conditions = append(where.Conditions, cv.ColumnValues[i]) @@ -84,7 +94,9 @@ func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql. } return case db.RawValue: - where.Conditions = []exql.Fragment{exql.RawValue(t.String())} + r, v := expandPlaceholders(t.Raw(), t.Arguments()...) + where.Conditions = []exql.Fragment{exql.RawValue(r)} + args = append(args, v...) return case db.Constraints: for _, c := range t.Constraints() { @@ -142,7 +154,7 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] fnName := t.Name() fnArgs := []interface{}{} - args := tu.ToInterfaceArguments(t.Arguments()) + args := toInterfaceArguments(t.Arguments()) fragments := []string{} for i := range args { frag, args := tu.PlaceholderValue(args[i]) @@ -156,8 +168,8 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] } } -// ToInterfaceArguments converts the given value into an array of interfaces. -func (tu *templateWithUtils) ToInterfaceArguments(value interface{}) (args []interface{}) { +// toInterfaceArguments converts the given value into an array of interfaces. +func toInterfaceArguments(value interface{}) (args []interface{}) { if value == nil { return nil } @@ -167,19 +179,15 @@ func (tu *templateWithUtils) ToInterfaceArguments(value interface{}) (args []int switch v.Type().Kind() { case reflect.Slice: var i, total int - if v.Type().Elem().Kind() == reflect.Uint8 { return []interface{}{string(value.([]byte))} } - total = v.Len() if total > 0 { args = make([]interface{}, total) - for i = 0; i < total; i++ { args[i] = v.Index(i).Interface() } - return args } return nil @@ -239,24 +247,33 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal columnValue.Operator = chunks[1] } } else { - columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) + if rawValue, ok := t.Key().(db.RawValue); ok { + columnValue.Column = exql.RawValue(rawValue.Raw()) + args = append(args, rawValue.Arguments()...) + } else { + columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) + } } switch value := t.Value().(type) { case db.Function: - v := tu.ToInterfaceArguments(value.Arguments()) - - if v == nil { + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { // A function with no arguments. - columnValue.Value = exql.RawValue(fmt.Sprintf(`%s()`, value.Name())) + fnName = fnName + "()" } else { // A function with one or more arguments. - columnValue.Value = exql.RawValue(fmt.Sprintf(`%s(?%s)`, value.Name(), strings.Repeat(`, ?`, len(v)-1))) + fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - - args = append(args, v...) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + columnValue.Value = exql.RawValue(expanded) + args = append(args, fnArgs...) + case db.RawValue: + expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...) + columnValue.Value = exql.RawValue(expanded) + args = append(args, rawArgs...) default: - v := tu.ToInterfaceArguments(value) + v := toInterfaceArguments(value) if v == nil { // Nil value given. diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index 094e8639..b0c4382a 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -45,7 +45,7 @@ type Builder interface { // Example: // // q := sqlbuilder.SelectFrom("people").Where(...) - SelectFrom(table string) Selector + SelectFrom(table ...interface{}) Selector // InsertInto prepares an returns a Inserter that points at the given table. // @@ -139,7 +139,7 @@ type Selector interface { // Or with the shortcut: // // s.Columns(...).From("people p").Where("p.name = ?", ...) - From(tables ...string) Selector + From(tables ...interface{}) Selector // Distict represents a DISCTING clause. // @@ -147,6 +147,9 @@ type Selector interface { // different. Distinct() Selector + // As defines an alias for a table. + As(string) Selector + // Where specifies the conditions that columns must match in order to be // retrieved. // @@ -174,6 +177,9 @@ type Selector interface { // fmt.Sprintf("%v", arg) to transform the type into a string. Where(conds ...interface{}) Selector + // And appends more arguments to the WHERE clause. + And(conds ...interface{}) Selector + // GroupBy represents a GROUP BY statement. // // GROUP BY defines which columns should be used to aggregate and group @@ -283,6 +289,9 @@ type Selector interface { // fmt.Stringer provides `String() string`, you can use `String()` to compile // the `Selector` into a string. fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} } // Inserter represents an INSERT statement. diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 3483b91f..854d54b0 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -21,7 +21,8 @@ type selector struct { *stringer mode selectMode builder *sqlBuilder - table string + table *exql.Columns + as string where *exql.Where groupBy *exql.GroupBy orderBy exql.OrderBy @@ -33,18 +34,34 @@ type selector struct { err error } -func (qs *selector) From(tables ...string) Selector { - qs.table = strings.Join(tables, ",") +func (qs *selector) From(tables ...interface{}) Selector { + f, args, err := columnFragments(qs.builder.t, tables) + if err != nil { + qs.err = err + return qs + } + c := exql.JoinColumns(f...) + qs.table = c + + qs.arguments = append(qs.arguments, args...) return qs } func (qs *selector) Columns(columns ...interface{}) Selector { - f, err := columnFragments(qs.builder.t, columns) + f, args, err := columnFragments(qs.builder.t, columns) if err != nil { qs.err = err return qs } - qs.columns = exql.JoinColumns(f...) + + c := exql.JoinColumns(f...) + if qs.columns != nil { + qs.columns.Append(c) + } else { + qs.columns = c + } + + qs.arguments = append(qs.arguments, args...) return qs } @@ -54,18 +71,39 @@ func (qs *selector) Distinct() Selector { } func (qs *selector) Where(terms ...interface{}) Selector { + if qs.where != nil { + return qs.And(terms...) + } where, arguments := qs.builder.t.ToWhereWithArguments(terms) qs.where = &where qs.arguments = append(qs.arguments, arguments...) return qs } +func (qs *selector) And(terms ...interface{}) Selector { + where, arguments := qs.builder.t.ToWhereWithArguments(terms) + if qs.where == nil { + qs.where = &exql.Where{} + } + qs.where.Append(&where) + qs.arguments = append(qs.arguments, arguments...) + return qs +} + +func (qs *selector) Arguments() []interface{} { + return qs.arguments +} + func (qs *selector) GroupBy(columns ...interface{}) Selector { - var fragments []exql.Fragment - fragments, qs.err = columnFragments(qs.builder.t, columns) + fragments, args, err := columnFragments(qs.builder.t, columns) + if err != nil { + qs.err = err + return qs + } if fragments != nil { qs.groupBy = exql.GroupByColumns(fragments...) } + qs.arguments = append(qs.arguments, args...) return qs } @@ -121,11 +159,12 @@ func (qs *selector) Using(columns ...interface{}) Selector { return qs } - fragments, err := columnFragments(qs.builder.t, columns) + fragments, args, err := columnFragments(qs.builder.t, columns) if err != nil { qs.err = err return qs } + qs.arguments = append(qs.arguments, args...) lastJoin.Using = exql.UsingColumns(fragments...) return qs @@ -205,7 +244,7 @@ func (qs *selector) Offset(n int) Selector { func (qs *selector) statement() *exql.Statement { return &exql.Statement{ Type: exql.Select, - Table: exql.TableWithName(qs.table), + Table: qs.table, Columns: qs.columns, Limit: qs.limit, Offset: qs.offset, @@ -220,6 +259,18 @@ func (qs *selector) Query() (*sql.Rows, error) { return qs.builder.sess.StatementQuery(qs.statement(), qs.arguments...) } +func (qs *selector) As(alias string) Selector { + if qs.table == nil { + qs.err = errors.New("Cannot use As() without a preceding From() expression") + return qs + } + last := len(qs.table.Columns) - 1 + if raw, ok := qs.table.Columns[last].(*exql.Raw); ok { + qs.table.Columns[last] = exql.RawValue("(" + raw.Value + ") AS " + exql.ColumnWithName(alias).Compile(qs.stringer.t)) + } + return qs +} + func (qs *selector) QueryRow() (*sql.Row, error) { return qs.builder.sess.StatementQueryRow(qs.statement(), qs.arguments...) } -- GitLab