From e6ed3f7215e04ca9e26c0a6d9ec06608dda3a9e3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Sat, 30 May 2015 10:43:10 -0500 Subject: [PATCH] Adding template utility methods. --- mysql/collection.go | 6 +++--- mysql/database.go | 10 +++++----- mysql/mysql.go | 8 +++++--- mysql/template.go | 1 + postgresql/collection.go | 6 +++--- postgresql/database.go | 10 +++++----- postgresql/postgresql.go | 8 +++++--- postgresql/template.go | 1 + sqlite/collection.go | 6 +++--- sqlite/database.go | 10 +++++----- sqlite/sqlite.go | 8 +++++--- sqlite/template.go | 1 + util/sqlgen/template.go | 1 + util/sqlutil/convert.go | 37 +++++++++++++++++++++-------------- util/sqlutil/result/result.go | 10 ++++++---- 15 files changed, 71 insertions(+), 52 deletions(-) diff --git a/mysql/collection.go b/mysql/collection.go index ad4e2ae6..d510642b 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -40,8 +40,8 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } // Truncate deletes all rows from the table. @@ -67,7 +67,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/mysql/database.go b/mysql/database.go index 389cb156..1714917c 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -285,11 +285,11 @@ func (d *database) Transaction() (db.Tx, error) { var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = d.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = d.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } @@ -316,7 +316,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { res, err = d.tx.Exec(query, args...) @@ -345,7 +345,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { rows, err = d.tx.Queryx(query, args...) @@ -374,7 +374,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { row = d.tx.QueryRowx(query, args...) diff --git a/mysql/mysql.go b/mysql/mysql.go index c0141dad..6cdf2d36 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -25,15 +25,16 @@ import ( "upper.io/cache" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) // Adapter is the public name of the adapter. const Adapter = `mysql` -var template *sqlgen.Template +var template *sqlutil.TemplateWithUtils func init() { - template = &sqlgen.Template{ + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ ColumnSeparator: adapterColumnSeparator, IdentifierSeparator: adapterIdentifierSeparator, IdentifierQuote: adapterIdentifierQuote, @@ -45,6 +46,7 @@ func init() { DescKeyword: adapterDescKeyword, AscKeyword: adapterAscKeyword, DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, ClauseGroup: adapterClauseGroup, ClauseOperator: adapterClauseOperator, ColumnValue: adapterColumnValue, @@ -63,7 +65,7 @@ func init() { CountLayout: adapterSelectCountLayout, GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), - } + }) db.Register(Adapter, &database{}) } diff --git a/mysql/template.go b/mysql/template.go index 97af9cd2..1b2d21da 100644 --- a/mysql/template.go +++ b/mysql/template.go @@ -33,6 +33,7 @@ const ( adapterDescKeyword = `DESC` adapterAscKeyword = `ASC` adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` adapterClauseGroup = `({{.}})` adapterClauseOperator = ` {{.}} ` adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` diff --git a/postgresql/collection.go b/postgresql/collection.go index 1f5d00e4..0158c17d 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -43,8 +43,8 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } // Truncate deletes all rows from the table. @@ -70,7 +70,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/postgresql/database.go b/postgresql/database.go index 0ece2ade..f66ad5fc 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -273,11 +273,11 @@ func (d *database) Transaction() (db.Tx, error) { var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = d.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = d.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } @@ -304,7 +304,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { @@ -338,7 +338,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { @@ -372,7 +372,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index f7fac8b1..7e8363cc 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -25,15 +25,16 @@ import ( "upper.io/cache" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) // Adapter is the public name of the adapter. const Adapter = `postgresql` -var template *sqlgen.Template +var template *sqlutil.TemplateWithUtils func init() { - template = &sqlgen.Template{ + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ ColumnSeparator: adapterColumnSeparator, IdentifierSeparator: adapterIdentifierSeparator, IdentifierQuote: adapterIdentifierQuote, @@ -45,6 +46,7 @@ func init() { DescKeyword: adapterDescKeyword, AscKeyword: adapterAscKeyword, DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, ClauseGroup: adapterClauseGroup, ClauseOperator: adapterClauseOperator, ColumnValue: adapterColumnValue, @@ -63,7 +65,7 @@ func init() { CountLayout: adapterSelectCountLayout, GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), - } + }) db.Register(Adapter, &database{}) } diff --git a/postgresql/template.go b/postgresql/template.go index 6df9d1c4..aa414ecb 100644 --- a/postgresql/template.go +++ b/postgresql/template.go @@ -33,6 +33,7 @@ const ( adapterDescKeyword = `DESC` adapterAscKeyword = `ASC` adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` adapterClauseGroup = `({{.}})` adapterClauseOperator = ` {{.}} ` adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` diff --git a/sqlite/collection.go b/sqlite/collection.go index 199a56af..0aadfcd0 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -41,8 +41,8 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } // Truncate deletes all rows from the table. @@ -68,7 +68,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/sqlite/database.go b/sqlite/database.go index d1d575ff..472e34f8 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -273,11 +273,11 @@ func (d *database) Transaction() (db.Tx, error) { var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = d.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = d.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } @@ -304,7 +304,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { res, err = d.tx.Exec(query, args...) @@ -333,7 +333,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { rows, err = d.tx.Queryx(query, args...) @@ -362,7 +362,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { row = d.tx.QueryRowx(query, args...) diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index eecef2d3..ea66cafd 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -25,15 +25,16 @@ import ( "upper.io/cache" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) // Adapter is the public name of the adapter. const Adapter = `sqlite` -var template *sqlgen.Template +var template *sqlutil.TemplateWithUtils func init() { - template = &sqlgen.Template{ + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ ColumnSeparator: adapterColumnSeparator, IdentifierSeparator: adapterIdentifierSeparator, IdentifierQuote: adapterIdentifierQuote, @@ -45,6 +46,7 @@ func init() { DescKeyword: adapterDescKeyword, AscKeyword: adapterAscKeyword, DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, ClauseGroup: adapterClauseGroup, ClauseOperator: adapterClauseOperator, ColumnValue: adapterColumnValue, @@ -63,7 +65,7 @@ func init() { CountLayout: adapterSelectCountLayout, GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), - } + }) db.Register(Adapter, &database{}) } diff --git a/sqlite/template.go b/sqlite/template.go index 48c1ce6d..abea1661 100644 --- a/sqlite/template.go +++ b/sqlite/template.go @@ -33,6 +33,7 @@ const ( adapterDescKeyword = `DESC` adapterAscKeyword = `ASC` adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` adapterClauseGroup = `({{.}})` adapterClauseOperator = ` {{.}} ` adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` diff --git a/util/sqlgen/template.go b/util/sqlgen/template.go index 78663f54..bea487f3 100644 --- a/util/sqlgen/template.go +++ b/util/sqlgen/template.go @@ -48,6 +48,7 @@ type Template struct { DescKeyword string AscKeyword string DefaultOperator string + AssignmentOperator string ClauseGroup string ClauseOperator string ColumnValue string diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go index d88bfaca..d7bc224c 100644 --- a/util/sqlutil/convert.go +++ b/util/sqlutil/convert.go @@ -9,20 +9,27 @@ import ( ) var ( - sqlPlaceholder = sqlgen.RawValue(`?`) - sqlNull = sqlgen.RawValue(`NULL`) - sqlDefaultOperator = "=" + sqlPlaceholder = sqlgen.RawValue(`?`) + sqlNull = sqlgen.RawValue(`NULL`) ) +type TemplateWithUtils struct { + *sqlgen.Template +} + +func NewTemplateWithUtils(template *sqlgen.Template) *TemplateWithUtils { + return &TemplateWithUtils{template} +} + // ToWhereWithArguments converts the given db.Cond parameters into a sqlgen.Where // value. -func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { +func (tu *TemplateWithUtils) ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { args = []interface{}{} switch t := term.(type) { case []interface{}: for i := range t { - w, v := ToWhereWithArguments(t[i]) + w, v := tu.ToWhereWithArguments(t[i]) args = append(args, v...) where.Conditions = append(where.Conditions, w.Conditions...) } @@ -30,7 +37,7 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac case db.And: var op sqlgen.And for i := range t { - k, v := ToWhereWithArguments(t[i]) + k, v := tu.ToWhereWithArguments(t[i]) args = append(args, v...) op.Conditions = append(op.Conditions, k.Conditions...) } @@ -39,7 +46,7 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac case db.Or: var op sqlgen.Or for i := range t { - w, v := ToWhereWithArguments(t[i]) + w, v := tu.ToWhereWithArguments(t[i]) args = append(args, v...) op.Conditions = append(op.Conditions, w.Conditions...) } @@ -51,14 +58,14 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac } return case db.Cond: - cv, v := ToColumnValues(t) + cv, v := tu.ToColumnValues(t) args = append(args, v...) for i := range cv.ColumnValues { where.Conditions = append(where.Conditions, cv.ColumnValues[i]) } return case db.Constrainer: - cv, v := ToColumnValues(t.Constraint()) + cv, v := tu.ToColumnValues(t.Constraint()) args = append(args, v...) for i := range cv.ColumnValues { where.Conditions = append(where.Conditions, cv.ColumnValues[i]) @@ -70,7 +77,7 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac } // ToInterfaceArguments converts the given value into an array of interfaces. -func ToInterfaceArguments(value interface{}) (args []interface{}) { +func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []interface{}) { if value == nil { return nil } @@ -100,7 +107,7 @@ func ToInterfaceArguments(value interface{}) (args []interface{}) { } // ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct. -func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { +func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { args = []interface{}{} @@ -116,12 +123,12 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in if len(chunks) > 1 { columnValue.Operator = chunks[1] } else { - columnValue.Operator = sqlDefaultOperator + columnValue.Operator = tu.DefaultOperator } switch value := value.(type) { case db.Func: - v := ToInterfaceArguments(value.Args) + v := tu.ToInterfaceArguments(value.Args) columnValue.Operator = value.Name if v == nil { @@ -134,7 +141,7 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in args = append(args, v...) default: - v := ToInterfaceArguments(value) + v := tu.ToInterfaceArguments(value) l := len(v) if v == nil || l == 0 { @@ -160,7 +167,7 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in // ToColumnsValuesAndArguments maps the given columnNames and columnValues into // sqlgen's Columns and Values, it also extracts and returns query arguments. -func ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { +func (tu *TemplateWithUtils) ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { var arguments []interface{} columns := new(sqlgen.Columns) diff --git a/util/sqlutil/result/result.go b/util/sqlutil/result/result.go index 9616c9bb..7badb0e7 100644 --- a/util/sqlutil/result/result.go +++ b/util/sqlutil/result/result.go @@ -49,15 +49,17 @@ type Result struct { orderBy sqlgen.OrderBy groupBy sqlgen.GroupBy arguments []interface{} + template *sqlutil.TemplateWithUtils } // NewResult creates and results a new result set on the given table, this set // is limited by the given sqlgen.Where conditions. -func NewResult(p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { +func NewResult(template *sqlutil.TemplateWithUtils, p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { return &Result{ table: p, where: where, arguments: arguments, + template: template, } } @@ -82,7 +84,7 @@ func (r *Result) setCursor() error { // Sets conditions for reducing the working set. func (r *Result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = sqlutil.ToWhereWithArguments(terms) + r.where, r.arguments = r.template.ToWhereWithArguments(terms) return r } @@ -166,7 +168,7 @@ func (r *Result) Select(fields ...interface{}) db.Result { var col sqlgen.Fragment switch value := fields[i].(type) { case db.Func: - v := sqlutil.ToInterfaceArguments(value.Args) + v := r.template.ToInterfaceArguments(value.Args) var s string if len(v) == 0 { s = fmt.Sprintf(`%s()`, value.Name) @@ -269,7 +271,7 @@ func (r *Result) Update(values interface{}) error { cvs := new(sqlgen.ColumnValues) for i := range ff { - cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: "=", Value: sqlPlaceholder}) + cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: r.template.AssignmentOperator, Value: sqlPlaceholder}) } vv = append(vv, r.arguments...) -- GitLab