1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-06 22:31:54 +00:00

feat(middleware/csrf): TrustedOrigins using https://*.example.com style subdomains (#2925)

* feat(middleware/csrf): TrustedOrigins using https://*.example.com style subdomains

* Update middleware/csrf/csrf_test.go

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* test(middleware/csrf): parallel test

* test(middleware/csrf): parallel fix

* chmore(middleware/csrf): no pkg/log

* feat(middleware/csrf): Add tests for Trusted Origin deeply nested subdomain

* test(middleware/csrf): fix loop variable tt being captured

* docs(middleware/csrf): TrustedOrigins validates and normalizes note

* test(middleware/csrf): fix Benchmark_Middleware_CSRF_Check

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
This commit is contained in:
Jason McNeil 2024-03-25 11:29:37 -03:00 committed by GitHub
parent 95c181469d
commit 643b4b3f53
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 381 additions and 86 deletions

View File

@ -116,7 +116,7 @@ func (h *Handler) DeleteToken(c fiber.Ctx) error
| Storage | `fiber.Storage` | Store is used to store the state of the middleware. | `nil` |
| Session | `*session.Store` | Session is used to store the state of the middleware. Overrides Storage if set. | `nil` |
| SessionKey | `string` | SessionKey is the key used to store the token in the session. | "csrfToken" |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://.example.com" to allow any subdomain of example.com to submit requests. | `[]` |
| TrustedOrigins | `[]string` | TrustedOrigins is a list of trusted origins for unsafe requests. This supports subdomain matching, so you can use a value like "https://*.example.com" to allow any subdomain of example.com to submit requests. | `[]` |
### Default Config
@ -154,6 +154,36 @@ var ConfigDefault = Config{
}
```
### Trusted Origins
The `TrustedOrigins` option is used to specify a list of trusted origins for unsafe requests. This is useful when you want to allow requests from other origins. This supports matching subdomains at any level. This means you can use a value like `"https://*.example.com"` to allow any subdomain of `example.com` to submit requests, including multiple subdomain levels such as `"https://sub.sub.example.com"`.
To ensure that the provided `TrustedOrigins` origins are correctly formatted, this middleware validates and normalizes them. It checks for valid schemes, i.e., HTTP or HTTPS, and it will automatically remove trailing slashes. If the provided origin is invalid, the middleware will panic.
#### Example with Explicit Origins
In the following example, the CSRF middleware will allow requests from `trusted.example.com`, in addition to the current host.
```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://trusted.example.com"},
}))
```
#### Example with Subdomain Matching
In the following example, the CSRF middleware will allow requests from any subdomain of `example.com`, in addition to the current host.
```go
app.Use(csrf.New(csrf.Config{
TrustedOrigins: []string{"https://*.example.com"},
}))
```
::caution
When using `TrustedOrigins` with subdomain matching, make sure you control and trust all the subdomains, including all subdomain levels. If not, an attacker could create a subdomain under a trusted origin and use it to send harmful requests.
:::
## Constants
```go
@ -273,7 +303,6 @@ When HTTPS requests are protected by CSRF, referer checking is always carried ou
The Referer header is automatically included in requests by all modern browsers, including those made using the JS Fetch API. However, if you're making use of this middleware with a custom client, it's important to ensure that the client sends a valid Referer header.
:::
### Token Lifecycle
Tokens are valid until they expire or until they are deleted. By default, tokens are valid for 1 hour, and each subsequent request extends the expiration by 1 hour. The token only expires if the user doesn't make a request for the duration of the expiration time.

View File

@ -24,7 +24,7 @@ var (
// Handler for CSRF middleware
type Handler struct {
config *Config
config Config
sessionManager *sessionManager
storageManager *storageManager
}
@ -56,6 +56,36 @@ func New(config ...Config) fiber.Handler {
storageManager = newStorageManager(cfg.Storage)
}
// Pre-parse trusted origins
trustedOrigins := []string{}
trustedSubOrigins := []subdomain{}
for _, origin := range cfg.TrustedOrigins {
if i := strings.Index(origin, "://*."); i != -1 {
trimmedOrigin := strings.TrimSpace(origin[:i+3] + origin[i+4:])
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + origin)
}
sd := subdomain{prefix: normalizedOrigin[:i+3], suffix: normalizedOrigin[i+3:]}
trustedSubOrigins = append(trustedSubOrigins, sd)
} else {
trimmedOrigin := strings.TrimSpace(origin)
isValid, normalizedOrigin := normalizeOrigin(trimmedOrigin)
if !isValid {
panic("[CSRF] Invalid origin format in configuration:" + origin)
}
trustedOrigins = append(trustedOrigins, normalizedOrigin)
}
}
// Create the handler outside of the returned function
handler := &Handler{
config: cfg,
sessionManager: sessionManager,
storageManager: storageManager,
}
// Return new handler
return func(c fiber.Ctx) error {
// Don't execute middleware if Next returns true
@ -64,11 +94,7 @@ func New(config ...Config) fiber.Handler {
}
// Store the CSRF handler in the context
c.Locals(handlerKey, &Handler{
config: &cfg,
sessionManager: sessionManager,
storageManager: storageManager,
})
c.Locals(handlerKey, handler)
var token string
@ -88,12 +114,12 @@ func New(config ...Config) fiber.Handler {
// Assume that anything not defined as 'safe' by RFC7231 needs protection
// Enforce an origin check for unsafe requests.
err := originMatchesHost(c, cfg.TrustedOrigins)
err := originMatchesHost(c, trustedOrigins, trustedSubOrigins)
// If there's no origin, enforce a referer check for HTTPS connections.
if errors.Is(err, errOriginNotFound) {
if c.Scheme() == "https" {
err = refererMatchesHost(c, cfg.TrustedOrigins)
err = refererMatchesHost(c, trustedOrigins, trustedSubOrigins)
} else {
// If it's not HTTPS, clear the error to allow the request to proceed.
err = nil
@ -237,20 +263,15 @@ func setCSRFCookie(c fiber.Ctx, cfg Config, token string, expiry time.Duration)
// DeleteToken removes the token found in the context from the storage
// and expires the CSRF cookie
func (handler *Handler) DeleteToken(c fiber.Ctx) error {
// Get the config from the context
config := handler.config
if config == nil {
panic("CSRF Handler config not found in context")
}
// Extract token from the client request cookie
cookieToken := c.Cookies(config.CookieName)
cookieToken := c.Cookies(handler.config.CookieName)
if cookieToken == "" {
return config.ErrorHandler(c, ErrTokenNotFound)
return handler.config.ErrorHandler(c, ErrTokenNotFound)
}
// Remove the token from storage
deleteTokenFromStorage(c, cookieToken, *config, handler.sessionManager, handler.storageManager)
deleteTokenFromStorage(c, cookieToken, handler.config, handler.sessionManager, handler.storageManager)
// Expire the cookie
expireCSRFCookie(c, *config)
expireCSRFCookie(c, handler.config)
return nil
}
@ -262,8 +283,8 @@ func isFromCookie(extractor any) bool {
// originMatchesHost checks that the origin header matches the host header
// returns an error if the origin header is not present or is invalid
// returns nil if the origin header is valid
func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
origin := c.Get(fiber.HeaderOrigin)
func originMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
origin := strings.ToLower(c.Get(fiber.HeaderOrigin))
if origin == "" || origin == "null" { // "null" is set by some browsers when the origin is a secure context https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Origin#description
return errOriginNotFound
}
@ -273,23 +294,31 @@ func originMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return ErrOriginInvalid
}
if originURL.Host != c.Host() {
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, origin) {
if originURL.Scheme == c.Scheme() && originURL.Host == c.Host() {
return nil
}
}
return ErrOriginNoMatch
}
for _, trustedOrigin := range trustedOrigins {
if origin == trustedOrigin {
return nil
}
}
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(origin) {
return nil
}
}
return ErrOriginNoMatch
}
// refererMatchesHost checks that the referer header matches the host header
// returns an error if the referer header is not present or is invalid
// returns nil if the referer header is valid
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
referer := c.Get(fiber.HeaderReferer)
func refererMatchesHost(c fiber.Ctx, trustedOrigins []string, trustedSubOrigins []subdomain) error {
referer := strings.ToLower(c.Get(fiber.HeaderReferer))
if referer == "" {
return ErrRefererNotFound
}
@ -299,41 +328,23 @@ func refererMatchesHost(c fiber.Ctx, trustedOrigins []string) error {
return ErrRefererInvalid
}
if refererURL.Host != c.Host() {
if refererURL.Scheme == c.Scheme() && refererURL.Host == c.Host() {
return nil
}
referer = refererURL.String()
for _, trustedOrigin := range trustedOrigins {
if isTrustedSchemeAndDomain(trustedOrigin, referer) {
if referer == trustedOrigin {
return nil
}
}
for _, trustedSubOrigin := range trustedSubOrigins {
if trustedSubOrigin.match(referer) {
return nil
}
}
return ErrRefererNoMatch
}
return nil
}
// isTrustedSchemeAndDomain checks if the trustedProtoDomain is the same as the protoDomain
// or if the protoDomain is a subdomain of the trustedProtoDomain where trustedProtoDomain
// is prefixed with "https://." or "http://."
func isTrustedSchemeAndDomain(trustedProtoDomain, protoDomain string) bool {
if trustedProtoDomain == protoDomain {
return true
}
// Use constant prefixes for better readability and avoid magic numbers.
const httpsPrefix = "https://."
const httpPrefix = "http://."
if strings.HasPrefix(trustedProtoDomain, httpsPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpsPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "https://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
}
if strings.HasPrefix(trustedProtoDomain, httpPrefix) {
trustedProtoDomain = trustedProtoDomain[len(httpPrefix):]
protoDomain = strings.TrimPrefix(protoDomain, "http://")
return strings.HasSuffix(protoDomain, "."+trustedProtoDomain)
}
return false
}

View File

@ -733,18 +733,6 @@ func Test_CSRF_Origin(t *testing.T) {
h(ctx)
require.Equal(t, 403, ctx.Response.StatusCode())
// Test Correct Origin with path
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "http")
ctx.Request.Header.Set(fiber.HeaderXForwardedHost, "example.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example.com/action/items?gogogo=true")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Wrong Origin
ctx.Request.Reset()
ctx.Response.Reset()
@ -767,8 +755,8 @@ func Test_CSRF_TrustedOrigins(t *testing.T) {
TrustedOrigins: []string{
"http://safe.example.com",
"https://safe.example.com",
"http://.domain-1.com",
"https://.domain-1.com",
"http://*.domain-1.com",
"https://*.domain-1.com",
},
}))
@ -812,6 +800,20 @@ func Test_CSRF_TrustedOrigins(t *testing.T) {
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Origin deeply nested subdomain
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("a.b.c.domain-1.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("a.b.c.domain-1.com")
ctx.Request.Header.Set(fiber.HeaderOrigin, "https://a.b.c.domain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Origin Invalid
ctx.Request.Reset()
ctx.Response.Reset()
@ -856,6 +858,21 @@ func Test_CSRF_TrustedOrigins(t *testing.T) {
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Referer deeply nested subdomain
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("a.b.c.domain-1.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("a.b.c.domain-1.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://a.b.c.domain-1.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
h(ctx)
require.Equal(t, 200, ctx.Response.StatusCode())
// Test Trusted Referer Invalid
ctx.Request.Reset()
ctx.Response.Reset()
@ -872,6 +889,37 @@ func Test_CSRF_TrustedOrigins(t *testing.T) {
require.Equal(t, 403, ctx.Response.StatusCode())
}
func Test_CSRF_TrustedOrigins_InvalidOrigins(t *testing.T) {
t.Parallel()
tests := []struct {
name string
origin string
}{
{"No Scheme", "localhost"},
{"Wildcard", "https://*"},
{"Wildcard domain", "https://*example.com"},
{"File Scheme", "file://example.com"},
{"FTP Scheme", "ftp://example.com"},
{"Port Wildcard", "http://example.com:*"},
{"Multiple Wildcards", "https://*.*.com"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
origin := tt.origin
t.Parallel()
require.Panics(t, func() {
app := fiber.New()
app.Use(New(Config{
CookieSecure: true,
TrustedOrigins: []string{origin},
}))
}, "Expected panic")
})
}
}
func Test_CSRF_Referer(t *testing.T) {
t.Parallel()
app := fiber.New()
@ -979,6 +1027,18 @@ func Test_CSRF_DeleteToken(t *testing.T) {
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
// DeleteToken after token generation and remove the cookie
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.Set(HeaderName, "")
handler := HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
ctx.Request.Header.DelAllCookies()
err := handler.DeleteToken(app.AcquireCtx(ctx))
require.ErrorIs(t, err, ErrTokenNotFound)
}
h(ctx)
// Generate CSRF token
ctx.Request.Header.SetMethod(fiber.MethodGet)
h(ctx)
@ -991,7 +1051,7 @@ func Test_CSRF_DeleteToken(t *testing.T) {
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
handler := HandlerFromContext(app.AcquireCtx(ctx))
handler = HandlerFromContext(app.AcquireCtx(ctx))
if handler != nil {
if err := handler.DeleteToken(app.AcquireCtx(ctx)); err != nil {
t.Fatal(err)
@ -1270,7 +1330,10 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
return c.SendStatus(fiber.StatusTeapot)
})
fctx := &fasthttp.RequestCtx{}
app.Post("/", func(c fiber.Ctx) error {
return c.SendStatus(fiber.StatusTeapot)
})
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
@ -1280,17 +1343,27 @@ func Benchmark_Middleware_CSRF_Check(b *testing.B) {
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
// Test Correct Referer POST
ctx.Request.Reset()
ctx.Response.Reset()
ctx.Request.Header.SetMethod(fiber.MethodPost)
ctx.Request.Header.Set(fiber.HeaderXForwardedProto, "https")
ctx.Request.URI().SetScheme("https")
ctx.Request.URI().SetHost("example.com")
ctx.Request.Header.SetProtocol("https")
ctx.Request.Header.SetHost("example.com")
ctx.Request.Header.Set(fiber.HeaderReferer, "https://example.com")
ctx.Request.Header.Set(HeaderName, token)
ctx.Request.Header.SetCookie(ConfigDefault.CookieName, token)
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
h(ctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
}
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
@ -1302,7 +1375,6 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
return c.SendStatus(fiber.StatusTeapot)
})
fctx := &fasthttp.RequestCtx{}
h := app.Handler()
ctx := &fasthttp.RequestCtx{}
@ -1312,10 +1384,11 @@ func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
h(fctx)
h(ctx)
}
require.Equal(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
// Ensure the GET request returns a 418 status code
require.Equal(b, fiber.StatusTeapot, ctx.Response.Header.StatusCode())
}
func Test_CSRF_InvalidURLHeaders(t *testing.T) {

View File

@ -2,6 +2,8 @@ package csrf
import (
"crypto/subtle"
"net/url"
"strings"
)
func compareTokens(a, b []byte) bool {
@ -11,3 +13,45 @@ func compareTokens(a, b []byte) bool {
func compareStrings(a, b string) bool {
return subtle.ConstantTimeCompare([]byte(a), []byte(b)) == 1
}
// normalizeOrigin checks if the provided origin is in a correct format
// and normalizes it by removing any path or trailing slash.
// It returns a boolean indicating whether the origin is valid
// and the normalized origin.
func normalizeOrigin(origin string) (bool, string) {
parsedOrigin, err := url.Parse(origin)
if err != nil {
return false, ""
}
// Validate the scheme is either http or https
if parsedOrigin.Scheme != "http" && parsedOrigin.Scheme != "https" {
return false, ""
}
// Don't allow a wildcard with a protocol
// wildcards cannot be used within any other value. For example, the following header is not valid:
// Access-Control-Allow-Origin: https://*
if strings.Contains(parsedOrigin.Host, "*") {
return false, ""
}
// Validate there is a host present. The presence of a path, query, or fragment components
// is checked, but a trailing "/" (indicative of the root) is allowed for the path and will be normalized
if parsedOrigin.Host == "" || (parsedOrigin.Path != "" && parsedOrigin.Path != "/") || parsedOrigin.RawQuery != "" || parsedOrigin.Fragment != "" {
return false, ""
}
// Normalize the origin by constructing it from the scheme and host.
// The path or trailing slash is not included in the normalized origin.
return true, strings.ToLower(parsedOrigin.Scheme) + "://" + strings.ToLower(parsedOrigin.Host)
}
type subdomain struct {
prefix string
suffix string
}
func (s subdomain) match(o string) bool {
return len(o) >= len(s.prefix)+len(s.suffix) && strings.HasPrefix(o, s.prefix) && strings.HasSuffix(o, s.suffix)
}

View File

@ -0,0 +1,138 @@
package csrf
import (
"testing"
"github.com/stretchr/testify/assert"
)
// go test -run -v Test_normalizeOrigin
func Test_normalizeOrigin(t *testing.T) {
testCases := []struct {
origin string
expectedValid bool
expectedOrigin string
}{
{"http://example.com", true, "http://example.com"}, // Simple case should work.
{"HTTP://EXAMPLE.COM", true, "http://example.com"}, // Case should be normalized.
{"http://example.com/", true, "http://example.com"}, // Trailing slash should be removed.
{"http://example.com:3000", true, "http://example.com:3000"}, // Port should be preserved.
{"http://example.com:3000/", true, "http://example.com:3000"}, // Trailing slash should be removed.
{"http://", false, ""}, // Invalid origin should not be accepted.
{"file:///etc/passwd", false, ""}, // File scheme should not be accepted.
{"https://*example.com", false, ""}, // Wildcard domain should not be accepted.
{"http://*.example.com", false, ""}, // Wildcard subdomain should not be accepted.
{"http://example.com/path", false, ""}, // Path should not be accepted.
{"http://example.com?query=123", false, ""}, // Query should not be accepted.
{"http://example.com#fragment", false, ""}, // Fragment should not be accepted.
{"http://localhost", true, "http://localhost"}, // Localhost should be accepted.
{"http://127.0.0.1", true, "http://127.0.0.1"}, // IPv4 address should be accepted.
{"http://[::1]", true, "http://[::1]"}, // IPv6 address should be accepted.
{"http://[::1]:8080", true, "http://[::1]:8080"}, // IPv6 address with port should be accepted.
{"http://[::1]:8080/", true, "http://[::1]:8080"}, // IPv6 address with port and trailing slash should be accepted.
{"http://[::1]:8080/path", false, ""}, // IPv6 address with port and path should not be accepted.
{"http://[::1]:8080?query=123", false, ""}, // IPv6 address with port and query should not be accepted.
{"http://[::1]:8080#fragment", false, ""}, // IPv6 address with port and fragment should not be accepted.
{"http://[::1]:8080/path?query=123#fragment", false, ""}, // IPv6 address with port, path, query, and fragment should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/", false, ""}, // IPv6 address with port, path, query, fragment, and trailing slash should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/invalid", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/invalid/", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with trailing slash should not be accepted.
{"http://[::1]:8080/path?query=123#fragment/invalid/segment", false, ""}, // IPv6 address with port, path, query, fragment, trailing slash, and invalid segment with additional segment should not be accepted.
}
for _, tc := range testCases {
valid, normalizedOrigin := normalizeOrigin(tc.origin)
if valid != tc.expectedValid {
t.Errorf("Expected origin '%s' to be valid: %v, but got: %v", tc.origin, tc.expectedValid, valid)
}
if normalizedOrigin != tc.expectedOrigin {
t.Errorf("Expected normalized origin '%s' for origin '%s', but got: '%s'", tc.expectedOrigin, tc.origin, normalizedOrigin)
}
}
}
// go test -run -v TestSubdomainMatch
func TestSubdomainMatch(t *testing.T) {
tests := []struct {
name string
sub subdomain
origin string
expected bool
}{
{
name: "match with different scheme",
sub: subdomain{prefix: "http://api.", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: false,
},
{
name: "match with different scheme",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "http://api.service.example.com",
expected: false,
},
{
name: "match with valid subdomain",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://api.service.example.com",
expected: true,
},
{
name: "match with valid nested subdomain",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://1.2.api.service.example.com",
expected: true,
},
{
name: "no match with invalid prefix",
sub: subdomain{prefix: "https://abc.", suffix: ".example.com"},
origin: "https://service.example.com",
expected: false,
},
{
name: "no match with invalid suffix",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "https://api.example.org",
expected: false,
},
{
name: "no match with empty origin",
sub: subdomain{prefix: "https://", suffix: ".example.com"},
origin: "",
expected: false,
},
{
name: "partial match not considered a match",
sub: subdomain{prefix: "https://service.", suffix: ".example.com"},
origin: "https://api.example.com",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.sub.match(tt.origin)
assert.Equal(t, tt.expected, got, "subdomain.match()")
})
}
}
// go test -v -run=^$ -bench=Benchmark_CSRF_SubdomainMatch -benchmem -count=4
func Benchmark_CSRF_SubdomainMatch(b *testing.B) {
s := subdomain{
prefix: "www",
suffix: ".example.com",
}
o := "www.example.com"
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s.match(o)
}
}