diff --git a/mysql/collection.go b/mysql/collection.go index 6b53facbfe082b6f296ffdd802719cef1801101d..7858e04f1640c39f030ac7421f6f95b337a90e0b 100644 --- a/mysql/collection.go +++ b/mysql/collection.go @@ -67,7 +67,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { return nil, err } - sqlgenCols, sqlgenVals, sqlgenArgs, err := t.ColumnsValuesAndArguments(columnNames, columnValues) + sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) if err != nil { return nil, err diff --git a/postgresql/collection.go b/postgresql/collection.go index 19ecda73b6424b7186091b5454581210e770cafe..1f5d00e487d8c7efbb41202df9c67cf938dfacb8 100644 --- a/postgresql/collection.go +++ b/postgresql/collection.go @@ -64,38 +64,16 @@ func (t *table) Truncate() error { // Append inserts an item (map or struct) into the collection. func (t *table) Append(item interface{}) (interface{}, error) { - cols, vals, err := t.FieldValues(item) + columnNames, columnValues, err := t.FieldValues(item) if err != nil { return nil, err } - columns := new(sqlgen.Columns) + sqlgenCols, sqlgenVals, sqlgenArgs, err := sqlutil.ToColumnsValuesAndArguments(columnNames, columnValues) - columns.Columns = make([]sqlgen.Fragment, 0, len(cols)) - for i := range cols { - columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(cols[i])) - } - - values := new(sqlgen.Values) - var arguments []interface{} - - arguments = make([]interface{}, 0, len(vals)) - values.Values = make([]sqlgen.Fragment, 0, len(vals)) - - for i := range vals { - switch v := vals[i].(type) { - case *sqlgen.Value: - // Adding value. - values.Values = append(values.Values, v) - case sqlgen.Value: - // Adding value. - values.Values = append(values.Values, &v) - default: - // Adding both value and placeholder. - values.Values = append(values.Values, sqlPlaceholder) - arguments = append(arguments, v) - } + if err != nil { + return nil, err } var pKey []string @@ -110,15 +88,15 @@ func (t *table) Append(item interface{}) (interface{}, error) { stmt := sqlgen.Statement{ Type: sqlgen.Insert, Table: sqlgen.TableWithName(t.MainTableName()), - Columns: columns, - Values: values, + Columns: sqlgenCols, + Values: sqlgenVals, } // No primary keys defined. if len(pKey) == 0 { var res sql.Result - if res, err = t.database.Exec(stmt, arguments...); err != nil { + if res, err = t.database.Exec(stmt, sqlgenArgs...); err != nil { return nil, err } @@ -133,7 +111,7 @@ func (t *table) Append(item interface{}) (interface{}, error) { // A primary key was found. stmt.Extra = sqlgen.Extra(fmt.Sprintf(`RETURNING "%s"`, strings.Join(pKey, `", "`))) - if rows, err = t.database.Query(stmt, arguments...); err != nil { + if rows, err = t.database.Query(stmt, sqlgenArgs...); err != nil { return nil, err } diff --git a/util/sqlutil/convert.go b/util/sqlutil/convert.go index 1944626836fd1c62805332a020e42b9b831eaf80..d88bfaca92dafd66c5fbca4198e087dbc9a27b98 100644 --- a/util/sqlutil/convert.go +++ b/util/sqlutil/convert.go @@ -157,3 +157,38 @@ func ToColumnValues(cond db.Cond) (ToColumnValues sqlgen.ColumnValues, args []in return ToColumnValues, args } + +// ToColumnsValuesAndArguments maps the given columnNames and columnValues into +// sqlgen's Columns and Values, it also extracts and returns query arguments. +func ToColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(sqlgen.Columns) + + columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i])) + } + + values := new(sqlgen.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]sqlgen.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *sqlgen.Value: + // Adding value. + values.Values = append(values.Values, v) + case sqlgen.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} diff --git a/util/sqlutil/main.go b/util/sqlutil/main.go index 11e5e2669d0a6913342506a6389d824b45c8a8e1..3c1546a9d8cb7c6956cf9435233337476f3dcc44 100644 --- a/util/sqlutil/main.go +++ b/util/sqlutil/main.go @@ -33,7 +33,6 @@ import ( "upper.io/db" "upper.io/db/util" - "upper.io/db/util/sqlgen" ) var ( @@ -228,36 +227,3 @@ func (t *T) NthTableName(i int) string { } return "" } - -func (t *T) ColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*sqlgen.Columns, *sqlgen.Values, []interface{}, error) { - var arguments []interface{} - - columns := new(sqlgen.Columns) - - columns.Columns = make([]sqlgen.Fragment, 0, len(columnNames)) - for i := range columnNames { - columns.Columns = append(columns.Columns, sqlgen.ColumnWithName(columnNames[i])) - } - - values := new(sqlgen.Values) - - arguments = make([]interface{}, 0, len(columnValues)) - values.Values = make([]sqlgen.Fragment, 0, len(columnValues)) - - for i := range columnValues { - switch v := columnValues[i].(type) { - case *sqlgen.Value: - // Adding value. - values.Values = append(values.Values, v) - case sqlgen.Value: - // Adding value. - values.Values = append(values.Values, &v) - default: - // Adding both value and placeholder. - values.Values = append(values.Values, sqlPlaceholder) - arguments = append(arguments, v) - } - } - - return columns, values, arguments, nil -}