From be9f7a4a7fe3714d274504fa4aa0dca6d9c8f062 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Tue, 29 Sep 2015 12:33:04 +0200 Subject: [PATCH] Drafting query selector (wip). --- builder.go | 41 +++++- postgresql/builder.go | 336 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 364 insertions(+), 13 deletions(-) diff --git a/builder.go b/builder.go index edc8bc8d..7c03191e 100644 --- a/builder.go +++ b/builder.go @@ -2,29 +2,53 @@ package db import ( "database/sql" + "fmt" + "github.com/jmoiron/sqlx" ) // QueryBuilder is an experimental interface. type QueryBuilder interface { - Select(fields ...interface{}) QuerySelector + Select(columns ...interface{}) QuerySelector + SelectAllFrom(table string) QuerySelector + InsertInto(table string) QueryInserter DeleteFrom(table string) QueryDeleter Update(table string) QueryUpdater } type QuerySelector interface { - From(table ...string) Result + From(tables ...string) QuerySelector + Distinct() QuerySelector + Where(...interface{}) QuerySelector + GroupBy(...interface{}) QuerySelector + //Having(...interface{}) QuerySelector + OrderBy(...interface{}) QuerySelector + Using(...interface{}) QuerySelector + FullJoin(...interface{}) QuerySelector + CrossJoin(...interface{}) QuerySelector + RightJoin(...interface{}) QuerySelector + LeftJoin(...interface{}) QuerySelector + Join(...interface{}) QuerySelector + On(...interface{}) QuerySelector + Limit(int) QuerySelector + Offset(int) QuerySelector + + QueryGetter + ResultIterator + fmt.Stringer } type QueryInserter interface { Values(...interface{}) QueryInserter Columns(...string) QueryInserter + QueryExecer } type QueryDeleter interface { Where(...interface{}) QueryDeleter Limit(int) QueryDeleter + QueryExecer } @@ -32,9 +56,22 @@ type QueryUpdater interface { Set(...interface{}) QueryUpdater Where(...interface{}) QueryUpdater Limit(int) QueryUpdater + QueryExecer } type QueryExecer interface { Exec() (sql.Result, error) } + +type QueryGetter interface { + Query() (*sqlx.Rows, error) + QueryRow() (*sqlx.Row, error) +} + +type ResultIterator interface { + All(interface{}) error + Next(interface{}) error + One(interface{}) error + Close() error +} diff --git a/postgresql/builder.go b/postgresql/builder.go index bfe125ac..715de142 100644 --- a/postgresql/builder.go +++ b/postgresql/builder.go @@ -2,19 +2,45 @@ package postgresql import ( "database/sql" + "errors" "fmt" + "github.com/jmoiron/sqlx" + "regexp" + "strings" "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" +) + +type SelectMode uint8 + +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) +) + +const ( + selectModeAll SelectMode = iota + selectModeDistinct ) type Builder struct { sess *database } -func (b *Builder) Select(fields ...interface{}) db.QuerySelector { +func (b *Builder) SelectAllFrom(table string) db.QuerySelector { return &QuerySelector{ builder: b, - fields: fields, + table: table, + } +} + +func (b *Builder) Select(columns ...interface{}) db.QuerySelector { + f, err := columnFragments(columns) + + return &QuerySelector{ + builder: b, + columns: sqlgen.JoinColumns(f...), + err: err, } } @@ -39,15 +65,6 @@ func (b *Builder) Update(table string) db.QueryUpdater { } } -type QuerySelector struct { - builder *Builder - fields []interface{} -} - -func (qs *QuerySelector) From(table ...string) db.Result { - return qs.builder.sess.C(table...).Find().Select(qs.fields...) -} - type QueryInserter struct { builder *Builder table string @@ -177,3 +194,300 @@ func (qu *QueryUpdater) Limit(limit int) db.QueryUpdater { qu.limit = limit return qu } + +type QuerySelector struct { + mode SelectMode + cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. + builder *Builder + table string + where *sqlgen.Where + groupBy *sqlgen.GroupBy + orderBy sqlgen.OrderBy + limit sqlgen.Limit + offset sqlgen.Offset + columns *sqlgen.Columns + joins []*sqlgen.Join + arguments []interface{} + err error +} + +func (qs *QuerySelector) From(tables ...string) db.QuerySelector { + qs.table = strings.Join(tables, ",") + return qs +} + +func (qs *QuerySelector) Distinct() db.QuerySelector { + qs.mode = selectModeDistinct + return qs +} + +func (qs *QuerySelector) Where(terms ...interface{}) db.QuerySelector { + where, arguments := template.ToWhereWithArguments(terms) + qs.where = &where + qs.arguments = append(qs.arguments, arguments...) + return qs +} + +func (qs *QuerySelector) GroupBy(columns ...interface{}) db.QuerySelector { + var fragments []sqlgen.Fragment + fragments, qs.err = columnFragments(columns) + if fragments != nil { + qs.groupBy = sqlgen.GroupByColumns(fragments...) + } + return qs +} + +func (qs *QuerySelector) OrderBy(columns ...interface{}) db.QuerySelector { + var sortColumns sqlgen.SortColumns + + for i := range columns { + var sort *sqlgen.SortColumn + + switch value := columns[i].(type) { + case db.Raw: + sort = &sqlgen.SortColumn{ + Column: sqlgen.RawValue(fmt.Sprintf(`%v`, value.Value)), + } + case string: + if strings.HasPrefix(value, `-`) { + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value[1:]), + Order: sqlgen.Descendent, + } + } else { + sort = &sqlgen.SortColumn{ + Column: sqlgen.ColumnWithName(value), + Order: sqlgen.Ascendent, + } + } + } + sortColumns.Columns = append(sortColumns.Columns, sort) + } + + qs.orderBy.SortColumns = &sortColumns + + return qs +} + +func (qs *QuerySelector) Using(columns ...interface{}) db.QuerySelector { + if len(qs.joins) == 0 { + qs.err = errors.New(`Cannot use Using() without a preceding Join() expression.`) + return qs + } + + lastJoin := qs.joins[len(qs.joins)-1] + + if lastJoin.On != nil { + qs.err = errors.New(`Cannot use Using() and On() with the same Join() expression.`) + return qs + } + + fragments, err := columnFragments(columns) + if err != nil { + qs.err = err + return qs + } + + lastJoin.Using = sqlgen.UsingColumns(fragments...) + return qs +} + +func (qs *QuerySelector) pushJoin(t string, tables []interface{}) db.QuerySelector { + if qs.joins == nil { + qs.joins = []*sqlgen.Join{} + } + + tableNames := make([]string, len(tables)) + for i := range tables { + tableNames[i] = fmt.Sprintf("%s", tables[i]) + } + + qs.joins = append(qs.joins, + &sqlgen.Join{ + Type: t, + Table: sqlgen.TableWithName(strings.Join(tableNames, ", ")), + }, + ) + + return qs +} + +func (qs *QuerySelector) FullJoin(tables ...interface{}) db.QuerySelector { + return qs.pushJoin("FULL", tables) +} + +func (qs *QuerySelector) CrossJoin(tables ...interface{}) db.QuerySelector { + return qs.pushJoin("CROSS", tables) +} + +func (qs *QuerySelector) RightJoin(tables ...interface{}) db.QuerySelector { + return qs.pushJoin("RIGHT", tables) +} + +func (qs *QuerySelector) LeftJoin(tables ...interface{}) db.QuerySelector { + return qs.pushJoin("LEFT", tables) +} + +func (qs *QuerySelector) Join(tables ...interface{}) db.QuerySelector { + return qs.pushJoin("", tables) +} + +func (qs *QuerySelector) On(terms ...interface{}) db.QuerySelector { + if len(qs.joins) == 0 { + qs.err = errors.New(`Cannot use On() without a preceding Join() expression.`) + return qs + } + + lastJoin := qs.joins[len(qs.joins)-1] + + if lastJoin.On != nil { + qs.err = errors.New(`Cannot use Using() and On() with the same Join() expression.`) + return qs + } + + w, a := template.ToWhereWithArguments(terms) + o := sqlgen.On(w) + lastJoin.On = &o + + qs.arguments = append(qs.arguments, a...) + return qs +} + +func (qs *QuerySelector) Limit(n int) db.QuerySelector { + qs.limit = sqlgen.Limit(n) + return qs +} + +func (qs *QuerySelector) Offset(n int) db.QuerySelector { + qs.offset = sqlgen.Offset(n) + return qs +} + +func (qs *QuerySelector) statement() *sqlgen.Statement { + return &sqlgen.Statement{ + Type: sqlgen.Select, + Table: sqlgen.TableWithName(qs.table), + Columns: qs.columns, + Limit: qs.limit, + Offset: qs.offset, + Joins: sqlgen.JoinConditions(qs.joins...), + Where: qs.where, + OrderBy: &qs.orderBy, + GroupBy: qs.groupBy, + } +} + +func (qs *QuerySelector) Query() (*sqlx.Rows, error) { + return qs.builder.sess.Query(qs.statement(), qs.arguments...) +} + +func (qs *QuerySelector) QueryRow() (*sqlx.Row, error) { + return qs.builder.sess.QueryRow(qs.statement(), qs.arguments...) +} + +func (qs *QuerySelector) Close() (err error) { + if qs.err != nil { + return qs.err + } + if qs.cursor != nil { + err = qs.cursor.Close() + qs.cursor = nil + } + return err +} + +func (qs *QuerySelector) setCursor() (err error) { + if qs.cursor == nil { + qs.cursor, err = qs.builder.sess.Query(qs.statement(), qs.arguments...) + } + return err +} + +func (qs *QuerySelector) One(dst interface{}) error { + var err error + + if qs.err != nil { + return qs.err + } + + if qs.cursor != nil { + return db.ErrQueryIsPending + } + + defer qs.Close() + + err = qs.Next(dst) + + return err +} + +func (qs *QuerySelector) All(dst interface{}) error { + var err error + + if qs.err != nil { + return qs.err + } + + if qs.cursor != nil { + return db.ErrQueryIsPending + } + + err = qs.setCursor() + + if err != nil { + return err + } + + defer qs.Close() + + // Fetching all results within the cursor. + err = sqlutil.FetchRows(qs.cursor, dst) + + return err +} + +func (qs *QuerySelector) Next(dst interface{}) (err error) { + if qs.err != nil { + return qs.err + } + + if err = qs.setCursor(); err != nil { + qs.Close() + return err + } + + if err = sqlutil.FetchRow(qs.cursor, dst); err != nil { + qs.Close() + return err + } + + return nil +} + +func (qs *QuerySelector) String() string { + q := compileAndReplacePlaceholders(qs.statement()) + q = reInvisibleChars.ReplaceAllString(q, ` `) + return strings.TrimSpace(q) +} + +func columnFragments(columns []interface{}) ([]sqlgen.Fragment, error) { + l := len(columns) + f := make([]sqlgen.Fragment, l) + + for i := 0; i < l; i++ { + switch v := columns[i].(type) { + case db.Raw: + f[i] = sqlgen.RawValue(fmt.Sprintf("%v", v)) + case sqlgen.Fragment: + f[i] = v + case string: + f[i] = sqlgen.ColumnWithName(v) + case interface{}: + f[i] = sqlgen.ColumnWithName(fmt.Sprintf("%v", v)) + default: + return nil, fmt.Errorf("Unexpected argument type %T for Select() argument.", v) + } + } + + return f, nil +} -- GitLab