diff --git a/internal/sqladapter/exql/order_by_test.go b/internal/sqladapter/exql/order_by_test.go index 8a4ff471da34f585e505faf170ad536e09834d02..f800c58b638eff9878fc96e0f3c8f2005f095144 100644 --- a/internal/sqladapter/exql/order_by_test.go +++ b/internal/sqladapter/exql/order_by_test.go @@ -19,6 +19,21 @@ func TestOrderBy(t *testing.T) { } } +func TestOrderByRaw(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: RawValue("CASE WHEN id IN ? THEN 0 ELSE 1 END")}, + ), + ) + + s := o.Compile(defaultTemplate) + e := `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END` + + if trim(s) != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestOrderByDesc(t *testing.T) { o := JoinWithOrderBy( JoinSortColumns( diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 4ad9d8c600b526b807d798171b52a8144ae4fbed..1a2e5dbabb354b56461e4fdc86fe92cc2fdd9cd6 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -22,6 +22,71 @@ func TestSelect(t *testing.T) { b.SelectFrom("artist").String(), ) + { + rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000, 2000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + sel.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000}, + sel.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN (NULL) THEN 0 ELSE 1 END") + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + rawCase.Arguments(), + ) + } + + { + rawCase := db.Raw("CASE WHEN id IN (?, ?) THEN 0 ELSE 1 END", 1000, 2000) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + rawCase.Arguments(), + ) + } + { sel := b.Select(db.Func("DISTINCT", "name")).From("artist") assert.Equal( @@ -49,15 +114,29 @@ func TestSelect(t *testing.T) { b.Select().From("artist").Where(db.Cond{1: db.Func("ANY", db.Raw("column"))}).String(), ) - assert.Equal( - `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, - b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{0, -1}}).String(), - ) + { + q := b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{0, -1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, + q.String(), + ) + assert.Equal( + []interface{}{0, -1}, + q.Arguments(), + ) + } - assert.Equal( - `SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`, - b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{-1}}).String(), - ) + { + q := b.Select().From("artist").Where(db.Cond{"id NOT IN": []int{-1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`, + q.String(), + ) + assert.Equal( + []interface{}{-1}, + q.Arguments(), + ) + } assert.Equal( `SELECT * FROM "artist" WHERE ("id" IN ($1, $2))`, @@ -288,7 +367,7 @@ func TestSelect(t *testing.T) { ) assert.Equal( - `SELECT * FROM "artist" WHERE ("id" IS NULL)`, + `SELECT * FROM "artist" WHERE ("id" IN (NULL))`, b.SelectFrom("artist").Where(db.Cond{"id": []int64{}}).String(), ) @@ -671,7 +750,7 @@ func TestUpdate(t *testing.T) { idSlice := []int64{} q := b.Update("artist").Set(db.Cond{"some_column": 10}).Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}) assert.Equal( - `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`, + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, q.String(), ) assert.Equal( @@ -684,7 +763,7 @@ func TestUpdate(t *testing.T) { idSlice := []int64{} q := b.Update("artist").Where(db.Cond{"id": 1}, db.Cond{"another_val": idSlice}).Set(db.Cond{"some_column": 10}) assert.Equal( - `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IS NULL)`, + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, q.String(), ) assert.Equal( diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index e1ecd446476b10332aff30b03a17fb7b0a2168c1..b410281f4e1ceffa469ec548546881b158e6d048 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -26,33 +26,42 @@ func newTemplateWithUtils(template *exql.Template) *templateWithUtils { func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) { argn := 0 + argx := make([]interface{}, 0, len(args)) for i := 0; i < len(in); i++ { if in[i] == '?' { - if len(args) > argn { // we have arguments to match. - u := toInterfaceArguments(args[argn]) + if len(args) > argn { k := `?` - if len(u) > 1 { - // An array of arguments - k = `(?` + strings.Repeat(`, ?`, len(u)-1) + `)` - } else if len(u) == 1 { - if rawValue, ok := u[0].(db.RawValue); ok { - k = rawValue.Raw() - u = []interface{}{} + values, isSlice := toInterfaceArguments(args[argn]) + if isSlice { + if len(values) == 0 { + k = `(NULL)` + } else { + k = `(?` + strings.Repeat(`, ?`, len(values)-1) + `)` + } + } else { + if len(values) == 1 { + if rawValue, ok := values[0].(db.RawValue); ok { + k, values = rawValue.Raw(), nil + } + } else if len(values) == 0 { + k = `NULL` } } - lk := len(k) - if lk > 1 { + if k != `?` { in = in[:i] + k + in[i+1:] i += len(k) - 1 } - args = append(args[:argn], append(u, args[argn+1:]...)...) - argn += len(u) + + if len(values) > 0 { + argx = append(argx, values...) + } + argn++ } } } - return in, args + return in, argx } // ToWhereWithArguments converts the given parameters into a exql.Where @@ -154,7 +163,7 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] fnName := t.Name() fnArgs := []interface{}{} - args := toInterfaceArguments(t.Arguments()) + args, _ := toInterfaceArguments(t.Arguments()) fragments := []string{} for i := range args { frag, args := tu.PlaceholderValue(args[i]) @@ -169,33 +178,30 @@ func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, [] } // toInterfaceArguments converts the given value into an array of interfaces. -func toInterfaceArguments(value interface{}) (args []interface{}) { +func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) { + v := reflect.ValueOf(value) + if value == nil { - return nil + return nil, false } - v := reflect.ValueOf(value) - - switch v.Type().Kind() { - case reflect.Slice: + if v.Type().Kind() == reflect.Slice { var i, total int + + // Byte slice gets transformed into a string. if v.Type().Elem().Kind() == reflect.Uint8 { - return []interface{}{string(value.([]byte))} + return []interface{}{string(value.([]byte))}, false } + total = v.Len() - if total > 0 { - args = make([]interface{}, total) - for i = 0; i < total; i++ { - args[i] = v.Index(i).Interface() - } - return args + args = make([]interface{}, total) + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() } - return nil - default: - args = []interface{}{value} + return args, true } - return args + return []interface{}{value}, false } // ToColumnValues converts the given conditions into a exql.ColumnValues struct. @@ -265,7 +271,7 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal // A function with one or more arguments. fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs) + expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) columnValue.Value = exql.RawValue(expanded) args = append(args, fnArgs...) case db.RawValue: @@ -273,27 +279,33 @@ func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnVal columnValue.Value = exql.RawValue(expanded) args = append(args, rawArgs...) default: - v := toInterfaceArguments(value) + v, isSlice := toInterfaceArguments(value) - if v == nil { - // Nil value given. - columnValue.Value = sqlNull + if isSlice { if columnValue.Operator == "" { - columnValue.Operator = sqlIsOperator + columnValue.Operator = sqlInOperator } - } else { - if len(v) > 1 || reflect.TypeOf(value).Kind() == reflect.Slice { + if len(v) > 0 { // Array value given. columnValue.Value = exql.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } else { + // Single value given. + columnValue.Value = exql.RawValue(`(NULL)`) + } + args = append(args, v...) + } else { + if v == nil { + // Nil value given. + columnValue.Value = sqlNull if columnValue.Operator == "" { - columnValue.Operator = sqlInOperator + columnValue.Operator = sqlIsOperator } } else { - // Single value given. columnValue.Value = sqlPlaceholder + args = append(args, v...) } - args = append(args, v...) } + } // Using guessed operator if no operator was given. diff --git a/lib/sqlbuilder/placeholder_test.go b/lib/sqlbuilder/placeholder_test.go index 82f472cd26477c0dfee15e28483cf37b0ff4ca04..3f05da3912e09c504e46f3775f576717f170caec 100644 --- a/lib/sqlbuilder/placeholder_test.go +++ b/lib/sqlbuilder/placeholder_test.go @@ -48,7 +48,7 @@ func TestPlaceholderArray(t *testing.T) { { ret, _ := expandPlaceholders("??", []interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}) - assert.Equal(t, "(?, ?, ?)?", ret) + assert.Equal(t, "(?, ?, ?)(NULL)", ret) } } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 131911dfd46ed9b11759a33bba7166ebd2cb50dc..695748ee38c3669e038501f2963282cbda5cec61 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -35,7 +35,7 @@ type selector struct { groupBy *exql.GroupBy groupByArgs []interface{} - orderBy exql.OrderBy + orderBy *exql.OrderBy orderByArgs []interface{} limit exql.Limit @@ -161,7 +161,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { Column: exql.RawValue(col), } qs.mu.Lock() - qs.orderByArgs = args + qs.orderByArgs = append(qs.orderByArgs, args...) qs.mu.Unlock() case db.Function: fnName, fnArgs := value.Name(), value.Arguments() @@ -175,7 +175,7 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { Column: exql.RawValue(expanded), } qs.mu.Lock() - qs.orderByArgs = fnArgs + qs.orderByArgs = append(qs.orderByArgs, fnArgs...) qs.mu.Unlock() case string: if strings.HasPrefix(value, "-") { @@ -204,7 +204,9 @@ func (qs *selector) OrderBy(columns ...interface{}) Selector { } qs.mu.Lock() - qs.orderBy.SortColumns = &sortColumns + qs.orderBy = &exql.OrderBy{ + SortColumns: &sortColumns, + } qs.mu.Unlock() return qs @@ -332,7 +334,7 @@ func (qs *selector) statement() *exql.Statement { Offset: qs.offset, Joins: exql.JoinConditions(qs.joins...), Where: qs.where, - OrderBy: &qs.orderBy, + OrderBy: qs.orderBy, GroupBy: qs.groupBy, } }