From 35ecf39581e2863321eed1330d12a9fe16a005d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Thu, 28 May 2015 06:00:17 -0500 Subject: [PATCH] Moving more shared logic to sqlutil. --- mysql/collection.go | 65 +++++++++---------------------------- mysql/database.go | 6 ++-- mysql/database_test.go | 6 ++-- postgresql/collection.go | 22 +++---------- postgresql/database.go | 28 ++++------------ postgresql/database_test.go | 2 +- postgresql/template.go | 2 +- util/sqlutil/debug.go | 17 ++++++++++ util/sqlutil/main.go | 51 +++++++++++++++++++++++++++++ 9 files changed, 103 insertions(+), 96 deletions(-) diff --git a/mysql/collection.go b/mysql/collection.go index 95a11462..6b53facb 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -34,22 +34,10 @@ import ( type table struct { sqlutil.T *database - names []string } var _ = db.Collection(&table{}) -// tableN returns the nth name provided to the table. -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { where, arguments := sqlutil.ToWhereWithArguments(terms) @@ -60,7 +48,7 @@ func (t *table) Find(terms ...interface{}) db.Result { func (t *table) Truncate() error { _, err := t.database.Exec(sqlgen.Statement{ Type: sqlgen.Truncate, - Table: sqlgen.TableWithName(t.tableN(0)), + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { @@ -73,41 +61,19 @@ func (t *table) Truncate() error { func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - cols, vals, err := t.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns := new(sqlgen.Columns) + sqlgenCols, sqlgenVals, sqlgenArgs, err := t.ColumnsValuesAndArguments(columnNames, columnValues) - columns.Columns = make([]sqlgen.Fragment, 0, len(cols)) - for i := range cols { - columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(cols[i])) - } - - values := new(sqlgen.Values) - var arguments []interface{} - - arguments = make([]interface{}, 0, len(vals)) - values.Values = make([]sqlgen.Fragment, 0, len(vals)) - - for i := range vals { - switch v := vals[i].(type) { - case *sqlgen.Value: - // Adding value. - values.Values = append(values.Values, v) - case sqlgen.Value: - // Adding value. - values.Values = append(values.Values, &v) - default: - // Adding both value and placeholder. - values.Values = append(values.Values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } - if pKey, err = t.database.getPrimaryKey(t.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -116,13 +82,13 @@ func (t *table) Append(item interface{}) (interface{}, error) { stmt := sqlgen.Statement{ Type: sqlgen.Insert, - Table: sqlgen.TableWithName(t.tableN(0)), - Columns: columns, - Values: values, + Table: sqlgen.TableWithName(t.MainTableName()), + Columns: sqlgenCols, + Values: sqlgenVals, } var res sql.Result - if res, err = t.database.Exec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -149,10 +115,10 @@ func (t *table) Append(item interface{}) (interface{}, error) { // were given for constructing the composite key. keyMap := make(map[string]interface{}) - for i := range cols { + for i := range columnNames { for j := 0; j < len(pKey); j++ { - if pKey[j] == cols[i] { - keyMap[pKey[j]] = vals[i] + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] } } } @@ -177,12 +143,13 @@ func (t *table) Append(item interface{}) (interface{}, error) { // Returns true if the collection exists. func (t *table) Exists() bool { - if err := t.database.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true } +// Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) + return strings.Join(t.Tables, `, `) } diff --git a/mysql/database.go b/mysql/database.go index 11e5e5de..389cb156 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -163,10 +163,8 @@ func (d *database) Collection(names ...string) (db.Collection, error) { } } - col := &table{ - database: d, - names: names, - } + col := &table{database: d} + col.Tables = names for _, name := range names { chunks := strings.SplitN(name, ` `, 2) diff --git a/mysql/database_test.go b/mysql/database_test.go index e0e4d3ae..624dcc20 100644 --- a/mysql/database_test.go +++ b/mysql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the @@ -1470,7 +1470,7 @@ func BenchmarkAppendRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if _, err = driver.Exec("TRUNCATE TABLE `artist`"); err != nil { b.Fatal(err) @@ -1524,7 +1524,7 @@ func BenchmarkAppendTxRawSQL(b *testing.B) { defer sess.Close() - driver := sess.Driver().(*sql.DB) + driver := sess.Driver().(*sqlx.DB) if tx, err = driver.Begin(); err != nil { b.Fatal(err) diff --git a/postgresql/collection.go b/postgresql/collection.go index abdaeb55..19ecda73 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -37,22 +37,10 @@ type table struct { sqlutil.T *database primaryKey string - names []string } var _ = db.Collection(&table{}) -// tableN returns the nth name provided to the table. -func (t *table) tableN(i int) string { - if len(t.names) > i { - chunks := strings.SplitN(t.names[i], " ", 2) - if len(chunks) > 0 { - return chunks[0] - } - } - return "" -} - // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { where, arguments := sqlutil.ToWhereWithArguments(terms) @@ -63,7 +51,7 @@ func (t *table) Find(terms ...interface{}) db.Result { func (t *table) Truncate() error { _, err := t.database.Exec(sqlgen.Statement{ Type: sqlgen.Truncate, - Table: sqlgen.TableWithName(t.tableN(0)), + Table: sqlgen.TableWithName(t.MainTableName()), }) if err != nil { @@ -112,7 +100,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { var pKey []string - if pKey, err = t.database.getPrimaryKey(t.tableN(0)); err != nil { + if pKey, err = t.database.getPrimaryKey(t.MainTableName()); err != nil { if err != sql.ErrNoRows { // Can't tell primary key. return nil, err @@ -121,7 +109,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { stmt := sqlgen.Statement{ Type: sqlgen.Insert, - Table: sqlgen.TableWithName(t.tableN(0)), + Table: sqlgen.TableWithName(t.MainTableName()), Columns: columns, Values: values, } @@ -194,7 +182,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { // Exists returns true if the collection exists. func (t *table) Exists() bool { - if err := t.database.tableExists(t.names...); err != nil { + if err := t.database.tableExists(t.Tables...); err != nil { return false } return true @@ -202,5 +190,5 @@ func (t *table) Exists() bool { // Name returns the name of the table or tables that form the collection. func (t *table) Name() string { - return strings.Join(t.names, `, `) + return strings.Join(t.Tables, `, `) } diff --git a/postgresql/database.go b/postgresql/database.go index 18a79d41..0ece2ade 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -24,13 +24,12 @@ package postgresql import ( "database/sql" "fmt" - "os" "strconv" "strings" "time" "github.com/jmoiron/sqlx" - _ "github.com/lib/pq" // Go PostgreSQL driver. + _ "github.com/lib/pq" // PostgreSQL driver. "upper.io/db" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" @@ -64,26 +63,12 @@ type columnSchemaT struct { DataType string `db:"data_type"` } -func debugEnabled() bool { - if os.Getenv(db.EnvEnableDebug) != "" { - return true - } - return false -} - -func debugLog(query string, args []interface{}, err error, start int64, end int64) { - if debugEnabled() == true { - d := sqlutil.Debug{query, args, err, start, end} - d.Print() - } -} - // Driver returns the underlying *sqlx.DB instance. func (d *database) Driver() interface{} { return d.session } -// Open attempts to connect to the PostgreSQL server using already stored settings. +// Open attempts to connect to the database server using already stored settings. func (d *database) Open() error { var err error @@ -164,9 +149,10 @@ func (d *database) Collection(names ...string) (db.Collection, error) { col := &table{ database: d, - names: names, } + col.Tables = names + for _, name := range names { chunks := strings.SplitN(name, ` `, 2) @@ -311,7 +297,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() if d.session == nil { @@ -345,7 +331,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() if d.session == nil { @@ -379,7 +365,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R defer func() { end = time.Now().UnixNano() - debugLog(query, args, err, start, end) + sqlutil.Log(query, args, err, start, end) }() if d.session == nil { diff --git a/postgresql/database_test.go b/postgresql/database_test.go index 0e7e9350..090779c4 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the diff --git a/postgresql/template.go b/postgresql/template.go index 7c8f13a4..6df9d1c4 100644 --- a/postgresql/template.go +++ b/postgresql/template.go @@ -1,4 +1,4 @@ -// Copyright (c) 2012-2014 José Carlos Nieto, https://menteslibres.net/xiam +// Copyright (c) 2012-2015 José Carlos Nieto, https://menteslibres.net/xiam // // Permission is hereby granted, free of charge, to any person obtaining // a copy of this software and associated documentation files (the diff --git a/util/sqlutil/debug.go b/util/sqlutil/debug.go index 08d8ebe9..f82a9857 100644 --- a/util/sqlutil/debug.go +++ b/util/sqlutil/debug.go @@ -24,7 +24,10 @@ package sqlutil import ( "fmt" "log" + "os" "strings" + + "upper.io/db" ) // Debug is used for printing SQL queries and arguments. @@ -59,3 +62,17 @@ func (d *Debug) Print() { log.Printf("\n\t%s\n\n", strings.Join(s, "\n\t")) } + +func IsDebugEnabled() bool { + if os.Getenv(db.EnvEnableDebug) != "" { + return true + } + return false +} + +func Log(query string, args []interface{}, err error, start int64, end int64) { + if IsDebugEnabled() { + d := Debug{query, args, err, start, end} + d.Print() + } +} diff --git a/util/sqlutil/main.go b/util/sqlutil/main.go index 77cc0637..11e5e266 100644 --- a/util/sqlutil/main.go +++ b/util/sqlutil/main.go @@ -33,6 +33,7 @@ import ( "upper.io/db" "upper.io/db/util" + "upper.io/db/util/sqlgen" ) var ( @@ -50,6 +51,7 @@ var ( // using FieldValues() type T struct { Columns []string + Tables []string } func (t *T) columnLike(s string) string { @@ -210,3 +212,52 @@ func NewMapper() *reflectx.Mapper { return reflectx.NewMapperTagFunc("db", mapFunc, tagFunc) } + +// MainTableName returns the name of the first table. +func (t *T) MainTableName() string { + return t.NthTableName(0) +} + +// NthTableName returns the table name at index i. +func (t *T) NthTableName(i int) string { + if len(t.Tables) > i { + chunks := strings.SplitN(t.Tables[i], " ", 2) + if len(chunks) > 0 { + return chunks[0] + } + } + return "" +} + +func (t *T) ColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(sqlgen.Columns) + + columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i])) + } + + values := new(sqlgen.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]sqlgen.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *sqlgen.Value: + // Adding value. + values.Values = append(values.Values, v) + case sqlgen.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} -- GitLab