good morning!!!!

Skip to content
Snippets Groups Projects
Commit c7822c33 authored by José Carlos Nieto's avatar José Carlos Nieto
Browse files

Migrating SQLite3 adapter and tests to sqlx.

parent c04f3df4
Branches
Tags
No related merge requests found
...@@ -47,10 +47,10 @@ CREATE TABLE data_types ( ...@@ -47,10 +47,10 @@ CREATE TABLE data_types (
_rune integer, _rune integer,
_bool integer, _bool integer,
_string text, _string text,
_date text, _date datetime,
_nildate text, _nildate datetime,
_ptrdate text, _ptrdate datetime,
_bytea text, _defaultdate datetime default current_timestamp,
_time text _time text
); );
......
...@@ -25,11 +25,9 @@ import ( ...@@ -25,11 +25,9 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strings" "strings"
"time"
"database/sql" "database/sql"
"menteslibres.net/gosexy/to"
"upper.io/db" "upper.io/db"
"upper.io/db/util/sqlgen" "upper.io/db/util/sqlgen"
"upper.io/db/util/sqlutil" "upper.io/db/util/sqlutil"
...@@ -96,7 +94,6 @@ func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) { ...@@ -96,7 +94,6 @@ func whereValues(term interface{}) (where sqlgen.Where, args []interface{}) {
} }
func interfaceArgs(value interface{}) (args []interface{}) { func interfaceArgs(value interface{}) (args []interface{}) {
if value == nil { if value == nil {
return nil return nil
} }
...@@ -112,14 +109,14 @@ func interfaceArgs(value interface{}) (args []interface{}) { ...@@ -112,14 +109,14 @@ func interfaceArgs(value interface{}) (args []interface{}) {
args = make([]interface{}, total) args = make([]interface{}, total)
for i = 0; i < total; i++ { for i = 0; i < total; i++ {
args[i] = toInternal(v.Index(i).Interface()) args[i] = v.Index(i).Interface()
} }
return args return args
} }
return nil return nil
default: default:
args = []interface{}{toInternal(value)} args = []interface{}{value}
} }
return args return args
...@@ -229,7 +226,7 @@ func (c *table) Append(item interface{}) (interface{}, error) { ...@@ -229,7 +226,7 @@ func (c *table) Append(item interface{}) (interface{}, error) {
var values sqlgen.Values var values sqlgen.Values
var arguments []interface{} var arguments []interface{}
cols, vals, err := c.FieldValues(item, toInternal) cols, vals, err := c.FieldValues(item)
// Error ocurred, stop appending. // Error ocurred, stop appending.
if err != nil { if err != nil {
...@@ -334,61 +331,3 @@ func (c *table) Exists() bool { ...@@ -334,61 +331,3 @@ func (c *table) Exists() bool {
func (c *table) Name() string { func (c *table) Name() string {
return strings.Join(c.names, `, `) return strings.Join(c.names, `, `)
} }
// Converts a Go value into internal database representation.
func toInternal(val interface{}) interface{} {
switch t := val.(type) {
case db.Marshaler:
return t
case []byte:
return string(t)
case *time.Time:
if t == nil || t.IsZero() {
return sqlgen.Value{sqlgen.Raw{sqlNull}}
}
return t.Format(DateFormat)
case time.Time:
if t.IsZero() {
return sqlgen.Value{sqlgen.Raw{sqlNull}}
}
return t.Format(DateFormat)
case time.Duration:
return fmt.Sprintf(TimeFormat, int(t/time.Hour), int(t/time.Minute%60), int(t/time.Second%60), t%time.Second/time.Millisecond)
case sql.NullBool:
if t.Valid {
if t.Bool {
return toInternal(t.Bool)
}
return false
}
return sqlgen.Value{sqlgen.Raw{sqlNull}}
case sql.NullFloat64:
if t.Valid {
if t.Float64 != 0.0 {
return toInternal(t.Float64)
}
return float64(0)
}
return sqlgen.Value{sqlgen.Raw{sqlNull}}
case sql.NullInt64:
if t.Valid {
if t.Int64 != 0 {
return toInternal(t.Int64)
}
return 0
}
return sqlgen.Value{sqlgen.Raw{sqlNull}}
case sql.NullString:
if t.Valid {
return toInternal(t.String)
}
return sqlgen.Value{sqlgen.Raw{sqlNull}}
case bool:
if t == true {
return `1`
}
return `0`
}
return to.String(val)
}
...@@ -27,7 +27,9 @@ import ( ...@@ -27,7 +27,9 @@ import (
"os" "os"
"strings" "strings"
"time" "time"
// Importing SQLite3 driver. // Importing SQLite3 driver.
"github.com/jmoiron/sqlx"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"upper.io/cache" "upper.io/cache"
"upper.io/db" "upper.io/db"
...@@ -41,13 +43,6 @@ const ( ...@@ -41,13 +43,6 @@ const (
Adapter = `sqlite` Adapter = `sqlite`
) )
var (
// DateFormat defines the format used for storing dates.
DateFormat = "2006-01-02 15:04:05"
// TimeFormat defines the format used for storing time values.
TimeFormat = "%d:%02d:%02d.%d"
)
var template *sqlgen.Template var template *sqlgen.Template
var ( var (
...@@ -56,7 +51,7 @@ var ( ...@@ -56,7 +51,7 @@ var (
type source struct { type source struct {
connURL db.ConnectionURL connURL db.ConnectionURL
session *sql.DB session *sqlx.DB
tx *tx tx *tx
schema *schema.DatabaseSchema schema *schema.DatabaseSchema
// columns property was introduced so we could query PRAGMA data only once // columns property was introduced so we could query PRAGMA data only once
...@@ -177,8 +172,8 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result, ...@@ -177,8 +172,8 @@ func (s *source) doExec(stmt sqlgen.Statement, args ...interface{}) (sql.Result,
return res, err return res, err
} }
func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, error) { func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Rows, error) {
var rows *sql.Rows var rows *sqlx.Rows
var query string var query string
var err error var err error
var start, end int64 var start, end int64
...@@ -197,17 +192,17 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows, ...@@ -197,17 +192,17 @@ func (s *source) doQuery(stmt sqlgen.Statement, args ...interface{}) (*sql.Rows,
query = stmt.Compile(template) query = stmt.Compile(template)
if s.tx != nil { if s.tx != nil {
rows, err = s.tx.sqlTx.Query(query, args...) rows, err = s.tx.sqlTx.Queryx(query, args...)
} else { } else {
rows, err = s.session.Query(query, args...) rows, err = s.session.Queryx(query, args...)
} }
return rows, err return rows, err
} }
func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Row, error) { func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sqlx.Row, error) {
var query string var query string
var row *sql.Row var row *sqlx.Row
var err error var err error
var start, end int64 var start, end int64
...@@ -225,16 +220,16 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro ...@@ -225,16 +220,16 @@ func (s *source) doQueryRow(stmt sqlgen.Statement, args ...interface{}) (*sql.Ro
query = stmt.Compile(template) query = stmt.Compile(template)
if s.tx != nil { if s.tx != nil {
row = s.tx.sqlTx.QueryRow(query, args...) row = s.tx.sqlTx.QueryRowx(query, args...)
} else { } else {
row = s.session.QueryRow(query, args...) row = s.session.QueryRowx(query, args...)
} }
return row, err return row, err
} }
func (s *source) doRawQuery(query string, args ...interface{}) (*sql.Rows, error) { func (s *source) doRawQuery(query string, args ...interface{}) (*sqlx.Rows, error) {
var rows *sql.Rows var rows *sqlx.Rows
var err error var err error
var start, end int64 var start, end int64
...@@ -250,9 +245,9 @@ func (s *source) doRawQuery(query string, args ...interface{}) (*sql.Rows, error ...@@ -250,9 +245,9 @@ func (s *source) doRawQuery(query string, args ...interface{}) (*sql.Rows, error
} }
if s.tx != nil { if s.tx != nil {
rows, err = s.tx.sqlTx.Query(query, args...) rows, err = s.tx.sqlTx.Queryx(query, args...)
} else { } else {
rows, err = s.session.Query(query, args...) rows, err = s.session.Queryx(query, args...)
} }
return rows, err return rows, err
...@@ -287,9 +282,9 @@ func (s *source) Clone() (db.Database, error) { ...@@ -287,9 +282,9 @@ func (s *source) Clone() (db.Database, error) {
func (s *source) Transaction() (db.Tx, error) { func (s *source) Transaction() (db.Tx, error) {
var err error var err error
var clone *source var clone *source
var sqlTx *sql.Tx var sqlTx *sqlx.Tx
if sqlTx, err = s.session.Begin(); err != nil { if sqlTx, err = s.session.Beginx(); err != nil {
return nil, err return nil, err
} }
...@@ -310,7 +305,7 @@ func (s *source) Setup(conn db.ConnectionURL) error { ...@@ -310,7 +305,7 @@ func (s *source) Setup(conn db.ConnectionURL) error {
return s.Open() return s.Open()
} }
// Returns the underlying *sql.DB instance. // Returns the underlying *sqlx.DB instance.
func (s *source) Driver() interface{} { func (s *source) Driver() interface{} {
return s.session return s.session
} }
...@@ -334,10 +329,12 @@ func (s *source) Open() error { ...@@ -334,10 +329,12 @@ func (s *source) Open() error {
s.connURL = conn s.connURL = conn
} }
if s.session, err = sql.Open(`sqlite3`, s.connURL.String()); err != nil { if s.session, err = sqlx.Open(`sqlite3`, s.connURL.String()); err != nil {
return err return err
} }
s.session.Mapper = sqlutil.NewMapper()
if err = s.populateSchema(); err != nil { if err = s.populateSchema(); err != nil {
return err return err
} }
...@@ -404,7 +401,7 @@ func (s *source) Collections() (collections []string, err error) { ...@@ -404,7 +401,7 @@ func (s *source) Collections() (collections []string, err error) {
} }
// Executing statement. // Executing statement.
var rows *sql.Rows var rows *sqlx.Rows
if rows, err = s.doQuery(stmt); err != nil { if rows, err = s.doQuery(stmt); err != nil {
return nil, err return nil, err
} }
...@@ -434,7 +431,7 @@ func (s *source) Collections() (collections []string, err error) { ...@@ -434,7 +431,7 @@ func (s *source) Collections() (collections []string, err error) {
func (s *source) tableExists(names ...string) error { func (s *source) tableExists(names ...string) error {
var stmt sqlgen.Statement var stmt sqlgen.Statement
var err error var err error
var rows *sql.Rows var rows *sqlx.Rows
for i := range names { for i := range names {
......
...@@ -31,6 +31,7 @@ package sqlite ...@@ -31,6 +31,7 @@ package sqlite
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"math/rand" "math/rand"
"os" "os"
"reflect" "reflect"
...@@ -39,6 +40,7 @@ import ( ...@@ -39,6 +40,7 @@ import (
"testing" "testing"
"time" "time"
"github.com/jmoiron/sqlx"
"menteslibres.net/gosexy/to" "menteslibres.net/gosexy/to"
"upper.io/db" "upper.io/db"
"upper.io/db/util/sqlutil" "upper.io/db/util/sqlutil"
...@@ -48,34 +50,39 @@ const ( ...@@ -48,34 +50,39 @@ const (
database = `_dumps/gotest.sqlite3.db` database = `_dumps/gotest.sqlite3.db`
) )
const (
testTimeZone = "Canada/Eastern"
)
var settings = ConnectionURL{ var settings = ConnectionURL{
Database: database, Database: database,
} }
// Structure for testing conversions and datatypes. // Structure for testing conversions and datatypes.
type testValuesStruct struct { type testValuesStruct struct {
Uint uint `field:"_uint"` Uint uint `db:"_uint"`
Uint8 uint8 `field:"_uint8"` Uint8 uint8 `db:"_uint8"`
Uint16 uint16 `field:"_uint16"` Uint16 uint16 `db:"_uint16"`
Uint32 uint32 `field:"_uint32"` Uint32 uint32 `db:"_uint32"`
Uint64 uint64 `field:"_uint64"` Uint64 uint64 `db:"_uint64"`
Int int `field:"_int"` Int int `db:"_int"`
Int8 int8 `field:"_int8"` Int8 int8 `db:"_int8"`
Int16 int16 `field:"_int16"` Int16 int16 `db:"_int16"`
Int32 int32 `field:"_int32"` Int32 int32 `db:"_int32"`
Int64 int64 `field:"_int64"` Int64 int64 `db:"_int64"`
Float32 float32 `field:"_float32"` Float32 float32 `db:"_float32"`
Float64 float64 `field:"_float64"` Float64 float64 `db:"_float64"`
Bool bool `field:"_bool"` Bool bool `db:"_bool"`
String string `field:"_string"` String string `db:"_string"`
Date time.Time `field:"_date"` Date time.Time `db:"_date"`
DateN *time.Time `field:"_nildate"` DateN *time.Time `db:"_nildate"`
DateP *time.Time `field:"_ptrdate"` DateP *time.Time `db:"_ptrdate"`
Time time.Duration `field:"_time"` DateD *time.Time `db:"_defaultdate,omitempty"`
Time int64 `db:"_time"`
} }
type artistWithInt64Key struct { type artistWithInt64Key struct {
...@@ -114,7 +121,14 @@ func (item *itemWithKey) SetID(keys map[string]interface{}) error { ...@@ -114,7 +121,14 @@ func (item *itemWithKey) SetID(keys map[string]interface{}) error {
var testValues testValuesStruct var testValues testValuesStruct
func init() { func init() {
t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) loc, err := time.LoadLocation(testTimeZone)
if err != nil {
panic(err.Error())
}
t := time.Date(2011, 7, 28, 1, 2, 3, 0, loc)
tnz := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local)
testValues = testValuesStruct{ testValues = testValuesStruct{
1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
...@@ -124,8 +138,9 @@ func init() { ...@@ -124,8 +138,9 @@ func init() {
"Hello world!", "Hello world!",
t, t,
nil, nil,
&t, &tnz,
time.Second * time.Duration(7331), nil,
int64(time.Second * time.Duration(7331)),
} }
} }
...@@ -564,8 +579,8 @@ func TestResultFetch(t *testing.T) { ...@@ -564,8 +579,8 @@ func TestResultFetch(t *testing.T) {
// Dumping into a tagged struct. // Dumping into a tagged struct.
rowStruct2 := struct { rowStruct2 := struct {
Value1 uint64 `field:"id"` Value1 uint64 `db:"id"`
Value2 string `field:"name"` Value2 string `db:"name"`
}{} }{}
res = artist.Find() res = artist.Find()
...@@ -633,8 +648,8 @@ func TestResultFetch(t *testing.T) { ...@@ -633,8 +648,8 @@ func TestResultFetch(t *testing.T) {
// Dumping into an slice of tagged structs. // Dumping into an slice of tagged structs.
allRowsStruct2 := []struct { allRowsStruct2 := []struct {
Value1 uint64 `field:"id"` Value1 uint64 `db:"id"`
Value2 string `field:"name"` Value2 string `db:"name"`
}{} }{}
res = artist.Find() res = artist.Find()
...@@ -1036,9 +1051,9 @@ func TestRawRelations(t *testing.T) { ...@@ -1036,9 +1051,9 @@ func TestRawRelations(t *testing.T) {
func TestRawQuery(t *testing.T) { func TestRawQuery(t *testing.T) {
var sess db.Database var sess db.Database
var rows *sql.Rows var rows *sqlx.Rows
var err error var err error
var drv *sql.DB var drv *sqlx.DB
type publicationType struct { type publicationType struct {
ID int64 `db:"id,omitempty"` ID int64 `db:"id,omitempty"`
...@@ -1052,9 +1067,9 @@ func TestRawQuery(t *testing.T) { ...@@ -1052,9 +1067,9 @@ func TestRawQuery(t *testing.T) {
defer sess.Close() defer sess.Close()
drv = sess.Driver().(*sql.DB) drv = sess.Driver().(*sqlx.DB)
rows, err = drv.Query(` rows, err = drv.Queryx(`
SELECT SELECT
p.id, p.id,
p.title AS publication_title, p.title AS publication_title,
...@@ -1349,10 +1364,24 @@ func TestDataTypes(t *testing.T) { ...@@ -1349,10 +1364,24 @@ func TestDataTypes(t *testing.T) {
// Trying to dump the subject into an empty structure of the same type. // Trying to dump the subject into an empty structure of the same type.
var item testValuesStruct var item testValuesStruct
res.One(&item) if err = res.One(&item); err != nil {
t.Fatal(err)
}
if item.DateD == nil {
t.Fatal("Expecting default date to have been set on append")
}
// Copy the default date (this value is set by the database)
testValues.DateD = item.DateD
loc, _ := time.LoadLocation(testTimeZone)
item.Date = item.Date.In(loc)
// The original value and the test subject must match. // The original value and the test subject must match.
if reflect.DeepEqual(item, testValues) == false { if reflect.DeepEqual(item, testValues) == false {
fmt.Printf("item1: %v\n", item)
fmt.Printf("test2: %v\n", testValues)
t.Fatalf("Struct is different.") t.Fatalf("Struct is different.")
} }
} }
......
...@@ -22,10 +22,10 @@ ...@@ -22,10 +22,10 @@
package sqlite package sqlite
import ( import (
"database/sql"
"fmt" "fmt"
"strings" "strings"
"github.com/jmoiron/sqlx"
"upper.io/db" "upper.io/db"
"upper.io/db/util/sqlgen" "upper.io/db/util/sqlgen"
"upper.io/db/util/sqlutil" "upper.io/db/util/sqlutil"
...@@ -37,7 +37,7 @@ type counter struct { ...@@ -37,7 +37,7 @@ type counter struct {
type result struct { type result struct {
table *table table *table
cursor *sql.Rows // This is the main query cursor. It starts as a nil value. cursor *sqlx.Rows // This is the main query cursor. It starts as a nil value.
limit sqlgen.Limit limit sqlgen.Limit
offset sqlgen.Offset offset sqlgen.Offset
columns sqlgen.Columns columns sqlgen.Columns
...@@ -92,6 +92,7 @@ func (r *result) Group(fields ...interface{}) db.Result { ...@@ -92,6 +92,7 @@ func (r *result) Group(fields ...interface{}) db.Result {
groupByColumns := make(sqlgen.GroupBy, 0, len(fields)) groupByColumns := make(sqlgen.GroupBy, 0, len(fields))
l := len(fields) l := len(fields)
for i := 0; i < l; i++ { for i := 0; i < l; i++ {
switch value := fields[i].(type) { switch value := fields[i].(type) {
// Maybe other types? // Maybe other types?
...@@ -217,35 +218,31 @@ func (r *result) One(dst interface{}) error { ...@@ -217,35 +218,31 @@ func (r *result) One(dst interface{}) error {
} }
// Fetches the next result from the resultset. // Fetches the next result from the resultset.
func (r *result) Next(dst interface{}) error { func (r *result) Next(dst interface{}) (err error) {
var err error
// Current cursor.
err = r.setCursor()
if err != nil { if err = r.setCursor(); err != nil {
r.Close() r.Close()
return err
} }
// Fetching the next result from the cursor. if err = sqlutil.FetchRow(r.cursor, dst); err != nil {
err = sqlutil.FetchRow(r.cursor, dst)
if err != nil {
r.Close() r.Close()
return err
} }
return err return nil
} }
// Removes the matching items from the collection. // Removes the matching items from the collection.
func (r *result) Remove() error { func (r *result) Remove() error {
var err error var err error
_, err = r.table.source.doExec(sqlgen.Statement{ _, err = r.table.source.doExec(sqlgen.Statement{
Type: sqlgen.SqlDelete, Type: sqlgen.SqlDelete,
Table: sqlgen.Table{r.table.Name()}, Table: sqlgen.Table{r.table.Name()},
Where: r.where, Where: r.where,
}, r.arguments...) }, r.arguments...)
return err return err
} }
...@@ -254,7 +251,10 @@ func (r *result) Remove() error { ...@@ -254,7 +251,10 @@ func (r *result) Remove() error {
// struct. // struct.
func (r *result) Update(values interface{}) error { func (r *result) Update(values interface{}) error {
ff, vv, err := r.table.FieldValues(values, toInternal) ff, vv, err := r.table.FieldValues(values)
if err != nil {
return err
}
total := len(ff) total := len(ff)
...@@ -277,8 +277,7 @@ func (r *result) Update(values interface{}) error { ...@@ -277,8 +277,7 @@ func (r *result) Update(values interface{}) error {
} }
// Closes the result set. // Closes the result set.
func (r *result) Close() error { func (r *result) Close() (err error) {
var err error
if r.cursor != nil { if r.cursor != nil {
err = r.cursor.Close() err = r.cursor.Close()
r.cursor = nil r.cursor = nil
...@@ -286,11 +285,11 @@ func (r *result) Close() error { ...@@ -286,11 +285,11 @@ func (r *result) Close() error {
return err return err
} }
// Counting the elements that will be returned. // Counts the elements within the main conditions of the set.
func (r *result) Count() (uint64, error) { func (r *result) Count() (uint64, error) {
var count counter var count counter
rows, err := r.table.source.doQuery(sqlgen.Statement{ row, err := r.table.source.doQueryRow(sqlgen.Statement{
Type: sqlgen.SqlSelectCount, Type: sqlgen.SqlSelectCount,
Table: sqlgen.Table{r.table.Name()}, Table: sqlgen.Table{r.table.Name()},
Where: r.where, Where: r.where,
...@@ -300,8 +299,8 @@ func (r *result) Count() (uint64, error) { ...@@ -300,8 +299,8 @@ func (r *result) Count() (uint64, error) {
return 0, err return 0, err
} }
defer rows.Close() err = row.Scan(&count.Total)
if err = sqlutil.FetchRow(rows, &count); err != nil { if err != nil {
return 0, err return 0, err
} }
......
...@@ -22,12 +22,12 @@ ...@@ -22,12 +22,12 @@
package sqlite package sqlite
import ( import (
"database/sql" "github.com/jmoiron/sqlx"
) )
type tx struct { type tx struct {
*source *source
sqlTx *sql.Tx sqlTx *sqlx.Tx
done bool done bool
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment