diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 546d5ca13905edb36a1e2aef6c2166d5edbe8f38..e5b383b97317d85522c4e412616421236166cef3 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -8,6 +8,7 @@ import ( "tuxpa.in/a/zlog/log" + "pggat/lib/gat/modes/cloud_sql_discovery" "pggat/lib/gat/modes/digitalocean_discovery" "pggat/lib/gat/modes/pgbouncer" "pggat/lib/gat/modes/zalando" @@ -50,6 +51,18 @@ func main() { return } + if os.Getenv("PGGAT_GC_PROJECT") != "" { + conf, err := cloud_sql_discovery.Load() + if err != nil { + panic(err) + } + err = conf.ListenAndServe() + if err != nil { + panic(err) + } + return + } + if os.Getenv("PGGAT_DO_API_KEY") != "" { log.Printf("running in digitalocean discovery mode") diff --git a/go.mod b/go.mod index 17e750bd2d8c8c9cb6f5bb43bdf91170267f317f..a0de8e296d445229c91580798263140b5ba5c563 100644 --- a/go.mod +++ b/go.mod @@ -3,17 +3,19 @@ module pggat go 1.20 require ( + gfx.cafe/ghalliday1/scram v0.0.2 gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c github.com/digitalocean/godo v1.102.1 github.com/google/uuid v1.3.0 - github.com/xdg-go/scram v1.1.2 - github.com/zalando/postgres-operator v1.10.1 + github.com/zalando/postgres-operator v1.8.2 + google.golang.org/api v0.30.0 k8s.io/apimachinery v0.27.4 k8s.io/client-go v0.27.4 tuxpa.in/a/zlog v1.61.0 ) require ( + cloud.google.com/go v0.65.0 // indirect github.com/cristalhq/aconfig v0.18.3 // indirect github.com/cristalhq/aconfig/aconfigdotenv v0.17.1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -23,11 +25,13 @@ require ( github.com/go-openapi/jsonreference v0.20.1 // indirect github.com/go-openapi/swag v0.22.3 // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/gnostic v0.5.7-v3refs // indirect github.com/google/go-cmp v0.5.9 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/gofuzz v1.1.0 // indirect + github.com/googleapis/gax-go/v2 v2.0.5 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.4 // indirect github.com/imdario/mergo v0.3.6 // indirect @@ -45,22 +49,24 @@ require ( github.com/rs/zerolog v1.28.0 // indirect github.com/sirupsen/logrus v1.9.0 // indirect github.com/spf13/pflag v1.0.5 // indirect - github.com/xdg-go/pbkdf2 v1.0.0 // indirect github.com/xdg-go/stringprep v1.0.4 // indirect + go.opencensus.io v0.22.4 // indirect golang.org/x/crypto v0.8.0 // indirect golang.org/x/net v0.9.0 // indirect golang.org/x/oauth2 v0.0.0-20220411215720-9780585627b5 // indirect golang.org/x/sys v0.7.0 // indirect golang.org/x/term v0.7.0 // indirect - golang.org/x/text v0.9.0 // indirect + golang.org/x/text v0.13.0 // indirect golang.org/x/time v0.0.0-20220922220347-f3bd1da661af // indirect google.golang.org/appengine v1.6.7 // indirect + google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 // indirect + google.golang.org/grpc v1.47.0 // indirect google.golang.org/protobuf v1.28.1 // indirect gopkg.in/inf.v0 v0.9.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect k8s.io/api v0.27.4 // indirect - k8s.io/apiextensions-apiserver v0.25.9 // indirect + k8s.io/apiextensions-apiserver v0.25.0-rc.1 // indirect k8s.io/klog/v2 v2.90.1 // indirect k8s.io/kube-openapi v0.0.0-20230501164219-8b0f38b5fd1f // indirect k8s.io/utils v0.0.0-20230209194617-a36077c30491 // indirect diff --git a/go.sum b/go.sum index 2cab442610268da01c534c02fc398442c067962a..f571d8eb76a8e8bf72cda03eb9b80b95fb2901d9 100644 --- a/go.sum +++ b/go.sum @@ -12,6 +12,7 @@ cloud.google.com/go v0.54.0/go.mod h1:1rq2OEkV3YMf6n/9ZvGWI3GWw0VoqH/1x2nd8Is/bP cloud.google.com/go v0.56.0/go.mod h1:jr7tqZxxKOVYizybht9+26Z/gUq7tiRzu+ACVAMbKVk= cloud.google.com/go v0.57.0/go.mod h1:oXiQ6Rzq3RAkkY7N6t3TcE6jE+CIBBbA36lwQ1JyzZs= cloud.google.com/go v0.62.0/go.mod h1:jmCYTdRCQuc1PHIIJ/maLInMho30T/Y0M4hTdTShOYc= +cloud.google.com/go v0.65.0 h1:Dg9iHVQfrhq82rUNu9ZxUDrJLaxFUe/HlCVaLyRruq8= cloud.google.com/go v0.65.0/go.mod h1:O5N8zS7uWy9vkA9vayVHs65eM1ubvY4h553ofrNHObY= cloud.google.com/go/bigquery v1.0.1/go.mod h1:i/xbL2UlR5RvWAURpBYZTtm/cXjCha9lbfbpx4poX+o= cloud.google.com/go/bigquery v1.3.0/go.mod h1:PjpwJnslEMmckchkHFfq+HTD2DmtT67aNFKH1/VBDHE= @@ -31,16 +32,33 @@ cloud.google.com/go/storage v1.6.0/go.mod h1:N7U0C8pVQ/+NIKOBQyamJIeKQKkZ+mxpohl cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RXyy7KQOVs= cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0= dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU= +gfx.cafe/ghalliday1/scram v0.0.1 h1:CUhXNgquodVk+8hIUYGzqQg585T+yQ4cRs2RU6wG2tA= +gfx.cafe/ghalliday1/scram v0.0.1/go.mod h1:qt6+WJoFcX2id67G5w+/dktBJ56Se0sZAa7WjqfNNu0= +gfx.cafe/ghalliday1/scram v0.0.2 h1:uuScaL7DUEP/lKWJnA5kKHq5RJev26q8DMbP3gKviAg= +gfx.cafe/ghalliday1/scram v0.0.2/go.mod h1:qt6+WJoFcX2id67G5w+/dktBJ56Se0sZAa7WjqfNNu0= gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c h1:4XxKaHfYPam36FibTiy1Te7ycfW4+ys08WYyDih5VmU= gfx.cafe/util/go/gun v0.0.0-20230721185457-c559e86c829c/go.mod h1:zxq7FdmfdrI4oGeze0MPJt9WqdkFj3BDDSAWRuB63JQ= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= +github.com/NYTimes/gziphandler v1.1.1/go.mod h1:n/CVRwUEOgIxrgPvAQhUUr9oeUtvrhMomdKFjzJNB0c= +github.com/PuerkitoBio/purell v1.1.1/go.mod h1:c11w/QuzBsJSee3cPx9rAFu61PvFxuPbtSwDGJws/X0= +github.com/PuerkitoBio/urlesc v0.0.0-20170810143723-de5bf2ad4578/go.mod h1:uGdkoq3SwY9Y+13GIhn11/XLaGBb4BfwItxLd5jeuXE= +github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= +github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= +github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= +github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnhtbo57+VJT5O0JRkEi1Wbu0z5j0R8u5Hbk= +github.com/cncf/udpa/go v0.0.0-20210930031921-04548b0d99d4/go.mod h1:6pvJx4me5XPnfI9Z40ddWsdw2W/uZgQLFXToKeRcDiI= +github.com/cncf/xds/go v0.0.0-20210922020428-25de7278fc84/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20211001041855-01bcc9b48dfe/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/cncf/xds/go v0.0.0-20211011173535-cb28da3451f1/go.mod h1:eXthEFrGJvWHgFFCl3hGmgk+/aYT6PnTQLykKQRLhEs= +github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk= github.com/coreos/go-systemd/v22 v22.3.3-0.20220203105225-a9a7ef127534/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= @@ -55,12 +73,20 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/digitalocean/godo v1.102.1 h1:BrNePwIXjQWjOJXVTBqkURMjm70BRR0qXbRKfHNBF24= github.com/digitalocean/godo v1.102.1/go.mod h1:SaUYccN7r+CO1QtsbXGypAsgobDrmSfVMJESEfXgoEg= github.com/docopt/docopt-go v0.0.0-20180111231733-ee0de3bc6815/go.mod h1:WwZ+bS3ebgob9U8Nd0kOddGdZWjyMGR8Wziv+TBNwSE= +github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/emicklei/go-restful v2.9.5+incompatible/go.mod h1:otzb+WCGbkyDHkqmQmT5YD2WR4BBwUdeQoFo8l/7tVs= github.com/emicklei/go-restful/v3 v3.9.0 h1:XwGDlfxEnQZzuopoqxwSEllNcCOM9DhhFyhFIIGKwxE= github.com/emicklei/go-restful/v3 v3.9.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= +github.com/envoyproxy/go-control-plane v0.9.9-0.20201210154907-fd9021fe5dad/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= +github.com/envoyproxy/go-control-plane v0.10.2-0.20220325020618-49ff273808a1/go.mod h1:KJwIaB5Mv44NWtYuAOFCVOjcI94vtpEz2JU/D2v6IjE= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= +github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/form3tech-oss/jwt-go v3.2.3+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -81,6 +107,8 @@ github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfU github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20200121045136-8c9f03a8e57e/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= +github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/mock v1.3.1/go.mod h1:sBzyDLLjw3U8JLTeZvSv8jJB+tU5PVekmnlKIyFUx0Y= @@ -101,12 +129,14 @@ github.com/golang/protobuf v1.4.0-rc.4.0.20200313231945-b860323f09d0/go.mod h1:W github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/google/btree v1.0.1/go.mod h1:xXMiIv4Fb/0kKde4SpL7qlzvu5cMJDRkFDxJfI9uaxA= github.com/google/gnostic v0.5.7-v3refs h1:FhTMOKj2VhjpouxvWJAV1TL304uMlb9zcDqkl6cEI54= github.com/google/gnostic v0.5.7-v3refs/go.mod h1:73MKFl6jIHelAJNaBGFzt3SPtZULs9dYrGFt8OiIsHQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -118,6 +148,7 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -136,10 +167,18 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20210720184732-4bb14d4b1be1 h1:K6RDEckDVWvDI9JAJYCmNdQXq6neHJOYx3V6jnqNEec= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= +github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= +github.com/googleapis/gnostic v0.5.5/go.mod h1:7+EbHbldMins07ALC74bsA81Ovc97DwqyJO1AENw9kA= +github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= +github.com/grpc-ecosystem/go-grpc-middleware v1.3.0/go.mod h1:z0ButlSOZa5vEBq9m2m2hlwIgKw+rp3sdCBRoJY+30Y= +github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= +github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw= github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= @@ -151,8 +190,10 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/imdario/mergo v0.3.6 h1:xTNEAn+kxVO7dTZGu0CegyqKZmoWFI0rF8UxjlB2d28= github.com/imdario/mergo v0.3.6/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJh5FfA= +github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= +github.com/jonboulle/clockwork v0.2.2/go.mod h1:Pkfl5aHPm1nk2H9h0bjmnJD/BcgbGXUBGnn1kMkgxc8= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -175,6 +216,8 @@ github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZb github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= github.com/mattn/go-isatty v0.0.14 h1:yVuAays6BHfxijgZPzw+3Zlu5yQgKGP2/hcQbHb7S9Y= github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= +github.com/moby/spdystream v0.2.0/go.mod h1:f7i0iNDQJ059oMTcWxx8MA/zKFIuD/lY+0GqbN2Wy8c= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -184,12 +227,16 @@ github.com/motomux/pretty v0.0.0-20161209205251-b2aad2c9a95d h1:LznySqW8MqVeFh+p github.com/motomux/pretty v0.0.0-20161209205251-b2aad2c9a95d/go.mod h1:u3hJ0kqCQu/cPpsu3RbCOPZ0d7V3IjPjv1adNRleM9I= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+o7JKHSa8/e818NopupXU1YMK5fe1lsApnBw= github.com/onsi/ginkgo/v2 v2.9.1 h1:zie5Ly042PD3bsCvsSOPvRnFwyo3rKe64TJlD6nu0mk= github.com/onsi/gomega v1.27.4 h1:Z2AnStgsdSayCMDiCU42qIz+HLqEPcgiOCXjAU/w+8E= +github.com/peterbourgon/diskv v2.0.1+incompatible/go.mod h1:uqqh8zWWbv1HBMNONnaR/tNboyR3/BZd58JJSHlUSCU= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/prometheus/client_model v0.2.0/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= +github.com/rogpeppe/fastuuid v1.2.0/go.mod h1:jVj6XXZzXRy/MSR5jhDC/2q6DgLz+nrA6LYCDYWNEvQ= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= @@ -197,6 +244,7 @@ github.com/rs/zerolog v1.28.0 h1:MirSo27VyNi7RJYP3078AA1+Cyzd2GB66qy3aUHvsWY= github.com/rs/zerolog v1.28.0/go.mod h1:NILgTygv/Uej1ra5XxGf82ZFSLk58MFGAUS2o6usyD0= github.com/sirupsen/logrus v1.9.0 h1:trlNQbNUG3OdDrDil03MCb1H2o9nJ1x4/5LYw7byDE0= github.com/sirupsen/logrus v1.9.0/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stoewer/go-strcase v1.2.0/go.mod h1:IBiWB2sKIp3wVVQ3Y035++gc+knqhUQag1KpM8ahLw8= @@ -212,24 +260,37 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= -github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= -github.com/xdg-go/scram v1.1.2 h1:FHX5I5B4i4hKRVRBCFRxq1iQRej7WO3hhBuJf+UUySY= -github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4= +github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/xdg-go/stringprep v1.0.4 h1:XLI/Ng3O1Atzq0oBs3TWm+5ZVgkq2aqdlvP9JtoZ6c8= github.com/xdg-go/stringprep v1.0.4/go.mod h1:mPGuuIYwz7CmR2bT9j4GbQqutWS1zV24gijq1dTyGkM= +github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.1.32/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -github.com/zalando/postgres-operator v1.10.1 h1:2QAZam6e3dhK8D64Hc9m4eul29f1yggGMAH3ff20etw= -github.com/zalando/postgres-operator v1.10.1/go.mod h1:UYVdslgiYgsKSuU24Mne2qO67nuWTJwWiT1WQDurROs= +github.com/zalando/postgres-operator v1.8.2 h1:3FW3j2gXua1MSeE+NiSvB8cxM7k7fyoun46G1v++CCA= +github.com/zalando/postgres-operator v1.8.2/go.mod h1:f7AXk8LO/tWFdW4myPJZCwMueGg6fI4RqTuOA0BefZE= +go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU= go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8= go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opencensus.io v0.22.4 h1:LYy1Hy3MJdrCdMwwzxA/dRok4ejH+RwNGbuoD9fCjto= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= +go.opentelemetry.io/contrib v0.20.0/go.mod h1:G/EtFaa6qaN7+LxqfIAT3GiZa7Wv5DTBUzl5H4LY0Kc= +go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.20.0/go.mod h1:oVGt1LRbBOBq1A5BQLlUg9UaU/54aiHw8cgjV3aWZ/E= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.20.0/go.mod h1:2AboqHi0CiIZU0qwhtUfCYD1GeUzvvIXWNkhDt7ZMG4= +go.opentelemetry.io/otel v0.20.0/go.mod h1:Y3ugLH2oa81t5QO+Lty+zXf8zC9L26ax4Nzoxm/dooo= +go.opentelemetry.io/otel/exporters/otlp v0.20.0/go.mod h1:YIieizyaN77rtLJra0buKiNBOm9XQfkPEKBeuhoMwAM= +go.opentelemetry.io/otel/metric v0.20.0/go.mod h1:598I5tYlH1vzBjn+BTuhzTCSb/9debfNp6R3s7Pr1eU= +go.opentelemetry.io/otel/sdk v0.20.0/go.mod h1:g/IcepuwNsoiX5Byy2nNV0ySUF1em498m7hBWC279Yc= +go.opentelemetry.io/otel/sdk/export/metric v0.20.0/go.mod h1:h7RBNMsDJ5pmI1zExLi+bJK+Dr8NQCh0qGhm1KDnNlE= +go.opentelemetry.io/otel/sdk/metric v0.20.0/go.mod h1:knxiS8Xd4E/N+ZqKmUPf3gTTZ4/0TjTXukfxjzSTpHE= +go.opentelemetry.io/otel/trace v0.20.0/go.mod h1:6GjCW8zgDjwGHGa6GkyeB8+/5vjT16gUEi0Nf1iBdgw= +go.opentelemetry.io/proto/otlp v0.7.0/go.mod h1:PqfVotwruBrMGOCsRd/89rSnXhoiJIqeYNgFYFoEGnI= +go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= @@ -297,6 +358,7 @@ golang.org/x/net v0.0.0-20200707034311-ab3426394381/go.mod h1:/O7V0waA8r7cgGh81R golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= golang.org/x/net v0.0.0-20220127200216-cd36cc0744dd/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.9.0 h1:aWJ/m6xSmxWBx+V0XRHTlrYrPG56jKsLdTFmsSsCzOM= @@ -345,6 +407,9 @@ golang.org/x/sys v0.0.0-20200523222454-059865788121/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -364,10 +429,11 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE= -golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -436,6 +502,7 @@ google.golang.org/api v0.22.0/go.mod h1:BwFmGc8tA3vsd7r/7kR8DY7iEEGSU04BFxCo5jP/ google.golang.org/api v0.24.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.28.0/go.mod h1:lIXQywCXRcnZPGlsd8NbLnOjtAoL6em04bJ9+z0MncE= google.golang.org/api v0.29.0/go.mod h1:Lcubydp8VUV7KeIHD9z2Bys/sm/vGKnG1UHuDBSrHWM= +google.golang.org/api v0.30.0 h1:yfrXXP61wVuLb0vBcG6qaOoIoqYEzOQS8jum51jkv2w= google.golang.org/api v0.30.0/go.mod h1:QGmEvQ87FHZNiUVJkT14jQNYJ4ZJjdRF23ZXz5138Fc= google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= @@ -468,6 +535,7 @@ google.golang.org/genproto v0.0.0-20200312145019-da6875a35672/go.mod h1:55QSHmfG google.golang.org/genproto v0.0.0-20200331122359-1ee6d9798940/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200430143042-b979b6f78d84/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200511104702-f5ebc3bea380/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= +google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200515170657-fc4c6c6a6587/go.mod h1:YsZOwe1myG/8QRHRsmBRE1LrgQY60beZKjly0O1fX9U= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20200618031413-b414f8b61790/go.mod h1:jDfRM7FcilCzHH/e9qn6dsT145K34l5v+OpcnNgKAAA= @@ -475,6 +543,8 @@ google.golang.org/genproto v0.0.0-20200729003335-053ba62fc06f/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200804131852-c06518451d9c/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20200825200019-8632dd797987/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201019141844-1ed22bb0c154/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21 h1:hrbNEivu7Zn1pxvHk6MBrq9iE22woVILTHqexqBxe6I= +google.golang.org/genproto v0.0.0-20220502173005-c8bf987b8c21/go.mod h1:RAyBrSAP7Fh3Nc84ghnVLDPuV51xc9agzmm4Ph6i0Q4= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.20.1/go.mod h1:10oTOabMzJvdu6/UiuZezV6QK5dSlG84ov/aaiqXj38= google.golang.org/grpc v1.21.1/go.mod h1:oYelfM1adQP15Ek0mdvEgi9Df8B9CZIaU1084ijfRaM= @@ -487,6 +557,11 @@ google.golang.org/grpc v1.28.0/go.mod h1:rpkK4SK4GF4Ach/+MFLZUBavHOvF2JJB5uozKKa google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.30.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= google.golang.org/grpc v1.31.0/go.mod h1:N36X2cJ7JwdamYAgDz+s+rVMFjt3numwzf/HckM8pak= +google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0= +google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAGRRjU= +google.golang.org/grpc v1.46.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= +google.golang.org/grpc v1.47.0 h1:9n77onPX5F3qfFCqjy9dhn8PbNQsIKeVU04J9G7umt8= +google.golang.org/grpc v1.47.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= @@ -499,6 +574,8 @@ google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGj google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= @@ -509,7 +586,9 @@ gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EV gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc= gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw= +gopkg.in/natefinch/lumberjack.v2 v2.0.0/go.mod h1:l0ndWWf7gzL7RNwBG7wST/UCcT4T24xpD6X8LsfU/+k= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.3/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= @@ -526,8 +605,8 @@ honnef.co/go/tools v0.0.1-2020.1.3/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9 honnef.co/go/tools v0.0.1-2020.1.4/go.mod h1:X/FiERA/W4tHapMX5mGpAtMSVEeEUOyHaw9vFzvIQ3k= k8s.io/api v0.27.4 h1:0pCo/AN9hONazBKlNUdhQymmnfLRbSZjd5H5H3f0bSs= k8s.io/api v0.27.4/go.mod h1:O3smaaX15NfxjzILfiln1D8Z3+gEYpjEpiNA/1EVK1Y= -k8s.io/apiextensions-apiserver v0.25.9 h1:Pycd6lm2auABp9wKQHCFSEPG+NPdFSTJXPST6NJFzB8= -k8s.io/apiextensions-apiserver v0.25.9/go.mod h1:ijGxmSG1GLOEaWhTuaEr0M7KUeia3mWCZa6FFQqpt1M= +k8s.io/apiextensions-apiserver v0.25.0-rc.1 h1:VtIahSlfaERJvFNJPfcBGYtkz/X0zHaV2CEUfrlzTS4= +k8s.io/apiextensions-apiserver v0.25.0-rc.1/go.mod h1:Cuej1N2xyvuVAoUx3GidzheWE/oHF2r55VjsqUfBJKk= k8s.io/apimachinery v0.27.4 h1:CdxflD4AF61yewuid0fLl6bM4a3q04jWel0IlP+aYjs= k8s.io/apimachinery v0.27.4/go.mod h1:XNfZ6xklnMCOGGFNqXG7bUrQCoR04dh/E7FprV6pb+E= k8s.io/client-go v0.27.4 h1:vj2YTtSJ6J4KxaC88P4pMPEQECWMY8gqPqsTgUKzvjk= diff --git a/lib/auth/credentials.go b/lib/auth/credentials.go index b8b2299429e286ee2e2929c69d96f6cc8de02888..5bb6b7fa89950b9ba71a5e104956a4e007a3e659 100644 --- a/lib/auth/credentials.go +++ b/lib/auth/credentials.go @@ -1,20 +1,30 @@ package auth type Credentials interface { - GetUsername() string + Credentials() } -type Cleartext interface { +type CleartextClient interface { Credentials EncodeCleartext() string +} + +type CleartextServer interface { + Credentials + VerifyCleartext(value string) error } -type MD5 interface { +type MD5Client interface { Credentials EncodeMD5(salt [4]byte) string +} + +type MD5Server interface { + Credentials + VerifyMD5(salt [4]byte, value string) error } @@ -32,11 +42,16 @@ type SASLVerifier interface { Write(bytes []byte) ([]byte, error) } -type SASL interface { +type SASLClient interface { + Credentials + + EncodeSASL(mechanisms []SASLMechanism) (SASLMechanism, SASLEncoder, error) +} + +type SASLServer interface { Credentials SupportedSASLMechanisms() []SASLMechanism - EncodeSASL(mechanisms []SASLMechanism) (SASLMechanism, SASLEncoder, error) VerifySASL(mechanism SASLMechanism) (SASLVerifier, error) } diff --git a/lib/auth/credentials/cleartext.go b/lib/auth/credentials/cleartext.go index 45821e07f726f6e2c4042040b0843e5feed5e356..57a52feb1f09b6e2f16b7db999cce21f064f9a12 100644 --- a/lib/auth/credentials/cleartext.go +++ b/lib/auth/credentials/cleartext.go @@ -2,10 +2,12 @@ package credentials import ( "crypto/md5" + "crypto/rand" + "crypto/sha256" "encoding/hex" "strings" - "github.com/xdg-go/scram" + "gfx.cafe/ghalliday1/scram" "pggat/lib/auth" "pggat/lib/util/slices" @@ -16,9 +18,7 @@ type Cleartext struct { Password string } -func (T Cleartext) GetUsername() string { - return T.Username -} +func (Cleartext) Credentials() {} func (T Cleartext) EncodeCleartext() string { return T.Password @@ -67,107 +67,55 @@ func (T Cleartext) SupportedSASLMechanisms() []auth.SASLMechanism { } } -type CleartextScramEncoder struct { - conversation *scram.ClientConversation -} - -func MakeCleartextScramEncoder(username, password string, hashGenerator scram.HashGeneratorFcn) (CleartextScramEncoder, error) { - client, err := hashGenerator.NewClient(username, password, "") - if err != nil { - return CleartextScramEncoder{}, err - } - - return CleartextScramEncoder{ - conversation: client.NewConversation(), - }, nil -} - -func (T CleartextScramEncoder) Write(bytes []byte) ([]byte, error) { - msg, err := T.conversation.Step(string(bytes)) - if err != nil { - return nil, err - } - return []byte(msg), nil -} - -var _ auth.SASLEncoder = CleartextScramEncoder{} - func (T Cleartext) EncodeSASL(mechanisms []auth.SASLMechanism) (auth.SASLMechanism, auth.SASLEncoder, error) { for _, mechanism := range mechanisms { switch mechanism { case auth.ScramSHA256: - encoder, err := MakeCleartextScramEncoder(T.Username, T.Password, scram.SHA256) - if err != nil { - return "", nil, err - } - - return auth.ScramSHA256, encoder, nil + return auth.ScramSHA256, &scram.ClientConversation{ + Lookup: scram.ClientPasswordLookup(T.Password, sha256.New), + }, nil } } return "", nil, auth.ErrSASLMechanismNotSupported } -type CleartextScramVerifier struct { - conversation *scram.ServerConversation -} - -func MakeCleartextScramVerifier(username, password string, hashGenerator scram.HashGeneratorFcn) (CleartextScramVerifier, error) { - client, err := hashGenerator.NewClient(username, password, "") - if err != nil { - return CleartextScramVerifier{}, err - } - - kf := scram.KeyFactors{ - Iters: 4096, - } - stored := client.GetStoredCredentials(kf) - - server, err := hashGenerator.NewServer( - func(string) (scram.StoredCredentials, error) { - return stored, nil - }, - ) - if err != nil { - return CleartextScramVerifier{}, err - } - - return CleartextScramVerifier{ - conversation: server.NewConversation(), - }, nil -} - -func (T CleartextScramVerifier) Write(bytes []byte) ([]byte, error) { - msg, err := T.conversation.Step(string(bytes)) - if err != nil { - return nil, err - } - - if T.conversation.Done() { - // check if conversation params are valid - if !T.conversation.Valid() { - return nil, auth.ErrFailed - } - - // done - return []byte(msg), auth.ErrSASLComplete - } - - // there is more - return []byte(msg), nil -} - -var _ auth.SASLVerifier = CleartextScramVerifier{} - func (T Cleartext) VerifySASL(mechanism auth.SASLMechanism) (auth.SASLVerifier, error) { switch mechanism { case auth.ScramSHA256: - return MakeCleartextScramVerifier(T.Username, T.Password, scram.SHA256) + return &scram.ServerConversation{ + Lookup: func(string) (scram.ServerKeys, bool) { + var salt [32]byte + _, err := rand.Read(salt[:]) + if err != nil { + return scram.ServerKeys{}, false + } + hasher := scram.Hasher(sha256.New) + keyInfo := scram.KeyInfo{ + Salt: salt[:], + Iters: 2048, + Hasher: hasher, + } + saltedPassword := hasher.SaltedPassword([]byte(T.Password), keyInfo.Salt, keyInfo.Iters) + serverKey := hasher.ServerKey(saltedPassword) + clientKey := hasher.ClientKey(saltedPassword) + storedKey := hasher.StoredKey(clientKey) + + return scram.ServerKeys{ + ServerKey: serverKey, + StoredKey: storedKey, + KeyInfo: keyInfo, + }, true + }, + }, nil default: return nil, auth.ErrSASLMechanismNotSupported } } var _ auth.Credentials = Cleartext{} -var _ auth.Cleartext = Cleartext{} -var _ auth.MD5 = Cleartext{} -var _ auth.SASL = Cleartext{} +var _ auth.CleartextClient = Cleartext{} +var _ auth.CleartextServer = Cleartext{} +var _ auth.MD5Client = Cleartext{} +var _ auth.MD5Server = Cleartext{} +var _ auth.SASLClient = Cleartext{} +var _ auth.SASLServer = Cleartext{} diff --git a/lib/auth/credentials/credentials_test.go b/lib/auth/credentials/credentials_test.go index 2d045954233f0e0eb3b7e0be0790d32be360132e..f6781ab8601d8c0945fd827e38bc2c38beb5a94b 100644 --- a/lib/auth/credentials/credentials_test.go +++ b/lib/auth/credentials/credentials_test.go @@ -11,8 +11,8 @@ func TestMD5(t *testing.T) { pw := FromString("bob", "jNKuKKlBDO48qbLiVw7IuoaamZ1SmHAUdQ9PKH7qRzsyJVF0BNPSFMbHTQwxe0HJ") md5 := FromString("bob", "md5e20510fd38e1c0fd99db13da5c29bd95") - pwMD5 := pw.(auth.MD5) - md5MD5 := md5.(auth.MD5) + pwMD5 := pw.(auth.MD5Client) + md5MD5 := md5.(auth.MD5Server) var salt [4]byte _, err := rand.Read(salt[:]) diff --git a/lib/auth/credentials/errors.go b/lib/auth/credentials/errors.go new file mode 100644 index 0000000000000000000000000000000000000000..e41c6f2379aa814c2020d6406d25e7fe96d7b9ca --- /dev/null +++ b/lib/auth/credentials/errors.go @@ -0,0 +1,7 @@ +package credentials + +import "errors" + +var ( + ErrInvalidSecretFormat = errors.New("invalid secret format") +) diff --git a/lib/auth/credentials/md5.go b/lib/auth/credentials/md5.go index c96c48487a40b355e03dcba04ef97d4c0b66f7cc..bb891bf54505018e5b8c72ee37326b03e46ba034 100644 --- a/lib/auth/credentials/md5.go +++ b/lib/auth/credentials/md5.go @@ -10,14 +10,27 @@ import ( ) type MD5 struct { - Username string - Hash []byte + Hash []byte } -func (T MD5) GetUsername() string { - return T.Username +func MD5FromString(value string) (MD5, error) { + if !strings.HasPrefix(value, "md5") { + return MD5{}, ErrInvalidSecretFormat + } + + var res MD5 + var err error + hexString := strings.TrimPrefix(value, "md5") + res.Hash, err = hex.DecodeString(hexString) + if err != nil { + return MD5{}, err + } + + return res, nil } +func (MD5) Credentials() {} + func (T MD5) EncodeMD5(salt [4]byte) string { hexEncoded := make([]byte, hex.EncodedLen(len(T.Hash))) hex.Encode(hexEncoded, T.Hash) @@ -44,4 +57,6 @@ func (T MD5) VerifyMD5(salt [4]byte, value string) error { return nil } -var _ auth.MD5 = MD5{} +var _ auth.Credentials = MD5{} +var _ auth.MD5Client = MD5{} +var _ auth.MD5Server = MD5{} diff --git a/lib/auth/credentials/scram.go b/lib/auth/credentials/scram.go new file mode 100644 index 0000000000000000000000000000000000000000..898d3e1634f118129ce3c2a7df66b253afef2935 --- /dev/null +++ b/lib/auth/credentials/scram.go @@ -0,0 +1,144 @@ +package credentials + +import ( + "crypto/sha256" + "encoding/base64" + "errors" + "io" + "strconv" + "strings" + "sync" + + "gfx.cafe/ghalliday1/scram" + + "pggat/lib/auth" +) + +type Scram struct { + Keys scram.ServerKeys + + clientKey []byte + mu sync.RWMutex +} + +func ScramFromString(password string) (*Scram, error) { + alg, iterKeys, ok := strings.Cut(password, "$") + if !ok { + return nil, ErrInvalidSecretFormat + } + var hasher scram.Hasher + switch alg { + case "SCRAM-SHA-256": + hasher = sha256.New + default: + // invalid algorithm + return nil, ErrInvalidSecretFormat + } + + iterSalt, keys, ok := strings.Cut(iterKeys, "$") + if !ok { + return nil, ErrInvalidSecretFormat + } + iter, salt, ok := strings.Cut(iterSalt, ":") + if !ok { + return nil, ErrInvalidSecretFormat + } + storedKey, serverKey, ok := strings.Cut(keys, ":") + + var res Scram + res.Keys.Hasher = hasher + + var err error + res.Keys.Iters, err = strconv.Atoi(iter) + if err != nil { + return nil, err + } + + var saltBytes []byte + saltBytes, err = base64.StdEncoding.DecodeString(salt) + if err != nil { + return nil, err + } + res.Keys.Salt = saltBytes + + res.Keys.StoredKey, err = base64.StdEncoding.DecodeString(storedKey) + if err != nil { + return nil, err + } + + res.Keys.ServerKey, err = base64.StdEncoding.DecodeString(serverKey) + if err != nil { + return nil, err + } + + return &res, nil +} + +func (T *Scram) SupportedSASLMechanisms() []auth.SASLMechanism { + return []auth.SASLMechanism{ + auth.ScramSHA256, + } +} + +func (T *Scram) EncodeSASL(mechanisms []auth.SASLMechanism) (auth.SASLMechanism, auth.SASLEncoder, error) { + T.mu.RLock() + clientKey := T.clientKey + T.mu.RUnlock() + if clientKey == nil { + return "", nil, errors.New("you must log in with SASL first") + } + + for _, mechanism := range mechanisms { + switch mechanism { + case auth.ScramSHA256: + return auth.ScramSHA256, &scram.ClientConversation{ + Lookup: scram.ClientKeysLookup(scram.ClientKeys{ + ClientKey: clientKey, + ServerKey: T.Keys.ServerKey, + KeyInfo: T.Keys.KeyInfo, + }), + }, nil + } + } + return "", nil, auth.ErrSASLMechanismNotSupported +} + +type ScramInterceptorVerifier struct { + Scram *Scram + Conversation *scram.ServerConversation +} + +func (T ScramInterceptorVerifier) Write(bytes []byte) ([]byte, error) { + resp, err := T.Conversation.Write(bytes) + if err == io.EOF { + T.Scram.mu.Lock() + defer T.Scram.mu.Unlock() + + T.Scram.clientKey = T.Conversation.RecoveredClientKey + } + return resp, err +} + +var _ auth.SASLVerifier = ScramInterceptorVerifier{} + +func (T *Scram) VerifySASL(mechanism auth.SASLMechanism) (auth.SASLVerifier, error) { + switch mechanism { + case auth.ScramSHA256: + return ScramInterceptorVerifier{ + Scram: T, + Conversation: &scram.ServerConversation{ + Lookup: func(string) (scram.ServerKeys, bool) { + return T.Keys, true + }, + }, + }, nil + default: + return nil, auth.ErrSASLMechanismNotSupported + } +} + +func (*Scram) Credentials() {} + +var _ auth.Credentials = (*Scram)(nil) +var _ auth.SASLServer = (*Scram)(nil) +var _ auth.SASLClient = (*Scram)(nil) diff --git a/lib/auth/credentials/string.go b/lib/auth/credentials/string.go index 4f04779d0e9164cc9873c7ce66c63005461a3bcb..2f57b16d4ffb8ca26f3c8d95fa078133504d22b8 100644 --- a/lib/auth/credentials/string.go +++ b/lib/auth/credentials/string.go @@ -1,32 +1,20 @@ package credentials import ( - "encoding/hex" - "strings" - "pggat/lib/auth" ) func FromString(user, password string) auth.Credentials { if password == "" { return nil - } else if strings.HasPrefix(password, "md5") { - hexHash := strings.TrimPrefix(password, "md5") - hash, err := hex.DecodeString(hexHash) - if err != nil { - return Cleartext{ - Username: user, - Password: password, - } - } - return MD5{ - Username: user, - Hash: hash, - } + } else if v, err := ScramFromString(password); err == nil { + return v + } else if v, err := MD5FromString(password); err == nil { + return v } else { return Cleartext{ Username: user, - Password: password, // TODO(garet) sasl + Password: password, } } } diff --git a/lib/auth/errors.go b/lib/auth/errors.go index b08882e8b17d55f384cff244df185203b72c27d1..5bdd70e8b28933bfa922590e72b4cb97bd7aed68 100644 --- a/lib/auth/errors.go +++ b/lib/auth/errors.go @@ -6,5 +6,4 @@ var ( ErrMethodNotSupported = errors.New("auth method not supported") ErrFailed = errors.New("auth failed") ErrSASLMechanismNotSupported = errors.New("SASL mechanism not supported") - ErrSASLComplete = errors.New("SASL Complete") ) diff --git a/lib/bouncer/backends/v0/accept.go b/lib/bouncer/backends/v0/accept.go index cc9fb2a7449e7af8702efe948bbb5ae84aa54756..5632d0edca65fd6d3bc1410a71b64027db9b0c69 100644 --- a/lib/bouncer/backends/v0/accept.go +++ b/lib/bouncer/backends/v0/accept.go @@ -1,7 +1,6 @@ package backends import ( - "crypto/tls" "errors" "io" @@ -11,20 +10,19 @@ import ( "pggat/lib/util/strutil" ) -func authenticationSASLChallenge(server fed.Conn, encoder auth.SASLEncoder) (done bool, err error) { - var packet fed.Packet - packet, err = server.ReadPacket(true) +func authenticationSASLChallenge(ctx *AcceptContext, encoder auth.SASLEncoder) (done bool, err error) { + ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) if err != nil { return } - if packet.Type() != packets.TypeAuthentication { + if ctx.Packet.Type() != packets.TypeAuthentication { err = ErrUnexpectedPacket return } var method int32 - p := packet.ReadInt32(&method) + p := ctx.Packet.ReadInt32(&method) switch method { case 11: @@ -36,12 +34,16 @@ func authenticationSASLChallenge(server fed.Conn, encoder auth.SASLEncoder) (don } resp := packets.AuthenticationResponse(response) - err = server.WritePacket(resp.IntoPacket()) + ctx.Packet = resp.IntoPacket(ctx.Packet) + err = ctx.Conn.WritePacket(ctx.Packet) return case 12: // finish _, err = encoder.Write(p) - if err != nil { + if err != io.EOF { + if err == nil { + err = errors.New("expected EOF") + } return } @@ -52,7 +54,7 @@ func authenticationSASLChallenge(server fed.Conn, encoder auth.SASLEncoder) (don } } -func authenticationSASL(server fed.Conn, mechanisms []string, creds auth.SASL) error { +func authenticationSASL(ctx *AcceptContext, mechanisms []string, creds auth.SASLClient) error { mechanism, encoder, err := creds.EncodeSASL(mechanisms) if err != nil { return err @@ -66,7 +68,8 @@ func authenticationSASL(server fed.Conn, mechanisms []string, creds auth.SASL) e Mechanism: mechanism, InitialResponse: initialResponse, } - err = server.WritePacket(saslInitialResponse.IntoPacket()) + ctx.Packet = saslInitialResponse.IntoPacket(ctx.Packet) + err = ctx.Conn.WritePacket(ctx.Packet) if err != nil { return err } @@ -74,7 +77,7 @@ func authenticationSASL(server fed.Conn, mechanisms []string, creds auth.SASL) e // challenge loop for { var done bool - done, err = authenticationSASLChallenge(server, encoder) + done, err = authenticationSASLChallenge(ctx, encoder) if err != nil { return err } @@ -86,31 +89,33 @@ func authenticationSASL(server fed.Conn, mechanisms []string, creds auth.SASL) e return nil } -func authenticationMD5(server fed.Conn, salt [4]byte, creds auth.MD5) error { +func authenticationMD5(ctx *AcceptContext, salt [4]byte, creds auth.MD5Client) error { pw := packets.PasswordMessage{ Password: creds.EncodeMD5(salt), } - err := server.WritePacket(pw.IntoPacket()) + ctx.Packet = pw.IntoPacket(ctx.Packet) + err := ctx.Conn.WritePacket(ctx.Packet) if err != nil { return err } return nil } -func authenticationCleartext(server fed.Conn, creds auth.Cleartext) error { +func authenticationCleartext(ctx *AcceptContext, creds auth.CleartextClient) error { pw := packets.PasswordMessage{ Password: creds.EncodeCleartext(), } - err := server.WritePacket(pw.IntoPacket()) + ctx.Packet = pw.IntoPacket(ctx.Packet) + err := ctx.Conn.WritePacket(ctx.Packet) if err != nil { return err } return nil } -func authentication(server fed.Conn, creds auth.Credentials, packet fed.Packet) (done bool, err error) { +func authentication(ctx *AcceptContext) (done bool, err error) { var method int32 - packet.ReadInt32(&method) + ctx.Packet.ReadInt32(&method) // they have more authentication methods than there are pokemon switch method { case 0: @@ -120,23 +125,23 @@ func authentication(server fed.Conn, creds auth.Credentials, packet fed.Packet) err = errors.New("kerberos v5 is not supported") return case 3: - c, ok := creds.(auth.Cleartext) + c, ok := ctx.Options.Credentials.(auth.CleartextClient) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationCleartext(server, c) + return false, authenticationCleartext(ctx, c) case 5: var md5 packets.AuthenticationMD5 - if !md5.ReadFromPacket(packet) { + if !md5.ReadFromPacket(ctx.Packet) { err = ErrBadFormat return } - c, ok := creds.(auth.MD5) + c, ok := ctx.Options.Credentials.(auth.MD5Client) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationMD5(server, md5.Salt, c) + return false, authenticationMD5(ctx, md5.Salt, c) case 6: err = errors.New("scm credential is not supported") return @@ -149,40 +154,39 @@ func authentication(server fed.Conn, creds auth.Credentials, packet fed.Packet) case 10: // read list of mechanisms var sasl packets.AuthenticationSASL - if !sasl.ReadFromPacket(packet) { + if !sasl.ReadFromPacket(ctx.Packet) { err = ErrBadFormat return } - c, ok := creds.(auth.SASL) + c, ok := ctx.Options.Credentials.(auth.SASLClient) if !ok { return false, auth.ErrMethodNotSupported } - return false, authenticationSASL(server, sasl.Mechanisms, c) + return false, authenticationSASL(ctx, sasl.Mechanisms, c) default: err = errors.New("unknown authentication method") return } } -func startup0(server fed.Conn, creds auth.Credentials) (done bool, err error) { - var packet fed.Packet - packet, err = server.ReadPacket(true) +func startup0(ctx *AcceptContext) (done bool, err error) { + ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) if err != nil { return } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeErrorResponse: var err2 packets.ErrorResponse - if !err2.ReadFromPacket(packet) { + if !err2.ReadFromPacket(ctx.Packet) { err = ErrBadFormat } else { err = errors.New(err2.Error.String()) } return case packets.TypeAuthentication: - return authentication(server, creds, packet) + return authentication(ctx) case packets.TypeNegotiateProtocolVersion: // we only support protocol 3.0 for now err = errors.New("server wanted to negotiate protocol version") @@ -193,20 +197,19 @@ func startup0(server fed.Conn, creds auth.Credentials) (done bool, err error) { } } -func startup1(conn fed.Conn, params *AcceptParams) (done bool, err error) { - var packet fed.Packet - packet, err = conn.ReadPacket(true) +func startup1(ctx *AcceptContext, params *AcceptParams) (done bool, err error) { + ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) if err != nil { return } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeBackendKeyData: - packet.ReadBytes(params.BackendKey[:]) + ctx.Packet.ReadBytes(params.BackendKey[:]) return false, nil case packets.TypeParameterStatus: var ps packets.ParameterStatus - if !ps.ReadFromPacket(packet) { + if !ps.ReadFromPacket(ctx.Packet) { err = ErrBadFormat return } @@ -220,7 +223,7 @@ func startup1(conn fed.Conn, params *AcceptParams) (done bool, err error) { return true, nil case packets.TypeErrorResponse: var err2 packets.ErrorResponse - if !err2.ReadFromPacket(packet) { + if !err2.ReadFromPacket(ctx.Packet) { err = ErrBadFormat } else { err = errors.New(err2.Error.String()) @@ -235,15 +238,15 @@ func startup1(conn fed.Conn, params *AcceptParams) (done bool, err error) { } } -func enableSSL(server fed.Conn, config *tls.Config) (bool, error) { - packet := fed.NewPacket(0, 4) - packet = packet.AppendUint16(1234) - packet = packet.AppendUint16(5679) - if err := server.WritePacket(packet); err != nil { +func enableSSL(ctx *AcceptContext) (bool, error) { + ctx.Packet = ctx.Packet.Reset(0, 4) + ctx.Packet = ctx.Packet.AppendUint16(1234) + ctx.Packet = ctx.Packet.AppendUint16(5679) + if err := ctx.Conn.WritePacket(ctx.Packet); err != nil { return false, err } - byteReader, ok := server.(io.ByteReader) + byteReader, ok := ctx.Conn.(io.ByteReader) if !ok { return false, errors.New("server must be io.ByteReader to enable ssl") } @@ -259,65 +262,65 @@ func enableSSL(server fed.Conn, config *tls.Config) (bool, error) { return false, nil } - sslClient, ok := server.(fed.SSLClient) + sslClient, ok := ctx.Conn.(fed.SSLClient) if !ok { return false, errors.New("server must be fed.SSLClient to enable ssl") } - if err = sslClient.EnableSSLClient(config); err != nil { + if err = sslClient.EnableSSLClient(ctx.Options.SSLConfig); err != nil { return false, err } return true, nil } -func Accept(server fed.Conn, options AcceptOptions) (AcceptParams, error) { - username := options.Credentials.GetUsername() +func Accept(ctx *AcceptContext) (AcceptParams, error) { + username := ctx.Options.Username - if options.Database == "" { - options.Database = username + if ctx.Options.Database == "" { + ctx.Options.Database = username } var params AcceptParams - if options.SSLMode.ShouldAttempt() { + if ctx.Options.SSLMode.ShouldAttempt() { var err error - params.SSLEnabled, err = enableSSL(server, options.SSLConfig) + params.SSLEnabled, err = enableSSL(ctx) if err != nil { return AcceptParams{}, err } - if !params.SSLEnabled && options.SSLMode.IsRequired() { + if !params.SSLEnabled && ctx.Options.SSLMode.IsRequired() { return AcceptParams{}, errors.New("server rejected SSL encryption") } } - size := 4 + len("user") + 1 + len(username) + 1 + len("database") + 1 + len(options.Database) + 1 - for key, value := range options.StartupParameters { + size := 4 + len("user") + 1 + len(username) + 1 + len("database") + 1 + len(ctx.Options.Database) + 1 + for key, value := range ctx.Options.StartupParameters { size += len(key.String()) + len(value) + 2 } size += 1 - packet := fed.NewPacket(0, size) - packet = packet.AppendUint16(3) - packet = packet.AppendUint16(0) - packet = packet.AppendString("user") - packet = packet.AppendString(username) - packet = packet.AppendString("database") - packet = packet.AppendString(options.Database) - for key, value := range options.StartupParameters { - packet = packet.AppendString(key.String()) - packet = packet.AppendString(value) + ctx.Packet = ctx.Packet.Reset(0, size) + ctx.Packet = ctx.Packet.AppendUint16(3) + ctx.Packet = ctx.Packet.AppendUint16(0) + ctx.Packet = ctx.Packet.AppendString("user") + ctx.Packet = ctx.Packet.AppendString(username) + ctx.Packet = ctx.Packet.AppendString("database") + ctx.Packet = ctx.Packet.AppendString(ctx.Options.Database) + for key, value := range ctx.Options.StartupParameters { + ctx.Packet = ctx.Packet.AppendString(key.String()) + ctx.Packet = ctx.Packet.AppendString(value) } - packet = packet.AppendUint8(0) + ctx.Packet = ctx.Packet.AppendUint8(0) - err := server.WritePacket(packet) + err := ctx.Conn.WritePacket(ctx.Packet) if err != nil { return AcceptParams{}, err } for { var done bool - done, err = startup0(server, options.Credentials) + done, err = startup0(ctx) if err != nil { return AcceptParams{}, err } @@ -328,7 +331,7 @@ func Accept(server fed.Conn, options AcceptOptions) (AcceptParams, error) { for { var done bool - done, err = startup1(server, ¶ms) + done, err = startup1(ctx, ¶ms) if err != nil { return AcceptParams{}, err } diff --git a/lib/bouncer/backends/v0/context.go b/lib/bouncer/backends/v0/context.go index 3a6c95eb64b9a99ac7a6796e7edb5d55f76e7131..94477cac686d53f8b1c86db3da659fb47328ac28 100644 --- a/lib/bouncer/backends/v0/context.go +++ b/lib/bouncer/backends/v0/context.go @@ -1,13 +1,33 @@ package backends -import "pggat/lib/fed" +import ( + "pggat/lib/fed" +) + +type AcceptContext struct { + Packet fed.Packet + Conn fed.Conn + Options AcceptOptions +} type Context struct { + Server fed.ReadWriter + Packet fed.Packet Peer fed.ReadWriter PeerError error TxState byte } +func (T *Context) ServerRead() error { + var err error + T.Packet, err = T.Server.ReadPacket(true, T.Packet) + return err +} + +func (T *Context) ServerWrite() error { + return T.Server.WritePacket(T.Packet) +} + func (T *Context) PeerOK() bool { if T == nil { return false @@ -23,29 +43,30 @@ func (T *Context) PeerFail(err error) { T.PeerError = err } -func (T *Context) PeerRead() fed.Packet { +func (T *Context) PeerRead() bool { if T == nil { - return nil + return false } if !T.PeerOK() { - return nil + return false } - packet, err := T.Peer.ReadPacket(true) + var err error + T.Packet, err = T.Peer.ReadPacket(true, T.Packet) if err != nil { T.PeerFail(err) - return nil + return false } - return packet + return true } -func (T *Context) PeerWrite(packet fed.Packet) { +func (T *Context) PeerWrite() { if T == nil { return } if !T.PeerOK() { return } - err := T.Peer.WritePacket(packet) + err := T.Peer.WritePacket(T.Packet) if err != nil { T.PeerFail(err) } diff --git a/lib/bouncer/backends/v0/options.go b/lib/bouncer/backends/v0/options.go index b75821792c0eb02ee912423a2a954d8e14fe22bf..b05b8ba5d00a4d1024bd569be561d41fb5134850 100644 --- a/lib/bouncer/backends/v0/options.go +++ b/lib/bouncer/backends/v0/options.go @@ -11,6 +11,7 @@ import ( type AcceptOptions struct { SSLMode bouncer.SSLMode SSLConfig *tls.Config + Username string Credentials auth.Credentials Database string StartupParameters map[strutil.CIString]string diff --git a/lib/bouncer/backends/v0/query.go b/lib/bouncer/backends/v0/query.go index b2ac45107defac7e5d731c5bad971ba1ad44e80c..cbfe89bb4c4eabfdfc8009eec05f792f7f4a421b 100644 --- a/lib/bouncer/backends/v0/query.go +++ b/lib/bouncer/backends/v0/query.go @@ -3,54 +3,52 @@ package backends import ( "fmt" - "pggat/lib/fed" packets "pggat/lib/fed/packets/v3.0" "pggat/lib/util/strutil" ) -func CopyIn(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { - ctx.PeerWrite(packet) +func CopyIn(ctx *Context) error { + ctx.PeerWrite() for { - packet = ctx.PeerRead() - if packet == nil { + if !ctx.PeerRead() { copyFail := packets.CopyFail{ Reason: "peer failed", } - return server.WritePacket(copyFail.IntoPacket()) + ctx.Packet = copyFail.IntoPacket(ctx.Packet) + return ctx.ServerWrite() } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeCopyData: - if err := server.WritePacket(packet); err != nil { + if err := ctx.ServerWrite(); err != nil { return err } case packets.TypeCopyDone, packets.TypeCopyFail: - return server.WritePacket(packet) + return ctx.ServerWrite() default: ctx.PeerFail(ErrUnexpectedPacket) } } } -func CopyOut(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { - ctx.PeerWrite(packet) +func CopyOut(ctx *Context) error { + ctx.PeerWrite() for { - var err error - packet, err = server.ReadPacket(true) + err := ctx.ServerRead() if err != nil { return err } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeCopyData, packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite(packet) + ctx.PeerWrite() case packets.TypeCopyDone, packets.TypeErrorResponse: - ctx.PeerWrite(packet) + ctx.PeerWrite() return nil default: return ErrUnexpectedPacket @@ -58,19 +56,18 @@ func CopyOut(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { } } -func Query(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { - if err := server.WritePacket(packet); err != nil { +func Query(ctx *Context) error { + if err := ctx.ServerWrite(); err != nil { return err } for { - var err error - packet, err = server.ReadPacket(true) + err := ctx.ServerRead() if err != nil { return err } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeCommandComplete, packets.TypeRowDescription, packets.TypeDataRow, @@ -79,22 +76,22 @@ func Query(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite(packet) + ctx.PeerWrite() case packets.TypeCopyInResponse: - if err = CopyIn(ctx, server, packet); err != nil { + if err = CopyIn(ctx); err != nil { return err } case packets.TypeCopyOutResponse: - if err = CopyOut(ctx, server, packet); err != nil { + if err = CopyOut(ctx); err != nil { return err } case packets.TypeReadyForQuery: var txState packets.ReadyForQuery - if !txState.ReadFromPacket(packet) { + if !txState.ReadFromPacket(ctx.Packet) { return ErrBadFormat } ctx.TxState = byte(txState) - ctx.PeerWrite(packet) + ctx.PeerWrite() return nil default: return ErrUnexpectedPacket @@ -102,45 +99,44 @@ func Query(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { } } -func QueryString(ctx *Context, server fed.ReadWriter, query string) error { +func QueryString(ctx *Context, query string) error { q := packets.Query(query) - return Query(ctx, server, q.IntoPacket()) + ctx.Packet = q.IntoPacket(ctx.Packet) + return Query(ctx) } -func SetParameter(ctx *Context, server fed.ReadWriter, name strutil.CIString, value string) error { +func SetParameter(ctx *Context, name strutil.CIString, value string) error { return QueryString( ctx, - server, fmt.Sprintf(`SET "%s" = '%s'`, strutil.Escape(name.String(), '"'), strutil.Escape(value, '\'')), ) } -func FunctionCall(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { - if err := server.WritePacket(packet); err != nil { +func FunctionCall(ctx *Context) error { + if err := ctx.ServerWrite(); err != nil { return err } for { - var err error - packet, err = server.ReadPacket(true) + err := ctx.ServerRead() if err != nil { return err } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeErrorResponse, packets.TypeFunctionCallResponse, packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite(packet) + ctx.PeerWrite() case packets.TypeReadyForQuery: var txState packets.ReadyForQuery - if !txState.ReadFromPacket(packet) { + if !txState.ReadFromPacket(ctx.Packet) { return ErrBadFormat } ctx.TxState = byte(txState) - ctx.PeerWrite(packet) + ctx.PeerWrite() return nil default: return ErrUnexpectedPacket @@ -148,18 +144,19 @@ func FunctionCall(ctx *Context, server fed.ReadWriter, packet fed.Packet) error } } -func Sync(ctx *Context, server fed.ReadWriter) (bool, error) { - if err := server.WritePacket(fed.NewPacket(packets.TypeSync)); err != nil { +func Sync(ctx *Context) (bool, error) { + ctx.Packet = ctx.Packet.Reset(packets.TypeSync) + if err := ctx.ServerWrite(); err != nil { return false, err } for { - packet, err := server.ReadPacket(true) + err := ctx.ServerRead() if err != nil { return false, err } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeParseComplete, packets.TypeBindComplete, packets.TypeCloseComplete, @@ -176,24 +173,24 @@ func Sync(ctx *Context, server fed.ReadWriter) (bool, error) { packets.TypeNoticeResponse, packets.TypeParameterStatus, packets.TypeNotificationResponse: - ctx.PeerWrite(packet) + ctx.PeerWrite() case packets.TypeCopyInResponse: - if err = CopyIn(ctx, server, packet); err != nil { + if err = CopyIn(ctx); err != nil { return false, err } // why return false, nil case packets.TypeCopyOutResponse: - if err = CopyOut(ctx, server, packet); err != nil { + if err = CopyOut(ctx); err != nil { return false, err } case packets.TypeReadyForQuery: var txState packets.ReadyForQuery - if !txState.ReadFromPacket(packet) { + if !txState.ReadFromPacket(ctx.Packet) { return false, ErrBadFormat } ctx.TxState = byte(txState) - ctx.PeerWrite(packet) + ctx.PeerWrite() return true, nil default: return false, ErrUnexpectedPacket @@ -201,16 +198,15 @@ func Sync(ctx *Context, server fed.ReadWriter) (bool, error) { } } -func EQP(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { - if err := server.WritePacket(packet); err != nil { +func EQP(ctx *Context) error { + if err := ctx.ServerWrite(); err != nil { return err } for { - packet = ctx.PeerRead() - if packet == nil { + if !ctx.PeerRead() { for { - ok, err := Sync(ctx, server) + ok, err := Sync(ctx) if err != nil { return err } @@ -220,9 +216,9 @@ func EQP(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { } } - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeSync: - ok, err := Sync(ctx, server) + ok, err := Sync(ctx) if err != nil { return err } @@ -230,7 +226,7 @@ func EQP(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { return nil } case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: - if err := server.WritePacket(packet); err != nil { + if err := ctx.ServerWrite(); err != nil { return err } default: @@ -239,26 +235,27 @@ func EQP(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { } } -func Transaction(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { +func Transaction(ctx *Context) error { if ctx.TxState == '\x00' { ctx.TxState = 'I' } for { - switch packet.Type() { + switch ctx.Packet.Type() { case packets.TypeQuery: - if err := Query(ctx, server, packet); err != nil { + if err := Query(ctx); err != nil { return err } case packets.TypeFunctionCall: - if err := FunctionCall(ctx, server, packet); err != nil { + if err := FunctionCall(ctx); err != nil { return err } case packets.TypeSync: // phony sync call, we can just reply with a fake ReadyForQuery(TxState) rfq := packets.ReadyForQuery(ctx.TxState) - ctx.PeerWrite(rfq.IntoPacket()) + ctx.Packet = rfq.IntoPacket(ctx.Packet) + ctx.PeerWrite() case packets.TypeParse, packets.TypeBind, packets.TypeClose, packets.TypeDescribe, packets.TypeExecute, packets.TypeFlush: - if err := EQP(ctx, server, packet); err != nil { + if err := EQP(ctx); err != nil { return err } default: @@ -269,10 +266,9 @@ func Transaction(ctx *Context, server fed.ReadWriter, packet fed.Packet) error { return nil } - packet = ctx.PeerRead() - if packet == nil { + if !ctx.PeerRead() { // abort tx - err := QueryString(ctx, server, "ABORT;") + err := QueryString(ctx, "ABORT;") if err != nil { return err } diff --git a/lib/bouncer/bouncers/v2/bouncer.go b/lib/bouncer/bouncers/v2/bouncer.go index c13b05e27516d9ed4632522849e557e967428c79..89189164b4fa87822ec1bc61df9dbf0cd143dc3f 100644 --- a/lib/bouncer/bouncers/v2/bouncer.go +++ b/lib/bouncer/bouncers/v2/bouncer.go @@ -7,26 +7,31 @@ import ( "pggat/lib/perror" ) -func clientFail(client fed.ReadWriter, err perror.Error) { +func clientFail(ctx *backends.Context, client fed.ReadWriter, err perror.Error) { // send fatal error to client resp := packets.ErrorResponse{ Error: err, } - _ = client.WritePacket(resp.IntoPacket()) + ctx.Packet = resp.IntoPacket(ctx.Packet) + _ = client.WritePacket(ctx.Packet) } -func Bounce(client, server fed.ReadWriter, initialPacket fed.Packet) (clientError error, serverError error) { +func Bounce(client, server fed.ReadWriter, initialPacket fed.Packet) (packet fed.Packet, clientError error, serverError error) { ctx := backends.Context{ - Peer: client, + Server: server, + Packet: initialPacket, + Peer: client, } - serverError = backends.Transaction(&ctx, server, initialPacket) + serverError = backends.Transaction(&ctx) clientError = ctx.PeerError if clientError != nil { - clientFail(client, perror.Wrap(clientError)) + clientFail(&ctx, client, perror.Wrap(clientError)) } else if serverError != nil { - clientFail(client, perror.Wrap(serverError)) + clientFail(&ctx, client, perror.Wrap(serverError)) } + packet = ctx.Packet + return } diff --git a/lib/bouncer/frontends/v0/accept.go b/lib/bouncer/frontends/v0/accept.go index ef47bd9d0667a94e8d0b3c65c55af9922ff2a3ca..6609a92394a0bb1a8170caba62e1dbeab84f053c 100644 --- a/lib/bouncer/frontends/v0/accept.go +++ b/lib/bouncer/frontends/v0/accept.go @@ -13,11 +13,11 @@ import ( ) func startup0( - conn fed.Conn, + ctx *AcceptContext, params *AcceptParams, - options AcceptOptions, ) (done bool, err perror.Error) { - packet, err2 := conn.ReadPacket(false) + var err2 error + ctx.Packet, err2 = ctx.Conn.ReadPacket(false, ctx.Packet) if err2 != nil { err = perror.Wrap(err2) return @@ -25,7 +25,7 @@ func startup0( var majorVersion uint16 var minorVersion uint16 - p := packet.ReadUint16(&majorVersion) + p := ctx.Packet.ReadUint16(&majorVersion) p = p.ReadUint16(&minorVersion) if majorVersion == 1234 { @@ -49,7 +49,7 @@ func startup0( done = true return case 5679: - byteWriter, ok := conn.(io.ByteWriter) + byteWriter, ok := ctx.Conn.(io.ByteWriter) if !ok { err = perror.New( perror.FATAL, @@ -60,12 +60,12 @@ func startup0( } // ssl is not enabled - if options.SSLConfig == nil { + if ctx.Options.SSLConfig == nil { err = perror.Wrap(byteWriter.WriteByte('N')) return } - sslServer, ok := conn.(fed.SSLServer) + sslServer, ok := ctx.Conn.(fed.SSLServer) if !ok { err = perror.New( perror.FATAL, @@ -79,13 +79,13 @@ func startup0( if err = perror.Wrap(byteWriter.WriteByte('S')); err != nil { return } - if err = perror.Wrap(sslServer.EnableSSLServer(options.SSLConfig)); err != nil { + if err = perror.Wrap(sslServer.EnableSSLServer(ctx.Options.SSLConfig)); err != nil { return } params.SSLEnabled = true return case 5680: - byteWriter, ok := conn.(io.ByteWriter) + byteWriter, ok := ctx.Conn.(io.ByteWriter) if !ok { err = perror.New( perror.FATAL, @@ -154,7 +154,7 @@ func startup0( ikey := strutil.MakeCIString(key) - if !slices.Contains(options.AllowedStartupOptions, ikey) { + if !slices.Contains(ctx.Options.AllowedStartupOptions, ikey) { err = perror.New( perror.FATAL, perror.FeatureNotSupported, @@ -190,7 +190,7 @@ func startup0( } else { ikey := strutil.MakeCIString(key) - if !slices.Contains(options.AllowedStartupOptions, ikey) { + if !slices.Contains(ctx.Options.AllowedStartupOptions, ikey) { err = perror.New( perror.FATAL, perror.FeatureNotSupported, @@ -213,8 +213,8 @@ func startup0( MinorProtocolVersion: 0, UnrecognizedOptions: unsupportedOptions, } - - err = perror.Wrap(conn.WritePacket(uopts.IntoPacket())) + ctx.Packet = uopts.IntoPacket(ctx.Packet) + err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) if err != nil { return } @@ -237,12 +237,11 @@ func startup0( } func accept( - client fed.Conn, - options AcceptOptions, + ctx *AcceptContext, ) (params AcceptParams, err perror.Error) { for { var done bool - done, err = startup0(client, ¶ms, options) + done, err = startup0(ctx, ¶ms) if err != nil { return } @@ -255,7 +254,7 @@ func accept( return } - if options.SSLRequired && !params.SSLEnabled { + if ctx.Options.SSLRequired && !params.SSLEnabled { err = perror.New( perror.FATAL, perror.InvalidPassword, @@ -267,17 +266,18 @@ func accept( return } -func fail(client fed.Conn, err perror.Error) { +func fail(packet fed.Packet, client fed.Conn, err perror.Error) { resp := packets.ErrorResponse{ Error: err, } - _ = client.WritePacket(resp.IntoPacket()) + packet = resp.IntoPacket(packet) + _ = client.WritePacket(packet) } -func Accept(client fed.Conn, options AcceptOptions) (AcceptParams, perror.Error) { - params, err := accept(client, options) +func Accept(ctx *AcceptContext) (AcceptParams, perror.Error) { + params, err := accept(ctx) if err != nil { - fail(client, err) + fail(ctx.Packet, ctx.Conn, err) return AcceptParams{}, err } return params, nil diff --git a/lib/bouncer/frontends/v0/authenticate.go b/lib/bouncer/frontends/v0/authenticate.go index aeabbad83b16e944c49df68194090cec17d62150..f2b2398ec99f896e013a3ed1d4863ff3741f0414 100644 --- a/lib/bouncer/frontends/v0/authenticate.go +++ b/lib/bouncer/frontends/v0/authenticate.go @@ -3,22 +3,23 @@ package frontends import ( "crypto/rand" "errors" + "io" "pggat/lib/auth" - "pggat/lib/fed" packets "pggat/lib/fed/packets/v3.0" "pggat/lib/perror" ) -func authenticationSASLInitial(client fed.Conn, creds auth.SASL) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { +func authenticationSASLInitial(ctx *AuthenticateContext, creds auth.SASLServer) (tool auth.SASLVerifier, resp []byte, done bool, err perror.Error) { // check which authentication method the client wants - packet, err2 := client.ReadPacket(true) + var err2 error + ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) if err2 != nil { err = perror.Wrap(err2) return } var initialResponse packets.SASLInitialResponse - if !initialResponse.ReadFromPacket(packet) { + if !initialResponse.ReadFromPacket(ctx.Packet) { err = packets.ErrBadFormat return } @@ -31,7 +32,7 @@ func authenticationSASLInitial(client fed.Conn, creds auth.SASL) (tool auth.SASL resp, err2 = tool.Write(initialResponse.InitialResponse) if err2 != nil { - if errors.Is(err2, auth.ErrSASLComplete) { + if errors.Is(err2, io.EOF) { done = true return } @@ -41,21 +42,22 @@ func authenticationSASLInitial(client fed.Conn, creds auth.SASL) (tool auth.SASL return } -func authenticationSASLContinue(client fed.Conn, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { - packet, err2 := client.ReadPacket(true) +func authenticationSASLContinue(ctx *AuthenticateContext, tool auth.SASLVerifier) (resp []byte, done bool, err perror.Error) { + var err2 error + ctx.Packet, err2 = ctx.Conn.ReadPacket(true, ctx.Packet) if err2 != nil { err = perror.Wrap(err2) return } var authResp packets.AuthenticationResponse - if !authResp.ReadFromPacket(packet) { + if !authResp.ReadFromPacket(ctx.Packet) { err = packets.ErrBadFormat return } resp, err2 = tool.Write(authResp) if err2 != nil { - if errors.Is(err2, auth.ErrSASLComplete) { + if errors.Is(err2, io.EOF) { done = true return } @@ -65,16 +67,17 @@ func authenticationSASLContinue(client fed.Conn, tool auth.SASLVerifier) (resp [ return } -func authenticationSASL(client fed.Conn, creds auth.SASL) perror.Error { +func authenticationSASL(ctx *AuthenticateContext, creds auth.SASLServer) perror.Error { saslInitial := packets.AuthenticationSASL{ Mechanisms: creds.SupportedSASLMechanisms(), } - err := perror.Wrap(client.WritePacket(saslInitial.IntoPacket())) + ctx.Packet = saslInitial.IntoPacket(ctx.Packet) + err := perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) if err != nil { return err } - tool, resp, done, err := authenticationSASLInitial(client, creds) + tool, resp, done, err := authenticationSASLInitial(ctx, creds) if err != nil { return err } @@ -82,20 +85,22 @@ func authenticationSASL(client fed.Conn, creds auth.SASL) perror.Error { for { if done { final := packets.AuthenticationSASLFinal(resp) - err = perror.Wrap(client.WritePacket(final.IntoPacket())) + ctx.Packet = final.IntoPacket(ctx.Packet) + err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) if err != nil { return err } break } else { cont := packets.AuthenticationSASLContinue(resp) - err = perror.Wrap(client.WritePacket(cont.IntoPacket())) + ctx.Packet = cont.IntoPacket(ctx.Packet) + err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)) if err != nil { return err } } - resp, done, err = authenticationSASLContinue(client, tool) + resp, done, err = authenticationSASLContinue(ctx, tool) if err != nil { return err } @@ -104,7 +109,7 @@ func authenticationSASL(client fed.Conn, creds auth.SASL) perror.Error { return nil } -func authenticationMD5(client fed.Conn, creds auth.MD5) perror.Error { +func authenticationMD5(ctx *AuthenticateContext, creds auth.MD5Server) perror.Error { var salt [4]byte _, err := rand.Read(salt[:]) if err != nil { @@ -113,19 +118,19 @@ func authenticationMD5(client fed.Conn, creds auth.MD5) perror.Error { md5Initial := packets.AuthenticationMD5{ Salt: salt, } - err = client.WritePacket(md5Initial.IntoPacket()) + ctx.Packet = md5Initial.IntoPacket(ctx.Packet) + err = ctx.Conn.WritePacket(ctx.Packet) if err != nil { return perror.Wrap(err) } - var packet fed.Packet - packet, err = client.ReadPacket(true) + ctx.Packet, err = ctx.Conn.ReadPacket(true, ctx.Packet) if err != nil { return perror.Wrap(err) } var pw packets.PasswordMessage - if !pw.ReadFromPacket(packet) { + if !pw.ReadFromPacket(ctx.Packet) { return packets.ErrUnexpectedPacket } @@ -136,33 +141,28 @@ func authenticationMD5(client fed.Conn, creds auth.MD5) perror.Error { return nil } -func authenticate(client fed.Conn, options AuthenticateOptions) (params AuthenticateParams, err perror.Error) { - if options.Credentials == nil { - err = perror.New( - perror.FATAL, - perror.InvalidPassword, - "User or database not found", - ) - return - } - if credsSASL, ok := options.Credentials.(auth.SASL); ok { - err = authenticationSASL(client, credsSASL) - } else if credsMD5, ok := options.Credentials.(auth.MD5); ok { - err = authenticationMD5(client, credsMD5) - } else { - err = perror.New( - perror.FATAL, - perror.InternalError, - "Auth method not supported", - ) - } - if err != nil { - return +func authenticate(ctx *AuthenticateContext) (params AuthenticateParams, err perror.Error) { + if ctx.Options.Credentials != nil { + if credsSASL, ok := ctx.Options.Credentials.(auth.SASLServer); ok { + err = authenticationSASL(ctx, credsSASL) + } else if credsMD5, ok := ctx.Options.Credentials.(auth.MD5Server); ok { + err = authenticationMD5(ctx, credsMD5) + } else { + err = perror.New( + perror.FATAL, + perror.InternalError, + "Auth method not supported", + ) + } + if err != nil { + return + } } // send auth Ok authOk := packets.AuthenticationOk{} - if err = perror.Wrap(client.WritePacket(authOk.IntoPacket())); err != nil { + ctx.Packet = authOk.IntoPacket(ctx.Packet) + if err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)); err != nil { return } @@ -176,17 +176,18 @@ func authenticate(client fed.Conn, options AuthenticateOptions) (params Authenti keyData := packets.BackendKeyData{ CancellationKey: params.BackendKey, } - if err = perror.Wrap(client.WritePacket(keyData.IntoPacket())); err != nil { + ctx.Packet = keyData.IntoPacket(ctx.Packet) + if err = perror.Wrap(ctx.Conn.WritePacket(ctx.Packet)); err != nil { return } return } -func Authenticate(client fed.Conn, options AuthenticateOptions) (AuthenticateParams, perror.Error) { - params, err := authenticate(client, options) +func Authenticate(ctx *AuthenticateContext) (AuthenticateParams, perror.Error) { + params, err := authenticate(ctx) if err != nil { - fail(client, err) + fail(ctx.Packet, ctx.Conn, err) return AuthenticateParams{}, err } return params, nil diff --git a/lib/bouncer/frontends/v0/context.go b/lib/bouncer/frontends/v0/context.go new file mode 100644 index 0000000000000000000000000000000000000000..f77573172b8709f338a554107232910e8f3c2aa1 --- /dev/null +++ b/lib/bouncer/frontends/v0/context.go @@ -0,0 +1,15 @@ +package frontends + +import "pggat/lib/fed" + +type AcceptContext struct { + Packet fed.Packet + Conn fed.Conn + Options AcceptOptions +} + +type AuthenticateContext struct { + Packet fed.Packet + Conn fed.Conn + Options AuthenticateOptions +} diff --git a/lib/fed/conn.go b/lib/fed/conn.go index a51de715ff3b14b6933b7e84b529d7d75f344852..c45752adc7bc4a048917dd5ff5a9321ccc74b146 100644 --- a/lib/fed/conn.go +++ b/lib/fed/conn.go @@ -7,6 +7,8 @@ import ( "errors" "io" "net" + + "pggat/lib/util/slices" ) type Conn interface { @@ -67,33 +69,35 @@ func (T *netConn) ReadByte() (byte, error) { return T.reader.ReadByte() } -func (T *netConn) ReadPacket(typed bool) (Packet, error) { - if err := T.writer.Flush(); err != nil { - return nil, err +func (T *netConn) ReadPacket(typed bool, buffer Packet) (packet Packet, err error) { + packet = buffer + + if err = T.writer.Flush(); err != nil { + return } + if typed { - _, err := io.ReadFull(&T.reader, T.headerBuf[:]) + _, err = io.ReadFull(&T.reader, T.headerBuf[:]) if err != nil { - return nil, err + return } } else { - _, err := io.ReadFull(&T.reader, T.headerBuf[1:]) + _, err = io.ReadFull(&T.reader, T.headerBuf[1:]) if err != nil { - return nil, err + return } } length := binary.BigEndian.Uint32(T.headerBuf[1:]) - p := make([]byte, length+1) - copy(p, T.headerBuf[:]) + packet = slices.Resize(buffer, int(length)+1) + copy(packet, T.headerBuf[:]) - packet := Packet(p) - _, err := io.ReadFull(&T.reader, packet.Payload()) + _, err = io.ReadFull(&T.reader, packet.Payload()) if err != nil { - return nil, err + return } - return packet, nil + return } func (T *netConn) WriteByte(b byte) error { diff --git a/lib/fed/packet.go b/lib/fed/packet.go index 1fc2f71476989fb3d6e0a14a3fecf0c8f227fd8d..f50496535871c5cfab50fcf8bd85ae2da2a242ad 100644 --- a/lib/fed/packet.go +++ b/lib/fed/packet.go @@ -3,18 +3,34 @@ package fed import ( "encoding/binary" "math" + + "pggat/lib/util/slices" ) type Packet []byte func NewPacket(typ Type, size ...int) Packet { + return Packet(nil).Reset(typ, size...) +} + +func (T Packet) Reset(typ Type, size ...int) Packet { + packet := T + c := 5 if len(size) > 0 { - packet := make([]byte, 5, 5+size[0]) - packet[0] = byte(typ) - return packet + c += size[0] + } + + if cap(packet) < c { + packet = make([]byte, 5, c) } else { - return []byte{byte(typ), 0, 0, 0, 0} + packet = slices.Resize(packet, 5) } + packet[0] = byte(typ) + packet[1] = 0 + packet[2] = 0 + packet[3] = 0 + packet[4] = 0 + return packet } func (T Packet) Payload() PacketFragment { @@ -130,6 +146,10 @@ func (T Packet) ReadBytes(v []byte) PacketFragment { return T.Payload().ReadBytes(v) } +func (T Packet) Done() { + // TODO(garet) +} + type PacketFragment []byte func (T PacketFragment) ReadUint8(v *uint8) PacketFragment { diff --git a/lib/fed/packets/v3.0/authenticationcleartext.go b/lib/fed/packets/v3.0/authenticationcleartext.go index 42af01dc6d7e5e335639b8261aea1ceb24e0861d..2726645f63f80f38da65946a73d7438d552908f8 100644 --- a/lib/fed/packets/v3.0/authenticationcleartext.go +++ b/lib/fed/packets/v3.0/authenticationcleartext.go @@ -16,8 +16,8 @@ func (T *AuthenticationCleartext) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationCleartext) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthentication, 4) +func (T *AuthenticationCleartext) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthentication, 4) packet = packet.AppendUint32(3) return packet } diff --git a/lib/fed/packets/v3.0/authenticationmd5.go b/lib/fed/packets/v3.0/authenticationmd5.go index c8748de25bc6bf5fd8283a5933499a7223e9230a..d7a264c0672f9f3f981dd95fd6de735816d9f2ba 100644 --- a/lib/fed/packets/v3.0/authenticationmd5.go +++ b/lib/fed/packets/v3.0/authenticationmd5.go @@ -19,8 +19,8 @@ func (T *AuthenticationMD5) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationMD5) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthentication, 8) +func (T *AuthenticationMD5) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthentication, 8) packet = packet.AppendUint32(5) packet = packet.AppendBytes(T.Salt[:]) return packet diff --git a/lib/fed/packets/v3.0/authenticationok.go b/lib/fed/packets/v3.0/authenticationok.go index f2a089083917ae5640bf1a843b29007673472dc4..a4d00cbed2be87aaf8fdad7cc04da6d1434a36b8 100644 --- a/lib/fed/packets/v3.0/authenticationok.go +++ b/lib/fed/packets/v3.0/authenticationok.go @@ -16,8 +16,8 @@ func (T *AuthenticationOk) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationOk) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthentication, 4) +func (T *AuthenticationOk) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthentication, 4) packet = packet.AppendUint32(0) return packet } diff --git a/lib/fed/packets/v3.0/authenticationresponse.go b/lib/fed/packets/v3.0/authenticationresponse.go index 03728f95295f41111f52681dffaa8a52f530f355..96ca746a0feae7bead0da4720545243252fe0f41 100644 --- a/lib/fed/packets/v3.0/authenticationresponse.go +++ b/lib/fed/packets/v3.0/authenticationresponse.go @@ -16,7 +16,8 @@ func (T *AuthenticationResponse) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationResponse) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthenticationResponse, len(*T)) - return packet.AppendBytes(*T) +func (T *AuthenticationResponse) IntoPacket(packet fed.Packet) fed.Packet { + packet = fed.NewPacket(TypeAuthenticationResponse, len(*T)) + packet = packet.AppendBytes(*T) + return packet } diff --git a/lib/fed/packets/v3.0/authenticationsasl.go b/lib/fed/packets/v3.0/authenticationsasl.go index cdf34d78676faafbb03678cf697baaa38d210ad1..8fb2833f66bad87b835596bae18757725624065f 100644 --- a/lib/fed/packets/v3.0/authenticationsasl.go +++ b/lib/fed/packets/v3.0/authenticationsasl.go @@ -27,13 +27,13 @@ func (T *AuthenticationSASL) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationSASL) IntoPacket() fed.Packet { +func (T *AuthenticationSASL) IntoPacket(packet fed.Packet) fed.Packet { size := 5 for _, mechanism := range T.Mechanisms { size += len(mechanism) + 1 } - packet := fed.NewPacket(TypeAuthentication, size) + packet = packet.Reset(TypeAuthentication, size) packet = packet.AppendInt32(10) for _, mechanism := range T.Mechanisms { diff --git a/lib/fed/packets/v3.0/authenticationsaslcontinue.go b/lib/fed/packets/v3.0/authenticationsaslcontinue.go index 7175e759964ba69a75ba6985a18c2574c7d6564e..8dd145a94294e6461b5097daf37129fda60d4f97 100644 --- a/lib/fed/packets/v3.0/authenticationsaslcontinue.go +++ b/lib/fed/packets/v3.0/authenticationsaslcontinue.go @@ -21,8 +21,8 @@ func (T *AuthenticationSASLContinue) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationSASLContinue) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthentication, 4+len(*T)) +func (T *AuthenticationSASLContinue) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthentication, 4+len(*T)) packet = packet.AppendUint32(11) packet = packet.AppendBytes(*T) return packet diff --git a/lib/fed/packets/v3.0/authenticationsaslfinal.go b/lib/fed/packets/v3.0/authenticationsaslfinal.go index 368f0f53fd19224fec7b3fe3f08eaf4fbed84920..f529d5d81b0a30e4ec67b837e2e88b2c7ef049d1 100644 --- a/lib/fed/packets/v3.0/authenticationsaslfinal.go +++ b/lib/fed/packets/v3.0/authenticationsaslfinal.go @@ -21,8 +21,8 @@ func (T *AuthenticationSASLFinal) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *AuthenticationSASLFinal) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthentication, 4+len(*T)) +func (T *AuthenticationSASLFinal) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthentication, 4+len(*T)) packet = packet.AppendUint32(12) packet = packet.AppendBytes(*T) return packet diff --git a/lib/fed/packets/v3.0/backendkeydata.go b/lib/fed/packets/v3.0/backendkeydata.go index b3ee5959aea949d7c560aa4a3ac65405d72aefe1..5efa8c921fe186b9a72674bd2810feb0e6a73402 100644 --- a/lib/fed/packets/v3.0/backendkeydata.go +++ b/lib/fed/packets/v3.0/backendkeydata.go @@ -14,8 +14,8 @@ func (T *BackendKeyData) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *BackendKeyData) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeBackendKeyData, 8) +func (T *BackendKeyData) IntoPacket(packet fed.Packet) fed.Packet { + packet = fed.NewPacket(TypeBackendKeyData, 8) packet = packet.AppendBytes(T.CancellationKey[:]) return packet } diff --git a/lib/fed/packets/v3.0/bind.go b/lib/fed/packets/v3.0/bind.go index 1dc4dee31379cac65c56629bae199988e4b25061..d316ae055320b7cd984ced8623faebe62b50c1a0 100644 --- a/lib/fed/packets/v3.0/bind.go +++ b/lib/fed/packets/v3.0/bind.go @@ -51,7 +51,7 @@ func (T *Bind) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *Bind) IntoPacket() fed.Packet { +func (T *Bind) IntoPacket(packet fed.Packet) fed.Packet { size := 0 size += len(T.Destination) + 1 size += len(T.Source) + 1 @@ -64,7 +64,7 @@ func (T *Bind) IntoPacket() fed.Packet { size += 2 size += len(T.ResultFormatCodes) * 2 - packet := fed.NewPacket(TypeBind, size) + packet = packet.Reset(TypeBind, size) packet = packet.AppendString(T.Destination) packet = packet.AppendString(T.Source) packet = packet.AppendUint16(uint16(len(T.ParameterFormatCodes))) diff --git a/lib/fed/packets/v3.0/close.go b/lib/fed/packets/v3.0/close.go index d4ec7fb1a8c72afc050cbcef4da6f10be821fed2..895bd545d8634b3b174575ea38ad1f18f743aa7e 100644 --- a/lib/fed/packets/v3.0/close.go +++ b/lib/fed/packets/v3.0/close.go @@ -16,8 +16,8 @@ func (T *Close) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *Close) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeClose, 2+len(T.Target)) +func (T *Close) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeClose, 2+len(T.Target)) packet = packet.AppendUint8(T.Which) packet = packet.AppendString(T.Target) return packet diff --git a/lib/fed/packets/v3.0/commandcomplete.go b/lib/fed/packets/v3.0/commandcomplete.go index d2f2cb863759ff8c0aa4d2755ca61b60d62578b9..96c0a4574cfa546905687109223672afd7554d7d 100644 --- a/lib/fed/packets/v3.0/commandcomplete.go +++ b/lib/fed/packets/v3.0/commandcomplete.go @@ -12,8 +12,8 @@ func (T *CommandComplete) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *CommandComplete) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeCommandComplete, len(*T)+1) +func (T *CommandComplete) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeCommandComplete, len(*T)+1) packet = packet.AppendString(string(*T)) return packet } diff --git a/lib/fed/packets/v3.0/copydata.go b/lib/fed/packets/v3.0/copydata.go index 9f4078708e956ce1c1a4a0dba7e89aff66061a34..68e056d63a1d05eeaa8c2d01ec1b228ee1a110dc 100644 --- a/lib/fed/packets/v3.0/copydata.go +++ b/lib/fed/packets/v3.0/copydata.go @@ -17,8 +17,8 @@ func (T *CopyData) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *CopyData) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeCopyData, len(*T)) +func (T *CopyData) IntoPacket(packet fed.Packet) fed.Packet { + packet = fed.NewPacket(TypeCopyData, len(*T)) packet = packet.AppendBytes(*T) return packet } diff --git a/lib/fed/packets/v3.0/copyfail.go b/lib/fed/packets/v3.0/copyfail.go index e1c08fd371bace7d9793fec4e02abad8253633f9..2762bda7cedf82a673024bc08ad982f4651d61e0 100644 --- a/lib/fed/packets/v3.0/copyfail.go +++ b/lib/fed/packets/v3.0/copyfail.go @@ -14,8 +14,8 @@ func (T *CopyFail) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *CopyFail) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeCopyFail, len(T.Reason)+1) +func (T *CopyFail) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeCopyFail, len(T.Reason)+1) packet = packet.AppendString(T.Reason) return packet } diff --git a/lib/fed/packets/v3.0/datarow.go b/lib/fed/packets/v3.0/datarow.go index 62d02a6a3948db30a6ade5f75df2542b001e677b..c4e6a9889900ebb54bd857d05a31efc391dc67a0 100644 --- a/lib/fed/packets/v3.0/datarow.go +++ b/lib/fed/packets/v3.0/datarow.go @@ -30,13 +30,13 @@ func (T *DataRow) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *DataRow) IntoPacket() fed.Packet { +func (T *DataRow) IntoPacket(packet fed.Packet) fed.Packet { size := 2 for _, v := range T.Columns { size += len(v) + 4 } - packet := fed.NewPacket(TypeDataRow, size) + packet = packet.Reset(TypeDataRow, size) packet = packet.AppendUint16(uint16(len(T.Columns))) for _, v := range T.Columns { if v == nil { diff --git a/lib/fed/packets/v3.0/describe.go b/lib/fed/packets/v3.0/describe.go index a5346ea0e1b933b1888cc6c67211edfadfab3ee4..e9c6c75a8db297907c7de9473b6c51097462cf7f 100644 --- a/lib/fed/packets/v3.0/describe.go +++ b/lib/fed/packets/v3.0/describe.go @@ -16,8 +16,8 @@ func (T *Describe) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *Describe) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeDescribe, len(T.Target)+2) +func (T *Describe) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeDescribe, len(T.Target)+2) packet = packet.AppendUint8(T.Which) packet = packet.AppendString(T.Target) return packet diff --git a/lib/fed/packets/v3.0/errorresponse.go b/lib/fed/packets/v3.0/errorresponse.go index 3bf60fb3bddfa053b84abf37e6b1ff07ff52e7ea..8d380ba50d2ccee2a26f503dbab82f514c6c0c06 100644 --- a/lib/fed/packets/v3.0/errorresponse.go +++ b/lib/fed/packets/v3.0/errorresponse.go @@ -56,7 +56,7 @@ func (T *ErrorResponse) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *ErrorResponse) IntoPacket() fed.Packet { +func (T *ErrorResponse) IntoPacket(packet fed.Packet) fed.Packet { size := 1 size += len(T.Error.Severity()) + 2 size += len(T.Error.Code()) + 2 @@ -65,7 +65,7 @@ func (T *ErrorResponse) IntoPacket() fed.Packet { size += len(field.Value) + 2 } - packet := fed.NewPacket(TypeErrorResponse, size) + packet = packet.Reset(TypeErrorResponse, size) packet = packet.AppendUint8('S') packet = packet.AppendString(string(T.Error.Severity())) diff --git a/lib/fed/packets/v3.0/execute.go b/lib/fed/packets/v3.0/execute.go index 7b0fd858bab5be5d53b3bdfe218b8776c9ba4dce..c6db8ac96283c27a289656731473d44510c2d541 100644 --- a/lib/fed/packets/v3.0/execute.go +++ b/lib/fed/packets/v3.0/execute.go @@ -16,8 +16,8 @@ func (T *Execute) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *Execute) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeExecute, len(T.Target)+5) +func (T *Execute) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeExecute, len(T.Target)+5) packet = packet.AppendString(T.Target) packet = packet.AppendInt32(T.MaxRows) return packet diff --git a/lib/fed/packets/v3.0/negotiateprotocolversion.go b/lib/fed/packets/v3.0/negotiateprotocolversion.go index a6810e3f9bd7d551cc3a7c840f2ac42a6dac7b5d..e0c556ecc52f2a624b7f66eaf0aef31d18bd3463 100644 --- a/lib/fed/packets/v3.0/negotiateprotocolversion.go +++ b/lib/fed/packets/v3.0/negotiateprotocolversion.go @@ -27,13 +27,13 @@ func (T *NegotiateProtocolVersion) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *NegotiateProtocolVersion) IntoPacket() fed.Packet { +func (T *NegotiateProtocolVersion) IntoPacket(packet fed.Packet) fed.Packet { size := 8 for _, v := range T.UnrecognizedOptions { size += len(v) + 1 } - packet := fed.NewPacket(TypeNegotiateProtocolVersion, size) + packet = packet.Reset(TypeNegotiateProtocolVersion, size) packet = packet.AppendInt32(T.MinorProtocolVersion) packet = packet.AppendInt32(int32(len(T.UnrecognizedOptions))) for _, v := range T.UnrecognizedOptions { diff --git a/lib/fed/packets/v3.0/parameterstatus.go b/lib/fed/packets/v3.0/parameterstatus.go index f4b296e1afa6df0e93814f45eab84226d34ceea0..34b1ade85df3731f8049a0c8c6b2237a05987bb7 100644 --- a/lib/fed/packets/v3.0/parameterstatus.go +++ b/lib/fed/packets/v3.0/parameterstatus.go @@ -16,8 +16,8 @@ func (T *ParameterStatus) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *ParameterStatus) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeParameterStatus, len(T.Key)+len(T.Value)+2) +func (T *ParameterStatus) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeParameterStatus, len(T.Key)+len(T.Value)+2) packet = packet.AppendString(T.Key) packet = packet.AppendString(T.Value) return packet diff --git a/lib/fed/packets/v3.0/parse.go b/lib/fed/packets/v3.0/parse.go index 55a93188ae1045f9705dc5888238c3c253c5038c..d4fedc633e887ee5f7fef1d8e91477c7e88b55d8 100644 --- a/lib/fed/packets/v3.0/parse.go +++ b/lib/fed/packets/v3.0/parse.go @@ -26,8 +26,8 @@ func (T *Parse) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *Parse) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeParse, len(T.Destination)+len(T.Query)+4+len(T.ParameterDataTypes)*4) +func (T *Parse) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeParse, len(T.Destination)+len(T.Query)+4+len(T.ParameterDataTypes)*4) packet = packet.AppendString(T.Destination) packet = packet.AppendString(T.Query) packet = packet.AppendInt16(int16(len(T.ParameterDataTypes))) diff --git a/lib/fed/packets/v3.0/passwordmessage.go b/lib/fed/packets/v3.0/passwordmessage.go index a568fe5646897743b3f1dcbb3b1e123eed31878d..6bdb123656859a5d01777706f9483e8fbf9ed0ac 100644 --- a/lib/fed/packets/v3.0/passwordmessage.go +++ b/lib/fed/packets/v3.0/passwordmessage.go @@ -16,8 +16,8 @@ func (T *PasswordMessage) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *PasswordMessage) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthenticationResponse, len(T.Password)+1) +func (T *PasswordMessage) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthenticationResponse, len(T.Password)+1) packet = packet.AppendString(T.Password) return packet } diff --git a/lib/fed/packets/v3.0/query.go b/lib/fed/packets/v3.0/query.go index ea7085b112d9225e0bab07e73bce06dd3633a128..bd47b747a276ae2451192e4a79b87df7d110679a 100644 --- a/lib/fed/packets/v3.0/query.go +++ b/lib/fed/packets/v3.0/query.go @@ -12,8 +12,8 @@ func (T *Query) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *Query) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeQuery, len(*T)+1) +func (T *Query) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeQuery, len(*T)+1) packet = packet.AppendString(string(*T)) return packet } diff --git a/lib/fed/packets/v3.0/readyforquery.go b/lib/fed/packets/v3.0/readyforquery.go index e306477a506aa072ea993b2b87a0dccf2f7d3d07..7ba5ea98a2fe7e7d63b5fc98fd77ab513e1fad25 100644 --- a/lib/fed/packets/v3.0/readyforquery.go +++ b/lib/fed/packets/v3.0/readyforquery.go @@ -14,8 +14,8 @@ func (T *ReadyForQuery) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *ReadyForQuery) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeReadyForQuery, 1) +func (T *ReadyForQuery) IntoPacket(packet fed.Packet) fed.Packet { + packet = fed.NewPacket(TypeReadyForQuery, 1) packet = packet.AppendUint8(byte(*T)) return packet } diff --git a/lib/fed/packets/v3.0/rowdescription.go b/lib/fed/packets/v3.0/rowdescription.go index 5df7b58762de7f882ae9b087b00b08955325441d..91014c23b8014b27c75511c1cc245368f92cd0ec 100644 --- a/lib/fed/packets/v3.0/rowdescription.go +++ b/lib/fed/packets/v3.0/rowdescription.go @@ -40,14 +40,14 @@ func (T *RowDescription) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *RowDescription) IntoPacket() fed.Packet { +func (T *RowDescription) IntoPacket(packet fed.Packet) fed.Packet { size := 2 for _, v := range T.Fields { size += len(v.Name) + 1 size += 4 + 2 + 4 + 2 + 4 + 2 } - packet := fed.NewPacket(TypeRowDescription, size) + packet = packet.Reset(TypeRowDescription, size) packet = packet.AppendUint16(uint16(len(T.Fields))) for _, v := range T.Fields { packet = packet.AppendString(v.Name) diff --git a/lib/fed/packets/v3.0/saslinitialresponse.go b/lib/fed/packets/v3.0/saslinitialresponse.go index 944e17778f14207928a876aa9319f7fdf14c0260..1ff65ef58e99951ffaa66dd9245820d0c0036524 100644 --- a/lib/fed/packets/v3.0/saslinitialresponse.go +++ b/lib/fed/packets/v3.0/saslinitialresponse.go @@ -26,8 +26,8 @@ func (T *SASLInitialResponse) ReadFromPacket(packet fed.Packet) bool { return true } -func (T *SASLInitialResponse) IntoPacket() fed.Packet { - packet := fed.NewPacket(TypeAuthenticationResponse, len(T.Mechanism)+5+len(T.InitialResponse)) +func (T *SASLInitialResponse) IntoPacket(packet fed.Packet) fed.Packet { + packet = packet.Reset(TypeAuthenticationResponse, len(T.Mechanism)+5+len(T.InitialResponse)) packet = packet.AppendString(T.Mechanism) packet = packet.AppendInt32(int32(len(T.InitialResponse))) packet = packet.AppendBytes(T.InitialResponse) diff --git a/lib/fed/readwriter.go b/lib/fed/readwriter.go index 1f8b23c2d02fdace07310098f7a1c465fae8058c..e26d66f933db5e2e3368ace6f4ab398330c98f2e 100644 --- a/lib/fed/readwriter.go +++ b/lib/fed/readwriter.go @@ -1,7 +1,7 @@ package fed type Reader interface { - ReadPacket(typed bool) (Packet, error) + ReadPacket(typed bool, buffer Packet) (Packet, error) } type Writer interface { diff --git a/lib/gat/acceptor.go b/lib/gat/acceptor.go index c394ecf60153744b526ac204ea48e533e302e234..69ab7b0cd13648b2f8d27e84dd672485d958a86e 100644 --- a/lib/gat/acceptor.go +++ b/lib/gat/acceptor.go @@ -17,20 +17,6 @@ type Acceptor struct { Options frontends.AcceptOptions } -func (T Acceptor) Accept() (fed.Conn, frontends.AcceptParams, error) { - netConn, err := T.Listener.Accept() - if err != nil { - return nil, frontends.AcceptParams{}, err - } - conn := fed.WrapNetConn(netConn) - params, err := frontends.Accept(conn, T.Options) - if err != nil { - _ = conn.Close() - return nil, frontends.AcceptParams{}, err - } - return conn, params, nil -} - func Listen(network, address string, options frontends.AcceptOptions) (Acceptor, error) { listener, err := net.Listen(network, address) if err != nil { @@ -48,7 +34,7 @@ func Listen(network, address string, options frontends.AcceptOptions) (Acceptor, }, nil } -func serve(client fed.Conn, acceptParams frontends.AcceptParams, pools Pools) error { +func serve(client fed.Conn, acceptParams frontends.AcceptParams, pools *KeyedPools) error { defer func() { _ = client.Close() }() @@ -69,9 +55,13 @@ func serve(client fed.Conn, acceptParams frontends.AcceptParams, pools Pools) er return nil } - authParams, err := frontends.Authenticate(client, frontends.AuthenticateOptions{ - Credentials: p.GetCredentials(), - }) + ctx := frontends.AuthenticateContext{ + Conn: client, + Options: frontends.AuthenticateOptions{ + Credentials: p.GetCredentials(), + }, + } + authParams, err := frontends.Authenticate(&ctx) if err != nil { return err } @@ -82,26 +72,43 @@ func serve(client fed.Conn, acceptParams frontends.AcceptParams, pools Pools) er return p.Serve(client, acceptParams.InitialParameters, authParams.BackendKey) } -func Serve(acceptor Acceptor, pools Pools) error { +func Serve(acceptor Acceptor, pools *KeyedPools) error { for { - conn, acceptParams, err := acceptor.Accept() + netConn, err := acceptor.Listener.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return nil } - log.Print("error accepting client: ", err) + log.Print("error accepting connection: ", err) continue } + conn := fed.WrapNetConn(netConn) + go func() { - err := serve(conn, acceptParams, pools) + defer func() { + _ = conn.Close() + }() + + ctx := frontends.AcceptContext{ + Conn: conn, + Options: acceptor.Options, + } + acceptParams, acceptErr := frontends.Accept(&ctx) + if acceptErr != nil { + log.Print("error accepting client: ", acceptErr) + return + } + + err = serve(conn, acceptParams, pools) if err != nil && !errors.Is(err, io.EOF) { log.Print("error serving client: ", err) + return } }() } } -func ListenAndServe(network, address string, options frontends.AcceptOptions, pools Pools) error { +func ListenAndServe(network, address string, options frontends.AcceptOptions, pools *KeyedPools) error { listener, err := Listen(network, address, options) if err != nil { return err diff --git a/lib/gat/keyedpool.go b/lib/gat/keyedpool.go new file mode 100644 index 0000000000000000000000000000000000000000..ef5864eb9b7997ef5a1f43f99f2fe181e2884501 --- /dev/null +++ b/lib/gat/keyedpool.go @@ -0,0 +1,35 @@ +package gat + +import ( + "pggat/lib/gat/pool" + "pggat/lib/util/maps" +) + +type KeyedPools struct { + Pools + + keys maps.RWLocked[[8]byte, *pool.Pool] +} + +func NewKeyedPools(pools Pools) *KeyedPools { + return &KeyedPools{ + Pools: pools, + } +} + +func (T *KeyedPools) RegisterKey(key [8]byte, user, database string) { + p := T.Lookup(user, database) + if p == nil { + return + } + T.keys.Store(key, p) +} + +func (T *KeyedPools) UnregisterKey(key [8]byte) { + T.keys.Delete(key) +} + +func (T *KeyedPools) LookupKey(key [8]byte) *pool.Pool { + p, _ := T.keys.Load(key) + return p +} diff --git a/lib/gat/modes/cloud_sql_discovery/config.go b/lib/gat/modes/cloud_sql_discovery/config.go new file mode 100644 index 0000000000000000000000000000000000000000..da7df4fdf5b529927be8fae763c13cd797023929 --- /dev/null +++ b/lib/gat/modes/cloud_sql_discovery/config.go @@ -0,0 +1,68 @@ +package cloud_sql_discovery + +import ( + "errors" + "time" + + "gfx.cafe/util/go/gun" + "tuxpa.in/a/zlog/log" + + "pggat/lib/bouncer/frontends/v0" + "pggat/lib/gat" + "pggat/lib/gat/metrics" + "pggat/lib/util/flip" + "pggat/lib/util/strutil" +) + +type Config struct { + Project string `env:"PGGAT_GC_PROJECT"` + IpAddressType string `env:"PGGAT_GC_IP_ADDR_TYPE" default:"PRIMARY"` + AuthUser string `env:"PGGAT_GC_AUTH_USER" default:"pggat"` + AuthPassword string `env:"PGGAT_GC_AUTH_PASSWORD"` +} + +func Load() (Config, error) { + var conf Config + gun.Load(&conf) + if conf.Project == "" { + return Config{}, errors.New("expected google cloud project id") + } + return conf, nil +} + +func (T *Config) ListenAndServe() error { + pools, err := NewPools(T) + if err != nil { + return err + } + + go func() { + var m metrics.Pools + for { + m.Clear() + time.Sleep(1 * time.Minute) + pools.ReadMetrics(&m) + log.Print(m.String()) + } + }() + + var b flip.Bank + + b.Queue(func() error { + log.Print("listening on :5432") + return gat.ListenAndServe("tcp", ":5432", frontends.AcceptOptions{ + // TODO(garet) ssl config + AllowedStartupOptions: []strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + strutil.MakeCIString("extra_float_digits"), + strutil.MakeCIString("options"), + }, + }, gat.NewKeyedPools(pools)) + }) + + return b.Wait() +} diff --git a/lib/gat/modes/cloud_sql_discovery/pools.go b/lib/gat/modes/cloud_sql_discovery/pools.go new file mode 100644 index 0000000000000000000000000000000000000000..e170fa517d5b129b19c41c739a9ed02575901721 --- /dev/null +++ b/lib/gat/modes/cloud_sql_discovery/pools.go @@ -0,0 +1,216 @@ +package cloud_sql_discovery + +import ( + "context" + "crypto/tls" + "errors" + "io" + "net" + "strings" + "time" + + sqladmin "google.golang.org/api/sqladmin/v1beta4" + "tuxpa.in/a/zlog/log" + + "pggat/lib/auth" + "pggat/lib/auth/credentials" + "pggat/lib/bouncer" + "pggat/lib/bouncer/backends/v0" + "pggat/lib/gat" + "pggat/lib/gat/metrics" + "pggat/lib/gat/pool" + "pggat/lib/gat/pool/dialer" + "pggat/lib/gat/pool/pools/transaction" + "pggat/lib/gat/pool/recipe" + "pggat/lib/gsql" + "pggat/lib/util/maps" + "pggat/lib/util/strutil" +) + +type authQueryResult struct { + Username string `sql:"0"` + Password string `sql:"1"` +} + +type poolKey struct { + User string + Database string +} + +type poolTemplate struct { + Address string +} + +type Pools struct { + Config *Config + + templates maps.RWLocked[poolKey, poolTemplate] + pools maps.RWLocked[poolKey, *pool.Pool] +} + +func NewPools(config *Config) (*Pools, error) { + p := &Pools{ + Config: config, + } + + if err := p.init(); err != nil { + return nil, err + } + + return p, nil +} + +func (T *Pools) init() error { + service, err := sqladmin.NewService(context.Background()) + if err != nil { + return err + } + + instances, err := service.Instances.List(T.Config.Project).Do() + if err != nil { + return err + } + + for _, instance := range instances.Items { + if !strings.HasPrefix(instance.DatabaseVersion, "POSTGRES_") { + continue + } + + var address string + for _, ip := range instance.IpAddresses { + if ip.Type != T.Config.IpAddressType { + continue + } + address = net.JoinHostPort(ip.IpAddress, "5432") + } + if address == "" { + continue + } + + users, err := service.Users.List(T.Config.Project, instance.Name).Do() + if err != nil { + return err + } + databases, err := service.Databases.List(T.Config.Project, instance.Name).Do() + if err != nil { + return err + } + for _, user := range users.Items { + for _, database := range databases.Items { + T.templates.Store(poolKey{ + User: user.Name, + Database: database.Name, + }, poolTemplate{ + Address: address, + }) + log.Printf("registered database user=%s database=%s", user.Name, database.Name) + } + } + } + + return nil +} + +func (T *Pools) Lookup(user, database string) *pool.Pool { + p, ok := T.pools.Load(poolKey{ + User: user, + Database: database, + }) + if ok { + return p + } + template, ok := T.templates.Load(poolKey{ + User: user, + Database: database, + }) + if !ok { + return nil + } + + var creds auth.Credentials + if user == T.Config.AuthUser { + creds = credentials.Cleartext{ + Username: user, + Password: T.Config.AuthPassword, + } + } else { + // query for password + authPool := T.Lookup(T.Config.AuthUser, database) + if authPool == nil { + return nil + } + + var result authQueryResult + client := new(gsql.Client) + err := gsql.ExtendedQuery(client, &result, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", user) + if err != nil { + log.Println("auth query failed:", err) + return nil + } + err = client.Close() + if err != nil { + log.Println("auth query failed:", err) + return nil + } + err = authPool.ServeBot(client) + if err != nil && !errors.Is(err, io.EOF) { + log.Println("auth query failed:", err) + return nil + } + + if result.Username != user { + // user not found + return nil + } + + creds = credentials.FromString(result.Username, result.Password) + } + + d := dialer.Net{ + Network: "tcp", + Address: template.Address, + AcceptOptions: backends.AcceptOptions{ + SSLMode: bouncer.SSLModePrefer, + SSLConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + Username: user, + Credentials: creds, + Database: database, + }, + } + + options := transaction.Apply(pool.Options{ + Credentials: creds, + ServerReconnectInitialTime: 5 * time.Second, + ServerReconnectMaxTime: 5 * time.Second, + ServerIdleTimeout: 5 * time.Minute, + TrackedParameters: []strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, + }) + + p = pool.NewPool(options) + p.AddRecipe("gc", recipe.NewRecipe(recipe.Options{ + Dialer: d, + })) + + T.pools.Store(poolKey{ + User: user, + Database: database, + }, p) + return p +} + +func (T *Pools) ReadMetrics(metrics *metrics.Pools) { + T.pools.Range(func(_ poolKey, p *pool.Pool) bool { + p.ReadMetrics(&metrics.Pool) + return true + }) +} + +var _ gat.Pools = (*Pools)(nil) diff --git a/lib/gat/modes/digitalocean_discovery/config.go b/lib/gat/modes/digitalocean_discovery/config.go index f09a0472cd59ab211f0374433cfe71587ef8a987..ddf416577b32c468ce5b694ed1a488dceff8f9a4 100644 --- a/lib/gat/modes/digitalocean_discovery/config.go +++ b/lib/gat/modes/digitalocean_discovery/config.go @@ -122,6 +122,7 @@ func (T *Config) ListenAndServe() error { SSLConfig: &tls.Config{ InsecureSkipVerify: true, }, + Username: user.Name, Credentials: creds, Database: dbname, } @@ -163,7 +164,7 @@ func (T *Config) ListenAndServe() error { replicaAddr = net.JoinHostPort(replica.Connection.Host, strconv.Itoa(replica.Connection.Port)) } - p2.AddRecipe(replica.Name, recipe.NewRecipe(recipe.Options{ + p2.AddRecipe(replica.ID, recipe.NewRecipe(recipe.Options{ Dialer: dialer.Net{ Network: "tcp", Address: replicaAddr, @@ -194,7 +195,7 @@ func (T *Config) ListenAndServe() error { strutil.MakeCIString("extra_float_digits"), strutil.MakeCIString("options"), }, - }, &pools) + }, gat.NewKeyedPools(&pools)) }) return b.Wait() diff --git a/lib/gat/modes/pgbouncer/config.go b/lib/gat/modes/pgbouncer/config.go index d1444602d8ddd93036b07b5af1e8aef423afea56..191ba398a2b860416f042c4223e317d48d22944b 100644 --- a/lib/gat/modes/pgbouncer/config.go +++ b/lib/gat/modes/pgbouncer/config.go @@ -299,6 +299,8 @@ func (T *Config) ListenAndServe() error { var bank flip.Bank + keyedPools := gat.NewKeyedPools(pools) + if T.PgBouncer.ListenAddr != "" { bank.Queue(func() error { listenAddr := T.PgBouncer.ListenAddr @@ -310,7 +312,7 @@ func (T *Config) ListenAndServe() error { log.Printf("listening on %s", listen) - return gat.ListenAndServe("tcp", listen, acceptOptions, pools) + return gat.ListenAndServe("tcp", listen, acceptOptions, keyedPools) }) } @@ -326,7 +328,7 @@ func (T *Config) ListenAndServe() error { log.Printf("listening on unix:%s", dir) - return gat.ListenAndServe("unix", dir, acceptOptions, pools) + return gat.ListenAndServe("unix", dir, acceptOptions, keyedPools) }) return bank.Wait() diff --git a/lib/gat/modes/pgbouncer/pools.go b/lib/gat/modes/pgbouncer/pools.go index da870275974aa4307e1b552a9b8e4755a5f25882..9d8b29ae2f959590f61d44e2db3f27e95620deae 100644 --- a/lib/gat/modes/pgbouncer/pools.go +++ b/lib/gat/modes/pgbouncer/pools.go @@ -3,6 +3,7 @@ package pgbouncer import ( "crypto/tls" "errors" + "io" "net" "strconv" "strings" @@ -25,8 +26,8 @@ import ( ) type authQueryResult struct { - Username string `sql:"0"` - Password *string `sql:"1"` + Username string `sql:"0"` + Password string `sql:"1"` } type poolKey struct { @@ -38,7 +39,6 @@ type Pools struct { Config *Config pools maps.RWLocked[poolKey, *pool.Pool] - keys maps.RWLocked[[8]byte, *pool.Pool] } func NewPools(config *Config) (*Pools, error) { @@ -110,7 +110,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { return nil } err = authPool.ServeBot(client) - if err != nil && !errors.Is(err, net.ErrClosed) { + if err != nil && !errors.Is(err, io.EOF) { log.Println("auth query failed:", err) return nil } @@ -120,9 +120,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { return nil } - if result.Password != nil { - password = *result.Password - } + password = result.Password } creds := credentials.FromString(user, password) @@ -191,6 +189,7 @@ func (T *Pools) Lookup(user, database string) *pool.Pool { SSLConfig: &tls.Config{ InsecureSkipVerify: true, // TODO(garet) }, + Username: user, Credentials: dbCreds, Database: backendDatabase, StartupParameters: db.StartupParameters, @@ -256,21 +255,4 @@ func (T *Pools) ReadMetrics(metrics *metrics.Pools) { }) } -func (T *Pools) RegisterKey(key [8]byte, user, database string) { - p := T.Lookup(user, database) - if p == nil { - return - } - T.keys.Store(key, p) -} - -func (T *Pools) UnregisterKey(key [8]byte) { - T.keys.Delete(key) -} - -func (T *Pools) LookupKey(key [8]byte) *pool.Pool { - p, _ := T.keys.Load(key) - return p -} - var _ gat.Pools = (*Pools)(nil) diff --git a/lib/gat/modes/zalando_operator_discovery/server.go b/lib/gat/modes/zalando_operator_discovery/server.go index a705590ee5e8549f451aa64b97e8e1800e267266..fdf2fef01c40374871bfaab1683a7b2e29380627 100644 --- a/lib/gat/modes/zalando_operator_discovery/server.go +++ b/lib/gat/modes/zalando_operator_discovery/server.go @@ -153,7 +153,7 @@ func (T *Server) addPostgresql(psql *acidv1.Postgresql) { T.updatePostgresql(nil, psql) } -func (T *Server) addPool(name string, userCreds, serverCreds auth.Credentials, database string) { +func (T *Server) addPool(name string, userCreds, serverCreds auth.Credentials, userUser, serverUser, database string) { d := dialer.Net{ Network: "tcp", Address: fmt.Sprintf("%s.%s.svc.%s:5432", name, T.config.Namespace, T.opConfig.ClusterDomain), @@ -162,6 +162,7 @@ func (T *Server) addPool(name string, userCreds, serverCreds auth.Credentials, d SSLConfig: &tls.Config{ InsecureSkipVerify: true, }, + Username: serverUser, Credentials: serverCreds, Database: database, }, @@ -205,7 +206,7 @@ func (T *Server) addPool(name string, userCreds, serverCreds auth.Credentials, d p.AddRecipe("service", r) - T.pools.Add(userCreds.GetUsername(), database, p) + T.pools.Add(userUser, database, p) } func (T *Server) updatePostgresql(oldPsql *acidv1.Postgresql, newPsql *acidv1.Postgresql) { @@ -305,7 +306,7 @@ func (T *Server) updatePostgresql(oldPsql *acidv1.Postgresql, newPsql *acidv1.Po Username: pair.User, Password: creds.Password, } - T.addPool(details.Name, userCreds, creds, pair.Database) + T.addPool(details.Name, userCreds, creds, pair.User, details.SecretUser, pair.Database) log.Print("added pool username=", pair.User, " database=", pair.Database) } } @@ -361,7 +362,7 @@ func (T *Server) ListenAndServe() error { strutil.MakeCIString("options"), }, SSLConfig: sslConfig, - }, &T.pools) + }, gat.NewKeyedPools(&T.pools)) }) return bank.Wait() diff --git a/lib/gat/pool/dialer/net.go b/lib/gat/pool/dialer/net.go index 490554b3b9be3dc15e6557c25d6e97482a93e0a5..b9949e6eb1f0803dc9dc1318417876df386e5fc0 100644 --- a/lib/gat/pool/dialer/net.go +++ b/lib/gat/pool/dialer/net.go @@ -22,7 +22,11 @@ func (T Net) Dial() (fed.Conn, backends.AcceptParams, error) { return nil, backends.AcceptParams{}, err } conn := fed.WrapNetConn(c) - params, err := backends.Accept(conn, T.AcceptOptions) + ctx := backends.AcceptContext{ + Conn: conn, + Options: T.AcceptOptions, + } + params, err := backends.Accept(&ctx) if err != nil { return nil, backends.AcceptParams{}, err } @@ -43,7 +47,7 @@ func (T Net) Cancel(key [8]byte) error { } // wait for server to close the connection, this means that the server received it ok - _, err = conn.ReadPacket(true) + _, err = conn.ReadPacket(true, nil) if err != nil && !errors.Is(err, io.EOF) { return err } diff --git a/lib/gat/pool/flow.go b/lib/gat/pool/flow.go index 0dce80416b81122000c6e32f8f2cc6fbc56a8c97..3e47c04d84818bab801435dc38f1b073fd5015b1 100644 --- a/lib/gat/pool/flow.go +++ b/lib/gat/pool/flow.go @@ -2,6 +2,7 @@ package pool import ( "pggat/lib/bouncer/backends/v0" + "pggat/lib/fed" packets "pggat/lib/fed/packets/v3.0" "pggat/lib/gat/metrics" "pggat/lib/middleware/middlewares/eqp" @@ -42,6 +43,8 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli clientParams := client.GetInitialParameters() serverParams := server.GetInitialParameters() + var packet fed.Packet + for key, value := range clientParams { // skip already set params if serverParams[key] == value { @@ -49,7 +52,8 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli Key: key.String(), Value: serverParams[key], } - clientErr = client.GetConn().WritePacket(p.IntoPacket()) + packet = p.IntoPacket(packet) + clientErr = client.GetConn().WritePacket(packet) if clientErr != nil { return } @@ -66,7 +70,8 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli Key: key.String(), Value: value, } - clientErr = client.GetConn().WritePacket(p.IntoPacket()) + packet = p.IntoPacket(packet) + clientErr = client.GetConn().WritePacket(packet) if clientErr != nil { return } @@ -75,10 +80,15 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli continue } - serverErr = backends.SetParameter(new(backends.Context), server.GetReadWriter(), key, value) + ctx := backends.Context{ + Packet: packet, + Server: server.GetReadWriter(), + } + serverErr = backends.SetParameter(&ctx, key, value) if serverErr != nil { return } + packet = ctx.Packet } for key, value := range serverParams { @@ -93,7 +103,8 @@ func SyncInitialParameters(options Options, client *Client, server *Server) (cli Key: key.String(), Value: value, } - clientErr = client.GetConn().WritePacket(p.IntoPacket()) + packet = p.IntoPacket(packet) + clientErr = client.GetConn().WritePacket(packet) if clientErr != nil { return } diff --git a/lib/gat/pool/pool.go b/lib/gat/pool/pool.go index 4f53b811b8a741d212f92e47984870ea93373cf0..1d668135eee24d51a5e1f048335c5af4b79225ab 100644 --- a/lib/gat/pool/pool.go +++ b/lib/gat/pool/pool.go @@ -29,12 +29,13 @@ type Pool struct { pendingCount atomic.Int64 pending chan struct{} - recipes map[string]*recipe.Recipe - clients map[uuid.UUID]*Client - clientsByKey map[[8]byte]*Client - servers map[uuid.UUID]*Server - serversByRecipe map[string][]*Server - mu sync.RWMutex + recipes map[string]*recipe.Recipe + recipeScaleOrder slices.Sorted[string] + clients map[uuid.UUID]*Client + clientsByKey map[[8]byte]*Client + servers map[uuid.UUID]*Server + serversByRecipe map[string][]*Server + mu sync.RWMutex } func NewPool(options Options) *Pool { @@ -94,6 +95,11 @@ func (T *Pool) AddRecipe(name string, r *recipe.Recipe) { T.recipes = make(map[string]*recipe.Recipe) } T.recipes[name] = r + + // add to front of scale order + T.recipeScaleOrder = T.recipeScaleOrder.Insert(name, func(n string) int { + return len(T.serversByRecipe[n]) + }) }() count := r.AllocateInitial() @@ -124,6 +130,8 @@ func (T *Pool) removeRecipe(name string) { servers := T.serversByRecipe[name] delete(T.serversByRecipe, name) + // remove from recipeScaleOrder + T.recipeScaleOrder = slices.Delete(T.recipeScaleOrder, name) for _, server := range servers { r.Free() @@ -135,15 +143,13 @@ func (T *Pool) scaleUpL0() (string, *recipe.Recipe) { T.mu.RLock() defer T.mu.RUnlock() - for name, r := range T.recipes { + for _, name := range T.recipeScaleOrder { + r := T.recipes[name] if r.Allocate() { return name, r } } - if len(T.servers) > 0 { - return "", nil - } return "", nil } @@ -181,6 +187,10 @@ func (T *Pool) scaleUpL1(name string, r *recipe.Recipe) error { T.serversByRecipe = make(map[string][]*Server) } T.serversByRecipe[name] = append(T.serversByRecipe[name], server) + // update order + T.recipeScaleOrder.Update(slices.Index(T.recipeScaleOrder, name), func(n string) int { + return len(T.serversByRecipe[n]) + }) return server, nil }() @@ -219,7 +229,12 @@ func (T *Pool) removeServerL1(server *Server) { T.pooler.DeleteServer(server.GetID()) _ = server.GetConn().Close() if T.serversByRecipe != nil { - T.serversByRecipe[server.GetRecipe()] = slices.Delete(T.serversByRecipe[server.GetRecipe()], server) + name := server.GetRecipe() + T.serversByRecipe[name] = slices.Delete(T.serversByRecipe[name], server) + // update order + T.recipeScaleOrder.Update(slices.Index(T.recipeScaleOrder, name), func(n string) int { + return len(T.serversByRecipe[n]) + }) } } @@ -253,7 +268,10 @@ func (T *Pool) releaseServer(server *Server) { if T.options.ServerResetQuery != "" { server.SetState(metrics.ConnStateRunningResetQuery, uuid.Nil) - err := backends.QueryString(new(backends.Context), server.GetReadWriter(), T.options.ServerResetQuery) + ctx := backends.Context{ + Server: server.GetReadWriter(), + } + err := backends.QueryString(&ctx, T.options.ServerResetQuery) if err != nil { T.removeServer(server) return @@ -322,6 +340,8 @@ func (T *Pool) serve(client *Client, initialized bool) error { } }() + var packet fed.Packet + if !initialized { server = T.acquireServer(client) @@ -334,7 +354,8 @@ func (T *Pool) serve(client *Client, initialized bool) error { } p := packets.ReadyForQuery('I') - err = client.GetConn().WritePacket(p.IntoPacket()) + packet = p.IntoPacket(packet) + err = client.GetConn().WritePacket(packet) if err != nil { return err } @@ -347,8 +368,7 @@ func (T *Pool) serve(client *Client, initialized bool) error { server = nil } - var packet fed.Packet - packet, err = client.GetConn().ReadPacket(true) + packet, err = client.GetConn().ReadPacket(true, packet) if err != nil { return err } @@ -359,7 +379,7 @@ func (T *Pool) serve(client *Client, initialized bool) error { err, serverErr = Pair(T.options, client, server) } if err == nil && serverErr == nil { - err, serverErr = bouncers.Bounce(client.GetReadWriter(), server.GetReadWriter(), packet) + packet, err, serverErr = bouncers.Bounce(client.GetReadWriter(), server.GetReadWriter(), packet) } if serverErr != nil { return serverErr diff --git a/lib/gat/pool/scaler.go b/lib/gat/pool/scaler.go index 224c37fbcedfed63b74be006e0fd864a444f58a5..ef61951b7f1a83558a883db71606185a8be007fb 100644 --- a/lib/gat/pool/scaler.go +++ b/lib/gat/pool/scaler.go @@ -2,6 +2,7 @@ package pool import ( "time" + "tuxpa.in/a/zlog/log" ) @@ -19,7 +20,7 @@ type Scaler struct { func NewScaler(pool *Pool) *Scaler { s := &Scaler{ pool: pool, - backoff: pool.options.ServerIdleTimeout, + backoff: pool.options.ServerReconnectInitialTime, } if pool.options.ServerIdleTimeout != 0 { diff --git a/lib/gat/pools.go b/lib/gat/pools.go index 850cc4b4c54a4c78e7334e3c35fa09c95e170862..3096dee1f437d55e45a3086b27a1de3bb94f9fb2 100644 --- a/lib/gat/pools.go +++ b/lib/gat/pools.go @@ -3,84 +3,10 @@ package gat import ( "pggat/lib/gat/metrics" "pggat/lib/gat/pool" - "pggat/lib/util/maps" ) type Pools interface { Lookup(user, database string) *pool.Pool ReadMetrics(metrics *metrics.Pools) - - // Key based lookup functions (for cancellation) - - RegisterKey(key [8]byte, user, database string) - UnregisterKey(key [8]byte) - - LookupKey(key [8]byte) *pool.Pool -} - -type mapKey struct { - User string - Database string -} - -type PoolsMap struct { - pools maps.RWLocked[mapKey, *pool.Pool] - keys maps.RWLocked[[8]byte, mapKey] -} - -func (T *PoolsMap) Add(user, database string, pool *pool.Pool) { - T.pools.Store(mapKey{ - User: user, - Database: database, - }, pool) -} - -func (T *PoolsMap) Remove(user, database string) *pool.Pool { - p, _ := T.pools.LoadAndDelete(mapKey{ - User: user, - Database: database, - }) - return p -} - -func (T *PoolsMap) Lookup(user, database string) *pool.Pool { - p, _ := T.pools.Load(mapKey{ - User: user, - Database: database, - }) - return p -} - -func (T *PoolsMap) ReadMetrics(metrics *metrics.Pools) { - T.pools.Range(func(_ mapKey, p *pool.Pool) bool { - p.ReadMetrics(&metrics.Pool) - return true - }) -} - -// key based lookup funcs - -func (T *PoolsMap) RegisterKey(key [8]byte, user, database string) { - T.keys.Store(key, mapKey{ - User: user, - Database: database, - }) -} - -func (T *PoolsMap) UnregisterKey(key [8]byte) { - T.keys.Delete(key) -} - -func (T *PoolsMap) LookupKey(key [8]byte) *pool.Pool { - m, ok := T.keys.Load(key) - if !ok { - return nil - } - p, ok := T.pools.Load(m) - if !ok { - T.keys.Delete(key) - return nil - } - return p } diff --git a/lib/gat/poolsmap.go b/lib/gat/poolsmap.go new file mode 100644 index 0000000000000000000000000000000000000000..d3bf50cd24bc90dcbf27f1f4ba5f6c282805d9b4 --- /dev/null +++ b/lib/gat/poolsmap.go @@ -0,0 +1,46 @@ +package gat + +import ( + "pggat/lib/gat/metrics" + "pggat/lib/gat/pool" + "pggat/lib/util/maps" +) + +type mapKey struct { + User string + Database string +} + +type PoolsMap struct { + pools maps.RWLocked[mapKey, *pool.Pool] +} + +func (T *PoolsMap) Add(user, database string, pool *pool.Pool) { + T.pools.Store(mapKey{ + User: user, + Database: database, + }, pool) +} + +func (T *PoolsMap) Remove(user, database string) *pool.Pool { + p, _ := T.pools.LoadAndDelete(mapKey{ + User: user, + Database: database, + }) + return p +} + +func (T *PoolsMap) Lookup(user, database string) *pool.Pool { + p, _ := T.pools.Load(mapKey{ + User: user, + Database: database, + }) + return p +} + +func (T *PoolsMap) ReadMetrics(metrics *metrics.Pools) { + T.pools.Range(func(_ mapKey, p *pool.Pool) bool { + p.ReadMetrics(&metrics.Pool) + return true + }) +} diff --git a/lib/gsql/client.go b/lib/gsql/client.go index 436acd50d24dacc93c8a5a4e0aa84e29c8b9596f..d119c9e461f90141b32a0e411361dad57b47d5cb 100644 --- a/lib/gsql/client.go +++ b/lib/gsql/client.go @@ -7,6 +7,7 @@ import ( "pggat/lib/fed" "pggat/lib/util/ring" + "pggat/lib/util/slices" ) type batch struct { @@ -57,7 +58,9 @@ func (T *Client) queueNext() bool { return false } -func (T *Client) ReadPacket(typed bool) (fed.Packet, error) { +func (T *Client) ReadPacket(typed bool, buffer fed.Packet) (packet fed.Packet, err error) { + packet = buffer + T.mu.Lock() defer T.mu.Unlock() @@ -75,7 +78,8 @@ func (T *Client) ReadPacket(typed bool) (fed.Packet, error) { } if T.closed { - return nil, io.EOF + err = io.EOF + return } if T.readC == nil { @@ -85,10 +89,13 @@ func (T *Client) ReadPacket(typed bool) (fed.Packet, error) { } if (p.Type() == 0 && typed) || (p.Type() != 0 && !typed) { - return nil, ErrTypedMismatch + err = ErrTypedMismatch + return } - return p, nil + packet = slices.Resize(packet, len(p)) + copy(packet, p) + return } func (T *Client) WritePacket(packet fed.Packet) error { diff --git a/lib/gsql/eq.go b/lib/gsql/eq.go index f03cfe1d92d5c10f821d0f5457e16c38d4b5f1c7..9097b86d93aec0cfd41d9a5820362900dc7490dd 100644 --- a/lib/gsql/eq.go +++ b/lib/gsql/eq.go @@ -20,7 +20,7 @@ func ExtendedQuery(client *Client, result any, query string, args ...any) error parse := packets.Parse{ Query: query, } - pkts = append(pkts, parse.IntoPacket()) + pkts = append(pkts, parse.IntoPacket(nil)) // bind params := make([][]byte, 0, len(args)) @@ -60,17 +60,17 @@ outer: bind := packets.Bind{ ParameterValues: params, } - pkts = append(pkts, bind.IntoPacket()) + pkts = append(pkts, bind.IntoPacket(nil)) // describe describe := packets.Describe{ Which: 'P', } - pkts = append(pkts, describe.IntoPacket()) + pkts = append(pkts, describe.IntoPacket(nil)) // execute execute := packets.Execute{} - pkts = append(pkts, execute.IntoPacket()) + pkts = append(pkts, execute.IntoPacket(nil)) // sync sync := fed.NewPacket(packets.TypeSync) diff --git a/lib/gsql/query.go b/lib/gsql/query.go index 518b55869ae67508d3b72dd60160cf3dc0e996c6..28171fb85151250ff80a37b4c145a53feeed4825 100644 --- a/lib/gsql/query.go +++ b/lib/gsql/query.go @@ -8,7 +8,7 @@ import ( func Query(client *Client, results []any, query string) { var q = packets.Query(query) - client.Do(NewQueryWriter(results...), q.IntoPacket()) + client.Do(NewQueryWriter(results...), q.IntoPacket(nil)) } type QueryWriter struct { diff --git a/lib/gsql/query_test.go b/lib/gsql/query_test.go index 9de3ed4d0e795c572edd089905d14e29fdd7752a..36c1662b437d04793c516526887efb878cb8de7c 100644 --- a/lib/gsql/query_test.go +++ b/lib/gsql/query_test.go @@ -13,8 +13,8 @@ import ( ) type Result struct { - Username string `sql:"0"` - Password *string `sql:"1"` + Username string `sql:"0"` + Password string `sql:"1"` } func TestQuery(t *testing.T) { @@ -25,13 +25,18 @@ func TestQuery(t *testing.T) { return } server := fed.WrapNetConn(s) - _, err = backends.Accept(server, backends.AcceptOptions{ - Credentials: credentials.Cleartext{ + ctx := backends.AcceptContext{ + Conn: server, + Options: backends.AcceptOptions{ Username: "postgres", - Password: "password", + Credentials: credentials.Cleartext{ + Username: "postgres", + Password: "password", + }, + Database: "postgres", }, - Database: "postgres", - }) + } + _, err = backends.Accept(&ctx) if err != nil { t.Error(err) return @@ -39,7 +44,7 @@ func TestQuery(t *testing.T) { var res Result client := new(Client) - err = ExtendedQuery(client, &res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "bob") + err = ExtendedQuery(client, &res, "SELECT usename, passwd FROM pg_shadow WHERE usename=$1", "postgres") if err != nil { t.Error(err) return @@ -49,11 +54,12 @@ func TestQuery(t *testing.T) { t.Error(err) } - initial, err := client.ReadPacket(true) + var initial fed.Packet + initial, err = client.ReadPacket(true, initial) if err != nil { t.Error(err) } - clientErr, serverErr := bouncers.Bounce(client, server, initial) + _, clientErr, serverErr := bouncers.Bounce(client, server, initial) if clientErr != nil { t.Error(clientErr) } diff --git a/lib/middleware/interceptor/interceptor.go b/lib/middleware/interceptor/interceptor.go index b87616f32c5ac94139bdd8a2dbd379aa32ee2c70..338e953ef09320d563f6912f57ccfb741752f11c 100644 --- a/lib/middleware/interceptor/interceptor.go +++ b/lib/middleware/interceptor/interceptor.go @@ -23,19 +23,20 @@ func NewInterceptor(rw fed.Conn, middlewares ...middleware.Middleware) *Intercep } } -func (T *Interceptor) ReadPacket(typed bool) (fed.Packet, error) { +func (T *Interceptor) ReadPacket(typed bool, packet fed.Packet) (fed.Packet, error) { outer: for { - packet, err := T.rw.ReadPacket(typed) + var err error + packet, err = T.rw.ReadPacket(typed, packet) if err != nil { - return nil, err + return packet, err } for _, mw := range T.middlewares { T.context.reset() err = mw.Read(&T.context, packet) if err != nil { - return nil, err + return packet, err } if T.context.cancelled { continue outer diff --git a/lib/middleware/middlewares/eqp/state.go b/lib/middleware/middlewares/eqp/state.go index 7079c6b122126dc285384921071033c1dfc98a48..dfa0607e2503945328548af35e68235872845845 100644 --- a/lib/middleware/middlewares/eqp/state.go +++ b/lib/middleware/middlewares/eqp/state.go @@ -1,6 +1,7 @@ package eqp import ( + "bytes" "hash/maphash" "pggat/lib/fed" @@ -24,7 +25,7 @@ func MakePreparedStatement(packet fed.Packet) PreparedStatement { var res PreparedStatement packet.ReadString(&res.Target) - res.Packet = packet + res.Packet = bytes.Clone(packet) res.Hash = maphash.Bytes(seed, packet.Payload()) return res @@ -42,7 +43,7 @@ func MakePortal(packet fed.Packet) Portal { var res Portal packet.ReadString(&res.Target) - res.Packet = packet + res.Packet = bytes.Clone(packet) return res } diff --git a/lib/middleware/middlewares/eqp/sync.go b/lib/middleware/middlewares/eqp/sync.go index 818b859979c3dedfcfe83e58975dbf33c06c3384..2685b65b7e7dd9df884e3891056992de746decb2 100644 --- a/lib/middleware/middlewares/eqp/sync.go +++ b/lib/middleware/middlewares/eqp/sync.go @@ -16,12 +16,15 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { needsBackendSync = true } + var packet fed.Packet + for name := range s.state.portals { p := packets.Close{ Which: 'P', Target: name, } - if err := server.WritePacket(p.IntoPacket()); err != nil { + packet = p.IntoPacket(packet) + if err := server.WritePacket(packet); err != nil { return err } } @@ -44,7 +47,8 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { Which: 'S', Target: name, } - if err := server.WritePacket(p.IntoPacket()); err != nil { + packet = p.IntoPacket(packet) + if err := server.WritePacket(packet); err != nil { return err } @@ -79,7 +83,11 @@ func Sync(c *Client, server fed.ReadWriter, s *Server) error { } if needsBackendSync { - _, err := backends.Sync(new(backends.Context), server) + ctx := backends.Context{ + Packet: packet, + Server: server, + } + _, err := backends.Sync(&ctx) return err } diff --git a/lib/middleware/middlewares/ps/sync.go b/lib/middleware/middlewares/ps/sync.go index 3dc6ca50318a76506d5138b3d673e6fff4bdc773..006df33e11808640b7bd5d265da281dcc377044e 100644 --- a/lib/middleware/middlewares/ps/sync.go +++ b/lib/middleware/middlewares/ps/sync.go @@ -12,13 +12,16 @@ func sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server value, hasValue := c.parameters[name] expected, hasExpected := s.parameters[name] + var packet fed.Packet + if value == expected { if !c.synced { ps := packets.ParameterStatus{ Key: name.String(), Value: expected, } - if err := client.WritePacket(ps.IntoPacket()); err != nil { + packet = ps.IntoPacket(packet) + if err := client.WritePacket(packet); err != nil { return err } } @@ -28,9 +31,14 @@ func sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server var doSet bool if hasValue && slices.Contains(tracking, name) { - if err := backends.SetParameter(&backends.Context{}, server, name, value); err != nil { + ctx := backends.Context{ + Packet: packet, + Server: server, + } + if err := backends.SetParameter(&ctx, name, value); err != nil { return err } + packet = ctx.Packet if s.parameters == nil { s.parameters = make(map[strutil.CIString]string) } @@ -46,7 +54,8 @@ func sync(tracking []strutil.CIString, client fed.ReadWriter, c *Client, server Key: name.String(), Value: expected, } - if err := client.WritePacket(ps.IntoPacket()); err != nil { + packet = ps.IntoPacket(packet) + if err := client.WritePacket(packet); err != nil { return err } } diff --git a/lib/util/slices/index.go b/lib/util/slices/index.go new file mode 100644 index 0000000000000000000000000000000000000000..dc2af77934c636bc1679d9e7c0665f7afd6c914f --- /dev/null +++ b/lib/util/slices/index.go @@ -0,0 +1,10 @@ +package slices + +func Index[T comparable](haystack []T, needle T) int { + for i, v := range haystack { + if needle == v { + return i + } + } + return -1 +} diff --git a/lib/util/slices/remove.go b/lib/util/slices/remove.go index e0c239a653295225d72f58619285524a12bd6131..1848711f5065da4f78b972da1e0675d17ecd24bc 100644 --- a/lib/util/slices/remove.go +++ b/lib/util/slices/remove.go @@ -4,26 +4,22 @@ package slices // with length-1. The original slice will contain all items (though in a different order), and the new slice will contain all // but item. func Remove[T comparable](slice []T, item T) []T { - for i, s := range slice { - if s == item { - copy(slice[i:], slice[i+1:]) - slice[len(slice)-1] = item - return slice[:len(slice)-1] - } + i := Index(slice, item) + if i == -1 { + return slice } - - return slice + copy(slice[i:], slice[i+1:]) + slice[len(slice)-1] = item + return slice[:len(slice)-1] } // Delete is similar to Remove but leaves a *new(T) in the old slice, allowing the value to be GC'd func Delete[T comparable](slice []T, item T) []T { - for i, s := range slice { - if s == item { - copy(slice[i:], slice[i+1:]) - slice[len(slice)-1] = *new(T) - return slice[:len(slice)-1] - } + i := Index(slice, item) + if i == -1 { + return slice } - - return slice + copy(slice[i:], slice[i+1:]) + slice[len(slice)-1] = *new(T) + return slice[:len(slice)-1] } diff --git a/lib/util/slices/sorted.go b/lib/util/slices/sorted.go new file mode 100644 index 0000000000000000000000000000000000000000..2c8a4eb87414d89701e7aa517d0cd210d0336672 --- /dev/null +++ b/lib/util/slices/sorted.go @@ -0,0 +1,55 @@ +package slices + +// Sorted is a sorted slice. As long as all items are inserted by Insert, updated by Update, and removed by Delete, +// this slice will stay sorted +type Sorted[V any] []V + +func (T Sorted[V]) Insert(value V, sorter func(V) int) Sorted[V] { + key := sorter(value) + for i, v := range T { + if sorter(v) < key { + continue + } + + res := append(T, *new(V)) + copy(res[i+1:], res[i:]) + res[i] = value + return res + } + + return append(T, value) +} + +func (T Sorted[V]) Update(index int, sorter func(V) int) { + value := T[index] + key := sorter(value) + + for i, v := range T { + switch { + case i < index: + if sorter(v) < key { + continue + } + + // move all up by one, move from index to i + copy(T[i+1:], T[i:index]) + T[i] = value + return + case i > index: + if sorter(v) < key { + continue + } + + // move all down by one, move from index to i + copy(T[index:], T[index+1:i]) + T[i-1] = value + return + default: + continue + } + } + + // move all down by one, move from index to i + copy(T[index:], T[index+1:]) + T[len(T)-1] = value +} diff --git a/lib/util/slices/sorted_test.go b/lib/util/slices/sorted_test.go new file mode 100644 index 0000000000000000000000000000000000000000..49e19e073bf939cd94dfaaed2f793cfde6936b31 --- /dev/null +++ b/lib/util/slices/sorted_test.go @@ -0,0 +1,79 @@ +package slices + +import ( + "sort" + "testing" + + "tuxpa.in/a/zlog/log" +) + +func TestSorted_Insert(t *testing.T) { + sorter := func(v string) int { + return len(v) + } + + expected := []string{ + "test", + "abc", + "this is a long string", + "gjkdfjgksg", + "retre", + "abd", + "def", + "ttierotiretiiret34t43t34534", + } + + var x Sorted[string] + for _, v := range expected { + x = x.Insert(v, sorter) + } + + if !sort.SliceIsSorted(x, func(i, j int) bool { + return sorter(x[i]) < sorter(x[j]) + }) { + t.Errorf("slice isn't sorted: %#v", x) + } +} + +func TestSorted_Update(t *testing.T) { + values := map[string]int{ + "abc": 43, + "def": 32, + "cool": 594390069, + "amazing": -432, + "i hope this works": 32, + } + + sorter := func(v string) int { + return values[v] + } + + var x Sorted[string] + for v := range values { + x = x.Insert(v, sorter) + } + + if !sort.SliceIsSorted(x, func(i, j int) bool { + return sorter(x[i]) < sorter(x[j]) + }) { + t.Errorf("slice isn't sorted: %#v", x) + } + + log.Printf("%#v", x) + + values["cool"] = -10 + x.Update(Index(x, "cool"), sorter) + values["amazing"] = 543543 + x.Update(Index(x, "amazing"), sorter) + x.Update(Index(x, "abc"), sorter) + values["i hope this works"] = 44 + x.Update(Index(x, "i hope this works"), sorter) + values["abc"] = 31 + x.Update(Index(x, "abc"), sorter) + + if !sort.SliceIsSorted(x, func(i, j int) bool { + return sorter(x[i]) < sorter(x[j]) + }) { + t.Errorf("slice isn't sorted: %#v", x) + } +} diff --git a/test/capturer.go b/test/capturer.go index cc301226b82b5513ddce95a696e9be0f62c0f839..c319ed2728ce8c8febb412ad7270b9a2724cfed2 100644 --- a/test/capturer.go +++ b/test/capturer.go @@ -13,7 +13,7 @@ type Capturer struct { } func (T *Capturer) WritePacket(packet fed.Packet) error { - T.Packets = append(T.Packets, packet) + T.Packets = append(T.Packets, bytes.Clone(packet)) return nil } diff --git a/test/runner.go b/test/runner.go index 4cd6ce0f8f1b44c49193f0fada6cd6672b0fdb2a..4267435782bc965ac263dfb749b1777fdebf1ff3 100644 --- a/test/runner.go +++ b/test/runner.go @@ -34,7 +34,7 @@ func (T *Runner) prepare(client *gsql.Client, until int) []Capturer { switch v := x.(type) { case inst.SimpleQuery: q := packets.Query(v) - client.Do(&results[i], q.IntoPacket()) + client.Do(&results[i], q.IntoPacket(nil)) case inst.Sync: client.Do(&results[i], fed.NewPacket(packets.TypeSync)) case inst.Parse: @@ -42,45 +42,45 @@ func (T *Runner) prepare(client *gsql.Client, until int) []Capturer { Destination: v.Destination, Query: v.Query, } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.Bind: p := packets.Bind{ Destination: v.Destination, Source: v.Source, } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.DescribePortal: p := packets.Describe{ Which: 'P', Target: string(v), } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.DescribePreparedStatement: p := packets.Describe{ Which: 'S', Target: string(v), } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.Execute: p := packets.Execute{ Target: string(v), } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.ClosePortal: p := packets.Close{ Which: 'P', Target: string(v), } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.ClosePreparedStatement: p := packets.Close{ Which: 'S', Target: string(v), } - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.CopyData: p := packets.CopyData(v) - client.Do(&results[i], p.IntoPacket()) + client.Do(&results[i], p.IntoPacket(nil)) case inst.CopyDone: client.Do(&results[i], fed.NewPacket(packets.TypeCopyDone)) } @@ -100,7 +100,7 @@ func (T *Runner) runModeL1(dialer dialer.Dialer, client *gsql.Client) error { for { var p fed.Packet - p, err = client.ReadPacket(true) + p, err = client.ReadPacket(true, p) if err != nil { if errors.Is(err, io.EOF) { break @@ -108,7 +108,7 @@ func (T *Runner) runModeL1(dialer dialer.Dialer, client *gsql.Client) error { return err } - clientErr, serverErr := bouncers.Bounce(client, server, p) + _, clientErr, serverErr := bouncers.Bounce(client, server, p) if clientErr != nil { return clientErr } diff --git a/test/tester_test.go b/test/tester_test.go index 3555fa5c9066c487205dea478054a536a5bbf3b5..7aacba13b78d48e7897a33cb2c68983c7590e4ff 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -6,6 +6,9 @@ import ( "fmt" "net" _ "net/http/pprof" + "strconv" + "testing" + "pggat/lib/auth" "pggat/lib/auth/credentials" "pggat/lib/bouncer/backends/v0" @@ -18,8 +21,6 @@ import ( "pggat/lib/gat/pool/recipe" "pggat/test" "pggat/test/tests" - "strconv" - "testing" ) func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, error) { @@ -49,7 +50,7 @@ func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, port := listener.Listener.Addr().(*net.TCPAddr).Port go func() { - err := gat.Serve(listener, &g) + err := gat.Serve(listener, gat.NewKeyedPools(&g)) if err != nil { panic(err) } @@ -59,6 +60,7 @@ func daisyChain(creds auth.Credentials, control dialer.Net, n int) (dialer.Net, Network: "tcp", Address: ":" + strconv.Itoa(port), AcceptOptions: backends.AcceptOptions{ + Username: "runner", Credentials: creds, Database: "pool", }, @@ -73,6 +75,7 @@ func TestTester(t *testing.T) { Network: "tcp", Address: "localhost:5432", AcceptOptions: backends.AcceptOptions{ + Username: "postgres", Credentials: credentials.Cleartext{ Username: "postgres", Password: "password", @@ -127,7 +130,7 @@ func TestTester(t *testing.T) { port := listener.Listener.Addr().(*net.TCPAddr).Port go func() { - err := gat.Serve(listener, &g) + err := gat.Serve(listener, gat.NewKeyedPools(&g)) if err != nil { t.Error(err) } @@ -137,6 +140,7 @@ func TestTester(t *testing.T) { Network: "tcp", Address: ":" + strconv.Itoa(port), AcceptOptions: backends.AcceptOptions{ + Username: "runner", Credentials: creds, Database: "transaction", }, @@ -145,6 +149,7 @@ func TestTester(t *testing.T) { Network: "tcp", Address: ":" + strconv.Itoa(port), AcceptOptions: backends.AcceptOptions{ + Username: "runner", Credentials: creds, Database: "session", },