mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-21 19:53:19 +00:00
Fix limit middleware skip options (#1568)
* fix limit middleware skip options * fix limiter middleware remaining count * used constant StatusBadRequest instead of int 400
This commit is contained in:
parent
9eaa8b0c73
commit
9c37b4c1c5
@ -50,10 +50,6 @@ func New(config ...Config) fiber.Handler {
|
||||
return c.Next()
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Get key from request
|
||||
key := cfg.KeyGenerator(c)
|
||||
|
||||
@ -76,12 +72,8 @@ func New(config ...Config) fiber.Handler {
|
||||
e.exp = ts + expiration
|
||||
}
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (!cfg.SkipSuccessfulRequests || c.Response().StatusCode() >= 400) &&
|
||||
(!cfg.SkipFailedRequests || c.Response().StatusCode() < 400) {
|
||||
// Increment hits
|
||||
e.hits++
|
||||
}
|
||||
// Increment hits
|
||||
e.hits++
|
||||
|
||||
// Calculate when it resets in seconds
|
||||
expire := e.exp - ts
|
||||
@ -105,6 +97,19 @@ func New(config ...Config) fiber.Handler {
|
||||
return cfg.LimitReached(c)
|
||||
}
|
||||
|
||||
// Continue stack for reaching c.Response().StatusCode()
|
||||
// Store err for returning
|
||||
err := c.Next()
|
||||
|
||||
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
|
||||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
|
||||
mux.Lock()
|
||||
e.hits--
|
||||
remaining++
|
||||
mux.Unlock()
|
||||
}
|
||||
|
||||
// We can continue, update RateLimit headers
|
||||
c.Set(xRateLimitLimit, max)
|
||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||
|
@ -107,6 +107,39 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
||||
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_No_Skip_Choices -v
|
||||
func Test_Limiter_No_Skip_Choices(t *testing.T) {
|
||||
|
||||
app := fiber.New()
|
||||
|
||||
app.Use(New(Config{
|
||||
Max: 2,
|
||||
Expiration: 2 * time.Second,
|
||||
SkipFailedRequests: false,
|
||||
SkipSuccessfulRequests: false,
|
||||
}))
|
||||
|
||||
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||
if c.Params("status") == "fail" {
|
||||
return c.SendStatus(400)
|
||||
}
|
||||
return c.SendStatus(200)
|
||||
})
|
||||
|
||||
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||
|
||||
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||
|
||||
}
|
||||
|
||||
// go test -run Test_Limiter_Skip_Failed_Requests -v
|
||||
func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user