From b66d7c9539c6f03fa6e83e435076b3aa34a3b9d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Tue, 17 Jun 2014 16:50:59 -0500 Subject: [PATCH] Adding support for debugLog(). --- ql/database.go | 81 +++++++++++++++++++++++++++++--------------------- ql/result.go | 10 +++++-- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/ql/database.go b/ql/database.go index 5f36733e..bf909d8f 100644 --- a/ql/database.go +++ b/ql/database.go @@ -69,6 +69,13 @@ func debugEnabled() bool { return false } +func debugLog(query string, args []interface{}, err error) { + if debugEnabled() == true { + d := sqlutil.Debug{query, args, err} + d.Print() + } +} + func init() { template = &sqlgen.Template{ @@ -106,25 +113,29 @@ func init() { func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { + var query string + var res sql.Result + var err error + + defer func() { + debugLog(query, args, err) + }() + if self.session == nil { return nil, db.ErrNotConnected } - query := stmt.Compile(template) + query = stmt.Compile(template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if debugEnabled() == true { - sqlutil.DebugQuery(query, args) - } - - if self.tx == nil { + if self.tx != nil { + res, err = self.tx.Exec(query, args...) + } else { var tx *sql.Tx - var err error - var res sql.Result if tx, err = self.session.Begin(); err != nil { return nil, err @@ -137,33 +148,35 @@ func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Resu if err = tx.Commit(); err != nil { return nil, err } - - return res, nil } - return self.tx.Exec(query, args...) + return res, err } func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { + var rows *sql.Rows + var query string + var err error + + defer func() { + debugLog(query, args, err) + }() + if self.session == nil { return nil, db.ErrNotConnected } - query := stmt.Compile(template) + query = stmt.Compile(template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if debugEnabled() == true { - sqlutil.DebugQuery(query, args) - } - - if self.tx == nil { + if self.tx != nil { + rows, err = self.tx.Query(query, args...) + } else { var tx *sql.Tx - var err error - var rows *sql.Rows if tx, err = self.session.Begin(); err != nil { return nil, err @@ -176,33 +189,35 @@ func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro if err = tx.Commit(); err != nil { return nil, err } - - return rows, nil } - return self.tx.Query(query, args...) + return rows, err } func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { + var query string + var row *sql.Row + var err error + + defer func() { + debugLog(query, args, err) + }() + if self.session == nil { return nil, db.ErrNotConnected } - query := stmt.Compile(template) + query = stmt.Compile(template) l := len(args) for i := 0; i < l; i++ { query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1) } - if debugEnabled() == true { - sqlutil.DebugQuery(query, args) - } - - if self.tx == nil { + if self.tx != nil { + row = self.tx.QueryRow(query, args...) + } else { var tx *sql.Tx - var err error - var row *sql.Row if tx, err = self.session.Begin(); err != nil { return nil, err @@ -215,11 +230,9 @@ func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql if err = tx.Commit(); err != nil { return nil, err } - - return row, nil - } else { - return self.tx.QueryRow(query, args...), nil } + + return row, err } // Returns the string name of the database. diff --git a/ql/result.go b/ql/result.go index 2c88d749..97fe1a24 100644 --- a/ql/result.go +++ b/ql/result.go @@ -23,6 +23,7 @@ package ql import ( "database/sql" + "fmt" "strings" "upper.io/db" "upper.io/db/util/sqlgen" @@ -113,12 +114,17 @@ func (self *Result) Sort(fields ...string) db.Result { } // Retrieves only the given fields. -func (self *Result) Select(fields ...string) db.Result { +func (self *Result) Select(fields ...interface{}) db.Result { self.columns = make(sqlgen.Columns, 0, len(fields)) l := len(fields) for i := 0; i < l; i++ { - self.columns = append(self.columns, sqlgen.Column{fields[i]}) + switch value := fields[i].(type) { + case db.Raw: + self.columns = append(self.columns, sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}}) + default: + self.columns = append(self.columns, sqlgen.Column{value}) + } } return self -- GitLab