diff --git a/builder.go b/builder.go index 7c03191e2c1317480753f0374aff0b2338fb7438..f67adfc578e57111cad3b54c0a5b7cc36878cd51 100644 --- a/builder.go +++ b/builder.go @@ -34,7 +34,7 @@ type QuerySelector interface { Offset(int) QuerySelector QueryGetter - ResultIterator + Iterator fmt.Stringer } @@ -43,6 +43,7 @@ type QueryInserter interface { Columns(...string) QueryInserter QueryExecer + fmt.Stringer } type QueryDeleter interface { @@ -50,6 +51,7 @@ type QueryDeleter interface { Limit(int) QueryDeleter QueryExecer + fmt.Stringer } type QueryUpdater interface { @@ -58,6 +60,7 @@ type QueryUpdater interface { Limit(int) QueryUpdater QueryExecer + fmt.Stringer } type QueryExecer interface { @@ -69,9 +72,10 @@ type QueryGetter interface { QueryRow() (*sqlx.Row, error) } -type ResultIterator interface { +type Iterator interface { All(interface{}) error - Next(interface{}) error One(interface{}) error + Next(interface{}) bool + Err() error Close() error } diff --git a/postgresql/builder.go b/postgresql/builder.go index 715de14276843554a0d84d77f318d81708461f98..f981159645ce5a0ba25fb7bf64eafaf65b1fe6d1 100644 --- a/postgresql/builder.go +++ b/postgresql/builder.go @@ -28,64 +28,69 @@ type Builder struct { } func (b *Builder) SelectAllFrom(table string) db.QuerySelector { - return &QuerySelector{ + qs := &QuerySelector{ builder: b, table: table, } + + qs.stringer = &stringer{qs} + return qs } func (b *Builder) Select(columns ...interface{}) db.QuerySelector { f, err := columnFragments(columns) - return &QuerySelector{ + qs := &QuerySelector{ builder: b, columns: sqlgen.JoinColumns(f...), err: err, } + + qs.stringer = &stringer{qs} + return qs } func (b *Builder) InsertInto(table string) db.QueryInserter { - return &QueryInserter{ + qi := &QueryInserter{ builder: b, table: table, } + + qi.stringer = &stringer{qi} + return qi } func (b *Builder) DeleteFrom(table string) db.QueryDeleter { - return &QueryDeleter{ + qd := &QueryDeleter{ builder: b, table: table, } + + qd.stringer = &stringer{qd} + return qd } func (b *Builder) Update(table string) db.QueryUpdater { - return &QueryUpdater{ + qu := &QueryUpdater{ builder: b, table: table, } + + qu.stringer = &stringer{qu} + return qu } type QueryInserter struct { - builder *Builder - table string - values []*sqlgen.Values - columns []sqlgen.Fragment + *stringer + builder *Builder + table string + values []*sqlgen.Values + columns []sqlgen.Fragment + arguments []interface{} } func (qi *QueryInserter) Exec() (sql.Result, error) { - stmt := &sqlgen.Statement{ - Type: sqlgen.Insert, - Table: sqlgen.TableWithName(qi.table), - } - - if len(qi.values) > 0 { - stmt.Values = sqlgen.JoinValueGroups(qi.values...) - } - if len(qi.columns) > 0 { - stmt.Columns = sqlgen.JoinColumns(qi.columns...) - } - - return qi.builder.sess.Exec(stmt) + return qi.builder.sess.Exec(qi.statement()) } func (qi *QueryInserter) Columns(columns ...string) db.QueryInserter { @@ -99,31 +104,49 @@ func (qi *QueryInserter) Columns(columns ...string) db.QueryInserter { } func (qi *QueryInserter) Values(values ...interface{}) db.QueryInserter { - l := len(values) - f := make([]sqlgen.Fragment, l) - for i := 0; i < l; i++ { - if _, ok := values[i].(db.Raw); ok { - f[i] = sqlgen.NewValue(sqlgen.RawValue(fmt.Sprintf("%v", values[i]))) - } else { - f[i] = sqlgen.NewValue(values[i]) + if len(qi.columns) == 0 || len(values) == len(qi.columns) { + qi.arguments = append(qi.arguments, values...) + + l := len(values) + placeholders := make([]sqlgen.Fragment, l) + for i := 0; i < l; i++ { + placeholders[i] = sqlgen.RawValue(`?`) } + qi.values = append(qi.values, sqlgen.NewValueGroup(placeholders...)) } - qi.values = append(qi.values, sqlgen.NewValueGroup(f...)) + return qi } +func (qi *QueryInserter) statement() *sqlgen.Statement { + stmt := &sqlgen.Statement{ + Type: sqlgen.Insert, + Table: sqlgen.TableWithName(qi.table), + } + + if len(qi.values) > 0 { + stmt.Values = sqlgen.JoinValueGroups(qi.values...) + } + + if len(qi.columns) > 0 { + stmt.Columns = sqlgen.JoinColumns(qi.columns...) + } + return stmt +} + type QueryDeleter struct { - builder *Builder - table string - limit int - where *sqlgen.Where - args []interface{} + *stringer + builder *Builder + table string + limit int + where *sqlgen.Where + arguments []interface{} } func (qd *QueryDeleter) Where(terms ...interface{}) db.QueryDeleter { where, arguments := template.ToWhereWithArguments(terms) qd.where = &where - qd.args = append(qd.args, arguments...) + qd.arguments = append(qd.arguments, arguments...) return qd } @@ -133,6 +156,10 @@ func (qd *QueryDeleter) Limit(limit int) db.QueryDeleter { } func (qd *QueryDeleter) Exec() (sql.Result, error) { + return qd.builder.sess.Exec(qd.statement(), qd.arguments...) +} + +func (qd *QueryDeleter) statement() *sqlgen.Statement { stmt := &sqlgen.Statement{ Type: sqlgen.Delete, Table: sqlgen.TableWithName(qd.table), @@ -146,33 +173,43 @@ func (qd *QueryDeleter) Exec() (sql.Result, error) { stmt.Limit = sqlgen.Limit(qd.limit) } - return qd.builder.sess.Exec(stmt, qd.args...) + return stmt } type QueryUpdater struct { + *stringer builder *Builder table string columnValues *sqlgen.ColumnValues limit int where *sqlgen.Where - args []interface{} + arguments []interface{} } func (qu *QueryUpdater) Set(terms ...interface{}) db.QueryUpdater { - cv, args := template.ToColumnValues(terms) + cv, arguments := template.ToColumnValues(terms) qu.columnValues = &cv - qu.args = append(qu.args, args...) + qu.arguments = append(qu.arguments, arguments...) return qu } func (qu *QueryUpdater) Where(terms ...interface{}) db.QueryUpdater { where, arguments := template.ToWhereWithArguments(terms) qu.where = &where - qu.args = append(qu.args, arguments...) + qu.arguments = append(qu.arguments, arguments...) return qu } func (qu *QueryUpdater) Exec() (sql.Result, error) { + return qu.builder.sess.Exec(qu.statement(), qu.arguments...) +} + +func (qu *QueryUpdater) Limit(limit int) db.QueryUpdater { + qu.limit = limit + return qu +} + +func (qu *QueryUpdater) statement() *sqlgen.Statement { stmt := &sqlgen.Statement{ Type: sqlgen.Update, Table: sqlgen.TableWithName(qu.table), @@ -187,15 +224,11 @@ func (qu *QueryUpdater) Exec() (sql.Result, error) { stmt.Limit = sqlgen.Limit(qu.limit) } - return qu.builder.sess.Exec(stmt, qu.args...) -} - -func (qu *QueryUpdater) Limit(limit int) db.QueryUpdater { - qu.limit = limit - return qu + return stmt } type QuerySelector struct { + *stringer mode SelectMode cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value. builder *Builder @@ -386,9 +419,6 @@ func (qs *QuerySelector) QueryRow() (*sqlx.Row, error) { } func (qs *QuerySelector) Close() (err error) { - if qs.err != nil { - return qs.err - } if qs.cursor != nil { err = qs.cursor.Close() qs.cursor = nil @@ -404,8 +434,6 @@ func (qs *QuerySelector) setCursor() (err error) { } func (qs *QuerySelector) One(dst interface{}) error { - var err error - if qs.err != nil { return qs.err } @@ -416,9 +444,11 @@ func (qs *QuerySelector) One(dst interface{}) error { defer qs.Close() - err = qs.Next(dst) + if !qs.Next(dst) { + return qs.Err() + } - return err + return nil } func (qs *QuerySelector) All(dst interface{}) error { @@ -446,28 +476,30 @@ func (qs *QuerySelector) All(dst interface{}) error { return err } -func (qs *QuerySelector) Next(dst interface{}) (err error) { +func (qs *QuerySelector) Err() (err error) { + return qs.err +} + +func (qs *QuerySelector) Next(dst interface{}) bool { + var err error + if qs.err != nil { - return qs.err + return false } if err = qs.setCursor(); err != nil { + qs.err = err qs.Close() - return err + return false } if err = sqlutil.FetchRow(qs.cursor, dst); err != nil { + qs.err = err qs.Close() - return err + return false } - return nil -} - -func (qs *QuerySelector) String() string { - q := compileAndReplacePlaceholders(qs.statement()) - q = reInvisibleChars.ReplaceAllString(q, ` `) - return strings.TrimSpace(q) + return true } func columnFragments(columns []interface{}) ([]sqlgen.Fragment, error) { @@ -477,7 +509,7 @@ func columnFragments(columns []interface{}) ([]sqlgen.Fragment, error) { for i := 0; i < l; i++ { switch v := columns[i].(type) { case db.Raw: - f[i] = sqlgen.RawValue(fmt.Sprintf("%v", v)) + f[i] = sqlgen.RawValue(fmt.Sprintf("%v", v.Value)) case sqlgen.Fragment: f[i] = v case string: @@ -491,3 +523,20 @@ func columnFragments(columns []interface{}) ([]sqlgen.Fragment, error) { return f, nil } + +type hasStatement interface { + statement() *sqlgen.Statement +} + +type stringer struct { + i hasStatement +} + +func (s *stringer) String() string { + if s != nil && s.i != nil { + q := compileAndReplacePlaceholders(s.i.statement()) + q = reInvisibleChars.ReplaceAllString(q, ` `) + return strings.TrimSpace(q) + } + return "" +} diff --git a/postgresql/database_test.go b/postgresql/database_test.go index b48ccb7fd76f14baa70d9bb1889fc063d7199523..fc0b5baca1c0881d6272d5596088dc6b516939de 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -1917,10 +1917,6 @@ func TestOptionTypeJsonbStruct(t *testing.T) { func TestQueryBuilder(t *testing.T) { var sess db.Database var err error - var sel db.QuerySelector - var artist artistType - - assert := assert.New(t) if sess, err = db.Open(Adapter, settings); err != nil { t.Fatal(err) @@ -1930,6 +1926,10 @@ func TestQueryBuilder(t *testing.T) { b := sess.Builder() + assert := assert.New(t) + + // Testing SELECT. + assert.Equal( `SELECT * FROM "artist"`, b.SelectAllFrom("artist").String(), @@ -2038,51 +2038,58 @@ func TestQueryBuilder(t *testing.T) { b.SelectAllFrom("artist").Join("publication").Using("id").String(), ) - // Should not work because we are using both "On()" and "Using()" - sel = b.Select().From("artist a").Join("publications p").On("p1.id = a.id").Using("id") - assert.Error(sel.One(&artist)) + assert.Equal( + `SELECT DATE()`, + b.Select(db.Raw{"DATE()"}).String(), + ) - // Should not work because a Join() is missing before On(). - sel = b.Select().From("artist a").On("p1.id = a.id") - assert.Error(sel.One(&artist)) + // Testing INSERT. - // INSERT INTO artist VALUES (10, 'Ryuichi Sakamoto'), (11, 'Alondra de la Parra') - if _, err = b.InsertInto("artist").Values(10, "Ryuichi Sakamoto").Values(11, "Alondra de la Parra").Exec(); err != nil { - t.Fatal(err) - } + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) - // INSERT INTO artist COLUMNS("name") VALUES('Chavela Vargas') - if _, err = b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).Exec(); err != nil { - t.Fatal(err) - } + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) - // DELETE FROM artist WHERE name = 'Chavela Vargas' LIMIT 1 - if _, err = b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).Exec(); err != nil { - t.Fatal(err) - } + // Testing DELETE. - // DELETE FROM artist WHERE id > 5 - if _, err = b.DeleteFrom("artist").Where("id > 5").Exec(); err != nil { - t.Fatal(err) - } + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1) LIMIT 1`, + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String(), + ) - // UPDATE artist SET name = ? - if _, err = b.Update("artist").Set("name", "Artist").Exec(); err != nil { - t.Fatal(err) - } + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + b.DeleteFrom("artist").Where("id > 5").String(), + ) - // UPDATE artist SET name = ? WHERE id < 5 - if _, err = b.Update("artist").Set("name = ?", "Artist").Where("id < ?", 5).Exec(); err != nil { - t.Fatal(err) - } + // Testing UPDATE. - // UPDATE artist SET name = ? || ' ' || ? || id, id = id + ? WHERE id > ? - if _, err = b.Update("artist").Set( - "name = ? || ' ' || ? || id", "Artist", "#", - "id = id + ?", 10, - ).Where("id > ?", 0).Exec(); err != nil { - t.Fatal(err) - } + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) /* // INSERT INTO artist (name) VALUES(? || ?) @@ -2105,6 +2112,34 @@ func TestQueryBuilder(t *testing.T) { } */ + // Testing actual queries. + + var artist artistType + var artists []artistType + + err = b.SelectAllFrom("artist").All(&artists) + assert.NoError(err) + assert.True(len(artists) > 0) + + err = b.SelectAllFrom("artist").One(&artist) + assert.NoError(err) + assert.NotNil(artist) + + var qs db.QuerySelector + + qs = b.SelectAllFrom("artist") + for qs.Next(&artist) { + assert.Nil(qs.Err()) + assert.NotNil(artist) + } + + assert.Nil(qs.Close()) + + qs = b.Select().From("artist a").Join("publications p").On("p1.id = a.id").Using("id") + assert.Error(qs.One(&artist), `Should not work because it attempts to use both "On()" and "Using()" in the same JOIN.`) + + qs = b.Select().From("artist a").On("p1.id = a.id") + assert.Error(qs.One(&artist), `Should not work because it should put a "Join()" before "On()".`) } // TestExhaustConnections simulates a "too many connections" situation diff --git a/postgresql/template.go b/postgresql/template.go index f195e54ab575a2b35a7f526fe0f5c67d418aa148..59f40584766167c8693d2360bfd220a6b0293261 100644 --- a/postgresql/template.go +++ b/postgresql/template.go @@ -114,6 +114,13 @@ const ( DELETE FROM {{.Table}} {{.Where}} + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} ` adapterUpdateLayout = ` UPDATE