1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-07 00:12:00 +00:00

fix(middleware/cors): CORS handling (#2938)

* fix(middleware/cors): CORS handling

* fix(middleware/cors): Vary header handling

* fix(middleware/cors): Add Vary header for non-CORS OPTIONS requests
This commit is contained in:
Jason McNeil 2024-03-28 04:52:10 -03:00 committed by GitHub
parent 7ba02c14cf
commit 0248e58b58
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 46 deletions

View File

@ -169,9 +169,24 @@ func New(config ...Config) fiber.Handler {
// Get originHeader header
originHeader := strings.ToLower(c.Get(fiber.HeaderOrigin))
// If the request does not have Origin and Access-Control-Request-Method
// headers, the request is outside the scope of CORS
if originHeader == "" || c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// If the request does not have Origin header, the request is outside the scope of CORS
if originHeader == "" {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
// Unless all origins are allowed, we include the Vary header to cache the response correctly
if !allowAllOrigins {
c.Vary(fiber.HeaderOrigin)
}
return c.Next()
}
// If it's a preflight request and doesn't have Access-Control-Request-Method header, it's outside the scope of CORS
if c.Method() == fiber.MethodOptions && c.Get(fiber.HeaderAccessControlRequestMethod) == "" {
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// for non-CORS OPTIONS requests:
c.Vary(fiber.HeaderOrigin)
return c.Next()
}
@ -211,17 +226,28 @@ func New(config ...Config) fiber.Handler {
// Simple request
// Ommit allowMethods and allowHeaders, only used for pre-flight requests
if c.Method() != fiber.MethodOptions {
if !allowAllOrigins {
// See https://fetch.spec.whatwg.org/#cors-protocol-and-http-caches
c.Vary(fiber.HeaderOrigin)
}
setCORSHeaders(c, allowOrigin, "", "", exposeHeaders, maxAge, cfg)
return c.Next()
}
// Preflight request
// Pre-flight request
// Response to OPTIONS request should not be cached but,
// some caching can be configured to cache such responses.
// To Avoid poisoning the cache, we include the Vary header
// of preflight responses:
c.Vary(fiber.HeaderAccessControlRequestMethod)
c.Vary(fiber.HeaderAccessControlRequestHeaders)
if cfg.AllowPrivateNetwork && c.Get(fiber.HeaderAccessControlRequestPrivateNetwork) == "true" {
c.Vary(fiber.HeaderAccessControlRequestPrivateNetwork)
c.Set(fiber.HeaderAccessControlAllowPrivateNetwork, "true")
}
c.Vary(fiber.HeaderOrigin)
setCORSHeaders(c, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge, cfg)
// Send 204 No Content
@ -231,8 +257,6 @@ func New(config ...Config) fiber.Handler {
// Function to set CORS headers
func setCORSHeaders(c fiber.Ctx, allowOrigin, allowMethods, allowHeaders, exposeHeaders, maxAge string, cfg Config) {
c.Vary(fiber.HeaderOrigin)
if cfg.AllowCredentials {
// When AllowCredentials is true, set the Access-Control-Allow-Origin to the specific origin instead of '*'
if allowOrigin == "*" {

View File

@ -49,7 +49,6 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
@ -69,6 +68,33 @@ func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
func Test_CORS_AllowOrigins_Vary(t *testing.T) {
t.Parallel()
app := fiber.New()
app.Use(New(
Config{
AllowOrigins: "http://localhost",
},
))
h := app.Handler()
// Test Vary header non-Cors request
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
// Test Vary header Cors request
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
h(ctx)
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
}
// go test -run -v Test_CORS_Wildcard
func Test_CORS_Wildcard(t *testing.T) {
t.Parallel()
@ -96,6 +122,7 @@ func Test_CORS_Wildcard(t *testing.T) {
// Check result
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin))) // Validates request is not reflecting origin in the response
require.Contains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should be set")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
require.Equal(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
@ -104,9 +131,9 @@ func Test_CORS_Wildcard(t *testing.T) {
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
require.NotContains(t, string(ctx.Response.Header.Peek(fiber.HeaderVary)), fiber.HeaderOrigin, "Vary header should not be set")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
require.Equal(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
@ -146,7 +173,6 @@ func Test_CORS_Origin_AllowCredentials(t *testing.T) {
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://localhost")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
@ -465,7 +491,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
// Get handler pointer
handler := app.Handler()
t.Run("Without origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Run("Without origin", func(t *testing.T) {
t.Parallel()
// Make request without origin header, and without Access-Control-Request-Method
for _, method := range methods {
@ -478,34 +504,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
}
})
t.Run("With origin but without Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request with origin header, but without Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})
t.Run("Without origin but with Access-Control-Request-Method header", func(t *testing.T) {
t.Parallel()
// Make request without origin header, but with Access-Control-Request-Method
for _, method := range methods {
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should not be set")
}
})
t.Run("Preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Parallel()
// Make preflight request with origin header and with Access-Control-Request-Method
@ -523,7 +521,7 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
}
})
t.Run("Non-preflight request with origin and Access-Control-Request-Method headers", func(t *testing.T) {
t.Run("Non-preflight request with origin", func(t *testing.T) {
t.Parallel()
// Make non-preflight request with origin header and with Access-Control-Request-Method
for _, method := range methods {
@ -531,7 +529,6 @@ func Test_CORS_Headers_BasedOnRequestType(t *testing.T) {
ctx.Request.Header.SetMethod(method)
ctx.Request.SetRequestURI("https://example.com/api/action")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com")
ctx.Request.Header.Set(fiber.HeaderAccessControlRequestMethod, method)
handler(ctx)
require.Equal(t, 200, ctx.Response.StatusCode(), "Status code should be 200")
require.Equal(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)), "Access-Control-Allow-Origin header should be set")
@ -1008,7 +1005,6 @@ func Benchmark_CORS_NewHandler(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
@ -1049,7 +1045,6 @@ func Benchmark_CORS_NewHandlerParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
@ -1083,7 +1078,6 @@ func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
@ -1124,7 +1118,6 @@ func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
@ -1158,7 +1151,6 @@ func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
@ -1199,7 +1191,6 @@ func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) {
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
@ -1229,6 +1220,7 @@ func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// Preflight request
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")