mirror of
https://github.com/gofiber/fiber.git
synced 2025-02-06 13:49:31 +00:00
🐛 bug: Fix square bracket notation in Multipart FormData (#3235)
* 🐛 bug: add square bracket notation support to BindMultipart
* Fix golangci-lint issues
* Fixing undef variable
* Fix more lint issues
* test
* update1
* improve coverage
* fix linter
* reduce code duplication
* reduce code duplications in bindMultipart
---------
Co-authored-by: Juan Calderon-Perez <835733+gaby@users.noreply.github.com>
Co-authored-by: René <rene@gofiber.io>
This commit is contained in:
parent
d0e767fc47
commit
ef04a8a99e
101
bind_test.go
101
bind_test.go
@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"mime/multipart"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
@ -887,6 +888,7 @@ func Test_Bind_Body(t *testing.T) {
|
||||
|
||||
type Demo struct {
|
||||
Name string `json:"name" xml:"name" form:"name" query:"name"`
|
||||
Names []string `json:"names" xml:"names" form:"names" query:"names"`
|
||||
}
|
||||
|
||||
// Helper function to test compressed bodies
|
||||
@ -996,6 +998,48 @@ func Test_Bind_Body(t *testing.T) {
|
||||
Data []Demo `query:"data"`
|
||||
}
|
||||
|
||||
t.Run("MultipartCollectionQueryDotNotation", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Reset()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(buf)
|
||||
require.NoError(t, writer.WriteField("data.0.name", "john"))
|
||||
require.NoError(t, writer.WriteField("data.1.name", "doe"))
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
c.Request().Header.SetContentType(writer.FormDataContentType())
|
||||
c.Request().SetBody(buf.Bytes())
|
||||
c.Request().Header.SetContentLength(len(c.Body()))
|
||||
|
||||
cq := new(CollectionQuery)
|
||||
require.NoError(t, c.Bind().Body(cq))
|
||||
require.Len(t, cq.Data, 2)
|
||||
require.Equal(t, "john", cq.Data[0].Name)
|
||||
require.Equal(t, "doe", cq.Data[1].Name)
|
||||
})
|
||||
|
||||
t.Run("MultipartCollectionQuerySquareBrackets", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Reset()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(buf)
|
||||
require.NoError(t, writer.WriteField("data[0][name]", "john"))
|
||||
require.NoError(t, writer.WriteField("data[1][name]", "doe"))
|
||||
require.NoError(t, writer.Close())
|
||||
|
||||
c.Request().Header.SetContentType(writer.FormDataContentType())
|
||||
c.Request().SetBody(buf.Bytes())
|
||||
c.Request().Header.SetContentLength(len(c.Body()))
|
||||
|
||||
cq := new(CollectionQuery)
|
||||
require.NoError(t, c.Bind().Body(cq))
|
||||
require.Len(t, cq.Data, 2)
|
||||
require.Equal(t, "john", cq.Data[0].Name)
|
||||
require.Equal(t, "doe", cq.Data[1].Name)
|
||||
})
|
||||
|
||||
t.Run("CollectionQuerySquareBrackets", func(t *testing.T) {
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
c.Request().Reset()
|
||||
@ -1192,9 +1236,14 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
|
||||
Name string `form:"name"`
|
||||
}
|
||||
|
||||
body := []byte("--b\r\nContent-Disposition: form-data; name=\"name\"\r\n\r\njohn\r\n--b--")
|
||||
buf := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(buf)
|
||||
require.NoError(b, writer.WriteField("name", "john"))
|
||||
require.NoError(b, writer.Close())
|
||||
body := buf.Bytes()
|
||||
|
||||
c.Request().SetBody(body)
|
||||
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary="b"`)
|
||||
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
|
||||
c.Request().Header.SetContentLength(len(body))
|
||||
d := new(Demo)
|
||||
|
||||
@ -1204,10 +1253,58 @@ func Benchmark_Bind_Body_MultipartForm(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
err = c.Bind().Body(d)
|
||||
}
|
||||
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, "john", d.Name)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Bind_Body_MultipartForm_Nested -benchmem -count=4
|
||||
func Benchmark_Bind_Body_MultipartForm_Nested(b *testing.B) {
|
||||
var err error
|
||||
|
||||
app := New()
|
||||
c := app.AcquireCtx(&fasthttp.RequestCtx{})
|
||||
|
||||
type Person struct {
|
||||
Name string `form:"name"`
|
||||
Age int `form:"age"`
|
||||
}
|
||||
|
||||
type Demo struct {
|
||||
Name string `form:"name"`
|
||||
Persons []Person `form:"persons"`
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(buf)
|
||||
require.NoError(b, writer.WriteField("name", "john"))
|
||||
require.NoError(b, writer.WriteField("persons.0.name", "john"))
|
||||
require.NoError(b, writer.WriteField("persons[0][age]", "10"))
|
||||
require.NoError(b, writer.WriteField("persons[1][name]", "doe"))
|
||||
require.NoError(b, writer.WriteField("persons.1.age", "20"))
|
||||
require.NoError(b, writer.Close())
|
||||
body := buf.Bytes()
|
||||
|
||||
c.Request().SetBody(body)
|
||||
c.Request().Header.SetContentType(MIMEMultipartForm + `;boundary=` + writer.Boundary())
|
||||
c.Request().Header.SetContentLength(len(body))
|
||||
d := new(Demo)
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
err = c.Bind().Body(d)
|
||||
}
|
||||
|
||||
require.NoError(b, err)
|
||||
require.Equal(b, "john", d.Name)
|
||||
require.Equal(b, "john", d.Persons[0].Name)
|
||||
require.Equal(b, 10, d.Persons[0].Age)
|
||||
require.Equal(b, "doe", d.Persons[1].Name)
|
||||
require.Equal(b, 20, d.Persons[1].Age)
|
||||
}
|
||||
|
||||
// go test -v -run=^$ -bench=Benchmark_Bind_Body_Form_Map -benchmem -count=4
|
||||
func Benchmark_Bind_Body_Form_Map(b *testing.B) {
|
||||
var err error
|
||||
|
@ -1,9 +1,6 @@
|
||||
package binder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/utils/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -30,15 +27,7 @@ func (b *CookieBinding) Bind(req *fasthttp.Request, out any) error {
|
||||
|
||||
k := utils.UnsafeString(key)
|
||||
v := utils.UnsafeString(val)
|
||||
|
||||
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
|
||||
values := strings.Split(v, ",")
|
||||
for i := 0; i < len(values); i++ {
|
||||
data[k] = append(data[k], values[i])
|
||||
}
|
||||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -1,9 +1,6 @@
|
||||
package binder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/utils/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -37,19 +34,7 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {
|
||||
|
||||
k := utils.UnsafeString(key)
|
||||
v := utils.UnsafeString(val)
|
||||
|
||||
if strings.Contains(k, "[") {
|
||||
k, err = parseParamSquareBrackets(k)
|
||||
}
|
||||
|
||||
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
|
||||
values := strings.Split(v, ",")
|
||||
for i := 0; i < len(values); i++ {
|
||||
data[k] = append(data[k], values[i])
|
||||
}
|
||||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
@ -61,12 +46,20 @@ func (b *FormBinding) Bind(req *fasthttp.Request, out any) error {
|
||||
|
||||
// bindMultipart parses the request body and returns the result.
|
||||
func (b *FormBinding) bindMultipart(req *fasthttp.Request, out any) error {
|
||||
data, err := req.MultipartForm()
|
||||
multipartForm, err := req.MultipartForm()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return parse(b.Name(), out, data.Value)
|
||||
data := make(map[string][]string)
|
||||
for key, values := range multipartForm.Value {
|
||||
err = formatBindData(out, data, key, values, b.EnableSplitting, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return parse(b.Name(), out, data)
|
||||
}
|
||||
|
||||
// Reset resets the FormBinding binder.
|
||||
|
@ -93,9 +93,14 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
||||
}
|
||||
require.Equal(t, "form", b.Name())
|
||||
|
||||
type Post struct {
|
||||
Title string `form:"title"`
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Name string `form:"name"`
|
||||
Names []string `form:"names"`
|
||||
Posts []Post `form:"posts"`
|
||||
Age int `form:"age"`
|
||||
}
|
||||
var user User
|
||||
@ -106,9 +111,13 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
||||
mw := multipart.NewWriter(buf)
|
||||
|
||||
require.NoError(t, mw.WriteField("name", "john"))
|
||||
require.NoError(t, mw.WriteField("names", "john"))
|
||||
require.NoError(t, mw.WriteField("names", "john,eric"))
|
||||
require.NoError(t, mw.WriteField("names", "doe"))
|
||||
require.NoError(t, mw.WriteField("age", "42"))
|
||||
require.NoError(t, mw.WriteField("posts[0][title]", "post1"))
|
||||
require.NoError(t, mw.WriteField("posts[1][title]", "post2"))
|
||||
require.NoError(t, mw.WriteField("posts[2][title]", "post3"))
|
||||
|
||||
require.NoError(t, mw.Close())
|
||||
|
||||
req.Header.SetContentType(mw.FormDataContentType())
|
||||
@ -125,6 +134,11 @@ func Test_FormBinder_BindMultipart(t *testing.T) {
|
||||
require.Equal(t, 42, user.Age)
|
||||
require.Contains(t, user.Names, "john")
|
||||
require.Contains(t, user.Names, "doe")
|
||||
require.Contains(t, user.Names, "eric")
|
||||
require.Len(t, user.Posts, 3)
|
||||
require.Equal(t, "post1", user.Posts[0].Title)
|
||||
require.Equal(t, "post2", user.Posts[1].Title)
|
||||
require.Equal(t, "post3", user.Posts[2].Title)
|
||||
}
|
||||
|
||||
func Benchmark_FormBinder_BindMultipart(b *testing.B) {
|
||||
|
@ -1,9 +1,6 @@
|
||||
package binder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/utils/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -21,20 +18,21 @@ func (*HeaderBinding) Name() string {
|
||||
// Bind parses the request header and returns the result.
|
||||
func (b *HeaderBinding) Bind(req *fasthttp.Request, out any) error {
|
||||
data := make(map[string][]string)
|
||||
var err error
|
||||
req.Header.VisitAll(func(key, val []byte) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
k := utils.UnsafeString(key)
|
||||
v := utils.UnsafeString(val)
|
||||
|
||||
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
|
||||
values := strings.Split(v, ",")
|
||||
for i := 0; i < len(values); i++ {
|
||||
data[k] = append(data[k], values[i])
|
||||
}
|
||||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return parse(b.Name(), out, data)
|
||||
}
|
||||
|
||||
|
@ -107,7 +107,7 @@ func parseToStruct(aliasTag string, out any, data map[string][]string) error {
|
||||
func parseToMap(ptr any, data map[string][]string) error {
|
||||
elem := reflect.TypeOf(ptr).Elem()
|
||||
|
||||
switch elem.Kind() { //nolint:exhaustive // it's not necessary to check all types
|
||||
switch elem.Kind() {
|
||||
case reflect.Slice:
|
||||
newMap, ok := ptr.(map[string][]string)
|
||||
if !ok {
|
||||
@ -130,6 +130,8 @@ func parseToMap(ptr any, data map[string][]string) error {
|
||||
}
|
||||
newMap[k] = v[len(v)-1]
|
||||
}
|
||||
default:
|
||||
return nil // it's not necessary to check all types
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -247,3 +249,37 @@ func FilterFlags(content string) string {
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func formatBindData[T any](out any, data map[string][]string, key string, value T, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
|
||||
var err error
|
||||
if supportBracketNotation && strings.Contains(key, "[") {
|
||||
key, err = parseParamSquareBrackets(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
switch v := any(value).(type) {
|
||||
case string:
|
||||
assignBindData(out, data, key, v, enableSplitting)
|
||||
case []string:
|
||||
for _, val := range v {
|
||||
assignBindData(out, data, key, val, enableSplitting)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported value type: %T", value)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func assignBindData(out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay
|
||||
if enableSplitting && strings.Contains(value, ",") && equalFieldType(out, reflect.Slice, key) {
|
||||
values := strings.Split(value, ",")
|
||||
for i := 0; i < len(values); i++ {
|
||||
data[key] = append(data[key], values[i])
|
||||
}
|
||||
} else {
|
||||
data[key] = append(data[key], value)
|
||||
}
|
||||
}
|
||||
|
@ -1,9 +1,6 @@
|
||||
package binder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/utils/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -30,19 +27,7 @@ func (b *QueryBinding) Bind(reqCtx *fasthttp.Request, out any) error {
|
||||
|
||||
k := utils.UnsafeString(key)
|
||||
v := utils.UnsafeString(val)
|
||||
|
||||
if strings.Contains(k, "[") {
|
||||
k, err = parseParamSquareBrackets(k)
|
||||
}
|
||||
|
||||
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
|
||||
values := strings.Split(v, ",")
|
||||
for i := 0; i < len(values); i++ {
|
||||
data[k] = append(data[k], values[i])
|
||||
}
|
||||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
err = formatBindData(out, data, k, v, b.EnableSplitting, true)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
|
@ -1,9 +1,6 @@
|
||||
package binder
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/gofiber/utils/v2"
|
||||
"github.com/valyala/fasthttp"
|
||||
)
|
||||
@ -21,20 +18,22 @@ func (*RespHeaderBinding) Name() string {
|
||||
// Bind parses the response header and returns the result.
|
||||
func (b *RespHeaderBinding) Bind(resp *fasthttp.Response, out any) error {
|
||||
data := make(map[string][]string)
|
||||
var err error
|
||||
|
||||
resp.Header.VisitAll(func(key, val []byte) {
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
k := utils.UnsafeString(key)
|
||||
v := utils.UnsafeString(val)
|
||||
|
||||
if b.EnableSplitting && strings.Contains(v, ",") && equalFieldType(out, reflect.Slice, k) {
|
||||
values := strings.Split(v, ",")
|
||||
for i := 0; i < len(values); i++ {
|
||||
data[k] = append(data[k], values[i])
|
||||
}
|
||||
} else {
|
||||
data[k] = append(data[k], v)
|
||||
}
|
||||
err = formatBindData(out, data, k, v, b.EnableSplitting, false)
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return parse(b.Name(), out, data)
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user