From eac88f67479a5fdfaeda97109687cf343f2e999f Mon Sep 17 00:00:00 2001
From: or-else <or.else@gmail.com>
Date: Wed, 15 Mar 2023 10:03:28 -0700
Subject: [PATCH] fix CORS handling, #840

---
 server/hdl_files.go   | 67 ++++++++++++++++++++++++++++++++++---------
 server/media/media.go | 39 +++++++++++++------------
 2 files changed, 73 insertions(+), 33 deletions(-)

diff --git a/server/hdl_files.go b/server/hdl_files.go
index 8a6e74aa..9f9bda3b 100644
--- a/server/hdl_files.go
+++ b/server/hdl_files.go
@@ -39,6 +39,32 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) {
 		}
 	}
 
+	// Preflight request: process before any security checks.
+	if req.Method == http.MethodOptions {
+		headers, statusCode, err := mh.Headers(req, true)
+		if err != nil {
+			writeHttpResponse(decodeStoreError(err, "", now, nil), err)
+			return
+		}
+		for name, values := range headers {
+			for _, value := range values {
+				wrt.Header().Add(name, value)
+			}
+		}
+		if statusCode <= 0 {
+			statusCode = http.StatusNoContent
+		}
+		wrt.WriteHeader(statusCode)
+		logs.Info.Println("media serve: preflight completed")
+		return
+	}
+
+	// Check if this is a GET/HEAD request.
+	if req.Method != http.MethodGet && req.Method != http.MethodHead {
+		writeHttpResponse(ErrOperationNotAllowed("", "", now), errors.New("method '"+req.Method+"' not allowed"))
+		return
+	}
+
 	// Check for API key presence
 	if isValid, _ := checkAPIKey(getAPIKey(req)); !isValid {
 		writeHttpResponse(ErrAPIKeyRequired(now), errors.New("invalid or missing API key"))
@@ -63,12 +89,6 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	// Check if this is a GET/OPTIONS/HEAD request.
-	if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions {
-		writeHttpResponse(ErrOperationNotAllowed("", "", now), errors.New("method '"+req.Method+"' not allowed"))
-		return
-	}
-
 	// Check if media handler redirects or adds headers.
 	headers, statusCode, err := mh.Headers(req, true)
 	if err != nil {
@@ -98,7 +118,7 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) {
 		return
 	}
 
-	if req.Method == http.MethodHead || req.Method == http.MethodOptions {
+	if req.Method == http.MethodHead {
 		wrt.WriteHeader(http.StatusOK)
 		logs.Info.Println("media serve: completed", req.Method, "uid=", uid)
 		return
@@ -137,9 +157,28 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) {
 		}
 	}
 
-	// Check if this is a POST/PUT/OPTIONS/HEAD request.
-	if req.Method != http.MethodPost && req.Method != http.MethodPut &&
-		req.Method != http.MethodHead && req.Method != http.MethodOptions {
+	// Preflight request: process before any security checks.
+	if req.Method == http.MethodOptions {
+		headers, statusCode, err := mh.Headers(req, false)
+		if err != nil {
+			writeHttpResponse(decodeStoreError(err, "", now, nil), err)
+			return
+		}
+		for name, values := range headers {
+			for _, value := range values {
+				wrt.Header().Add(name, value)
+			}
+		}
+		if statusCode <= 0 {
+			statusCode = http.StatusNoContent
+		}
+		wrt.WriteHeader(statusCode)
+		logs.Info.Println("media upload: preflight completed")
+		return
+	}
+
+	// Check if this is a POST/PUT/HEAD request.
+	if req.Method != http.MethodPost && req.Method != http.MethodPut && req.Method != http.MethodHead {
 		writeHttpResponse(ErrOperationNotAllowed("", "", now), errors.New("method '"+req.Method+"' not allowed"))
 		return
 	}
@@ -175,7 +214,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) {
 	// Check if uploads are handled elsewhere.
 	headers, statusCode, err := mh.Headers(req, false)
 	if err != nil {
-		logs.Info.Println("Headers check failed", err)
+		logs.Info.Println("media upload: headers check failed", err)
 		writeHttpResponse(decodeStoreError(err, "", now, nil), err)
 		return
 	}
@@ -210,7 +249,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) {
 
 	file, _, err := req.FormFile("file")
 	if err != nil {
-		logs.Info.Println("Invalid multipart form", err)
+		logs.Info.Println("media upload: invalid multipart form", err)
 		if strings.Contains(err.Error(), "request body too large") {
 			writeHttpResponse(ErrTooLarge(msgID, "", now), err)
 		} else {
@@ -240,7 +279,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) {
 
 	url, size, err := mh.Upload(fdef, file)
 	if err != nil {
-		logs.Info.Println("Upload failed", file, "key", fdef.Location, err)
+		logs.Info.Println("media upload: failed", file, "key", fdef.Location, err)
 		store.Files.FinishUpload(fdef, false, 0)
 		writeHttpResponse(decodeStoreError(err, msgID, now, nil), err)
 		return
@@ -248,7 +287,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) {
 
 	fdef, err = store.Files.FinishUpload(fdef, true, size)
 	if err != nil {
-		logs.Info.Println("Failed to finalize upload", file, "key", fdef.Location, err)
+		logs.Info.Println("media upload: failed to finalize", file, "key", fdef.Location, err)
 		// Best effort cleanup.
 		mh.Delete([]string{fdef.Location})
 		writeHttpResponse(decodeStoreError(err, msgID, now, nil), err)
diff --git a/server/media/media.go b/server/media/media.go
index a3f69a96..06892646 100644
--- a/server/media/media.go
+++ b/server/media/media.go
@@ -70,8 +70,9 @@ func matchCORSOrigin(allowed []string, origin string) string {
 		return "*"
 	}
 
+	origin = strings.ToLower(origin)
 	for _, val := range allowed {
-		if val == origin {
+		if strings.ToLower(val) == origin {
 			return origin
 		}
 	}
@@ -87,7 +88,7 @@ func matchCORSMethod(allowMethods []string, method string) bool {
 
 	method = strings.ToUpper(method)
 	for _, mm := range allowMethods {
-		if mm == method {
+		if strings.ToUpper(mm) == method {
 			return true
 		}
 	}
@@ -102,17 +103,6 @@ func CORSHandler(req *http.Request, allowedOrigins []string, serve bool) (http.H
 		return nil, 0
 	}
 
-	headers := map[string][]string{
-		// Always add Vary because of possible intermediate caches.
-		"Vary": {"Origin", "Access-Control-Request-Method"},
-	}
-
-	allowedOrigin := matchCORSOrigin(allowedOrigins, req.Header.Get("Origin"))
-	if allowedOrigin == "" {
-		// CORS policy does not match the origin.
-		return headers, http.StatusOK
-	}
-
 	var allowMethods []string
 	if serve {
 		allowMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions}
@@ -120,16 +110,27 @@ func CORSHandler(req *http.Request, allowedOrigins []string, serve bool) (http.H
 		allowMethods = []string{http.MethodPost, http.MethodPut, http.MethodHead, http.MethodOptions}
 	}
 
+	headers := map[string][]string{
+		// Always add Vary because of possible intermediate caches.
+		"Vary":                             {"Origin", "Access-Control-Request-Method"},
+		"Access-Control-Allow-Headers":     {"*"},
+		"Access-Control-Max-Age":           {"86400"},
+		"Access-Control-Allow-Credentials": {"true"},
+		"Access-Control-Allow-Methods":     {strings.Join(allowMethods, ", ")},
+	}
+
 	if !matchCORSMethod(allowMethods, req.Header.Get("Access-Control-Request-Method")) {
 		// CORS policy does not allow this method.
-		return headers, http.StatusOK
+		return headers, http.StatusNoContent
+	}
+
+	allowedOrigin := matchCORSOrigin(allowedOrigins, req.Header.Get("Origin"))
+	if allowedOrigin == "" {
+		// CORS policy does not match the origin.
+		return headers, http.StatusNoContent
 	}
 
 	headers["Access-Control-Allow-Origin"] = []string{allowedOrigin}
-	headers["Access-Control-Allow-Headers"] = []string{"*"}
-	headers["Access-Control-Allow-Methods"] = []string{strings.Join(allowMethods, ",")}
-	headers["Access-Control-Max-Age"] = []string{"86400"}
-	headers["Access-Control-Allow-Credentials"] = []string{"true"}
 
-	return headers, http.StatusOK
+	return headers, http.StatusNoContent
 }
-- 
GitLab