1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-22 10:13:11 +00:00

limiter middleware new options: SkipFailedRequests, SkipSuccessfulRequests (#1542)

This commit is contained in:
Ali Eren Öztürk 2021-09-28 11:10:29 +03:00 committed by GitHub
parent d2f5e3a430
commit 2aef5f8e01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 114 additions and 5 deletions

View File

@ -109,6 +109,16 @@ type Config struct {
// } // }
LimitReached fiber.Handler LimitReached fiber.Handler
// When set to true, requests with StatusCode >= 400 won't be counted.
//
// Default: false
SkipFailedRequests bool
// When set to true, requests with StatusCode < 400 won't be counted.
//
// Default: false
SkipSuccessfulRequests bool
// Store is used to store the state of the middleware // Store is used to store the state of the middleware
// //
// Default: an in memory store for this process only // Default: an in memory store for this process only
@ -130,5 +140,7 @@ var ConfigDefault = Config{
LimitReached: func(c *fiber.Ctx) error { LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests) return c.SendStatus(fiber.StatusTooManyRequests)
}, },
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
} }
``` ```

View File

@ -38,6 +38,16 @@ type Config struct {
// } // }
LimitReached fiber.Handler LimitReached fiber.Handler
// When set to true, requests with StatusCode >= 400 won't be counted.
//
// Default: false
SkipFailedRequests bool
// When set to true, requests with StatusCode < 400 won't be counted.
//
// Default: false
SkipSuccessfulRequests bool
// Store is used to store the state of the middleware // Store is used to store the state of the middleware
// //
// Default: an in memory store for this process only // Default: an in memory store for this process only
@ -63,6 +73,8 @@ var ConfigDefault = Config{
LimitReached: func(c *fiber.Ctx) error { LimitReached: func(c *fiber.Ctx) error {
return c.SendStatus(fiber.StatusTooManyRequests) return c.SendStatus(fiber.StatusTooManyRequests)
}, },
SkipFailedRequests: false,
SkipSuccessfulRequests: false,
} }
// Helper function to set default values // Helper function to set default values

View File

@ -50,6 +50,10 @@ func New(config ...Config) fiber.Handler {
return c.Next() return c.Next()
} }
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Get key from request // Get key from request
key := cfg.KeyGenerator(c) key := cfg.KeyGenerator(c)
@ -72,8 +76,12 @@ func New(config ...Config) fiber.Handler {
e.exp = ts + expiration e.exp = ts + expiration
} }
// Increment hits // Check for SkipFailedRequests and SkipSuccessfulRequests
e.hits++ if (!cfg.SkipSuccessfulRequests || c.Response().StatusCode() >= 400) &&
(!cfg.SkipFailedRequests || c.Response().StatusCode() < 400) {
// Increment hits
e.hits++
}
// Calculate when it resets in seconds // Calculate when it resets in seconds
expire := e.exp - ts expire := e.exp - ts
@ -102,7 +110,6 @@ func New(config ...Config) fiber.Handler {
c.Set(xRateLimitRemaining, strconv.Itoa(remaining)) c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(expire, 10)) c.Set(xRateLimitReset, strconv.FormatUint(expire, 10))
// Continue stack return err
return c.Next()
} }
} }

View File

@ -14,7 +14,7 @@ import (
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
// go test -run Test_Limiter_Concurrency -race -v // go test -run Test_Limiter_Concurrency_Store -race -v
func Test_Limiter_Concurrency_Store(t *testing.T) { func Test_Limiter_Concurrency_Store(t *testing.T) {
// Test concurrency using a custom store // Test concurrency using a custom store
@ -107,6 +107,84 @@ func Test_Limiter_Concurrency(t *testing.T) {
} }
// go test -run Test_Limiter_Skip_Failed_Requests -v
func Test_Limiter_Skip_Failed_Requests(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipFailedRequests: true,
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -run Test_Limiter_Skip_Successful_Requests -v
func Test_Limiter_Skip_Successful_Requests(t *testing.T) {
// Test concurrency using a default store
app := fiber.New()
app.Use(New(Config{
Max: 1,
Expiration: 2 * time.Second,
SkipSuccessfulRequests: true,
}))
app.Get("/:status", func(c *fiber.Ctx) error {
if c.Params("status") == "fail" {
return c.SendStatus(400)
}
return c.SendStatus(200)
})
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/success", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/fail", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 400, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4 // go test -v -run=^$ -bench=Benchmark_Limiter_Custom_Store -benchmem -count=4
func Benchmark_Limiter_Custom_Store(b *testing.B) { func Benchmark_Limiter_Custom_Store(b *testing.B) {
app := fiber.New() app := fiber.New()