diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 51aee331431c51cba6328978c88d145e09f50575..fd0d9375b60079492378b2e2fa42f52207c6b687 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -284,7 +284,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql } else { fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) f[i] = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index bea0ef66f524ce94d32030084b5f6785c28a9f5e..0be5715244feee81bfd704893c5aefb4d991df06 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -125,6 +125,30 @@ func TestSelect(t *testing.T) { b.Select().From("artist").OrderBy("name DESC").String(), ) + { + sel := b.Select().From("artist").OrderBy(db.Raw("id = ?", 1), "name DESC") + assert.Equal( + `SELECT * FROM "artist" ORDER BY id = $1 , "name" DESC`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.Select().From("artist").OrderBy(db.Func("RAND")) + assert.Equal( + `SELECT * FROM "artist" ORDER BY RAND()`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + assert.Equal( `SELECT * FROM "artist" ORDER BY RAND()`, b.Select().From("artist").OrderBy(db.Raw("RAND()")).String(), diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 854d54b020275f53b36adc3ffc6fc9c60189e5e8..29aaddbabbc4c9536603aa4260cc39ce1a3e8d0b 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -115,9 +115,23 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { switch value := columns[i].(type) { case db.RawValue: + col, args := expandPlaceholders(value.Raw(), value.Arguments()...) sort = &exql.SortColumn{ - Column: exql.RawValue(value.String()), + Column: exql.RawValue(col), } + qs.arguments = append(qs.arguments, args...) + case db.Function: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" + } + expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + sort = &exql.SortColumn{ + Column: exql.RawValue(expanded), + } + qs.arguments = append(qs.arguments, fnArgs...) case string: if strings.HasPrefix(value, "-") { sort = &exql.SortColumn{ @@ -137,6 +151,9 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { Order: order, } } + default: + qs.err = fmt.Errorf("Can't sort by type %T", value) + return qs } sortColumns.Columns = append(sortColumns.Columns, sort) }