diff --git a/mysql/collection.go b/mysql/collection.go index b5a56fde2a6b3860620a089c4bdb2def6a37486f..d1e66776fa4d03b2de4129154537389af7712f7f 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -67,9 +67,9 @@ func (self *Table) Find(terms ...interface{}) db.Result { } // Transforms conditions into arguments for sql.Exec/sql.Query -func (self *Table) compileConditions(term interface{}) (string, []string) { +func (self *Table) compileConditions(term interface{}) (string, []interface{}) { sql := []string{} - args := []string{} + args := []interface{}{} switch t := term.(type) { case []interface{}: @@ -112,10 +112,10 @@ func (self *Table) compileConditions(term interface{}) (string, []string) { return "", args } -func (self *Table) compileStatement(where db.Cond) (string, []string) { +func (self *Table) compileStatement(where db.Cond) (string, []interface{}) { str := make([]string, len(where)) - arg := make([]string, len(where)) + arg := make([]interface{}, len(where)) i := 0 @@ -139,7 +139,7 @@ func (self *Table) compileStatement(where db.Cond) (string, []string) { case 1: return str[0], arg case 0: - return "", []string{} + return "", []interface{}{} } return `(` + strings.Join(str, ` AND `) + `)`, arg @@ -211,7 +211,7 @@ func toInternalInterface(val interface{}) interface{} { } // Converts a Go value into internal database representation. -func toInternal(val interface{}) string { +func toInternal(val interface{}) interface{} { switch t := val.(type) { case []byte: diff --git a/mysql/database.go b/mysql/database.go index 8a09ff40441f0dc943d9a9f34a029aee132762ac..0ae9125529d7a6fbe86bed4f809574e86387dd68 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -49,7 +49,7 @@ func init() { db.Register(driverName, &Source{}) } -type sqlValues_t []string +type sqlValues_t []interface{} type Source struct { session *sql.DB @@ -69,12 +69,6 @@ func sqlCompile(terms []interface{}) *sqlQuery { for _, term := range terms { switch t := term.(type) { - case string: - q.Query = append(q.Query, t) - case []string: - for _, arg := range t { - q.Args = append(q.Args, arg) - } case sqlValues_t: args := make([]string, len(t)) for i, arg := range t { @@ -82,6 +76,17 @@ func sqlCompile(terms []interface{}) *sqlQuery { q.Args = append(q.Args, arg) } q.Query = append(q.Query, `(`+strings.Join(args, `, `)+`)`) + case string: + q.Query = append(q.Query, t) + default: + if reflect.TypeOf(t).Kind() == reflect.Slice { + var v = reflect.ValueOf(t) + for i := 0; i < v.Len(); i++ { + q.Args = append(q.Args, v.Index(i).Interface()) + } + } else { + q.Args = append(q.Args, t) + } } } @@ -95,7 +100,7 @@ func sqlFields(names []string) string { return "(`" + strings.Join(names, "`, `") + "`)" } -func sqlValues(values []string) sqlValues_t { +func sqlValues(values []interface{}) sqlValues_t { ret := make(sqlValues_t, len(values)) for i, _ := range values { ret[i] = values[i] diff --git a/mysql/result.go b/mysql/result.go index ffe27f0e1bfc1c833e286601475db34d0e629e24..518b7923a130e7067f6aa39a30d4bb2270cce0a3 100644 --- a/mysql/result.go +++ b/mysql/result.go @@ -195,7 +195,7 @@ func (self *Result) Update(values interface{}) error { total := len(ff) updateFields := make([]string, total) - updateArgs := make([]string, total) + updateArgs := make([]interface{}, total) for i := 0; i < total; i++ { updateFields[i] = fmt.Sprintf(`%s = ?`, ff[i])