diff --git a/util/sqlgen/column.go b/util/sqlgen/column.go index 47d095eac3336cbdeb4c57e3ebc12b352a6e5c9b..d529b5ca639efb62f4eafbe2e209a78c6570a47e 100644 --- a/util/sqlgen/column.go +++ b/util/sqlgen/column.go @@ -12,7 +12,7 @@ func (self Column) String() string { chunks := strings.Split(self.v, sqlColumnSeparator) for i := range chunks { - chunks[i] = mustParse(sqlEscape, Raw{chunks[i]}) + chunks[i] = mustParse(sqlIdentifierQuote, Raw{chunks[i]}) } return strings.Join(chunks, sqlColumnSeparator) diff --git a/util/sqlgen/column_value.go b/util/sqlgen/column_value.go index 8cbdc6d27f3fe20fddfdea64dba820730b768f1d..4a18a2912a023c385e2c9fec31f399e8105cdb6b 100644 --- a/util/sqlgen/column_value.go +++ b/util/sqlgen/column_value.go @@ -25,5 +25,5 @@ func (self ColumnValues) String() string { out[i] = self[i].String() } - return strings.Join(out, sqlColumnComma) + return strings.Join(out, sqlIdentifierSeparator) } diff --git a/util/sqlgen/column_value_test.go b/util/sqlgen/column_value_test.go index 538c9ac29ea145b002e26d81bccd86957d8309bc..2716388e85796bd3e40b6419332fe65001d4ff9d 100644 --- a/util/sqlgen/column_value_test.go +++ b/util/sqlgen/column_value_test.go @@ -11,7 +11,7 @@ func TestColumnValue(t *testing.T) { cv = ColumnValue{Column{"id"}, "=", Value{1}} s = cv.String() - e = `"id" = "1"` + e = `"id" = '1'` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -40,7 +40,7 @@ func TestColumnValues(t *testing.T) { } s = cvs.String() - e = `"id" > "8", "other"."id" < 100, "name" = "Haruki Murakami", "created" >= NOW(), "modified" <= NOW()` + e = `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) diff --git a/util/sqlgen/columns.go b/util/sqlgen/columns.go index 4177996830105e23ad8b289593e6a2559d75cc10..a661b6a006a711bdd3a741370a4baa61ce1101b3 100644 --- a/util/sqlgen/columns.go +++ b/util/sqlgen/columns.go @@ -16,7 +16,7 @@ func (self Columns) String() string { out[i] = self[i].String() } - return strings.Join(out, sqlColumnComma) + return strings.Join(out, sqlIdentifierSeparator) } return "" } diff --git a/util/sqlgen/database.go b/util/sqlgen/database.go index 281bd904f73d45cb49c7830e56252c41d903df32..8ea234fe676a90d8f7e6f27bd0caacc372b24a7c 100644 --- a/util/sqlgen/database.go +++ b/util/sqlgen/database.go @@ -9,5 +9,5 @@ type Database struct { } func (self Database) String() string { - return mustParse(sqlEscape, Raw{fmt.Sprintf(`%v`, self.v)}) + return mustParse(sqlIdentifierQuote, Raw{fmt.Sprintf(`%v`, self.v)}) } diff --git a/util/sqlgen/main.go b/util/sqlgen/main.go index 2d610f2bd85334043d2fbdb262c07402c2b51d23..7076931372854f23efb1ecb426c8b8569a721956 100644 --- a/util/sqlgen/main.go +++ b/util/sqlgen/main.go @@ -2,21 +2,38 @@ package sqlgen import ( "bytes" - "fmt" "text/template" ) const ( - sqlColumnSeparator = `.` - sqlColumnComma = `, ` - sqlValueComma = `, ` - sqlEscape = `"{{.Raw}}"` + sqlColumnSeparator = `.` + sqlIdentifierSeparator = `, ` + sqlIdentifierQuote = `"{{.Raw}}"` + sqlValueSeparator = `, ` + sqlValueQuote = `'{{.}}'` + + sqlAndKeyword = `AND` + sqlOrKeyword = `OR` + sqlNotKeyword = `NOT` + sqlDescKeyword = `DESC` + sqlAscKeyword = `ASC` + sqlDefaultOperator = `=` + sqlClauseGroup = `({{.}})` + sqlClauseOperator = ` {{.}} ` + sqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` sqlOrderByLayout = ` {{if .Columns}} ORDER BY {{.Columns}} {{.Sort}} {{end}} ` + + sqlWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + sqlSelectLayout = ` SELECT @@ -28,9 +45,7 @@ const ( FROM {{.Table}} - {{if .Where}} - WHERE {{.Where}} - {{end}} + {{.Where}} {{.OrderBy}} @@ -45,26 +60,20 @@ const ( sqlDeleteLayout = ` DELETE FROM {{.Table}} - {{if .Where}} - WHERE {{.Where}} - {{end}} + {{.Where}} ` sqlUpdateLayout = ` UPDATE {{.Table}} SET {{.ColumnValues}} - {{if .Where}} - WHERE {{.Where}} - {{end}} + {{ .Where }} ` sqlSelectCountLayout = ` SELECT COUNT(1) AS _t FROM {{.Table}} - {{if .Where}} - WHERE {{.Where}} - {{end}} + {{.Where}} ` sqlInsertLayout = ` @@ -85,15 +94,6 @@ const ( sqlDropTableLayout = ` DROP TABLE {{.Table}} ` - - sqlAndKeyword = `AND` - sqlOrKeyword = `OR` - sqlDescKeyword = `DESC` - sqlAscKeyword = `ASC` - sqlDefaultOperator = `=` - sqlConditionGroup = `({{.}})` - - sqlColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` ) type Type uint @@ -120,7 +120,6 @@ func mustParse(text string, data interface{}) string { t := template.Must(template.New("").Parse(text)) if err := t.Execute(&b, data); err != nil { - fmt.Printf("data: %v\n", data) panic("t.Execute: " + err.Error()) } diff --git a/util/sqlgen/main_test.go b/util/sqlgen/main_test.go index 87892f4112d6cca3b4799bf676cc90b658d1be7d..c3431854e006ba31424552405f4344612b3f8f02 100644 --- a/util/sqlgen/main_test.go +++ b/util/sqlgen/main_test.go @@ -110,6 +110,26 @@ func TestSelectStarFrom(t *testing.T) { } } +func TestSelectArtistNameFrom(t *testing.T) { + var s, e string + var stmt Statement + + stmt = Statement{ + Type: SqlSelect, + Table: Table{"artist"}, + Columns: Columns{ + Column{"artist.name"}, + }, + } + + s = trim(stmt.Compile()) + e = `SELECT "artist"."name" FROM "artist"` + + if s != e { + t.Fatalf("Got: %s, Expecting: %s", s, e) + } +} + func TestSelectFieldsFrom(t *testing.T) { var s, e string var stmt Statement @@ -289,7 +309,7 @@ func TestSelectFieldsFromWhere(t *testing.T) { } s = trim(stmt.Compile()) - e = `SELECT "foo", "bar", "baz" FROM "table name" WHERE ("baz" = "99")` + e = `SELECT "foo", "bar", "baz" FROM "table name" WHERE ("baz" = '99')` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -316,7 +336,7 @@ func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { } s = trim(stmt.Compile()) - e = `SELECT "foo", "bar", "baz" FROM "table name" WHERE ("baz" = "99") LIMIT 10 OFFSET 23` + e = `SELECT "foo", "bar", "baz" FROM "table name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -336,7 +356,7 @@ func TestDelete(t *testing.T) { } s = trim(stmt.Compile()) - e = `DELETE FROM "table name" WHERE ("baz" = "99")` + e = `DELETE FROM "table name" WHERE ("baz" = '99')` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -359,7 +379,7 @@ func TestUpdate(t *testing.T) { } s = trim(stmt.Compile()) - e = `UPDATE "table name" SET "foo" = "76" WHERE ("baz" = "99")` + e = `UPDATE "table name" SET "foo" = '76' WHERE ("baz" = '99')` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -378,7 +398,7 @@ func TestUpdate(t *testing.T) { } s = trim(stmt.Compile()) - e = `UPDATE "table name" SET "foo" = "76", "bar" = 88 WHERE ("baz" = "99")` + e = `UPDATE "table name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -405,7 +425,7 @@ func TestInsert(t *testing.T) { } s = trim(stmt.Compile()) - e = `INSERT INTO "table name" ("foo", "bar", "baz") VALUES ("1", "2", 3)` + e = `INSERT INTO "table name" ("foo", "bar", "baz") VALUES ('1', '2', 3)` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) diff --git a/util/sqlgen/table.go b/util/sqlgen/table.go index 46720770c03e0e1a7d05b88e8813fcdfc1a63851..dc1db4799a473880b365abc8ad6228459791c160 100644 --- a/util/sqlgen/table.go +++ b/util/sqlgen/table.go @@ -9,5 +9,5 @@ type Table struct { } func (self Table) String() string { - return mustParse(sqlEscape, Raw{fmt.Sprintf(`%v`, self.v)}) + return mustParse(sqlIdentifierQuote, Raw{fmt.Sprintf(`%v`, self.v)}) } diff --git a/util/sqlgen/value.go b/util/sqlgen/value.go index 11be7586a031066f35073e38765e6354575266a8..f4d839eeeade25f961578f137a783d0fa322e990 100644 --- a/util/sqlgen/value.go +++ b/util/sqlgen/value.go @@ -15,8 +15,7 @@ func (self Value) String() string { if raw, ok := self.v.(Raw); ok { return raw.Raw } - - return mustParse(sqlEscape, Raw{fmt.Sprintf(`%v`, self.v)}) + return mustParse(sqlValueQuote, Raw{fmt.Sprintf(`%v`, self.v)}) } func (self Values) String() string { @@ -29,7 +28,7 @@ func (self Values) String() string { chunks = append(chunks, self[i].String()) } - return strings.Join(chunks, sqlValueComma) + return strings.Join(chunks, sqlValueSeparator) } return "" diff --git a/util/sqlgen/value_test.go b/util/sqlgen/value_test.go index 525e7eb061b52e3742837b95bc4dff4843fe40b4..bc97bd1e341dc2659c7e74cd45c07cb25e432dbc 100644 --- a/util/sqlgen/value_test.go +++ b/util/sqlgen/value_test.go @@ -11,7 +11,7 @@ func TestValue(t *testing.T) { val = Value{1} s = val.String() - e = `"1"` + e = `'1'` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -38,7 +38,7 @@ func TestValues(t *testing.T) { } s = val.String() - e = `1, 2, "3"` + e = `1, 2, '3'` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) diff --git a/util/sqlgen/where.go b/util/sqlgen/where.go index 830c02b0405a2b53a147798d9abfe91fa605175c..99dfae74bbcf1566991a13bac10f148a2da25d1b 100644 --- a/util/sqlgen/where.go +++ b/util/sqlgen/where.go @@ -10,16 +10,24 @@ type ( Where []interface{} ) +type conds struct { + Conds string +} + func (self Or) String() string { - return groupCondition(self, sqlOrKeyword) + return groupCondition(self, mustParse(sqlClauseOperator, sqlOrKeyword)) } func (self And) String() string { - return groupCondition(self, sqlAndKeyword) + return groupCondition(self, mustParse(sqlClauseOperator, sqlAndKeyword)) } func (self Where) String() string { - return groupCondition(self, sqlAndKeyword) + grouped := groupCondition(self, mustParse(sqlClauseOperator, sqlAndKeyword)) + if grouped != "" { + return mustParse(sqlWhereLayout, conds{grouped}) + } + return "" } func groupCondition(terms []interface{}, joinKeyword string) string { @@ -44,7 +52,7 @@ func groupCondition(terms []interface{}, joinKeyword string) string { } if len(chunks) > 0 { - return mustParse(sqlConditionGroup, strings.Join(chunks, " "+joinKeyword+" ")) + return mustParse(sqlClauseGroup, strings.Join(chunks, joinKeyword)) } return "" diff --git a/util/sqlgen/where_test.go b/util/sqlgen/where_test.go index 3605fc92cd21d38b95ee4ea40131d9276f799acd..df86ea32a5d2cf60394022156118a9ac79b426ab 100644 --- a/util/sqlgen/where_test.go +++ b/util/sqlgen/where_test.go @@ -15,7 +15,7 @@ func TestWhereAnd(t *testing.T) { } s = and.String() - e = `("id" > 8 AND "id" < 99 AND "name" = "John")` + e = `("id" > 8 AND "id" < 99 AND "name" = 'John')` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -54,7 +54,7 @@ func TestWhereAndOr(t *testing.T) { } s = and.String() - e = `("id" > 8 AND "id" < 99 AND "name" = "John" AND ("last_name" = "Smith" OR "last_name" = "Reyes"))` + e = `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e) @@ -82,8 +82,8 @@ func TestWhereAndRawOrAnd(t *testing.T) { }, } - s = where.String() - e = `(("id" > 8 AND "id" < 99) AND "name" = "John" AND city_id = 728 AND ("last_name" = "Smith" OR "last_name" = "Reyes") AND ("age" > 18 AND "age" < 41))` + s = trim(where.String()) + e = `WHERE (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))` if s != e { t.Fatalf("Got: %s, Expecting: %s", s, e)