diff --git a/cmd/caddygat/main.go b/cmd/caddygat/main.go index 89d5bbb2e206008d301adb56e6c1571ca0bb4c4f..775c9e1504af14dd82da6544fb16555581d3ef83 100644 --- a/cmd/caddygat/main.go +++ b/cmd/caddygat/main.go @@ -4,6 +4,12 @@ import ( caddycmd "github.com/caddyserver/caddy/v2/cmd" _ "gfx.cafe/gfx/pggat/contrib/caddy" + _ "gfx.cafe/gfx/pggat/lib/gat/modules/cloud_sql_discovery" + _ "gfx.cafe/gfx/pggat/lib/gat/modules/digitalocean_discovery" + _ "gfx.cafe/gfx/pggat/lib/gat/modules/pgbouncer" + _ "gfx.cafe/gfx/pggat/lib/gat/modules/ssl_endpoint" + _ "gfx.cafe/gfx/pggat/lib/gat/modules/zalando" + _ "gfx.cafe/gfx/pggat/lib/gat/modules/zalando_operator_discovery" ) func main() { diff --git a/cmd/cgat/main.go b/cmd/cgat/main.go index 3963519cda736d7e97a9417e97d71e339dc6b256..4bc38292811216020c39cb94f5985ecc34b72cbf 100644 --- a/cmd/cgat/main.go +++ b/cmd/cgat/main.go @@ -5,6 +5,8 @@ import ( "net/http" _ "net/http/pprof" "os" + "os/signal" + "syscall" "time" "tuxpa.in/a/zlog/log" @@ -26,37 +28,49 @@ func loadModule(mode string) (gat.Module, error) { if err != nil { return nil, err } - return pgbouncer.NewModule(conf) + return &pgbouncer.Module{ + Config: conf, + }, nil case "pgbouncer": conf, err := pgbouncer.Load(os.Args[1]) if err != nil { return nil, err } - return pgbouncer.NewModule(conf) + return &pgbouncer.Module{ + Config: conf, + }, nil case "pgbouncer_spilo": conf, err := zalando.Load() if err != nil { return nil, err } - return zalando.NewModule(conf) + return &zalando.Module{ + Config: conf, + }, nil case "zalando_kubernetes_operator": conf, err := zalando_operator_discovery.Load() if err != nil { return nil, err } - return zalando_operator_discovery.NewModule(conf) + return &zalando_operator_discovery.Module{ + Config: conf, + }, nil case "google_cloud_sql": conf, err := cloud_sql_discovery.Load() if err != nil { return nil, err } - return cloud_sql_discovery.NewModule(conf) + return &cloud_sql_discovery.Module{ + Config: conf, + }, nil case "digitalocean_databases": conf, err := digitalocean_discovery.Load() if err != nil { return nil, err } - return digitalocean_discovery.NewModule(conf) + return &digitalocean_discovery.Module{ + Config: conf, + }, nil default: return nil, errors.New("Unknown PGGAT_RUN_MODE: " + mode) } @@ -75,6 +89,22 @@ func main() { log.Printf("Starting pggat (%s)...", runMode) var server gat.Server + defer func() { + if err := server.Stop(); err != nil { + log.Printf("error stopping: %v", err) + } + }() + + c := make(chan os.Signal, 2) + signal.Notify(c, os.Interrupt, syscall.SIGTERM) + + go func() { + <-c + + if err := server.Stop(); err != nil { + log.Printf("error stopping: %v", err) + } + }() // load and add main module module, err := loadModule(runMode) @@ -84,11 +114,7 @@ func main() { server.AddModule(module) // back up ssl endpoint (for modules that don't have endpoints by default such as discovery) - ep, err := ssl_endpoint.NewModule() - if err != nil { - panic(err) - } - server.AddModule(ep) + server.AddModule(&ssl_endpoint.Module{}) go func() { var m metrics.Server @@ -100,7 +126,7 @@ func main() { } }() - err = server.ListenAndServe() + err = server.Start() if err != nil { panic(err) } diff --git a/lib/gat/listener.go b/lib/gat/exposed.go similarity index 64% rename from lib/gat/listener.go rename to lib/gat/exposed.go index 71893dbab6362b65b1fce73af27ccd8384fd7370..fd1556207a408b1cf933602253e450d62b8f46b4 100644 --- a/lib/gat/listener.go +++ b/lib/gat/exposed.go @@ -1,6 +1,6 @@ package gat -type Listener interface { +type Exposed interface { Module Endpoints() []Endpoint diff --git a/lib/gat/modules/cloud_sql_discovery/module.go b/lib/gat/modules/cloud_sql_discovery/module.go index 3cee645f095d1a240925b7da024fa3767904a6b5..d904096f2c59cdeb69996bd1090f1814c4b94443 100644 --- a/lib/gat/modules/cloud_sql_discovery/module.go +++ b/lib/gat/modules/cloud_sql_discovery/module.go @@ -9,29 +9,38 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func NewModule(config Config) (*discovery.Module, error) { - d, err := NewDiscoverer(config) +type Module struct { + Config + + discovery.Module `json:"-"` +} + +func (T *Module) Start() error { + d, err := NewDiscoverer(T.Config) if err != nil { - return nil, err + return err } - return discovery.NewModule(discovery.Config{ - ReconcilePeriod: 5 * time.Minute, - Discoverer: d, - ServerSSLMode: bouncer.SSLModePrefer, - ServerSSLConfig: &tls.Config{ - InsecureSkipVerify: true, + T.Module = discovery.Module{ + Config: discovery.Config{ + ReconcilePeriod: 5 * time.Minute, + Discoverer: d, + ServerSSLMode: bouncer.SSLModePrefer, + ServerSSLConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + ServerReconnectInitialTime: 5 * time.Second, + ServerReconnectMaxTime: 5 * time.Second, + ServerIdleTimeout: 5 * time.Minute, + TrackedParameters: []strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, + PoolMode: "transaction", // TODO(garet) }, - ServerReconnectInitialTime: 5 * time.Second, - ServerReconnectMaxTime: 5 * time.Second, - ServerIdleTimeout: 5 * time.Minute, - TrackedParameters: []strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), - }, - PoolMode: "transaction", // TODO(garet) - }) + } + return T.Module.Start() } diff --git a/lib/gat/modules/digitalocean_discovery/module.go b/lib/gat/modules/digitalocean_discovery/module.go index 24dc545c2337b33f84642704dba93514af4cedff..d004dd40a4e28735f56e423cc39d9062a9428eec 100644 --- a/lib/gat/modules/digitalocean_discovery/module.go +++ b/lib/gat/modules/digitalocean_discovery/module.go @@ -9,29 +9,38 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func NewModule(config Config) (*discovery.Module, error) { - d, err := NewDiscoverer(config) +type Module struct { + Config + + discovery.Module `json:"-"` +} + +func (T *Module) Start() error { + d, err := NewDiscoverer(T.Config) if err != nil { - return nil, err + return err } - return discovery.NewModule(discovery.Config{ - ReconcilePeriod: 5 * time.Minute, - Discoverer: d, - ServerSSLMode: bouncer.SSLModeRequire, - ServerSSLConfig: &tls.Config{ - InsecureSkipVerify: true, + T.Module = discovery.Module{ + Config: discovery.Config{ + ReconcilePeriod: 5 * time.Minute, + Discoverer: d, + ServerSSLMode: bouncer.SSLModeRequire, + ServerSSLConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + ServerReconnectInitialTime: 5 * time.Second, + ServerReconnectMaxTime: 5 * time.Second, + ServerIdleTimeout: 5 * time.Minute, + TrackedParameters: []strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, + PoolMode: "transaction", // TODO(garet) }, - ServerReconnectInitialTime: 5 * time.Second, - ServerReconnectMaxTime: 5 * time.Second, - ServerIdleTimeout: 5 * time.Minute, - TrackedParameters: []strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), - }, - PoolMode: "transaction", // TODO(garet) - }) + } + return T.Module.Start() } diff --git a/lib/gat/modules/discovery/module.go b/lib/gat/modules/discovery/module.go index 8962fa8896fa9bd629411d1ae5e0cc7252cece95..5a9068752ef3279d1e92188d4cd49300a80e60a2 100644 --- a/lib/gat/modules/discovery/module.go +++ b/lib/gat/modules/discovery/module.go @@ -1,6 +1,7 @@ package discovery import ( + "errors" "sync" "time" @@ -19,7 +20,9 @@ import ( ) type Module struct { - config Config + Config + + closed chan struct{} // this is fine to have no locking because it is only accessed by discoverLoop clusters map[string]Cluster @@ -28,15 +31,25 @@ type Module struct { mu sync.RWMutex } -func NewModule(config Config) (*Module, error) { - m := &Module{ - config: config, +func (T *Module) Start() error { + if T.closed != nil { + return errors.New("start called multiple times") + } + T.closed = make(chan struct{}) + + if err := T.reconcile(); err != nil { + return err } - if err := m.reconcile(); err != nil { - return nil, err + go T.discoverLoop() + return nil +} + +func (T *Module) Stop() error { + if T.closed == nil { + return errors.New("discoverer not running") } - go m.discoverLoop() - return m, nil + close(T.closed) + return nil } func (T *Module) replicaUsername(username string) string { @@ -51,32 +64,32 @@ func (T *Module) creds(user User) (primary, replica auth.Credentials) { func (T *Module) backendAcceptOptions(username string, creds auth.Credentials, database string) recipe.BackendAcceptOptions { return recipe.BackendAcceptOptions{ - SSLMode: T.config.ServerSSLMode, - SSLConfig: T.config.ServerSSLConfig, + SSLMode: T.ServerSSLMode, + SSLConfig: T.ServerSSLConfig, Username: username, Credentials: creds, Database: database, - StartupParameters: T.config.ServerStartupParameters, + StartupParameters: T.ServerStartupParameters, } } func (T *Module) poolOptions(creds auth.Credentials) pool.Options { options := pool.Options{ Credentials: creds, - ServerReconnectInitialTime: T.config.ServerReconnectInitialTime, - ServerReconnectMaxTime: T.config.ServerReconnectMaxTime, - ServerIdleTimeout: T.config.ServerIdleTimeout, - TrackedParameters: T.config.TrackedParameters, - ServerResetQuery: T.config.ServerResetQuery, + ServerReconnectInitialTime: T.ServerReconnectInitialTime, + ServerReconnectMaxTime: T.ServerReconnectMaxTime, + ServerIdleTimeout: T.ServerIdleTimeout, + TrackedParameters: T.TrackedParameters, + ServerResetQuery: T.ServerResetQuery, } - switch T.config.PoolMode { + switch T.PoolMode { case "session": options = session.Apply(options) case "transaction": options = transaction.Apply(options) default: - log.Printf("unknown pool mode: %s", T.config.PoolMode) + log.Printf("unknown pool mode: %s", T.PoolMode) } return options @@ -385,7 +398,7 @@ func (T *Module) removed(id string) { } func (T *Module) reconcile() error { - clusters, err := T.config.Discoverer.Clusters() + clusters, err := T.Discoverer.Clusters() if err != nil { return err } @@ -415,19 +428,19 @@ outer: func (T *Module) discoverLoop() { var reconcile <-chan time.Time - if T.config.ReconcilePeriod != 0 { - r := time.NewTicker(T.config.ReconcilePeriod) + if T.ReconcilePeriod != 0 { + r := time.NewTicker(T.ReconcilePeriod) defer r.Stop() reconcile = r.C } for { select { - case cluster := <-T.config.Discoverer.Added(): + case cluster := <-T.Discoverer.Added(): T.added(cluster) - case id := <-T.config.Discoverer.Removed(): + case id := <-T.Discoverer.Removed(): T.removed(id) - case next := <-T.config.Discoverer.Updated(): + case next := <-T.Discoverer.Updated(): T.updated(T.clusters[next.ID], next) case <-reconcile: err := T.reconcile() diff --git a/lib/gat/modules/pgbouncer/module.go b/lib/gat/modules/pgbouncer/module.go index 1acb9c836fc2a8bf801f773fc3165cf39572faf3..53ab543ed2ac364a41888e92d0e6153baeed8885 100644 --- a/lib/gat/modules/pgbouncer/module.go +++ b/lib/gat/modules/pgbouncer/module.go @@ -31,29 +31,23 @@ type authQueryResult struct { } type Module struct { - config Config + Config pools maps.TwoKey[string, string, *gat.Pool] } -func NewModule(config Config) (*Module, error) { - return &Module{ - config: config, - }, nil -} - func (T *Module) getPassword(user, database string) (string, bool) { // try to get password - password, ok := T.config.PgBouncer.AuthFile[user] + password, ok := T.PgBouncer.AuthFile[user] if !ok { // try to run auth query - if T.config.PgBouncer.AuthQuery == "" { + if T.PgBouncer.AuthQuery == "" { return "", false } - authUser := T.config.Databases[database].AuthUser + authUser := T.Databases[database].AuthUser if authUser == "" { - authUser = T.config.PgBouncer.AuthUser + authUser = T.PgBouncer.AuthUser if authUser == "" { return "", false } @@ -66,7 +60,7 @@ func (T *Module) getPassword(user, database string) (string, bool) { var result authQueryResult client := new(gsql.Client) - err := gsql.ExtendedQuery(client, &result, T.config.PgBouncer.AuthQuery, user) + err := gsql.ExtendedQuery(client, &result, T.PgBouncer.AuthQuery, user) if err != nil { log.Println("auth query failed:", err) return "", false @@ -94,10 +88,10 @@ func (T *Module) getPassword(user, database string) (string, bool) { } func (T *Module) tryCreate(user, database string) *gat.Pool { - db, ok := T.config.Databases[database] + db, ok := T.Databases[database] if !ok { // try wildcard - db, ok = T.config.Databases["*"] + db, ok = T.Databases["*"] if !ok { return nil } @@ -116,13 +110,13 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { serverDatabase = database } - configUser := T.config.Users[user] + configUser := T.Users[user] poolMode := db.PoolMode if poolMode == "" { poolMode = configUser.PoolMode if poolMode == "" { - poolMode = T.config.PgBouncer.PoolMode + poolMode = T.PgBouncer.PoolMode } } @@ -132,15 +126,15 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { strutil.MakeCIString("timezone"), strutil.MakeCIString("standard_conforming_strings"), strutil.MakeCIString("application_name"), - }, T.config.PgBouncer.TrackExtraParameters...) + }, T.PgBouncer.TrackExtraParameters...) - serverLoginRetry := time.Duration(T.config.PgBouncer.ServerLoginRetry * float64(time.Second)) + serverLoginRetry := time.Duration(T.PgBouncer.ServerLoginRetry * float64(time.Second)) poolOptions := pool.Options{ Credentials: creds, TrackedParameters: trackedParameters, - ServerResetQuery: T.config.PgBouncer.ServerResetQuery, - ServerIdleTimeout: time.Duration(T.config.PgBouncer.ServerIdleTimeout * float64(time.Second)), + ServerResetQuery: T.PgBouncer.ServerResetQuery, + ServerIdleTimeout: time.Duration(T.PgBouncer.ServerIdleTimeout * float64(time.Second)), ServerReconnectInitialTime: serverLoginRetry, } @@ -148,7 +142,7 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { case PoolModeSession: poolOptions = session.Apply(poolOptions) case PoolModeTransaction: - if T.config.PgBouncer.ServerResetQueryAlways == 0 { + if T.PgBouncer.ServerResetQueryAlways == 0 { poolOptions.ServerResetQuery = "" } poolOptions = transaction.Apply(poolOptions) @@ -168,7 +162,7 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { } acceptOptions := backends.AcceptOptions{ - SSLMode: T.config.PgBouncer.ServerTLSSSLMode, + SSLMode: T.PgBouncer.ServerTLSSSLMode, SSLConfig: &tls.Config{ InsecureSkipVerify: true, // TODO(garet) }, @@ -219,10 +213,10 @@ func (T *Module) tryCreate(user, database string) *gat.Pool { MaxConnections: db.MaxDBConnections, } if recipeOptions.MinConnections == 0 { - recipeOptions.MinConnections = T.config.PgBouncer.MinPoolSize + recipeOptions.MinConnections = T.PgBouncer.MinPoolSize } if recipeOptions.MaxConnections == 0 { - recipeOptions.MaxConnections = T.config.PgBouncer.MaxDBConnections + recipeOptions.MaxConnections = T.PgBouncer.MaxDBConnections } r := recipe.NewRecipe(recipeOptions) @@ -255,12 +249,12 @@ func (T *Module) Endpoints() []gat.Endpoint { strutil.MakeCIString("timezone"), strutil.MakeCIString("standard_conforming_strings"), strutil.MakeCIString("application_name"), - }, T.config.PgBouncer.TrackExtraParameters...) + }, T.PgBouncer.TrackExtraParameters...) - allowedStartupParameters := append(trackedParameters, T.config.PgBouncer.IgnoreStartupParameters...) + allowedStartupParameters := append(trackedParameters, T.PgBouncer.IgnoreStartupParameters...) var sslConfig *tls.Config - if T.config.PgBouncer.ClientTLSCertFile != "" && T.config.PgBouncer.ClientTLSKeyFile != "" { - certificate, err := tls.LoadX509KeyPair(T.config.PgBouncer.ClientTLSCertFile, T.config.PgBouncer.ClientTLSKeyFile) + if T.PgBouncer.ClientTLSCertFile != "" && T.PgBouncer.ClientTLSKeyFile != "" { + certificate, err := tls.LoadX509KeyPair(T.PgBouncer.ClientTLSCertFile, T.PgBouncer.ClientTLSKeyFile) if err != nil { log.Printf("error loading X509 keypair: %v", err) } else { @@ -273,20 +267,20 @@ func (T *Module) Endpoints() []gat.Endpoint { } acceptOptions := frontends.AcceptOptions{ - SSLRequired: T.config.PgBouncer.ClientTLSSSLMode.IsRequired(), + SSLRequired: T.PgBouncer.ClientTLSSSLMode.IsRequired(), SSLConfig: sslConfig, AllowedStartupOptions: allowedStartupParameters, } var endpoints []gat.Endpoint - if T.config.PgBouncer.ListenAddr != "" { - listenAddr := T.config.PgBouncer.ListenAddr + if T.PgBouncer.ListenAddr != "" { + listenAddr := T.PgBouncer.ListenAddr if listenAddr == "*" { listenAddr = "" } - listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.config.PgBouncer.ListenPort)) + listen := net.JoinHostPort(listenAddr, strconv.Itoa(T.PgBouncer.ListenPort)) endpoints = append(endpoints, gat.Endpoint{ Network: "tcp", @@ -296,8 +290,8 @@ func (T *Module) Endpoints() []gat.Endpoint { } // listen on unix socket - dir := T.config.PgBouncer.UnixSocketDir - port := T.config.PgBouncer.ListenPort + dir := T.PgBouncer.UnixSocketDir + port := T.PgBouncer.ListenPort if !strings.HasSuffix(dir, "/") { dir = dir + "/" @@ -317,4 +311,4 @@ func (T *Module) GatModule() {} var _ gat.Module = (*Module)(nil) var _ gat.Provider = (*Module)(nil) -var _ gat.Listener = (*Module)(nil) +var _ gat.Exposed = (*Module)(nil) diff --git a/lib/gat/modules/raw_pools/module.go b/lib/gat/modules/raw_pools/module.go index 14ab7262f5316a602d0f5aa602491f324b1d8309..5272591deaa413e7e040af0fcb901dc663646b58 100644 --- a/lib/gat/modules/raw_pools/module.go +++ b/lib/gat/modules/raw_pools/module.go @@ -14,10 +14,6 @@ type Module struct { mu sync.RWMutex } -func NewModule() (*Module, error) { - return &Module{}, nil -} - func (T *Module) GatModule() {} func (T *Module) Add(user, database string, p *pool.Pool) { diff --git a/lib/gat/modules/ssl_endpoint/module.go b/lib/gat/modules/ssl_endpoint/module.go index 4589cd98db87ad8eac08d75f3fc88c4034ac2c20..5880e684bcb1320ae249c58fd2aab993d17cf90c 100644 --- a/lib/gat/modules/ssl_endpoint/module.go +++ b/lib/gat/modules/ssl_endpoint/module.go @@ -20,10 +20,6 @@ type Module struct { config *tls.Config } -func NewModule() (*Module, error) { - return &Module{}, nil -} - func (T *Module) generateKeys() error { // generate private key priv, err := rsa.GenerateKey(rand.Reader, 2048) @@ -106,4 +102,4 @@ func (T *Module) Endpoints() []gat.Endpoint { } var _ gat.Module = (*Module)(nil) -var _ gat.Listener = (*Module)(nil) +var _ gat.Exposed = (*Module)(nil) diff --git a/lib/gat/modules/zalando/module.go b/lib/gat/modules/zalando/module.go index 061e3ea25ec0e36b9afb3437c3dc4df3e53e7219..0115e362973f42b0097d52c6f9585770c6be0d5a 100644 --- a/lib/gat/modules/zalando/module.go +++ b/lib/gat/modules/zalando/module.go @@ -8,25 +8,31 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func NewModule(config Config) (*pgbouncer.Module, error) { +type Module struct { + Config + + pgbouncer.Module `json:"-"` +} + +func (T *Module) Start() error { pgb := pgbouncer.Default if pgb.Databases == nil { pgb.Databases = make(map[string]pgbouncer.Database) } pgb.Databases["*"] = pgbouncer.Database{ - Host: config.PGHost, - Port: config.PGPort, - AuthUser: config.PGUser, + Host: T.PGHost, + Port: T.PGPort, + AuthUser: T.PGUser, } - pgb.PgBouncer.PoolMode = pgbouncer.PoolMode(config.PoolerMode) - pgb.PgBouncer.ListenPort = config.PoolerPort + pgb.PgBouncer.PoolMode = pgbouncer.PoolMode(T.PoolerMode) + pgb.PgBouncer.ListenPort = T.PoolerPort pgb.PgBouncer.ListenAddr = "*" pgb.PgBouncer.AuthType = "md5" pgb.PgBouncer.AuthFile = pgbouncer.AuthFile{ - config.PGUser: config.PGPassword, + T.PGUser: T.PGPassword, } - pgb.PgBouncer.AdminUsers = []string{config.PGUser} - pgb.PgBouncer.AuthQuery = fmt.Sprintf("SELECT * FROM %s.user_lookup($1)", config.PGSchema) + pgb.PgBouncer.AdminUsers = []string{T.PGUser} + pgb.PgBouncer.AuthQuery = fmt.Sprintf("SELECT * FROM %s.user_lookup($1)", T.PGSchema) pgb.PgBouncer.LogFile = "/var/log/pgbouncer/pgbouncer.log" pgb.PgBouncer.PidFile = "/var/run/pgbouncer/pgbouncer.pid" @@ -42,10 +48,10 @@ func NewModule(config Config) (*pgbouncer.Module, error) { pgb.PgBouncer.LogConnections = 0 pgb.PgBouncer.LogDisconnections = 0 - pgb.PgBouncer.DefaultPoolSize = config.PoolerDefaultSize - pgb.PgBouncer.ReservePoolSize = config.PoolerReserveSize - pgb.PgBouncer.MaxClientConn = config.PoolerMaxClientConn - pgb.PgBouncer.MaxDBConnections = config.PoolerMaxDBConn + pgb.PgBouncer.DefaultPoolSize = T.PoolerDefaultSize + pgb.PgBouncer.ReservePoolSize = T.PoolerReserveSize + pgb.PgBouncer.MaxClientConn = T.PoolerMaxClientConn + pgb.PgBouncer.MaxDBConnections = T.PoolerMaxDBConn pgb.PgBouncer.IdleTransactionTimeout = 600 pgb.PgBouncer.ServerLoginRetry = 5 @@ -54,5 +60,9 @@ func NewModule(config Config) (*pgbouncer.Module, error) { strutil.MakeCIString("options"), } - return pgbouncer.NewModule(pgb) + T.Module = pgbouncer.Module{ + Config: pgb, + } + + return nil } diff --git a/lib/gat/modules/zalando_operator_discovery/module.go b/lib/gat/modules/zalando_operator_discovery/module.go index e3702a15ba0422a879135cbee8d7e115edc97f23..68d2cef6a406dcbc7da91b6aa05b693c54ca4aff 100644 --- a/lib/gat/modules/zalando_operator_discovery/module.go +++ b/lib/gat/modules/zalando_operator_discovery/module.go @@ -9,32 +9,39 @@ import ( "gfx.cafe/gfx/pggat/lib/util/strutil" ) -func NewModule(config Config) (*discovery.Module, error) { - d, err := NewDiscoverer(config) +type Module struct { + Config + + discovery.Module `json:"-"` +} + +func (T *Module) Start() error { + d, err := NewDiscoverer(T.Config) if err != nil { - return nil, err + return err } - m, err := discovery.NewModule(discovery.Config{ - Discoverer: d, - ServerSSLMode: bouncer.SSLModePrefer, - ServerSSLConfig: &tls.Config{ - InsecureSkipVerify: true, - }, - ServerReconnectInitialTime: 5 * time.Second, - ServerReconnectMaxTime: 5 * time.Second, - ServerIdleTimeout: 5 * time.Minute, - // ServerResetQuery: "discard all", - TrackedParameters: []strutil.CIString{ - strutil.MakeCIString("client_encoding"), - strutil.MakeCIString("datestyle"), - strutil.MakeCIString("timezone"), - strutil.MakeCIString("standard_conforming_strings"), - strutil.MakeCIString("application_name"), + + T.Module = discovery.Module{ + Config: discovery.Config{ + Discoverer: d, + ServerSSLMode: bouncer.SSLModePrefer, + ServerSSLConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + ServerReconnectInitialTime: 5 * time.Second, + ServerReconnectMaxTime: 5 * time.Second, + ServerIdleTimeout: 5 * time.Minute, + // ServerResetQuery: "discard all", + TrackedParameters: []strutil.CIString{ + strutil.MakeCIString("client_encoding"), + strutil.MakeCIString("datestyle"), + strutil.MakeCIString("timezone"), + strutil.MakeCIString("standard_conforming_strings"), + strutil.MakeCIString("application_name"), + }, + PoolMode: "transaction", // TODO(garet) pool mode from operator config }, - PoolMode: "transaction", // TODO(garet) pool mode from operator config - }) - if err != nil { - return nil, err } - return m, nil + + return T.Module.Start() } diff --git a/lib/gat/server.go b/lib/gat/server.go index 8c4d587e03af8969691f4bb673a3a760f82cfb54..aebac874e542c491ff06c48b2c87eb5ea8569cc1 100644 --- a/lib/gat/server.go +++ b/lib/gat/server.go @@ -10,7 +10,6 @@ import ( "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/fed" "gfx.cafe/gfx/pggat/lib/gat/metrics" - "gfx.cafe/gfx/pggat/lib/util/beforeexit" "gfx.cafe/gfx/pggat/lib/util/flip" "gfx.cafe/gfx/pggat/lib/util/maps" ) @@ -18,7 +17,11 @@ import ( type Server struct { modules []Module providers []Provider - listeners []Listener + exposed []Exposed + starters []Starter + stoppers []Stopper + + listeners []net.Listener keys maps.RWLocked[[8]byte, *Pool] } @@ -28,8 +31,14 @@ func (T *Server) AddModule(module Module) { if provider, ok := module.(Provider); ok { T.providers = append(T.providers, provider) } - if listener, ok := module.(Listener); ok { - T.listeners = append(T.listeners, listener) + if listener, ok := module.(Exposed); ok { + T.exposed = append(T.exposed, listener) + } + if starter, ok := module.(Starter); ok { + T.starters = append(T.starters, starter) + } + if stopper, ok := module.(Stopper); ok { + T.stoppers = append(T.stoppers, stopper) } } @@ -115,23 +124,19 @@ func (T *Server) accept(raw net.Conn, acceptOptions FrontendAcceptOptions) { } } -func (T *Server) Listen(network, address string) (net.Listener, error) { +func (T *Server) startListening(network, address string) (net.Listener, error) { listener, err := net.Listen(network, address) if err != nil { return nil, err } - if network == "unix" { - beforeexit.Run(func() { - _ = listener.Close() - }) - } + T.listeners = append(T.listeners, listener) log.Printf("listening on %s(%s)", network, address) return listener, nil } -func (T *Server) Serve(listener net.Listener, acceptOptions FrontendAcceptOptions) error { +func (T *Server) listen(listener net.Listener, acceptOptions FrontendAcceptOptions) error { for { raw, err := listener.Accept() if err != nil { @@ -146,20 +151,20 @@ func (T *Server) Serve(listener net.Listener, acceptOptions FrontendAcceptOption return nil } -func (T *Server) ListenAndServe() error { +func (T *Server) listenAndServe() error { var b flip.Bank - if len(T.listeners) > 0 { - l := T.listeners[0] + if len(T.exposed) > 0 { + l := T.exposed[0] endpoints := l.Endpoints() for _, endpoint := range endpoints { e := endpoint b.Queue(func() error { - listener, err := T.Listen(e.Network, e.Address) + listener, err := T.startListening(e.Network, e.Address) if err != nil { return err } - return T.Serve(listener, e.AcceptOptions) + return T.listen(listener, e.AcceptOptions) }) } } @@ -172,3 +177,30 @@ func (T *Server) ReadMetrics(m *metrics.Server) { provider.ReadMetrics(&m.Pools) } } + +func (T *Server) Start() error { + for _, starter := range T.starters { + if err := starter.Start(); err != nil { + return err + } + } + + return T.listenAndServe() +} + +func (T *Server) Stop() error { + var err error + for _, listener := range T.listeners { + if err2 := listener.Close(); err2 != nil { + err = err2 + } + } + + for _, stopper := range T.stoppers { + if err2 := stopper.Stop(); err2 != nil { + err = err2 + } + } + + return err +} diff --git a/lib/gat/starter.go b/lib/gat/starter.go new file mode 100644 index 0000000000000000000000000000000000000000..82f9914eaeea8d7924598df5dbe4dc3d352b4a73 --- /dev/null +++ b/lib/gat/starter.go @@ -0,0 +1,7 @@ +package gat + +type Starter interface { + Module + + Start() error +} diff --git a/lib/gat/stopper.go b/lib/gat/stopper.go new file mode 100644 index 0000000000000000000000000000000000000000..defd8646b4501ed5df61024ca9aab374ea71aebb --- /dev/null +++ b/lib/gat/stopper.go @@ -0,0 +1,7 @@ +package gat + +type Stopper interface { + Module + + Stop() error +} diff --git a/lib/util/beforeexit/run.go b/lib/util/beforeexit/run.go deleted file mode 100644 index 7dfc25c29aa889ca1dfe448721d05693c66b11ed..0000000000000000000000000000000000000000 --- a/lib/util/beforeexit/run.go +++ /dev/null @@ -1,50 +0,0 @@ -package beforeexit - -import ( - "os" - "os/signal" - "sync" - "syscall" -) - -var ( - q []func() - active bool - mu sync.Mutex -) - -func registerHandler() { - c := make(chan os.Signal, 2) - signal.Notify(c, os.Interrupt, syscall.SIGTERM) - - go func() { - <-c - - mu.Lock() - defer mu.Unlock() - for _, fn := range q { - // ignore any panics in funcs - func() { - defer func() { - recover() - }() - fn() - }() - } - - os.Exit(1) - }() -} - -// Run will register a func to run before exit on receiving an interrupt -// Tasks will run in the order that they are added -func Run(fn func()) { - mu.Lock() - defer mu.Unlock() - - q = append(q, fn) - if !active { - active = true - registerHandler() - } -} diff --git a/test/tester_test.go b/test/tester_test.go index 99e985f37af5a9f681cf693079cda262582903fa..dc53e46b6768b02c95e2c89f5a684525d16cfff3 100644 --- a/test/tester_test.go +++ b/test/tester_test.go @@ -42,21 +42,18 @@ func daisyChain(creds auth.Credentials, control recipe.Dialer, n int) (recipe.Di Dialer: control, })) - m, err := raw_pools.NewModule() - if err != nil { - return recipe.Dialer{}, err - } + m := new(raw_pools.Module) m.Add("runner", "pool", p) server.AddModule(m) - listener, err := server.Listen("tcp", ":0") + listener, err := server.listen("tcp", ":0") if err != nil { return recipe.Dialer{}, err } port := listener.Addr().(*net.TCPAddr).Port go func() { - err := server.Serve(listener, frontends.AcceptOptions{}) + err := server.serve(listener, frontends.AcceptOptions{}) if err != nil { panic(err) } @@ -111,11 +108,7 @@ func TestTester(t *testing.T) { var server gat.Server - m, err := raw_pools.NewModule() - if err != nil { - t.Error(err) - return - } + m := new(raw_pools.Module) transactionPool := pool.NewPool(transaction.Apply(pool.Options{ Credentials: creds, })) @@ -135,7 +128,7 @@ func TestTester(t *testing.T) { server.AddModule(m) - listener, err := server.Listen("tcp", ":0") + listener, err := server.listen("tcp", ":0") if err != nil { t.Error(err) return @@ -143,7 +136,7 @@ func TestTester(t *testing.T) { port := listener.Addr().(*net.TCPAddr).Port go func() { - err := server.Serve(listener, frontends.AcceptOptions{}) + err := server.serve(listener, frontends.AcceptOptions{}) if err != nil { t.Error(err) }