From ddbf174fc57a0f0b65461cfcafc939ecfa1b60c8 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Tue, 16 Sep 2014 08:19:18 -0500
Subject: [PATCH] PostgreSQL: An error was missing after trying to use
 tx.Collection() after the transaction was completed.

---
 postgresql/database.go | 23 +++++++++++++++++------
 postgresql/tx.go       | 16 +++++++++++++---
 2 files changed, 30 insertions(+), 9 deletions(-)

diff --git a/postgresql/database.go b/postgresql/database.go
index f5d5a8ab..6668eba5 100644
--- a/postgresql/database.go
+++ b/postgresql/database.go
@@ -56,7 +56,7 @@ var (
 type source struct {
 	config  db.Settings
 	session *sql.DB
-	tx      *sql.Tx
+	tx      *tx
 	schema  *schema.DatabaseSchema
 }
 
@@ -162,7 +162,7 @@ func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Resu
 	}
 
 	if self.tx != nil {
-		res, err = self.tx.Exec(query, args...)
+		res, err = self.tx.sqlTx.Exec(query, args...)
 	} else {
 		res, err = self.session.Exec(query, args...)
 	}
@@ -195,7 +195,7 @@ func (self *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro
 	}
 
 	if self.tx != nil {
-		rows, err = self.tx.Query(query, args...)
+		rows, err = self.tx.sqlTx.Query(query, args...)
 	} else {
 		rows, err = self.session.Query(query, args...)
 	}
@@ -228,7 +228,7 @@ func (self *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql
 	}
 
 	if self.tx != nil {
-		row = self.tx.QueryRow(query, args...)
+		row = self.tx.sqlTx.QueryRow(query, args...)
 	} else {
 		row = self.session.QueryRow(query, args...)
 	}
@@ -275,9 +275,9 @@ func (self *source) Transaction() (db.Tx, error) {
 		return nil, err
 	}
 
-	tx := &tx{clone}
+	tx := &tx{source: clone, sqlTx: sqlTx}
 
-	clone.tx = sqlTx
+	clone.tx = tx
 
 	return tx, nil
 }
@@ -433,6 +433,11 @@ func (self *source) tableExists(names ...string) error {
 
 	for i := range names {
 
+		if self.schema.HasTable(names[i]) {
+			// We already know this table exists.
+			continue
+		}
+
 		stmt = sqlgen.Statement{
 			Type:  sqlgen.SqlSelect,
 			Table: sqlgen.Table{`information_schema.tables`},
@@ -521,6 +526,12 @@ func (self *source) Collection(names ...string) (db.Collection, error) {
 		return nil, db.ErrMissingCollectionName
 	}
 
+	if self.tx != nil {
+		if self.tx.done {
+			return nil, sql.ErrTxDone
+		}
+	}
+
 	col := &table{
 		source: self,
 		names:  names,
diff --git a/postgresql/tx.go b/postgresql/tx.go
index 66f40581..964b2e11 100644
--- a/postgresql/tx.go
+++ b/postgresql/tx.go
@@ -21,14 +21,24 @@
 
 package postgresql
 
+import (
+	"database/sql"
+)
+
 type tx struct {
 	*source
+	sqlTx *sql.Tx
+	done  bool
 }
 
-func (self *tx) Commit() error {
-	return self.source.tx.Commit()
+func (self *tx) Commit() (err error) {
+	err = self.sqlTx.Commit()
+	if err == nil {
+		self.done = true
+	}
+	return err
 }
 
 func (self *tx) Rollback() error {
-	return self.source.tx.Rollback()
+	return self.sqlTx.Rollback()
 }
-- 
GitLab