1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-20 23:13:06 +00:00

Merge pull request #789 from Fenny/master

🩹 fix limiter racce condition
This commit is contained in:
Fenny 2020-09-15 20:57:16 +02:00 committed by GitHub
commit 9eba12c334
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 26 additions and 17 deletions

View File

@ -3,6 +3,7 @@ package limiter
import (
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
@ -90,9 +91,9 @@ func New(config ...Config) fiber.Handler {
// Limiter settings
var max = strconv.Itoa(cfg.Max)
var hits = make(map[string]int)
var reset = make(map[string]int)
var timestamp = int(time.Now().Unix())
var duration = int(cfg.Duration.Seconds())
var reset = make(map[string]uint64)
var timestamp = uint64(time.Now().Unix())
var duration = uint64(cfg.Duration.Seconds())
// mutex for parallel read and write access
mux := &sync.Mutex{}
@ -100,8 +101,8 @@ func New(config ...Config) fiber.Handler {
// Update timestamp every second
go func() {
for {
timestamp = int(time.Now().Unix())
time.Sleep(1 * time.Second)
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(time.Second)
}
}()
@ -119,11 +120,12 @@ func New(config ...Config) fiber.Handler {
mux.Lock()
// Set unix timestamp if not exist
ts := atomic.LoadUint64(&timestamp)
if reset[key] == 0 {
reset[key] = timestamp + duration
} else if timestamp >= reset[key] {
reset[key] = ts + duration
} else if ts >= reset[key] {
hits[key] = 0
reset[key] = timestamp + duration
reset[key] = ts + duration
}
// Increment key hits
@ -133,7 +135,7 @@ func New(config ...Config) fiber.Handler {
hitCount := hits[key]
// Calculate when it resets in seconds
resetTime := reset[key] - timestamp
resetTime := reset[key] - ts
// Unlock map
mux.Unlock()
@ -145,7 +147,7 @@ func New(config ...Config) fiber.Handler {
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.Itoa(resetTime))
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetTime, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
@ -154,7 +156,7 @@ func New(config ...Config) fiber.Handler {
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.Itoa(resetTime))
c.Set(xRateLimitReset, strconv.FormatUint(resetTime, 10))
// Continue stack
return c.Next()

View File

@ -2,7 +2,6 @@ package limiter
import (
"io/ioutil"
"math/rand"
"net/http"
"net/http/httptest"
"sync"
@ -20,13 +19,11 @@ func Test_Limiter_Concurrency(t *testing.T) {
app := fiber.New()
app.Use(New(Config{
Max: 100,
Duration: 1 * time.Minute,
Max: 50,
Duration: 2 * time.Second,
}))
app.Get("/", func(c *fiber.Ctx) error {
// random delay between the requests
time.Sleep(time.Duration(rand.Intn(10000)) * time.Microsecond)
return c.SendString("Hello tester!")
})
@ -46,12 +43,22 @@ func Test_Limiter_Concurrency(t *testing.T) {
}
}
for i := 0; i <= 50; i++ {
for i := 0; i <= 49; i++ {
wg.Add(1)
go singleRequest(&wg)
}
wg.Wait()
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 429, resp.StatusCode)
time.Sleep(3 * time.Second)
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, 200, resp.StatusCode)
}
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4