diff --git a/mysql/database.go b/mysql/database.go index 0dd71f85a1b12b4e04d2582784704456c77487f4..8c2f480e3959d97f951a37f7866a0d9e16ad6b6d 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -69,6 +69,13 @@ func debugEnabled() bool { return false } +func debugLog(query string, args []interface{}, err error) { + if debugEnabled() == true { + d := sqlutil.Debug{query, args, err} + d.Print() + } +} + func init() { template = &sqlgen.Template{ @@ -105,58 +112,75 @@ func init() { } func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { + var query string + var res sql.Result + var err error + + defer func() { + debugLog(query, args, err) + }() if self.session == nil { return nil, db.ErrNotConnected } - query := stmt.Compile(template) - - if debugEnabled() == true { - sqlutil.DebugQuery(query, args) - } + query = stmt.Compile(template) if self.tx != nil { - return self.tx.Exec(query, args...) + res, err = self.tx.Exec(query, args...) + } else { + res, err = self.session.Exec(query, args...) } - return self.session.Exec(query, args...) + return res, err } func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { + var rows *sql.Rows + var query string + var err error + + defer func() { + debugLog(query, args, err) + }() + if self.session == nil { return nil, db.ErrNotConnected } - query := stmt.Compile(template) - - if debugEnabled() == true { - sqlutil.DebugQuery(query, args) - } + query = stmt.Compile(template) if self.tx != nil { - return self.tx.Query(query, args...) + rows, err = self.tx.Query(query, args...) + } else { + rows, err = self.session.Query(query, args...) } - return self.session.Query(query, args...) + return rows, err } func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { + var query string + var row *sql.Row + var err error + + defer func() { + debugLog(query, args, err) + }() + if self.session == nil { return nil, db.ErrNotConnected } - query := stmt.Compile(template) - - if debugEnabled() == true { - sqlutil.DebugQuery(query, args) - } + query = stmt.Compile(template) if self.tx != nil { - return self.tx.QueryRow(query, args...), nil + row = self.tx.QueryRow(query, args...) + } else { + row = self.session.QueryRow(query, args...) } - return self.session.QueryRow(query, args...), nil + return row, err } // Returns the string name of the database. diff --git a/mysql/result.go b/mysql/result.go index 614bfabfd2b3d11d2c5124d90ff77864b86aa97c..ca810196ebf524805581f98da41ea88b13e409de 100644 --- a/mysql/result.go +++ b/mysql/result.go @@ -23,6 +23,7 @@ package mysql import ( "database/sql" + "fmt" "strings" "upper.io/db" "upper.io/db/util/sqlgen" @@ -112,12 +113,17 @@ func (self *Result) Sort(fields ...string) db.Result { } // Retrieves only the given fields. -func (self *Result) Select(fields ...string) db.Result { +func (self *Result) Select(fields ...interface{}) db.Result { self.columns = make(sqlgen.Columns, 0, len(fields)) l := len(fields) for i := 0; i < l; i++ { - self.columns = append(self.columns, sqlgen.Column{fields[i]}) + switch value := fields[i].(type) { + case db.Raw: + self.columns = append(self.columns, sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}) + default: + self.columns = append(self.columns, sqlgen.Column{value}) + } } return self