diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 649b6670d770293fbccc429b7e3e2719fe42fbb0..f7e7ba90d8b28763b4bcef031616a1efe839053a 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -182,13 +182,11 @@ func (b *sqlBuilder) DeleteFrom(table string) Deleter { func (b *sqlBuilder) Update(table string) Updater { qu := &updater{ - builder: b, - table: table, - columnValues: &exql.ColumnValues{}, + builder: b, } qu.stringer = &stringer{qu, b.t.Template} - return qu + return qu.setTable(table) } // Map receives a pointer to map or struct and maps it to columns and values. diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index 19fffd66abdd15db4e40cd6eb04e8fcbf42b5d7c..da512533c31ad41197c4dcb14e78ccc9f7d095a2 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -79,7 +79,7 @@ func (ins *inserter) String() string { } func (ins *inserter) frame(fn func(*inserterQuery) error) *inserter { - return &inserter{prev: ins, fn: fn} + return &inserter{prev: ins, fn: fn, builder: ins.builder} } func (ins *inserter) clone() *inserter { diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index b7cfd83141d3838fba23ee3a47fe0576bd6a51e8..d58072a46cf9874e3295427c0f153c60cc056573 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -2,15 +2,13 @@ package sqlbuilder import ( "database/sql" - "sync" + "strings" "upper.io/db.v2/internal/sqladapter/exql" ) -type updater struct { - *stringer - builder *sqlBuilder - table string +type updaterQuery struct { + table string columnValues *exql.ColumnValues columnValuesArgs []interface{} @@ -19,81 +17,182 @@ type updater struct { where *exql.Where whereArgs []interface{} - - mu sync.Mutex } -func (qu *updater) Set(terms ...interface{}) Updater { - if len(terms) == 1 { - ff, vv, _ := Map(terms[0], nil) +func (uq *updaterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Update, + Table: exql.TableWithName(uq.table), + ColumnValues: uq.columnValues, + } + + if uq.where != nil { + stmt.Where = uq.where + } - cvs := make([]exql.Fragment, 0, len(ff)) - args := make([]interface{}, 0, len(vv)) + if uq.limit != 0 { + stmt.Limit = exql.Limit(uq.limit) + } - for i := range ff { - cv := &exql.ColumnValue{ - Column: exql.ColumnWithName(ff[i]), - Operator: qu.builder.t.AssignmentOperator, - } + return stmt +} + +func (uq *updaterQuery) arguments() []interface{} { + return joinArguments( + uq.columnValuesArgs, + uq.whereArgs, + ) +} + +type updater struct { + *stringer + builder *sqlBuilder - var localArgs []interface{} - cv.Value, localArgs = qu.builder.t.PlaceholderValue(vv[i]) + fn func(*updaterQuery) error + prev *updater +} - args = append(args, localArgs...) - cvs = append(cvs, cv) +func (upd *updater) Builder() *sqlBuilder { + p := &upd + for { + if (*p).builder != nil { + return (*p).builder + } + if (*p).prev == nil { + return nil } + p = &(*p).prev + } +} - qu.columnValues.Insert(cvs...) - qu.columnValuesArgs = append(qu.columnValuesArgs, args...) - } else if len(terms) > 1 { - cv, arguments := toColumnValues(terms) - qu.columnValues.Insert(cv.ColumnValues...) - qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...) +func (upd *updater) Stringer() *stringer { + p := &upd + for { + if (*p).stringer != nil { + return (*p).stringer + } + if (*p).prev == nil { + return nil + } + p = &(*p).prev } +} - return qu +func (upd *updater) String() string { + query, err := upd.build() + if err != nil { + return "" + } + q := upd.Stringer().compileAndReplacePlaceholders(query.statement()) + q = reInvisibleChars.ReplaceAllString(q, ` `) + return strings.TrimSpace(q) } -func (qu *updater) Arguments() []interface{} { - qu.mu.Lock() - defer qu.mu.Unlock() +func (upd *updater) setTable(table string) *updater { + return upd.frame(func(uq *updaterQuery) error { + uq.table = table + return nil + }) +} - return joinArguments( - qu.columnValuesArgs, - qu.whereArgs, - ) +func (upd *updater) frame(fn func(*updaterQuery) error) *updater { + return &updater{prev: upd, fn: fn} } -func (qu *updater) Where(terms ...interface{}) Updater { - where, arguments := toWhereWithArguments(terms) - qu.where = &where - qu.whereArgs = append(qu.whereArgs, arguments...) - return qu +func (upd *updater) Set(terms ...interface{}) Updater { + return upd.frame(func(uq *updaterQuery) error { + if uq.columnValues == nil { + uq.columnValues = &exql.ColumnValues{} + } + + if len(terms) == 1 { + ff, vv, _ := Map(terms[0], nil) + + cvs := make([]exql.Fragment, 0, len(ff)) + args := make([]interface{}, 0, len(vv)) + + for i := range ff { + cv := &exql.ColumnValue{ + Column: exql.ColumnWithName(ff[i]), + Operator: upd.Builder().t.AssignmentOperator, + } + + var localArgs []interface{} + cv.Value, localArgs = upd.Builder().t.PlaceholderValue(vv[i]) + + args = append(args, localArgs...) + cvs = append(cvs, cv) + } + + uq.columnValues.Insert(cvs...) + uq.columnValuesArgs = append(uq.columnValuesArgs, args...) + } else if len(terms) > 1 { + cv, arguments := toColumnValues(terms) + uq.columnValues.Insert(cv.ColumnValues...) + uq.columnValuesArgs = append(uq.columnValuesArgs, arguments...) + } + + return nil + }) } -func (qu *updater) Exec() (sql.Result, error) { - return qu.builder.sess.StatementExec(qu.statement(), qu.Arguments()...) +func (upd *updater) Arguments() []interface{} { + uq, err := upd.build() + if err != nil { + return nil + } + return uq.arguments() } -func (qu *updater) Limit(limit int) Updater { - qu.limit = limit - return qu +func (upd *updater) Where(terms ...interface{}) Updater { + return upd.frame(func(uq *updaterQuery) error { + where, arguments := toWhereWithArguments(terms) + uq.where = &where + uq.whereArgs = append(uq.whereArgs, arguments...) + return nil + }) } -func (qu *updater) statement() *exql.Statement { - stmt := &exql.Statement{ - Type: exql.Update, - Table: exql.TableWithName(qu.table), - ColumnValues: qu.columnValues, +func (upd *updater) Exec() (sql.Result, error) { + uq, err := upd.build() + if err != nil { + return nil, err } + return upd.builder.sess.StatementExec(uq.statement(), uq.arguments()...) +} - if qu.Where != nil { - stmt.Where = qu.where - } +func (upd *updater) Limit(limit int) Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.limit = limit + return nil + }) +} + +func (upd *updater) statement() *exql.Statement { + iq, _ := upd.build() + return iq.statement() +} - if qu.limit != 0 { - stmt.Limit = exql.Limit(qu.limit) +func (upd *updater) build() (*updaterQuery, error) { + iq, err := updaterFastForward(&updaterQuery{}, upd) + if err != nil { + return nil, err } + return iq, nil +} - return stmt +func (upd *updater) Compile() string { + return upd.statement().Compile(upd.Stringer().t) +} + +func updaterFastForward(in *updaterQuery, curr *updater) (*updaterQuery, error) { + if curr == nil || curr.fn == nil { + return in, nil + } + in, err := updaterFastForward(in, curr.prev) + if err != nil { + return nil, err + } + err = curr.fn(in) + return in, err }