From ec7400c701e15fa29a244754dd9bf3ad9383ef17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Thu, 8 Dec 2016 18:43:06 -0600 Subject: [PATCH] Make Selector immutable --- lib/sqlbuilder/builder.go | 8 +- lib/sqlbuilder/builder_test.go | 17 +- lib/sqlbuilder/convert.go | 216 ++++++++++++- lib/sqlbuilder/select.go | 538 +++++++++++++++++++-------------- 4 files changed, 541 insertions(+), 238 deletions(-) diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index b5eb8623..4c5cf387 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -26,6 +26,10 @@ var defaultMapOptions = MapOptions{ IncludeNil: false, } +type hasStringer interface { + Stringer() *stringer +} + type hasIsZero interface { IsZero() bool } @@ -312,7 +316,7 @@ func extractArguments(fragments []interface{}) []interface{} { return args } -func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql.Fragment, []interface{}, error) { +func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, error) { l := len(columns) f := make([]exql.Fragment, l) args := []interface{}{} @@ -320,7 +324,7 @@ func columnFragments(template *templateWithUtils, columns []interface{}) ([]exql for i := 0; i < l; i++ { switch v := columns[i].(type) { case *selector: - expanded, rawArgs := expandPlaceholders(v.statement().Compile(v.stringer.t), v.Arguments()...) + expanded, rawArgs := expandPlaceholders(v.Compile(), v.Arguments()...) f[i] = exql.RawValue(expanded) args = append(args, rawArgs...) case db.Function: diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index 22f5d0f4..44d5b60c 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -627,8 +627,22 @@ func TestSelect(t *testing.T) { From("user_access"). Where(db.Cond{"hub_id": 3}) + // Don't reassign sq.And(db.Cond{"role": []int{1, 2}}) + assert.Equal( + `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1)`, + sq.String(), + ) + + assert.Equal( + []interface{}{3}, + sq.Arguments(), + ) + + // Reassign + sq = sq.And(db.Cond{"role": []int{1, 2}}) + assert.Equal( `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))`, sq.String(), @@ -652,7 +666,7 @@ func TestSelect(t *testing.T) { Where(cond) search := "word" - sel.And(db.Or( + sel = sel.And(db.Or( db.Raw("COALESCE(NULLIF(ml.name,''), a.name) ILIKE ?", fmt.Sprintf("%%%s%%", search)), db.Cond{"a.email ILIKE": fmt.Sprintf("%%%s%%", search)}, )) @@ -666,7 +680,6 @@ func TestSelect(t *testing.T) { []interface{}{3, 1, 2, 4, 5, 6, `%word%`, `%word%`}, sel.Arguments(), ) - } } diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index af03f883..8f512bdc 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -45,7 +45,7 @@ func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) case db.RawValue: k, values = t.Raw(), nil case *selector: - k, values = `(`+t.statement().Compile(t.stringer.t)+`)`, t.Arguments() + k, values = `(`+t.Compile()+`)`, t.Arguments() } } else if len(values) == 0 { k = `NULL` @@ -369,3 +369,217 @@ func (tu *templateWithUtils) ToColumnsValuesAndArguments(columnNames []string, c return columns, values, arguments, nil } + +// ToWhereWithArguments converts the given parameters into a exql.Where +// value. +func toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + if len(t) > 0 { + if s, ok := t[0].(string); ok { + if strings.ContainsAny(s, "?") || len(t) == 1 { + s, args = expandPlaceholders(s, t[1:]...) + where.Conditions = []exql.Fragment{exql.RawValue(s)} + } else { + var val interface{} + key := s + if len(t) > 2 { + val = t[1:] + } else { + val = t[1] + } + cv, v := toColumnValues(db.NewConstraint(key, val)) + args = append(args, v...) + for i := range cv.ColumnValues { + where.Conditions = append(where.Conditions, cv.ColumnValues[i]) + } + } + return + } + } + for i := range t { + w, v := toWhereWithArguments(t[i]) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case db.RawValue: + r, v := expandPlaceholders(t.Raw(), t.Arguments()...) + where.Conditions = []exql.Fragment{exql.RawValue(r)} + args = append(args, v...) + return + case db.Constraints: + for _, c := range t.Constraints() { + w, v := toWhereWithArguments(c) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case db.Compound: + var cond exql.Where + + for _, c := range t.Sentences() { + w, v := toWhereWithArguments(c) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + cond.Conditions = append(cond.Conditions, w.Conditions...) + } + + if len(cond.Conditions) > 0 { + var frag exql.Fragment + switch t.Operator() { + case db.OperatorNone, db.OperatorAnd: + q := exql.And(cond) + frag = &q + case db.OperatorOr: + q := exql.Or(cond) + frag = &q + default: + panic(fmt.Sprintf("Unknown type %T", t)) + } + where.Conditions = append(where.Conditions, frag) + } + + return + case db.Constraint: + cv, v := toColumnValues(t) + args = append(args, v...) + where.Conditions = append(where.Conditions, cv.ColumnValues...) + return where, args + } + + panic(fmt.Sprintf("Unknown condition type %T", term)) +} + +func toColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + l := len(t) + for i := 0; i < l; i++ { + column := t[i].(string) + + if !strings.ContainsAny(column, "=") { + column = fmt.Sprintf("%s = ?", column) + } + + chunks := strings.SplitN(column, "=", 2) + + column = chunks[0] + format := strings.TrimSpace(chunks[1]) + + columnValue := exql.ColumnValue{ + Column: exql.ColumnWithName(column), + Operator: "=", + Value: exql.RawValue(format), + } + + ps := strings.Count(format, "?") + if i+ps < l { + for j := 0; j < ps; j++ { + args = append(args, t[i+j+1]) + } + i = i + ps + } else { + panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format)) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + } + return cv, args + case db.Constraint: + columnValue := exql.ColumnValue{} + + // Guessing operator from input, or using a default one. + if column, ok := t.Key().(string); ok { + chunks := strings.SplitN(strings.TrimSpace(column), ` `, 2) + columnValue.Column = exql.ColumnWithName(chunks[0]) + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } + } else { + if rawValue, ok := t.Key().(db.RawValue); ok { + columnValue.Column = exql.RawValue(rawValue.Raw()) + args = append(args, rawValue.Arguments()...) + } else { + columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) + } + } + + switch value := t.Value().(type) { + case db.Function: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + // A function with no arguments. + fnName = fnName + "()" + } else { + // A function with one or more arguments. + fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" + } + expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) + columnValue.Value = exql.RawValue(expanded) + args = append(args, fnArgs...) + case db.RawValue: + expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...) + columnValue.Value = exql.RawValue(expanded) + args = append(args, rawArgs...) + default: + v, isSlice := toInterfaceArguments(value) + + if isSlice { + if columnValue.Operator == "" { + columnValue.Operator = sqlInOperator + } + if len(v) > 0 { + // Array value given. + columnValue.Value = exql.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } else { + // Single value given. + columnValue.Value = exql.RawValue(`(NULL)`) + } + args = append(args, v...) + } else { + if v == nil { + // Nil value given. + columnValue.Value = sqlNull + if columnValue.Operator == "" { + columnValue.Operator = sqlIsOperator + } + } else { + columnValue.Value = sqlPlaceholder + args = append(args, v...) + } + } + + } + + // Using guessed operator if no operator was given. + if columnValue.Operator == "" { + columnValue.Operator = sqlDefaultOperator + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + + return cv, args + case db.Constraints: + for _, c := range t.Constraints() { + p, q := toColumnValues(c) + cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) + args = append(args, q...) + } + return cv, args + } + + panic(fmt.Sprintf("Unknown term type %T.", term)) +} diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 695748ee..b207c081 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "strings" - "sync" "upper.io/db.v2" "upper.io/db.v2/internal/sqladapter/exql" @@ -18,11 +17,8 @@ const ( selectModeDistinct ) -type selector struct { - *stringer - - mode selectMode - builder *sqlBuilder +type selectorQuery struct { + mode selectMode table *exql.Columns tableArgs []interface{} @@ -46,334 +42,410 @@ type selector struct { joins []*exql.Join joinsArgs []interface{} - - mu sync.Mutex - - err error } -func (qs *selector) From(tables ...interface{}) Selector { - f, args, err := columnFragments(qs.builder.t, tables) - if err != nil { - qs.setErr(err) - return qs - } - c := exql.JoinColumns(f...) +func (sq *selectorQuery) and(terms ...interface{}) error { + where, whereArgs := toWhereWithArguments(terms) - qs.mu.Lock() - qs.table = c - qs.tableArgs = args - qs.mu.Unlock() + if sq.where == nil { + sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} + } + sq.where.Append(&where) + sq.whereArgs = append(sq.whereArgs, whereArgs...) - return qs + return nil } -func (qs *selector) Columns(columns ...interface{}) Selector { - f, args, err := columnFragments(qs.builder.t, columns) - if err != nil { - qs.setErr(err) - return qs - } +type selector struct { + builder *sqlBuilder + *stringer - c := exql.JoinColumns(f...) + fn func(*selectorQuery) error + prev *selector +} - qs.mu.Lock() - if qs.columns != nil { - qs.columns.Append(c) - } else { - qs.columns = c +func (sel *selector) Stringer() *stringer { + p := &sel + for { + if (*p).stringer != nil { + return (*p).stringer + } + if (*p).prev == nil { + return nil + } + p = &(*p).prev } - qs.columnsArgs = append(qs.columnsArgs, args...) - qs.mu.Unlock() +} - return qs +func (sel *selector) String() string { + query, err := sel.build() + if err != nil { + return "" + } + q := sel.Stringer().compileAndReplacePlaceholders(query.statement()) + q = reInvisibleChars.ReplaceAllString(q, ` `) + return strings.TrimSpace(q) } -func (qs *selector) Distinct() Selector { - qs.mu.Lock() - qs.mode = selectModeDistinct - qs.mu.Unlock() - return qs +func (sel *selector) frame(fn func(*selectorQuery) error) *selector { + return &selector{prev: sel, fn: fn} } -func (qs *selector) Where(terms ...interface{}) Selector { - qs.mu.Lock() - qs.where, qs.whereArgs = &exql.Where{}, []interface{}{} - qs.mu.Unlock() - return qs.And(terms...) +func (sel *selector) From(tables ...interface{}) Selector { + return sel.frame( + func(sq *selectorQuery) error { + f, args, err := columnFragments(tables) + if err != nil { + return err + } + sq.table = exql.JoinColumns(f...) + sq.tableArgs = args + return nil + }, + ) } -func (qs *selector) And(terms ...interface{}) Selector { - where, whereArgs := qs.builder.t.ToWhereWithArguments(terms) +func (sel *selector) Columns(columns ...interface{}) Selector { + return sel.frame( + func(sq *selectorQuery) error { + f, args, err := columnFragments(columns) + if err != nil { + return err + } - qs.mu.Lock() - if qs.where == nil { - qs.where, qs.whereArgs = &exql.Where{}, []interface{}{} - } - qs.where.Append(&where) - qs.whereArgs = append(qs.whereArgs, whereArgs...) - qs.mu.Unlock() + c := exql.JoinColumns(f...) + + if sq.columns != nil { + sq.columns.Append(c) + } else { + sq.columns = c + } + + sq.columnsArgs = append(sq.columnsArgs, args...) + return nil + }, + ) +} + +func (sel *selector) Distinct() Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.mode = selectModeDistinct + return nil + }) +} - return qs +func (sel *selector) Where(terms ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} + return sq.and(terms...) + }) } -func (qs *selector) Arguments() []interface{} { - qs.mu.Lock() - defer qs.mu.Unlock() +func (sel *selector) And(terms ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.and(terms...) + return nil + }) +} +func (sq *selectorQuery) arguments() []interface{} { return joinArguments( - qs.tableArgs, - qs.columnsArgs, - qs.joinsArgs, - qs.whereArgs, - qs.groupByArgs, - qs.orderByArgs, + sq.tableArgs, + sq.columnsArgs, + sq.joinsArgs, + sq.whereArgs, + sq.groupByArgs, + sq.orderByArgs, ) } -func (qs *selector) GroupBy(columns ...interface{}) Selector { - fragments, args, err := columnFragments(qs.builder.t, columns) +func (sel *selector) Arguments() []interface{} { + sq, err := sel.build() if err != nil { - qs.setErr(err) - return qs + return nil + } + return sq.arguments() +} + +func (sq *selectorQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Select, + Table: sq.table, + Columns: sq.columns, + Limit: sq.limit, + Offset: sq.offset, + Where: sq.where, + OrderBy: sq.orderBy, + GroupBy: sq.groupBy, } - qs.mu.Lock() - if fragments != nil { - qs.groupBy = exql.GroupByColumns(fragments...) + if len(sq.joins) > 0 { + stmt.Joins = exql.JoinConditions(sq.joins...) } - qs.groupByArgs = args - qs.mu.Unlock() - return qs + return stmt } -func (qs *selector) OrderBy(columns ...interface{}) Selector { - var sortColumns exql.SortColumns +func (sel *selector) GroupBy(columns ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + fragments, args, err := columnFragments(columns) + if err != nil { + return err + } - for i := range columns { - var sort *exql.SortColumn + if fragments != nil { + sq.groupBy = exql.GroupByColumns(fragments...) + } + sq.groupByArgs = args - switch value := columns[i].(type) { - case db.RawValue: - col, args := expandPlaceholders(value.Raw(), value.Arguments()...) - sort = &exql.SortColumn{ - Column: exql.RawValue(col), - } - qs.mu.Lock() - qs.orderByArgs = append(qs.orderByArgs, args...) - qs.mu.Unlock() - case db.Function: - fnName, fnArgs := value.Name(), value.Arguments() - if len(fnArgs) == 0 { - fnName = fnName + "()" - } else { - fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" - } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) - sort = &exql.SortColumn{ - Column: exql.RawValue(expanded), - } - qs.mu.Lock() - qs.orderByArgs = append(qs.orderByArgs, fnArgs...) - qs.mu.Unlock() - case string: - if strings.HasPrefix(value, "-") { + return nil + }) +} + +func (sel *selector) OrderBy(columns ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + var sortColumns exql.SortColumns + + for i := range columns { + var sort *exql.SortColumn + + switch value := columns[i].(type) { + case db.RawValue: + col, args := expandPlaceholders(value.Raw(), value.Arguments()...) sort = &exql.SortColumn{ - Column: exql.ColumnWithName(value[1:]), - Order: exql.Descendent, + Column: exql.RawValue(col), } - } else { - chunks := strings.SplitN(value, " ", 2) - - order := exql.Ascendent - if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { - order = exql.Descendent + sq.orderByArgs = append(sq.orderByArgs, args...) + case db.Function: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" } - + expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) sort = &exql.SortColumn{ - Column: exql.ColumnWithName(chunks[0]), - Order: order, + Column: exql.RawValue(expanded), + } + sq.orderByArgs = append(sq.orderByArgs, fnArgs...) + case string: + if strings.HasPrefix(value, "-") { + sort = &exql.SortColumn{ + Column: exql.ColumnWithName(value[1:]), + Order: exql.Descendent, + } + } else { + chunks := strings.SplitN(value, " ", 2) + + order := exql.Ascendent + if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { + order = exql.Descendent + } + + sort = &exql.SortColumn{ + Column: exql.ColumnWithName(chunks[0]), + Order: order, + } } + default: + return fmt.Errorf("Can't sort by type %T", value) } - default: - qs.setErr(fmt.Errorf("Can't sort by type %T", value)) - return qs + sortColumns.Columns = append(sortColumns.Columns, sort) } - sortColumns.Columns = append(sortColumns.Columns, sort) - } - qs.mu.Lock() - qs.orderBy = &exql.OrderBy{ - SortColumns: &sortColumns, - } - qs.mu.Unlock() - - return qs + sq.orderBy = &exql.OrderBy{ + SortColumns: &sortColumns, + } + return nil + }) } -func (qs *selector) Using(columns ...interface{}) Selector { - qs.mu.Lock() - joins := len(qs.joins) - qs.mu.Unlock() +func (sel *selector) Using(columns ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { - if joins == 0 { - qs.setErr(errors.New(`Cannot use Using() without a preceding Join() expression.`)) - return qs - } + joins := len(sq.joins) - lastJoin := qs.joins[joins-1] - if lastJoin.On != nil { - qs.setErr(errors.New(`Cannot use Using() and On() with the same Join() expression.`)) - return qs - } + if joins == 0 { + return errors.New(`Cannot use Using() without a preceding Join() expression.`) + } - fragments, args, err := columnFragments(qs.builder.t, columns) - if err != nil { - qs.setErr(err) - return qs - } + lastJoin := sq.joins[joins-1] + if lastJoin.On != nil { + return errors.New(`Cannot use Using() and On() with the same Join() expression.`) + } - qs.mu.Lock() - qs.joinsArgs = append(qs.joinsArgs, args...) - lastJoin.Using = exql.UsingColumns(fragments...) - qs.mu.Unlock() + fragments, args, err := columnFragments(columns) + if err != nil { + return err + } + + sq.joinsArgs = append(sq.joinsArgs, args...) + lastJoin.Using = exql.UsingColumns(fragments...) - return qs + return nil + }) } -func (qs *selector) pushJoin(t string, tables []interface{}) Selector { +func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { tableNames := make([]string, len(tables)) for i := range tables { tableNames[i] = fmt.Sprintf("%s", tables[i]) } - qs.mu.Lock() - if qs.joins == nil { - qs.joins = []*exql.Join{} + if sq.joins == nil { + sq.joins = []*exql.Join{} } - qs.joins = append(qs.joins, + sq.joins = append(sq.joins, &exql.Join{ Type: t, Table: exql.TableWithName(strings.Join(tableNames, ", ")), }, ) - qs.mu.Unlock() - return qs + return nil } -func (qs *selector) FullJoin(tables ...interface{}) Selector { - return qs.pushJoin("FULL", tables) +func (sel *selector) FullJoin(tables ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("FULL", tables) + }) } -func (qs *selector) CrossJoin(tables ...interface{}) Selector { - return qs.pushJoin("CROSS", tables) +func (sel *selector) CrossJoin(tables ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("CROSS", tables) + }) } -func (qs *selector) RightJoin(tables ...interface{}) Selector { - return qs.pushJoin("RIGHT", tables) +func (sel *selector) RightJoin(tables ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("RIGHT", tables) + }) } -func (qs *selector) LeftJoin(tables ...interface{}) Selector { - return qs.pushJoin("LEFT", tables) +func (sel *selector) LeftJoin(tables ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("LEFT", tables) + }) } -func (qs *selector) Join(tables ...interface{}) Selector { - return qs.pushJoin("", tables) +func (sel *selector) Join(tables ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("", tables) + }) } -func (qs *selector) On(terms ...interface{}) Selector { - qs.mu.Lock() - joins := len(qs.joins) - qs.mu.Unlock() +func (sel *selector) On(terms ...interface{}) Selector { + return sel.frame(func(sq *selectorQuery) error { + joins := len(sq.joins) - if joins == 0 { - qs.setErr(errors.New(`Cannot use On() without a preceding Join() expression.`)) - return qs - } + if joins == 0 { + return errors.New(`Cannot use On() without a preceding Join() expression.`) + } - lastJoin := qs.joins[joins-1] - if lastJoin.On != nil { - qs.setErr(errors.New(`Cannot use Using() and On() with the same Join() expression.`)) - return qs - } + lastJoin := sq.joins[joins-1] + if lastJoin.On != nil { + return errors.New(`Cannot use Using() and On() with the same Join() expression.`) + } - w, a := qs.builder.t.ToWhereWithArguments(terms) - o := exql.On(w) + w, a := toWhereWithArguments(terms) + o := exql.On(w) - lastJoin.On = &o + lastJoin.On = &o - qs.mu.Lock() - qs.joinsArgs = append(qs.joinsArgs, a...) - qs.mu.Unlock() + sq.joinsArgs = append(sq.joinsArgs, a...) - return qs + return nil + }) } -func (qs *selector) Limit(n int) Selector { - qs.mu.Lock() - qs.limit = exql.Limit(n) - qs.mu.Unlock() - return qs +func (sel *selector) Limit(n int) Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.limit = exql.Limit(n) + return nil + }) } -func (qs *selector) Offset(n int) Selector { - qs.mu.Lock() - qs.offset = exql.Offset(n) - qs.mu.Unlock() - return qs +func (sel *selector) Offset(n int) Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.offset = exql.Offset(n) + return nil + }) } -func (qs *selector) statement() *exql.Statement { - return &exql.Statement{ - Type: exql.Select, - Table: qs.table, - Columns: qs.columns, - Limit: qs.limit, - Offset: qs.offset, - Joins: exql.JoinConditions(qs.joins...), - Where: qs.where, - OrderBy: qs.orderBy, - GroupBy: qs.groupBy, - } +func (sel *selector) As(alias string) Selector { + return sel.frame(func(sq *selectorQuery) error { + if sq.table == nil { + return errors.New("Cannot use As() without a preceding From() expression") + } + last := len(sq.table.Columns) - 1 + if raw, ok := sq.table.Columns[last].(*exql.Raw); ok { + sq.table.Columns[last] = exql.RawValue("(" + raw.Value + ") AS " + exql.ColumnWithName(alias).Compile(sel.Stringer().t)) + } + return nil + }) } -func (qs *selector) Query() (*sql.Rows, error) { - return qs.builder.sess.StatementQuery(qs.statement(), qs.Arguments()...) +func (sel *selector) statement() *exql.Statement { + sq, _ := sel.build() + return sq.statement() } -func (qs *selector) As(alias string) Selector { - if qs.table == nil { - qs.setErr(errors.New("Cannot use As() without a preceding From() expression")) - return qs - } - last := len(qs.table.Columns) - 1 - if raw, ok := qs.table.Columns[last].(*exql.Raw); ok { - qs.table.Columns[last] = exql.RawValue("(" + raw.Value + ") AS " + exql.ColumnWithName(alias).Compile(qs.stringer.t)) +func (sel *selector) QueryRow() (*sql.Row, error) { + sq, err := sel.build() + if err != nil { + return nil, err } - return qs + + return sel.builder.sess.StatementQueryRow(sq.statement(), sq.arguments()...) } -func (qs *selector) QueryRow() (*sql.Row, error) { - return qs.builder.sess.StatementQueryRow(qs.statement(), qs.Arguments()...) +func (sel *selector) Query() (*sql.Rows, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + return sel.builder.sess.StatementQuery(sq.statement(), sq.arguments()...) } -func (qs *selector) Iterator() Iterator { - rows, err := qs.builder.sess.StatementQuery(qs.statement(), qs.Arguments()...) +func (sel *selector) Iterator() Iterator { + sq, err := sel.build() + if err != nil { + return &iterator{nil, err} + } + + rows, err := sel.builder.sess.StatementQuery(sq.statement(), sq.arguments()...) return &iterator{rows, err} } -func (qs *selector) All(destSlice interface{}) error { - return qs.Iterator().All(destSlice) +func (sel *selector) All(destSlice interface{}) error { + return sel.Iterator().All(destSlice) +} + +func (sel *selector) One(dest interface{}) error { + return sel.Iterator().One(dest) } -func (qs *selector) One(dest interface{}) error { - return qs.Iterator().One(dest) +func (sel *selector) build() (*selectorQuery, error) { + sq, err := selectorFastForward(&selectorQuery{}, sel) + if err != nil { + return nil, err + } + return sq, nil } -func (qs *selector) setErr(err error) { - qs.mu.Lock() - qs.err = err - qs.mu.Unlock() +func (sel *selector) Compile() string { + return sel.statement().Compile(sel.Stringer().t) +} + +func selectorFastForward(in *selectorQuery, curr *selector) (*selectorQuery, error) { + if curr == nil || curr.fn == nil { + return in, nil + } + in, err := selectorFastForward(in, curr.prev) + if err != nil { + return nil, err + } + err = curr.fn(in) + return in, err } -- GitLab