1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-19 04:27:53 +00:00
fiber/middleware/cors/cors_test.go
Muhammed Efe Çetin 1188144d78
🎉 v3: init
2022-05-31 17:35:49 +03:00

238 lines
7.2 KiB
Go

package cors
import (
"net/http/httptest"
"testing"
"github.com/gofiber/fiber/v3"
"github.com/gofiber/fiber/v3/utils"
"github.com/valyala/fasthttp"
)
func Test_CORS_Defaults(t *testing.T) {
app := fiber.New()
app.Use(New())
testDefaultOrEmptyConfig(t, app)
}
func Test_CORS_Empty_Config(t *testing.T) {
app := fiber.New()
app.Use(New(Config{}))
testDefaultOrEmptyConfig(t, app)
}
func testDefaultOrEmptyConfig(t *testing.T, app *fiber.App) {
t.Helper()
h := app.Handler()
// Test default GET response headers
ctx := &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
utils.AssertEqual(t, "*", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
// Test default OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodOptions)
h(ctx)
utils.AssertEqual(t, "GET,POST,HEAD,PUT,DELETE,PATCH", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowMethods)))
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
}
// go test -run -v Test_CORS_Wildcard
func Test_CORS_Wildcard(t *testing.T) {
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is *
app.Use(New(Config{
AllowOrigins: "*",
AllowCredentials: true,
MaxAge: 3600,
ExposeHeaders: "X-Request-ID",
AllowHeaders: "Authentication",
}))
// Get handler pointer
handler := app.Handler()
// Make request
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.Set(fiber.HeaderOrigin, "localhost")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
// Perform request
handler(ctx)
// Check result
utils.AssertEqual(t, "localhost", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "3600", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlMaxAge)))
utils.AssertEqual(t, "Authentication", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowHeaders)))
// Test non OPTIONS (preflight) response headers
ctx = &fasthttp.RequestCtx{}
ctx.Request.Header.SetMethod(fiber.MethodGet)
handler(ctx)
utils.AssertEqual(t, "true", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowCredentials)))
utils.AssertEqual(t, "X-Request-ID", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlExposeHeaders)))
}
// go test -run -v Test_CORS_Subdomain
func Test_CORS_Subdomain(t *testing.T) {
// New fiber instance
app := fiber.New()
// OPTIONS (preflight) response headers when AllowOrigins is set to a subdomain
app.Use("/", New(Config{AllowOrigins: "http://*.example.com"}))
// Get handler pointer
handler := app.Handler()
// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")
// Perform request
handler(ctx)
// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
ctx.Request.Reset()
ctx.Response.Reset()
// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://test.example.com")
handler(ctx)
utils.AssertEqual(t, "http://test.example.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
func Test_CORS_AllowOriginScheme(t *testing.T) {
tests := []struct {
reqOrigin, pattern string
shouldAllowOrigin bool
}{
{
pattern: "http://example.com",
reqOrigin: "http://example.com",
shouldAllowOrigin: true,
},
{
pattern: "https://example.com",
reqOrigin: "https://example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
reqOrigin: "https://example.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://bbb.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://*.aaa.example.com",
reqOrigin: "http://bbb.aaa.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://*.example.com:8080",
reqOrigin: "http://aaa.example.com:8080",
shouldAllowOrigin: true,
},
{
pattern: "http://example.com",
reqOrigin: "http://gofiber.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.aaa.example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: `http://1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890\
.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.1234567890.example.com`,
shouldAllowOrigin: false,
},
{
pattern: "http://example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
{
pattern: "https://*--aaa.bbb.com",
reqOrigin: "https://prod-preview--aaa.bbb.com",
shouldAllowOrigin: false,
},
{
pattern: "http://*.example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: true,
},
{
pattern: "http://foo.[a-z]*.example.com",
reqOrigin: "http://ccc.bbb.example.com",
shouldAllowOrigin: false,
},
}
for _, tt := range tests {
app := fiber.New()
app.Use("/", New(Config{AllowOrigins: tt.pattern}))
handler := app.Handler()
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, tt.reqOrigin)
handler(ctx)
if tt.shouldAllowOrigin {
utils.AssertEqual(t, tt.reqOrigin, string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
} else {
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}
}
}
// go test -run Test_CORS_Next
func Test_CORS_Next(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Next: func(_ *fiber.Ctx) bool {
return true
},
}))
resp, err := app.Test(httptest.NewRequest("GET", "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}