1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-23 22:23:47 +00:00
fiber/middleware/limiter/limiter.go
2020-11-20 11:43:07 +01:00

138 lines
2.9 KiB
Go

package limiter
import (
"strconv"
"sync/atomic"
"time"
"github.com/gofiber/fiber/v2"
)
const (
// Storage ErrNotExist
errNotExist = "key does not exist"
// X-RateLimit-* headers
xRateLimitLimit = "X-RateLimit-Limit"
xRateLimitRemaining = "X-RateLimit-Remaining"
xRateLimitReset = "X-RateLimit-Reset"
)
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := configDefault(config...)
var (
// Limiter settings
max = strconv.Itoa(cfg.Max)
timestamp = uint64(time.Now().Unix())
expiration = uint64(cfg.Expiration.Seconds())
// mux = &sync.RWMutex{}
// // Default store logic (if no Store is provided)
// entries = make(map[string]entry)
)
store := newStorage(&cfg)
// Update timestamp every second
go func() {
for {
atomic.StoreUint64(&timestamp, uint64(time.Now().Unix()))
time.Sleep(1 * time.Second)
}
}()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Get key from request
key := cfg.KeyGenerator(c)
e := store.get(key)
// // Create new entry
// entry := entry{}
// // Lock entry
// mux.Lock()
// defer mux.Unlock()
// // Use Storage if provided
// if cfg.Storage != nil {
// val, err := cfg.Storage.Get(key)
// if val != nil && len(val) > 0 {
// if _, err := entry.UnmarshalMsg(val); err != nil {
// return err
// }
// }
// if err != nil && err.Error() != errNotExist {
// fmt.Println("[LIMITER]", err.Error())
// }
// } else {
// entry = entries[key]
// }
// Get timestamp
ts := atomic.LoadUint64(&timestamp)
// Set expiration if entry does not exist
if e.exp == 0 {
e.exp = ts + expiration
} else if ts >= e.exp {
// Check if entry is expired
e.hits = 0
e.exp = ts + expiration
}
// Increment hits
e.hits++
store.set(key, e)
// // Use Storage if provided
// if cfg.Storage != nil {
// // Marshal entry to bytes
// val, err := entry.MarshalMsg(nil)
// if err != nil {
// return err
// }
// // Pass value to Storage
// if err = cfg.Storage.Set(key, val, cfg.Expiration); err != nil {
// return err
// }
// } else {
// entries[key] = entry
// }
// Calculate when it resets in seconds
expire := e.exp - ts
// Set how many hits we have left
remaining := cfg.Max - e.hits
// Check if hits exceed the cfg.Max
if remaining < 0 {
// Return response with Retry-After header
// https://tools.ietf.org/html/rfc6584
c.Set(fiber.HeaderRetryAfter, strconv.FormatUint(expire, 10))
// Call LimitReached handler
return cfg.LimitReached(c)
}
// We can continue, update RateLimit headers
c.Set(xRateLimitLimit, max)
c.Set(xRateLimitRemaining, strconv.Itoa(remaining))
c.Set(xRateLimitReset, strconv.FormatUint(expire, 10))
// Continue stack
return c.Next()
}
}