diff --git a/mysql/collection.go b/mysql/collection.go index d1e66776fa4d03b2de4129154537389af7712f7f..785a88963129fd8e9cc1a2d592bcd423096e031f 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -27,6 +27,7 @@ import ( "fmt" _ "github.com/go-sql-driver/mysql" "menteslibres.net/gosexy/to" + "reflect" "strings" "time" "upper.io/db" @@ -112,27 +113,46 @@ func (self *Table) compileConditions(term interface{}) (string, []interface{}) { return "", args } -func (self *Table) compileStatement(where db.Cond) (string, []interface{}) { +func (self *Table) compileStatement(cond db.Cond) (string, []interface{}) { - str := make([]string, len(where)) - arg := make([]interface{}, len(where)) + total := len(cond) - i := 0 + str := make([]string, 0, total) + arg := make([]interface{}, 0, total) - for key, _ := range where { - key = strings.Trim(key, ` `) - chunks := strings.SplitN(key, ` `, 2) + // Walking over conditions + for field, value := range cond { + // Removing leading or trailing spaces. + field = strings.TrimSpace(field) + chunks := strings.SplitN(field, ` `, 2) + + // Default operator. op := `=` if len(chunks) > 1 { + // User has defined a different operator. op = chunks[1] } - str[i] = fmt.Sprintf(`%s %s ?`, chunks[0], op) - arg[i] = toInternal(where[key]) - - i++ + switch value := value.(type) { + case db.Func: + value_i := interfaceArgs(value.Args) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], value.Name)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], value.Name, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + default: + value_i := interfaceArgs(value) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], op)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], op, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + } } switch len(str) { @@ -235,3 +255,34 @@ func toInternal(val interface{}) interface{} { func toNative(val interface{}) interface{} { return val } + +func interfaceArgs(value interface{}) (args []interface{}) { + + if value == nil { + return nil + } + + value_v := reflect.ValueOf(value) + + switch value_v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = value_v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = toInternal(value_v.Index(i).Interface()) + } + + return args + } else { + return nil + } + default: + args = []interface{}{toInternal(value)} + } + + return args +} diff --git a/mysql/database_test.go b/mysql/database_test.go index cb610582ed708c31c361eaa4ff9d46c163179f90..7362110870e515ca8f42d1137ae376b5ad07edd0 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -513,6 +513,48 @@ func TestUpdate(t *testing.T) { } +// Test database functions +func TestFunction(t *testing.T) { + var err error + var res db.Result + + // Opening database. + sess, err := db.Open(wrapperName, settings) + + if err != nil { + t.Fatalf(err.Error()) + } + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist, err := sess.Collection("artist") + + if err != nil { + t.Fatalf(err.Error()) + } + + row_s := struct { + Id uint64 + Name string + }{} + + res = artist.Find(db.Cond{"id NOT IN": []int{0, -1}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res = artist.Find(db.Cond{"id": db.Func{"NOT IN", []int{0, -1}}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res.Close() +} + // This test tries to remove some previously added rows. func TestRemove(t *testing.T) {