diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go index f38d9121e82e144a36a1bd5bade5b8e4d440ea89..4c6cd6d88806df48337f62f7f99ff7e09aea03dc 100644 --- a/internal/sqladapter/database.go +++ b/internal/sqladapter/database.go @@ -1,6 +1,7 @@ package sqladapter import ( + "context" "database/sql" "math" "strconv" @@ -27,7 +28,7 @@ type HasCleanUp interface { // HasStatementExec allows the adapter to have its own exec statement. type HasStatementExec interface { - StatementExec(query string, args ...interface{}) (sql.Result, error) + StatementExec(ctx context.Context, query string, args ...interface{}) (sql.Result, error) } // Database represents a SQL database. @@ -67,6 +68,8 @@ type BaseDatabase interface { Driver() interface{} WaitForConnection(func() error) error + Context() context.Context + WithContext(context.Context) Database BindSession(*sql.DB) error Session() *sql.DB @@ -97,6 +100,8 @@ type database struct { PartialDatabase baseTx BaseTx + ctx context.Context + collectionMu sync.Mutex databaseMu sync.Mutex @@ -285,7 +290,7 @@ func (d *database) Collection(name string) db.Collection { // StatementExec compiles and executes a statement that does not return any // rows. -func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { +func (d *database) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { var query string if db.Conf.LoggingEnabled() { @@ -317,7 +322,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res if execer, ok := d.PartialDatabase.(HasStatementExec); ok { query, args = d.compileStatement(stmt, args) - res, err = execer.StatementExec(query, args...) + res, err = execer.StatementExec(ctx, query, args...) return } @@ -325,27 +330,27 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res if db.Conf.PreparedStatementCacheEnabled() && tx == nil { var p *Stmt - if p, query, args, err = d.prepareStatement(stmt, args); err != nil { + if p, query, args, err = d.prepareStatement(ctx, stmt, args); err != nil { return nil, err } defer p.Close() - res, err = p.Exec(args...) + res, err = p.ExecContext(ctx, args...) return } query, args = d.compileStatement(stmt, args) if tx != nil { - res, err = tx.(*sqlTx).Exec(query, args...) + res, err = tx.(*sqlTx).ExecContext(ctx, query, args...) return } - res, err = d.sess.Exec(query, args...) + res, err = d.sess.ExecContext(ctx, query, args...) return } // StatementQuery compiles and executes a statement that returns rows. -func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { +func (d *database) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { var query string if db.Conf.LoggingEnabled() { @@ -366,29 +371,29 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro if db.Conf.PreparedStatementCacheEnabled() && tx == nil { var p *Stmt - if p, query, args, err = d.prepareStatement(stmt, args); err != nil { + if p, query, args, err = d.prepareStatement(ctx, stmt, args); err != nil { return nil, err } defer p.Close() - rows, err = p.Query(args...) + rows, err = p.QueryContext(ctx, args...) return } query, args = d.compileStatement(stmt, args) if tx != nil { - rows, err = tx.(*sqlTx).Query(query, args...) + rows, err = tx.(*sqlTx).QueryContext(ctx, query, args...) return } - rows, err = d.sess.Query(query, args...) + rows, err = d.sess.QueryContext(ctx, query, args...) return } // StatementQueryRow compiles and executes a statement that returns at most one // row. -func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { +func (d *database) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { var query string if db.Conf.LoggingEnabled() { @@ -409,22 +414,22 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) if db.Conf.PreparedStatementCacheEnabled() && tx == nil { var p *Stmt - if p, query, args, err = d.prepareStatement(stmt, args); err != nil { + if p, query, args, err = d.prepareStatement(ctx, stmt, args); err != nil { return nil, err } defer p.Close() - row = p.QueryRow(args...) + row = p.QueryRowContext(ctx, args...) return } query, args = d.compileStatement(stmt, args) if tx != nil { - row = tx.(*sqlTx).QueryRow(query, args...) + row = tx.(*sqlTx).QueryRowContext(ctx, query, args...) return } - row = d.sess.QueryRow(query, args...) + row = d.sess.QueryRowContext(ctx, query, args...) return } @@ -437,6 +442,21 @@ func (d *database) Driver() interface{} { return d.sess } +func (d *database) Context() context.Context { + if d.ctx == nil { + return context.Background() + } + return d.ctx +} + +func (d *database) WithContext(ctx context.Context) Database { + // TODO: Don't just copy this over. + var newDB *database + *newDB = *d + newDB.ctx = ctx + return newDB +} + // compileStatement compiles the given statement into a string. func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) { return d.PartialDatabase.CompileStatement(stmt, args) @@ -444,7 +464,7 @@ func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (s // prepareStatement compiles a query and tries to use previously generated // statement. -func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { +func (d *database) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { d.sessMu.Lock() defer d.sessMu.Unlock() @@ -466,9 +486,9 @@ func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (* query, args := d.compileStatement(stmt, args) sqlStmt, err := func(query *string) (*sql.Stmt, error) { if tx != nil { - return tx.(*sqlTx).Prepare(*query) + return tx.(*sqlTx).PrepareContext(ctx, *query) } - return sess.Prepare(*query) + return sess.PrepareContext(ctx, *query) }(&query) if err != nil { return nil, "", nil, err diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go index fb922dfb331142b5bf3ad7125b94c8ce9e3bf6bb..3786cde3f709a827886d8171ae0a7f794fb7bdd1 100644 --- a/lib/sqlbuilder/builder.go +++ b/lib/sqlbuilder/builder.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "fmt" @@ -63,9 +64,11 @@ var ( ) type exprDB interface { - StatementQuery(stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) - StatementQueryRow(stmt *exql.Statement, args ...interface{}) (*sql.Row, error) - StatementExec(stmt *exql.Statement, args ...interface{}) (sql.Result, error) + StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) + StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) + StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) + + Context() context.Context } type sqlBuilder struct { @@ -75,17 +78,11 @@ type sqlBuilder struct { // WithSession returns a query builder that is bound to the given database session. func WithSession(sess interface{}, t *exql.Template) (Builder, error) { - switch v := sess.(type) { - case *sql.DB: - sess = newSqlgenProxy(v, t) - case exprDB: - // OK! - default: - // There should be no way this error is ignored. - panic(fmt.Sprintf("Unkown source type: %T", sess)) + if sqlDB, ok := sess.(*sql.DB); ok { + sess = sqlDB } return &sqlBuilder{ - sess: sess.(exprDB), + sess: sess.(exprDB), // Let it panic, it will show the developer an informative error. t: newTemplateWithUtils(t), }, nil } @@ -103,44 +100,60 @@ func NewIterator(rows *sql.Rows) Iterator { } func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) Iterator { - rows, err := b.Query(query, args...) + return b.IteratorContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator { + rows, err := b.QueryContext(ctx, query, args...) return &iterator{rows, err} } func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + return b.ExecContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) { switch q := query.(type) { case *exql.Statement: - return b.sess.StatementExec(q, args...) + return b.sess.StatementExec(ctx, q, args...) case string: - return b.sess.StatementExec(exql.RawSQL(q), args...) + return b.sess.StatementExec(ctx, exql.RawSQL(q), args...) case db.RawValue: - return b.Exec(q.Raw(), q.Arguments()...) + return b.ExecContext(ctx, q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } } func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, error) { + return b.QueryContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) { switch q := query.(type) { case *exql.Statement: - return b.sess.StatementQuery(q, args...) + return b.sess.StatementQuery(ctx, q, args...) case string: - return b.sess.StatementQuery(exql.RawSQL(q), args...) + return b.sess.StatementQuery(ctx, exql.RawSQL(q), args...) case db.RawValue: - return b.Query(q.Raw(), q.Arguments()...) + return b.QueryContext(ctx, q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } } func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) { + return b.QueryRowContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) { switch q := query.(type) { case *exql.Statement: - return b.sess.StatementQueryRow(q, args...) + return b.sess.StatementQueryRow(ctx, q, args...) case string: - return b.sess.StatementQueryRow(exql.RawSQL(q), args...) + return b.sess.StatementQueryRow(ctx, exql.RawSQL(q), args...) case db.RawValue: - return b.QueryRow(q.Raw(), q.Arguments()...) + return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...) default: return nil, fmt.Errorf("Unsupported query type %T.", query) } @@ -494,19 +507,23 @@ func newSqlgenProxy(db *sql.DB, t *exql.Template) *exprProxy { return &exprProxy{db: db, t: t} } -func (p *exprProxy) StatementExec(stmt *exql.Statement, args ...interface{}) (sql.Result, error) { +func (p *exprProxy) Context() context.Context { + return context.Background() +} + +func (p *exprProxy) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) { s := stmt.Compile(p.t) - return p.db.Exec(s, args...) + return p.db.ExecContext(ctx, s, args...) } -func (p *exprProxy) StatementQuery(stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) { +func (p *exprProxy) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) { s := stmt.Compile(p.t) - return p.db.Query(s, args...) + return p.db.QueryContext(ctx, s, args...) } -func (p *exprProxy) StatementQueryRow(stmt *exql.Statement, args ...interface{}) (*sql.Row, error) { +func (p *exprProxy) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) { s := stmt.Compile(p.t) - return p.db.QueryRow(s, args...), nil + return p.db.QueryRowContext(ctx, s, args...), nil } var ( diff --git a/lib/sqlbuilder/delete.go b/lib/sqlbuilder/delete.go index 764460e645bfc37648145fa4bcee65cbd8464f8a..649fef29927e156728e3a3ba8fe4c31c42fe1f98 100644 --- a/lib/sqlbuilder/delete.go +++ b/lib/sqlbuilder/delete.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "upper.io/db.v3/internal/immutable" @@ -91,11 +92,15 @@ func (del *deleter) Arguments() []interface{} { } func (del *deleter) Exec() (sql.Result, error) { + return del.ExecContext(del.Builder().sess.Context()) +} + +func (del *deleter) ExecContext(ctx context.Context) (sql.Result, error) { dq, err := del.build() if err != nil { return nil, err } - return del.Builder().sess.StatementExec(dq.statement(), dq.arguments...) + return del.Builder().sess.StatementExec(ctx, dq.statement(), dq.arguments...) } func (del *deleter) statement() *exql.Statement { diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go index 17b8260e6ec687fc1829633edfbd04d9e1e624bb..30e98cc43ca88b06ba91b0feca4d8fa9d3f72e5c 100644 --- a/lib/sqlbuilder/insert.go +++ b/lib/sqlbuilder/insert.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "upper.io/db.v3/internal/immutable" @@ -147,31 +148,47 @@ func (ins *inserter) Returning(columns ...string) Inserter { } func (ins *inserter) Exec() (sql.Result, error) { + return ins.ExecContext(ins.Builder().sess.Context()) +} + +func (ins *inserter) ExecContext(ctx context.Context) (sql.Result, error) { iq, err := ins.build() if err != nil { return nil, err } - return ins.Builder().sess.StatementExec(iq.statement(), iq.arguments...) + return ins.Builder().sess.StatementExec(ctx, iq.statement(), iq.arguments...) } func (ins *inserter) Query() (*sql.Rows, error) { + return ins.QueryContext(ins.Builder().sess.Context()) +} + +func (ins *inserter) QueryContext(ctx context.Context) (*sql.Rows, error) { iq, err := ins.build() if err != nil { return nil, err } - return ins.Builder().sess.StatementQuery(iq.statement(), iq.arguments...) + return ins.Builder().sess.StatementQuery(ctx, iq.statement(), iq.arguments...) } func (ins *inserter) QueryRow() (*sql.Row, error) { + return ins.QueryRowContext(ins.Builder().sess.Context()) +} + +func (ins *inserter) QueryRowContext(ctx context.Context) (*sql.Row, error) { iq, err := ins.build() if err != nil { return nil, err } - return ins.Builder().sess.StatementQueryRow(iq.statement(), iq.arguments...) + return ins.Builder().sess.StatementQueryRow(ctx, iq.statement(), iq.arguments...) } func (ins *inserter) Iterator() Iterator { - rows, err := ins.Query() + return ins.IteratorContext(ins.Builder().sess.Context()) +} + +func (ins *inserter) IteratorContext(ctx context.Context) Iterator { + rows, err := ins.QueryContext(ctx) return &iterator{rows, err} } diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go index 8684fd16f412aaed29030c277968b37d632d4d5a..d7e74fde7137dbefcff90c3ba83e704e3bb6e378 100644 --- a/lib/sqlbuilder/interfaces.go +++ b/lib/sqlbuilder/interfaces.go @@ -22,6 +22,7 @@ package sqlbuilder import ( + "context" "database/sql" "fmt" ) @@ -82,6 +83,13 @@ type Builder interface { // sqlbuilder.Query(`SELECT * FROM people WHERE name = "Mateo"`) Query(query interface{}, args ...interface{}) (*sql.Rows, error) + // QueryContext executes the given SQL query and returns *sql.Rows. + // + // Example: + // + // sqlbuilder.QueryContext(ctx, `SELECT * FROM people WHERE name = "Mateo"`) + QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) + // QueryRow executes the given SQL query and returns *sql.Row. // // Example: @@ -89,12 +97,26 @@ type Builder interface { // sqlbuilder.QueryRow(`SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) + // QueryRowContext executes the given SQL query and returns *sql.Row. + // + // Example: + // + // sqlbuilder.QueryRowContext(ctx, `SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`) + QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) + // Iterator executes the given SQL query and returns an Iterator. // // Example: // // sqlbuilder.Iterator(`SELECT * FROM people WHERE name LIKE "M%"`) Iterator(query interface{}, args ...interface{}) Iterator + + // IteratorContext executes the given SQL query and returns an Iterator. + // + // Example: + // + // sqlbuilder.IteratorContext(ctx, `SELECT * FROM people WHERE name LIKE "M%"`) + IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator } // Selector represents a SELECT statement. @@ -283,6 +305,10 @@ type Selector interface { // Selector. Iterator() Iterator + // IteratorContext provides methods to iterate over the results returned by + // the Selector. + IteratorContext(ctx context.Context) Iterator + // Getter provides methods to compile and execute a query that returns // results. Getter @@ -330,6 +356,10 @@ type Inserter interface { // Inserter. This is only possible when using Returning(). Iterator() Iterator + // IteratorContext provides methods to iterate over the results returned by + // the Inserter. This is only possible when using Returning(). + IteratorContext(ctx context.Context) Iterator + // Batch provies a BatchInserter that can be used to insert many elements at // once by issuing several calls to Values(). It accepts a size parameter // which defines the batch size. If size is < 1, the batch size is set to 1. @@ -400,6 +430,9 @@ type Updater interface { type Execer interface { // Exec executes a statement and returns sql.Result. Exec() (sql.Result, error) + + // ExecContext executes a statement and returns sql.Result. + ExecContext(context.Context) (sql.Result, error) } // Getter provides methods for executing statements that return results. @@ -407,8 +440,14 @@ type Getter interface { // Query returns *sql.Rows. Query() (*sql.Rows, error) + // QueryContext returns *sql.Rows. + QueryContext(context.Context) (*sql.Rows, error) + // QueryRow returns only one row. QueryRow() (*sql.Row, error) + + // QueryRowContext returns only one row. + QueryRowContext(ctx context.Context) (*sql.Row, error) } // ResultMapper defined methods for a result mapper. diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go index 9df30e276f8dbb1cd7b7cb8680bf1753eec4078e..dcb2c860601e3bfed0186f9556e0048330bc3fdd 100644 --- a/lib/sqlbuilder/select.go +++ b/lib/sqlbuilder/select.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "errors" "fmt" @@ -386,29 +387,41 @@ func (sel *selector) statement() *exql.Statement { } func (sel *selector) QueryRow() (*sql.Row, error) { + return sel.QueryRowContext(sel.Builder().sess.Context()) +} + +func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) { sq, err := sel.build() if err != nil { return nil, err } - return sel.Builder().sess.StatementQueryRow(sq.statement(), sq.arguments()...) + return sel.Builder().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...) } func (sel *selector) Query() (*sql.Rows, error) { + return sel.QueryContext(sel.Builder().sess.Context()) +} + +func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) { sq, err := sel.build() if err != nil { return nil, err } - return sel.Builder().sess.StatementQuery(sq.statement(), sq.arguments()...) + return sel.Builder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) } func (sel *selector) Iterator() Iterator { + return sel.IteratorContext(sel.Builder().sess.Context()) +} + +func (sel *selector) IteratorContext(ctx context.Context) Iterator { sq, err := sel.build() if err != nil { return &iterator{nil, err} } - rows, err := sel.Builder().sess.StatementQuery(sq.statement(), sq.arguments()...) + rows, err := sel.Builder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) return &iterator{rows, err} } diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go index 56c071ec38bc4cc275667ac0d56c70137d4f02be..5e7c42443fa9a275156f4b5122d74e0b6888123b 100644 --- a/lib/sqlbuilder/update.go +++ b/lib/sqlbuilder/update.go @@ -1,6 +1,7 @@ package sqlbuilder import ( + "context" "database/sql" "upper.io/db.v3/internal/immutable" @@ -134,11 +135,15 @@ func (upd *updater) Where(terms ...interface{}) Updater { } func (upd *updater) Exec() (sql.Result, error) { + return upd.ExecContext(upd.Builder().sess.Context()) +} + +func (upd *updater) ExecContext(ctx context.Context) (sql.Result, error) { uq, err := upd.build() if err != nil { return nil, err } - return upd.Builder().sess.StatementExec(uq.statement(), uq.arguments()...) + return upd.Builder().sess.StatementExec(ctx, uq.statement(), uq.arguments()...) } func (upd *updater) Limit(limit int) Updater {