mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-21 20:33:08 +00:00
Merge pull request #789 from Fenny/master
🩹 fix limiter racce condition
This commit is contained in:
commit
9eba12c334
@ -3,6 +3,7 @@ package limiter
|
|||||||
import (
|
import (
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gofiber/fiber/v2"
|
"github.com/gofiber/fiber/v2"
|
||||||
@ -90,9 +91,9 @@ func New(config ...Config) fiber.Handler {
|
|||||||
// Limiter settings
|
// Limiter settings
|
||||||
var max = strconv.Itoa(cfg.Max)
|
var max = strconv.Itoa(cfg.Max)
|
||||||
var hits = make(map[string]int)
|
var hits = make(map[string]int)
|
||||||
var reset = make(map[string]int)
|
var reset = make(map[string]uint64)
|
||||||
var timestamp = int(time.Now().Unix())
|
var timestamp = uint64(time.Now().Unix())
|
||||||
var duration = int(cfg.Duration.Seconds())
|
var duration = uint64(cfg.Duration.Seconds())
|
||||||
|
|
||||||
// mutex for parallel read and write access
|
// mutex for parallel read and write access
|
||||||
mux := &sync.Mutex{}
|
mux := &sync.Mutex{}
|
||||||
@ -100,8 +101,8 @@ func New(config ...Config) fiber.Handler {
|
|||||||
// Update timestamp every second
|
// Update timestamp every second
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
timestamp = int(time.Now().Unix())
|
atomic.StoreUint64(×tamp, uint64(time.Now().Unix()))
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@ -119,11 +120,12 @@ func New(config ...Config) fiber.Handler {
|
|||||||
mux.Lock()
|
mux.Lock()
|
||||||
|
|
||||||
// Set unix timestamp if not exist
|
// Set unix timestamp if not exist
|
||||||
|
ts := atomic.LoadUint64(×tamp)
|
||||||
if reset[key] == 0 {
|
if reset[key] == 0 {
|
||||||
reset[key] = timestamp + duration
|
reset[key] = ts + duration
|
||||||
} else if timestamp >= reset[key] {
|
} else if ts >= reset[key] {
|
||||||
hits[key] = 0
|
hits[key] = 0
|
||||||
reset[key] = timestamp + duration
|
reset[key] = ts + duration
|
||||||
}
|
}
|
||||||
|
|
||||||
// Increment key hits
|
// Increment key hits
|
||||||
@ -133,7 +135,7 @@ func New(config ...Config) fiber.Handler {
|
|||||||
hitCount := hits[key]
|
hitCount := hits[key]
|
||||||
|
|
||||||
// Calculate when it resets in seconds
|
// Calculate when it resets in seconds
|
||||||
resetTime := reset[key] - timestamp
|
resetTime := reset[key] - ts
|
||||||
|
|
||||||
// Unlock map
|
// Unlock map
|
||||||
mux.Unlock()
|
mux.Unlock()
|
||||||
@ -145,7 +147,7 @@ func New(config ...Config) fiber.Handler {
|
|||||||
if remaining < 0 {
|
if remaining < 0 {
|
||||||
// Return response with Retry-After header
|
// Return response with Retry-After header
|
||||||
// https://tools.ietf.org/html/rfc6584
|
// https://tools.ietf.org/html/rfc6584
|
||||||
c.Set(fiber.HeaderRetryAfter, strconv.Itoa(resetTime))
|
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetTime, 10))
|
||||||
|
|
||||||
// Call LimitReached handler
|
// Call LimitReached handler
|
||||||
return cfg.LimitReached(c)
|
return cfg.LimitReached(c)
|
||||||
@ -154,7 +156,7 @@ func New(config ...Config) fiber.Handler {
|
|||||||
// We can continue, update RateLimit headers
|
// We can continue, update RateLimit headers
|
||||||
c.Set(xRateLimitLimit, max)
|
c.Set(xRateLimitLimit, max)
|
||||||
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
|
||||||
c.Set(xRateLimitReset, strconv.Itoa(resetTime))
|
c.Set(xRateLimitReset, strconv.FormatUint(resetTime, 10))
|
||||||
|
|
||||||
// Continue stack
|
// Continue stack
|
||||||
return c.Next()
|
return c.Next()
|
||||||
|
@ -2,7 +2,6 @@ package limiter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"math/rand"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"sync"
|
"sync"
|
||||||
@ -20,13 +19,11 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||||||
app := fiber.New()
|
app := fiber.New()
|
||||||
|
|
||||||
app.Use(New(Config{
|
app.Use(New(Config{
|
||||||
Max: 100,
|
Max: 50,
|
||||||
Duration: 1 * time.Minute,
|
Duration: 2 * time.Second,
|
||||||
}))
|
}))
|
||||||
|
|
||||||
app.Get("/", func(c *fiber.Ctx) error {
|
app.Get("/", func(c *fiber.Ctx) error {
|
||||||
// random delay between the requests
|
|
||||||
time.Sleep(time.Duration(rand.Intn(10000)) * time.Microsecond)
|
|
||||||
return c.SendString("Hello tester!")
|
return c.SendString("Hello tester!")
|
||||||
})
|
})
|
||||||
|
|
||||||
@ -46,12 +43,22 @@ func Test_Limiter_Concurrency(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i <= 50; i++ {
|
for i := 0; i <= 49; i++ {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go singleRequest(&wg)
|
go singleRequest(&wg)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
|
resp, err := app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 429, resp.StatusCode)
|
||||||
|
|
||||||
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
|
resp, err = app.Test(httptest.NewRequest(http.MethodGet, "/", nil))
|
||||||
|
utils.AssertEqual(t, nil, err)
|
||||||
|
utils.AssertEqual(t, 200, resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
// go test -v -run=^$ -bench=Benchmark_Limiter -benchmem -count=4
|
||||||
|
Loading…
x
Reference in New Issue
Block a user