diff --git a/postgresql/database.go b/postgresql/database.go index f5d5a8abf0fe8a837f1e24c083b86ecfc674ee3e..6668eba535c919703e311eb98baf50760ab5d1ee 100644 --- a/postgresql/database.go +++ b/postgresql/database.go @@ -56,7 +56,7 @@ var ( type source struct { config db.Settings session *sql.DB - tx *sql.Tx + tx *tx schema *schema.DatabaseSchema } @@ -162,7 +162,7 @@ func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Resu } if self.tx != nil { - res, err = self.tx.Exec(query, args...) + res, err = self.tx.sqlTx.Exec(query, args...) } else { res, err = self.session.Exec(query, args...) } @@ -195,7 +195,7 @@ func (self *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro } if self.tx != nil { - rows, err = self.tx.Query(query, args...) + rows, err = self.tx.sqlTx.Query(query, args...) } else { rows, err = self.session.Query(query, args...) } @@ -228,7 +228,7 @@ func (self *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql } if self.tx != nil { - row = self.tx.QueryRow(query, args...) + row = self.tx.sqlTx.QueryRow(query, args...) } else { row = self.session.QueryRow(query, args...) } @@ -275,9 +275,9 @@ func (self *source) Transaction() (db.Tx, error) { return nil, err } - tx := &tx{clone} + tx := &tx{source: clone, sqlTx: sqlTx} - clone.tx = sqlTx + clone.tx = tx return tx, nil } @@ -433,6 +433,11 @@ func (self *source) tableExists(names ...string) error { for i := range names { + if self.schema.HasTable(names[i]) { + // We already know this table exists. + continue + } + stmt = sqlgen.Statement{ Type: sqlgen.SqlSelect, Table: sqlgen.Table{`information_schema.tables`}, @@ -521,6 +526,12 @@ func (self *source) Collection(names ...string) (db.Collection, error) { return nil, db.ErrMissingCollectionName } + if self.tx != nil { + if self.tx.done { + return nil, sql.ErrTxDone + } + } + col := &table{ source: self, names: names, diff --git a/postgresql/tx.go b/postgresql/tx.go index 66f405819567906b7bb227e3900e6d13a761f768..964b2e1160194f3c71f23593406f87b9cf326d2a 100644 --- a/postgresql/tx.go +++ b/postgresql/tx.go @@ -21,14 +21,24 @@ package postgresql +import ( + "database/sql" +) + type tx struct { *source + sqlTx *sql.Tx + done bool } -func (self *tx) Commit() error { - return self.source.tx.Commit() +func (self *tx) Commit() (err error) { + err = self.sqlTx.Commit() + if err == nil { + self.done = true + } + return err } func (self *tx) Rollback() error { - return self.source.tx.Rollback() + return self.sqlTx.Rollback() }