mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-22 10:13:11 +00:00
limiter middleware new options: SkipFailedRequests, SkipSuccessfulRequests (#1542)
This commit is contained in:
parent
d2f5e3a430
commit
2aef5f8e01
@ -109,6 +109,16 @@ type Config struct {
|
|||||||
// }
|
// }
|
||||||
LimitReached fiber.Handler
|
LimitReached fiber.Handler
|
||||||
|
|
||||||
|
// When set to true, requests with StatusCode >= 400 won't be counted.
|
||||||
|
//
|
||||||
|
// Default: false
|
||||||
|
SkipFailedRequests bool
|
||||||
|
|
||||||
|
// When set to true, requests with StatusCode < 400 won't be counted.
|
||||||
|
//
|
||||||
|
// Default: false
|
||||||
|
SkipSuccessfulRequests bool
|
||||||
|
|
||||||
// Store is used to store the state of the middleware
|
// Store is used to store the state of the middleware
|
||||||
//
|
//
|
||||||
// Default: an in memory store for this process only
|
// Default: an in memory store for this process only
|
||||||
@ -130,5 +140,7 @@ var ConfigDefault = Config{
|
|||||||
LimitReached: func(c *fiber.Ctx) error {
|
LimitReached: func(c *fiber.Ctx) error {
|
||||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||||
},
|
},
|
||||||
|
SkipFailedRequests: false,
|
||||||
|
SkipSuccessfulRequests: false,
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -38,6 +38,16 @@ type Config struct {
|
|||||||
// }
|
// }
|
||||||
LimitReached fiber.Handler
|
LimitReached fiber.Handler
|
||||||
|
|
||||||
|
// When set to true, requests with StatusCode >= 400 won't be counted.
|
||||||
|
//
|
||||||
|
// Default: false
|
||||||
|
SkipFailedRequests bool
|
||||||
|
|
||||||
|
// When set to true, requests with StatusCode < 400 won't be counted.
|
||||||
|
//
|
||||||
|
// Default: false
|
||||||
|
SkipSuccessfulRequests bool
|
||||||
|
|
||||||
// Store is used to store the state of the middleware
|
// Store is used to store the state of the middleware
|
||||||
//
|
//
|
||||||
// Default: an in memory store for this process only
|
// Default: an in memory store for this process only
|
||||||
@ -63,6 +73,8 @@ var ConfigDefault = Config{
|
|||||||
LimitReached: func(c *fiber.Ctx) error {
|
LimitReached: func(c *fiber.Ctx) error {
|
||||||
return c.SendStatus(fiber.StatusTooManyRequests)
|
return c.SendStatus(fiber.StatusTooManyRequests)
|
||||||
},
|
},
|
||||||
|
SkipFailedRequests: false,
|
||||||
|
SkipSuccessfulRequests: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Helper function to set default values
|
// Helper function to set default values
|
||||||
|
@ -50,6 +50,10 @@ func New(config ...Config) fiber.Handler {
|
|||||||
return c.Next()
|
return c.Next()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Continue stack for reaching c.Response().StatusCode()
|
||||||
|
// Store err for returning
|
||||||
|
err := c.Next()
|
||||||
|
|
||||||
// Get key from request
|
// Get key from request
|
||||||
key := cfg.KeyGenerator(c)
|
key := cfg.KeyGenerator(c)
|
||||||
|
|
||||||
@ -72,8 +76,12 @@ func New(config ...Config) fiber.Handler {
|
|||||||
e.exp = ts + expiration
|
e.exp = ts + expiration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment hits
|
// Check for SkipFailedRequests and SkipSuccessfulRequests
|
||||||
e.hits++
|
if (!cfg.SkipSuccessfulRequests || c.Response().StatusCode() >= 400) &&
|
||||||
|
(!cfg.SkipFailedRequests || c.Response().StatusCode() < 400) {
|
||||||
|
// Increment hits
|
||||||
|
e.hits++
|
||||||
|
}
|
||||||
|
|
||||||
// Calculate when it resets in seconds
|
// Calculate when it resets in seconds
|
||||||
expire := e.exp - ts
|
expire := e.exp - ts
|
||||||
@ -102,7 +110,6 @@ func New(config ...Config) fiber.Handler {
|
|||||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||||
c.Set(xRateLimitReset, strconv.FormatUint(expire, 10))
|
c.Set(xRateLimitReset, strconv.FormatUint(expire, 10))
|
||||||
|
|
||||||
// Continue stack
|
return err
|
||||||
return c.Next()
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,7 +14,7 @@ import (
|
|||||||
"github.com/valyala/fasthttp"
|
"github.com/valyala/fasthttp"
|
||||||
)
|
)
|
||||||
|
|
||||||
// go test -run Test_Limiter_Concurrency -race -v
|
// go test -run Test_Limiter_Concurrency_Store -race -v
|
||||||
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
func Test_Limiter_Concurrency_Store(t *testing.T) {
|
||||||
// Test concurrency using a custom store
|
// Test concurrency using a custom store
|
||||||
|
|
||||||
@ -107,6 +107,84 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// go test -run Test_Limiter_Skip_Failed_Requests -v
|
||||||
|
func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Use(New(Config{
|
||||||
|
Max: 1,
|
||||||
|
Expiration: 2 * time.Second,
|
||||||
|
SkipFailedRequests: true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||||
|
if c.Params("status") == "fail" {
|
||||||
|
return c.SendStatus(400)
|
||||||
|
}
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// go test -run Test_Limiter_Skip_Successful_Requests -v
|
||||||
|
func Test_Limiter_Skip_Successful_Requests(t *testing.T) {
|
||||||
|
|
||||||
|
// Test concurrency using a default store
|
||||||
|
|
||||||
|
app := fiber.New()
|
||||||
|
|
||||||
|
app.Use(New(Config{
|
||||||
|
Max: 1,
|
||||||
|
Expiration: 2 * time.Second,
|
||||||
|
SkipSuccessfulRequests: true,
|
||||||
|
}))
|
||||||
|
|
||||||
|
app.Get("/:status", func(c *fiber.Ctx) error {
|
||||||
|
if c.Params("status") == "fail" {
|
||||||
|
return c.SendStatus(400)
|
||||||
|
}
|
||||||
|
return c.SendStatus(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 400, resp.StatusCode)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
|
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
|
||||||
func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
func Benchmark_Limiter_Custom_Store(b *testing.B) {
|
||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user