diff --git a/builder.go b/builder.go index 3dc52af6984d326f0d52a9e56467407447a9d9b7..7dc52ba9ce801e507e27ad86509b8ae38406207e 100644 --- a/builder.go +++ b/builder.go @@ -14,6 +14,8 @@ type QueryBuilder interface { InsertInto(table string) QueryInserter DeleteFrom(table string) QueryDeleter Update(table string) QueryUpdater + + Exec(query interface{}, args ...interface{}) (sql.Result, error) } type QuerySelector interface { diff --git a/builder/builder.go b/builder/builder.go index a895c84f79aaa76857a457068840cd5970951a28..1a9d026cd1f5c5a76f00b38cbe17d79c54d46e45 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -63,6 +63,15 @@ type Builder struct { t *sqlutil.TemplateWithUtils } +func (b *Builder) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + switch q := query.(type) { + case *sqlgen.Statement: + return b.sess.Exec(q, args...) + default: + return nil, errors.New("Unsupported query type.") + } +} + func (b *Builder) TruncateTable(table string) db.QueryTruncater { qs := &QueryTruncater{ builder: b, diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go new file mode 100644 index 0000000000000000000000000000000000000000..bd13ec55d82981c4b50f3e5169089f83536a7e32 --- /dev/null +++ b/internal/sqladapter/collection.go @@ -0,0 +1,39 @@ +package sqladapter + +import ( + "upper.io/db" + "upper.io/db/util/sqlutil/result" +) + +type Collection struct { + database Database + tableName string +} + +// NewCollection returns a collection with basic methods. +func NewCollection(d Database, tableName string) *Collection { + return &Collection{database: d, tableName: tableName} +} + +// Name returns the name of the table. +func (c *Collection) Name() string { + return c.tableName +} + +// Exists returns true if the collection exists. +func (c *Collection) Exists() bool { + if err := c.Database().TableExists(c.Name()); err != nil { + return false + } + return true +} + +// Find creates a result set with the given conditions. +func (c *Collection) Find(conds ...interface{}) db.Result { + return result.NewResult(c.Database().Builder(), c.Name(), conds) +} + +// Database returns the database session that backs the collection. +func (c *Collection) Database() Database { + return c.database +} diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go new file mode 100644 index 0000000000000000000000000000000000000000..b12869dacaade927fc3e1bc93ba90b289ef438b5 --- /dev/null +++ b/internal/sqladapter/database.go @@ -0,0 +1,11 @@ +package sqladapter + +import ( + "upper.io/db" +) + +type Database interface { + db.Database + TableExists(name string) error + TablePrimaryKey(name string) ([]string, error) +} diff --git a/postgresql/collection.go b/postgresql/collection.go index d5a9fb377de1195c0b11f519d0cfe7b0ded8731e..14378bdf360f188f1f817e6354fac16ccfc680d6 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -28,22 +28,16 @@ import ( "upper.io/db" "upper.io/db/builder" + "upper.io/db/internal/sqladapter" "upper.io/db/util/sqlgen" - "upper.io/db/util/sqlutil/result" ) type table struct { - *database - name string + *sqladapter.Collection } var _ = db.Collection(&table{}) -// Find creates a result set with the given conditions. -func (t *table) Find(conds ...interface{}) db.Result { - return result.NewResult(t.database.Builder(), t.Name(), conds) -} - // Truncate deletes all rows from the table. func (t *table) Truncate() error { stmt := sqlgen.Statement{ @@ -52,7 +46,7 @@ func (t *table) Truncate() error { Extra: sqlgen.Extra("RESTART IDENTITY"), } - if _, err := t.database.Exec(&stmt); err != nil { + if _, err := t.Database().Builder().Exec(&stmt); err != nil { return err } return nil @@ -66,37 +60,35 @@ func (t *table) Append(item interface{}) (interface{}, error) { } var pKey []string - if pKey, err = t.database.getPrimaryKey(t.Name()); err != nil { + if pKey, err = t.Database().TablePrimaryKey(t.Name()); err != nil { if err != sql.ErrNoRows { - // Can't tell primary key. return nil, err } } - q := t.database.Builder().InsertInto(t.Name()). + q := t.Database().Builder().InsertInto(t.Name()). Columns(columnNames...). Values(columnValues...) - // No primary keys defined. if len(pKey) == 0 { - + // There is no primary key. var res sql.Result + if res, err = q.Exec(); err != nil { return nil, err } - // Attempt to use LastInsertId() (probably won't work, but the exec() - // succeeded, so the error from LastInsertId() is ignored). + // Attempt to use LastInsertId() (probably won't work, but the Exec() + // succeeded, so we can safely ignore the error from LastInsertId()). lastID, _ := res.LastInsertId() return lastID, nil } - // A primary key was found. + // Asking the database to return the primary key after insertion. q.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) var keyMap map[string]interface{} - if err = q.Iterator().One(&keyMap); err != nil { return nil, err } @@ -133,19 +125,10 @@ func (t *table) Append(item interface{}) (interface{}, error) { return id.(int64), nil } - // More than one key, no interface matched, let's return a map. + // This was a compound key and no interface matched it, let's return a map. return keyMap, nil } -// Exists returns true if the collection exists. -func (t *table) Exists() bool { - if err := t.database.tableExists(t.Name()); err != nil { - return false - } - return true -} - -// Name returns the name of the table or tables that form the collection. -func (t *table) Name() string { - return t.name +func newTable(d *database, name string) *table { + return &table{sqladapter.NewCollection(d, name)} } diff --git a/postgresql/database.go b/postgresql/database.go index 010c3507505e72515f5d91c57a41e4641441177a..2a7863495e6926e63f6a4f8820d02534c3364bcb 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -211,7 +211,7 @@ func (d *database) C(name string) db.Collection { return c } -// Collection returns a table by name. +// Collection returns the table that matches the given name. func (d *database) Collection(name string) (db.Collection, error) { if d.tx != nil { if d.tx.Done() { @@ -219,11 +219,11 @@ func (d *database) Collection(name string) (db.Collection, error) { } } - if err := d.tableExists(name); err != nil { + if err := d.TableExists(name); err != nil { return nil, err } - col := &table{database: d, name: name} + col := newTable(d, name) d.collectionsMu.Lock() d.collections[name] = col @@ -424,19 +424,13 @@ func (d *database) populateSchema() (err error) { return err } -func (d *database) tableExists(names ...string) error { - var row map[string]string - - for _, tableName := range names { - - if d.schema.HasTable(tableName) { - // We already know this table exists. - continue - } - +func (d *database) TableExists(name string) error { + if !d.schema.HasTable(name) { q := d.Builder().Select("table_name"). From("information_schema.tables"). - Where("table_catalog = ? AND table_name = ?", d.schema.Name, tableName) + Where("table_catalog = ? AND table_name = ?", d.schema.Name, name) + + var row map[string]string if err := q.Iterator().One(&row); err != nil { return db.ErrCollectionDoesNotExist @@ -447,10 +441,6 @@ func (d *database) tableExists(names ...string) error { } func (d *database) TableColumns(tableName string) ([]string, error) { - return d.tableColumns(tableName) -} - -func (d *database) tableColumns(tableName string) ([]string, error) { tableSchema := d.schema.Table(tableName) @@ -477,7 +467,7 @@ func (d *database) tableColumns(tableName string) ([]string, error) { return d.schema.TableInfo[tableName].Columns, nil } -func (d *database) getPrimaryKey(tableName string) ([]string, error) { +func (d *database) TablePrimaryKey(tableName string) ([]string, error) { tableSchema := d.schema.Table(tableName) if len(tableSchema.PrimaryKey) != 0 {