diff --git a/sqlite/collection.go b/sqlite/collection.go index ac770f6612f7c54a34499e3ed1151fd490a3c98e..5139885e6ab0bcb6f1bed2e38c57f2e2b4988173 100644 --- a/sqlite/collection.go +++ b/sqlite/collection.go @@ -35,9 +35,9 @@ import ( const defaultOperator = `=` -type Table struct { +type table struct { sqlutil.T - source *Source + source *source names []string } @@ -106,9 +106,8 @@ func interfaceArgs(value interface{}) (args []interface{}) { } return args - } else { - return nil } + return nil default: args = []interface{}{toInternal(value)} } @@ -175,10 +174,10 @@ func conditionValues(cond db.Cond) (columnValues sqlgen.ColumnValues, args []int return columnValues, args } -func (self *Table) Find(terms ...interface{}) db.Result { +func (self *table) Find(terms ...interface{}) db.Result { where, arguments := whereValues(terms) - result := &Result{ + result := &result{ table: self, where: where, arguments: arguments, @@ -187,7 +186,7 @@ func (self *Table) Find(terms ...interface{}) db.Result { return result } -func (self *Table) tableN(i int) string { +func (self *table) tableN(i int) string { if len(self.names) > i { chunks := strings.SplitN(self.names[i], " ", 2) if len(chunks) > 0 { @@ -198,7 +197,7 @@ func (self *Table) tableN(i int) string { } // Deletes all the rows within the collection. -func (self *Table) Truncate() error { +func (self *table) Truncate() error { _, err := self.source.doExec(sqlgen.Statement{ Type: sqlgen.SqlTruncate, @@ -213,7 +212,7 @@ func (self *Table) Truncate() error { } // Appends an item (map or struct) into the collection. -func (self *Table) Append(item interface{}) (interface{}, error) { +func (self *table) Append(item interface{}) (interface{}, error) { cols, vals, err := self.FieldValues(item, toInternal) @@ -251,14 +250,14 @@ func (self *Table) Append(item interface{}) (interface{}, error) { } // Returns true if the collection exists. -func (self *Table) Exists() bool { +func (self *table) Exists() bool { if err := self.source.tableExists(self.names...); err != nil { return false } return true } -func (self *Table) Name() string { +func (self *table) Name() string { return strings.Join(self.names, `, `) } @@ -274,9 +273,8 @@ func toInternal(val interface{}) interface{} { case bool: if t == true { return `1` - } else { - return `0` } + return `0` } return to.String(val) } diff --git a/sqlite/database.go b/sqlite/database.go index 400314f9a8994c7c95573bccd5449f9e47ff7fc4..554e40a114dda54748eeaa93971cb4e650e9b019 100644 --- a/sqlite/database.go +++ b/sqlite/database.go @@ -25,8 +25,6 @@ import ( "database/sql" "fmt" "os" - "reflect" - "regexp" "strings" _ "github.com/mattn/go-sqlite3" @@ -35,6 +33,7 @@ import ( "upper.io/db/util/sqlutil" ) +// Public adapters name under which this adapter registers itself. const Adapter = `sqlite` var ( @@ -42,17 +41,15 @@ var ( DateFormat = "2006-01-02 15:04:05" // Format for saving times. TimeFormat = "%d:%02d:%02d.%d" - SSLMode = "disable" ) var template *sqlgen.Template var ( - columnPattern = regexp.MustCompile(`^([a-zA-Z]+)\(?([0-9,]+)?\)?\s?([a-zA-Z]*)?`) sqlPlaceholder = sqlgen.Value{sqlgen.Raw{`?`}} ) -type Source struct { +type source struct { config db.Settings session *sql.DB collections map[string]db.Collection @@ -60,8 +57,7 @@ type Source struct { } type columnSchema_t struct { - ColumnName string `db:"name"` - DataType string `db:"type"` + Name string `db:"name"` } func debugEnabled() bool { @@ -110,10 +106,10 @@ func init() { sqlSelectCountLayout, } - db.Register(Adapter, &Source{}) + db.Register(Adapter, &source{}) } -func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { +func (self *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, error) { var query string var res sql.Result var err error @@ -137,7 +133,7 @@ func (self *Source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Resu return res, err } -func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { +func (self *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { var rows *sql.Rows var query string var err error @@ -161,7 +157,7 @@ func (self *Source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro return rows, err } -func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { +func (self *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { var query string var row *sql.Row var err error @@ -185,7 +181,7 @@ func (self *Source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql return row, err } -func (self *Source) doRawQuery(query string, args ...interface{}) (*sql.Rows, error) { +func (self *source) doRawQuery(query string, args ...interface{}) (*sql.Rows, error) { var rows *sql.Rows var err error @@ -207,18 +203,18 @@ func (self *Source) doRawQuery(query string, args ...interface{}) (*sql.Rows, er } // Returns the string name of the database. -func (self *Source) Name() string { +func (self *source) Name() string { return self.config.Database } // Ping verifies a connection to the database is still alive, // establishing a connection if necessary. -func (self *Source) Ping() error { +func (self *source) Ping() error { return self.session.Ping() } -func (self *Source) clone() (*Source, error) { - src := &Source{} +func (self *source) clone() (*source, error) { + src := &source{} src.Setup(self.config) if err := src.Open(); err != nil { @@ -228,13 +224,13 @@ func (self *Source) clone() (*Source, error) { return src, nil } -func (self *Source) Clone() (db.Database, error) { +func (self *source) Clone() (db.Database, error) { return self.clone() } -func (self *Source) Transaction() (db.Tx, error) { +func (self *source) Transaction() (db.Tx, error) { var err error - var clone *Source + var clone *source var sqlTx *sql.Tx if sqlTx, err = self.session.Begin(); err != nil { @@ -245,7 +241,7 @@ func (self *Source) Transaction() (db.Tx, error) { return nil, err } - tx := &Tx{clone} + tx := &tx{clone} clone.tx = sqlTx @@ -253,19 +249,19 @@ func (self *Source) Transaction() (db.Tx, error) { } // Stores database settings. -func (self *Source) Setup(config db.Settings) error { +func (self *source) Setup(config db.Settings) error { self.config = config self.collections = make(map[string]db.Collection) return self.Open() } // Returns the underlying *sql.DB instance. -func (self *Source) Driver() interface{} { +func (self *source) Driver() interface{} { return self.session } // Attempts to connect to a database using the stored settings. -func (self *Source) Open() error { +func (self *source) Open() error { var err error if self.config.Database == "" { @@ -280,7 +276,7 @@ func (self *Source) Open() error { } // Closes the current database session. -func (self *Source) Close() error { +func (self *source) Close() error { if self.session != nil { return self.session.Close() } @@ -288,13 +284,13 @@ func (self *Source) Close() error { } // Changes the active database. -func (self *Source) Use(database string) error { +func (self *source) Use(database string) error { self.config.Database = database return self.Open() } // Drops the currently active database. -func (self *Source) Drop() error { +func (self *source) Drop() error { _, err := self.doQuery(sqlgen.Statement{ Type: sqlgen.SqlDropDatabase, @@ -305,7 +301,7 @@ func (self *Source) Drop() error { } // Returns a list of all tables within the currently active database. -func (self *Source) Collections() ([]string, error) { +func (self *source) Collections() ([]string, error) { var collections []string var collection string @@ -334,7 +330,7 @@ func (self *Source) Collections() ([]string, error) { return collections, nil } -func (self *Source) tableExists(names ...string) error { +func (self *source) tableExists(names ...string) error { for _, name := range names { rows, err := self.doQuery(sqlgen.Statement{ @@ -364,20 +360,18 @@ func (self *Source) tableExists(names ...string) error { } // Returns a collection instance by name. -func (self *Source) Collection(names ...string) (db.Collection, error) { +func (self *source) Collection(names ...string) (db.Collection, error) { if len(names) == 0 { return nil, db.ErrMissingCollectionName } - col := &Table{ + col := &table{ source: self, names: names, } - col.PrimaryKey = `id` - - columns_t := []columnSchema_t{} + var columns_t []columnSchema_t for _, name := range names { chunks := strings.SplitN(name, " ", 2) @@ -396,48 +390,15 @@ func (self *Source) Collection(names ...string) (db.Collection, error) { return nil, err } - if err = col.FetchRows(&columns_t, rows); err != nil { + if err = sqlutil.FetchRows(rows, &columns_t); err != nil { return nil, err } - col.ColumnTypes = make(map[string]reflect.Kind, len(columns_t)) + col.Columns = make([]string, len(columns_t)) for _, column := range columns_t { - - column.ColumnName = strings.ToLower(column.ColumnName) - column.DataType = strings.ToLower(column.DataType) - - results := columnPattern.FindStringSubmatch(column.DataType) - - // Default properties. - dextra := `` - dtype := `text` - - dtype = results[1] - - if len(results) > 3 { - dextra = results[3] - } - - ctype := reflect.String - - // Guessing datatypes. - switch dtype { - case `integer`: - if dextra == `unsigned` { - ctype = reflect.Uint64 - } else { - ctype = reflect.Int64 - } - case `real`, `numeric`: - ctype = reflect.Float64 - default: - ctype = reflect.String - } - - col.ColumnTypes[column.ColumnName] = ctype + col.Columns = append(col.Columns, strings.ToLower(column.Name)) } - } } diff --git a/sqlite/database_test.go b/sqlite/database_test.go index ff20eaf700483478ef131c42614b8eaab76cf530..73f6c7ccc96395d9f970e6c549475f6c7039c837 100644 --- a/sqlite/database_test.go +++ b/sqlite/database_test.go @@ -38,6 +38,7 @@ import ( "menteslibres.net/gosexy/to" "upper.io/db" + "upper.io/db/util/sqlutil" ) const ( @@ -752,7 +753,53 @@ func TestRawRelations(t *testing.T) { if len(all) != 9 { t.Fatalf("Expecting some rows.") } +} + +func TestRawQuery(t *testing.T) { + var sess db.Database + var rows *sql.Rows + var err error + var drv *sql.DB + type publication_t struct { + Id int64 `db:"id,omitempty"` + Title string `db:"title"` + AuthorId int64 `db:"author_id"` + } + + if sess, err = db.Open(Adapter, settings); err != nil { + t.Fatal(err) + } + + defer sess.Close() + + drv = sess.Driver().(*sql.DB) + + rows, err = drv.Query(` + SELECT + p.id, + p.title AS publication_title, + a.name AS artist_name + FROM + artist AS a, + publication AS p + WHERE + a.id = p.author_id + `) + + if err != nil { + t.Fatal(err) + } + + var all []publication_t + + if err = sqlutil.FetchRows(rows, &all); err != nil { + t.Fatal(err) + } + + if len(all) != 9 { + t.Fatalf("Expecting some rows.") + } } // Attempts to test database transactions. diff --git a/sqlite/result.go b/sqlite/result.go index 0c2fc850bc95ced169c0d96ad6e101b84f1b09ae..9f54902a0876b9d6b948e379c4386393a13a5c32 100644 --- a/sqlite/result.go +++ b/sqlite/result.go @@ -28,14 +28,15 @@ import ( "upper.io/db" "upper.io/db/util/sqlgen" + "upper.io/db/util/sqlutil" ) type counter_t struct { Total uint64 `db:"_t"` } -type Result struct { - table *Table +type result struct { + table *table cursor *sql.Rows // This is the main query cursor. It starts as a nil value. limit sqlgen.Limit offset sqlgen.Offset @@ -46,7 +47,7 @@ type Result struct { } // Executes a SELECT statement that can feed Next(), All() or One(). -func (self *Result) setCursor() error { +func (self *result) setCursor() error { var err error // We need a cursor, if the cursor does not exists yet then we create one. if self.cursor == nil { @@ -63,20 +64,20 @@ func (self *Result) setCursor() error { } // Sets conditions for reducing the working set. -func (self *Result) Where(terms ...interface{}) db.Result { +func (self *result) Where(terms ...interface{}) db.Result { self.where, self.arguments = whereValues(terms) return self } // Determines the maximum limit of results to be returned. -func (self *Result) Limit(n uint) db.Result { +func (self *result) Limit(n uint) db.Result { self.limit = sqlgen.Limit(n) return self } // Determines how many documents will be skipped before starting to grab // results. -func (self *Result) Skip(n uint) db.Result { +func (self *result) Skip(n uint) db.Result { self.offset = sqlgen.Offset(n) return self } @@ -84,7 +85,7 @@ func (self *Result) Skip(n uint) db.Result { // Determines sorting of results according to the provided names. Fields may be // prefixed by - (minus) which means descending order, ascending order would be // used otherwise. -func (self *Result) Sort(fields ...string) db.Result { +func (self *result) Sort(fields ...string) db.Result { sortColumns := make(sqlgen.SortColumns, 0, len(fields)) @@ -114,7 +115,7 @@ func (self *Result) Sort(fields ...string) db.Result { } // Retrieves only the given fields. -func (self *Result) Select(fields ...interface{}) db.Result { +func (self *result) Select(fields ...interface{}) db.Result { self.columns = make(sqlgen.Columns, 0, len(fields)) l := len(fields) @@ -131,7 +132,7 @@ func (self *Result) Select(fields ...interface{}) db.Result { } // Dumps all results into a pointer to an slice of structs or maps. -func (self *Result) All(dst interface{}) error { +func (self *result) All(dst interface{}) error { var err error if self.cursor != nil { @@ -148,13 +149,13 @@ func (self *Result) All(dst interface{}) error { defer self.Close() // Fetching all results within the cursor. - err = self.table.T.FetchRows(dst, self.cursor) + err = sqlutil.FetchRows(self.cursor, dst) return err } // Fetches only one result from the resultset. -func (self *Result) One(dst interface{}) error { +func (self *result) One(dst interface{}) error { var err error if self.cursor != nil { @@ -169,7 +170,7 @@ func (self *Result) One(dst interface{}) error { } // Fetches the next result from the resultset. -func (self *Result) Next(dst interface{}) error { +func (self *result) Next(dst interface{}) error { var err error @@ -181,7 +182,7 @@ func (self *Result) Next(dst interface{}) error { } // Fetching the next result from the cursor. - err = self.table.T.FetchRow(dst, self.cursor) + err = sqlutil.FetchRow(self.cursor, dst) if err != nil { self.Close() @@ -191,7 +192,7 @@ func (self *Result) Next(dst interface{}) error { } // Removes the matching items from the collection. -func (self *Result) Remove() error { +func (self *result) Remove() error { var err error _, err = self.table.source.doExec(sqlgen.Statement{ Type: sqlgen.SqlDelete, @@ -204,7 +205,7 @@ func (self *Result) Remove() error { // Updates matching items from the collection with values of the given map or // struct. -func (self *Result) Update(values interface{}) error { +func (self *result) Update(values interface{}) error { ff, vv, err := self.table.FieldValues(values, toInternal) @@ -229,7 +230,7 @@ func (self *Result) Update(values interface{}) error { } // Closes the result set. -func (self *Result) Close() error { +func (self *result) Close() error { var err error if self.cursor != nil { err = self.cursor.Close() @@ -239,7 +240,7 @@ func (self *Result) Close() error { } // Counting the elements that will be returned. -func (self *Result) Count() (uint64, error) { +func (self *result) Count() (uint64, error) { rows, err := self.table.source.doQuery(sqlgen.Statement{ Type: sqlgen.SqlSelectCount, @@ -256,7 +257,10 @@ func (self *Result) Count() (uint64, error) { defer rows.Close() dst := counter_t{} - self.table.T.FetchRow(&dst, rows) + + if err = sqlutil.FetchRow(rows, &dst); err != nil { + return 0, err + } return dst.Total, nil } diff --git a/sqlite/tx.go b/sqlite/tx.go index 75faadf00524d09cfecd1940b630cced0cba35e8..4386d4043a73f65fe8032ef244a88792b9a07a19 100644 --- a/sqlite/tx.go +++ b/sqlite/tx.go @@ -21,14 +21,14 @@ package sqlite -type Tx struct { - *Source +type tx struct { + *source } -func (self *Tx) Commit() error { - return self.Source.tx.Commit() +func (self *tx) Commit() error { + return self.source.tx.Commit() } -func (self *Tx) Rollback() error { - return self.Source.tx.Rollback() +func (self *tx) Rollback() error { + return self.source.tx.Rollback() }