From 1529225d0fc28bb8cc3aa3ef963fbf5822ffb5eb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Mon, 7 Sep 2015 10:45:05 -0500
Subject: [PATCH] Query builder's first fully working CRUD test.

---
 builder.go                  |  20 ++++--
 postgresql/builder.go       |  97 +++++++++++++++++++++++++++
 postgresql/database_test.go |  30 ++++++++-
 util/sqlutil/convert.go     | 129 +++++++++++++++++++++++-------------
 4 files changed, 225 insertions(+), 51 deletions(-)

diff --git a/builder.go b/builder.go
index 94617c77..edc8bc8d 100644
--- a/builder.go
+++ b/builder.go
@@ -8,7 +8,8 @@ import (
 type QueryBuilder interface {
 	Select(fields ...interface{}) QuerySelector
 	InsertInto(table string) QueryInserter
-	//Update(table string) QueryUpdater
+	DeleteFrom(table string) QueryDeleter
+	Update(table string) QueryUpdater
 }
 
 type QuerySelector interface {
@@ -18,11 +19,22 @@ type QuerySelector interface {
 type QueryInserter interface {
 	Values(...interface{}) QueryInserter
 	Columns(...string) QueryInserter
-	Exec() (sql.Result, error)
+	QueryExecer
+}
+
+type QueryDeleter interface {
+	Where(...interface{}) QueryDeleter
+	Limit(int) QueryDeleter
+	QueryExecer
 }
 
 type QueryUpdater interface {
-	Set() QueryUpdater
+	Set(...interface{}) QueryUpdater
+	Where(...interface{}) QueryUpdater
+	Limit(int) QueryUpdater
+	QueryExecer
+}
 
-	Do() error
+type QueryExecer interface {
+	Exec() (sql.Result, error)
 }
diff --git a/postgresql/builder.go b/postgresql/builder.go
index 14d836ea..bfe125ac 100644
--- a/postgresql/builder.go
+++ b/postgresql/builder.go
@@ -25,6 +25,20 @@ func (b *Builder) InsertInto(table string) db.QueryInserter {
 	}
 }
 
+func (b *Builder) DeleteFrom(table string) db.QueryDeleter {
+	return &QueryDeleter{
+		builder: b,
+		table:   table,
+	}
+}
+
+func (b *Builder) Update(table string) db.QueryUpdater {
+	return &QueryUpdater{
+		builder: b,
+		table:   table,
+	}
+}
+
 type QuerySelector struct {
 	builder *Builder
 	fields  []interface{}
@@ -80,3 +94,86 @@ func (qi *QueryInserter) Values(values ...interface{}) db.QueryInserter {
 	qi.values = append(qi.values, sqlgen.NewValueGroup(f...))
 	return qi
 }
+
+type QueryDeleter struct {
+	builder *Builder
+	table   string
+	limit   int
+	where   *sqlgen.Where
+	args    []interface{}
+}
+
+func (qd *QueryDeleter) Where(terms ...interface{}) db.QueryDeleter {
+	where, arguments := template.ToWhereWithArguments(terms)
+	qd.where = &where
+	qd.args = append(qd.args, arguments...)
+	return qd
+}
+
+func (qd *QueryDeleter) Limit(limit int) db.QueryDeleter {
+	qd.limit = limit
+	return qd
+}
+
+func (qd *QueryDeleter) Exec() (sql.Result, error) {
+	stmt := &sqlgen.Statement{
+		Type:  sqlgen.Delete,
+		Table: sqlgen.TableWithName(qd.table),
+	}
+
+	if qd.Where != nil {
+		stmt.Where = qd.where
+	}
+
+	if qd.limit != 0 {
+		stmt.Limit = sqlgen.Limit(qd.limit)
+	}
+
+	return qd.builder.sess.Exec(stmt, qd.args...)
+}
+
+type QueryUpdater struct {
+	builder      *Builder
+	table        string
+	columnValues *sqlgen.ColumnValues
+	limit        int
+	where        *sqlgen.Where
+	args         []interface{}
+}
+
+func (qu *QueryUpdater) Set(terms ...interface{}) db.QueryUpdater {
+	cv, args := template.ToColumnValues(terms)
+	qu.columnValues = &cv
+	qu.args = append(qu.args, args...)
+	return qu
+}
+
+func (qu *QueryUpdater) Where(terms ...interface{}) db.QueryUpdater {
+	where, arguments := template.ToWhereWithArguments(terms)
+	qu.where = &where
+	qu.args = append(qu.args, arguments...)
+	return qu
+}
+
+func (qu *QueryUpdater) Exec() (sql.Result, error) {
+	stmt := &sqlgen.Statement{
+		Type:         sqlgen.Update,
+		Table:        sqlgen.TableWithName(qu.table),
+		ColumnValues: qu.columnValues,
+	}
+
+	if qu.Where != nil {
+		stmt.Where = qu.where
+	}
+
+	if qu.limit != 0 {
+		stmt.Limit = sqlgen.Limit(qu.limit)
+	}
+
+	return qu.builder.sess.Exec(stmt, qu.args...)
+}
+
+func (qu *QueryUpdater) Limit(limit int) db.QueryUpdater {
+	qu.limit = limit
+	return qu
+}
diff --git a/postgresql/database_test.go b/postgresql/database_test.go
index 3b1f189d..2503efe4 100644
--- a/postgresql/database_test.go
+++ b/postgresql/database_test.go
@@ -2080,9 +2080,37 @@ func TestQueryBuilder(t *testing.T) {
 		t.Fatal(err)
 	}
 
+	// DELETE FROM artist WHERE name = 'Chavela Vargas' LIMIT 1
+	if _, err = b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).Exec(); err != nil {
+		t.Fatal(err)
+	}
+
+	// DELETE FROM artist WHERE id > 5
+	if _, err = b.DeleteFrom("artist").Where("id > 5").Exec(); err != nil {
+		t.Fatal(err)
+	}
+
+	// UPDATE artist SET name = ?
+	if _, err = b.Update("artist").Set("name", "Artist").Exec(); err != nil {
+		t.Fatal(err)
+	}
+
+	// UPDATE artist SET name = ? WHERE id < 5
+	if _, err = b.Update("artist").Set("name = ?", "Artist").Where("id < ?", 5).Exec(); err != nil {
+		t.Fatal(err)
+	}
+
+	// UPDATE artist SET name = ? || ' ' || ? || id, id = id + ? WHERE id > ?
+	if _, err = b.Update("artist").Set(
+		"name = ? || ' ' || ? || id", "Artist", "#",
+		"id = id + ?", 10,
+	).Where("id > ?", 0).Exec(); err != nil {
+		t.Fatal(err)
+	}
+
 	/*
 		// INSERT INTO artist (name) VALUES(? || ?)
-		if err = b.InsertInto("artist").Columns("name").Values(db.Raw("(? || ' ' || ?)"), "Tom", "Yorke").Exec(); err != nil {
+		if err = b.InsertInto("artist").Columns("name").Values(db.Expr("? || ' ' || ?", "Tom", "Yorke")).Exec(); err != nil {
 			t.Fatal(err)
 		}
 		// INSERT INTO artist ("name") VALUES('Michael Jackson')
diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go
index 4504742a..1033a6eb 100644
--- a/util/sqlutil/convert.go
+++ b/util/sqlutil/convert.go
@@ -148,74 +148,111 @@ func (tu *TemplateWithUtils) ToInterfaceArguments(value interface{}) (args []int
 }
 
 // ToColumnValues converts the given db.Cond into a sqlgen.ColumnValues struct.
-func (tu *TemplateWithUtils) ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []interface{}) {
-
+func (tu *TemplateWithUtils) ToColumnValues(term interface{}) (cv sqlgen.ColumnValues, args []interface{}) {
 	args = []interface{}{}
 
-	for column, value := range cond {
-		columnValue := sqlgen.ColumnValue{}
+	switch t := term.(type) {
+	case []interface{}:
+		l := len(t)
+		for i := 0; i < l; i++ {
+			column := t[i].(string)
 
-		// Guessing operator from input, or using a default one.
-		column := strings.TrimSpace(column)
-		chunks := strings.SplitN(column, ` `, 2)
+			if !strings.ContainsAny(column, "=") {
+				column = fmt.Sprintf("%s = ?", column)
+			}
 
-		columnValue.Column = sqlgen.ColumnWithName(chunks[0])
+			chunks := strings.SplitN(column, "=", 2)
 
-		if len(chunks) > 1 {
-			columnValue.Operator = chunks[1]
-		}
+			column = chunks[0]
+			format := strings.TrimSpace(chunks[1])
 
-		switch value := value.(type) {
-		case db.Func:
-			v := tu.ToInterfaceArguments(value.Args)
-			columnValue.Operator = value.Name
+			columnValue := sqlgen.ColumnValue{
+				Column:   sqlgen.ColumnWithName(column),
+				Operator: "=",
+				Value:    sqlgen.RawValue(format),
+			}
 
-			if v == nil {
-				// A function with no arguments.
-				columnValue.Value = sqlgen.RawValue(`()`)
+			ps := strings.Count(format, "?")
+			if i+ps < l {
+				for j := 0; j < ps; j++ {
+					args = append(args, t[i+j+1])
+				}
+				i = i + ps
 			} else {
-				// A function with one or more arguments.
-				columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1)))
+				panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format))
 			}
 
-			args = append(args, v...)
-		default:
-			v := tu.ToInterfaceArguments(value)
+			cv.ColumnValues = append(cv.ColumnValues, &columnValue)
+		}
+		return cv, args
+		// Return error.
+	case db.Cond:
+		for column, value := range t {
+			columnValue := sqlgen.ColumnValue{}
 
-			if v == nil {
-				// Nil value given.
-				columnValue.Value = sqlNull
-				if columnValue.Operator == "" {
-					columnValue.Operator = sqlIsOperator
-				}
-			} else {
-				if len(v) > 1 {
-					// Array value given.
+			// Guessing operator from input, or using a default one.
+			column := strings.TrimSpace(column)
+			chunks := strings.SplitN(column, ` `, 2)
+
+			columnValue.Column = sqlgen.ColumnWithName(chunks[0])
+
+			if len(chunks) > 1 {
+				columnValue.Operator = chunks[1]
+			}
+
+			switch value := value.(type) {
+			case db.Func:
+				v := tu.ToInterfaceArguments(value.Args)
+				columnValue.Operator = value.Name
+
+				if v == nil {
+					// A function with no arguments.
+					columnValue.Value = sqlgen.RawValue(`()`)
+				} else {
+					// A function with one or more arguments.
 					columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1)))
+				}
+
+				args = append(args, v...)
+			default:
+				v := tu.ToInterfaceArguments(value)
+
+				if v == nil {
+					// Nil value given.
+					columnValue.Value = sqlNull
 					if columnValue.Operator == "" {
-						columnValue.Operator = sqlInOperator
+						columnValue.Operator = sqlIsOperator
 					}
 				} else {
-					// Single value given.
-					columnValue.Value = sqlPlaceholder
+					if len(v) > 1 {
+						// Array value given.
+						columnValue.Value = sqlgen.RawValue(fmt.Sprintf(`(?%s)`, strings.Repeat(`, ?`, len(v)-1)))
+						if columnValue.Operator == "" {
+							columnValue.Operator = sqlInOperator
+						}
+					} else {
+						// Single value given.
+						columnValue.Value = sqlPlaceholder
+					}
+					args = append(args, v...)
 				}
-				args = append(args, v...)
 			}
-		}
 
-		// Using guessed operator if no operator was given.
-		if columnValue.Operator == "" {
-			if tu.DefaultOperator != "" {
-				columnValue.Operator = tu.DefaultOperator
-			} else {
-				columnValue.Operator = sqlDefaultOperator
+			// Using guessed operator if no operator was given.
+			if columnValue.Operator == "" {
+				if tu.DefaultOperator != "" {
+					columnValue.Operator = tu.DefaultOperator
+				} else {
+					columnValue.Operator = sqlDefaultOperator
+				}
 			}
-		}
 
-		ToColumnValues.ColumnValues = append(ToColumnValues.ColumnValues, &columnValue)
+			cv.ColumnValues = append(cv.ColumnValues, &columnValue)
+		}
+		return cv, args
 	}
 
-	return ToColumnValues, args
+	panic("Unknown map type.")
 }
 
 // ToColumnsValuesAndArguments maps the given columnNames and columnValues into
-- 
GitLab