From 03193b47035437244ff86226b5a82af228b5e97d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Thu, 1 Dec 2016 16:56:08 -0600 Subject: [PATCH] Interpolate subqueries with ? --- lib/sqlbuilder/builder_test.go | 49 ++++++++++++++++++++++++++++++++++ lib/sqlbuilder/convert.go | 7 +++-- 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/lib/sqlbuilder/builder_test.go b/lib/sqlbuilder/builder_test.go index e96f80ac..22f5d0f4 100644 --- a/lib/sqlbuilder/builder_test.go +++ b/lib/sqlbuilder/builder_test.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "fmt" "testing" "github.com/stretchr/testify/assert" @@ -619,6 +620,54 @@ func TestSelect(t *testing.T) { sel2.Arguments(), ) } + + { + sq := b. + Select("user_id"). + From("user_access"). + Where(db.Cond{"hub_id": 3}) + + sq.And(db.Cond{"role": []int{1, 2}}) + + assert.Equal( + `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))`, + sq.String(), + ) + + assert.Equal( + []interface{}{3, 1, 2}, + sq.Arguments(), + ) + + cond := db.Or( + db.Raw("a.id IN ?", sq), + ) + + cond.Or(db.Cond{"ml.mailing_list_id": []int{4, 5, 6}}) + + sel := b. + Select(db.Raw("DISTINCT ON(a.id) a.id"), db.Raw("COALESCE(NULLIF(ml.name,''), a.name) as name"), "a.email"). + From("mailing_list_recipients ml"). + FullJoin("accounts a").On("a.id = ml.user_id"). + Where(cond) + + search := "word" + sel.And(db.Or( + db.Raw("COALESCE(NULLIF(ml.name,''), a.name) ILIKE ?", fmt.Sprintf("%%%s%%", search)), + db.Cond{"a.email ILIKE": fmt.Sprintf("%%%s%%", search)}, + )) + + assert.Equal( + `SELECT DISTINCT ON(a.id) a.id, COALESCE(NULLIF(ml.name,''), a.name) as name, "a"."email" FROM "mailing_list_recipients" AS "ml" FULL JOIN "accounts" AS "a" ON (a.id = ml.user_id) WHERE ((a.id IN (SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))) OR "ml"."mailing_list_id" IN ($4, $5, $6)) AND (COALESCE(NULLIF(ml.name,''), a.name) ILIKE $7 OR "a"."email" ILIKE $8))`, + sel.String(), + ) + + assert.Equal( + []interface{}{3, 1, 2, 4, 5, 6, `%word%`, `%word%`}, + sel.Arguments(), + ) + + } } func TestInsert(t *testing.T) { diff --git a/lib/sqlbuilder/convert.go b/lib/sqlbuilder/convert.go index b410281f..af03f883 100644 --- a/lib/sqlbuilder/convert.go +++ b/lib/sqlbuilder/convert.go @@ -41,8 +41,11 @@ func expandPlaceholders(in string, args ...interface{}) (string, []interface{}) } } else { if len(values) == 1 { - if rawValue, ok := values[0].(db.RawValue); ok { - k, values = rawValue.Raw(), nil + switch t := values[0].(type) { + case db.RawValue: + k, values = t.Raw(), nil + case *selector: + k, values = `(`+t.statement().Compile(t.stringer.t)+`)`, t.Arguments() } } else if len(values) == 0 { k = `NULL` -- GitLab