mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-21 19:53:19 +00:00
✨ Cache middleware: Store e2e headers. (#1807)
* ✨ Cache middleware: Store e2e headers. As defined in RFC2616 - section-13.5.1, shared caches MUST store end-to-end headers from backend response and MUST be transmitted in any response formed from a cache entry. This commit ensures a stronger consistency between responses served from the handlers & from the cache middleware. * ✨ Cache middleware: Add flag for e2e headers. Set flag to prevent e2e headers caching to be the default behavior of the cache middleware. This would otherwise change quite a lot the experience for cache middleware current users. * ✨ Cache middleware: Add Benchmark for additionalHeaders feature. * ✨ Cache middleware: Rename E2Eheaders into StoreResponseHeaders. E2E is an acronym commonly associated with test. While in the present case it refers to end-to-end HTTP headers (by opposition to hop-by-hop), this still remains confusing. This commits renames it to a more generic name. * ✨ Cache middleware: Update README * ✨ Cache middleware: Move map instanciation. This will prevent an extra memory allocation for users not interested in this feature. * ✨ Cache middleware: Prevent memory allocation when StoreResponseHeaders is disabled. * ✨ Cache middleware: Store e2e headers. #1807 - use set instead of add for the headers - copy value from the headers -> prevent problems with mutable values Co-authored-by: wernerr <rene@gofiber.io>
This commit is contained in:
parent
0120531fcc
commit
1cddc56f13
7
middleware/cache/README.md
vendored
7
middleware/cache/README.md
vendored
@ -10,6 +10,7 @@ Cache middleware for [Fiber](https://github.com/gofiber/fiber) designed to inter
|
||||
- [Examples](#examples)
|
||||
- [Default Config](#default-config)
|
||||
- [Custom Config](#custom-config)
|
||||
- [Custom Cache Key Or Expiration](#custom-cache-key-or-expiration)
|
||||
- [Config](#config)
|
||||
- [Default Config](#default-config-1)
|
||||
|
||||
@ -112,6 +113,11 @@ type Config struct {
|
||||
//
|
||||
// Default: an in memory store for this process only
|
||||
Storage fiber.Storage
|
||||
|
||||
// allows you to store additional headers generated by next middlewares & handler
|
||||
//
|
||||
// Default: false
|
||||
StoreResponseHeaders bool
|
||||
}
|
||||
```
|
||||
|
||||
@ -128,6 +134,7 @@ var ConfigDefault = Config{
|
||||
return utils.CopyString(c.Path())
|
||||
},
|
||||
ExpirationGenerator : nil,
|
||||
StoreResponseHeaders: false,
|
||||
Storage: nil,
|
||||
}
|
||||
```
|
||||
|
33
middleware/cache/cache.go
vendored
33
middleware/cache/cache.go
vendored
@ -27,6 +27,19 @@ const (
|
||||
cacheMiss = "miss"
|
||||
)
|
||||
|
||||
var ignoreHeaders = map[string]interface{}{
|
||||
"Connection": nil,
|
||||
"Keep-Alive": nil,
|
||||
"Proxy-Authenticate": nil,
|
||||
"Proxy-Authorization": nil,
|
||||
"TE": nil,
|
||||
"Trailers": nil,
|
||||
"Transfer-Encoding": nil,
|
||||
"Upgrade": nil,
|
||||
"Content-Type": nil, // already stored explicitely by the cache manager
|
||||
"Content-Encoding": nil, // already stored explicitely by the cache manager
|
||||
}
|
||||
|
||||
// New creates a new middleware handler
|
||||
func New(config ...Config) fiber.Handler {
|
||||
// Set default config
|
||||
@ -96,6 +109,11 @@ func New(config ...Config) fiber.Handler {
|
||||
if len(e.cencoding) > 0 {
|
||||
c.Response().Header.SetBytesV(fiber.HeaderContentEncoding, e.cencoding)
|
||||
}
|
||||
if e.headers != nil {
|
||||
for k, v := range e.headers {
|
||||
c.Response().Header.SetBytesV(k, v)
|
||||
}
|
||||
}
|
||||
// Set Cache-Control header if enabled
|
||||
if cfg.CacheControl {
|
||||
maxAge := strconv.FormatUint(e.exp-ts, 10)
|
||||
@ -134,6 +152,21 @@ func New(config ...Config) fiber.Handler {
|
||||
e.ctype = utils.CopyBytes(c.Response().Header.ContentType())
|
||||
e.cencoding = utils.CopyBytes(c.Response().Header.Peek(fiber.HeaderContentEncoding))
|
||||
|
||||
// Store all response headers
|
||||
// (more: https://datatracker.ietf.org/doc/html/rfc2616#section-13.5.1)
|
||||
if cfg.StoreResponseHeaders {
|
||||
e.headers = make(map[string][]byte)
|
||||
c.Response().Header.VisitAll(
|
||||
func(key []byte, value []byte) {
|
||||
// create real copy
|
||||
keyS := string(key)
|
||||
if _, ok := ignoreHeaders[keyS]; !ok {
|
||||
e.headers[keyS] = utils.CopyBytes(value)
|
||||
}
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// default cache expiration
|
||||
expiration := uint64(cfg.Expiration.Seconds())
|
||||
// Calculate expiration by response header or other setting
|
||||
|
50
middleware/cache/cache_test.go
vendored
50
middleware/cache/cache_test.go
vendored
@ -302,6 +302,28 @@ func Test_CustomExpiration(t *testing.T) {
|
||||
utils.AssertEqual(t, 6000, newCacheTime)
|
||||
}
|
||||
|
||||
func Test_AdditionalE2EResponseHeaders(t *testing.T) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
StoreResponseHeaders: true,
|
||||
}))
|
||||
|
||||
app.Get("/", func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Add("X-Foobar", "foobar")
|
||||
return c.SendString("hi")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest("GET", "/", nil)
|
||||
resp, err := app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||
|
||||
req = httptest.NewRequest("GET", "/", nil)
|
||||
resp, err = app.Test(req)
|
||||
utils.AssertEqual(t, nil, err)
|
||||
utils.AssertEqual(t, "foobar", resp.Header.Get("X-Foobar"))
|
||||
}
|
||||
|
||||
func Test_CacheHeader(t *testing.T) {
|
||||
app := fiber.New()
|
||||
|
||||
@ -475,3 +497,31 @@ func Benchmark_Cache_Storage(b *testing.B) {
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, true, len(fctx.Response.Body()) > 30000)
|
||||
}
|
||||
|
||||
func Benchmark_Cache_AdditionalHeaders(b *testing.B) {
|
||||
app := fiber.New()
|
||||
app.Use(New(Config{
|
||||
StoreResponseHeaders: true,
|
||||
}))
|
||||
|
||||
app.Get("/demo", func(c *fiber.Ctx) error {
|
||||
c.Response().Header.Add("X-Foobar", "foobar")
|
||||
return c.SendStatus(418)
|
||||
})
|
||||
|
||||
h := app.Handler()
|
||||
|
||||
fctx := &fasthttp.RequestCtx{}
|
||||
fctx.Request.Header.SetMethod("GET")
|
||||
fctx.Request.SetRequestURI("/demo")
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
h(fctx)
|
||||
}
|
||||
|
||||
utils.AssertEqual(b, fiber.StatusTeapot, fctx.Response.Header.StatusCode())
|
||||
utils.AssertEqual(b, []byte("foobar"), fctx.Response.Header.Peek("X-Foobar"))
|
||||
}
|
||||
|
10
middleware/cache/config.go
vendored
10
middleware/cache/config.go
vendored
@ -54,6 +54,11 @@ type Config struct {
|
||||
|
||||
// Deprecated, use KeyGenerator instead
|
||||
Key func(*fiber.Ctx) string
|
||||
|
||||
// allows you to store additional headers generated by next middlewares & handler
|
||||
//
|
||||
// Default: false
|
||||
StoreResponseHeaders bool
|
||||
}
|
||||
|
||||
// ConfigDefault is the default config
|
||||
@ -65,8 +70,9 @@ var ConfigDefault = Config{
|
||||
KeyGenerator: func(c *fiber.Ctx) string {
|
||||
return utils.CopyString(c.Path())
|
||||
},
|
||||
ExpirationGenerator: nil,
|
||||
Storage: nil,
|
||||
ExpirationGenerator: nil,
|
||||
StoreResponseHeaders: false,
|
||||
Storage: nil,
|
||||
}
|
||||
|
||||
// Helper function to set default values
|
||||
|
2
middleware/cache/manager.go
vendored
2
middleware/cache/manager.go
vendored
@ -18,6 +18,7 @@ type item struct {
|
||||
cencoding []byte
|
||||
status int
|
||||
exp uint64
|
||||
headers map[string][]byte
|
||||
}
|
||||
|
||||
//msgp:ignore manager
|
||||
@ -61,6 +62,7 @@ func (m *manager) release(e *item) {
|
||||
e.ctype = nil
|
||||
e.status = 0
|
||||
e.exp = 0
|
||||
e.headers = nil
|
||||
m.pool.Put(e)
|
||||
}
|
||||
|
||||
|
105
middleware/cache/manager_msgp.go
vendored
105
middleware/cache/manager_msgp.go
vendored
@ -54,6 +54,36 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
case "headers":
|
||||
var zb0002 uint32
|
||||
zb0002, err = dc.ReadMapHeader()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers")
|
||||
return
|
||||
}
|
||||
if z.headers == nil {
|
||||
z.headers = make(map[string][]byte, zb0002)
|
||||
} else if len(z.headers) > 0 {
|
||||
for key := range z.headers {
|
||||
delete(z.headers, key)
|
||||
}
|
||||
}
|
||||
for zb0002 > 0 {
|
||||
zb0002--
|
||||
var za0001 string
|
||||
var za0002 []byte
|
||||
za0001, err = dc.ReadString()
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers")
|
||||
return
|
||||
}
|
||||
za0002, err = dc.ReadBytes(za0002)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers", za0001)
|
||||
return
|
||||
}
|
||||
z.headers[za0001] = za0002
|
||||
}
|
||||
default:
|
||||
err = dc.Skip()
|
||||
if err != nil {
|
||||
@ -67,9 +97,9 @@ func (z *item) DecodeMsg(dc *msgp.Reader) (err error) {
|
||||
|
||||
// EncodeMsg implements msgp.Encodable
|
||||
func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
// map header, size 5
|
||||
// map header, size 6
|
||||
// write "body"
|
||||
err = en.Append(0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
err = en.Append(0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -118,15 +148,37 @@ func (z *item) EncodeMsg(en *msgp.Writer) (err error) {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
// write "headers"
|
||||
err = en.Append(0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = en.WriteMapHeader(uint32(len(z.headers)))
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers")
|
||||
return
|
||||
}
|
||||
for za0001, za0002 := range z.headers {
|
||||
err = en.WriteString(za0001)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers")
|
||||
return
|
||||
}
|
||||
err = en.WriteBytes(za0002)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers", za0001)
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// MarshalMsg implements msgp.Marshaler
|
||||
func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
o = msgp.Require(b, z.Msgsize())
|
||||
// map header, size 5
|
||||
// map header, size 6
|
||||
// string "body"
|
||||
o = append(o, 0x85, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
o = append(o, 0x86, 0xa4, 0x62, 0x6f, 0x64, 0x79)
|
||||
o = msgp.AppendBytes(o, z.body)
|
||||
// string "ctype"
|
||||
o = append(o, 0xa5, 0x63, 0x74, 0x79, 0x70, 0x65)
|
||||
@ -140,6 +192,13 @@ func (z *item) MarshalMsg(b []byte) (o []byte, err error) {
|
||||
// string "exp"
|
||||
o = append(o, 0xa3, 0x65, 0x78, 0x70)
|
||||
o = msgp.AppendUint64(o, z.exp)
|
||||
// string "headers"
|
||||
o = append(o, 0xaa, 0x65, 0x32, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, 0x73)
|
||||
o = msgp.AppendMapHeader(o, uint32(len(z.headers)))
|
||||
for za0001, za0002 := range z.headers {
|
||||
o = msgp.AppendString(o, za0001)
|
||||
o = msgp.AppendBytes(o, za0002)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@ -191,6 +250,36 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
err = msgp.WrapError(err, "exp")
|
||||
return
|
||||
}
|
||||
case "headers":
|
||||
var zb0002 uint32
|
||||
zb0002, bts, err = msgp.ReadMapHeaderBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers")
|
||||
return
|
||||
}
|
||||
if z.headers == nil {
|
||||
z.headers = make(map[string][]byte, zb0002)
|
||||
} else if len(z.headers) > 0 {
|
||||
for key := range z.headers {
|
||||
delete(z.headers, key)
|
||||
}
|
||||
}
|
||||
for zb0002 > 0 {
|
||||
var za0001 string
|
||||
var za0002 []byte
|
||||
zb0002--
|
||||
za0001, bts, err = msgp.ReadStringBytes(bts)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers")
|
||||
return
|
||||
}
|
||||
za0002, bts, err = msgp.ReadBytesBytes(bts, za0002)
|
||||
if err != nil {
|
||||
err = msgp.WrapError(err, "headers", za0001)
|
||||
return
|
||||
}
|
||||
z.headers[za0001] = za0002
|
||||
}
|
||||
default:
|
||||
bts, err = msgp.Skip(bts)
|
||||
if err != nil {
|
||||
@ -205,6 +294,12 @@ func (z *item) UnmarshalMsg(bts []byte) (o []byte, err error) {
|
||||
|
||||
// Msgsize returns an upper bound estimate of the number of bytes occupied by the serialized message
|
||||
func (z *item) Msgsize() (s int) {
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size
|
||||
s = 1 + 5 + msgp.BytesPrefixSize + len(z.body) + 6 + msgp.BytesPrefixSize + len(z.ctype) + 10 + msgp.BytesPrefixSize + len(z.cencoding) + 7 + msgp.IntSize + 4 + msgp.Uint64Size + 11 + msgp.MapHeaderSize
|
||||
if z.headers != nil {
|
||||
for za0001, za0002 := range z.headers {
|
||||
_ = za0002
|
||||
s += msgp.StringPrefixSize + len(za0001) + msgp.BytesPrefixSize + len(za0002)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user