diff --git a/db.go b/db.go index 47bd2e71137179b30d16147bce078a5a7835b75f..e35228c2a5d4f742bdf3744b158ea8881fc4e9d9 100644 --- a/db.go +++ b/db.go @@ -45,6 +45,14 @@ import ( */ type Cond map[string]interface{} +/* + The db.Func expression is used to represent database functions. +*/ +type Func struct { + Name string + Args interface{} +} + /* The db.And() expression is used to glue two or more expressions under logical conjunction, it accepts db.Cond{}, db.Or() and other db.And() expressions. diff --git a/mongo/collection.go b/mongo/collection.go index 97bf700a53ad759373998086b1303289ad4fcd46..f3fb164a145369636ff64138bc7b87d6f1acb7ae 100644 --- a/mongo/collection.go +++ b/mongo/collection.go @@ -70,15 +70,19 @@ func (self *Collection) Find(terms ...interface{}) db.Result { } // Transforms conditions into something *mgo.Session can understand. -func compileStatement(where db.Cond) bson.M { +func compileStatement(cond db.Cond) bson.M { conds := bson.M{} - for key, val := range where { - key = strings.Trim(key, ` `) - chunks := strings.SplitN(key, ` `, 2) + // Walking over conditions + for field, value := range cond { + // Removing leading or trailing spaces. + field = strings.TrimSpace(field) + + chunks := strings.SplitN(field, ` `, 2) + + var op string if len(chunks) > 1 { - op := "" switch chunks[1] { case `>`: op = `$gt` @@ -91,11 +95,17 @@ func compileStatement(where db.Cond) bson.M { default: op = chunks[1] } - //conds[chunks[0]] = bson.M{op: toInternal(val)} - conds[chunks[0]] = bson.M{op: val} - } else { - //conds[key] = toInternal(val) - conds[key] = val + } + + switch value := value.(type) { + case db.Func: + conds[chunks[0]] = bson.M{value.Name: value.Args} + default: + if op == "" { + conds[chunks[0]] = value + } else { + conds[chunks[0]] = bson.M{op: value} + } } } diff --git a/mongo/database_test.go b/mongo/database_test.go index a071612f7bb042a752124438b7a3c2237d13e89a..12bc45350f9f420aa0dba5cd93139f660eea45f8 100644 --- a/mongo/database_test.go +++ b/mongo/database_test.go @@ -532,6 +532,48 @@ func TestUpdate(t *testing.T) { } +// Test database functions +func TestFunction(t *testing.T) { + var err error + var res db.Result + + // Opening database. + sess, err := db.Open(wrapperName, settings) + + if err != nil { + t.Fatalf(err.Error()) + } + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist, err := sess.Collection("artist") + + if err != nil { + t.Fatalf(err.Error()) + } + + row_s := struct { + Id uint64 + Name string + }{} + + res = artist.Find(db.Cond{"_id $nin": []int{0, -1}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res = artist.Find(db.Cond{"_id": db.Func{"$nin", []int{0, -1}}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res.Close() +} + // This test tries to remove some previously added rows. func TestRemove(t *testing.T) { diff --git a/mongo/result.go b/mongo/result.go index 473f7640d4b030e31d4ad1e2d66641c32b26a9be..abc6487c9cc7f7131e3ad1394f31cf6cc9775f36 100644 --- a/mongo/result.go +++ b/mongo/result.go @@ -126,7 +126,11 @@ func (self *Result) Next(dst interface{}) error { success := self.iter.Next(dst) if success == false { - return db.ErrNoMoreRows + err := self.iter.Err() + if err == nil { + return db.ErrNoMoreRows + } + return err } return nil diff --git a/mysql/collection.go b/mysql/collection.go index d1e66776fa4d03b2de4129154537389af7712f7f..785a88963129fd8e9cc1a2d592bcd423096e031f 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -27,6 +27,7 @@ import ( "fmt" _ "github.com/go-sql-driver/mysql" "menteslibres.net/gosexy/to" + "reflect" "strings" "time" "upper.io/db" @@ -112,27 +113,46 @@ func (self *Table) compileConditions(term interface{}) (string, []interface{}) { return "", args } -func (self *Table) compileStatement(where db.Cond) (string, []interface{}) { +func (self *Table) compileStatement(cond db.Cond) (string, []interface{}) { - str := make([]string, len(where)) - arg := make([]interface{}, len(where)) + total := len(cond) - i := 0 + str := make([]string, 0, total) + arg := make([]interface{}, 0, total) - for key, _ := range where { - key = strings.Trim(key, ` `) - chunks := strings.SplitN(key, ` `, 2) + // Walking over conditions + for field, value := range cond { + // Removing leading or trailing spaces. + field = strings.TrimSpace(field) + chunks := strings.SplitN(field, ` `, 2) + + // Default operator. op := `=` if len(chunks) > 1 { + // User has defined a different operator. op = chunks[1] } - str[i] = fmt.Sprintf(`%s %s ?`, chunks[0], op) - arg[i] = toInternal(where[key]) - - i++ + switch value := value.(type) { + case db.Func: + value_i := interfaceArgs(value.Args) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], value.Name)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], value.Name, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + default: + value_i := interfaceArgs(value) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], op)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], op, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + } } switch len(str) { @@ -235,3 +255,34 @@ func toInternal(val interface{}) interface{} { func toNative(val interface{}) interface{} { return val } + +func interfaceArgs(value interface{}) (args []interface{}) { + + if value == nil { + return nil + } + + value_v := reflect.ValueOf(value) + + switch value_v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = value_v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = toInternal(value_v.Index(i).Interface()) + } + + return args + } else { + return nil + } + default: + args = []interface{}{toInternal(value)} + } + + return args +} diff --git a/mysql/database_test.go b/mysql/database_test.go index cb610582ed708c31c361eaa4ff9d46c163179f90..7362110870e515ca8f42d1137ae376b5ad07edd0 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -513,6 +513,48 @@ func TestUpdate(t *testing.T) { } +// Test database functions +func TestFunction(t *testing.T) { + var err error + var res db.Result + + // Opening database. + sess, err := db.Open(wrapperName, settings) + + if err != nil { + t.Fatalf(err.Error()) + } + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist, err := sess.Collection("artist") + + if err != nil { + t.Fatalf(err.Error()) + } + + row_s := struct { + Id uint64 + Name string + }{} + + res = artist.Find(db.Cond{"id NOT IN": []int{0, -1}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res = artist.Find(db.Cond{"id": db.Func{"NOT IN", []int{0, -1}}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res.Close() +} + // This test tries to remove some previously added rows. func TestRemove(t *testing.T) { diff --git a/postgresql/collection.go b/postgresql/collection.go index 41164c6783f244418261d4df62b241c691092a6e..6d9ea81f5afc0bffc3ce06288a8bafdd286bfe5a 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -26,6 +26,7 @@ package postgresql import ( "fmt" "menteslibres.net/gosexy/to" + "reflect" "strings" "time" "upper.io/db" @@ -110,27 +111,46 @@ func (self *Table) compileConditions(term interface{}) (string, []interface{}) { return "", args } -func (self *Table) compileStatement(where db.Cond) (string, []interface{}) { +func (self *Table) compileStatement(cond db.Cond) (string, []interface{}) { - str := make([]string, len(where)) - arg := make([]interface{}, len(where)) + total := len(cond) - i := 0 + str := make([]string, 0, total) + arg := make([]interface{}, 0, total) - for key, _ := range where { - key = strings.Trim(key, ` `) - chunks := strings.SplitN(key, ` `, 2) + // Walking over conditions + for field, value := range cond { + // Removing leading or trailing spaces. + field = strings.TrimSpace(field) + chunks := strings.SplitN(field, ` `, 2) + + // Default operator. op := `=` if len(chunks) > 1 { + // User has defined a different operator. op = chunks[1] } - str[i] = fmt.Sprintf(`%s %s ?`, chunks[0], op) - arg[i] = toInternal(where[key]) - - i++ + switch value := value.(type) { + case db.Func: + value_i := interfaceArgs(value.Args) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], value.Name)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], value.Name, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + default: + value_i := interfaceArgs(value) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], op)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], op, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + } } switch len(str) { @@ -219,7 +239,6 @@ func toInternalInterface(val interface{}) interface{} { // Converts a Go value into internal database representation. func toInternal(val interface{}) interface{} { - switch t := val.(type) { case []byte: return string(t) @@ -234,7 +253,6 @@ func toInternal(val interface{}) interface{} { return `0` } } - return to.String(val) } @@ -242,3 +260,34 @@ func toInternal(val interface{}) interface{} { func toNative(val interface{}) interface{} { return val } + +func interfaceArgs(value interface{}) (args []interface{}) { + + if value == nil { + return nil + } + + value_v := reflect.ValueOf(value) + + switch value_v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = value_v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = toInternal(value_v.Index(i).Interface()) + } + + return args + } else { + return nil + } + default: + args = []interface{}{toInternal(value)} + } + + return args +} diff --git a/postgresql/database_test.go b/postgresql/database_test.go index cdf595215a9fa812441b140e8355f794c58ad3d6..80231380466522759a92f170a819f510205bce83 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -507,6 +507,48 @@ func TestUpdate(t *testing.T) { } +// Test database functions +func TestFunction(t *testing.T) { + var err error + var res db.Result + + // Opening database. + sess, err := db.Open(wrapperName, settings) + + if err != nil { + t.Fatalf(err.Error()) + } + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist, err := sess.Collection("artist") + + if err != nil { + t.Fatalf(err.Error()) + } + + row_s := struct { + Id uint64 + Name string + }{} + + res = artist.Find(db.Cond{"id NOT IN": []int{0, -1}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res = artist.Find(db.Cond{"id": db.Func{"NOT IN", []int{0, -1}}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res.Close() +} + // This test tries to remove some previously added rows. func TestRemove(t *testing.T) { diff --git a/ql/collection.go b/ql/collection.go index bcdc3d43f9a64eabdff8c54f68a29b27c27a6b06..4adcc8a5776a78b71f567b7aa8bb419a60f50e1d 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -25,6 +25,7 @@ package ql import ( "fmt" + "reflect" "strings" "upper.io/db" "upper.io/db/util/sqlutil" @@ -113,37 +114,56 @@ func (self *Table) compileConditions(term interface{}) (string, []interface{}) { return "", args } -func (self *Table) compileStatement(where db.Cond) (string, []interface{}) { +func (self *Table) compileStatement(cond db.Cond) (string, []interface{}) { - str := make([]string, len(where)) - arg := make([]interface{}, len(where)) + total := len(cond) - i := 0 + str := make([]string, 0, total) + arg := make([]interface{}, 0, total) - for key, _ := range where { - key = strings.Trim(key, ` `) - chunks := strings.SplitN(key, ` `, 2) + // Walking over conditions + for field, value := range cond { + // Removing leading or trailing spaces. + field = strings.TrimSpace(field) + chunks := strings.SplitN(field, ` `, 2) + + // Default operator. op := `==` if len(chunks) > 1 { + // User has defined a different operator. op = chunks[1] } - str[i] = fmt.Sprintf(`%s %s ?`, chunks[0], op) - arg[i] = where[key] - - i++ + switch value := value.(type) { + case db.Func: + value_i := interfaceArgs(value.Args) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], value.Name)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], value.Name, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + default: + value_i := interfaceArgs(value) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], op)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], op, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + } } switch len(str) { case 1: return str[0], arg case 0: - return "", nil + return "", []interface{}{} } - return `(` + strings.Join(str, ` && `) + `)`, arg + return `(` + strings.Join(str, ` AND `) + `)`, arg } // Deletes all the rows within the collection. @@ -206,3 +226,34 @@ func (self *Table) Exists() bool { return rows.Next() } + +func interfaceArgs(value interface{}) (args []interface{}) { + + if value == nil { + return nil + } + + value_v := reflect.ValueOf(value) + + switch value_v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = value_v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = value_v.Index(i).Interface() + } + + return args + } else { + return nil + } + default: + args = []interface{}{value} + } + + return args +} diff --git a/ql/database_test.go b/ql/database_test.go index c3af4b82ff04d3f1d6712783f8d1e5a51be3721b..844620b1d769f0e3e8ef0567126b8f7bd09f338f 100644 --- a/ql/database_test.go +++ b/ql/database_test.go @@ -505,6 +505,48 @@ func TestUpdate(t *testing.T) { } +// Test database functions +func TestFunction(t *testing.T) { + var err error + var res db.Result + + // Opening database. + sess, err := db.Open(wrapperName, settings) + + if err != nil { + t.Fatalf(err.Error()) + } + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist, err := sess.Collection("artist") + + if err != nil { + t.Fatalf(err.Error()) + } + + row_s := struct { + Id uint64 + Name string + }{} + + res = artist.Find(db.Cond{"id() NOT IN": []int{0, -1}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res = artist.Find(db.Cond{"id()": db.Func{"NOT IN", []int{0, -1}}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res.Close() +} + // This test tries to remove some previously added rows. func TestRemove(t *testing.T) { diff --git a/sqlite/collection.go b/sqlite/collection.go index dfac77a5a4fde7f870335381d8e48415560b4f01..02b26a56dbd2e03bd84fb061f694e3b92a5e8bd3 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -26,6 +26,7 @@ package sqlite import ( "fmt" "menteslibres.net/gosexy/to" + "reflect" "strings" "time" "upper.io/db" @@ -111,27 +112,46 @@ func (self *Table) compileConditions(term interface{}) (string, []interface{}) { return "", args } -func (self *Table) compileStatement(where db.Cond) (string, []interface{}) { +func (self *Table) compileStatement(cond db.Cond) (string, []interface{}) { - str := make([]string, len(where)) - arg := make([]interface{}, len(where)) + total := len(cond) - i := 0 + str := make([]string, 0, total) + arg := make([]interface{}, 0, total) - for key, _ := range where { - key = strings.Trim(key, ` `) - chunks := strings.SplitN(key, ` `, 2) + // Walking over conditions + for field, value := range cond { + // Removing leading or trailing spaces. + field = strings.TrimSpace(field) + chunks := strings.SplitN(field, ` `, 2) + + // Default operator. op := `=` if len(chunks) > 1 { + // User has defined a different operator. op = chunks[1] } - str[i] = fmt.Sprintf(`%s %s ?`, chunks[0], op) - arg[i] = toInternal(where[key]) - - i++ + switch value := value.(type) { + case db.Func: + value_i := interfaceArgs(value.Args) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], value.Name)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], value.Name, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + default: + value_i := interfaceArgs(value) + if value_i == nil { + str = append(str, fmt.Sprintf(`%s %s ()`, chunks[0], op)) + } else { + str = append(str, fmt.Sprintf(`%s %s (?%s)`, chunks[0], op, strings.Repeat(`,?`, len(value_i)-1))) + arg = append(arg, value_i...) + } + } } switch len(str) { @@ -232,3 +252,34 @@ func toInternal(val interface{}) interface{} { func toNative(val interface{}) interface{} { return val } + +func interfaceArgs(value interface{}) (args []interface{}) { + + if value == nil { + return nil + } + + value_v := reflect.ValueOf(value) + + switch value_v.Type().Kind() { + case reflect.Slice: + var i, total int + + total = value_v.Len() + if total > 0 { + args = make([]interface{}, total) + + for i = 0; i < total; i++ { + args[i] = toInternal(value_v.Index(i).Interface()) + } + + return args + } else { + return nil + } + default: + args = []interface{}{toInternal(value)} + } + + return args +} diff --git a/sqlite/database_test.go b/sqlite/database_test.go index 4614f005a41b601bdbdeb8a5c6f94005955fbdf9..611a12e0cf484f25ff6c4627b1eb40fd959aea51 100644 --- a/sqlite/database_test.go +++ b/sqlite/database_test.go @@ -505,6 +505,48 @@ func TestUpdate(t *testing.T) { } +// Test database functions +func TestFunction(t *testing.T) { + var err error + var res db.Result + + // Opening database. + sess, err := db.Open(wrapperName, settings) + + if err != nil { + t.Fatalf(err.Error()) + } + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist, err := sess.Collection("artist") + + if err != nil { + t.Fatalf(err.Error()) + } + + row_s := struct { + Id uint64 + Name string + }{} + + res = artist.Find(db.Cond{"id NOT IN": []int{0, -1}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res = artist.Find(db.Cond{"id": db.Func{"NOT IN", []int{0, -1}}}) + + if err = res.One(&row_s); err != nil { + t.Fatalf("One: %q", err) + } + + res.Close() +} + // This test tries to remove some previously added rows. func TestRemove(t *testing.T) {