diff --git a/mysql/collection.go b/mysql/collection.go index ad4e2ae6d4f50ae5535db7ec4e80c54975053a29..d510642b72976d0a02435c5469ec4e54e08e6caf 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -40,8 +40,8 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } // Truncate deletes all rows from the table. @@ -67,7 +67,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/mysql/database.go b/mysql/database.go index 389cb156146176f562ccfa47f7150f77c7bc664e..1714917c752cd0507fa3c784eb27b0d70bcd87e4 100644 --- a/mysql/database.go +++ b/mysql/database.go @@ -285,11 +285,11 @@ func (d *database) Transaction() (db.Tx, error) { var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = d.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = d.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } @@ -316,7 +316,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { res, err = d.tx.Exec(query, args...) @@ -345,7 +345,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { rows, err = d.tx.Queryx(query, args...) @@ -374,7 +374,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { row = d.tx.QueryRowx(query, args...) diff --git a/mysql/mysql.go b/mysql/mysql.go index c0141dadff36125a4483c0d2b51d340f062dd6d5..6cdf2d36b7d8aa56b163581a11b4c10f05d41ec8 100644 --- a/mysql/mysql.go +++ b/mysql/mysql.go @@ -25,15 +25,16 @@ import ( "upper.io/cache" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) // Adapter is the public name of the adapter. const Adapter = `mysql` -var template *sqlgen.Template +var template *sqlutil.TemplateWithUtils func init() { - template = &sqlgen.Template{ + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ ColumnSeparator: adapterColumnSeparator, IdentifierSeparator: adapterIdentifierSeparator, IdentifierQuote: adapterIdentifierQuote, @@ -45,6 +46,7 @@ func init() { DescKeyword: adapterDescKeyword, AscKeyword: adapterAscKeyword, DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, ClauseGroup: adapterClauseGroup, ClauseOperator: adapterClauseOperator, ColumnValue: adapterColumnValue, @@ -63,7 +65,7 @@ func init() { CountLayout: adapterSelectCountLayout, GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), - } + }) db.Register(Adapter, &database{}) } diff --git a/mysql/template.go b/mysql/template.go index 97af9cd2ac398bfe64fa08a6d071aa6fa9e90cba..1b2d21dab87f67e0f0cc87ed872d9ba9f37b1112 100644 --- a/mysql/template.go +++ b/mysql/template.go @@ -33,6 +33,7 @@ const ( adapterDescKeyword = `DESC` adapterAscKeyword = `ASC` adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` adapterClauseGroup = `({{.}})` adapterClauseOperator = ` {{.}} ` adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` diff --git a/postgresql/collection.go b/postgresql/collection.go index 1f5d00e487d8c7efbb41202df9c67cf938dfacb8..0158c17dd6f8b646cc5a94e685a46dfbb4e725b0 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -43,8 +43,8 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } // Truncate deletes all rows from the table. @@ -70,7 +70,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/postgresql/database.go b/postgresql/database.go index 0ece2adeec987213ebd43b6ac224911cb05412a7..f66ad5fc6a4baab675572c254c5eb0367d13af34 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -273,11 +273,11 @@ func (d *database) Transaction() (db.Tx, error) { var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = d.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = d.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } @@ -304,7 +304,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { @@ -338,7 +338,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { @@ -372,7 +372,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) l := len(args) for i := 0; i < l; i++ { diff --git a/postgresql/postgresql.go b/postgresql/postgresql.go index f7fac8b1015eeb45d745640df141a62bf3c18a8b..7e8363cc3b1de831ce83948f0ed1baf59fadc51b 100644 --- a/postgresql/postgresql.go +++ b/postgresql/postgresql.go @@ -25,15 +25,16 @@ import ( "upper.io/cache" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) // Adapter is the public name of the adapter. const Adapter = `postgresql` -var template *sqlgen.Template +var template *sqlutil.TemplateWithUtils func init() { - template = &sqlgen.Template{ + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ ColumnSeparator: adapterColumnSeparator, IdentifierSeparator: adapterIdentifierSeparator, IdentifierQuote: adapterIdentifierQuote, @@ -45,6 +46,7 @@ func init() { DescKeyword: adapterDescKeyword, AscKeyword: adapterAscKeyword, DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, ClauseGroup: adapterClauseGroup, ClauseOperator: adapterClauseOperator, ColumnValue: adapterColumnValue, @@ -63,7 +65,7 @@ func init() { CountLayout: adapterSelectCountLayout, GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), - } + }) db.Register(Adapter, &database{}) } diff --git a/postgresql/template.go b/postgresql/template.go index 6df9d1c46f8c361f4ee11f81082a72142997fea7..aa414ecbb85e1065fa52a2145eb403ac69e6c690 100644 --- a/postgresql/template.go +++ b/postgresql/template.go @@ -33,6 +33,7 @@ const ( adapterDescKeyword = `DESC` adapterAscKeyword = `ASC` adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` adapterClauseGroup = `({{.}})` adapterClauseOperator = ` {{.}} ` adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` diff --git a/sqlite/collection.go b/sqlite/collection.go index 199a56afcd63a047fe1aa4364034657ae400c429..0aadfcd010ce8ae64d1e4ce1bc6321b3d69bd9c4 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -41,8 +41,8 @@ var _ = db.Collection(&table{}) // Find creates a result set with the given conditions. func (t *table) Find(terms ...interface{}) db.Result { - where, arguments := sqlutil.ToWhereWithArguments(terms) - return result.NewResult(t, where, arguments) + where, arguments := template.ToWhereWithArguments(terms) + return result.NewResult(template, t, where, arguments) } // Truncate deletes all rows from the table. @@ -68,7 +68,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := template.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/sqlite/database.go b/sqlite/database.go index d1d575ffe0071eab7aec101a46ba116f70f9d70e..472e34f8c873272829abe66ade0fa8063d2542b7 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -273,11 +273,11 @@ func (d *database) Transaction() (db.Tx, error) { var clone *database var sqlTx *sqlx.Tx - if sqlTx, err = d.session.Beginx(); err != nil { + if clone, err = d.clone(); err != nil { return nil, err } - if clone, err = d.clone(); err != nil { + if sqlTx, err = clone.session.Beginx(); err != nil { return nil, err } @@ -304,7 +304,7 @@ func (d *database) Exec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { res, err = d.tx.Exec(query, args...) @@ -333,7 +333,7 @@ func (d *database) Query(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { rows, err = d.tx.Queryx(query, args...) @@ -362,7 +362,7 @@ func (d *database) QueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.R return nil, db.ErrNotConnected } - query = stmt.Compile(template) + query = stmt.Compile(template.Template) if d.tx != nil { row = d.tx.QueryRowx(query, args...) diff --git a/sqlite/sqlite.go b/sqlite/sqlite.go index eecef2d3099bbbf931a6f9050ecb0a75115020fe..ea66cafdedeb1a49e361414cda975acdfb47c827 100644 --- a/sqlite/sqlite.go +++ b/sqlite/sqlite.go @@ -25,15 +25,16 @@ import ( "upper.io/cache" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) // Adapter is the public name of the adapter. const Adapter = `sqlite` -var template *sqlgen.Template +var template *sqlutil.TemplateWithUtils func init() { - template = &sqlgen.Template{ + template = sqlutil.NewTemplateWithUtils(&sqlgen.Template{ ColumnSeparator: adapterColumnSeparator, IdentifierSeparator: adapterIdentifierSeparator, IdentifierQuote: adapterIdentifierQuote, @@ -45,6 +46,7 @@ func init() { DescKeyword: adapterDescKeyword, AscKeyword: adapterAscKeyword, DefaultOperator: adapterDefaultOperator, + AssignmentOperator: adapterAssignmentOperator, ClauseGroup: adapterClauseGroup, ClauseOperator: adapterClauseOperator, ColumnValue: adapterColumnValue, @@ -63,7 +65,7 @@ func init() { CountLayout: adapterSelectCountLayout, GroupByLayout: adapterGroupByLayout, Cache: cache.NewCache(), - } + }) db.Register(Adapter, &database{}) } diff --git a/sqlite/template.go b/sqlite/template.go index 48c1ce6d4d6dda827b064461504e87ac959806a5..abea16614305df659ce23b2b33a84df02a01da90 100644 --- a/sqlite/template.go +++ b/sqlite/template.go @@ -33,6 +33,7 @@ const ( adapterDescKeyword = `DESC` adapterAscKeyword = `ASC` adapterDefaultOperator = `=` + adapterAssignmentOperator = `=` adapterClauseGroup = `({{.}})` adapterClauseOperator = ` {{.}} ` adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` diff --git a/util/sqlgen/template.go b/util/sqlgen/template.go index 78663f54ba94edddf66c2f59b38b07d986c13e0f..bea487f39f26fb4061b9aeef3c81bd2973f7fba5 100644 --- a/util/sqlgen/template.go +++ b/util/sqlgen/template.go @@ -48,6 +48,7 @@ type Template struct { DescKeyword string AscKeyword string DefaultOperator string + AssignmentOperator string ClauseGroup string ClauseOperator string ColumnValue string diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go index d88bfaca92dafd66c5fbca4198e087dbc9a27b98..d7bc224c8e75bf632986ed8ea87513644dc238e0 100644 --- a/util/sqlutil/convert.go +++ b/util/sqlutil/convert.go @@ -9,20 +9,27 @@ import ( ) var ( - sqlPlaceholder = sqlgen.RawValue(`?`) - sqlNull = sqlgen.RawValue(`NULL`) - sqlDefaultOperator = "=" + sqlPlaceholder = sqlgen.RawValue(`?`) + sqlNull = sqlgen.RawValue(`NULL`) ) +type TemplateWithUtils struct { + *sqlgen.Template +} + +func NewTemplateWithUtils(template *sqlgen.Template) *TemplateWithUtils { + return &TemplateWithUtils{template} +} + // ToWhereWithArguments converts the given db.Cond parameters into a sqlgen.Where // value. -func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { +func (tu *TemplateWithUtils) ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interface{}) { args = []interface{}{} switch t := term.(type) { case []interface{}: for i := range t { - w, v := ToWhereWithArguments(t[i]) + w, v := tu.ToWhereWithArguments(t[i]) args = append(args, v...) where.Conditions = append(where.Conditions, w.Conditions...) } @@ -30,7 +37,7 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac case db.And: var op sqlgen.And for i := range t { - k, v := ToWhereWithArguments(t[i]) + k, v := tu.ToWhereWithArguments(t[i]) args = append(args, v...) op.Conditions = append(op.Conditions, k.Conditions...) } @@ -39,7 +46,7 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac case db.Or: var op sqlgen.Or for i := range t { - w, v := ToWhereWithArguments(t[i]) + w, v := tu.ToWhereWithArguments(t[i]) args = append(args, v...) op.Conditions = append(op.Conditions, w.Conditions...) } @@ -51,14 +58,14 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac } return case db.Cond: - cv, v := ToColumnValues(t) + cv, v := tu.ToColumnValues(t) args = append(args, v...) for i := range cv.ColumnValues { where.Conditions = append(where.Conditions, cv.ColumnValues[i]) } return case db.Constrainer: - cv, v := ToColumnValues(t.Constraint()) + cv, v := tu.ToColumnValues(t.Constraint()) args = append(args, v...) for i := range cv.ColumnValues { where.Conditions = append(where.Conditions, cv.ColumnValues[i]) @@ -70,7 +77,7 @@ func ToWhereWithArguments(term interface{}) (where sqlgen.Where, args []interfac } // ToInterfaceArguments converts the given value into an array of interfaces. -func ToInterfaceArguments(value interface{}) (args []interface{}) { +func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []interface{}) { if value == nil { return nil } @@ -100,7 +107,7 @@ func ToInterfaceArguments(value interface{}) (args []interface{}) { } // ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct. -func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { +func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { args = []interface{}{} @@ -116,12 +123,12 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in if len(chunks) > 1 { columnValue.Operator = chunks[1] } else { - columnValue.Operator = sqlDefaultOperator + columnValue.Operator = tu.DefaultOperator } switch value := value.(type) { case db.Func: - v := ToInterfaceArguments(value.Args) + v := tu.ToInterfaceArguments(value.Args) columnValue.Operator = value.Name if v == nil { @@ -134,7 +141,7 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in args = append(args, v...) default: - v := ToInterfaceArguments(value) + v := tu.ToInterfaceArguments(value) l := len(v) if v == nil || l == 0 { @@ -160,7 +167,7 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in // ToColumnsValuesAndArguments maps the given columnNames and columnValues into // sqlgen's Columns and Values, it also extracts and returns query arguments. -func ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { +func (tu *TemplateWithUtils) ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { var arguments []interface{} columns := new(sqlgen.Columns) diff --git a/util/sqlutil/result/result.go b/util/sqlutil/result/result.go index 9616c9bb9d4b5ab214440883476445c6d885dec3..7badb0e750e2774ff0f511bb1018cccd8b624813 100644 --- a/util/sqlutil/result/result.go +++ b/util/sqlutil/result/result.go @@ -49,15 +49,17 @@ type Result struct { orderBy sqlgen.OrderBy groupBy sqlgen.GroupBy arguments []interface{} + template *sqlutil.TemplateWithUtils } // NewResult creates and results a new result set on the given table, this set // is limited by the given sqlgen.Where conditions. -func NewResult(p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { +func NewResult(template *sqlutil.TemplateWithUtils, p DataProvider, where sqlgen.Where, arguments []interface{}) *Result { return &Result{ table: p, where: where, arguments: arguments, + template: template, } } @@ -82,7 +84,7 @@ func (r *Result) setCursor() error { // Sets conditions for reducing the working set. func (r *Result) Where(terms ...interface{}) db.Result { - r.where, r.arguments = sqlutil.ToWhereWithArguments(terms) + r.where, r.arguments = r.template.ToWhereWithArguments(terms) return r } @@ -166,7 +168,7 @@ func (r *Result) Select(fields ...interface{}) db.Result { var col sqlgen.Fragment switch value := fields[i].(type) { case db.Func: - v := sqlutil.ToInterfaceArguments(value.Args) + v := r.template.ToInterfaceArguments(value.Args) var s string if len(v) == 0 { s = fmt.Sprintf(`%s()`, value.Name) @@ -269,7 +271,7 @@ func (r *Result) Update(values interface{}) error { cvs := new(sqlgen.ColumnValues) for i := range ff { - cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: "=", Value: sqlPlaceholder}) + cvs.ColumnValues = append(cvs.ColumnValues, &sqlgen.ColumnValue{Column: sqlgen.ColumnWithName(ff[i]), Operator: r.template.AssignmentOperator, Value: sqlPlaceholder}) } vv = append(vv, r.arguments...)