diff --git a/postgresql/builder.go b/builder/builder.go similarity index 84% rename from postgresql/builder.go rename to builder/builder.go index f981159645ce5a0ba25fb7bf64eafaf65b1fe6d1..568bd6b485b1c37c720d7c05ff73fed14cbf7ab5 100644 --- a/postgresql/builder.go +++ b/builder/builder.go @@ -1,4 +1,4 @@ -package postgresql +package builder import ( "database/sql" @@ -6,6 +6,7 @@ import ( "fmt" "github.com/jmoiron/sqlx" "regexp" + "strconv" "strings" "upper.io/db" "upper.io/db/util/sqlgen" @@ -23,8 +24,15 @@ const ( selectModeDistinct ) +type sqlDatabase interface { + Query(stmt *sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) + QueryRow(stmt *sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) + Exec(stmt *sqlgen.Statement, args ...interface{}) (sql.Result, error) +} + type Builder struct { - sess *database + sess sqlDatabase + t *sqlutil.TemplateWithUtils } func (b *Builder) SelectAllFrom(table string) db.QuerySelector { @@ -33,12 +41,12 @@ func (b *Builder) SelectAllFrom(table string) db.QuerySelector { table: table, } - qs.stringer = &stringer{qs} + qs.stringer = &stringer{qs, b.t.Template} return qs } func (b *Builder) Select(columns ...interface{}) db.QuerySelector { - f, err := columnFragments(columns) + f, err := columnFragments(b.t, columns) qs := &QuerySelector{ builder: b, @@ -46,7 +54,7 @@ func (b *Builder) Select(columns ...interface{}) db.QuerySelector { err: err, } - qs.stringer = &stringer{qs} + qs.stringer = &stringer{qs, b.t.Template} return qs } @@ -56,7 +64,7 @@ func (b *Builder) InsertInto(table string) db.QueryInserter { table: table, } - qi.stringer = &stringer{qi} + qi.stringer = &stringer{qi, b.t.Template} return qi } @@ -66,7 +74,7 @@ func (b *Builder) DeleteFrom(table string) db.QueryDeleter { table: table, } - qd.stringer = &stringer{qd} + qd.stringer = &stringer{qd, b.t.Template} return qd } @@ -76,7 +84,7 @@ func (b *Builder) Update(table string) db.QueryUpdater { table: table, } - qu.stringer = &stringer{qu} + qu.stringer = &stringer{qu, b.t.Template} return qu } @@ -144,7 +152,7 @@ type QueryDeleter struct { } func (qd *QueryDeleter) Where(terms ...interface{}) db.QueryDeleter { - where, arguments := template.ToWhereWithArguments(terms) + where, arguments := qd.builder.t.ToWhereWithArguments(terms) qd.where = &where qd.arguments = append(qd.arguments, arguments...) return qd @@ -187,14 +195,14 @@ type QueryUpdater struct { } func (qu *QueryUpdater) Set(terms ...interface{}) db.QueryUpdater { - cv, arguments := template.ToColumnValues(terms) + cv, arguments := qu.builder.t.ToColumnValues(terms) qu.columnValues = &cv qu.arguments = append(qu.arguments, arguments...) return qu } func (qu *QueryUpdater) Where(terms ...interface{}) db.QueryUpdater { - where, arguments := template.ToWhereWithArguments(terms) + where, arguments := qu.builder.t.ToWhereWithArguments(terms) qu.where = &where qu.arguments = append(qu.arguments, arguments...) return qu @@ -255,7 +263,7 @@ func (qs *QuerySelector) Distinct() db.QuerySelector { } func (qs *QuerySelector) Where(terms ...interface{}) db.QuerySelector { - where, arguments := template.ToWhereWithArguments(terms) + where, arguments := qs.builder.t.ToWhereWithArguments(terms) qs.where = &where qs.arguments = append(qs.arguments, arguments...) return qs @@ -263,7 +271,7 @@ func (qs *QuerySelector) Where(terms ...interface{}) db.QuerySelector { func (qs *QuerySelector) GroupBy(columns ...interface{}) db.QuerySelector { var fragments []sqlgen.Fragment - fragments, qs.err = columnFragments(columns) + fragments, qs.err = columnFragments(qs.builder.t, columns) if fragments != nil { qs.groupBy = sqlgen.GroupByColumns(fragments...) } @@ -315,7 +323,7 @@ func (qs *QuerySelector) Using(columns ...interface{}) db.QuerySelector { return qs } - fragments, err := columnFragments(columns) + fragments, err := columnFragments(qs.builder.t, columns) if err != nil { qs.err = err return qs @@ -378,7 +386,7 @@ func (qs *QuerySelector) On(terms ...interface{}) db.QuerySelector { return qs } - w, a := template.ToWhereWithArguments(terms) + w, a := qs.builder.t.ToWhereWithArguments(terms) o := sqlgen.On(w) lastJoin.On = &o @@ -502,12 +510,25 @@ func (qs *QuerySelector) Next(dst interface{}) bool { return true } -func columnFragments(columns []interface{}) ([]sqlgen.Fragment, error) { +func columnFragments(template *sqlutil.TemplateWithUtils, columns []interface{}) ([]sqlgen.Fragment, error) { l := len(columns) f := make([]sqlgen.Fragment, l) for i := 0; i < l; i++ { switch v := columns[i].(type) { + case db.Func: + var s string + a := template.ToInterfaceArguments(v.Args) + if len(a) == 0 { + s = fmt.Sprintf(`%s()`, v.Name) + } else { + ss := make([]string, 0, len(a)) + for j := range a { + ss = append(ss, fmt.Sprintf(`%v`, a[j])) + } + s = fmt.Sprintf(`%s(%s)`, v.Name, strings.Join(ss, `, `)) + } + f[i] = sqlgen.RawValue(s) case db.Raw: f[i] = sqlgen.RawValue(fmt.Sprintf("%v", v.Value)) case sqlgen.Fragment: @@ -530,13 +551,37 @@ type hasStatement interface { type stringer struct { i hasStatement + t *sqlgen.Template } func (s *stringer) String() string { if s != nil && s.i != nil { - q := compileAndReplacePlaceholders(s.i.statement()) + q := s.compileAndReplacePlaceholders(s.i.statement()) q = reInvisibleChars.ReplaceAllString(q, ` `) return strings.TrimSpace(q) } return "" } + +func (s *stringer) compileAndReplacePlaceholders(stmt *sqlgen.Statement) (query string) { + buf := stmt.Compile(s.t) + + j := 1 + for i := range buf { + if buf[i] == '?' { + query = query + "$" + strconv.Itoa(j) + j++ + } else { + query = query + string(buf[i]) + } + } + + return query +} + +func NewBuilder(sess sqlDatabase, t *sqlutil.TemplateWithUtils) *Builder { + return &Builder{ + sess: sess, + t: t, + } +} diff --git a/postgresql/database.go b/postgresql/database.go index 73b3888eb32a19b6dbd3de868f6657a43c4b5f71..90a2e45ee76dd1d2640440fad97b6963e1ef7040 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -32,6 +32,7 @@ import ( _ "github.com/lib/pq" // PostgreSQL driver. "upper.io/cache" "upper.io/db" + "upper.io/db/builder" "upper.io/db/util/adapter" "upper.io/db/util/schema" "upper.io/db/util/sqlgen" @@ -636,7 +637,7 @@ func (d *database) getPrimaryKey(tableName string) ([]string, error) { // Builder returns a custom query builder. func (d *database) Builder() db.QueryBuilder { - return &Builder{sess: d} + return builder.NewBuilder(d, template) } // waitForConnection tries to execute the connectFn function, if connectFn