good morning!!!!

Skip to content
Snippets Groups Projects
Commit 93936726 authored by José Carlos Nieto's avatar José Carlos Nieto
Browse files

Updating postgresql wrapper.

parent e91e4786
Branches
Tags
No related merge requests found
...@@ -33,8 +33,12 @@ import ( ...@@ -33,8 +33,12 @@ import (
"regexp" "regexp"
"strconv" "strconv"
"strings" "strings"
"time"
) )
const dateFormat = "2006-01-02 15:04:05"
const timeFormat = "%d:%02d:%02d"
type pgQuery struct { type pgQuery struct {
Query []string Query []string
SqlArgs []string SqlArgs []string
...@@ -144,15 +148,17 @@ func (t *PostgresqlTable) pgFetchAll(rows sql.Rows) []db.Item { ...@@ -144,15 +148,17 @@ func (t *PostgresqlTable) pgFetchAll(rows sql.Rows) []db.Item {
return items return items
} }
func (pg *PostgresqlDataSource) pgExec(method string, terms ...interface{}) sql.Rows { func (pg *PostgresqlDataSource) pgExec(method string, terms ...interface{}) (sql.Rows, error) {
sn := reflect.ValueOf(pg.session) sn := reflect.ValueOf(pg.session)
fn := sn.MethodByName(method) fn := sn.MethodByName(method)
q := pgCompile(terms) q := pgCompile(terms)
//fmt.Printf("Q: %v\n", q.Query) /*
//fmt.Printf("A: %v\n", q.SqlArgs) fmt.Printf("Q: %v\n", q.Query)
fmt.Printf("A: %v\n", q.SqlArgs)
*/
qs := strings.Join(q.Query, " ") qs := strings.Join(q.Query, " ")
...@@ -168,10 +174,15 @@ func (pg *PostgresqlDataSource) pgExec(method string, terms ...interface{}) sql. ...@@ -168,10 +174,15 @@ func (pg *PostgresqlDataSource) pgExec(method string, terms ...interface{}) sql.
res := fn.Call(args) res := fn.Call(args)
if res[1].IsNil() == false { if res[1].IsNil() == false {
panic(res[1].Elem().Interface().(error)) return sql.Rows{}, res[1].Elem().Interface().(error)
} }
return res[0].Elem().Interface().(sql.Rows) switch res[0].Elem().Interface().(type) {
case sql.Rows:
return res[0].Elem().Interface().(sql.Rows), nil
}
return sql.Rows{}, nil
} }
// Represents a PostgreSQL table. // Represents a PostgreSQL table.
...@@ -245,12 +256,17 @@ func (pg *PostgresqlDataSource) Driver() interface{} { ...@@ -245,12 +256,17 @@ func (pg *PostgresqlDataSource) Driver() interface{} {
func (pg *PostgresqlDataSource) Collections() []string { func (pg *PostgresqlDataSource) Collections() []string {
var collections []string var collections []string
var collection string var collection string
rows, _ := pg.session.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
rows, err := pg.session.Query("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'")
if err == nil {
for rows.Next() { for rows.Next() {
rows.Scan(&collection) rows.Scan(&collection)
collections = append(collections, collection) collections = append(collections, collection)
} }
} else {
panic(err)
}
return collections return collections
} }
...@@ -364,18 +380,18 @@ func (t *PostgresqlTable) marshal(where db.Cond) (string, []string) { ...@@ -364,18 +380,18 @@ func (t *PostgresqlTable) marshal(where db.Cond) (string, []string) {
} }
// Deletes all the rows in the table. // Deletes all the rows in the table.
func (t *PostgresqlTable) Truncate() bool { func (t *PostgresqlTable) Truncate() error {
t.parent.pgExec( _, err := t.parent.pgExec(
"Query", "Exec",
fmt.Sprintf("TRUNCATE TABLE %s", pgTable(t.name)), fmt.Sprintf("TRUNCATE TABLE %s", pgTable(t.name)),
) )
return false return err
} }
// Deletes all the rows in the table that match certain conditions. // Deletes all the rows in the table that match certain conditions.
func (t *PostgresqlTable) Remove(terms ...interface{}) bool { func (t *PostgresqlTable) Remove(terms ...interface{}) error {
conditions, cargs := t.compileConditions(terms) conditions, cargs := t.compileConditions(terms)
...@@ -383,17 +399,17 @@ func (t *PostgresqlTable) Remove(terms ...interface{}) bool { ...@@ -383,17 +399,17 @@ func (t *PostgresqlTable) Remove(terms ...interface{}) bool {
conditions = "1 = 1" conditions = "1 = 1"
} }
t.parent.pgExec( _, err := t.parent.pgExec(
"Query", "Exec",
fmt.Sprintf("DELETE FROM %s", pgTable(t.name)), fmt.Sprintf("DELETE FROM %s", pgTable(t.name)),
fmt.Sprintf("WHERE %s", conditions), cargs, fmt.Sprintf("WHERE %s", conditions), cargs,
) )
return true return err
} }
// Modifies all the rows in the table that match certain conditions. // Modifies all the rows in the table that match certain conditions.
func (t *PostgresqlTable) Update(terms ...interface{}) bool { func (t *PostgresqlTable) Update(terms ...interface{}) error {
var fields string var fields string
var fargs db.SqlArgs var fargs db.SqlArgs
...@@ -410,13 +426,13 @@ func (t *PostgresqlTable) Update(terms ...interface{}) bool { ...@@ -410,13 +426,13 @@ func (t *PostgresqlTable) Update(terms ...interface{}) bool {
conditions = "1 = 1" conditions = "1 = 1"
} }
t.parent.pgExec( _, err := t.parent.pgExec(
"Query", "Exec",
fmt.Sprintf("UPDATE %s SET %s", pgTable(t.name), fields), fargs, fmt.Sprintf("UPDATE %s SET %s", pgTable(t.name), fields), fargs,
fmt.Sprintf("WHERE %s", conditions), cargs, fmt.Sprintf("WHERE %s", conditions), cargs,
) )
return true return err
} }
// Returns all the rows in the table that match certain conditions. // Returns all the rows in the table that match certain conditions.
...@@ -457,13 +473,17 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []db.Item { ...@@ -457,13 +473,17 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []db.Item {
conditions = "1 = 1" conditions = "1 = 1"
} }
rows := t.parent.pgExec( rows, err := t.parent.pgExec(
"Query", "Query",
fmt.Sprintf("SELECT %s FROM %s", fields, pgTable(t.name)), fmt.Sprintf("SELECT %s FROM %s", fields, pgTable(t.name)),
fmt.Sprintf("WHERE %s", conditions), args, fmt.Sprintf("WHERE %s", conditions), args,
limit, offset, limit, offset,
) )
if err != nil {
panic(err)
}
result := t.pgFetchAll(rows) result := t.pgFetchAll(rows)
var relations []sugar.Tuple var relations []sugar.Tuple
...@@ -581,7 +601,7 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []db.Item { ...@@ -581,7 +601,7 @@ func (t *PostgresqlTable) FindAll(terms ...interface{}) []db.Item {
} }
// Returns the number of rows in the current table that match certain conditions. // Returns the number of rows in the current table that match certain conditions.
func (t *PostgresqlTable) Count(terms ...interface{}) int { func (t *PostgresqlTable) Count(terms ...interface{}) (int, error) {
terms = append(terms, db.Fields{"COUNT(1) AS _total"}) terms = append(terms, db.Fields{"COUNT(1) AS _total"})
...@@ -591,11 +611,11 @@ func (t *PostgresqlTable) Count(terms ...interface{}) int { ...@@ -591,11 +611,11 @@ func (t *PostgresqlTable) Count(terms ...interface{}) int {
response := result[0].Interface().([]db.Item) response := result[0].Interface().([]db.Item)
if len(response) > 0 { if len(response) > 0 {
val, _ := strconv.Atoi(response[0]["_total"].(string)) val, _ := strconv.Atoi(response[0]["_total"].(string))
return val return val, nil
} }
} }
return 0 return 0, nil
} }
// Returns the first row in the table that matches certain conditions. // Returns the first row in the table that matches certain conditions.
...@@ -617,8 +637,29 @@ func (t *PostgresqlTable) Find(terms ...interface{}) db.Item { ...@@ -617,8 +637,29 @@ func (t *PostgresqlTable) Find(terms ...interface{}) db.Item {
return item return item
} }
func toInternal(val interface{}) string {
switch val.(type) {
case []byte:
return fmt.Sprintf("%s", string(val.([]byte)))
case time.Time:
return val.(time.Time).Format(dateFormat)
case time.Duration:
t := val.(time.Duration)
return fmt.Sprintf(timeFormat, int(t.Hours()), int(t.Minutes())%60, int(t.Seconds())%60)
case bool:
if val.(bool) == true {
return "1"
} else {
return "0"
}
}
return fmt.Sprintf("%v", val)
}
// Inserts rows into the currently active table. // Inserts rows into the currently active table.
func (t *PostgresqlTable) Append(items ...interface{}) bool { func (t *PostgresqlTable) Append(items ...interface{}) error {
itop := len(items) itop := len(items)
...@@ -631,10 +672,11 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool { ...@@ -631,10 +672,11 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool {
for field, value := range item.(db.Item) { for field, value := range item.(db.Item) {
fields = append(fields, field) fields = append(fields, field)
values = append(values, fmt.Sprintf("%v", value)) values = append(values, toInternal(value))
} }
t.parent.pgExec("Query", _, err := t.parent.pgExec(
"Exec",
"INSERT INTO", "INSERT INTO",
pgTable(t.name), pgTable(t.name),
pgFields(fields), pgFields(fields),
...@@ -642,9 +684,10 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool { ...@@ -642,9 +684,10 @@ func (t *PostgresqlTable) Append(items ...interface{}) bool {
pgValues(values), pgValues(values),
) )
return err
} }
return true return nil
} }
// Returns a MySQL table structure by name. // Returns a MySQL table structure by name.
...@@ -661,11 +704,13 @@ func (pg *PostgresqlDataSource) Collection(name string) db.Collection { ...@@ -661,11 +704,13 @@ func (pg *PostgresqlDataSource) Collection(name string) db.Collection {
// Fetching table datatypes and mapping to internal gotypes. // Fetching table datatypes and mapping to internal gotypes.
rows := t.parent.pgExec( rows, err := t.parent.pgExec(
"Query", "Query",
"SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", db.SqlArgs{t.name}, "SELECT column_name, data_type FROM information_schema.columns WHERE table_name = ?", db.SqlArgs{t.name},
) )
if err == nil {
columns := t.pgFetchAll(rows) columns := t.pgFetchAll(rows)
pattern, _ := regexp.Compile("^([a-z]+)\\(?([0-9,]+)?\\)?\\s?([a-z]*)?") pattern, _ := regexp.Compile("^([a-z]+)\\(?([0-9,]+)?\\)?\\s?([a-z]*)?")
...@@ -708,6 +753,9 @@ func (pg *PostgresqlDataSource) Collection(name string) db.Collection { ...@@ -708,6 +753,9 @@ func (pg *PostgresqlDataSource) Collection(name string) db.Collection {
} }
pg.collections[name] = t pg.collections[name] = t
} else {
panic(err)
}
return t return t
} }
...@@ -3,16 +3,60 @@ package postgresql ...@@ -3,16 +3,60 @@ package postgresql
import ( import (
"fmt" "fmt"
"github.com/gosexy/db" "github.com/gosexy/db"
"github.com/gosexy/sugar"
"github.com/kr/pretty" "github.com/kr/pretty"
"math/rand" "math/rand"
"testing" "testing"
"time"
) )
const pgHost = "10.0.0.11" const pgHost = "192.168.1.110"
const pgDatabase = "gotest" const pgDatabase = "gotest"
const pgUser = "gouser" const pgUser = "gouser"
const pgPassword = "gopass" const pgPassword = "gopass"
func getTestData() db.Item {
_time, _ := time.ParseDuration("17h20m")
data := db.Item{
"_uint": uint(1),
"_uintptr": uintptr(1),
"_uint8": uint8(1),
"_uint16": uint16(1),
"_uint32": uint32(1),
"_uint64": uint64(1),
"_int": int(-1),
"_int8": int8(-1),
"_int16": int16(-1),
"_int32": int32(-1),
"_int64": int64(-1),
"_float32": float32(1.0),
"_float64": float64(1.0),
//"_complex64": complex64(1),
//"_complex128": complex128(1),
"_byte": byte(1),
"_rune": rune(1),
"_bool": bool(true),
"_string": string("abc"),
"_bytea": []byte{'a', 'b', 'c'},
//"_list": sugar.List{1, 2, 3},
//"_map": sugar.Tuple{"a": 1, "b": 2, "c": 3},
"_date": time.Date(2012, 7, 28, 1, 2, 3, 0, time.UTC),
"_time": _time,
}
return data
}
func TestPgTruncate(t *testing.T) { func TestPgTruncate(t *testing.T) {
sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword}) sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword})
...@@ -29,7 +73,8 @@ func TestPgTruncate(t *testing.T) { ...@@ -29,7 +73,8 @@ func TestPgTruncate(t *testing.T) {
for _, name := range collections { for _, name := range collections {
col := sess.Collection(name) col := sess.Collection(name)
col.Truncate() col.Truncate()
if col.Count() != 0 { total, _ := col.Count()
if total != 0 {
t.Errorf("Could not truncate '%s'.", name) t.Errorf("Could not truncate '%s'.", name)
} }
} }
...@@ -57,7 +102,9 @@ func TestPgAppend(t *testing.T) { ...@@ -57,7 +102,9 @@ func TestPgAppend(t *testing.T) {
col.Append(db.Item{"name": names[i]}) col.Append(db.Item{"name": names[i]})
} }
if col.Count() != len(names) { total, _ := col.Count()
if total != len(names) {
t.Error("Could not append all items.") t.Error("Could not append all items.")
} }
...@@ -225,3 +272,126 @@ func TestPgRelation(t *testing.T) { ...@@ -225,3 +272,126 @@ func TestPgRelation(t *testing.T) {
fmt.Printf("%# v\n", pretty.Formatter(result)) fmt.Printf("%# v\n", pretty.Formatter(result))
} }
func TestDataTypes(t *testing.T) {
sess := Session(db.DataSource{Host: pgHost, Database: pgDatabase, User: pgUser, Password: pgPassword})
err := sess.Open()
if err == nil {
defer sess.Close()
}
col := sess.Collection("data_types")
col.Truncate()
data := getTestData()
err = col.Append(data)
if err != nil {
panic(err)
}
// Getting and reinserting.
item := col.Find()
err = col.Append(item)
if err == nil {
t.Errorf("Expecting duplicated-key error.")
}
delete(item, "id")
err = col.Append(item)
if err != nil {
t.Errorf("Could not append second element.")
}
// Testing rows
items := col.FindAll()
for i := 0; i < len(items); i++ {
item := items[i]
for key, _ := range item {
switch key {
// Signed integers.
case
"_int",
"_int8",
"_int16",
"_int32",
"_int64":
if item.GetInt(key) != int64(data["_int"].(int)) {
t.Errorf("Wrong datatype %v.", key)
}
// Unsigned integers.
case
"_uint",
"_uintptr",
"_uint8",
"_uint16",
"_uint32",
"_uint64",
"_byte",
"_rune":
if item.GetInt(key) != int64(data["_uint"].(uint)) {
t.Errorf("Wrong datatype %v.", key)
}
// Floating point.
case "_float32":
case "_float64":
if item.GetFloat(key) != data["_float64"].(float64) {
t.Errorf("Wrong datatype %v.", key)
}
// Boolean
case "_bool":
if item.GetBool(key) != data["_bool"].(bool) {
t.Errorf("Wrong datatype %v.", key)
}
// String
case "_string":
if item.GetString(key) != data["_string"].(string) {
t.Errorf("Wrong datatype %v.", key)
}
// Map
case "_map":
if item.GetTuple(key)["a"] != data["_map"].(sugar.Tuple)["a"] {
t.Errorf("Wrong datatype %v.", key)
}
// Array
case "_list":
if item.GetList(key)[0] != data["_list"].(sugar.List)[0] {
t.Errorf("Wrong datatype %v.", key)
}
// Time
case "_time":
if item.GetDuration(key).String() != data["_time"].(time.Duration).String() {
t.Errorf("Wrong datatype %v.", key)
}
// Date
case "_date":
if item.GetDate(key).Equal(data["_date"].(time.Time)) == false {
t.Errorf("Wrong datatype %v.", key)
}
}
}
}
}
...@@ -31,9 +31,9 @@ import ( ...@@ -31,9 +31,9 @@ import (
_ "github.com/xiam/gosqlite3" _ "github.com/xiam/gosqlite3"
"reflect" "reflect"
"regexp" "regexp"
"time"
"strconv" "strconv"
"strings" "strings"
"time"
) )
const dateFormat = "2006-01-02 15:04:05" const dateFormat = "2006-01-02 15:04:05"
......
...@@ -3,8 +3,8 @@ package sqlite ...@@ -3,8 +3,8 @@ package sqlite
import ( import (
"fmt" "fmt"
"github.com/gosexy/db" "github.com/gosexy/db"
"github.com/kr/pretty"
"github.com/gosexy/sugar" "github.com/gosexy/sugar"
"github.com/kr/pretty"
"math/rand" "math/rand"
"testing" "testing"
"time" "time"
...@@ -54,7 +54,6 @@ func getTestData() db.Item { ...@@ -54,7 +54,6 @@ func getTestData() db.Item {
return data return data
} }
func TestSqTruncate(t *testing.T) { func TestSqTruncate(t *testing.T) {
sess := SqliteSession(db.DataSource{Database: sqDatabase}) sess := SqliteSession(db.DataSource{Database: sqDatabase})
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment