1
0
mirror of https://github.com/gofiber/fiber.git synced 2025-02-06 18:31:55 +00:00

🐛 bug: Fix square bracket notation in Multipart FormData (#3235)

* 🐛 bug: add square bracket notation support to BindMultipart

* Fix golangci-lint issues

* Fixing undef variable

* Fix more lint issues

* test

* update1

* improve coverage

* fix linter

* reduce code duplication

* reduce code duplications in bindMultipart

---------

Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
Co-authored-by: René <rene@gofiber.io>
This commit is contained in:
M. Efe Çetin 2024-12-31 18:34:28 +03:00 committed by GitHub
parent d0e767fc47
commit ef04a8a99e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 186 additions and 75 deletions

View File

@ -7,6 +7,7 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"mime/multipart"
"net/http/httptest" "net/http/httptest"
"reflect" "reflect"
"testing" "testing"
@ -887,6 +888,7 @@ func Test_Bind_Body(t *testing.T) {
type Demo struct { type Demo struct {
Name string `json:"name" xml:"name" form:"name" query:"name"` Name string `json:"name" xml:"name" form:"name" query:"name"`
Names []string `json:"names" xml:"names" form:"names" query:"names"`
} }
// Helper function to test compressed bodies // Helper function to test compressed bodies
@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) {
Data []Demo `query:"data"` Data []Demo `query:"data"`
} }
t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data.0.name", "john"))
require.NoError(t, writer.WriteField("data.1.name", "doe"))
require.NoError(t, writer.Close())
c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))
cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})
t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset()
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(t, writer.WriteField("data[0][name]", "john"))
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
require.NoError(t, writer.Close())
c.Request().Header.SetContentType(writer.FormDataContentType())
c.Request().SetBody(buf.Bytes())
c.Request().Header.SetContentLength(len(c.Body()))
cq := new(CollectionQuery)
require.NoError(t, c.Bind().Body(cq))
require.Len(t, cq.Data, 2)
require.Equal(t, "john", cq.Data[0].Name)
require.Equal(t, "doe", cq.Data[1].Name)
})
t.Run("CollectionQuerySquareBrackets", func(t *testing.T) { t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
c := app.AcquireCtx(&fasthttp.RequestCtx{}) c := app.AcquireCtx(&fasthttp.RequestCtx{})
c.Request().Reset() c.Request().Reset()
@ -1192,9 +1236,14 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
Name string `form:"name"` Name string `form:"name"`
} }
body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--") buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.Close())
body := buf.Bytes()
c.Request().SetBody(body) c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`) c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body)) c.Request().Header.SetContentLength(len(body))
d := new(Demo) d := new(Demo)
@ -1204,10 +1253,58 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
err = c.Bind().Body(d) err = c.Bind().Body(d)
} }
require.NoError(b, err) require.NoError(b, err)
require.Equal(b, "john", d.Name) require.Equal(b, "john", d.Name)
} }
// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
var err error
app := New()
c := app.AcquireCtx(&fasthttp.RequestCtx{})
type Person struct {
Name string `form:"name"`
Age int `form:"age"`
}
type Demo struct {
Name string `form:"name"`
Persons []Person `form:"persons"`
}
buf := &bytes.Buffer{}
writer := multipart.NewWriter(buf)
require.NoError(b, writer.WriteField("name", "john"))
require.NoError(b, writer.WriteField("persons.0.name", "john"))
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
require.NoError(b, writer.WriteField("persons.1.age", "20"))
require.NoError(b, writer.Close())
body := buf.Bytes()
c.Request().SetBody(body)
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
c.Request().Header.SetContentLength(len(body))
d := new(Demo)
b.ReportAllocs()
b.ResetTimer()
for n := 0; n < b.N; n++ {
err = c.Bind().Body(d)
}
require.NoError(b, err)
require.Equal(b, "john", d.Name)
require.Equal(b, "john", d.Persons[0].Name)
require.Equal(b, 10, d.Persons[0].Age)
require.Equal(b, "doe", d.Persons[1].Name)
require.Equal(b, 20, d.Persons[1].Age)
}
// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4 // go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4
func Benchmark_Bind_Body_Form_Map(b *testing.B) { func Benchmark_Bind_Body_Form_Map(b *testing.B) {
var err error var err error

View File

@ -1,9 +1,6 @@
package binder package binder
import ( import (
"reflect"
"strings"
"github.com/gofiber/utils/v2" "github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error {
k := utils.UnsafeString(key) k := utils.UnsafeString(key)
v := utils.UnsafeString(val) v := utils.UnsafeString(val)
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
}) })
if err != nil { if err != nil {

View File

@ -1,9 +1,6 @@
package binder package binder
import ( import (
"reflect"
"strings"
"github.com/gofiber/utils/v2" "github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -37,19 +34,7 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {
k := utils.UnsafeString(key) k := utils.UnsafeString(key)
v := utils.UnsafeString(val) v := utils.UnsafeString(val)
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
}) })
if err != nil { if err != nil {
@ -61,12 +46,20 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {
// bindMultipart parses the request body and returns the result. // bindMultipart parses the request body and returns the result.
func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error { func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
data, err := req.MultipartForm() multipartForm, err := req.MultipartForm()
if err != nil { if err != nil {
return err return err
} }
return parse(b.Name(), out, data.Value) data := make(map[string][]string)
for key, values := range multipartForm.Value {
err = formatBindData(out, data, key, values, b.EnableSplitting, true)
if err != nil {
return err
}
}
return parse(b.Name(), out, data)
} }
// Reset resets the FormBinding binder. // Reset resets the FormBinding binder.

View File

@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
} }
require.Equal(t, "form", b.Name()) require.Equal(t, "form", b.Name())
type Post struct {
Title string `form:"title"`
}
type User struct { type User struct {
Name string `form:"name"` Name string `form:"name"`
Names []string `form:"names"` Names []string `form:"names"`
Posts []Post `form:"posts"`
Age int `form:"age"` Age int `form:"age"`
} }
var user User var user User
@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
mw := multipart.NewWriter(buf) mw := multipart.NewWriter(buf)
require.NoError(t, mw.WriteField("name", "john")) require.NoError(t, mw.WriteField("name", "john"))
require.NoError(t, mw.WriteField("names", "john")) require.NoError(t, mw.WriteField("names", "john,eric"))
require.NoError(t, mw.WriteField("names", "doe")) require.NoError(t, mw.WriteField("names", "doe"))
require.NoError(t, mw.WriteField("age", "42")) require.NoError(t, mw.WriteField("age", "42"))
require.NoError(t, mw.WriteField("posts[0][title]", "post1"))
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))
require.NoError(t, mw.Close()) require.NoError(t, mw.Close())
req.Header.SetContentType(mw.FormDataContentType()) req.Header.SetContentType(mw.FormDataContentType())
@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
require.Equal(t, 42, user.Age) require.Equal(t, 42, user.Age)
require.Contains(t, user.Names, "john") require.Contains(t, user.Names, "john")
require.Contains(t, user.Names, "doe") require.Contains(t, user.Names, "doe")
require.Contains(t, user.Names, "eric")
require.Len(t, user.Posts, 3)
require.Equal(t, "post1", user.Posts[0].Title)
require.Equal(t, "post2", user.Posts[1].Title)
require.Equal(t, "post3", user.Posts[2].Title)
} }
func Benchmark_FormBinder_BindMultipart(b *testing.B) { func Benchmark_FormBinder_BindMultipart(b *testing.B) {

View File

@ -1,9 +1,6 @@
package binder package binder
import ( import (
"reflect"
"strings"
"github.com/gofiber/utils/v2" "github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -21,20 +18,21 @@ func (*HeaderBinding) Name() string {
// Bind parses the request header and returns the result. // Bind parses the request header and returns the result.
func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error { func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error {
data := make(map[string][]string) data := make(map[string][]string)
var err error
req.Header.VisitAll(func(key, val []byte) { req.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}
k := utils.UnsafeString(key) k := utils.UnsafeString(key)
v := utils.UnsafeString(val) v := utils.UnsafeString(val)
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
}) })
if err != nil {
return err
}
return parse(b.Name(), out, data) return parse(b.Name(), out, data)
} }

View File

@ -107,7 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error {
func parseToMap(ptr any, data map[string][]string) error { func parseToMap(ptr any, data map[string][]string) error {
elem := reflect.TypeOf(ptr).Elem() elem := reflect.TypeOf(ptr).Elem()
switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types switch elem.Kind() {
case reflect.Slice: case reflect.Slice:
newMap, ok := ptr.(map[string][]string) newMap, ok := ptr.(map[string][]string)
if !ok { if !ok {
@ -130,6 +130,8 @@ func parseToMap(ptr any, data map[string][]string) error {
} }
newMap[k] = v[len(v)-1] newMap[k] = v[len(v)-1]
} }
default:
return nil // it's not necessary to check all types
} }
return nil return nil
@ -247,3 +249,37 @@ func FilterFlags(content string) string {
} }
return content return content
} }
func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
var err error
if supportBracketNotation && strings.Contains(key, "[") {
key, err = parseParamSquareBrackets(key)
if err != nil {
return err
}
}
switch v := any(value).(type) {
case string:
assignBindData(out, data, key, v, enableSplitting)
case []string:
for _, val := range v {
assignBindData(out, data, key, val, enableSplitting)
}
default:
return fmt.Errorf("unsupported value type: %T", value)
}
return err
}
func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) {
values := strings.Split(value, ",")
for i := 0; i < len(values); i++ {
data[key] = append(data[key], values[i])
}
} else {
data[key] = append(data[key], value)
}
}

View File

@ -1,9 +1,6 @@
package binder package binder
import ( import (
"reflect"
"strings"
"github.com/gofiber/utils/v2" "github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error {
k := utils.UnsafeString(key) k := utils.UnsafeString(key)
v := utils.UnsafeString(val) v := utils.UnsafeString(val)
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
if strings.Contains(k, "[") {
k, err = parseParamSquareBrackets(k)
}
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
}) })
if err != nil { if err != nil {

View File

@ -1,9 +1,6 @@
package binder package binder
import ( import (
"reflect"
"strings"
"github.com/gofiber/utils/v2" "github.com/gofiber/utils/v2"
"github.com/valyala/fasthttp" "github.com/valyala/fasthttp"
) )
@ -21,20 +18,22 @@ func (*RespHeaderBinding) Name() string {
// Bind parses the response header and returns the result. // Bind parses the response header and returns the result.
func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error { func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error {
data := make(map[string][]string) data := make(map[string][]string)
var err error
resp.Header.VisitAll(func(key, val []byte) { resp.Header.VisitAll(func(key, val []byte) {
if err != nil {
return
}
k := utils.UnsafeString(key) k := utils.UnsafeString(key)
v := utils.UnsafeString(val) v := utils.UnsafeString(val)
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
values := strings.Split(v, ",")
for i := 0; i < len(values); i++ {
data[k] = append(data[k], values[i])
}
} else {
data[k] = append(data[k], v)
}
}) })
if err != nil {
return err
}
return parse(b.Name(), out, data) return parse(b.Name(), out, data)
} }