good morning!!!!

Skip to content
Snippets Groups Projects
database.go 7.12 KiB
Newer Older
	"database/sql"
	"sync"
	"time"

	"github.com/jmoiron/sqlx"
	sqlbuilder "upper.io/builder"
	"upper.io/builder/meta"
	"upper.io/builder/sqlgen"
	"upper.io/db/internal/debug"
	"upper.io/db/internal/schema"
	"upper.io/db/internal/sqlutil/tx"
type HasExecStatement interface {
	Exec(stmt *sqlx.Stmt, args ...interface{}) (sql.Result, error)
}

type PartialDatabase interface {
	PopulateSchema() error
	TableExists(name string) error
	TablePrimaryKey(name string) ([]string, error)
	NewTable(name string) db.Collection
	CompileAndReplacePlaceholders(stmt *sqlgen.Statement) (query string)
	Err(in error) (out error)
}

type Database interface {
	db.Database
	TableExists(name string) error
	TablePrimaryKey(name string) ([]string, error)
}

type BaseDatabase struct {
	partial PartialDatabase
	sess    *sqlx.DB

	connURL          db.ConnectionURL
	schema           *schema.DatabaseSchema
	cachedStatements *cache.Cache
	collections      map[string]db.Collection
	collectionsMu    sync.Mutex
	builder          builder.QueryBuilder

	template *sqlgen.Template
}

type cachedStatement struct {
	*sqlx.Stmt
	query string
}

func NewDatabase(partial PartialDatabase, connURL db.ConnectionURL, template *sqlgen.Template) *BaseDatabase {
	d := &BaseDatabase{
		partial:  partial,
		connURL:  connURL,
		template: template,
	}

	d.builder = sqlbuilder.NewBuilder(d, d.template)
	d.cachedStatements = cache.NewCache()

	return d
}

func (d *BaseDatabase) Session() *sqlx.DB {
	return d.sess
}

func (d *BaseDatabase) Template() *sqlgen.Template {
	return d.template
}

func (d *BaseDatabase) BindTx(tx *sqlx.Tx) {
	d.tx = sqltx.New(tx)
}

func (d *BaseDatabase) Tx() *sqltx.Tx {
	return d.tx
}

func (d *BaseDatabase) NewSchema() {
	d.schema = schema.NewDatabaseSchema()
}

func (d *BaseDatabase) Schema() *schema.DatabaseSchema {
	return d.schema
}

func (d *BaseDatabase) Bind(sess *sqlx.DB) error {
	d.sess = sess
	return d.populate()
}

func (d *BaseDatabase) populate() error {

	d.collections = make(map[string]db.Collection)

	if d.schema == nil {
		if err := d.partial.PopulateSchema(); err != nil {
			return err
		}
	}

	return nil
}

func (d *BaseDatabase) Clone(partial PartialDatabase) *BaseDatabase {
	clone := NewDatabase(partial, d.connURL, d.template)
	clone.schema = d.schema
	return clone
}

// Ping checks whether a connection to the database is still alive by pinging
// it, establishing a connection if necessary.
func (d *BaseDatabase) Ping() error {
	return d.sess.Ping()
}

// Close terminates the current database session.
func (d *BaseDatabase) Close() error {
	if d.sess != nil {
		if d.tx != nil && !d.tx.Done() {
			d.tx.Rollback()
		}
		d.cachedStatements.Clear()
		return d.sess.Close()
	}
	return nil
}

// C returns a collection interface.
func (d *BaseDatabase) C(name string) db.Collection {
	if c, ok := d.collections[name]; ok {
		return c
	}

	c, err := d.Collection(name)
	if err != nil {
		return &adapter.NonExistentCollection{Err: err}
	}
	return c
}

// Collection returns the table that matches the given name.
func (d *BaseDatabase) Collection(name string) (db.Collection, error) {
	if d.tx != nil {
		if d.tx.Done() {
			return nil, sql.ErrTxDone
		}
	}

	if err := d.partial.TableExists(name); err != nil {
		return nil, err
	}

	col := d.partial.NewTable(name)

	d.collectionsMu.Lock()
	d.collections[name] = col
	d.collectionsMu.Unlock()

	return col, nil
}

func (d *BaseDatabase) ConnectionURL() db.ConnectionURL {
	return d.connURL
}

// Name returns the name of the database.
func (d *BaseDatabase) Name() string {
	return d.schema.Name
}

// Exec compiles and executes a statement that does not return any rows.
func (d *BaseDatabase) Exec(stmt *sqlgen.Statement, args ...interface{}) (sql.Result, error) {
	var query string
	var p *sqlx.Stmt
	var err error

	if db.Debug {
		var start, end int64
		start = time.Now().UnixNano()

		defer func() {
			end = time.Now().UnixNano()
			debug.Log(query, args, err, start, end)
		}()
	}

	if p, query, err = d.prepareStatement(stmt); err != nil {
		return nil, err
	}

	if execer, ok := d.partial.(HasExecStatement); ok {
		return execer.Exec(p, args...)
	}

	return p.Exec(args...)
}

// Query compiles and executes a statement that returns rows.
func (d *BaseDatabase) Query(stmt *sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) {
	var query string
	var p *sqlx.Stmt
	var err error

	if db.Debug {
		var start, end int64
		start = time.Now().UnixNano()

		defer func() {
			end = time.Now().UnixNano()
			debug.Log(query, args, err, start, end)
		}()
	}

	if p, query, err = d.prepareStatement(stmt); err != nil {
		return nil, err
	}

	return p.Queryx(args...)
}

// QueryRow compiles and executes a statement that returns at most one row.
func (d *BaseDatabase) QueryRow(stmt *sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) {
	var query string
	var p *sqlx.Stmt
	var err error

	if db.Debug {
		var start, end int64
		start = time.Now().UnixNano()

		defer func() {
			end = time.Now().UnixNano()
			debug.Log(query, args, err, start, end)
		}()
	}

	if p, query, err = d.prepareStatement(stmt); err != nil {
		return nil, err
	}

	return p.QueryRowx(args...), nil
}

// Builder returns a custom query builder.
func (d *BaseDatabase) Builder() builder.QueryBuilder {
	return d.builder
}

// Driver returns the underlying *sqlx.DB instance.
func (d *BaseDatabase) Driver() interface{} {
	if d.tx != nil {
		return d.tx.Tx
	}
	return d.sess
}

func (d *BaseDatabase) prepareStatement(stmt *sqlgen.Statement) (p *sqlx.Stmt, query string, err error) {
	if d.sess == nil {
		return nil, "", db.ErrNotConnected
	}

	pc, ok := d.cachedStatements.ReadRaw(stmt)

	if ok {
		ps := pc.(*cachedStatement)
		p = ps.Stmt
		query = ps.query
	} else {
		query = d.partial.CompileAndReplacePlaceholders(stmt)

		if d.tx != nil {
			p, err = d.tx.Preparex(query)
		} else {
			p, err = d.sess.Preparex(query)
		}

		if err != nil {
			return nil, query, err
		}

		d.cachedStatements.Write(stmt, &cachedStatement{p, query})
	}

	return p, query, nil
}

var waitForConnMu sync.Mutex

// waitForConnection tries to execute the connectFn function, if connectFn
// returns an error, then waitForConnection will keep trying until connectFn
// returns nil. Maximum waiting time is 5s after having acquired the lock.
func (d *BaseDatabase) WaitForConnection(connectFn func() error) error {
	// This lock ensures first-come, first-served and prevents opening too many
	// file descriptors.
	waitForConnMu.Lock()
	defer waitForConnMu.Unlock()

	// Minimum waiting time.
	waitTime := time.Millisecond * 10

	// Waitig 5 seconds for a successful connection.
	for timeStart := time.Now(); time.Now().Sub(timeStart) < time.Second*5; {
		if err := connectFn(); err != nil {
			if d.partial.Err(err) == db.ErrTooManyClients {
				// Sleep and try again if, and only if, the server replied with a "too
				// many clients" error.
				time.Sleep(waitTime)
				if waitTime < time.Millisecond*500 {
					// Wait a bit more next time.
					waitTime = waitTime * 2
				}
				continue
			}
			// Return any other error immediately.
			return err
		}
		return nil
	}

	return db.ErrGivingUpTryingToConnect
}