From 90739068fe2af6280cfb02de4e55274ddee6df32 Mon Sep 17 00:00:00 2001
From: Garet Halliday <me@garet.holiday>
Date: Tue, 23 Apr 2024 16:56:01 -0500
Subject: [PATCH] max open connections

---
 lib/gat/gatcaddyfile/gattype.go | 21 ++++++++++++++++++++-
 lib/gat/listen.go               |  7 +++++--
 lib/gat/server.go               | 14 ++++++++++++++
 3 files changed, 39 insertions(+), 3 deletions(-)

diff --git a/lib/gat/gatcaddyfile/gattype.go b/lib/gat/gatcaddyfile/gattype.go
index 9a38074c..70567c99 100644
--- a/lib/gat/gatcaddyfile/gattype.go
+++ b/lib/gat/gatcaddyfile/gattype.go
@@ -2,6 +2,7 @@ package gatcaddyfile
 
 import (
 	"encoding/json"
+	"strconv"
 	"strings"
 	"time"
 
@@ -69,7 +70,8 @@ func (ServerType) Setup(blocks []caddyfile.ServerBlock, m map[string]any) (*cadd
 		server.Listen = make([]gat.ListenerConfig, 0, len(block.Keys))
 		for _, key := range block.Keys {
 			listen := gat.ListenerConfig{
-				Address: key,
+				Address:        key,
+				MaxConnections: 1000,
 			}
 			server.Listen = append(server.Listen, listen)
 		}
@@ -114,6 +116,23 @@ func (ServerType) Setup(blocks []caddyfile.ServerBlock, m map[string]any) (*cadd
 					server.Listen[i].SSL = val
 				}
 
+				if d.CountRemainingArgs() > 0 {
+					return nil, nil, d.ArgErr()
+				}
+			case directive == "max_connections":
+				if !d.NextArg() {
+					return nil, nil, d.ArgErr()
+				}
+
+				maxConnections, err := strconv.Atoi(d.Val())
+				if err != nil {
+					return nil, nil, d.Err(err.Error())
+				}
+
+				for i := range server.Listen {
+					server.Listen[i].MaxConnections = maxConnections
+				}
+
 				if d.CountRemainingArgs() > 0 {
 					return nil, nil, d.ArgErr()
 				}
diff --git a/lib/gat/listen.go b/lib/gat/listen.go
index b68fcd80..9a4e22b8 100644
--- a/lib/gat/listen.go
+++ b/lib/gat/listen.go
@@ -9,6 +9,7 @@ import (
 	"path/filepath"
 	"strconv"
 	"strings"
+	"sync/atomic"
 
 	"github.com/caddyserver/caddy/v2"
 	"go.uber.org/zap"
@@ -17,8 +18,9 @@ import (
 )
 
 type ListenerConfig struct {
-	Address string          `json:"address"`
-	SSL     json.RawMessage `json:"ssl,omitempty" caddy:"namespace=pggat.ssl.servers inline_key=provider"`
+	Address        string          `json:"address"`
+	SSL            json.RawMessage `json:"ssl,omitempty" caddy:"namespace=pggat.ssl.servers inline_key=provider"`
+	MaxConnections int             `json:"max_connections,omitempty"`
 }
 
 type Listener struct {
@@ -28,6 +30,7 @@ type Listener struct {
 	ssl            SSLServer
 
 	listener net.Listener
+	open     atomic.Int64
 
 	log *zap.Logger
 }
diff --git a/lib/gat/server.go b/lib/gat/server.go
index 81a1cd58..1a79750e 100644
--- a/lib/gat/server.go
+++ b/lib/gat/server.go
@@ -164,6 +164,20 @@ func (T *Server) accept(listener *Listener, conn *fed.Conn) {
 		return
 	}
 
+	count := listener.open.Add(1)
+	defer listener.open.Add(-1)
+
+	if listener.MaxConnections != 0 && int(count) > listener.MaxConnections {
+		_ = conn.WritePacket(
+			perror.ToPacket(perror.New(
+				perror.FATAL,
+				perror.TooManyConnections,
+				"Too many connections, sorry",
+			)),
+		)
+		return
+	}
+
 	T.Serve(conn)
 }
 
-- 
GitLab