From b66d7c9539c6f03fa6e83e435076b3aa34a3b9d4 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 17 Jun 2014 16:50:59 -0500
Subject: [PATCH] Adding support for debugLog().

---
 ql/database.go | 81 +++++++++++++++++++++++++++++---------------------
 ql/result.go   | 10 +++++--
 2 files changed, 55 insertions(+), 36 deletions(-)

diff --git a/ql/database.go b/ql/database.go
index 5f36733e..bf909d8f 100644
--- a/ql/database.go
+++ b/ql/database.go
@@ -69,6 +69,13 @@ func debugEnabled() bool {
 	return false
 }
 
+func debugLog(query string, args []interface{}, err error) {
+	if debugEnabled() == true {
+		d := sqlutil.Debug{query, args, err}
+		d.Print()
+	}
+}
+
 func init() {
 
 	template = &sqlgen.Template{
@@ -106,25 +113,29 @@ func init() {
 
 func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) {
 
+	var query string
+	var res sql.Result
+	var err error
+
+	defer func() {
+		debugLog(query, args, err)
+	}()
+
 	if self.session == nil {
 		return nil, db.ErrNotConnected
 	}
 
-	query := stmt.Compile(template)
+	query = stmt.Compile(template)
 
 	l := len(args)
 	for i := 0; i < l; i++ {
 		query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1)
 	}
 
-	if debugEnabled() == true {
-		sqlutil.DebugQuery(query, args)
-	}
-
-	if self.tx == nil {
+	if self.tx != nil {
+		res, err = self.tx.Exec(query, args...)
+	} else {
 		var tx *sql.Tx
-		var err error
-		var res sql.Result
 
 		if tx, err = self.session.Begin(); err != nil {
 			return nil, err
@@ -137,33 +148,35 @@ func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Resu
 		if err = tx.Commit(); err != nil {
 			return nil, err
 		}
-
-		return res, nil
 	}
 
-	return self.tx.Exec(query, args...)
+	return res, err
 }
 
 func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) {
+	var rows *sql.Rows
+	var query string
+	var err error
+
+	defer func() {
+		debugLog(query, args, err)
+	}()
+
 	if self.session == nil {
 		return nil, db.ErrNotConnected
 	}
 
-	query := stmt.Compile(template)
+	query = stmt.Compile(template)
 
 	l := len(args)
 	for i := 0; i < l; i++ {
 		query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1)
 	}
 
-	if debugEnabled() == true {
-		sqlutil.DebugQuery(query, args)
-	}
-
-	if self.tx == nil {
+	if self.tx != nil {
+		rows, err = self.tx.Query(query, args...)
+	} else {
 		var tx *sql.Tx
-		var err error
-		var rows *sql.Rows
 
 		if tx, err = self.session.Begin(); err != nil {
 			return nil, err
@@ -176,33 +189,35 @@ func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro
 		if err = tx.Commit(); err != nil {
 			return nil, err
 		}
-
-		return rows, nil
 	}
 
-	return self.tx.Query(query, args...)
+	return rows, err
 }
 
 func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) {
+	var query string
+	var row *sql.Row
+	var err error
+
+	defer func() {
+		debugLog(query, args, err)
+	}()
+
 	if self.session == nil {
 		return nil, db.ErrNotConnected
 	}
 
-	query := stmt.Compile(template)
+	query = stmt.Compile(template)
 
 	l := len(args)
 	for i := 0; i < l; i++ {
 		query = strings.Replace(query, `?`, fmt.Sprintf(`$%d`, i+1), 1)
 	}
 
-	if debugEnabled() == true {
-		sqlutil.DebugQuery(query, args)
-	}
-
-	if self.tx == nil {
+	if self.tx != nil {
+		row = self.tx.QueryRow(query, args...)
+	} else {
 		var tx *sql.Tx
-		var err error
-		var row *sql.Row
 
 		if tx, err = self.session.Begin(); err != nil {
 			return nil, err
@@ -215,11 +230,9 @@ func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql
 		if err = tx.Commit(); err != nil {
 			return nil, err
 		}
-
-		return row, nil
-	} else {
-		return self.tx.QueryRow(query, args...), nil
 	}
+
+	return row, err
 }
 
 // Returns the string name of the database.
diff --git a/ql/result.go b/ql/result.go
index 2c88d749..97fe1a24 100644
--- a/ql/result.go
+++ b/ql/result.go
@@ -23,6 +23,7 @@ package ql
 
 import (
 	"database/sql"
+	"fmt"
 	"strings"
 	"upper.io/db"
 	"upper.io/db/util/sqlgen"
@@ -113,12 +114,17 @@ func (self *Result) Sort(fields ...string) db.Result {
 }
 
 // Retrieves only the given fields.
-func (self *Result) Select(fields ...string) db.Result {
+func (self *Result) Select(fields ...interface{}) db.Result {
 	self.columns = make(sqlgen.Columns, 0, len(fields))
 
 	l := len(fields)
 	for i := 0; i < l; i++ {
-		self.columns = append(self.columns, sqlgen.Column{fields[i]})
+		switch value := fields[i].(type) {
+		case db.Raw:
+			self.columns = append(self.columns, sqlgen.Column{sqlgen.Raw{fmt.Sprintf(`%v`, value.Value)}})
+		default:
+			self.columns = append(self.columns, sqlgen.Column{value})
+		}
 	}
 
 	return self
-- 
GitLab