1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-21 23:53:22 +00:00
fiber/middleware/limiter/limiter_sliding.go
Muhammed Efe Çetin 1188144d78
🎉 v3: init
2022-05-31 17:35:49 +03:00

139 lines
3.8 KiB
Go

package limiter
import (
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v3"
)
type SlidingWindow struct{}
// New creates a new sliding window middleware handler
func (SlidingWindow) New(cfg Config) fiber.Handler {
var (
// Limiter variables
mux = &sync.RWMutex{}
max = strconv.Itoa(cfg.Max)
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
)
// Create manager to simplify storage operations ( see manager.go )
manager := newManager(cfg.Storage)
// Update timestamp every second
go func() {
for {
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(1 * time.Second)
}
}()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
// Lock entry
mux.Lock()
// Get entry from pool and release when finished
e := manager.get(key)
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// The entry has expired, handle the expiration.
// Set the prevHits to the current hits and reset the hits to 0.
e.prevHits = e.currHits
// Reset the current hits to 0.
e.currHits = 0
// Check how much into the current window it currently is and sets the
// expiry based on that, otherwise this would only reset on
// the next request and not show the correct expiry.
elapsed := ts - e.exp
if elapsed >= expiration {
e.exp = ts + expiration
} else {
e.exp = ts + expiration - elapsed
}
}
// Increment hits
e.currHits++
// Calculate when it resets in seconds
resetInSec := e.exp - ts
// weight = time until current window reset / total window length
weight := float64(resetInSec) / float64(expiration)
// rate = request count in previous window - weight + request count in current window
rate := int(float64(e.prevHits)*weight) + e.currHits
// Calculate how many hits can be made based on the current rate
remaining := cfg.Max - rate
// Update storage. Garbage collect when the next window ends.
// |--------------------------|--------------------------|
// ^ ^ ^ ^
// ts e.exp End sample window End next window
// <------------>
// resetInSec
// resetInSec = e.exp - ts - time until end of current window.
// duration + expiration = end of next window.
// Because we don't want to garbage collect in the middle of a window
// we add the expiration to the duration.
// Otherwise after the end of "sample window", attackers could launch
// a new request with the full window length.
manager.set(key, e, time.Duration(resetInSec+expiration)*time.Second)
// Unlock entry
mux.Unlock()
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(resetInSec, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// Continue stack for reaching c.Response().StatusCode()
// Store err for returning
err := c.Next()
// Check for SkipFailedRequests and SkipSuccessfulRequests
if (cfg.SkipSuccessfulRequests && c.Response().StatusCode() < fiber.StatusBadRequest) ||
(cfg.SkipFailedRequests && c.Response().StatusCode() >= fiber.StatusBadRequest) {
mux.Lock()
e.currHits--
remaining++
mux.Unlock()
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(resetInSec, 10))
return err
}
}