1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-06 23:51:40 +00:00

fix: Decompress request body when multi Content-Encoding sent on request headers (#2555)

* 🔧 feat: Decode body in order when sent a list on content-encoding

* 🚀 perf: Change `getSplicedStrList` to have 0 allocations

* 🍵 test: Add tests for the new features

* 🍵 test: Ensure session test will not raise an error unexpectedly

* 🐗 feat: Replace strings.TrimLeft by utils.TrimLeft

Add docs to functions to inform correctly what the change is

* 🌷 refactor: Apply linter rules

* 🍵 test: Add test cases to the new body method change

* 🔧 feat: Remove return problems to be able to reach original body

* 🌷 refactor: Split Body method into two to make it more maintainable

Also, with the previous fix to problems detected by tests, it becomes really hard to make the linter happy, so this change also helps in it

* 🚀 perf: Came back with Header.VisitAll, to improve speed

* 📃 docs: Update Context docs
This commit is contained in:
João Victor Oliveira Couto 2023-08-06 12:23:37 -03:00 committed by GitHub
parent e91b02b345
commit f29f39b1b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 398 additions and 70 deletions

89
ctx.go
View File

@ -260,31 +260,92 @@ func (c *Ctx) BaseURL() string {
return c.baseURI
}
// Body contains the raw body submitted in a POST request.
// BodyRaw contains the raw body submitted in a POST request.
// Returned value is only valid within the handler. Do not store any references.
// Make copies or use the Immutable setting instead.
func (c *Ctx) BodyRaw() []byte {
return c.fasthttp.Request.Body()
}
func (c *Ctx) tryDecodeBodyInOrder(
originalBody *[]byte,
encodings []string,
) ([]byte, uint8, error) {
var (
err error
body []byte
decodesRealized uint8
)
for index, encoding := range encodings {
decodesRealized++
switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
decodesRealized--
if len(encodings) == 1 {
body = c.fasthttp.Request.Body()
}
return body, decodesRealized, nil
}
if err != nil {
return nil, decodesRealized, err
}
// Only execute body raw update if it has a next iteration to try to decode
if index < len(encodings)-1 && decodesRealized > 0 {
if index == 0 {
tempBody := c.fasthttp.Request.Body()
*originalBody = make([]byte, len(tempBody))
copy(*originalBody, tempBody)
}
c.fasthttp.Request.SetBodyRaw(body)
}
}
return body, decodesRealized, nil
}
// Body contains the raw body submitted in a POST request.
// This method will decompress the body if the 'Content-Encoding' header is provided.
// It returns the original (or decompressed) body data which is valid only within the handler.
// Don't store direct references to the returned data.
// If you need to keep the body's data later, make a copy or use the Immutable option.
func (c *Ctx) Body() []byte {
var err error
var encoding string
var body []byte
var (
err error
body, originalBody []byte
headerEncoding string
encodingOrder = []string{"", "", ""}
)
// faster than peek
c.Request().Header.VisitAll(func(key, value []byte) {
if c.app.getString(key) == HeaderContentEncoding {
encoding = c.app.getString(value)
headerEncoding = c.app.getString(value)
}
})
switch encoding {
case StrGzip:
body, err = c.fasthttp.Request.BodyGunzip()
case StrBr, StrBrotli:
body, err = c.fasthttp.Request.BodyUnbrotli()
case StrDeflate:
body, err = c.fasthttp.Request.BodyInflate()
default:
body = c.fasthttp.Request.Body()
// Split and get the encodings list, in order to attend the
// rule defined at: https://www.rfc-editor.org/rfc/rfc9110#section-8.4-5
encodingOrder = getSplicedStrList(headerEncoding, encodingOrder)
if len(encodingOrder) == 0 {
return c.fasthttp.Request.Body()
}
var decodesRealized uint8
body, decodesRealized, err = c.tryDecodeBodyInOrder(&originalBody, encodingOrder)
// Ensure that the body will be the original
if originalBody != nil && decodesRealized > 0 {
c.fasthttp.Request.SetBodyRaw(originalBody)
}
if err != nil {
return []byte(err.Error())
}

View File

@ -9,6 +9,7 @@ import (
"bufio"
"bytes"
"compress/gzip"
"compress/zlib"
"context"
"crypto/tls"
"encoding/xml"
@ -323,47 +324,211 @@ func Test_Ctx_Body(t *testing.T) {
utils.AssertEqual(t, []byte("john=doe"), c.Body())
}
// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
func Benchmark_Ctx_Body(b *testing.B) {
const input = "john=doe"
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var b bytes.Buffer
gz := gzip.NewWriter(&b)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(t, nil, err)
err = gz.Flush()
utils.AssertEqual(t, nil, err)
err = gz.Close()
utils.AssertEqual(t, nil, err)
c.Request().SetBody(b.Bytes())
utils.AssertEqual(t, []byte("john=doe"), c.Body())
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4
func Benchmark_Ctx_Body_With_Compression(b *testing.B) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", "gzip")
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, err := gz.Write([]byte("john=doe"))
utils.AssertEqual(b, nil, err)
err = gz.Flush()
utils.AssertEqual(b, nil, err)
err = gz.Close()
utils.AssertEqual(b, nil, err)
c.Request().SetBody(buf.Bytes())
c.Request().SetBody([]byte(input))
for i := 0; i < b.N; i++ {
_ = c.Body()
}
utils.AssertEqual(b, []byte("john=doe"), c.Body())
utils.AssertEqual(b, []byte(input), c.Body())
}
// go test -run Test_Ctx_Body_With_Compression
func Test_Ctx_Body_With_Compression(t *testing.T) {
t.Parallel()
tests := []struct {
name string
contentEncoding string
body []byte
expectedBody []byte
}{
{
name: "gzip",
contentEncoding: "gzip",
body: []byte("john=doe"),
expectedBody: []byte("john=doe"),
},
{
name: "unsupported_encoding",
contentEncoding: "undefined",
body: []byte("keeps_ORIGINAL"),
expectedBody: []byte("keeps_ORIGINAL"),
},
{
name: "gzip then unsupported",
contentEncoding: "gzip, undefined",
body: []byte("Go, be gzipped"),
expectedBody: []byte("Go, be gzipped"),
},
{
name: "invalid_deflate",
contentEncoding: "gzip,deflate",
body: []byte("I'm not correctly compressed"),
expectedBody: []byte(zlib.ErrHeader.Error()),
},
}
for _, testObject := range tests {
tCase := testObject // Duplicate object to ensure it will be unique across all runs
t.Run(tCase.name, func(t *testing.T) {
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", tCase.contentEncoding)
if strings.Contains(tCase.contentEncoding, "gzip") {
var b bytes.Buffer
gz := gzip.NewWriter(&b)
_, err := gz.Write(tCase.body)
if err != nil {
t.Fatal(err)
}
if err = gz.Flush(); err != nil {
t.Fatal(err)
}
if err = gz.Close(); err != nil {
t.Fatal(err)
}
tCase.body = b.Bytes()
}
c.Request().SetBody(tCase.body)
body := c.Body()
utils.AssertEqual(t, tCase.expectedBody, body)
// Check if body raw is the same as previous before decompression
utils.AssertEqual(
t, tCase.body, c.Request().Body(),
"Body raw must be the same as set before",
)
})
}
}
// go test -v -run=^$ -bench=Benchmark_Ctx_Body_With_Compression -benchmem -count=4
func Benchmark_Ctx_Body_With_Compression(b *testing.B) {
encodingErr := errors.New("failed to encoding data")
var (
compressGzip = func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
}
compressDeflate = func(data []byte) ([]byte, error) {
var buf bytes.Buffer
writer := zlib.NewWriter(&buf)
if _, err := writer.Write(data); err != nil {
return nil, encodingErr
}
if err := writer.Flush(); err != nil {
return nil, encodingErr
}
if err := writer.Close(); err != nil {
return nil, encodingErr
}
return buf.Bytes(), nil
}
)
compressionTests := []struct {
contentEncoding string
compressWriter func([]byte) ([]byte, error)
}{
{
contentEncoding: "gzip",
compressWriter: compressGzip,
},
{
contentEncoding: "gzip,invalid",
compressWriter: compressGzip,
},
{
contentEncoding: "deflate",
compressWriter: compressDeflate,
},
{
contentEncoding: "gzip,deflate",
compressWriter: func(data []byte) ([]byte, error) {
var (
buf bytes.Buffer
writer interface {
io.WriteCloser
Flush() error
}
err error
)
// deflate
{
writer = zlib.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}
data = make([]byte, buf.Len())
copy(data, buf.Bytes())
buf.Reset()
// gzip
{
writer = gzip.NewWriter(&buf)
if _, err = writer.Write(data); err != nil {
return nil, encodingErr
}
if err = writer.Flush(); err != nil {
return nil, encodingErr
}
if err = writer.Close(); err != nil {
return nil, encodingErr
}
}
return buf.Bytes(), nil
},
},
}
for _, ct := range compressionTests {
b.Run(ct.contentEncoding, func(b *testing.B) {
app := New()
const input = "john=doe"
c := app.AcquireCtx(&fasthttp.RequestCtx{})
defer app.ReleaseCtx(c)
c.Request().Header.Set("Content-Encoding", ct.contentEncoding)
compressedBody, err := ct.compressWriter([]byte(input))
utils.AssertEqual(b, nil, err)
c.Request().SetBody(compressedBody)
for i := 0; i < b.N; i++ {
_ = c.Body()
}
utils.AssertEqual(b, []byte(input), c.Body())
})
}
}
// go test -run Test_Ctx_BodyParser

View File

@ -57,13 +57,13 @@ Fiber provides similar functions for the other accept headers.
// Accept-Language: en;q=0.8, nl, ru
app.Get("/", func(c *fiber.Ctx) error {
c.AcceptsCharsets("utf-16", "iso-8859-1")
c.AcceptsCharsets("utf-16", "iso-8859-1")
// "iso-8859-1"
c.AcceptsEncodings("compress", "br")
c.AcceptsEncodings("compress", "br")
// "compress"
c.AcceptsLanguages("pt", "nl", "ru")
c.AcceptsLanguages("pt", "nl", "ru")
// "nl"
// ...
})
@ -171,6 +171,7 @@ app.Get("/", func(c *fiber.Ctx) error {
```
## Bind
Add vars to default view var map binding to template engine.
Variables are read by the Render method and may be overwritten.
@ -190,12 +191,12 @@ app.Get("/", func(c *fiber.Ctx) error {
})
```
## Body
## BodyRaw
Returns the raw request **body**.
```go title="Signature"
func (c *Ctx) Body() []byte
func (c *Ctx) BodyRaw() []byte
```
```go title="Example"
@ -203,6 +204,26 @@ func (c *Ctx) Body() []byte
app.Post("/", func(c *fiber.Ctx) error {
// Get raw body from POST request:
return c.Send(c.BodyRaw()) // []byte("user=john")
})
```
> _Returned value is only valid within the handler. Do not store any references.
> Make copies or use the_ [_**`Immutable`**_](ctx.md) _setting instead._ [_Read more..._](../#zero-allocation)
## Body
As per the header `Content-Encoding`, this method will try to perform a file decompression from the **body** bytes. In case no `Content-Encoding` header is sent, it will perform as [BodyRaw](#bodyraw).
```go title="Signature"
func (c *Ctx) Body() []byte
```
```go title="Example"
// echo 'user=john' | gzip | curl -v -i --data-binary @- -H "Content-Encoding: gzip" http://localhost:8080
app.Post("/", func(c *fiber.Ctx) error {
// Decompress body from POST request based on the Content-Encoding and return the raw content:
return c.Send(c.Body()) // []byte("user=john")
})
```
@ -216,13 +237,13 @@ Binds the request body to a struct.
It is important to specify the correct struct tag based on the content type to be parsed. For example, if you want to parse a JSON body with a field called Pass, you would use a struct field of `json:"pass"`.
| content-type | struct tag |
|---|---|
| `application/x-www-form-urlencoded` | form |
| `multipart/form-data` | form |
| `application/json` | json |
| `application/xml` | xml |
| `text/xml` | xml |
| content-type | struct tag |
| ----------------------------------- | ---------- |
| `application/x-www-form-urlencoded` | form |
| `multipart/form-data` | form |
| `application/json` | json |
| `application/xml` | xml |
| `text/xml` | xml |
```go title="Signature"
func (c *Ctx) BodyParser(out interface{}) error
@ -693,6 +714,7 @@ app.Get("/", func(c *fiber.Ctx) error {
## IsFromLocal
Returns true if request came from localhost
```go title="Signature"
func (c *Ctx) IsFromLocal() bool {
```
@ -837,7 +859,7 @@ app.Post("/", func(c *fiber.Ctx) error {
c.Location("http://example.com")
c.Location("/foo/bar")
return nil
})
```
@ -1024,6 +1046,7 @@ app.Get("/user/:id", func(c *fiber.Ctx) error {
This method is equivalent of using `atoi` with ctx.Params
## ParamsParser
This method is similar to BodyParser, but for path parameters. It is important to use the struct tag "params". For example, if you want to parse a path parameter with a field called Pass, you would use a struct field of params:"pass"
```go title="Signature"
@ -1034,7 +1057,7 @@ func (c *Ctx) ParamsParser(out interface{}) error
// GET http://example.com/user/111
app.Get("/user/:id", func(c *fiber.Ctx) error {
param := struct {ID uint `params:"id"`}{}
c.ParamsParser(&param) // "{"id": 111}"
// ...
@ -1176,7 +1199,6 @@ app.Get("/", func(c *fiber.Ctx) error {
This property is an object containing a property for each query boolean parameter in the route, you could pass an optional default value that will be returned if the query key does not exist.
:::caution
Please note if that parameter is not in the request, false will be returned.
If the parameter is not a boolean, it is still tried to be converted and usually returned as false.
@ -1232,12 +1254,10 @@ app.Get("/", func(c *fiber.Ctx) error {
})
```
## QueryInt
This property is an object containing a property for each query integer parameter in the route, you could pass an optional default value that will be returned if the query key does not exist.
:::caution
Please note if that parameter is not in the request, zero will be returned.
If the parameter is not a number, it is still tried to be converted and usually returned as 1.
@ -1522,7 +1542,7 @@ func (c *Ctx) Route() *Route
app.Get("/hello/:name", func(c *fiber.Ctx) error {
r := c.Route()
fmt.Println(r.Method, r.Path, r.Params, r.Handlers)
// GET /hello/:name handler [name]
// GET /hello/:name handler [name]
// ...
})
@ -1768,7 +1788,7 @@ var timeConverter = func(value string) reflect.Value {
customTime := fiber.ParserType{
Customtype: CustomTime{},
Converter: timeConverter,
}
}
// Add setting to the Decoder
fiber.SetParserDecoder(fiber.ParserConfig{
@ -1804,7 +1824,6 @@ app.Get("/query", func(c *fiber.Ctx) error {
```
## SetUserContext
Sets the user specified implementation for context interface.
@ -2020,7 +2039,7 @@ XML also sets the content header to **application/xml**.
:::
```go title="Signature"
func (c *Ctx) XML(data interface{}) error
func (c *Ctx) XML(data interface{}) error
```
```go title="Example"

View File

@ -269,6 +269,41 @@ func acceptsOfferType(spec, offerType string) bool {
return false
}
// getSplicedStrList function takes a string and a string slice as an argument, divides the string into different
// elements divided by ',' and stores these elements in the string slice.
// It returns the populated string slice as an output.
//
// If the given slice hasn't enough space, it will allocate more and return.
func getSplicedStrList(headerValue string, dst []string) []string {
if headerValue == "" {
return nil
}
var (
index int
character rune
lastElementEndsAt uint8
insertIndex int
)
for index, character = range headerValue + "$" {
if character == ',' || index == len(headerValue) {
if insertIndex >= len(dst) {
oldSlice := dst
dst = make([]string, len(dst)+(len(dst)>>1)+2)
copy(dst, oldSlice)
}
dst[insertIndex] = utils.TrimLeft(headerValue[lastElementEndsAt:index], ' ')
lastElementEndsAt = uint8(index + 1)
insertIndex++
}
}
if len(dst) > insertIndex {
dst = dst[:insertIndex]
}
return dst
}
// getOffer return valid offer for header negotiation
func getOffer(header string, isAccepted func(spec, offer string) bool, offers ...string) string {
if len(offers) == 0 {

View File

@ -107,6 +107,53 @@ func Benchmark_Utils_GetOffer(b *testing.B) {
}
}
func Test_Utils_GetSplicedStrList(t *testing.T) {
testCases := []struct {
description string
headerValue string
expectedList []string
}{
{
description: "normal case",
headerValue: "gzip, deflate,br",
expectedList: []string{"gzip", "deflate", "br"},
},
{
description: "no matter the value",
headerValue: " gzip,deflate, br, zip",
expectedList: []string{"gzip", "deflate", "br", "zip"},
},
{
description: "headerValue is empty",
headerValue: "",
expectedList: nil,
},
{
description: "has a comma without element",
headerValue: "gzip,",
expectedList: []string{"gzip", ""},
},
}
for _, tc := range testCases {
t.Run(tc.description, func(t *testing.T) {
dst := make([]string, 10)
result := getSplicedStrList(tc.headerValue, dst)
utils.AssertEqual(t, tc.expectedList, result)
})
}
}
func Benchmark_Utils_GetSplicedStrList(b *testing.B) {
destination := make([]string, 5)
result := destination
const input = "deflate, gzip,br,brotli"
for n := 0; n < b.N; n++ {
result = getSplicedStrList(input, destination)
}
utils.AssertEqual(b, []string{"deflate", "gzip", "br", "brotli"}, result)
}
func Test_Utils_SortAcceptedTypes(t *testing.T) {
t.Parallel()
acceptedTypes := []acceptedType{

View File

@ -287,6 +287,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
t.Parallel()
t.Run("save to cookie", func(t *testing.T) {
const sessionDuration = 5 * time.Second
t.Parallel()
// session store
store := New()
@ -302,7 +303,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
sess.Set("name", "john")
// expire this session in 5 seconds
sess.SetExpiry(time.Second * 5)
sess.SetExpiry(sessionDuration)
// save session
err = sess.Save()
@ -314,7 +315,7 @@ func Test_Session_Save_Expiration(t *testing.T) {
utils.AssertEqual(t, "john", sess.Get("name"))
// just to make sure the session has been expired
time.Sleep(time.Second * 5)
time.Sleep(sessionDuration + (10 * time.Millisecond))
// here you should get a new session
sess, err = store.Get(ctx)