1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-21 19:53:19 +00:00

Fix limit middleware skip options (#1568)

* fix limit middleware skip options

* fix limiter middleware remaining count

* used constant StatusBadRequest instead of int 400
This commit is contained in:
Ali Eren Öztürk 2021-10-11 19:31:37 +03:00 committed by GitHub
parent 9eaa8b0c73
commit 9c37b4c1c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 10 deletions

View File

@ -50,10 +50,6 @@ func New(config ...Config) fiber.Handler {
return c.Next()
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Get key from request
key := cfg.KeyGenerator(c)
@ -76,12 +72,8 @@ func New(config ...Config) fiber.Handler {
e.exp = ts + expiration
}
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (!cfg.SkipSuccessfulRequests || c.Response().StatusCode() >= 400) &&
(!cfg.SkipFailedRequests || c.Response().StatusCode() < 400) {
// Increment hits
e.hits++
}
// Increment hits
e.hits++
// Calculate when it resets in seconds
expire := e.exp - ts
@ -105,6 +97,19 @@ func New(config ...Config) fiber.Handler {
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
mux.Lock()
e.hits--
remaining++
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))

View File

@ -107,6 +107,39 @@ func Test_Limiter_Concurrency(t *testing.T) {
}
// go test -run Test_Limiter_No_Skip_Choices -v
func Test_Limiter_No_Skip_Choices(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Max: 2,
Expiration: 2 * time.Second,
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
}
// go test -run Test_Limiter_Skip_Failed_Requests -v
func Test_Limiter_Skip_Failed_Requests(t *testing.T) {