From 1529225d0fc28bb8cc3aa3ef963fbf5822ffb5eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Mon, 7 Sep 2015 10:45:05 -0500 Subject: [PATCH] Query builder's first fully working CRUD test. --- builder.go | 20 ++++-- postgresql/builder.go | 97 +++++++++++++++++++++++++++ postgresql/database_test.go | 30 ++++++++- util/sqlutil/convert.go | 129 +++++++++++++++++++++++------------- 4 files changed, 225 insertions(+), 51 deletions(-) diff --git a/builder.go b/builder.go index 94617c77..edc8bc8d 100644 --- a/builder.go +++ b/builder.go @@ -8,7 +8,8 @@ import ( type QueryBuilder interface { Select(fields ...interface{}) QuerySelector InsertInto(table string) QueryInserter - //Update(table string) QueryUpdater + DeleteFrom(table string) QueryDeleter + Update(table string) QueryUpdater } type QuerySelector interface { @@ -18,11 +19,22 @@ type QuerySelector interface { type QueryInserter interface { Values(...interface{}) QueryInserter Columns(...string) QueryInserter - Exec() (sql.Result, error) + QueryExecer +} + +type QueryDeleter interface { + Where(...interface{}) QueryDeleter + Limit(int) QueryDeleter + QueryExecer } type QueryUpdater interface { - Set() QueryUpdater + Set(...interface{}) QueryUpdater + Where(...interface{}) QueryUpdater + Limit(int) QueryUpdater + QueryExecer +} - Do() error +type QueryExecer interface { + Exec() (sql.Result, error) } diff --git a/postgresql/builder.go b/postgresql/builder.go index 14d836ea..bfe125ac 100644 --- a/postgresql/builder.go +++ b/postgresql/builder.go @@ -25,6 +25,20 @@ func (b *Builder) InsertInto(table string) db.QueryInserter { } } +func (b *Builder) DeleteFrom(table string) db.QueryDeleter { + return &QueryDeleter{ + builder: b, + table: table, + } +} + +func (b *Builder) Update(table string) db.QueryUpdater { + return &QueryUpdater{ + builder: b, + table: table, + } +} + type QuerySelector struct { builder *Builder fields []interface{} @@ -80,3 +94,86 @@ func (qi *QueryInserter) Values(values ...interface{}) db.QueryInserter { qi.values = append(qi.values, sqlgen.NewValueGroup(f...)) return qi } + +type QueryDeleter struct { + builder *Builder + table string + limit int + where *sqlgen.Where + args []interface{} +} + +func (qd *QueryDeleter) Where(terms ...interface{}) db.QueryDeleter { + where, arguments := template.ToWhereWithArguments(terms) + qd.where = &where + qd.args = append(qd.args, arguments...) + return qd +} + +func (qd *QueryDeleter) Limit(limit int) db.QueryDeleter { + qd.limit = limit + return qd +} + +func (qd *QueryDeleter) Exec() (sql.Result, error) { + stmt := &sqlgen.Statement{ + Type: sqlgen.Delete, + Table: sqlgen.TableWithName(qd.table), + } + + if qd.Where != nil { + stmt.Where = qd.where + } + + if qd.limit != 0 { + stmt.Limit = sqlgen.Limit(qd.limit) + } + + return qd.builder.sess.Exec(stmt, qd.args...) +} + +type QueryUpdater struct { + builder *Builder + table string + columnValues *sqlgen.ColumnValues + limit int + where *sqlgen.Where + args []interface{} +} + +func (qu *QueryUpdater) Set(terms ...interface{}) db.QueryUpdater { + cv, args := template.ToColumnValues(terms) + qu.columnValues = &cv + qu.args = append(qu.args, args...) + return qu +} + +func (qu *QueryUpdater) Where(terms ...interface{}) db.QueryUpdater { + where, arguments := template.ToWhereWithArguments(terms) + qu.where = &where + qu.args = append(qu.args, arguments...) + return qu +} + +func (qu *QueryUpdater) Exec() (sql.Result, error) { + stmt := &sqlgen.Statement{ + Type: sqlgen.Update, + Table: sqlgen.TableWithName(qu.table), + ColumnValues: qu.columnValues, + } + + if qu.Where != nil { + stmt.Where = qu.where + } + + if qu.limit != 0 { + stmt.Limit = sqlgen.Limit(qu.limit) + } + + return qu.builder.sess.Exec(stmt, qu.args...) +} + +func (qu *QueryUpdater) Limit(limit int) db.QueryUpdater { + qu.limit = limit + return qu +} diff --git a/postgresql/database_test.go b/postgresql/database_test.go index 3b1f189d..2503efe4 100644 --- a/postgresql/database_test.go +++ b/postgresql/database_test.go @@ -2080,9 +2080,37 @@ func TestQueryBuilder(t *testing.T) { t.Fatal(err) } + // DELETE FROM artist WHERE name = 'Chavela Vargas' LIMIT 1 + if _, err = b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).Exec(); err != nil { + t.Fatal(err) + } + + // DELETE FROM artist WHERE id > 5 + if _, err = b.DeleteFrom("artist").Where("id > 5").Exec(); err != nil { + t.Fatal(err) + } + + // UPDATE artist SET name = ? + if _, err = b.Update("artist").Set("name", "Artist").Exec(); err != nil { + t.Fatal(err) + } + + // UPDATE artist SET name = ? WHERE id < 5 + if _, err = b.Update("artist").Set("name = ?", "Artist").Where("id < ?", 5).Exec(); err != nil { + t.Fatal(err) + } + + // UPDATE artist SET name = ? || ' ' || ? || id, id = id + ? WHERE id > ? + if _, err = b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).Exec(); err != nil { + t.Fatal(err) + } + /* // INSERT INTO artist (name) VALUES(? || ?) - if err = b.InsertInto("artist").Columns("name").Values(db.Raw("(? || ' ' || ?)"), "Tom", "Yorke").Exec(); err != nil { + if err = b.InsertInto("artist").Columns("name").Values(db.Expr("? || ' ' || ?", "Tom", "Yorke")).Exec(); err != nil { t.Fatal(err) } // INSERT INTO artist ("name") VALUES('Michael Jackson') diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go index 4504742a..1033a6eb 100644 --- a/util/sqlutil/convert.go +++ b/util/sqlutil/convert.go @@ -148,74 +148,111 @@ func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []int } // ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct. -func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) { - +func (tu *TemplateWithUtils) ToColumnValues(term interface{}) (cv sqlgen.ColumnValues, args []interface{}) { args = []interface{}{} - for column, value := range cond { - columnValue := sqlgen.ColumnValue{} + switch t := term.(type) { + case []interface{}: + l := len(t) + for i := 0; i < l; i++ { + column := t[i].(string) - // Guessing operator from input, or using a default one. - column := strings.TrimSpace(column) - chunks := strings.SplitN(column, ` `, 2) + if !strings.ContainsAny(column, "=") { + column = fmt.Sprintf("%s = ?", column) + } - columnValue.Column = sqlgen.ColumnWithName(chunks[0]) + chunks := strings.SplitN(column, "=", 2) - if len(chunks) > 1 { - columnValue.Operator = chunks[1] - } + column = chunks[0] + format := strings.TrimSpace(chunks[1]) - switch value := value.(type) { - case db.Func: - v := tu.ToInterfaceArguments(value.Args) - columnValue.Operator = value.Name + columnValue := sqlgen.ColumnValue{ + Column: sqlgen.ColumnWithName(column), + Operator: "=", + Value: sqlgen.RawValue(format), + } - if v == nil { - // A function with no arguments. - columnValue.Value = sqlgen.RawValue(`()`) + ps := strings.Count(format, "?") + if i+ps < l { + for j := 0; j < ps; j++ { + args = append(args, t[i+j+1]) + } + i = i + ps } else { - // A function with one or more arguments. - columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format)) } - args = append(args, v...) - default: - v := tu.ToInterfaceArguments(value) + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + } + return cv, args + // Return error. + case db.Cond: + for column, value := range t { + columnValue := sqlgen.ColumnValue{} - if v == nil { - // Nil value given. - columnValue.Value = sqlNull - if columnValue.Operator == "" { - columnValue.Operator = sqlIsOperator - } - } else { - if len(v) > 1 { - // Array value given. + // Guessing operator from input, or using a default one. + column := strings.TrimSpace(column) + chunks := strings.SplitN(column, ` `, 2) + + columnValue.Column = sqlgen.ColumnWithName(chunks[0]) + + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } + + switch value := value.(type) { + case db.Func: + v := tu.ToInterfaceArguments(value.Args) + columnValue.Operator = value.Name + + if v == nil { + // A function with no arguments. + columnValue.Value = sqlgen.RawValue(`()`) + } else { + // A function with one or more arguments. columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + } + + args = append(args, v...) + default: + v := tu.ToInterfaceArguments(value) + + if v == nil { + // Nil value given. + columnValue.Value = sqlNull if columnValue.Operator == "" { - columnValue.Operator = sqlInOperator + columnValue.Operator = sqlIsOperator } } else { - // Single value given. - columnValue.Value = sqlPlaceholder + if len(v) > 1 { + // Array value given. + columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1))) + if columnValue.Operator == "" { + columnValue.Operator = sqlInOperator + } + } else { + // Single value given. + columnValue.Value = sqlPlaceholder + } + args = append(args, v...) } - args = append(args, v...) } - } - // Using guessed operator if no operator was given. - if columnValue.Operator == "" { - if tu.DefaultOperator != "" { - columnValue.Operator = tu.DefaultOperator - } else { - columnValue.Operator = sqlDefaultOperator + // Using guessed operator if no operator was given. + if columnValue.Operator == "" { + if tu.DefaultOperator != "" { + columnValue.Operator = tu.DefaultOperator + } else { + columnValue.Operator = sqlDefaultOperator + } } - } - ToColumnValues.ColumnValues = append(ToColumnValues.ColumnValues, &columnValue) + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + } + return cv, args } - return ToColumnValues, args + panic("Unknown map type.") } // ToColumnsValuesAndArguments maps the given columnNames and columnValues into -- GitLab