diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 6c0b50450e933e80d9c6bb5f17200125f8a5b462..db2951e63cf7cb93b0bde8be5957c94d2fa106a7 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -1,11 +1,12 @@ package main import ( - "log" "net/http" _ "net/http/pprof" "os" + "tuxpa.in/a/zlog/log" + "pggat2/lib/gat/configs/pgbouncer" "pggat2/lib/gat/configs/zalando" ) @@ -15,10 +16,10 @@ func main() { panic(http.ListenAndServe(":8080", nil)) }() - log.Println("Starting pggat...") + log.Printf("Starting pggat...") if len(os.Args) == 2 { - log.Println("running in pgbouncer compatibility mode") + log.Printf("running in pgbouncer compatibility mode") conf, err := pgbouncer.Load(os.Args[1]) if err != nil { panic(err) @@ -31,7 +32,7 @@ func main() { return } - log.Println("running in zalando compatibility mode") + log.Printf("running in zalando compatibility mode") conf, err := zalando.Load() if err != nil { diff --git a/go.mod b/go.mod index 44b7c46e060b0a7932034e5a1e44d394395ef654..d42c941c36abfc8bc8008344453b1df2a2e32695 100644 --- a/go.mod +++ b/go.mod @@ -3,9 +3,9 @@ module pggat2 go 1.20 require ( + gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c github.com/google/uuid v1.3.0 github.com/xdg-go/scram v1.1.2 - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 k8s.io/api v0.27.4 k8s.io/apimachinery v0.27.4 k8s.io/client-go v0.27.4 @@ -13,8 +13,6 @@ require ( ) require ( - gfx.cafe/util/go v0.0.0-20230721185457-c559e86c829c // indirect - gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c // indirect github.com/cristalhq/aconfig v0.18.3 // indirect github.com/cristalhq/aconfig/aconfigdotenv v0.17.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/go.sum b/go.sum index 671345352b55d4c261df593142a1437374e27a4c..879607a8943d40ec0ac221f1ade31b2848942666 100644 --- a/go.sum +++ b/go.sum @@ -31,8 +31,6 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= -gfx.cafe/util/go v0.0.0-20230721185457-c559e86c829c h1:oEwP7BRlmwX0blcWuKVTy2L/aY6uJYmHcqSo05PVOIU= -gfx.cafe/util/go v0.0.0-20230721185457-c559e86c829c/go.mod h1:G98xT1KTC97UZGgq/q4gRskGfAC8syLf69tQQ3T40Rs= gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c h1:4XxKaHfYPam36FibTiy1Te7ycfW4+ys08WYyDih5VmU= gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c/go.mod h1:zxq7FdmfdrI4oGeze0MPJt9WqdkFj3BDDSAWRuB63JQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= @@ -295,7 +293,6 @@ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= diff --git a/lib/gat/configs/pgbouncer/config.go b/lib/gat/configs/pgbouncer/config.go index b04b0315731b0d517dd3f5a40b8dbfdd676598a2..27cbdce65297e6dbe5c205ee5758a237306b70e7 100644 --- a/lib/gat/configs/pgbouncer/config.go +++ b/lib/gat/configs/pgbouncer/config.go @@ -2,13 +2,14 @@ package pgbouncer import ( "errors" - "log" "net" "os" "strconv" "strings" "time" + "tuxpa.in/a/zlog/log" + "pggat2/lib/auth/credentials" "pggat2/lib/gat" "pggat2/lib/gat/pools/session" @@ -398,7 +399,7 @@ func (T *Config) ListenAndServe() error { return err } - log.Println("listening on", listen) + log.Printf("listening on %s", listen) return pooler.ListenAndServe(listener) }) diff --git a/lib/gat/configs/zalando/config.go b/lib/gat/configs/zalando/config.go index 859340fffca8618f43b3b5536a6fee3342114193..52a0d858696dce38826a006311405b14420b25d4 100644 --- a/lib/gat/configs/zalando/config.go +++ b/lib/gat/configs/zalando/config.go @@ -3,10 +3,11 @@ package zalando import ( "errors" "fmt" - "log" "net" "strconv" + "tuxpa.in/a/zlog/log" + "gfx.cafe/util/go/gun" "pggat2/lib/auth/credentials" @@ -80,7 +81,7 @@ func (T *Config) ListenAndServe() error { return err } - log.Println("listening on", listen) + log.Printf("listening on %s", listen) return pooler.ListenAndServe(listener) }) diff --git a/lib/gat/pool.go b/lib/gat/pool.go index df5b27b158bba673466fc6671d00fe559a5724ab..48da8ad453929ee3a70ec84a5aa83117460e433d 100644 --- a/lib/gat/pool.go +++ b/lib/gat/pool.go @@ -1,10 +1,11 @@ package gat import ( - "log" "sync" "time" + "tuxpa.in/a/zlog/log" + "github.com/google/uuid" "pggat2/lib/bouncer" diff --git a/lib/psql/errors.go b/lib/psql/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..53f75c6f8a9d3d42a6bf8e73bd40a862f5830c29 --- /dev/null +++ b/lib/psql/errors.go @@ -0,0 +1,10 @@ +package psql + +import "errors" + +var ( + ErrResultTooBig = errors.New("got too many rows for result") + ErrExtraFields = errors.New("received unexpected fields") + ErrResultMustBeNonNil = errors.New("result must be non nil") + ErrUnexpectedType = errors.New("unexpected result type") +) diff --git a/lib/psql/query.go b/lib/psql/query.go index 1abc85957c46680987759f7724e7c3bef73d2c3a..88cdac2d349c168fa390984b94b640914d1a7292 100644 --- a/lib/psql/query.go +++ b/lib/psql/query.go @@ -4,14 +4,18 @@ import ( "crypto/tls" "errors" "io" - "log" + "reflect" "pggat2/lib/bouncer/backends/v0" "pggat2/lib/zap" packets "pggat2/lib/zap/packets/v3.0" ) -type resultReader struct{} +type resultReader struct { + result reflect.Value + rd packets.RowDescription + row int +} func (T *resultReader) EnableSSLClient(_ *tls.Config) error { return errors.New("ssl not supported") @@ -33,20 +37,170 @@ func (T *resultReader) WriteByte(_ byte) error { return nil } +func (T *resultReader) set(i int, row []byte) error { + if i >= len(T.rd.Fields) { + return ErrExtraFields + } + desc := T.rd.Fields[i] + + result := T.result + + // unptr + for result.Kind() == reflect.Pointer { + if result.IsNil() { + if !result.CanSet() { + return ErrResultMustBeNonNil + } + result.Set(reflect.New(result.Type().Elem())) + } + result = result.Elem() + } + + // get row +outer: + for { + kind := result.Kind() + switch kind { + case reflect.Array: + if T.row >= result.Len() { + return ErrResultTooBig + } + result = result.Index(T.row) + break outer + case reflect.Slice: + for T.row >= result.Len() { + if !result.CanSet() { + return ErrResultTooBig + } + result.Set(reflect.Append(result, reflect.Zero(result.Type().Elem()))) + } + result = result.Index(T.row) + break outer + case reflect.Struct, reflect.Map: + if T.row != 0 { + return ErrResultTooBig + } + break outer + default: + return ErrUnexpectedType + } + } + + // unptr + for result.Kind() == reflect.Pointer { + if result.IsNil() { + if !result.CanSet() { + return ErrResultMustBeNonNil + } + result.Set(reflect.New(result.Type().Elem())) + } + result = result.Elem() + } + + // get field + kind := result.Kind() + typ := result.Type() +outer2: + switch kind { + case reflect.Struct: + for j := 0; j < typ.NumField(); j++ { + field := typ.Field(j) + if !field.IsExported() { + continue + } + + sqlName, hasSQLName := field.Tag.Lookup("sql") + if !hasSQLName { + sqlName = field.Name + } + + if sqlName == desc.Name { + result = result.Field(j) + break outer2 + } + } + + // ignore field + return nil + case reflect.Map: + key := typ.Key() + if key.Kind() != reflect.String { + return ErrUnexpectedType + } + + if result.IsNil() { + if !result.CanSet() { + return ErrResultMustBeNonNil + } + + result.Set(reflect.MakeMap(typ)) + } + + k := reflect.New(key).Elem() + k.SetString(desc.Name) + value := typ.Elem() + v := reflect.New(value).Elem() + m := result + result = v + defer func() { + m.SetMapIndex(k, v) + }() + default: + return ErrUnexpectedType + } + + if !result.CanSet() { + return ErrUnexpectedType + } + + if row == nil { + if result.Kind() == reflect.Pointer { + result.Set(reflect.Zero(result.Type())) + return nil + } else { + return ErrUnexpectedType + } + } + + // unptr + for result.Kind() == reflect.Pointer { + if result.IsNil() { + if !result.CanSet() { + return ErrResultMustBeNonNil + } + result.Set(reflect.New(result.Type().Elem())) + } + result = result.Elem() + } + + kind = result.Kind() + typ = result.Type() + switch kind { + case reflect.String: + result.SetString(string(row)) + return nil + default: + return ErrUnexpectedType + } +} + func (T *resultReader) WritePacket(packet zap.Packet) error { switch packet.Type() { case packets.TypeRowDescription: - var rd packets.RowDescription - if !rd.ReadFromPacket(packet) { + if !T.rd.ReadFromPacket(packet) { return errors.New("invalid format") } - log.Printf("row description: %#v", rd) case packets.TypeDataRow: var dr packets.DataRow if !dr.ReadFromPacket(packet) { return errors.New("invalid format") } - log.Printf("data row: %#v", dr) + for i, row := range dr.Columns { + if err := T.set(i, row); err != nil { + return err + } + } + T.row += 1 } return nil } @@ -57,8 +211,10 @@ func (T *resultReader) Close() error { var _ zap.ReadWriter = (*resultReader)(nil) -func Query(server zap.ReadWriter, query string) error { - var res resultReader +func Query(server zap.ReadWriter, query string, result any) error { + res := resultReader{ + result: reflect.ValueOf(result), + } ctx := backends.Context{ Peer: &res, } diff --git a/lib/psql/query_test.go b/lib/psql/query_test.go index 4a573c258404f9a7cfe278e850d72069391586d8..271b203ef340fe8fecda2fd3d070ac5e45e324b6 100644 --- a/lib/psql/query_test.go +++ b/lib/psql/query_test.go @@ -4,11 +4,18 @@ import ( "net" "testing" + "tuxpa.in/a/zlog/log" + "pggat2/lib/auth/credentials" "pggat2/lib/bouncer/backends/v0" "pggat2/lib/zap" ) +type Result struct { + Username string `sql:"usename"` + Password *string `sql:"passwd"` +} + func TestQuery(t *testing.T) { // open server s, err := net.Dial("tcp", "localhost:5432") @@ -29,9 +36,13 @@ func TestQuery(t *testing.T) { return } - err = Query(server, "SELECT usename, passwd FROM pg_shadow WHERE usename='postgres'") + var res Result + + err = Query(server, "SELECT usename, passwd FROM pg_shadow WHERE usename='postgres'", &res) if err != nil { t.Error(err) return } + + log.Printf("%#v", res) }