diff --git a/postgresql/collection.go b/postgresql/collection.go index bb4ecdea2b8aba7aeff2e841d17fad8182538cbc..927df7834aeb4d718b95174201d34f9f20e88bed 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -38,8 +38,9 @@ const defaultOperator = `=` type table struct { sqlutil.T - source *source - names []string + source *source + primaryKey string + names []string } func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { @@ -214,11 +215,12 @@ func (self *table) Truncate() error { // Appends an item (map or struct) into the collection. func (self *table) Append(item interface{}) (interface{}, error) { - - cols, vals, err := self.FieldValues(item, toInternal) - + var pKey string var columns sqlgen.Columns var values sqlgen.Values + var id int64 + + cols, vals, err := self.FieldValues(item, toInternal) for _, col := range cols { columns = append(columns, sqlgen.Column{col}) @@ -233,36 +235,54 @@ func (self *table) Append(item interface{}) (interface{}, error) { return nil, err } - var extra string - - //if _, ok := self.ColumnTypes[self.PrimaryKey]; ok == true { - // extra = fmt.Sprintf(`RETURNING %s`, self.PrimaryKey) - //} + if pKey, err = self.source.getPrimaryKey(self.tableN(0)); err != nil { + if err != sql.ErrNoRows { + // Can't tell primary key. + return nil, err + } + } - row, err := self.source.doQueryRow(sqlgen.Statement{ + stmt := sqlgen.Statement{ Type: sqlgen.SqlInsert, Table: sqlgen.Table{self.tableN(0)}, Columns: columns, Values: values, - Extra: sqlgen.Extra(extra), - }, vals...) - - if err != nil { - return nil, err } - var id int64 + if pKey == "" { + // No primary key found. + var res sql.Result + if res, err = self.source.doExec(stmt, vals...); err != nil { + return nil, err + } + + // Attempt to use LastInsertId() (probably won't work, but the exec() + // succeeded, so the error from LastInsertId() is ignored). + id, _ = res.LastInsertId() - if err = row.Scan(&id); err != nil { - if err == sql.ErrNoRows { - // Can't tell the row's id. Maybe there isn't any? - return nil, nil + return id, nil + } else { + var row *sql.Row + + // A primary key was found. + stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING %s`, pKey)) + if row, err = self.source.doQueryRow(stmt, vals...); err != nil { + return nil, err } - // Other kind of error. - return nil, err + + // Retrieving key value. + if err = row.Scan(&id); err != nil { + if err == sql.ErrNoRows { + // Can't tell the row's id. Maybe there isn't any? + return nil, nil + } + // Other kind of error. + return nil, err + } + return id, nil } - return id, nil + return nil, nil } // Returns true if the collection exists. diff --git a/postgresql/database.go b/postgresql/database.go index f983970150ed6e477c487eaec36452b718bc06f4..cff0f7ef9cd9922f46bb78f9b4ce7cecfea76d7b 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -56,6 +56,7 @@ type source struct { session *sql.DB collections map[string]db.Collection tx *sql.Tx + primaryKeys map[string]string } type columnSchema_t struct { @@ -210,7 +211,7 @@ func (self *source) Ping() error { } func (self *source) clone() (*source, error) { - src := &source{} + src := new(source) src.Setup(self.config) if err := src.Open(); err != nil { @@ -386,6 +387,8 @@ func (self *source) tableExists(names ...string) error { // Returns a collection instance by name. func (self *source) Collection(names ...string) (db.Collection, error) { + var rows *sql.Rows + var err error if len(names) == 0 { return nil, db.ErrMissingCollectionName @@ -409,7 +412,8 @@ func (self *source) Collection(names ...string) (db.Collection, error) { return nil, err } - rows, err := self.doQuery(sqlgen.Statement{ + // Getting columns + rows, err = self.doQuery(sqlgen.Statement{ Type: sqlgen.SqlSelect, Table: sqlgen.Table{`information_schema.columns`}, Columns: sqlgen.Columns{ @@ -435,9 +439,55 @@ func (self *source) Collection(names ...string) (db.Collection, error) { for _, column := range columns_t { col.Columns = append(col.Columns, strings.ToLower(column.Name)) } - } } return col, nil } + +func (self *source) getPrimaryKey(tableName string) (string, error) { + var row *sql.Row + var err error + var pKey string + + if self.primaryKeys == nil { + self.primaryKeys = make(map[string]string) + } + + if pKey, ok := self.primaryKeys[tableName]; ok { + // Retrieving cached key name. + return pKey, nil + } + + // Getting primary key. See https://github.com/upper/db/issues/24. + row, err = self.doQueryRow(sqlgen.Statement{ + Type: sqlgen.SqlSelect, + Table: sqlgen.Table{`pg_index, pg_class, pg_attribute`}, + Columns: sqlgen.Columns{ + {`pg_attribute.attname`}, + }, + Where: sqlgen.Where{ + sqlgen.ColumnValue{sqlgen.Column{`pg_class.oid`}, `=`, sqlgen.Value{sqlgen.Raw{`?::regclass`}}}, + sqlgen.ColumnValue{sqlgen.Column{`indrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, + sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attrelid`}, `=`, sqlgen.Value{sqlgen.Raw{`pg_class.oid`}}}, + sqlgen.ColumnValue{sqlgen.Column{`pg_attribute.attnum`}, `=`, sqlgen.Value{sqlgen.Raw{`any(pg_index.indkey)`}}}, + sqlgen.Raw{`indisprimary`}, + }, + Limit: 1, + }, tableName) + + if err != nil { + return "", err + } + + if err = row.Scan(&pKey); err != nil { + return "", err + } + + // Caching key name. + // TODO: There is currently no policy for cache life and no cache-cleaning + // methods are provided. + self.primaryKeys[tableName] = pKey + + return pKey, nil +}