From c9da34336eeb2f38e9fabbe84f7b5a8a48d01f3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Tue, 13 Dec 2016 15:27:27 +0000 Subject: [PATCH] Make db.Result immutable --- internal/sqladapter/result.go | 233 +++++++++++++++++++++++++--------- 1 file changed, 172 insertions(+), 61 deletions(-) diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go index 108ec25b..fafb6c2a 100644 --- a/internal/sqladapter/result.go +++ b/internal/sqladapter/result.go @@ -23,16 +23,27 @@ package sqladapter import ( "sync" + "sync/atomic" "upper.io/db.v2" "upper.io/db.v2/lib/sqlbuilder" ) -// Result represents a delimited set of items bound by a condition. type Result struct { - b sqlbuilder.Builder + b sqlbuilder.Builder + + err atomic.Value + + iter sqlbuilder.Iterator + iterMu sync.Mutex + + prev *Result + fn func(*result) error +} + +// result represents a delimited set of items bound by a condition. +type result struct { table string - iter sqlbuilder.Iterator limit int offset int fields []interface{} @@ -40,9 +51,6 @@ type Result struct { orderBy []interface{} groupBy []interface{} conds []interface{} - err error - errMu sync.RWMutex - iterMu sync.Mutex } func filter(conds []interface{}) []interface{} { @@ -52,93 +60,140 @@ func filter(conds []interface{}) []interface{} { // NewResult creates and Results a new Result set on the given table, this set // is limited by the given exql.Where conditions. func NewResult(b sqlbuilder.Builder, table string, conds []interface{}) *Result { - return &Result{ - b: b, - table: table, - conds: conds, + r := &Result{ + b: b, + } + return r.from(table).where(conds) +} + +func (r *Result) frame(fn func(*result) error) *Result { + return &Result{prev: r, fn: fn} +} + +func (r *Result) builder() sqlbuilder.Builder { + p := &r + for { + if (*p).b != nil { + return (*p).b + } + if (*p).prev == nil { + return nil + } + p = &(*p).prev } } +func (r *Result) from(table string) *Result { + return r.frame(func(res *result) error { + res.table = table + return nil + }) +} + +func (r *Result) where(conds []interface{}) *Result { + return r.frame(func(res *result) error { + res.conds = conds + return nil + }) +} + func (r *Result) setErr(err error) error { if err == nil { return nil } - - r.errMu.Lock() - defer r.errMu.Unlock() - - r.err = err + r.err.Store(err) return err } // Err returns the last error that has happened with the result set, // nil otherwise func (r *Result) Err() error { - r.errMu.RLock() - defer r.errMu.RUnlock() - return r.err + if errV := r.err.Load(); errV != nil { + return errV.(error) + } + return nil } // Where sets conditions for the result set. func (r *Result) Where(conds ...interface{}) db.Result { - r.conds = conds - return r + return r.where(conds) } // And adds more conditions on top of the existing ones. func (r *Result) And(conds ...interface{}) db.Result { - r.conds = append(r.conds, conds...) - return r + return r.frame(func(res *result) error { + res.conds = append(res.conds, conds...) + return nil + }) } // Limit determines the maximum limit of Results to be returned. func (r *Result) Limit(n int) db.Result { - r.limit = n - return r + return r.frame(func(res *result) error { + res.limit = n + return nil + }) } // Offset determines how many documents will be skipped before starting to grab // Results. func (r *Result) Offset(n int) db.Result { - r.offset = n - return r + return r.frame(func(res *result) error { + res.offset = n + return nil + }) } // Group is used to group Results that have the same value in the same column // or columns. func (r *Result) Group(fields ...interface{}) db.Result { - r.groupBy = fields - return r + return r.frame(func(res *result) error { + res.groupBy = fields + return nil + }) } // OrderBy determines sorting of Results according to the provided names. Fields // may be prefixed by - (minus) which means descending order, ascending order // would be used otherwise. func (r *Result) OrderBy(fields ...interface{}) db.Result { - r.orderBy = fields - return r + return r.frame(func(res *result) error { + res.orderBy = fields + return nil + }) } // Select determines which fields to return. func (r *Result) Select(fields ...interface{}) db.Result { - r.fields = fields - return r + return r.frame(func(res *result) error { + res.fields = fields + return nil + }) } // String satisfies fmt.Stringer func (r *Result) String() string { - return r.buildSelect().String() + query, _ := r.buildSelect() + return query.String() } // All dumps all Results into a pointer to an slice of structs or maps. func (r *Result) All(dst interface{}) error { - err := r.buildSelect().Iterator().All(dst) + query, err := r.buildSelect() + if err != nil { + return r.setErr(err) + } + err = query.Iterator().All(dst) return r.setErr(err) } // One fetches only one Result from the set. func (r *Result) One(dst interface{}) error { - err := r.buildSelect().Iterator().One(dst) + query, err := r.buildSelect() + if err != nil { + return r.setErr(err) + } + err = query.Iterator().One(dst) return r.setErr(err) } @@ -148,24 +203,33 @@ func (r *Result) Next(dst interface{}) bool { defer r.iterMu.Unlock() if r.iter == nil { - r.iter = r.buildSelect().Iterator() + query, err := r.buildSelect() + if err != nil { + r.setErr(err) + return false + } + r.iter = query.Iterator() } + if r.iter.Next(dst) { return true } + if err := r.iter.Err(); err != db.ErrNoMoreRows { r.setErr(err) } + return false } // Delete deletes all matching items from the collection. func (r *Result) Delete() error { - q := r.b.DeleteFrom(r.table). - Where(filter(r.conds)...). - Limit(r.limit) + query, err := r.buildDelete() + if err != nil { + return r.setErr(err) + } - _, err := q.Exec() + _, err = query.Exec() return r.setErr(err) } @@ -180,28 +244,26 @@ func (r *Result) Close() error { // Update updates matching items from the collection with values of the given // map or struct. func (r *Result) Update(values interface{}) error { - q := r.b.Update(r.table). - Set(values). - Where(filter(r.conds)...). - Limit(r.limit) + query, err := r.buildUpdate(values) + if err != nil { + return r.setErr(err) + } - _, err := q.Exec() + _, err = query.Exec() return r.setErr(err) } // Count counts the elements on the set. func (r *Result) Count() (uint64, error) { + query, err := r.buildCount() + if err != nil { + return 0, r.setErr(err) + } + counter := struct { Count uint64 `db:"_t"` }{} - - q := r.b.Select(db.Raw("count(1) AS _t")). - From(r.table). - Where(filter(r.conds)...). - GroupBy(r.groupBy...). - Limit(1) - - if err := q.Iterator().One(&counter); err != nil { + if err := query.Iterator().One(&counter); err != nil { if err == db.ErrNoMoreRows { return 0, nil } @@ -211,12 +273,61 @@ func (r *Result) Count() (uint64, error) { return counter.Count, nil } -func (r *Result) buildSelect() sqlbuilder.Selector { - return r.b.Select(r.fields...). - From(r.table). - Where(filter(r.conds)...). - Limit(r.limit). - Offset(r.offset). - GroupBy(r.groupBy...). - OrderBy(r.orderBy...) +func (r *Result) buildSelect() (sqlbuilder.Selector, error) { + res, err := resultFastForward(&result{}, r) + + sel := r.builder().Select(res.fields...). + From(res.table). + Where(filter(res.conds)...). + Limit(res.limit). + Offset(res.offset). + GroupBy(res.groupBy...). + OrderBy(res.orderBy...) + + return sel, err +} + +func (r *Result) buildDelete() (sqlbuilder.Deleter, error) { + res, err := resultFastForward(&result{}, r) + + del := r.builder().DeleteFrom(res.table). + Where(filter(res.conds)...). + Limit(res.limit) + + return del, err +} + +func (r *Result) buildUpdate(values interface{}) (sqlbuilder.Updater, error) { + res, err := resultFastForward(&result{}, r) + + upd := r.builder().Update(res.table). + Set(values). + Where(filter(res.conds)...). + Limit(res.limit) + + return upd, err +} + +func (r *Result) buildCount() (sqlbuilder.Selector, error) { + res, err := resultFastForward(&result{}, r) + + sel := r.builder().Select(db.Raw("count(1) AS _t")). + From(res.table). + Where(filter(res.conds)...). + GroupBy(res.groupBy...). + Limit(1) + + return sel, err +} + +func resultFastForward(in *result, curr *Result) (*result, error) { + if curr == nil || curr.fn == nil { + return in, nil + } + in, err := resultFastForward(in, curr.prev) + if err != nil { + return nil, err + } + err = curr.fn(in) + return in, err } -- GitLab