1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-11 23:21:20 +00:00

fix(middleware/cors): Validation of multiple Origins (#2883)

* fix: allow origins check

Refactor CORS origin validation and normalization to trim leading or trailing whitespace in the cfg.AllowOrigins string [list]. URLs with whitespace inside the URL are invalid, so the normalizeOrigin will return false because url.Parse will fail, and the middleware will panic.

fixes #2882

* test: AllowOrigins with whitespace

* test(middleware/cors): add benchmarks

* chore: fix linter errors

* test(middleware/cors): use h() instead of app.Test()

* test(middleware/cors): add miltiple origins in Test_CORS_AllowOriginScheme

* chore: refactor validate and normalize

* test(cors/middleware): add more benchmarks

(cherry picked from commit d456e7d82ee087d5b2a5df5d4ab35f19c6397aae)
This commit is contained in:
Jason McNeil 2024-03-01 10:31:11 +01:00 committed by René
parent 3c08c1b637
commit 4ab8629706
2 changed files with 513 additions and 15 deletions

View File

@ -113,23 +113,31 @@ func New(config ...Config) fiber.Handler {
log.Panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") //nolint:revive // we want to exit the program
}
// Validate and normalize static AllowOrigins if not using AllowOriginsFunc
if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
validatedOrigins := []string{}
for _, origin := range strings.Split(cfg.AllowOrigins, ",") {
isValid, normalizedOrigin := normalizeOrigin(origin)
// allowOrigins is a slice of strings that contains the allowed origins
// defined in the 'AllowOrigins' configuration.
var allowOrigins []string
// Validate and normalize static AllowOrigins
if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" {
origins := strings.Split(cfg.AllowOrigins, ",")
allowOrigins = make([]string, len(origins))
for i, origin := range origins {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if isValid {
validatedOrigins = append(validatedOrigins, normalizedOrigin)
allowOrigins[i] = normalizedOrigin
} else {
log.Panicf("[CORS] Invalid origin format in configuration: %s", origin) //nolint:revive // we want to exit the program
log.Panicf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) //nolint:revive // we want to exit the program
}
}
cfg.AllowOrigins = strings.Join(validatedOrigins, ",")
} else {
// If AllowOrigins is set to a wildcard or not set,
// set allowOrigins to a slice with a single element
allowOrigins = []string{cfg.AllowOrigins}
}
// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")
// Strip white spaces
allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "")
allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "")
@ -164,10 +172,8 @@ func New(config ...Config) fiber.Handler {
// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) {
allowOrigin = originHeader
}
// Simple request

View File

@ -306,6 +306,21 @@ func Test_CORS_AllowOriginScheme(t *testing.T) {
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: "http://domain-1.com, http://example.com",
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://domain-1.com, http://example.com",
reqOrigin: "http://domain-2.com",
shouldAllowOrigin: false,
},
{
pattern: "http://domain-1.com,http://example.com",
reqOrigin: "http://domain-1.com",
shouldAllowOrigin: true,
},
}
for _, tt := range tests {
@ -451,6 +466,33 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) {
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com,http://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://bbb.com",
ResponseOrigin: "http://bbb.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginNotAllowed",
Config: Config{
AllowOrigins: "http://aaa.com,http://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://ccc.com",
ResponseOrigin: "",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed",
Config: Config{
AllowOrigins: "http://aaa.com, http://bbb.com",
AllowOriginsFunc: nil,
},
RequestOrigin: "http://aaa.com",
ResponseOrigin: "http://aaa.com",
},
{
Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed",
Config: Config{
@ -646,3 +688,453 @@ func Test_CORS_AllowCredentials(t *testing.T) {
})
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandler -benchmem -count=4
func Benchmark_CORS_NewHandler(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://localhost,http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerParallel -benchmem -count=4
func Benchmark_CORS_NewHandlerParallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://localhost,http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://localhost")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin -benchmem -count=4
func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOriginParallel -benchmem -count=4
func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard -benchmem -count=4
func Benchmark_CORS_NewHandlerWildcard(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "*",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcardParallel -benchmem -count=4
func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "*",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodGet)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflight(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://localhost,http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightParallel -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightParallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://localhost,http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightSingleOrigin(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOriginParallel -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightSingleOriginParallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "http://example.com",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: true,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightWildcard(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "*",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
h(ctx)
}
}
// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcardParallel -benchmem -count=4
func Benchmark_CORS_NewHandlerPreflightWildcardParallel(b *testing.B) {
app := fiber.New()
c := New(Config{
AllowOrigins: "*",
AllowMethods: "GET,POST,PUT,DELETE",
AllowHeaders: "Origin,Content-Type,Accept",
AllowCredentials: false,
MaxAge: 600,
})
app.Use(c)
app.Use(func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusOK)
})
h := app.Handler()
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
ctx := &fasthttp.RequestCtx{}
req := &fasthttp.Request{}
req.Header.SetMethod(fiber.MethodOptions)
req.SetRequestURI("/")
req.Header.Set(fiber.HeaderOrigin, "http://example.com")
req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost)
req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept")
ctx.Init(req, nil, nil)
for pb.Next() {
h(ctx)
}
})
}