diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index 4c5cf387f13cfef80e71b26a42d3449fe2e88e85..649b6670d770293fbccc429b7e3e2719fe42fbb0 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -164,11 +164,10 @@ func (b *sqlBuilder) Select(columns ...interface{}) Selector { func (b *sqlBuilder) InsertInto(table string) Inserter { qi := &inserter{ builder: b, - table: table, } qi.stringer = &stringer{qi, b.t.Template} - return qi + return qi.Into(table) } func (b *sqlBuilder) DeleteFrom(table string) Deleter { diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index 8f512bdcb5adf33538731ea383a8c0f23ce78bc6..f4345aeaca21fcca92c4dcdfb9c345bff80e6649 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -67,97 +67,6 @@ func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) return in, argx } -// ToWhereWithArguments converts the given parameters into a exql.Where -// value. -func (tu *templateWithUtils) ToWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - if len(t) > 0 { - if s, ok := t[0].(string); ok { - if strings.ContainsAny(s, "?") || len(t) == 1 { - s, args = expandPlaceholders(s, t[1:]...) - where.Conditions = []exql.Fragment{exql.RawValue(s)} - } else { - var val interface{} - key := s - if len(t) > 2 { - val = t[1:] - } else { - val = t[1] - } - cv, v := tu.ToColumnValues(db.NewConstraint(key, val)) - args = append(args, v...) - for i := range cv.ColumnValues { - where.Conditions = append(where.Conditions, cv.ColumnValues[i]) - } - } - return - } - } - for i := range t { - w, v := tu.ToWhereWithArguments(t[i]) - if len(w.Conditions) == 0 { - continue - } - args = append(args, v...) - where.Conditions = append(where.Conditions, w.Conditions...) - } - return - case db.RawValue: - r, v := expandPlaceholders(t.Raw(), t.Arguments()...) - where.Conditions = []exql.Fragment{exql.RawValue(r)} - args = append(args, v...) - return - case db.Constraints: - for _, c := range t.Constraints() { - w, v := tu.ToWhereWithArguments(c) - if len(w.Conditions) == 0 { - continue - } - args = append(args, v...) - where.Conditions = append(where.Conditions, w.Conditions...) - } - return - case db.Compound: - var cond exql.Where - - for _, c := range t.Sentences() { - w, v := tu.ToWhereWithArguments(c) - if len(w.Conditions) == 0 { - continue - } - args = append(args, v...) - cond.Conditions = append(cond.Conditions, w.Conditions...) - } - - if len(cond.Conditions) > 0 { - var frag exql.Fragment - switch t.Operator() { - case db.OperatorNone, db.OperatorAnd: - q := exql.And(cond) - frag = &q - case db.OperatorOr: - q := exql.Or(cond) - frag = &q - default: - panic(fmt.Sprintf("Unknown type %T", t)) - } - where.Conditions = append(where.Conditions, frag) - } - - return - case db.Constraint: - cv, v := tu.ToColumnValues(t) - args = append(args, v...) - where.Conditions = append(where.Conditions, cv.ColumnValues...) - return where, args - } - - panic(fmt.Sprintf("Unknown condition type %T", term)) -} - func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { switch t := in.(type) { case db.RawValue: @@ -207,137 +116,9 @@ func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) return []interface{}{value}, false } -// ToColumnValues converts the given conditions into a exql.ColumnValues struct. -func (tu *templateWithUtils) ToColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { - args = []interface{}{} - - switch t := term.(type) { - case []interface{}: - l := len(t) - for i := 0; i < l; i++ { - column := t[i].(string) - - if !strings.ContainsAny(column, "=") { - column = fmt.Sprintf("%s = ?", column) - } - - chunks := strings.SplitN(column, "=", 2) - - column = chunks[0] - format := strings.TrimSpace(chunks[1]) - - columnValue := exql.ColumnValue{ - Column: exql.ColumnWithName(column), - Operator: "=", - Value: exql.RawValue(format), - } - - ps := strings.Count(format, "?") - if i+ps < l { - for j := 0; j < ps; j++ { - args = append(args, t[i+j+1]) - } - i = i + ps - } else { - panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format)) - } - - cv.ColumnValues = append(cv.ColumnValues, &columnValue) - } - return cv, args - case db.Constraint: - columnValue := exql.ColumnValue{} - - // Guessing operator from input, or using a default one. - if column, ok := t.Key().(string); ok { - chunks := strings.SplitN(strings.TrimSpace(column), ` `, 2) - columnValue.Column = exql.ColumnWithName(chunks[0]) - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } - } else { - if rawValue, ok := t.Key().(db.RawValue); ok { - columnValue.Column = exql.RawValue(rawValue.Raw()) - args = append(args, rawValue.Arguments()...) - } else { - columnValue.Column = exql.RawValue(fmt.Sprintf("%v", t.Key())) - } - } - - switch value := t.Value().(type) { - case db.Function: - fnName, fnArgs := value.Name(), value.Arguments() - if len(fnArgs) == 0 { - // A function with no arguments. - fnName = fnName + "()" - } else { - // A function with one or more arguments. - fnName = fnName + "(?" + strings.Repeat("?, ", len(fnArgs)-1) + ")" - } - expanded, fnArgs := expandPlaceholders(fnName, fnArgs...) - columnValue.Value = exql.RawValue(expanded) - args = append(args, fnArgs...) - case db.RawValue: - expanded, rawArgs := expandPlaceholders(value.Raw(), value.Arguments()...) - columnValue.Value = exql.RawValue(expanded) - args = append(args, rawArgs...) - default: - v, isSlice := toInterfaceArguments(value) - - if isSlice { - if columnValue.Operator == "" { - columnValue.Operator = sqlInOperator - } - if len(v) > 0 { - // Array value given. - columnValue.Value = exql.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) - } else { - // Single value given. - columnValue.Value = exql.RawValue(`(NULL)`) - } - args = append(args, v...) - } else { - if v == nil { - // Nil value given. - columnValue.Value = sqlNull - if columnValue.Operator == "" { - columnValue.Operator = sqlIsOperator - } - } else { - columnValue.Value = sqlPlaceholder - args = append(args, v...) - } - } - - } - - // Using guessed operator if no operator was given. - if columnValue.Operator == "" { - if tu.DefaultOperator != "" { - columnValue.Operator = tu.DefaultOperator - } else { - columnValue.Operator = sqlDefaultOperator - } - } - - cv.ColumnValues = append(cv.ColumnValues, &columnValue) - - return cv, args - case db.Constraints: - for _, c := range t.Constraints() { - p, q := tu.ToColumnValues(c) - cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) - args = append(args, q...) - } - return cv, args - } - - panic(fmt.Sprintf("Unknown term type %T.", term)) -} - -// ToColumnsValuesAndArguments maps the given columnNames and columnValues into +// toColumnsValuesAndArguments maps the given columnNames and columnValues into // expr's Columns and Values, it also extracts and returns query arguments. -func (tu *templateWithUtils) ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*exql.Columns, *exql.Values, []interface{}, error) { +func toColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*exql.Columns, *exql.Values, []interface{}, error) { var arguments []interface{} columns := new(exql.Columns) @@ -370,7 +151,7 @@ func (tu *templateWithUtils) ToColumnsValuesAndArguments(columnNames []string, c return columns, values, arguments, nil } -// ToWhereWithArguments converts the given parameters into a exql.Where +// toWhereWithArguments converts the given parameters into a exql.Where // value. func toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { args = []interface{}{} diff --git a/lib/sqlbuilder/delete.go b/lib/sqlbuilder/delete.go index d417784d552ae87d69d2ea43b8b9ba16500389a9..779262e28c24e0135e84333d14e951c5ed4c7957 100644 --- a/lib/sqlbuilder/delete.go +++ b/lib/sqlbuilder/delete.go @@ -16,7 +16,7 @@ type deleter struct { } func (qd *deleter) Where(terms ...interface{}) Deleter { - where, arguments := qd.builder.t.ToWhereWithArguments(terms) + where, arguments := toWhereWithArguments(terms) qd.where = &where qd.arguments = append(qd.arguments, arguments...) return qd diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index bef31ae34663812afa32a795a4a4f67784694a87..19fffd66abdd15db4e40cd6eb04e8fcbf42b5d7c 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -2,13 +2,12 @@ package sqlbuilder import ( "database/sql" + "strings" "upper.io/db.v2/internal/sqladapter/exql" ) -type inserter struct { - *stringer - builder *sqlBuilder +type inserterQuery struct { table string values []*exql.Values returning []exql.Fragment @@ -17,21 +16,28 @@ type inserter struct { extra string } -func (qi *inserter) clone() *inserter { - clone := &inserter{} - *clone = *qi - return clone -} +func (iq *inserterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Insert, + Table: exql.TableWithName(iq.table), + } -func (qi *inserter) Batch(n int) *BatchInserter { - return newBatchInserter(qi.clone(), n) -} + if len(iq.values) > 0 { + stmt.Values = exql.JoinValueGroups(iq.values...) + } + + if len(iq.columns) > 0 { + stmt.Columns = exql.JoinColumns(iq.columns...) + } + + if len(iq.returning) > 0 { + stmt.Returning = exql.ReturningColumns(iq.returning...) + } -func (qi *inserter) Arguments() []interface{} { - return qi.arguments + return stmt } -func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) error { +func columnsToFragments(dst *[]exql.Fragment, columns []string) error { l := len(columns) f := make([]exql.Fragment, l) for i := 0; i < l; i++ { @@ -41,81 +47,167 @@ func (qi *inserter) columnsToFragments(dst *[]exql.Fragment, columns []string) e return nil } -func (qi *inserter) Returning(columns ...string) Inserter { - qi.columnsToFragments(&qi.returning, columns) - return qi +type inserter struct { + builder *sqlBuilder + *stringer + + fn func(*inserterQuery) error + prev *inserter +} + +func (ins *inserter) Stringer() *stringer { + p := &ins + for { + if (*p).stringer != nil { + return (*p).stringer + } + if (*p).prev == nil { + return nil + } + p = &(*p).prev + } +} + +func (ins *inserter) String() string { + query, err := ins.build() + if err != nil { + return "" + } + q := ins.Stringer().compileAndReplacePlaceholders(query.statement()) + q = reInvisibleChars.ReplaceAllString(q, ` `) + return strings.TrimSpace(q) +} + +func (ins *inserter) frame(fn func(*inserterQuery) error) *inserter { + return &inserter{prev: ins, fn: fn} } -func (qi *inserter) Exec() (sql.Result, error) { - return qi.builder.sess.StatementExec(qi.statement(), qi.arguments...) +func (ins *inserter) clone() *inserter { + clone := &inserter{} + *clone = *ins + return clone +} + +func (ins *inserter) Batch(n int) *BatchInserter { + return newBatchInserter(ins.clone(), n) +} + +func (ins *inserter) Arguments() []interface{} { + iq, err := ins.build() + if err != nil { + return nil + } + return iq.arguments } -func (qi *inserter) Query() (*sql.Rows, error) { - return qi.builder.sess.StatementQuery(qi.statement(), qi.arguments...) +func (ins *inserter) Returning(columns ...string) Inserter { + return ins.frame(func(iq *inserterQuery) error { + columnsToFragments(&iq.returning, columns) + return nil + }) } -func (qi *inserter) QueryRow() (*sql.Row, error) { - return qi.builder.sess.StatementQueryRow(qi.statement(), qi.arguments...) +func (ins *inserter) Exec() (sql.Result, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.builder.sess.StatementExec(iq.statement(), iq.arguments...) +} + +func (ins *inserter) Query() (*sql.Rows, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.builder.sess.StatementQuery(iq.statement(), iq.arguments...) +} + +func (ins *inserter) QueryRow() (*sql.Row, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.builder.sess.StatementQueryRow(iq.statement(), iq.arguments...) } -func (qi *inserter) Iterator() Iterator { - rows, err := qi.builder.sess.StatementQuery(qi.statement(), qi.arguments...) +func (ins *inserter) Iterator() Iterator { + rows, err := ins.Query() return &iterator{rows, err} } -func (qi *inserter) Columns(columns ...string) Inserter { - qi.columnsToFragments(&qi.columns, columns) - return qi +func (ins *inserter) Into(table string) Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.table = table + return nil + }) } -func (qi *inserter) Values(values ...interface{}) Inserter { - if len(values) == 1 { - ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true}) - if err == nil { - columns, vals, arguments, _ := qi.builder.t.ToColumnsValuesAndArguments(ff, vv) +func (ins *inserter) Columns(columns ...string) Inserter { + return ins.frame(func(iq *inserterQuery) error { + columnsToFragments(&iq.columns, columns) + return nil + }) +} - qi.arguments = append(qi.arguments, arguments...) - qi.values = append(qi.values, vals) - if len(qi.columns) == 0 { - for _, c := range columns.Columns { - qi.columns = append(qi.columns, c) +func (ins *inserter) Values(values ...interface{}) Inserter { + return ins.frame(func(iq *inserterQuery) error { + if len(values) == 1 { + ff, vv, err := Map(values[0], &MapOptions{IncludeZeroed: true, IncludeNil: true}) + if err == nil { + columns, vals, arguments, _ := toColumnsValuesAndArguments(ff, vv) + + iq.arguments = append(iq.arguments, arguments...) + iq.values = append(iq.values, vals) + if len(iq.columns) == 0 { + for _, c := range columns.Columns { + iq.columns = append(iq.columns, c) + } } + return nil } - return qi } - } - if len(qi.columns) == 0 || len(values) == len(qi.columns) { - qi.arguments = append(qi.arguments, values...) + if len(iq.columns) == 0 || len(values) == len(iq.columns) { + iq.arguments = append(iq.arguments, values...) - l := len(values) - placeholders := make([]exql.Fragment, l) - for i := 0; i < l; i++ { - placeholders[i] = exql.RawValue(`?`) + l := len(values) + placeholders := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + placeholders[i] = exql.RawValue(`?`) + } + iq.values = append(iq.values, exql.NewValueGroup(placeholders...)) } - qi.values = append(qi.values, exql.NewValueGroup(placeholders...)) - } - return qi + return nil + }) } -func (qi *inserter) statement() *exql.Statement { - stmt := &exql.Statement{ - Type: exql.Insert, - Table: exql.TableWithName(qi.table), - } +func (ins *inserter) statement() *exql.Statement { + iq, _ := ins.build() + return iq.statement() +} - if len(qi.values) > 0 { - stmt.Values = exql.JoinValueGroups(qi.values...) +func (ins *inserter) build() (*inserterQuery, error) { + iq, err := inserterFastForward(&inserterQuery{}, ins) + if err != nil { + return nil, err } + return iq, nil +} - if len(qi.columns) > 0 { - stmt.Columns = exql.JoinColumns(qi.columns...) - } +func (ins *inserter) Compile() string { + return ins.statement().Compile(ins.Stringer().t) +} - if len(qi.returning) > 0 { - stmt.Returning = exql.ReturningColumns(qi.returning...) +func inserterFastForward(in *inserterQuery, curr *inserter) (*inserterQuery, error) { + if curr == nil || curr.fn == nil { + return in, nil } - - return stmt + in, err := inserterFastForward(in, curr.prev) + if err != nil { + return nil, err + } + err = curr.fn(in) + return in, err } diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index b207c08112125c8b3a4bd87e5244f52c1c60665b..270e5081bdbfb05bfdc33f1dbf5a6897964df639 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -56,6 +56,55 @@ func (sq *selectorQuery) and(terms ...interface{}) error { return nil } +func (sq *selectorQuery) arguments() []interface{} { + return joinArguments( + sq.tableArgs, + sq.columnsArgs, + sq.joinsArgs, + sq.whereArgs, + sq.groupByArgs, + sq.orderByArgs, + ) +} + +func (sq *selectorQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Select, + Table: sq.table, + Columns: sq.columns, + Limit: sq.limit, + Offset: sq.offset, + Where: sq.where, + OrderBy: sq.orderBy, + GroupBy: sq.groupBy, + } + + if len(sq.joins) > 0 { + stmt.Joins = exql.JoinConditions(sq.joins...) + } + + return stmt +} + +func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { + tableNames := make([]string, len(tables)) + for i := range tables { + tableNames[i] = fmt.Sprintf("%s", tables[i]) + } + + if sq.joins == nil { + sq.joins = []*exql.Join{} + } + sq.joins = append(sq.joins, + &exql.Join{ + Type: t, + Table: exql.TableWithName(strings.Join(tableNames, ", ")), + }, + ) + + return nil +} + type selector struct { builder *sqlBuilder *stringer @@ -148,17 +197,6 @@ func (sel *selector) And(terms ...interface{}) Selector { }) } -func (sq *selectorQuery) arguments() []interface{} { - return joinArguments( - sq.tableArgs, - sq.columnsArgs, - sq.joinsArgs, - sq.whereArgs, - sq.groupByArgs, - sq.orderByArgs, - ) -} - func (sel *selector) Arguments() []interface{} { sq, err := sel.build() if err != nil { @@ -167,25 +205,6 @@ func (sel *selector) Arguments() []interface{} { return sq.arguments() } -func (sq *selectorQuery) statement() *exql.Statement { - stmt := &exql.Statement{ - Type: exql.Select, - Table: sq.table, - Columns: sq.columns, - Limit: sq.limit, - Offset: sq.offset, - Where: sq.where, - OrderBy: sq.orderBy, - GroupBy: sq.groupBy, - } - - if len(sq.joins) > 0 { - stmt.Joins = exql.JoinConditions(sq.joins...) - } - - return stmt -} - func (sel *selector) GroupBy(columns ...interface{}) Selector { return sel.frame(func(sq *selectorQuery) error { fragments, args, err := columnFragments(columns) @@ -286,25 +305,6 @@ func (sel *selector) Using(columns ...interface{}) Selector { }) } -func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { - tableNames := make([]string, len(tables)) - for i := range tables { - tableNames[i] = fmt.Sprintf("%s", tables[i]) - } - - if sq.joins == nil { - sq.joins = []*exql.Join{} - } - sq.joins = append(sq.joins, - &exql.Join{ - Type: t, - Table: exql.TableWithName(strings.Join(tableNames, ", ")), - }, - ) - - return nil -} - func (sel *selector) FullJoin(tables ...interface{}) Selector { return sel.frame(func(sq *selectorQuery) error { return sq.pushJoin("FULL", tables) diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index cb1b09bd3bc74113063c2062b78ccbdae7528e3f..b7cfd83141d3838fba23ee3a47fe0576bd6a51e8 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -46,7 +46,7 @@ func (qu *updater) Set(terms ...interface{}) Updater { qu.columnValues.Insert(cvs...) qu.columnValuesArgs = append(qu.columnValuesArgs, args...) } else if len(terms) > 1 { - cv, arguments := qu.builder.t.ToColumnValues(terms) + cv, arguments := toColumnValues(terms) qu.columnValues.Insert(cv.ColumnValues...) qu.columnValuesArgs = append(qu.columnValuesArgs, arguments...) } @@ -65,7 +65,7 @@ func (qu *updater) Arguments() []interface{} { } func (qu *updater) Where(terms ...interface{}) Updater { - where, arguments := qu.builder.t.ToWhereWithArguments(terms) + where, arguments := toWhereWithArguments(terms) qu.where = &where qu.whereArgs = append(qu.whereArgs, arguments...) return qu