From df7fbfc3b3bca126f27bb556734e39f670978dd5 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jos=C3=A9=20Carlos=20Nieto?= <jose.carlos@menteslibres.net>
Date: Fri, 16 Dec 2016 20:32:27 +0000
Subject: [PATCH] Basic support for query context.

---
 internal/sqladapter/database.go | 60 ++++++++++++++++++---------
 lib/sqlbuilder/builder.go       | 73 ++++++++++++++++++++-------------
 lib/sqlbuilder/delete.go        |  7 +++-
 lib/sqlbuilder/insert.go        | 25 +++++++++--
 lib/sqlbuilder/interfaces.go    | 39 ++++++++++++++++++
 lib/sqlbuilder/select.go        | 19 +++++++--
 lib/sqlbuilder/update.go        |  7 +++-
 7 files changed, 173 insertions(+), 57 deletions(-)

diff --git a/internal/sqladapter/database.go b/internal/sqladapter/database.go
index f38d9121..4c6cd6d8 100644
--- a/internal/sqladapter/database.go
+++ b/internal/sqladapter/database.go
@@ -1,6 +1,7 @@
 package sqladapter
 
 import (
+	"context"
 	"database/sql"
 	"math"
 	"strconv"
@@ -27,7 +28,7 @@ type HasCleanUp interface {
 
 // HasStatementExec allows the adapter to have its own exec statement.
 type HasStatementExec interface {
-	StatementExec(query string, args ...interface{}) (sql.Result, error)
+	StatementExec(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
 }
 
 // Database represents a SQL database.
@@ -67,6 +68,8 @@ type BaseDatabase interface {
 	Driver() interface{}
 
 	WaitForConnection(func() error) error
+	Context() context.Context
+	WithContext(context.Context) Database
 
 	BindSession(*sql.DB) error
 	Session() *sql.DB
@@ -97,6 +100,8 @@ type database struct {
 	PartialDatabase
 	baseTx BaseTx
 
+	ctx context.Context
+
 	collectionMu sync.Mutex
 	databaseMu   sync.Mutex
 
@@ -285,7 +290,7 @@ func (d *database) Collection(name string) db.Collection {
 
 // StatementExec compiles and executes a statement that does not return any
 // rows.
-func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) {
+func (d *database) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) {
 	var query string
 
 	if db.Conf.LoggingEnabled() {
@@ -317,7 +322,7 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res
 
 	if execer, ok := d.PartialDatabase.(HasStatementExec); ok {
 		query, args = d.compileStatement(stmt, args)
-		res, err = execer.StatementExec(query, args...)
+		res, err = execer.StatementExec(ctx, query, args...)
 		return
 	}
 
@@ -325,27 +330,27 @@ func (d *database) StatementExec(stmt *exql.Statement, args ...interface{}) (res
 
 	if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
 		var p *Stmt
-		if p, query, args, err = d.prepareStatement(stmt, args); err != nil {
+		if p, query, args, err = d.prepareStatement(ctx, stmt, args); err != nil {
 			return nil, err
 		}
 		defer p.Close()
 
-		res, err = p.Exec(args...)
+		res, err = p.ExecContext(ctx, args...)
 		return
 	}
 
 	query, args = d.compileStatement(stmt, args)
 	if tx != nil {
-		res, err = tx.(*sqlTx).Exec(query, args...)
+		res, err = tx.(*sqlTx).ExecContext(ctx, query, args...)
 		return
 	}
 
-	res, err = d.sess.Exec(query, args...)
+	res, err = d.sess.ExecContext(ctx, query, args...)
 	return
 }
 
 // StatementQuery compiles and executes a statement that returns rows.
-func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) {
+func (d *database) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) {
 	var query string
 
 	if db.Conf.LoggingEnabled() {
@@ -366,29 +371,29 @@ func (d *database) StatementQuery(stmt *exql.Statement, args ...interface{}) (ro
 
 	if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
 		var p *Stmt
-		if p, query, args, err = d.prepareStatement(stmt, args); err != nil {
+		if p, query, args, err = d.prepareStatement(ctx, stmt, args); err != nil {
 			return nil, err
 		}
 		defer p.Close()
 
-		rows, err = p.Query(args...)
+		rows, err = p.QueryContext(ctx, args...)
 		return
 	}
 
 	query, args = d.compileStatement(stmt, args)
 	if tx != nil {
-		rows, err = tx.(*sqlTx).Query(query, args...)
+		rows, err = tx.(*sqlTx).QueryContext(ctx, query, args...)
 		return
 	}
 
-	rows, err = d.sess.Query(query, args...)
+	rows, err = d.sess.QueryContext(ctx, query, args...)
 	return
 
 }
 
 // StatementQueryRow compiles and executes a statement that returns at most one
 // row.
-func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) {
+func (d *database) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) {
 	var query string
 
 	if db.Conf.LoggingEnabled() {
@@ -409,22 +414,22 @@ func (d *database) StatementQueryRow(stmt *exql.Statement, args ...interface{})
 
 	if db.Conf.PreparedStatementCacheEnabled() && tx == nil {
 		var p *Stmt
-		if p, query, args, err = d.prepareStatement(stmt, args); err != nil {
+		if p, query, args, err = d.prepareStatement(ctx, stmt, args); err != nil {
 			return nil, err
 		}
 		defer p.Close()
 
-		row = p.QueryRow(args...)
+		row = p.QueryRowContext(ctx, args...)
 		return
 	}
 
 	query, args = d.compileStatement(stmt, args)
 	if tx != nil {
-		row = tx.(*sqlTx).QueryRow(query, args...)
+		row = tx.(*sqlTx).QueryRowContext(ctx, query, args...)
 		return
 	}
 
-	row = d.sess.QueryRow(query, args...)
+	row = d.sess.QueryRowContext(ctx, query, args...)
 	return
 }
 
@@ -437,6 +442,21 @@ func (d *database) Driver() interface{} {
 	return d.sess
 }
 
+func (d *database) Context() context.Context {
+	if d.ctx == nil {
+		return context.Background()
+	}
+	return d.ctx
+}
+
+func (d *database) WithContext(ctx context.Context) Database {
+	// TODO: Don't just copy this over.
+	var newDB *database
+	*newDB = *d
+	newDB.ctx = ctx
+	return newDB
+}
+
 // compileStatement compiles the given statement into a string.
 func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}) {
 	return d.PartialDatabase.CompileStatement(stmt, args)
@@ -444,7 +464,7 @@ func (d *database) compileStatement(stmt *exql.Statement, args []interface{}) (s
 
 // prepareStatement compiles a query and tries to use previously generated
 // statement.
-func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) {
+func (d *database) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) {
 	d.sessMu.Lock()
 	defer d.sessMu.Unlock()
 
@@ -466,9 +486,9 @@ func (d *database) prepareStatement(stmt *exql.Statement, args []interface{}) (*
 	query, args := d.compileStatement(stmt, args)
 	sqlStmt, err := func(query *string) (*sql.Stmt, error) {
 		if tx != nil {
-			return tx.(*sqlTx).Prepare(*query)
+			return tx.(*sqlTx).PrepareContext(ctx, *query)
 		}
-		return sess.Prepare(*query)
+		return sess.PrepareContext(ctx, *query)
 	}(&query)
 	if err != nil {
 		return nil, "", nil, err
diff --git a/lib/sqlbuilder/builder.go b/lib/sqlbuilder/builder.go
index fb922dfb..3786cde3 100644
--- a/lib/sqlbuilder/builder.go
+++ b/lib/sqlbuilder/builder.go
@@ -1,6 +1,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 	"errors"
 	"fmt"
@@ -63,9 +64,11 @@ var (
 )
 
 type exprDB interface {
-	StatementQuery(stmt *exql.Statement, args ...interface{}) (*sql.Rows, error)
-	StatementQueryRow(stmt *exql.Statement, args ...interface{}) (*sql.Row, error)
-	StatementExec(stmt *exql.Statement, args ...interface{}) (sql.Result, error)
+	StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error)
+	StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error)
+	StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error)
+
+	Context() context.Context
 }
 
 type sqlBuilder struct {
@@ -75,17 +78,11 @@ type sqlBuilder struct {
 
 // WithSession returns a query builder that is bound to the given database session.
 func WithSession(sess interface{}, t *exql.Template) (Builder, error) {
-	switch v := sess.(type) {
-	case *sql.DB:
-		sess = newSqlgenProxy(v, t)
-	case exprDB:
-		// OK!
-	default:
-		// There should be no way this error is ignored.
-		panic(fmt.Sprintf("Unkown source type: %T", sess))
+	if sqlDB, ok := sess.(*sql.DB); ok {
+		sess = sqlDB
 	}
 	return &sqlBuilder{
-		sess: sess.(exprDB),
+		sess: sess.(exprDB), // Let it panic, it will show the developer an informative error.
 		t:    newTemplateWithUtils(t),
 	}, nil
 }
@@ -103,44 +100,60 @@ func NewIterator(rows *sql.Rows) Iterator {
 }
 
 func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) Iterator {
-	rows, err := b.Query(query, args...)
+	return b.IteratorContext(b.sess.Context(), query, args...)
+}
+
+func (b *sqlBuilder) IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator {
+	rows, err := b.QueryContext(ctx, query, args...)
 	return &iterator{rows, err}
 }
 
 func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, error) {
+	return b.ExecContext(b.sess.Context(), query, args...)
+}
+
+func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) {
 	switch q := query.(type) {
 	case *exql.Statement:
-		return b.sess.StatementExec(q, args...)
+		return b.sess.StatementExec(ctx, q, args...)
 	case string:
-		return b.sess.StatementExec(exql.RawSQL(q), args...)
+		return b.sess.StatementExec(ctx, exql.RawSQL(q), args...)
 	case db.RawValue:
-		return b.Exec(q.Raw(), q.Arguments()...)
+		return b.ExecContext(ctx, q.Raw(), q.Arguments()...)
 	default:
 		return nil, fmt.Errorf("Unsupported query type %T.", query)
 	}
 }
 
 func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, error) {
+	return b.QueryContext(b.sess.Context(), query, args...)
+}
+
+func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) {
 	switch q := query.(type) {
 	case *exql.Statement:
-		return b.sess.StatementQuery(q, args...)
+		return b.sess.StatementQuery(ctx, q, args...)
 	case string:
-		return b.sess.StatementQuery(exql.RawSQL(q), args...)
+		return b.sess.StatementQuery(ctx, exql.RawSQL(q), args...)
 	case db.RawValue:
-		return b.Query(q.Raw(), q.Arguments()...)
+		return b.QueryContext(ctx, q.Raw(), q.Arguments()...)
 	default:
 		return nil, fmt.Errorf("Unsupported query type %T.", query)
 	}
 }
 
 func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) {
+	return b.QueryRowContext(b.sess.Context(), query, args...)
+}
+
+func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) {
 	switch q := query.(type) {
 	case *exql.Statement:
-		return b.sess.StatementQueryRow(q, args...)
+		return b.sess.StatementQueryRow(ctx, q, args...)
 	case string:
-		return b.sess.StatementQueryRow(exql.RawSQL(q), args...)
+		return b.sess.StatementQueryRow(ctx, exql.RawSQL(q), args...)
 	case db.RawValue:
-		return b.QueryRow(q.Raw(), q.Arguments()...)
+		return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...)
 	default:
 		return nil, fmt.Errorf("Unsupported query type %T.", query)
 	}
@@ -494,19 +507,23 @@ func newSqlgenProxy(db *sql.DB, t *exql.Template) *exprProxy {
 	return &exprProxy{db: db, t: t}
 }
 
-func (p *exprProxy) StatementExec(stmt *exql.Statement, args ...interface{}) (sql.Result, error) {
+func (p *exprProxy) Context() context.Context {
+	return context.Background()
+}
+
+func (p *exprProxy) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) {
 	s := stmt.Compile(p.t)
-	return p.db.Exec(s, args...)
+	return p.db.ExecContext(ctx, s, args...)
 }
 
-func (p *exprProxy) StatementQuery(stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) {
+func (p *exprProxy) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) {
 	s := stmt.Compile(p.t)
-	return p.db.Query(s, args...)
+	return p.db.QueryContext(ctx, s, args...)
 }
 
-func (p *exprProxy) StatementQueryRow(stmt *exql.Statement, args ...interface{}) (*sql.Row, error) {
+func (p *exprProxy) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) {
 	s := stmt.Compile(p.t)
-	return p.db.QueryRow(s, args...), nil
+	return p.db.QueryRowContext(ctx, s, args...), nil
 }
 
 var (
diff --git a/lib/sqlbuilder/delete.go b/lib/sqlbuilder/delete.go
index 764460e6..649fef29 100644
--- a/lib/sqlbuilder/delete.go
+++ b/lib/sqlbuilder/delete.go
@@ -1,6 +1,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 
 	"upper.io/db.v3/internal/immutable"
@@ -91,11 +92,15 @@ func (del *deleter) Arguments() []interface{} {
 }
 
 func (del *deleter) Exec() (sql.Result, error) {
+	return del.ExecContext(del.Builder().sess.Context())
+}
+
+func (del *deleter) ExecContext(ctx context.Context) (sql.Result, error) {
 	dq, err := del.build()
 	if err != nil {
 		return nil, err
 	}
-	return del.Builder().sess.StatementExec(dq.statement(), dq.arguments...)
+	return del.Builder().sess.StatementExec(ctx, dq.statement(), dq.arguments...)
 }
 
 func (del *deleter) statement() *exql.Statement {
diff --git a/lib/sqlbuilder/insert.go b/lib/sqlbuilder/insert.go
index 17b8260e..30e98cc4 100644
--- a/lib/sqlbuilder/insert.go
+++ b/lib/sqlbuilder/insert.go
@@ -1,6 +1,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 
 	"upper.io/db.v3/internal/immutable"
@@ -147,31 +148,47 @@ func (ins *inserter) Returning(columns ...string) Inserter {
 }
 
 func (ins *inserter) Exec() (sql.Result, error) {
+	return ins.ExecContext(ins.Builder().sess.Context())
+}
+
+func (ins *inserter) ExecContext(ctx context.Context) (sql.Result, error) {
 	iq, err := ins.build()
 	if err != nil {
 		return nil, err
 	}
-	return ins.Builder().sess.StatementExec(iq.statement(), iq.arguments...)
+	return ins.Builder().sess.StatementExec(ctx, iq.statement(), iq.arguments...)
 }
 
 func (ins *inserter) Query() (*sql.Rows, error) {
+	return ins.QueryContext(ins.Builder().sess.Context())
+}
+
+func (ins *inserter) QueryContext(ctx context.Context) (*sql.Rows, error) {
 	iq, err := ins.build()
 	if err != nil {
 		return nil, err
 	}
-	return ins.Builder().sess.StatementQuery(iq.statement(), iq.arguments...)
+	return ins.Builder().sess.StatementQuery(ctx, iq.statement(), iq.arguments...)
 }
 
 func (ins *inserter) QueryRow() (*sql.Row, error) {
+	return ins.QueryRowContext(ins.Builder().sess.Context())
+}
+
+func (ins *inserter) QueryRowContext(ctx context.Context) (*sql.Row, error) {
 	iq, err := ins.build()
 	if err != nil {
 		return nil, err
 	}
-	return ins.Builder().sess.StatementQueryRow(iq.statement(), iq.arguments...)
+	return ins.Builder().sess.StatementQueryRow(ctx, iq.statement(), iq.arguments...)
 }
 
 func (ins *inserter) Iterator() Iterator {
-	rows, err := ins.Query()
+	return ins.IteratorContext(ins.Builder().sess.Context())
+}
+
+func (ins *inserter) IteratorContext(ctx context.Context) Iterator {
+	rows, err := ins.QueryContext(ctx)
 	return &iterator{rows, err}
 }
 
diff --git a/lib/sqlbuilder/interfaces.go b/lib/sqlbuilder/interfaces.go
index 8684fd16..d7e74fde 100644
--- a/lib/sqlbuilder/interfaces.go
+++ b/lib/sqlbuilder/interfaces.go
@@ -22,6 +22,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 	"fmt"
 )
@@ -82,6 +83,13 @@ type Builder interface {
 	//  sqlbuilder.Query(`SELECT * FROM people WHERE name = "Mateo"`)
 	Query(query interface{}, args ...interface{}) (*sql.Rows, error)
 
+	// QueryContext executes the given SQL query and returns *sql.Rows.
+	//
+	// Example:
+	//
+	//  sqlbuilder.QueryContext(ctx, `SELECT * FROM people WHERE name = "Mateo"`)
+	QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error)
+
 	// QueryRow executes the given SQL query and returns *sql.Row.
 	//
 	// Example:
@@ -89,12 +97,26 @@ type Builder interface {
 	//  sqlbuilder.QueryRow(`SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`)
 	QueryRow(query interface{}, args ...interface{}) (*sql.Row, error)
 
+	// QueryRowContext executes the given SQL query and returns *sql.Row.
+	//
+	// Example:
+	//
+	//  sqlbuilder.QueryRowContext(ctx, `SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`)
+	QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error)
+
 	// Iterator executes the given SQL query and returns an Iterator.
 	//
 	// Example:
 	//
 	//  sqlbuilder.Iterator(`SELECT * FROM people WHERE name LIKE "M%"`)
 	Iterator(query interface{}, args ...interface{}) Iterator
+
+	// IteratorContext executes the given SQL query and returns an Iterator.
+	//
+	// Example:
+	//
+	//  sqlbuilder.IteratorContext(ctx, `SELECT * FROM people WHERE name LIKE "M%"`)
+	IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator
 }
 
 // Selector represents a SELECT statement.
@@ -283,6 +305,10 @@ type Selector interface {
 	// Selector.
 	Iterator() Iterator
 
+	// IteratorContext provides methods to iterate over the results returned by
+	// the Selector.
+	IteratorContext(ctx context.Context) Iterator
+
 	// Getter provides methods to compile and execute a query that returns
 	// results.
 	Getter
@@ -330,6 +356,10 @@ type Inserter interface {
 	// Inserter. This is only possible when using Returning().
 	Iterator() Iterator
 
+	// IteratorContext provides methods to iterate over the results returned by
+	// the Inserter. This is only possible when using Returning().
+	IteratorContext(ctx context.Context) Iterator
+
 	// Batch provies a BatchInserter that can be used to insert many elements at
 	// once by issuing several calls to Values(). It accepts a size parameter
 	// which defines the batch size. If size is < 1, the batch size is set to 1.
@@ -400,6 +430,9 @@ type Updater interface {
 type Execer interface {
 	// Exec executes a statement and returns sql.Result.
 	Exec() (sql.Result, error)
+
+	// ExecContext executes a statement and returns sql.Result.
+	ExecContext(context.Context) (sql.Result, error)
 }
 
 // Getter provides methods for executing statements that return results.
@@ -407,8 +440,14 @@ type Getter interface {
 	// Query returns *sql.Rows.
 	Query() (*sql.Rows, error)
 
+	// QueryContext returns *sql.Rows.
+	QueryContext(context.Context) (*sql.Rows, error)
+
 	// QueryRow returns only one row.
 	QueryRow() (*sql.Row, error)
+
+	// QueryRowContext returns only one row.
+	QueryRowContext(ctx context.Context) (*sql.Row, error)
 }
 
 // ResultMapper defined methods for a result mapper.
diff --git a/lib/sqlbuilder/select.go b/lib/sqlbuilder/select.go
index 9df30e27..dcb2c860 100644
--- a/lib/sqlbuilder/select.go
+++ b/lib/sqlbuilder/select.go
@@ -1,6 +1,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 	"errors"
 	"fmt"
@@ -386,29 +387,41 @@ func (sel *selector) statement() *exql.Statement {
 }
 
 func (sel *selector) QueryRow() (*sql.Row, error) {
+	return sel.QueryRowContext(sel.Builder().sess.Context())
+}
+
+func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) {
 	sq, err := sel.build()
 	if err != nil {
 		return nil, err
 	}
 
-	return sel.Builder().sess.StatementQueryRow(sq.statement(), sq.arguments()...)
+	return sel.Builder().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...)
 }
 
 func (sel *selector) Query() (*sql.Rows, error) {
+	return sel.QueryContext(sel.Builder().sess.Context())
+}
+
+func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) {
 	sq, err := sel.build()
 	if err != nil {
 		return nil, err
 	}
-	return sel.Builder().sess.StatementQuery(sq.statement(), sq.arguments()...)
+	return sel.Builder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
 }
 
 func (sel *selector) Iterator() Iterator {
+	return sel.IteratorContext(sel.Builder().sess.Context())
+}
+
+func (sel *selector) IteratorContext(ctx context.Context) Iterator {
 	sq, err := sel.build()
 	if err != nil {
 		return &iterator{nil, err}
 	}
 
-	rows, err := sel.Builder().sess.StatementQuery(sq.statement(), sq.arguments()...)
+	rows, err := sel.Builder().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...)
 	return &iterator{rows, err}
 }
 
diff --git a/lib/sqlbuilder/update.go b/lib/sqlbuilder/update.go
index 56c071ec..5e7c4244 100644
--- a/lib/sqlbuilder/update.go
+++ b/lib/sqlbuilder/update.go
@@ -1,6 +1,7 @@
 package sqlbuilder
 
 import (
+	"context"
 	"database/sql"
 
 	"upper.io/db.v3/internal/immutable"
@@ -134,11 +135,15 @@ func (upd *updater) Where(terms ...interface{}) Updater {
 }
 
 func (upd *updater) Exec() (sql.Result, error) {
+	return upd.ExecContext(upd.Builder().sess.Context())
+}
+
+func (upd *updater) ExecContext(ctx context.Context) (sql.Result, error) {
 	uq, err := upd.build()
 	if err != nil {
 		return nil, err
 	}
-	return upd.Builder().sess.StatementExec(uq.statement(), uq.arguments()...)
+	return upd.Builder().sess.StatementExec(ctx, uq.statement(), uq.arguments()...)
 }
 
 func (upd *updater) Limit(limit int) Updater {
-- 
GitLab