diff --git a/db.go b/db.go index 2d88627576855dd25ab53638a5201f2f598bf1d0..8bd3d2a4e43130ad3210b81dcfcea7de7597f9ea 100644 --- a/db.go +++ b/db.go @@ -80,6 +80,8 @@ import ( "fmt" "reflect" "time" + + "upper.io/db.v2/internal/immutable" ) // Constraint interface represents a single condition, like "a = 1". where `a` @@ -262,30 +264,25 @@ func (r rawValue) Empty() bool { type compound struct { prev *compound - fn func() []Compound + fn func(*[]Compound) error } func newCompound(conds ...Compound) *compound { c := &compound{} - if len(conds) > 0 { - c.fn = func() []Compound { - return conds - } - } - return c + return c.frame(func(in *[]Compound) error { + *in = append(*in, conds...) + return nil + }) } -func defaultJoin(in ...Compound) []Compound { - for i := range in { - if cond, ok := in[i].(Cond); ok && len(cond) > 1 { - in[i] = And(cond) - } - } - return in -} +var _ = immutable.Immutable(&compound{}) func (c *compound) Sentences() []Compound { - return compoundFastForward(c) + conds, err := immutable.FastForward(c) + if err == nil { + return *(conds.(*[]Compound)) + } + return nil } func (c *compound) Operator() CompoundOperator { @@ -299,20 +296,35 @@ func (c *compound) Empty() bool { return true } -func (c *compound) frame(a []Compound) *compound { - if len(a) == 0 { - return c +func (c *compound) frame(fn func(*[]Compound) error) *compound { + return &compound{prev: c, fn: fn} +} + +func (c *compound) Prev() immutable.Immutable { + if c == nil { + return nil } - nc := newCompound(a...) - nc.prev = c - return nc + return c.prev } -func compoundFastForward(curr *compound) []Compound { - if curr == nil || curr.fn == nil { - return []Compound{} +func (c *compound) Fn(in interface{}) error { + if c.fn == nil { + return nil } - return append(compoundFastForward(curr.prev), curr.fn()...) + return c.fn(in.(*[]Compound)) +} + +func (c *compound) Base() interface{} { + return &[]Compound{} +} + +func defaultJoin(in ...Compound) []Compound { + for i := range in { + if cond, ok := in[i].(Cond); ok && len(cond) > 1 { + in[i] = And(cond) + } + } + return in } // Union represents a compound joined by OR. @@ -321,8 +333,11 @@ type Union struct { } // Or adds more terms to the compound. -func (o *Union) Or(conds ...Compound) *Union { - return &Union{o.compound.frame(conds)} +func (o *Union) Or(orConds ...Compound) *Union { + return &Union{o.compound.frame(func(in *[]Compound) error { + *in = append(*in, orConds...) + return nil + })} } // Operator returns the OR operator. @@ -336,8 +351,11 @@ func (o *Union) Empty() bool { } // And adds more terms to the compound. -func (a *Intersection) And(conds ...Compound) *Intersection { - return &Intersection{a.compound.frame(conds)} +func (a *Intersection) And(andConds ...Compound) *Intersection { + return &Intersection{a.compound.frame(func(in *[]Compound) error { + *in = append(*in, andConds...) + return nil + })} } // Empty returns false if this struct holds no conditions. diff --git a/internal/immutable/immutable.go b/internal/immutable/immutable.go new file mode 100644 index 0000000000000000000000000000000000000000..5969c34b779ca40bcd7d286c402e13d1d68e18e2 --- /dev/null +++ b/internal/immutable/immutable.go @@ -0,0 +1,22 @@ +package immutable + +// Immutable represents immutable chains +type Immutable interface { + Prev() Immutable + Fn(interface{}) error + Base() interface{} +} + +// FastForward applies all Fn methods in order on the given new Base. +func FastForward(curr Immutable) (interface{}, error) { + prev := curr.Prev() + if prev == nil { + return curr.Base(), nil + } + in, err := FastForward(prev) + if err != nil { + return nil, err + } + err = curr.Fn(in) + return in, err +} diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go index fafb6c2a1e19ea3bed1a49e7f3c48f80bc30ce8f..d27525f701f2fb08c41daabd508f3f00df84b8cb 100644 --- a/internal/sqladapter/result.go +++ b/internal/sqladapter/result.go @@ -26,11 +26,12 @@ import ( "sync/atomic" "upper.io/db.v2" + "upper.io/db.v2/internal/immutable" "upper.io/db.v2/lib/sqlbuilder" ) type Result struct { - b sqlbuilder.Builder + builder sqlbuilder.Builder err atomic.Value @@ -59,9 +60,9 @@ func filter(conds []interface{}) []interface{} { // NewResult creates and Results a new Result set on the given table, this set // is limited by the given exql.Where conditions. -func NewResult(b sqlbuilder.Builder, table string, conds []interface{}) *Result { +func NewResult(builder sqlbuilder.Builder, table string, conds []interface{}) *Result { r := &Result{ - b: b, + builder: builder, } return r.from(table).where(conds) } @@ -70,17 +71,11 @@ func (r *Result) frame(fn func(*result) error) *Result { return &Result{prev: r, fn: fn} } -func (r *Result) builder() sqlbuilder.Builder { - p := &r - for { - if (*p).b != nil { - return (*p).b - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev +func (r *Result) Builder() sqlbuilder.Builder { + if r.prev == nil { + return r.builder } + return r.prev.Builder() } func (r *Result) from(table string) *Result { @@ -274,9 +269,12 @@ func (r *Result) Count() (uint64, error) { } func (r *Result) buildSelect() (sqlbuilder.Selector, error) { - res, err := resultFastForward(&result{}, r) + res, err := r.fastForward() + if err != nil { + return nil, err + } - sel := r.builder().Select(res.fields...). + sel := r.Builder().Select(res.fields...). From(res.table). Where(filter(res.conds)...). Limit(res.limit). @@ -284,50 +282,75 @@ func (r *Result) buildSelect() (sqlbuilder.Selector, error) { GroupBy(res.groupBy...). OrderBy(res.orderBy...) - return sel, err + return sel, nil } func (r *Result) buildDelete() (sqlbuilder.Deleter, error) { - res, err := resultFastForward(&result{}, r) + res, err := r.fastForward() + if err != nil { + return nil, err + } - del := r.builder().DeleteFrom(res.table). + del := r.Builder().DeleteFrom(res.table). Where(filter(res.conds)...). Limit(res.limit) - return del, err + return del, nil } func (r *Result) buildUpdate(values interface{}) (sqlbuilder.Updater, error) { - res, err := resultFastForward(&result{}, r) + res, err := r.fastForward() + if err != nil { + return nil, err + } - upd := r.builder().Update(res.table). + upd := r.Builder().Update(res.table). Set(values). Where(filter(res.conds)...). Limit(res.limit) - return upd, err + return upd, nil +} + +func (r *Result) fastForward() (*result, error) { + ff, err := immutable.FastForward(r) + if err != nil { + return nil, err + } + return ff.(*result), nil } func (r *Result) buildCount() (sqlbuilder.Selector, error) { - res, err := resultFastForward(&result{}, r) + res, err := r.fastForward() + if err != nil { + return nil, err + } - sel := r.builder().Select(db.Raw("count(1) AS _t")). + sel := r.Builder().Select(db.Raw("count(1) AS _t")). From(res.table). Where(filter(res.conds)...). GroupBy(res.groupBy...). Limit(1) - return sel, err + return sel, nil } -func resultFastForward(in *result, curr *Result) (*result, error) { - if curr == nil || curr.fn == nil { - return in, nil +func (r *Result) Prev() immutable.Immutable { + if r == nil { + return nil } - in, err := resultFastForward(in, curr.prev) - if err != nil { - return nil, err + return r.prev +} + +func (r *Result) Fn(in interface{}) error { + if r.fn == nil { + return nil } - err = curr.fn(in) - return in, err + return r.fn(in.(*result)) } + +func (r *Result) Base() interface{} { + return &result{} +} + +var _ = immutable.Immutable(&Result{}) diff --git a/lib/sqlbuilder/batch.go b/lib/sqlbuilder/batch.go index 1ed3836534d6596399441e98aa1412ae409217e3..a8a8b144df5e07b3d1a2e6acfd5e2117270f00b0 100644 --- a/lib/sqlbuilder/batch.go +++ b/lib/sqlbuilder/batch.go @@ -28,11 +28,12 @@ func (b *BatchInserter) Values(values ...interface{}) *BatchInserter { } func (b *BatchInserter) nextQuery() *inserter { - clone := b.inserter.clone() + ins := &inserter{} + *ins = *b.inserter i := 0 for values := range b.values { i++ - clone = clone.Values(values...).(*inserter) + ins = ins.Values(values...).(*inserter) if i == b.size { break } @@ -40,7 +41,7 @@ func (b *BatchInserter) nextQuery() *inserter { if i == 0 { return nil } - return clone + return ins } // NextResult is useful when using PostgreSQL and Returning(), it dumps the diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 31ea07361389f9a8ef6cce7e601430648e5fdcaf..be5eaa8491f1e0721341468ff5a44b5a5a7bc41d 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -31,10 +31,6 @@ type compilable interface { Arguments() []interface{} } -type hasStringer interface { - Stringer() *stringer -} - type hasIsZero interface { IsZero() bool } @@ -47,11 +43,6 @@ type hasStatement interface { statement() *exql.Statement } -type stringer struct { - i hasStatement - t *exql.Template -} - type iterator struct { cursor *sql.Rows // This is the main query cursor. It starts as a nil value. err error @@ -159,7 +150,6 @@ func (b *sqlBuilder) SelectFrom(table ...interface{}) Selector { qs := &selector{ builder: b, } - qs.stringer = &stringer{qs, b.t.Template} return qs.From(table...) } @@ -167,8 +157,6 @@ func (b *sqlBuilder) Select(columns ...interface{}) Selector { qs := &selector{ builder: b, } - - qs.stringer = &stringer{qs, b.t.Template} return qs.Columns(columns...) } @@ -176,8 +164,6 @@ func (b *sqlBuilder) InsertInto(table string) Inserter { qi := &inserter{ builder: b, } - - qi.stringer = &stringer{qi, b.t.Template} return qi.Into(table) } @@ -185,8 +171,6 @@ func (b *sqlBuilder) DeleteFrom(table string) Deleter { qd := &deleter{ builder: b, } - - qd.stringer = &stringer{qd, b.t.Template} return qd.setTable(table) } @@ -194,8 +178,6 @@ func (b *sqlBuilder) Update(table string) Updater { qu := &updater{ builder: b, } - - qu.stringer = &stringer{qu, b.t.Template} return qu.setTable(table) } @@ -361,29 +343,19 @@ func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, err return f, args, nil } -func (s *stringer) String() string { - if s != nil && s.i != nil { - q := s.compileAndReplacePlaceholders(s.i.statement()) - q = reInvisibleChars.ReplaceAllString(q, ` `) - return strings.TrimSpace(q) - } - return "" -} - -func (s *stringer) compileAndReplacePlaceholders(stmt *exql.Statement) (query string) { - buf := stmt.Compile(s.t) - +func prepareQueryForDisplay(in string) (out string) { j := 1 - for i := range buf { - if buf[i] == '?' { - query = query + "$" + strconv.Itoa(j) + for i := range in { + if in[i] == '?' { + out = out + "$" + strconv.Itoa(j) j++ } else { - query = query + string(buf[i]) + out = out + string(in[i]) } } - return query + out = reInvisibleChars.ReplaceAllString(out, ` `) + return strings.TrimSpace(out) } func (iter *iterator) NextScan(dst ...interface{}) error { diff --git a/lib/sqlbuilder/delete.go b/lib/sqlbuilder/delete.go index 15f78746ec8761020ad8e99f5d2d5cb10e364e23..076162c641affd7bd851e4db0e5d89d4808cb8af 100644 --- a/lib/sqlbuilder/delete.go +++ b/lib/sqlbuilder/delete.go @@ -2,8 +2,8 @@ package sqlbuilder import ( "database/sql" - "strings" + "upper.io/db.v2/internal/immutable" "upper.io/db.v2/internal/sqladapter/exql" ) @@ -32,47 +32,27 @@ func (dq *deleterQuery) statement() *exql.Statement { } type deleter struct { - *stringer builder *sqlBuilder fn func(*deleterQuery) error prev *deleter } +var _ = immutable.Immutable(&deleter{}) + func (del *deleter) Builder() *sqlBuilder { - p := &del - for { - if (*p).builder != nil { - return (*p).builder - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev + if del.prev == nil { + return del.builder } + return del.prev.Builder() } -func (del *deleter) Stringer() *stringer { - p := &del - for { - if (*p).stringer != nil { - return (*p).stringer - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev - } +func (del *deleter) template() *exql.Template { + return del.Builder().t.Template } func (del *deleter) String() string { - query, err := del.build() - if err != nil { - return "" - } - q := del.Stringer().compileAndReplacePlaceholders(query.statement()) - q = reInvisibleChars.ReplaceAllString(q, ` `) - return strings.TrimSpace(q) + return prepareQueryForDisplay(del.Compile()) } func (del *deleter) setTable(table string) *deleter { @@ -124,25 +104,31 @@ func (del *deleter) statement() *exql.Statement { } func (del *deleter) build() (*deleterQuery, error) { - iq, err := deleterFastForward(&deleterQuery{}, del) + dq, err := immutable.FastForward(del) if err != nil { return nil, err } - return iq, nil + return dq.(*deleterQuery), nil } func (del *deleter) Compile() string { - return del.statement().Compile(del.Stringer().t) + return del.statement().Compile(del.template()) } -func deleterFastForward(in *deleterQuery, curr *deleter) (*deleterQuery, error) { - if curr == nil || curr.fn == nil { - return in, nil +func (del *deleter) Prev() immutable.Immutable { + if del == nil { + return nil } - in, err := deleterFastForward(in, curr.prev) - if err != nil { - return nil, err + return del.prev +} + +func (del *deleter) Fn(in interface{}) error { + if del.fn == nil { + return nil } - err = curr.fn(in) - return in, err + return del.fn(in.(*deleterQuery)) +} + +func (del *deleter) Base() interface{} { + return &deleterQuery{} } diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index 0a29b607137f7a99174bfad1638551f569542bfd..9cdee4d2dd11d64bc5c49747f9a7159fcea3aae4 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -2,8 +2,8 @@ package sqlbuilder import ( "database/sql" - "strings" + "upper.io/db.v2/internal/immutable" "upper.io/db.v2/internal/sqladapter/exql" ) @@ -99,24 +99,15 @@ func (iq *inserterQuery) statement() *exql.Statement { return stmt } -func columnsToFragments(dst *[]exql.Fragment, columns []string) error { - l := len(columns) - f := make([]exql.Fragment, l) - for i := 0; i < l; i++ { - f[i] = exql.ColumnWithName(columns[i]) - } - *dst = append(*dst, f...) - return nil -} - type inserter struct { builder *sqlBuilder - *stringer fn func(*inserterQuery) error prev *inserter } +var _ = immutable.Immutable(&inserter{}) + func (ins *inserter) Builder() *sqlBuilder { if ins.prev == nil { return ins.builder @@ -124,35 +115,20 @@ func (ins *inserter) Builder() *sqlBuilder { return ins.prev.Builder() } -func (ins *inserter) Stringer() *stringer { - if ins.prev == nil { - return ins.stringer - } - return ins.prev.Stringer() +func (ins *inserter) template() *exql.Template { + return ins.Builder().t.Template } func (ins *inserter) String() string { - query, err := ins.build() - if err != nil { - return "" - } - q := ins.Stringer().compileAndReplacePlaceholders(query.statement()) - q = reInvisibleChars.ReplaceAllString(q, ` `) - return strings.TrimSpace(q) + return prepareQueryForDisplay(ins.Compile()) } func (ins *inserter) frame(fn func(*inserterQuery) error) *inserter { return &inserter{prev: ins, fn: fn} } -func (ins *inserter) clone() *inserter { - clone := &inserter{} - *clone = *ins - return clone -} - func (ins *inserter) Batch(n int) *BatchInserter { - return newBatchInserter(ins.clone(), n) + return newBatchInserter(ins, n) } func (ins *inserter) Arguments() []interface{} { @@ -226,26 +202,43 @@ func (ins *inserter) statement() *exql.Statement { } func (ins *inserter) build() (*inserterQuery, error) { - iq, err := inserterFastForward(&inserterQuery{}, ins) + iq, err := immutable.FastForward(ins) if err != nil { return nil, err } - iq.values, iq.arguments = iq.processValues() - return iq, nil + ret := iq.(*inserterQuery) + ret.values, ret.arguments = ret.processValues() + return ret, nil } func (ins *inserter) Compile() string { - return ins.statement().Compile(ins.Stringer().t) + return ins.statement().Compile(ins.template()) } -func inserterFastForward(in *inserterQuery, curr *inserter) (*inserterQuery, error) { - if curr == nil || curr.fn == nil { - return in, nil +func (ins *inserter) Prev() immutable.Immutable { + if ins == nil { + return nil } - in, err := inserterFastForward(in, curr.prev) - if err != nil { - return nil, err + return ins.prev +} + +func (ins *inserter) Fn(in interface{}) error { + if ins.fn == nil { + return nil } - err = curr.fn(in) - return in, err + return ins.fn(in.(*inserterQuery)) +} + +func (ins *inserter) Base() interface{} { + return &inserterQuery{} +} + +func columnsToFragments(dst *[]exql.Fragment, columns []string) error { + l := len(columns) + f := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + f[i] = exql.ColumnWithName(columns[i]) + } + *dst = append(*dst, f...) + return nil } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index d560fb85ade49e7ff98f8a0a0d286cd8acc5fec4..4ebffd80daf49f399d930923411b825e7f34b5ec 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -7,6 +7,7 @@ import ( "strings" "upper.io/db.v2" + "upper.io/db.v2/internal/immutable" "upper.io/db.v2/internal/sqladapter/exql" ) @@ -107,46 +108,22 @@ func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { type selector struct { builder *sqlBuilder - *stringer fn func(*selectorQuery) error prev *selector } -func (sel *selector) Builder() *sqlBuilder { - p := &sel - for { - if (*p).builder != nil { - return (*p).builder - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev - } -} +var _ = immutable.Immutable(&inserter{}) -func (sel *selector) Stringer() *stringer { - p := &sel - for { - if (*p).stringer != nil { - return (*p).stringer - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev +func (sel *selector) Builder() *sqlBuilder { + if sel.prev == nil { + return sel.builder } + return sel.prev.Builder() } func (sel *selector) String() string { - query, err := sel.build() - if err != nil { - return "" - } - q := sel.Stringer().compileAndReplacePlaceholders(query.statement()) - q = reInvisibleChars.ReplaceAllString(q, ` `) - return strings.TrimSpace(q) + return prepareQueryForDisplay(sel.Compile()) } func (sel *selector) frame(fn func(*selectorQuery) error) *selector { @@ -386,6 +363,10 @@ func (sel *selector) Offset(n int) Selector { }) } +func (sel *selector) template() *exql.Template { + return sel.Builder().t.Template +} + func (sel *selector) As(alias string) Selector { return sel.frame(func(sq *selectorQuery) error { if sq.table == nil { @@ -393,7 +374,7 @@ func (sel *selector) As(alias string) Selector { } last := len(sq.table.Columns) - 1 if raw, ok := sq.table.Columns[last].(*exql.Raw); ok { - sq.table.Columns[last] = exql.RawValue("(" + raw.Value + ") AS " + exql.ColumnWithName(alias).Compile(sel.Stringer().t)) + sq.table.Columns[last] = exql.RawValue("(" + raw.Value + ") AS " + exql.ColumnWithName(alias).Compile(sel.template())) } return nil }) @@ -440,25 +421,31 @@ func (sel *selector) One(dest interface{}) error { } func (sel *selector) build() (*selectorQuery, error) { - sq, err := selectorFastForward(&selectorQuery{}, sel) + sq, err := immutable.FastForward(sel) if err != nil { return nil, err } - return sq, nil + return sq.(*selectorQuery), nil } func (sel *selector) Compile() string { - return sel.statement().Compile(sel.Stringer().t) + return sel.statement().Compile(sel.template()) } -func selectorFastForward(in *selectorQuery, curr *selector) (*selectorQuery, error) { - if curr == nil || curr.fn == nil { - return in, nil +func (sel *selector) Prev() immutable.Immutable { + if sel == nil { + return nil } - in, err := selectorFastForward(in, curr.prev) - if err != nil { - return nil, err + return sel.prev +} + +func (sel *selector) Fn(in interface{}) error { + if sel.fn == nil { + return nil } - err = curr.fn(in) - return in, err + return sel.fn(in.(*selectorQuery)) +} + +func (sel *selector) Base() interface{} { + return &selectorQuery{} } diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index f3bb9f1799e2e0b15bb6a63e7aa732c7a4757f01..5e831e1acadf3f11ca19b905e383bb1a587a9eab 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -2,8 +2,8 @@ package sqlbuilder import ( "database/sql" - "strings" + "upper.io/db.v2/internal/immutable" "upper.io/db.v2/internal/sqladapter/exql" ) @@ -45,47 +45,27 @@ func (uq *updaterQuery) arguments() []interface{} { } type updater struct { - *stringer builder *sqlBuilder fn func(*updaterQuery) error prev *updater } +var _ = immutable.Immutable(&updater{}) + func (upd *updater) Builder() *sqlBuilder { - p := &upd - for { - if (*p).builder != nil { - return (*p).builder - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev + if upd.prev == nil { + return upd.builder } + return upd.prev.Builder() } -func (upd *updater) Stringer() *stringer { - p := &upd - for { - if (*p).stringer != nil { - return (*p).stringer - } - if (*p).prev == nil { - return nil - } - p = &(*p).prev - } +func (upd *updater) template() *exql.Template { + return upd.Builder().t.Template } func (upd *updater) String() string { - query, err := upd.build() - if err != nil { - return "" - } - q := upd.Stringer().compileAndReplacePlaceholders(query.statement()) - q = reInvisibleChars.ReplaceAllString(q, ` `) - return strings.TrimSpace(q) + return prepareQueryForDisplay(upd.Compile()) } func (upd *updater) setTable(table string) *updater { @@ -174,25 +154,31 @@ func (upd *updater) statement() *exql.Statement { } func (upd *updater) build() (*updaterQuery, error) { - iq, err := updaterFastForward(&updaterQuery{}, upd) + uq, err := immutable.FastForward(upd) if err != nil { return nil, err } - return iq, nil + return uq.(*updaterQuery), nil } func (upd *updater) Compile() string { - return upd.statement().Compile(upd.Stringer().t) + return upd.statement().Compile(upd.template()) } -func updaterFastForward(in *updaterQuery, curr *updater) (*updaterQuery, error) { - if curr == nil || curr.fn == nil { - return in, nil +func (upd *updater) Prev() immutable.Immutable { + if upd == nil { + return nil } - in, err := updaterFastForward(in, curr.prev) - if err != nil { - return nil, err + return upd.prev +} + +func (upd *updater) Fn(in interface{}) error { + if upd.fn == nil { + return nil } - err = curr.fn(in) - return in, err + return upd.fn(in.(*updaterQuery)) +} + +func (upd *updater) Base() interface{} { + return &updaterQuery{} }