From eac88f67479a5fdfaeda97109687cf343f2e999f Mon Sep 17 00:00:00 2001 From: or-else <or.else@gmail.com> Date: Wed, 15 Mar 2023 10:03:28 -0700 Subject: [PATCH] fix CORS handling, #840 --- server/hdl_files.go | 67 ++++++++++++++++++++++++++++++++++--------- server/media/media.go | 39 +++++++++++++------------ 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/server/hdl_files.go b/server/hdl_files.go index 8a6e74aa..9f9bda3b 100644 --- a/server/hdl_files.go +++ b/server/hdl_files.go @@ -39,6 +39,32 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) { } } + // Preflight request: process before any security checks. + if req.Method == http.MethodOptions { + headers, statusCode, err := mh.Headers(req, true) + if err != nil { + writeHttpResponse(decodeStoreError(err, "", now, nil), err) + return + } + for name, values := range headers { + for _, value := range values { + wrt.Header().Add(name, value) + } + } + if statusCode <= 0 { + statusCode = http.StatusNoContent + } + wrt.WriteHeader(statusCode) + logs.Info.Println("media serve: preflight completed") + return + } + + // Check if this is a GET/HEAD request. + if req.Method != http.MethodGet && req.Method != http.MethodHead { + writeHttpResponse(ErrOperationNotAllowed("", "", now), errors.New("method '"+req.Method+"' not allowed")) + return + } + // Check for API key presence if isValid, _ := checkAPIKey(getAPIKey(req)); !isValid { writeHttpResponse(ErrAPIKeyRequired(now), errors.New("invalid or missing API key")) @@ -63,12 +89,6 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) { return } - // Check if this is a GET/OPTIONS/HEAD request. - if req.Method != http.MethodGet && req.Method != http.MethodHead && req.Method != http.MethodOptions { - writeHttpResponse(ErrOperationNotAllowed("", "", now), errors.New("method '"+req.Method+"' not allowed")) - return - } - // Check if media handler redirects or adds headers. headers, statusCode, err := mh.Headers(req, true) if err != nil { @@ -98,7 +118,7 @@ func largeFileServe(wrt http.ResponseWriter, req *http.Request) { return } - if req.Method == http.MethodHead || req.Method == http.MethodOptions { + if req.Method == http.MethodHead { wrt.WriteHeader(http.StatusOK) logs.Info.Println("media serve: completed", req.Method, "uid=", uid) return @@ -137,9 +157,28 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) { } } - // Check if this is a POST/PUT/OPTIONS/HEAD request. - if req.Method != http.MethodPost && req.Method != http.MethodPut && - req.Method != http.MethodHead && req.Method != http.MethodOptions { + // Preflight request: process before any security checks. + if req.Method == http.MethodOptions { + headers, statusCode, err := mh.Headers(req, false) + if err != nil { + writeHttpResponse(decodeStoreError(err, "", now, nil), err) + return + } + for name, values := range headers { + for _, value := range values { + wrt.Header().Add(name, value) + } + } + if statusCode <= 0 { + statusCode = http.StatusNoContent + } + wrt.WriteHeader(statusCode) + logs.Info.Println("media upload: preflight completed") + return + } + + // Check if this is a POST/PUT/HEAD request. + if req.Method != http.MethodPost && req.Method != http.MethodPut && req.Method != http.MethodHead { writeHttpResponse(ErrOperationNotAllowed("", "", now), errors.New("method '"+req.Method+"' not allowed")) return } @@ -175,7 +214,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) { // Check if uploads are handled elsewhere. headers, statusCode, err := mh.Headers(req, false) if err != nil { - logs.Info.Println("Headers check failed", err) + logs.Info.Println("media upload: headers check failed", err) writeHttpResponse(decodeStoreError(err, "", now, nil), err) return } @@ -210,7 +249,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) { file, _, err := req.FormFile("file") if err != nil { - logs.Info.Println("Invalid multipart form", err) + logs.Info.Println("media upload: invalid multipart form", err) if strings.Contains(err.Error(), "request body too large") { writeHttpResponse(ErrTooLarge(msgID, "", now), err) } else { @@ -240,7 +279,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) { url, size, err := mh.Upload(fdef, file) if err != nil { - logs.Info.Println("Upload failed", file, "key", fdef.Location, err) + logs.Info.Println("media upload: failed", file, "key", fdef.Location, err) store.Files.FinishUpload(fdef, false, 0) writeHttpResponse(decodeStoreError(err, msgID, now, nil), err) return @@ -248,7 +287,7 @@ func largeFileReceive(wrt http.ResponseWriter, req *http.Request) { fdef, err = store.Files.FinishUpload(fdef, true, size) if err != nil { - logs.Info.Println("Failed to finalize upload", file, "key", fdef.Location, err) + logs.Info.Println("media upload: failed to finalize", file, "key", fdef.Location, err) // Best effort cleanup. mh.Delete([]string{fdef.Location}) writeHttpResponse(decodeStoreError(err, msgID, now, nil), err) diff --git a/server/media/media.go b/server/media/media.go index a3f69a96..06892646 100644 --- a/server/media/media.go +++ b/server/media/media.go @@ -70,8 +70,9 @@ func matchCORSOrigin(allowed []string, origin string) string { return "*" } + origin = strings.ToLower(origin) for _, val := range allowed { - if val == origin { + if strings.ToLower(val) == origin { return origin } } @@ -87,7 +88,7 @@ func matchCORSMethod(allowMethods []string, method string) bool { method = strings.ToUpper(method) for _, mm := range allowMethods { - if mm == method { + if strings.ToUpper(mm) == method { return true } } @@ -102,17 +103,6 @@ func CORSHandler(req *http.Request, allowedOrigins []string, serve bool) (http.H return nil, 0 } - headers := map[string][]string{ - // Always add Vary because of possible intermediate caches. - "Vary": {"Origin", "Access-Control-Request-Method"}, - } - - allowedOrigin := matchCORSOrigin(allowedOrigins, req.Header.Get("Origin")) - if allowedOrigin == "" { - // CORS policy does not match the origin. - return headers, http.StatusOK - } - var allowMethods []string if serve { allowMethods = []string{http.MethodGet, http.MethodHead, http.MethodOptions} @@ -120,16 +110,27 @@ func CORSHandler(req *http.Request, allowedOrigins []string, serve bool) (http.H allowMethods = []string{http.MethodPost, http.MethodPut, http.MethodHead, http.MethodOptions} } + headers := map[string][]string{ + // Always add Vary because of possible intermediate caches. + "Vary": {"Origin", "Access-Control-Request-Method"}, + "Access-Control-Allow-Headers": {"*"}, + "Access-Control-Max-Age": {"86400"}, + "Access-Control-Allow-Credentials": {"true"}, + "Access-Control-Allow-Methods": {strings.Join(allowMethods, ", ")}, + } + if !matchCORSMethod(allowMethods, req.Header.Get("Access-Control-Request-Method")) { // CORS policy does not allow this method. - return headers, http.StatusOK + return headers, http.StatusNoContent + } + + allowedOrigin := matchCORSOrigin(allowedOrigins, req.Header.Get("Origin")) + if allowedOrigin == "" { + // CORS policy does not match the origin. + return headers, http.StatusNoContent } headers["Access-Control-Allow-Origin"] = []string{allowedOrigin} - headers["Access-Control-Allow-Headers"] = []string{"*"} - headers["Access-Control-Allow-Methods"] = []string{strings.Join(allowMethods, ",")} - headers["Access-Control-Max-Age"] = []string{"86400"} - headers["Access-Control-Allow-Credentials"] = []string{"true"} - return headers, http.StatusOK + return headers, http.StatusNoContent } -- GitLab