From c9da34336eeb2f38e9fabbe84f7b5a8a48d01f3e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 13 Dec 2016 15:27:27 +0000
Subject: [PATCH] Make db.Result immutable

---
 internal/sqladapter/result.go | 233 +++++++++++++++++++++++++---------
 1 file changed, 172 insertions(+), 61 deletions(-)

diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go
index 108ec25b..fafb6c2a 100644
--- a/internal/sqladapter/result.go
+++ b/internal/sqladapter/result.go
@@ -23,16 +23,27 @@ package sqladapter
 
 import (
 	"sync"
+	"sync/atomic"
 
 	"upper.io/db.v2"
 	"upper.io/db.v2/lib/sqlbuilder"
 )
 
-// Result represents a delimited set of items bound by a condition.
 type Result struct {
-	b       sqlbuilder.Builder
+	b sqlbuilder.Builder
+
+	err atomic.Value
+
+	iter   sqlbuilder.Iterator
+	iterMu sync.Mutex
+
+	prev *Result
+	fn   func(*result) error
+}
+
+// result represents a delimited set of items bound by a condition.
+type result struct {
 	table   string
-	iter    sqlbuilder.Iterator
 	limit   int
 	offset  int
 	fields  []interface{}
@@ -40,9 +51,6 @@ type Result struct {
 	orderBy []interface{}
 	groupBy []interface{}
 	conds   []interface{}
-	err     error
-	errMu   sync.RWMutex
-	iterMu  sync.Mutex
 }
 
 func filter(conds []interface{}) []interface{} {
@@ -52,93 +60,140 @@ func filter(conds []interface{}) []interface{} {
 // NewResult creates and Results a new Result set on the given table, this set
 // is limited by the given exql.Where conditions.
 func NewResult(b sqlbuilder.Builder, table string, conds []interface{}) *Result {
-	return &Result{
-		b:     b,
-		table: table,
-		conds: conds,
+	r := &Result{
+		b: b,
+	}
+	return r.from(table).where(conds)
+}
+
+func (r *Result) frame(fn func(*result) error) *Result {
+	return &Result{prev: r, fn: fn}
+}
+
+func (r *Result) builder() sqlbuilder.Builder {
+	p := &r
+	for {
+		if (*p).b != nil {
+			return (*p).b
+		}
+		if (*p).prev == nil {
+			return nil
+		}
+		p = &(*p).prev
 	}
 }
 
+func (r *Result) from(table string) *Result {
+	return r.frame(func(res *result) error {
+		res.table = table
+		return nil
+	})
+}
+
+func (r *Result) where(conds []interface{}) *Result {
+	return r.frame(func(res *result) error {
+		res.conds = conds
+		return nil
+	})
+}
+
 func (r *Result) setErr(err error) error {
 	if err == nil {
 		return nil
 	}
-
-	r.errMu.Lock()
-	defer r.errMu.Unlock()
-
-	r.err = err
+	r.err.Store(err)
 	return err
 }
 
 // Err returns the last error that has happened with the result set,
 // nil otherwise
 func (r *Result) Err() error {
-	r.errMu.RLock()
-	defer r.errMu.RUnlock()
-	return r.err
+	if errV := r.err.Load(); errV != nil {
+		return errV.(error)
+	}
+	return nil
 }
 
 // Where sets conditions for the result set.
 func (r *Result) Where(conds ...interface{}) db.Result {
-	r.conds = conds
-	return r
+	return r.where(conds)
 }
 
 // And adds more conditions on top of the existing ones.
 func (r *Result) And(conds ...interface{}) db.Result {
-	r.conds = append(r.conds, conds...)
-	return r
+	return r.frame(func(res *result) error {
+		res.conds = append(res.conds, conds...)
+		return nil
+	})
 }
 
 // Limit determines the maximum limit of Results to be returned.
 func (r *Result) Limit(n int) db.Result {
-	r.limit = n
-	return r
+	return r.frame(func(res *result) error {
+		res.limit = n
+		return nil
+	})
 }
 
 // Offset determines how many documents will be skipped before starting to grab
 // Results.
 func (r *Result) Offset(n int) db.Result {
-	r.offset = n
-	return r
+	return r.frame(func(res *result) error {
+		res.offset = n
+		return nil
+	})
 }
 
 // Group is used to group Results that have the same value in the same column
 // or columns.
 func (r *Result) Group(fields ...interface{}) db.Result {
-	r.groupBy = fields
-	return r
+	return r.frame(func(res *result) error {
+		res.groupBy = fields
+		return nil
+	})
 }
 
 // OrderBy determines sorting of Results according to the provided names. Fields
 // may be prefixed by - (minus) which means descending order, ascending order
 // would be used otherwise.
 func (r *Result) OrderBy(fields ...interface{}) db.Result {
-	r.orderBy = fields
-	return r
+	return r.frame(func(res *result) error {
+		res.orderBy = fields
+		return nil
+	})
 }
 
 // Select determines which fields to return.
 func (r *Result) Select(fields ...interface{}) db.Result {
-	r.fields = fields
-	return r
+	return r.frame(func(res *result) error {
+		res.fields = fields
+		return nil
+	})
 }
 
 // String satisfies fmt.Stringer
 func (r *Result) String() string {
-	return r.buildSelect().String()
+	query, _ := r.buildSelect()
+	return query.String()
 }
 
 // All dumps all Results into a pointer to an slice of structs or maps.
 func (r *Result) All(dst interface{}) error {
-	err := r.buildSelect().Iterator().All(dst)
+	query, err := r.buildSelect()
+	if err != nil {
+		return r.setErr(err)
+	}
+	err = query.Iterator().All(dst)
 	return r.setErr(err)
 }
 
 // One fetches only one Result from the set.
 func (r *Result) One(dst interface{}) error {
-	err := r.buildSelect().Iterator().One(dst)
+	query, err := r.buildSelect()
+	if err != nil {
+		return r.setErr(err)
+	}
+	err = query.Iterator().One(dst)
 	return r.setErr(err)
 }
 
@@ -148,24 +203,33 @@ func (r *Result) Next(dst interface{}) bool {
 	defer r.iterMu.Unlock()
 
 	if r.iter == nil {
-		r.iter = r.buildSelect().Iterator()
+		query, err := r.buildSelect()
+		if err != nil {
+			r.setErr(err)
+			return false
+		}
+		r.iter = query.Iterator()
 	}
+
 	if r.iter.Next(dst) {
 		return true
 	}
+
 	if err := r.iter.Err(); err != db.ErrNoMoreRows {
 		r.setErr(err)
 	}
+
 	return false
 }
 
 // Delete deletes all matching items from the collection.
 func (r *Result) Delete() error {
-	q := r.b.DeleteFrom(r.table).
-		Where(filter(r.conds)...).
-		Limit(r.limit)
+	query, err := r.buildDelete()
+	if err != nil {
+		return r.setErr(err)
+	}
 
-	_, err := q.Exec()
+	_, err = query.Exec()
 	return r.setErr(err)
 }
 
@@ -180,28 +244,26 @@ func (r *Result) Close() error {
 // Update updates matching items from the collection with values of the given
 // map or struct.
 func (r *Result) Update(values interface{}) error {
-	q := r.b.Update(r.table).
-		Set(values).
-		Where(filter(r.conds)...).
-		Limit(r.limit)
+	query, err := r.buildUpdate(values)
+	if err != nil {
+		return r.setErr(err)
+	}
 
-	_, err := q.Exec()
+	_, err = query.Exec()
 	return r.setErr(err)
 }
 
 // Count counts the elements on the set.
 func (r *Result) Count() (uint64, error) {
+	query, err := r.buildCount()
+	if err != nil {
+		return 0, r.setErr(err)
+	}
+
 	counter := struct {
 		Count uint64 `db:"_t"`
 	}{}
-
-	q := r.b.Select(db.Raw("count(1) AS _t")).
-		From(r.table).
-		Where(filter(r.conds)...).
-		GroupBy(r.groupBy...).
-		Limit(1)
-
-	if err := q.Iterator().One(&counter); err != nil {
+	if err := query.Iterator().One(&counter); err != nil {
 		if err == db.ErrNoMoreRows {
 			return 0, nil
 		}
@@ -211,12 +273,61 @@ func (r *Result) Count() (uint64, error) {
 	return counter.Count, nil
 }
 
-func (r *Result) buildSelect() sqlbuilder.Selector {
-	return r.b.Select(r.fields...).
-		From(r.table).
-		Where(filter(r.conds)...).
-		Limit(r.limit).
-		Offset(r.offset).
-		GroupBy(r.groupBy...).
-		OrderBy(r.orderBy...)
+func (r *Result) buildSelect() (sqlbuilder.Selector, error) {
+	res, err := resultFastForward(&result{}, r)
+
+	sel := r.builder().Select(res.fields...).
+		From(res.table).
+		Where(filter(res.conds)...).
+		Limit(res.limit).
+		Offset(res.offset).
+		GroupBy(res.groupBy...).
+		OrderBy(res.orderBy...)
+
+	return sel, err
+}
+
+func (r *Result) buildDelete() (sqlbuilder.Deleter, error) {
+	res, err := resultFastForward(&result{}, r)
+
+	del := r.builder().DeleteFrom(res.table).
+		Where(filter(res.conds)...).
+		Limit(res.limit)
+
+	return del, err
+}
+
+func (r *Result) buildUpdate(values interface{}) (sqlbuilder.Updater, error) {
+	res, err := resultFastForward(&result{}, r)
+
+	upd := r.builder().Update(res.table).
+		Set(values).
+		Where(filter(res.conds)...).
+		Limit(res.limit)
+
+	return upd, err
+}
+
+func (r *Result) buildCount() (sqlbuilder.Selector, error) {
+	res, err := resultFastForward(&result{}, r)
+
+	sel := r.builder().Select(db.Raw("count(1) AS _t")).
+		From(res.table).
+		Where(filter(res.conds)...).
+		GroupBy(res.groupBy...).
+		Limit(1)
+
+	return sel, err
+}
+
+func resultFastForward(in *result, curr *Result) (*result, error) {
+	if curr == nil || curr.fn == nil {
+		return in, nil
+	}
+	in, err := resultFastForward(in, curr.prev)
+	if err != nil {
+		return nil, err
+	}
+	err = curr.fn(in)
+	return in, err
 }
-- 
GitLab