mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-21 06:33:11 +00:00
Fix csrf middleware behavior with header key lookup (#2063)
* 🐛 [Bug]: Strange CSRF middleware behavior with header KeyLookup configuration #2045
This commit is contained in:
parent
6026560c93
commit
ec96d161a0
@ -102,15 +102,17 @@ type Config struct {
|
||||
Extractor func(c *fiber.Ctx) (string, error)
|
||||
}
|
||||
|
||||
const HeaderName = "X-Csrf-Token"
|
||||
|
||||
// ConfigDefault is the default config
|
||||
var ConfigDefault = Config{
|
||||
KeyLookup: "header:X-Csrf-Token",
|
||||
KeyLookup: "header:" + HeaderName,
|
||||
CookieName: "csrf_",
|
||||
CookieSameSite: "Lax",
|
||||
Expiration: 1 * time.Hour,
|
||||
KeyGenerator: utils.UUID,
|
||||
ErrorHandler: defaultErrorHandler,
|
||||
Extractor: CsrfFromHeader("X-Csrf-Token"),
|
||||
Extractor: CsrfFromHeader(HeaderName),
|
||||
}
|
||||
|
||||
// default ErrorHandler that process return error from fiber.Handler
|
||||
|
@ -40,7 +40,7 @@ func Test_CSRF(t *testing.T) {
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
|
||||
ctx.Request.Header.Set(HeaderName, "johndoe")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 403, ctx.Response.StatusCode())
|
||||
|
||||
@ -55,7 +55,7 @@ func Test_CSRF(t *testing.T) {
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.Set("X-CSRF-Token", token)
|
||||
ctx.Request.Header.Set(HeaderName, token)
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 200, ctx.Response.StatusCode())
|
||||
}
|
||||
@ -305,7 +305,7 @@ func Test_CSRF_ErrorHandler_InvalidToken(t *testing.T) {
|
||||
ctx.Request.Reset()
|
||||
ctx.Response.Reset()
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.Set("X-CSRF-Token", "johndoe")
|
||||
ctx.Request.Header.Set(HeaderName, "johndoe")
|
||||
h(ctx)
|
||||
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "invalid CSRF token", string(ctx.Response.Body()))
|
||||
@ -340,3 +340,111 @@ func Test_CSRF_ErrorHandler_EmptyToken(t *testing.T) {
|
||||
utils.AssertEqual(t, 419, ctx.Response.StatusCode())
|
||||
utils.AssertEqual(t, "empty CSRF token", string(ctx.Response.Body()))
|
||||
}
|
||||
|
||||
// TODO: use this test case and make the unsafe header value bug from https://github.com/gofiber/fiber/issues/2045 reproducible and permanently fixed/tested by this testcase
|
||||
//func Test_CSRF_UnsafeHeaderValue(t *testing.T) {
|
||||
// app := fiber.New()
|
||||
//
|
||||
// app.Use(New())
|
||||
// app.Get("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
// app.Get("/test", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
// app.Post("/", func(c *fiber.Ctx) error {
|
||||
// return c.SendStatus(fiber.StatusOK)
|
||||
// })
|
||||
//
|
||||
// resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||
// utils.AssertEqual(t, nil, err)
|
||||
// utils.AssertEqual(t, fiber.StatusOK, resp.StatusCode)
|
||||
//
|
||||
// var token string
|
||||
// for _, c := range resp.Cookies() {
|
||||
// if c.Name != ConfigDefault.CookieName {
|
||||
// continue
|
||||
// }
|
||||
// token = c.Value
|
||||
// break
|
||||
// }
|
||||
//
|
||||
// fmt.Println("token", token)
|
||||
//
|
||||
// getReq := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
// getReq.Header.Set(HeaderName, token)
|
||||
// resp, err = app.Test(getReq)
|
||||
//
|
||||
// getReq = httptest.NewRequest(http.MethodGet, "/test", nil)
|
||||
// getReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
// getReq.Header.Set(fiber.HeaderCacheControl, "no")
|
||||
// getReq.Header.Set(HeaderName, token)
|
||||
//
|
||||
// resp, err = app.Test(getReq)
|
||||
//
|
||||
// getReq.Header.Set(fiber.HeaderAccept, "*/*")
|
||||
// getReq.Header.Del(HeaderName)
|
||||
// resp, err = app.Test(getReq)
|
||||
//
|
||||
// postReq := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
// postReq.Header.Set("X-Requested-With", "XMLHttpRequest")
|
||||
// postReq.Header.Set(HeaderName, token)
|
||||
// resp, err = app.Test(postReq)
|
||||
//}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_Check -benchmem -count=4
|
||||
func Benchmark_Middleware_CSRF_Check(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
h(ctx)
|
||||
token := string(ctx.Response.Header.Peek(fiber.HeaderSetCookie))
|
||||
token = strings.Split(strings.Split(token, ";")[0], "=")[1]
|
||||
|
||||
ctx.Request.Header.SetMethod("POST")
|
||||
ctx.Request.Header.Set(HeaderName, token)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Middleware_CSRF_GenerateToken -benchmem -count=4
|
||||
func Benchmark_Middleware_CSRF_GenerateToken(b *testing.B) {
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New())
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
return c.SendStatus(fiber.StatusTeapot)
|
||||
})
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
h := app.Handler()
|
||||
ctx := &fasthttp.RequestCtx{}
|
||||
|
||||
// Generate CSRF token
|
||||
ctx.Request.Header.SetMethod("GET")
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
}
|
||||
|
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gofiber/fiber/v2"
|
||||
"github.com/gofiber/fiber/v2/internal/memory"
|
||||
"github.com/gofiber/fiber/v2/utils"
|
||||
)
|
||||
|
||||
// go:generate msgp
|
||||
@ -88,7 +89,8 @@ func (m *manager) set(key string, it *item, exp time.Duration) {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
}
|
||||
} else {
|
||||
m.memory.Set(key, it, exp)
|
||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||
m.memory.Set(utils.CopyString(key), it, exp)
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,7 +99,8 @@ func (m *manager) setRaw(key string, raw []byte, exp time.Duration) {
|
||||
if m.storage != nil {
|
||||
_ = m.storage.Set(key, raw, exp)
|
||||
} else {
|
||||
m.memory.Set(key, raw, exp)
|
||||
// the key is crucial in crsf and sometimes a reference to another value which can be reused later(pool/unsafe values concept), so a copy is made here
|
||||
m.memory.Set(utils.CopyString(key), raw, exp)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user