1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-23 16:43:41 +00:00
2020-10-08 18:11:26 +02:00

156 lines
3.3 KiB
Go

// Special thanks to @codemicro for moving this to fiber core
// Original middleware: github.com/codemicro/fiber-cache
package cache
import (
"strconv"
"sync"
"time"
"github.com/gofiber/fiber/v2"
)
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool
// Expiration is the time that an cached response will live
//
// Optional. Default: 1 * time.Minute
Expiration time.Duration
// CacheControl enables client side caching if set to true
//
// Optional. Default: false
CacheControl bool
}
// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
Expiration: 1 * time.Minute,
CacheControl: false,
}
// cache is the manager to store the cached responses
type cache struct {
sync.RWMutex
entries map[string]entry
expiration int64
}
// entry defines the cached response
type entry struct {
body []byte
contentType []byte
statusCode int
expiration int64
}
// New creates a new middleware handler
func New(config ...Config) fiber.Handler {
// Set default config
cfg := ConfigDefault
// Override config if provided
if len(config) > 0 {
cfg = config[0]
// Set default values
if cfg.Next == nil {
cfg.Next = ConfigDefault.Next
}
if int(cfg.Expiration.Seconds()) == 0 {
cfg.Expiration = ConfigDefault.Expiration
}
}
// Nothing to cache
if int(cfg.Expiration.Seconds()) < 0 {
return func(c *fiber.Ctx) error {
return c.Next()
}
}
// Initialize db
db := &cache{
entries: make(map[string]entry),
expiration: int64(cfg.Expiration.Seconds()),
}
// Remove expired entries
go func() {
for {
// GC the entries every 10 seconds to avoid
time.Sleep(10 * time.Second)
db.Lock()
for k := range db.entries {
if time.Now().Unix() >= db.entries[k].expiration {
delete(db.entries, k)
}
}
db.Unlock()
}
}()
// Return new handler
return func(c *fiber.Ctx) error {
// Don't execute middleware if Next returns true
if cfg.Next != nil && cfg.Next(c) {
return c.Next()
}
// Only cache GET methods
if c.Method() != fiber.MethodGet {
return c.Next()
}
// Get key from request
key := c.Path()
// Find cached entry
db.RLock()
resp, ok := db.entries[key]
db.RUnlock()
if ok {
// Check if entry is expired
if time.Now().Unix() >= resp.expiration {
db.Lock()
delete(db.entries, key)
db.Unlock()
} else {
// Set response headers from cache
c.Response().SetBodyRaw(resp.body)
c.Response().SetStatusCode(resp.statusCode)
c.Response().Header.SetContentTypeBytes(resp.contentType)
// Set Cache-Control header if enabled
if cfg.CacheControl {
maxAge := strconv.FormatInt(resp.expiration-time.Now().Unix(), 10)
c.Set(fiber.HeaderCacheControl, "public, max-age="+maxAge)
}
return nil
}
}
// Continue stack, return err to Fiber if exist
if err := c.Next(); err != nil {
return err
}
// Cache response
db.Lock()
db.entries[key] = entry{
body: c.Response().Body(),
statusCode: c.Response().StatusCode(),
contentType: c.Response().Header.ContentType(),
expiration: time.Now().Unix() + db.expiration,
}
db.Unlock()
// Finish response
return nil
}
}