From 4ab862970609feeae08b33d17983b076f92e7254 Mon Sep 17 00:00:00 2001 From: Jason McNeil Date: Fri, 1 Mar 2024 10:31:11 +0100 Subject: [PATCH] fix(middleware/cors): Validation of multiple Origins (#2883) * fix: allow origins check Refactor CORS origin validation and normalization to trim leading or trailing whitespace in the cfg.AllowOrigins string [list]. URLs with whitespace inside the URL are invalid, so the normalizeOrigin will return false because url.Parse will fail, and the middleware will panic. fixes #2882 * test: AllowOrigins with whitespace * test(middleware/cors): add benchmarks * chore: fix linter errors * test(middleware/cors): use h() instead of app.Test() * test(middleware/cors): add miltiple origins in Test_CORS_AllowOriginScheme * chore: refactor validate and normalize * test(cors/middleware): add more benchmarks (cherry picked from commit d456e7d82ee087d5b2a5df5d4ab35f19c6397aae) --- middleware/cors/cors.go | 36 +-- middleware/cors/cors_test.go | 492 +++++++++++++++++++++++++++++++++++ 2 files changed, 513 insertions(+), 15 deletions(-) diff --git a/middleware/cors/cors.go b/middleware/cors/cors.go index de96c1a6..0869c4d9 100644 --- a/middleware/cors/cors.go +++ b/middleware/cors/cors.go @@ -113,23 +113,31 @@ func New(config ...Config) fiber.Handler { log.Panic("[CORS] Insecure setup, 'AllowCredentials' is set to true, and 'AllowOrigins' is set to a wildcard.") //nolint:revive // we want to exit the program } - // Validate and normalize static AllowOrigins if not using AllowOriginsFunc - if cfg.AllowOriginsFunc == nil && cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { - validatedOrigins := []string{} - for _, origin := range strings.Split(cfg.AllowOrigins, ",") { - isValid, normalizedOrigin := normalizeOrigin(origin) + // allowOrigins is a slice of strings that contains the allowed origins + // defined in the 'AllowOrigins' configuration. + var allowOrigins []string + + // Validate and normalize static AllowOrigins + if cfg.AllowOrigins != "" && cfg.AllowOrigins != "*" { + origins := strings.Split(cfg.AllowOrigins, ",") + allowOrigins = make([]string, len(origins)) + + for i, origin := range origins { + trimmedOrigin := strings.TrimSpace(origin) + isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin) + if isValid { - validatedOrigins = append(validatedOrigins, normalizedOrigin) + allowOrigins[i] = normalizedOrigin } else { - log.Panicf("[CORS] Invalid origin format in configuration: %s", origin) //nolint:revive // we want to exit the program + log.Panicf("[CORS] Invalid origin format in configuration: %s", trimmedOrigin) //nolint:revive // we want to exit the program } } - cfg.AllowOrigins = strings.Join(validatedOrigins, ",") + } else { + // If AllowOrigins is set to a wildcard or not set, + // set allowOrigins to a slice with a single element + allowOrigins = []string{cfg.AllowOrigins} } - // Convert string to slice - allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",") - // Strip white spaces allowMethods := strings.ReplaceAll(cfg.AllowMethods, " ", "") allowHeaders := strings.ReplaceAll(cfg.AllowHeaders, " ", "") @@ -164,10 +172,8 @@ func New(config ...Config) fiber.Handler { // Run AllowOriginsFunc if the logic for // handling the value in 'AllowOrigins' does // not result in allowOrigin being set. - if allowOrigin == "" && cfg.AllowOriginsFunc != nil { - if cfg.AllowOriginsFunc(originHeader) { - allowOrigin = originHeader - } + if allowOrigin == "" && cfg.AllowOriginsFunc != nil && cfg.AllowOriginsFunc(originHeader) { + allowOrigin = originHeader } // Simple request diff --git a/middleware/cors/cors_test.go b/middleware/cors/cors_test.go index 3d3ae8c6..ae85a00d 100644 --- a/middleware/cors/cors_test.go +++ b/middleware/cors/cors_test.go @@ -306,6 +306,21 @@ func Test_CORS_AllowOriginScheme(t *testing.T) { reqOrigin: "http://ccc.bbb.example.com", shouldAllowOrigin: false, }, + { + pattern: "http://domain-1.com, http://example.com", + reqOrigin: "http://example.com", + shouldAllowOrigin: true, + }, + { + pattern: "http://domain-1.com, http://example.com", + reqOrigin: "http://domain-2.com", + shouldAllowOrigin: false, + }, + { + pattern: "http://domain-1.com,http://example.com", + reqOrigin: "http://domain-1.com", + shouldAllowOrigin: true, + }, } for _, tt := range tests { @@ -451,6 +466,33 @@ func Test_CORS_AllowOriginsAndAllowOriginsFunc_AllUseCases(t *testing.T) { RequestOrigin: "http://aaa.com", ResponseOrigin: "http://aaa.com", }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com,http://bbb.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://bbb.com", + ResponseOrigin: "http://bbb.com", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/NoWhitespace/OriginNotAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com,http://bbb.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://ccc.com", + ResponseOrigin: "", + }, + { + Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/MultipleOrigins/Whitespace/OriginAllowed", + Config: Config{ + AllowOrigins: "http://aaa.com, http://bbb.com", + AllowOriginsFunc: nil, + }, + RequestOrigin: "http://aaa.com", + ResponseOrigin: "http://aaa.com", + }, { Name: "AllowOriginsDefined/AllowOriginsFuncUndefined/OriginNotAllowed", Config: Config{ @@ -646,3 +688,453 @@ func Test_CORS_AllowCredentials(t *testing.T) { }) } } + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandler -benchmem -count=4 +func Benchmark_CORS_NewHandler(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://localhost") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://localhost") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOrigin -benchmem -count=4 +func Benchmark_CORS_NewHandlerSingleOrigin(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerSingleOriginParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerSingleOriginParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcard -benchmem -count=4 +func Benchmark_CORS_NewHandlerWildcard(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerWildcardParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerWildcardParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodGet) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodGet) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflight -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflight(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://localhost,http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOrigin -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightSingleOrigin(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightSingleOriginParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightSingleOriginParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "http://example.com", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: true, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcard -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightWildcard(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + h(ctx) + } +} + +// go test -v -run=^$ -bench=Benchmark_CORS_NewHandlerPreflightWildcardParallel -benchmem -count=4 +func Benchmark_CORS_NewHandlerPreflightWildcardParallel(b *testing.B) { + app := fiber.New() + c := New(Config{ + AllowOrigins: "*", + AllowMethods: "GET,POST,PUT,DELETE", + AllowHeaders: "Origin,Content-Type,Accept", + AllowCredentials: false, + MaxAge: 600, + }) + + app.Use(c) + app.Use(func(c fiber.Ctx) error { + return c.SendStatus(fiber.StatusOK) + }) + + h := app.Handler() + + b.ReportAllocs() + b.ResetTimer() + + b.RunParallel(func(pb *testing.PB) { + ctx := &fasthttp.RequestCtx{} + + req := &fasthttp.Request{} + req.Header.SetMethod(fiber.MethodOptions) + req.SetRequestURI("/") + req.Header.Set(fiber.HeaderOrigin, "http://example.com") + req.Header.Set(fiber.HeaderAccessControlRequestMethod, fiber.MethodPost) + req.Header.Set(fiber.HeaderAccessControlRequestHeaders, "Origin,Content-Type,Accept") + + ctx.Init(req, nil, nil) + + for pb.Next() { + h(ctx) + } + }) +}