diff --git a/ql/collection.go b/ql/collection.go index c2a64590adaa445609913bd68f6cb59ad112ea22..b7d799562776b5db8d1d3c01ce97d7309a0cf443 100644 --- a/ql/collection.go +++ b/ql/collection.go @@ -25,8 +25,8 @@ import ( "database/sql" "upper.io/db.v2" - "upper.io/db.v2/sqlbuilder" "upper.io/db.v2/internal/sqladapter" + "upper.io/db.v2/sqlbuilder" ) // table is the actual implementation of a collection. @@ -63,7 +63,7 @@ func (r *resultProxy) Select(fields ...interface{}) db.Result { var columns []struct { Name string `db:"Name"` } - err := r.t.d.Builder.Select("Name"). + err := r.t.d.Select("Name"). From("__Column"). Where("TableName", r.t.Name()). Iterator().All(&columns) diff --git a/ql/database.go b/ql/database.go index 541d4e5cd683d979603f2f54d045107487257d54..9d7b4a30f4671d4ec5850f43b56192fee5466058 100644 --- a/ql/database.go +++ b/ql/database.go @@ -24,29 +24,23 @@ package ql import ( "database/sql" "errors" + "sync" "sync/atomic" _ "github.com/cznic/ql/driver" // QL driver "upper.io/db.v2" + "upper.io/db.v2/internal/sqladapter" "upper.io/db.v2/sqlbuilder" "upper.io/db.v2/sqlbuilder/exql" - "upper.io/db.v2/internal/sqladapter" ) -// Database represents a SQL database. -type Database interface { - db.Database - builder.Builder - - NewTransaction() (Tx, error) -} - // database is the actual implementation of Database type database struct { sqladapter.BaseDatabase // Leveraged by sqladapter - builder.Builder + db.SQLBuilder connURL db.ConnectionURL + txMu sync.Mutex } var ( @@ -56,8 +50,7 @@ var ( ) var ( - _ = sqladapter.Database(&database{}) - _ = db.Database(&database{}) + _ = db.SQLDatabase(&database{}) ) // newDatabase binds *database with sqladapter and the SQL builer. @@ -69,7 +62,7 @@ func newDatabase(settings db.ConnectionURL) (*database, error) { } // Open stablishes a new connection to a SQL server. -func Open(settings db.ConnectionURL) (Database, error) { +func Open(settings db.ConnectionURL) (db.SQLDatabase, error) { d, err := newDatabase(settings) if err != nil { return nil, err @@ -93,7 +86,7 @@ func (d *database) ConnectionURL() db.ConnectionURL { return d.connURL } -// Open attempts to open a connection to the database server. +// Open stablishes a new connection with the SQL server. func (d *database) Open(connURL db.ConnectionURL) error { if connURL == nil { return db.ErrMissingConnURL @@ -102,18 +95,66 @@ func (d *database) Open(connURL db.ConnectionURL) error { return d.open() } -// NewTransaction starts a transaction block. -func (d *database) NewTransaction() (Tx, error) { +// NewTx returns a transaction session. +func NewTx(sqlTx *sql.Tx) (db.SQLTx, error) { + d, err := newDatabase(nil) + if err != nil { + return nil, err + } + + // Binding with sqladapter's logic. + d.BaseDatabase = sqladapter.NewBaseDatabase(d) + + // Binding with builder. + b, err := builder.New(d.BaseDatabase, template) + if err != nil { + return nil, err + } + d.SQLBuilder = b + + if err := d.BaseDatabase.BindTx(sqlTx); err != nil { + return nil, err + } + + newTx := sqladapter.NewTx(d) + return &tx{DatabaseTx: newTx}, nil +} + +// New wraps the given *sql.DB session and creates a new db session. +func New(sess *sql.DB) (db.SQLDatabase, error) { + d, err := newDatabase(nil) + if err != nil { + return nil, err + } + + // Binding with sqladapter's logic. + d.BaseDatabase = sqladapter.NewBaseDatabase(d) + + // Binding with builder. + b, err := builder.New(d.BaseDatabase, template) + if err != nil { + return nil, err + } + d.SQLBuilder = b + + if err := d.BaseDatabase.BindSession(sess); err != nil { + return nil, err + } + return d, nil +} + +// NewTx starts a transaction block. +func (d *database) NewTx() (db.SQLTx, error) { nTx, err := d.NewLocalTransaction() if err != nil { return nil, err } - return &tx{Tx: nTx}, nil + return &tx{DatabaseTx: nTx}, nil } // Collections returns a list of non-system tables from the database. func (d *database) Collections() (collections []string, err error) { - q := d.Builder.Select("Name"). + q := d.Select("Name"). From("__Table") iter := q.Iterator() @@ -139,7 +180,7 @@ func (d *database) open() error { if err != nil { return err } - d.Builder = b + d.SQLBuilder = b openFn := func() error { openFiles := atomic.LoadInt32(&fileOpenCount) @@ -195,7 +236,7 @@ func (d *database) Err(err error) error { // StatementExec wraps the statement to execute around a transaction. func (d *database) StatementExec(stmt *sql.Stmt, args ...interface{}) (sql.Result, error) { - if d.BaseDatabase.Tx() == nil { + if d.BaseDatabase.Transaction() == nil { var tx *sql.Tx var res sql.Result var err error @@ -224,13 +265,22 @@ func (d *database) NewLocalCollection(name string) db.Collection { return newTable(d, name) } +// Tx creates a transaction and passes it to the given function, if if the +// function returns no error then the transaction is commited. +func (d *database) Tx(fn func(tx db.SQLTx) error) error { + return sqladapter.RunTx(d, fn) +} + // NewLocalTransaction allows sqladapter start a transaction block. -func (d *database) NewLocalTransaction() (sqladapter.Tx, error) { +func (d *database) NewLocalTransaction() (sqladapter.DatabaseTx, error) { clone, err := d.clone() if err != nil { return nil, err } + clone.txMu.Lock() + defer clone.txMu.Unlock() + openFn := func() error { sqlTx, err := clone.BaseDatabase.Session().Begin() if err == nil { @@ -258,7 +308,7 @@ func (d *database) FindDatabaseName() (string, error) { // TableExists allows sqladapter check whether a table exists and returns an // error in case it doesn't. func (d *database) TableExists(name string) error { - q := d.Builder.Select("Name"). + q := d.SQLBuilder.Select("Name"). From("__Table"). Where("Name == ?", name) diff --git a/ql/ql.go b/ql/ql.go index db3a11458729e03663899194209ba1c076a71c4a..0622dc9b6a25d9556957ab15da33d241e99a3733 100644 --- a/ql/ql.go +++ b/ql/ql.go @@ -31,5 +31,9 @@ const sqlDriver = `ql` const Adapter = sqlDriver func init() { - db.Register(Adapter, &database{}) + db.RegisterSQLAdapter(Adapter, &db.SQLAdapterFuncMap{ + New: New, + NewTx: NewTx, + Open: Open, + }) } diff --git a/ql/template_test.go b/ql/template_test.go index 4aecd12e9261371d54dcf1c6a5eed20238b49b81..59b5bb815df069c22f65d2cf0c93f2d04e1861b4 100644 --- a/ql/template_test.go +++ b/ql/template_test.go @@ -9,7 +9,7 @@ import ( ) func TestTemplateSelect(t *testing.T) { - b := builder.NewBuilderWithTemplate(template) + b := builder.NewSQLBuilder(template) assert := assert.New(t) assert.Equal( @@ -147,7 +147,7 @@ func TestTemplateSelect(t *testing.T) { } func TestTemplateInsert(t *testing.T) { - b := builder.NewBuilderWithTemplate(template) + b := builder.NewSQLBuilder(template) assert := assert.New(t) assert.Equal( @@ -189,7 +189,7 @@ func TestTemplateInsert(t *testing.T) { } func TestTemplateUpdate(t *testing.T) { - b := builder.NewBuilderWithTemplate(template) + b := builder.NewSQLBuilder(template) assert := assert.New(t) assert.Equal( @@ -231,7 +231,7 @@ func TestTemplateUpdate(t *testing.T) { } func TestTemplateDelete(t *testing.T) { - b := builder.NewBuilderWithTemplate(template) + b := builder.NewSQLBuilder(template) assert := assert.New(t) assert.Equal( diff --git a/ql/tx.go b/ql/tx.go index 36acd816cb67da28f47eadf214243bf4ddc4de7c..00a42fdf0bf0a403505cdedf0c3ab8410ca59943 100644 --- a/ql/tx.go +++ b/ql/tx.go @@ -26,22 +26,10 @@ import ( "upper.io/db.v2/internal/sqladapter" ) -// Tx represents a transaction. -type Tx interface { - Database - - Commit() error - Rollback() error -} - type tx struct { - sqladapter.Tx + sqladapter.DatabaseTx } var ( - _ = db.Tx(&tx{}) + _ = db.SQLTx(&tx{}) ) - -func (t *tx) NewTransaction() (Tx, error) { - return t, db.ErrAlreadyWithinTransaction -}