From 0f065e510a33d9f1b952f097ab33b94725e899fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net> Date: Tue, 6 Oct 2015 04:16:16 -0500 Subject: [PATCH] Moving collection's shared logic to an internal package. --- builder.go | 2 ++ builder/builder.go | 9 +++++++ internal/sqladapter/collection.go | 39 ++++++++++++++++++++++++++++ internal/sqladapter/database.go | 11 ++++++++ postgresql/collection.go | 43 ++++++++++--------------------- postgresql/database.go | 28 +++++++------------- 6 files changed, 83 insertions(+), 49 deletions(-) create mode 100644 internal/sqladapter/collection.go create mode 100644 internal/sqladapter/database.go diff --git a/builder.go b/builder.go index 3dc52af6..7dc52ba9 100644 --- a/builder.go +++ b/builder.go @@ -14,6 +14,8 @@ type QueryBuilder interface { InsertInto(table string) QueryInserter DeleteFrom(table string) QueryDeleter Update(table string) QueryUpdater + + Exec(query interface{}, args ...interface{}) (sql.Result, error) } type QuerySelector interface { diff --git a/builder/builder.go b/builder/builder.go index a895c84f..1a9d026c 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -63,6 +63,15 @@ type Builder struct { t *sqlutil.TemplateWithUtils } +func (b *Builder) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + switch q := query.(type) { + case *sqlgen.Statement: + return b.sess.Exec(q, args...) + default: + return nil, errors.New("Unsupported query type.") + } +} + func (b *Builder) TruncateTable(table string) db.QueryTruncater { qs := &QueryTruncater{ builder: b, diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go new file mode 100644 index 00000000..bd13ec55 --- /dev/null +++ b/internal/sqladapter/collection.go @@ -0,0 +1,39 @@ +package sqladapter + +import ( + "upper.io/db" + "upper.io/db/util/sqlutil/result" +) + +type Collection struct { + database Database + tableName string +} + +// NewCollection returns a collection with basic methods. +func NewCollection(d Database, tableName string) *Collection { + return &Collection{database: d, tableName: tableName} +} + +// Name returns the name of the table. +func (c *Collection) Name() string { + return c.tableName +} + +// Exists returns true if the collection exists. +func (c *Collection) Exists() bool { + if err := c.Database().TableExists(c.Name()); err != nil { + return false + } + return true +} + +// Find creates a result set with the given conditions. +func (c *Collection) Find(conds ...interface{}) db.Result { + return result.NewResult(c.Database().Builder(), c.Name(), conds) +} + +// Database returns the database session that backs the collection. +func (c *Collection) Database() Database { + return c.database +} diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go new file mode 100644 index 00000000..b12869da --- /dev/null +++ b/internal/sqladapter/database.go @@ -0,0 +1,11 @@ +package sqladapter + +import ( + "upper.io/db" +) + +type Database interface { + db.Database + TableExists(name string) error + TablePrimaryKey(name string) ([]string, error) +} diff --git a/postgresql/collection.go b/postgresql/collection.go index d5a9fb37..14378bdf 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -28,22 +28,16 @@ import ( "upper.io/db" "upper.io/db/builder" + "upper.io/db/internal/sqladapter" "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil/result" ) type table struct { - *database - name string + *sqladapter.Collection } var _ = db.Collection(&table{}) -// Find creates a result set with the given conditions. -func (t *table) Find(conds ...interface{}) db.Result { - return result.NewResult(t.database.Builder(), t.Name(), conds) -} - // Truncate deletes all rows from the table. func (t *table) Truncate() error { stmt := sqlgen.Statement{ @@ -52,7 +46,7 @@ func (t *table) Truncate() error { Extra: sqlgen.Extra("RESTART IDENTITY"), } - if _, err := t.database.Exec(&stmt); err != nil { + if _, err := t.Database().Builder().Exec(&stmt); err != nil { return err } return nil @@ -66,37 +60,35 @@ func (t *table) Append(item interface{}) (interface{}, error) { } var pKey []string - if pKey, err = t.database.getPrimaryKey(t.Name()); err != nil { + if pKey, err = t.Database().TablePrimaryKey(t.Name()); err != nil { if err != sql.ErrNoRows { - // Can't tell primary key. return nil, err } } - q := t.database.Builder().InsertInto(t.Name()). + q := t.Database().Builder().InsertInto(t.Name()). Columns(columnNames...). Values(columnValues...) - // No primary keys defined. if len(pKey) == 0 { - + // There is no primary key. var res sql.Result + if res, err = q.Exec(); err != nil { return nil, err } - // Attempt to use LastInsertId() (probably won't work, but the exec() - // succeeded, so the error from LastInsertId() is ignored). + // Attempt to use LastInsertId() (probably won't work, but the Exec() + // succeeded, so we can safely ignore the error from LastInsertId()). lastID, _ := res.LastInsertId() return lastID, nil } - // A primary key was found. + // Asking the database to return the primary key after insertion. q.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) var keyMap map[string]interface{} - if err = q.Iterator().One(&keyMap); err != nil { return nil, err } @@ -133,19 +125,10 @@ func (t *table) Append(item interface{}) (interface{}, error) { return id.(int64), nil } - // More than one key, no interface matched, let's return a map. + // This was a compound key and no interface matched it, let's return a map. return keyMap, nil } -// Exists returns true if the collection exists. -func (t *table) Exists() bool { - if err := t.database.tableExists(t.Name()); err != nil { - return false - } - return true -} - -// Name returns the name of the table or tables that form the collection. -func (t *table) Name() string { - return t.name +func newTable(d *database, name string) *table { + return &table{sqladapter.NewCollection(d, name)} } diff --git a/postgresql/database.go b/postgresql/database.go index 010c3507..2a786349 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -211,7 +211,7 @@ func (d *database) C(name string) db.Collection { return c } -// Collection returns a table by name. +// Collection returns the table that matches the given name. func (d *database) Collection(name string) (db.Collection, error) { if d.tx != nil { if d.tx.Done() { @@ -219,11 +219,11 @@ func (d *database) Collection(name string) (db.Collection, error) { } } - if err := d.tableExists(name); err != nil { + if err := d.TableExists(name); err != nil { return nil, err } - col := &table{database: d, name: name} + col := newTable(d, name) d.collectionsMu.Lock() d.collections[name] = col @@ -424,19 +424,13 @@ func (d *database) populateSchema() (err error) { return err } -func (d *database) tableExists(names ...string) error { - var row map[string]string - - for _, tableName := range names { - - if d.schema.HasTable(tableName) { - // We already know this table exists. - continue - } - +func (d *database) TableExists(name string) error { + if !d.schema.HasTable(name) { q := d.Builder().Select("table_name"). From("information_schema.tables"). - Where("table_catalog = ? AND table_name = ?", d.schema.Name, tableName) + Where("table_catalog = ? AND table_name = ?", d.schema.Name, name) + + var row map[string]string if err := q.Iterator().One(&row); err != nil { return db.ErrCollectionDoesNotExist @@ -447,10 +441,6 @@ func (d *database) tableExists(names ...string) error { } func (d *database) TableColumns(tableName string) ([]string, error) { - return d.tableColumns(tableName) -} - -func (d *database) tableColumns(tableName string) ([]string, error) { tableSchema := d.schema.Table(tableName) @@ -477,7 +467,7 @@ func (d *database) tableColumns(tableName string) ([]string, error) { return d.schema.TableInfo[tableName].Columns, nil } -func (d *database) getPrimaryKey(tableName string) ([]string, error) { +func (d *database) TablePrimaryKey(tableName string) ([]string, error) { tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 { -- GitLab