package gat import ( "crypto/tls" "errors" "fmt" "io" "net" "time" "github.com/caddyserver/caddy/v2" "go.uber.org/zap" "gfx.cafe/gfx/pggat/lib/bouncer/frontends/v0" "gfx.cafe/gfx/pggat/lib/fed" packets "gfx.cafe/gfx/pggat/lib/fed/packets/v3.0" "gfx.cafe/gfx/pggat/lib/gat/metrics" "gfx.cafe/gfx/pggat/lib/middleware/interceptor" "gfx.cafe/gfx/pggat/lib/middleware/middlewares/unterminate" "gfx.cafe/gfx/pggat/lib/perror" "gfx.cafe/gfx/pggat/lib/util/dur" "gfx.cafe/gfx/pggat/lib/util/maps" "gfx.cafe/gfx/pggat/lib/util/slices" ) type Config struct { StatLogPeriod dur.Duration `json:"stat_log_period"` Listen []ListenerConfig `json:"listen"` Servers []ServerConfig `json:"servers"` } func init() { caddy.RegisterModule((*App)(nil)) } type App struct { Config listen []*Listener servers []*Server keys maps.RWLocked[[8]byte, *Pool] closed chan struct{} log *zap.Logger } func (T *App) CaddyModule() caddy.ModuleInfo { return caddy.ModuleInfo{ ID: "pggat", New: func() caddy.Module { return new(App) }, } } func (T *App) Provision(ctx caddy.Context) error { T.log = ctx.Logger() T.listen = make([]*Listener, 0, len(T.Listen)) for _, config := range T.Listen { listener := &Listener{ ListenerConfig: config, } if err := listener.Provision(ctx); err != nil { return err } T.listen = append(T.listen, listener) } T.servers = make([]*Server, 0, len(T.Servers)) for _, config := range T.Servers { server := &Server{ ServerConfig: config, } if err := server.Provision(ctx); err != nil { return err } T.servers = append(T.servers, server) } return nil } func (T *App) cancel(key [8]byte) { p, _ := T.keys.Load(key) if p == nil { return } _ = p.Cancel(key) } func (T *App) serve(server *Server, conn fed.Conn) { initialParameters := conn.InitialParameters() for key := range initialParameters { if !slices.Contains(server.AllowedStartupParameters, key) { errResp := packets.ErrorResponse{ Error: perror.New( perror.FATAL, perror.FeatureNotSupported, fmt.Sprintf(`Startup parameter "%s" is not allowed`, key), ), } _ = conn.WritePacket(errResp.IntoPacket(nil)) return } } p := server.lookup(conn) if p == nil { T.log.Warn("database not found", zap.String("user", conn.User()), zap.String("database", conn.Database())) return } backendKey, err := frontends.Authenticate(conn, p.Credentials()) if err != nil { T.log.Warn("error authenticating client", zap.Error(err)) return } T.keys.Store(backendKey, p) defer T.keys.Delete(backendKey) if err2 := p.Serve(conn, backendKey); err2 != nil && !errors.Is(err2, io.EOF) { T.log.Warn("error serving client", zap.Error(err2)) return } } func (T *App) accept(listener *Listener, conn *fed.NetConn) { defer func() { _ = conn.Close() }() var tlsConfig *tls.Config if listener.ssl != nil { tlsConfig = listener.ssl.ServerTLSConfig() } cancelKey, isCanceling, _, user, database, initialParameters, err := frontends.Accept(conn, tlsConfig) if err != nil { T.log.Warn("error accepting client", zap.Error(err)) return } if isCanceling { T.cancel(cancelKey) return } conn.SetUser(user) conn.SetDatabase(database) conn.SetInitialParameters(initialParameters) for _, server := range T.servers { if server.match == nil || server.match.Matches(conn) { T.serve(server, interceptor.NewInterceptor(conn, unterminate.Unterminate)) return } } T.log.Warn("server not found", zap.String("user", conn.User()), zap.String("database", conn.Database())) errResp := packets.ErrorResponse{ Error: perror.New( perror.FATAL, perror.InternalError, "No server is available to handle your request", ), } _ = conn.WritePacket(errResp.IntoPacket(nil)) } func (T *App) acceptFrom(listener *Listener) bool { conn, err := listener.accept() if err != nil { if errors.Is(err, net.ErrClosed) { return false } T.log.Warn("error accepting client", zap.Error(err)) return true } go T.accept(listener, conn) return true } func (T *App) statLogLoop() { t := time.NewTicker(T.StatLogPeriod.Duration()) defer t.Stop() var stats metrics.Server for { select { case <-t.C: for _, server := range T.servers { for _, route := range server.routes { route.provide.ReadMetrics(&stats.Pools) } } T.log.Info(stats.String()) stats.Clear() case <-T.closed: return } } } func (T *App) Start() error { T.closed = make(chan struct{}) if T.StatLogPeriod != 0 { go T.statLogLoop() } // start listeners for _, listener := range T.listen { if err := listener.Start(); err != nil { return err } go func(listener *Listener) { for { if !T.acceptFrom(listener) { break } } }(listener) } return nil } func (T *App) Stop() error { close(T.closed) // stop listeners for _, listener := range T.listen { if err := listener.Stop(); err != nil { return err } } return nil } var _ caddy.Module = (*App)(nil) var _ caddy.Provisioner = (*App)(nil) var _ caddy.App = (*App)(nil)